티스토리 뷰

목차



     

    Vision Transformer를 간단하게 구현해보자

     

    patch단위로 sequence형태로 변형하여 image embedding, multi head attention, MLP 구조로 구현하려고 한다.

     

    1. Setup

    import torch
    import torch.nn as nn
    from torch import Tensor
    import torchvision
    import torchvision.transforms as transforms
    from torch.utils.data import DataLoader
    import torch.optim as optim
    
    from einops import rearrange, repeat
    from einops.layers.torch import Rearrange, Reduce

    einops 라이브러리를 활용해 전처리를 좀 더 쉽게 했다.

     

     

    2. Image Embedding Class

     

    class image_embedding(nn.Module):
        def __init__(self, in_channels = 3, img_size = 224, patch_size = 16, emb_dim = 16 * 16 * 3):
            super().__init__()
    
            self.rearrange = Rearrange('b c (num_w p1) (num_h p2) -> b (num_w num_h) (p1 p2 c)', p1 = patch_size, p2 = patch_size)
            self.linear = nn.Linear(in_channels * patch_size * patch_size, emb_dim)
    
            self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
    
            n_patches = img_size * img_size // patch_size ** 2
            self.positions = nn.Parameter(torch.randn(n_patches + 1, emb_dim))
    
    
        def forward(self, x):
            batch, channel, width, height = x.shape
            # print('before rearrange x shape :', x.shape)
            x = self.rearrange(x)
            # print('after rearrange x shape :', x.shape)
            x = self.linear(x)
            # print('cls_token shape :', self.cls_token.shape)
            c = repeat(self.cls_token, '() n d -> b n d', b = batch)
            x = torch.cat((c, x), 1)
            # print('positions shape :', self.positions.shape)
            x = torch.add(x, self.positions)
    
            return x

     

     

    3. Multi-head attention

    class multi_head_attention(nn.Module):
        def __init__(self, emb_dim : int = 16 * 16 * 3, num_heads : int = 8, dropout_ratio : float = 0.2, verbose = False, **kwargs):
            super(multi_head_attention, self).__init__()
            self.v = verbose
    
            self.emb_dim = emb_dim
            self.num_heads = num_heads
            self.scaling = (self.emb_dim // num_heads) ** (-0.5)
    
            self.value = nn.Linear(emb_dim, emb_dim)
            self.key = nn.Linear(emb_dim, emb_dim)
            self.query = nn.Linear(emb_dim, emb_dim)
            self.att_drop = nn.Dropout(dropout_ratio)
    
            self.linear = nn.Linear(emb_dim, emb_dim)
    
        def forward(self, x : Tensor) -> Tensor:
            # query, key, value
    
            Q = self.query(x)
            K = self.key(x)
            V = self.value(x)
    
            if self.v: print(Q.size(), K.size(), V.size())
    
            # q = k = v = patch_size * 2 + 1 & h * d = emb_dim
            Q = rearrange(Q, 'b q (h d) -> b h q d', h = self.num_heads)
            K = rearrange(K, 'b k (h d) -> b h d k', h = self.num_heads)
            V = rearrange(V, 'b v (h d) -> b h v d', h = self.num_heads)
            if self.v: print(Q.size(), K.size(), V.size())
    
            # scaled dot-product
            weight = torch.matmul(Q, K)
            weight = weight * self.scaling
            if self.v: print(weight.size())
    
            attention = torch.softmax(weight, dim = -1)
            attention = self.att_drop(attention)
            if self.v: print(attention.size())
    
            context = torch.matmul(attention, V)
            context = rearrange(context, 'b h q d -> b q (h d)')
            if self.v: print(context.size())
    
            x = self.linear(context)
            return x, attention

     

     

    4. MLP Block

    class mlp_block(nn.Module):
        def __init__(self, emb_dim : int = 16 * 16 * 3, forward_dim : int = 4, dropout_ratio : float = 0.2, **kwargs):
            super(mlp_block, self).__init__()
            self.linear_1 = nn.Linear(emb_dim, forward_dim * emb_dim)
            self.dropout = nn.Dropout(dropout_ratio)
            self.linear_2 = nn.Linear(emb_dim * forward_dim, emb_dim)
    
        def forward(self, x):
            x = self.linear_1(x)
            x = nn.ReLU()(x)
            x = self.dropout(x)
            x = self.linear_2(x)
    
            return x

     

     

    5. Encoder Block

    class encoder_block(nn.Module):
        def __init__(self, emb_dim : int = 16 * 16 * 3, num_heads : int = 8, forward_dim : int = 4, dropout_ratio : float = 0.2):
            super(encoder_block, self).__init__()
            self.norm_1 = nn.LayerNorm(emb_dim)
            self.mha = multi_head_attention(emb_dim, num_heads, dropout_ratio)
    
            self.norm_2 = nn.LayerNorm(emb_dim)
            self.mlp = mlp_block(emb_dim, forward_dim, dropout_ratio)
    
            self.residual_dropout = nn.Dropout(dropout_ratio)
    
        def forward(self, x):
            x2 = self.norm_1(x)
            x2, attention = self.mha(x2)
    
            x = torch.add(x2, x)
    
            x2 = self.norm_2(x)
            x2 = self.mlp(x2)
            x = torch.add(x2, x)
    
            return x, attention

     

     

    6. Model

    최종 ViT 모델은 다음과 같이 작성된다.

    class vision_transformer(nn.Module):
        def __init__(self, in_channels : int = 3, img_size : int = 224, patch_size : int = 16, emb_dim : int = 16 * 16 * 3,
                     n_enc_layers : int = 15, num_heads : int = 3, forward_dim : int = 4, dropout_ratio : float = 0.2, n_classes : int = 1000 ):
            super(vision_transformer, self).__init__()
    
            self.image_embedding = image_embedding(in_channels, img_size, patch_size, emb_dim)
            encoder_module = [encoder_block(emb_dim, num_heads, forward_dim, dropout_ratio) for _ in range(n_enc_layers)]
            self.encoder_module = nn.ModuleList(encoder_module)
    
            self.reduce_layer = Reduce('b n e -> b e', reduction = 'mean')
            self.normalization = nn.LayerNorm(emb_dim)
            self.classification_head = nn.Linear(emb_dim, n_classes)
    
        def forward(self, x):
            x = self.image_embedding(x)
            attentions = [block(x)[1] for block in self.encoder_module]
    
            x = self.reduce_layer(x)
            x = self.normalization(x)
            x = self.classification_head(x)
    
            return x, attentions

     

    CIFAR-10 데이터로 학습하는 코드는 아래 github 링크로 공유하니 참고

     

    https://github.com/kkt4828/reviewpaper/blob/master/Transformer/ViT.py

     

    GitHub - kkt4828/reviewpaper: 논문구현

    논문구현. Contribute to kkt4828/reviewpaper development by creating an account on GitHub.

    github.com