Случайным образом увеличивая изображения, используя Керрас
Я экспериментирую с набором данных MNIST, чтобы изучить библиотеку Keras. В MNIST есть 60k обучающих изображений и 10k проверочных изображений.
В обоих наборах я хотел бы ввести увеличение на 30% изображений.
datagen = ImageDataGenerator(horizontal_flip=True, vertical_flip=True)
datagen.fit(training_images)
datagen.fit(validation_images)
Это не увеличивает изображения, и я не уверен, как использовать метод model.fit_generator
. Мой текущий model.fit
выглядит следующим образом:
model.fit(training_images, training_labels, validation_data=(validation_images, validation_labels), epochs=10, batch_size=200, verbose=2)
Как применить увеличение к некоторым изображениям в этом наборе данных?
1 ответ:
Я бы попытался определить свой собственный генератор следующим образом:
from sklearn.model_selection import train_test_split from six import next def partial_flow(array, flags, generator, aug_percentage, batch_size): # Splitting data into arrays which will be augmented and which won't not_aug_array, not_aug_flags, aug_array, aug_flags = train_test_split( array, test_size=aug_percentage) # Preparation of generators which will be used for augmentation. aug_split_size = int(batch_size * split_size) # We will use generator without any augmentation to yield not augmented data not_augmented_gen = ImageDataGenerator() aug_gen = generator.flow( x=aug_array, y=aug_flags, batch_size=aug_split_size) not_aug_gen = not_augmented_gen.flow( x=not_aug_array, y=not_aug_flags, batch_size=batch_size - aug_split_size) # Yiedling data while True: # Getting augmented data aug_x, aug_y = next(aug_gen) # Getting not augmented data not_aug_x, not_aug_y = next(not_aug_gen) # Concatenation current_x = numpy.concatenate([aug_x, not_aug_x], axis=0) current_y = numpy.concatenate([aug_y, not_aug_y], axis=0) yield current_x, current_y
Теперь вы можете запустить обучение по:
batch_size = 200 model.fit_generator(partial_flow(training_images, training_labels, 0.7, batch_size), steps_per_epoch=int(training_images.shape[0] / batch_size), epochs=10, validation_data=partial_flow(validation_images, validation_labels, 0.7, batch_size), validation_steps=int(validation_images.shape[0] / batch_size))