clean up imports. fix naming, force CPU to fill the cache faster with images using 20 workers.
This commit is contained in:
@@ -54,7 +54,7 @@ def get_gen(path, dataset_type: DatasetType = DatasetType.TRAIN):
|
|||||||
|
|
||||||
def train_model(model_builder, train_gen, val_gen):
|
def train_model(model_builder, train_gen, val_gen):
|
||||||
model = model_builder.create_model()
|
model = model_builder.create_model()
|
||||||
model_name = "rot-shift-" + model_builder.get_name()
|
model_name = model_builder.get_name()
|
||||||
print(model)
|
print(model)
|
||||||
print(f"NOW TRAINING: {model_name}")
|
print(f"NOW TRAINING: {model_name}")
|
||||||
checkpoint = keras.callbacks.ModelCheckpoint(
|
checkpoint = keras.callbacks.ModelCheckpoint(
|
||||||
@@ -83,11 +83,11 @@ def train_model(model_builder, train_gen, val_gen):
|
|||||||
history = model.fit(
|
history = model.fit(
|
||||||
train_gen,
|
train_gen,
|
||||||
validation_data=val_gen,
|
validation_data=val_gen,
|
||||||
epochs=500,
|
epochs=8,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
workers=12,
|
workers=20,
|
||||||
callbacks=[checkpoint, early, tensorboard],
|
callbacks=[checkpoint, early, tensorboard],
|
||||||
max_queue_size=1000
|
max_queue_size=1000
|
||||||
)
|
)
|
||||||
@@ -134,7 +134,7 @@ if __name__ == "__main__":
|
|||||||
fine_tune=True,
|
fine_tune=True,
|
||||||
base_model_type=ImageClassModels.MOBILENET_V2,
|
base_model_type=ImageClassModels.MOBILENET_V2,
|
||||||
dense_layer_neurons=1024,
|
dense_layer_neurons=1024,
|
||||||
dropout_rate=.33,
|
dropout_rate=.5,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
for mb in model_builders:
|
for mb in model_builders:
|
||||||
|
|||||||
+8
-24
@@ -15,15 +15,13 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|||||||
accuracies = []
|
accuracies = []
|
||||||
losses = []
|
losses = []
|
||||||
filenames = []
|
filenames = []
|
||||||
test_idg = ImageDataGenerator(
|
|
||||||
)
|
|
||||||
|
|
||||||
input_shape = (224, 224, 3)
|
input_shape = (224, 224, 3)
|
||||||
batch_size = 32
|
batch_size = 32
|
||||||
|
|
||||||
test_gen = test_idg.flow_from_directory(
|
test_gen = ImageDataGenerator().flow_from_directory(
|
||||||
# './data/test',
|
'./data/test',
|
||||||
'./single_image_test_set',
|
# './single_image_test_set',
|
||||||
target_size=(input_shape[0], input_shape[1]),
|
target_size=(input_shape[0], input_shape[1]),
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=False
|
shuffle=False
|
||||||
@@ -34,17 +32,15 @@ for file in glob("./models/keras/*"):
|
|||||||
print(file)
|
print(file)
|
||||||
model = load_model(file)
|
model = load_model(file)
|
||||||
|
|
||||||
|
predictions = model.predict(test_gen, verbose=True, workers=12)
|
||||||
|
|
||||||
predictions = model.predict(test_gen, verbose=True, workers=12, steps=len(test_gen))
|
|
||||||
|
|
||||||
print(predictions)
|
print(predictions)
|
||||||
print(type(predictions))
|
print(type(predictions))
|
||||||
print(predictions.shape)
|
print(predictions.shape)
|
||||||
|
|
||||||
# Process the predictions
|
# Process the predictions
|
||||||
predictions = np.argmax(predictions,
|
predictions = np.argmax(predictions,
|
||||||
axis=1)
|
axis=1)
|
||||||
# test_gen.reset()
|
|
||||||
label_index = {v: k for k, v in test_gen.class_indices.items()}
|
label_index = {v: k for k, v in test_gen.class_indices.items()}
|
||||||
predictions = [label_index[p] for p in predictions]
|
predictions = [label_index[p] for p in predictions]
|
||||||
reals = [label_index[p] for p in test_gen.classes]
|
reals = [label_index[p] for p in test_gen.classes]
|
||||||
@@ -69,24 +65,12 @@ for file in glob("./models/keras/*"):
|
|||||||
print("Confusion Matrix", conf_mat)
|
print("Confusion Matrix", conf_mat)
|
||||||
|
|
||||||
accuracies.append(acc)
|
accuracies.append(acc)
|
||||||
# df_cm = pd.DataFrame(conf_mat, index=[i for i in list(set(reals))],
|
|
||||||
# columns=[i for i in list(set(reals))])
|
|
||||||
# print("made dataframe")
|
|
||||||
# plt.figure(figsize=(10, 7))
|
|
||||||
# print("made plot")
|
|
||||||
# # sn.heatmap(df_cm, annot=True)
|
|
||||||
# print("showing plot")
|
|
||||||
# plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
# with open("labels.txt", "w") as f:
|
|
||||||
# for label in label_index.values():
|
|
||||||
# f.write(label + "\n")
|
|
||||||
|
|
||||||
overall_df = pd.DataFrame(list(zip(filenames, accuracies)),
|
overall_df = pd.DataFrame(list(zip(filenames, accuracies)),
|
||||||
columns =['model', 'acc']).sort_values('acc')
|
columns =['model', 'acc']).sort_values('acc')
|
||||||
|
|
||||||
print(overall_df)
|
print(overall_df)
|
||||||
overall_df.to_csv("all_model_output.csv")
|
overall_df.plot.bar(y="acc", rot=90)
|
||||||
overall_df.plot.bar(x="model", y="acc", rot=0)
|
plt.tight_layout()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
overall_df.to_csv("all_model_output.csv")
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
|
import random
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from time import time
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
|
|
||||||
@@ -94,5 +95,5 @@ class ImageClassModelBuilder(object):
|
|||||||
def get_name(self):
|
def get_name(self):
|
||||||
return f"{'pt-' if self.pre_trained else ''}{'ft-' if self.fine_tune else ''}" \
|
return f"{'pt-' if self.pre_trained else ''}{'ft-' if self.fine_tune else ''}" \
|
||||||
f"{self.base_model_type.value.name}-d{self.dense_layer_neurons}-do{self.dropout_rate}" \
|
f"{self.base_model_type.value.name}-d{self.dense_layer_neurons}-do{self.dropout_rate}" \
|
||||||
f"{'-l1' + str(self.l1) if self.l1 > 0 else ''}{'-l2' + str(self.l2) if self.l2 > 0 else ''}" \
|
f"{'-l1' + np.format_float_scientific(self.l1) if self.l1 > 0 else ''}{'-l2' + np.format_float_scientific(self.l2) if self.l2 > 0 else ''}" \
|
||||||
f"-{int(time())}"
|
f"-{random.randint(1111, 9999)}"
|
||||||
|
|||||||
Reference in New Issue
Block a user