Tworzenie niestandardowego modelu klasyfikacji tekstu za pomocą narzędzia TensorFlow Lite Model Maker

1. Zanim zaczniesz

Z tego ćwiczenia w Codelabs dowiesz się, jak zaktualizować model klasyfikacji tekstu utworzony na podstawie oryginalnego zbioru danych blog-spam-comments, ale wzbogacony o Twoje komentarze, aby uzyskać model, który działa z Twoimi danymi.

Wymagania wstępne

Te warsztaty są częścią ścieżki Pierwsze kroki z klasyfikacją tekstu w aplikacjach na Fluttera. Ćwiczenia w Codelabs w tej ścieżce są ułożone w kolejności. Aplikacja i model, nad którymi będziesz pracować, powinny być już utworzone podczas wykonywania ćwiczeń z programowania. Jeśli nie udało Ci się jeszcze wykonać poprzednich działań, przerwij i zrób to teraz:

Czego się nauczysz

Czego potrzebujesz

  • aplikacja Flutter i model filtra spamu, które zostały utworzone w ramach poprzednich działań.

2. Ulepszanie klasyfikacji tekstu

  1. Kod możesz uzyskać, klonując to repozytorium i wczytując aplikację z folderu tfserving-flutter/codelab2/finished.
  2. Po uruchomieniu obrazu Dockera TensorFlow Serving w utworzonej aplikacji wpisz buy my book to learn online trading, a potem kliknij gRPC > Classify (gRPC > Klasyfikuj).

8f1e1974522f274d.png

Aplikacja generuje niski wynik spamu, ponieważ w oryginalnym zbiorze danych nie ma wielu wystąpień handlu online, a model nie nauczył się, że jest to spam. W tym samouczku zaktualizujesz model za pomocą nowych danych, aby model rozpoznawał to samo zdanie jako spam.

2bd68691a26aa3da.png

3. Edytowanie pliku CSV

Do wytrenowania pierwotnego modelu utworzono zbiór danych w formacie CSV (lmblog_comments.csv) zawierający prawie tysiąc komentarzy oznaczonych jako spam lub nie spam. (Jeśli chcesz sprawdzić plik CSV, otwórz go w dowolnym edytorze tekstu).

W pierwszym wierszu pliku CSV znajdują się opisy kolumn, które są oznaczone etykietami commenttextspam. Każdy kolejny wiersz ma ten format:

62025273971c9a7f.png

Etykieta po prawej stronie ma wartość true w przypadku spamu i false w przypadku braku spamu. Na przykład trzeci wiersz jest uznawany za spam.

Jeśli użytkownicy spamują Twoją witrynę wiadomościami o handlu online, możesz dodać przykłady takich komentarzy na dole witryny. Na przykład:

online trading can be highly highly effective,true
online trading can be highly effective,true
online trading now,true
online trading here,true
online trading for the win,true
  • Zapisz plik pod nową nazwą, np. lmblog_comments.csv, aby można było go użyć do trenowania nowego modelu.

W pozostałej części tego laboratorium będziesz używać podanego przykładu, który został zmodyfikowany i jest hostowany w Cloud Storage wraz z aktualizacjami dotyczącymi handlu online. Jeśli chcesz użyć własnego zbioru danych, możesz zmienić adres URL w kodzie.

4. Ponowne trenowanie modelu na podstawie nowych danych

Aby ponownie wytrenować model, możesz po prostu użyć kodu z (SpamCommentsModelMaker.ipynb), ale skierować go na nowy zbiór danych CSV o nazwie lmblog_comments_extras.csv. Jeśli chcesz uzyskać pełny notatnik ze zaktualizowaną zawartością, znajdziesz go pod adresem SpamCommentsUpdateModelMaker.ipynb.

Jeśli masz dostęp do Colaboratory, możesz uruchomić go bezpośrednio. W przeciwnym razie pobierz kod z repozytorium, a potem uruchom go w wybranym środowisku notatnika.

Zaktualizowany kod wygląda tak:

training_data = tf.keras.utils.get_file(fname='comments-spam-extras.csv',   
          origin='https://storage.googleapis.com/laurencemoroney-blog.appspot.com/
                  lmblog_comments_extras.csv', 
          extract=False)

Podczas trenowania zobaczysz, że model nadal osiąga wysoki poziom dokładności:

96a1547ddb6edf5b.png

Skompresuj cały folder /mm_update_spam_savedmodel i pobierz wygenerowany plik mm_update_spam_savedmodel.zip.

# Rename the SavedModel subfolder to a version number
!mv /mm_update_spam_savedmodel/saved_model /mm_update_spam_savedmodel/123
!zip -r mm_update_spam_savedmodel.zip /mm_update_spam_savedmodel/

5. Uruchom Dockera i zaktualizuj aplikację Flutter

  1. Rozpakuj pobrany plik mm_update_spam_savedmodel.zip do folderu, a następnie zatrzymaj instancję kontenera Docker z poprzedniego ćwiczenia i uruchom ją ponownie, ale zastąp obiekt zastępczy PATH/TO/UPDATE/SAVEDMODEL bezwzględną ścieżką do folderu, w którym znajdują się pobrane pliki:
docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/UPDATE/SAVEDMODEL:/models/spam-detection" -e MODEL_NAME=spam-detection tensorflow/serving
  1. Otwórz plik lib/main.dart w ulubionym edytorze kodu, a następnie znajdź część, która definiuje zmienne inputTensorName i outTensorName:
const inputTensorName = 'input_3';
const outputTensorName = 'dense_5';
  1. Przypisz zmiennej inputTensorName wartość „input_1'”, a zmiennej outputTensorName wartość 'dense_1':
const inputTensorName = 'input_1';
const outputTensorName = 'dense_1';
  1. Skopiuj pobrany plik vocab.txt do folderu lib/assets/, aby zastąpić istniejący plik.
  2. Ręcznie usuń aplikację Text Classification Flutter z emulatora Androida.
  3. Uruchom w terminalu polecenie 'flutter run', aby uruchomić aplikację.
  4. W aplikacji wpisz buy my book to learn online trading, a potem kliknij gRPC > Classify.

Model został ulepszony i wykrywa teraz jako spam wiadomości typu kup moją książkę, aby handlować online.

6. Gratulacje

Model został ponownie wytrenowany na podstawie nowych danych, zintegrowany z aplikacją Flutter i zaktualizowany, aby wykrywać nowe zdania spamowe.

Więcej informacji