Skip to content

Commit 7723f43

Browse files
committed
feat: make mvscale layer optional
1 parent e979226 commit 7723f43

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

python/nutpie/normalizing_flow.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,7 @@ def make_flow_scan(
969969
nn_depth=None,
970970
n_embed=None,
971971
n_deembed=None,
972+
mvscale=False,
972973
):
973974
dim = n_dim
974975

@@ -1060,8 +1061,11 @@ def make_layer(key, mask, embed, embed_back):
10601061
coupling,
10611062
)
10621063

1063-
mvscale = make_mvscale(key4, dim)
1064-
return bijections.Chain([coupling, mvscale])
1064+
if mvscale:
1065+
scale = make_mvscale(key4, dim)
1066+
return bijections.Chain([coupling, scale])
1067+
else:
1068+
return bijections.Chain([coupling])
10651069

10661070
keys = jax.random.split(key, n_layers)
10671071

@@ -1214,6 +1218,7 @@ def make_flow(
12141218
n_embed=None,
12151219
n_deembed=None,
12161220
kind="subset",
1221+
mvscale=False,
12171222
):
12181223
positions = np.array(positions)
12191224
gradients = np.array(gradients)
@@ -1283,6 +1288,7 @@ def make_flow(
12831288
nn_depth=nn_depth,
12841289
n_embed=n_embed,
12851290
n_deembed=n_deembed,
1291+
mvscale=mvscale,
12861292
)
12871293
else:
12881294
raise ValueError(f"Unknown flow kind: {kind}")

python/nutpie/transform_adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,7 @@ def make_transform_adapter(
900900
debug_save_bijection=False,
901901
make_optimizer=None,
902902
coupling_type="masked",
903+
mvscale_layer=False,
903904
n_embed=None,
904905
n_deembed=None,
905906
):
@@ -917,6 +918,7 @@ def make_transform_adapter(
917918
nn_width=nn_width,
918919
n_embed=n_embed,
919920
n_deembed=n_deembed,
921+
mvscale=mvscale_layer,
920922
kind=coupling_type,
921923
),
922924
show_progress=show_progress,

0 commit comments

Comments
 (0)