Vision Transformer (ViT)によるクラス分類を実装したのでメモしておく。
- 参考にしたサイト様
以下のコードは、Google colabで実行する。
必要なライブラリのインストール
!pip install timm seaborn linformer tqdm pandas
import (全部はいらないかも)
from __future__ import print_function import glob from itertools import chain import os import random import zipfile import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from linformer import Linformer from PIL import Image from sklearn.model_selection import train_test_split from torch.optim.lr_scheduler import StepLR from torch.utils.data import DataLoader, Dataset from torchvision import datasets, transforms from tqdm.notebook import tqdm from sklearn.metrics import average_precision_score #from vit_pytorch.efficient import ViT import seaborn as sns import timm from pprint import pprint model_names = timm.list_models(pretrained=True) pprint([ x for x in model_names if "vit" in x])
ここで使えるViTのモデル一覧が表示されるので、使いたいものを選択する。
ハイパーパラメータの設定
# Training settings model_name = 'vit_medium_patch16_gap_384.sw_in12k_ft_in1k' batch_size = 48 epochs = 20 lr = 3e-5 gamma = 0.7 seed = 42 img_size = 384 weight_decay = 1e-3 device = 'cuda' def seed_everything(seed): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True seed_everything(seed)
学習データのロード
data_path = "/content/drive/MyDrive/data" train_dir = data_path + '/train' test_dir = data_path + '/test' train_list = sorted(glob.glob(os.path.join(train_dir,'*.png'))) test_list = sorted(glob.glob(os.path.join(test_dir, '*.png'))) print(f"Train Data: {len(train_list)}") print(f"Test Data: {len(test_list)}") train_labels = np.loadtxt(data_path + "/train.csv", delimiter=',', dtype='int64', skiprows=1, usecols=[1]) test_labels = np.loadtxt(data_path + "/train.csv", delimiter=',', dtype='int64', skiprows=1, usecols=[1]) test_labels = np.zeros((len(test_list),)) train_df = pd.DataFrame({ 'data_path': train_list, 'label': train_labels }) test_df = pd.DataFrame({ 'data_path': test_list, 'label': test_labels }) random_idx = np.random.randint(1, len(train_df.data_path), size=9) fig, axes = plt.subplots(3, 3, figsize=(16, 12)) for idx, ax in enumerate(axes.ravel()): img = Image.open(train_df.data_path[idx]) ax.set_title(train_df.label[idx]) ax.imshow(img)
データセットを分割。
train_df, valid_df = train_test_split(train_df, test_size=0.1, stratify=train_df.label, random_state=seed) print(f"Train Data: {len(train_df)}") print(f"Validation Data: {len(valid_df)}") print(f"Test Data: {len(test_df)}")
前処理の設定。
train_transforms = transforms.Compose( [ transforms.Resize((img_size, img_size)), transforms.RandomResizedCrop(img_size), transforms.RandomRotation(degrees = 180), #transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), transforms.ToTensor(), transforms.RandomErasing(p=0.5), ] ) val_transforms = transforms.Compose( [ transforms.Resize(img_size), transforms.ToTensor(), ] ) test_transforms = transforms.Compose( [ transforms.Resize(img_size), transforms.ToTensor(), ] )
import numpy as np class MyDataset(Dataset): def __init__(self, df, transform=None, return_filepath=False): self.file_list = list(df.data_path) self.transform = transform self.label = list(df.label) self.return_filepath = return_filepath def __len__(self): self.filelength = len(self.file_list) return self.filelength def __getitem__(self, idx): img_path = self.file_list[idx] img = Image.open(img_path) img_transformed = self.transform(img) if not self.return_filepath: return img_transformed, self.label[idx] else: return img_transformed, os.path.basename(img_path)
train_data = MyDataset(train_df, transform=train_transforms) valid_data = MyDataset(valid_df, transform=test_transforms) test_data = MyDataset(test_df, transform=test_transforms, return_filepath=True) train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True ) valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True) test_loader = DataLoader(dataset = test_data, batch_size=1, shuffle=False) print(len(train_data), len(train_loader))
モデルの準備
model = timm.create_model(model_name, pretrained=True, num_classes=2).to(device) #model = torch.load(data_path + '/model_weight.pth').to(device)
# loss function criterion = nn.CrossEntropyLoss() # optimizer optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) # scheduler scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
学習の実行
train_acc_list = [] val_acc_list = [] train_loss_list = [] val_loss_list = [] ############## epochs = 20 m = nn.Softmax(dim=1) for epoch in range(epochs): epoch_loss = 0 epoch_accuracy = 0 for data, label in tqdm(train_loader): data = data.to(device) label = label.to(device) output = m(model(data)) loss = criterion(output, label) optimizer.zero_grad() loss.backward() optimizer.step() acc = (output.argmax(dim=1) == label).float().mean() #tmp = output.detach().cpu().numpy() #result = tmp[:, 1] #acc = average_precision_score(label.detach().cpu().numpy(), result, average='samples') epoch_accuracy += acc / len(train_loader) epoch_loss += loss / len(train_loader) with torch.no_grad(): epoch_val_accuracy = 0 epoch_val_loss = 0 for data, label in valid_loader: data = data.to(device) label = label.to(device) val_output = model(data) val_loss = criterion(val_output, label) acc = (val_output.argmax(dim=1) == label).float().mean() #tmp = val_output.detach().cpu().numpy() #result = tmp[:, 1] #acc = average_precision_score(label.detach().cpu().numpy(), result, average='samples') epoch_val_accuracy += acc / len(valid_loader) epoch_val_loss += val_loss / len(valid_loader) print( f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n" ) torch.save(model, data_path + '/model_weight.pth') train_acc_list.append(epoch_accuracy) val_acc_list.append(epoch_val_accuracy) train_loss_list.append(epoch_loss) val_loss_list.append(epoch_val_loss)
学習結果のプロット
device2 = torch.device('cpu') train_acc = [] train_loss = [] val_acc = [] val_loss = [] for i in range(epochs): train_acc2 = train_acc_list[i].to(device2) train_acc3 = train_acc2.clone().numpy() train_acc.append(train_acc3) train_loss2 = train_loss_list[i].to(device2) train_loss3 = train_loss2.clone().detach().numpy() train_loss.append(train_loss3) val_acc2 = val_acc_list[i].to(device2) val_acc3 = val_acc2.clone().numpy() val_acc.append(val_acc3) val_loss2 = val_loss_list[i].to(device2) val_loss3 = val_loss2.clone().numpy() val_loss.append(val_loss3) #取得したデータをグラフ化する sns.set() num_epochs = epochs fig = plt.subplots(figsize=(12, 4), dpi=80) ax1 = plt.subplot(1,2,1) ax1.plot(range(num_epochs), train_acc, c='b', label='train acc') ax1.plot(range(num_epochs), val_acc, c='r', label='val acc') ax1.set_xlabel('epoch', fontsize='12') ax1.set_ylabel('accuracy', fontsize='12') ax1.set_title('training and val acc', fontsize='14') ax1.legend(fontsize='12') ax2 = plt.subplot(1,2,2) ax2.plot(range(num_epochs), train_loss, c='b', label='train loss') ax2.plot(range(num_epochs), val_loss, c='r', label='val loss') ax2.set_xlabel('epoch', fontsize='12') ax2.set_ylabel('loss', fontsize='12') ax2.set_title('training and val loss', fontsize='14') ax2.legend(fontsize='12') plt.show()