티스토리 뷰

목차



     

    기존 CNN모델들과 다르게 image patch 처리를 해줘야하는 코드가 추가되었다.

     

    1. Setup

     

    기존과 동일하다.

     

    import torch.nn as nn
    import torch
    import torchvision
    import torchvision.transforms as transforms
    from torch.utils.data import DataLoader
    import torch.optim as optim
    import time
    import numpy as np
    
    import random
    import torch.backends.cudnn as cudnn
    
    seed = 2022
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(seed)

     

     

    2. Linear Projection

     

    이미지를 patch로 쪼갠 후 embedding 하는 class이다.

     

    class LinearPatchProjection(nn.Module):
        def __init__(self, device, batch_size, out_dim = 768, img_size=224, patch_size = 16, channel_size = 3,):
            super(LinearPatchProjection, self).__init__()
            self.device = device
            self.b = batch_size
            self.p = patch_size
            self.c = channel_size
            self.out_dim = out_dim
            self.n = img_size ** 2 // (patch_size ** 2)
    
    
            self.projection = nn.Linear(in_features=self.p**2 * self.c, out_features=self.out_dim)
    
        def forward(self, x):
            x = x.view(-1, self.n, (self.p ** 2) * self.c)
            x_p = self.projection(x)
            x_cls = nn.Parameter(torch.randn(x_p.size(0), 1, self.out_dim), requires_grad=True).to(device)
            x_pos = nn.Parameter(torch.randn(x_p.size(0), self.n + 1, self.out_dim), requires_grad=True).to(device)
            x_p = torch.concat((x_cls, x_p), dim = 1)
            x = torch.add(x_p, x_pos)
    
            return x

     

    class token과 position token은 batch size를 받아야 하기에 forward로 꺼내서 cuda 적용하는 방식으로 코딩을 진행했다.

     

     

    3. Self Attention

     

    Norm Scale 값은 논문에서 latent dimension 값 D를 head의 개수로 나눠준 값으로 논문에서 제시하고 있다.

     

    class SelfAttention(nn.Module):
        def __init__(self, out_dim = 768, d = 12):
            super(SelfAttention, self).__init__()
            self.out_dim = out_dim
            self.norm_scale = out_dim // d
            self.q = nn.Linear(in_features=out_dim, out_features=out_dim // d)
            self.k = nn.Linear(in_features=out_dim, out_features=out_dim // d)
            self.v = nn.Linear(in_features=out_dim, out_features=out_dim // d)
            self.soft = nn.Softmax(dim = -1)
    
        def forward(self, x):
            q = self.q(x)
            k = self.k(x)
            v = self.v(x)
            qk = torch.div(torch.matmul(q, torch.transpose(k, 1, 2)), self.norm_scale ** 0.5)
            qk = self.soft(qk)
            qkv = torch.matmul(qk, v)
            return qkv

     

     

    4. MSA(Multi head Self-Attention)

     

    class MultiHeadAttention(nn.Module):
        def __init__(self, out_dim = 768, h = 12):
            super(MultiHeadAttention, self).__init__()
            self.h = h
            self.SA = nn.ModuleList([SelfAttention(out_dim, h) for _ in range(h)])
            self.linear = nn.Linear(in_features=out_dim, out_features=out_dim)
    
        def forward(self, x):
            for i in range(self.h):
                if i == 0:
                    x_cat = self.SA[i](x)
                else:
                    x_cat = torch.cat((x_cat, self.SA[i](x)), dim = -1)
            x = self.linear(x_cat)
            return x

    ModuleList로 넣어줘 시작할 때 concat할 base를 만들어준 후 나머지는 base에 concat하는 방식으로 코딩을 진행했다.

     

     

    5. Encoder

     

    class Encoder(nn.Module):
        def __init__(self, out_dim = 768, h = 12):
            super(Encoder, self).__init__()
            self.norm1 = nn.LayerNorm(out_dim)
            self.act1 = nn.GELU()
            self.mha = MultiHeadAttention(out_dim, h = h)
            self.norm2 = nn.LayerNorm(out_dim)
            self.act2 = nn.GELU()
            self.linear = nn.Linear(in_features=out_dim, out_features=out_dim)
    
        def forward(self, x):
            x_norm = self.norm1(x)
            x_norm = self.act1(x_norm)
            x_norm = self.mha(x_norm)
            x = torch.add(x_norm, x)
            x_norm = self.norm2(x)
            x_norm = self.act2(x_norm)
            x_norm = self.linear(x_norm)
            x = torch.add(x_norm, x)
            return x

     

    Non-Linearity를 위한 Activation Function은 GELU를 적용하도록 했고, Layer Norm을 모든 Layer에 적용하도록 했다.

     

     

     

    6. ViT

     

    class VisionTransformer(nn.Module):
        def __init__(self, device, L = 12, out_dim = 768, h = 12, ML = 3096, num_classes = 10, img_size=224, patch_size = 16, channel_size = 3, batch_size = 16):
            super(VisionTransformer, self).__init__()
            self.batch_size = batch_size
            self.embedding = LinearPatchProjection(device, self.batch_size, out_dim, img_size, patch_size, channel_size)
            self.transencoder = nn.Sequential(*[Encoder(out_dim, h) for _ in range(L)])
            self.flatten = nn.Flatten()
            self.mlphead = nn.Sequential(nn.Linear(((img_size // patch_size) ** 2 + 1) * out_dim , num_classes))
            self.soft = nn.Softmax(dim = 1)
    
        def forward(self, x):
            x = self.embedding(x)
            x = self.transencoder(x)
            x = self.flatten(x)
            x = self.mlphead(x)
            x = self.soft(x)
    
            return x

     

    ML parameter는 원래 scratch에서는 필요하지만 GPU가 딸려서 마지막 classification head에서는 single Layer의 MLP block을 넣었다.

     

    https://ai.dreamkkt.com/65

     

    [논문 리뷰] ViT 살펴보기 2편 - Vision Transformer

    1편에서 ViT를 이해하는데 필요한 Transformer에 대해서 간단하게 알아봤다. https://ai.dreamkkt.com/64 [논문 리뷰] ViT 살펴보기 1편 - Transformer 비전 Task에서 많이 활용되는 ViT(Vision Transformer)를 이..

    ai.dreamkkt.com