From 1dc7c2dee2ad02d3880faea76bc84cef55fd6aac Mon Sep 17 00:00:00 2001 From: Lucas Oskorep Date: Wed, 17 Jul 2019 12:30:16 -0500 Subject: [PATCH] Updates....updates everywhere --- 1 - ImageGatherer.py | 33 +++++++++++++++++++-------------- 3 - TestTrainSplit.py | 4 ++-- 4 - TrainingModelKeras.py | 10 +++++++--- 4 - TransferLearningKeras.py | 2 +- 4 files changed, 29 insertions(+), 20 deletions(-) diff --git a/1 - ImageGatherer.py b/1 - ImageGatherer.py index 2fd9932..932193c 100644 --- a/1 - ImageGatherer.py +++ b/1 - ImageGatherer.py @@ -5,6 +5,9 @@ import json from pprint import pprint from google_images_download import google_images_download +total_per = 10 +form_increment = 1 + def create_forms_dict(df): poke_dict = {} @@ -39,28 +42,30 @@ def process_pokemon_names(df): pprint(poke_dict) pokes_to_limits = [] for pokemon, form_list in poke_dict.items(): - if len(form_list) == 0: - print(pokemon) - pokes_to_limits.append((pokemon, 200)) + print(pokemon) + num_forms = len(form_list) + if num_forms == 0: + pokes_to_limits.append((pokemon, total_per)) - elif len(form_list) == 1: - pokes_to_limits.append((pokemon, 150)) - pokes_to_limits.append((search_term(form_list[0]), 50)) + elif num_forms == 1: + pokes_to_limits.append((pokemon, total_per - form_increment)) + pokes_to_limits.append((search_term(form_list[0]), form_increment)) - elif len(form_list) == 2: - pokes_to_limits.append((pokemon, 100)) + elif num_forms == 2: + pokes_to_limits.append((pokemon, total_per - form_increment * num_forms)) for form in form_list: - pokes_to_limits.append((search_term(form), 50)) + pokes_to_limits.append((search_term(form), form_increment)) - elif len(form_list) >= 3: + elif num_forms >= 3: + revised_increment = int(total_per / len(form_list)) for form in form_list: - pokes_to_limits.append((search_term(form), int(200 / len(form_list)))) + pokes_to_limits.append((pokemon, total_per - revised_increment * num_forms)) + + pokes_to_limits.append((search_term(form), revised_increment)) return pokes_to_limits -import os - def get_images_for_pokemon(poke_to_limit): pokemon = poke_to_limit[0] @@ -69,7 +74,7 @@ def get_images_for_pokemon(poke_to_limit): response.download( { "keywords": pokemon + " pokemon", - "limit": 1,#limit, + "limit": limit, "chromedriver": "chromedriver" # Add chromedriver to your path or just point this var directly to your chromedriverv } diff --git a/3 - TestTrainSplit.py b/3 - TestTrainSplit.py index a1bd120..1337905 100644 --- a/3 - TestTrainSplit.py +++ b/3 - TestTrainSplit.py @@ -6,8 +6,8 @@ import multiprocessing train_dir = "./data/train/" test_dir = "./data/test/" val_dir = "./data/val/" -train = .80 -test = .15 +train = .5 +test = .5 val = .05 diff --git a/4 - TrainingModelKeras.py b/4 - TrainingModelKeras.py index 09ea00a..12a6514 100644 --- a/4 - TrainingModelKeras.py +++ b/4 - TrainingModelKeras.py @@ -11,6 +11,7 @@ from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard from keras.layers import Dense, Dropout, GlobalAveragePooling2D from keras.models import Sequential from keras.preprocessing.image import ImageDataGenerator +from keras.utils import multi_gpu_model from sklearn.metrics import accuracy_score, confusion_matrix, classification_report @@ -22,7 +23,7 @@ from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True input_shape = (224, 224, 3) -batch_size = 96 +batch_size = 32 model_name = "mobilenet-fixed-data" @@ -53,7 +54,7 @@ val_idg = ImageDataGenerator( ) val_gen = val_idg.flow_from_directory( - './data/val', + './data/test', target_size=(input_shape[0], input_shape[1]), batch_size=batch_size ) @@ -103,7 +104,10 @@ add_model.add(Dropout(0.5)) add_model.add(Dense(512, activation='relu')) add_model.add(Dense(len(train_gen.class_indices), activation='softmax')) # Decision layer -model = add_model +#TODO: Add in gpu support +model = multi_gpu_model(add_model, 2) +# model = add_model + model.compile(loss='categorical_crossentropy', # optimizer=optimizers.SGD(lr=1e-4, momentum=0.9), optimizer=optimizers.Adam(lr=1e-4), diff --git a/4 - TransferLearningKeras.py b/4 - TransferLearningKeras.py index af0b81f..3f0ffdc 100644 --- a/4 - TransferLearningKeras.py +++ b/4 - TransferLearningKeras.py @@ -8,7 +8,7 @@ from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True -input_shape = (244, 244, 3) +input_shape = (224, 224, 3) batch_size = 60 model_name = "MobileNetV2FullDataset"