티스토리 뷰
목차
논문을 바탕으로 간단하게 구현해보자!
1. Setup
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import time
import random
import torch.backends.cudnn as cudnn
seed = 2022
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(seed)
2. DenseBlock
우선 Block내 Layer Class부터 구현해보자
class DenseUnitBlock(nn.Module):
def __init__(self, in_features, out_features):
super(DenseUnitBlock, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.batch1 = nn.BatchNorm2d(self.in_features)
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(
in_channels=self.in_features,
out_channels=self.out_features * 4,
kernel_size=1
)
self.batch2 = nn.BatchNorm2d(self.out_features * 4)
self.relu2 = nn.ReLU()
self.conv2 = nn.Conv2d(
in_channels=self.out_features * 4,
out_channels=self.out_features,
kernel_size=3,
stride=1,
padding=1
)
self.net = nn.Sequential(
self.batch1, self.relu1, self.conv1, self.batch2, self.relu2, self.conv2
)
self.proj = nn.Conv2d(
in_channels=in_features,
out_channels=out_features,
kernel_size=1
)
def forward(self, x):
x2 = self.net(x)
x = self.proj(x)
x2 = x2.add(x)
return x2
논문에서 1x1 Conv의 out channels를 4k로 언급한 것으로 이해하여 최종 out channels가 k로 입력될 것이므로 out features * 4로 구현했다. 또한 Skip Connection을 추가로 적용하여 모든 conv에 skip connection이 적용되도록 했다.
Dense Block 은 다음과 같이 구현했다.
class DenseBlock(nn.Module):
def __init__(self, L, k_0 = 16, k = 12):
super(DenseBlock, self).__init__()
self.k = k
self.k_0 = k_0
net_list = [DenseUnitBlock(self.k_0 + i * self.k, self.k) for i in range(L)]
self.net_list = nn.ModuleList(net_list)
def forward(self, x):
input_list = [x]
for block in self.net_list:
for idx, i in enumerate(input_list):
if idx == 0:
t = i
else:
t = torch.cat((t, i), dim = 1)
out = block(t)
input_list.append(out)
return out
input_list를 따로 두어 layer 통과할 때마다 나왔던 feature를 입력하여 다음 layer 통과시 concat하여 입력하는 식으로 구현했다. 이후 마지막 output을 return해주었다.
3. Transition Layers
Size를 줄여 불필요한 픽셀을 줄이는 pooling 역할의 layers이다.
class TransitionLayers(nn.Module):
def __init__(self, in_channels, out_channels):
super(TransitionLayers, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
stride=1,
kernel_size=1)
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
4. 최종 DenseNet
class DenseNet(nn.Module):
def __init__(self, in_channels, k, num_classes):
super().__init__()
self.num_classes = num_classes
self.k = k
self.in_channels = in_channels
self.batch1 = nn.BatchNorm2d(num_features=in_channels)
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=self.k * 2,
kernel_size=7,
stride=2,
padding=0
)
self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
self.net1 = nn.Sequential(
self.batch1, self.relu1, self.conv1, self.pool1
)
self.block1 = DenseBlock(L = 6, k_0 = self.k * 2, k = self.k)
self.trans1 = TransitionLayers(in_channels=self.k, out_channels=self.k)
self.block2 = DenseBlock(L = 12, k_0 = self.k, k=self.k)
self.trans2 = TransitionLayers(in_channels=self.k, out_channels=self.k)
self.block3 = DenseBlock(L = 24, k_0 = self.k, k=self.k)
self.trans3 = TransitionLayers(in_channels=self.k ,out_channels=self.k)
self.block4 = DenseBlock(L = 16, k_0 = self.k, k = self.k)
self.pool2 = nn.AdaptiveAvgPool2d(output_size=1)
self.flat = nn.Flatten()
self.fc = nn.Linear(in_features=self.k, out_features=self.num_classes)
self.soft = nn.Softmax(dim = 1)
self.net2 = nn.Sequential(
self.block1,
self.trans1,
self.block2,
self.trans2,
self.block3,
self.trans3,
self.block4,
self.flat,
self.fc,
self.soft
)
def forward(self, x):
x = self.net1(x)
x = self.net2(x)
return x
전체 학습코드는 다음 Github를 참고!
'AI' 카테고리의 다른 글
[논문 리뷰] EfficientNet 간단하게 리뷰하기! (0) | 2022.08.02 |
---|---|
[Loss Function] Cross Entropy에 대해 간단하게 알아보자! (0) | 2022.07.31 |
[논문 리뷰] VGGNet을 간단하게 리뷰해보자 (0) | 2022.07.27 |
[논문 리뷰] AlexNet을 간단하게 리뷰해보자! (0) | 2022.07.27 |
[논문 구현] ResNet 직접 구현해보기! (0) | 2022.07.27 |