From 755fcde3a9e9c616ae592109402862df370ded6d Mon Sep 17 00:00:00 2001 From: Lucas Date: Wed, 1 Jun 2022 18:52:55 -0400 Subject: [PATCH] clean up imports. fix naming, force CPU to fill the cache faster with images using 20 workers. --- 4_train_keras_model.py | 8 +++---- 5_test_models.py | 32 +++++++--------------------- model_builder/image_class_builder.py | 7 +++--- 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/4_train_keras_model.py b/4_train_keras_model.py index 702a255..9641b32 100644 --- a/4_train_keras_model.py +++ b/4_train_keras_model.py @@ -54,7 +54,7 @@ def get_gen(path, dataset_type: DatasetType = DatasetType.TRAIN): def train_model(model_builder, train_gen, val_gen): model = model_builder.create_model() - model_name = "rot-shift-" + model_builder.get_name() + model_name = model_builder.get_name() print(model) print(f"NOW TRAINING: {model_name}") checkpoint = keras.callbacks.ModelCheckpoint( @@ -83,11 +83,11 @@ def train_model(model_builder, train_gen, val_gen): history = model.fit( train_gen, validation_data=val_gen, - epochs=500, + epochs=8, batch_size=batch_size, shuffle=True, verbose=True, - workers=12, + workers=20, callbacks=[checkpoint, early, tensorboard], max_queue_size=1000 ) @@ -134,7 +134,7 @@ if __name__ == "__main__": fine_tune=True, base_model_type=ImageClassModels.MOBILENET_V2, dense_layer_neurons=1024, - dropout_rate=.33, + dropout_rate=.5, ) ] for mb in model_builders: diff --git a/5_test_models.py b/5_test_models.py index 90f3cb0..731a3a3 100644 --- a/5_test_models.py +++ b/5_test_models.py @@ -15,15 +15,13 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True accuracies = [] losses = [] filenames = [] -test_idg = ImageDataGenerator( -) input_shape = (224, 224, 3) batch_size = 32 -test_gen = test_idg.flow_from_directory( - # './data/test', - './single_image_test_set', +test_gen = ImageDataGenerator().flow_from_directory( + './data/test', + # './single_image_test_set', target_size=(input_shape[0], input_shape[1]), batch_size=batch_size, shuffle=False @@ -34,17 +32,15 @@ for file in glob("./models/keras/*"): print(file) model = load_model(file) - - - predictions = model.predict(test_gen, verbose=True, workers=12, steps=len(test_gen)) + predictions = model.predict(test_gen, verbose=True, workers=12) print(predictions) print(type(predictions)) print(predictions.shape) + # Process the predictions predictions = np.argmax(predictions, axis=1) - # test_gen.reset() label_index = {v: k for k, v in test_gen.class_indices.items()} predictions = [label_index[p] for p in predictions] reals = [label_index[p] for p in test_gen.classes] @@ -69,24 +65,12 @@ for file in glob("./models/keras/*"): print("Confusion Matrix", conf_mat) accuracies.append(acc) - # df_cm = pd.DataFrame(conf_mat, index=[i for i in list(set(reals))], - # columns=[i for i in list(set(reals))]) - # print("made dataframe") - # plt.figure(figsize=(10, 7)) - # print("made plot") - # # sn.heatmap(df_cm, annot=True) - # print("showing plot") - # plt.show() - - - # with open("labels.txt", "w") as f: - # for label in label_index.values(): - # f.write(label + "\n") overall_df = pd.DataFrame(list(zip(filenames, accuracies)), columns =['model', 'acc']).sort_values('acc') print(overall_df) -overall_df.to_csv("all_model_output.csv") -overall_df.plot.bar(x="model", y="acc", rot=0) +overall_df.plot.bar(y="acc", rot=90) +plt.tight_layout() plt.show() +overall_df.to_csv("all_model_output.csv") diff --git a/model_builder/image_class_builder.py b/model_builder/image_class_builder.py index d14f903..caec94d 100644 --- a/model_builder/image_class_builder.py +++ b/model_builder/image_class_builder.py @@ -1,7 +1,8 @@ +import random from enum import Enum -from time import time from typing import Tuple +import numpy as np import tensorflow as tf from tensorflow import keras @@ -94,5 +95,5 @@ class ImageClassModelBuilder(object): def get_name(self): return f"{'pt-' if self.pre_trained else ''}{'ft-' if self.fine_tune else ''}" \ f"{self.base_model_type.value.name}-d{self.dense_layer_neurons}-do{self.dropout_rate}" \ - f"{'-l1' + str(self.l1) if self.l1 > 0 else ''}{'-l2' + str(self.l2) if self.l2 > 0 else ''}" \ - f"-{int(time())}" + f"{'-l1' + np.format_float_scientific(self.l1) if self.l1 > 0 else ''}{'-l2' + np.format_float_scientific(self.l2) if self.l2 > 0 else ''}" \ + f"-{random.randint(1111, 9999)}"