Label images with a custom model on Android

You can use ML Kit to recognize entities in an image and label them. This API supports a wide range of custom image classification models. Please refer to Custom models with ML Kit for guidance on model compatibility requirements, where to find pre-trained models, and how to train your own models.

There are two ways to integrate a custom model. You can bundle the model by putting it inside your app’s asset folder, or you can dynamically download it from Firebase. The following table compares the two options.

Bundled Model Hosted Model
The model is part of your app's APK, which increases its size. The model is not part your APK. It is hosted by uploading to Firebase Machine Learning.
The model is available immediately, even when the Android device is offline The model is downloaded on demand
No need for a Firebase project Requires a Firebase project
You must republish your app to update the model Push model updates without republishing your app
No built-in A/B testing Easy A/B testing with Firebase Remote Config

See the ML Kit Vision quickstart sample or the ML Kit AutoML quickstart sample on GitHub for examples of this API in use.

Before you begin

  1. In your project-level build.gradle file, make sure to include Google's Maven repository in both your buildscript and allprojects sections.

  2. Add the dependencies for the ML Kit Android libraries to your module's app-level gradle file, which is usually app/build.gradle:

    For bundling a model with your app:

    dependencies {
      // ...
      // Image labeling feature with custom bundled model
      implementation 'com.google.mlkit:image-labeling-custom:16.3.1'
    }
    

    For dynamically downloading a model from Firebase, add the linkFirebase dependency:

    dependencies {
      // ...
      // Image labeling feature with model downloaded
      // from firebase
      implementation 'com.google.mlkit:image-labeling-custom:16.3.1'
      implementation 'com.google.mlkit:linkfirebase:16.1.1'
    }
    
  3. If you want to download a model, make sure you add Firebase to your Android project, if you have not already done so. This is not required when you bundle the model.

1. Load the model

Configure a local model source

To bundle the model with your app:

  1. Copy the model file (usually ending in .tflite or .lite) to your app's assets/ folder. (You might need to create the folder first by right-clicking the app/ folder, then clicking New > Folder > Assets Folder.)

  2. Then, add the following to your app's build.gradle file to ensure Gradle doesn’t compress the model file when building the app:

    android {
        // ...
        aaptOptions {
            noCompress "tflite"
            // or noCompress "lite"
        }
    }
    

    The model file will be included in the app package and available to ML Kit as a raw asset.

  3. Create LocalModel object, specifying the path to the model file:

    Kotlin

    val localModel = LocalModel.Builder()
            .setAssetFilePath("model.tflite")
            // or .setAbsoluteFilePath(absolute file path to model file)
            // or .setUri(URI to model file)
            .build()

    Java

    LocalModel localModel =
        new LocalModel.Builder()
            .setAssetFilePath("model.tflite")
            // or .setAbsoluteFilePath(absolute file path to model file)
            // or .setUri(URI to model file)
            .build();

Configure a Firebase-hosted model source

To use the remotely-hosted model, create a RemoteModel object by FirebaseModelSource, specifying the name you assigned the model when you published it:

Kotlin

// Specify the name you assigned in the Firebase console.
val remoteModel =
    CustomRemoteModel
        .Builder(FirebaseModelSource.Builder("your_model_name").build())
        .build()

Java

// Specify the name you assigned in the Firebase console.
CustomRemoteModel remoteModel =
    new CustomRemoteModel
        .Builder(new FirebaseModelSource.Builder("your_model_name").build())
        .build();

Then, start the model download task, specifying the conditions under which you want to allow downloading. If the model isn't on the device, or if a newer version of the model is available, the task will asynchronously download the model from Firebase:

Kotlin

val downloadConditions = DownloadConditions.Builder()
    .requireWifi()
    .build()
RemoteModelManager.getInstance().download(remoteModel, downloadConditions)
    .addOnSuccessListener {
        // Success.
    }

Java

DownloadConditions downloadConditions = new DownloadConditions.Builder()
        .requireWifi()
        .build();
RemoteModelManager.getInstance().download(remoteModel, downloadConditions)
        .addOnSuccessListener(new OnSuccessListener() {
            @Override
            public void onSuccess(@NonNull Task task) {
                // Success.
            }
        });

