onnx模型修改:将均值和方差放到模型中

news/2024/7/5 3:41:46

训练模型时,一般都会对原始数据进行归一化再送入网络,即减均值和除方差。在部署时,我们也要进行同样的操作。有些推理框架会提供对应的接口,我们只需要设置均值和方差即可,如MNN.也有一些框架不提供这样的功能,如Tensorrt,这时,我们就需要自己去逐像素进行这个操作,不仅繁琐,还可能比较耗时。还有一种方式是将这个操作放到模型中,一个方法是在我们的原始pytorch模型中增加一个固定参数的Batchnorm层,另一种方式就是本文要讲的在导出的onnx模型中插入Sub和Div节点来完成。

插入节点

主要步骤:
1.创建2个常量节点,分别是均值和方差向量
2.分别插入一个Sub和Div节点,Sub节点输入是模型的输入和均值节点,Div节点输入是Sub节点输出和方差向量
3.将输入层后的第一层的输入修改为Div节点的输出。
代码如下:

import onnx
from onnx import numpy_helper
import numpy as np

# 加载ONNX模型
model_path = "xx.onnx"
model = onnx.load(model_path)
# 创建均值和方差张量
mean_value = [0.485*255, 0.456*255, 0.406*255]
variance_value = [0.229*255, 0.224*255, 0.225*255]

mean_tensor = numpy_helper.from_array(np.array(mean_value, dtype=np.float32).reshape(1,3,1,1), "mean")
variance_tensor = numpy_helper.from_array(np.array(variance_value, dtype=np.float32).reshape(1,3,1,1), "variance")

# 插入均值和方差节点
mean_node = onnx.helper.make_node("Constant", [], ["mean"], value=mean_tensor)
variance_node = onnx.helper.make_node("Constant", [], ["variance"], value=variance_tensor)

model.graph.node.insert(0, mean_node)
model.graph.node.insert(1, variance_node)

# 插入归一化节点
input_name = model.graph.input[0].name
normalize_node = onnx.helper.make_node("Sub", [input_name, "mean"], ["sub_output"])
scale_node = onnx.helper.make_node("Div", ["sub_output", "variance"], ["input_norm"])

# 插入节点到模型中
model.graph.node.insert(2, normalize_node)
model.graph.node.insert(3, scale_node)

# 更新模型
model.graph.node[4].input[0] = "input_norm"

# 保存修改后的ONNX模型
modified_model_path = "xx_norm.onnx"
# shape inference
model = onnx.shape_inference.infer_shapes(model)
# check model
onnx.checker.check_model(model)
onnx.save(model, modified_model_path)

在这里插入图片描述

验证模型结果

比对修改前和修改后的模型输出

import onnxruntime as ort
import numpy as np

# 加载原始模型
origin_ort = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
origin_input_name = origin_ort.get_inputs()[0].name
origin_output_name = origin_ort.get_outputs()[0].name
input = np.random.randn(1, 3, 128, 128).astype(np.float32)
input_norm = (input - np.array(mean_value,np.float32).reshape(1, 3, 1, 1)) / np.array(variance_value,np.float32).reshape(1, 3, 1, 1)
origin_output = origin_ort.run([origin_output_name], {origin_input_name: input_norm})[0]

# 加载修改后的模型
modified_ort = ort.InferenceSession(modified_model_path, providers=["CPUExecutionProvider"])
modified_input_name = modified_ort.get_inputs()[0].name
modified_output_name = modified_ort.get_outputs()[0].name
modified_output = modified_ort.run([modified_output_name], {modified_input_name: input})[0]

# 比较两个模型的输出
print(np.allclose(origin_output, modified_output, atol=1e-3))

输出为True证明添加节点后的模型正确,后续使用使,不再需要在外部进行均值和方差操作。


http://lihuaxi.xjx100.cn/news/1504462.html

相关文章

Matlab怎么引入外部的latex包?Matlab怎么使用特殊字符?

Matlab怎么引入外部的latex包?Matlab怎么使用特殊字符? Matlab怎么使用特殊字符?一种是使用latex方式,Matlab支持基本的Latex字符【这里】,但一些字符需要依赖外部的包,例如“𝔼”,需…

MySQL8.0.28数据库在windows版本下运行宕机问题解决

问题描述 MySQL8.0.28数据库在windows2022服务器上运行,使用Navicat执行数据备份或者是在并发量较高的情况下会自行宕机,查看日志后发现报错问题如下: mysqld got exception 0x16 ; Most likely, you have hit a bug, but this error can a…

Pipeline Stages

Use of Pipeline Stages in the Compactor Pipeline stages有时能够通过提高扫描移位频率来提高数据通过compactor中逻辑的整体速率。pipeline stages是通过logic level保持中间值输出的寄存器,所以进入logic level中的值可能在一个时钟周期中更早地更新。因为EDT逻辑逻辑级数…

C++进阶之多态

多态 多态的概念多态的定义及实现1.多态的构成条件2.虚函数3.虚函数的重写4.虚函数重写的两个例外5.C11 override 和 final6.重载、覆盖(重写)、隐藏(重定义)的对比 抽象类1.概念2.接口继承和实现继承 多态的原理1.虚函数表2.多态的原理3.动态绑定与静态绑定 单继承和多继承关系…

《来往拜访》的隐私政策

本应用尊重并保护所有使用服务用户的个人隐私权。为了给您提供更准确、更有个性化的服务,本应用会按照本隐私权政策的规定使用和披露您的个人信息。但本应用将以高度的勤勉、审慎义务对待这些信息。除本隐私权政策另有规定外,在未征得您事先许可的情况下…

权限管理 ACL、RBAC、ABAC的学习

ACL(Access Control List:访问控制列表) 最简单的一种方式,将权限直接与用户或用户组相关联,管理员直接给用户授予某些权限即可。 这种模型适用于小型和简单系统,权限一块较为简单,并且角色和权限的变化较少。 RBAC(R…

【C++】C++11新特性 lambda表达式

C11新特性 lambda表达式1、引入2、lambda表达式语法3、 捕获列表说明4、 lambda表达式的原理5、 lambda对象的大小 lambda表达式 1、引入 在C98中,如果想要对一个数据集合中的元素进行排序,可以使用std::sort方法,如果待排序元素为自定义类…

JVM的故事——垃圾收集器

垃圾收集器 文章目录 垃圾收集器一、serial收集器二、parnew收集器三、parallel scavenge收集器四、serial old收集器五、parallel old收集器六、CMS收集器七、Garbage First收集器八、收集器的权衡 一、serial收集器 新生代收集器,最基础的收集器,单线…