Kaggleチャレンジ7日目 レシピから料理の種類を予測してみた

シェアする

  • このエントリーをはてなブックマークに追加
スポンサーリンク

はじめに

 今回は、レシピから料理を予測するWhat’s Cooking? (Kernels Only)を試していきたいと思います。面白そうなデータセットだったので選んでみました。

前回の記事:

Kaggle チャレンジ 1日目 タイタニックの問題からデータを読み解いてみる

Kaggle チャレンジ 4日目 住宅価格問題を解いていく

今回もGoogleColaboratoryを使って進めていくので、はじめ方などは前回の記事を参考にしてください。

コンペの説明

 あなたの地元の市場を散策している自分の姿を想像してみて下さい…  あなたには何が見えますか?あなたには何のにおいをしていますか?今夜は何を食べますか?

もし、あなたが北カリフォルニアにいれば、暗い紫色のケールやブッシェル(?)といった食材、韓国ならキムチ、インドだったら豊富な色合いと数十種類の香辛料、ウコン、アザミ、ケシの種子、ガラムマサラなどを市場で見ることができるでしょう。

私たちの地域や文化と地元の食糧には強い結びつきがあります。ここでは、食材のリストをもとに、料理の料理のカテゴリーを予測するように求められています。

評価方法

提出は、分類精度(正確に分類する料理の割合)で評価されます。

フォーマット:

id,cuisine
35203,italian
17600,italian
35200,italian
17602,italian
...
etc.

(コンペの説明欄から引用)


やること

 GoogleColaboratoryとKernelsを使ってレシピから料理を予測していきたいと思います。。

・今回参考にするKernels:

GoogleColaboratoryの使い方はこちら

対象者

機械学習をKaggleを使って学びたい方、Kaggleに興味がある方、いろいろなデータセットを試してみたい方。

 

What’s Cooking? (Kernels Only)

今回のコンペはKernels Onlyという事になっています。この場合は、予測されたデータのファイルだけでなく、ファイルの事前処理を含む全てのモデルのコードのカーネルの提出が必要となります。

食材を見て、私は料理を理解するのに役立つかもしれない要素をみつけました。

  • 外れ値
  • 特殊文字
  • 大文字
  • アポストロフィ
  • ハイフン
  • 数字
  • 単位
  • 地域名
  • アクセント
  • ユニークな食材
  • 言語
  • スペルミス

データ準備

データは、こちらからダウンロードして、zipファイルをGoogleドライブにアップロードしてGoogleColaboratory上で解凍するのが簡単だと思うので、今回もそれで行きたいと思います。

今回必要な物をインストールしておく

In[1]:

!pip install pydrive langdetect unidecode ipywidgets

In[2]:

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In[3]:

id = '**********************' # 共有リンクで取得した id= より後の部分を*の部分に入力
downloaded = drive.CreateFile({'id': id})
downloaded.GetContentFile('cook.zip') #ファイルの名前

In[4]:

!unzip cook.zip

今回使うライブラリをよんでおきます。

In[5]:

import json
import langdetect
import re
import time
import unidecode
import ipywidgets as widgets
import numpy as np
import pandas as pd
from ipywidgets import interact
from nltk.stem import WordNetLemmatizer
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold, cross_validate, train_test_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.pipeline import make_pipeline, make_union
from sklearn.preprocessing import FunctionTransformer, LabelEncoder, MultiLabelBinarizer

いつものように、データセットの最初の数件を表示させてみましょう。

In[6]:

