学习pytorch20 pytorch完整的模型验证套路

news/2024/7/5 7:13:23

pytorch完整的模型验证套路

  • 使用非数据集的测试数据,测试训练好模型的效果
  • 代码
  • 预测结果
  • 解决报错

B站小土堆pytorch学习视频 https://www.bilibili.com/video/BV1hE411t7RN/?p=32&spm_id_from=pageDriver&vd_source=9607a6d9d829b667f8f0ccaaaa142fcb

在这里插入图片描述

使用非数据集的测试数据,测试训练好模型的效果

 测试:训练好的模型,提供对外真实数据的一个实际应用

从网上下载两张图片,整理图片的输入格式,输入模型测试模型效果
请添加图片描述
请添加图片描述

代码

import torch
from torch import nn
from torchvision import transforms
from PIL import Image
import cv2

dog_path = './images/dog.jpg'
airplane_path = './images/airplane.jpg'
model_path = './images/net_epoch9_gpu.pth'

dog_pil = Image.open(dog_path)
airp_pil = Image.open(airplane_path)
print(dog_pil)  # <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=258x174 at 0x237CC2CBE50>
# RGB 3通道 匹配模型输入的通道数
dog_pil = dog_pil.convert('RGB')  # def convert(self, mode=None, matrix=None, dither=None, palette=Palette.WEB, colors=256):
airp_pil = airp_pil.convert('RGB')
# dog_cv = cv2.imread(dog_path)  # numpy.array
# # print(dog_cv)
# img_trans = torchvision.transforms.ToTensor()  # 实例化转tensor的类
# dog_tensor = img_trans(dog_pil)
# dog_cv_tensor = img_trans(dog_cv)
# print(dog_tensor)
# print(dog_tensor.shape)
# print(dog_cv_tensor)
# 输入模型shape 需要是32*32大小的
transform = transforms.Compose([transforms.Resize((32,32)),
                                transforms.ToTensor()])
dog_tensor = transform(dog_pil)
airp_tensor = transform(airp_pil)
# print(dog_tensor)
print(dog_tensor.shape, airp_tensor.shape)


class Cifar10Net(nn.Module):
    def __init__(self):
        super(Cifar10Net, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.net(x)
        return x

# 加载模型要考虑是以哪种形式保存的模型  模型保存方式1:保存模型结构和参数 方式二:只保存模型参数
model = torch.load(model_path, map_location=torch.device('cpu'))
print(model)
dog_tensor = dog_tensor.reshape((1, 3, 32, 32))
airp_tensor = airp_tensor.reshape((1, 3, 32, 32))
model.eval() # 设置模型为测试状态 网络层有dropout batchNormal层不加eval函数会有问题
with torch.no_grad(): # 测试不做梯度计算 节省算力
    dog_output = model(dog_tensor)
    airp_output = model(airp_tensor)
print(dog_output)
print(dog_output.argmax())
print(dog_output.argmax(1))
print(airp_output)
print(airp_output.argmax(1))  # 概率值不便于解读 使用argmax 可以很方便的读出模型预测的是哪个类别

预测结果

在这里插入图片描述

tensor([[ 1.1317, -4.3441,  3.2116,  2.8930,  2.6749,  4.6079, -3.2860,  3.1357,
         -3.0432, -4.1703]])
tensor(5)
tensor([5])
tensor([[ 5.5993, -0.6140,  4.4758,  0.8463,  1.6311, -1.0217, -3.9990, -2.8343,
          1.1050, -1.6423]])
tensor([0])

预测结果和训练数据的标注一直,预测正确

解决报错

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x16 and 1024x64)
解决: dog_tensor = dog_tensor.reshape((1, 3, 32, 32)) 转换输入是4维的, 模型输入有一个batch-size维度

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device(‘cpu’) to map your storages to the CPU.
解决:model = torch.load(model_path, map_location=torch.device(‘cpu’))
在gpu上训练的模型,要在cpu上测试,模型加载的时候指定cpu设备


http://lihuaxi.xjx100.cn/news/1886167.html

相关文章

Python数据科学视频讲解:Python保留字与标识符

2.6 Python保留字与标识符 视频为《Python数据科学应用从入门到精通》张甜 杨维忠 清华大学出版社一书的随书赠送视频讲解2.6节内容。本书已正式出版上市&#xff0c;当当、京东、淘宝等平台热销中&#xff0c;搜索书名即可。内容涵盖数据科学应用的全流程&#xff0c;包括数据…

Web漏洞分析-文件解析及上传(上)

随着互联网的迅速发展&#xff0c;网络安全问题变得日益复杂&#xff0c;而文件解析及上传漏洞成为攻击者们频繁攻击的热点之一。本文将深入研究文件解析及上传漏洞&#xff0c;通过对文件上传、Web容器IIS、命令执行、Nginx文件解析漏洞以及公猫任意文件上传等方面的细致分析&…

lwIP 细节之三:errf 回调函数是何时调用的

使用 lwIP 协议栈进行 TCP 裸机编程&#xff0c;其本质就是编写协议栈指定的各种回调函数。将你的应用逻辑封装成函数&#xff0c;注册到协议栈&#xff0c;在适当的时候&#xff0c;由协议栈自动调用&#xff0c;所以称为回调。 注&#xff1a;除非特别说明&#xff0c;以下内…

业内首份!工业领域数据安全政策汇编发布(附下载)

在工业领域&#xff0c;数据是贯穿工业互联网的“血液”&#xff0c;是提质降本增效的关键。工业数据增长迅速、种类繁多、体量庞大&#xff0c;数据安全已成为保障工业发展&#xff0c;保障社会稳定的要点。 随着工业企业信息化的普及以及工业互联网的快速发展&#xff0c;工…

索引的使用

索引是一种数据结构&#xff0c;用于快速查找数据库中的数据。索引可以加快查询的速度&#xff0c;并减少数据库的负载和响应时间。以下是使用索引的一些方法&#xff1a; 1.创建索引&#xff1a;可以通过CREATE INDEX语句创建索引。在创建索引时&#xff0c;需要指定要创建索…

关于代码质量度量和分析的一些总结

最近团队做CMMI3认证&#xff0c;这期间涉及到了代码质量度量。花了点时间做了总结&#xff0c;分享给大家。 先看一张整体的图&#xff0c;然后逐个指标展开说明。 一、单元测试覆盖率 单元测试覆盖率&#xff08;Coverage&#xff09;是一个度量单元测试覆盖了多少代码的指标…

java面试题-线程、线程池的了解及工作原理、拒绝策略

远离八股文&#xff0c;面试大白话&#xff0c;通俗且易懂 看完后试着用自己的话复述出来。有问题请指出&#xff0c;有需要帮助理解的或者遇到的真实面试题不知道怎么总结的也请评论中写出来&#xff0c;大家一起解决。 java面试题汇总-目录-持续更新中 这篇还蛮好理解的&…

2023年最新prometheus + grafana搭建和使用+gmail邮箱告警配置

一、安装prometheus 1.1 安装 prometheus官网下载地址 sudo -i mkdir -p /opt/prometheus #移动解压后的文件名到/opt/,并改名prometheus mv prometheus-2.45 /opt/prometheus/ #创建一个专门的prometheus用户&#xff1a; -M 不创建家目录&#xff0c; -s 不让登录 useradd…