Kaggleチャレンジ8日目 pytorchでフルーツの画像分類をしてみた

シェアする

  • このエントリーをはてなブックマークに追加
スポンサーリンク

はじめに

 今回は、いつもとは異なり、コンペではなく、Kaggleに公開されているデータセットを使っていきたいとおもいます。Kaggleにはコンペにはなってないけどデータセットだけ公開されているものもあり、色々試してみたいのですが、今回は、基本的な画像分類をフルーツ画像のデータセットを使って行なっていきたいと思います。現在、画像分類などを調べてもMNISTやCIFAR10などのフレームワークに標準に備わっているものばかりなので、そういったデータセットはもう大丈夫だよという方はKaggleのデータセットをお勧めします。

また、今回もGoogleColaboratoryを使おうと思っていたのですが、zipファイルをGoogleColaboratory上で解凍しようとしたらファイル数の関係なのか動作が停止してしまったため、ローカル上で動かしています。

今回のデータセットのリンク

https://www.kaggle.com/moltean/fruits/home

関連のおすすめ記事:

Kaggle チャレンジ 1日目 タイタニックの問題からデータを読み解いてみる

Kaggle チャレンジ 4日目 住宅価格問題を解いていく

Kaggleチャレンジ 6日目 Googleアナリティクスの顧客収益予測 データの分析

データセットの説明

フルーツの種類

アペラシ、アボカド、熟したアボカド、バナナ(黄色、赤色)、サボテン果実、カンタロープ(2種類)、カランブラ(黄色、黄色、赤色) 、グリーバフルーツ(ピンク、ホワイト)、グアバ、ハックルベリー、キウイ、カキ、チェリー(各種品種、レーニア)、チェリーワックス(イエロー、レッド、ブラック)、クレマンチン、ココス、デー、グラナディーラ、ブドウ(ピンク、白、白2) 、Kumsquats、レモン(通常、メイヤー)、ライム、ライチ、マンダリン、マンゴー、マラクジャー、メロンピエルデサポ、クワ、ネクタリン、オレンジ、パパイヤ、パッションフルーツ、ピーチ、ペピーノ、ナシ(品種、アバート、モンスター、ウィリアムズ(ノーマル、ミニ)、ピタハヤレッド、プラム、ザクロ、クワン、ランブタン、ラズベリー、サラク、イチゴ(通常、ウェッジ)、タマリロ、タンジェロ、トマト(様々な品種、マルーン、チェリーレッド)、ウォールナット。

データセットの内訳

  • 総画像数:55244
  • トレーニングデータセット:41322枚
  • テストデータセット:13877枚
  • マルチフルーツセットサイズ:45枚(1枚ごとに複数のフルーツ(またはフルーツ))
  • クラス数:81(果物)
  • 画像サイズ:100×100ピクセル

ソースコード

ソースコードが綺麗でなくてすいません。

ライブラリの準備

import torch 
import numpy as np 
import torch.nn as nn 
import torch.optim as optim 
import torch.nn.functional as F 
import torch.utils.data 
import torch.backends.cudnn as cudnn 
import torchvision 
from torchvision import datasets, models, transforms 
from torchsummary import summary 
import torchvision 
from torchvision import datasets, models, transforms 
from sklearn.metrics import confusion_matrix, accuracy_score 
import random 
import os 
import time

torchsummaryはネットワークを表示させるのにすごい便利なのでぜひ使ってください。ローカルにない場合は、次のコードで入れられます。

pip install torchsummary

学習したモデルを保存、読み込む関数

def save_checkpoint(path, epoch, model):
 	save_path = os.path.join(path, "model_epoch_{}.pkl".format(epoch))
 	torch.save(model.state_dict(), save_path)
 	print("Checkpoint saved to {}".format(save_path))

 def load_checkpoint(model_dir,epoch,model):
     load_path = os.path.join(model_dir,"model_epoch_{}.pkl".format(epoch))
     checkpoint = torch.load(load_path)
     model.load_state_dict(checkpoint)
     print("Checkpoint loaded to {}".format(load_path))

初期化

np.random.seed(1) 
random.seed(1) 
torch.manual_seed(1) 
torch.cuda.manual_seed(1)

ネットワーク設定

今回、すぐに精度が高くなったのでエポック数を少なくしてあります。

batchsize = 32 
epochs = 5 
epoch_start =1

GPU

use_gpu = torch.cuda.is_available()
 if use_gpu:
     print("cuda is available!")
     cudnn.benchmark = True
     cudnn.deterministic = True

モデルの保存フォルダ

checkout_dir = "./checkout" 
if os.path.exists(checkout_dir) is False:
     os.mkdir(checkout_dir)

ネットワーク

今回クラス数が81なので出力数を変えています。

n_classes = 81
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
model = models.resnet18(pretrained=True).to(device) 
n_filters = model.fc.in_features 
model.fc = nn.Linear(n_filters, n_classes)
if use_gpu:
    model.cuda()

モデル読み込み(学習再開)

もし、学習を中断してしまった場合、Trueにし、再開するエポック数を入れることで再開できます。

model_load = True 
if model_load:
     epoch_start = 2
     load_checkpoint(checkout_dir,epoch_start,model)

オプティマイザ

今回は、エポック数が少ないので使いませんが、スケジューラを使うのがおすすめです。スケジューラを使う場合、最初の学習率は高い方がいいそうです。

optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, weight_decay=5e-4)
use_scheduler = False
if use_scheduler:
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

