torch.autograd.Function的使用

news/2024/7/7 20:01:51

(个人理解仅供参考)

1 什么情况下使用

自己定义的网络结构,没有现成的,就得手写forward和backward

2 怎么使用

2.1 forward

前向传播的表达式

2.2 backward

求导结果

2.3 举例

前向传播表达式:y = w * x + b
假设f()是我们关于y的loss函数,那么z = f(y)即为loss值
现在要求loss对w、x、b的偏导(假设只有一层):
dz/dx = dz/dy * dy/dx = dz/dy * w
dz/dw = dz/dy * dy/dw = dz/dy * x
dz/db = dz/dy * dy/db = dz/dy * 1
好在dz/dy不用我们再求了,它就是 backward 的参数grad_output。那么grad_output是从哪来的呢?其实就是 forward 会 return output 给 backward ,至于 backward 怎么把 output 变为 grad_output 就不用细究了。
所以:
dz/dx = grad_output * w
dz/dw = grad_output * x
dz/db = grad_output * 1

因此,对于y = w * x + b,我们的代码为:

import torch
from torch.autograd import Function

class MultiplyAdd(Function):

    @staticmethod
    def forward(ctx, w, x, b):
        ctx.save_for_backward(w, x)	 # 保存参数
        output = w * x + b
        return output	# 传给backward

    @staticmethod
    def backward(ctx, grad_output):
        w, x = ctx.saved_tensors
        grad_w = grad_output * x
        grad_x = grad_output * w
        grad_b = grad_output * 1
        return grad_w, grad_x, grad_b	# 传给forward


Linear = MultiplyAdd.apply

2.4 模板

"""
# 使用autograd.Function进行扩展的一个模板
class My_Function(Function):
    def forward(self, inputs, parameters):
        self.saved_for_backward = [inputs, parameters]
        # output = [对输入和参数进行的操作,其实就是前向运算的函数表达式]
        return output
 
    def backward(self, grad_output):
        inputs, parameters = self.saved_tensors # 或self.saved_variables
        # grad_input = [求函数forward(input)关于 parameters 的导数,其实就是反向运算的导数表达式] * grad_output
        return grad_input
"""

2.5 验证

验证的话需要使用torch.autograd.gradcheck,给上我的完整代码,验证部分在最后:

import torch
from torch.autograd import Function, gradcheck

"""
# 使用autograd.Function进行扩展的一个模板
class My_Function(Function):
    def forward(self, inputs, parameters):
        self.saved_for_backward = [inputs, parameters]
        # output = [对输入和参数进行的操作,其实就是前向运算的函数表达式]
        return output
 
    def backward(self, grad_output):
        inputs, parameters = self.saved_tensors # 或者是self.saved_variables
        # grad_input = [求函数forward(input)关于 parameters 的导数,其实就是反向运算的导数表达式] * grad_output
        return grad_input
"""


class MultiplyAdd(Function):

    @staticmethod
    def forward(ctx, w, x, b):
        ctx.save_for_backward(w, x)
        output = w * x + b
        return output

    @staticmethod
    def backward(ctx, grad_output):
        w, x = ctx.saved_tensors
        grad_w = grad_output * x
        grad_x = grad_output * w
        grad_b = grad_output * 1
        return grad_w, grad_x, grad_b


Linear = MultiplyAdd.apply

x = torch.ones(1, requires_grad=True, dtype=torch.float64)
w = torch.rand(1, requires_grad=True, dtype=torch.float64)
b = torch.rand(1, requires_grad=True, dtype=torch.float64)

# print("start forward...")
# z = MultiplyAdd.apply(w, x, b)
# print("start backward...")
# z.backward()
#
# print(x.grad, w.grad, b.grad)

test = gradcheck(Linear, (x, w, b), eps=1e-6)
print(test)

3 存疑

现在我只是会用了这个,但是如果是两层的全连接层,这段代码是怎么工作的?这个问题我还没想明白,留个坑


http://lihuaxi.xjx100.cn/news/1027326.html

相关文章

MIPI D-PHYv2.5笔记(16) -- Preamble Sequence、HS-Idle State、Sync Patterns

声明:作者是做嵌入式软件开发的,并非专业的硬件设计人员,笔记内容根据自己的经验和对协议的理解输出,肯定存在有些理解和翻译不到位的地方,有疑问请参考原始规范看 Preamble Sequence 前导码序列(Preamble …

Isolate microTask event Isolate.spawn() compute

我们的flutter应用启动的时候就会开辟一个独立的ioslate,这里面包含了一个独立的内存空间和一个携带 event loops的单一线程和 microTask queue(微任务队列),这个单一线程只处理事件循环。 使用Isolate.spawn()或Flutters compute()函数新建独立的ioslat…

把多列的迭代次数问题化简为单列问题

前已有实验表明,当训练集只有一列的时候,收敛迭代次数与训练集分布的标准差成反比。分布越均匀迭代次数越大。如果可以把多列问题化简为单列问题,比较迭代次数的大小顺序就会变得很简单。 ( A, B )---3*30*2---( 1, 0 )( 0, 1 ) 做一个网络来…

今年又是一年自媒体高峰期你是否抓住了

影视剪辑容易遇到哪些问题: 1、视频格式格式不对,剪辑软件不支持; 2、视频封面不会做; 3、PR导出视频时,没办法做其他事,效率不高; 4、自己配音不好听,配音软件又不好找&#xf…

合创视觉科技平面设计师的职业路线

平面设计师很有必要在一个时期中找到自己的目标,为自己规划一条适合自己的职业规划   确定志向   设定职业生涯目标   制定行动计划与措施   求知欲、设计习惯与坚持   平面设计是一个创意工作,它包含海报、书籍画册、LOGO等设计&am…

React 的源码与原理解读(六):reconcileChildren 与 DIFF 算法

写在专栏开头(叠甲) 作者并不是前端技术专家,也只是一名喜欢学习新东西的前端技术小白,想要学习源码只是为了应付急转直下的前端行情和找工作的需要,这篇专栏是作者学习的过程中自己的思考和体会,也有很多参…

adb环境变量配置

adb环境变量配置Android一. 简介二. 环境变量配置1.JDK安装2.SDK安装3. 资源共享4. 配置环境变量4.1 方式一:4.2 方式二:5. adb常用命令的使用6. 结果Android List of ADB Commands and Fastboot Commands for Android 如果你是一个android用户&#xf…

单片机addr2line的使用说明

1,单片机程序挂死了,无法用jlink调试时,我们一般怎么定位呢,我们一般借助外来工具addr2line工具来调式。 当程序挂死时,我们首先编译时选择c99, 编译后烧写相应的bin文件/csf文件到单片机,烧写后 发现程序挂死&#x…