@@ -112,24 +112,14 @@ def fit_to_data(
112112
113113 for i in loop :
114114 # Shuffle data
115- start = time .time ()
116115 key , * subkeys = jr .split (key , 3 )
117116 train_data = [jr .permutation (subkeys [0 ], a ) for a in train_data ]
118117 val_data = [jr .permutation (subkeys [1 ], a ) for a in val_data ]
119- if verbose and i == 0 :
120- print ("shuffle timing:" , time .time () - start )
121-
122- start = time .time ()
123118
124119 key , subkey = jr .split (key )
125120 batches = get_batches (train_data , batch_size )
126121 batch_losses = []
127122
128- if verbose and i == 0 :
129- print ("batch timing:" , time .time () - start )
130-
131- start = time .time ()
132-
133123 if True :
134124 for batch in zip (* batches , strict = True ):
135125 key , subkey = jr .split (key )
@@ -156,10 +146,6 @@ def fit_to_data(
156146
157147 losses ["train" ].append ((sum (batch_losses ) / len (batch_losses )).item ())
158148
159- if verbose and i == 0 :
160- print ("step timing:" , time .time () - start )
161-
162- start = time .time ()
163149 # Val epoch
164150 batch_losses = []
165151 for batch in zip (* get_batches (val_data , batch_size ), strict = True ):
@@ -168,9 +154,6 @@ def fit_to_data(
168154 batch_losses .append (loss_i )
169155 losses ["val" ].append (sum (batch_losses ) / len (batch_losses ))
170156
171- if verbose and i == 0 :
172- print ("val timing:" , time .time () - start )
173-
174157 loop .set_postfix ({k : v [- 1 ] for k , v in losses .items ()})
175158 if losses ["val" ][- 1 ] == min (losses ["val" ]):
176159 best_params = params
@@ -228,7 +211,7 @@ def inverse_gradient_and_val(bijection, draw, grad, logp):
228211 )
229212 elif isinstance (bijection , bijections .Affine ):
230213 draw , logdet = bijection .inverse_and_log_det (draw )
231- grad = grad * bijection .scale
214+ grad = grad * unwrap ( bijection .scale )
232215 return (draw , grad , logp - logdet )
233216 elif isinstance (bijection , bijections .Vmap ):
234217
@@ -710,12 +693,9 @@ def update(self, seed, positions, gradients, logps):
710693 )
711694 params , static = eqx .partition (flow , eqx .is_inexact_array )
712695
713- start = time .time ()
714696 new_loss = self ._loss_fn (
715697 params , static , positions [- 128 :], gradients [- 128 :], logps [- 128 :]
716698 )
717- if self ._verbose :
718- print ("new loss function time: " , time .time () - start )
719699
720700 if self ._verbose :
721701 print (f"Chain { self ._chain } : New loss { new_loss } , old loss { old_loss } " )
@@ -903,8 +883,8 @@ def make_transform_adapter(
903883 make_optimizer = None ,
904884 coupling_type = "masked" ,
905885 mvscale_layer = False ,
906- n_embed = None ,
907- n_deembed = None ,
886+ num_project = None ,
887+ num_embed = None ,
908888):
909889 if extension_windows is None :
910890 extension_windows = []
@@ -918,8 +898,8 @@ def make_transform_adapter(
918898 dct_layer = dct_layer ,
919899 nn_depth = nn_depth ,
920900 nn_width = nn_width ,
921- n_embed = n_embed ,
922- n_deembed = n_deembed ,
901+ n_embed = num_project ,
902+ n_deembed = num_embed ,
923903 mvscale = mvscale_layer ,
924904 kind = coupling_type ,
925905 ),
0 commit comments