11import itertools
2- import matplotlib as mpl
32import numpy as np
43import os
54import tensorflow as tf
5+ import tensorflow .keras as tfk
66import tensorflow .contrib .slim as slim
77import time
8- import seaborn as sns
9-
10- from matplotlib import pyplot as plt
8+ import tensorflow_datasets as tfds
9+ import tensorflow_probability as tfp
1110from imageio import imwrite
1211from tensorflow .contrib .learn .python .learn .datasets .mnist import read_data_sets
13-
14- sns .set_style ('whitegrid' )
15-
16- distributions = tf .distributions
12+ tfkl = tfk .layers
13+ tfc = tf .compat .v1
1714
1815flags = tf .app .flags
1916flags .DEFINE_string ('data_dir' , '/tmp/dat/' , 'Directory for data' )
2017flags .DEFINE_string ('logdir' , '/tmp/log/' , 'Directory for logs' )
21-
22- # For making plots:
23- # flags.DEFINE_integer('latent_dim', 2, 'Latent dimensionality of model')
24- # flags.DEFINE_integer('batch_size', 64, 'Minibatch size')
25- # flags.DEFINE_integer('n_samples', 10, 'Number of samples to save')
26- # flags.DEFINE_integer('print_every', 10, 'Print every n iterations')
27- # flags.DEFINE_integer('hidden_size', 200, 'Hidden size for neural networks')
28- # flags.DEFINE_integer('n_iterations', 1000, 'number of iterations')
29-
30- # For bigger model:
3118flags .DEFINE_integer ('latent_dim' , 100 , 'Latent dimensionality of model' )
3219flags .DEFINE_integer ('batch_size' , 64 , 'Minibatch size' )
3320flags .DEFINE_integer ('n_samples' , 1 , 'Number of samples to save' )
@@ -50,12 +37,13 @@ def inference_network(x, latent_dim, hidden_size):
5037 mu: Mean parameters for the variational family Normal
5138 sigma: Standard deviation parameters for the variational family Normal
5239 """
53- with slim .arg_scope ([slim .fully_connected ], activation_fn = tf .nn .relu ):
54- net = slim .flatten (x )
55- net = slim .fully_connected (net , hidden_size )
56- net = slim .fully_connected (net , hidden_size )
57- gaussian_params = slim .fully_connected (
58- net , latent_dim * 2 , activation_fn = None )
40+ inference_net = tfk .Sequential ([
41+ tfkl .Flatten (),
42+ tfkl .Dense (hidden_size , activation = tf .nn .relu ),
43+ tfkl .Dense (hidden_size , activation = tf .nn .relu ),
44+ tfkl .Dense (latent_dim * 2 , activation = None )
45+ ])
46+ gaussian_params = inference_net (x )
5947 # The mean parameter is unconstrained
6048 mu = gaussian_params [:, :latent_dim ]
6149 # The standard deviation must be positive. Parametrize with a softplus
@@ -73,174 +61,111 @@ def generative_network(z, hidden_size):
7361 Returns:
7462 bernoulli_logits: logits for the Bernoulli likelihood of the data
7563 """
76- with slim .arg_scope ([slim .fully_connected ], activation_fn = tf .nn .relu ):
77- net = slim .fully_connected (z , hidden_size )
78- net = slim .fully_connected (net , hidden_size )
79- bernoulli_logits = slim .fully_connected (net , 784 , activation_fn = None )
80- bernoulli_logits = tf .reshape (bernoulli_logits , [- 1 , 28 , 28 , 1 ])
81- return bernoulli_logits
64+ generative_net = tfk .Sequential ([
65+ tfkl .Dense (hidden_size , activation = tf .nn .relu ),
66+ tfkl .Dense (hidden_size , activation = tf .nn .relu ),
67+ tfkl .Dense (28 * 28 , activation = None )
68+ ])
69+ bernoulli_logits = generative_net (z )
70+ return tf .reshape (bernoulli_logits , [- 1 , 28 , 28 , 1 ])
8271
8372
8473def train ():
8574 # Train a Variational Autoencoder on MNIST
8675
8776 # Input placeholders
8877 with tf .name_scope ('data' ):
89- x = tf .placeholder (tf .float32 , [None , 28 , 28 , 1 ])
90- tf .summary .image ('data' , x )
78+ x = tfc .placeholder (tf .float32 , [None , 28 , 28 , 1 ])
79+ tfc .summary .image ('data' , x )
9180
92- with tf .variable_scope ('variational' ):
81+ with tfc .variable_scope ('variational' ):
9382 q_mu , q_sigma = inference_network (x = x ,
9483 latent_dim = FLAGS .latent_dim ,
9584 hidden_size = FLAGS .hidden_size )
9685 # The variational distribution is a Normal with mean and standard
9786 # deviation given by the inference network
98- q_z = distributions .Normal (loc = q_mu , scale = q_sigma )
99- assert q_z .reparameterization_type == distributions .FULLY_REPARAMETERIZED
87+ q_z = tfp . distributions .Normal (loc = q_mu , scale = q_sigma )
88+ assert q_z .reparameterization_type == tfp . distributions .FULLY_REPARAMETERIZED
10089
101- with tf .variable_scope ('model' ):
90+ with tfc .variable_scope ('model' ):
10291 # The likelihood is Bernoulli-distributed with logits given by the
10392 # generative network
10493 p_x_given_z_logits = generative_network (z = q_z .sample (),
10594 hidden_size = FLAGS .hidden_size )
106- p_x_given_z = distributions .Bernoulli (logits = p_x_given_z_logits )
95+ p_x_given_z = tfp . distributions .Bernoulli (logits = p_x_given_z_logits )
10796 posterior_predictive_samples = p_x_given_z .sample ()
108- tf .summary .image ('posterior_predictive' ,
97+ tfc .summary .image ('posterior_predictive' ,
10998 tf .cast (posterior_predictive_samples , tf .float32 ))
11099
111100 # Take samples from the prior
112- with tf .variable_scope ('model' , reuse = True ):
113- p_z = distributions .Normal (loc = np .zeros (FLAGS .latent_dim , dtype = np .float32 ),
101+ with tfc .variable_scope ('model' , reuse = True ):
102+ p_z = tfp . distributions .Normal (loc = np .zeros (FLAGS .latent_dim , dtype = np .float32 ),
114103 scale = np .ones (FLAGS .latent_dim , dtype = np .float32 ))
115104 p_z_sample = p_z .sample (FLAGS .n_samples )
116105 p_x_given_z_logits = generative_network (z = p_z_sample ,
117106 hidden_size = FLAGS .hidden_size )
118- prior_predictive = distributions .Bernoulli (logits = p_x_given_z_logits )
107+ prior_predictive = tfp . distributions .Bernoulli (logits = p_x_given_z_logits )
119108 prior_predictive_samples = prior_predictive .sample ()
120- tf .summary .image ('prior_predictive' ,
109+ tfc .summary .image ('prior_predictive' ,
121110 tf .cast (prior_predictive_samples , tf .float32 ))
122111
123112 # Take samples from the prior with a placeholder
124- with tf .variable_scope ('model' , reuse = True ):
113+ with tfc .variable_scope ('model' , reuse = True ):
125114 z_input = tf .placeholder (tf .float32 , [None , FLAGS .latent_dim ])
126115 p_x_given_z_logits = generative_network (z = z_input ,
127116 hidden_size = FLAGS .hidden_size )
128- prior_predictive_inp = distributions .Bernoulli (logits = p_x_given_z_logits )
117+ prior_predictive_inp = tfp . distributions .Bernoulli (logits = p_x_given_z_logits )
129118 prior_predictive_inp_sample = prior_predictive_inp .sample ()
130119
131120 # Build the evidence lower bound (ELBO) or the negative loss
132- kl = tf .reduce_sum (distributions .kl_divergence (q_z , p_z ), 1 )
121+ kl = tf .reduce_sum (tfp . distributions .kl_divergence (q_z , p_z ), 1 )
133122 expected_log_likelihood = tf .reduce_sum (p_x_given_z .log_prob (x ),
134123 [1 , 2 , 3 ])
135124
136125 elbo = tf .reduce_sum (expected_log_likelihood - kl , 0 )
137-
138- optimizer = tf .train .RMSPropOptimizer (learning_rate = 0.001 )
139-
126+ optimizer = tfc .train .RMSPropOptimizer (learning_rate = 0.001 )
140127 train_op = optimizer .minimize (- elbo )
141128
142129 # Merge all the summaries
143- summary_op = tf .summary .merge_all ()
130+ summary_op = tfc .summary .merge_all ()
144131
145- init_op = tf .global_variables_initializer ()
132+ init_op = tfc .global_variables_initializer ()
146133
147134 # Run training
148- sess = tf .InteractiveSession ()
135+ sess = tfc .InteractiveSession ()
149136 sess .run (init_op )
150137
151- mnist = read_data_sets (FLAGS .data_dir , one_hot = True )
138+ mnist_data = tfds .load (name = 'binarized_mnist' , split = 'train' , shuffle_files = False )
139+ dataset = mnist_data .repeat ().shuffle (buffer_size = 1024 ).batch (FLAGS .batch_size )
152140
153141 print ('Saving TensorBoard summaries and images to: %s' % FLAGS .logdir )
154- train_writer = tf .summary .FileWriter (FLAGS .logdir , sess .graph )
155-
156- # Get fixed MNIST digits for plotting posterior means during training
157- np_x_fixed , np_y = mnist .test .next_batch (5000 )
158- np_x_fixed = np_x_fixed .reshape (5000 , 28 , 28 , 1 )
159- np_x_fixed = (np_x_fixed > 0.5 ).astype (np .float32 )
142+ train_writer = tfc .summary .FileWriter (FLAGS .logdir , sess .graph )
160143
161144 t0 = time .time ()
162- for i in range (FLAGS .n_iterations ):
163- # Re-binarize the data at every batch; this improves results
164- np_x , _ = mnist .train .next_batch (FLAGS .batch_size )
165- np_x = np_x .reshape (FLAGS .batch_size , 28 , 28 , 1 )
166- np_x = (np_x > 0.5 ).astype (np .float32 )
145+ for i , batch in enumerate (tfds .as_numpy (dataset )):
146+ np_x = batch ['image' ]
167147 sess .run (train_op , {x : np_x })
168-
169- # Print progress and save samples every so often
170148 if i % FLAGS .print_every == 0 :
171149 np_elbo , summary_str = sess .run ([elbo , summary_op ], {x : np_x })
172150 train_writer .add_summary (summary_str , i )
173151 print ('Iteration: {0:d} ELBO: {1:.3f} s/iter: {2:.3e}' .format (
174152 i ,
175153 np_elbo / FLAGS .batch_size ,
176154 (time .time () - t0 ) / FLAGS .print_every ))
177- t0 = time .time ()
178-
179155 # Save samples
180156 np_posterior_samples , np_prior_samples = sess .run (
181157 [posterior_predictive_samples , prior_predictive_samples ], {x : np_x })
182158 for k in range (FLAGS .n_samples ):
183159 f_name = os .path .join (
184160 FLAGS .logdir , 'iter_%d_posterior_predictive_%d_data.jpg' % (i , k ))
185- imwrite (f_name , np_x [k , :, :, 0 ])
161+ imwrite (f_name , np_x [k , :, :, 0 ]. astype ( np . uint8 ) )
186162 f_name = os .path .join (
187163 FLAGS .logdir , 'iter_%d_posterior_predictive_%d_sample.jpg' % (i , k ))
188- imwrite (f_name , np_posterior_samples [k , :, :, 0 ])
164+ imwrite (f_name , np_posterior_samples [k , :, :, 0 ]. astype ( np . uint8 ) )
189165 f_name = os .path .join (
190166 FLAGS .logdir , 'iter_%d_prior_predictive_%d.jpg' % (i , k ))
191- imwrite (f_name , np_prior_samples [k , :, :, 0 ])
192-
193- # Plot the posterior predictive space
194- if FLAGS .latent_dim == 2 :
195- np_q_mu = sess .run (q_mu , {x : np_x_fixed })
196- cmap = mpl .colors .ListedColormap (sns .color_palette ("husl" ))
197- f , ax = plt .subplots (1 , figsize = (6 * 1.1618 , 6 ))
198- im = ax .scatter (np_q_mu [:, 0 ], np_q_mu [:, 1 ], c = np .argmax (np_y , 1 ), cmap = cmap ,
199- alpha = 0.7 )
200- ax .set_xlabel ('First dimension of sampled latent variable $z_1$' )
201- ax .set_ylabel ('Second dimension of sampled latent variable mean $z_2$' )
202- ax .set_xlim ([- 10. , 10. ])
203- ax .set_ylim ([- 10. , 10. ])
204- f .colorbar (im , ax = ax , label = 'Digit class' )
205- plt .tight_layout ()
206- plt .savefig (os .path .join (FLAGS .logdir ,
207- 'posterior_predictive_map_frame_%d.png' % i ))
208- plt .close ()
209-
210- nx = ny = 20
211- x_values = np .linspace (- 3 , 3 , nx )
212- y_values = np .linspace (- 3 , 3 , ny )
213- canvas = np .empty ((28 * ny , 28 * nx ))
214- for ii , yi in enumerate (x_values ):
215- for j , xi in enumerate (y_values ):
216- np_z = np .array ([[xi , yi ]])
217- x_mean = sess .run (prior_predictive_inp_sample , {z_input : np_z })
218- canvas [(nx - ii - 1 ) * 28 :(nx - ii ) * 28 , j *
219- 28 :(j + 1 ) * 28 ] = x_mean [0 ].reshape (28 , 28 )
220- imwrite (os .path .join (FLAGS .logdir ,
221- 'prior_predictive_map_frame_%d.png' % i ), canvas )
222- # plt.figure(figsize=(8, 10))
223- # Xi, Yi = np.meshgrid(x_values, y_values)
224- # plt.imshow(canvas, origin="upper")
225- # plt.tight_layout()
226- # plt.savefig()
227-
228- # Make the gifs
229- if FLAGS .latent_dim == 2 :
230- os .system (
231- 'convert -delay 15 -loop 0 {0}/posterior_predictive_map_frame*png {0}/posterior_predictive.gif'
232- .format (FLAGS .logdir ))
233- os .system (
234- 'convert -delay 15 -loop 0 {0}/prior_predictive_map_frame*png {0}/prior_predictive.gif'
235- .format (FLAGS .logdir ))
236-
237-
238- def main (_ ):
239- if tf .gfile .Exists (FLAGS .logdir ):
240- tf .gfile .DeleteRecursively (FLAGS .logdir )
241- tf .gfile .MakeDirs (FLAGS .logdir )
242- train ()
243-
167+ imwrite (f_name , np_prior_samples [k , :, :, 0 ].astype (np .uint8 ))
168+ t0 = time .time ()
244169
245170if __name__ == '__main__' :
246- tf . app . run ()
171+ train ()
0 commit comments