Many apps start the download task in their initialization code, but you can do so at any point before you need to use the model.

Configure the image labeler

After you configure your model sources, create an ImageLabeler object from one of them.

The following options are available:

Options
confidenceThreshold

Minimum confidence score of detected labels. If not set, any classifier threshold specified by the model’s metadata will be used. If the model does not contain any metadata or the metadata does not specify a classifier threshold, a default threshold of 0.0 will be used.

maxResultCount

Maximum number of labels to return. If not set, the default value of 10 will be used.

If you only have a locally-bundled model, just create a labeler from your LocalModel object:

Kotlin

val customImageLabelerOptions = CustomImageLabelerOptions.Builder(localModel)
    .setConfidenceThreshold(0.5f)
    .setMaxResultCount(5)
    .build()
val labeler = ImageLabeling.getClient(customImageLabelerOptions)

Java

CustomImageLabelerOptions customImageLabelerOptions =
        new CustomImageLabelerOptions.Builder(localModel)
            .setConfidenceThreshold(0.5f)
            .setMaxResultCount(5)
            .build();
ImageLabeler labeler = ImageLabeling.getClient(autoMLImageLabelerOptions);

If you have a remotely-hosted model, you will have to check that it has been downloaded before you run it. You can check the status of the model download task using the model manager's isModelDownloaded() method.

Although you only have to confirm this before running the labeler, if you have both a remotely-hosted model and a locally-bundled model, it might make sense to perform this check when instantiating the image labeler: create a labeler from the remote model if it's been downloaded, and from the local model otherwise.

Kotlin

RemoteModelManager.getInstance().isModelDownloaded(remoteModel)
    .addOnSuccessListener { isDownloaded ->
    val optionsBuilder =
        if (isDownloaded) {
            CustomImageLabelerOptions.Builder(remoteModel)
        } else {
            CustomImageLabelerOptions.Builder(localModel)
        }
    val options = optionsBuilder
                  .setConfidenceThreshold(0.5f)
                  .setMaxResultCount(5)
                  .build()
    val labeler = ImageLabeling.getClient(options)
}

Java

RemoteModelManager.getInstance().isModelDownloaded(remoteModel)
        .addOnSuccessListener(new OnSuccessListener() {
            @Override
            public void onSuccess(Boolean isDownloaded) {
                CustomImageLabelerOptions.Builder optionsBuilder;
                if (isDownloaded) {
                    optionsBuilder = new CustomImageLabelerOptions.Builder(remoteModel);
                } else {
                    optionsBuilder = new CustomImageLabelerOptions.Builder(localModel);
                }
                CustomImageLabelerOptions options = optionsBuilder
                    .setConfidenceThreshold(0.5f)
                    .setMaxResultCount(5)
                    .build();
                ImageLabeler labeler = ImageLabeling.getClient(options);
            }
        });

If you only have a remotely-hosted model, you should disable model-related functionality—for example, grey-out or hide part of your UI—until you confirm the model has been downloaded. You can do so by attaching a listener to the model manager's download() method:

Kotlin

RemoteModelManager.getInstance().download(remoteModel, conditions)
    .addOnSuccessListener {
        // Download complete. Depending on your app, you could enable the ML
        // feature, or switch from the local model to the remote model, etc.
    }

Java

RemoteModelManager.getInstance().download(remoteModel, conditions)
        .addOnSuccessListener(new OnSuccessListener() {
            @Override
            public void onSuccess(Void v) {
              // Download complete. Depending on your app, you could enable
              // the ML feature, or switch from the local model to the remote
              // model, etc.
            }
        });

2. Prepare the input image

Then, for each image you want to label, create an InputImage object from your image. The image labeler runs fastest when you use a Bitmap or, if you use the camera2 API, a YUV_420_888 media.Image, which are recommended when possible.

You can create an InputImage from different sources, each is explained below.

Using a media.Image

To create an InputImage object from a media.Image object, such as when you capture an image from a device's camera, pass the media.Image object and the image's rotation to InputImage.fromMediaImage().

