initial trainer
This commit is contained in:
111
CustomCallback.py
Normal file
111
CustomCallback.py
Normal file
@ -0,0 +1,111 @@
|
||||
from streamlit.proto.Progress_pb2 import Progress
|
||||
from tensorflow import keras
|
||||
import streamlit as st
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class CustomCallback(keras.callbacks.Callback):
|
||||
percent_complete = 0
|
||||
data_frame = pd.DataFrame()
|
||||
|
||||
def __init__(self, element, pregress, metrics):
|
||||
self.element = element
|
||||
self.progress = st.progress(0)
|
||||
self.metrics = st.line_chart(self.data_frame)
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "Starting training"
|
||||
self.element.markdown(" %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
#print("Starting training; got log keys: {}".format(keys))
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "Finished training"
|
||||
self.progress.progress(100)
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "Start epoch {} of training"
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
self.data_frame = pd.DataFrame()
|
||||
self.percent_complete = 0
|
||||
#st.write(keystring)
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "End epoch {} of training ".format(epoch)
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write()
|
||||
|
||||
def on_test_begin(self, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "Start testing"
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
|
||||
def on_test_end(self, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "Stop testing"
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
|
||||
def on_predict_begin(self, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "Start predicting"
|
||||
#self.element.markdown("*Training %s*..." % keystring)
|
||||
#st.write()
|
||||
|
||||
def on_predict_end(self, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring ="Stop predicting; got log keys"
|
||||
#self.element.markdown("*Training %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
|
||||
def on_train_batch_begin(self, batch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "...Training: start of batch {}".format(batch)
|
||||
if self.percent_complete == 100:
|
||||
self.percent_complete = self.percent_complete - 20
|
||||
return
|
||||
self.percent_complete = self.percent_complete + 1
|
||||
#st.write(self.percent_complete)
|
||||
self.progress.progress(self.percent_complete)
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "...Training: end of batch {}".format(batch)
|
||||
#st.write(logs)
|
||||
self.data_frame = self.data_frame.append({'loss':logs['loss'], 'accuracy':logs['accuracy']}, ignore_index=True)
|
||||
self.metrics.line_chart(self.data_frame)
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
|
||||
def on_test_batch_begin(self, batch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "...Evaluating: start of batch {}".format(batch)
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write()
|
||||
|
||||
def on_test_batch_end(self, batch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "...Evaluating: end of batch {}".format(batch)
|
||||
self.element.markdown("* %s*..." % keystring)
|
||||
#st.write()
|
||||
|
||||
def on_predict_batch_begin(self, batch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "...Predicting: start of batch {}".format(batch)
|
||||
#self.element.markdown("*Training %s*..." % keystring)
|
||||
#st.write()
|
||||
|
||||
def on_predict_batch_end(self, batch, logs=None):
|
||||
keys = list(logs.keys())
|
||||
keystring = "...Predicting: end of batch {}".format(batch)
|
||||
#self.element.markdown("*Training %s*..." % keystring)
|
||||
#st.write(keystring)
|
||||
Reference in New Issue
Block a user