Updated training for new KERAS API
This commit is contained in:
@@ -7,7 +7,6 @@ import matplotlib.pyplot as plt
|
|||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
from time import time
|
from time import time
|
||||||
from PIL import ImageFile
|
from PIL import ImageFile
|
||||||
|
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
input_shape = (224, 224, 3)
|
input_shape = (224, 224, 3)
|
||||||
@@ -18,7 +17,7 @@ from keras.preprocessing.image import ImageDataGenerator
|
|||||||
from keras.applications.inception_v3 import preprocess_input
|
from keras.applications.inception_v3 import preprocess_input
|
||||||
|
|
||||||
train_idg = ImageDataGenerator(
|
train_idg = ImageDataGenerator(
|
||||||
horizontal_flip=True,
|
# horizontal_flip=True,
|
||||||
preprocessing_function=preprocess_input
|
preprocessing_function=preprocess_input
|
||||||
)
|
)
|
||||||
train_gen = train_idg.flow_from_directory(
|
train_gen = train_idg.flow_from_directory(
|
||||||
@@ -28,7 +27,7 @@ train_gen = train_idg.flow_from_directory(
|
|||||||
)
|
)
|
||||||
|
|
||||||
val_idg = ImageDataGenerator(
|
val_idg = ImageDataGenerator(
|
||||||
horizontal_flip=True,
|
# horizontal_flip=True,
|
||||||
preprocessing_function=preprocess_input
|
preprocessing_function=preprocess_input
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -88,7 +87,7 @@ checkpoint = ModelCheckpoint(file_path, monitor='val_acc', verbose=1, save_best_
|
|||||||
early = EarlyStopping(monitor="val_acc", mode="max", patience=15)
|
early = EarlyStopping(monitor="val_acc", mode="max", patience=15)
|
||||||
|
|
||||||
tensorboard = TensorBoard(
|
tensorboard = TensorBoard(
|
||||||
log_dir="logs/" + model_name + "{}".format(time()), histogram_freq=0, batch_size=32,
|
log_dir="logs/" + model_name + "{}".format(time()), histogram_freq=0, batch_size=batch_size,
|
||||||
write_graph=True,
|
write_graph=True,
|
||||||
write_grads=True,
|
write_grads=True,
|
||||||
write_images=True,
|
write_images=True,
|
||||||
@@ -100,6 +99,8 @@ callbacks_list = [checkpoint, early, tensorboard] # early
|
|||||||
history = model.fit_generator(
|
history = model.fit_generator(
|
||||||
train_gen,
|
train_gen,
|
||||||
validation_data=val_gen,
|
validation_data=val_gen,
|
||||||
|
steps_per_epoch=len(train_gen),
|
||||||
|
validation_steps=len(val_gen),
|
||||||
epochs=2,
|
epochs=2,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@@ -119,10 +120,11 @@ test_gen = test_idg.flow_from_directory(
|
|||||||
)
|
)
|
||||||
len(test_gen.filenames)
|
len(test_gen.filenames)
|
||||||
|
|
||||||
score = model.evaluate_generator(test_gen, workers=1)
|
score = model.evaluate_generator(test_gen, workers=1, steps=len(test_gen))
|
||||||
|
|
||||||
# predicts
|
# predicts
|
||||||
predicts = model.predict_generator(test_gen, verbose=True, workers=1)
|
predicts = model.predict_generator(test_gen, verbose=True, workers=1, steps=len(test_gen))
|
||||||
|
|
||||||
|
|
||||||
keras_file = 'finished.h5'
|
keras_file = 'finished.h5'
|
||||||
keras.models.save_model(model, keras_file)
|
keras.models.save_model(model, keras_file)
|
||||||
|
|||||||
@@ -122,10 +122,11 @@ test_gen = test_idg.flow_from_directory(
|
|||||||
|
|
||||||
len(test_gen.filenames)
|
len(test_gen.filenames)
|
||||||
|
|
||||||
score = model.evaluate_generator(test_gen, workers=1)
|
score = model.evaluate_generator(test_gen, workers=1, steps=len(test_gen))
|
||||||
|
|
||||||
# predicts
|
# predicts
|
||||||
predicts = model.predict_generator(test_gen, verbose=True, workers=1)
|
predicts = model.predict_generator(test_gen, verbose=True, workers=1, steps=len(test_gen))
|
||||||
|
|
||||||
|
|
||||||
print("Loss: ", score[0], "Accuracy: ", score[1])
|
print("Loss: ", score[0], "Accuracy: ", score[1])
|
||||||
print(score)
|
print(score)
|
||||||
|
|||||||
Reference in New Issue
Block a user