ייתכן שמערךי נתונים גדולים מאוד לא ייכללו בזיכרון שהוקצה לתהליך שלכם. בשלבים הקודמים, הגדרנו צינור עיבוד שבו אנחנו מכניסים את כל מערך הנתונים לזיכרון, מכינים את הנתונים ומעבירים את קבוצת העבודה לפונקציית האימון. במקום זאת, Keras מספק פונקציית הדרכה חלופית (fit_generator
) שולפת את הנתונים בקבוצות. כך אנחנו יכולים להחיל את הטרנספורמציות בצינור הנתונים רק על חלק קטן (כפול של batch_size
) מהנתונים.
במהלך הניסויים, השתמשנו במקבצים (קוד ב-GitHub) למערכי נתונים כמו
DBPedia, ביקורות ב-Amazon, ביקורות ב-Ag וביקורות ב-Yelp.
דוגמת הקוד הבאה ממחישה איך ליצור קבוצות נתונים ולהזין אותן בפיד
fit_generator
.
def _data_generator(x, y, num_features, batch_size): """Generates batches of vectorized texts for training/validation. # Arguments x: np.matrix, feature matrix. y: np.ndarray, labels. num_features: int, number of features. batch_size: int, number of samples per batch. # Returns Yields feature and label data in batches. """ num_samples = x.shape[0] num_batches = num_samples // batch_size if num_samples % batch_size: num_batches += 1 while 1: for i in range(num_batches): start_idx = i * batch_size end_idx = (i + 1) * batch_size if end_idx > num_samples: end_idx = num_samples x_batch = x[start_idx:end_idx] y_batch = y[start_idx:end_idx] yield x_batch, y_batch # Create training and validation generators. training_generator = _data_generator( x_train, train_labels, num_features, batch_size) validation_generator = _data_generator( x_val, val_labels, num_features, batch_size) # Get number of training steps. This indicated the number of steps it takes # to cover all samples in one epoch. steps_per_epoch = x_train.shape[0] // batch_size if x_train.shape[0] % batch_size: steps_per_epoch += 1 # Get number of validation steps. validation_steps = x_val.shape[0] // batch_size if x_val.shape[0] % batch_size: validation_steps += 1 # Train and validate model. history = model.fit_generator( generator=training_generator, steps_per_epoch=steps_per_epoch, validation_data=validation_generator, validation_steps=validation_steps, callbacks=callbacks, epochs=epochs, verbose=2) # Logs once per epoch.