Swin Transformer中torch.roll()详解

  • Post author:
  • Post category:其他


torch.roll()这个函数看官方解释很懵,直接对照可视化来理解

参考:

torch.roll 函数的理解


torch.roll(x, shifts=(40, 40), dims=(1, 2))


在这里插入图片描述

这里img的shape是[1,56,56,96],即[B,H,W,C]格式。

dim=1,shift=40指的就是数据沿着H维度,将数据朝正反向滚动40,超出部分循环回到图像中

dim=2,shift=40指的就是数据沿着W维度,将数据朝正反向滚动40,超出部分循环回到图像中

这里的原点是左上角,H的正方向向下,W正方向向右


可视化代码:

import torch    
import numpy as np   
import matplotlib.pyplot as plt

shift_size = 3
'''构造多维张量'''
x=np.arange(301056).reshape(1,56,56,96)
x=torch.from_numpy(x)

if shift_size > 0:
    shifted_x = torch.roll(x, shifts=(40, 40), dims=(1, 2))
    #shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    print("---------经过循环位移了---------")
else:
    shifted_x = x
   
'''可视化部分'''
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(x[0,:,:,0])
plt.title("orgin_img")
plt.subplot(1,2,2)
plt.imshow(shifted_x[0,:,:,0])
if torch.equal(shifted_x, x):
    plt.title("non_shifted")
else:
    plt.title("shifted_img")
plt.show()
plt.pause(5)
plt.close()



版权声明:本文为weixin_44823313原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。