LLM on Android with Keras and TensorFlow Lite
Train and deploy your own large language model (LLM) on Android using Keras and TensorFlow Lite.
Go back
Overview: LLM on Android with Keras and TensorFlow Lite
Large language models (LLMs) have revolutionized tasks like text generation and language translation, and provide accurate, automated support that closely resembles human writing. LLMs are trained on massive amounts of text data which allows them to learn the statistical patterns and relationships between words and phrases.
Because of their substantial storage requirements and computing demands, LLMs typically require cloud deployment. While on-device machine learning (ODML) can be challenging, smaller-scale LLMs like GPT-2 can be effectively run on modern Android devices and deliver impressive performance.
This pathway shows you how to train and deploy your own large language model on Android. It features a video with developer comments, steps that guide you through the process, definitions of commonly used terms, and links to the final code.
On-device large language models with Keras and Android
Watch this video to learn how to load a large language model (LLM) built with Keras, optimize it, and deploy it on your Android device.
Problem framing
Problem framing is the process of analyzing a problem to isolate the individual elements that need to be addressed to solve it. It helps determine your project's technical feasibility and provides a clear set of goals and success criteria. When considering an ML solution, effective problem framing can determine whether or not your product ultimately succeeds. This process involves two distinct steps: determining whether ML is the best approach and framing the problem in ML terms.
You need the following to implement a model that generates text in response to user prompts for a mobile application:
- A powerful model that generates sophisticated responses in a desired writing style.
- A model that can be run on mobile devices.
- A model that is fast and lightweight.
In this pathway, we're going to finetune a pre-trained GPT-2 model.
In this pathway, we're going to convert the model to TensorFlow Lite .
In this pathway, we're going to quantize the model to reduce size and latency.
Dataset search
Datasets are collections of raw data, commonly (but not exclusively) organized in either a spreadsheet or a comma-separated values (CSV) file.
You can use the Dataset Search tool to find datasets that are suitable for your project. You can finetune the pre-trained model so that it generates text in a specific style based on your dataset.
In this pathway, we are using the
cnn_dailymail dataset
, which we downloaded using TensorFlow Datasets.
Feature engineering
Feature engineering is a process that involves determining which features might be useful in training a model, and converting raw data from the dataset into efficient versions of those features. For example, you might determine that temperature might be a useful feature. Then, you might experiment with bucketing to optimize what the model can learn from different temperature ranges.
TensorFlow's tf.data API enables you to build complex input pipelines from simple, reusable
pieces. For example, the pipeline for an image model might aggregate data from files in a
distributed file system, apply random perturbations to each image, and merge randomly selected
images into a batch for training. Then, you can create a TensorFlow dataset and run
preprocessing on it. You can use the tf.data API to map the gpt2_preprocessor
against the training dataset, as shown in the code sample for this step.
TensorFlow Text can perform the preprocessing regularly required by text-based models, and includes other features useful for sequence modeling not provided by core TensorFlow. TensorFlow Text is used inside the KerasNLP models for tokenization. It maps words to token IDs that can be further processed by the backbone models in KerasNLP.
If you are new to NLP, KerasNLP is a great place to start. It's a natural language processing library that's easy to use and provides a wide range of features such as the following:
- Supports users through their entire development cycle.
- Provides a high-level API for building NLP models.
- Includes a variety of pre-trained models and modules .
- Includes tokenizers and preprocessing layers.
Modeling
A model is, generally, any mathematical construct that processes input data and returns output. Phrased differently, a model is the set of parameters and structure needed for a system to make predictions. In supervised machine learning, a model takes an example as input and infers a prediction as output, though models differ somewhat. Unsupervised machine learning also generates models, typically a function that can map an input example to the most appropriate cluster.
Large language models (LLMs) are a type of machine learning models that are trained on a large corpus of text data and are good at text generation problems. They are complex to build and expensive to train from scratch, but there are pre-trained LLMs available for use right away.
KerasNLP is a natural language processing library that supports users through their entire development cycle, offering pre-trained models, tokenizers, and preprocessing layers.
The TensorFlow Model Optimization Toolkit is a suite of tools that minimizes the complexity of optimizing ML models for deployment and execution. Inference efficiency is a critical concern when deploying machine learning models because of latency and memory utilization. The TensorFlow Model Optimization Toolkit addresses these challenges and improves the optimization process, as shown in the code sample for this step.
Deployment
Deployment in ML refers to integrating your model into an existing production environment where it can take in input and return output. Now, the model can generate text of the intended style.
But this runs on a powerful GPU sitting in the cloud. How can we leverage it on a mobile device? You can build a backend service using KerasNLP and send requests from mobile devices to the endpoint, but the best practice is to run the model purely on the device using on-device machine learning (ODML) with TensorFlow Lite .