TensorFlow Lite Model Maker を使用してカスタム テキスト分類モデルを作成する

1. 始める前に

この Codelab では、ブログ スパムコメント データセットから作成されたテキスト分類モデルを更新し、独自のデータでモデルが機能するように、独自のコメントを使ってモデルを改善する方法を学びます。

前提条件

この Codelab は、Flutter アプリでテキスト分類を行うスタートガイドのパスウェイの一部です。このパスウェイの Codelab は順番に行っていきます。この Codelab を学習する前に、使用するアプリとモデルをあらかじめ構築しておく必要があります。前のアクティビティをまだ完了していない場合は、ここで停止して、次のことを行ってください。

学習内容

必要なもの

  • 前のアクティビティで確認して構築した Flutter アプリとスパムフィルタ モデル。

2. テキスト分類を強化する

  1. このコードを取得するには、このリポジトリのクローンを作成して、tfserving-flutter/codelab2/finished フォルダからアプリを読み込みます。
  2. TensorFlow Serving Docker イメージを起動した後、作成したアプリで「buy my book to learn online trading」と入力し、[gRPC] > [Classify] の順にクリックします。

8f1e1974522f274d.png

元のデータセット内でオンライン取引が頻繁に発生しておらず、モデルがスパムを認識していないため、低めのスパムスコアが生成されます。この Codelab では、新しいデータでモデルを更新し、同じ文がスパムとして識別されるようにします。

2bd68691a26aa3da.png

3. CSV ファイルを編集する

元のモデルをトレーニングするために、「スパム」または「スパムでない」とラベル付けされた約 1, 000 件のコメントを含むデータセットを CSV(lmblog_comments.csv)として作成しました(CSV ファイルを確認する場合は、任意のテキスト エディタで開きます)。

CSV ファイルでは、最初の行に commenttextspam というラベルの付いた列が記述されています。以降の行はすべて次の形式になります。

62025273971c9a7f.png

スパムの場合、右側のラベルに true 値が割り当てられ、スパムでなければ false 値が割り当てられます。たとえば、3 行目はスパムと見なされます。

オンライン取引に関するメッセージを含むウェブサイトがスパムして認識されている場合は、ウェブサイトの下部にスパムコメントの例を追加できます。例:

online trading can be highly highly effective,true
online trading can be highly effective,true
online trading now,true
online trading here,true
online trading for the win,true
  • 新しいモデルのトレーニングで使用できるように、lmblog_comments.csv などの新しい名前でファイルを保存します。

この Codelab の残りの部分では、オンライン取引を更新しながら、Cloud Storage で提供、編集、ホストされたサンプルを使用します。独自のデータセットを使用する場合は、コード内の URL を変更します。

4.新しいデータでモデルを再トレーニングする

モデルを再トレーニングするには、(SpamCommentsModelMaker.ipynb)のコードを再利用して、新しい CSV データセット(lmblog_comments_extras.csv)を指定します。更新済みのコンテンツを含む完全なノートブックが必要な場合は、SpamCommentsUpdateModelMaker.ipynb. を探してください。

Colaboratory に対するアクセス権がある場合は、Colaboratory を直接起動できます。それ以外の場合は、リポジトリからコードを取得し、選択したノートブック環境で実行します。

更新されたコードは、以下のコード スニペットのようになります。

training_data = tf.keras.utils.get_file(fname='comments-spam-extras.csv',
          origin='https://storage.googleapis.com/laurencemoroney-blog.appspot.com/
                  lmblog_comments_extras.csv',
          extract=False)

トレーニングを行うと、トレーニングによってモデルの精度が高くなっていることがわかります。

96a1547ddb6edf5b.png

/mm_update_spam_savedmodel フォルダ全体を圧縮し、生成された mm_update_spam_savedmodel.zip ファイルをダウンロードします。

# Rename the SavedModel subfolder to a version number
!mv /mm_update_spam_savedmodel/saved_model /mm_update_spam_savedmodel/123
!zip -r mm_update_spam_savedmodel.zip /mm_update_spam_savedmodel/

5. Docker を起動して Flutter アプリを更新する

  1. ダウンロードした mm_update_spam_savedmodel.zip ファイルをフォルダに解凍し、前の Codelab で作成した Docker コンテナ インスタンスを停止してから再起動します。ただし、PATH/TO/UPDATE/SAVEDMODEL プレースホルダは、ダウンロードしたファイルをホストするフォルダの絶対パスで置き換えます。
docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/UPDATE/SAVEDMODEL:/models/spam-detection" -e MODEL_NAME=spam-detection tensorflow/serving
  1. 任意のコードエディタで lib/main.dart ファイルを開き、inputTensorName 変数と outTensorName 変数の定義部分を探します。
const inputTensorName = 'input_3';
const outputTensorName = 'dense_5';
  1. inputTensorName 変数を input_1' 値に、outputTensorName 変数を 'dense_1' 値に再割り当てします。
const inputTensorName = 'input_1';
const outputTensorName = 'dense_1';
  1. ダウンロードした vocab.txt ファイルを lib/assets/ フォルダにコピーして、既存のファイルを置き換えます。
  2. Android Emulator からテキスト分類の Flutter アプリを手動で削除します。
  3. ターミナルで 'flutter run' コマンドを実行してアプリを起動します。
  4. アプリで「buy my book to learn online trading」と入力し、[gRPC] > [Classify] をクリックします。

モデルが改善され、「buy my book to online trading」がスパムとして検出されるようになりました。

6. 完了

新しいデータでモデルを再トレーニングし、Flutter アプリと統合して、新しいスパム文を検出する機能を更新しました。

詳細