1. 准备工作
在此 Codelab 中,您将学习如何通过 REST 和 gRPC 使用 TensorFlow Serving 从 Flutter 应用运行文本分类推断。
前提条件
- 了解有关使用 Dart 开发 Flutter 应用的基本知识
- 了解有关使用 TensorFlow 进行机器学习(例如训练与部署)的基本知识
- 了解有关终端和 Docker 的基本知识
- 学习《使用 TensorFlow Lite Model Maker 训练垃圾评论检测模型》Codelab
学习内容
- 如何通过 TensorFlow Serving(REST 和 gRPC)构建简单的 Flutter 应用并对文本进行分类。
- 如何在界面中显示结果。
您需要满足的条件
- Flutter SDK
- 适用于 Flutter 的 Android 或 iOS 设置
- 适用于 Flutter 和 Dart 的 Visual Studio Code (VS Code) 设置
- Docker
- Bash
- 协议缓冲区编译器和适用于协议编译器的 gRPC Dart 插件(仅当您要自行重新生成 gRPC 桩时才需要)
2. 进行设置
如需下载此 Codelab 的代码,请执行以下操作:
- 找到此 Codelab 的 GitHub 代码库。
- 点击 Code(代码)> Download zip(下载 Zip 文件),下载此 Codelab 的所有代码。
- 解压缩已下载的 Zip 文件,这会解压缩
codelabs-master
根文件,其中包含您需要的所有资源。
在此 Codelab 中,您只需要代码库的 tfserving-flutter/codelab2
子目录(其中包含两个文件夹)中的文件:
starter
文件夹包含您在此 Codelab 中执行构建的起始代码。finished
文件夹包含已完成示例应用的完成后的代码。
3. 下载项目的依赖项
- 在 VS Code 中,点击 File(文件)> Open folder(打开文件夹),然后从您之前下载的源代码中选择
starter
文件夹。 - 如果您看到一个对话框,提示您下载起始应用所需的软件包,请点击 Get packages(获取软件包)。
- 如果您没有看到此对话框,请打开终端,然后在
starter
文件夹中运行flutter pub get
命令。
4. 运行起始应用
- 在 VS Code 中,确保 Android 模拟器或 iOS 模拟器已正确设置并显示在状态栏中。
例如,当您将 Pixel 5 与 Android 模拟器搭配使用时,会看到以下内容:
当您将 iPhone 13 与 iOS 模拟器搭配使用时,会看到以下内容:
- 点击 Start debugging(开始调试)。
运行和探索应用
应用应在 Android 模拟器或 iOS 模拟器上启动。界面非常简单。系统提供了文本字段,可让用户输入文本。用户可以选择是使用 REST 还是 gRPC 将数据发送到后端。后端使用 TensorFlow 模型对预处理的输入执行文本分类,并将分类结果返回给客户端应用,客户端应用进而更新界面。
点击 Classify(分类)不起作用,因为它还不能与后端进行通信。
5. 使用 TensorFlow Serving 部署文本分类模型
文本分类是一项非常常见的机器学习任务,用于将文本分为预定义的类别。在此 Codelab 中,您将使用 TensorFlow Serving 部署《使用 TensorFlow Lite Model Maker 训练垃圾评论检测模型》Codelab 中的预训练模型,并从 Flutter 前端调用后端以将输入文本分类为垃圾文本或非垃圾文本。
启动 TensorFlow Serving
- 在您的终端中,使用 Docker 启动 TensorFlow Serving,但需要注意将
PATH/TO/SAVEDMODEL
占位符替换为您计算机上的mm_spam_savedmodel
文件夹的绝对路径。
docker pull tensorflow/serving docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/spam-detection" -e MODEL_NAME=spam-detection tensorflow/serving
Docker 会先自动下载 TensorFlow Serving 映像,此过程需要一分钟时间。之后,TensorFlow Serving 便会启动。日志应如下面的代码段所示:
2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle. 2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz 2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123 2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds. 2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests 2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: spam-detection version: 123} 2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models 2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled 2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ... [warn] getaddrinfo: address family for nodename not supported 2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ... [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
6. 对输入句子进行词元化处理
现在,后端已准备就绪,您也就基本上作好了将客户端请求发送到 TensorFlow Serving 的准备,但还需要先将输入句子进行词元化处理。如果您检查模型的输入张量,会发现它需要的是一个包含 20 个整数的列表,而不是原始字符串。词元化是指根据词典将您输入到应用的各个字词映射为一个整数列表的过程,然后将其发送到后端进行分类。例如,如果您输入 buy book online to learn more
,则词元化过程会将其映射为 [32, 79, 183, 10, 224, 631, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
。具体数字因词典而异。
- 在
lib/main.dart
文件中,将此代码添加到predict()
方法以构建_vocabMap
词典。
// Build _vocabMap if empty.
if (_vocabMap.isEmpty) {
final vocabFileString = await rootBundle.loadString(vocabFile);
final lines = vocabFileString.split('\n');
for (final l in lines) {
if (l != "") {
var wordAndIndex = l.split(' ');
(_vocabMap)[wordAndIndex[0]] = int.parse(wordAndIndex[1]);
}
}
}
- 紧接着上述代码段之后,添加以下代码以实现字元化处理:
// Tokenize the input sentence.
final inputWords = _inputSentenceController.text
.toLowerCase()
.replaceAll(RegExp('[^a-z ]'), '')
.split(' ');
// Initialize with padding token.
_tokenIndices = List.filled(maxSentenceLength, 0);
var i = 0;
for (final w in inputWords) {
if ((_vocabMap).containsKey(w)) {
_tokenIndices[i] = (_vocabMap)[w]!;
i++;
}
// Truncate the string if longer than maxSentenceLength.
if (i >= maxSentenceLength - 1) {
break;
}
}
此代码会将句子中的字符串转换为小写,移除非字母字符,并根据词汇表将字词映射成 20 个整数索引。
7. 通过 REST 将 Flutter 应用与 TensorFlow Serving 关联起来
您可以通过以下两种方式向 TensorFlow Serving 发送请求:
- REST
- gRPC
通过 REST 发送请求和接收响应
要通过 REST 发送请求和接收响应需要执行三个简单的步骤:
- 创建 REST 请求。
- 将 REST 请求发送到 TensorFlow Serving。
- 从 REST 响应中提取预测结果,并呈现界面。
您需要在 main.dart
文件中完成这些步骤。
创建 REST 请求并将其发送到 TensorFlow Serving
- 目前,
predict()
函数不会将 REST 请求发送到 TensorFlow Serving。您需要实现 REST 分支才能创建 REST 请求:
if (_connectionMode == ConnectionModeType.rest) {
// TODO: Create and send the REST request.
}
- 将以下代码添加到 REST 分支:
//Create the REST request.
final response = await http.post(
Uri.parse('http://' +
_server +
':' +
restPort.toString() +
'/v1/models/' +
modelName +
':predict'),
body: jsonEncode(<String, List<List<int>>>{
'instances': [_tokenIndices],
}),
);
处理来自 TensorFlow Serving 的 REST 响应
- 将以下代码添加到上述代码段之后,以处理 REST 响应:
// Process the REST response.
if (response.statusCode == 200) {
Map<String, dynamic> result = jsonDecode(response.body);
if (result['predictions']![0][1] >= classificationThreshold) {
return 'This sentence is spam. Spam score is ' +
result['predictions']![0][1].toString();
}
return 'This sentence is not spam. Spam score is ' +
result['predictions']![0][1].toString();
} else {
throw Exception('Error response');
}
后处理代码会从响应中提取输入句子是垃圾信息的概率,并在界面中显示分类结果。
运行应用
- 点击 Start debugging(开始调试),然后等待应用加载。
- 输入一些文本,然后选择 REST > Classify(分类)。
8. 通过 gRPC 将 Flutter 应用与 TensorFlow Serving 关联起来
除了 REST 之外,TensorFlow Serving 还支持 gRPC。
gRPC 是一种开放源代码的现代高性能远程过程调用 (RPC) 框架,可以在任何环境中运行。借助可插拔支持,它可以在数据中心内和跨数据中心高效地连接服务,以实现负载均衡、跟踪、健康检查和身份验证。我们发现,在实践中,gRPC 的性能比 REST 更高。
使用 gRPC 发送请求和接收响应
要使用 gRPC 发送请求和接收响应需要执行四个简单的步骤:
- 可选:生成 gRPC 客户端桩代码。
- 创建 gRPC 请求。
- 向 TensorFlow Serving 发送 gRPC 请求。
- 从 gRPC 响应中提取预测结果,并呈现界面。
您需要在 main.dart
文件中完成这些步骤。
可选:生成 gRPC 客户端桩代码
如需将 gRPC 与 TensorFlow Serving 搭配使用,您需要遵循 gRPC 工作流。如需了解详情,请参阅 gRPC 文档。
TensorFlow Serving 和 TensorFlow 会为您定义 .proto
文件。从 TensorFlow 和 TensorFlow Serving 2.8 开始,以下 .proto
文件是必须提供的文件:
tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto
tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto
google/protobuf/any.proto
google/protobuf/wrappers.proto
- 在您的终端中,找到
starter/lib/proto/
文件夹并生成桩:
bash generate_grpc_stub_dart.sh
创建 gRPC 请求
与 REST 请求类似,您可以在 gRPC 分支中创建 gRPC 请求。
if (_connectionMode == ConnectionModeType.rest) {
} else {
// TODO: Create and send the gRPC request.
}
- 添加以下代码以创建 gRPC 请求:
//Create the gRPC request.
final channel = ClientChannel(_server,
port: grpcPort,
options:
const ChannelOptions(credentials: ChannelCredentials.insecure()));
_stub = PredictionServiceClient(channel,
options: CallOptions(timeout: const Duration(seconds: 10)));
ModelSpec modelSpec = ModelSpec(
name: 'spam-detection',
signatureName: 'serving_default',
);
TensorShapeProto_Dim batchDim = TensorShapeProto_Dim(size: Int64(1));
TensorShapeProto_Dim inputDim =
TensorShapeProto_Dim(size: Int64(maxSentenceLength));
TensorShapeProto inputTensorShape =
TensorShapeProto(dim: [batchDim, inputDim]);
TensorProto inputTensor = TensorProto(
dtype: DataType.DT_INT32,
tensorShape: inputTensorShape,
intVal: _tokenIndices);
// If you train your own model, update the input and output tensor names.
const inputTensorName = 'input_3';
const outputTensorName = 'dense_5';
PredictRequest request = PredictRequest(
modelSpec: modelSpec, inputs: {inputTensorName: inputTensor});
注意:输入和输出张量名称可能因模型而异,即使模型架构相同也是如此。如果您训练自己的模型,请务必更新这些名称。
向 TensorFlow Serving 发送 gRPC 请求
- 在上述代码段之后添加以下代码,以将 gRPC 请求发送到 TensorFlow Serving:
// Send the gRPC request.
PredictResponse response = await _stub.predict(request);
处理来自 TensorFlow Serving 的 gRPC 响应
- 在上述代码段之后添加以下代码,以实现用于处理响应的回调函数:
// Process the response.
if (response.outputs.containsKey(outputTensorName)) {
if (response.outputs[outputTensorName]!.floatVal[1] >
classificationThreshold) {
return 'This sentence is spam. Spam score is ' +
response.outputs[outputTensorName]!.floatVal[1].toString();
} else {
return 'This sentence is not spam. Spam score is ' +
response.outputs[outputTensorName]!.floatVal[1].toString();
}
} else {
throw Exception('Error response');
}
现在,后处理代码会从响应中提取分类结果,并将其显示在界面中。
运行应用
- 点击 Start debugging(开始调试),然后等待应用加载。
- 输入一些文本,然后选择 gRPC > Classify(分类)。
9. 恭喜
您已使用 TensorFlow Serving 为应用添加了文本分类功能!
在下一个 Codelab 中,您将增强该模型,以便检测当前应用无法检测到的特定垃圾信息。