Pytorch构建ResNet-50V2

news/2024/7/3 5:31:29
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客

  • 🍦 参考文章地址: 365天深度学习训练营-第J2周:ResNet-50V2算法实战与解析

  • 🍖 作者:K同学啊

一、ResNetV2与ResNet结构对比

改进点

(a)original 表示原始的 ResNet 的残差结构,(b)proposed 表示新的 ResNet 的残差结构。主要差别就是(a)结构先卷积后进行 BN 和激活函数计算,最后执行 addition 后再进行ReLU 计算; (b)结构先进行 BN 和激活函数计算后卷积,把 addition 后的 ReLU 计算放到了残差结构内部。

改进结果

作者使用这两种不同的结构在 CIFAR-10 数据集上做测试,模型用的是 1001层的 ResNet 模型。从图中结果我们可以看出,(b)proposed 的测试集错误率明显更低一些,达到了 4.92%的错误率,(a)original 的测试集错误率是 7.61%

二、模型实现

2.1 残差结构

''' Residual Block '''
class Block2(nn.Module):
    def __init__(self, in_channel, filters, kernel_size=3, stride=1, conv_shortcut=False):
        super(Block2, self).__init__()
        self.preact = nn.Sequential(
            nn.BatchNorm2d(in_channel),
            nn.ReLU(True)
        )
        
        self.shortcut = conv_shortcut
        if self.shortcut:
            self.short = nn.Conv2d(in_channel, 4*filters, 1, stride=stride, padding=0, bias=False)
        elif stride>1:
            self.short = nn.MaxPool2d(kernel_size=1, stride=stride, padding=0)
        else:
            self.short = nn.Identity()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, filters, 1, stride=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(filters, filters, kernel_size, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(True)
        )
        self.conv3 = nn.Conv2d(filters, 4*filters, 1, stride=1, bias=False)
    
    def forward(self, x):
        x1 = self.preact(x)
        if self.shortcut:
            x2 = self.short(x1)
        else:
            x2 = self.short(x)
        x1 = self.conv1(x1)
        x1 = self.conv2(x1)
        x1 = self.conv3(x1)
        x = x1 + x2
        return x

2.2 模块构建

class Stack2(nn.Module):
    def __init__(self, in_channel, filters, blocks, stride=2):
        super(Stack2, self).__init__()
        self.conv = nn.Sequential()
        self.conv.add_module(str(0), Block2(in_channel, filters, conv_shortcut=True))
        for i in range(1, blocks-1):
            self.conv.add_module(str(i), Block2(4*filters, filters))
        self.conv.add_module(str(blocks-1), Block2(4*filters, filters, stride=stride))
    
    def forward(self, x):
        x = self.conv(x)
        return x

2.3 网络构建

''' 构建ResNet50V2 '''
class ResNet50V2(nn.Module):
    def __init__(self,
                 include_top=True,  # 是否包含位于网络顶部的全链接层
                 preact=True,  # 是否使用预激活
                 use_bias=True,  # 是否对卷积层使用偏置
                 input_shape=[224, 224, 3],
                 classes=1000,
                 pooling=None):  # 用于分类图像的可选类数
        super(ResNet50V2, self).__init__()
        
        self.conv1 = nn.Sequential()
        self.conv1.add_module('conv', nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=use_bias, padding_mode='zeros'))
        if not preact:
            self.conv1.add_module('bn', nn.BatchNorm2d(64))
            self.conv1.add_module('relu', nn.ReLU())
        self.conv1.add_module('max_pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        
        self.conv2 = Stack2(64, 64, 3)
        self.conv3 = Stack2(256, 128, 4)
        self.conv4 = Stack2(512, 256, 6)
        self.conv5 = Stack2(1024, 512, 3, stride=1)
        
        self.post = nn.Sequential()
        if preact:
            self.post.add_module('bn', nn.BatchNorm2d(2048))
            self.post.add_module('relu', nn.ReLU())
        if include_top:
            self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))
            self.post.add_module('flatten', nn.Flatten())
            self.post.add_module('fc', nn.Linear(2048, classes))
        else:
            if pooling=='avg':
                self.post.add_module('avg_pool', nn.AdaptiveAvgPool2d((1, 1)))
            elif pooling=='max':
                self.post.add_module('max_pool', nn.AdaptiveMaxPool2d((1, 1)))
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.post(x)
        return x

