티스토리 뷰

논문을 바탕으로 간단하게 구현해보자!

 

1. Setup

import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
import torch.optim as optim
import time

import random
import torch.backends.cudnn as cudnn

seed = 2022
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(seed)

2. DenseBlock

우선 Block내 Layer Class부터 구현해보자

class DenseUnitBlock(nn.Module):
    def __init__(self, in_features, out_features):
        super(DenseUnitBlock, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.batch1 = nn.BatchNorm2d(self.in_features)
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(
            in_channels=self.in_features,
            out_channels=self.out_features * 4,
            kernel_size=1
                               )
        self.batch2 = nn.BatchNorm2d(self.out_features * 4)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv2d(
            in_channels=self.out_features * 4,
            out_channels=self.out_features,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.net = nn.Sequential(
            self.batch1, self.relu1, self.conv1, self.batch2, self.relu2, self.conv2
        )
        self.proj = nn.Conv2d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1
        )
    def forward(self, x):
        x2 = self.net(x)
        x = self.proj(x)
        x2 = x2.add(x)
        return x2

논문에서 1x1 Conv의 out channels를 4k로 언급한 것으로 이해하여 최종 out channels가 k로 입력될 것이므로 out features * 4로 구현했다. 또한 Skip Connection을 추가로 적용하여 모든 conv에 skip connection이 적용되도록 했다.

Dense Block 은 다음과 같이 구현했다.

class DenseBlock(nn.Module):
    def __init__(self, L, k_0 = 16, k = 12):
        super(DenseBlock, self).__init__()
        self.k = k
        self.k_0 = k_0
        net_list = [DenseUnitBlock(self.k_0 + i * self.k, self.k) for i in range(L)]
        self.net_list = nn.ModuleList(net_list)
    def forward(self, x):
        input_list = [x]
        for block in self.net_list:
            for idx, i in enumerate(input_list):
                if idx == 0:
                    t = i
                else:
                    t = torch.cat((t, i), dim = 1)
            out = block(t)
            input_list.append(out)
        return out

input_list를 따로 두어 layer 통과할 때마다 나왔던 feature를 입력하여 다음 layer 통과시 concat하여 입력하는 식으로 구현했다. 이후 마지막 output을 return해주었다.

 

3. Transition Layers

Size를 줄여 불필요한 픽셀을 줄이는 pooling 역할의 layers이다.

class TransitionLayers(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TransitionLayers, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            stride=1,
            kernel_size=1)
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

4. 최종 DenseNet

class DenseNet(nn.Module):
    def __init__(self, in_channels, k, num_classes):
        super().__init__()
        self.num_classes = num_classes
        self.k = k
        self.in_channels = in_channels
        self.batch1 = nn.BatchNorm2d(num_features=in_channels)
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=self.k * 2,
            kernel_size=7,
            stride=2,
            padding=0
        )
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.net1 = nn.Sequential(
            self.batch1, self.relu1, self.conv1, self.pool1
        )

        self.block1 = DenseBlock(L = 6, k_0 = self.k * 2, k = self.k)
        self.trans1 = TransitionLayers(in_channels=self.k, out_channels=self.k)
        self.block2 = DenseBlock(L = 12, k_0 = self.k, k=self.k)
        self.trans2 = TransitionLayers(in_channels=self.k, out_channels=self.k)
        self.block3 = DenseBlock(L = 24, k_0 = self.k, k=self.k)
        self.trans3 = TransitionLayers(in_channels=self.k ,out_channels=self.k)
        self.block4 = DenseBlock(L = 16, k_0 = self.k, k = self.k)
        self.pool2 = nn.AdaptiveAvgPool2d(output_size=1)
        self.flat = nn.Flatten()
        self.fc = nn.Linear(in_features=self.k, out_features=self.num_classes)
        self.soft = nn.Softmax(dim = 1)
        self.net2 = nn.Sequential(
            self.block1,
            self.trans1,
            self.block2,
            self.trans2,
            self.block3,
            self.trans3,
            self.block4,
            self.flat,
            self.fc,
            self.soft
        )

    def forward(self, x):
        x = self.net1(x)
        x = self.net2(x)

        return x

 

 

전체 학습코드는 다음 Github를 참고!

https://github.com/kkt4828/reviewpaper/blob/857da870c553b1e97872c361808bb1f20b844a22/DenseNet/DenseNet.py

 

GitHub - kkt4828/reviewpaper: 논문구현

논문구현. Contribute to kkt4828/reviewpaper development by creating an account on GitHub.

github.com