[AITech][Image Classification][P stage] 20220305 - Code Review

15 minute read


Code Review

지난 [Code Template] 포스팅에서는 AITech에서 진행한 대회의 개요와 데이터, 그리고 PyTorch Project의 코드 템플릿에 대해 살펴보았습니다.

이번 포스팅에서는 지난 포스팅에서 소개한 각 파이썬 파일을 코드 레벨에서 살펴봄으로써, 어떻게 작성해야 하는지 알아보겠습니다.

EDA.ipynb

이번에 진행한 대회는 이미지 분류 대회인만큼, 이미지에 대한 EDA를 수행합니다. 이미지는 다른 종류의 데이터들보다는 EDA 단계에서 할 일이 많지는 않지만, 어떤 것을 수행할 수 있는지 알아보겠습니다.

EDA 방법은 크게 3가지 방법으로 나눌 수 있습니다.

  • Input이 될 X에 대한 분석
  • Target이 될 y에 대한 분석
  • X, y 관계를 확인할 수 있는 분석

그리고 위 세 가지를 달성하기 위해 여러 visualization 기법을 사용할 수 있습니다. EDA를 잘하기 위해서는, 데이터에 대한 이해, 데이터의 종류에 따라 사용할 수 있는 plot에 대한 이해, 그리고 그 plot을 구현할 수 있는 코딩 능력 등이 필요합니다.

Import and Settings

이미지 데이터를 분석할 때는 다음과 같은 module들을 사용합니다.

  • numpy와 pandas: 데이터 처리/가공
  • matplotlib.pyplot과 seaborn: 데이터 시각화
  • PIL.Image와 cv2: 이미지 객체를 다룸
  • os: 시스템 레벨에서 사용하는 명령들(파일/디렉토리 확인 및 생성, 경로 생성 등)
  • %matplotlib inline: matplotlib 시각화 시 더 예쁘게 보여줌

이외에도 tqdm, warnings 등을 추가로 import 했습니다.

import os
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.image as img
import seaborn as sns
from PIL import Image
import cv2

import warnings
warnings.filterwarnings('ignore')

%matplotlib inline

그리고 EDA 시에는 보통 맨 위에 data path를 지정해놓습니다.

TRAIN_DATA_PATH = '../../input/data/train/train.csv'
TRAIN_IMG_PATH = '../../input/data/train/images'
SUBMISSION_PATH = '../../input/data/eval/info.csv'
EVAL_IMG_PATH = '../../input/data/eval/images'

Data Loading and Check

이후에는 데이터에 대한 정보가 기록되어 있는 csv 파일을 불러와 확인했습니다.

# 데이터 가져오기
train_data = pd.read_csv(TRAIN_DATA_PATH)
train_data

image-20220306164834937

csv 파일을 불러온 후 결측치를 확인하거나 여러 통계치를 확인할 수 있습니다.

각 클래스별 데이터의 개수를 세는 것도 여기서 할 수 있습니다. 따로 코드를 올리지는 않지만, 이번 대회에서는 csv 파일을 약간 가공해서 클래스별 개수를 확인해야 했습니다.

Plotting

csv 파일을 불러왔다면, 이를 이용해 plot을 그릴 수 있습니다. 여기서는 여러분의 지식과 창의력이 EDA를 풍성하게 해주며, 정답을 없습니다. Plotting을 통해 데이터의 특징을 시각적으로 명확히 파악할 수 있습니다. 이번 대회의 경우 데이터의 불균형이 매우 심했습니다.

제가 진행한 EDA중 몇 개를 소개해보겠습니다.

# gender 피쳐 분포 보기
fig = plt.figure(figsize=(12,7))
ax = fig.add_subplot(111)

sns.countplot(x='gender', data=train_data, 
    order=sorted(train_data['gender'].unique()), 
    ax=ax
    )

for idx, val in enumerate(train_data['gender'].sort_index().value_counts()):
    ax.text(x=idx, y=val+3, s=val, 
        va='bottom', ha='center', 
        fontsize=11, fontweight='semibold')

for s in ['top', 'right']:
    ax.spines[s].set_visible(False)



plt.show()

image-20220306165648425

# age 피쳐 분포 보기
fig, axes = plt.subplots(1,2,figsize=(15,7), sharex=True)

