|
73 | 73 | "import jax.lax as lax\n", |
74 | 74 | "import jax.nn as jnn\n", |
75 | 75 | "import jax.numpy as jnp\n", |
76 | | - "import jax.random as jrandom\n", |
| 76 | + "import jax.random as jr\n", |
77 | 77 | "import matplotlib.pyplot as plt\n", |
78 | 78 | "import optax # https://github.com/deepmind/optax\n", |
79 | 79 | "import scipy.stats as stats\n", |
|
111 | 111 | "\n", |
112 | 112 | " def __init__(self, *, data_size, width_size, depth, key, **kwargs):\n", |
113 | 113 | " super().__init__(**kwargs)\n", |
114 | | - " keys = jrandom.split(key, depth + 1)\n", |
| 114 | + " keys = jr.split(key, depth + 1)\n", |
115 | 115 | " layers = []\n", |
116 | 116 | " if depth == 0:\n", |
117 | 117 | " layers.append(\n", |
|
150 | 150 | "\n", |
151 | 151 | " def __init__(self, *, in_size, out_size, key, **kwargs):\n", |
152 | 152 | " super().__init__(**kwargs)\n", |
153 | | - " key1, key2, key3 = jrandom.split(key, 3)\n", |
| 153 | + " key1, key2, key3 = jr.split(key, 3)\n", |
154 | 154 | " self.lin1 = eqx.nn.Linear(in_size, out_size, key=key1)\n", |
155 | 155 | " self.lin2 = eqx.nn.Linear(1, out_size, key=key2)\n", |
156 | 156 | " self.lin3 = eqx.nn.Linear(1, out_size, use_bias=False, key=key3)\n", |
|
251 | 251 | " **kwargs,\n", |
252 | 252 | " ):\n", |
253 | 253 | " super().__init__(**kwargs)\n", |
254 | | - " keys = jrandom.split(key, num_blocks)\n", |
| 254 | + " keys = jr.split(key, num_blocks)\n", |
255 | 255 | " self.funcs = [\n", |
256 | 256 | " Func(\n", |
257 | 257 | " data_size=data_size,\n", |
|
274 | 274 | " else:\n", |
275 | 275 | " term = diffrax.ODETerm(approx_logp_wrapper)\n", |
276 | 276 | " solver = diffrax.Tsit5()\n", |
277 | | - " eps = jrandom.normal(key, y.shape)\n", |
| 277 | + " eps = jr.normal(key, y.shape)\n", |
278 | 278 | " delta_log_likelihood = 0.0\n", |
279 | 279 | " for func in reversed(self.funcs):\n", |
280 | 280 | " y = (y, delta_log_likelihood)\n", |
|
286 | 286 | "\n", |
287 | 287 | " # Runs forward-in-time to draw samples from the CNF.\n", |
288 | 288 | " def sample(self, *, key):\n", |
289 | | - " y = jrandom.normal(key, (self.data_size,))\n", |
| 289 | + " y = jr.normal(key, (self.data_size,))\n", |
290 | 290 | " for func in self.funcs:\n", |
291 | 291 | " term = diffrax.ODETerm(func)\n", |
292 | 292 | " solver = diffrax.Tsit5()\n", |
|
300 | 300 | " t_so_far = self.t0\n", |
301 | 301 | " t_end = self.t0 + (self.t1 - self.t0) * len(self.funcs)\n", |
302 | 302 | " save_times = jnp.linspace(self.t0, t_end, 6)\n", |
303 | | - " y = jrandom.normal(key, (self.data_size,))\n", |
| 303 | + " y = jr.normal(key, (self.data_size,))\n", |
304 | 304 | " out = []\n", |
305 | 305 | " for i, func in enumerate(self.funcs):\n", |
306 | 306 | " if i == len(self.funcs) - 1:\n", |
|
404 | 404 | "class DataLoader(eqx.Module):\n", |
405 | 405 | " arrays: tuple[jnp.ndarray, ...]\n", |
406 | 406 | " batch_size: int\n", |
407 | | - " key: jrandom.PRNGKey\n", |
| 407 | + " key: jr.PRNGKey\n", |
408 | 408 | "\n", |
409 | 409 | " def __check_init__(self):\n", |
410 | 410 | " dataset_size = self.arrays[0].shape[0]\n", |
|
414 | 414 | " dataset_size = self.arrays[0].shape[0]\n", |
415 | 415 | " num_batches = dataset_size // self.batch_size\n", |
416 | 416 | " epoch = step // num_batches\n", |
417 | | - " key = jrandom.fold_in(self.key, epoch)\n", |
418 | | - " perm = jrandom.permutation(key, jnp.arange(dataset_size))\n", |
| 417 | + " key = jr.fold_in(self.key, epoch)\n", |
| 418 | + " perm = jr.permutation(key, jnp.arange(dataset_size))\n", |
419 | 419 | " start = (step % num_batches) * self.batch_size\n", |
420 | 420 | " slice_size = self.batch_size\n", |
421 | 421 | " batch_indices = lax.dynamic_slice_in_dim(perm, start, slice_size)\n", |
|
464 | 464 | " else:\n", |
465 | 465 | " out_path = pathlib.Path(out_path)\n", |
466 | 466 | "\n", |
467 | | - " key = jrandom.PRNGKey(seed)\n", |
468 | | - " model_key, loader_key, loss_key, sample_key = jrandom.split(key, 4)\n", |
| 467 | + " key = jr.PRNGKey(seed)\n", |
| 468 | + " model_key, loader_key, loss_key, sample_key = jr.split(key, 4)\n", |
469 | 469 | "\n", |
470 | 470 | " dataset, weights, mean, std, img, width, height = get_data(in_path)\n", |
471 | 471 | " dataset_size, data_size = dataset.shape\n", |
|
486 | 486 | " @eqx.filter_value_and_grad\n", |
487 | 487 | " def loss(model, data, weight, loss_key):\n", |
488 | 488 | " batch_size, _ = data.shape\n", |
489 | | - " noise_key, train_key = jrandom.split(loss_key, 2)\n", |
490 | | - " train_key = jrandom.split(key, batch_size)\n", |
491 | | - " data = data + jrandom.normal(noise_key, data.shape) * 0.5 / std\n", |
| 489 | + " noise_key, train_key = jr.split(loss_key, 2)\n", |
| 490 | + " train_key = jr.split(key, batch_size)\n", |
| 491 | + " data = data + jr.normal(noise_key, data.shape) * 0.5 / std\n", |
492 | 492 | " log_likelihood = jax.vmap(model.train)(data, key=train_key)\n", |
493 | 493 | " return -jnp.mean(weight * log_likelihood) # minimise negative log-likelihood\n", |
494 | 494 | "\n", |
|
514 | 514 | " value = value + value_\n", |
515 | 515 | " grads = jax.tree_util.tree_map(lambda a, b: a + b, grads, grads_)\n", |
516 | 516 | " step = step + 1\n", |
517 | | - " loss_key = jrandom.split(loss_key, 1)[0]\n", |
| 517 | + " loss_key = jr.split(loss_key, 1)[0]\n", |
518 | 518 | " return value, grads, step, loss_key\n", |
519 | 519 | "\n", |
520 | 520 | " value, grads, step, loss_key = lax.fori_loop(\n", |
|
537 | 537 | " print(f\"Step: {step}, Loss: {value}, Computation time: {end - start}\")\n", |
538 | 538 | "\n", |
539 | 539 | " num_samples = 5000\n", |
540 | | - " sample_key = jrandom.split(sample_key, num_samples)\n", |
| 540 | + " sample_key = jr.split(sample_key, num_samples)\n", |
541 | 541 | " samples = jax.vmap(model.sample)(key=sample_key)\n", |
542 | 542 | " sample_flows = jax.vmap(model.sample_flow, out_axes=-1)(key=sample_key)\n", |
543 | 543 | " fig, (*axs, ax, axtrue) = plt.subplots(\n", |
|
0 commit comments