Skip to content

Commit 34ebad4

Browse files
committed
feat: add layer norm in normalizing flow
1 parent 49dfa2a commit 34ebad4

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

python/nutpie/normalizing_flow.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def _generate_sequences(k, r_vals):
3434
Returns:
3535
A NumPy boolean array of shape (comb(k, r), k) containing all sequences.
3636
"""
37+
if k > 30:
38+
raise ValueError("Too many sequences to enumerate.")
3739
all_sequences = []
3840
for r in r_vals:
3941
N = math.comb(k, r) # number of sequences
@@ -807,12 +809,12 @@ def __init__(
807809
)
808810

809811
def transform_and_log_det(self, x, condition=None):
810-
transformer_params = self.conditioner(x)
812+
transformer_params = self.conditioner(x.astype(jnp.float32))
811813
transformer = self._flat_params_to_transformer(transformer_params)
812814
return transformer.transform_and_log_det(x)
813815

814816
def inverse_and_log_det(self, y, condition=None):
815-
transformer_params = self.conditioner(y)
817+
transformer_params = self.conditioner(y.astype(jnp.float32))
816818
transformer = self._flat_params_to_transformer(transformer_params)
817819
return transformer.inverse_and_log_det(y)
818820

@@ -987,7 +989,12 @@ def make_flow_scan(
987989
size = MaskedCoupling.conditioner_output_size(dim, transformer)
988990

989991
key, key1 = jax.random.split(key)
990-
embed = eqx.nn.Linear(dim, n_embed, key=key1, dtype=jnp.float32)
992+
embed = eqx.nn.Sequential(
993+
[
994+
eqx.nn.Linear(dim, n_embed, key=key1, dtype=jnp.float32),
995+
eqx.nn.LayerNorm(shape=(n_embed,), dtype=jnp.float32),
996+
]
997+
)
991998
key, key1 = jax.random.split(key)
992999
embed_back = eqx.nn.Linear(n_deembed, size, key=key1, dtype=jnp.float32)
9931000

python/nutpie/transform_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ def make_transform_adapter(
878878
verbose=False,
879879
window_size=600,
880880
show_progress=False,
881-
nn_depth=1,
881+
nn_depth=None,
882882
nn_width=None,
883883
num_layers=9,
884884
num_diag_windows=9,

0 commit comments

Comments
 (0)