0.写在前面:
什么是孪生网络,什么是rpn网络就不讲了,这里只对卷积过程梳理以及展示对应代码,代码主要是看卷积过程数据怎么变来变去的。rpn理解网络可参考我的博文
:
https://blog.csdn.net/gbz3300255/article/details/105493407
1.SiamRpn图:
siamrpn可以看成是孪生网络 + RPN组成。经典图如下
2.SiamRpn代码:
下面为示例代码,
示例代码的主要作用就是知道输入是啥,输出是啥,中间孪生网络怎么输出结果是啥,RPN网络输出结果又是啥。
孪生网络部分采用胡乱写的一个网络结构做主干(正常用VGG alexnet等吧)anchor个数就直接用9吧。具体值和上面图中有出入,主要是为了理解意思。想看哪一步的输出结果可以随意打印,
例如我一直对上图☆部分怎么计算不理解,一看代码其对应的就是F.conv2d()啊 ,其是拿4*4*2k*256这个东西做卷积核,在20*20*256这个图上做卷积啊。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
x = torch.randn(1,3, 255, 255)#待跟踪图像
z = torch.randn(1,3, 127, 127)#模板图像
feature = nn.Sequential(
# conv1
nn.Conv2d(3, 192, 11, 2),
nn.BatchNorm2d(192),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2),
# conv2
nn.Conv2d(192, 512, 5, 1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2),
# conv3
nn.Conv2d(512, 768, 3, 1),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True),
# conv4
nn.Conv2d(768, 768, 3, 1),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True),
# conv5
nn.Conv2d(768, 256, 3, 1),
nn.BatchNorm2d(256))
conv_reg_z = nn.Conv2d(256, 256 * 4 * 9, 3, 1)
conv_reg_x = nn.Conv2d(256, 256, 3)
conv_cls_z = nn.Conv2d(256, 256 * 2 * 9, 3, 1)
conv_cls_x = nn.Conv2d(256, 256, 3)
adjust_reg = nn.Conv2d(4 * 9, 4 * 9, 1)
z = feature(z)#模板图像孪生网络计算的特征 其是rpn的输入
print("template frame siam shape = ", z.shape)
kernel_reg = conv_reg_z(z)
kernel_cls = conv_cls_z(z)
k = kernel_reg.size()[-1]
# 这块的语句将kernel_reg 变成了一个【4*9 256 k k】形式维度的一个东西 为了后面做F.conv2d操作准# 备卷积核呢
kernel_reg = kernel_reg.view(4 * 9, 256, k, k)
kernel_cls = kernel_cls.view(2 * 9, 256, k, k)
x = feature(x)#待跟踪图像孪生网络计算的特征 其是rpn的输入
print("detection frame siam shape = ", x.shape)
x_reg = conv_reg_x(x)
x_cls = conv_cls_x(x)
out_reg = adjust_reg(F.conv2d(x_reg, kernel_reg)) #回归分支网络的输出结果
out_cls = F.conv2d(x_cls, kernel_cls)#分类分支网络的输出结果
print("rpn out_reg shape = ", out_reg.shape)
print("rpn out_cls shape = ", out_cls.shape)
注意:F.conv2d(a, b)是普通的卷积操作
输出
template frame siam shape = torch.Size([1, 256, 6, 6])
detection frame siam shape = torch.Size([1, 256, 22, 22])
rpn out_reg shape = torch.Size([1, 36, 17, 17])
rpn out_cls shape = torch.Size([1, 18, 17, 17])
下面这个链接是配置商汤科技的那个跟踪器的博客,以后可能会用到
版权声明:本文为gbz3300255原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。