XGBoostのScikit-Learn APIでearly stoppingを利用する
この記事では、XGBoostのScikit-Learn APIを使いながらもearly stoppingを利用する方法を紹介します。
一般的な方法
XGBoostのLearning APIとは違って、Scikit-Learn APIのXGBClassifier
クラス自体にはearly stoppingのパラメータがありません。その代わりにXGBClassifier.fit()
の引数にearly_stopping_rounds
がありますので、こちらを利用します。その際にはeval_set
を同時に指定する必要があります。
xgb_model = XGBClassifier()
xgb_model.fit(X_train,
y_train,
early_stopping_rounds=100,
eval_set=[[X_test, y_test]])
Python API Reference — xgboost 0.6 documentation
GridSearchCV/RandomizedSearchCVを併用する方法
実際にscikit-learnと組み合わせている場合には単体でのfitよりも、GridSearchCVやRandomizedSearchCVといったグリッドサーチと併用することが多いです。その際には、以下のようにfit_params
を指定することによって、グリッドサーチ内でのearly stoppingが可能になります。
fit_params = {"early_stopping_rounds": 100, "eval_set": [[X_test, y_test]]} xgb_model = xgb.XGBClassifier() gs = GridSearchCV(xgb_model, params, fit_params=fit_params, cv=10, n_jobs=-1, verbose=2) gs.fit(X_train, y_train)