feat: adding resnet and formatting updates

This commit is contained in:
Lucas Oskorep
2023-04-06 00:37:46 -04:00
parent ce5939d8a9
commit dc427837f6
12 changed files with 87 additions and 37 deletions
+1
View File
@@ -7,6 +7,7 @@ import multiprocessing
train_dir = "./data/train/" train_dir = "./data/train/"
test_dir = "./data/test/" test_dir = "./data/test/"
val_dir = "./data/val/" val_dir = "./data/val/"
train = .80 train = .80
test = .10 test = .10
val = .10 val = .10
+57 -17
View File
@@ -53,7 +53,7 @@ def get_gen(path, dataset_type: DatasetType = DatasetType.TRAIN):
) )
def train_model(model, model_name, train_gen, val_gen): def train_model(model, model_name, train_gen, val_gen, max_epochs):
print(model) print(model)
print(f"NOW TRAINING: {model_name}") print(f"NOW TRAINING: {model_name}")
checkpoint = keras.callbacks.ModelCheckpoint( checkpoint = keras.callbacks.ModelCheckpoint(
@@ -82,7 +82,7 @@ def train_model(model, model_name, train_gen, val_gen):
model.fit( model.fit(
train_gen, train_gen,
validation_data=val_gen, validation_data=val_gen,
epochs=100, epochs=max_epochs,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
verbose=True, verbose=True,
@@ -130,7 +130,47 @@ if __name__ == "__main__":
pre_trained=True, pre_trained=True,
freeze_layers=True, freeze_layers=True,
freeze_batch_norm=True, freeze_batch_norm=True,
base_model_type=ImageClassModels.EFFICIENTNET_V2S, base_model_type=ImageClassModels.MOBILENET_V2,
dense_layer_neurons=1024,
dropout_rate=.5,
), ImageClassModelBuilder(
input_shape=input_shape,
n_classes=807,
optimizer=keras.optimizers.Adam(learning_rate=.0001),
pre_trained=True,
freeze_layers=True,
freeze_batch_norm=True,
base_model_type=ImageClassModels.INCEPTION_RESNET_V2,
dense_layer_neurons=1024,
dropout_rate=.5,
), ImageClassModelBuilder(
input_shape=input_shape,
n_classes=807,
optimizer=keras.optimizers.Adam(learning_rate=.0001),
pre_trained=True,
freeze_layers=True,
freeze_batch_norm=True,
base_model_type=ImageClassModels.INCEPTION_V3,
dense_layer_neurons=1024,
dropout_rate=.5,
), ImageClassModelBuilder(
input_shape=input_shape,
n_classes=807,
optimizer=keras.optimizers.Adam(learning_rate=.0001),
pre_trained=True,
freeze_layers=True,
freeze_batch_norm=True,
base_model_type=ImageClassModels.XCEPTION,
dense_layer_neurons=1024,
dropout_rate=.5,
), ImageClassModelBuilder(
input_shape=input_shape,
n_classes=807,
optimizer=keras.optimizers.Adam(learning_rate=.0001),
pre_trained=True,
freeze_layers=True,
freeze_batch_norm=True,
base_model_type=ImageClassModels.DENSENET201,
dense_layer_neurons=1024, dense_layer_neurons=1024,
dropout_rate=.5, dropout_rate=.5,
) )
@@ -141,17 +181,17 @@ if __name__ == "__main__":
train_gen = get_gen('./data/train', dataset_type=DatasetType.TRAIN) train_gen = get_gen('./data/train', dataset_type=DatasetType.TRAIN)
val_gen = get_gen('./data/val', dataset_type=DatasetType.VAL) val_gen = get_gen('./data/val', dataset_type=DatasetType.VAL)
test_gen = get_gen('./data/test', dataset_type=DatasetType.TEST) test_gen = get_gen('./data/test', dataset_type=DatasetType.TEST)
model = train_model(model, model_name, train_gen, val_gen) model = train_model(model, model_name, train_gen, val_gen, 1)
for layer in model.layers[2].layers: # for layer in model.layers[2].layers:
if not isinstance(layer, keras.layers.BatchNormalization): # if not isinstance(layer, keras.layers.BatchNormalization):
layer.trainable = True # layer.trainable = True
model.layers[2].trainable = True # model.layers[2].trainable = True
print(model) # print(model)
model.compile( # model.compile(
optimizer=keras.optimizers.Adam(learning_rate=.00001), # optimizer=keras.optimizers.Adam(learning_rate=.00001),
loss=keras.losses.CategoricalCrossentropy(), # loss=keras.losses.CategoricalCrossentropy(),
metrics=['accuracy', 'categorical_crossentropy'] # metrics=['accuracy', 'categorical_crossentropy']
) # )
model.summary() # model.summary()
model = train_model(model, model_name + "-second_stage", train_gen, val_gen) # model = train_model(model, model_name + "-second_stage", train_gen, val_gen, 1)
test_model(model, test_gen) # test_model(model, test_gen)
+1 -2
View File
@@ -24,7 +24,7 @@ test_gen = ImageDataGenerator().flow_from_directory(
batch_size=batch_size, batch_size=batch_size,
shuffle=False shuffle=False
) )
#
single_gen = ImageDataGenerator().flow_from_directory( single_gen = ImageDataGenerator().flow_from_directory(
'./single_image_test_set', './single_image_test_set',
target_size=(input_shape[0], input_shape[1]), target_size=(input_shape[0], input_shape[1]),
@@ -32,7 +32,6 @@ single_gen = ImageDataGenerator().flow_from_directory(
shuffle=False shuffle=False
) )
for file in glob("./models/keras/*.hdf5"): for file in glob("./models/keras/*.hdf5"):
print(file) print(file)
if file in metrics_df.values: if file in metrics_df.values:
+6 -3
View File
@@ -5,7 +5,7 @@ from pathlib import Path
import pandas as pd import pandas as pd
import tensorflow as tf import tensorflow as tf
from keras.preprocessing.image import ImageDataGenerator from keras.preprocessing.image import ImageDataGenerator
from tensorflow import keras import tensorflow as tf
# TODO: Move these to a config for the project # TODO: Move these to a config for the project
input_shape = (224, 224, 3) input_shape = (224, 224, 3)
@@ -25,8 +25,11 @@ for file in glob("./models/keras/*.hdf5"):
path = Path(file) path = Path(file)
tflite_file = f'./models/tflite/models/{path.name[:-5] + ".tflite"}' tflite_file = f'./models/tflite/models/{path.name[:-5] + ".tflite"}'
if not Path(tflite_file).exists(): if not Path(tflite_file).exists():
keras_model = keras.models.load_model(file) print(tflite_file)
keras_model = tf.keras.models.load_model(file)
keras_model.summary()
print(keras_model.input)
print(keras_model.layers)
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
tflite_model = converter.convert() tflite_model = converter.convert()
with open(tflite_file, 'wb') as f: with open(tflite_file, 'wb') as f:
+3 -1
View File
@@ -2,4 +2,6 @@ model,test_acc,test_loss,single_acc,single_loss
./models/keras\pt-fl-fbn-efficientnet_v2b0-d1024-do0.5-l11.e-04-l21.e-04-5224-second_stage.hdf5,0.6720150708068079,1.7423864365349095,0.9893048128342246,0.4364729183409372 ./models/keras\pt-fl-fbn-efficientnet_v2b0-d1024-do0.5-l11.e-04-l21.e-04-5224-second_stage.hdf5,0.6720150708068079,1.7423864365349095,0.9893048128342246,0.4364729183409372
./models/keras\pt-fl-fbn-efficientnet_v2b0-d1024-do0.5-l11.e-04-l21.e-04-5224.hdf5,0.410029881772119,3.346152696366266,0.986096256684492,0.3234976000776315 ./models/keras\pt-fl-fbn-efficientnet_v2b0-d1024-do0.5-l11.e-04-l21.e-04-5224.hdf5,0.410029881772119,3.346152696366266,0.986096256684492,0.3234976000776315
./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105-second_stage.hdf5,0.6850721060153306,1.675868156533777,0.9967914438502674,0.3373779159304851 ./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105-second_stage.hdf5,0.6850721060153306,1.675868156533777,0.9967914438502674,0.3373779159304851
./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105.hdf5,0.37553592308691697,3.5500588697038067,0.9540106951871657,0.47270425785037834 ./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105.hdf5,0.3755359230869169,3.5500588697038067,0.9540106951871656,0.4727042578503783
./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-9317-second_stage.hdf5,0.6121780461172843,2.197206965588216,0.9946581196581196,0.2974041509252359
./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-9317.hdf5,0.3702228787976106,3.601324427207316,0.9594017094017094,0.4877960320956891
1 model test_acc test_loss single_acc single_loss
2 ./models/keras\pt-fl-fbn-efficientnet_v2b0-d1024-do0.5-l11.e-04-l21.e-04-5224-second_stage.hdf5 0.6720150708068079 1.7423864365349095 0.9893048128342246 0.4364729183409372
3 ./models/keras\pt-fl-fbn-efficientnet_v2b0-d1024-do0.5-l11.e-04-l21.e-04-5224.hdf5 0.410029881772119 3.346152696366266 0.986096256684492 0.3234976000776315
4 ./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105-second_stage.hdf5 0.6850721060153306 1.675868156533777 0.9967914438502674 0.3373779159304851
5 ./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105.hdf5 0.37553592308691697 0.3755359230869169 3.5500588697038067 0.9540106951871657 0.9540106951871656 0.47270425785037834 0.4727042578503783
6 ./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-9317-second_stage.hdf5 0.6121780461172843 2.197206965588216 0.9946581196581196 0.2974041509252359
7 ./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-9317.hdf5 0.3702228787976106 3.601324427207316 0.9594017094017094 0.4877960320956891
-3
View File
@@ -1,3 +0,0 @@
import pandas as pd
import os
+10 -1
View File
@@ -15,11 +15,21 @@ class ImageClassModels(Enum):
keras.applications.inception_v3.preprocess_input, keras.applications.inception_v3.preprocess_input,
"inception_v3" "inception_v3"
) )
INCEPTION_RESNET_V2 = ModelWrapper(
keras.applications.inception_resnet_v2.InceptionResNetV2,
keras.applications.inception_resnet_v2.preprocess_input,
"inception_resnet_v2"
)
XCEPTION = ModelWrapper( XCEPTION = ModelWrapper(
keras.applications.xception.Xception, keras.applications.xception.Xception,
keras.applications.xception.preprocess_input, keras.applications.xception.preprocess_input,
"xception" "xception"
) )
DENSENET201 = ModelWrapper(
keras.applications.densenet.DenseNet201,
keras.applications.densenet.preprocess_input,
"densenet201"
)
MOBILENET_V2 = ModelWrapper( MOBILENET_V2 = ModelWrapper(
keras.applications.mobilenet_v2.MobileNetV2, keras.applications.mobilenet_v2.MobileNetV2,
keras.applications.mobilenet_v2.preprocess_input, keras.applications.mobilenet_v2.preprocess_input,
@@ -34,7 +44,6 @@ class ImageClassModels(Enum):
keras.applications.efficientnet_v2.EfficientNetV2B0, keras.applications.efficientnet_v2.EfficientNetV2B0,
tf.keras.applications.efficientnet_v2.preprocess_input, tf.keras.applications.efficientnet_v2.preprocess_input,
"efficientnet_v2b0" "efficientnet_v2b0"
) )
+1 -2
View File
@@ -1,5 +1,4 @@
from collections import Callable from typing import Callable
class ModelWrapper(object): class ModelWrapper(object):
def __init__(self, model_func:Callable, model_preprocessor:Callable, name:str): def __init__(self, model_func:Callable, model_preprocessor:Callable, name:str):
+1
View File
@@ -1,4 +1,5 @@
import pandas as pd import pandas as pd
df = pd.read_csv("models/keras/pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105-second_stage.csv") df = pd.read_csv("models/keras/pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105-second_stage.csv")
print(df.loc[df["prediction"] != df["true_val"]]) print(df.loc[df["prediction"] != df["true_val"]])
-1
View File
@@ -19,7 +19,6 @@ for index, row in df2.iterrows():
incorrect = df[df["prediction"]!= df["true_val"]] incorrect = df[df["prediction"]!= df["true_val"]]
total_same_fam = 0 total_same_fam = 0
# TODO: Add in support for figuring out if the pokemon are related/evolutions of one another
for index, row in incorrect.iterrows(): for index, row in incorrect.iterrows():
img = mpimg.imread("./SingleImageTestSet/" + row['fname']) img = mpimg.imread("./SingleImageTestSet/" + row['fname'])
imgplot = plt.imshow(img) imgplot = plt.imshow(img)
Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB