PyTorch自定义损失函数实现

news/2024/7/5 2:32:08

在机器学习中,损失函数是衡量预测输出与实际输出之间差异的关键组成部分。 它在模型训练中起着至关重要的作用,因为它通过指示模型应该改进的方向来指导优化过程。 损失函数的选择取决于具体的任务和数据类型。 在本文中,我们将以用于手写数字分类的 MNIST 数据集为例,深入研究 PyTorch 中自定义损失函数的理论和实现。
在这里插入图片描述

推荐:使用 NSDT场景设计器 快速搭建 3D场景。

1、概述

MNIST 数据集是广泛用于图像分类任务的数据集,它包含 70,000 张手写数字图像,每张图像的分辨率为 28x28 像素。 任务是将这些图像分类为 10 个数字之一 (0–9)。 此任务旨在根据 MNIST 数据集中提供的训练示例训练一个模型,该模型可以准确地对手写数字的新图像进行分类。

此任务的典型方法是使用多类逻辑回归模型,它是一个 softmax 分类器。 softmax 函数将模型的输出映射到 10 个类别的概率分布。 交叉熵损失通常用作此类模型的损失函数。 交叉熵损失计算预测概率分布与实际概率分布之间的差异。

然而,在某些情况下,交叉熵损失可能不是特定任务的最佳选择。 例如,考虑一个场景,其中错误分类某些类的成本比其他类高得多。 在这种情况下,有必要使用考虑到每个类的相对重要性的自定义损失函数。

在本文中,我将向你展示如何为 MNIST 数据集实现自定义损失函数,其中误分类数字 9 的成本远高于其他数字。 我们将使用 Pytorch 作为框架,首先讨论自定义损失函数背后的理论,然后我们将展示使用 Pytorch 实现自定义损失函数。 最后,我们将使用自定义损失函数在 MNIST 数据集上训练线性模型,并评估模型的性能。

2、自定义损失函数:为什么

出于以下几个原因,实现自定义损失函数很重要:

  • Problem-specific:损失函数的选择取决于具体任务和数据类型。 可以设计自定义损失函数以更好地适应手头问题的特征,从而提高模型性能。
  • 类不平衡:在许多现实世界的数据集中,每个类中的样本数量可能非常不同。 可以设计一个自定义损失函数来考虑类别不平衡,并为不同的类别分配不同的成本。
  • 成本敏感:在某些任务中,错误分类某些类别的成本可能比其他类别高得多。 可以设计自定义损失函数以考虑每个类的相对重要性,从而产生更稳健的模型。
  • 多任务学习:可以设计自定义损失函数来同时处理多个任务。 这在需要单个模型来执行多个相关任务的情况下非常有用。
  • 正则化:自定义损失函数也可以用于正则化,有助于防止过拟合,提高模型的泛化能力。
  • 对抗性训练:自定义损失函数也可用于训练模型以抵抗对抗性攻击。
    总之,自定义损失函数可以提供一种更好地针对特定问题优化模型的方法,并且可以提供更好的性能和泛化能力。

3、PyTorch 中的自定义损失函数

MNIST 数据集包含 70,000 张手写数字图像,每张图像的分辨率为 28x28 像素。 任务是将这些图像分类为 10 个数字之一 (0–9)。 此任务的典型方法是使用多类逻辑回归模型,它是一个 softmax 分类器。 softmax 函数将模型的输出映射到 10 个类别的概率分布。 交叉熵损失通常用作此类模型的损失函数。

交叉熵损失计算预测概率分布与实际概率分布之间的差异。 通过将 softmax 函数应用于模型的输出来获得预测的概率分布。 实际的概率分布是一个one-hot vector,其中正确类别对应的元素值为1,其他元素值为0。交叉熵损失定义为:

    L = -∑(y_i * log(p_i))

其中 y_i 是类别 i 的实际概率,p_i 是类别 i 的预测概率。

然而,在某些情况下,交叉熵损失可能不是特定任务的最佳选择。 例如,考虑一个场景,其中错误分类某些类的成本比其他类高得多。 在这种情况下,有必要使用考虑到每个类的相对重要性的自定义损失函数。

在 PyTorch 中,可以通过创建 nn.Module 类的子类并覆盖 forward 方法来实现自定义损失函数。 forward 方法将预测输出和实际输出作为输入,并返回损失值。

下面是 MNIST 分类任务的自定义损失函数示例,其中错误分类数字 9 的成本远高于其他数字:

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, output, target):
        target = torch.LongTensor(target)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        mask = target == 9
        high_cost = (loss * mask.float()).mean()
        return loss + high_cost

在这个例子中,我们首先使用 nn.CrossEntropyLoss() 函数计算交叉熵损失。 接下来,我们为属于类别 9 的样本创建掩码 1,为其他样本创建掩码 0。 然后我们计算属于类别 9 的样本的平均损失。最后,我们将这个高成本损失添加到原始损失中以获得最终损失。

要使用自定义损失函数,我们需要将其实例化并将其作为参数传递给训练循环中优化器的标准参数。 以下是如何使用自定义损失函数在 MNIST 数据集上训练模型的示例:

