@@ -34,6 +34,8 @@ def _generate_sequences(k, r_vals):
3434 Returns:
3535 A NumPy boolean array of shape (comb(k, r), k) containing all sequences.
3636 """
37+ if k > 30 :
38+ raise ValueError ("Too many sequences to enumerate." )
3739 all_sequences = []
3840 for r in r_vals :
3941 N = math .comb (k , r ) # number of sequences
@@ -807,12 +809,12 @@ def __init__(
807809 )
808810
809811 def transform_and_log_det (self , x , condition = None ):
810- transformer_params = self .conditioner (x )
812+ transformer_params = self .conditioner (x . astype ( jnp . float32 ) )
811813 transformer = self ._flat_params_to_transformer (transformer_params )
812814 return transformer .transform_and_log_det (x )
813815
814816 def inverse_and_log_det (self , y , condition = None ):
815- transformer_params = self .conditioner (y )
817+ transformer_params = self .conditioner (y . astype ( jnp . float32 ) )
816818 transformer = self ._flat_params_to_transformer (transformer_params )
817819 return transformer .inverse_and_log_det (y )
818820
@@ -987,7 +989,12 @@ def make_flow_scan(
987989 size = MaskedCoupling .conditioner_output_size (dim , transformer )
988990
989991 key , key1 = jax .random .split (key )
990- embed = eqx .nn .Linear (dim , n_embed , key = key1 , dtype = jnp .float32 )
992+ embed = eqx .nn .Sequential (
993+ [
994+ eqx .nn .Linear (dim , n_embed , key = key1 , dtype = jnp .float32 ),
995+ eqx .nn .LayerNorm (shape = (n_embed ,), dtype = jnp .float32 ),
996+ ]
997+ )
991998 key , key1 = jax .random .split (key )
992999 embed_back = eqx .nn .Linear (n_deembed , size , key = key1 , dtype = jnp .float32 )
9931000
0 commit comments