If you use the CameraX library, the OnImageCapturedListener and ImageAnalysis.Analyzer classes calculate the rotation value for you.

Kotlin

private class YourImageAnalyzer : ImageAnalysis.Analyzer {

    override fun analyze(imageProxy: ImageProxy) {
        val mediaImage = imageProxy.image
        if (mediaImage != null) {
            val image = InputImage.fromMediaImage(mediaImage, imageProxy.imageInfo.rotationDegrees)
            // Pass image to an ML Kit Vision API
            // ...
        }
    }
}

Java

private class YourAnalyzer implements ImageAnalysis.Analyzer {

    @Override
    public void analyze(ImageProxy imageProxy) {
        Image mediaImage = imageProxy.getImage();
        if (mediaImage != null) {
          InputImage image =
                InputImage.fromMediaImage(mediaImage, imageProxy.getImageInfo().getRotationDegrees());
          // Pass image to an ML Kit Vision API
          // ...
        }
    }
}

If you don't use a camera library that gives you the image's rotation degree, you can calculate it from the device's rotation degree and the orientation of camera sensor in the device:

Kotlin

private val ORIENTATIONS = SparseIntArray()

init {
    ORIENTATIONS.append(Surface.ROTATION_0, 0)
    ORIENTATIONS.append(Surface.ROTATION_90, 90)
    ORIENTATIONS.append(Surface.ROTATION_180, 180)
    ORIENTATIONS.append(Surface.ROTATION_270, 270)
}

/**
 * Get the angle by which an image must be rotated given the device's current
 * orientation.
 */
@RequiresApi(api = Build.VERSION_CODES.LOLLIPOP)
@Throws(CameraAccessException::class)
private fun getRotationCompensation(cameraId: String, activity: Activity, isFrontFacing: Boolean): Int {
    // Get the device's current rotation relative to its "native" orientation.
    // Then, from the ORIENTATIONS table, look up the angle the image must be
    // rotated to compensate for the device's rotation.
    val deviceRotation = activity.windowManager.defaultDisplay.rotation
    var rotationCompensation = ORIENTATIONS.get(deviceRotation)

    // Get the device's sensor orientation.
    val cameraManager = activity.getSystemService(CAMERA_SERVICE) as CameraManager
    val sensorOrientation = cameraManager
            .getCameraCharacteristics(cameraId)
            .get(CameraCharacteristics.SENSOR_ORIENTATION)!!

    if (isFrontFacing) {
        rotationCompensation = (sensorOrientation + rotationCompensation) % 360
    } else { // back-facing
        rotationCompensation = (sensorOrientation - rotationCompensation + 360) % 360
    }
    return rotationCompensation
}

Java

private static final SparseIntArray ORIENTATIONS = new SparseIntArray();
static {
    ORIENTATIONS.append(Surface.ROTATION_0, 0);
    ORIENTATIONS.append(Surface.ROTATION_90, 90);
    ORIENTATIONS.append(Surface.ROTATION_180, 180);
    ORIENTATIONS.append(Surface.ROTATION_270, 270);
}

/**
 * Get the angle by which an image must be rotated given the device's current
 * orientation.
 */
@RequiresApi(api = Build.VERSION_CODES.LOLLIPOP)
private int getRotationCompensation(String cameraId, Activity activity, boolean isFrontFacing)
        throws CameraAccessException {
    // Get the device's current rotation relative to its "native" orientation.
    // Then, from the ORIENTATIONS table, look up the angle the image must be
    // rotated to compensate for the device's rotation.
    int deviceRotation = activity.getWindowManager().getDefaultDisplay().getRotation();
    int rotationCompensation = ORIENTATIONS.get(deviceRotation);

    // Get the device's sensor orientation.
    CameraManager cameraManager = (CameraManager) activity.getSystemService(CAMERA_SERVICE);
    int sensorOrientation = cameraManager
            .getCameraCharacteristics(cameraId)
            .get(CameraCharacteristics.SENSOR_ORIENTATION);

    if (isFrontFacing) {
        rotationCompensation = (sensorOrientation + rotationCompensation) % 360;
    } else { // back-facing
        rotationCompensation = (sensorOrientation - rotationCompensation + 360) % 360;
    }
    return rotationCompensation;
}

