diff --git a/1 - ImageGatherer.py b/1 - ImageGatherer.py index 4e92693..932193c 100644 --- a/1 - ImageGatherer.py +++ b/1 - ImageGatherer.py @@ -5,6 +5,9 @@ import json from pprint import pprint from google_images_download import google_images_download +total_per = 10 +form_increment = 1 + def create_forms_dict(df): poke_dict = {} @@ -39,22 +42,26 @@ def process_pokemon_names(df): pprint(poke_dict) pokes_to_limits = [] for pokemon, form_list in poke_dict.items(): - if len(form_list) == 0: - print(pokemon) - pokes_to_limits.append((pokemon, 200)) + print(pokemon) + num_forms = len(form_list) + if num_forms == 0: + pokes_to_limits.append((pokemon, total_per)) - elif len(form_list) == 1: - pokes_to_limits.append((pokemon, 150)) - pokes_to_limits.append((search_term(form_list[0]), 50)) + elif num_forms == 1: + pokes_to_limits.append((pokemon, total_per - form_increment)) + pokes_to_limits.append((search_term(form_list[0]), form_increment)) - elif len(form_list) == 2: - pokes_to_limits.append((pokemon, 100)) + elif num_forms == 2: + pokes_to_limits.append((pokemon, total_per - form_increment * num_forms)) for form in form_list: - pokes_to_limits.append((search_term(form), 50)) + pokes_to_limits.append((search_term(form), form_increment)) - elif len(form_list) >= 3: + elif num_forms >= 3: + revised_increment = int(total_per / len(form_list)) for form in form_list: - pokes_to_limits.append((search_term(form), int(200 / len(form_list)))) + pokes_to_limits.append((pokemon, total_per - revised_increment * num_forms)) + + pokes_to_limits.append((search_term(form), revised_increment)) return pokes_to_limits diff --git a/3 - TestTrainSplit.py b/3 - TestTrainSplit.py index b99f8a1..71b13ff 100644 --- a/3 - TestTrainSplit.py +++ b/3 - TestTrainSplit.py @@ -1,6 +1,6 @@ import os from random import random -from shutil import copyfile, rmtree +from shutil import rmtree from pathlib import Path import multiprocessing diff --git a/4 - TrainingModelKeras.py b/4 - TrainingModelKeras.py index 0c6e57f..9f492c2 100644 --- a/4 - TrainingModelKeras.py +++ b/4 - TrainingModelKeras.py @@ -11,6 +11,7 @@ from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard from keras.layers import Dense, Dropout, GlobalAveragePooling2D from keras.models import Sequential from keras.preprocessing.image import ImageDataGenerator +from keras.utils import multi_gpu_model from sklearn.metrics import accuracy_score, confusion_matrix, classification_report @@ -22,7 +23,7 @@ from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True input_shape = (224, 224, 3) -batch_size = 96 +batch_size = 32 model_name = "mobilenet-fixed-data" @@ -53,7 +54,7 @@ val_idg = ImageDataGenerator( ) val_gen = val_idg.flow_from_directory( - './data/val', + './data/test', target_size=(input_shape[0], input_shape[1]), batch_size=batch_size ) @@ -102,7 +103,10 @@ add_model.add(Dropout(0.5)) add_model.add(Dense(512, activation='relu')) add_model.add(Dense(len(train_gen.class_indices), activation='softmax')) # Decision layer -model = add_model +#TODO: Add in gpu support +model = multi_gpu_model(add_model, 2) +# model = add_model + model.compile(loss='categorical_crossentropy', # optimizer=optimizers.SGD(lr=1e-4, momentum=0.9), optimizer=optimizers.Adam(lr=1e-4),