register_hook和register_forward_hook记录

  • Post author:
  • Post category:其他




序言

在模型剪枝,模型量化以及模型中间层分析过程中都需要一个hook来得到中间结果。这里记录torch的register_hook和register_forward_hook两种方法。



register_hook

register_hook是为了获取反向计算的梯度值。

def get_grad(grad):
    print("backward grad is:", grad)

x = 2
w = torch.randn((2, 1), requires_grad=True)
print("init w is:", w)
z = w*x + 1
# print("z grad:", z.grad.data)
y = torch.mean(torch.pow(z, 2))
z.register_hook(get_grad)
y.backward()
print("output:", y)
lr = 0.01
update_w = w.data - lr*w.grad.data
print("update w:", update_w)

# 链式法则
# &y/&w = &y/&z * &z/&w



register_forward_hook

register_forward_hook是为了获取前向推理的一些中间结果。

class Mynet(nn.Module):
    def __init__(self):
        super(Mynet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, stride=2, padding=0)

    def forward(self, x):
        x = self.conv1(x)
        y = self.conv2(x)
        return y

model = Mynet()
x = torch.randn((2, 3, 102, 102), requires_grad=True)
def forward_hook(module, input, output):
    print("conv iup:", input[0].shape)
    print("conv out:", output.shape)

def forward_pre_hook(module, input):
    for i in input:
        i.cur_mod = module
        print(i.cur_mod)

# handle = model.conv2.register_forward_hook(forward_hook)
handle = model.conv2.register_forward_pre_hook(forward_pre_hook)
y = model(x)

handle.remove()



版权声明:本文为weixin_41803339原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。