From 964d058f3a85ddde413d02f2b8e47ccf468d5c95 Mon Sep 17 00:00:00 2001 From: Mohamed Nouffer Date: Fri, 22 Oct 2021 08:43:17 +0530 Subject: [PATCH] initial trainer --- CustomCallback.py | 111 +++++++++++++++++++ requirements.txt | 132 +++++++++++++++++++++++ trainer.py | 263 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 506 insertions(+) create mode 100644 CustomCallback.py create mode 100644 requirements.txt create mode 100644 trainer.py diff --git a/CustomCallback.py b/CustomCallback.py new file mode 100644 index 0000000..747ada9 --- /dev/null +++ b/CustomCallback.py @@ -0,0 +1,111 @@ +from streamlit.proto.Progress_pb2 import Progress +from tensorflow import keras +import streamlit as st +import pandas as pd + + +class CustomCallback(keras.callbacks.Callback): + percent_complete = 0 + data_frame = pd.DataFrame() + + def __init__(self, element, pregress, metrics): + self.element = element + self.progress = st.progress(0) + self.metrics = st.line_chart(self.data_frame) + + def on_train_begin(self, logs=None): + keys = list(logs.keys()) + keystring = "Starting training" + self.element.markdown(" %s*..." % keystring) + #st.write(keystring) + #print("Starting training; got log keys: {}".format(keys)) + + def on_train_end(self, logs=None): + keys = list(logs.keys()) + keystring = "Finished training" + self.progress.progress(100) + self.element.markdown("* %s*..." % keystring) + #st.write(keystring) + + def on_epoch_begin(self, epoch, logs=None): + keys = list(logs.keys()) + keystring = "Start epoch {} of training" + self.element.markdown("* %s*..." % keystring) + self.data_frame = pd.DataFrame() + self.percent_complete = 0 + #st.write(keystring) + + def on_epoch_end(self, epoch, logs=None): + keys = list(logs.keys()) + keystring = "End epoch {} of training ".format(epoch) + self.element.markdown("* %s*..." % keystring) + #st.write() + + def on_test_begin(self, logs=None): + keys = list(logs.keys()) + keystring = "Start testing" + self.element.markdown("* %s*..." % keystring) + #st.write(keystring) + + def on_test_end(self, logs=None): + keys = list(logs.keys()) + keystring = "Stop testing" + self.element.markdown("* %s*..." % keystring) + #st.write(keystring) + + def on_predict_begin(self, logs=None): + keys = list(logs.keys()) + keystring = "Start predicting" + #self.element.markdown("*Training %s*..." % keystring) + #st.write() + + def on_predict_end(self, logs=None): + keys = list(logs.keys()) + keystring ="Stop predicting; got log keys" + #self.element.markdown("*Training %s*..." % keystring) + #st.write(keystring) + + def on_train_batch_begin(self, batch, logs=None): + keys = list(logs.keys()) + keystring = "...Training: start of batch {}".format(batch) + if self.percent_complete == 100: + self.percent_complete = self.percent_complete - 20 + return + self.percent_complete = self.percent_complete + 1 + #st.write(self.percent_complete) + self.progress.progress(self.percent_complete) + self.element.markdown("* %s*..." % keystring) + #st.write(keystring) + + def on_train_batch_end(self, batch, logs=None): + keys = list(logs.keys()) + keystring = "...Training: end of batch {}".format(batch) + #st.write(logs) + self.data_frame = self.data_frame.append({'loss':logs['loss'], 'accuracy':logs['accuracy']}, ignore_index=True) + self.metrics.line_chart(self.data_frame) + self.element.markdown("* %s*..." % keystring) + #st.write(keystring) + + def on_test_batch_begin(self, batch, logs=None): + keys = list(logs.keys()) + keystring = "...Evaluating: start of batch {}".format(batch) + self.element.markdown("* %s*..." % keystring) + #st.write() + + def on_test_batch_end(self, batch, logs=None): + keys = list(logs.keys()) + keystring = "...Evaluating: end of batch {}".format(batch) + self.element.markdown("* %s*..." % keystring) + #st.write() + + def on_predict_batch_begin(self, batch, logs=None): + keys = list(logs.keys()) + keystring = "...Predicting: start of batch {}".format(batch) + #self.element.markdown("*Training %s*..." % keystring) + #st.write() + + def on_predict_batch_end(self, batch, logs=None): + keys = list(logs.keys()) + keystring = "...Predicting: end of batch {}".format(batch) + #self.element.markdown("*Training %s*..." % keystring) + #st.write(keystring) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3437cfd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,132 @@ +absl-py==0.13.0 +altair==4.1.0 +appnope==0.1.2 +argon2-cffi==21.1.0 +astor==0.8.1 +astunparse==1.6.3 +attrs==21.2.0 +backcall==0.2.0 +base58==2.1.0 +bleach==4.1.0 +blinker==1.4 +cachetools==4.2.2 +certifi==2021.5.30 +cffi==1.14.6 +charset-normalizer==2.0.6 +clang==5.0 +click==7.1.2 +cycler==0.10.0 +debugpy==1.4.3 +decorator==5.1.0 +defusedxml==0.7.1 +entrypoints==0.3 +flatbuffers==1.12 +gast==0.4.0 +gitdb==4.0.7 +GitPython==3.1.24 +google-auth==1.35.0 +google-auth-oauthlib==0.4.6 +google-pasta==0.2.0 +grpcio==1.40.0 +h5py==3.1.0 +idna==3.2 +ipykernel==6.4.1 +ipython==7.27.0 +ipython-genutils==0.2.0 +ipywidgets==7.6.5 +jedi==0.18.0 +Jinja2==3.0.1 +joblib==1.0.1 +jsonschema==3.2.0 +jupyter==1.0.0 +jupyter-client==7.0.3 +jupyter-console==6.4.0 +jupyter-core==4.8.1 +jupyterlab-pygments==0.1.2 +jupyterlab-widgets==1.0.2 +keras==2.6.0 +Keras-Preprocessing==1.1.2 +keras-tuner==1.0.4 +kiwisolver==1.3.2 +kt-legacy==1.0.4 +lxml==4.6.3 +Markdown==3.3.4 +MarkupSafe==2.0.1 +matplotlib==3.4.3 +matplotlib-inline==0.1.3 +mistune==0.8.4 +nbclient==0.5.4 +nbconvert==6.1.0 +nbformat==5.1.3 +nest-asyncio==1.5.1 +notebook==6.4.4 +numpy==1.19.5 +oauthlib==3.1.1 +opencv-contrib-python==4.5.3.56 +opencv-python==4.5.3.56 +opt-einsum==3.3.0 +packaging==21.0 +pandas==1.3.3 +pandocfilters==1.5.0 +parso==0.8.2 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==8.3.2 +prometheus-client==0.11.0 +prompt-toolkit==3.0.20 +protobuf==3.18.0 +ptyprocess==0.7.0 +pyarrow==5.0.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycparser==2.20 +pydeck==0.7.0 +pydot==1.4.2 +Pygments==2.10.0 +pyparsing==2.4.7 +PyQt5==5.15.4 +PyQt5-Qt5==5.15.2 +PyQt5-sip==12.9.0 +pyrsistent==0.18.0 +python-dateutil==2.8.2 +pytz==2021.1 +PyYAML==5.4.1 +pyzmq==22.3.0 +qtconsole==5.1.1 +QtPy==1.11.1 +requests==2.26.0 +requests-oauthlib==1.3.0 +rsa==4.7.2 +scikit-learn==0.24.2 +scipy==1.7.1 +seaborn==0.11.2 +Send2Trash==1.8.0 +six==1.15.0 +smmap==4.0.0 +split-folders==0.4.3 +streamlit==1.0.0 +tensorboard==2.6.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorflow==2.6.0 +tensorflow-docs==0.0.0.dev0 +tensorflow-estimator==2.6.0 +termcolor==1.1.0 +terminado==0.12.1 +testpath==0.5.0 +Theano==1.0.5 +threadpoolctl==2.2.0 +toml==0.10.2 +toolz==0.11.1 +tornado==6.1 +traitlets==5.1.0 +typing-extensions==3.7.4.3 +tzlocal==3.0 +urllib3==1.26.6 +validators==0.18.2 +watchdog==2.1.6 +wcwidth==0.2.5 +webencodings==0.5.1 +Werkzeug==2.0.1 +widgetsnbextension==3.5.1 +wrapt==1.12.1 diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..e34e63f --- /dev/null +++ b/trainer.py @@ -0,0 +1,263 @@ +from logging import exception +from keras.backend import constant +from tensorflow import keras +from tensorflow.keras import layers +from keras.preprocessing.image import load_img, ImageDataGenerator, array_to_img, img_to_array +from tensorflow.keras.models import load_model +from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input, decode_predictions +from tensorflow.keras.models import Model +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import streamlit as st +from tensorflow.python.keras.callbacks import History +from CustomCallback import CustomCallback +import os, shutil +from os import path +import splitfolders +import time + + + +########## Augmentation with VGG19 ############# +def tranferLearningVGG(model, train_dir, test_dir): + model = VGG19(include_top=False, input_shape=(224,224,3)) + + train_datagen = ImageDataGenerator( + rescale=1/255, + horizontal_flip=True, + rotation_range=20, + width_shift_range=0.2, + height_shift_range=0.2 + ) + test_datagen = ImageDataGenerator( + rescale=1/255, + ) + + for layer in model.layers: + layer.trainable = False + flatten_layer = layers.Flatten()(model.output) + + flatten_fully_connected_layer = layers.Dense(512, activation='relu')(flatten_layer) + flatten_fully_connected_softmax_layer =layers.Dense(5, activation='softmax')(flatten_fully_connected_layer) + model = Model(inputs=model.inputs, outputs=flatten_fully_connected_softmax_layer) + + training_iterater = train_datagen.flow_from_directory(train_dir, batch_size=64, target_size=(224,224)) + test_iterater = test_datagen.flow_from_directory(test_dir, batch_size=64, target_size=(224,224)) + model.compile(loss="categorical_crossentropy", metrics=['accuracy'], optimizer='adam') + model.save('models/vgg19_tl.h5') + history = model.fit(training_iterater, validation_data=test_iterater, epochs=4) + return history + + +############ Model creation ################ +def buildSimpleModel(): + model = keras.Sequential() + model.add(layers.Conv2D(32, (3,3), activation='relu', input_shape=(100,100,3))) + model.add(layers.MaxPooling2D((2,2))) + + model.add(layers.Conv2D(64, (3,3), activation='relu')) + model.add(layers.MaxPooling2D((2,2))) + + model.add(layers.Conv2D(128, (3,3), activation='relu')) + model.add(layers.MaxPooling2D((2,2))) + + model.add(layers.Flatten()) + + model.add(layers.Dense(512, activation='relu')) + model.add(layers.Dense(5, activation='softmax')) + return model + +###### preapre iterater ######## +def prepareIterater(folder_path, batch_size, img_size): + train_datagen = ImageDataGenerator(rescale=1) + test_datagen = ImageDataGenerator(rescale=1) + training_iterater = train_datagen.flow_from_directory(os.path.join(folder_path, 'train'), batch_size=batch_size, target_size=img_size) + test_iterater = test_datagen.flow_from_directory(os.path.join(folder_path, 'test'), batch_size=batch_size, target_size=img_size) + return training_iterater, test_iterater + + +########## start trining ############## +def trainSimplaeModel(model, epochs, tran_iterater, test_iterater, model_name, text_output, progrss_bar, graph): + + # train_datagen = ImageDataGenerator(rescale=1) + # test_datagen = ImageDataGenerator(rescale=1) + # training_iterater = train_datagen.flow_from_directory(tran_dir, batch_size=64, target_size=(100,100)) + # test_iterater = test_datagen.flow_from_directory(test_dir, batch_size=64, target_size=(100,100)) + model.compile(loss="categorical_crossentropy", metrics=['accuracy'], optimizer='adam') + history = model.fit(tran_iterater, validation_data=test_iterater, epochs=epochs, callbacks=[CustomCallback(text_output, progrss_bar, graph)],) + model_path = os.path.join(st.session_state['output_folder'], model_name) + st.session_state['model_folder'] = st.session_state["output_folder"] + #st.write(model_path) + model.save(model_path) + return history + + +######### plot ######### +def plotHistory(history): + plt.plot(history.history['loss']) + plt.plot(history.history['val_loss']) + plt.title("model loss") + plt.ylabel("loss") + plt.xlabel("no of epochs") + plt.legend(['training', 'testing'], loc='upper left') + plt.show() + + +####### print class labels ####### +def printLabels(training_iterater): + class_labels = training_iterater.class_indices + print(class_labels) + + +####### predict ######## +def predict(modelname): + mdl = load_model('models/'+ modelname) + img = load_img('dataset/flowers/sunflower.jpeg', target_size=(100,100)) + img = img_to_array(img) + img = img.reshape(1,100,100,3) + res = mdl.predict(img) + print(res) + print(np.argmax(res)) + print([key for key in class_labels][np.argmax(res)]) + +def getModelNames(model_folder): + if model_folder == "": + return [] + param = [] + length =len([name for name in os.listdir('.') if os.path.isfile(model_folder)]) + for entry in os.scandir(model_folder): + if entry.name.lower().endswith('.h5'): + param.append(entry.name) + return param + + + + + + +##### prepare image folders ###### +def prepareFolders(folder_path, output_folder): + file_path = st.session_state['output_folder'] + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + splitfolders.ratio(folder_path, output=output_folder, seed=1337, ratio=(.8, 0.1,0.1)) + except Exception as e: + pass + +##### handle change on image input folder ##### +def handlerImageFolderChanged(): + output_folder = st.session_state["output_folder"] + +def handlePrepare(): + prepareFolders(raw_image_folder, st.session_state['output_folder']) + +def handSelectModel(): + st.write('#####') + # if 'model_name' in st.session_state: + # st.write(st.session_state['model_name']) + + +raw_image_folder = "/Users/mohamednouffer/workspace/akira_san/image_classifier/dataset/raw_data" +#output_folder = "/Users/mohamednouffer/workspace/akira_san/image_classifier/dataset/output" +#model_folder = '/Users/mohamednouffer/workspace/akira_san/image_classifier/models' + +#train_folder = os.path.join(output_folder, "train") +#test_folder = os.path.join(output_folder, "test") + +st.title("Sumasen AI") +st.sidebar.header("Sumasen Trainer") +st.session_state["image_arranged"] = False +#model_folder = "" + + + +options = ("New", "Existing model") + +mode = st.sidebar.empty() +model_empty = st.sidebar.empty() + + +if "training_mode" not in st.session_state: + st.session_state['training_mode'] = "New" + +training_mode = mode.radio("Training mode:", options,0) +st.session_state['training_mode'] = training_mode + +if training_mode == "Existing model": + + if 'model_folder' in st.session_state: + model_folder = model_empty.text_input("model folder", st.session_state['model_folder']) + if st.sidebar.button("Load models"): + selected_model = st.sidebar.selectbox("Select a Model", getModelNames(model_folder), on_change=handSelectModel, key='model_name') + st.session_state['selected_model'] = selected_model + else: + model_folder = model_empty.text_input("Enter model folder",'') + if st.sidebar.button("Load models"): + selected_model = st.sidebar.selectbox("Select a Model", getModelNames(model_folder)) + st.session_state['model_folder'] = model_folder + st.session_state['selected_model'] = selected_model + + if 'model_name' in st.session_state: + st.write('Selected models is :', st.session_state['model_name']) + st.subheader("Predict a image") + file = st.file_uploader("Upload image",['jpeg', 'jpg', 'png']) + if file: + bytes_data = file.getvalue() + with open(os.path.join(st.session_state['model_folder'],file.name),"wb") as f: + f.write(file.getbuffer()) + model_file = os.path.join(st.session_state['model_folder'], st.session_state['model_name']) + mdl = load_model(model_file) + img = load_img(os.path.join(st.session_state['model_folder'], file.name), target_size=(100,100)) + st.image(img) + img = img_to_array(img) + img = img.reshape(1,100,100,3) + res = mdl.predict(img) + st.write(np.argmax(res)) + #st.write([key for key in class_labels][np.argmax(res)]) + +else: + if "output_folder" in st.session_state: + for fls in os.scandir(st.session_state["output_folder"]): + if os.path.isdir(fls.path) and (fls.name == "tran" or fls.name == "test" or fls.name == "val"): + st.info("Images are ready for training") + st.session_state["image_arranged"] = True + + + selected_image_folder = st.text_input("Enter image folder (make sure only category folders in this directory)", "Enter images folder ...") + output_folder = os.path.join(selected_image_folder, "arranged") + output_folder = st.text_input("Enter a folder to prepare images", output_folder) + input_shape = st.selectbox('Select image input shape',[(100,100), [150, 150], (200,200), (240,240), (300, 300)]) + batch_size = st.slider('Batch size', 1, 1000, 40) + if selected_image_folder == "Enter images folder ..." or selected_image_folder == "": + st.error("Need a valid image folder") + else: + try: + st.session_state["selected_image_folder"] = selected_image_folder + st.session_state["output_folder"] = output_folder + arranged = st.button("arrange images", on_click=handlePrepare) + if st.session_state["image_arranged"] == True: + traing_iterater, test_iterater = prepareIterater(st.session_state['output_folder'],batch_size, input_shape) + class_labels = traing_iterater.class_indices + st.write('class labels',class_labels) + model_name = st.text_input('model name:', 'mymodel.h5') + epochs = st.slider("Epochs",1,500, 2) + if st.button('begin train') and epochs and batch_size: + mdl = buildSimpleModel() + text_output = st.empty() + graph = st.empty() + my_bar = st.empty() + history = trainSimplaeModel(mdl, epochs, traing_iterater, test_iterater, model_name, text_output, my_bar, graph) + his_df = pd.DataFrame(history.history) + st.line_chart(his_df) + training_mode = "Existing model" + + except Exception as ex: + st.error(ex) + + + +