cifar10_eval.py
| 1 |
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|---|---|
| 2 |
#
|
| 3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
# you may not use this file except in compliance with the License.
|
| 5 |
# You may obtain a copy of the License at
|
| 6 |
#
|
| 7 |
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
#
|
| 9 |
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
# ==============================================================================
|
| 15 |
|
| 16 |
"""Evaluation for CIFAR-10.
|
| 17 |
|
| 18 |
Accuracy:
|
| 19 |
cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs
|
| 20 |
of data) as judged by cifar10_eval.py.
|
| 21 |
|
| 22 |
Speed:
|
| 23 |
On a single Tesla K40, cifar10_train.py processes a single batch of 128 images
|
| 24 |
in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86%
|
| 25 |
accuracy after 100K steps in 8 hours of training time.
|
| 26 |
|
| 27 |
Usage:
|
| 28 |
Please see the tutorial and website for how to download the CIFAR-10
|
| 29 |
data set, compile the program and train the model.
|
| 30 |
|
| 31 |
http://tensorflow.org/tutorials/deep_cnn/
|
| 32 |
"""
|
| 33 |
from __future__ import absolute_import |
| 34 |
from __future__ import division |
| 35 |
from __future__ import print_function |
| 36 |
|
| 37 |
from datetime import datetime |
| 38 |
import math |
| 39 |
import time |
| 40 |
import os |
| 41 |
import numpy as np |
| 42 |
import tensorflow as tf |
| 43 |
|
| 44 |
import cifar10 |
| 45 |
|
| 46 |
FLAGS = tf.app.flags.FLAGS |
| 47 |
|
| 48 |
username = str(os.environ['USER']) |
| 49 |
|
| 50 |
tf.app.flags.DEFINE_string('eval_dir', '/tmp/'+username+'/cifar10_eval', |
| 51 |
"""Directory where to write event logs.""")
|
| 52 |
tf.app.flags.DEFINE_string('eval_data', 'test', |
| 53 |
"""Either 'test' or 'train_eval'.""")
|
| 54 |
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/'+username+'/cifar10_train', |
| 55 |
"""Directory where to read model checkpoints.""")
|
| 56 |
tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5, |
| 57 |
"""How often to run the eval.""")
|
| 58 |
tf.app.flags.DEFINE_integer('num_examples', 10000, |
| 59 |
"""Number of examples to run.""")
|
| 60 |
tf.app.flags.DEFINE_boolean('run_once', True, |
| 61 |
"""Whether to run eval only once.""")
|
| 62 |
|
| 63 |
|
| 64 |
def eval_once(saver, summary_writer, top_k_op, summary_op): |
| 65 |
"""Run Eval once.
|
| 66 |
|
| 67 |
Args:
|
| 68 |
saver: Saver.
|
| 69 |
summary_writer: Summary writer.
|
| 70 |
top_k_op: Top K op.
|
| 71 |
summary_op: Summary op.
|
| 72 |
"""
|
| 73 |
with tf.Session() as sess: |
| 74 |
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) |
| 75 |
if ckpt and ckpt.model_checkpoint_path: |
| 76 |
# Restores from checkpoint
|
| 77 |
saver.restore(sess, ckpt.model_checkpoint_path) |
| 78 |
# Assuming model_checkpoint_path looks something like:
|
| 79 |
# /my-favorite-path/cifar10_train/model.ckpt-0,
|
| 80 |
# extract global_step from it.
|
| 81 |
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] |
| 82 |
else:
|
| 83 |
print('No checkpoint file found')
|
| 84 |
return
|
| 85 |
|
| 86 |
# Start the queue runners.
|
| 87 |
coord = tf.train.Coordinator() |
| 88 |
try:
|
| 89 |
threads = [] |
| 90 |
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): |
| 91 |
threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
|
| 92 |
start=True))
|
| 93 |
|
| 94 |
num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
|
| 95 |
true_count = 0 # Counts the number of correct predictions. |
| 96 |
total_sample_count = num_iter * FLAGS.batch_size |
| 97 |
step = 0
|
| 98 |
while step < num_iter and not coord.should_stop(): |
| 99 |
predictions = sess.run([top_k_op]) |
| 100 |
true_count += np.sum(predictions) |
| 101 |
step += 1
|
| 102 |
|
| 103 |
# Compute precision @ 1.
|
| 104 |
precision = true_count / total_sample_count |
| 105 |
#with open('out.txt', 'a') as f:
|
| 106 |
# print('%s: precision @ 1 = %.3f' % (datetime.now(), precision), file =f)
|
| 107 |
print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
|
| 108 |
summary = tf.Summary() |
| 109 |
summary.ParseFromString(sess.run(summary_op)) |
| 110 |
summary.value.add(tag='Precision @ 1', simple_value=precision)
|
| 111 |
summary_writer.add_summary(summary, global_step) |
| 112 |
except Exception as e: # pylint: disable=broad-except |
| 113 |
coord.request_stop(e) |
| 114 |
|
| 115 |
coord.request_stop() |
| 116 |
coord.join(threads, stop_grace_period_secs=10)
|
| 117 |
|
| 118 |
|
| 119 |
def evaluate(): |
| 120 |
"""Eval CIFAR-10 for a number of steps."""
|
| 121 |
with tf.Graph().as_default() as g: |
| 122 |
# Get images and labels for CIFAR-10.
|
| 123 |
eval_data = FLAGS.eval_data == 'test'
|
| 124 |
images, labels = cifar10.inputs(eval_data=eval_data) |
| 125 |
|
| 126 |
# Build a Graph that computes the logits predictions from the
|
| 127 |
# inference model.
|
| 128 |
logits = cifar10.inference(images) |
| 129 |
|
| 130 |
# Calculate predictions.
|
| 131 |
top_k_op = tf.nn.in_top_k(logits, labels, 1)
|
| 132 |
|
| 133 |
# Restore the moving average version of the learned variables for eval.
|
| 134 |
variable_averages = tf.train.ExponentialMovingAverage( |
| 135 |
cifar10.MOVING_AVERAGE_DECAY) |
| 136 |
variables_to_restore = variable_averages.variables_to_restore() |
| 137 |
saver = tf.train.Saver(variables_to_restore) |
| 138 |
|
| 139 |
# Build the summary operation based on the TF collection of Summaries.
|
| 140 |
summary_op = tf.summary.merge_all() |
| 141 |
|
| 142 |
summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g) |
| 143 |
|
| 144 |
while True: |
| 145 |
eval_once(saver, summary_writer, top_k_op, summary_op) |
| 146 |
if FLAGS.run_once:
|
| 147 |
break
|
| 148 |
time.sleep(FLAGS.eval_interval_secs) |
| 149 |
|
| 150 |
|
| 151 |
def main(argv=None): # pylint: disable=unused-argument |
| 152 |
cifar10.maybe_download_and_extract() |
| 153 |
if tf.gfile.Exists(FLAGS.eval_dir):
|
| 154 |
tf.gfile.DeleteRecursively(FLAGS.eval_dir) |
| 155 |
tf.gfile.MakeDirs(FLAGS.eval_dir) |
| 156 |
evaluate() |
| 157 |
#f.close()
|
| 158 |
|
| 159 |
if __name__ == '__main__': |
| 160 |
tf.app.run() |