CV03-双线性差值pytorch实现

news/2024/7/3 0:59:37

一、双线性差值

1.1 公式

在理解双线性差值(Bilinear Interpolation)的含义基础上,参考pytorch差值的官方实现注释,自己实现了一遍。

差值就是利用已知点来估计未知点的值。一维上,可以用两点求出斜率,再根据位置关系来求插入点的值。

同理,在二维平面上也可以用类似的办法来估计插入点的值。如图,已知四点Q_{00}Q_{01}Q_{10}Q_{11}四点的值与坐标值(h_{0},w_{0})(h_{0},w_{1})(h_{1},w_{0})(h_{1},w_{1}),求位于(h,w)的点P的值。思路是

  1. 先用w方向一维的线性差值,根据Q_{00}Q_{01}求出点R_{0},根据Q_{10}Q_{11}求出点R_{1}
  2. 再用h方向一维线性差值,根据R_{0}R_{1}求出点P

那么就有如下公式

\begin{aligned} R_{0} &= \frac{w_{1}-w}{w_{1}-w_{0}}Q_{00}+\frac{w-w_{0}}{w_{1}-w_{0}}Q_{01}\\ R_{1} &= \frac{w_{1}-w}{w_{1}-w_{0}}Q_{10}+\frac{w-w_{0}}{w_{1}-w_{0}}Q_{11}\\ P &= \frac{h_{1}-h}{h_{1}-h_{0}}R_{0}+\frac{h-h_{0}}{h_{1}-h_{0}}R_{1}\\ &= \frac{h_{1}-h}{(h_{1}-h_{0})(w_{1}-w_{0})}((w_{1}-w)Q_{00}+(w-w_{0})Q_{01}) + \frac{h-h_{0}}{(h_{1}-h_{0})(w_{1}-w_{0})}((w_{1}-w_{0})Q_{10}+(w-w_{0})Q_{11}) \end{aligned}

具体到图像的双线性差值问题,我们可以理解成将图片进行了放大,但不使图像变成大块的斑点状,而是增大了图像的分辨率,多出来的像素就是双线性差值的结果。图像上(h,w)周边4点一定是临近的,也就是说

\begin{aligned} &h_{0}=\left \lfloor h \right \rfloor,\quad &h_{1}=h_{0}+1 ,\quad &h_{1}-h_{0}=1\\ &w_{0}=\left \lfloor w \right \rfloor,\quad &w_{1}=w_{0}+1 ,\quad &w_{1}-w_{0}=1 \end{aligned}

上面的公式简化为

\begin{aligned} P= &(h_{1}-h)(w_{1}-w)Q_{00} + (h_{1}-h)(w-w_{0})Q_{01} + \\ &(h-h_{0})(w_{1}-w_{0})Q_{10} + (h-h_{0})(w-w_{0})Q_{11} \end{aligned}

这样我们就面临将目标图像的坐标(hd,wd)映射到原图像上求出(h,w)的问题。

1.2 坐标变换

对于第一个问题,目标图像的坐标(hd,wd)映射到原图像上求出(h,w),有两种思路。

第一种是把像素点看成是1×1大小的方块,像素点位于方块的中心,坐标转换时,HW方向的坐标都要加0.5才能对应起来。pytorch里面叫做torch.nn.functional.interpolate(align_corners=False)。

举例,如图原图像是一个3×3的图像,放大到5×5,每个像素点都是位于方形内的黑色小点。设h_{src},w_{src}是原图像的大小,本例是3×3,h_{dst},w_{dst}是目标图像的大小,本例是5×5。换算公式为

\begin{aligned} \frac{h+0.5}{h_{src}}=\frac{hd+0.5}{h_{dst}} \quad&\Rightarrow\quad h=\frac{h_{src}}{h_{dst}}(hd+0.5)-0.5\\ \frac{w+0.5}{w_{src}}=\frac{wd+0.5}{w_{dst}} \quad&\Rightarrow\quad w=\frac{w_{src}}{w_{dst}}(wd+0.5)-0.5\\ \end{aligned}

第二种是上下左右相邻的像素点之间连线,像素点都位于交点上,坐标转换时,HW方向的总长度都要减少1才能对应起来g。pytorch里面叫做torch.nn.functional.interpolate(align_corners=True)。

举例,一个3×3的图像放大到5×5,每个像素点都是位于交点的黑色小点。设h_{src},w_{src}是原图像的大小,本例是3×3,h_{dst},w_{dst}是目标图像的大小,本例是5×5。换算时,我们取边的长度,也就是HW方向各减1,也就是从2×2变成4×4。这样就有个结论就是变换以后目标图像四个顶点的像素值一定和原图像四个顶点像素值一样。换算公式为

