pytorch实现梯度反转层(Gradient Reversal Layer)

  • Post author:
  • Post category:其他


本文是对于另两篇文章(

文章1


文章2

)的修正



代码

def check_GRL():
    '''
    检查GRL层的有效性。
    :return:
    '''
    from torch.autograd import Function
    from typing import Any, Optional, Tuple
    import torch.nn as nn
    import torch
    import torch.nn.functional as F
    ## 定义梯度反转函数
    class GradientReverseFunction(Function):
        """
        重写自定义的梯度计算方式
        """

        @staticmethod
        def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
            ctx.coeff = coeff
            output = input * 1.0
            return output

        @staticmethod
        def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
            return grad_output.neg() * ctx.coeff, None

    class GRL_Layer(nn.Module):
        def __init__(self):
            super(GRL_Layer, self).__init__()

        def forward(self, *input):
            return GradientReverseFunction.apply(*input)

    class NormalClassifier(nn.Module):

        def __init__(self, num_features, num_classes):
            super().__init__()
            self.linear = nn.Linear(num_features, num_classes)
            self.grl = GRL_Layer()

        def forward(self, x):
            return self.linear(x)

        def grl_forward(self,x):
            x = self.linear(x)
            x = self.grl(x)
            return x

    net1 = NormalClassifier(3, 6)
    net2 = NormalClassifier(6, 10)
    net3 = NormalClassifier(10, 2)

    data = torch.rand((4, 3))
    label = torch.ones((4), dtype=torch.long)
    out = net3(net2(net1(data)))
    loss = F.cross_entropy(out, label)
    loss.backward()
    print("第一次前向传播,没有GRL层")
    print('net1.linear.weight.grad', net1.linear.weight.grad)
    print('net2.linear.weight.grad', net2.linear.weight.grad)
    print('net3.linear.weight.grad', net3.linear.weight.grad)


    print("第二次前向传播,有GRL层")
    net1.zero_grad()
    net2.zero_grad()
    net3.zero_grad()

    out = net3(net2(net1.grl_forward(data)))  ## 这里 net1先经过 linear,再经过GRL
    ## 网络前向: Net1--->  GRL ---> net2--->  net3
    ## 网络反向:  net3---->  net2 ----> GRL--->  net1
    loss = F.cross_entropy(out, label)
    loss.backward()
    print('net1.linear.weight.grad', net1.linear.weight.grad)
    print('net2.linear.weight.grad', net2.linear.weight.grad)
    print('net3.linear.weight.grad', net3.linear.weight.grad)



输出结果

第一次前向传播,没有GRL层
net1.linear.weight.grad tensor([[ 0.0820,  0.0691,  0.0513],
        [-0.1229, -0.1036, -0.0769],
        [ 0.1044,  0.0880,  0.0653],
        [-0.0945, -0.0797, -0.0592],
        [-0.0188, -0.0158, -0.0117],
        [-0.0675, -0.0569, -0.0423]])
net2.linear.weight.grad tensor([[ 0.0957, -0.0512, -0.1555,  0.1024, -0.0103, -0.0861],
        [-0.0637,  0.0341,  0.1035, -0.0682,  0.0069,  0.0573],
        [-0.0634,  0.0339,  0.1030, -0.0678,  0.0068,  0.0570],
        [ 0.1854, -0.0991, -0.3011,  0.1984, -0.0199, -0.1668],
        [ 0.0259, -0.0138, -0.0421,  0.0277, -0.0028, -0.0233],
        [ 0.0182, -0.0097, -0.0295,  0.0194, -0.0020, -0.0163],
        [ 0.1247, -0.0666, -0.2024,  0.1334, -0.0134, -0.1122],
        [-0.0786,  0.0420,  0.1276, -0.0841,  0.0085,  0.0707],
        [-0.0368,  0.0197,  0.0598, -0.0394,  0.0040,  0.0331],
        [-0.0383,  0.0204,  0.0621, -0.0409,  0.0041,  0.0344]])
net3.linear.weight.grad tensor([[-0.1790,  0.1751, -0.3633, -0.2654, -0.1151,  0.2412,  0.1301, -0.3185,
         -0.1266, -0.1912],
        [ 0.1790, -0.1751,  0.3633,  0.2654,  0.1151, -0.2412, -0.1301,  0.3185,
          0.1266,  0.1912]])
第二次前向传播,有GRL层
net1.linear.weight.grad tensor([[-0.0820, -0.0691, -0.0513],
        [ 0.1229,  0.1036,  0.0769],
        [-0.1044, -0.0880, -0.0653],
        [ 0.0945,  0.0797,  0.0592],
        [ 0.0188,  0.0158,  0.0117],
        [ 0.0675,  0.0569,  0.0423]])
net2.linear.weight.grad tensor([[ 0.0957, -0.0512, -0.1555,  0.1024, -0.0103, -0.0861],
        [-0.0637,  0.0341,  0.1035, -0.0682,  0.0069,  0.0573],
        [-0.0634,  0.0339,  0.1030, -0.0678,  0.0068,  0.0570],
        [ 0.1854, -0.0991, -0.3011,  0.1984, -0.0199, -0.1668],
        [ 0.0259, -0.0138, -0.0421,  0.0277, -0.0028, -0.0233],
        [ 0.0182, -0.0097, -0.0295,  0.0194, -0.0020, -0.0163],
        [ 0.1247, -0.0666, -0.2024,  0.1334, -0.0134, -0.1122],
        [-0.0786,  0.0420,  0.1276, -0.0841,  0.0085,  0.0707],
        [-0.0368,  0.0197,  0.0598, -0.0394,  0.0040,  0.0331],
        [-0.0383,  0.0204,  0.0621, -0.0409,  0.0041,  0.0344]])
net3.linear.weight.grad tensor([[-0.1790,  0.1751, -0.3633, -0.2654, -0.1151,  0.2412,  0.1301, -0.3185,
         -0.1266, -0.1912],
        [ 0.1790, -0.1751,  0.3633,  0.2654,  0.1151, -0.2412, -0.1301,  0.3185,
          0.1266,  0.1912]])



版权声明:本文为qq_45193988原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。