yag's blog

Twitter以上Zenn以下なことを書く場所

XGBoostのScikit-Learn APIでearly stoppingを利用する

この記事では、XGBoostのScikit-Learn APIを使いながらもearly stoppingを利用する方法を紹介します。

一般的な方法

XGBoostのLearning APIとは違って、Scikit-Learn APIXGBClassifierクラス自体にはearly stoppingのパラメータがありません。その代わりにXGBClassifier.fit()の引数にearly_stopping_roundsがありますので、こちらを利用します。その際にはeval_setを同時に指定する必要があります。

xgb_model = XGBClassifier()
xgb_model.fit(X_train,
              y_train,
              early_stopping_rounds=100,
              eval_set=[[X_test, y_test]])

Python API Reference — xgboost 0.6 documentation

GridSearchCV/RandomizedSearchCVを併用する方法

実際にscikit-learnと組み合わせている場合には単体でのfitよりも、GridSearchCVやRandomizedSearchCVといったグリッドサーチと併用することが多いです。その際には、以下のようにfit_paramsを指定することによって、グリッドサーチ内でのearly stoppingが可能になります。

fit_params = {"early_stopping_rounds": 100,
              "eval_set": [[X_test, y_test]]}

xgb_model = xgb.XGBClassifier()
gs = GridSearchCV(xgb_model,
                  params,
                  fit_params=fit_params,
                  cv=10,
                  n_jobs=-1,
                  verbose=2)
gs.fit(X_train, y_train)

参考