initial trainer
This commit is contained in:
111
CustomCallback.py
Normal file
111
CustomCallback.py
Normal file
@ -0,0 +1,111 @@
|
||||
from streamlit.proto.Progress_pb2 import Progress
|
||||
from tensorflow import keras
|
||||
import streamlit as st
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class CustomCallback(keras.callbacks.Callback):
|
||||
percent_complete = 0
|
||||
data_frame = pd.DataFrame()
|
||||
|
||||
def __init__(self, element, pregress, metrics):
|
||||
self.element = element
|
||||
self.progress = st.progress(0)
|
||||
self.metrics = st.line_chart(self.data_frame)
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "Starting training"
|
||||
self.element.markdown(" %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
#print("Starting training; got log keys: {}".format(keys))
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "Finished training"
|
||||
self.progress.progress(100)
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "Start epoch {} of training"
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
self.data_frame = pd.DataFrame()
|
||||
self.percent_complete = 0
|
||||
#st.write(keystring)
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "End epoch {} of training ".format(epoch)
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write()
|
||||
|
||||
def on_test_begin(self, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "Start testing"
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
|
||||
def on_test_end(self, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "Stop testing"
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
|
||||
def on_predict_begin(self, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "Start predicting"
|
||||
#self.element.markdown("*Training %s*..." % keystring)
|
||||
#st.write()
|
||||
|
||||
def on_predict_end(self, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring ="Stop predicting; got log keys"
|
||||
#self.element.markdown("*Training %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
|
||||
def on_train_batch_begin(self, batch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "...Training: start of batch {}".format(batch)
|
||||
if self.percent_complete == 100:
|
||||
self.percent_complete = self.percent_complete - 20
|
||||
return
|
||||
self.percent_complete = self.percent_complete + 1
|
||||
#st.write(self.percent_complete)
|
||||
self.progress.progress(self.percent_complete)
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "...Training: end of batch {}".format(batch)
|
||||
#st.write(logs)
|
||||
self.data_frame = self.data_frame.append({'loss':logs['loss'], 'accuracy':logs['accuracy']}, ignore_index=True)
|
||||
self.metrics.line_chart(self.data_frame)
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
|
||||
def on_test_batch_begin(self, batch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "...Evaluating: start of batch {}".format(batch)
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write()
|
||||
|
||||
def on_test_batch_end(self, batch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "...Evaluating: end of batch {}".format(batch)
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write()
|
||||
|
||||
def on_predict_batch_begin(self, batch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "...Predicting: start of batch {}".format(batch)
|
||||
#self.element.markdown("*Training %s*..." % keystring)
|
||||
#st.write()
|
||||
|
||||
def on_predict_batch_end(self, batch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "...Predicting: end of batch {}".format(batch)
|
||||
#self.element.markdown("*Training %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
132
requirements.txt
Normal file
132
requirements.txt
Normal file
@ -0,0 +1,132 @@
|
||||
absl-py==0.13.0
|
||||
altair==4.1.0
|
||||
appnope==0.1.2
|
||||
argon2-cffi==21.1.0
|
||||
astor==0.8.1
|
||||
astunparse==1.6.3
|
||||
attrs==21.2.0
|
||||
backcall==0.2.0
|
||||
base58==2.1.0
|
||||
bleach==4.1.0
|
||||
blinker==1.4
|
||||
cachetools==4.2.2
|
||||
certifi==2021.5.30
|
||||
cffi==1.14.6
|
||||
charset-normalizer==2.0.6
|
||||
clang==5.0
|
||||
click==7.1.2
|
||||
cycler==0.10.0
|
||||
debugpy==1.4.3
|
||||
decorator==5.1.0
|
||||
defusedxml==0.7.1
|
||||
entrypoints==0.3
|
||||
flatbuffers==1.12
|
||||
gast==0.4.0
|
||||
gitdb==4.0.7
|
||||
GitPython==3.1.24
|
||||
google-auth==1.35.0
|
||||
google-auth-oauthlib==0.4.6
|
||||
google-pasta==0.2.0
|
||||
grpcio==1.40.0
|
||||
h5py==3.1.0
|
||||
idna==3.2
|
||||
ipykernel==6.4.1
|
||||
ipython==7.27.0
|
||||
ipython-genutils==0.2.0
|
||||
ipywidgets==7.6.5
|
||||
jedi==0.18.0
|
||||
Jinja2==3.0.1
|
||||
joblib==1.0.1
|
||||
jsonschema==3.2.0
|
||||
jupyter==1.0.0
|
||||
jupyter-client==7.0.3
|
||||
jupyter-console==6.4.0
|
||||
jupyter-core==4.8.1
|
||||
jupyterlab-pygments==0.1.2
|
||||
jupyterlab-widgets==1.0.2
|
||||
keras==2.6.0
|
||||
Keras-Preprocessing==1.1.2
|
||||
keras-tuner==1.0.4
|
||||
kiwisolver==1.3.2
|
||||
kt-legacy==1.0.4
|
||||
lxml==4.6.3
|
||||
Markdown==3.3.4
|
||||
MarkupSafe==2.0.1
|
||||
matplotlib==3.4.3
|
||||
matplotlib-inline==0.1.3
|
||||
mistune==0.8.4
|
||||
nbclient==0.5.4
|
||||
nbconvert==6.1.0
|
||||
nbformat==5.1.3
|
||||
nest-asyncio==1.5.1
|
||||
notebook==6.4.4
|
||||
numpy==1.19.5
|
||||
oauthlib==3.1.1
|
||||
opencv-contrib-python==4.5.3.56
|
||||
opencv-python==4.5.3.56
|
||||
opt-einsum==3.3.0
|
||||
packaging==21.0
|
||||
pandas==1.3.3
|
||||
pandocfilters==1.5.0
|
||||
parso==0.8.2
|
||||
pexpect==4.8.0
|
||||
pickleshare==0.7.5
|
||||
Pillow==8.3.2
|
||||
prometheus-client==0.11.0
|
||||
prompt-toolkit==3.0.20
|
||||
protobuf==3.18.0
|
||||
ptyprocess==0.7.0
|
||||
pyarrow==5.0.0
|
||||
pyasn1==0.4.8
|
||||
pyasn1-modules==0.2.8
|
||||
pycparser==2.20
|
||||
pydeck==0.7.0
|
||||
pydot==1.4.2
|
||||
Pygments==2.10.0
|
||||
pyparsing==2.4.7
|
||||
PyQt5==5.15.4
|
||||
PyQt5-Qt5==5.15.2
|
||||
PyQt5-sip==12.9.0
|
||||
pyrsistent==0.18.0
|
||||
python-dateutil==2.8.2
|
||||
pytz==2021.1
|
||||
PyYAML==5.4.1
|
||||
pyzmq==22.3.0
|
||||
qtconsole==5.1.1
|
||||
QtPy==1.11.1
|
||||
requests==2.26.0
|
||||
requests-oauthlib==1.3.0
|
||||
rsa==4.7.2
|
||||
scikit-learn==0.24.2
|
||||
scipy==1.7.1
|
||||
seaborn==0.11.2
|
||||
Send2Trash==1.8.0
|
||||
six==1.15.0
|
||||
smmap==4.0.0
|
||||
split-folders==0.4.3
|
||||
streamlit==1.0.0
|
||||
tensorboard==2.6.0
|
||||
tensorboard-data-server==0.6.1
|
||||
tensorboard-plugin-wit==1.8.0
|
||||
tensorflow==2.6.0
|
||||
tensorflow-docs==0.0.0.dev0
|
||||
tensorflow-estimator==2.6.0
|
||||
termcolor==1.1.0
|
||||
terminado==0.12.1
|
||||
testpath==0.5.0
|
||||
Theano==1.0.5
|
||||
threadpoolctl==2.2.0
|
||||
toml==0.10.2
|
||||
toolz==0.11.1
|
||||
tornado==6.1
|
||||
traitlets==5.1.0
|
||||
typing-extensions==3.7.4.3
|
||||
tzlocal==3.0
|
||||
urllib3==1.26.6
|
||||
validators==0.18.2
|
||||
watchdog==2.1.6
|
||||
wcwidth==0.2.5
|
||||
webencodings==0.5.1
|
||||
Werkzeug==2.0.1
|
||||
widgetsnbextension==3.5.1
|
||||
wrapt==1.12.1
|
||||
263
trainer.py
Normal file
263
trainer.py
Normal file
@ -0,0 +1,263 @@
|
||||
from logging import exception
|
||||
from keras.backend import constant
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras import layers
|
||||
from keras.preprocessing.image import load_img, ImageDataGenerator, array_to_img, img_to_array
|
||||
from tensorflow.keras.models import load_model
|
||||
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input, decode_predictions
|
||||
from tensorflow.keras.models import Model
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from tensorflow.python.keras.callbacks import History
|
||||
from CustomCallback import CustomCallback
|
||||
import os, shutil
|
||||
from os import path
|
||||
import splitfolders
|
||||
import time
|
||||
|
||||
|
||||
|
||||
########## Augmentation with VGG19 #############
|
||||
def tranferLearningVGG(model, train_dir, test_dir):
|
||||
model = VGG19(include_top=False, input_shape=(224,224,3))
|
||||
|
||||
train_datagen = ImageDataGenerator(
|
||||
rescale=1/255,
|
||||
horizontal_flip=True,
|
||||
rotation_range=20,
|
||||
width_shift_range=0.2,
|
||||
height_shift_range=0.2
|
||||
)
|
||||
test_datagen = ImageDataGenerator(
|
||||
rescale=1/255,
|
||||
)
|
||||
|
||||
for layer in model.layers:
|
||||
layer.trainable = False
|
||||
flatten_layer = layers.Flatten()(model.output)
|
||||
|
||||
flatten_fully_connected_layer = layers.Dense(512, activation='relu')(flatten_layer)
|
||||
flatten_fully_connected_softmax_layer =layers.Dense(5, activation='softmax')(flatten_fully_connected_layer)
|
||||
model = Model(inputs=model.inputs, outputs=flatten_fully_connected_softmax_layer)
|
||||
|
||||
training_iterater = train_datagen.flow_from_directory(train_dir, batch_size=64, target_size=(224,224))
|
||||
test_iterater = test_datagen.flow_from_directory(test_dir, batch_size=64, target_size=(224,224))
|
||||
model.compile(loss="categorical_crossentropy", metrics=['accuracy'], optimizer='adam')
|
||||
model.save('models/vgg19_tl.h5')
|
||||
history = model.fit(training_iterater, validation_data=test_iterater, epochs=4)
|
||||
return history
|
||||
|
||||
|
||||
############ Model creation ################
|
||||
def buildSimpleModel():
|
||||
model = keras.Sequential()
|
||||
model.add(layers.Conv2D(32, (3,3), activation='relu', input_shape=(100,100,3)))
|
||||
model.add(layers.MaxPooling2D((2,2)))
|
||||
|
||||
model.add(layers.Conv2D(64, (3,3), activation='relu'))
|
||||
model.add(layers.MaxPooling2D((2,2)))
|
||||
|
||||
model.add(layers.Conv2D(128, (3,3), activation='relu'))
|
||||
model.add(layers.MaxPooling2D((2,2)))
|
||||
|
||||
model.add(layers.Flatten())
|
||||
|
||||
model.add(layers.Dense(512, activation='relu'))
|
||||
model.add(layers.Dense(5, activation='softmax'))
|
||||
return model
|
||||
|
||||
###### preapre iterater ########
|
||||
def prepareIterater(folder_path, batch_size, img_size):
|
||||
train_datagen = ImageDataGenerator(rescale=1)
|
||||
test_datagen = ImageDataGenerator(rescale=1)
|
||||
training_iterater = train_datagen.flow_from_directory(os.path.join(folder_path, 'train'), batch_size=batch_size, target_size=img_size)
|
||||
test_iterater = test_datagen.flow_from_directory(os.path.join(folder_path, 'test'), batch_size=batch_size, target_size=img_size)
|
||||
return training_iterater, test_iterater
|
||||
|
||||
|
||||
########## start trining ##############
|
||||
def trainSimplaeModel(model, epochs, tran_iterater, test_iterater, model_name, text_output, progrss_bar, graph):
|
||||
|
||||
# train_datagen = ImageDataGenerator(rescale=1)
|
||||
# test_datagen = ImageDataGenerator(rescale=1)
|
||||
# training_iterater = train_datagen.flow_from_directory(tran_dir, batch_size=64, target_size=(100,100))
|
||||
# test_iterater = test_datagen.flow_from_directory(test_dir, batch_size=64, target_size=(100,100))
|
||||
model.compile(loss="categorical_crossentropy", metrics=['accuracy'], optimizer='adam')
|
||||
history = model.fit(tran_iterater, validation_data=test_iterater, epochs=epochs, callbacks=[CustomCallback(text_output, progrss_bar, graph)],)
|
||||
model_path = os.path.join(st.session_state['output_folder'], model_name)
|
||||
st.session_state['model_folder'] = st.session_state["output_folder"]
|
||||
#st.write(model_path)
|
||||
model.save(model_path)
|
||||
return history
|
||||
|
||||
|
||||
######### plot #########
|
||||
def plotHistory(history):
|
||||
plt.plot(history.history['loss'])
|
||||
plt.plot(history.history['val_loss'])
|
||||
plt.title("model loss")
|
||||
plt.ylabel("loss")
|
||||
plt.xlabel("no of epochs")
|
||||
plt.legend(['training', 'testing'], loc='upper left')
|
||||
plt.show()
|
||||
|
||||
|
||||
####### print class labels #######
|
||||
def printLabels(training_iterater):
|
||||
class_labels = training_iterater.class_indices
|
||||
print(class_labels)
|
||||
|
||||
|
||||
####### predict ########
|
||||
def predict(modelname):
|
||||
mdl = load_model('models/'+ modelname)
|
||||
img = load_img('dataset/flowers/sunflower.jpeg', target_size=(100,100))
|
||||
img = img_to_array(img)
|
||||
img = img.reshape(1,100,100,3)
|
||||
res = mdl.predict(img)
|
||||
print(res)
|
||||
print(np.argmax(res))
|
||||
print([key for key in class_labels][np.argmax(res)])
|
||||
|
||||
def getModelNames(model_folder):
|
||||
if model_folder == "":
|
||||
return []
|
||||
param = []
|
||||
length =len([name for name in os.listdir('.') if os.path.isfile(model_folder)])
|
||||
for entry in os.scandir(model_folder):
|
||||
if entry.name.lower().endswith('.h5'):
|
||||
param.append(entry.name)
|
||||
return param
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
##### prepare image folders ######
|
||||
def prepareFolders(folder_path, output_folder):
|
||||
file_path = st.session_state['output_folder']
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
splitfolders.ratio(folder_path, output=output_folder, seed=1337, ratio=(.8, 0.1,0.1))
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
##### handle change on image input folder #####
|
||||
def handlerImageFolderChanged():
|
||||
output_folder = st.session_state["output_folder"]
|
||||
|
||||
def handlePrepare():
|
||||
prepareFolders(raw_image_folder, st.session_state['output_folder'])
|
||||
|
||||
def handSelectModel():
|
||||
st.write('#####')
|
||||
# if 'model_name' in st.session_state:
|
||||
# st.write(st.session_state['model_name'])
|
||||
|
||||
|
||||
raw_image_folder = "/Users/mohamednouffer/workspace/akira_san/image_classifier/dataset/raw_data"
|
||||
#output_folder = "/Users/mohamednouffer/workspace/akira_san/image_classifier/dataset/output"
|
||||
#model_folder = '/Users/mohamednouffer/workspace/akira_san/image_classifier/models'
|
||||
|
||||
#train_folder = os.path.join(output_folder, "train")
|
||||
#test_folder = os.path.join(output_folder, "test")
|
||||
|
||||
st.title("Sumasen AI")
|
||||
st.sidebar.header("Sumasen Trainer")
|
||||
st.session_state["image_arranged"] = False
|
||||
#model_folder = ""
|
||||
|
||||
|
||||
|
||||
options = ("New", "Existing model")
|
||||
|
||||
mode = st.sidebar.empty()
|
||||
model_empty = st.sidebar.empty()
|
||||
|
||||
|
||||
if "training_mode" not in st.session_state:
|
||||
st.session_state['training_mode'] = "New"
|
||||
|
||||
training_mode = mode.radio("Training mode:", options,0)
|
||||
st.session_state['training_mode'] = training_mode
|
||||
|
||||
if training_mode == "Existing model":
|
||||
|
||||
if 'model_folder' in st.session_state:
|
||||
model_folder = model_empty.text_input("model folder", st.session_state['model_folder'])
|
||||
if st.sidebar.button("Load models"):
|
||||
selected_model = st.sidebar.selectbox("Select a Model", getModelNames(model_folder), on_change=handSelectModel, key='model_name')
|
||||
st.session_state['selected_model'] = selected_model
|
||||
else:
|
||||
model_folder = model_empty.text_input("Enter model folder",'')
|
||||
if st.sidebar.button("Load models"):
|
||||
selected_model = st.sidebar.selectbox("Select a Model", getModelNames(model_folder))
|
||||
st.session_state['model_folder'] = model_folder
|
||||
st.session_state['selected_model'] = selected_model
|
||||
|
||||
if 'model_name' in st.session_state:
|
||||
st.write('Selected models is :', st.session_state['model_name'])
|
||||
st.subheader("Predict a image")
|
||||
file = st.file_uploader("Upload image",['jpeg', 'jpg', 'png'])
|
||||
if file:
|
||||
bytes_data = file.getvalue()
|
||||
with open(os.path.join(st.session_state['model_folder'],file.name),"wb") as f:
|
||||
f.write(file.getbuffer())
|
||||
model_file = os.path.join(st.session_state['model_folder'], st.session_state['model_name'])
|
||||
mdl = load_model(model_file)
|
||||
img = load_img(os.path.join(st.session_state['model_folder'], file.name), target_size=(100,100))
|
||||
st.image(img)
|
||||
img = img_to_array(img)
|
||||
img = img.reshape(1,100,100,3)
|
||||
res = mdl.predict(img)
|
||||
st.write(np.argmax(res))
|
||||
#st.write([key for key in class_labels][np.argmax(res)])
|
||||
|
||||
else:
|
||||
if "output_folder" in st.session_state:
|
||||
for fls in os.scandir(st.session_state["output_folder"]):
|
||||
if os.path.isdir(fls.path) and (fls.name == "tran" or fls.name == "test" or fls.name == "val"):
|
||||
st.info("Images are ready for training")
|
||||
st.session_state["image_arranged"] = True
|
||||
|
||||
|
||||
selected_image_folder = st.text_input("Enter image folder (make sure only category folders in this directory)", "Enter images folder ...")
|
||||
output_folder = os.path.join(selected_image_folder, "arranged")
|
||||
output_folder = st.text_input("Enter a folder to prepare images", output_folder)
|
||||
input_shape = st.selectbox('Select image input shape',[(100,100), [150, 150], (200,200), (240,240), (300, 300)])
|
||||
batch_size = st.slider('Batch size', 1, 1000, 40)
|
||||
if selected_image_folder == "Enter images folder ..." or selected_image_folder == "":
|
||||
st.error("Need a valid image folder")
|
||||
else:
|
||||
try:
|
||||
st.session_state["selected_image_folder"] = selected_image_folder
|
||||
st.session_state["output_folder"] = output_folder
|
||||
arranged = st.button("arrange images", on_click=handlePrepare)
|
||||
if st.session_state["image_arranged"] == True:
|
||||
traing_iterater, test_iterater = prepareIterater(st.session_state['output_folder'],batch_size, input_shape)
|
||||
class_labels = traing_iterater.class_indices
|
||||
st.write('class labels',class_labels)
|
||||
model_name = st.text_input('model name:', 'mymodel.h5')
|
||||
epochs = st.slider("Epochs",1,500, 2)
|
||||
if st.button('begin train') and epochs and batch_size:
|
||||
mdl = buildSimpleModel()
|
||||
text_output = st.empty()
|
||||
graph = st.empty()
|
||||
my_bar = st.empty()
|
||||
history = trainSimplaeModel(mdl, epochs, traing_iterater, test_iterater, model_name, text_output, my_bar, graph)
|
||||
his_df = pd.DataFrame(history.history)
|
||||
st.line_chart(his_df)
|
||||
training_mode = "Existing model"
|
||||
|
||||
except Exception as ex:
|
||||
st.error(ex)
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user