Pytorch提取参数及自定义初始化

news/2024/7/7 18:53:49

点击上方“视学算法”,选择加"星标"或“置顶

重磅干货,第一时间送达69a3059ffd38cbfe590f9f81a0db3251.png

作者丨李元芳@知乎(已授权)

来源丨https://zhuanlan.zhihu.com/p/52297770

编辑丨极市平台

导读

 

有时候提取出的层结构并不够,还需要对里面的参数进行初始化,那么如何提取出网络的参数并对其初始化呢?本文对其进行简单的介绍。 

首先 nn.Module 里面有两个特别重要的关于参数的属性,分别是 named_parameters()和 parameters()。named_parameters() 是给出网络层的名字和参数的迭代器,parameters()会给出一个网络的全部参数的选代器。

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.init as init
import argparse
import torch.autograd.variable as variableclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN,self).__init__()  #b,3,32,32layer1=nn.Sequential()layer1.add_module('conv1',nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,stride=1,padding=1))#b,32,32,32layer1.add_module('relu1',nn.ReLU(True))layer1.add_module('pool1',nn.MaxPool2d(2,2))#b,32,16,16self.layer1=layer1layer2=nn.Sequential()layer1.add_module('conv2',nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,stride=1,padding=1))#b,64,16,16layer2.add_module('relu2',nn.ReLU(True))layer2.add_module('pool2',nn.MaxPool2d(2,2))#b,64,8,8self.layer2=layer2layer3=nn.Sequential()layer3.add_module('conv3', nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3 ,stride=1, padding = 1)) #b,128,8,8layer3.add_module('relu3', nn.ReLU(True))layer3.add_module('poo13', nn.MaxPool2d(2, 2))#b,128,4,4self.layer3=layer3layer4 =nn.Sequential()layer4.add_module('fc1',nn.Linear(in_features=2048, out_features=512 ))layer4.add_module('fc_relu1', nn.ReLU(True))layer4.add_module('fc2 ', nn.Linear(in_features=512, out_features=64 ))layer4.add_module('fc_relu2', nn.ReLU(True))layer4.add_module('fc3', nn.Linear(64, 10))self.layer4 = layer4def forward(self,x):conv1=self.layer1(x)conv2=self.layer2(conv1)conv3=self.layer3(conv2)fc_input=conv3.view(conv3.size(0),-1)fc_output=self.layer4(fc_input)return fc_output
model=SimpleCNN()
for param in model.named_parameters():print(param[0])

可以得到每一层参数的名字,输出为

4d294dad3c0d1deb7d886088802ef0dc.png

如何对权重做初始化呢 ? 非常简单,因为权重是一个 Variable ,所以只需要取出其中的 data 属性,然后对它进行所需要的处理就可以了。

for m in model.modules():if isinstance(m,nn.Conv2d):init.normal(m.weight.data) #通过正态分布填充张量init.xavier_normal(m.weight.data) 
#xavier均匀分布的方法来init,来自2010年的论文“Understanding the difficulty of training deep feedforward neural networks”init.kaiming_normal(m.weight.data) 
#来自2015年何凯明的论文“Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification”m.bias.data.fill_(0)elif isinstance(m,nn.Linear):m.weight.data.normal_()

通过上面的操作,对将卷积层中使用 PyTorch 里面提供的方法的权重进行初始化,这样就能够使用任意我们想使用的初始化,甚至我们可以自己定义初始化方法并对权重进行初始化 。

更多初始化方法参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/nn_init/

cb76beca17310b84baf33905024a05c1.png

outside_default.png

点个在看 paper不断!


http://lihuaxi.xjx100.cn/news/260084.html

相关文章

jdk1.8

最好不要改变安装路径转载于:https://www.cnblogs.com/jcfxl/p/7552080.html

Java 8 开发的 4 大技巧

欢迎关注方志朋的博客,回复”666“获面试宝典Java 开发过程经常需要编写有固定格式的代码,例如说声明一个私有变量,logger或者bean等等。对于这种小范围的代码生成,我们可以利用 IDEA 提供的 Live Templates功能。刚开始觉得它只是…

程序的编译、链接和执行

参考:程序的编译、链接和执行 - 知乎 处理C语言程序: 预处理、编译、汇编、链接、加载 预处理(Preprocessing) 翻译一段 C 语言程序的第一步是预处理。这一步主要处理所有以“#”号开头的行。比如当我们遇到 #include "he…

重磅!深度学习知识总结和调参技巧开放下载了

近年来,人工智能正在进入一个蓬勃发展的新时期,这主要得益于深度学习和CV领域近年来的发展和成就。在这其中,卷积神经网络的成功也带动了更多学术和商业应用的发展和进步。为了避免“内卷”,更多人选择学习进阶,但是仍…

强化学习,路在何方?

↑↑↑关注后"星标"Datawhale每日干货 & 每月组队学习,不错过Datawhale干货 来源:DeepRL实验室,转自:睿慕课▌一、深度强化学习的泡沫2015年,DeepMind的Volodymyr Mnih等研究员在《自然》杂志上发表论文…

ElasticSearch中结构化查询(term、terms、range、exists、match、bool)

term查询 term 主要用于精确匹配哪些值,比如数字,日期,布尔值或 not_analyzed 的字符串(未经分析的文本数据类型): { "term": { "age": 26 }} { "term": { "date": "2014-09-01&q…

Quartz定时任务学习(四)调度器

org.quartz.Scheduler 类层次 作为一个 Quartz 用户,你要与实现了 org.quartz.Scheduler 接口的类交互。在你调用它的任何 API 之前,你需要知道如何创建一个 Scheduler 的实例。取而代之的是用了某个工厂方法来确保了构造出 Sheduler 实例并正确的得到初…

强大的 IDEA 代码生成

欢迎关注方志朋的博客,回复”666“获面试宝典Java 开发过程经常需要编写有固定格式的代码,例如说声明一个私有变量,logger或者bean等等。对于这种小范围的代码生成,我们可以利用 IDEA 提供的 Live Templates功能。刚开始觉得它只是…