sns.histplot(x='age', data=train_data, ax=axes[0], 
    hue='gender',
    element='step', 
    multiple='stack')
sns.kdeplot(x='age', data=train_data, ax=axes[1], 
    hue='gender', 
    fill=True, 
    multiple="layer", 
    cumulative=False)

plt.show()

image-20220306165659190

# gender-age 간 분포
fig, axes = plt.subplots(1, 2, figsize=(15,7))

sns.countplot(x='age_range',
    data=train_data_with_agerange, 
    order=sorted(train_data_with_agerange['age_range'].unique()),
    hue='gender', 
    ax=axes[0])

sns.countplot(x='gender', data=train_data_with_agerange, 
    order=sorted(train_data['gender'].unique()), 
    hue='age_range', 
    hue_order=sorted(train_data['age_range'].unique()), 
    color='red', 
    saturation=1, 
    ax=axes[1])

plt.show()

image-20220306165728196

Check Images

이미지 데이터인 만큼, 이미지를 직접 불러와 확인하는 작업도 필요합니다.

이번 대회의 데이터셋은 각 사람마다 폴더 안에 5장의 mask 사진, 1장의 incorrect_mask 사진, 1장의 normal 사진이 있었습니다.

# 이미지 보기
# sampling
sample = train_data.sample()

# sampling된 사람의 7개 사진의 경로 가져오기
sample_img_path = sample.path.values[0] # values: value, dtype
sample_img_list = [img for img in os.listdir(TRAIN_IMG_PATH+'/'+sample_img_path) if '._' not in img]
sample_img_list = sorted(sample_img_list)

# 이미지 출력
fig = plt.figure(figsize=(18,8))

for i, filename in enumerate(sample_img_list):
    img_path = os.path.join(TRAIN_IMG_PATH, sample_img_path, filename)
    print(img_path)
    img = cv2.imread(img_path)
    ax = fig.add_subplot(2, 4, i+1)
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    ax.axis('off')
    ax.set_title(filename.split('.')[0], fontsize=15)

plt.tight_layout()
plt.show()

image-20220306165917028

이를 확장하여 5명의 사람에 대해 이미지를 출력할 수도 있습니다.

# 5개 샘플에 대해 이미지 출력
sample = train_data.sample(5)
sample_img_path = sample.path.values
img_list = []

for img in sample_img_path:
    lists = []
    imgs = [img for img in os.listdir(TRAIN_IMG_PATH+'/'+img) if '._' not in img]
    
    for data in imgs:
        path = TRAIN_IMG_PATH+'/'+img+'/'+data
        img_kind = data.split('.')[0]
        lists.append(path)
    lists = sorted(lists)
    img_list.append(lists)

fig = plt.figure(figsize=(35, 20))

i=1

for files in img_list:
    for filename in files:
        img = cv2.imread(filename)
        ax = fig.add_subplot(5, 7, i)
        ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        ax.axis('off')
        
        if i < 8:
            ax.set_title(filename.split('/')[-1].split('.')[0], fontsize=30, fontdict={'weight':'semibold'})
        
        i += 1

plt.tight_layout()
plt.show()

image-20220306170254608


dataset.py

dataset.py에서는 Dataset, Augmentation, Custom Transform 등을 위한 클래스들이 정의되어 있습니다.

Import and Settings

dataset.py 파일의 맨 위에는 import와 image file checking이 이루어집니다.

Import

os, numpy, PIL.Image 등의 라이브러리는 기본적으로 불러옵니다.

이에 더해 torch, torchvision, albumentations 등의 라이브러리도 불러옵니다. 이외에는 자신이 필요한대로 그때그때 추가하면 됩니다.

import os
import random
from collections import defaultdict
from enum import Enum
from typing import Tuple, List

import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset, Subset, random_split
from torchvision import transforms
from torchvision.transforms import *

import albumentations

Settings

이미지 파일을 불러올 때는, 그 이미지 파일이 적절한 파일인지 검사하는 과정이 필수적입니다. 이를 해주지 않으면, 나중에 뜻하지 않은 에러에 애를 먹을 수도 있습니다.

IMG_EXTENSIONS = [
    ".jpg", ".JPG", ".jpeg", ".JPEG", ".png",
    ".PNG", ".ppm", ".PPM", ".bmp", ".BMP",
]

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

Custom Transform

