创建 Flutter 应用以对文本进行分类

1. 准备工作

在此 Codelab 中,您将学习如何通过 REST 和 gRPC 使用 TensorFlow Serving 从 Flutter 应用运行文本分类推断。

前提条件

学习内容

  • 如何通过 TensorFlow Serving(REST 和 gRPC)构建简单的 Flutter 应用并对文本进行分类。
  • 如何在界面中显示结果。

您需要满足的条件

2. 进行设置

如需下载此 Codelab 的代码,请执行以下操作:

  1. 找到此 Codelab 的 GitHub 代码库
  2. 点击 Code(代码)> Download zip(下载 Zip 文件),下载此 Codelab 的所有代码。

2cd45599f51fb8a2.png

  1. 解压缩已下载的 Zip 文件,这会解压缩 codelabs-master 根文件,其中包含您需要的所有资源。

在此 Codelab 中,您只需要代码库的 tfserving-flutter/codelab2 子目录(其中包含两个文件夹)中的文件:

  • starter 文件夹包含您在此 Codelab 中执行构建的起始代码。
  • finished 文件夹包含已完成示例应用的完成后的代码。

3. 下载项目的依赖项

  1. 在 VS Code 中,点击 File(文件)> Open folder(打开文件夹),然后从您之前下载的源代码中选择 starter 文件夹。
  2. 如果您看到一个对话框,提示您下载起始应用所需的软件包,请点击 Get packages(获取软件包)。
  3. 如果您没有看到此对话框,请打开终端,然后在 starter 文件夹中运行 flutter pub get 命令。

7ada07c300f166a6.png

4. 运行起始应用

  1. 在 VS Code 中,确保 Android 模拟器或 iOS 模拟器已正确设置并显示在状态栏中。

例如,当您将 Pixel 5 与 Android 模拟器搭配使用时,会看到以下内容:

9767649231898791.png

当您将 iPhone 13 与 iOS 模拟器搭配使用时,会看到以下内容:

95529e3a682268b2.png

  1. 点击 a19a0c68bc4046e6.png Start debugging(开始调试)。

运行和探索应用

应用应在 Android 模拟器或 iOS 模拟器上启动。界面非常简单。系统提供了文本字段,可让用户输入文本。用户可以选择是使用 REST 还是 gRPC 将数据发送到后端。后端使用 TensorFlow 模型对预处理的输入执行文本分类,并将分类结果返回给客户端应用,客户端应用进而更新界面。

b298f605d64dc132.png d3ef3ccd3c338108.png

点击 Classify(分类)不起作用,因为它还不能与后端进行通信。

5. 使用 TensorFlow Serving 部署文本分类模型

文本分类是一项非常常见的机器学习任务,用于将文本分为预定义的类别。在此 Codelab 中,您将使用 TensorFlow Serving 部署《使用 TensorFlow Lite Model Maker 训练垃圾评论检测模型》Codelab 中的预训练模型,并从 Flutter 前端调用后端以将输入文本分类为垃圾文本或非垃圾文本。

启动 TensorFlow Serving

  • 在您的终端中,使用 Docker 启动 TensorFlow Serving,但需要注意将 PATH/TO/SAVEDMODEL 占位符替换为您计算机上的 mm_spam_savedmodel 文件夹的绝对路径。
docker pull tensorflow/serving

docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/spam-detection" -e MODEL_NAME=spam-detection tensorflow/serving

Docker 会先自动下载 TensorFlow Serving 映像,此过程需要一分钟时间。之后,TensorFlow Serving 便会启动。日志应如下面的代码段所示:

2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.
2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz
2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123
2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds.
2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests
2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: spam-detection version: 123}
2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models
2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled
2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ...
[warn] getaddrinfo: address family for nodename not supported
2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ...
[evhttp_server.cc : 245] NET_LOG: Entering the event loop ...

6. 对输入句子进行词元化处理

