Tập dữ liệu rất lớn có thể không vừa với bộ nhớ được phân bổ cho quy trình của bạn. Trong
các bước trước đó, chúng tôi đã thiết lập một quy trình để đưa toàn bộ tập dữ liệu vào
vào bộ nhớ, chuẩn bị dữ liệu và truyền tập hợp làm việc vào chương trình huấn luyện
. Thay vào đó, Keras cung cấp một hàm huấn luyện thay thế
(fit_generator
)
để lấy dữ liệu theo lô. Điều này cho phép chúng ta áp dụng các phép biến đổi trong
quy trình dữ liệu chỉ đến một phần nhỏ (bội số của batch_size
) của dữ liệu.
Trong quá trình thử nghiệm, chúng tôi đã sử dụng tính năng phân lô (mã trong GitHub) cho các tập dữ liệu như
DBPedia, bài đánh giá của Amazon, Tin tức Ag và bài đánh giá về Yelp.
Mã sau đây minh hoạ cách tạo các lô dữ liệu và cấp dữ liệu cho các lô đó
fit_generator
.
def _data_generator(x, y, num_features, batch_size): """Generates batches of vectorized texts for training/validation. # Arguments x: np.matrix, feature matrix. y: np.ndarray, labels. num_features: int, number of features. batch_size: int, number of samples per batch. # Returns Yields feature and label data in batches. """ num_samples = x.shape[0] num_batches = num_samples // batch_size if num_samples % batch_size: num_batches += 1 while 1: for i in range(num_batches): start_idx = i * batch_size end_idx = (i + 1) * batch_size if end_idx > num_samples: end_idx = num_samples x_batch = x[start_idx:end_idx] y_batch = y[start_idx:end_idx] yield x_batch, y_batch # Create training and validation generators. training_generator = _data_generator( x_train, train_labels, num_features, batch_size) validation_generator = _data_generator( x_val, val_labels, num_features, batch_size) # Get number of training steps. This indicated the number of steps it takes # to cover all samples in one epoch. steps_per_epoch = x_train.shape[0] // batch_size if x_train.shape[0] % batch_size: steps_per_epoch += 1 # Get number of validation steps. validation_steps = x_val.shape[0] // batch_size if x_val.shape[0] % batch_size: validation_steps += 1 # Train and validate model. history = model.fit_generator( generator=training_generator, steps_per_epoch=steps_per_epoch, validation_data=validation_generator, validation_steps=validation_steps, callbacks=callbacks, epochs=epochs, verbose=2) # Logs once per epoch.