\begin{aligned} \frac{h}{h_{src}-1}=\frac{hd}{h_{dst}-1} \quad&\Rightarrow\quad h=\frac{h_{src}-1}{h_{dst}-1}hd\\ \frac{w}{w_{src}-1}=\frac{wd}{w_{dst}-1} \quad&\Rightarrow\quad w=\frac{w_{src}-1}{w_{dst}-1}wd\\ \end{aligned}

 

二、for循环实现双线性差值(naive实现)

是对一张图像的,维度HWC,采用for循环遍历H、W计算差值点的像素值。这个实现too young,too simple,简直naive,效率低但易于理解;这里只实现了第一种坐标变换。

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import osdef bilinear_interpolation_naive(src, dst_size):"""双线性差值的naive实现:param src: 源图像:param dst_size: 目标图像大小H*W:return: 双线性差值后的图像"""(src_h, src_w, src_c) = src.shape  # 原图像大小 H*W*C(dst_h, dst_w), dst_c = dst_size, src_c  # 目标图像大小H*W*Cif src_h == dst_h and src_w == dst_w:  # 如果大小不变,直接返回copyreturn src.copy()scale_h = float(src_h) / dst_h  # 计算H方向缩放比scale_w = float(src_w) / dst_w  # 计算W方向缩放比dst = np.zeros((dst_h, dst_w, dst_c), dtype=src.dtype)  # 目标图像初始化for h_d, row in enumerate(dst):  # 遍历目标图像H方向for w_d, col in enumerate(row):  # 遍历目标图像所有W方向h = scale_h * (h_d + 0.5) - 0.5  # 将目标图像H坐标映射到源图像上w = scale_w * (w_d + 0.5) - 0.5  # 将目标图像W坐标映射到源图像上h0 = int(np.floor(h))  # 最近4个点坐标h0w0 = int(np.floor(w))  # 最近4个点坐标w0h1 = min(h0 + 1, src_h - 1)  # h0 + 1就是h1,但是不能越界w1 = min(w0 + 1, src_w - 1)  # w0 + 1就是w1,但是不能越界r0 = (w1 - w) * src[h0, w0, ...] + (w - w0) * src[h0, w1, ...]  # 双线性差值R0r1 = (w1 - w) * src[h1, w0, ...] + (w - w0) * src[h1, w1, ...]  # 双线性插值R1p = (h1 - h) * r0 + (h - h0) * r1  # 双线性插值Pdst[h_d, w_d, ...] = p.astype(np.uint8) # 插值结果放进目标像素点return dstif __name__ == '__main__':def unit_test():image_file = os.path.join(os.getcwd(), 'test.jpg')image = mpimg.imread(image_file)image_scale = bilinear_interpolation_naive(image, (256, 256))fig, axes = plt.subplots(1, 2, figsize=(8, 10))axes = axes.flatten()axes[0].imshow(image)axes[1].imshow(image_scale)axes[0].axis([0, image.shape[1], image.shape[0], 0])axes[1].axis([0, image_scale.shape[1], image_scale.shape[0], 0])fig.tight_layout()plt.show()passunit_test()

三、用numpy矩阵实现

