Flutter アプリを作成してテキストを分類する

1. 始める前に

この Codelab では、Flutter アプリから REST または gRPC を介して TensorFlow Serving を呼び出し、テキスト分類推論を実行する方法を学びます。

前提条件

学習内容

  • 簡単な Flutter アプリを作成し、TensorFlow Serving(REST と gRPC)を使用してテキストを分類する方法。
  • 結果を UI に表示する方法。

必要なもの

2. 設定する

この Codelab のコードをダウンロードするには:

  1. この Codelab の GitHub リポジトリに移動します。
  2. [Code] > [Download zip] をクリックして、この Codelab のすべてのコードをダウンロードします。

2cd45599f51fb8a2.png

  1. ダウンロードした zip ファイルを解凍して、codelabs-master ルートフォルダを展開します。このフォルダに必要なリソースがすべて含まれています。

この Codelab では、リポジトリの tfserving-flutter/codelab2 サブディレクトリ内のファイルのみが必要です。このサブディレクトリには次の 2 つのフォルダが含まれています。

  • starter フォルダには、この Codelab で構築するスターター コードが含まれています。
  • finished フォルダには、完成したサンプルアプリの完全なコードが含まれています。

3. プロジェクトの依存関係をダウンロードする

  1. VS Code で、[File] > [Open folder] をクリックし、先ほどダウンロードしたソースコードの starter フォルダを選択します。
  2. スターター アプリに必要なパッケージのダウンロードを求めるダイアログが表示されたら、[Get packages] をクリックします。
  3. このダイアログが表示されない場合は、ターミナルを開いて starter フォルダで flutter pub get コマンドを実行します。

7ada07c300f166a6.png

4.スターター アプリを実行する

  1. VS Code で、Android Emulator または iOS Simulator が正しくセットアップされ、ステータスバーに表示されていることを確認します。

たとえば、Android Emulator で Pixel 5 を使用する場合は次のようになります。

9767649231898791.png

iOS シミュレータで iPhone 13 を使用する場合は次のようになります。

95529e3a682268b2.png

  1. [a19a0c68bc4046e6.png Start debugging] をクリックします。

アプリを実行して操作する

Android Emulator または iOS シミュレータでアプリを起動します。UI は非常にシンプルです。ユーザーがテキストを入力できるテキスト フィールドがあります。ユーザーは、データをバックエンドに送信する方法(REST または gRPC)を選択できます。バックエンドは、TensorFlow モデルを使用して前処理済みの入力に対してテキスト分類を実行します。分類結果がクライアント アプリに返され、UI が更新されます。

b298f605d64dc132.png d3ef3ccd3c338108.png

[Classify] をクリックしても、まだバックエンドと通信できないため、何も起こりません。

5. TensorFlow Serving を使用してテキスト分類モデルをデプロイする

テキスト分類は、事前定義されたカテゴリにテキストを分類するごく一般的な機械学習タスクです。この Codelab では、TensorFlow Serving を使用して、TensorFlow Lite Model Maker Codelab でコメントスパム検出モデルをトレーニングするでトレーニング済みのモデルをデプロイし、Flutter フロントエンドからバックエンドを呼び出して、入力テキストを「スパム」または「スパムでない」に分類します。

TensorFlow Serving を開始する

  • ターミナルでは、Docker を使用して TensorFlow Serving を起動しますが、PATH/TO/SAVEDMODEL プレースホルダは、パソコンの mm_spam_savedmodel フォルダの絶対パスに置き換えます。
docker pull tensorflow/serving

docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/spam-detection" -e MODEL_NAME=spam-detection tensorflow/serving

Docker は、まず TensorFlow Serving のイメージを自動的にダウンロードします。これには 1 分ほどかかります。その後、TensorFlow Serving が起動します。次のようなログが出力されます。

2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.
2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz
2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123
2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds.
2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests
2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: spam-detection version: 123}
2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models
2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled
2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ...
[warn] getaddrinfo: address family for nodename not supported
2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ...
[evhttp_server.cc : 245] NET_LOG: Entering the event loop ...

6. 入力文をトークン化する