Custom Transform은 Augmentation 클래스에서 사용하고, Augmentation은 Dataset 클래스에서 사용하게 됩니다. 따라서 Custom Transform => Augmentation => Dataset 순으로 소개하겠습니다.

torchvisoin.transform이나 albumentations, imgaug 등의 라이브러리들에는 다양한 transform 클래스들이 있습니다. 이미 존재하는 것 외에, 자신만의 transform 클래스를 만들어 적용할 수 있습니다.

class AddGaussianNoise(object):
    """
        transform 에 없는 기능들은 이런식으로 __init__, __call__, __repr__ 부분을
        직접 구현하여 사용할 수 있습니다.
    """
    def __init__(self, mean=0., std=1., p=0.5):
        self.std = std
        self.mean = mean
        self.p = p

    def __call__(self, tensor):
        if torch.rand(1)[0] < self.p:
            return tensor + torch.randn(tensor.size()) * self.std + self.mean
        else:
            return tensor

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

Custom Transform 클래스를 생성할 때는 위와 같이 __init__, __call__, __repr__ 매직 메서드들을 구현해줍니다.

Augmentation

앞서 소개한 augmentation 라이브러리들이나 직접 만든 transform들을 이용해 Augmentation 클래스를 정의할 수 있습니다.

class CustomAugmentation:
    def __init__(self, resize, mean, std, **args):
        self.transform = transforms.Compose([
            CenterCrop((320, 256)),
            Resize(resize, Image.BILINEAR),
            ColorJitter(0.1, 0.1, 0.1, 0.1),
            ToTensor(),
            Normalize(mean=mean, std=std),
            AddGaussianNoise()
        ])

    def __call__(self, image):
        return self.transform(image)

Custom Augmentation 클래스를 정의할 때는 __init__에서 self.transform을 정의하고, __call__에서 self.transform이 적용된 image 객체를 반환합니다. Input으로 들어오는 image는 PIL Image 객체여야 하고, output으로 반환되는 객체는 torch.Tensor 객체입니다.

여기서 궁금한 것이, __init__의 파라미터들입니다. resize는 해당 모델이 요구하는 input size에 맞추거나 한다고 치고, mean과 std에는 무엇을 전달해야 할까요? 이에 대한 것은 다음 부분에서 다루도록 하겠습니다.

Dataset

드디어 dataset 클래스입니다. 여기서는 Dataset은 클래스로 만들고, DataLoader는 따로 만들지 않고 torch.data에서 제공하는 클래스를 바로 사용합니다. 어찌됐든, Dataset이나 DataLoader 클래스를 커스텀으로 만들 때는 보통 BaseDataset(BaseDataLoader)을 하나 생성합니다. 모든 dataset 클래스들에서 필요한 것들이 정의되어 있는 base class를 만들고, 이 base class를 상속하여 다양한 형태의 dataset 클래스를 정의합니다.

BaseDataset

전체 코드는 아니고, 핵심을 이해하기 위한 코드들만 남겨두었습니다.

  • __init__: 이미지를 PIL.Image 객체로 불러와서 저장하고, self.XX 프로퍼티들을 저장합니다.
    • setup: 데이터 경로로부터 이미지를 읽어와 저장해두는 과정을 진행합니다.
    • calc_statistics: 데이터 normalizing을 위해 필요한 mean과 std를 계산합니다. 여기서 계산된 mean과 std가 앞서 얘기한 Augmentation 클래스의 파라미터인 mean, std로 전달되어 normalize가 진행됩니다.
  • __getitem__: 인자로 전달받은 index에 해당하는 이미지를 불러오고, transform이 적용된 이미지와 label을 함께 return합니다. 이때 반환되는 값은 torch.Tensor입니다.
    • set_transform: __getitem__ 메서드에서 transform이 적용되기 위해서는 set_transform 메서드를 통해 self.transform이 지정되어 있어야 합니다. set_transform 메서드에 전달되는 것이 앞서 정의한 Augmentation 클래스입니다.
    • denormalize_image: 편의 함수로, transform된 이미지를 다시 원본 이미지로 변환해 반환해주는 클래스입니다.
  • __len__: 데이터 셋의 총 크기를 반환해줍니다.
  • split_dataset: 호출하면 dataset을 self.val_ratio비율만큼 train set/validation set으로 나누어 반환해줍니다.
