מידע על Codelab זה
1. לפני שמתחילים
במעבדה זו אתם יכולים ללמוד איך להסיק מסקנות של זיהוי אובייקטים מאפליקציה ל-Android באמצעות TensorFlow serving עם REST ו-gRPC.
דרישות מוקדמות
- ידע בסיסי בפיתוח ב-Android עם Java
- ידע בסיסי על למידה חישובית עם TensorFlow, כמו הדרכה ופריסה
- ידע בסיסי על מסופים ותחנות עגינה
מה תלמדו
- איך למצוא מודלים לזיהוי אובייקטים באימון מראש ב-TensorFlow Hub.
- איך לבנות אפליקציה פשוטה ל-Android ולבצע חיזויים באמצעות מודל זיהוי האובייקטים שהורד דרך TensorFlow serving (REST ו-gRPC).
- איך לעבד את תוצאת הזיהוי בממשק המשתמש.
מה תצטרך להכין
- הגרסה האחרונה של Android Studio
- אביזר עגינה
- שיט
2. להגדרה
כדי להוריד את הקוד של Lab Lab זה:
- עוברים אל מאגר GitHub עבור מעבדת קוד זו.
- לוחצים על Code > הורדת zip כדי להוריד את כל הקוד של מעבדת הקוד הזו.
- יש לבטל את הדחיסה של קובץ ה-ZIP שהורדת כדי לפתוח את תיקיית הבסיס של
codelabs
עם כל המשאבים.
ב-codelab זה נדרשים רק הקבצים בספריית המשנה TFServing/ObjectDetectionAndroid
במאגר, המכיל שתי תיקיות:
- התיקייה
starter
מכילה את הקוד למתחילים שעליו כדאי לבנות את מעבדת הקוד הזו. - התיקייה
finished
מכילה את הקוד המלא לאפליקציה לדוגמה שהושלמה.
3. מוסיפים את התלות לפרויקט
ייבוא האפליקציה למתחילים ל-Android Studio
- ב-Android Studio, לוחצים על קובץ > חדש > ייבוא פרויקט ולאחר מכן בוחרים את התיקייה
starter
בקוד המקור שהורדתם בעבר.
הוספת תלות ב-OkHttp וב-gRPC
- בקובץ
app/build.gradle
של הפרויקט, מאשרים את נוכחות התלות.
dependencies {
// ...
implementation 'com.squareup.okhttp3:okhttp:4.9.0'
implementation 'javax.annotation:javax.annotation-api:1.3.2'
implementation 'io.grpc:grpc-okhttp:1.29.0'
implementation 'io.grpc:grpc-protobuf-lite:1.29.0'
implementation 'io.grpc:grpc-stub:1.29.0'
}
סנכרון הפרויקט שלך עם קובצי Gradle
- בוחרים באפשרות
סנכרון פרויקט עם קובצי Gradle מתפריט הניווט.
4. הפעלת האפליקציה למתחילים
- מפעילים את אמולטור Android, ולאחר מכן לוחצים על
הפעלת 'app' בתפריט הניווט.
הרצה ועיון באפליקציה
יש להפעיל את האפליקציה במכשיר Android. ממשק המשתמש די פשוט: יש תמונת חתול שבה רוצים לזהות אובייקטים, והמשתמש יכול לבחור איך לשלוח את הנתונים לקצה העורפי באמצעות REST או gRPC. הקצה העורפי מבצע זיהוי אובייקטים בתמונה ומחזיר את תוצאות הזיהוי לאפליקציית הלקוח, שמעבדת את ממשק המשתמש שוב.
נכון לעכשיו, אם לוחצים על הפעלת מסקנות, לא יקרה דבר. הסיבה לכך היא שעדיין לא ניתן לתקשר עם הקצה העורפי.
5. פריסת מודל לזיהוי אובייקטים עם בשירות TensorFlow
זיהוי אובייקטים הוא משימה נפוצה מאוד בלמידה חישובית, והמטרה שלו היא לזהות אובייקטים בתוך תמונות. כלומר, ניתן ליצור תחזיות לגבי קטגוריות אפשריות של אובייקטים ותיבות שמקיפות אותן. דוגמה לתוצאת חיפוש:
Google פרסמה מספר מודלים מוכנים מראש ב-TensorFlow Hub. כדי לראות את הרשימה המלאה, נכנסים לדף object_detection. אתם משתמשים במודל ה-SSD MobileNet V2 FPNLite 320x320 הפשוט יחסית, כדי שלא תצטרכו להשתמש ב-GPU כדי להפעיל אותו.
כדי לפרוס את המודל לזיהוי אובייקטים באמצעות הצגת TensorFlow:
- מורידים את קובץ הדגם.
- יש לבטל את הדחיסה של הקובץ
.tar.gz
שהורדת באמצעות כלי לדחיסת נתונים, כמו 7-Zip. - יצירת תיקייה של
ssd_mobilenet_v2_2_320
ולאחר מכן יצירת תיקיית משנה123
. - מעבירים את התיקייה
variables
ואת הקובץsaved_model.pb
לתיקיית המשנה של123
.
אפשר להתייחס לתיקייה ssd_mobilenet_v2_2_320
כתיקייה של SavedModel
. 123
הוא מספר גרסה לדוגמה. אפשר לבחור מספר אחר.
מבנה התיקייה אמור להיראות כך:
התחלת הצגה של TensorFlow
- במסוף, מפעילים את TensorFlow serving עם Docker, אבל מחליפים את ה-placeholder
PATH/TO/SAVEDMODEL
בנתיב המוחלט של תיקייתssd_mobilenet_v2_2_320
במחשב.
docker pull tensorflow/serving docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/ssd_mobilenet_v2_2" -e MODEL_NAME=ssd_mobilenet_v2_2 tensorflow/serving
אביזר העגינה מוריד תחילה באופן אוטומטי את התמונה של TensorFlow serving. לאחר מכן, ההצגה של TensorFlow אמורה להתחיל. היומן אמור להיראות כמו קטע קוד זה:
2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle. 2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz 2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123 2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds. 2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests 2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: ssd_mobilenet_v2_2 version: 123} 2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models 2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled 2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ... [warn] getaddrinfo: address family for nodename not supported 2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ... [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
6. חיבור האפליקציה ל-Android עם TensorFlow serving דרך REST
הקצה העורפי מוכן עכשיו, ואפשר לשלוח בקשות של לקוחות אל TensorFlow serving כדי לזהות אובייקטים בתמונות. יש שתי דרכים לשלוח בקשות ל-TensorFlow serving:
- REST
- gRPC
שליחת בקשות וקבלת תשובות באמצעות REST
יש שלושה שלבים פשוטים:
- יוצרים את הבקשה ל-REST.
- שליחת בקשת REST להגשה של TensorFlow.
- מחלצים את התוצאה החזויה מתגובת ה-REST ומעבדים את ממשק המשתמש.
אפשר לעשות זאת בעוד MainActivity.java.
יצירה של בקשת REST
כרגע יש פונקציה createRESTRequest()
ריקה בקובץ MainActivity.java
. הפונקציה מיישמת את הפונקציה הזו כדי ליצור בקשת REST.
private Request createRESTRequest() {
}
ל-TenororFlow יש בקשה לבקשת POST שמכילה את דגמת הטעינה של מודל SSD MobileNet שבו אתם משתמשים. לכן, עליכם לחלץ את ערכי ה-RGB מכל פיקסל של התמונה לתוך מערך, ולאחר מכן לשרטט את המערך בקובץ JSON, שהוא המטען הייעודי של הבקשה.
- צריך להוסיף את הקוד הזה לפונקציה
createRESTRequest()
:
//Create the REST request.
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
int[][][][] inputImgRGB = new int[1][INPUT_IMG_HEIGHT][INPUT_IMG_WIDTH][3];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
// Extract RBG values from each pixel; alpha is ignored
pixel = inputImg[i * INPUT_IMG_WIDTH + j];
inputImgRGB[0][i][j][0] = ((pixel >> 16) & 0xff);
inputImgRGB[0][i][j][1] = ((pixel >> 8) & 0xff);
inputImgRGB[0][i][j][2] = ((pixel) & 0xff);
}
}
RequestBody requestBody =
RequestBody.create("{\"instances\": " + Arrays.deepToString(inputImgRGB) + "}", JSON);
Request request =
new Request.Builder()
.url("http://" + SERVER + ":" + REST_PORT + "/v1/models/" + MODEL_NAME + ":predict")
.post(requestBody)
.build();
return request;
שליחת בקשת REST להצגת TensorFlow
האפליקציה מאפשרת למשתמש לבחור ב-REST או ב-gRPC כדי לתקשר עם TensorFlow serving, כך שיש שני סניפים ב-onClick(View view)
בהאזנה.
predictButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View view) {
if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
// TODO: REST request
}
else {
}
}
}
)
- מוסיפים את הקוד הזה לסניף של REST של הכלי להאזנה ל-
onClick(View view)
כדי להשתמש ב-OkHttp כדי לשלוח את הבקשה להצגה ב-TensorFlow:
// Send the REST request.
Request request = createRESTRequest();
try {
client =
new OkHttpClient.Builder()
.connectTimeout(20, TimeUnit.SECONDS)
.writeTimeout(20, TimeUnit.SECONDS)
.readTimeout(20, TimeUnit.SECONDS)
.callTimeout(20, TimeUnit.SECONDS)
.build();
Response response = client.newCall(request).execute();
JSONObject responseObject = new JSONObject(response.body().string());
postprocessRESTResponse(responseObject);
} catch (IOException | JSONException e) {
Log.e(TAG, e.getMessage());
responseTextView.setText(e.getMessage());
return;
}
עיבוד התגובה ל-REST מ-TensorFlow serving
מודל SSD MobileNet מחזיר מספר תוצאות, הכוללות:
num_detections
: מספר הזיהוייםdetection_scores
: ציוני זיהויdetection_classes
: האינדקס של סיווג הזיהויdetection_boxes
: הקואורדינטות של התיבה התוחמת
כדי לממש את התגובה צריך להטמיע את הפונקציה postprocessRESTResponse()
.
private void postprocessRESTResponse(Predict.PredictResponse response) {
}
- צריך להוסיף את הקוד הזה לפונקציה
postprocessRESTResponse()
:
// Process the REST response.
JSONArray predictionsArray = responseObject.getJSONArray("predictions");
//You only send one image, so you directly extract the first element.
JSONObject predictions = predictionsArray.getJSONObject(0);
// Argmax
int maxIndex = 0;
JSONArray detectionScores = predictions.getJSONArray("detection_scores");
for (int j = 0; j < predictions.getInt("num_detections"); j++) {
maxIndex =
detectionScores.getDouble(j) > detectionScores.getDouble(maxIndex + 1) ? j : maxIndex;
}
int detectionClass = predictions.getJSONArray("detection_classes").getInt(maxIndex);
JSONArray boundingBox = predictions.getJSONArray("detection_boxes").getJSONArray(maxIndex);
double ymin = boundingBox.getDouble(0);
double xmin = boundingBox.getDouble(1);
double ymax = boundingBox.getDouble(2);
double xmax = boundingBox.getDouble(3);
displayResult(detectionClass, (float) ymin, (float) xmin, (float) ymax, (float) xmax);
עכשיו, הפונקציה אחרי העיבוד מחלצת את הערכים החזויים מהתגובה, מזהה את הקטגוריה הסבירה ביותר של האובייקט ואת הקואורדינטות של קודצות התיבה התוחמת, ולבסוף מעבדת את תיבת התוחם של הזיהוי בממשק המשתמש.
הפעלה
- לוחצים על
הפעלת 'app' בתפריט הניווט ומחכים שהאפליקציה תיטען.
- בוחרים באפשרות REST > Run aמסיקה.
לוקח כמה שניות עד שהאפליקציה מעבדת את התיבה התוחמת של החתול, ומציגה את 17
כקטגוריה של האובייקט, שממופה לאובייקט cat
במערך הנתונים של COCO.
7. חיבור של אפליקציית Android עם TensorFlow serving דרך gRPC
בנוסף ל-REST, שירות TensorFlow תומך גם ב-gRPC.
gRPC היא מסגרת מודרנית ומודרנית של ביצועים טובים (RTT) שניתן להפעיל בכל סביבה. הוא יכול לחבר שירותים ביעילות למרכזי נתונים, ובהם בתמיכה, לאיזון עומסים, למעקב אחר נתונים, לבדיקת תקינות ולאימות. התגלה ש-gRPC מניב ביצועים טובים יותר בהשוואה ל-REST בפועל.
שליחה של בקשות וקבלת תשובות באמצעות gRPC
יש ארבעה שלבים פשוטים:
- [אופציונלי] יוצרים את קוד הלקוח של gRPC.
- יוצרים את בקשת ה-gRPC.
- שליחת בקשת gRPC לשירות TensorFlow.
- מחלצים את התוצאה החזויה מתגובת ה-gRPC ומעבדים את ממשק המשתמש.
אפשר לעשות זאת בעוד MainActivity.java.
אופציונלי: יצירת קוד stub של לקוח gRPC
כדי להשתמש ב-gRPC עם TensorFlow serving, יש לפעול לפי תהליך העבודה של gRPC. למידע נוסף, ניתן לעיין בתיעוד של gRPC.
TensorFlow הגשה ו-TensorFlow מגדירים את קובצי .proto
עבורך. נכון ל-TensorFlow ול-TensorFlow 2.8, אלו קובצי ה-.proto
הנדרשים:
tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto
tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto
- כדי ליצור את הגלויה, צריך להוסיף את הקוד הזה לקובץ
app/build.gradle
.
apply plugin: 'com.google.protobuf'
protobuf {
protoc { artifact = 'com.google.protobuf:protoc:3.11.0' }
plugins {
grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.29.0'
}
}
generateProtoTasks {
all().each { task ->
task.builtins {
java { option 'lite' }
}
task.plugins {
grpc { option 'lite' }
}
}
}
}
יצירת בקשת gRPC
בדומה לבקשת ה-REST, צריך ליצור את בקשת ה-gRPC בפונקציה createGRPCRequest()
.
private Request createGRPCRequest() {
}
- צריך להוסיף את הקוד הזה לפונקציה
createGRPCRequest()
:
if (stub == null) {
channel = ManagedChannelBuilder.forAddress(SERVER, GRPC_PORT).usePlaintext().build();
stub = PredictionServiceGrpc.newBlockingStub(channel);
}
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(MODEL_NAME);
modelSpecBuilder.setVersion(Int64Value.of(MODEL_VERSION));
modelSpecBuilder.setSignatureName(SIGNATURE_NAME);
Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
builder.setModelSpec(modelSpecBuilder);
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_UINT8);
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_HEIGHT));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_WIDTH));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
// Extract RBG values from each pixel; alpha is ignored.
pixel = inputImg[i * INPUT_IMG_WIDTH + j];
tensorProtoBuilder.addIntVal((pixel >> 16) & 0xff);
tensorProtoBuilder.addIntVal((pixel >> 8) & 0xff);
tensorProtoBuilder.addIntVal((pixel) & 0xff);
}
}
TensorProto tensorProto = tensorProtoBuilder.build();
builder.putInputs("input_tensor", tensorProto);
builder.addOutputFilter("num_detections");
builder.addOutputFilter("detection_boxes");
builder.addOutputFilter("detection_classes");
builder.addOutputFilter("detection_scores");
return builder.build();
שליחת בקשת gRPC לשירות TensorFlow
עכשיו אפשר לסיים את ההאזנה של onClick(View view)
.
predictButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View view) {
if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
}
else {
// TODO: gRPC request
}
}
}
)
- צריך להוסיף את הקוד לסניף של gRPC:
try {
Predict.PredictRequest request = createGRPCRequest();
Predict.PredictResponse response = stub.predict(request);
postprocessGRPCResponse(response);
} catch (Exception e) {
Log.e(TAG, e.getMessage());
responseTextView.setText(e.getMessage());
return;
}
מעבד את תגובת ה-gRPC מ-TensorFlow serving
בדומה ל-gRPC, עליך ליישם את הפונקציה postprocessGRPCResponse()
כדי לטפל בתגובה.
private void postprocessGRPCResponse(Predict.PredictResponse response) {
}
- צריך להוסיף את הקוד הזה לפונקציה
postprocessGRPCResponse()
:
// Process the response.
float numDetections = response.getOutputsMap().get("num_detections").getFloatValList().get(0);
List<Float> detectionScores = response.getOutputsMap().get("detection_scores").getFloatValList();
int maxIndex = 0;
for (int j = 0; j < numDetections; j++) {
maxIndex = detectionScores.get(j) > detectionScores.get(maxIndex + 1) ? j : maxIndex;
}
Float detectionClass = response.getOutputsMap().get("detection_classes").getFloatValList().get(maxIndex);
List<Float> boundingBoxValues = response.getOutputsMap().get("detection_boxes").getFloatValList();
float ymin = boundingBoxValues.get(maxIndex * 4);
float xmin = boundingBoxValues.get(maxIndex * 4 + 1);
float ymax = boundingBoxValues.get(maxIndex * 4 + 2);
float xmax = boundingBoxValues.get(maxIndex * 4 + 3);
displayResult(detectionClass.intValue(), ymin, xmin, ymax, xmax);
עכשיו פונקציית העיבוד החוזר יכולה לחלץ את הערכים החזויים מהתגובה ולעבד את תיבת התוחם של זיהוי המשתמש בממשק המשתמש.
הפעלה
- לוחצים על
הפעלת 'app' בתפריט הניווט ומחכים שהאפליקציה תיטען.
- יש לבחור באפשרות gRPC > Run aמסיקה.
יחלפו כמה שניות עד שהאפליקציה תעבד את התיבה התוחמת של החתול, ואז תוצג 17
בתור הקטגוריה של האובייקט, שממופה אל הקטגוריה cat
במערך הנתונים של COCO.