Les ensembles de données très volumineux peuvent ne pas tenir dans la mémoire allouée à votre processus. Dans
aux étapes précédentes, nous avons créé un pipeline qui récupère l'intégralité de l'ensemble de données
dans la mémoire, préparer les données et transmettre l'ensemble de travail
. À la place, Keras propose une autre fonction d'entraînement
(fit_generator
)
qui extrait les données par lot. Cela nous permet d'appliquer les transformations
le pipeline de données qu'à une petite partie des données (un multiple de batch_size
).
Au cours de nos tests, nous avons utilisé le traitement par lot (code dans GitHub) pour les ensembles de données tels que
DBPedia, Amazon reviews, Ag news et Yelp reviews
Le code suivant montre comment générer des lots de données et les alimenter dans
fit_generator
def _data_generator(x, y, num_features, batch_size): """Generates batches of vectorized texts for training/validation. # Arguments x: np.matrix, feature matrix. y: np.ndarray, labels. num_features: int, number of features. batch_size: int, number of samples per batch. # Returns Yields feature and label data in batches. """ num_samples = x.shape[0] num_batches = num_samples // batch_size if num_samples % batch_size: num_batches += 1 while 1: for i in range(num_batches): start_idx = i * batch_size end_idx = (i + 1) * batch_size if end_idx > num_samples: end_idx = num_samples x_batch = x[start_idx:end_idx] y_batch = y[start_idx:end_idx] yield x_batch, y_batch # Create training and validation generators. training_generator = _data_generator( x_train, train_labels, num_features, batch_size) validation_generator = _data_generator( x_val, val_labels, num_features, batch_size) # Get number of training steps. This indicated the number of steps it takes # to cover all samples in one epoch. steps_per_epoch = x_train.shape[0] // batch_size if x_train.shape[0] % batch_size: steps_per_epoch += 1 # Get number of validation steps. validation_steps = x_val.shape[0] // batch_size if x_val.shape[0] % batch_size: validation_steps += 1 # Train and validate model. history = model.fit_generator( generator=training_generator, steps_per_epoch=steps_per_epoch, validation_data=validation_generator, validation_steps=validation_steps, callbacks=callbacks, epochs=epochs, verbose=2) # Logs once per epoch.