feat: adding resnet and formatting updates
This commit is contained in:
@@ -7,6 +7,7 @@ import multiprocessing
|
|||||||
train_dir = "./data/train/"
|
train_dir = "./data/train/"
|
||||||
test_dir = "./data/test/"
|
test_dir = "./data/test/"
|
||||||
val_dir = "./data/val/"
|
val_dir = "./data/val/"
|
||||||
|
|
||||||
train = .80
|
train = .80
|
||||||
test = .10
|
test = .10
|
||||||
val = .10
|
val = .10
|
||||||
|
|||||||
+57
-17
@@ -53,7 +53,7 @@ def get_gen(path, dataset_type: DatasetType = DatasetType.TRAIN):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def train_model(model, model_name, train_gen, val_gen):
|
def train_model(model, model_name, train_gen, val_gen, max_epochs):
|
||||||
print(model)
|
print(model)
|
||||||
print(f"NOW TRAINING: {model_name}")
|
print(f"NOW TRAINING: {model_name}")
|
||||||
checkpoint = keras.callbacks.ModelCheckpoint(
|
checkpoint = keras.callbacks.ModelCheckpoint(
|
||||||
@@ -82,7 +82,7 @@ def train_model(model, model_name, train_gen, val_gen):
|
|||||||
model.fit(
|
model.fit(
|
||||||
train_gen,
|
train_gen,
|
||||||
validation_data=val_gen,
|
validation_data=val_gen,
|
||||||
epochs=100,
|
epochs=max_epochs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@@ -130,7 +130,47 @@ if __name__ == "__main__":
|
|||||||
pre_trained=True,
|
pre_trained=True,
|
||||||
freeze_layers=True,
|
freeze_layers=True,
|
||||||
freeze_batch_norm=True,
|
freeze_batch_norm=True,
|
||||||
base_model_type=ImageClassModels.EFFICIENTNET_V2S,
|
base_model_type=ImageClassModels.MOBILENET_V2,
|
||||||
|
dense_layer_neurons=1024,
|
||||||
|
dropout_rate=.5,
|
||||||
|
), ImageClassModelBuilder(
|
||||||
|
input_shape=input_shape,
|
||||||
|
n_classes=807,
|
||||||
|
optimizer=keras.optimizers.Adam(learning_rate=.0001),
|
||||||
|
pre_trained=True,
|
||||||
|
freeze_layers=True,
|
||||||
|
freeze_batch_norm=True,
|
||||||
|
base_model_type=ImageClassModels.INCEPTION_RESNET_V2,
|
||||||
|
dense_layer_neurons=1024,
|
||||||
|
dropout_rate=.5,
|
||||||
|
), ImageClassModelBuilder(
|
||||||
|
input_shape=input_shape,
|
||||||
|
n_classes=807,
|
||||||
|
optimizer=keras.optimizers.Adam(learning_rate=.0001),
|
||||||
|
pre_trained=True,
|
||||||
|
freeze_layers=True,
|
||||||
|
freeze_batch_norm=True,
|
||||||
|
base_model_type=ImageClassModels.INCEPTION_V3,
|
||||||
|
dense_layer_neurons=1024,
|
||||||
|
dropout_rate=.5,
|
||||||
|
), ImageClassModelBuilder(
|
||||||
|
input_shape=input_shape,
|
||||||
|
n_classes=807,
|
||||||
|
optimizer=keras.optimizers.Adam(learning_rate=.0001),
|
||||||
|
pre_trained=True,
|
||||||
|
freeze_layers=True,
|
||||||
|
freeze_batch_norm=True,
|
||||||
|
base_model_type=ImageClassModels.XCEPTION,
|
||||||
|
dense_layer_neurons=1024,
|
||||||
|
dropout_rate=.5,
|
||||||
|
), ImageClassModelBuilder(
|
||||||
|
input_shape=input_shape,
|
||||||
|
n_classes=807,
|
||||||
|
optimizer=keras.optimizers.Adam(learning_rate=.0001),
|
||||||
|
pre_trained=True,
|
||||||
|
freeze_layers=True,
|
||||||
|
freeze_batch_norm=True,
|
||||||
|
base_model_type=ImageClassModels.DENSENET201,
|
||||||
dense_layer_neurons=1024,
|
dense_layer_neurons=1024,
|
||||||
dropout_rate=.5,
|
dropout_rate=.5,
|
||||||
)
|
)
|
||||||
@@ -141,17 +181,17 @@ if __name__ == "__main__":
|
|||||||
train_gen = get_gen('./data/train', dataset_type=DatasetType.TRAIN)
|
train_gen = get_gen('./data/train', dataset_type=DatasetType.TRAIN)
|
||||||
val_gen = get_gen('./data/val', dataset_type=DatasetType.VAL)
|
val_gen = get_gen('./data/val', dataset_type=DatasetType.VAL)
|
||||||
test_gen = get_gen('./data/test', dataset_type=DatasetType.TEST)
|
test_gen = get_gen('./data/test', dataset_type=DatasetType.TEST)
|
||||||
model = train_model(model, model_name, train_gen, val_gen)
|
model = train_model(model, model_name, train_gen, val_gen, 1)
|
||||||
for layer in model.layers[2].layers:
|
# for layer in model.layers[2].layers:
|
||||||
if not isinstance(layer, keras.layers.BatchNormalization):
|
# if not isinstance(layer, keras.layers.BatchNormalization):
|
||||||
layer.trainable = True
|
# layer.trainable = True
|
||||||
model.layers[2].trainable = True
|
# model.layers[2].trainable = True
|
||||||
print(model)
|
# print(model)
|
||||||
model.compile(
|
# model.compile(
|
||||||
optimizer=keras.optimizers.Adam(learning_rate=.00001),
|
# optimizer=keras.optimizers.Adam(learning_rate=.00001),
|
||||||
loss=keras.losses.CategoricalCrossentropy(),
|
# loss=keras.losses.CategoricalCrossentropy(),
|
||||||
metrics=['accuracy', 'categorical_crossentropy']
|
# metrics=['accuracy', 'categorical_crossentropy']
|
||||||
)
|
# )
|
||||||
model.summary()
|
# model.summary()
|
||||||
model = train_model(model, model_name + "-second_stage", train_gen, val_gen)
|
# model = train_model(model, model_name + "-second_stage", train_gen, val_gen, 1)
|
||||||
test_model(model, test_gen)
|
# test_model(model, test_gen)
|
||||||
|
|||||||
+1
-2
@@ -24,7 +24,7 @@ test_gen = ImageDataGenerator().flow_from_directory(
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=False
|
shuffle=False
|
||||||
)
|
)
|
||||||
#
|
|
||||||
single_gen = ImageDataGenerator().flow_from_directory(
|
single_gen = ImageDataGenerator().flow_from_directory(
|
||||||
'./single_image_test_set',
|
'./single_image_test_set',
|
||||||
target_size=(input_shape[0], input_shape[1]),
|
target_size=(input_shape[0], input_shape[1]),
|
||||||
@@ -32,7 +32,6 @@ single_gen = ImageDataGenerator().flow_from_directory(
|
|||||||
shuffle=False
|
shuffle=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
for file in glob("./models/keras/*.hdf5"):
|
for file in glob("./models/keras/*.hdf5"):
|
||||||
print(file)
|
print(file)
|
||||||
if file in metrics_df.values:
|
if file in metrics_df.values:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from keras.preprocessing.image import ImageDataGenerator
|
from keras.preprocessing.image import ImageDataGenerator
|
||||||
from tensorflow import keras
|
import tensorflow as tf
|
||||||
|
|
||||||
# TODO: Move these to a config for the project
|
# TODO: Move these to a config for the project
|
||||||
input_shape = (224, 224, 3)
|
input_shape = (224, 224, 3)
|
||||||
@@ -25,16 +25,19 @@ for file in glob("./models/keras/*.hdf5"):
|
|||||||
path = Path(file)
|
path = Path(file)
|
||||||
tflite_file = f'./models/tflite/models/{path.name[:-5] + ".tflite"}'
|
tflite_file = f'./models/tflite/models/{path.name[:-5] + ".tflite"}'
|
||||||
if not Path(tflite_file).exists():
|
if not Path(tflite_file).exists():
|
||||||
keras_model = keras.models.load_model(file)
|
print(tflite_file)
|
||||||
|
keras_model = tf.keras.models.load_model(file)
|
||||||
|
keras_model.summary()
|
||||||
|
print(keras_model.input)
|
||||||
|
print(keras_model.layers)
|
||||||
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
|
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
with open(tflite_file, 'wb') as f:
|
with open(tflite_file, 'wb') as f:
|
||||||
f.write(tflite_model)
|
f.write(tflite_model)
|
||||||
# TODO: Verify the model performance after converting to TFLITE
|
# TODO: Verify the model performance after converting to TFLITE
|
||||||
# interpreter = tf.lite.Interpreter(model_path=tflite_file)
|
# interpreter = tf.lite.Interpreter(model_path=tflite_file)
|
||||||
# single_acc, single_ll = get_metrics(single_gen, keras_model)
|
# single_acc, single_ll = get_metrics(single_gen, keras_model)
|
||||||
# tf_single_acc, tf_single_ll = get_metrics(single_gen, tflite_model)
|
# tf_single_acc, tf_single_ll = get_metrics(single_gen, tflite_model)
|
||||||
#
|
#
|
||||||
# print(single_acc, tf_single_acc)
|
# print(single_acc, tf_single_acc)
|
||||||
# print(single_ll, tf_single_ll)
|
# print(single_ll, tf_single_ll)
|
||||||
|
|||||||
@@ -2,4 +2,6 @@ model,test_acc,test_loss,single_acc,single_loss
|
|||||||
./models/keras\pt-fl-fbn-efficientnet_v2b0-d1024-do0.5-l11.e-04-l21.e-04-5224-second_stage.hdf5,0.6720150708068079,1.7423864365349095,0.9893048128342246,0.4364729183409372
|
./models/keras\pt-fl-fbn-efficientnet_v2b0-d1024-do0.5-l11.e-04-l21.e-04-5224-second_stage.hdf5,0.6720150708068079,1.7423864365349095,0.9893048128342246,0.4364729183409372
|
||||||
./models/keras\pt-fl-fbn-efficientnet_v2b0-d1024-do0.5-l11.e-04-l21.e-04-5224.hdf5,0.410029881772119,3.346152696366266,0.986096256684492,0.3234976000776315
|
./models/keras\pt-fl-fbn-efficientnet_v2b0-d1024-do0.5-l11.e-04-l21.e-04-5224.hdf5,0.410029881772119,3.346152696366266,0.986096256684492,0.3234976000776315
|
||||||
./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105-second_stage.hdf5,0.6850721060153306,1.675868156533777,0.9967914438502674,0.3373779159304851
|
./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105-second_stage.hdf5,0.6850721060153306,1.675868156533777,0.9967914438502674,0.3373779159304851
|
||||||
./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105.hdf5,0.37553592308691697,3.5500588697038067,0.9540106951871657,0.47270425785037834
|
./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105.hdf5,0.3755359230869169,3.5500588697038067,0.9540106951871656,0.4727042578503783
|
||||||
|
./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-9317-second_stage.hdf5,0.6121780461172843,2.197206965588216,0.9946581196581196,0.2974041509252359
|
||||||
|
./models/keras\pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-9317.hdf5,0.3702228787976106,3.601324427207316,0.9594017094017094,0.4877960320956891
|
||||||
|
|||||||
|
@@ -1,3 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import os
|
|
||||||
|
|
||||||
@@ -15,11 +15,21 @@ class ImageClassModels(Enum):
|
|||||||
keras.applications.inception_v3.preprocess_input,
|
keras.applications.inception_v3.preprocess_input,
|
||||||
"inception_v3"
|
"inception_v3"
|
||||||
)
|
)
|
||||||
|
INCEPTION_RESNET_V2 = ModelWrapper(
|
||||||
|
keras.applications.inception_resnet_v2.InceptionResNetV2,
|
||||||
|
keras.applications.inception_resnet_v2.preprocess_input,
|
||||||
|
"inception_resnet_v2"
|
||||||
|
)
|
||||||
XCEPTION = ModelWrapper(
|
XCEPTION = ModelWrapper(
|
||||||
keras.applications.xception.Xception,
|
keras.applications.xception.Xception,
|
||||||
keras.applications.xception.preprocess_input,
|
keras.applications.xception.preprocess_input,
|
||||||
"xception"
|
"xception"
|
||||||
)
|
)
|
||||||
|
DENSENET201 = ModelWrapper(
|
||||||
|
keras.applications.densenet.DenseNet201,
|
||||||
|
keras.applications.densenet.preprocess_input,
|
||||||
|
"densenet201"
|
||||||
|
)
|
||||||
MOBILENET_V2 = ModelWrapper(
|
MOBILENET_V2 = ModelWrapper(
|
||||||
keras.applications.mobilenet_v2.MobileNetV2,
|
keras.applications.mobilenet_v2.MobileNetV2,
|
||||||
keras.applications.mobilenet_v2.preprocess_input,
|
keras.applications.mobilenet_v2.preprocess_input,
|
||||||
@@ -34,7 +44,6 @@ class ImageClassModels(Enum):
|
|||||||
keras.applications.efficientnet_v2.EfficientNetV2B0,
|
keras.applications.efficientnet_v2.EfficientNetV2B0,
|
||||||
tf.keras.applications.efficientnet_v2.preprocess_input,
|
tf.keras.applications.efficientnet_v2.preprocess_input,
|
||||||
"efficientnet_v2b0"
|
"efficientnet_v2b0"
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from collections import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
class ModelWrapper(object):
|
class ModelWrapper(object):
|
||||||
def __init__(self, model_func:Callable, model_preprocessor:Callable, name:str):
|
def __init__(self, model_func:Callable, model_preprocessor:Callable, name:str):
|
||||||
|
|||||||
Binary file not shown.
@@ -1,4 +1,5 @@
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
df = pd.read_csv("models/keras/pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105-second_stage.csv")
|
df = pd.read_csv("models/keras/pt-fl-fbn-efficientnet_v2s-d1024-do0.5-l11.e-04-l21.e-04-8105-second_stage.csv")
|
||||||
|
|
||||||
print(df.loc[df["prediction"] != df["true_val"]])
|
print(df.loc[df["prediction"] != df["true_val"]])
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ for index, row in df2.iterrows():
|
|||||||
incorrect = df[df["prediction"]!= df["true_val"]]
|
incorrect = df[df["prediction"]!= df["true_val"]]
|
||||||
|
|
||||||
total_same_fam = 0
|
total_same_fam = 0
|
||||||
# TODO: Add in support for figuring out if the pokemon are related/evolutions of one another
|
|
||||||
for index, row in incorrect.iterrows():
|
for index, row in incorrect.iterrows():
|
||||||
img = mpimg.imread("./SingleImageTestSet/" + row['fname'])
|
img = mpimg.imread("./SingleImageTestSet/" + row['fname'])
|
||||||
imgplot = plt.imshow(img)
|
imgplot = plt.imshow(img)
|
||||||
|
|||||||
Binary file not shown.
|
After Width: | Height: | Size: 90 KiB |
Reference in New Issue
Block a user