Vision Transformerによるクラス分類を実装してみた

Vision Transformer (ViT)によるクラス分類を実装したのでメモしておく。

  • 参考にしたサイト様

farml1.com

以下のコードは、Google colabで実行する。

必要なライブラリのインストール

!pip install timm seaborn linformer tqdm pandas

import (全部はいらないかも)

from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

from sklearn.metrics import average_precision_score

#from vit_pytorch.efficient import ViT
import seaborn as sns
import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint([ x for x in model_names if "vit" in x])

ここで使えるViTのモデル一覧が表示されるので、使いたいものを選択する。

ハイパーパラメータの設定

# Training settings
model_name = 'vit_medium_patch16_gap_384.sw_in12k_ft_in1k'
batch_size = 48
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42
img_size = 384
weight_decay = 1e-3
device = 'cuda'
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

学習データのロード

data_path = "/content/drive/MyDrive/data"
train_dir = data_path + '/train'
test_dir = data_path + '/test'
train_list = sorted(glob.glob(os.path.join(train_dir,'*.png')))
test_list = sorted(glob.glob(os.path.join(test_dir, '*.png')))
print(f"Train Data: {len(train_list)}")
print(f"Test Data: {len(test_list)}")

train_labels = np.loadtxt(data_path + "/train.csv", delimiter=',', dtype='int64',
               skiprows=1, usecols=[1])
test_labels = np.loadtxt(data_path + "/train.csv", delimiter=',', dtype='int64',
               skiprows=1, usecols=[1])
test_labels = np.zeros((len(test_list),))

train_df = pd.DataFrame({
    'data_path': train_list,
    'label': train_labels
})
test_df = pd.DataFrame({
    'data_path': test_list,
    'label': test_labels
})
random_idx = np.random.randint(1, len(train_df.data_path), size=9)
fig, axes = plt.subplots(3, 3, figsize=(16, 12))

for idx, ax in enumerate(axes.ravel()):
    img = Image.open(train_df.data_path[idx])
    ax.set_title(train_df.label[idx])
    ax.imshow(img)

データセットを分割。

train_df, valid_df = train_test_split(train_df,
                                          test_size=0.1,
                                          stratify=train_df.label,
                                          random_state=seed)
print(f"Train Data: {len(train_df)}")
print(f"Validation Data: {len(valid_df)}")
print(f"Test Data: {len(test_df)}")

前処理の設定。

train_transforms = transforms.Compose(
    [
        transforms.Resize((img_size, img_size)),
        transforms.RandomResizedCrop(img_size),
        transforms.RandomRotation(degrees = 180),
        #transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
        transforms.ToTensor(),
        transforms.RandomErasing(p=0.5),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize(img_size),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize(img_size),
        transforms.ToTensor(),
    ]
)
import numpy as np
class MyDataset(Dataset):
    def __init__(self, df, transform=None, return_filepath=False):
        self.file_list = list(df.data_path)
        self.transform = transform
        self.label = list(df.label)
        self.return_filepath = return_filepath

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)
        if not self.return_filepath:
          return img_transformed, self.label[idx]
        else:
          return img_transformed, os.path.basename(img_path)
train_data = MyDataset(train_df, transform=train_transforms)
valid_data = MyDataset(valid_df, transform=test_transforms)
test_data = MyDataset(test_df, transform=test_transforms, return_filepath=True)

train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size=1, shuffle=False)

print(len(train_data), len(train_loader))

モデルの準備

model = timm.create_model(model_name, pretrained=True, num_classes=2).to(device)
#model = torch.load(data_path + '/model_weight.pth').to(device)
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

学習の実行

train_acc_list = []
val_acc_list = []
train_loss_list = []
val_loss_list = []
##############

epochs = 20
m = nn.Softmax(dim=1)

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = m(model(data))
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        #tmp = output.detach().cpu().numpy()
        #result = tmp[:, 1]
        #acc = average_precision_score(label.detach().cpu().numpy(), result, average='samples')


        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            #tmp = val_output.detach().cpu().numpy()
            #result = tmp[:, 1]
            #acc = average_precision_score(label.detach().cpu().numpy(), result, average='samples')
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )
    torch.save(model, data_path + '/model_weight.pth')

    train_acc_list.append(epoch_accuracy)
    val_acc_list.append(epoch_val_accuracy)
    train_loss_list.append(epoch_loss)
    val_loss_list.append(epoch_val_loss)

学習結果のプロット

device2 = torch.device('cpu')

train_acc = []
train_loss = []
val_acc = []
val_loss = []

for i in range(epochs):
    train_acc2 = train_acc_list[i].to(device2)
    train_acc3 = train_acc2.clone().numpy()
    train_acc.append(train_acc3)

    train_loss2 = train_loss_list[i].to(device2)
    train_loss3 = train_loss2.clone().detach().numpy()
    train_loss.append(train_loss3)

    val_acc2 = val_acc_list[i].to(device2)
    val_acc3 = val_acc2.clone().numpy()
    val_acc.append(val_acc3)

    val_loss2 = val_loss_list[i].to(device2)
    val_loss3 = val_loss2.clone().numpy()
    val_loss.append(val_loss3)

#取得したデータをグラフ化する
sns.set()
num_epochs = epochs

fig = plt.subplots(figsize=(12, 4), dpi=80)

ax1 = plt.subplot(1,2,1)
ax1.plot(range(num_epochs), train_acc, c='b', label='train acc')
ax1.plot(range(num_epochs), val_acc, c='r', label='val acc')
ax1.set_xlabel('epoch', fontsize='12')
ax1.set_ylabel('accuracy', fontsize='12')
ax1.set_title('training and val acc', fontsize='14')
ax1.legend(fontsize='12')

ax2 = plt.subplot(1,2,2)
ax2.plot(range(num_epochs), train_loss, c='b', label='train loss')
ax2.plot(range(num_epochs), val_loss, c='r', label='val loss')
ax2.set_xlabel('epoch', fontsize='12')
ax2.set_ylabel('loss', fontsize='12')
ax2.set_title('training and val loss', fontsize='14')
ax2.legend(fontsize='12')
plt.show()