train = pd.read_json('train.json')
test = pd.read_json('test.json')
train.head()
Out[6]:
cuisine id ingredients
0 greek 10259 [romaine lettuce, black olives, grape tomatoes…
1 southern_us 25693 [plain flour, ground pepper, salt, tomatoes, g…
2 filipino 20130 [eggs, pepper, salt, mayonaise, cooking oil, g…
3 indian 22213 [water, vegetable oil, wheat, salt]
4 indian 13162 [black pepper, shallots, cornflour, cayenne pe…

In[7]:

df = pd.concat([train, test])
df['ingredients_text'] = df['ingredients'].apply(lambda x: ', '.join(x))
df['num_ingredients'] = df['ingredients'].apply(lambda x: len(x))
raw_ingredients = [ingredient for ingredients in df.ingredients.values for ingredient in ingredients]
df.head()
Out[7]:
cuisine id ingredients ingredients_text num_ingredients
0 greek 10259 [romaine lettuce, black olives, grape tomatoes… romaine lettuce, black olives, grape tomatoes,… 9
1 southern_us 25693 [plain flour, ground pepper, salt, tomatoes, g… plain flour, ground pepper, salt, tomatoes, gr… 11
2 filipino 20130 [eggs, pepper, salt, mayonaise, cooking oil, g… eggs, pepper, salt, mayonaise, cooking oil, gr… 12
3 indian 22213 [water, vegetable oil, wheat, salt] water, vegetable oil, wheat, salt 4
4 indian 13162 [black pepper, shallots, cornflour, cayenne pe… black pepper, shallots, cornflour, cayenne pep… 20

データの分析

What are ingredients?を参考にしながらデータを見ていきたいと思います。

外れ値の除去

次のような1つの成分のみからなるレシピをいくつか見つけました。

水=>日本語
バター=>インド
バター=>フランス語
モデルに悪影響を与える可能性があります。しかし、そのようなレシピはテストデータセットにも存在します。

In [8]:
from matplotlib import pyplot as plt
import seaborn as sns

plt.figure(figsize=(16,4))
sns.countplot(x='num_ingredients', data=df)
Out[8]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f7b48e000f0>

In[9]:

df[df['num_ingredients'] <= 1]

Out[9]:
cuisine id ingredients ingredients_text num_ingredients
940 japanese 4734 [sushi rice] sushi rice 1
2088 vietnamese 7833 [dried rice noodles] dried rice noodles 1
6787 indian 36818 [plain low-fat yogurt] plain low-fat yogurt 1
7011 indian 19772 [unsalted butter] unsalted butter 1
8181 japanese 16116 [udon] udon 1
8852 thai 29738 [sticky rice] sticky rice 1
8990 indian 41124 [butter] butter 1
10506 mexican 32631 [corn tortillas] corn tortillas 1
13178 thai 29570 [grained] grained 1
17804 southern_us 29849 [lemonade concentrate] lemonade concentrate 1
18136 thai 39186 [jasmine rice] jasmine rice 1
18324 indian 14335 [unsalted butter] unsalted butter 1
21008 italian 39221 [cherry tomatoes] cherry tomatoes 1
22119 french 41135 [butter] butter 1
22387 indian 36874 [cumin seed] cumin seed 1
23512 french 35028 [haricots verts] haricots verts 1
26887 mexican 18593 [vegetable oil] vegetable oil 1
29294 spanish 7460 [spanish chorizo] spanish chorizo 1
30636 spanish 32772 [sweetened condensed milk] sweetened condensed milk 1
32105 japanese 12805 [water] water 1
34531 greek 10816 [phyllo] phyllo 1
37220 indian 27192 [unsalted butter] unsalted butter 1
544 NaN 36822 [plain low-fat yogurt] plain low-fat yogurt 1
3248 NaN 34002 [glutinous rice] glutinous rice 1
3444 NaN 28414 [pimentos] pimentos 1
3621 NaN 10077 [sweetened condensed milk] sweetened condensed milk 1
4021 NaN 32883 [unsalted butter] unsalted butter 1
7417 NaN 45798 [chiles] chiles 1
8081 NaN 45398 [parmesan cheese] parmesan cheese 1
9407 NaN 32743 [shiitake] shiitake 1

すべてのレシピ成分は有効でしょうか? たとえば、2文字以下で構成されているものは意味があるのでしょうか?

In[10]:

[ingredient for ingredient in raw_ingredients if len(ingredient) <= 2]

Out[10]:

['mi', 'mi', 'v8', 'v8', 'mi', 'la', 'mi', 'mi']

特殊文字

どのような特殊文字が含まれているかを確認します。

例えば、

  • “Bertolli® Alfredo Sauce”
  • “Progresso Chicken Broth”
  • “green bell pepper, slice”
  • “half & half”
  • “asafetida (powder)
  • “Spring! Water”

のようなものがあげられます。

In[11]:

' '.join(sorted([char for char in set(' '.join(raw_ingredients)) if re.findall('[^A-Za-z]', char)]))
Out[11]:
"  ! % & ' ( ) , - . / 0 1 2 3 4 5 6 7 8 9 ® â ç è é í î ú ’ € ™"

大文字

大文字を使用している場合、それは固有名詞の可能性があります。

例えば、

  • 会社名
    • Oscar Mayer Deli Fresh Smoked Ham”
  • 地域名
    • Shaoxing wine”
    • California bay leaves”
    • Italian parsley leaves”

In[12]:

list(set([ingredient for ingredient in raw_ingredients if re.findall('[A-Z]+', ingredient)]))[:5]
Out[12]:
['Asian chili sauce',
 'Spanish tuna',
 'Green Giant™ frozen chopped spinach',
 'Spanish olives',
 'Bramley apples']

アポストロフィ

例えば、

  • “Zatarain’s Jambalaya Mix”
  • “Breakstone’s Sour Cream”
  • “sheep’s milk cheese”

のようなものがあります。

データセット内にアポストロフィが多数ある場合は便利ですが、調べたところ数は多くないようです。

In[13]:

list(set([ingredient for ingredient in raw_ingredients if '’' in ingredient]))
Out[13]:
['Breakstone’s Sour Cream', 'sheep’s milk cheese', 'Zatarain’s Jambalaya Mix']

ハイフン( – )

” – “を “”に置き換えても問題ないでしょう。

  • “chicken-apple sausage”
  • “chocolate-hazelnut spread”
  • “bone-in chicken breasts”

In[14]:

list(set([ingredient for ingredient in raw_ingredients if re.findall('-', ingredient)]))[:5]
Out[14]:
['soft-shelled crabs',
 'soft-boiled egg',
 'liquid non-dairy creamer',
 'gluten-free oyster sauce',
 '1% low-fat milk']

数字
数字は数量または密度を示します。

例えば、

  • “1% low-fat milk”
  • “40% less sodium taco seasoning”
  • “mexican style 4 cheese blend”

のように表記されます。厳密に言えば、量は料理を特定する要素になる可能性がありますが、このデータセットには少ししか含まれていません。

In[15]:

list(set([ingredient for ingredient in raw_ingredients if re.findall('[0-9]', ingredient)]))[:5]
Out[15]:
['low sodium 96% fat free ham',
 'KRAFT Mexican Style 2% Milk Finely Shredded Four Cheese',
 '7 Up',
 '1% low-fat milk',
 '(14.5 oz.) diced tomatoes']

単位

ユニットには数字が付いています。

例えば、

  • “(15 oz.) refried beans”
  • “2 1/2 to 3 lb. chicken, cut into serving pieces”
  • “pork chops, 1 inch thick”

単位には特定の地域でのみ使用されるものがあるので、 分類に役立つかもしれません。

In[16]:

units = ['inch', 'oz', 'lb', 'ounc', '%'] # ounc is a misspelling of ounce?

@interact(unit=units)
def f(unit):
    ingredients_df = pd.DataFrame([ingredient for ingredient in raw_ingredients if unit in ingredient], columns=['ingredient'])
    return ingredients_df.groupby(['ingredient']).size().reset_index(name='count').sort_values(['count'], ascending=False)
Out[16]:
ingredient count
0 kinchay 4
2 pork chops, 1 inch thick 3
1 mccormick perfect pinch italian seasoning 1

地域名

In[17]:

keywords = [
    # It indicates the cusine directly
    'american', 'greek', 'filipino', 'indian', 'jamaican', 'spanish', 'italian', 'mexican', 'chinese', 'thai',
    'vietnamese', 'cajun', 'creole', 'french', 'japanese', 'irish', 'korean', 'moroccan', 'russian',
    # Region names I found in the dataset
    'tokyo', 'shaoxing', 'california'
]

@interact(keyword=keywords)
def f(keyword):
    ingredients_df = pd.DataFrame([ingredient for ingredient in raw_ingredients if keyword in ingredient], columns=['ingredient'])
    return ingredients_df.groupby(['ingredient']).size().reset_index(name='count').sort_values(['count'], ascending=False)
Out[17]:
ingredient count
1 american cheese slices 6
2 american eggplant 2
3 american long grain rice 2
0 american cheese food 1

アクセント

一部のアクセントも単位と同様に特定の地域でのみ使用される場合があります。 この情報は使用できるかもしれません。

In[18]:

accents = ['â', 'ç', 'è', 'é', 'í', 'î', 'ú']

@interact(accent=accents)
def f(accent):
    ingredients_df = pd.DataFrame([ingredient for ingredient in raw_ingredients if accent in ingredient], columns=['ingredient'])
    return ingredients_df.groupby(['ingredient']).size().reset_index(name='count').sort_values(['count'], ascending=False)
Out[18]:
ingredient count
0 Neufchâtel 7
4 pâté 3
3 pâte brisée 2
1 bâtarde 1
2 hellmann’ or best food canola cholesterol fr… 1

ユニークな食材

一部の食材は特定の地域でのみ使用されます。この情報も分類に役に立ちそうです。

In[19]:

lemmatizer = WordNetLemmatizer()
def preprocess(ingredients):
    ingredients = ' '.join(ingredients).lower().replace('-', ' ')
    ingredients = re.sub("\d+", "", ingredients)
    return [lemmatizer.lemmatize(ingredient) for ingredient in ingredients.split()]

ingredients_df = df.groupby(['cuisine'])['ingredients'].sum().apply(lambda ingredients: preprocess(ingredients)).reset_index()
unique_ingredients = []
for cuisine in ingredients_df['cuisine'].unique():
    target = set(ingredients_df[ingredients_df['cuisine'] == cuisine]['ingredients'].values[0])
    others = set(ingredients_df[ingredients_df['cuisine'] != cuisine]['ingredients'].sum())
    unique_ingredients.append({
        'cuisine': cuisine,
        'ingredients': target - others
    })
pd.DataFrame(unique_ingredients, columns=['cuisine', 'ingredients'])
Out[19]:
cuisine ingredients
0 brazilian {maca, longaniza, licor, linguica, acai, farof…
1 british {kippered, standing, poundcake, swede, bénédic…
2 cajun_creole {tuttorosso, smart, seven, blackened, tony, wh…
3 chinese {cheong, loofah, ploy, wolfberries, tangzhong,…
4 filipino {nestle, ampalaya, bihon, kangkong, dew, blueb…
5 french {glacés, lillet, verbena, frankfurter, kamut, …
6 greek {ammonium, mezzetta, graviera, mahlab, cavende…
7 indian {chana, puffed, barberry, chapatti, indian, ar…
8 irish {porter, nettle, maraschino, challenge, scone,…
9 italian {calabrese, ricard, gnocchetti, fume, robusto,…
10 jamaican {caribbean, patty, cho, any, cara, callaloo, p…
11 japanese {tsuyu, unagi, mentsuyu, dhaniya, maitake, cha…
12 korean {azuki, riso, angus, gochugaru, kochu, trimmed…
13 mexican {guanabana, beater, unhulled, panela, anejo, m…
14 moroccan {pareve, mince, moroccan, bordelaise, arak, ha…
15 russian {beluga, pullman, bacardi®, bear, bermuda, tus…
16 southern_us {secret, velvet, jarred, better, chowchow, alm…
17 spanish {vera, foccacia, cabrales, atlantic, valencia,…
18 thai {wine,, prik, based, belacan, basa, muscavado,…
19 vietnamese {olie, paddy, romanesco, hubbard, woksaus, chu…

ここで、エラーがでる可能性があります。出た場合は次のコードを入力してください。

import nltk
nltk.download('wordnet')

言語

アクセントと同様に、私たちは文字のシーケンスを見ることによって、どの言語がその成分であるかを推測することができます。

  • tofu => Japanese
  • purée => French

食材から言語情報を抽出できるでしょうか? この場合、食材の言語をどのように検出するかについて考える必要があります。

In[20]:

text_languages = []
for text in [
    'ein, zwei, drei, vier',
    'purée',
    'taco',
    'tofu',
    'tangzhong',
    'xuxu',
]:
    text_languages.append({
        'text': text,
        'detected language': langdetect.detect(text)
    })
pd.DataFrame(text_languages, columns=['text', 'detected language'])
Out[20]:
text detected language
0 ein, zwei, drei, vier de
1 purée fr
2 taco es
3 tofu en
4 tangzhong tl
5 xuxu so

スペルミス

データセットにスペルミスがあることがわかりました。

  • ounc (ounce)
  • wasabe (wasabi)

正規化

私は特殊文字、スペルミスを含むいくつかの食材を見たところ、私はおそらく食材を正規化しなければならないでしょう。

In[21]:

from IPython.display import clear_output

ingredients = ['romaine lettuce', 'Eggs', 'Beef demi-glace', 'Sugar 10g', 'Pumpkin purée', 'Kahlúa']
labels = [widgets.Label(ingredient) for ingredient in ingredients]

lower_checkbox = widgets.Checkbox(value=False, description='lower', indent=False)
lemmatize_checkbox = widgets.Checkbox(value=False, description='lemmatize', indent=False)
remove_hyphens_checkbox = widgets.Checkbox(value=False, description='remove hyphens', indent=False)
remove_numbers_checkbox = widgets.Checkbox(value=False, description='remove numbers', indent=False)
strip_accents_checkbox = widgets.Checkbox(value=False, description='strip accents', indent=False)

lemmatizer = WordNetLemmatizer()
def lemmatize(sentence):
    return ' '.join([lemmatizer.lemmatize(word) for word in sentence.split()])
assert lemmatize('eggs') == 'egg'

def remove_numbers(sentence):
    words = []
    for word in sentence.split():
        if re.findall('[0-9]', word): continue
        if len(word) > 0: words.append(word)
    return ' '.join(words)

def update_ingredients(widget):
    for i, ingredient in enumerate(ingredients):
        processed = ingredient
        if lower_checkbox.value: processed = processed.lower()
        if lemmatize_checkbox.value: processed = lemmatize(processed)
        if remove_hyphens_checkbox.value: processed = processed.replace('-', ' ')
        if remove_numbers_checkbox.value: processed = remove_numbers(processed)
        if strip_accents_checkbox.value: processed = unidecode.unidecode(processed)
        if processed == ingredient:
            labels[i].value = ingredient
        else:
            labels[i].value = f'{ingredient} => {processed}'

lower_checkbox.observe(update_ingredients)
lemmatize_checkbox.observe(update_ingredients)
remove_hyphens_checkbox.observe(update_ingredients)
remove_numbers_checkbox.observe(update_ingredients)
strip_accents_checkbox.observe(update_ingredients)

display(widgets.VBox([
    widgets.Box([lower_checkbox, lemmatize_checkbox, remove_hyphens_checkbox, remove_numbers_checkbox, strip_accents_checkbox]),
    widgets.VBox(labels)
]))

モデル作り

次は、Let’s cook modelと今までに得たデータをもとにモデルを作っていきたいと思います。

次のような流れになります。しかし、いくつかは

  1. Load dataset (済み)
  2. Remove outliers (上の分析をもとに外れ値を取り除く)
  3. Preprocess (
  4. Create model
  5. Check local CV
  6. Train model
  7. Check predicted values

データのロード

データセットをロードするのは上でやりましたが、いろいろいじったので改めて作り直します。

In[22]:

train = pd.read_json('train.json')
test = pd.read_json('test.json')

ライブラリも入れなおします。

In[23]:

import json
import re
import unidecode
import numpy as np
import pandas as pd
from collections import defaultdict
from nltk.stem import WordNetLemmatizer
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import cross_validate
from sklearn.multiclass import OneVsRestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline, make_union
from sklearn.preprocessing import FunctionTransformer, LabelEncoder
from tqdm import tqdm
tqdm.pandas()

外れ値の除去

上でも言った通り、食材が一つのレシピがあります。このようなレシピにフィルタをかけてみましょう。

In[24]:

train['num_ingredients'] = train['ingredients'].apply(len)
train = train[train['num_ingredients'] > 1]

前処理

データ分析から、前処理は次のようになっています。

  • 小文字に変換
  • ハイフンを削除する
  • 番号を削除する
  • 2文字未満の単語を削除する
  • 最適化する

In[25]:

lemmatizer = WordNetLemmatizer()
def preprocess(ingredients):
    ingredients_text = ' '.join(ingredients)
    ingredients_text = ingredients_text.lower()
    ingredients_text = ingredients_text.replace('-', ' ')
    words = []
    for word in ingredients_text.split():
        if re.findall('[0-9]', word): continue
        if len(word) <= 2: continue
        if '’' in word: continue
        word = lemmatizer.lemmatize(word)
        if len(word) > 0: words.append(word)
    return ' '.join(words)

for ingredient, expected in [
    ('Eggs', 'egg'),
    ('all-purpose flour', 'all purpose flour'),
    ('purée', 'purée'),
    ('1% low-fat milk', 'low fat milk'),
    ('half & half', 'half half'),
    ('safetida (powder)', 'safetida (powder)')
]:
    actual = preprocess([ingredient])
    assert actual == expected, f'"{expected}" is excpected but got "{actual}"'
In [26]:
train['x'] = train['ingredients'].progress_apply(preprocess)
test['x'] = test['ingredients'].progress_apply(preprocess)
train.head()
100%|██████████| 39752/39752 [00:04<00:00, 8436.79it/s]
100%|██████████| 9944/9944 [00:01<00:00, 8387.72it/s]
Out[26]:
cuisine id ingredients num_ingredients x
0 greek 10259 [romaine lettuce, black olives, grape tomatoes… 9 romaine lettuce black olive grape tomato garli…
1 southern_us 25693 [plain flour, ground pepper, salt, tomatoes, g… 11 plain flour ground pepper salt tomato ground b…
2 filipino 20130 [eggs, pepper, salt, mayonaise, cooking oil, g… 12 egg pepper salt mayonaise cooking oil green ch…
3 indian 22213 [water, vegetable oil, wheat, salt] 4 water vegetable oil wheat salt
4 indian 13162 [black pepper, shallots, cornflour, cayenne pe… 20 black pepper shallot cornflour cayenne pepper …

後でTfidfVectorizerのパラメータを調整する必要があります。

In[27]:

vectorizer = make_pipeline(
    TfidfVectorizer(sublinear_tf=True),
    FunctionTransformer(lambda x: x.astype('float16'), validate=False)
)

x_train = vectorizer.fit_transform(train['x'].values)
x_train.sort_indices()
x_test = vectorizer.transform(test['x'].values)

ラベルエンコード

In [28]:
label_encoder = LabelEncoder()
y_train = label_encoder.fit_transform(train['cuisine'].values)
dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))
Out[28]:
{'brazilian': 0,
 'british': 1,
 'cajun_creole': 2,
 'chinese': 3,
 'filipino': 4,
 'french': 5,
 'greek': 6,
 'indian': 7,
 'irish': 8,
 'italian': 9,
 'jamaican': 10,
 'japanese': 11,
 'korean': 12,
 'mexican': 13,
 'moroccan': 14,
 'russian': 15,
 'southern_us': 16,
 'spanish': 17,
 'thai': 18,
 'vietnamese': 19}

モデルを作成する

LogisticRegression、GaussianProcessClassifier、GradientBoostingClassifier、MLPClassifier、LGBMClassifier、SGDClassifier、Kerasを試したところ、SVCがいまのところうまく機能しています。

モデルとパラメータをもっと詳しく調べる必要があります。

In[29]:

estimator = SVC(
    C=50,
    kernel='rbf',
    gamma=1.4,
    coef0=1,
    cache_size=500,
)
classifier = OneVsRestClassifier(estimator, n_jobs=-1)

 local CVを確認する

あなたの local CVを信頼してください。これは一番大事なことです。local CVを見ながら、異なる処理とパラメータを試してみてください。

In[30]:

%%time
scores = cross_validate(classifier, x_train, y_train, cv=3)
scores['test_score'].mean()
Out[30]:
CPU times: user 1h 37min 38s, sys: 32.3 s, total: 1h 38min 10s
Wall time: 1h 38min 1s

Train model

モデルに信用が持てるようになると、私はtrainデータ全体を提出するように学習します。

In [31]:
%%time
classifier.fit(x_train, y_train)
Out[31]:
CPU times: user 296 ms, sys: 208 ms, total: 504 ms
Wall time: 22min 3s
OneVsRestClassifier(estimator=SVC(C=50, cache_size=500, class_weight=None, coef0=1,
  decision_function_shape='ovr', degree=3, gamma=1.4, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False),
          n_jobs=-1)

予測値をチェックする

モデルがうまく学習しているかどうか確認してみます。

In[32]:

y_pred = label_encoder.inverse_transform(classifier.predict(x_train))
y_true = label_encoder.inverse_transform(y_train)

print(f'accuracy score on train data: {accuracy_score(y_true, y_pred)}')

def report2dict(cr):
    rows = []
    for row in cr.split("\n"):
        parsed_row = [x for x in row.split("  ") if len(x) > 0]
        if len(parsed_row) > 0: rows.append(parsed_row)
    measures = rows[0]
    classes = defaultdict(dict)
    for row in rows[1:]:
        class_label = row[0]
        for j, m in enumerate(measures):
            classes[class_label][m.strip()] = float(row[j + 1].strip())
    return classes
report = classification_report(y_true, y_pred)
pd.DataFrame(report2dict(report)).T
Out[32]:
accuracy score on train data: 0.9996478164620648
Out[33]:
f1-score precision recall support
brazilian 1.0 1.0 1.0 467.0
british 1.0 1.0 1.0 804.0
cajun_creole 1.0 1.0 1.0 1546.0
chinese 1.0 1.0 1.0 2673.0
filipino 1.0 1.0 1.0 755.0
french 1.0 1.0 1.0 2644.0
greek 1.0 1.0 1.0 1174.0
indian 1.0 1.0 1.0 2997.0
irish 1.0 1.0 1.0 667.0
italian 1.0 1.0 1.0 7837.0
jamaican 1.0 1.0 1.0 526.0
japanese 1.0 1.0 1.0 1420.0
korean 1.0 1.0 1.0 830.0
mexican 1.0 1.0 1.0 6436.0
moroccan 1.0 1.0 1.0 821.0
russian 1.0 1.0 1.0 489.0
southern_us 1.0 1.0 1.0 4319.0
spanish 1.0 1.0 1.0 987.0
thai 1.0 1.0 1.0 1536.0
vietnamese 1.0 1.0 1.0 824.0
micro avg 1.0 1.0 1.0 39752.0
macro avg 1.0 1.0 1.0 39752.0
weighted avg 1.0 1.0 1.0 39752.0

提出

In[34]:

y_pred = label_encoder.inverse_transform(classifier.predict(x_test))
test['cuisine'] = y_pred
test[['id', 'cuisine']].to_csv('submission.csv', index=False)
test[['id', 'cuisine']].head()
Out[34]:
id cuisine
0 18009 irish
1 28583 southern_us
2 41580 italian
3 29752 cajun_creole
4 35687 italian

最後に、出力させたcsvファイルをローカルに落として提出しましょう。

In[35]:

from google.colab import files
files.download('submission.csv')

今回は、レシピの食材から料理の種類を予測する問題を試してみました。色々なデータセットを見ることで、アプローチが違うのが面白いですね。ほかのデータセットを使った記事もあるのでぜひ見てください。

最後まで読んでいただきありがとうございました。よろしければこの記事をシェアしていただけると励みになります。よろしくお願いします。

スポンサーリンク
レクタングル広告(大)
レクタングル広告(大)

シェアする

  • このエントリーをはてなブックマークに追加

フォローする

スポンサーリンク
レクタングル広告(大)