clean up imports. fix naming, force CPU to fill the cache faster with images using 20 workers.
This commit is contained in:
@@ -54,7 +54,7 @@ def get_gen(path, dataset_type: DatasetType = DatasetType.TRAIN):
|
||||
|
||||
def train_model(model_builder, train_gen, val_gen):
|
||||
model = model_builder.create_model()
|
||||
model_name = "rot-shift-" + model_builder.get_name()
|
||||
model_name = model_builder.get_name()
|
||||
print(model)
|
||||
print(f"NOW TRAINING: {model_name}")
|
||||
checkpoint = keras.callbacks.ModelCheckpoint(
|
||||
@@ -83,11 +83,11 @@ def train_model(model_builder, train_gen, val_gen):
|
||||
history = model.fit(
|
||||
train_gen,
|
||||
validation_data=val_gen,
|
||||
epochs=500,
|
||||
epochs=8,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
verbose=True,
|
||||
workers=12,
|
||||
workers=20,
|
||||
callbacks=[checkpoint, early, tensorboard],
|
||||
max_queue_size=1000
|
||||
)
|
||||
@@ -134,7 +134,7 @@ if __name__ == "__main__":
|
||||
fine_tune=True,
|
||||
base_model_type=ImageClassModels.MOBILENET_V2,
|
||||
dense_layer_neurons=1024,
|
||||
dropout_rate=.33,
|
||||
dropout_rate=.5,
|
||||
)
|
||||
]
|
||||
for mb in model_builders:
|
||||
|
||||
Reference in New Issue
Block a user