これでバックエンドの準備が整いました。TensorFlow Serving にクライアント リクエストを送信する準備がほぼ整いましたが、その前に入力文をトークン化する必要があります。モデルの入力テンソルを調べてみると、そのままの文字列ではなく、20 個の整数値から構成されるリストが必要であることがわかります。分類のためにバックエンドに送信する前にトークン化を行い、アプリに入力された個々の単語を語彙辞書に基づいて整数のリストにマッピングします。たとえば、「buy book online to learn more」と入力すると、この入力はトークン化プロセスにより [32, 79, 183, 10, 224, 631, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] にマッピングされます。実際の数値は語彙辞書によって異なります。

  1. lib/main.dart ファイルで、このコードを predict() メソッドに追加し、_vocabMap 語彙辞書を作成します。
// Build _vocabMap if empty.
if (_vocabMap.isEmpty) {
  final vocabFileString = await rootBundle.loadString(vocabFile);
  final lines = vocabFileString.split('\n');
  for (final l in lines) {
    if (l != "") {
      var wordAndIndex = l.split(' ');
      (_vocabMap)[wordAndIndex[0]] = int.parse(wordAndIndex[1]);
    }
  }
}
  1. 上記のコード スニペットの直後に次のコードを追加して、トークン化を実装します。
// Tokenize the input sentence.
final inputWords = _inputSentenceController.text
    .toLowerCase()
    .replaceAll(RegExp('[^a-z ]'), '')
    .split(' ');
// Initialize with padding token.
_tokenIndices = List.filled(maxSentenceLength, 0);
var i = 0;
for (final w in inputWords) {
  if ((_vocabMap).containsKey(w)) {
    _tokenIndices[i] = (_vocabMap)[w]!;
    i++;
  }

  // Truncate the string if longer than maxSentenceLength.
  if (i >= maxSentenceLength - 1) {
    break;
  }
}

このコードは、語彙テーブルに基づいて文の文字列を小文字に変換し、アルファベット以外の文字を削除して、単語を 20 個の整数インデックスにマッピングします。

7. REST を介して Flutter アプリを TensorFlow Serving に接続する

TensorFlow Serving にリクエストを送信する方法は 2 つあります。

  • REST
  • gRPC

REST でリクエストを送信してレスポンスを受信する

REST でリクエストを送信してレスポンスを受信する場合は、次の 3 つの簡単なステップを行います。

  1. REST リクエストを作成します。
  2. TensorFlow Serving に REST リクエストを送信します。
  3. REST レスポンスから予測結果を抽出し、UI をレンダリングします。

これらの手順を main.dart ファイルに記述します。

REST リクエストを作成して TensorFlow Serving に送信する

  1. この状態では、predict() 関数は TensorFlow Serving に REST リクエストを送信しません。REST リクエストを作成するには、REST ブランチを実装する必要があります。
if (_connectionMode == ConnectionModeType.rest) {
  // TODO: Create and send the REST request.

}
  1. 次のコードを REST ブランチに追加します。
//Create the REST request.
final response = await http.post(
  Uri.parse('http://' +
      _server +
      ':' +
      restPort.toString() +
      '/v1/models/' +
      modelName +
      ':predict'),
  body: jsonEncode(<String, List<List<int>>>{
    'instances': [_tokenIndices],
  }),
);

TensorFlow Serving からの REST レスポンスを処理する

  • 上記のコード スニペットの直後に次のコードを追加して、REST レスポンスを処理します。
// Process the REST response.
if (response.statusCode == 200) {
  Map<String, dynamic> result = jsonDecode(response.body);
  if (result['predictions']![0][1] >= classificationThreshold) {
    return 'This sentence is spam. Spam score is ' +
        result['predictions']![0][1].toString();
  }
  return 'This sentence is not spam. Spam score is ' +
      result['predictions']![0][1].toString();
} else {
  throw Exception('Error response');
}

後処理コードは、入力文がスパム メッセージである確率をレスポンスから抽出し、分類結果を UI に表示します。

実行

  1. [a19a0c68bc4046e6.png Start debugging] をクリックして、アプリが読み込まれるまで待ちます。
  2. テキストを入力して、[REST] > [Classify] を選択します。

8e21d795af36d07a.png e79a0367a03c2169.png

8. gRPC を使用して Flutter アプリを TensorFlow Serving に接続する

TensorFlow Serving は、REST に加えて gRPC もサポートします。

b6f4449c2c850b0e.png

gRPC は、最新のオープンソースの高性能なリモート プロシージャ コール(RPC)フレームワークで、あらゆる環境で実行できます。ロード バランシング、トレース、ヘルスチェック、認証のプラグ可能なサポートにより、データセンター内やセンター間で効率的にサービスを接続できます。gRPC は、REST よりもパフォーマンスが高いことが確認されています。

