Updated training for new KERAS API
This commit is contained in:
@@ -7,7 +7,6 @@ import matplotlib.pyplot as plt
|
||||
from tensorflow import keras
|
||||
from time import time
|
||||
from PIL import ImageFile
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
input_shape = (224, 224, 3)
|
||||
@@ -18,7 +17,7 @@ from keras.preprocessing.image import ImageDataGenerator
|
||||
from keras.applications.inception_v3 import preprocess_input
|
||||
|
||||
train_idg = ImageDataGenerator(
|
||||
horizontal_flip=True,
|
||||
# horizontal_flip=True,
|
||||
preprocessing_function=preprocess_input
|
||||
)
|
||||
train_gen = train_idg.flow_from_directory(
|
||||
@@ -28,7 +27,7 @@ train_gen = train_idg.flow_from_directory(
|
||||
)
|
||||
|
||||
val_idg = ImageDataGenerator(
|
||||
horizontal_flip=True,
|
||||
# horizontal_flip=True,
|
||||
preprocessing_function=preprocess_input
|
||||
)
|
||||
|
||||
@@ -88,7 +87,7 @@ checkpoint = ModelCheckpoint(file_path, monitor='val_acc', verbose=1, save_best_
|
||||
early = EarlyStopping(monitor="val_acc", mode="max", patience=15)
|
||||
|
||||
tensorboard = TensorBoard(
|
||||
log_dir="logs/" + model_name + "{}".format(time()), histogram_freq=0, batch_size=32,
|
||||
log_dir="logs/" + model_name + "{}".format(time()), histogram_freq=0, batch_size=batch_size,
|
||||
write_graph=True,
|
||||
write_grads=True,
|
||||
write_images=True,
|
||||
@@ -100,6 +99,8 @@ callbacks_list = [checkpoint, early, tensorboard] # early
|
||||
history = model.fit_generator(
|
||||
train_gen,
|
||||
validation_data=val_gen,
|
||||
steps_per_epoch=len(train_gen),
|
||||
validation_steps=len(val_gen),
|
||||
epochs=2,
|
||||
shuffle=True,
|
||||
verbose=True,
|
||||
@@ -119,10 +120,11 @@ test_gen = test_idg.flow_from_directory(
|
||||
)
|
||||
len(test_gen.filenames)
|
||||
|
||||
score = model.evaluate_generator(test_gen, workers=1)
|
||||
score = model.evaluate_generator(test_gen, workers=1, steps=len(test_gen))
|
||||
|
||||
# predicts
|
||||
predicts = model.predict_generator(test_gen, verbose=True, workers=1)
|
||||
predicts = model.predict_generator(test_gen, verbose=True, workers=1, steps=len(test_gen))
|
||||
|
||||
|
||||
keras_file = 'finished.h5'
|
||||
keras.models.save_model(model, keras_file)
|
||||
|
||||
Reference in New Issue
Block a user