From f8ce682583929b54fbeae3756a8d9e3fc896065b Mon Sep 17 00:00:00 2001 From: Mohamed Nouffer Date: Fri, 22 Oct 2021 13:12:47 +0530 Subject: [PATCH] update cate size --- trainer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/trainer.py b/trainer.py index e34e63f..ec08195 100644 --- a/trainer.py +++ b/trainer.py @@ -51,7 +51,7 @@ def tranferLearningVGG(model, train_dir, test_dir): ############ Model creation ################ -def buildSimpleModel(): +def buildSimpleModel(cats): model = keras.Sequential() model.add(layers.Conv2D(32, (3,3), activation='relu', input_shape=(100,100,3))) model.add(layers.MaxPooling2D((2,2))) @@ -65,7 +65,7 @@ def buildSimpleModel(): model.add(layers.Flatten()) model.add(layers.Dense(512, activation='relu')) - model.add(layers.Dense(5, activation='softmax')) + model.add(layers.Dense(cats, activation='softmax')) return model ###### preapre iterater ######## @@ -204,6 +204,7 @@ if training_mode == "Existing model": if 'model_name' in st.session_state: st.write('Selected models is :', st.session_state['model_name']) st.subheader("Predict a image") + input_shape = st.selectbox('Select image input shape',[(100,100), [150, 150], (200,200), (240,240), (300, 300)]) file = st.file_uploader("Upload image",['jpeg', 'jpg', 'png']) if file: bytes_data = file.getvalue() @@ -211,12 +212,12 @@ if training_mode == "Existing model": f.write(file.getbuffer()) model_file = os.path.join(st.session_state['model_folder'], st.session_state['model_name']) mdl = load_model(model_file) - img = load_img(os.path.join(st.session_state['model_folder'], file.name), target_size=(100,100)) + img = load_img(os.path.join(st.session_state['model_folder'], file.name), target_size=input_shape) st.image(img) img = img_to_array(img) - img = img.reshape(1,100,100,3) + img = img.reshape(1,input_shape[0],input_shape[1],3) res = mdl.predict(img) - st.write(np.argmax(res)) + st.markdown('## Category is {}'.format(np.argmax(res))) #st.write([key for key in class_labels][np.argmax(res)]) else: @@ -246,7 +247,7 @@ else: model_name = st.text_input('model name:', 'mymodel.h5') epochs = st.slider("Epochs",1,500, 2) if st.button('begin train') and epochs and batch_size: - mdl = buildSimpleModel() + mdl = buildSimpleModel(len(class_labels)) text_output = st.empty() graph = st.empty() my_bar = st.empty() @@ -254,6 +255,7 @@ else: his_df = pd.DataFrame(history.history) st.line_chart(his_df) training_mode = "Existing model" + st.markdown('## Training completed, plese check the output folder for saved model.') except Exception as ex: st.error(ex)