1. 事前準備
在本程式碼研究室中,您將瞭解如何使用 TensorFlow 和 gRPC,透過 TensorFlow Serving 從 Flutter 應用程式執行文字分類推論。
必要條件
- 透過 Dart 進行 Flutter 開發的基本概念
- TensorFlow 機器學習基本知識,例如訓練和部署作業
- 終端機和 Docker 的基本知識
- 使用 TensorFlow Lite Model Maker 程式碼研究室訓練垃圾留言偵測模型
課程內容
- 瞭解如何使用 TensorFlow Serving (REST 和 gRPC) 建構簡單的 Flutter 應用程式並將文字分類。
- 如何在使用者介面中顯示結果。
軟硬體需求
- Flutter SDK
- Flutter 的 Android 或 iOS 設定
- Flutter 和 Dart 適用的 Visual Studio Code (VS Code)
- Docker
- 現金
- 「通訊協定緩衝區編譯器」和「通訊協定編譯器適用的 gRPC Dart 外掛程式」 (只有當您想要自行重新產生 gRPC 存根時,才需要使用這個選項)
2. 設定 Flutter 開發環境
如要進行 Flutter 開發,您必須安裝兩個軟體:Flutter SDK 和編輯器。
您可以透過下列任一裝置執行程式碼研究室:
- iOS 模擬工具 (必須安裝 Xcode 工具)。
- Android 模擬器 (需在 Android Studio 中設定)。
- 瀏覽器 (需要安裝 Chrome 才能進行偵錯)。
- 提供 Windows、Linux 或 macOS 電腦版應用程式。您必須在預計部署的平台上進行開發。因此,如果您要開發 Windows 桌面應用程式,就必須在 Windows 上開發,才能存取適當的建構鏈。如要瞭解作業系統相關規定,請參閱 docs.flutter.dev/desktop。
3. 做好準備
若要下載此程式碼研究室的程式碼:
- 前往這項程式碼研究室的 GitHub 存放區。
- 按一下 [程式碼 >下載 zip],即可下載這個程式碼研究室的所有程式碼。
- 將下載的 ZIP 檔案解壓縮,解壓縮您需要的所有
codelabs-main
根資料夾。
在這個程式碼研究室中,您只需要存放區的 tfserving-flutter/codelab2
子目錄中的檔案,其中包含兩個資料夾:
starter
資料夾包含您為這個程式碼研究室建立的範例程式碼。finished
資料夾包含已完成範例應用程式的範例程式碼。
4. 下載專案的依附元件
- 在 VS Code 中,按一下 [File > OpenFolder],然後從先前下載的原始碼中選取
starter
資料夾。 - 如果畫面上出現對話方塊,提示您下載啟動程式應用程式所需的套件,請按一下 [取得套件]。
- 如果您沒有看到這個對話方塊,請開啟終端機,然後在
starter
資料夾中執行flutter pub get
指令。
5. 執行啟動應用程式
- 在 VS Code 中,確認 Android Emulator 或 iOS 模擬工具已正確設定並顯示在狀態列中。
舉例來說,以下為搭配 Android Emulator 使用 Pixel 5 時所顯示的內容:
透過 iOS 模擬器使用 iPhone 13 時會看到下列資訊:
- 點選 [開始偵錯]。
執行並探索應用程式
應用程式應該會在 Android Emulator 或 iOS 模擬工具上啟動。使用者介面相當簡單,這裡有文字欄位,可讓使用者輸入文字。使用者可以選擇要使用 REST 或 gRPC 將資料傳送到後端。後端使用 TensorFlow 模型對預先處理的輸入執行文字分類,並將分類結果傳回用戶端應用程式,進而更新使用者介面。
假如您按一下 [分類],系統就沒有任何作用,因為這個代理程式尚未與後端通訊。
6. 使用 TensorFlow Serving 部署文字分類模型
文字分類是一種常見的機器學習工作,可將文字分類到預先定義的類別。在這個程式碼研究室中,您將使用 TensorFlow Serving,透過使用 TensorFlow Lite Model Maker 程式碼研究室訓練垃圾留言偵測模型部署預先訓練的模型,並呼叫 Flutter 前端的後端,將輸入的文字分類為垃圾內容或非垃圾內容。
啟動 TensorFlow Serving
- 在終端機中,以 Docker 啟動 TensorFlow Serving,但將
PATH/TO/SAVEDMODEL
預留位置改成您電腦上mm_spam_savedmodel
資料夾的絕對路徑。
docker pull tensorflow/serving docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/spam-detection" -e MODEL_NAME=spam-detection tensorflow/serving
Docker 會先自動下載 TensorFlow Serving 映像檔,這可能需要一分鐘的時間。之後,TensorFlow Serving 會隨即啟動。記錄應如下所示:
2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle. 2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz 2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123 2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds. 2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests 2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: spam-detection version: 123} 2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models 2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled 2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ... [warn] getaddrinfo: address family for nodename not supported 2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ... [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
7. 將輸入語句化
後端已就緒,因此你幾乎準備好傳送用戶端要求到 TensorFlow Serving,但必須先將輸入語句化代碼。當您檢查模型的輸入張量時,會發現模型預期會有 20 個整數,而非原始字串。「憑證化」是指您根據詞彙字典,將應用程式中的個別字詞對應至整數清單,然後再傳送至後端進行分類。舉例來說,如果您輸入 buy book online to learn more
,代碼化程序就會對應至 [32, 79, 183, 10, 224, 631, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
。具體數字會因詞彙字典而異。
- 在
lib/main.dart
檔案中,將此程式碼加入predict()
方法以建立_vocabMap
詞彙字典。
// Build _vocabMap if empty.
if (_vocabMap.isEmpty) {
final vocabFileString = await rootBundle.loadString(vocabFile);
final lines = vocabFileString.split('\n');
for (final l in lines) {
if (l != "") {
var wordAndIndex = l.split(' ');
(_vocabMap)[wordAndIndex[0]] = int.parse(wordAndIndex[1]);
}
}
}
- 加入前一個程式碼片段後,立即加入這段程式碼以導入代碼化:
// Tokenize the input sentence.
final inputWords = _inputSentenceController.text
.toLowerCase()
.replaceAll(RegExp('[^a-z ]'), '')
.split(' ');
// Initialize with padding token.
_tokenIndices = List.filled(maxSentenceLength, 0);
var i = 0;
for (final w in inputWords) {
if ((_vocabMap).containsKey(w)) {
_tokenIndices[i] = (_vocabMap)[w]!;
i++;
}
// Truncate the string if longer than maxSentenceLength.
if (i >= maxSentenceLength - 1) {
break;
}
}
此程式碼會使句子字串小寫,移除非字母字元,並根據字彙表格將字詞對應到 20 個整數索引。
8. 透過 REST 將 Flutter 應用程式與 TensorFlow Serving 連結
向 TensorFlow Serving 傳送要求的方法有兩種:
- REST
- gRPC
透過 REST 傳送要求及接收回應
透過 REST 傳送要求及接收回應的三個簡單步驟如下:
- 建立 REST 要求。
- 將 REST 要求傳送至 TensorFlow Serving。
- 從 REST 回應中擷取預測結果並顯示 UI。
您已完成 main.dart
檔案中的步驟。
建立 REST 要求並傳送至 TensorFlow Serving
predict()
函式目前無法向 REST Serving 傳送 REST 要求。您必須導入 REST 分支來建立 REST 要求:
if (_connectionMode == ConnectionModeType.rest) {
// TODO: Create and send the REST request.
}
- 將下列程式碼新增至 REST 分支版本:
//Create the REST request.
final response = await http.post(
Uri.parse('http://' +
_server +
':' +
restPort.toString() +
'/v1/models/' +
modelName +
':predict'),
body: jsonEncode(<String, List<List<int>>>{
'instances': [_tokenIndices],
}),
);
處理 TensorFlow Serving 中的 REST 回應
- 加入下列程式碼,以處理 REST 回應:
// Process the REST response.
if (response.statusCode == 200) {
Map<String, dynamic> result = jsonDecode(response.body);
if (result['predictions']![0][1] >= classificationThreshold) {
return 'This sentence is spam. Spam score is ' +
result['predictions']![0][1].toString();
}
return 'This sentence is not spam. Spam score is ' +
result['predictions']![0][1].toString();
} else {
throw Exception('Error response');
}
後續處理代碼會擷取輸入語句為垃圾訊息的機率,並在使用者介面中顯示分類結果。
執行
- 按一下 [開始偵錯],然後等待應用程式載入。
- 輸入文字,然後選取 [REST > Classify]。
9. 透過 gRPC 連結 Flutter 應用程式與 TensorFlow Serving
除了 REST 以外,TensorFlow Serving 也支援 gRPC。
gRPC 是現代化、開放原始碼的高效能遠端程序呼叫 (RPC) 架構,可在任何環境中執行。這項服務具備可連接的負載平衡、追蹤、健康狀態檢查和驗證等功能,可透過高效率的方式連結資料中心內外的服務。發現在實驗中,gRPC 的成效比 REST 更好。
使用 gRPC 傳送要求及接收回應
透過 gRPC 傳送要求並接收回應的四個簡單步驟:
- 選用:產生 gRPC 用戶端 stub 程式碼。
- 建立 gRPC 要求。
- 將 gRPC 要求傳送至 TensorFlow Serving。
- 從 gRPC 回應中擷取預測結果,然後顯示 UI。
您已完成 main.dart
檔案中的步驟。
選用:產生 gRPC 用戶端 stub 程式碼
如要搭配 TensorFlow Serving 使用 gRPC,您需要遵守 gRPC 工作流程。如要進一步瞭解相關細節,請參閱 gRPC 說明文件。
TensorFlow Serving 和 TensorFlow 會為您定義 .proto
檔案。自 TensorFlow 和 TensorFlow Serving 2.8 起,下列 .proto
檔案為必要檔案:
tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto
tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto
google/protobuf/any.proto
google/protobuf/wrappers.proto
- 在終端機中,前往
starter/lib/proto/
資料夾並產生 stub:
bash generate_grpc_stub_dart.sh
建立 gRPC 要求
與 REST 要求類似,您會在 gRPC 分支中建立 gRPC 要求。
if (_connectionMode == ConnectionModeType.rest) {
} else {
// TODO: Create and send the gRPC request.
}
- 新增這段程式碼以建立 gRPC 要求:
//Create the gRPC request.
final channel = ClientChannel(_server,
port: grpcPort,
options:
const ChannelOptions(credentials: ChannelCredentials.insecure()));
_stub = PredictionServiceClient(channel,
options: CallOptions(timeout: const Duration(seconds: 10)));
ModelSpec modelSpec = ModelSpec(
name: 'spam-detection',
signatureName: 'serving_default',
);
TensorShapeProto_Dim batchDim = TensorShapeProto_Dim(size: Int64(1));
TensorShapeProto_Dim inputDim =
TensorShapeProto_Dim(size: Int64(maxSentenceLength));
TensorShapeProto inputTensorShape =
TensorShapeProto(dim: [batchDim, inputDim]);
TensorProto inputTensor = TensorProto(
dtype: DataType.DT_INT32,
tensorShape: inputTensorShape,
intVal: _tokenIndices);
// If you train your own model, update the input and output tensor names.
const inputTensorName = 'input_3';
const outputTensorName = 'dense_5';
PredictRequest request = PredictRequest(
modelSpec: modelSpec, inputs: {inputTensorName: inputTensor});
注意:即使模型架構相同,輸入和輸出張量名稱也可以隨著模型而有所不同。訓練自己的模型時,請記得更新模型。
將 gRPC 要求傳送至 TensorFlow Serving
- 將這段程式碼加到前一個程式碼片段後方,以將 gRPC 要求傳送至 TensorFlow Serving:
// Send the gRPC request.
PredictResponse response = await _stub.predict(request);
處理 TensorFlow Serving 中的 gRPC 回應
- 將這段程式碼加到前一個程式碼片段後方,以執行回呼函式來處理回應:
// Process the response.
if (response.outputs.containsKey(outputTensorName)) {
if (response.outputs[outputTensorName]!.floatVal[1] >
classificationThreshold) {
return 'This sentence is spam. Spam score is ' +
response.outputs[outputTensorName]!.floatVal[1].toString();
} else {
return 'This sentence is not spam. Spam score is ' +
response.outputs[outputTensorName]!.floatVal[1].toString();
}
} else {
throw Exception('Error response');
}
現在,後處理程式碼會從回應中擷取分類結果,並將其顯示在使用者介面中。
執行
- 按一下 [開始偵錯],然後等待應用程式載入。
- 輸入一些文字,然後選取 [gRPC > Classify]。
10. 恭喜
您使用 TensorFlow Serving 在應用程式中新增文字分類功能!
在下一個程式碼研究室中,您將強化模型,以便偵測目前的應用程式無法偵測到的特定垃圾郵件。