In [1]:
%matplotlib inline
In [2]:
import tensorflow as tf import numpy as np
import matplotlib.pyplot as plt import matplotlib.image as mpimg
In [3]:
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
In [16]:
learning_rate = 0.01 training_epochs =20 batch_size = 256 display_step =1
examples_to_show = 10 n_hidden = 256
n_input = 28*28
In [17]:
def rgb2gray(rgb):
return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])
In [18]:
def read_imagebytes(imagefile):
file = open(imagefile,'rb') bytes = file.read()
return bytes
In [19]:
bytes = read_imagebytes('p16_16.jpg')
Extracting MNIST_data/train-images-idx3-ubyte.gz Extracting MNIST_data/train-labels-idx1-ubyte.gz Extracting MNIST_data/t10k-images-idx3-ubyte.gz Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
In [20]:
X = tf.placeholder(tf.float32, [None, n_input]) x_ = tf.placeholder(tf.float32, [None, n_hidden])
W_encode = tf.Variable(tf.random_normal([n_input, n_hidden])) b_encode = tf.Variable(tf.random_normal([n_hidden]))
encoder = tf.nn.sigmoid(tf.add(tf.matmul(X, W_encode), b_encode)) print(x_.shape)
print(encoder.shape)
W_decode = tf.Variable(tf.random_normal([n_hidden, n_input])) b_decode = tf.Variable(tf.random_normal([n_input]))
decoder = tf.nn.sigmoid(tf.add(tf.matmul(encoder, W_decode), b_decode)) test_decoder = tf.nn.sigmoid(tf.add(tf.matmul(x_, W_decode), b_decode)) cost = tf.reduce_mean(tf.pow(X - decoder, 2))
optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(cost)
(?, 256) (?, 256)
In [23]:
init = tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init)
total_batch = int(mnist.train.num_examples/batch_size) print(total_batch)
for epoch in range(training_epochs):
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size) _, c = sess.run([optimizer, cost], feed_dict={X:batch_xs})
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(c))
print("Optimization Finished!")
encode = sess.run(encoder,
feed_dict={X:mnist.test.images[:examples_to_show]})
encode_decode = sess.run(decoder,
feed_dict={X:mnist.test.images[:examples_to_show]})
# test_result = sess.run(test_decoder, feed_dict={x_:image_0}) f, a = plt.subplots(3, 10, figsize=(10,3))
for i in range(examples_to_show):
a[0][i].imshow(np.reshape(mnist.test.images[i], (28,28))) a[1][i].imshow(np.reshape(encode[i], (16,16)))
a[2][i].imshow(np.reshape(encode_decode[i], (28,28)))
# f.show()
# a[0][10].imshow(np.reshape(image_0, (16, 16)))
# a[1][10].imshow(np.reshape(test_result, (28, 28))) plt.draw()
plt.show()
In [22]:
214
Epoch: 0001 cost= 0.129844934 Epoch: 0002 cost= 0.069910102 Epoch: 0003 cost= 0.056689274 Epoch: 0004 cost= 0.051029567 Epoch: 0005 cost= 0.046199732 Epoch: 0006 cost= 0.043180466 Epoch: 0007 cost= 0.036908202 Epoch: 0008 cost= 0.037092172 Epoch: 0009 cost= 0.035878919 Epoch: 0010 cost= 0.033974003 Epoch: 0011 cost= 0.034789402 Epoch: 0012 cost= 0.032585703 Epoch: 0013 cost= 0.031352710 Epoch: 0014 cost= 0.032603778 Epoch: 0015 cost= 0.030622769 Epoch: 0016 cost= 0.030134838 Epoch: 0017 cost= 0.030411912 Epoch: 0018 cost= 0.030365063 Epoch: 0019 cost= 0.029278379 Epoch: 0020 cost= 0.028706312 Optimization Finished!
In [ ]:
--- ---
RuntimeError Traceback (most recent c all last)
<ipython-input-22-4e48b80f303c> in <module>() 1 encode = sess.run(encoder,
----> 2 feed_dict={X:mnist.test.images[:examples_t o_show]})
3
4 encode_decode = sess.run(decoder,
5 feed_dict={X:mnist.test.images[:examples_t o_show]})
~/anaconda/lib/python3.5/site-packages/tensorflow/python/client/se ssion.py in run(self, fetches, feed_dict, options, run_metadata) 893 try:
894 result = self._run(None, fetches, feed_dict, options _ptr,
--> 895 run_metadata_ptr) 896 if run_metadata:
897 proto_data = tf_session.TF_GetBuffer(run_metadata_
ptr)
~/anaconda/lib/python3.5/site-packages/tensorflow/python/client/se ssion.py in _run(self, handle, fetches, feed_dict, options, run_me tadata)
1049 # Check session.
1050 if self._closed:
-> 1051 raise RuntimeError('Attempted to use a closed Sessio n.')
1052 if self.graph.version == 0:
1053 raise RuntimeError('The Session graph is empty. Add operations to the '
RuntimeError: Attempted to use a closed Session.