是对一张图像的,维度HWC;采用numpy矩阵实现,速度快;

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import os
import torchdef bilinear_interpolation(src, dst_size, align_corners=False):"""双线性插值高效实现:param src: 源图像H*W*C:param dst_size: 目标图像大小H*W:return: 双线性插值后的图像"""(src_h, src_w, src_c) = src.shape  # 原图像大小 H*W*C(dst_h, dst_w), dst_c = dst_size, src_c  # 目标图像大小H*W*Cif src_h == dst_h and src_w == dst_w:  # 如果大小不变,直接返回copyreturn src.copy()# 矩阵方式实现h_d = np.arange(dst_h)  # 目标图像H方向坐标w_d = np.arange(dst_w)  # 目标图像W方向坐标if align_corners:h = float(src_h - 1) / (dst_h - 1) * h_dw = float(src_w - 1) / (dst_w - 1) * w_delse:h = float(src_h) / dst_h * (h_d + 0.5) - 0.5  # 将目标图像H坐标映射到源图像上w = float(src_w) / dst_w * (w_d + 0.5) - 0.5  # 将目标图像W坐标映射到源图像上h = np.clip(h, 0, src_h - 1)  # 防止越界,最上一行映射后是负数,置为0w = np.clip(w, 0, src_w - 1)  # 防止越界,最左一行映射后是负数,置为0h = np.repeat(h.reshape(dst_h, 1), dst_w, axis=1)  # 同一行映射的h值都相等w = np.repeat(w.reshape(dst_w, 1), dst_h, axis=1).T  # 同一列映射的w值都相等h0 = np.floor(h).astype(np.int)  # 同一行的h0值都相等w0 = np.floor(w).astype(np.int)  # 同一列的w0值都相等h0 = np.clip(h0, 0, src_h - 2)  # 最下一行上不大于src_h - 2,相当于paddingw0 = np.clip(w0, 0, src_w - 2)  # 最右一列左不大于src_w - 2,相当于paddingh1 = np.clip(h0 + 1, 0, src_h - 1)  # 同一行的h1值都相等,防止越界w1 = np.clip(w0 + 1, 0, src_w - 1)  # 同一列的w1值都相等,防止越界q00 = src[h0, w0]  # 取每一个像素对应的q00q01 = src[h0, w1]  # 取每一个像素对应的q01q10 = src[h1, w0]  # 取每一个像素对应的q10q11 = src[h1, w1]  # 取每一个像素对应的q11h = np.repeat(h[..., np.newaxis], dst_c, axis=2)  # 图像有通道C,所有的计算都增加通道Cw = np.repeat(w[..., np.newaxis], dst_c, axis=2)h0 = np.repeat(h0[..., np.newaxis], dst_c, axis=2)w0 = np.repeat(w0[..., np.newaxis], dst_c, axis=2)h1 = np.repeat(h1[..., np.newaxis], dst_c, axis=2)w1 = np.repeat(w1[..., np.newaxis], dst_c, axis=2)r0 = (w1 - w) * q00 + (w - w0) * q01  # 双线性插值的r0r1 = (w1 - w) * q10 + (w - w0) * q11  # 双线性差值的r1q = (h1 - h) * r0 + (h - h0) * r1  # 双线性差值的qdst = q.astype(src.dtype)  # 图像的数据类型return dstif __name__ == "__main__":def unit_test2():image_file = os.path.join(os.getcwd(), 'test.jpg')image = mpimg.imread(image_file)image_scale = bilinear_interpolation(image, (256, 256))fig, axes = plt.subplots(1, 2, figsize=(8, 10))axes = axes.flatten()axes[0].imshow(image)axes[1].imshow(image_scale)axes[0].axis([0, image.shape[1], image.shape[0], 0])axes[1].axis([0, image_scale.shape[1], image_scale.shape[0], 0])fig.tight_layout()plt.show()passunit_test2()def unit_test3():src = np.array([[1, 2], [3, 4]])print(src)src = src.reshape((2, 2, 1))dst_size = (4, 4)dst = bilinear_interpolation(src, dst_size)dst = dst.reshape(dst_size)print(dst)tsrc = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)print(tsrc)tdst = F.interpolate(tsrc,size=(4, 4),mode='bilinear')print(tdst)# unit_test3()

四、用torch张量实现

是对tensor的,维度NCHW;和第二段一样,但是采用了张量,可以批量处理。