Then, pass the media.Image object and the rotation degree value to InputImage.fromMediaImage():

Kotlin

val image = InputImage.fromMediaImage(mediaImage, rotation)

Java

InputImage image = InputImage.fromMediaImage(mediaImage, rotation);

Using a file URI

To create an InputImage object from a file URI, pass the app context and file URI to InputImage.fromFilePath(). This is useful when you use an ACTION_GET_CONTENT intent to prompt the user to select an image from their gallery app.

Kotlin

val image: InputImage
try {
    image = InputImage.fromFilePath(context, uri)
} catch (e: IOException) {
    e.printStackTrace()
}

Java

InputImage image;
try {
    image = InputImage.fromFilePath(context, uri);
} catch (IOException e) {
    e.printStackTrace();
}

Using a ByteBuffer or ByteArray

To create an InputImage object from a ByteBuffer or a ByteArray, first calculate the image rotation degree as previously described for media.Image input. Then, create the InputImage object with the buffer or array, together with image's height, width, color encoding format, and rotation degree:

Kotlin

val image = InputImage.fromByteBuffer(
        byteBuffer,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
)
// Or:
val image = InputImage.fromByteArray(
        byteArray,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
)

Java

InputImage image = InputImage.fromByteBuffer(byteBuffer,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
);
// Or:
InputImage image = InputImage.fromByteArray(
        byteArray,
        /* image width */480,
        /* image height */360,
        rotation,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
);

Using a Bitmap

To create an InputImage object from a Bitmap object, make the following declaration:

Kotlin

val image = InputImage.fromBitmap(bitmap, 0)

Java

InputImage image = InputImage.fromBitmap(bitmap, rotationDegree);

The image is represented by a Bitmap object together with rotation degrees.

3. Run the image labeler

To label objects in an image, pass the image object to the ImageLabeler's process() method.

Kotlin

labeler.process(image)
        .addOnSuccessListener { labels ->
            // Task completed successfully
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Java

labeler.process(image)
        .addOnSuccessListener(new OnSuccessListener<List<ImageLabel>>() {
            @Override
            public void onSuccess(List<ImageLabel> labels) {
                // Task completed successfully
                // ...
            }
        })
        .addOnFailureListener(new OnFailureListener() {
            @Override
            public void onFailure(@NonNull Exception e) {
                // Task failed with an exception
                // ...
            }
        });

4. Get information about labeled entities

If the image labeling operation succeeds, a list of ImageLabel objects is passed to the success listener. Each ImageLabel object represents something that was labeled in the image. You can get each label's text description (if available in the metadata of the TensorFlow Lite model file), confidence score, and index. For example:

Kotlin

for (label in labels) {
    val text = label.text
    val confidence = label.confidence
    val index = label.index
}

Java

for (ImageLabel label : labels) {
    String text = label.getText();
    float confidence = label.getConfidence();
    int index = label.getIndex();
}

Tips to improve real-time performance

If you want to label images in a real-time application, follow these guidelines to achieve the best frame rates:

  • If you use the Camera or camera2 API, throttle calls to the image labeler. If a new video frame becomes available while the image labeler is running, drop the frame. See the VisionProcessorBase class in the quickstart sample app for an example.
  • If you use the CameraX API, be sure that backpressure strategy is set to its default value ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST. This guarantees only one image will be delivered for analysis at a time. If more images are produced when the analyzer is busy, they will be dropped automatically and not queued for delivery. Once the image being analyzed is closed by calling ImageProxy.close(), the next latest image will be delivered.
  • If you use the output of the image labeler to overlay graphics on the input image, first get the result from ML Kit, then render the image and overlay in a single step. This renders to the display surface only once for each input frame. See the CameraSourcePreview and GraphicOverlay classes in the quickstart sample app for an example.
  • If you use the Camera2 API, capture images in ImageFormat.YUV_420_888 format. If you use the older Camera API, capture images in ImageFormat.NV21 format.

Next steps