class BaseDataset(Dataset):
    num_classes = 3 * 2 * 3

    # ...

    image_paths = []
    mask_labels = []
    gender_labels = []
    age_labels = []

    def __init__(self, data_dir, mean=(0.548, 0.504, 0.479), std=(0.237, 0.247, 0.246), val_ratio=0.2):
        self.data_dir = data_dir
        self.mean = mean
        self.std = std
        self.val_ratio = val_ratio

        self.transform = None
        self.setup()
        self.calc_statistics()

    def setup(self):
        profiles = os.listdir(self.data_dir)
        for profile in profiles:
            if profile.startswith("."):  # "." 로 시작하는 파일은 무시합니다
                continue

            img_folder = os.path.join(self.data_dir, profile)
            for file_name in os.listdir(img_folder):
                _file_name, ext = os.path.splitext(file_name)
                if _file_name not in self._file_names:  # "." 로 시작하는 파일 및 invalid 한 파일들은 무시합니다
                    continue

                img_path = os.path.join(self.data_dir, profile, file_name)
                mask_label = self._file_names[_file_name]

                id, gender, race, age = profile.split("_")
                gender_label = GenderLabels.from_str(gender)
                age_label = AgeLabels.from_number(age)

                self.image_paths.append(img_path)
                self.mask_labels.append(mask_label)
                self.gender_labels.append(gender_label)
                self.age_labels.append(age_label)

    def calc_statistics(self):
        has_statistics = self.mean is not None and self.std is not None
        if not has_statistics:
            print("[Warning] Calculating statistics... It can take a long time depending on your CPU machine")
            sums = []
            squared = []
            for image_path in self.image_paths[:3000]:
                image = np.array(Image.open(image_path)).astype(np.int32)
                sums.append(image.mean(axis=(0, 1)))
                squared.append((image ** 2).mean(axis=(0, 1)))

            self.mean = np.mean(sums, axis=0) / 255
            self.std = (np.mean(squared, axis=0) - self.mean ** 2) ** 0.5 / 255

    def __getitem__(self, index):
        assert self.transform is not None, ".set_tranform 메소드를 이용하여 transform 을 주입해주세요"

        image = self.read_image(index)
        mask_label = self.get_mask_label(index)
        gender_label = self.get_gender_label(index)
        age_label = self.get_age_label(index)
        multi_class_label = self.encode_multi_class(mask_label, gender_label, age_label)

        image_transform = self.transform(image)
        return image_transform, multi_class_label
    
    def set_transform(self, transform):
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)
    
    # ...
    
    @staticmethod
    def denormalize_image(image, mean, std):
        img_cp = image.copy()
        img_cp *= std
        img_cp += mean
        img_cp *= 255.0
        img_cp = np.clip(img_cp, 0, 255).astype(np.uint8)
        return img_cp

    # ...

    def split_dataset(self) -> Tuple[Subset, Subset]:
        """
        데이터셋을 train 과 val 로 나눕니다.
        pytorch 내부의 torch.utils.data.random_split 함수를 사용하여
        torch.utils.data.Subset 클래스 둘로 나눕니다.
        구현이 어렵지 않으니 구글링 혹은 IDE (e.g. pycharm) 의 navigation 기능을 통해 코드를 한 번 읽어보는 것을 추천드립니다^^
        """
        n_val = int(len(self) * self.val_ratio)
        n_train = len(self) - n_val
        train_set, val_set = random_split(self, [n_train, n_val])
        return train_set, val_set

위와 같은 base dataset이 잘 정의되어 있으면, 다른 dataset 클래스들은 이 클래스를 상속하여 일부 메서드들을 오버라이딩하여 쉽게 정의할 수 있습니다.

TestDataset

테스트용 데이터셋을 위한 TestDataset 클래스는 base dataset 클래스를 상속받는 형태가 아닌, 별도로 작성해야 합니다.

class TestDataset(Dataset):
    def __init__(self, img_paths, resize, mean=(0.548, 0.504, 0.479), std=(0.237, 0.247, 0.246)):
        self.img_paths = img_paths
        self.transform = transforms.Compose([
            Resize(resize, Image.BICUBIC),
            ToTensor(),
            Normalize(mean=mean, std=std),
        ])

    def __getitem__(self, index):
        image = Image.open(self.img_paths[index])

        if self.transform:
            image = self.transform(image)
        return image

    def __len__(self):
        return len(self.img_paths)


