From 23c99e15ca42ec8ba4a93bdbcdd6cd2c52b3801f Mon Sep 17 00:00:00 2001 From: Elad Shaham <7040645+eshaham@users.noreply.github.com> Date: Mon, 8 Oct 2018 20:56:17 +0300 Subject: [PATCH] save final version --- mnist_cnn.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/mnist_cnn.py b/mnist_cnn.py index 8a300c3..997e339 100644 --- a/mnist_cnn.py +++ b/mnist_cnn.py @@ -11,6 +11,9 @@ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data +import missinglink + +project = missinglink.TensorFlowProject() # Input params NUM_CLASSES = 10 # The MNIST dataset has 10 classes, representing the digits 0 through 9. @@ -115,18 +118,20 @@ def run_training(): session = tf.Session() session.run(init) - # Start the training loop - for step in range(MAX_STEPS): - feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder) - - _, loss_value = session.run([train_op, loss], feed_dict=feed_dict) + with project.create_experiment() as experiment: + # Start the training loop + for step in experiment.loop(max_iterations=MAX_STEPS): + feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder) - # Validate the model with the validation dataset - if (step + 1) % 500 == 0 or (step + 1) == MAX_STEPS: - print('Step %d: loss = %.2f' % (step, loss_value)) - print('Running on validation dataset...') - do_eval(session, eval_correct, images_placeholder, labels_placeholder, data_sets.validation) + with experiment.train(monitored_metrics={'loss': loss, 'acc': eval_correct}): + _, loss_value = session.run([train_op, loss], feed_dict=feed_dict) + # Validate the model with the validation dataset + if (step + 1) % 500 == 0 or (step + 1) == MAX_STEPS: + print('Step %d: loss = %.2f' % (step, loss_value)) + print('Running on validation dataset...') + with experiment.validation(monitored_metrics={'loss': loss, 'acc': eval_correct}): + do_eval(session, eval_correct, images_placeholder, labels_placeholder, data_sets.validation) if __name__ == '__main__': run_training() \ No newline at end of file