论文学习——U-Net: Convolutional Networks for Biomedical Image Segmentation

news/2024/7/5 1:37:42

UNet的特点

  • 采用端到端的结构,通过FCN(最后一层仍然是通过卷积完成),最后输出图像。
  • 通过编码(下采样)-解码(上采样)形成一个“U”型结构。每次下采样时,先进行两次卷积(通道数不变),然后通过一次池化层(也可以通过卷积)处理(长宽减半,通道数加倍);在每次上采样时,同样先进行两次卷积操作,再通过反卷积函数进行上采样(长宽加倍,通道不变),然后与编码过程中对应层进行拼接(通道加倍)。到最后一层时,通过1x1的卷积核修改通道数,最后输出目标图像。编码操作逐层提取图像特征,解码操作则逐层恢复图像信息。
  • 通过跳跃连接,将编码器结构中的底层信息与解码器结构中的高层信息融合,从而提高了分割精度。

网络结构如图所示:
在这里插入图片描述
代码实现(基于pytorch):
相关包的引入:

from math import sqrt
import torch
from torch import nn
import torch.nn.functional as F

定义卷积块:
定义了两个卷积操作,分别使用大小为3x3的卷积核进行卷积,步长为1,并且对卷积后的输出进行批量归一化(批量归一化的作用),激活函数采用ReLU。使用卷积模块时,需要指明输入通道数(in_channel)和输出通道数(out_channel)。