import torch.nn as nn
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import torch.nn.functional as F
import torchvision
import os

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, output, target):
        target = torch.LongTensor(target)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        mask = target == 9
        high_cost = (loss * mask.float()).mean()
        return loss + high_cost




# Load the MNIST dataset
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=32, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=32, shuffle=True)


# Define the model, loss function and optimizer
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

network = Net()
optimizer = optim.SGD(network.parameters(), lr=0.01,
                      momentum=0.5)
criterion = CustomLoss()

# Training loop
n_epochs = 10

train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

if os.path.exists('results'):
  os.system('rm -r results')

os.mkdir('results')

def train(epoch):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % 1000 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
      torch.save(network.state_dict(), 'results/model.pth')
      torch.save(optimizer.state_dict(), 'results/optimizer.pth')

def test():
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += criterion(output, target).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))


test()
for epoch in range(1, n_epochs + 1):
  train(epoch)
  test()

此代码是 PyTorch 中 MNIST 数据集的自定义损失函数的实现。 MNIST 数据集包含 70,000 张手写数字图像,每张图像的分辨率为 28x28 像素。 任务是将这些图像分类为 10 个数字之一 (0–9)。

第一个代码块通过继承 PyTorch nn.Module 创建一个名为“CustomLoss”的自定义损失函数。 它有一个前向方法,接受两个输入; 模型的输出和目标标签。 forward 方法首先将目标标签转换为长整数张量。 然后它创建一个内置 PyTorch 交叉熵损失函数的实例,并使用它来计算模型输出和目标标签之间的损失。 接下来,它创建一个标识等于 9 的目标标签的掩码,然后将损失乘以该掩码并计算所得张量的平均值。 最后,它返回原始损失和高成本损失的均值之和。

下一个代码块使用 PyTorch 的内置数据加载实用程序加载 MNIST 数据集。 train_loader 加载训练数据集并对图像应用指定的变换,例如将图像转换为张量并归一化像素值。 test_loader 加载测试数据集并应用相同的转换。

以下代码块通过对 PyTorch nn.Module 进行子类化来定义一个名为“Net”的卷积神经网络 (CNN)。 CNN 由 2 个卷积层、2 个线性层和一些用于正则化的 dropout 层组成。 Net 类的 forward 方法依次应用卷积层和线性层,将输出传递给 ReLU 激活函数和最大池化层。 它还将 dropout 层应用于输出并返回最终输出的 log-softmax。

下一个代码块创建 Net 类的一个实例、一个优化器(随机梯度下降)和一个自定义损失函数的实例。

最后的代码块是训练循环,其中模型训练了 10 个时期。 在每个时期,模型迭代训练数据集,通过网络传递图像,使用自定义损失函数计算损失并反向传播梯度。 然后它使用优化器更新模型的参数。 它还跟踪训练损失和测试损失,并定期将当前损失打印到控制台。 此外,它会创建一个名为“results”的新目录来存储训练过程的结果和输出。

import matplotlib.pyplot as plt

fig = plt.figure()
plt.plot(train_counter, train_losses, color='blue')
plt.scatter(test_counter, test_losses, color='red')
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen')
plt.ylabel('negative log likelihood loss')
plt.show()

在这里插入图片描述

此代码在训练过程中为 MNIST 数据集创建自定义损失函数图。 该图将显示训练集和测试集的自定义损失。

它首先导入 Matplotlib 库,这是一个用于 Python 的绘图库。 然后,它使用 plt.figure() 函数创建一个具有指定大小的图形对象。

下一行代码使用 plt.plot() 函数绘制训练集的自定义损失。 它使用 train_counter 和 train_losses 变量分别作为 x 和 y 轴值。 使用 color 参数将图的颜色设置为蓝色。

然后,它使用 plt.scatter() 函数绘制测试集的自定义损失。 它使用 test_counter 和 test_losses 变量分别作为 x 和 y 轴值。 使用 color 参数将图的颜色设置为红色。

plt.legend() 函数为绘图添加图例,指示哪个绘图对应于训练损失,哪个对应于测试损失。 loc 参数设置为“右上角”,这意味着图例将位于绘图的右上角。

plt.xlabel() 和 plt.ylabel() 函数分别向绘图的 x 轴和 y 轴添加标签。 x 轴标签设置为“看到的训练示例数”,y 轴标签设置为“自定义损失”。

最后,plt.show() 函数用于显示绘图。

此代码将显示一个图,显示所见训练示例的自定义损失函数。 蓝线代表训练集的自定义损失,红点代表测试集的自定义损失。 该图将允许你查看自定义损失函数在训练过程中的表现,并评估模型的性能。

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
with torch.no_grad():
  output = network(example_data)
fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Prediction: {}".format(
    output.data.max(1, keepdim=True)[1][i].item()))
  plt.xticks([])
  plt.yticks([])

plt.show()

在这里插入图片描述

此代码显示一个图形,其中包含来自测试集的 6 个图像以及训练网络做出的相应预测。

它首先使用 enumerate() 函数循环遍历 test_loader,这是一个批量加载测试数据集的迭代器。 next() 函数用于从测试集中获取第一批示例。