loss.py

Loss 클래스들이 정의되어 있는 파일입니다.

Loss 클래스

Loss 클래스는 nn.Module 클래스를 상속받는 클래스입니다. 따라서 기본적으로 __init__forward 함수를 구현해줘야 합니다.

예를 들어 Focal Loss는 다음과 같이 구현할 수 있습니다.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class FocalLoss(nn.Module):
    def __init__(self, weight=None,
                 gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob,
            target_tensor,
            weight=self.weight,
            reduction=self.reduction
        )

다양한 Loss들에 대한 코드가 이미 github나 stack overflow에 많이 있어서, 그 코드들을 가져와 loss.py 파일 내에 클래스로 정의하고 사용하면 됩니다.

편의 함수

이번에 제공된 baseline code에는 아래와 같은 편의 함수들이 포함되어 있었습니다. train.py와 같은 곳에서 loss 객체를 생성할 때, 클래스를 직접 호출하는 것이 아니라 create_criterion에 인자로 criterion_name과 criterion parameters를 dictionary 형태로 전달하면서 호출하여 객체를 생성합니다.

이것이 꼭 필요한 것인지 아직 잘 모르겠지만, 적절히 사용하면 더 좋은 코드를 작성할 수 있으리라 생각합니다.

# Loss 목록
_criterion_entrypoints = {
    'cross_entropy': nn.CrossEntropyLoss,
    'focal': FocalLoss,
    'label_smoothing': LabelSmoothingLoss,
    'f1': F1Loss,
    'ldam': LDAMLoss,
    'custom_ldam': CustomLDAMLoss,
    'weighted_cross_entropy': WeightedCrossEntropyLoss
}

# criterion 클래스 반환
def criterion_entrypoint(criterion_name):
    return _criterion_entrypoints[criterion_name]

# criterion이 정의되어 있는지 확인
def is_criterion(criterion_name):
    return criterion_name in _criterion_entrypoints

# 생성된 criterion 객체 반환
def create_criterion(criterion_name, **kwargs):
    if is_criterion(criterion_name):
        create_fn = criterion_entrypoint(criterion_name)
        criterion = create_fn(**kwargs)
    else:
        raise RuntimeError('Unknown loss (%s)' % criterion_name)
    return criterion


model.py

model.py 파일도 loss.py 파일이 작성되는 형태와 비슷합니다. Model class도 nn.Module 클래스를 상속하고, __init__forward 메서드를 정의합니다.

이때, Model 클래스는 직접 구현하거나 pretrained model을 가져오는 형태가 가능합니다. 두 경우에 코드는 아래와 같이 작성됩니다.

Custom Model

여기서는 간단한 모델이기 때문에 하나의 클래스 내에 모든 코드를 작성했지만, 큰 모델의 경우 여러 클래스로 나누어 각 부분을 정의하고 최종 모델 클래스 하나에서 각 클래스들을 이용해 모델을 생성하는 식으로 구현됩니다.

class BaseModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.25)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)

        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout2(x)

        x = self.avgpool(x)
        x = x.view(-1, 128)
        return self.fc(x)

Pretrained Model

torchvision.models나 timm에서 pretrained model을 가져오는 형태로 모델 클래스를 정의할 수도 있습니다.

이 경우에 인자로 전달받은 num_classes를 이용해 마지막 classification layer를 변경해주어야 합니다.

  • Timm library
