當前位置:網站首頁>Vision Transformer(Pytorch版)代碼閱讀注釋

Vision Transformer(Pytorch版)代碼閱讀注釋

2022-01-28 14:24:52 HollowKnightZ

在這裏插入圖片描述

前言

因為Google Research官方的Vision Transformer源碼是tensorflow版本,而筆者平時多用pytorch,所以在github上找了作者rwightman版本的代碼:rwightman/pytorch-image-models/timm/models/vision_transformer.py

Vision Transformer介紹博客:論文閱讀筆記:Vision Transformer

下面的代碼介紹以vit_base_patch16_224(ViT-B/16:patch_size=16, img_size=224)為例。

VIT Model

原文中模型由三個模塊組成:
· Linear Projection of Flattened Patches
· Transformer Encoder
· MLP Head

對應代碼中的三個模塊:
· patch embedding layer
· Block
· Representation layer + Classifier head

Linear Projection of Flattened Patches

在這裏插入圖片描述
如圖,Linear Projection of Flattened Patches的實現的通過一個kernel_size=stride=16的卷積加上一個flatten實現的。他的功能是將 244 × 244 × 3 244×244×3 244×244×3 的的2D Image轉換為 196 × 768 196×768 196×768 的Patch Embedding。具體代碼及注釋如下:

