【模型复现】Network in Network,将1*1卷积引入网络设计,运用全局平均池化替代全连接层。模块化设计网络

news/2024/7/5 7:04:19
  • 《Network In Network》是一篇比较老的文章了(2014年ICLR的一篇paper),是当时比较厉害的一篇论文,同时在现在看来也是一篇非常经典并且影响深远的论文,后续很多创新都有这篇文章的影子。[1312.4400] Network In Network (arxiv.org)这篇文章采用较少参数就取得了Alexnet的效果,Alexnet参数大小为230M,而Network In Network仅为29M。

  • 卷积网络通常由卷积和池化交替堆叠,最后接全连接完成模型构建,卷积通过线性滤波器对应特征图位置相乘并求和,然后进行非线性激活得到特征图。线性模型足以抽象线性可分的隐含特征,但是实际上这些特征通常是高度非线性的,常规的卷积网络则可以通过采用一组超完备滤波器(尽可能多)提取统一潜在特征各种变体(宁可错杀一千不可放过一个),但是同一潜在特征使用太多的滤波器会给下一层带来额外的负担,需要考虑来自前一层的所有变化的组合,来自更高层的滤波器会映射到原始输入的更大区域,它通过结合下层的较低级概念生成较高级的特征,因此作者认为网络局部模块做出更好的特征抽象会更好,顺势引入Network in Network则能达到这个目标,在每个卷积层内引入一个微型网络,来计算和抽象每个局部块的特征

  • 论文Network in Network的网络结构中有由两处新的结构(当时),MLP Convolution Layers和Global Average Pooling。所谓MLPConv其实就是在常规卷积(感受野大于1的)后接若干1x1卷积,每个特征图视为一个神经元,特征图通过1x1卷积就类似多个神经元线性组合,这样就像是MLP(多层感知机)了,这是文章最大的创新点,也就是Network in Network(网络中内嵌微型网络)。径向基(Radial basis network)和 从多层感知机(multilayer perceptron)是两种通用的函数逼近器,作者选择了多层感知机,因为多层感知器与卷积神经网络的结构一样,都是通过反向传播训练。其次多层感知器本身就是一个深度模型,符合特征再利用的原则。NIN(Network in Network)学习笔记_nin函数_SyGoing的博客-CSDN博客

  • 普通卷积层(感受野大于1)及文中提到的GLM(generalized linear model)相当于单层网络,抽象能力有限。为了提高特征的抽象表达能力,作者用MLPConv代替了GLM。 n为网络层数,第一层为线性卷积层(卷积核尺寸大于1),后面的为1x1卷积。

    • 在这里插入图片描述

    • ( a ) f i , j , k = m a x ( w k T x i , j , 0 ) ( b ) f i , j , k 1 1 = m a x ( w k 1 1 T x i , j + b k 1 , 0 ) . . . f i , j , k n n = m a x ( w k n n T f i , j n − 1 + b k n , 0 ) (a)f_{i,j,k}=max(w^T_kx_{i,j},0)\\ (b)f_{i,j,k_1}^1=max({w_{k_1}^1}^Tx_{i,j}+b_{k_1},0)\\ ...\\ f_{i,j,k_n}^n=max({w_{k_n}^n}^Tf_{i,j}^{n-1}+b_{k_n},0)\\ (a)fi,j,k=max(wkTxi,j,0)(b)fi,j,k11=max(wk11Txi,j+bk1,0)...fi,j,knn=max(wknnTfi,jn1+bkn,0)

    • 1x1卷积作为NIN函数逼近器基本单元,除了增强了网络局部模块的抽象表达能力外,在现在看来还可以实现跨通道特征融合和通道升维降维。

  • 当时作者应该是第一个使用1x1卷积的,具有划时代的意义,之后的Googlenet借鉴了1*1卷积,还专门致谢过这篇论文,现在很多优秀的网络结构都离不开1x1卷积,ResNet、ResNext、SqueezeNet、MobileNetv1-3、ShuffleNetv1-2等等。

  • 传统卷积神经网络在网络的浅层进行卷积运算。对于分类任务,最后一个卷积层得到的特征图被向量化(flatten)然后送入全连接层,接一个softmax逻辑回归层。这种结构将卷积结构与传统神经网络分类器连接起来,卷积层作为特征提取器,得到的特征用传统神经网络进行分类。全连接层参数量是非常庞大的,模型通常会容易过拟合,针对这个问题,Hinton提出Dropout方法来提高泛化能力,但是全连接的计算量依旧很大。

  • 基于此,论文提出用全局平均池化代替全连接层,具体做法是对最后一层的特征图进行平均池化,得到的结果向量直接输入softmax层。这样做好处之一是使得特征图与分类任务直接关联,另一个优点是全局平均池化不需要优化额外的模型参数,因此模型大小和计算量较全连接大大减少,并且可以避免过拟合。

  • 知道NIN的基本单元,整体网络结构为Input+MLPConv+GAP+softmax,网络结构示意图如下:

    • 在这里插入图片描述
  • Network in Network对常规卷积网络的特征提取抽象表示进行改进,提出MLPconv,其实就是在常规卷积后接1x1卷积(首次使用1x1卷积),首次采用全局平均池化降低网络复杂度,避免过拟合,在之后的很多经典论文中都有用到,具有开创性意义;深度学习发展迅猛,论文很多,但是经典的还是少数,所以很值得学习,以前的ResNet,MobileNetv1-3,ShuffleNetv1-2等等。

