代码
def check_GRL():
'''
检查GRL层的有效性。
:return:
'''
from torch.autograd import Function
from typing import Any, Optional, Tuple
import torch.nn as nn
import torch
import torch.nn.functional as F
## 定义梯度反转函数
class GradientReverseFunction(Function):
"""
重写自定义的梯度计算方式
"""
@staticmethod
def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
ctx.coeff = coeff
output = input * 1.0
return output
@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
return grad_output.neg() * ctx.coeff, None
class GRL_Layer(nn.Module):
def __init__(self):
super(GRL_Layer, self).__init__()
def forward(self, *input):
return GradientReverseFunction.apply(*input)
class NormalClassifier(nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.linear = nn.Linear(num_features, num_classes)
self.grl = GRL_Layer()
def forward(self, x):
return self.linear(x)
def grl_forward(self,x):
x = self.linear(x)
x = self.grl(x)
return x
net1 = NormalClassifier(3, 6)
net2 = NormalClassifier(6, 10)
net3 = NormalClassifier(10, 2)
data = torch.rand((4, 3))
label = torch.ones((4), dtype=torch.long)
out = net3(net2(net1(data)))
loss = F.cross_entropy(out, label)
loss.backward()
print("第一次前向传播,没有GRL层")
print('net1.linear.weight.grad', net1.linear.weight.grad)
print('net2.linear.weight.grad', net2.linear.weight.grad)
print('net3.linear.weight.grad', net3.linear.weight.grad)
print("第二次前向传播,有GRL层")
net1.zero_grad()
net2.zero_grad()
net3.zero_grad()
out = net3(net2(net1.grl_forward(data))) ## 这里 net1先经过 linear,再经过GRL
## 网络前向: Net1---> GRL ---> net2---> net3
## 网络反向: net3----> net2 ----> GRL---> net1
loss = F.cross_entropy(out, label)
loss.backward()
print('net1.linear.weight.grad', net1.linear.weight.grad)
print('net2.linear.weight.grad', net2.linear.weight.grad)
print('net3.linear.weight.grad', net3.linear.weight.grad)
输出结果
第一次前向传播,没有GRL层
net1.linear.weight.grad tensor([[ 0.0820, 0.0691, 0.0513],
[-0.1229, -0.1036, -0.0769],
[ 0.1044, 0.0880, 0.0653],
[-0.0945, -0.0797, -0.0592],
[-0.0188, -0.0158, -0.0117],
[-0.0675, -0.0569, -0.0423]])
net2.linear.weight.grad tensor([[ 0.0957, -0.0512, -0.1555, 0.1024, -0.0103, -0.0861],
[-0.0637, 0.0341, 0.1035, -0.0682, 0.0069, 0.0573],
[-0.0634, 0.0339, 0.1030, -0.0678, 0.0068, 0.0570],
[ 0.1854, -0.0991, -0.3011, 0.1984, -0.0199, -0.1668],
[ 0.0259, -0.0138, -0.0421, 0.0277, -0.0028, -0.0233],
[ 0.0182, -0.0097, -0.0295, 0.0194, -0.0020, -0.0163],
[ 0.1247, -0.0666, -0.2024, 0.1334, -0.0134, -0.1122],
[-0.0786, 0.0420, 0.1276, -0.0841, 0.0085, 0.0707],
[-0.0368, 0.0197, 0.0598, -0.0394, 0.0040, 0.0331],
[-0.0383, 0.0204, 0.0621, -0.0409, 0.0041, 0.0344]])
net3.linear.weight.grad tensor([[-0.1790, 0.1751, -0.3633, -0.2654, -0.1151, 0.2412, 0.1301, -0.3185,
-0.1266, -0.1912],
[ 0.1790, -0.1751, 0.3633, 0.2654, 0.1151, -0.2412, -0.1301, 0.3185,
0.1266, 0.1912]])
第二次前向传播,有GRL层
net1.linear.weight.grad tensor([[-0.0820, -0.0691, -0.0513],
[ 0.1229, 0.1036, 0.0769],
[-0.1044, -0.0880, -0.0653],
[ 0.0945, 0.0797, 0.0592],
[ 0.0188, 0.0158, 0.0117],
[ 0.0675, 0.0569, 0.0423]])
net2.linear.weight.grad tensor([[ 0.0957, -0.0512, -0.1555, 0.1024, -0.0103, -0.0861],
[-0.0637, 0.0341, 0.1035, -0.0682, 0.0069, 0.0573],
[-0.0634, 0.0339, 0.1030, -0.0678, 0.0068, 0.0570],
[ 0.1854, -0.0991, -0.3011, 0.1984, -0.0199, -0.1668],
[ 0.0259, -0.0138, -0.0421, 0.0277, -0.0028, -0.0233],
[ 0.0182, -0.0097, -0.0295, 0.0194, -0.0020, -0.0163],
[ 0.1247, -0.0666, -0.2024, 0.1334, -0.0134, -0.1122],
[-0.0786, 0.0420, 0.1276, -0.0841, 0.0085, 0.0707],
[-0.0368, 0.0197, 0.0598, -0.0394, 0.0040, 0.0331],
[-0.0383, 0.0204, 0.0621, -0.0409, 0.0041, 0.0344]])
net3.linear.weight.grad tensor([[-0.1790, 0.1751, -0.3633, -0.2654, -0.1151, 0.2412, 0.1301, -0.3185,
-0.1266, -0.1912],
[ 0.1790, -0.1751, 0.3633, 0.2654, 0.1151, -0.2412, -0.1301, 0.3185,
0.1266, 0.1912]])
版权声明:本文为qq_45193988原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。