skorchを使ってPyTorchでCross-Validationを試してみる

シェアする

  • このエントリーをはてなブックマークに追加
スポンサーリンク

はじめに

 今回は、Skorchを使って画像分類をしたいと思います。今回使うデータは、前回も使ったkaggleの花のデータセットです。Skorchとは、PyTorchのsklearnラッパーで、sklearnのインターフェースでPyTorchを使うことができます。Skorchを使ってPyTorchでcross validationしてみたいと思います。

今回のデータセットのリンク

Flowers Recognition

関連のおすすめ記事:

Kaggleチャレンジ8日目 pytorchでフルーツの画像分類をしてみた

kaggleチャレンジ 10日目 花のデータセットを使ってPytorchで画像分類をしてみた

データセットの説明

スタイルの種類

  • カモミール
  • チューリップ
  • バラ
  • ヒマワリ
  • タンポポ

各クラス約800枚ずつ

データセットの内訳

  • 総画像数:4323
  • 画像サイズ:約320×240ピクセル

ソースコード

綺麗なコードではありませんが許してください。

まず、kaggleの花のデータセットをGoogleColaboratory上で扱えるようにします。わからない方は関連のおすすめ記事にあるリンクから確認してください。

ライブラリの準備

import torch  
import numpy as np  
import torch.nn as nn  
import torch.optim as optim  
import torch.nn.functional as F  
import torch.utils.data  
import torch.backends.cudnn as cudnn 
import torchvision  
from torchvision import datasets, models, transforms  
from sklearn.metrics import confusion_matrix, accuracy_score  
import random  
import os  
import time 
import skorch 
from skorch import NeuralNetClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report 
from PIL import Image 
import matplotlib.pyplot as plt 
% matplotlib inline

ネットワークの構築

Pytorchのmodelsからresnet18を呼び出して、今回は5種類の分類なので出力を5にします。今回は、学習済みモデルの重みを使っていません。理由は後で話します。

n_classes = 5 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
model = models.resnet18().to(device)  
n_filters = model.fc.in_features  
model.fc = nn.Linear(n_filters, n_classes) 
model.to(device)

データセット

これをしないとGoogleColaboratory上でPIL周りでエラーを吐くのでやっておきます。

!pip install Pillow==4.0.0 
!pip install PIL 
!pip install image 
import PIL 
import image

kerasの前処理のライブラリを読んできて、画像をnumpy配列のリストに変換します。本当はPytorchのImageFolderあたりでうまいことやりたかったのですがうまく動かせなかったのでこちらでやります。また、解凍したデータセット内に画像ファイル以外のものがありエラーを吐いたので、try文で囲んでいます。

from keras.preprocessing.image import load_img,img_to_array 
img_size = (224,224) 
dir_name = 'flowers/' 
labels = os.listdir(dir_name) 
print(labels) 
temp_img_array_list = [] 
temp_label_array_list = [] 
for label in labels:
     img_list = dir_name + label+ '/'
     imgs = os.listdir(img_list)
     for img in imgs:
       img_path = img_list + img
       try:
         img = PIL.Image.open(img_path)
         img_resize = img.resize(img_size)
         temp_img_array = img_to_array(img_resize) /255
         temp_img_array_list.append(temp_img_array)
         temp_label_array_list.append(label)
       except:
         print('not_jpg')
X_train = np.array(temp_img_array_list)
Y_train = np.array(temp_label_array_list)

とりあえず、画像で出力させてみましょう。

plt.imshow(X_train[1]) 
plt.show()

次に、今のままだとラベルがstring型で怒られてしまうので変えておきます。

from sklearn import preprocessing 
le = preprocessing.LabelEncoder()
le.fit(['tulip', 'rose', 'sunflower', 'dandelion', 'daisy']) 
Y_train = le.transform(Y_train)

ちなみにどんなことをしているかというとこんな感じです。適当に数値化してくれています。

