update cate size
This commit is contained in:
14
trainer.py
14
trainer.py
@ -51,7 +51,7 @@ def tranferLearningVGG(model, train_dir, test_dir):
|
|||||||
|
|
||||||
|
|
||||||
############ Model creation ################
|
############ Model creation ################
|
||||||
def buildSimpleModel():
|
def buildSimpleModel(cats):
|
||||||
model = keras.Sequential()
|
model = keras.Sequential()
|
||||||
model.add(layers.Conv2D(32, (3,3), activation='relu', input_shape=(100,100,3)))
|
model.add(layers.Conv2D(32, (3,3), activation='relu', input_shape=(100,100,3)))
|
||||||
model.add(layers.MaxPooling2D((2,2)))
|
model.add(layers.MaxPooling2D((2,2)))
|
||||||
@ -65,7 +65,7 @@ def buildSimpleModel():
|
|||||||
model.add(layers.Flatten())
|
model.add(layers.Flatten())
|
||||||
|
|
||||||
model.add(layers.Dense(512, activation='relu'))
|
model.add(layers.Dense(512, activation='relu'))
|
||||||
model.add(layers.Dense(5, activation='softmax'))
|
model.add(layers.Dense(cats, activation='softmax'))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
###### preapre iterater ########
|
###### preapre iterater ########
|
||||||
@ -204,6 +204,7 @@ if training_mode == "Existing model":
|
|||||||
if 'model_name' in st.session_state:
|
if 'model_name' in st.session_state:
|
||||||
st.write('Selected models is :', st.session_state['model_name'])
|
st.write('Selected models is :', st.session_state['model_name'])
|
||||||
st.subheader("Predict a image")
|
st.subheader("Predict a image")
|
||||||
|
input_shape = st.selectbox('Select image input shape',[(100,100), [150, 150], (200,200), (240,240), (300, 300)])
|
||||||
file = st.file_uploader("Upload image",['jpeg', 'jpg', 'png'])
|
file = st.file_uploader("Upload image",['jpeg', 'jpg', 'png'])
|
||||||
if file:
|
if file:
|
||||||
bytes_data = file.getvalue()
|
bytes_data = file.getvalue()
|
||||||
@ -211,12 +212,12 @@ if training_mode == "Existing model":
|
|||||||
f.write(file.getbuffer())
|
f.write(file.getbuffer())
|
||||||
model_file = os.path.join(st.session_state['model_folder'], st.session_state['model_name'])
|
model_file = os.path.join(st.session_state['model_folder'], st.session_state['model_name'])
|
||||||
mdl = load_model(model_file)
|
mdl = load_model(model_file)
|
||||||
img = load_img(os.path.join(st.session_state['model_folder'], file.name), target_size=(100,100))
|
img = load_img(os.path.join(st.session_state['model_folder'], file.name), target_size=input_shape)
|
||||||
st.image(img)
|
st.image(img)
|
||||||
img = img_to_array(img)
|
img = img_to_array(img)
|
||||||
img = img.reshape(1,100,100,3)
|
img = img.reshape(1,input_shape[0],input_shape[1],3)
|
||||||
res = mdl.predict(img)
|
res = mdl.predict(img)
|
||||||
st.write(np.argmax(res))
|
st.markdown('## Category is {}'.format(np.argmax(res)))
|
||||||
#st.write([key for key in class_labels][np.argmax(res)])
|
#st.write([key for key in class_labels][np.argmax(res)])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -246,7 +247,7 @@ else:
|
|||||||
model_name = st.text_input('model name:', 'mymodel.h5')
|
model_name = st.text_input('model name:', 'mymodel.h5')
|
||||||
epochs = st.slider("Epochs",1,500, 2)
|
epochs = st.slider("Epochs",1,500, 2)
|
||||||
if st.button('begin train') and epochs and batch_size:
|
if st.button('begin train') and epochs and batch_size:
|
||||||
mdl = buildSimpleModel()
|
mdl = buildSimpleModel(len(class_labels))
|
||||||
text_output = st.empty()
|
text_output = st.empty()
|
||||||
graph = st.empty()
|
graph = st.empty()
|
||||||
my_bar = st.empty()
|
my_bar = st.empty()
|
||||||
@ -254,6 +255,7 @@ else:
|
|||||||
his_df = pd.DataFrame(history.history)
|
his_df = pd.DataFrame(history.history)
|
||||||
st.line_chart(his_df)
|
st.line_chart(his_df)
|
||||||
training_mode = "Existing model"
|
training_mode = "Existing model"
|
||||||
|
st.markdown('## Training completed, plese check the output folder for saved model.')
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
st.error(ex)
|
st.error(ex)
|
||||||
|
|||||||
Reference in New Issue
Block a user