pytorch复现NIN

  • 导包,查看配置信息

  • import time
    import torch
    import torchvision
    from torch import nn, optim
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(torch.__version__)
    print(device)
    
  • 1.13.1
    cpu
    
  • NIN模块及模型构建

  • def nin_block(in_channels, out_channels, kernel_size, stride, padding):
        blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                            nn.ReLU(),
                            nn.Conv2d(out_channels, out_channels, kernel_size=1),
                            nn.ReLU(),
                            nn.Conv2d(out_channels, out_channels, kernel_size=1),
                            nn.ReLU())
        return blk  # convnext模块和它好像
    class FlattenLayer(torch.nn.Module):
        def __init__(self):
            super(FlattenLayer, self).__init__()
        def forward(self, x): # x shape: (batch, *, *, ...)
            return x.view(x.shape[0], -1)
    net = nn.Sequential(
        nin_block(1, 96, kernel_size=11, stride=4, padding=0),
        nn.MaxPool2d(kernel_size=3, stride=2),
        nin_block(96, 256, kernel_size=5, stride=1, padding=2),
        nn.MaxPool2d(kernel_size=3, stride=2),
        nin_block(256, 384, kernel_size=3, stride=1, padding=1),
        nn.MaxPool2d(kernel_size=3, stride=2), 
        nn.Dropout(0.5),
        # 标签类别数是10
        nin_block(384, 10, kernel_size=3, stride=1, padding=1),
        # 全局平均池化层可通过将窗口形状设置成输入的高和宽实现
        nn.AvgPool2d(kernel_size=5),
        # 将四维的输出转成二维的输出,其形状为(批量大小, 10)
        FlattenLayer())
    X = torch.rand(1, 1, 224, 224)
    for name, blk in net.named_children(): 
        X = blk(X)
        print(name, 'output shape: ', X.shape)
    
  • 0 output shape:  torch.Size([1, 96, 54, 54])
    1 output shape:  torch.Size([1, 96, 26, 26])
    2 output shape:  torch.Size([1, 256, 26, 26])
    3 output shape:  torch.Size([1, 256, 12, 12])
    4 output shape:  torch.Size([1, 384, 12, 12])
    5 output shape:  torch.Size([1, 384, 5, 5])
    6 output shape:  torch.Size([1, 384, 5, 5])
    7 output shape:  torch.Size([1, 10, 5, 5])
    8 output shape:  torch.Size([1, 10, 1, 1])
    9 output shape:  torch.Size([1, 10])
    
  • 获取数据和训练模型

  • import sys
    batch_size = 32
    # 如出现“out of memory”的报错信息,可减小batch_size或resize
    def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):
        """Download the fashion mnist dataset and then load into memory."""
        trans = []
        if resize:
            trans.append(torchvision.transforms.Resize(size=resize))
        trans.append(torchvision.transforms.ToTensor())
        transform = torchvision.transforms.Compose(trans)
        mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)
        mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)
        if sys.platform.startswith('win'):
            num_workers = 0  # 0表示不用额外的进程来加速读取数据
        else:
            num_workers = 4
        train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
        return train_iter, test_iter
    def evaluate_accuracy(data_iter, net, device=None):
        if device is None and isinstance(net, torch.nn.Module):
            # 如果没指定device就使用net的device
            device = list(net.parameters())[0].device 
        acc_sum, n = 0.0, 0
        with torch.no_grad():
            for X, y in data_iter:
                if isinstance(net, torch.nn.Module):
                    net.eval() # 评估模式, 这会关闭dropout
                    acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                    net.train() # 改回训练模式
                else: # 
                    if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数
                        # 将is_training设置成False
                        acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() 
                    else:
                        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 
                n += y.shape[0]
        return acc_sum / n
    def mytrain(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):
        net = net.to(device)
        print("training on ", device)
        loss = torch.nn.CrossEntropyLoss()
        for epoch in range(num_epochs):
            train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
            for X, y in train_iter:
                X = X.to(device)
                y = y.to(device)
                y_hat = net(X)
                l = loss(y_hat, y)
                optimizer.zero_grad()
                l.backward()
                optimizer.step()
                train_l_sum += l.cpu().item()
                train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
                n += y.shape[0]
                batch_count += 1
            test_acc = evaluate_accuracy(test_iter, net)
            print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
                  % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
    train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)
    lr, num_epochs = 0.002, 5
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    mytrain(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
    
  • training on  cpu
    epoch 1, loss 2.2931, train acc 0.106, test acc 0.100, time 2587.8 sec
    epoch 2, loss 2.3026, train acc 0.100, test acc 0.100, time 2413.8 sec
    epoch 3, loss 2.3026, train acc 0.100, test acc 0.100, time 2336.9 sec
    epoch 4, loss 2.3026, train acc 0.100, test acc 0.100, time 2333.4 sec
    epoch 5, loss 2.3026, train acc 0.100, test acc 0.100, time 2333.9 sec
    
  • 针对分类任务提出了一个新的深度网络NIN。这种新网络包括mlpconv层(使用MLP来进行卷积)以及全局平均池化层(取代FC层)。mlpconv层对局部块特征提取更好,全局平均池化可以作为正则化器来防止全局的过拟合。我们使用这种结构在几种数据集上取得了目前最好的效果。通过特征图的可视化,我们验证了最后一层mlpconv输出的特征图是类别的信度图,同时也提升了使用NIN进行目标检测的可能性。


http://lihuaxi.xjx100.cn/news/1010981.html

相关文章

ambari源码分析 -----ambari-server启动流程

一、启动脚本分析 1、ambari的启动脚本为:service ambari-server start 或者 ambari-server start。分别对应脚本文件/etc/init.d/ambari-server 和 /usr/sbin/ambari-server,其中/usr/sbin/ambari-server文件是一个快捷方式,指向/etc/init.d…

Ubuntu常用环境配置

配置软件源 切换清华源 sudo sed -i "shttp://.*archive.ubuntu.comhttps://mirrors.tuna.tsinghua.edu.cng" /etc/apt/sources.list sudo sed -i "shttp://.*security.ubuntu.comhttps://mirrors.tuna.tsinghua.edu.cng" /etc/apt/sources.list sudo ap…

Scrapy-爬虫多开技能

我们知道,现在运行Scrapy项目中的爬虫文件,需要一个一个地运行,那么是否可以将对应的爬虫文件批量运行呢?如果可以,又该怎么实现呢?在Scrapy中,如果想批量运行爬虫文件,常见的有两种…

IPv4 和 IPv6 的组成结构和对比

IPv4 和 IPv6 的组成结构和对比IPv4IPv6互联网协议 (IP) 是互联网通信的基础,IP 地址是互联网上每个设备的唯一标识符。目前最常用的 IP 协议是 IPv4,它已经有近 30 年的历史了。然而,IPv4 存在一些问题,例如: 地址空间不足:IPv4 …

向量的内积外积哈达玛积

1.向量的内积 1.1 定义 从代数角度看&#xff0c;先对两个数字序列中的每组对应元素求积&#xff0c;再对所有积求和&#xff0c;结果即为点积。从几何角度看&#xff0c;点积则是两个向量的长度与它们夹角余弦的积。 表示形式&#xff1a;ATBA^TBATB、<A,B><A,B&g…

[学习笔记]金融风控实战

参考资料&#xff1a; 零基础入门金融风控-贷款违约预测 导包 import pandas as pd import matplotlib.pyplot as plt# 读取数据 train pd.read_csv(train.csv) testA pd.read_csv(testA.csv) print(Train data shape:, train.shape) print(testA data shape:, testA.shape…

我的面试八股(JAVA并发)

程序计数器为什么是线程私有的? 程序计数器主要有下面两个作用&#xff1a; 字节码解释器通过改变程序计数器来依次读取指令&#xff0c;从而实现代码的流程控制&#xff0c;如&#xff1a;顺序执行、选择、循环、异常处理。在多线程的情况下&#xff0c;程序计数器用于记录…

Redis 客户端连接服务器失败

公司项目开发环境需要使用到 Redis&#xff0c;申请基础技术支撑平台的 Redis 中间件比较麻烦&#xff0c;项目组也不知道具体流程&#xff0c;而且时间可能比较长。 现在的情况是&#xff0c;项目因为 Redis 启动报错。 这种情况下&#xff0c;我们项目组就自行在虚拟机上临…