Skip to content

Commit 60235eb

Browse files
jstacclaude
andauthored
Fix bugs in ifp_advanced lecture (#757)
- Fix typo: jnp.block_until_ready(a_start) → a_star.block_until_ready() - Add missing functools.partial import for JAX JIT static arguments - Fix compute_asset_stationary JIT compilation with static_argnums - Remove invalid verbose parameter from solve_model calls - Refactor solve_model to use jax.lax.while_loop for better performance - Clean up code structure and improve function signatures 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <noreply@anthropic.com>
1 parent 5efda54 commit 60235eb

File tree

1 file changed

+79
-108
lines changed

1 file changed

+79
-108
lines changed

lectures/ifp_advanced.md

Lines changed: 79 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,12 @@ We require the following imports:
6060
```{code-cell} ipython3
6161
import matplotlib.pyplot as plt
6262
import numpy as np
63-
from quantecon import MarkovChain
63+
import quantecon as qe
6464
import jax
6565
import jax.numpy as jnp
6666
from jax import vmap
6767
from typing import NamedTuple
68+
from functools import partial
6869
```
6970

7071

@@ -129,7 +130,7 @@ does not grow too quickly.
129130

130131
When $\{R_t\}$ was constant we required that $\beta R < 1$.
131132

132-
Since it is now stochastic, we require that
133+
Since it is now stochastic, we require (see {cite}`ma2020income`) that
133134

134135
```{math}
135136
:label: fpbc2
@@ -140,15 +141,15 @@ G_R := \lim_{n \to \infty}
140141
\left(\mathbb E \prod_{t=1}^n R_t \right)^{1/n}
141142
```
142143

143-
Notice that, when $\{R_t\}$ takes some constant value $R$, this
144-
reduces to the previous restriction $\beta R < 1$.
145-
146144
The value $G_R$ can be thought of as the long run (geometric) average
147145
gross rate of return.
148146

149-
More intuition behind {eq}`fpbc2` is provided in {cite}`ma2020income`.
147+
To simplify this lecture, we will *assume that the interest rate process is
148+
IID*.
149+
150+
In that case, it is clear from the definition of $G_R$ that $G_R$ is just $\mathbb E R_t$.
150151

151-
Discussion on how to check it is given below.
152+
We test the condition $\beta \mathbb E R_t < 1$ in the code below.
152153

153154
Finally, we impose some routine technical restrictions on non-financial income.
154155

@@ -309,28 +310,6 @@ obtained by interpolating $\{a_i, c_i\}$ at each $z$.
309310

310311
In what follows, we use linear interpolation.
311312

312-
### Testing the Assumptions
313-
314-
Convergence of time iteration is dependent on the condition $\beta G_R < 1$ being satisfied.
315-
316-
One can check this using the fact that $G_R$ is equal to the spectral
317-
radius of the matrix $L$ defined by
318-
319-
$$
320-
L(z, \hat z) := P(z, \hat z) \int R(\hat z, x) \phi(x) dx
321-
$$
322-
323-
This identity is proved in {cite}`ma2020income`, where $\phi$ is the
324-
density of the innovation $\zeta_t$ to returns on assets.
325-
326-
(Remember that $\mathsf Z$ is a finite set, so this expression defines a matrix.)
327-
328-
Checking the condition is even easier when $\{R_t\}$ is IID.
329-
330-
In that case, it is clear from the definition of $G_R$ that $G_R$
331-
is just $\mathbb E R_t$.
332-
333-
We test the condition $\beta \mathbb E R_t < 1$ in the code below.
334313

335314
## Implementation
336315

@@ -354,32 +333,31 @@ class IFP(NamedTuple):
354333
ζ_draws: jnp.ndarray
355334
356335
357-
def create_ifp(γ=1.5,
358-
β=0.96,
359-
P=np.array([(0.9, 0.1),
360-
(0.1, 0.9)]),
361-
a_r=0.16,
362-
b_r=0.0,
363-
a_y=0.2,
364-
b_y=0.5,
365-
shock_draw_size=100,
366-
grid_max=100,
367-
grid_size=100,
368-
seed=1234):
336+
def create_ifp(
337+
γ=1.5, # Utility parameter
338+
β=0.96, # Discount factor
339+
P=jnp.array([(0.9, 0.1), # Default Markov chain for Z
340+
(0.1, 0.9)]),
341+
a_r=0.16, # Volatility term in R shock
342+
b_r=0.0, # Mean shift R shock
343+
a_y=0.2, # Volatility term in Y shock
344+
b_y=0.5, # Mean shift Y shock
345+
shock_draw_size=100, # For Monte Carlo
346+
grid_max=100, # Exogenous grid max
347+
grid_size=100, # Exogenous grid size
348+
seed=1234 # Random seed
349+
):
369350
"""
370351
Create an instance of IFP with the given parameters.
352+
371353
"""
372-
# Test stability assuming {R_t} is IID and adopts the lognormal
373-
# specification given below. The test is then β E R_t < 1.
354+
# Test stability assuming {R_t} is IID and ln R ~ N(b_r, a_r)
374355
ER = np.exp(b_r + a_r**2 / 2)
375356
assert β * ER < 1, "Stability condition failed."
376357
377-
# Convert to JAX arrays
378-
P = jnp.array(P)
379-
380358
# Generate random draws using JAX
381359
key = jax.random.PRNGKey(seed)
382-
key, subkey1, subkey2 = jax.random.split(key, 3)
360+
subkey1, subkey2 = jax.random.split(key)
383361
η_draws = jax.random.normal(subkey1, (shock_draw_size,))
384362
ζ_draws = jax.random.normal(subkey2, (shock_draw_size,))
385363
s_grid = jnp.linspace(0, grid_max, grid_size)
@@ -409,96 +387,83 @@ def Y(z, η, a_y, b_y):
409387
Here's the Coleman-Reffett operator using JAX:
410388

411389
```{code-cell} ipython3
412-
@jax.jit
413-
def K(ae_vals, c_vals, ifp):
390+
def K(
391+
a_in: jnp.array, # a_in[i, z] is an asset grid
392+
c_in: jnp.array, # c_in[i, z] = consumption at a_in[i, z]
393+
ifp: IFP
394+
):
414395
"""
415396
The Coleman--Reffett operator for the income fluctuation problem,
416397
using the endogenous grid method with JAX.
417398
418-
* ifp is an instance of IFP
419-
* ae_vals[i, z] is an asset grid
420-
* c_vals[i, z] is consumption at ae_vals[i, z]
421399
"""
422400
423401
# Extract parameters from ifp
424-
γ, β, P = ifp.γ, ifp.β, ifp.P
425-
a_r, b_r, a_y, b_y = ifp.a_r, ifp.b_r, ifp.a_y, ifp.b_y
426-
s_grid, η_draws, ζ_draws = ifp.s_grid, ifp.η_draws, ifp.ζ_draws
402+
γ, β, P, a_r, b_r, a_y, b_y, s_grid, η_draws, ζ_draws = ifp
427403
n = len(P)
428404
429-
# Allocate memory
430-
c_out = jnp.empty_like(c_vals)
431-
432-
# Obtain c_i at each s_i, z, store in c_out[i, z], computing
433-
# the expectation term by Monte Carlo
434405
def compute_expectation(s, z):
435-
"""Compute expectation for given s and z"""
436406
def inner_expectation(z_hat):
437-
# Vectorize over shocks
438407
def compute_term(η, ζ):
439408
R_hat = R(z_hat, ζ, a_r, b_r)
440409
Y_hat = Y(z_hat, η, a_y, b_y)
441410
a_val = R_hat * s + Y_hat
442411
# Interpolate consumption
443-
c_interp = jnp.interp(a_val, ae_vals[:, z_hat], c_vals[:, z_hat])
444-
U = u_prime(c_interp, γ)
445-
return R_hat * U
446-
412+
c_interp = jnp.interp(a_val, a_in[:, z_hat], c_in[:, z_hat])
413+
mu = u_prime(c_interp, γ)
414+
return R_hat * mu
447415
# Vectorize over all shock combinations
448416
η_grid, ζ_grid = jnp.meshgrid(η_draws, ζ_draws, indexing='ij')
449417
terms = vmap(vmap(compute_term))(η_grid, ζ_grid)
450418
return P[z, z_hat] * jnp.mean(terms)
451-
452419
# Sum over z_hat states
453420
Ez = jnp.sum(vmap(inner_expectation)(jnp.arange(n)))
454421
return u_prime_inv(β * Ez, γ)
455422
456423
# Vectorize over s_grid and z
457-
c_out = vmap(vmap(compute_expectation, in_axes=(None, 0)),
458-
in_axes=(0, None))(s_grid, jnp.arange(n))
459-
424+
compute_exp_v1 = vmap(compute_expectation, in_axes=(None, 0))
425+
compute_exp_v2 = vmap(compute_exp_v1, in_axes=(0, None))
426+
c_out = compute_exp_v2(s_grid, jnp.arange(n))
460427
# Calculate endogenous asset grid
461-
ae_out = s_grid[:, None] + c_out
462-
463-
# Fixing a consumption-asset pair at (0, 0) improves interpolation
428+
a_out = s_grid[:, None] + c_out
429+
# Fix consumption-asset pair at (0, 0)
464430
c_out = c_out.at[0, :].set(0)
465-
ae_out = ae_out.at[0, :].set(0)
431+
a_out = a_out.at[0, :].set(0)
466432
467-
return ae_out, c_out
433+
return a_out, c_out
468434
```
469435

470436
The next function solves for an approximation of the optimal consumption policy
471437
via time iteration using JAX:
472438

473439
```{code-cell} ipython3
474-
def solve_model_time_iter(
475-
model, # Class with model information
476-
a_vec, # Initial condition for assets
477-
σ_vec, # Initial condition for consumption
478-
tol=1e-4,
479-
max_iter=1000,
480-
verbose=True,
481-
print_skip=25
482-
):
483-
484-
# Set up loop
485-
i = 0
486-
error = tol + 1
487-
488-
while i < max_iter and error > tol:
489-
a_new, σ_new = K(a_vec, σ_vec, model)
490-
error = jnp.max(jnp.abs(σ_vec - σ_new))
440+
@jax.jit
441+
def solve_model(
442+
ifp: IFP,
443+
c_init: jnp.ndarray, # Initial guess of σ on grid endogenous grid
444+
a_init: jnp.ndarray, # Initial endogenous grid
445+
tol: float = 1e-5,
446+
max_iter: int = 1000
447+
) -> jnp.ndarray:
448+
" Solve the model using time iteration with EGM. "
449+
450+
def condition(loop_state):
451+
c_in, a_in, i, error = loop_state
452+
return (error > tol) & (i < max_iter)
453+
454+
def body(loop_state):
455+
c_in, a_in, i, error = loop_state
456+
c_out, a_out = K(c_in, a_in, ifp)
457+
error = jnp.max(jnp.abs(c_out - c_in))
491458
i += 1
492-
if verbose and i % print_skip == 0:
493-
print(f"Error at iteration {i} is {error}.")
494-
a_vec, σ_vec = a_new, σ_new
459+
return c_out, a_out, i, error
495460
496-
if error > tol:
497-
print("Failed to converge!")
498-
elif verbose:
499-
print(f"\nConverged in {i} iterations.")
461+
i, error = 0, tol + 1
462+
initial_state = (c_init, a_init, i, error)
463+
final_loop_state = jax.lax.while_loop(condition, body, initial_state)
464+
c_out, a_out, i, error = final_loop_state
500465
501-
return a_new, σ_new
466+
return c_out, a_out
502467
```
503468

504469
Now we can create an instance and solve the model using JAX:
@@ -522,10 +487,16 @@ a_init = σ_init.copy()
522487
Let's generate an approximation solution with JAX:
523488

524489
```{code-cell} ipython3
525-
a_star, σ_star = solve_model_time_iter(ifp, a_init, σ_init, print_skip=5)
490+
a_star, σ_star = solve_model(ifp, a_init, σ_init)
526491
```
527492

493+
Let's try it again with a timer.
528494

495+
```{code-cell} python3
496+
with qe.Timer(precision=8):
497+
a_star, σ_star = solve_model(ifp, a_init, σ_init)
498+
a_star.block_until_ready()
499+
```
529500

530501
## Simulation
531502

@@ -543,7 +514,6 @@ The function takes a solution pair `c_vec` and `a_vec`, understanding them
543514
as representing an optimal policy associated with a given model `ifp`
544515

545516
```{code-cell} ipython3
546-
@jax.jit
547517
def simulate_household(
548518
key, a_0, z_idx_0, c_vec, a_vec, ifp, T
549519
):
@@ -593,6 +563,7 @@ def simulate_household(
593563
Now we write a function to simulate many households in parallel.
594564

595565
```{code-cell} ipython3
566+
@partial(jax.jit, static_argnums=(3, 4, 5))
596567
def compute_asset_stationary(
597568
c_vec, a_vec, ifp, num_households=50_000, T=500, seed=1234
598569
):
@@ -623,7 +594,7 @@ def compute_asset_stationary(
623594
)
624595
assets = sim_all_households(keys, a_0_vector, z_idx_0_vector, c_vec, a_vec, ifp, T)
625596
626-
return np.array(assets)
597+
return jnp.array(assets)
627598
```
628599

629600
We'll need some inequality measures for visualization, so let's define them first:
@@ -671,7 +642,7 @@ s_grid = ifp.s_grid
671642
n_z = len(ifp.P)
672643
a_init = s_grid[:, None] * jnp.ones(n_z)
673644
c_init = a_init
674-
a_vec, c_vec = solve_model_time_iter(ifp, a_init, c_init)
645+
a_vec, c_vec = solve_model(ifp, a_init, c_init)
675646
assets = compute_asset_stationary(c_vec, a_vec, ifp, num_households=200_000)
676647
677648
# Compute Gini coefficient for the plot
@@ -763,8 +734,8 @@ for a_r in a_r_vals:
763734
n_z_temp = len(ifp_temp.P)
764735
a_init_temp = s_grid_temp[:, None] * jnp.ones(n_z_temp)
765736
c_init_temp = a_init_temp
766-
a_vec_temp, c_vec_temp = solve_model_time_iter(
767-
ifp_temp, a_init_temp, c_init_temp, verbose=False
737+
a_vec_temp, c_vec_temp = solve_model(
738+
ifp_temp, a_init_temp, c_init_temp
768739
)
769740
770741
# Simulate households
@@ -840,8 +811,8 @@ for a_y in a_y_vals:
840811
n_z_temp = len(ifp_temp.P)
841812
a_init_temp = s_grid_temp[:, None] * jnp.ones(n_z_temp)
842813
c_init_temp = a_init_temp
843-
a_vec_temp, c_vec_temp = solve_model_time_iter(
844-
ifp_temp, a_init_temp, c_init_temp, verbose=False
814+
a_vec_temp, c_vec_temp = solve_model(
815+
ifp_temp, a_init_temp, c_init_temp
845816
)
846817
847818
# Simulate households

0 commit comments

Comments
 (0)