Проект

Общее

Профиль

cifar10_eval.py

Сергей Мальковский, 27.09.2017 15:44

Загрузить (5,53 КБ)

 
1
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
# ==============================================================================
15

    
16
"""Evaluation for CIFAR-10.
17

18
Accuracy:
19
cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs
20
of data) as judged by cifar10_eval.py.
21

22
Speed:
23
On a single Tesla K40, cifar10_train.py processes a single batch of 128 images
24
in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86%
25
accuracy after 100K steps in 8 hours of training time.
26

27
Usage:
28
Please see the tutorial and website for how to download the CIFAR-10
29
data set, compile the program and train the model.
30

31
http://tensorflow.org/tutorials/deep_cnn/
32
"""
33
from __future__ import absolute_import
34
from __future__ import division
35
from __future__ import print_function
36

    
37
from datetime import datetime
38
import math
39
import time
40
import os
41
import numpy as np
42
import tensorflow as tf
43

    
44
import cifar10
45

    
46
FLAGS = tf.app.flags.FLAGS
47

    
48
username = str(os.environ['USER'])
49

    
50
tf.app.flags.DEFINE_string('eval_dir', '/tmp/'+username+'/cifar10_eval',
51
                           """Directory where to write event logs.""")
52
tf.app.flags.DEFINE_string('eval_data', 'test',
53
                           """Either 'test' or 'train_eval'.""")
54
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/'+username+'/cifar10_train',
55
                           """Directory where to read model checkpoints.""")
56
tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5,
57
                            """How often to run the eval.""")
58
tf.app.flags.DEFINE_integer('num_examples', 10000,
59
                            """Number of examples to run.""")
60
tf.app.flags.DEFINE_boolean('run_once', True,
61
                         """Whether to run eval only once.""")
62

    
63

    
64
def eval_once(saver, summary_writer, top_k_op, summary_op):
65
  """Run Eval once.
66

67
  Args:
68
    saver: Saver.
69
    summary_writer: Summary writer.
70
    top_k_op: Top K op.
71
    summary_op: Summary op.
72
  """
73
  with tf.Session() as sess:
74
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
75
    if ckpt and ckpt.model_checkpoint_path:
76
      # Restores from checkpoint
77
      saver.restore(sess, ckpt.model_checkpoint_path)
78
      # Assuming model_checkpoint_path looks something like:
79
      #   /my-favorite-path/cifar10_train/model.ckpt-0,
80
      # extract global_step from it.
81
      global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
82
    else:
83
      print('No checkpoint file found')
84
      return
85

    
86
    # Start the queue runners.
87
    coord = tf.train.Coordinator()
88
    try:
89
      threads = []
90
      for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
91
        threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
92
                                         start=True))
93

    
94
      num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
95
      true_count = 0  # Counts the number of correct predictions.
96
      total_sample_count = num_iter * FLAGS.batch_size
97
      step = 0
98
      while step < num_iter and not coord.should_stop():
99
        predictions = sess.run([top_k_op])
100
        true_count += np.sum(predictions)
101
        step += 1
102

    
103
      # Compute precision @ 1.
104
      precision = true_count / total_sample_count
105
      #with open('out.txt', 'a') as f:
106
      #  print('%s: precision @ 1 = %.3f' % (datetime.now(), precision), file =f)
107
      print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
108
      summary = tf.Summary()
109
      summary.ParseFromString(sess.run(summary_op))
110
      summary.value.add(tag='Precision @ 1', simple_value=precision)
111
      summary_writer.add_summary(summary, global_step)
112
    except Exception as e:  # pylint: disable=broad-except
113
      coord.request_stop(e)
114

    
115
    coord.request_stop()
116
    coord.join(threads, stop_grace_period_secs=10)
117

    
118

    
119
def evaluate():
120
  """Eval CIFAR-10 for a number of steps."""
121
  with tf.Graph().as_default() as g:
122
    # Get images and labels for CIFAR-10.
123
    eval_data = FLAGS.eval_data == 'test'
124
    images, labels = cifar10.inputs(eval_data=eval_data)
125

    
126
    # Build a Graph that computes the logits predictions from the
127
    # inference model.
128
    logits = cifar10.inference(images)
129

    
130
    # Calculate predictions.
131
    top_k_op = tf.nn.in_top_k(logits, labels, 1)
132

    
133
    # Restore the moving average version of the learned variables for eval.
134
    variable_averages = tf.train.ExponentialMovingAverage(
135
        cifar10.MOVING_AVERAGE_DECAY)
136
    variables_to_restore = variable_averages.variables_to_restore()
137
    saver = tf.train.Saver(variables_to_restore)
138

    
139
    # Build the summary operation based on the TF collection of Summaries.
140
    summary_op = tf.summary.merge_all()
141

    
142
    summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
143

    
144
    while True:
145
      eval_once(saver, summary_writer, top_k_op, summary_op)
146
      if FLAGS.run_once:
147
        break
148
      time.sleep(FLAGS.eval_interval_secs)
149

    
150

    
151
def main(argv=None):  # pylint: disable=unused-argument
152
  cifar10.maybe_download_and_extract()
153
  if tf.gfile.Exists(FLAGS.eval_dir):
154
    tf.gfile.DeleteRecursively(FLAGS.eval_dir)
155
  tf.gfile.MakeDirs(FLAGS.eval_dir)
156
  evaluate()
157
  #f.close()
158

    
159
if __name__ == '__main__':
160
  tf.app.run()