class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding """
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        super().__init__()
        ''' image_size = (244,244) patch_size = (16,16) gird_size = (244/16,244/16)=(14,14) num_patches = 14 * 14 = 196 '''
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
		
		''' 使用大小為16,stride為16的卷積核實現embeding, 輸出14*14大小,通道為768(768 = 16*16*3,相當於將每個patch部分轉換為1維向量)的patch '''
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        ''' 如果norm_layer為true則使用layerNorm,這裏作者沒有使用, 所以self.norm = nn.Identity(),對輸入不做任何改變直接輸出 '''
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({
      H}*{
      W}) doesn't match model ({
      self.img_size[0]}*{
      self.img_size[1]})."
        
        ''' self.proj(x):[B,3,244,244]->[B,768,14,14] flatten(2):[B,768,14,14]->[B,768,14*14]=[B,768,196] transpose(1, 2):[B,768,196]->[B,196,768] self.norm(x)不對輸入做處理直接輸出 '''    
        x = self.proj(x).flat1ten(2).transpose(1, 2)
        x = self.norm(x)
        return x

Transformer Encoder

Transformer Encoder由Attention、MLP和DropPath代碼組成,其結構圖如下:
在這裏插入圖片描述

Multi-Head Attention

關於 Multi-Head Attention 的結構圖和詳細介紹可查看博文,論文閱讀筆記:Attention Is All You Need
Attention具體代碼及注釋如下:

class Attention(nn.Module):
    def __init__(self,
                 dim,   # 輸入token的dim 768
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        ''' num_heads = 12 head_dim = 768 // 12 = 64 (Attention is all you need論文中提到的dk=dv=dmodel/h) scale = 64 ^ -0.5 = 1/8(Attention is all you need論文中Scaled Dot-Product Attention提到的公式Attention(Q,K,V)中的根號dk分之一) qkv:將輸入線性映射到q,k,v proj:Attention is all you need論文中Multi-Head Attention最後的融合矩陣 Wo,使用 Linear 的實現 '''
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):
    	''' B = batch_size N = 197 C = 768 '''
        B, N, C = x.shape
		
		''' qkv(x) : [B,197,768] -> [B,197,768*3] reshape : [B,197,768*3] -> [B,197,3,12,64] (3分別代錶qkv,12個head,每個head為64維向量) permute:[B,197,3,12,64] -> [3,B,12,197,64] '''
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        ''' q,k,v = [B,12,197,64] '''
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        ''' K.transpose(-2, -1) : [B,12,197,64] = [B,12,64,197] q @ K.transpose(-2, -1) : [B,12,197,64] @ [B,12,64,197] = [B,12,197,197] attn : [B,12,197,197] attn.softmax(dim=-1)對最後一個維度(即每一行)進行softmax處理 '''
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        ''' attn @ v = [B,12,197,197] @ [B,12,197,64] = [B,12,197,64] transpose(1, 2) : [B,197,12,64] reshape : [B,197,768] '''
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

MLP

在這裏插入圖片描述
MLP結構和代碼都很簡單,就是全連接加激活函數加dropout,這裏的激活函數用的GELU:

G E L U ( x ) = 0.5 x ( 1 + t a n h [ 2 π ( x + 0.044715 x 3 ) ] ) GELU(x)=0.5x(1+tanh[\frac{2}{π}(x+0.044715x^3)]) GELU(x)=0.5x(1+tanh[π2(x+0.044715x3)])

MLP模塊代碼如下:

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

DropPath

在Transformer Encoder中代碼使用DropPath代替論文中的Dropout,具體代碼及注釋如下:

def drop_path(x, drop_prob: float = 0., training: bool = False):
    ''' x.shape : [B,197,768] '''
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    ''' shape = [B,1,1] 即將X的第一維度保留,其他維度改為1 '''
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    ''' 生成形狀為shape的隨機張量並加上keep_prob '''
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    ''' 將隨機張量向下取整,一部分為0,一部分為1 '''
    random_tensor.floor_()  # binarize
    ''' 將x除以keep_prob再乘上隨機張量,一部分變成0,一部分保留 '''
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

MLP Head

在這裏插入圖片描述
原文中關於MLP Head的代碼:

# Representation layer
if representation_size and not distilled:
 	self.has_logits = True
	self.num_features = representation_size
	self.pre_logits = nn.Sequential(OrderedDict([
		("fc", nn.Linear(embed_dim, representation_size)),
		("act", nn.Tanh())
	]))
else:
	self.has_logits = False
	self.pre_logits = nn.Identity()

# Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

這裏的代碼也很簡單,就不做過多注釋了,代碼中distilled = False,所以:
self.pre_logits = nn.Sequential(nn.Linear,(embed_dim, representation_size)nn.Tanh())
self.head = nn.Linear(self.num_features, num_classes)
MLPHead(x) = self.head(self.pre_logits(x[:, 0]))

VisionTransformer

ViT-B/16整體網絡結構如下圖:
在這裏插入圖片描述
ViT-B/16模型使用的圖像輸入尺寸為 224×224×3,patch尺寸為16×16×3,每個patch embed的維度為768,transformer encoder block的個數為12, Multi-Head Attention的head個數為12,最後兩個參數看調用模型時的參數設置,representation_size為pre_logits中全連接層節點個數,num_classes為預測的總分類數。

def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=768 if has_logits else None,
                              num_classes=num_classes)
    return model

VisionTransformer具體代碼及注釋如下:

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        """ Args: img_size (int, tuple): 輸入圖像尺寸 patch_size (int, tuple): patch 尺寸 in_c (int): 輸入通道 num_classes (int): 分類數 embed_dim (int): patchembed 維度 depth (int): transformer encoder 模塊( Block 模塊)個數 num_heads (int): Multi-Head Attention 中的 head 個數 mlp_ratio (int): MLP 隱藏層和 embed_dim 的比例 qkv_bias (bool): 是否使用 qkv 偏置(即使用 Linear 將輸入映射到 qkv 時,Linear是否使用 bias ) qk_scale (float): qk縮放比例,默認使用根號 dim_k 分之一 representation_size (Optional[int]): pre-logits 中的全連接節點個數,如果是 None 則不要 pre-logits (MLP Head 中只有一個全連接層) distilled (bool): 是否使用 DeiT 模型(基於知識蒸餾的transformer),在 VIT 中默認為 False drop_ratio (float): dropout概率 attn_drop_ratio (float): attention 中的 dropout 概率 drop_path_ratio (float): attention 中的 droppath 概率 embed_layer (nn.Module): patch embedding 層 norm_layer: (nn.Module): normalization 層 """
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        ''' self.num_features = self.embed_dim = 768 self.num_tokens = 1 norm_layer = nn.LayerNorm(eps=1e-6) act_layer = nn.GELU '''
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 2 if distilled else 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU
		
		''' 構建patch embeding layer num_patches = (224/16) * (224/16) = 196 '''
        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
		
		''' 構建可學習參數: self.cls_token : [1,1,768] 分類token self.dist_token : None self.pos_embed : [1,197,768] 比特置編碼 '''
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)
		
		''' 構建首項為0,長度為depth的等差數列,且每一項小於drop_path_ratio 也就是說 傳入 Block 的 droppath 概率是遞增的。 代碼這裏是讓 drop_path_ratio 默認等於0 最後利用參數構建 depth(12) 層 block 層 並把 LayerNorm(embed_dim) 賦值給self.norm '''
        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)
		
		''' 構建 pre_logits : 1.全連接層:輸入embed_dim(768),輸出representation_size(768) 2.激活函數:Tanh '''
        # Representation layer
        if representation_size and not distilled:
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()
		
		''' 構建分類器: self.num_features = 768 '''
        # Classifier head(s)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
		
		''' 初始化pos_embed、cls_token 初始化網絡其他層的權重 '''
        # Weight init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)

        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(_init_vit_weights)

    def forward_features(self, x):
        ''' self.patch_embed(x) : [B,3,244,244] -> [B,196,768] 合並 cls_token: self.cls_token : [1,1,768] cls_token : [B,1,768] x = torch.cat((cls_token, x), dim=1) : [B,197,768] '''
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1) 
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
		
		''' 加上比特置編碼: x = x + self.pos_embed : [B,197,768] 經過 Attention blocks 和 LayerNorm : [B,197,768] 最後返回分類 token 並傳入 pre_logits: return self.pre_logits(x[:, 0]) : [B,768] '''
        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
    	''' self.forward_features(x) : [B,3,244,244] -> [B,768] x = self.head(x) : [B,768] -> [B,num_classes] '''
        x = self.forward_features(x)
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.head(x)
        return x


def _init_vit_weights(m):
    """ ViT weight initialization :param m: module """
    if isinstance(m, nn.Linear):
        nn.init.trunc_normal_(m.weight, std=.01)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out")
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.zeros_(m.bias)
        nn.init.ones_(m.weight)

上述代碼的distilled參數所涉及的 DeiT models 代碼中並沒有使用,論文中也沒有提到,如有疑惑可查看ViT和DeiT的原理與使用

版權聲明
本文為[HollowKnightZ]所創,轉載請帶上原文鏈接,感謝
https://cht.chowdera.com/2022/01/202201281424515940.html

隨機推薦