numpy和torch数据操作对比

news/2024/7/5 1:29:15

对numpy和torch数据操作进行对比,避免遗忘。

ndarray和tensor

import torch
import numpy as npnp_data = np.arange(6).reshape((2, 3))
torch_data = torch.arange(6) # 张量
tensor2array = torch_data.numpy()print('\nnumpy array:\n', np_data,'\ntorch tensor\n', torch_data,'\ntensor to array\n', tensor2array
)
"""
numpy array:[[0 1 2][3 4 5]] 
torch tensortensor([0, 1, 2, 3, 4, 5]) 
tensor to array[0 1 2 3 4 5]
"""

numpy和tensor的维度

import torch
import numpy as npnp_data = np.array([[i for i in range(r * 4, (r + 1) * 4)] for r in range(5)], dtype=np.float32)
np_data = np_data.reshape((2, 5, 2))  # 5x4 reshape一下变 2x5x2
print('numpy shape', np_data.shape)  # 数组的维度
print('numpy ndim', np_data.ndim)  # 数组的轴,也就是rank,也就是shape的的大小
print('numpy size', np_data.size)  # 数组元素的个数,也就是shape乘起来
print('numpy dtype', np_data.dtype)  # 数组元素数据类型
print('numpy itemsize', np_data.itemsize)  # 数组有元素数据类型的大小,float32, 32/8=4np_data = np_data[:, :, np.newaxis, :]  # 增加维度
print('numpy add dim, np.newaxis', np_data.shape)np_data = np_data.transpose((2, 0, 1, 3))  # 转置 transpose
print('numpy transpose, np.newaxis', np_data.shape)np_data = np_data.squeeze(0) # 移除0维度,只能移除维数为1的维度
print('numpy remove dim, np.squeeze', np_data.shape)#######################
print()
#######################torch_data = torch.Tensor([[i for i in range(r * 4, (r + 1) * 4)] for r in range(5)]
)
torch_data = torch_data.view(2, 5, 2)  # 方法变成view, 5x4 view一下变 2x5x2
print('torch shape', torch_data.shape)
print('torch ndim', torch_data.ndim)
print('torch size', torch_data.size())  # 要调用函数
print('torch dtype', torch_data.dtype)  # 没有itemsizetorch_data = torch_data.unsqueeze(2)  # 增加维度
print('torch add dim, torch.unsqueeze', torch_data.shape)torch_data = torch_data.permute((2, 0, 1, 3))  # 转置 permute
print('torch permute', torch_data.shape)torch_data = torch_data.squeeze(0)  # 移除0维度,只能移除维数是1的维度
print('torch remove dim, torch.squeeze', torch_data.shape)"""
numpy shape (2, 5, 2)
numpy ndim 3
numpy size 20
numpy dtype float32
numpy itemsize 4
numpy add dim, np.newaxis (2, 5, 1, 2)
numpy transpose, np.newaxis (1, 2, 5, 2)
numpy remove dim, np.squeeze (2, 5, 2)torch shape torch.Size([2, 5, 2])
torch ndim 3
torch size torch.Size([2, 5, 2])
torch dtype torch.float32
torch add dim, torch.unsqueeze torch.Size([2, 5, 1, 2])
torch permute torch.Size([1, 2, 5, 2])
torch remove dim, torch.squeeze torch.Size([2, 5, 2])
"""

计算对比

import torch
import numpy as npnp_data = np.array([[1, 2], [1, 2]])
torch_data = torch.tensor(np_data)np_t = np.array([0, np.pi / 4., np.pi / 2.])
torch_t = torch.tensor(np_t)# 平均值
np_x = np.arange(5)
torch_x = torch.FloatTensor([i for i in range(5)])
# torch.mean 只能计算float的平均值,不能计算int的平均值,所有必须用FloatTensorprint('\nnumpy.abs:\n', np.abs(np_data),'\ntorch.abs:\n', torch.abs(torch_data),'\nnumpy.sin:\n', np.sin(np_t),'\ntorch.sin:\n', torch.sin(torch_t),'\nnumpy.mean\n', np.mean(np_x),'\ntorch.mean\n', torch.mean(torch_x)
)
"""
numpy.abs:[[1 2][1 2]] 
torch.abs:tensor([[1, 2],[1, 2]]) 
numpy.sin:[0.         0.70710678 1.        ] 
torch.sin:tensor([0.0000, 0.7071, 1.0000], dtype=torch.float64) 
numpy.mean2.0 
torch.meantensor(2.)
"""

矩阵的乘法对比

