この記事では、#30DaysOfStreamlitの内容の紹介を行います。
#30DaysOfStreamlitについてはコチラの記事を参照してください。
今回の課題は、Charly Wargnier氏がブログで紹介していたZero-shotモデルでのテキスト分類機能を再現する形になります。
元のブログはこちらになります。
ディープラーニングや機械学習に於いては、学習用のデータセットを使って学習を行い、学習させたモデルで分類や回帰問題の固有タスクを解くのが一般的な利用方法になります。
ですが、近年のLLM(大規模言語モデル)はタスク毎の学習(ファインチューニング)の前に事前学習を行う事で言語特性等を習得することが分かっており、事前学習の段階で入力文に従うような調整を受けたモデルであれば、学習用のデータセットが無くとも分類や文章生成が可能となります。
このようなモデルをZero-shot学習モデルといいます。
今回の課題では、Hugging Face APIを通してZero-shot学習モデルを利用します。
Hugging Face の API 推論と Distilbart を使用して、ゼロショット学習テキスト分類器を作成します。
アプリケーションを使用すると、ML学習なしで、キーフレーズをその場で高速に分類できます。
これらのラベルは、次のように任意に最大3種類設定できます。
また、今回は元のブログのコードを元に上記機能を持つアプリケーションを構築します。
最初に今回利用するコンポーネントとライブラリをインストールしておきます。
続いて、アプリケーション上で画像を利用するので、元ブログで紹介されているGithubから画像を拝借します。
この画像は、アプリケーションのコードが配置してあるディレクトリに配置します。
続いてHugging Face APIを利用するために、Hugging Faceにサインアップします。 コチラからアクセスして右上の「Sign Up」からユーザー作成をします。
ユーザー作成後、こちらからAPIアクセストークンを作成してメモしてください。
下記のようなPythonスクリプトを作成します。
上記のコードを実行すると下記のような画面が展開します。
上部のテキストボックスで最大3種までのラベルを設定できます。
また、分類の対象となるフレーズは下部のテキストボックスで入力します。
最後に「Submit」を押下すると下記のような画面に遷移します。
分類が成功すると上記のような画面となり、結果をテーブル形式で確認できます。
また、一番下の「Download results as CSV」を押下するとテーブル形式で表示している結果をCSV形式でダウンロードできます。
最初に必要なライブラリをインポートします。
続いて、セッションデータを参照してwideが指定されている場合はlayoutをwideに変更するようにします。
ページの設定を行います。
この時に、layoutは上記のコードで指定したものとして、ページタイトルとアイコンを設定しています。
Hugging Face APIに送るユーザーが入力したデータが存在するかのフラグをセッションデータに記述します。
初期起動なので、valid_inputs_receivedが存在しない場合、valid_inputs_receivedをFalseで保存します。
続いて、ページのタイトルを作成していきます。
st.columns()を利用して画像領域(c1)とテキスト領域(c2)に分けて設定していきます。
そのあと、アプリケーションの説明文を作成します。
そして、アプリケーション利用ユーザー用の入力フォームを作成します。
この時にAPI_KEY
に準備編で取得したAPIアクセストークンを記入します。(本来はセキュリティ上、環境変数等に埋め込む等の対応が必要でありますが、、、)
APIアクセス用のURLをリクエスト用ヘッダーを作成します。
streamlit-tagのコンポーネントからst_tagsを使い、タグの入力UIを作成しています。
valueに初期値を入力して、タグ数をmaxtagsで3種に限定しています。
ここでの入力がZero-shot learning分類でのラベルに相当します。
続いて、分類したいフレーズを入力するUIを作成します。
最初に、サンプル用の文章を作成します。
サンプル文をnums
に用意してnew_line.join(map(str, nums))
で改行で各文を区切った文字列を作成して、それをsample
に格納してます。
さらに、st.text_area()
を使ってフレーズ入力用UIを作成しています。
この時に、MAX_LINES_FULL
で入力フレーズの限界数を設定しています。
初期値として先ほど作成したsample
を利用しています。
入力用インターフェースを作成後、入力したフレーズをリストに格納します。
この時に重複文削除(下から2行目)とNull削除(下から1行目)を行っています。
入力されたフレーズが設定した既定の数を超えてないか確認を行います。
この時に既定数を超えていれば、infoでユーザーに既定数越えのメッセージを出してフレーズリストを既定数に揃えます。
最後にこの入力フォームのsubmitボタンを作成します。
ここからは、submit
ボタンが押されたときに入力データが異常なものであるかチェックを行っています。
最初のif
文で初期状態検知して入力時間を維持しています。(submit
が押されてなく、インプットデータもない状態)
次のelif
で分類対象のフレーズがない場合、その旨の警告を表示しvalid_inputs_received
をFalse
に変更します。(submit
が押されており、分類対象フレーズがない状態)
さらに次のelif
で分類ラベルがない場合、その旨の警告を表示し、valid_inputs_received
をFalse
に変更します。(submit
が押されており、分類ラベルがない状態)
最後にelif
で分類ラベルが1種類しかない場合、その旨の警告を表示しvalid_inputs_received
をFalse
に変更します。(submit
が押されており、分類ラベルが1種類しかない状態)
そして、submit
が押されているもしくは、valid_inputs_received
がTrue
であれば、推論処理に移行します。
まずは、APIを呼び出すのでネットワーク障害等の不確定要素でエラーが出る可能性があるため、try文を利用します。
その後、submit
が押されたときにvalid_inputs_received
をTrueに変更します。
そして、query()
というAPIコール用の関数を作成します。 推論結果を格納するlistToAppend
というリストを作成します。
入力したフレーズを一つずつ分類推論させていきます。
この時、payload
内に必要な情報を記述してポストします。
推論結果は、listToAppend
に格納していきます。
すべて完了したら、"✅ Done!"を表示します。
得られた推論結果をpandas
のDataFrame
に変換します。
そのあと、推論結果表示用のUIのタイトルを作成して、場合によってはワイド表示のほうが見やすいので、ワイド表示に変更できるようにチェックボックスを作成します。
DataFrame内の値を調整します。
まず、分類のスコア(確信度)を小数点2桁までに制限します。 そのあと、sequenceというカラムをkeyphraseに変更します。
streamlit-aggrid
を使って推論結果のDataFrame
を表示します。streamlit-aggrid
を利用することで以下の機能を追加できます。
最初の行で表示したいデータフレームから必要な設定情報を作成しています。
二行目以降は、設定を追加しています。(詳細はドキュメントを参照ください。)
その後、UI作成のためresponse
を作成しています。
そして、推論結果をCSVファイルで出力できるように出力用ボタンを作成しています。
最後に、APIアクセスでエラーが出た場合の処理を記述しています。