使用机器学习套件的自定义模型

默认情况下,机器学习套件的 API 会使用 Google 训练的机器学习模型。这些模型旨在支持各种应用。但是,某些用例需要更具针对性的模型。因此,有些机器学习套件 API 现在允许您用自定义 TensorFlow Lite 模型替换默认模型。

图片标签对象检测和跟踪 API 均支持自定义图片分类模型。它们与 TensorFlow Hub 上精选的大量预训练模型兼容,或者可与您自己的使用 TensorFlow、AutoML Vision Edge 或 TensorFlow Lite Model Maker 训练的自定义模型兼容。

如果您需要针对其他领域或用例的自定义解决方案,请访问设备端机器学习页面,获取 Google 设备端机器学习的所有解决方案和工具指南。

将机器学习套件与自定义模型结合使用的优势

将自定义图片分类模型与机器学习套件搭配使用的优势如下:

  • 易于使用的高级别 API - 无需处理低级别的模型输入/输出,无需处理图片预处理/后处理或构建处理流水线。
  • 您无需担心标签映射,机器学习套件会从 TFLite 模型元数据中提取标签,并为您执行映射。
  • 支持各种来源的自定义模型,包括在 TensorFlow Hub 上发布的预训练模型,以及使用 TensorFlow、AutoML Vision Edge 或 TensorFlow Lite Model Maker 训练的新模型。
  • 支持通过 Firebase 托管的模型。通过按需下载模型来缩减 APK 大小。无需重新发布应用即可推送模型更新,并使用 Firebase Remote Config 执行简单的 A/B 测试。
  • 针对与 Android 的 Camera API 的集成进行了优化。

尤其是对象检测和跟踪功能:

  • 通过先定位对象并仅在相关图片区域运行分类器来提高分类准确性
  • 在检测到对象和对对象进行分类时,立即向用户提供有关对象的反馈,从而提供实时互动体验

使用预训练的图片分类模型

您可以使用预训练的 TensorFlow Lite 模型,前提是这些模型满足一组条件。我们通过 TensorFlow Hub 提供了一组经过审核的模型,这些模型由 Google 或其他模型创建者提供,他们都符合条件。

使用 TensorFlow Hub 上发布的模型

TensorFlow Hub 提供来自各种模型创建者的大量预训练图片分类模型,这些模型可与图片标签以及对象检测和跟踪 API 搭配使用。请按以下步骤操作。

  1. 与机器学习套件兼容的模型集合中选择模型。
  2. 从模型详情页面下载 .tflite 模型文件。请尽可能选择包含元数据的模型格式。
  3. 如需了解如何将模型文件与您的项目捆绑并在 Android 或 iOS 应用中使用,请参阅 Image Labeling APIObject Detection and Tracking API 的相关指南。

训练您自己的图片分类模型

如果预训练的图片分类模型无法满足您的需求,您可以通过多种方式训练您自己的 TensorFlow Lite 模型,其中一些方法已在下文中概述和讨论。

训练您自己的图片分类模型的选项
AutoML Vision Edge
  • 通过 Google Cloud AI 提供
  • 创建先进的图片分类模型
  • 轻松地评估性能和尺寸
TensorFlow Lite Model Maker
  • 与从头开始训练模型相比,重新训练模型(迁移学习)所需的时间更短、需要的数据更少
将 TensorFlow 模型转换为 TensorFlow Lite
  • 使用 TensorFlow 训练模型,然后将其转换为 TensorFlow Lite

AutoML Vision Edge

图片标签对象检测和跟踪 API API 中的自定义模型支持使用 AutoML Vision Edge 训练的图片分类模型。这些 API 还支持下载使用 Firebase 模型部署托管的模型。

如需详细了解如何在 Android 和 iOS 应用中使用通过 AutoML Vision Edge 训练的模型,请根据您的用例根据每个 API 的自定义模型指南执行操作。

TensorFlow Lite 模型制作工具

在为设备端机器学习应用部署此模型时,TFLite Model Maker 库简化了调整 TensorFlow 神经网络模型并将其转换为特定输入数据的过程。您可以参考使用 TensorFlow Lite Model Maker 对 Colab 进行图片分类

如需详细了解如何在 Android 和 iOS 应用中使用 Model Maker 训练的模型,请根据您的用例,参阅我们的 Image Labeling APIObject Detection and Tracking API 指南。

使用 TensorFlow Lite 转换器创建的模型

如果您已有 TensorFlow 图片分类模型,可以使用 TensorFlow Lite 转换器转换该模型。请确保创建的模型符合以下兼容性要求。

如需详细了解如何在 Android 和 iOS 应用中使用 TensorFlow Lite 模型,请根据您的用例,参阅我们的 Image Labeling APIObject Detection and Tracking API 指南。

TensorFlow Lite 模型兼容性

您可以使用任何预训练的 TensorFlow Lite 图片分类模型,前提是模型满足以下要求:

张量

  • 模型只能有一个具有以下约束的输入张量:
    • 数据采用 RGB 像素格式。
    • 数据为 UINT8 或 FLOAT32 类型。如果输入张量类型为 FLOAT32,则必须通过附加元数据指定 normalizationOptions。
    • 该张量有 4 个维度:BxHxWxC,其中:
      • B 是批次大小。它必须是 1(不支持较大批次的推断)。
      • W 和 H 是输入宽度和高度。
      • C 是预期的渠道数量。该值必须为 3。
  • 模型必须具有至少一个包含 N 类以及 2 个或 4 个维度的输出张量:
    • (1xN)
    • (1x1x1xN)

元数据

向 TensorFlow Lite 模型添加元数据中所述,您可以向 TensorFlow Lite 文件添加元数据。

如需使用具有 FLOAT32 输入张量的模型,您必须在元数据中指定 normalizationOptions

此外,我们还建议您将此元数据附加到输出张量 TensorMetadata