学習曲線(Learning Curve)で過学習、学習不足を検証

データサイエンス
スポンサーリンク

前回はvalidation_curveでパラメータの範囲を絞り込む方法を使ってGridSearchCVの実行時間削減に挑戦しました。各パラメータの最適値についてはGridSearchCVで求めることはできるようになりました。

その一方で、学習データ数によって学習不足だったり過学習を起こしていないか?という心配が出てきます。サンプル数が少なかったり、学習データのパラメータが多すぎると過学習を起こしてしまって、検証データにうまく適応できてない、ということはよくあるので、過学習のチェックは必須でしょう。

そこで使うのが学習曲線(learning_curve)です。learning_curveはサンプル数を変えながら学習データと検証データの正解率について、

  • 両者がどのくらいの正解率に着地するか?(=漸近線)
  • 両者の乖離はどれくらいか?

といった観点で比較・検証していきます。

スポンサーリンク

Learning Curveのサンプル

それではさっそく、Learning Curveのコードを書いていきます。今までと同様にKaggleのTitanic課題のデータを使っていますので、学習データと検証データの作り方については過去の記事を参考にしてください。

今回のモデルはランダムフォレストのデフォルトを使ってみます。learning_curve関数の引数は以下になります。

  • estimator:検証したいモデル
  • X :入力データ
  • y : 出力データ
  • train_sizes : 試したいサンプル数([100, 200, 300, …, 1000])
  • cv : バリデーションデータセットを作成する際の分割方法(デフォルトは5-fold法)
from sklearn.model_selection import learning_curve
clf = RandomForestClassifier()
train_sizes, train_scores, test_scores = learning_curve(estimator = clf, X = X_train,train_sizes=train_sizes, y = y_train, cv=10, n_jobs=1)
train_mean = np.mean(train_scores, axis=1)
train_std  = np.std(train_scores, axis=1)
test_mean = np.mean(test_scores, axis=1)
test_std  = np.std(test_scores, axis=1)

learning_curveは学習データの結果と検証データの結果の2つの結果をそれぞれ配列で返します。そこでこれらの結果を比較するためにグラフ化します。

import matplotlib.pyplot as plt
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)
plt.figure()
plt.title("Learning Curve")
plt.xlabel("Training examples")
plt.ylabel("Score")
# Traing score と Test score をプロット
plt.plot(train_sizes, train_scores_mean, 'o-', color="r", label="Training score")
plt.plot(train_sizes, test_scores_mean, 'o-', color="g", label="Validation score")
# 標準偏差の範囲を色付け
plt.fill_between(train_sizes, train_scores_mean - train_scores_std, train_scores_mean + train_scores_std, color="r", alpha=0.2)
plt.fill_between(train_sizes, test_scores_mean - test_scores_std, test_scores_mean + test_scores_std, color="g", alpha=0.2)
# Y軸の範囲
plt.ylim(0.7, 1.0)
# 凡例の表示位置
plt.legend(loc="best")
plt.show()

これだけでは検証データの正解率が上がっていないのはわかりますが、だから何?と言った感じで、どうしたら良いのかがわかりませんね。

Learning Curveの使い方

Learning Curveの結果は大きく3つに分けられます。

成功パターン(両方とも高い正解率で収束)

学習データ、検証データともに高い正解率に収束し、サンプル数が増えても検証データが下がることがない場合は成功パターンと言えます。

学習不足パターン(両方とも低い正解率)

学習データと検証データが近づいているものの、正解率が低い場合は学習不足が考えられます。このような学習不足の場合はパラメータを増やすのが一般的です。

過学習(学習データだけ高い正解率)

学習データの正解率だけ高くて、検証データの正解率が低い場合、過学習を起こしている可能性が高いです。これが発生するのは主にサンプル数が少なくて、パラメータ数が多すぎる場合です。

こちらは検証データの正解率が0.83くらいで頭打ちになっていることがわかります。このことからサンプル数を増やしても正解率が上がらないことが予想されます。

この場合は、パラメータが多すぎる可能性が高いので、パラメータを減らしてみる事をオススメします。

一方、下記の場合は、サンプル数が1400の時点でも、それ以降正解率が上がり続ける可能性がありそうです。このような場合はサンプル数を増やして再度検証するのが良いでしょう。

Learning Curveでモデル作成後の過学習の検証に使う

Learning Curveは過学習の検証に使うのが多いようです。そのため、使う流れとしては、GridSearchCVなどでモデルを作成後で、このLearning Curveで過学習の可能性がある場合は、パラメータの変更が必要になってきます。

次回はパラメータ(特徴量)の選定方法について紹介したいと思います。

コメント

タイトルとURLをコピーしました