Updated training for new KERAS API

This commit is contained in:
Lucas Oskorep
2019-04-14 15:17:57 -05:00
parent bc44d30180
commit fce361ddcb
2 changed files with 11 additions and 8 deletions
+8 -6
View File
@@ -7,7 +7,6 @@ import matplotlib.pyplot as plt
from tensorflow import keras
from time import time
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
input_shape = (224, 224, 3)
@@ -18,7 +17,7 @@ from keras.preprocessing.image import ImageDataGenerator
from keras.applications.inception_v3 import preprocess_input
train_idg = ImageDataGenerator(
horizontal_flip=True,
# horizontal_flip=True,
preprocessing_function=preprocess_input
)
train_gen = train_idg.flow_from_directory(
@@ -28,7 +27,7 @@ train_gen = train_idg.flow_from_directory(
)
val_idg = ImageDataGenerator(
horizontal_flip=True,
# horizontal_flip=True,
preprocessing_function=preprocess_input
)
@@ -88,7 +87,7 @@ checkpoint = ModelCheckpoint(file_path, monitor='val_acc', verbose=1, save_best_
early = EarlyStopping(monitor="val_acc", mode="max", patience=15)
tensorboard = TensorBoard(
log_dir="logs/" + model_name + "{}".format(time()), histogram_freq=0, batch_size=32,
log_dir="logs/" + model_name + "{}".format(time()), histogram_freq=0, batch_size=batch_size,
write_graph=True,
write_grads=True,
write_images=True,
@@ -100,6 +99,8 @@ callbacks_list = [checkpoint, early, tensorboard] # early
history = model.fit_generator(
train_gen,
validation_data=val_gen,
steps_per_epoch=len(train_gen),
validation_steps=len(val_gen),
epochs=2,
shuffle=True,
verbose=True,
@@ -119,10 +120,11 @@ test_gen = test_idg.flow_from_directory(
)
len(test_gen.filenames)
score = model.evaluate_generator(test_gen, workers=1)
score = model.evaluate_generator(test_gen, workers=1, steps=len(test_gen))
# predicts
predicts = model.predict_generator(test_gen, verbose=True, workers=1)
predicts = model.predict_generator(test_gen, verbose=True, workers=1, steps=len(test_gen))
keras_file = 'finished.h5'
keras.models.save_model(model, keras_file)