過学習とプルーニング
上記のアルゴリズムを使用すると、トレーニング サンプルが分離可能であると仮定して、トレーニング サンプルを完全に分類するディシジョン ツリーをトレーニングできます。ただし、データセットにノイズが含まれていると、このツリーがデータに過学習し、テストの精度が低下します。
次の図は、特徴 x とラベル y との間に線形関係があるノイズの多いデータセットを示しています。この図は、このデータセットでトレーニングされたディシジョン ツリーも示しています。正則化は一切行いません。このモデルは、すべてのトレーニング例を正しく予測します(モデルの予測はトレーニング例と一致します)。ただし、同じ線形パターンと異なるノイズ インスタンスを含む新しいデータセットでは、モデルのパフォーマンスが低下します。
図 12. ノイズの多いデータセット。
ディシジョン ツリーの過学習を制限するには、ディシジョン ツリーをトレーニングする際に、次の正則化基準のいずれかまたは両方を適用します。
- 最大深度を設定する: ディシジョン ツリーが最大深度(10 など)を超えて成長しないようにします。
- リーフのサンプルの最小数を設定する: 特定のサンプル数未満のリーフは分割の対象とみなされません。
次の図は、リーフごとに最小サンプル数が異なる場合の影響を示しています。モデルでキャプチャされるノイズが少なくなる。
図 13. リーフごとの最小サンプル数の違い。
また、特定のブランチを選択的に削除(プルーニング)することで、つまり、特定のリーフ以外のノードをリーフに変換することで、トレーニング後に正則化することもできます。削除するブランチを選択する一般的な解決策は、検証データセットを使用することです。つまり、ブランチを削除することで検証データセットのモデルの品質が向上する場合、そのブランチは削除されます。
次の図は、この考えを示しています。ここでは、リーフ以外の緑色のノードがリーフに変換された場合(オレンジ色のノードをプルーニングする場合)に、ディシジョン ツリーの検証精度が向上するかどうかをテストします。
図 14. 条件とその子をリーフにプルーニングする。
次の図は、ディシジョン ツリーをプルーニングするために、データセットの 20% を検証として使用した場合の効果を示しています。
図 15. データセットの 20% を使用してディシジョン ツリーをプルーニングします。
検証データセットを使用すると、ディシジョン ツリーの初期トレーニングに使用できるサンプルの数が少なくなります。
多くのモデル作成者は複数の基準を適用します。たとえば、次のすべてを行うことができます。
- リーフごとに最小数のサンプルを適用します。
- ディシジョン ツリーの成長を制限するために最大深度を適用します。
- ディシジョン ツリーをプルーニングします。
- サンプルの最小数は 5(
min_examples = 5
)です - トレーニング データセットの 10% が検証用に保持されます(
validation_ratio = 0.1
)。
validation_ratio=0.0
を設定します。
これらの基準により、調整が必要な新しいハイパーパラメータ(ツリーの最大深度など)が導入されます。多くの場合、ハイパーパラメータ チューニングは自動で行われます。ディシジョン ツリーは通常、ハイパーパラメータ チューニングと交差検証を使用するためにトレーニングできる速度で十分です。たとえば、「n」サンプルを含むデータセットでは次のようになります。
- トレーニング サンプルを重複しない p 個のグループに分割します。例:
p=10
。 - 使用可能なすべてのハイパーパラメータ値(例: {3,5,6,7,8,9} の最大深度、{5,8,10,20} の最小例)。
- 各グループで、他の p-1 グループでトレーニングしたディシジョン ツリーの品質を評価します。
- グループ全体で評価を平均化する。
- 平均評価が最も高いハイパーパラメータ値を選択します。
- 「n」個のすべてのサンプルと選択したハイパーパラメータを使用して、最終的なディシジョン ツリーをトレーニングします。
このセクションでは、ディシジョン ツリーで過学習を制限する方法について説明しました。これらの方法にもかかわらず、学習不足と過学習がディシジョン ツリーの主な弱点です。デシジョン フォレストは、過学習を制限する新しい方法を導入します。これについては後で説明します。
直接的なディシジョン ツリーの解釈
ディシジョン ツリーは簡単に解釈できます。とはいえ、少数の例を変更しただけでも、ディシジョン ツリーの構造、ひいては解釈が完全に変わる可能性があります。
ディシジョン ツリーの構築方法(トレーニング サンプルのパーティショニング)により、ディシジョン ツリーを使用して(モデルではなく)データセット自体を解釈できます。各リーフは、データセットの特定の角を表します。
model.describe()
関数を使用してツリーを表示できます。model.get_tree()
を使用して、個々のツリーにアクセスしてプロットすることもできます。詳細については、
YDF のモデル検査チュートリアルをご覧ください。しかし、間接的な解釈も有益です。