import torch
import numpy as np# numpy矩阵相乘
na = np.array([[1, 2], [3, 4]])
nb = np.array([[1, 1], [0, 1]])
nc = np.matmul(na, nb)
nd = na * nb
print(nc, nd, sep='\n')# torch矩阵相乘
ta = torch.FloatTensor([[1, 2], [3, 4]])
tb = torch.FloatTensor([[1, 1], [0, 1]])
tc = torch.mm(ta, tb)
td = ta * tb
print(tc, td, sep='\n')"""
[[1 3][3 7]]
[[1 2][0 4]]
tensor([[1., 3.],[3., 7.]])
tensor([[1., 2.],[0., 4.]])
"""

tensor求导数

import torch
from torch.autograd import Variable# 变量
tensor = torch.FloatTensor([[1, 2], [3, 4]])
variable = Variable(tensor, requires_grad=True)
print(tensor)
print(variable)
t_mean = torch.mean(tensor * tensor)
v_mean = torch.mean(variable * variable)
print(t_mean)
print(v_mean)# 反向传播求导数
v_mean.backward()
print(variable.grad)print(variable)  # Variable形式
print(variable.data)  # tensor形式
print(variable.data.numpy())  # numpy形式,随后输出结果一般用numpy形式

tensor的池化操作

import torch
import torch.nn as nn
from torch.autograd import Variable# 最大池化与反池化
pool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True,ceil_mode=True
)
unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
in_data = Variable(torch.Tensor([[[[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12],[13, 14, 15, 16]]]])
)
print('\nin_data\n', in_data)
print(in_data.shape)out_data, indices = pool(in_data)
print('\nout_data\n', out_data)
print('\nindices\n', indices)un_data = unpool(out_data, indices)
print('\nun_data\n', un_data)"""
in_datatensor([[[[ 1.,  2.,  3.,  4.],[ 5.,  6.,  7.,  8.],[ 9., 10., 11., 12.],[13., 14., 15., 16.]]]])
torch.Size([1, 1, 4, 4])out_datatensor([[[[ 6.,  8.],[14., 16.]]]])indicestensor([[[[ 5,  7],[13, 15]]]])un_datatensor([[[[ 0.,  0.,  0.,  0.],[ 0.,  6.,  0.,  8.],[ 0.,  0.,  0.,  0.],[ 0., 14.,  0., 16.]]]])
"""

 


http://lihuaxi.xjx100.cn/news/236028.html

相关文章

微软宣布 Win10 设备数突破8亿,距离10亿还远吗?

开发四年只会写业务代码,分布式高并发都不会还做程序员? >>> 微软高管 Yusuf Mehdi 昨天在推特发布了一条推文,宣布运行 Windows 10 的设备数已突破 8 亿,比半年前增加了 1 亿。 根据之前的报道,两个月前 W…

【学习——字符串】字符串之一网打尽quq

学弟lyh上午讲课,喜闻乐见的制胡窜 一上午讲惹KMP, manachar, trie树, AC自动机 orz 例题都是洛咕咕上的, 贴一下(督促自己不要咕 AC自动机不会qaq(并且没有学的意向 manachar 没写过 P4555 […

CV02-FCN笔记

目录 一、Convolutionalization 卷积化 二、Upsample 上采样 2.1 Unpool反池化 2.2 Interpolation差值 2.3 Transposed Convolution转置卷积 三、Skip Architecture 3.1 特征融合 3.2 裁剪 FCN原理及实践,记录一些自己认为重要的要点,以免日后遗…

对Python课的看法

学习Python已经有两周的时间了,我是计算机专业的学生,我抱着可以多了解一种语言的想法报了Python的选修课,从第一次听肖老师的课开始,我便感受到一种好久没有感受到的课堂氛围,感觉十分舒服,不再是那种高中…

CV04-UNet笔记

目录 一、UNet模型 二、Encoder & Decoder 2.1 Encoder 2.2 Decoder 2.3 classifier 学习U-Net: Convolutional Networks for Biomedical Image Segmentation,记录一些自己认为重要的要点,以免日后遗忘。 代码:https://github.com/…

Git安装配置(Linux)

使用yum安装Git   yum install git -y 编译安装 # 安装依赖关系 yum install curl-devel expat-devel gettext-devel openssl-devel zlib-devel # 编译安装 tar -zxf git-2.0.0.tar.gz cd git-2.0.0 make configure ./configure --prefix/usr make make install 配置Git…

[Python_7] Python Socket 编程

0. 说明 Python Socket 编程 1. TCP 协议 [TCP Server] 通过 netstat -ano 查看端口是否开启 # -*-coding:utf-8-*-"""TCP 协议的 Socket 编程,Server 端Server 端绑定到指定地址,监听特定的端口,接受发来的连接请求 "&q…

CV05-ResNet笔记

目录 一、为什么是ResNet 二、Residual Learning细节 2.1 shortcut计算 2.2 11卷积调整channel维度大小 2.3 ResNet层数 2.4 ResNet里的Basic Block 和 Bottleneck Block 2.5 Global Average Pooling 全局平均池化 2.6 Batch Normalization 学习ResNet,记录…