class TimmEfficientNetB4(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        self.model = timm.create_model('efficientnet_b4', pretrained=True, num_classes=num_classes)
        
        # for param in self.model.parameters():
        #     param.require_grads = False
        
        in_features = self.model.classifier.in_features # 1536
        self.model.classifier = nn.Sequential(
            nn.BatchNorm1d(in_features), 
            nn.Dropout(0.5), 
            nn.Linear(in_features=in_features, out_features=num_classes, bias=False)
        )


    def forward(self, x):
        x = self.model.forward(x)
        return x
  • torchvision library
class TVResNext50(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.model = models.resnext50_32x4d(pretrained=True)
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, out_features=num_classes, bias=True)
        torch.nn.init.xavier_uniform_(self.model.fc.weight)
        stdv = 1. / math.sqrt(self.model.fc.weight.size(1))
        self.model.fc.bias.data.uniform_(-stdv, stdv)

    def forward(self, x):
        x = self.model.forward(x)
        return x


train.py

가장 복잡하게 느껴지는 train.py 파일입니다. 이름처럼 모델의 학습을 진행합니다. 코드 자체는 300줄 정도 되는데, 일단 어떤 흐름으로 작성되어 있는지부터 봅시다.

# import modules 
import argparse
# ...
wandb.init(project="level 1-p stage", entity="wowo0709")

# seeding 
def seed_everything(seed):
    # ...

# Image visualization in Tensorboard 
def grid_image(np_images, gts, preds, n=16, shuffle=False):
    # ...
    
# Make automatically incrementing path 
def increment_path(path, exist_ok=False):
    # ...

# train
def train(data_dir, model_dir, args):
    # -- settings
    # ...

    # -- dataset
    # ...

    # -- augmentation
    # ...

    # -- data_loader
    # ...

    # -- model
    # ...

    # -- loss, optimizer, scheduler
    # ...

    # -- compile options
    # ...

    # -- logging
    # ...

    # -- training&validating
    best_val_acc = 0
    best_val_loss = np.inf
    for epoch in range(args.epochs):
        # train loop
        # ...

        # val loop
        # ...

        # early stopping
        # ...

# main
if __name__ == '__main__':
    # -- arguments
    parser = argparse.ArgumentParser()
    
    # -- print verbose
    from dotenv import load_dotenv
    import os
    load_dotenv(verbose=True)

    # -- register arguments
    parser.add_argument('--seed', type=int, default=42, help='random seed (default: 42)')
    # ...
    data_dir = args.data_dir
    model_dir = args.model_dir
    
    # -- call train function 
    train(data_dir, model_dir, args)

전체 학습 코드를 한 눈에 보려 하면 그 길이에 압도 당하지만, 위와 같이 흐름만 본다면 충분히 납득이 가는 흐름으로 작성되어 있습니다.

다른 부분들은 제외하고, train function 내의 각 부분들에 대한 코드가 어떻게 작성되는지 보겠습니다.

settings

settings 파트에서는 다음의 것들을 합니다.

  • 실험 과정 고정을 위한 seeding
  • 각종 로그 및 모델 데이터 저장 경로 생성
  • cpu/gpu 사용 여부
    # -- settings
    seed_everything(args.seed)
    save_dir = increment_path(os.path.join(model_dir, args.name))
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

dataset

train/validation dataset을 생성합니다.

    # -- dataset
    dataset_module = getattr(import_module("dataset"), args.dataset)  # default: BaseAugmentation
    dataset = dataset_module(
        data_dir=data_dir,
    )
    train_set, val_set = dataset.split_dataset()
    num_classes = dataset.num_classes  # 18

augmentation

사용할 augmentation 객체를 생성하고 dataset의 set_transform 메서드를 호출하며 인자로 전달합니다.

    # -- augmentation
    transform_module = getattr(import_module("dataset"), args.augmentation)  # default: BaseAugmentation
    transform = transform_module(
        resize=args.resize,
        mean=dataset.mean,
        std=dataset.std,
    )
    dataset.set_transform(transform)

data_loader

앞서 만든 train/validation set을 이용해 train/validation DataLoader 객체를 생성합니다.

    # -- data_loader
    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        num_workers=multiprocessing.cpu_count()//2,
        shuffle=True,
        pin_memory=use_cuda,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_set,
        batch_size=args.valid_batch_size,
        num_workers=multiprocessing.cpu_count()//2,
        shuffle=False,
        pin_memory=use_cuda,
        drop_last=True,
    )

model

model 객체를 생성합니다.

    # -- model
    model_module = getattr(import_module("model"), args.model)  # default: BaseModel
    model = model_module(
        num_classes=num_classes
    ).to(device)
    model = torch.nn.DataParallel(model)

loss, optimizer, scheduler

loss, optimizer, lr scheduler 객체를 생성합니다.

    # -- loss & metric
    criterion = create_criterion(args.criterion)  # default: cross_entropy
    opt_module = getattr(import_module("torch.optim"), args.optimizer)  # default: SGD
    optimizer = opt_module(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=args.lr,
        weight_decay=1e-3 # 5e-4
    )
    scheduler = StepLR(optimizer, args.lr_decay_step, gamma=0.5)

