update cate size
This commit is contained in:
14
trainer.py
14
trainer.py
@ -51,7 +51,7 @@ def tranferLearningVGG(model, train_dir, test_dir):
|
||||
|
||||
|
||||
############ Model creation ################
|
||||
def buildSimpleModel():
|
||||
def buildSimpleModel(cats):
|
||||
model = keras.Sequential()
|
||||
model.add(layers.Conv2D(32, (3,3), activation='relu', input_shape=(100,100,3)))
|
||||
model.add(layers.MaxPooling2D((2,2)))
|
||||
@ -65,7 +65,7 @@ def buildSimpleModel():
|
||||
model.add(layers.Flatten())
|
||||
|
||||
model.add(layers.Dense(512, activation='relu'))
|
||||
model.add(layers.Dense(5, activation='softmax'))
|
||||
model.add(layers.Dense(cats, activation='softmax'))
|
||||
return model
|
||||
|
||||
###### preapre iterater ########
|
||||
@ -204,6 +204,7 @@ if training_mode == "Existing model":
|
||||
if 'model_name' in st.session_state:
|
||||
st.write('Selected models is :', st.session_state['model_name'])
|
||||
st.subheader("Predict a image")
|
||||
input_shape = st.selectbox('Select image input shape',[(100,100), [150, 150], (200,200), (240,240), (300, 300)])
|
||||
file = st.file_uploader("Upload image",['jpeg', 'jpg', 'png'])
|
||||
if file:
|
||||
bytes_data = file.getvalue()
|
||||
@ -211,12 +212,12 @@ if training_mode == "Existing model":
|
||||
f.write(file.getbuffer())
|
||||
model_file = os.path.join(st.session_state['model_folder'], st.session_state['model_name'])
|
||||
mdl = load_model(model_file)
|
||||
img = load_img(os.path.join(st.session_state['model_folder'], file.name), target_size=(100,100))
|
||||
img = load_img(os.path.join(st.session_state['model_folder'], file.name), target_size=input_shape)
|
||||
st.image(img)
|
||||
img = img_to_array(img)
|
||||
img = img.reshape(1,100,100,3)
|
||||
img = img.reshape(1,input_shape[0],input_shape[1],3)
|
||||
res = mdl.predict(img)
|
||||
st.write(np.argmax(res))
|
||||
st.markdown('## Category is {}'.format(np.argmax(res)))
|
||||
#st.write([key for key in class_labels][np.argmax(res)])
|
||||
|
||||
else:
|
||||
@ -246,7 +247,7 @@ else:
|
||||
model_name = st.text_input('model name:', 'mymodel.h5')
|
||||
epochs = st.slider("Epochs",1,500, 2)
|
||||
if st.button('begin train') and epochs and batch_size:
|
||||
mdl = buildSimpleModel()
|
||||
mdl = buildSimpleModel(len(class_labels))
|
||||
text_output = st.empty()
|
||||
graph = st.empty()
|
||||
my_bar = st.empty()
|
||||
@ -254,6 +255,7 @@ else:
|
||||
his_df = pd.DataFrame(history.history)
|
||||
st.line_chart(his_df)
|
||||
training_mode = "Existing model"
|
||||
st.markdown('## Training completed, plese check the output folder for saved model.')
|
||||
|
||||
except Exception as ex:
|
||||
st.error(ex)
|
||||
|
||||
Reference in New Issue
Block a user