import torch
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimgdef bilinear_interpolate(src, dst_size, align_corners=False):"""双线性差值:param src: 原图像张量 NCHW:param dst_size: 目标图像spatial大小(H,W):param align_corners: 换算坐标的不同方式:return: 目标图像张量NCHW"""src_n, src_c, src_h, src_w = src.shapedst_n, dst_c, (dst_h, dst_w) = src_n, src_c, dst_sizeif src_h == dst_h and src_w == dst_w:return src.copy()"""将dst的H和W坐标映射到src的H和W坐标"""hd = torch.arange(0, dst_h)wd = torch.arange(0, dst_w)if align_corners:h = float(src_h - 1) / (dst_h - 1) * hdw = float(src_w - 1) / (dst_w - 1) * wdelse:h = float(src_h) / dst_h * (hd + 0.5) - 0.5w = float(src_w) / dst_w * (wd + 0.5) - 0.5h = torch.clamp(h, 0, src_h - 1)  # 防止越界,0相当于上边界paddingw = torch.clamp(w, 0, src_w - 1)  # 防止越界,0相当于左边界paddingh = h.view(dst_h, 1)  # 1维dst_h个,变2维dst_h*1个w = w.view(1, dst_w)  # 1维dst_w个,变2维1*dst_w个h = h.repeat(1, dst_w)  # H方向重复1次,W方向重复dst_w次w = w.repeat(dst_h, 1)  # H方向重复dsth次,W方向重复1次"""求出四点坐标"""h0 = torch.clamp(torch.floor(h), 0, src_h - 2)  # -2相当于下边界paddingw0 = torch.clamp(torch.floor(w), 0, src_w - 2)  # -2相当于右边界paddingh0 = h0.long()  # torch坐标必须是longw0 = w0.long()  # torch坐标必须是longh1 = h0 + 1w1 = w0 + 1"""求出四点值"""q00 = src[..., h0, w0]q01 = src[..., h0, w1]q10 = src[..., h1, w0]q11 = src[..., h1, w1]"""公式计算"""r0 = (w1 - w) * q00 + (w - w0) * q01  # 双线性插值的r0r1 = (w1 - w) * q10 + (w - w0) * q11  # 双线性差值的r1dst = (h1 - h) * r0 + (h - h0) * r1  # 双线性差值的qreturn dstif __name__ == '__main__':def unit_test4():# src = torch.randint(0, 100, (1, 3, 3, 3))src = torch.arange(1, 1 + 27).view((1, 3, 3, 3))\.type(torch.float32)print(src)dst = bilinear_interpolate(src,dst_size=(4, 4),align_corners=True)print(dst)pt_dst = F.interpolate(src.float(),size=(4, 4),mode='bilinear',align_corners=True)print(pt_dst)if torch.equal(dst, pt_dst):print('success')image_file = os.path.join(os.getcwd(), 'test.jpg')image = mpimg.imread(image_file)image_in = torch.from_numpy(image.transpose(2, 0, 1))image_in = torch.unsqueeze(image_in, 0)image_out = bilinear_interpolate(image_in, (256, 256))image_out = torch.squeeze(image_out, 0).numpy().astype(int)image_out = image_out.transpose(1, 2, 0)fig, axes = plt.subplots(1, 2, figsize=(8, 10))axes = axes.flatten()axes[0].imshow(image)axes[1].imshow(image_out)axes[0].axis([0, image.shape[1], image.shape[0], 0])axes[1].axis([0, image_out.shape[1], image_out.shape[0], 0])fig.tight_layout()plt.show()unit_test4()


http://lihuaxi.xjx100.cn/news/236031.html

相关文章

2018 JVM 生态报告:79% 的 Java 开发者使用 Java 8

百度智能云 云生态狂欢季 热门云产品1折起>>> 2018 JVM 生态调查报告已于近日发布,该报告由 Snyk 和 The Java Magazine(Oracle 的双月刊)联合推出,旨在了解 JDK 的实现、工具、平台和应用方面的前景。基于超过 10200 …

Flutter环境搭建(Windows)

SDK获取 去官方网站下载最新的安装包 ,或者在Github中的Flutter项目去 下载 。 将下载的安装包解压 注意:不要将Flutter安装到高权限路径,例如 C:\Program Files\ 配置环境变量,在Path中添加flutter\bin的全路径(如:D…

numpy和torch数据操作对比

对numpy和torch数据操作进行对比,避免遗忘。 ndarray和tensor import torch import numpy as npnp_data np.arange(6).reshape((2, 3)) torch_data torch.arange(6) # 张量 tensor2array torch_data.numpy()print(\nnumpy array:\n, np_data,\ntorch tensor\n,…

微软宣布 Win10 设备数突破8亿,距离10亿还远吗?

开发四年只会写业务代码,分布式高并发都不会还做程序员? >>> 微软高管 Yusuf Mehdi 昨天在推特发布了一条推文,宣布运行 Windows 10 的设备数已突破 8 亿,比半年前增加了 1 亿。 根据之前的报道,两个月前 W…

【学习——字符串】字符串之一网打尽quq

学弟lyh上午讲课,喜闻乐见的制胡窜 一上午讲惹KMP, manachar, trie树, AC自动机 orz 例题都是洛咕咕上的, 贴一下(督促自己不要咕 AC自动机不会qaq(并且没有学的意向 manachar 没写过 P4555 […

CV02-FCN笔记

目录 一、Convolutionalization 卷积化 二、Upsample 上采样 2.1 Unpool反池化 2.2 Interpolation差值 2.3 Transposed Convolution转置卷积 三、Skip Architecture 3.1 特征融合 3.2 裁剪 FCN原理及实践,记录一些自己认为重要的要点,以免日后遗…

对Python课的看法

学习Python已经有两周的时间了,我是计算机专业的学生,我抱着可以多了解一种语言的想法报了Python的选修课,从第一次听肖老师的课开始,我便感受到一种好久没有感受到的课堂氛围,感觉十分舒服,不再是那种高中…

CV04-UNet笔记

目录 一、UNet模型 二、Encoder & Decoder 2.1 Encoder 2.2 Decoder 2.3 classifier 学习U-Net: Convolutional Networks for Biomedical Image Segmentation,记录一些自己认为重要的要点,以免日后遗忘。 代码:https://github.com/…