আপনার প্রক্রিয়ার জন্য বরাদ্দ করা মেমরিতে খুব বড় ডেটাসেটগুলি ফিট নাও হতে পারে৷ পূর্ববর্তী ধাপে, আমরা একটি পাইপলাইন সেট আপ করেছি যেখানে আমরা পুরো ডেটাসেট মেমরিতে আনব, ডেটা প্রস্তুত করি এবং প্রশিক্ষণ ফাংশনে কাজের সেটটি পাস করি। পরিবর্তে, কেরাস একটি বিকল্প প্রশিক্ষণ ফাংশন ( fit_generator
) প্রদান করে যা ব্যাচে ডেটা টানে। এটি আমাদের ডেটা পাইপলাইনে রূপান্তরগুলিকে ডেটার শুধুমাত্র একটি ছোট ( batch_size
একাধিক) অংশে প্রয়োগ করতে দেয়। আমাদের পরীক্ষার সময়, আমরা DBPedia , Amazon পর্যালোচনা , Ag news , এবং Yelp পর্যালোচনার মতো ডেটাসেটের জন্য ব্যাচিং (GitHub-এ কোড) ব্যবহার করি।
নিম্নলিখিত কোডটি ব্যাখ্যা করে কিভাবে ডেটা ব্যাচ তৈরি করতে হয় এবং fit_generator
এ তাদের খাওয়ানো যায়।
def _data_generator(x, y, num_features, batch_size): """Generates batches of vectorized texts for training/validation. # Arguments x: np.matrix, feature matrix. y: np.ndarray, labels. num_features: int, number of features. batch_size: int, number of samples per batch. # Returns Yields feature and label data in batches. """ num_samples = x.shape[0] num_batches = num_samples // batch_size if num_samples % batch_size: num_batches += 1 while 1: for i in range(num_batches): start_idx = i * batch_size end_idx = (i + 1) * batch_size if end_idx > num_samples: end_idx = num_samples x_batch = x[start_idx:end_idx] y_batch = y[start_idx:end_idx] yield x_batch, y_batch # Create training and validation generators. training_generator = _data_generator( x_train, train_labels, num_features, batch_size) validation_generator = _data_generator( x_val, val_labels, num_features, batch_size) # Get number of training steps. This indicated the number of steps it takes # to cover all samples in one epoch. steps_per_epoch = x_train.shape[0] // batch_size if x_train.shape[0] % batch_size: steps_per_epoch += 1 # Get number of validation steps. validation_steps = x_val.shape[0] // batch_size if x_val.shape[0] % batch_size: validation_steps += 1 # Train and validate model. history = model.fit_generator( generator=training_generator, steps_per_epoch=steps_per_epoch, validation_data=validation_generator, validation_steps=validation_steps, callbacks=callbacks, epochs=epochs, verbose=2) # Logs once per epoch.