[PYTHON] tensorflow mnist_deep.py

Overview

--Some fixes to run on a weak PC.

Since it consumes too much memory during testing, it is evaluated separately.

  batch_size = 50

  def my_eval(st,en):
    x_test = mnist.test.images[st:en]
    y_test = mnist.test.labels[st:en]
    return accuracy.eval(feed_dict={x: x_test, y_: y_test, keep_prob: 1.0})

  test_accuracy = np.mean([my_eval(i,i+batch_size) for i in range(0, mnist.test._num_examples, batch_size)])

Improve learning speed

--Use the initial value of He

def weight_variable(shape):
  """weight_variable generates a weight variable of a given shape."""
  if len(shape) == 4:
    sd = np.sqrt(2.0/shape[0]/shape[1]/shape[2])
  if len(shape) == 2:
    sd = np.sqrt(2.0/shape[0])
  initial = tf.truncated_normal(shape, stddev=sd)
  return tf.Variable(initial)

--Expand the learning rate to 1e-3

  train_step = tf.train.AdamOptimizer(1e-3).minimize(cross_entropy)

Recommended Posts

tensorflow mnist_deep.py
Tensorflow Glossary
TensorFlow tutorial tutorial
Tensorflow API: tf.truncated_normal
Tensorflow API: FLAGS
Try Distributed TensorFlow
Practice RNN TensorFlow
[Note] Regarding Tensorflow
TensorFlow 2.1 is here!
Tensorflow API: tf.reverse
Tensorflow personal tips
Install tensorflow. (Only!)
Zundokokiyoshi with TensorFlow
Breakout with Tensorflow
TensorFlow API memo
Understanding TensorFlow Arithmetic
Tensorflow doesn't work!