class Conv_Block(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Conv_Block, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        outputs = self.conv1(input)
        outputs = self.conv2(outputs)
        return outputs

编码操作(下采样):将卷积模块的输出进行池化处理。

class UnetDown(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(UnetDown, self).__init__()
        self.conv = Conv_Block(in_channel, out_channel)
        self.down = nn.MaxPool2d(2, 2, ceil_mode=True)

    def forward(self, inputs):
        outputs = self.conv(inputs)
        outputs = self.down(outputs)
        return outputs

解码操作(上采样):这里的上采样操作提出了两种——ConvTranspose2d和UpsamplingBilinear2d,两者的区别见这里。另外,由于要进行拼接操作,所以在拼接前对上采样的输出进行填充,避免拼接出错。
(ps:代码里面的解码操作是先进行上采样,然后拼接数据,最后进行卷积的,但是在UnetModel中的最后一个编码操作后,单独进行了一次卷积操作,最后的网络结构还是没有变的。)

class UnetUp(nn.Module):
    def __init__(self, in_channel, out_channel, is_deconv=True):
        super(UnetUp, self).__init__()
        self.conv = Conv_Block(in_channel, out_channel)

        if is_deconv:
            self.up = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2)
        else:
            self.up = nn.UpsamplingBilinear2d(scale_factor=2)

    def forward(self, inputs1, inputs2):
        outputs2 = self.up(inputs2)
        offset1 = (outputs2.size()[2] - inputs1.size()[2])
        offset2 = (outputs2.size()[3] - inputs1.size()[1])
        # pad传入四个元素时,指的是左填充,右填充,上填充,下填充;前两个元素作用在第一四维,后两个元素作用在第三维
        padding = [offset2 // 2, (offset2 + 1) // 2, offset1 // 2, (offset1 + 1) // 2]
        # Skip and concatenate
        outputs1 = F.pad(inputs1, padding)
        return self.conv(torch.cat([outputs1, outputs2], 1))

最后定义整个UNet模块:将代码和网络结构的图结合起来看就很容易理解了。

class UnetModel(nn.Module):
    def __init__(self, n_classes, in_channels, is_deconv):
        super(UnetModel, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.n_classes = n_classes

        filters = [64, 128, 256, 512, 1024]

        self.down1 = UnetDown(self.in_channels, filters[0])
        self.down2 = UnetDown(filters[0], filters[1])
        self.down3 = UnetDown(filters[1], filters[2])
        self.down4 = UnetDown(filters[2], filters[3])
        self.center = Conv_Block(filters[3], filters[4])
        self.up4 = UnetUp(filters[4], filters[3], self.is_deconv)
        self.up3 = UnetUp(filters[3], filters[2], self.is_deconv)
        self.up2 = UnetUp(filters[2], filters[1], self.is_deconv)
        self.up1 = UnetUp(filters[1], filters[0], self.is_deconv)
        self.final = nn.Conv2d(filters[0], self.n_classes, 1)

    def forward(self, inputs, label_dsp_dim):
        down1 = self.down1(inputs)
        down2 = self.down2(down1)
        down3 = self.down3(down2)
        down4 = self.down4(down3)
        center = self.center(down4)
        up4 = self.up1(down4, center)
        up3 = self.up2(down3, up4)
        up2 = self.up3(down2, up3)
        up1 = self.up4(down1, up2)
        up1 = up1[:, :, 1:1 + label_dsp_dim[0], 1:1 + label_dsp_dim[1]].contiguous()
        return self.final(up1)

    # Initialization of parameters
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.ConvTranspose2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()

总结:UNet 是一种经典的图像分割网络,它通过编码器-解码器结构、跳跃连接和多尺度特征融合等设计,能够在图像分割任务中取得优秀的性能。基于UNet还衍生出了很多网络,例如 U-Net++, ResUNet, Dense U-Net等,接下来就学习它的衍生网络吧,学习大佬是怎么魔改网络的~另外,刚开始写深度学习的代码时,我不知道从何下手,通过学习大佬实现代码的过程,我发现结合两点就能轻松实现代码:1)写代码时结合网络结构的图片,2)百度相关操作的函数。


http://lihuaxi.xjx100.cn/news/1319495.html

相关文章

VUE项目打包成apk

在我们的开发需求中,可能会遇到需要将vue项目中的H5代码打包成一个安卓的app,那么我为大家介绍一套保姆级的解决方案,看完你就会。 VUE HBuilder 1.准备工作: 需要下载一个HBuilder X编辑器,不过我相信大家身为前端…

代码随想录算法训练营第59天 | 503.下一个更大元素 II + 42.接雨水

今日任务 目录 503.下一个更大元素 II - Medium 42.接雨水 - Hard 503.下一个更大元素 II - Medium 题目链接:力扣-503. 下一个更大元素 II 给定一个循环数组 nums ( nums[nums.length - 1] 的下一个元素是 nums[0] ),返回 nu…

【C语言基础】函数(2)

在函数(1)中我们已经讲过了函数的定义,形参与实参,函数的调用,局部变量与栈内存 接下来还有几个要强调的函数相关知识。 一、静态函数 静态函数是在函数声明前加上关键字 static 的函数。静态函数在C语言中具有以…

MySQL的高可用性方案有哪些?MySQL的字段类型如何选择和优化?MySQL的并发控制机制是怎样的?MySQL的全文搜索如何实现?

1、MySQL的高可用性方案有哪些? MySQL的高可用性方案有以下几种: 主从复制(Master-Slave Replication):这是MySQL最常用的高可用性方案之一。在主从复制中,一个主数据库(Master)接收…

2023-07-07 LeetCode每日一题(过桥的时间)

2023-07-07每日一题 一、题目编号 2532. 过桥的时间二、题目链接 点击跳转到题目位置 三、题目描述 共有 k 位工人计划将 n 个箱子从旧仓库移动到新仓库。给你两个整数 n 和 k,以及一个二维整数数组 time ,数组的大小为 k x 4 ,其中 tim…

Vue组件库Element-常见组件-分页

常见组件-Pagination 分页 Pagination 分页&#xff1a;当数据过多时&#xff0c;会使用分页分解数据 具体关键代码如下&#xff1a;&#xff08;重视注释&#xff09; <template><div><!-- Pagination 分页 --><el-pagination background layout"…

如何使用ChatGPT的API(四)思维链推理

在回答一个具体问题之前&#xff0c;模型对问题进行详细的推理是很重要的。有时&#xff0c;模型可能会因为急于得出结论而犯推理错误&#xff0c;所以我们可以仔细设计prompt&#xff0c;要求在模型提供最终答案之前进行一系列相关的推理步骤&#xff0c;这样它就可以更长时间…

“量贩零食”热潮袭来:真风口还是假繁荣?

以前只听过量贩式KTV&#xff0c;现在“量贩零食店”也出现在了大街小巷。 高考结束后&#xff0c;家住武汉的花花频繁逛起了量贩零食店。这类店把各种零食集合在一起销售&#xff0c;用低价来换取高销量&#xff0c;主打一个性价比。店里的散装零食即便按斤售卖&#xff0c;也…