Create a custom text-classification model with TensorFlow Lite Model Maker

1. Before you begin

In this codelab, you learn how to update the text-classification model built from the original blog-spam-comments dataset, but enhanced with comments of your own so that you can have a model that works with your data.

Prerequisites

This codelab is part of the Get started with text classification in Flutter apps pathway. The codelabs in this pathway are sequential. The app and the model you'll work on should have been built previously, while you were following along with the codelabs. If you haven't yet completed the previous activities, please stop and do so now:

What you'll learn

What you'll need

  • The Flutter app and spam-filter model that you observed and built in the previous activities.

2. Enhance text classification

  1. You can get the code for this code by cloning this repository and loading the app from the tfserving-flutter/codelab2/finished folder.
  2. After starting TensorFlow Serving Docker image, in the app that you built, enter buy my book to learn online trading and then click gRPC > Classify.

8f1e1974522f274d.png

The app generates a low spam score because there aren't many occurrences of online trading in the original dataset and the model hasn't learned that it's spam. In this codelab, you update the model with new data so that the model identifies the same sentence as spam!

2bd68691a26aa3da.png

3. Edit your CSV file

To train the original model, a dataset was created as a CSV (lmblog_comments.csv) containing almost a thousand comments labeled either spam or not spam. (Open the CSV in any text editor if you want to inspect it.)

The makeup of the CSV file is to have the first row describe the columns, which are labeled commenttext and spam. Every subsequent row follows this format:

62025273971c9a7f.png

The label to the right is assigned a true value for spam and a false value for not spam. For example, the third line is considered spam.

If people spam your website with messages about online trading, you can add examples of spam comments at the bottom of your website. For example:

online trading can be highly highly effective,true
online trading can be highly effective,true
online trading now,true
online trading here,true
online trading for the win,true
  • Save the file with a new name, such as lmblog_comments.csv, so that you can use it to train a new model.

For the rest of this codelab, you use the example provided, edited, and hosted on Cloud Storage with the online trading updates. If you want to use your own dataset, you can change the URL in the code.

4. Retrain the model with the new data

To retrain the model, you can simply reuse the code from (SpamCommentsModelMaker.ipynb), but point it at the new CSV dataset, which is called lmblog_comments_extras.csv. If you want the full notebook with the updated contents, you can find it as SpamCommentsUpdateModelMaker.ipynb.

If you have access to Colaboratory, you can launch it directly. Otherwise get the code from the repository and then run it in your notebook environment of choice.

The updated code looks like this code snippet:

training_data = tf.keras.utils.get_file(fname='comments-spam-extras.csv',   
          origin='https://storage.googleapis.com/laurencemoroney-blog.appspot.com/
                  lmblog_comments_extras.csv', 
          extract=False)

When you train, you should see that the model still trains to a high level of accuracy:

96a1547ddb6edf5b.png

Compress the entire folder of /mm_update_spam_savedmodel and down the generated mm_update_spam_savedmodel.zip file.

# Rename the SavedModel subfolder to a version number
!mv /mm_update_spam_savedmodel/saved_model /mm_update_spam_savedmodel/123
!zip -r mm_update_spam_savedmodel.zip /mm_update_spam_savedmodel/

5. Start Docker and update your Flutter App

  1. Unzip the downloaded mm_update_spam_savedmodel.zip file into a folder, and then stop the Docker container instance from the previous codelab and start it again, but replace the PATH/TO/UPDATE/SAVEDMODEL placeholder with the absolute path of the folder that hosts your downloaded files):
docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/UPDATE/SAVEDMODEL:/models/spam-detection" -e MODEL_NAME=spam-detection tensorflow/serving
  1. Open the lib/main.dart file with your favorite code editor and then find the part that defines the inputTensorName and outTensorName variables:
const inputTensorName = 'input_3';
const outputTensorName = 'dense_5';
  1. Reassign the inputTensorName variable to an ‘input_1' value and the outputTensorName variable to a 'dense_1' value:
const inputTensorName = 'input_1';
const outputTensorName = 'dense_1';
  1. Copy the vocab.txt file that you downloaded into the lib/assets/ folder to replace the existing one.
  2. Manually remove the Text Classification Flutter app from the Android emulator.
  3. Run the 'flutter run' command in your terminal to launch the app.
  4. In the app, enter buy my book to learn online trading and then click gRPC > Classify.

Now the model has improved to detect buy my book to online trading as spam.

6. Congratulations

You retrained the model with new data, integrated it with the Flutter app, and updated the functionality to detect new spam sentences!

Learn more