序言
在模型剪枝,模型量化以及模型中间层分析过程中都需要一个hook来得到中间结果。这里记录torch的register_hook和register_forward_hook两种方法。
register_hook
register_hook是为了获取反向计算的梯度值。
def get_grad(grad):
print("backward grad is:", grad)
x = 2
w = torch.randn((2, 1), requires_grad=True)
print("init w is:", w)
z = w*x + 1
# print("z grad:", z.grad.data)
y = torch.mean(torch.pow(z, 2))
z.register_hook(get_grad)
y.backward()
print("output:", y)
lr = 0.01
update_w = w.data - lr*w.grad.data
print("update w:", update_w)
# 链式法则
# &y/&w = &y/&z * &z/&w
register_forward_hook
register_forward_hook是为了获取前向推理的一些中间结果。
class Mynet(nn.Module):
def __init__(self):
super(Mynet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, stride=2, padding=0)
def forward(self, x):
x = self.conv1(x)
y = self.conv2(x)
return y
model = Mynet()
x = torch.randn((2, 3, 102, 102), requires_grad=True)
def forward_hook(module, input, output):
print("conv iup:", input[0].shape)
print("conv out:", output.shape)
def forward_pre_hook(module, input):
for i in input:
i.cur_mod = module
print(i.cur_mod)
# handle = model.conv2.register_forward_hook(forward_hook)
handle = model.conv2.register_forward_pre_hook(forward_pre_hook)
y = model(x)
handle.remove()
版权声明:本文为weixin_41803339原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。