Pytorchで自作のデータセットの使い方をまとめてみた

シェアする

  • このエントリーをはてなブックマークに追加
スポンサーリンク

はじめに

 今回は、Pytorchのデータセットの作り方について、データセットのコードを見つつ、実際に可視化しながらかみ砕いていこうと思います。なぜこのような記事を書こうと思ったというと、私は現在Pytorch、Keras、Chainer、Scikit-learnを使っているのですが、フレームワーク毎要求されるものが違い、頭がこんがらがってきたのでアウトプットして整理するとともに他の方のご指摘も受けながらしっかり理解していきたいと思ったからです。

また、現在公開されている記事の多くが、MNISTやCIFAR10などで次に何をしたらいいかわからない方が多いのではないかと思います。そういったものを埋めれたらと思います。

それでは、今回は前回のフルーツのデータを使ってデータセットの作り方を解説していきたいと思います。

今回のデータセットのリンク

Flowers Recognition

関連のおすすめ記事:

データセットの説明

スタイルの種類

  • カモミール
  • チューリップ
  • バラ
  • ヒマワリ
  • タンポポ

各クラス約800枚ずつ

データセットの内訳

  • 総画像数:4323
  • 画像サイズ:約320×240ピクセル

データセットの作り方

まず、Pytorchで画像を使って学習させたいと思ったら、便利なImageFolder()というものがあります。

これは、画像データを次のように保存しておくと自動的に(画像パス,target(フォルダの名前))といったタプル型のリストを生成してくれます。

ディレクトリ構成例

  • train
    • カモミール
      • image_0.png
      • image_1.png
      • image_100.png
    • チューリップ
    • バラ
    • ヒマワリ
    • タンポポ
  • test
    • カモミール
    • チューリップ
    • バラ
    • ヒマワリ
    • タンポポ

ImageFoloder

コードを見てみましょう。

torchvision/datasets/folder.py

必要なところだけ抜粋しています。詳しく見たい方は上にあるリンクで飛んで読んでください。まず、ImageFolderを見たところ、DatasetFolderを継承しているのがわかりますね。なので、DatasetFolderを見てみましょう。

class ImageFolder(DatasetFolder):
   def __init__(self, root, transform=None, target_transform=None,
     loader=default_loader):
     super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
                                       transform=transform,
                                       target_transform=target_transform)
     self.imgs = self.samples

こちらも必要な部分のみ抜粋しています。find_classes()でディレクトリ名を読み込んで、(classes, class_to_idx)という形で返しています。そして、make_datasetでsamplesというのを作っていますね。 make_datasetを次に見ていきます。

def __getitem__:ここでのイメージとしては、1つ1つの(画像パス,ラベル)というタプル型のデータに行う処理

ここでは、self.loader(path)でPILでpathにある画像を読み込み、画像とラベルにそれぞれtransformの有無を調べ、あったら適応し返している。

class DatasetFolder(data.Dataset):
     def __init__(self, root, loader, extensions, transform=None, target_transform=None):
         classes, class_to_idx = self._find_classes(root)
         samples = make_dataset(root, class_to_idx, extensions)
     def __getitem__(self, index):
         path, target = self.samples[index]
         sample = self.loader(path)
         if self.transform is not None:
             sample = self.transform(sample)
         if self.target_transform is not None:
             target = self.target_transform(target)
         return sample, target
     def __len__(self):
         return len(self.samples)

make_dataset()では、拡張子を確認してから(パス,ラベル)というタプル型のデータをimagesというリストに入れて返しています。次のIMG_EXTENSIONSに入っているのが、使うことができる拡張子になっています。IMG_EXTENSIONSでエラーを吐いた場合は使いたい拡張子をappend()してあげればよさそう。

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']

def make_dataset(dir, class_to_idx, extensions):
     images = []
     dir = os.path.expanduser(dir)
     for target in sorted(class_to_idx.keys()):
         d = os.path.join(dir, target)
         if not os.path.isdir(d):
             continue
         for root, _, fnames in sorted(os.walk(d)):
             for fname in sorted(fnames):
                 if has_file_allowed_extension(fname, extensions):
                     path = os.path.join(root, fname)
                     item = (path, class_to_idx[target])
                     images.append(item)
     return images

さて、これで何が返ってきているのかなんとなくイメージできましたね。つぎはtransformについてみていきます。transformでできる処理は公式のこのリンクに載っていますが今回は使いそうなものを一つ一つ可視化しながら試していきたいと思います。

transform

ImageFolderに少しコードを追加してみました。これでtransformに入っている処理をする前とした後を出力させてみます。