gRPC でリクエストを送信してレスポンスを受信する

gRPC でリクエストを送信してレスポンスを受信する場合は、次の 4 つの簡単なステップを行います。

  1. (省略可)gRPC クライアント スタブコードを生成します。
  2. gRPC リクエストを作成します。
  3. TensorFlow Serving に gRPC リクエストを送信します。
  4. gRPC レスポンスから予測結果を抽出し、UI をレンダリングします。

これらの手順を main.dart ファイルに記述します。

省略可: gRPC クライアント スタブコードを生成する

gRPC と TensorFlow Serving を使用する場合は、gRPC ワークフローに従う必要があります。詳細については、gRPC のドキュメントをご覧ください。

a9d0e5cb543467b4.png

TensorFlow Serving と TensorFlow は必要な .proto ファイルを定義します。TensorFlow と TensorFlow Serving 2.8 では、次の .proto ファイルが必要です。

tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto

tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto

google/protobuf/any.proto
google/protobuf/wrappers.proto
  • ターミナルで、starter/lib/proto/ フォルダに移動し、スタブを生成します。
bash generate_grpc_stub_dart.sh

gRPC リクエストを作成する

REST リクエストと同様に、gRPC ブランチに gRPC リクエストを作成します。

if (_connectionMode == ConnectionModeType.rest) {

} else {
  // TODO: Create and send the gRPC request.

}
  • 次のコードを追加して、gRPC リクエストを作成します。
//Create the gRPC request.
final channel = ClientChannel(_server,
    port: grpcPort,
    options:
        const ChannelOptions(credentials: ChannelCredentials.insecure()));
_stub = PredictionServiceClient(channel,
    options: CallOptions(timeout: const Duration(seconds: 10)));

ModelSpec modelSpec = ModelSpec(
  name: 'spam-detection',
  signatureName: 'serving_default',
);

TensorShapeProto_Dim batchDim = TensorShapeProto_Dim(size: Int64(1));
TensorShapeProto_Dim inputDim =
    TensorShapeProto_Dim(size: Int64(maxSentenceLength));
TensorShapeProto inputTensorShape =
    TensorShapeProto(dim: [batchDim, inputDim]);
TensorProto inputTensor = TensorProto(
    dtype: DataType.DT_INT32,
    tensorShape: inputTensorShape,
    intVal: _tokenIndices);

// If you train your own model, update the input and output tensor names.
const inputTensorName = 'input_3';
const outputTensorName = 'dense_5';
PredictRequest request = PredictRequest(
    modelSpec: modelSpec, inputs: {inputTensorName: inputTensor});

注: モデルのアーキテクチャが同じであっても、モデルによって入力テンソルと出力テンソルの名前が異なる場合があります。独自のモデルをトレーニングする場合は、必ず更新してください。

TensorFlow Serving に gRPC リクエストを送信する

  • 上記のコード スニペットの後に次のコードを追加して、gRPC リクエストを TensorFlow Serving に送信します。
// Send the gRPC request.
PredictResponse response = await _stub.predict(request);

TensorFlow Serving からの gRPC レスポンスを処理する

  • 前のコード スニペットの後に次のコードを追加して、レスポンスを処理するコールバック関数を実装します。
// Process the response.
if (response.outputs.containsKey(outputTensorName)) {
  if (response.outputs[outputTensorName]!.floatVal[1] >
      classificationThreshold) {
    return 'This sentence is spam. Spam score is ' +
        response.outputs[outputTensorName]!.floatVal[1].toString();
  } else {
    return 'This sentence is not spam. Spam score is ' +
        response.outputs[outputTensorName]!.floatVal[1].toString();
  }
} else {
  throw Exception('Error response');
}

後処理コードはレスポンスから分類結果を抽出し、UI に表示します。

実行

  1. [a19a0c68bc4046e6.png Start debugging] をクリックして、アプリが読み込まれるまで待ちます。
  2. テキストを入力して、[gRPC] > [Classify] を選択します。

e44e6e9a5bde2188.png 92644d723f61968c.png

9. 完了

TensorFlow Serving を使用して、アプリにテキスト分類機能を追加しました。

次の Codelab では、現在のアプリで検出できない特定のスパム メッセージを検出できるように、モデルを改善します。

詳細