티스토리 뷰

 

Vision Transformer를 간단하게 구현해보자

 

patch단위로 sequence형태로 변형하여 image embedding, multi head attention, MLP 구조로 구현하려고 한다.

 

1. Setup

import torch
import torch.nn as nn
from torch import Tensor
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim

from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

einops 라이브러리를 활용해 전처리를 좀 더 쉽게 했다.

 

 

2. Image Embedding Class

 

class image_embedding(nn.Module):
    def __init__(self, in_channels = 3, img_size = 224, patch_size = 16, emb_dim = 16 * 16 * 3):
        super().__init__()

        self.rearrange = Rearrange('b c (num_w p1) (num_h p2) -> b (num_w num_h) (p1 p2 c)', p1 = patch_size, p2 = patch_size)
        self.linear = nn.Linear(in_channels * patch_size * patch_size, emb_dim)

        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))

        n_patches = img_size * img_size // patch_size ** 2
        self.positions = nn.Parameter(torch.randn(n_patches + 1, emb_dim))


    def forward(self, x):
        batch, channel, width, height = x.shape
        # print('before rearrange x shape :', x.shape)
        x = self.rearrange(x)
        # print('after rearrange x shape :', x.shape)
        x = self.linear(x)
        # print('cls_token shape :', self.cls_token.shape)
        c = repeat(self.cls_token, '() n d -> b n d', b = batch)
        x = torch.cat((c, x), 1)
        # print('positions shape :', self.positions.shape)
        x = torch.add(x, self.positions)

        return x

 

 

3. Multi-head attention

class multi_head_attention(nn.Module):
    def __init__(self, emb_dim : int = 16 * 16 * 3, num_heads : int = 8, dropout_ratio : float = 0.2, verbose = False, **kwargs):
        super(multi_head_attention, self).__init__()
        self.v = verbose

        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.scaling = (self.emb_dim // num_heads) ** (-0.5)

        self.value = nn.Linear(emb_dim, emb_dim)
        self.key = nn.Linear(emb_dim, emb_dim)
        self.query = nn.Linear(emb_dim, emb_dim)
        self.att_drop = nn.Dropout(dropout_ratio)

        self.linear = nn.Linear(emb_dim, emb_dim)

    def forward(self, x : Tensor) -> Tensor:
        # query, key, value

        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        if self.v: print(Q.size(), K.size(), V.size())

        # q = k = v = patch_size * 2 + 1 & h * d = emb_dim
        Q = rearrange(Q, 'b q (h d) -> b h q d', h = self.num_heads)
        K = rearrange(K, 'b k (h d) -> b h d k', h = self.num_heads)
        V = rearrange(V, 'b v (h d) -> b h v d', h = self.num_heads)
        if self.v: print(Q.size(), K.size(), V.size())

        # scaled dot-product
        weight = torch.matmul(Q, K)
        weight = weight * self.scaling
        if self.v: print(weight.size())

        attention = torch.softmax(weight, dim = -1)
        attention = self.att_drop(attention)
        if self.v: print(attention.size())

        context = torch.matmul(attention, V)
        context = rearrange(context, 'b h q d -> b q (h d)')
        if self.v: print(context.size())

        x = self.linear(context)
        return x, attention

 

 

4. MLP Block

class mlp_block(nn.Module):
    def __init__(self, emb_dim : int = 16 * 16 * 3, forward_dim : int = 4, dropout_ratio : float = 0.2, **kwargs):
        super(mlp_block, self).__init__()
        self.linear_1 = nn.Linear(emb_dim, forward_dim * emb_dim)
        self.dropout = nn.Dropout(dropout_ratio)
        self.linear_2 = nn.Linear(emb_dim * forward_dim, emb_dim)

    def forward(self, x):
        x = self.linear_1(x)
        x = nn.ReLU()(x)
        x = self.dropout(x)
        x = self.linear_2(x)

        return x

 

 

5. Encoder Block

class encoder_block(nn.Module):
    def __init__(self, emb_dim : int = 16 * 16 * 3, num_heads : int = 8, forward_dim : int = 4, dropout_ratio : float = 0.2):
        super(encoder_block, self).__init__()
        self.norm_1 = nn.LayerNorm(emb_dim)
        self.mha = multi_head_attention(emb_dim, num_heads, dropout_ratio)

        self.norm_2 = nn.LayerNorm(emb_dim)
        self.mlp = mlp_block(emb_dim, forward_dim, dropout_ratio)

        self.residual_dropout = nn.Dropout(dropout_ratio)

    def forward(self, x):
        x2 = self.norm_1(x)
        x2, attention = self.mha(x2)

        x = torch.add(x2, x)

        x2 = self.norm_2(x)
        x2 = self.mlp(x2)
        x = torch.add(x2, x)

        return x, attention

 

 

6. Model

최종 ViT 모델은 다음과 같이 작성된다.

class vision_transformer(nn.Module):
    def __init__(self, in_channels : int = 3, img_size : int = 224, patch_size : int = 16, emb_dim : int = 16 * 16 * 3,
                 n_enc_layers : int = 15, num_heads : int = 3, forward_dim : int = 4, dropout_ratio : float = 0.2, n_classes : int = 1000 ):
        super(vision_transformer, self).__init__()

        self.image_embedding = image_embedding(in_channels, img_size, patch_size, emb_dim)
        encoder_module = [encoder_block(emb_dim, num_heads, forward_dim, dropout_ratio) for _ in range(n_enc_layers)]
        self.encoder_module = nn.ModuleList(encoder_module)

        self.reduce_layer = Reduce('b n e -> b e', reduction = 'mean')
        self.normalization = nn.LayerNorm(emb_dim)
        self.classification_head = nn.Linear(emb_dim, n_classes)

    def forward(self, x):
        x = self.image_embedding(x)
        attentions = [block(x)[1] for block in self.encoder_module]

        x = self.reduce_layer(x)
        x = self.normalization(x)
        x = self.classification_head(x)

        return x, attentions

 

CIFAR-10 데이터로 학습하는 코드는 아래 github 링크로 공유하니 참고

 

https://github.com/kkt4828/reviewpaper/blob/master/Transformer/ViT.py

 

GitHub - kkt4828/reviewpaper: 논문구현

논문구현. Contribute to kkt4828/reviewpaper development by creating an account on GitHub.

github.com