クロスエントロピー

今回は分類タスクなので、CrossEntropyLoss()を使います。

criterion = nn.CrossEntropyLoss()

train関数

def train(model, train_loader,epoch):
   model.train()
   print('\nEpoch: %d' % epoch)
   train_loss = 0
   correct = 0
   total = 0
   if use_scheduler:
     scheduler.step()
   for batch_idx, (image, label) in enumerate(train_loader):
     image, label = image.to(device), label.to(device)
     optimizer.zero_grad()
     outputs = model(image)
     loss = criterion(outputs, label)
     loss.backward()
     optimizer.step()
     train_loss += loss.item()
     _, predicted = outputs.max(1)
     total += label.size(0)
     correct += predicted.eq(label).sum().item()
   print('Loss:{} | Acc:{} ({}/{})'.format((train_loss/(batch_idx+1)), 100.*correct/total, correct, total))

test関数

def test(model, test_loader,epoch):
   model.eval()
   running_loss = 0
   test_loss = 0
   correct = 0
   total = 0
   for batch_idx,(image, label) in enumerate(test_loader):
     image, label = image.to(device), label.to(device)
     outputs = model(image)
     loss = criterion(outputs, label)
     test_loss += loss.item()
     _, predicted = outputs.max(1)
     total += label.size(0)
     correct += predicted.eq(label).sum().item()
   print('Loss:{} | Acc:{} ({}/{})'.format((test_loss/(batch_idx+1)), 100.*correct/total, correct, total))

evaluation関数

def evaluation(model_dir,epoch,model,test_loader):
     load_checkpoint(model_dir,epoch,model)
     model.eval()
     y_test = []
     y_pred = []
     for batch_idx,(image, label) in enumerate(test_loader):
         image, label = image.to(device), label.to(device)
         outputs = model(image)
         _, predictions = outputs.max(1)
         y_test.append(label.data.cpu().numpy())
         y_pred.append(predictions.data.cpu().numpy())
     y_test = np.concatenate(y_test)
     y_pred = np.concatenate(y_pred)
     print(accuracy_score(y_test, y_pred))

main

if __name__ == '__main__':
     summary(model, (3, 224, 224))
     data_transform = transforms.Compose([
       transforms.Resize([224, 224]),
       transforms.RandomHorizontalFlip(),
       transforms.ToTensor(),
       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
     ])
     train_Dataset = datasets.ImageFolder(root='./fruits-360/Training',transform=data_transform)
     train_loader = torch.utils.data.DataLoader(train_Dataset,batch_size=batchsize, shuffle=True)
     test_Dataset = datasets.ImageFolder(root='./fruits-360/Test',transform=data_transform)
     test_loader = torch.utils.data.DataLoader(test_Dataset,batch_size=batchsize, shuffle=False)
     print("train images: {}".format(len(train_Dataset)))
     print("test images: {}".format(len(test_Dataset)))
     print("epoch: {}".format(epochs))
     print("batch size: {}".format(batchsize))

     for epoch in range(epoch_start, epochs+1):
         train(model, train_loader,epoch)
         test(model, test_loader,epoch)
         save_checkpoint(checkout_dir, epoch, model)
     evaluation(checkout_dir,epochs,model,test_loader)

結果

0.991280536138935

ファインチューニングをしていて、データ数も結構あるので、かなり高い精度となりました。


いかがだったでしょうか。今度はPyTorchのsklearnラッパーであるSkorchを使ってわかりやすく書きたいですね。ほかにもkaggleの記事があるのでよかったら読んでください。

最後まで読んでいただきありがとうございました。よろしければこの記事をシェアしていただけると励みになります。よろしくお願いします。

スポンサーリンク
レクタングル広告(大)
レクタングル広告(大)

シェアする

  • このエントリーをはてなブックマークに追加

フォローする

スポンサーリンク
レクタングル広告(大)