文章结构分文transformer 和performer结构
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
'classifier': 'head',
**kwargs
}
default_cfgs = {
'T2t_vit_7': _cfg(),
'T2t_vit_10': _cfg(),
'T2t_vit_12': _cfg(),
'T2t_vit_14': _cfg(),
'T2t_vit_19': _cfg(),
'T2t_vit_24': _cfg(),
'T2t_vit_t_14': _cfg(),
'T2t_vit_t_19': _cfg(),
'T2t_vit_t_24': _cfg(),
'T2t_vit_14_resnext': _cfg(),
'T2t_vit_14_wide': _cfg(),
}
class T2T_module(nn.Module):
"""
Tokens-to-Token encoding module
"""
def __init__(self, img_size=224, tokens_type='performer', in_chans=3, embed_dim=768, token_dim=64):
super().__init__()
#选择tokens的类别 这个阶段采用了三次unfold操作
if tokens_type == 'transformer':
print('adopt transformer encoder for tokens-to-token')
self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.attention1 = Token_transformer(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)
self.attention2 = Token_transformer(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)
self.project = nn.Linear(token_dim * 3 * 3, embed_dim)
elif tokens_type == 'performer':
print('adopt performer encoder for tokens-to-token')
self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
#self.attention1 = Token_performer(dim=token_dim, in_dim=in_chans*7*7, kernel_ratio=0.5)
#self.attention2 = Token_performer(dim=token_dim, in_dim=token_dim*3*3, kernel_ratio=0.5)
self.attention1 = Token_performer(dim=in_chans*7*7, in_dim=token_dim, kernel_ratio=0.5)
self.attention2 = Token_performer(dim=token_dim*3*3, in_dim=token_dim, kernel_ratio=0.5)
self.project = nn.Linear(token_dim * 3 * 3, embed_dim)
elif tokens_type == 'convolution': # just for comparison with conolution, not our model
# for this tokens type, you need change forward as three convolution operation
print('adopt convolution layers for tokens-to-token')
self.soft_split0 = nn.Conv2d(3, token_dim, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) # the 1st convolution
self.soft_split1 = nn.Conv2d(token_dim, token_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 2nd convolution
self.project = nn.Conv2d(token_dim, embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 3rd convolution
self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2)) # there are 3 sfot split, stride are 4,2,2 seperately
def forward(self, x):
#第一次split 输入图片为224*224 kernel_size=7 s=4 padding=2
#通过Unfold操作,可以将每一个滑块内的向量提取出来然后拉平 flatten
# step0: soft split
x = self.soft_split0(x).transpose(1, 2)
# iteration1: re-structurization/reconstruction
#这里的attention1采用的就是传统msa的结构 一个attention计算+一个mlp
x = self.attention1(x)
#重构图像
B, new_HW, C = x.shape
x = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
# iteration1: soft split
#重复上述操作,k=3,s=2,p=2
x = self.soft_split1(x).transpose(1, 2)
# iteration2: re-structurization/reconstruction
x = self.attention2(x)
B, new_HW, C = x.shape
x = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
# iteration2: soft split
x = self.soft_split2(x).transpose(1, 2)
#经过三次split和两次restructure之后,得到最后的token序列
# final tokens
#一个linear操作
x = self.project(x)
return x
class T2T_ViT(nn.Module):
def __init__(self, img_size=224, tokens_type='performer', in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, token_dim=64):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.tokens_to_token = T2T_module(
img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim, token_dim=token_dim)
num_patches = self.tokens_to_token.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# Classifier head
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
x = self.tokens_to_token(x) #T2T模块,就是前面说的三次split操作
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
#block根据不用的结构采用不同的block 可以直接用transformer的MSA 也可以用dense Net的block 也可以采用SE 结构
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
@register_model
def t2t_vit_7(pretrained=False, **kwargs): # adopt performer for tokens to token
if pretrained:
kwargs.setdefault('qk_scale', 256 ** -0.5)
model = T2T_ViT(tokens_type='performer', embed_dim=256, depth=7, num_heads=4, mlp_ratio=2., **kwargs)
model.default_cfg = default_cfgs['T2t_vit_7']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def t2t_vit_10(pretrained=False, **kwargs): # adopt performer for tokens to token
if pretrained:
kwargs.setdefault('qk_scale', 256 ** -0.5)
model = T2T_ViT(tokens_type='performer', embed_dim=256, depth=10, num_heads=4, mlp_ratio=2., **kwargs)
model.default_cfg = default_cfgs['T2t_vit_10']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def t2t_vit_12(pretrained=False, **kwargs): # adopt performer for tokens to token
if pretrained:
kwargs.setdefault('qk_scale', 256 ** -0.5)
model = T2T_ViT(tokens_type='performer', embed_dim=256, depth=12, num_heads=4, mlp_ratio=2., **kwargs)
model.default_cfg = default_cfgs['T2t_vit_12']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def t2t_vit_14(pretrained=False, **kwargs): # adopt performer for tokens to token
if pretrained:
kwargs.setdefault('qk_scale', 384 ** -0.5)
model = T2T_ViT(tokens_type='performer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3., **kwargs)
model.default_cfg = default_cfgs['T2t_vit_14']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def t2t_vit_19(pretrained=False, **kwargs): # adopt performer for tokens to token
if pretrained:
kwargs.setdefault('qk_scale', 448 ** -0.5)
model = T2T_ViT(tokens_type='performer', embed_dim=448, depth=19, num_heads=7, mlp_ratio=3., **kwargs)
model.default_cfg = default_cfgs['T2t_vit_19']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def t2t_vit_24(pretrained=False, **kwargs): # adopt performer for tokens to token
if pretrained:
kwargs.setdefault('qk_scale', 512 ** -0.5)
model = T2T_ViT(tokens_type='performer', embed_dim=512, depth=24, num_heads=8, mlp_ratio=3., **kwargs)
model.default_cfg = default_cfgs['T2t_vit_24']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def t2t_vit_t_14(pretrained=False, **kwargs): # adopt transformers for tokens to token
if pretrained:
kwargs.setdefault('qk_scale', 384 ** -0.5)
model = T2T_ViT(tokens_type='transformer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3., **kwargs)
model.default_cfg = default_cfgs['T2t_vit_t_14']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def t2t_vit_t_19(pretrained=False, **kwargs): # adopt transformers for tokens to token
if pretrained:
kwargs.setdefault('qk_scale', 448 ** -0.5)
model = T2T_ViT(tokens_type='transformer', embed_dim=448, depth=19, num_heads=7, mlp_ratio=3., **kwargs)
model.default_cfg = default_cfgs['T2t_vit_t_19']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def t2t_vit_t_24(pretrained=False, **kwargs): # adopt transformers for tokens to token
if pretrained:
kwargs.setdefault('qk_scale', 512 ** -0.5)
model = T2T_ViT(tokens_type='transformer', embed_dim=512, depth=24, num_heads=8, mlp_ratio=3., **kwargs)
model.default_cfg = default_cfgs['T2t_vit_t_24']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
# rexnext and wide structure
@register_model
def t2t_vit_14_resnext(pretrained=False, **kwargs):
if pretrained:
kwargs.setdefault('qk_scale', 384 ** -0.5)
model = T2T_ViT(tokens_type='performer', embed_dim=384, depth=14, num_heads=32, mlp_ratio=3., **kwargs)
model.default_cfg = default_cfgs['T2t_vit_14_resnext']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def t2t_vit_14_wide(pretrained=False, **kwargs):
if pretrained:
kwargs.setdefault('qk_scale', 512 ** -0.5)
model = T2T_ViT(tokens_type='performer', embed_dim=768, depth=4, num_heads=12, mlp_ratio=3., **kwargs)
model.default_cfg = default_cfgs['T2t_vit_14_wide']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
以SE 为例子,在开源代码中有t2t_vit_se
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x): # x: [B, N, C]
x = torch.transpose(x, 1, 2) # [B, C, N]
b, c, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1)
x = x * y.expand_as(x)
x = torch.transpose(x, 1, 2) # [B, N, C]
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
#此处用到了SE layer
self.se_layer = SELayer(dim)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
#在映射后使用se 不同的attention导致不同的block
x = self.se_layer(x)
x = self.proj_drop(x)
return x
版权声明:本文为whr980114原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。