clean up imports. fix naming, force CPU to fill the cache faster with images using 20 workers.
This commit is contained in:
+8
-24
@@ -15,15 +15,13 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
accuracies = []
|
||||
losses = []
|
||||
filenames = []
|
||||
test_idg = ImageDataGenerator(
|
||||
)
|
||||
|
||||
input_shape = (224, 224, 3)
|
||||
batch_size = 32
|
||||
|
||||
test_gen = test_idg.flow_from_directory(
|
||||
# './data/test',
|
||||
'./single_image_test_set',
|
||||
test_gen = ImageDataGenerator().flow_from_directory(
|
||||
'./data/test',
|
||||
# './single_image_test_set',
|
||||
target_size=(input_shape[0], input_shape[1]),
|
||||
batch_size=batch_size,
|
||||
shuffle=False
|
||||
@@ -34,17 +32,15 @@ for file in glob("./models/keras/*"):
|
||||
print(file)
|
||||
model = load_model(file)
|
||||
|
||||
|
||||
|
||||
predictions = model.predict(test_gen, verbose=True, workers=12, steps=len(test_gen))
|
||||
predictions = model.predict(test_gen, verbose=True, workers=12)
|
||||
|
||||
print(predictions)
|
||||
print(type(predictions))
|
||||
print(predictions.shape)
|
||||
|
||||
# Process the predictions
|
||||
predictions = np.argmax(predictions,
|
||||
axis=1)
|
||||
# test_gen.reset()
|
||||
label_index = {v: k for k, v in test_gen.class_indices.items()}
|
||||
predictions = [label_index[p] for p in predictions]
|
||||
reals = [label_index[p] for p in test_gen.classes]
|
||||
@@ -69,24 +65,12 @@ for file in glob("./models/keras/*"):
|
||||
print("Confusion Matrix", conf_mat)
|
||||
|
||||
accuracies.append(acc)
|
||||
# df_cm = pd.DataFrame(conf_mat, index=[i for i in list(set(reals))],
|
||||
# columns=[i for i in list(set(reals))])
|
||||
# print("made dataframe")
|
||||
# plt.figure(figsize=(10, 7))
|
||||
# print("made plot")
|
||||
# # sn.heatmap(df_cm, annot=True)
|
||||
# print("showing plot")
|
||||
# plt.show()
|
||||
|
||||
|
||||
# with open("labels.txt", "w") as f:
|
||||
# for label in label_index.values():
|
||||
# f.write(label + "\n")
|
||||
|
||||
overall_df = pd.DataFrame(list(zip(filenames, accuracies)),
|
||||
columns =['model', 'acc']).sort_values('acc')
|
||||
|
||||
print(overall_df)
|
||||
overall_df.to_csv("all_model_output.csv")
|
||||
overall_df.plot.bar(x="model", y="acc", rot=0)
|
||||
overall_df.plot.bar(y="acc", rot=90)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
overall_df.to_csv("all_model_output.csv")
|
||||
|
||||
Reference in New Issue
Block a user