clean up imports. fix naming, force CPU to fill the cache faster with images using 20 workers.

This commit is contained in:
Lucas
2022-06-01 18:52:55 -04:00
parent 1b539d6945
commit 755fcde3a9
3 changed files with 16 additions and 31 deletions
+4 -4
View File
@@ -54,7 +54,7 @@ def get_gen(path, dataset_type: DatasetType = DatasetType.TRAIN):
def train_model(model_builder, train_gen, val_gen):
model = model_builder.create_model()
model_name = "rot-shift-" + model_builder.get_name()
model_name = model_builder.get_name()
print(model)
print(f"NOW TRAINING: {model_name}")
checkpoint = keras.callbacks.ModelCheckpoint(
@@ -83,11 +83,11 @@ def train_model(model_builder, train_gen, val_gen):
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=500,
epochs=8,
batch_size=batch_size,
shuffle=True,
verbose=True,
workers=12,
workers=20,
callbacks=[checkpoint, early, tensorboard],
max_queue_size=1000
)
@@ -134,7 +134,7 @@ if __name__ == "__main__":
fine_tune=True,
base_model_type=ImageClassModels.MOBILENET_V2,
dense_layer_neurons=1024,
dropout_rate=.33,
dropout_rate=.5,
)
]
for mb in model_builders: