111 lines
4.0 KiB
Python
111 lines
4.0 KiB
Python
from streamlit.proto.Progress_pb2 import Progress
|
|
from tensorflow import keras
|
|
import streamlit as st
|
|
import pandas as pd
|
|
|
|
|
|
class CustomCallback(keras.callbacks.Callback):
|
|
percent_complete = 0
|
|
data_frame = pd.DataFrame()
|
|
|
|
def __init__(self, element, pregress, metrics):
|
|
self.element = element
|
|
self.progress = st.progress(0)
|
|
self.metrics = st.line_chart(self.data_frame)
|
|
|
|
def on_train_begin(self, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring = "Starting training"
|
|
self.element.markdown(" %s*..." % keystring)
|
|
#st.write(keystring)
|
|
#print("Starting training; got log keys: {}".format(keys))
|
|
|
|
def on_train_end(self, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring = "Finished training"
|
|
self.progress.progress(100)
|
|
self.element.markdown("* %s*..." % keystring)
|
|
#st.write(keystring)
|
|
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring = "Start epoch {} of training"
|
|
self.element.markdown("* %s*..." % keystring)
|
|
self.data_frame = pd.DataFrame()
|
|
self.percent_complete = 0
|
|
#st.write(keystring)
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring = "End epoch {} of training ".format(epoch)
|
|
self.element.markdown("* %s*..." % keystring)
|
|
#st.write()
|
|
|
|
def on_test_begin(self, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring = "Start testing"
|
|
self.element.markdown("* %s*..." % keystring)
|
|
#st.write(keystring)
|
|
|
|
def on_test_end(self, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring = "Stop testing"
|
|
self.element.markdown("* %s*..." % keystring)
|
|
#st.write(keystring)
|
|
|
|
def on_predict_begin(self, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring = "Start predicting"
|
|
#self.element.markdown("*Training %s*..." % keystring)
|
|
#st.write()
|
|
|
|
def on_predict_end(self, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring ="Stop predicting; got log keys"
|
|
#self.element.markdown("*Training %s*..." % keystring)
|
|
#st.write(keystring)
|
|
|
|
def on_train_batch_begin(self, batch, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring = "...Training: start of batch {}".format(batch)
|
|
if self.percent_complete == 100:
|
|
self.percent_complete = self.percent_complete - 20
|
|
return
|
|
self.percent_complete = self.percent_complete + 1
|
|
#st.write(self.percent_complete)
|
|
self.progress.progress(self.percent_complete)
|
|
self.element.markdown("* %s*..." % keystring)
|
|
#st.write(keystring)
|
|
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring = "...Training: end of batch {}".format(batch)
|
|
#st.write(logs)
|
|
self.data_frame = self.data_frame.append({'loss':logs['loss'], 'accuracy':logs['accuracy']}, ignore_index=True)
|
|
self.metrics.line_chart(self.data_frame)
|
|
self.element.markdown("* %s*..." % keystring)
|
|
#st.write(keystring)
|
|
|
|
def on_test_batch_begin(self, batch, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring = "...Evaluating: start of batch {}".format(batch)
|
|
self.element.markdown("* %s*..." % keystring)
|
|
#st.write()
|
|
|
|
def on_test_batch_end(self, batch, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring = "...Evaluating: end of batch {}".format(batch)
|
|
self.element.markdown("* %s*..." % keystring)
|
|
#st.write()
|
|
|
|
def on_predict_batch_begin(self, batch, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring = "...Predicting: start of batch {}".format(batch)
|
|
#self.element.markdown("*Training %s*..." % keystring)
|
|
#st.write()
|
|
|
|
def on_predict_batch_end(self, batch, logs=None):
|
|
keys = list(logs.keys())
|
|
keystring = "...Predicting: end of batch {}".format(batch)
|
|
#self.element.markdown("*Training %s*..." % keystring)
|
|
#st.write(keystring) |