Thông tin về lớp học lập trình này
1. Trước khi bắt đầu
Trong lớp học lập trình này, bạn tìm hiểu cách chạy suy luận phát hiện đối tượng từ một ứng dụng Android bằng cách sử dụng việc phân phát TensorFlow với REST và gRPC.
Điều kiện tiên quyết
- Kiến thức cơ bản về việc phát triển Android thông qua Java
- Kiến thức cơ bản về công nghệ máy học với TensorFlow, chẳng hạn như chương trình đào tạo và triển khai
- Kiến thức cơ bản về thiết bị đầu cuối và Docker
Kiến thức bạn sẽ học được
- Cách tìm các mô hình phát hiện đối tượng được đào tạo trước trên TensorFlow Hub.
- Cách xây dựng một ứng dụng Android đơn giản và đưa ra dự đoán bằng mô hình phát hiện đối tượng được tải xuống thông qua việc phân phát TensorFlow (REST và gRPC).
- Cách hiển thị kết quả phát hiện trong giao diện người dùng.
Bạn cần có
- Phiên bản mới nhất của Android Studio
- Docker
- Bash
2. Bắt đầu thiết lập
Cách tải mã xuống cho lớp học lập trình này:
- Chuyển đến kho lưu trữ GitHub cho lớp học lập trình này.
- Nhấp vào Code > Download zip để tải tất cả mã xuống cho lớp học lập trình này.
- Giải nén tệp zip đã tải xuống để giải nén thư mục gốc
codelabs
bằng tất cả tài nguyên mà bạn cần.
Đối với lớp học lập trình này, bạn chỉ cần các tệp trong thư mục con TFServing/ObjectDetectionAndroid
trong kho lưu trữ, nơi chứa hai thư mục:
- Thư mục
starter
chứa mã dành cho người mới bắt đầu mà bạn xây dựng cho lớp học lập trình này. - Thư mục
finished
chứa mã đã hoàn tất cho ứng dụng mẫu đã hoàn thành.
3. Thêm các phần phụ thuộc vào dự án
Nhập ứng dụng dành cho người mới bắt đầu vào Android Studio
- Trong Android Studio, hãy nhấp vào Tệp > Mới > Nhập dự án, rồi chọn thư mục
starter
từ mã nguồn mà bạn đã tải xuống trước đó.
Thêm các phần phụ thuộc cho OkHttp và gRPC
- Trong tệp
app/build.gradle
của dự án, hãy xác nhận sự hiện diện của các phần phụ thuộc.
dependencies {
// ...
implementation 'com.squareup.okhttp3:okhttp:4.9.0'
implementation 'javax.annotation:javax.annotation-api:1.3.2'
implementation 'io.grpc:grpc-okhttp:1.29.0'
implementation 'io.grpc:grpc-protobuf-lite:1.29.0'
implementation 'io.grpc:grpc-stub:1.29.0'
}
Đồng bộ hóa dự án của bạn với các tệp cho Gradle
- Chọn
Đồng bộ hóa dự án với tệp Gradle từ trình đơn điều hướng.
4. Chạy ứng dụng khởi động
- Khởi động Trình mô phỏng Android,sau đó nhấp vào
Chạy "ứng dụng\39; trong trình đơn điều hướng.
Chạy và khám phá ứng dụng
Ứng dụng sẽ khởi chạy trên thiết bị Android của bạn. Giao diện người dùng khá đơn giản: có một hình ảnh mèo mà bạn muốn phát hiện các đối tượng và người dùng có thể chọn cách gửi dữ liệu tới phần phụ trợ, với REST hoặc gRPC. Phụ trợ này thực hiện chức năng phát hiện đối tượng trên hình ảnh và trả lại kết quả phát hiện cho ứng dụng khách, ứng dụng này sẽ hiển thị lại giao diện người dùng.
Ngay bây giờ, nếu bạn nhấp vào Chạy dự đoán, thì sẽ không có gì xảy ra. Điều này là do nó không thể giao tiếp với chương trình phụ trợ.
5. Triển khai mô hình phát hiện đối tượng bằng tính năng Phân phát TensorFlow
Phát hiện đối tượng là một tác vụ máy học rất phổ biến và mục tiêu của nó là phát hiện các đối tượng trong hình ảnh, cụ thể là dự đoán các danh mục có thể có của các đối tượng và hộp giới hạn xung quanh chúng. Sau đây là ví dụ về kết quả phát hiện:
Google đã phát hành một số mô hình được đào tạo trước trên TensorFlow Hub. Để xem danh sách đầy đủ, hãy truy cập vào trang object_detection. Bạn sử dụng mô hình SSD MobileNet V2 FPNLite 320x320 tương đối nhẹ cho lớp học lập trình này để bạn không nhất thiết phải sử dụng GPU để chạy mô hình đó.
Để triển khai mô hình phát hiện đối tượng bằng tính năng Phân phát TensorFlow:
- Tải tệp mô hình xuống.
- Giải nén tệp
.tar.gz
đã tải xuống bằng một công cụ giải nén, chẳng hạn như 7-zip. - Tạo thư mục
ssd_mobilenet_v2_2_320
, sau đó tạo một thư mục con123
bên trong thư mục đó. - Đặt thư mục
variables
đã trích xuất và tệpsaved_model.pb
vào thư mục con123
.
Bạn có thể tham chiếu thư mục ssd_mobilenet_v2_2_320
dưới dạng thư mục SavedModel
. 123
là số phiên bản mẫu. Nếu muốn, bạn có thể chọn một số khác.
Cấu trúc thư mục sẽ có dạng như sau:
Bắt đầu phân phát TensorFlow
- Trên thiết bị đầu cuối, hãy bắt đầu phân phát TensorFlow bằng Docker, nhưng thay thế trình giữ chỗ
PATH/TO/SAVEDMODEL
bằng đường dẫn tuyệt đối của thư mụcssd_mobilenet_v2_2_320
trên máy tính.
docker pull tensorflow/serving docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/ssd_mobilenet_v2_2" -e MODEL_NAME=ssd_mobilenet_v2_2 tensorflow/serving
Docker, hệ thống sẽ tự động tải hình ảnh Phân phát TensorFlow xuống trước. Quá trình này sẽ mất một phút. Sau đó, Quá trình phân phát TensorFlow sẽ bắt đầu. Nhật ký sẽ trông giống như đoạn mã này:
2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle. 2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz 2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123 2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds. 2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests 2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: ssd_mobilenet_v2_2 version: 123} 2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models 2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled 2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ... [warn] getaddrinfo: address family for nodename not supported 2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ... [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
6. Kết nối ứng dụng Android với TensorFlow Delivery thông qua REST
Phần phụ trợ này hiện đã sẵn sàng, vì vậy bạn có thể gửi các yêu cầu ứng dụng đến TensorFlow Delivery để phát hiện các đối tượng trong hình ảnh. Có hai cách để gửi yêu cầu đến Phục vụ TensorFlow:
- Kiến trúc chuyển trạng thái đại diện (REST)
- gRPC
Gửi yêu cầu và nhận phản hồi qua REST
Có 3 bước đơn giản:
- Tạo yêu cầu REST.
- Gửi yêu cầu REST tới TensorFlowServe.
- Trích xuất kết quả dự đoán từ phản hồi REST và hiển thị giao diện người dùng.
Bạn sẽ đạt được các mục này trong MainActivity.java.
Tạo yêu cầu REST
Hiện tại, có một hàm createRESTRequest()
trống trong tệp MainActivity.java
. Bạn triển khai hàm này để tạo một yêu cầu REST.
private Request createRESTRequest() {
}
TensorFlow Phục vụ yêu cầu một yêu cầu POST có chứa áp lực hình ảnh cho mô hình SSD MobileNet mà bạn sử dụng, vì vậy, bạn cần trích xuất các giá trị RGB từ mỗi pixel của hình ảnh vào một mảng, sau đó gói mảng vào một JSON, là tải trọng của yêu cầu.
- Thêm mã này vào hàm
createRESTRequest()
:
//Create the REST request.
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
int[][][][] inputImgRGB = new int[1][INPUT_IMG_HEIGHT][INPUT_IMG_WIDTH][3];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
// Extract RBG values from each pixel; alpha is ignored
pixel = inputImg[i * INPUT_IMG_WIDTH + j];
inputImgRGB[0][i][j][0] = ((pixel >> 16) & 0xff);
inputImgRGB[0][i][j][1] = ((pixel >> 8) & 0xff);
inputImgRGB[0][i][j][2] = ((pixel) & 0xff);
}
}
RequestBody requestBody =
RequestBody.create("{\"instances\": " + Arrays.deepToString(inputImgRGB) + "}", JSON);
Request request =
new Request.Builder()
.url("http://" + SERVER + ":" + REST_PORT + "/v1/models/" + MODEL_NAME + ":predict")
.post(requestBody)
.build();
return request;
Gửi yêu cầu REST đến Phục vụ TensorFlow
Ứng dụng này cho phép người dùng chọn REST hoặc gRPC để giao tiếp với TensorFlow Delivery, vì vậy có hai nhánh trong trình nghe onClick(View view)
.
predictButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View view) {
if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
// TODO: REST request
}
else {
}
}
}
)
- Thêm mã này vào nhánh REST của trình nghe
onClick(View view)
để sử dụng OkHttp để gửi yêu cầu đến TensorFlowServe:
// Send the REST request.
Request request = createRESTRequest();
try {
client =
new OkHttpClient.Builder()
.connectTimeout(20, TimeUnit.SECONDS)
.writeTimeout(20, TimeUnit.SECONDS)
.readTimeout(20, TimeUnit.SECONDS)
.callTimeout(20, TimeUnit.SECONDS)
.build();
Response response = client.newCall(request).execute();
JSONObject responseObject = new JSONObject(response.body().string());
postprocessRESTResponse(responseObject);
} catch (IOException | JSONException e) {
Log.e(TAG, e.getMessage());
responseTextView.setText(e.getMessage());
return;
}
Xử lý phản hồi REST (Phân phát) của TensorFlow
Mô hình SSD MobileNet trả về một số kết quả, bao gồm:
num_detections
: số lượt phát hiệndetection_scores
: điểm phát hiệndetection_classes
: chỉ mục của lớp phát hiệndetection_boxes
: tọa độ hộp giới hạn
Bạn triển khai hàm postprocessRESTResponse()
để xử lý phản hồi.
private void postprocessRESTResponse(Predict.PredictResponse response) {
}
- Thêm mã này vào hàm
postprocessRESTResponse()
:
// Process the REST response.
JSONArray predictionsArray = responseObject.getJSONArray("predictions");
//You only send one image, so you directly extract the first element.
JSONObject predictions = predictionsArray.getJSONObject(0);
// Argmax
int maxIndex = 0;
JSONArray detectionScores = predictions.getJSONArray("detection_scores");
for (int j = 0; j < predictions.getInt("num_detections"); j++) {
maxIndex =
detectionScores.getDouble(j) > detectionScores.getDouble(maxIndex + 1) ? j : maxIndex;
}
int detectionClass = predictions.getJSONArray("detection_classes").getInt(maxIndex);
JSONArray boundingBox = predictions.getJSONArray("detection_boxes").getJSONArray(maxIndex);
double ymin = boundingBox.getDouble(0);
double xmin = boundingBox.getDouble(1);
double ymax = boundingBox.getDouble(2);
double xmax = boundingBox.getDouble(3);
displayResult(detectionClass, (float) ymin, (float) xmin, (float) ymax, (float) xmax);
Giờ đây, hàm xử lý hậu kỳ trích xuất các giá trị dự đoán từ phản hồi, xác định danh mục có thể xảy ra nhất của đối tượng và tọa độ của các đỉnh hộp giới hạn, cuối cùng hiển thị hộp giới hạn phát hiện trên giao diện người dùng.
Chạy
- Nhấp vào
Chạy "ứng dụng\39"; trong trình đơn điều hướng rồi chờ ứng dụng tải.
- Chọn REST > Chạy dự đoán.
Ứng dụng này sẽ mất vài giây trước khi ứng dụng hiển thị hộp giới hạn của mèo và cho thấy 17
làm danh mục của đối tượng, liên kết tới đối tượng cat
trong tập dữ liệu COCO.
7. Kết nối ứng dụng Android với TensorFlow Delivery thông qua gRPC
Ngoài REST, việc phân phát TensorFlow cũng hỗ trợ gRPC.
gRPC là một khung lệnh gọi từ xa, nguồn mở hiện đại, hiệu suất cao, có thể chạy trong mọi môi trường. API này có thể kết nối các dịch vụ trong và trên các trung tâm dữ liệu một cách hiệu quả với khả năng hỗ trợ dễ dàng để cân bằng tải, theo dõi, kiểm tra tình trạng và xác thực. Theo quan sát, gRPC có hiệu suất cao hơn REST trong thực tế.
Gửi yêu cầu và nhận phản hồi bằng gRPC
Có 4 bước đơn giản:
- [Không bắt buộc] Tạo mã mã ứng dụng gRPC.
- Tạo yêu cầu gRPC.
- Gửi yêu cầu gRPC đến TensorFlowServe.
- Trích xuất kết quả dự đoán từ phản hồi gRPC và hiển thị giao diện người dùng.
Bạn sẽ đạt được các mục này trong MainActivity.java.
Không bắt buộc: Tạo mã mã ứng dụng gRPC
Để sử dụng gRPC với TensorFlowServe, bạn cần phải tuân theo quy trình làm việc của gRPC. Để tìm hiểu thêm về các chi tiết, hãy xem tài liệu gRPC.
TensorFlow và TensorFlow xác định các tệp .proto
cho bạn. Kể từ TensorFlow và TensorFlow phân phát 2.8, .proto
tệp này là những tệp cần thiết:
tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto
tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto
- Để tạo mã giả, hãy thêm mã này vào tệp
app/build.gradle
.
apply plugin: 'com.google.protobuf'
protobuf {
protoc { artifact = 'com.google.protobuf:protoc:3.11.0' }
plugins {
grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.29.0'
}
}
generateProtoTasks {
all().each { task ->
task.builtins {
java { option 'lite' }
}
task.plugins {
grpc { option 'lite' }
}
}
}
}
Tạo yêu cầu gRPC
Tương tự như yêu cầu REST, bạn tạo yêu cầu gRPC trong hàm createGRPCRequest()
.
private Request createGRPCRequest() {
}
- Thêm mã này vào hàm
createGRPCRequest()
:
if (stub == null) {
channel = ManagedChannelBuilder.forAddress(SERVER, GRPC_PORT).usePlaintext().build();
stub = PredictionServiceGrpc.newBlockingStub(channel);
}
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(MODEL_NAME);
modelSpecBuilder.setVersion(Int64Value.of(MODEL_VERSION));
modelSpecBuilder.setSignatureName(SIGNATURE_NAME);
Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
builder.setModelSpec(modelSpecBuilder);
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_UINT8);
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_HEIGHT));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_WIDTH));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
// Extract RBG values from each pixel; alpha is ignored.
pixel = inputImg[i * INPUT_IMG_WIDTH + j];
tensorProtoBuilder.addIntVal((pixel >> 16) & 0xff);
tensorProtoBuilder.addIntVal((pixel >> 8) & 0xff);
tensorProtoBuilder.addIntVal((pixel) & 0xff);
}
}
TensorProto tensorProto = tensorProtoBuilder.build();
builder.putInputs("input_tensor", tensorProto);
builder.addOutputFilter("num_detections");
builder.addOutputFilter("detection_boxes");
builder.addOutputFilter("detection_classes");
builder.addOutputFilter("detection_scores");
return builder.build();
Gửi yêu cầu gRPC đến TensorFlowServe
Giờ đây, bạn đã có thể nghe xong trình nghe onClick(View view)
.
predictButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View view) {
if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
}
else {
// TODO: gRPC request
}
}
}
)
- Thêm mã này vào chi nhánh gRPC:
try {
Predict.PredictRequest request = createGRPCRequest();
Predict.PredictResponse response = stub.predict(request);
postprocessGRPCResponse(response);
} catch (Exception e) {
Log.e(TAG, e.getMessage());
responseTextView.setText(e.getMessage());
return;
}
Xử lý phản hồi gRPC từ TensorFlowServe
Tương tự như gRPC, bạn triển khai hàm postprocessGRPCResponse()
để xử lý phản hồi.
private void postprocessGRPCResponse(Predict.PredictResponse response) {
}
- Thêm mã này vào hàm
postprocessGRPCResponse()
:
// Process the response.
float numDetections = response.getOutputsMap().get("num_detections").getFloatValList().get(0);
List<Float> detectionScores = response.getOutputsMap().get("detection_scores").getFloatValList();
int maxIndex = 0;
for (int j = 0; j < numDetections; j++) {
maxIndex = detectionScores.get(j) > detectionScores.get(maxIndex + 1) ? j : maxIndex;
}
Float detectionClass = response.getOutputsMap().get("detection_classes").getFloatValList().get(maxIndex);
List<Float> boundingBoxValues = response.getOutputsMap().get("detection_boxes").getFloatValList();
float ymin = boundingBoxValues.get(maxIndex * 4);
float xmin = boundingBoxValues.get(maxIndex * 4 + 1);
float ymax = boundingBoxValues.get(maxIndex * 4 + 2);
float xmax = boundingBoxValues.get(maxIndex * 4 + 3);
displayResult(detectionClass.intValue(), ymin, xmin, ymax, xmax);
Giờ đây, hàm xử lý hậu kỳ có thể trích xuất các giá trị dự đoán từ phản hồi và hiển thị hộp giới hạn phát hiện trong giao diện người dùng.
Chạy
- Nhấp vào
Chạy "ứng dụng\39"; trong trình đơn điều hướng rồi chờ ứng dụng tải.
- Chọn gRPC > Chạy dự đoán.
Sẽ mất vài giây trước khi ứng dụng hiển thị hộp giới hạn của mèo và hiển thị 17
dưới dạng danh mục của đối tượng, liên kết tới danh mục cat
trong tập dữ liệu COCO.
8. Xin chúc mừng
Bạn đã sử dụng tính năng Phân phát TensorFlow để thêm các tính năng phát hiện đối tượng vào ứng dụng của bạn!