现在,后端已准备就绪,您也就基本上作好了将客户端请求发送到 TensorFlow Serving 的准备,但还需要先将输入句子进行词元化处理。如果您检查模型的输入张量,会发现它需要的是一个包含 20 个整数的列表,而不是原始字符串。词元化是指根据词典将您输入到应用的各个字词映射为一个整数列表的过程,然后将其发送到后端进行分类。例如,如果您输入 buy book online to learn more,则词元化过程会将其映射为 [32, 79, 183, 10, 224, 631, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]。具体数字因词典而异。

  1. lib/main.dart 文件中,将此代码添加到 predict() 方法以构建 _vocabMap 词典。
// Build _vocabMap if empty.
if (_vocabMap.isEmpty) {
  final vocabFileString = await rootBundle.loadString(vocabFile);
  final lines = vocabFileString.split('\n');
  for (final l in lines) {
    if (l != "") {
      var wordAndIndex = l.split(' ');
      (_vocabMap)[wordAndIndex[0]] = int.parse(wordAndIndex[1]);
    }
  }
}
  1. 紧接着上述代码段之后,添加以下代码以实现字元化处理:
// Tokenize the input sentence.
final inputWords = _inputSentenceController.text
    .toLowerCase()
    .replaceAll(RegExp('[^a-z ]'), '')
    .split(' ');
// Initialize with padding token.
_tokenIndices = List.filled(maxSentenceLength, 0);
var i = 0;
for (final w in inputWords) {
  if ((_vocabMap).containsKey(w)) {
    _tokenIndices[i] = (_vocabMap)[w]!;
    i++;
  }

  // Truncate the string if longer than maxSentenceLength.
  if (i >= maxSentenceLength - 1) {
    break;
  }
}

此代码会将句子中的字符串转换为小写,移除非字母字符,并根据词汇表将字词映射成 20 个整数索引。

7. 通过 REST 将 Flutter 应用与 TensorFlow Serving 关联起来

您可以通过以下两种方式向 TensorFlow Serving 发送请求:

  • REST
  • gRPC

通过 REST 发送请求和接收响应

要通过 REST 发送请求和接收响应需要执行三个简单的步骤:

  1. 创建 REST 请求。
  2. 将 REST 请求发送到 TensorFlow Serving。
  3. 从 REST 响应中提取预测结果,并呈现界面。

您需要在 main.dart 文件中完成这些步骤。

创建 REST 请求并将其发送到 TensorFlow Serving

  1. 目前,predict() 函数不会将 REST 请求发送到 TensorFlow Serving。您需要实现 REST 分支才能创建 REST 请求:
if (_connectionMode == ConnectionModeType.rest) {
  // TODO: Create and send the REST request.

}
  1. 将以下代码添加到 REST 分支:
//Create the REST request.
final response = await http.post(
  Uri.parse('http://' +
      _server +
      ':' +
      restPort.toString() +
      '/v1/models/' +
      modelName +
      ':predict'),
  body: jsonEncode(<String, List<List<int>>>{
    'instances': [_tokenIndices],
  }),
);

处理来自 TensorFlow Serving 的 REST 响应

  • 将以下代码添加到上述代码段之后,以处理 REST 响应:
// Process the REST response.
if (response.statusCode == 200) {
  Map<String, dynamic> result = jsonDecode(response.body);
  if (result['predictions']![0][1] >= classificationThreshold) {
    return 'This sentence is spam. Spam score is ' +
        result['predictions']![0][1].toString();
  }
  return 'This sentence is not spam. Spam score is ' +
      result['predictions']![0][1].toString();
} else {
  throw Exception('Error response');
}

后处理代码会从响应中提取输入句子是垃圾信息的概率,并在界面中显示分类结果。

运行应用

  1. 点击 a19a0c68bc4046e6.png Start debugging(开始调试),然后等待应用加载。
  2. 输入一些文本,然后选择 REST > Classify(分类)

8e21d795af36d07a.png e79a0367a03c2169.png

8. 通过 gRPC 将 Flutter 应用与 TensorFlow Serving 关联起来

除了 REST 之外,TensorFlow Serving 还支持 gRPC

b6f4449c2c850b0e.png

gRPC 是一种开放源代码的现代高性能远程过程调用 (RPC) 框架,可以在任何环境中运行。借助可插拔支持,它可以在数据中心内和跨数据中心高效地连接服务,以实现负载均衡、跟踪、健康检查和身份验证。我们发现,在实践中,gRPC 的性能比 REST 更高。

