mediapipe_model_maker.text_classifier.BertHParams
Stay organized with collections
Save and categorize content based on your preferences.
The hyperparameters for a Bert Classifier.
Inherits From: BaseHParams
mediapipe_model_maker.text_classifier.BertHParams(
learning_rate: float = 3e-05,
batch_size: int = 48,
epochs: int = 2,
steps_per_epoch: Optional[int] = None,
class_weights: Optional[Mapping[int, float]] = None,
shuffle: bool = False,
repeat: bool = False,
export_dir: str = tempfile.mkdtemp(),
distribution_strategy: str = 'off',
num_gpus: int = 0,
tpu: str = '',
end_learning_rate: float = 0.0,
optimizer: mediapipe_model_maker.text_classifier.BertOptimizer = mediapipe_model_maker.text_classifier.BertHParams.optimizer,
weight_decay: float = 0.01,
desired_precisions: Sequence[float] = dataclasses.field(default_factory=list),
desired_recalls: Sequence[float] = dataclasses.field(default_factory=list),
gamma: float = 2.0,
tokenizer: mediapipe_model_maker.text_classifier.SupportedBertTokenizers = mediapipe_model_maker.text_classifier.BertHParams.tokenizer,
checkpoint_frequency: int = 0
)
Attributes |
learning_rate
|
Learning rate to use for gradient descent training.
|
end_learning_rate
|
End learning rate for linear decay. Defaults to 0.
|
batch_size
|
Batch size for training. Defaults to 48.
|
epochs
|
Number of training iterations over the dataset. Defaults to 2.
|
optimizer
|
Optimizer to use for training. Supported values are defined in
BertOptimizer enum: ADAMW and LAMB.
|
weight_decay
|
Weight decay of the optimizer. Defaults to 0.01.
|
desired_precisions
|
If specified, adds a RecallAtPrecision metric per
desired_precisions[i] entry which tracks the recall given the constraint
on precision. Only supported for binary classification.
|
desired_recalls
|
If specified, adds a PrecisionAtRecall metric per
desired_recalls[i] entry which tracks the precision given the constraint
on recall. Only supported for binary classification.
|
gamma
|
Gamma parameter for focal loss. To use cross entropy loss, set this
value to 0. Defaults to 2.0.
|
tokenizer
|
Tokenizer to use for preprocessing. Must be one of the enum
options of SupportedBertTokenizers. Defaults to FULL_TOKENIZER.
|
checkpoint_frequency
|
Frequency(in epochs) of saving checkpoints during
training. Defaults to 0 which does not save training checkpoints.
|
steps_per_epoch
|
Dataclass field
|
class_weights
|
Dataclass field
|
shuffle
|
Dataclass field
|
repeat
|
Dataclass field
|
export_dir
|
Dataclass field
|
distribution_strategy
|
Dataclass field
|
num_gpus
|
Dataclass field
|
tpu
|
Dataclass field
|
Methods
get_strategy
View source
get_strategy()
__eq__
__eq__(
other
)
Class Variables |
|
batch_size
|
48
|
|
checkpoint_frequency
|
0
|
|
class_weights
|
None
|
|
distribution_strategy
|
'off'
|
|
end_learning_rate
|
0.0
|
|
epochs
|
2
|
|
export_dir
|
'/tmpfs/tmp/tmpnt_h4p9w'
|
|
gamma
|
2.0
|
|
learning_rate
|
3e-05
|
|
num_gpus
|
0
|
|
optimizer
|
<BertOptimizer.ADAMW: 'adamw'>
|
|
repeat
|
False
|
|
shuffle
|
False
|
|
steps_per_epoch
|
None
|
|
tokenizer
|
<SupportedBertTokenizers.FULL_TOKENIZER: 'fulltokenizer'>
|
|
tpu
|
''
|
|
weight_decay
|
0.01
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2026-05-28 UTC.
[null,null,["Last updated 2026-05-28 UTC."],[],[]]