from streamlit.proto.Progress_pb2 import Progress from tensorflow import keras import streamlit as st import pandas as pd class CustomCallback(keras.callbacks.Callback): percent_complete = 0 data_frame = pd.DataFrame() def __init__(self, element, pregress, metrics): self.element = element self.progress = st.progress(0) self.metrics = st.line_chart(self.data_frame) def on_train_begin(self, logs=None): keys = list(logs.keys()) keystring = "Starting training" self.element.markdown(" %s*..." % keystring) #st.write(keystring) #print("Starting training; got log keys: {}".format(keys)) def on_train_end(self, logs=None): keys = list(logs.keys()) keystring = "Finished training" self.progress.progress(100) self.element.markdown("* %s*..." % keystring) #st.write(keystring) def on_epoch_begin(self, epoch, logs=None): keys = list(logs.keys()) keystring = "Start epoch {} of training" self.element.markdown("* %s*..." % keystring) self.data_frame = pd.DataFrame() self.percent_complete = 0 #st.write(keystring) def on_epoch_end(self, epoch, logs=None): keys = list(logs.keys()) keystring = "End epoch {} of training ".format(epoch) self.element.markdown("* %s*..." % keystring) #st.write() def on_test_begin(self, logs=None): keys = list(logs.keys()) keystring = "Start testing" self.element.markdown("* %s*..." % keystring) #st.write(keystring) def on_test_end(self, logs=None): keys = list(logs.keys()) keystring = "Stop testing" self.element.markdown("* %s*..." % keystring) #st.write(keystring) def on_predict_begin(self, logs=None): keys = list(logs.keys()) keystring = "Start predicting" #self.element.markdown("*Training %s*..." % keystring) #st.write() def on_predict_end(self, logs=None): keys = list(logs.keys()) keystring ="Stop predicting; got log keys" #self.element.markdown("*Training %s*..." % keystring) #st.write(keystring) def on_train_batch_begin(self, batch, logs=None): keys = list(logs.keys()) keystring = "...Training: start of batch {}".format(batch) if self.percent_complete == 100: self.percent_complete = self.percent_complete - 20 return self.percent_complete = self.percent_complete + 1 #st.write(self.percent_complete) self.progress.progress(self.percent_complete) self.element.markdown("* %s*..." % keystring) #st.write(keystring) def on_train_batch_end(self, batch, logs=None): keys = list(logs.keys()) keystring = "...Training: end of batch {}".format(batch) #st.write(logs) self.data_frame = self.data_frame.append({'loss':logs['loss'], 'accuracy':logs['accuracy']}, ignore_index=True) self.metrics.line_chart(self.data_frame) self.element.markdown("* %s*..." % keystring) #st.write(keystring) def on_test_batch_begin(self, batch, logs=None): keys = list(logs.keys()) keystring = "...Evaluating: start of batch {}".format(batch) self.element.markdown("* %s*..." % keystring) #st.write() def on_test_batch_end(self, batch, logs=None): keys = list(logs.keys()) keystring = "...Evaluating: end of batch {}".format(batch) self.element.markdown("* %s*..." % keystring) #st.write() def on_predict_batch_begin(self, batch, logs=None): keys = list(logs.keys()) keystring = "...Predicting: start of batch {}".format(batch) #self.element.markdown("*Training %s*..." % keystring) #st.write() def on_predict_batch_end(self, batch, logs=None): keys = list(logs.keys()) keystring = "...Predicting: end of batch {}".format(batch) #self.element.markdown("*Training %s*..." % keystring) #st.write(keystring)