from sklearn import preprocessing 
print(Y_train[:5]) 
le = preprocessing.LabelEncoder()
le.fit(['tulip', 'rose', 'sunflower', 'dandelion', 'daisy']) 
Y = le.transform(['tulip', 'rose', 'sunflower', 'dandelion', 'daisy']) 
print(Y)

['tulip' 'tulip' 'tulip' 'tulip' 'tulip']
#['tulip', 'rose', 'sunflower', 'dandelion', 'daisy'] 
[4 2 3 1 0]

あとは、Xも欲しがっている形式に変形させてあげます。

X_train = X_train.reshape(-1, 3, 224, 224) 
print(X_train.shape)

(4323, 3, 224, 224)

あとは、学習用と検証用に分けます。

x_train, x_test, y_train, y_test = train_test_split(X_train, Y_train, test_size=0.2, stratify=Y_train)

学習

Skorchのいいところは、sklearnのように書けるのでわかりやすくて楽ですね。fitとpredictでいいので。あと、今回学習回数を少なくしています。というのもGoogleColaboratoryの12GBという制限上、あまり多く回すと今回の場合途中で止まってしまったからですね。

  • NeuralNetClassifier :分類器
  • NeuralNetRegressor:回帰
net = NeuralNetClassifier(
     model,
     optimizer=torch.optim.Adam,
     criterion=torch.nn.CrossEntropyLoss,
     max_epochs=3, 
    lr=0.01, # default=0.01 
    iterator_train__batch_size=32, # default=128 
    iterator_train__shuffle=True, 
    device='cuda', ) 
net.fit(x_train, y_train) 
y_pred = net.predict(x_test) 
print(classification_report(y_test, y_pred))

  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        1.8696       0.2968        1.6170  24.4634
      2        1.5463       0.3343        1.9325  24.2520
      3        1.4379       0.4078        1.3483  24.2573
             precision    recall  f1-score   support

          0       0.44      0.10      0.17       154
          1       0.37      0.64      0.47       210
          2       0.00      0.00      0.00       157
          3       0.54      0.44      0.48       147
          4       0.37      0.66      0.48       197

avg / total       0.35      0.40      0.34       865

出力がわかりやすいですね。ちなみにcpuでやったら1epochあたり650secぐらいかかりました。

Cross Validation

お待ちかねのCross Validationをやっていきましょう。といってもこの二文でいけます。

from sklearn.model_selection import cross_val_predict 
y_pred = cross_val_predict(net, x_train, y_train, cv=5)

 epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        1.4302       0.2734        1.5954  19.5158
      2        1.2904       0.4191        1.4027  19.3534
      3        1.1973       0.4335        1.4994  19.3831
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        1.4519       0.3964        2.1590  19.4022
      2        1.2786       0.4216        1.5153  19.3832
      3        1.1872       0.4414        1.5136  19.3925
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        1.4183       0.4227        1.4543  19.4129
      2        1.2711       0.3831        1.4921  19.3797
      3        1.2470       0.3993        1.3312  19.3706
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        1.4737       0.3237        2.5694  19.4980
      2        1.4228       0.3417        1.8766  19.4501
      3        1.3026       0.4838        1.2620  19.4042

cv=5でまわしたのに4回の時点でメモリ制限によりランタイムリセットがかかり試合終了。

といってもほぼ花のデータセットに吸われているので違うデータセット使えば問題ないでしょう。


いかがだったでしょうか。PyTorchでsklearnみたく書きたい人、またCross Validationをしたい人はぜひskorchを使ってみてください。また、Pytorchのデータセットでうまく学習するやり方わかる人は教えてくださると助かります。

最後まで読んでいただきありがとうございました。よろしければこの記事をシェアしていただけると励みになります。よろしくお願いします。

スポンサーリンク
レクタングル広告(大)
レクタングル広告(大)

シェアする

  • このエントリーをはてなブックマークに追加

フォローする

スポンサーリンク
レクタングル広告(大)