使用 gRPC 发送请求和接收响应

要使用 gRPC 发送请求和接收响应需要执行四个简单的步骤:

  1. 可选:生成 gRPC 客户端桩代码。
  2. 创建 gRPC 请求。
  3. 向 TensorFlow Serving 发送 gRPC 请求。
  4. 从 gRPC 响应中提取预测结果,并呈现界面。

您需要在 main.dart 文件中完成这些步骤。

可选:生成 gRPC 客户端桩代码

如需将 gRPC 与 TensorFlow Serving 搭配使用,您需要遵循 gRPC 工作流。如需了解详情,请参阅 gRPC 文档

a9d0e5cb543467b4.png

TensorFlow Serving 和 TensorFlow 会为您定义 .proto 文件。从 TensorFlow 和 TensorFlow Serving 2.8 开始,以下 .proto 文件是必须提供的文件:

tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto

tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto

google/protobuf/any.proto
google/protobuf/wrappers.proto
  • 在您的终端中,找到 starter/lib/proto/ 文件夹并生成桩:
bash generate_grpc_stub_dart.sh

创建 gRPC 请求

与 REST 请求类似,您可以在 gRPC 分支中创建 gRPC 请求。

if (_connectionMode == ConnectionModeType.rest) {

} else {
  // TODO: Create and send the gRPC request.

}
  • 添加以下代码以创建 gRPC 请求:
//Create the gRPC request.
final channel = ClientChannel(_server,
    port: grpcPort,
    options:
        const ChannelOptions(credentials: ChannelCredentials.insecure()));
_stub = PredictionServiceClient(channel,
    options: CallOptions(timeout: const Duration(seconds: 10)));

ModelSpec modelSpec = ModelSpec(
  name: 'spam-detection',
  signatureName: 'serving_default',
);

TensorShapeProto_Dim batchDim = TensorShapeProto_Dim(size: Int64(1));
TensorShapeProto_Dim inputDim =
    TensorShapeProto_Dim(size: Int64(maxSentenceLength));
TensorShapeProto inputTensorShape =
    TensorShapeProto(dim: [batchDim, inputDim]);
TensorProto inputTensor = TensorProto(
    dtype: DataType.DT_INT32,
    tensorShape: inputTensorShape,
    intVal: _tokenIndices);

// If you train your own model, update the input and output tensor names.
const inputTensorName = 'input_3';
const outputTensorName = 'dense_5';
PredictRequest request = PredictRequest(
    modelSpec: modelSpec, inputs: {inputTensorName: inputTensor});

注意:输入和输出张量名称可能因模型而异,即使模型架构相同也是如此。如果您训练自己的模型,请务必更新这些名称。

向 TensorFlow Serving 发送 gRPC 请求

  • 在上述代码段之后添加以下代码,以将 gRPC 请求发送到 TensorFlow Serving:
// Send the gRPC request.
PredictResponse response = await _stub.predict(request);

处理来自 TensorFlow Serving 的 gRPC 响应

  • 在上述代码段之后添加以下代码,以实现用于处理响应的回调函数:
// Process the response.
if (response.outputs.containsKey(outputTensorName)) {
  if (response.outputs[outputTensorName]!.floatVal[1] >
      classificationThreshold) {
    return 'This sentence is spam. Spam score is ' +
        response.outputs[outputTensorName]!.floatVal[1].toString();
  } else {
    return 'This sentence is not spam. Spam score is ' +
        response.outputs[outputTensorName]!.floatVal[1].toString();
  }
} else {
  throw Exception('Error response');
}

现在,后处理代码会从响应中提取分类结果,并将其显示在界面中。

运行应用

  1. 点击 a19a0c68bc4046e6.png Start debugging(开始调试),然后等待应用加载。
  2. 输入一些文本,然后选择 gRPC > Classify(分类)

e44e6e9a5bde2188.png 92644d723f61968c.png

9. 恭喜

您已使用 TensorFlow Serving 为应用添加了文本分类功能!

在下一个 Codelab 中,您将增强该模型,以便检测当前应用无法检测到的特定垃圾信息。

了解详情