class ImageFolder(DatasetFolder):
     def __init__(self, root, transform=None, target_transform=None,
                  loader=default_loader):
         super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
                                           transform=transform,
                                           target_transform=target_transform)
         self.imgs = self.samples
         i = 0
         for img,label in self.imgs:
           if label == i:
             print(img)
             im = Image.open(img)
             im_list = np.asarray(im)
             plt.imshow(im_list)
             plt.show()
             im_trans = self.transform(im)
             print(im_trans)
             im_list = np.asarray(im_trans)
             plt.imshow(im_list)
             plt.show()
             i +=1

data_transform = transforms.Compose([
 #ここの中身を変えていく  ])
 Dataset =ImageFolder(root='flowers',transform=data_transform)

transforms.CenterCrop

transforms.CenterCrop(size)

次のコードをtransforms.Composeの中に入れてください。真ん中でsizeに合わせてクロッピングします。

transforms.CenterCrop(224)

出力として、カモミール、チューリップ、バラ、ヒマワリ、タンポポの画像を一枚ずつ表示できているとおもいます。この記事では全て表示させると長くなってしまうためカモミールだけにしておきます。

もしPIL周りでエラーを吐いている人がいましたら、次のコードを使ってください。GoogleColaboratoryの特有のバグみたいなものらしいです。

!pip install Pillow==4.0.0  
!pip install PIL  
!pip install image  
import PIL  
import image

transforms.ColorJitter

transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)

画像の明るさ、コントラスト、彩度をランダムに変更します。

  • 明るさ(float)
  • コントラスト(float)
  • 彩度(float)
  • 色相(float)
transforms.ColorJitter(brightness=1, contrast=1, saturation=1, hue=0.5)

transforms.Grayscale

transforms.Grayscale(channels)

画像をグレースケールに変換します。

transforms.Grayscale(1)

transforms.RandomCrop

transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')

ランダムに画像をサイズに合わせてクロッピングする。

transforms.RandomCrop(224)

transforms.RandomHorizontalFlip

transforms.RandomHorizontalFlip(p=0.5)

画像を与えられた確率で反転する。(デフォルトp=0.5)

transforms.RandomHorizontalFlip(p=0.5)

transforms.RandomResizedCrop

transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2)

与えられた画像をランダムなサイズとアスペクト比にクロッピングします。

 transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2)

transforms.RandomRotation

transforms.RandomRotation(degrees, resample=False, expand=False, center=None)

画像を回転させます。

transforms.RandomRotation(80, resample=False, expand=False, center=None)

transforms.RandomVerticalFlip

transforms.RandomVerticalFlip(p=0.5)

画像を与えられた確率でランダムに垂直に反転させる。

transforms.RandomVerticalFlip(p=0.5)

transforms.Resize

transforms.Resize(size, interpolation=2)

入力されたPILイメージのサイズを指定したサイズに変更します。

transforms.Resize((224,224))

全部は紹介できないのでこのくらいにしておこうと思います。

utils.data.random_split

random_splitを使うことでデータセットを分けることができます。主に学習用と検証用に分けるときに使うことができます。

sklearnのtrain_test_splitのpytorch版のようなイメージですね。

full_Dataset = datasets.ImageFolder(root='flowers',transform=data_transform)  
print(len(full_Dataset))  
train_size = int(0.8 * len(full_Dataset))  
test_size = len(full_Dataset) - train_size  
train_Dataset, test_Dataset = torch.utils.data.random_split(full_Dataset, [train_size, test_size])  
print(len(train_Dataset))  
print(len(test_Dataset))

utils.data.DataLoader

パラメータは次のようになっています。utils.data.Datalodaerを使うと、Datasetを渡すことでミニバッチを返すIterableなオブジェクトにしてくれます。

よく使うパラメータはデータセットとバッチサイズとシャッフルでしょうか。シャッフルするかどうかはデータによって変わってきます。

  • dataset (Dataset)
  • batch_size (int, optional)
  • shuffle (bool, optional) (default: False).
  • sampler (Sampler, optional)
  • batch_sampler (Sampler, optional)
  • num_workers (int, optional)
  • collate_fn (callable, optional)
  • pin_memory (bool, optional)
  • drop_last (bool, optional)
  • timeout (numeric, optional)
  • worker_init_fn (callable, optional)

今回は画像をメインにまとめました。ほかのデータセットを使うときの書き方などはまた書きたいと思います。

最後まで読んでいただきありがとうございました。よろしければこの記事をシェアしていただけると励みになります。よろしくお願いします。

スポンサーリンク
レクタングル広告(大)
レクタングル広告(大)

シェアする

  • このエントリーをはてなブックマークに追加

フォローする

スポンサーリンク
レクタングル広告(大)