compile options

Early Stopping 등 학습 과정에서 사용할 compile option용 객체들을 생성합니다.

    # -- compile options
    early_stopping = EarlyStopping(patience=7, verbose=True, path=os.path.join(save_dir, 'early_stopping.pth'))

logging

Tensorboard, wandb 등 logging을 위한 파일을 저장하는 코드를 작성합니다.

    # -- logging
    # Tensorboard logging
    logger = SummaryWriter(log_dir=save_dir)
    with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
        json.dump(vars(args), f, ensure_ascii=False, indent=4)
        # wandb logging
        wandb.config = f

training&validating

위의 코드들을 작성하고 나면 최종적으로 training과 validating 코드를 작성합니다. Train과 Validation을 위한 loop 코드도 매우 다양하게 존재합니다. 따라서 여기서 제시하는 코드는 하나의 좋은 예시이고, 개인마다 좀 더 자신에게 맞는 혹은 상황에 맞는 코드로 수정해도 좋습니다.

train loop

		# train loop
        model.train()
        loss_value = 0
        matches = 0
        for idx, train_batch in enumerate(train_loader):
            inputs, labels = train_batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outs = model(inputs)
            preds = torch.argmax(outs, dim=-1)
            loss = criterion(outs, labels)

            loss.backward()
            optimizer.step()

            loss_value += loss.item()
            matches += (preds == labels).sum().item()
            if (idx + 1) % args.log_interval == 0:
                train_loss = loss_value / args.log_interval
                train_acc = matches / args.batch_size / args.log_interval
                current_lr = get_lr(optimizer)
                print(
                    f"Epoch[{epoch+1}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
                )
                logger.add_scalar("Train/loss", train_loss, epoch * len(train_loader) + idx)
                logger.add_scalar("Train/accuracy", train_acc, epoch * len(train_loader) + idx)

                loss_value = 0
                matches = 0

        scheduler.step()

        # weights&biases logging - train
        wandb.log({'train_accuracy': train_acc, 'train_loss': train_loss})

validation loop

        # val loop
        with torch.no_grad():
            print("Calculating validation results...")
            model.eval()
            val_loss_items = []
            val_acc_items = []
            figure = None
            for val_batch in val_loader:
                inputs, labels = val_batch
                inputs = inputs.to(device)
                labels = labels.to(device)

                outs = model(inputs)
                preds = torch.argmax(outs, dim=-1)

                loss_item = criterion(outs, labels).item()
                acc_item = (labels == preds).sum().item()
                val_loss_items.append(loss_item)
                val_acc_items.append(acc_item)

                if figure is None:
                    inputs_np = torch.clone(inputs).detach().cpu().permute(0, 2, 3, 1).numpy()
                    inputs_np = dataset_module.denormalize_image(inputs_np, dataset.mean, dataset.std)
                    figure = grid_image(
                        inputs_np, labels, preds, n=16, shuffle=args.dataset != "MaskSplitByProfileDataset"
                    )

            val_loss = np.sum(val_loss_items) / len(val_loader)
            val_acc = np.sum(val_acc_items) / len(val_set)
            best_val_loss = min(best_val_loss, val_loss)
            if val_acc > best_val_acc:
                print(f"New best model for val accuracy : {val_acc:4.2%}! saving the best model..")
                torch.save(model.module.state_dict(), f"{save_dir}/best.pth")
                best_val_acc = val_acc
            torch.save(model.module.state_dict(), f"{save_dir}/last.pth")
            print(
                f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.2} || "
                f"best acc : {best_val_acc:4.2%}, best loss: {best_val_loss:4.2}"
            )
            logger.add_scalar("Val/loss", val_loss, epoch)
            logger.add_scalar("Val/accuracy", val_acc, epoch)
            logger.add_figure("results", figure, epoch)
            print()

        # weights&biases logging - validation
        wandb.log({'val_accuracy': val_acc, 'val_loss': val_loss})

compile options (early stopping)

        # early stopping
        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("Early Stopping")
            break


inference.py

inference.py 파일도 터미널에서 직접적으로 실행하는 파일인만큼 argument들을 등록하는 코드가 존재합니다. 해당 부분을 제외하면 inference.py 는 크게 다음의 것들을 수행합니다.

  • 저장된 모델 불러오기
  • 추론할 데이터 불러오기
  • 모델 추론 후 결과 출력(저장)

