使用 TensorFlow Lite Model Maker 创建自定义文本分类模型

1. 准备工作

在此 Codelab 中,您将学习如何更新基于原始 blog-spam-comments 数据集构建的文本分类模型,该模型已使用您自己的评论加以增强,以便拥有一个可处理您的数据的模型。

前提条件

此 Codelab 是《Flutter 应用文本分类入门》在线课程的一部分。此在线课程中的 Codelab 是按顺序编排的。您将使用的应用和模型应该在之前您依次按照 Codelab 学习时已构建完毕。如果您还没有完成之前的活动,请立即停下来先完成这些活动

学习内容

所需条件

  • 您在之前的活动中观察和构建的 Flutter 应用和垃圾内容过滤器模型。

2. 增强文本分类

  1. 您可以通过克隆此代码库并从 tfserving-flutter/codelab2/finished 文件夹加载该应用来获取此代码。
  2. 启动 TensorFlow Serving Docker 映像后,在您构建的应用中,输入 buy my book to learn online trading,然后点击 gRPC > Classify(分类)

8f1e1974522f274d.png

该应用会生成较低的垃圾内容得分,因为原始数据集中没有许多在线交易并且模型不知道它是垃圾内容。在此 Codelab 中,您将使用新数据更新模型,以便模型将相同的句子识别为垃圾内容!

2bd68691a26aa3da.png

3. 修改 CSV 文件

为了训练原始模型,系统创建了一个 CSV 格式 (lmblog_comments.csv) 的数据集,其中包含近一千条标记为垃圾内容或非垃圾内容的评论。如果您想检查该 CSV 文件,请在任意文本编辑器中打开它。

在该 CSV 文件中,第一行描述列(标记为 commenttextspam)。后续的每一行都遵循以下格式:

62025273971c9a7f.png

对于垃圾内容和非垃圾内容,系统会为右侧的标签分别分配 true 值和 false 值。例如,第三行会被视为垃圾内容。

如果有人在您的网站上发布有关在线交易的垃圾信息,则您可以在网站底部添加垃圾评论样本。例如:

online trading can be highly highly effective,true
online trading can be highly effective,true
online trading now,true
online trading here,true
online trading for the win,true
  • 使用新名称(例如 lmblog_comments.csv)保存该文件,以便使用它来训练新模型。

在此 Codelab 的其余部分,您会将 Cloud Storage 上提供、修改和托管的样本与在线交易更新搭配使用。如果您想使用自己的数据集,可以更改代码中的网址。

4. 使用新数据重新训练模型

如需重新训练模型,您只需重复使用来自 SpamCommentsModelMaker.ipynb 的代码,但需要将其指向新的 CSV 数据集(称为 lmblog_comments_extras.csv)。如果您需要包含更新内容的完整笔记本,可以在 SpamCommentsUpdateModelMaker.ipynb 中找到它。

如果您有权访问 Colaboratory,则可以直接启动它。否则,请从代码库获取代码,然后在您选择的笔记本环境中运行它。

更新后的代码类似于以下代码段:

training_data = tf.keras.utils.get_file(fname='comments-spam-extras.csv',
          origin='https://storage.googleapis.com/laurencemoroney-blog.appspot.com/
                  lmblog_comments_extras.csv',
          extract=False)

在训练时,您应该看到模型仍会训练到较高级别的准确率:

96a1547ddb6edf5b.png

压缩整个 /mm_update_spam_savedmodel 文件夹,然后下载生成的 mm_update_spam_savedmodel.zip 文件。

# Rename the SavedModel subfolder to a version number
!mv /mm_update_spam_savedmodel/saved_model /mm_update_spam_savedmodel/123
!zip -r mm_update_spam_savedmodel.zip /mm_update_spam_savedmodel/

5. 启动 Docker 并更新您的 Flutter 应用

  1. 将下载的 mm_update_spam_savedmodel.zip 文件解压缩到某个文件夹中,然后停止来自上一个 Codelab 的 Docker 容器实例并再次启动,但需要将 PATH/TO/UPDATE/SAVEDMODEL 占位符替换为托管下载文件的文件夹的绝对路径:
docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/UPDATE/SAVEDMODEL:/models/spam-detection" -e MODEL_NAME=spam-detection tensorflow/serving
  1. 使用您最喜爱的代码编辑器打开 lib/main.dart 文件,然后找到定义 inputTensorNameoutTensorName 变量的部分:
const inputTensorName = 'input_3';
const outputTensorName = 'dense_5';
  1. 将 'input_1' 值重新赋给 inputTensorName 变量,将 'dense_1' 值重新赋给 outputTensorName 变量:
const inputTensorName = 'input_1';
const outputTensorName = 'dense_1';
  1. 将下载的 vocab.txt 文件复制到 lib/assets/ 文件夹中,以替换现有文件。
  2. 从 Android 模拟器中手动移除 Text Classification Flutter 应用。
  3. 在终端中运行 'flutter run' 命令以启动该应用。
  4. 在该应用中,输入 buy my book to learn online trading,然后点击 gRPC > Classify(分类)

现在,模型已改进,可将 buy my book to online trading 检测为垃圾内容。

6. 恭喜

您使用新数据重新训练了模型,将模型与 Flutter 应用进行了集成,并且更新了功能以检测新的垃圾句子!

了解详情