티스토리 뷰

 

기존 CNN모델들과 다르게 image patch 처리를 해줘야하는 코드가 추가되었다.

 

1. Setup

 

기존과 동일하다.

 

import torch.nn as nn
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import time
import numpy as np

import random
import torch.backends.cudnn as cudnn

seed = 2022
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(seed)

 

 

2. Linear Projection

 

이미지를 patch로 쪼갠 후 embedding 하는 class이다.

 

class LinearPatchProjection(nn.Module):
    def __init__(self, device, batch_size, out_dim = 768, img_size=224, patch_size = 16, channel_size = 3,):
        super(LinearPatchProjection, self).__init__()
        self.device = device
        self.b = batch_size
        self.p = patch_size
        self.c = channel_size
        self.out_dim = out_dim
        self.n = img_size ** 2 // (patch_size ** 2)


        self.projection = nn.Linear(in_features=self.p**2 * self.c, out_features=self.out_dim)

    def forward(self, x):
        x = x.view(-1, self.n, (self.p ** 2) * self.c)
        x_p = self.projection(x)
        x_cls = nn.Parameter(torch.randn(x_p.size(0), 1, self.out_dim), requires_grad=True).to(device)
        x_pos = nn.Parameter(torch.randn(x_p.size(0), self.n + 1, self.out_dim), requires_grad=True).to(device)
        x_p = torch.concat((x_cls, x_p), dim = 1)
        x = torch.add(x_p, x_pos)

        return x

 

class token과 position token은 batch size를 받아야 하기에 forward로 꺼내서 cuda 적용하는 방식으로 코딩을 진행했다.

 

 

3. Self Attention

 

Norm Scale 값은 논문에서 latent dimension 값 D를 head의 개수로 나눠준 값으로 논문에서 제시하고 있다.

 

class SelfAttention(nn.Module):
    def __init__(self, out_dim = 768, d = 12):
        super(SelfAttention, self).__init__()
        self.out_dim = out_dim
        self.norm_scale = out_dim // d
        self.q = nn.Linear(in_features=out_dim, out_features=out_dim // d)
        self.k = nn.Linear(in_features=out_dim, out_features=out_dim // d)
        self.v = nn.Linear(in_features=out_dim, out_features=out_dim // d)
        self.soft = nn.Softmax(dim = -1)

    def forward(self, x):
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        qk = torch.div(torch.matmul(q, torch.transpose(k, 1, 2)), self.norm_scale ** 0.5)
        qk = self.soft(qk)
        qkv = torch.matmul(qk, v)
        return qkv

 

 

4. MSA(Multi head Self-Attention)

 

class MultiHeadAttention(nn.Module):
    def __init__(self, out_dim = 768, h = 12):
        super(MultiHeadAttention, self).__init__()
        self.h = h
        self.SA = nn.ModuleList([SelfAttention(out_dim, h) for _ in range(h)])
        self.linear = nn.Linear(in_features=out_dim, out_features=out_dim)

    def forward(self, x):
        for i in range(self.h):
            if i == 0:
                x_cat = self.SA[i](x)
            else:
                x_cat = torch.cat((x_cat, self.SA[i](x)), dim = -1)
        x = self.linear(x_cat)
        return x

ModuleList로 넣어줘 시작할 때 concat할 base를 만들어준 후 나머지는 base에 concat하는 방식으로 코딩을 진행했다.

 

 

5. Encoder

 

class Encoder(nn.Module):
    def __init__(self, out_dim = 768, h = 12):
        super(Encoder, self).__init__()
        self.norm1 = nn.LayerNorm(out_dim)
        self.act1 = nn.GELU()
        self.mha = MultiHeadAttention(out_dim, h = h)
        self.norm2 = nn.LayerNorm(out_dim)
        self.act2 = nn.GELU()
        self.linear = nn.Linear(in_features=out_dim, out_features=out_dim)

    def forward(self, x):
        x_norm = self.norm1(x)
        x_norm = self.act1(x_norm)
        x_norm = self.mha(x_norm)
        x = torch.add(x_norm, x)
        x_norm = self.norm2(x)
        x_norm = self.act2(x_norm)
        x_norm = self.linear(x_norm)
        x = torch.add(x_norm, x)
        return x

 

Non-Linearity를 위한 Activation Function은 GELU를 적용하도록 했고, Layer Norm을 모든 Layer에 적용하도록 했다.

 

 

 

6. ViT

 

class VisionTransformer(nn.Module):
    def __init__(self, device, L = 12, out_dim = 768, h = 12, ML = 3096, num_classes = 10, img_size=224, patch_size = 16, channel_size = 3, batch_size = 16):
        super(VisionTransformer, self).__init__()
        self.batch_size = batch_size
        self.embedding = LinearPatchProjection(device, self.batch_size, out_dim, img_size, patch_size, channel_size)
        self.transencoder = nn.Sequential(*[Encoder(out_dim, h) for _ in range(L)])
        self.flatten = nn.Flatten()
        self.mlphead = nn.Sequential(nn.Linear(((img_size // patch_size) ** 2 + 1) * out_dim , num_classes))
        self.soft = nn.Softmax(dim = 1)

    def forward(self, x):
        x = self.embedding(x)
        x = self.transencoder(x)
        x = self.flatten(x)
        x = self.mlphead(x)
        x = self.soft(x)

        return x

 

ML parameter는 원래 scratch에서는 필요하지만 GPU가 딸려서 마지막 classification head에서는 single Layer의 MLP block을 넣었다.

 

https://ai.dreamkkt.com/65

 

[논문 리뷰] ViT 살펴보기 2편 - Vision Transformer

1편에서 ViT를 이해하는데 필요한 Transformer에 대해서 간단하게 알아봤다. https://ai.dreamkkt.com/64 [논문 리뷰] ViT 살펴보기 1편 - Transformer 비전 Task에서 많이 활용되는 ViT(Vision Transformer)를 이..

ai.dreamkkt.com