Merge remote-tracking branch 'origin/master'
# Conflicts: # 1 - ImageGatherer.py # 3 - TestTrainSplit.py # 4 - TransferLearningKeras.py
This commit is contained in:
+18
-11
@@ -5,6 +5,9 @@ import json
|
|||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from google_images_download import google_images_download
|
from google_images_download import google_images_download
|
||||||
|
|
||||||
|
total_per = 10
|
||||||
|
form_increment = 1
|
||||||
|
|
||||||
|
|
||||||
def create_forms_dict(df):
|
def create_forms_dict(df):
|
||||||
poke_dict = {}
|
poke_dict = {}
|
||||||
@@ -39,22 +42,26 @@ def process_pokemon_names(df):
|
|||||||
pprint(poke_dict)
|
pprint(poke_dict)
|
||||||
pokes_to_limits = []
|
pokes_to_limits = []
|
||||||
for pokemon, form_list in poke_dict.items():
|
for pokemon, form_list in poke_dict.items():
|
||||||
if len(form_list) == 0:
|
print(pokemon)
|
||||||
print(pokemon)
|
num_forms = len(form_list)
|
||||||
pokes_to_limits.append((pokemon, 200))
|
if num_forms == 0:
|
||||||
|
pokes_to_limits.append((pokemon, total_per))
|
||||||
|
|
||||||
elif len(form_list) == 1:
|
elif num_forms == 1:
|
||||||
pokes_to_limits.append((pokemon, 150))
|
pokes_to_limits.append((pokemon, total_per - form_increment))
|
||||||
pokes_to_limits.append((search_term(form_list[0]), 50))
|
pokes_to_limits.append((search_term(form_list[0]), form_increment))
|
||||||
|
|
||||||
elif len(form_list) == 2:
|
elif num_forms == 2:
|
||||||
pokes_to_limits.append((pokemon, 100))
|
pokes_to_limits.append((pokemon, total_per - form_increment * num_forms))
|
||||||
for form in form_list:
|
for form in form_list:
|
||||||
pokes_to_limits.append((search_term(form), 50))
|
pokes_to_limits.append((search_term(form), form_increment))
|
||||||
|
|
||||||
elif len(form_list) >= 3:
|
elif num_forms >= 3:
|
||||||
|
revised_increment = int(total_per / len(form_list))
|
||||||
for form in form_list:
|
for form in form_list:
|
||||||
pokes_to_limits.append((search_term(form), int(200 / len(form_list))))
|
pokes_to_limits.append((pokemon, total_per - revised_increment * num_forms))
|
||||||
|
|
||||||
|
pokes_to_limits.append((search_term(form), revised_increment))
|
||||||
|
|
||||||
return pokes_to_limits
|
return pokes_to_limits
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from random import random
|
from random import random
|
||||||
from shutil import copyfile, rmtree
|
from shutil import rmtree
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard
|
|||||||
from keras.layers import Dense, Dropout, GlobalAveragePooling2D
|
from keras.layers import Dense, Dropout, GlobalAveragePooling2D
|
||||||
from keras.models import Sequential
|
from keras.models import Sequential
|
||||||
from keras.preprocessing.image import ImageDataGenerator
|
from keras.preprocessing.image import ImageDataGenerator
|
||||||
|
from keras.utils import multi_gpu_model
|
||||||
|
|
||||||
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
|
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
|
||||||
|
|
||||||
@@ -22,7 +23,7 @@ from PIL import ImageFile
|
|||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
input_shape = (224, 224, 3)
|
input_shape = (224, 224, 3)
|
||||||
batch_size = 96
|
batch_size = 32
|
||||||
|
|
||||||
model_name = "mobilenet-fixed-data"
|
model_name = "mobilenet-fixed-data"
|
||||||
|
|
||||||
@@ -53,7 +54,7 @@ val_idg = ImageDataGenerator(
|
|||||||
)
|
)
|
||||||
|
|
||||||
val_gen = val_idg.flow_from_directory(
|
val_gen = val_idg.flow_from_directory(
|
||||||
'./data/val',
|
'./data/test',
|
||||||
target_size=(input_shape[0], input_shape[1]),
|
target_size=(input_shape[0], input_shape[1]),
|
||||||
batch_size=batch_size
|
batch_size=batch_size
|
||||||
)
|
)
|
||||||
@@ -102,7 +103,10 @@ add_model.add(Dropout(0.5))
|
|||||||
add_model.add(Dense(512, activation='relu'))
|
add_model.add(Dense(512, activation='relu'))
|
||||||
add_model.add(Dense(len(train_gen.class_indices), activation='softmax')) # Decision layer
|
add_model.add(Dense(len(train_gen.class_indices), activation='softmax')) # Decision layer
|
||||||
|
|
||||||
model = add_model
|
#TODO: Add in gpu support
|
||||||
|
model = multi_gpu_model(add_model, 2)
|
||||||
|
# model = add_model
|
||||||
|
|
||||||
model.compile(loss='categorical_crossentropy',
|
model.compile(loss='categorical_crossentropy',
|
||||||
# optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
|
# optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
|
||||||
optimizer=optimizers.Adam(lr=1e-4),
|
optimizer=optimizers.Adam(lr=1e-4),
|
||||||
|
|||||||
Reference in New Issue
Block a user