Creare un'app Android per rilevare oggetti all'interno delle immagini
Informazioni su questo codelab
1. Prima di iniziare
In questo codelab, imparerai a eseguire un'inferenza con rilevamento degli oggetti da un'app Android utilizzando TensorFlow Serving con REST e gRPC.
Prerequisiti
- Conoscenza di base dello sviluppo di Android con Java
- Conoscenza di base del machine learning con TensorFlow, ad esempio addestramento e deployment
- Conoscenza di base dei terminali e Docker
Obiettivi didattici
- Come trovare modelli di rilevamento di oggetti preaddestrati su TensorFlow Hub.
- Come creare un'app Android semplice e fare previsioni con il modello di rilevamento di oggetti scaricato tramite TensorFlow Serving (REST e gRPC).
- Come visualizzare il risultato del rilevamento nell'interfaccia utente.
Che cosa ti serve
- L'ultima versione di Android Studio
- Docker
- Palla
2. Configura
Per scaricare il codice per questo codelab:
- Accedi al repository GitHub per questo codelab.
- Fai clic su Codice > Scarica zip per scaricare tutto il codice di questo codelab.
- Decomprimi il file ZIP scaricato per decomprimere una cartella principale
codelabs
con tutte le risorse necessarie.
Per questo codelab, hai bisogno solo dei file nella sottodirectory TFServing/ObjectDetectionAndroid
del repository, che contiene due cartelle:
- La cartella
starter
contiene il codice di avvio che crei per questo codelab. - La cartella
finished
contiene il codice completato per l'app di esempio terminata.
3. Aggiungi le dipendenze al progetto
Importare l'app iniziale in Android Studio
- In Android Studio, fai clic su File > New > Import project (File > importa progetto) e scegli la cartella
starter
dal codice sorgente che hai scaricato in precedenza.
Aggiungi le dipendenze per OkHttp e gRPC
- Nel file
app/build.gradle
del progetto, conferma la presenza delle dipendenze.
dependencies {
// ...
implementation 'com.squareup.okhttp3:okhttp:4.9.0'
implementation 'javax.annotation:javax.annotation-api:1.3.2'
implementation 'io.grpc:grpc-okhttp:1.29.0'
implementation 'io.grpc:grpc-protobuf-lite:1.29.0'
implementation 'io.grpc:grpc-stub:1.29.0'
}
Sincronizzare il progetto con i file Gradle
- Seleziona
Sync Project with Gradle Files (Sincronizza progetto con file Gradle) dal menu di navigazione.
4. Esegui l'app iniziale
- Avvia l'emulatore Android, quindi fai clic su
Esegui "app' nel menu di navigazione.
Esegui ed esplora l'app
L'app dovrebbe essere avviata sul tuo dispositivo Android. L'interfaccia utente è piuttosto semplice: è presente un'immagine gatto in cui vuoi rilevare gli oggetti e l'utente può scegliere il modo in cui inviare i dati al backend, con REST o gRPC. Il backend esegue il rilevamento degli oggetti sull'immagine e restituisce i risultati di rilevamento all'app client, che esegue nuovamente il rendering dell'interfaccia utente.
Al momento, se fai clic su Esegui inferenza, non accade nulla. Questo perché non è ancora in grado di comunicare con il backend.
5. Esegui il deployment di un modello di rilevamento degli oggetti con TensorFlow Serving
Il rilevamento di oggetti è un'attività di machine learning molto comune e il suo obiettivo è rilevare gli oggetti all'interno delle immagini, in particolare per prevedere possibili categorie di oggetti e riquadri di delimitazione degli oggetti. Ecco un esempio di risultato di rilevamento:
Google ha pubblicato una serie di modelli preaddestrati su TensorFlow Hub. Per vedere l'elenco completo, visita la pagina object_detection. Utilizzi il modello relativamente SSD MobileNet V2 FPNLite 320x320 per questo codelab, quindi non è necessario utilizzare una GPU per eseguirla.
Per eseguire il deployment del modello di rilevamento oggetti con TensorFlow Serving:
- Scarica il file del modello.
- Decomprimi il file
.tar.gz
scaricato con uno strumento di decompressione, ad esempio 7-Zip. - Crea una cartella
ssd_mobilenet_v2_2_320
, quindi crea una sottocartella123
al suo interno. - Inserisci la cartella
variables
e il filesaved_model.pb
estratti nella sottocartella123
.
Puoi impostare la cartella ssd_mobilenet_v2_2_320
come cartella SavedModel
. 123
è un numero di versione di esempio. Se vuoi, puoi scegliere un altro numero.
La struttura delle cartelle dovrebbe essere simile a questa:
Avvia la pubblicazione TensorFlow
- Nel terminale, avvia TensorFlow Serving con Docker, ma sostituisci il segnaposto
PATH/TO/SAVEDMODEL
con il percorso assoluto della cartellassd_mobilenet_v2_2_320
sul tuo computer.
docker pull tensorflow/serving docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/ssd_mobilenet_v2_2" -e MODEL_NAME=ssd_mobilenet_v2_2 tensorflow/serving
Docker scarica automaticamente prima l'immagine di TensorFlow Serving, che richiede un minuto. In seguito, TensorFlow Serving dovrebbe iniziare. Il log dovrebbe avere il seguente aspetto:
2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle. 2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz 2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123 2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds. 2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests 2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: ssd_mobilenet_v2_2 version: 123} 2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models 2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled 2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ... [warn] getaddrinfo: address family for nodename not supported 2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ... [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
6. Collega l'app Android a TensorFlow Serving tramite REST
Il backend è pronto, quindi puoi inviare richieste client a TensorFlow Serving per rilevare gli oggetti all'interno delle immagini. Esistono due modi per inviare richieste a TensorFlow Serving:
- REST
- gRPC
Inviare richieste e ricevere risposte tramite REST
Ci sono tre semplici passaggi:
- Crea la richiesta REST.
- Invia la richiesta REST a TensorFlow Serving.
- Estrai il risultato previsto dalla risposta REST e visualizza l'interfaccia utente.
Li potrai raggiungere tra MainActivity.java.
Crea la richiesta REST
Attualmente c'è una funzione createRESTRequest()
vuota nel file MainActivity.java
. Questa funzione viene implementata per creare una richiesta REST.
private Request createRESTRequest() {
}
TensorFlow Serving prevede una richiesta POST che contiene il tensore di immagine per il modello SSD MobileNet utilizzato, quindi dovrai estrarre i valori RGB da ciascun pixel dell'immagine in un array e quindi aggregare l'array in un JSON, che è il payload della richiesta.
- Aggiungi questo codice alla funzione
createRESTRequest()
:
//Create the REST request.
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
int[][][][] inputImgRGB = new int[1][INPUT_IMG_HEIGHT][INPUT_IMG_WIDTH][3];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
// Extract RBG values from each pixel; alpha is ignored
pixel = inputImg[i * INPUT_IMG_WIDTH + j];
inputImgRGB[0][i][j][0] = ((pixel >> 16) & 0xff);
inputImgRGB[0][i][j][1] = ((pixel >> 8) & 0xff);
inputImgRGB[0][i][j][2] = ((pixel) & 0xff);
}
}
RequestBody requestBody =
RequestBody.create("{\"instances\": " + Arrays.deepToString(inputImgRGB) + "}", JSON);
Request request =
new Request.Builder()
.url("http://" + SERVER + ":" + REST_PORT + "/v1/models/" + MODEL_NAME + ":predict")
.post(requestBody)
.build();
return request;
Invia la richiesta REST a TensorFlow Serving
L'app consente all'utente di scegliere REST o gRPC per comunicare con TensorFlow Serving, quindi ci sono due rami nel listener onClick(View view)
.
predictButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View view) {
if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
// TODO: REST request
}
else {
}
}
}
)
- Aggiungi questo codice al ramo REST del listener
onClick(View view)
per utilizzare OkHttp per inviare la richiesta a TensorFlow Serving:
// Send the REST request.
Request request = createRESTRequest();
try {
client =
new OkHttpClient.Builder()
.connectTimeout(20, TimeUnit.SECONDS)
.writeTimeout(20, TimeUnit.SECONDS)
.readTimeout(20, TimeUnit.SECONDS)
.callTimeout(20, TimeUnit.SECONDS)
.build();
Response response = client.newCall(request).execute();
JSONObject responseObject = new JSONObject(response.body().string());
postprocessRESTResponse(responseObject);
} catch (IOException | JSONException e) {
Log.e(TAG, e.getMessage());
responseTextView.setText(e.getMessage());
return;
}
Elabora la risposta REST da TensorFlow Serving
Il modello SSD MobileNet restituisce una serie di risultati, tra cui:
num_detections
: il numero di rilevazionidetection_scores
: punteggi di rilevamentodetection_classes
: indice della classe di rilevamentodetection_boxes
: le coordinate del riquadro di delimitazione
Implementi la funzione postprocessRESTResponse()
per gestire la risposta.
private void postprocessRESTResponse(Predict.PredictResponse response) {
}
- Aggiungi questo codice alla funzione
postprocessRESTResponse()
:
// Process the REST response.
JSONArray predictionsArray = responseObject.getJSONArray("predictions");
//You only send one image, so you directly extract the first element.
JSONObject predictions = predictionsArray.getJSONObject(0);
// Argmax
int maxIndex = 0;
JSONArray detectionScores = predictions.getJSONArray("detection_scores");
for (int j = 0; j < predictions.getInt("num_detections"); j++) {
maxIndex =
detectionScores.getDouble(j) > detectionScores.getDouble(maxIndex + 1) ? j : maxIndex;
}
int detectionClass = predictions.getJSONArray("detection_classes").getInt(maxIndex);
JSONArray boundingBox = predictions.getJSONArray("detection_boxes").getJSONArray(maxIndex);
double ymin = boundingBox.getDouble(0);
double xmin = boundingBox.getDouble(1);
double ymax = boundingBox.getDouble(2);
double xmax = boundingBox.getDouble(3);
displayResult(detectionClass, (float) ymin, (float) xmin, (float) ymax, (float) xmax);
Ora la funzione di post-elaborazione estrae i valori previsti dalla risposta, identifica la categoria più probabile dell'oggetto e le coordinate dei vertici del riquadro di delimitazione, oltre a visualizzare il riquadro di delimitazione del rilevamento nell'interfaccia utente.
Esegui
- Fai clic su
Esegui App' nel menu di navigazione, quindi attendi il caricamento dell'app.
- Seleziona REST > Run inference.
Bastano pochi secondi prima che l'app esegua il rendering del riquadro di delimitazione del gatto e mostri 17
come categoria dell'oggetto, che viene mappato all'oggetto cat
nel set di dati COCO.
7. Collega l'app Android a TensorFlow Serving tramite gRPC
Oltre a REST, TensorFlow Serving supporta anche gRPC.
gRPC è un framework RPC (Open Procedure Call) moderno e open source eseguibile in qualsiasi ambiente. Può connettere i servizi in modo efficiente all'interno dei diversi data center e fornire un supporto collegabile per bilanciamento del carico, tracciamento, controllo di integrità e autenticazione. È stato osservato che gRPC ha un rendimento migliore rispetto a REST nella pratica.
Inviare richieste e ricevere risposte con gRPC
Sono disponibili quattro semplici passaggi:
- (Facoltativo) Genera il codice stub client gRPC.
- Crea la richiesta gRPC.
- Invia la richiesta gRPC a TensorFlow Serving.
- Estrai il risultato previsto dalla risposta gRPC e visualizza l'interfaccia utente.
Li potrai raggiungere tra MainActivity.java.
Facoltativo: genera il codice stub client gRPC
Per utilizzare gRPC con TensorFlow Serving, devi seguire il flusso di lavoro gRPC. Per saperne di più, consulta la documentazione di gRPC.
TensorFlow Serving e TensorFlow definiscono i file .proto
per tuo conto. In data TensorFlow e TensorFlow Serving 2.8, questi file .proto
sono quelli necessari:
tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto
tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto
- Per generare lo stub, aggiungi questo codice al file
app/build.gradle
.
apply plugin: 'com.google.protobuf'
protobuf {
protoc { artifact = 'com.google.protobuf:protoc:3.11.0' }
plugins {
grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.29.0'
}
}
generateProtoTasks {
all().each { task ->
task.builtins {
java { option 'lite' }
}
task.plugins {
grpc { option 'lite' }
}
}
}
}
Crea la richiesta gRPC
Come per la richiesta REST, puoi creare la richiesta gRPC nella funzione createGRPCRequest()
.
private Request createGRPCRequest() {
}
- Aggiungi questo codice alla funzione
createGRPCRequest()
:
if (stub == null) {
channel = ManagedChannelBuilder.forAddress(SERVER, GRPC_PORT).usePlaintext().build();
stub = PredictionServiceGrpc.newBlockingStub(channel);
}
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(MODEL_NAME);
modelSpecBuilder.setVersion(Int64Value.of(MODEL_VERSION));
modelSpecBuilder.setSignatureName(SIGNATURE_NAME);
Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
builder.setModelSpec(modelSpecBuilder);
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_UINT8);
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_HEIGHT));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_WIDTH));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
// Extract RBG values from each pixel; alpha is ignored.
pixel = inputImg[i * INPUT_IMG_WIDTH + j];
tensorProtoBuilder.addIntVal((pixel >> 16) & 0xff);
tensorProtoBuilder.addIntVal((pixel >> 8) & 0xff);
tensorProtoBuilder.addIntVal((pixel) & 0xff);
}
}
TensorProto tensorProto = tensorProtoBuilder.build();
builder.putInputs("input_tensor", tensorProto);
builder.addOutputFilter("num_detections");
builder.addOutputFilter("detection_boxes");
builder.addOutputFilter("detection_classes");
builder.addOutputFilter("detection_scores");
return builder.build();
Invia la richiesta gRPC a TensorFlow Serving
Ora puoi completare il listener onClick(View view)
.
predictButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View view) {
if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
}
else {
// TODO: gRPC request
}
}
}
)
- Aggiungi questo codice al ramo gRPC:
try {
Predict.PredictRequest request = createGRPCRequest();
Predict.PredictResponse response = stub.predict(request);
postprocessGRPCResponse(response);
} catch (Exception e) {
Log.e(TAG, e.getMessage());
responseTextView.setText(e.getMessage());
return;
}
Elabora la risposta gRPC da TensorFlow Serving
Come per gRPC, implementi la funzione postprocessGRPCResponse()
per gestire la risposta.
private void postprocessGRPCResponse(Predict.PredictResponse response) {
}
- Aggiungi questo codice alla funzione
postprocessGRPCResponse()
:
// Process the response.
float numDetections = response.getOutputsMap().get("num_detections").getFloatValList().get(0);
List<Float> detectionScores = response.getOutputsMap().get("detection_scores").getFloatValList();
int maxIndex = 0;
for (int j = 0; j < numDetections; j++) {
maxIndex = detectionScores.get(j) > detectionScores.get(maxIndex + 1) ? j : maxIndex;
}
Float detectionClass = response.getOutputsMap().get("detection_classes").getFloatValList().get(maxIndex);
List<Float> boundingBoxValues = response.getOutputsMap().get("detection_boxes").getFloatValList();
float ymin = boundingBoxValues.get(maxIndex * 4);
float xmin = boundingBoxValues.get(maxIndex * 4 + 1);
float ymax = boundingBoxValues.get(maxIndex * 4 + 2);
float xmax = boundingBoxValues.get(maxIndex * 4 + 3);
displayResult(detectionClass.intValue(), ymin, xmin, ymax, xmax);
Ora la funzione di post-elaborazione può estrarre i valori previsti dalla risposta e visualizzare il riquadro di delimitazione del rilevamento nell'interfaccia utente.
Esegui
- Fai clic su
Esegui App' nel menu di navigazione, quindi attendi il caricamento dell'app.
- Seleziona gRPC > Run inference.
Bastano pochi secondi prima che l'app esegua il rendering del riquadro di delimitazione del gatto e mostri 17
come categoria dell'oggetto, che mappa alla categoria cat
nel set di dati COCO.
8. Complimenti
Hai utilizzato TensorFlow Serving per aggiungere funzionalità di rilevamento oggetti alla tua app.