torch.roll()这个函数看官方解释很懵,直接对照可视化来理解
参考:
torch.roll 函数的理解
torch.roll(x, shifts=(40, 40), dims=(1, 2))
这里img的shape是[1,56,56,96],即[B,H,W,C]格式。
dim=1,shift=40指的就是数据沿着H维度,将数据朝正反向滚动40,超出部分循环回到图像中
dim=2,shift=40指的就是数据沿着W维度,将数据朝正反向滚动40,超出部分循环回到图像中
这里的原点是左上角,H的正方向向下,W正方向向右
可视化代码:
import torch
import numpy as np
import matplotlib.pyplot as plt
shift_size = 3
'''构造多维张量'''
x=np.arange(301056).reshape(1,56,56,96)
x=torch.from_numpy(x)
if shift_size > 0:
shifted_x = torch.roll(x, shifts=(40, 40), dims=(1, 2))
#shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
print("---------经过循环位移了---------")
else:
shifted_x = x
'''可视化部分'''
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(x[0,:,:,0])
plt.title("orgin_img")
plt.subplot(1,2,2)
plt.imshow(shifted_x[0,:,:,0])
if torch.equal(shifted_x, x):
plt.title("non_shifted")
else:
plt.title("shifted_img")
plt.show()
plt.pause(5)
plt.close()
版权声明:本文为weixin_44823313原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。