train.py 파일의 흐름을 이해했다면, inference.py의 흐름은 쉽게 이해할 수 잇을 것입니다.

load_model

def load_model(saved_model, num_classes, device):
    model_cls = getattr(import_module("model"), args.model)
    model = model_cls(
        num_classes=num_classes
    )

    # tarpath = os.path.join(saved_model, 'best.tar.gz')
    # tar = tarfile.open(tarpath, 'r:gz')
    # tar.extractall(path=saved_model)

    try:
        model_path = os.path.join(saved_model, 'best.pth')
        model.load_state_dict(torch.load(model_path, map_location=device))
    except: # module. prefix -> nn.Parallel
        model = torch.nn.DataParallel(model)
        model_path = os.path.join(saved_model, 'best.pth')
        model.load_state_dict(torch.load(model_path, map_location=device))
        # Or you can try like, 
        '''
        # original saved file with DataParallel
        state_dict = torch.load('myfile.pth.tar')
        # create new OrderedDict that does not contain `module.`
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v
        # load params
        model.load_state_dict(new_state_dict)
        '''

    return model

inference

@torch.no_grad()
def inference(data_dir, model_dir, output_dir, args):
    """
    """
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    num_classes = MaskBaseDataset.num_classes  # 18
    model = load_model(model_dir, num_classes, device).to(device)
    model.eval()

    img_root = os.path.join(data_dir, 'images')
    info_path = os.path.join(data_dir, 'info.csv')
    info = pd.read_csv(info_path)

    img_paths = [os.path.join(img_root, img_id) for img_id in info.ImageID]
    dataset = TestDataset(img_paths, args.resize)
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=8,
        shuffle=False,
        pin_memory=use_cuda,
        drop_last=False,
    )

    print("Calculating inference results..")
    preds = []
    with torch.no_grad():
        for idx, images in enumerate(loader):
            images = images.to(device)
            pred = model(images)
            pred = pred.argmax(dim=-1)
            preds.extend(pred.cpu().numpy())

    info['ans'] = preds
    nickname = args.model_dir.split('/')[-2]
    info.to_csv(os.path.join(output_dir, f'{nickname}_output.csv'), index=False)
    print(f'Inference Done!')

main

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Data and model checkpoints directories
    parser.add_argument('--batch_size', type=int, default=32, help='input batch size for validing (default: 1000)')
    parser.add_argument('--resize', nargs="+", type=int, default=[300, 300], help='resize size for image when you trained (default: (96, 128))')
    parser.add_argument('--model', type=str, default='TimmEfficientNetB3', help='model type (default: BaseModel)')

    # Container environment
    parser.add_argument('--data_dir', type=str, default=os.environ.get('SM_CHANNEL_EVAL', '/opt/ml/input/data/eval'))
    parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_CHANNEL_MODEL', './model'))
    parser.add_argument('--output_dir', type=str, default=os.environ.get('SM_OUTPUT_DATA_DIR', './output'))

    args = parser.parse_args()

    data_dir = args.data_dir
    model_dir = args.model_dir
    output_dir = args.output_dir

    os.makedirs(output_dir, exist_ok=True)

    inference(data_dir, model_dir, output_dir, args)


utils.py

utils.py 파일은 말 그대로 편의용 클래스나 함수들을 정의해놓는 곳입니다. 앞서 코드에서 살펴본 조기종료를 위한 EarlyStopping 클래스, optimizer의 learning rate를 얻기 위한 get_lr 함수 등이 있습니다.


requirements.txt

requirements.txt는 다른 사용자가 코드를 사용할 때 환경 구축을 위한 파일입니다.

torch==1.7.1
torchvision==0.8.2
tensorboard==2.4.1
pandas==1.1.5
opencv-python==4.1.2.30
scikit-learn~=0.24.1
matplotlib==3.2.1



이상으로 PyTorch Project를 구성하는 Code Template에서 각 Python file의 code level에서의 흐름과, 어떻게 작성할 수 있는지에 대해 살펴보았습니다.

다음 [More Tips] 포스팅에서는 대회 과정에서 사용한 다양한 기법이나 라이브러리들을 소개하고 어떻게 사용할 수 있는지에 대해 살펴보겠습니다.


Leave a comment