三、鸟类数据集效果

数据集可视化:

 后三个epoch:

Epoch:18, Train_acc:92.9%, Train_loss:0.210, Test_acc:84.1%,Test_loss:0.538
Epoch:19, Train_acc:94.9%, Train_loss:0.160, Test_acc:89.4%,Test_loss:0.484
Epoch:20, Train_acc:92.7%, Train_loss:0.270, Test_acc:82.3%,Test_loss:0.700
Done
best_acc: 0.9491150442477876

Loss与Accuracy图:

指定图片预测:

 


http://lihuaxi.xjx100.cn/news/1007421.html

相关文章

windows服务器自带IIS搭建网站并发布公网访问【内网穿透】

文章目录1.前言2.Windows网页设置2.1 Windows IIS功能设置2.2 IIS网页访问测试3. Cpolar内网穿透3.1 下载安装Cpolar3.2 Cpolar云端设置3.3 Cpolar本地设置4.公网访问测试5.结语转载自远程源码文章:【IIS搭建网站】本地电脑做服务器搭建web站点并公网访问「内网穿透…

C++ Primer Plus编程作业

编写一个C程序&#xff0c;如下述输出示例所示的那样请求并显示信息&#xff1a; What is your first name? Betty Sue What is your last name? Yewe What letter grade do you deserve? B What is your age? 22 Name: Yewe, Betty Sue Grade: Age: 22 #include <iost…

消防基础知识——燃烧与火灾

燃烧 燃烧的本质与条件 发光、发热的剧烈的氧化反应 可燃物、助燃剂&#xff0c;火源燃烧可从着火方式&#xff0c;持续燃烧形式&#xff0c;燃烧物形态&#xff0c;燃烧现象等不同角度作不同分类。按照燃烧形成的条件和发生瞬间的特点&#xff0c;燃烧可分为着火和爆炸。液体…

nacos本地启动单节点

1.官网下载 Releases alibaba/nacos GitHub 解压文件 unzip nacos-server-2.2.1.zip cd /Users/xiaosa/dev_tools/nacos/bin sh startup.sh -m standalone 启动不成功&#xff0c;报错入如下 原因是下面的配置为空。位置在nacos/config目录下的application.properties文件…

牛客网算法八股刷题系列(八)K-Means真题描述

牛客网算法八股刷题系列——K-Means真题描述题目描述正确答案&#xff1a;A\mathcal AA题目解析题目描述 两个种子点A(−1,1),B(2,1)A(-1,1),B(2,1)A(−1,1),B(2,1)&#xff0c;其余样本点为(0,0),(0,2),(1,1),(3,2),(6,0),(6,2)(0,0),(0,2),(1,1),(3,2),(6,0),(6,2)(0,0),(0,…

人工智能前沿——「全域全知全能」人类新宇宙ChatGPT

&#x1f680;&#x1f680;&#x1f680;OpenAI聊天机器人ChatGPT——「全域全知全能」人类全宇宙大爆炸&#xff01;&#xff01;&#x1f525;&#x1f525;&#x1f525; 一、什么是ChatGPT?&#x1f340;&#x1f340; ChatGPT是生成型预训练变换模型&#xff08;Chat G…

李宏毅2021春季机器学习课程视频笔记11-卷积伸进网络(CNN)

卷积神经网络架构 图像识别 将图片拉直放入神经网络中进行训练。 网络通过对图像中的存在的特征进行分析&#xff0c;判断当前属于何种类别。 神经网络其实不需要对整个图片进行分析&#xff0c;只需要对一些特殊的信息进行分析就可以得知当前图片所属的类别&#xff0c;基于此…

JSON 元素的添加删除

javasscript删除数组的3种方法 1&#xff0c;用shift()方法 shift&#xff1a;删除原数组第一项&#xff0c;并返回删除元素的值&#xff1b;如果数组为空则返回undefined var chaomao[1,2,3,4,5] var chaomao.shift()//得到1 alert(chaomao)//[2,3,4,5] 2&#xff0c;用pop()…