example_data 变量包含图像,example_targets 变量包含相应的标签。

然后它使用 Pytorch 的 torch.no_grad() 函数,它用于临时将 requires_grad 标志设置为 false。 它将减少内存使用并加快计算速度,但也不会跟踪操作。

下一个代码块使用 plt.figure() 函数创建一个新的图形对象。 然后,它使用 for 循环迭代测试集中的前 6 个示例。 对于每个示例,它使用 plt.subplot() 函数在当前图窗中创建一个子图。 plt.tight_layout() 函数用于调整子图之间的间距。

然后它使用 plt.imshow() 函数在当前子图中显示图像。 cmap 参数设置为“灰色”以灰度显示图像,插值参数设置为“无”以显示图像而不进行任何插值。

plt.title() 函数用于为当前子图添加标题。 标题显示了网络对当前示例所做的预测。 网络的输出通过 output.data.max(1, keepdim=True)[1] 传递,它返回预测类的索引。 [i].item() 提取预测类的整数值。

plt.xticks() 和 plt.yticks() 函数分别用于从当前子图中删除 x 轴和 y 轴刻度。

最后,plt.show() 函数用于显示图形。 此代码将显示一个图形,其中包含来自测试集的 6 张图像以及经过训练的网络对其做出的相应预测。 图像以灰度显示且没有任何插值,预测类别显示为每张图像上方的标题。 这可能是一个有用的工具,可用于可视化模型在测试集上的性能并识别任何潜在问题或错误分类。

4、结束语

在本文中,我们以用于数字分类的 MNIST 数据集为例,讨论了 PyTorch 中自定义损失函数的理论和实现。 我们已经展示了如何通过继承 nn.Module 类并覆盖 forward 方法来创建自定义损失函数。 我们还提供了一个示例,说明如何使用自定义损失函数在 MNIST 数据集上训练模型。 在错误分类某些类的成本远高于其他类的情况下,自定义损失函数可能很有用。 重要的是要注意,在实现自定义损失函数时应格外小心,因为它们会对模型的性能产生重大影响。

— ‌
原文链接:Pytorch自定义损失函数 — BimAnt


http://lihuaxi.xjx100.cn/news/863849.html

相关文章

LwIP系列--线程通信消息结构

一、目的如果有小伙伴移植过LwIP,那么你肯定知道在LwIP源码中tcp/ip协议栈是作为一个单独的线程运行的,那么就有这样一个问题,我们从mac外设上收到的以太网数据包是如何交给tcp/ip线程进行处理的,用户发送的数据又是如何经过协议栈…

XXL-JOB 任务调度平台实践

XXL-JOB 任务调度平台实践一、调度中心(服务端)1、从gitbub 获取项目源码:[https://github.com/xuxueli/xxl-job](https://github.com/xuxueli/xxl-job)2、从源码中得到SQL脚本创建和初始化数据库3、Maven 编译打包 xxl-job-admin 并部署为调度中心4、启动运行 xxl-…

【云原生】解读Kubernetes三层网络方案

在上一篇文章中,我以网桥类型的 Flannel 插件为例,为你讲解了 Kubernetes 里容器网络和 CNI 插件的主要工作原理。不过,除了这种模式之外,还有一种纯三层(Pure Layer 3)网络方案非常值得你注意。其中的典型…

实战打靶集锦-004-My-Cmsms

**写在前面:**记录一次艰难曲折的打靶经历。 目录1. 主机发现2. 端口扫描3. 服务枚举4. 服务探查4.1 WEB服务探查4.1.1 浏览器访问4.1.2 目录枚举4.1.3 控制台探查4.1.4 其他目录探查4.2 阶段小结5. 公共EXP搜索5.1 CMS搜索5.2 Apache搜索5.3 PHP搜索5.4 MySQL搜索5…

vue项目第二天

项目中使用element-ui库中文网https://element.eleme.cn/#/zh-CN安装命令npm install element-ui安装按需加载babel插件npm install babel-plugin-component -Dnpm i //可以通过npm i 的指令让配置刷新重新配置一下项目中使用element-ui组件抽离文件中按需使用element ui &…

低噪声与功放选型购买

低噪声与功率放大器的区别?购买时怎么区分? 低噪放 低噪放,低噪声射频放大器。作用就是要求噪声系数很低,放大电压信号。一般放在系统第一级,因为噪声系数低,接收放大的信号有很好的的信噪比。如天线的接…

Java的异常处理

异常 异常就是程序非正常运行时的报错,不正常就是异常。 异常分类 通常分为两类: Error:错误。通常是Java虚拟机无法解决的严重问题。如:JVM系统内部错误、资源耗尽等严重情况。比如:StackOverflowError和OOM-->…

2023-02-10 - 5 文本搜索

与其他需要精确匹配的数据不同,文本数据在前期的索引构建和搜索环节都需要进行额外的处理,并且在匹配环节还要进行相关性分数计算。本章将详细介绍文本搜索的相关知识。 本章首先从总体上介绍文本的索引建立过程和搜索过程,然后介绍分析器的…