AIDXナレッジ - INSIGHT LAB

#30DaysOfStreamlit Day29 Zero-shot学習モデルでのテキスト分類器

作成者: Budo Ogimoto|2023年12月19日

はじめに

この記事では、#30DaysOfStreamlitの内容の紹介を行います。
#30DaysOfStreamlitについてはコチラの記事を参照してください。

課題について

今回の課題は、Charly Wargnier氏がブログで紹介していたZero-shotモデルでのテキスト分類機能を再現する形になります。
元のブログはこちらになります。

Zero-shot学習モデルとは

ディープラーニングや機械学習に於いては、学習用のデータセットを使って学習を行い、学習させたモデルで分類や回帰問題の固有タスクを解くのが一般的な利用方法になります。
ですが、近年のLLM(大規模言語モデル)はタスク毎の学習(ファインチューニング)の前に事前学習を行う事で言語特性等を習得することが分かっており、事前学習の段階で入力文に従うような調整を受けたモデルであれば、学習用のデータセットが無くとも分類や文章生成が可能となります。
このようなモデルをZero-shot学習モデルといいます。

今回の課題では、Hugging Face APIを通してZero-shot学習モデルを利用します。

構築する目標

Hugging Face の API 推論と Distilbart を使用して、ゼロショット学習テキスト分類器を作成します。

アプリケーションを使用すると、ML学習なしで、キーフレーズをその場で高速に分類できます。
これらのラベルは、次のように任意に最大3種類設定できます。

  • 感情分析のための「ポジティブ」、「ネガティブ」、「ニュートラル」
  • 感情分析用の「Angry」、「Happy」、「Emotional」
  • 意図の分類を目的とした「ナビゲーション」、「トランザクション」、「情報」
  • 製品範囲 (バッグ、靴、ブーツなど)

また、今回は元のブログのコードを元に上記機能を持つアプリケーションを構築します。

環境準備

最初に今回利用するコンポーネントとライブラリをインストールしておきます。

続いて、アプリケーション上で画像を利用するので、元ブログで紹介されているGithubから画像を拝借します。
この画像は、アプリケーションのコードが配置してあるディレクトリに配置します。

続いてHugging Face APIを利用するために、Hugging Faceにサインアップします。 コチラからアクセスして右上の「Sign Up」からユーザー作成をします。
ユーザー作成後、こちらからAPIアクセストークンを作成してメモしてください。

アプリケーションの構築

下記のようなPythonスクリプトを作成します。

上記のコードを実行すると下記のような画面が展開します。

上部のテキストボックスで最大3種までのラベルを設定できます。
また、分類の対象となるフレーズは下部のテキストボックスで入力します。
最後に「Submit」を押下すると下記のような画面に遷移します。

分類が成功すると上記のような画面となり、結果をテーブル形式で確認できます。
また、一番下の「Download results as CSV」を押下するとテーブル形式で表示している結果をCSV形式でダウンロードできます。

コードの解説

最初に必要なライブラリをインポートします。

続いて、セッションデータを参照してwideが指定されている場合はlayoutwideに変更するようにします。

ページの設定を行います。
この時に、layoutは上記のコードで指定したものとして、ページタイトルとアイコンを設定しています。

Hugging Face APIに送るユーザーが入力したデータが存在するかのフラグをセッションデータに記述します。
初期起動なので、valid_inputs_receivedが存在しない場合、valid_inputs_receivedFalseで保存します。

続いて、ページのタイトルを作成していきます。
st.columns()を利用して画像領域(c1)とテキスト領域(c2)に分けて設定していきます。

そのあと、アプリケーションの説明文を作成します。

そして、アプリケーション利用ユーザー用の入力フォームを作成します。
この時にAPI_KEYに準備編で取得したAPIアクセストークンを記入します。(本来はセキュリティ上、環境変数等に埋め込む等の対応が必要でありますが、、、)
APIアクセス用のURLをリクエスト用ヘッダーを作成します。

streamlit-tagのコンポーネントからst_tagsを使い、タグの入力UIを作成しています。
valueに初期値を入力して、タグ数をmaxtagsで3種に限定しています。
ここでの入力がZero-shot learning分類でのラベルに相当します。

続いて、分類したいフレーズを入力するUIを作成します。
最初に、サンプル用の文章を作成します。

サンプル文をnumsに用意してnew_line.join(map(str, nums))で改行で各文を区切った文字列を作成して、それをsampleに格納してます。

さらに、st.text_area()を使ってフレーズ入力用UIを作成しています。

この時に、MAX_LINES_FULLで入力フレーズの限界数を設定しています。
初期値として先ほど作成したsampleを利用しています。

入力用インターフェースを作成後、入力したフレーズをリストに格納します。
この時に重複文削除(下から2行目)とNull削除(下から1行目)を行っています。

入力されたフレーズが設定した既定の数を超えてないか確認を行います。
この時に既定数を超えていれば、infoでユーザーに既定数越えのメッセージを出してフレーズリストを既定数に揃えます。
最後にこの入力フォームのsubmitボタンを作成します。

ここからは、submitボタンが押されたときに入力データが異常なものであるかチェックを行っています。

最初のif文で初期状態検知して入力時間を維持しています。(submitが押されてなく、インプットデータもない状態)

次のelifで分類対象のフレーズがない場合、その旨の警告を表示しvalid_inputs_receivedFalseに変更します。(submitが押されており、分類対象フレーズがない状態)

さらに次のelifで分類ラベルがない場合、その旨の警告を表示し、valid_inputs_receivedFalseに変更します。(submitが押されており、分類ラベルがない状態)

最後にelifで分類ラベルが1種類しかない場合、その旨の警告を表示しvalid_inputs_receivedFalseに変更します。(submitが押されており、分類ラベルが1種類しかない状態)

そして、submitが押されているもしくは、valid_inputs_receivedTrueであれば、推論処理に移行します。

まずは、APIを呼び出すのでネットワーク障害等の不確定要素でエラーが出る可能性があるため、try文を利用します。
その後、submitが押されたときにvalid_inputs_receivedをTrueに変更します。
そして、query()というAPIコール用の関数を作成します。 推論結果を格納するlistToAppendというリストを作成します。

入力したフレーズを一つずつ分類推論させていきます。
この時、payload内に必要な情報を記述してポストします。
推論結果は、listToAppendに格納していきます。
すべて完了したら、"✅ Done!"を表示します。

得られた推論結果をpandasDataFrameに変換します。
そのあと、推論結果表示用のUIのタイトルを作成して、場合によってはワイド表示のほうが見やすいので、ワイド表示に変更できるようにチェックボックスを作成します。
DataFrame内の値を調整します。
まず、分類のスコア(確信度)を小数点2桁までに制限します。 そのあと、sequenceというカラムをkeyphraseに変更します。

streamlit-aggridを使って推論結果のDataFrameを表示します。
streamlit-aggridを利用することで以下の機能を追加できます。

  • 列の並べ替え、フィルター、検索
  • 列をドラッグして順序を変更する
  • 列をグループ化して固定し、集計を計算します
  • 大きなデータフレームもページ分割可能

最初の行で表示したいデータフレームから必要な設定情報を作成しています。
二行目以降は、設定を追加しています。(詳細はドキュメントを参照ください。)
その後、UI作成のためresponseを作成しています。

そして、推論結果をCSVファイルで出力できるように出力用ボタンを作成しています。

最後に、APIアクセスでエラーが出た場合の処理を記述しています。