目次
はじめに
この記事では、#30DaysOfStreamlitの内容の紹介を行います。
#30DaysOfStreamlitについてはコチラの記事を参照してください。
SHAPとは
SHAP(SHapley Additive exPlanations)は、機械学習モデルの予測に対する各特徴量の寄与度を評価するための手法です。
この手法は、シャプレー値と呼ばれるゲーム理論から派生した「貢献度」を用いて、各特徴量が予測値にどの程度貢献しているかを算出することができます。
SHAPは、機械学習モデルの正確性と解釈可能性の両方を向上することができます。
モデルの予測結果を解釈することは、重要な決定を行う上で必要不可欠です。
SHAPは、ローカル解釈可能性(LIME)やパーティシパル・インタープリタブル・マシンラーニング(PDPbox)などの手法と比較して、モデル全体に対する解釈性を提供することができます。
また、SHAPは、異なる特徴量の組み合わせに対する相互作用の影響を評価することもできます。
SHAPは、Pythonのライブラリとして、様々な機械学習フレームワークに統合されています。
例えば、XGBoostやLightGBM、scikit-learnなどに統合されており、容易に利用することができます。
SHAPを使用することで、モデルの予測をより正確に理解し、説明することができます。
詳細については、こちらの記事を参照してください。
Streamlit-shapとは
streamlit-shap
とは、StreamlitでSHAPプロットを表示する為のラッパーを提供するStreamlitコンポーネントです。
SHAPプロットに関しては、こちらの記事を参照してください。(利用しているライブラリは同じです。)
導入手順
必要なライブラリをインストールします。
また、今回は機械学習モデルの作成も行う為、インストールしてない場合は、以下の様にライブラリを追加します。(コンポーネントが依存しているmatplotlibも導入します。)
構築する目標
Pythonライブラリであるshapでのデモデータを利用して、xgboostでの機械学習モデルを作成して、そのモデルに対してSHAP分析を行うアプリケーションを構築します。
今回利用するデータセットには、説明変数としてアメリカ合衆国の成人の収入に関するデータで、教育レベル、年齢、労働時間などの特徴量を含んでいます。 目的変数は、収入が50,000ドルを超えるか否かを予測するためのもので、"<=50K"(50,000ドル以下)または">50K"(50,000ドル超)の2つのクラスで構成されます。
また、説明変数のすべてが数値化(量的・数値カテゴリ)された状態で格納されています。 以下にデータの例を記述します。
Age | Workclass | Education-Num | ・・・ | Country | target_label |
39.0 | 7 | 13.0 | ・・・ | 39 | False |
50.0 | 6 | 13.0 | ・・・ | 39 | False |
38.0 | 4 | 9.0 | ・・・ | 39 | False |
: | : | : | ・・・ | : | : |
52.0 | 6 | 9.0 | ・・・ | 39 | True |
上記のデータを利用して作成した機械学習モデルを使い、以下のプロットでSHAPの機能を紹介するアプリケーションを構築します。
- Waterfall plot
- Beeswarm plot
- Force plot
アプリケーションの構築
下記のようなPythonスクリプトを準備します。
上記のスクリプトを起動させると下記のような画面が展開されます。
入力データの確認タブを展開すると以下のような表示が出ます。
Xが説明変数であり、yが目的変数の状況です。
SHAPの出力に関して、Waterfall plotは以下のような表示になります。
コチラのプロットは、とあるレコードの各説明変数に対して目的変数の推論予測がどのように影響しているか表示しています。
マイナスに振れている場合はFalseの要因として、プラスに振れている場合はTrueの要因と認識できます。
続いて、Beeswarm plotは以下のような表示になります。
コチラは、データセットの全件に対して横軸で各説明変数がどれほどのSHAP値をマークしたのかを散布図で表しており、色合いでその説明変数の大小を表しています。
コチラのプロットでは、説明変数の大小とSHAP値がどのように関連しているか確認できます。
最後に、Force plotは以下のような表示になります。
First data instanceでは、とあるレコードのSHAP値の広がりを確認することができます。
赤いバーがプラス(True要因)で青いバーがマイナス(False要因)です。
最終的なSHAP値は赤いバーと青いバーの境目となります。
First thousand data instanceは、データセットの最初の1,000件に対し縦軸でSHAP値を表し、横軸がデータセットのインデックスになります。
また、縦軸と横軸は可変であり、縦軸は指定の説明変数のSHAP値のみに絞り込むことができ、横軸はレコードの並びを指定の説明変数で並べ替えが可能です。
このグラフにより、SHAP値の影響を調査することができます。
また、これらのグラフはDarkモードだと見づらいため、Lightモードにすることをお勧めします。
コードの解説
まずは、必要なライブラリをインポートします。
ページ設定をwideに変更します。
shapライブラリからデータをロードします。
続いて取り込んだデータを元にXGBoostのモデルを作成する関数を記述します。
この際に、リロードによる再実行を短時間に抑える為に、cache_data、cache_resourceを利用します。
アプリケーションのタイトルと説明を作成します。
次に、利用するデータセットをロードします。
この時にst.headerを利用してページのヘッダーも設定していきます。
また、データのロードに関しては、モデル作成用にload_data()を利用したX、yを作成し、データ確認用にshap.datasets.adult()を引数であるdisplayをTrueにしてX_display、y_displayを作成します。
データ確認用の表示UIを作成します。
st.expander()を利用してドリルダウン形式で表示します。
続いて、学習モデルの作成とSHAP計算を行います。
SHAPの計算結果を表示する為、ページのヘッダーを追加します。
Waterfall plotとBeeswarm plotを表示させます。
データ確認のUIと同様ドリルダウン形式で表示させています。
続いて、Force plotを表示させる為、グラフのインスタンスを生成後、ドリルダウン形式で表示させています。