@@ -60,11 +60,12 @@ We require the following imports:
6060``` {code-cell} ipython3
6161import matplotlib.pyplot as plt
6262import numpy as np
63- from quantecon import MarkovChain
63+ import quantecon as qe
6464import jax
6565import jax.numpy as jnp
6666from jax import vmap
6767from typing import NamedTuple
68+ from functools import partial
6869```
6970
7071
@@ -129,7 +130,7 @@ does not grow too quickly.
129130
130131When $\{ R_t\} $ was constant we required that $\beta R < 1$.
131132
132- Since it is now stochastic, we require that
133+ Since it is now stochastic, we require (see {cite} ` ma2020income ` ) that
133134
134135``` {math}
135136:label: fpbc2
@@ -140,15 +141,15 @@ G_R := \lim_{n \to \infty}
140141\left(\mathbb E \prod_{t=1}^n R_t \right)^{1/n}
141142```
142143
143- Notice that, when $\{ R_t\} $ takes some constant value $R$, this
144- reduces to the previous restriction $\beta R < 1$.
145-
146144The value $G_R$ can be thought of as the long run (geometric) average
147145gross rate of return.
148146
149- More intuition behind {eq}` fpbc2 ` is provided in {cite}` ma2020income ` .
147+ To simplify this lecture, we will * assume that the interest rate process is
148+ IID* .
149+
150+ In that case, it is clear from the definition of $G_R$ that $G_R$ is just $\mathbb E R_t$.
150151
151- Discussion on how to check it is given below.
152+ We test the condition $\beta \mathbb E R_t < 1$ in the code below.
152153
153154Finally, we impose some routine technical restrictions on non-financial income.
154155
@@ -309,28 +310,6 @@ obtained by interpolating $\{a_i, c_i\}$ at each $z$.
309310
310311In what follows, we use linear interpolation.
311312
312- ### Testing the Assumptions
313-
314- Convergence of time iteration is dependent on the condition $\beta G_R < 1$ being satisfied.
315-
316- One can check this using the fact that $G_R$ is equal to the spectral
317- radius of the matrix $L$ defined by
318-
319- $$
320- L(z, \hat z) := P(z, \hat z) \int R(\hat z, x) \phi(x) dx
321- $$
322-
323- This identity is proved in {cite}` ma2020income ` , where $\phi$ is the
324- density of the innovation $\zeta_t$ to returns on assets.
325-
326- (Remember that $\mathsf Z$ is a finite set, so this expression defines a matrix.)
327-
328- Checking the condition is even easier when $\{ R_t\} $ is IID.
329-
330- In that case, it is clear from the definition of $G_R$ that $G_R$
331- is just $\mathbb E R_t$.
332-
333- We test the condition $\beta \mathbb E R_t < 1$ in the code below.
334313
335314## Implementation
336315
@@ -354,32 +333,31 @@ class IFP(NamedTuple):
354333 ζ_draws: jnp.ndarray
355334
356335
357- def create_ifp(γ=1.5,
358- β=0.96,
359- P=np.array([(0.9, 0.1),
360- (0.1, 0.9)]),
361- a_r=0.16,
362- b_r=0.0,
363- a_y=0.2,
364- b_y=0.5,
365- shock_draw_size=100,
366- grid_max=100,
367- grid_size=100,
368- seed=1234):
336+ def create_ifp(
337+ γ=1.5, # Utility parameter
338+ β=0.96, # Discount factor
339+ P=jnp.array([(0.9, 0.1), # Default Markov chain for Z
340+ (0.1, 0.9)]),
341+ a_r=0.16, # Volatility term in R shock
342+ b_r=0.0, # Mean shift R shock
343+ a_y=0.2, # Volatility term in Y shock
344+ b_y=0.5, # Mean shift Y shock
345+ shock_draw_size=100, # For Monte Carlo
346+ grid_max=100, # Exogenous grid max
347+ grid_size=100, # Exogenous grid size
348+ seed=1234 # Random seed
349+ ):
369350 """
370351 Create an instance of IFP with the given parameters.
352+
371353 """
372- # Test stability assuming {R_t} is IID and adopts the lognormal
373- # specification given below. The test is then β E R_t < 1.
354+ # Test stability assuming {R_t} is IID and ln R ~ N(b_r, a_r)
374355 ER = np.exp(b_r + a_r**2 / 2)
375356 assert β * ER < 1, "Stability condition failed."
376357
377- # Convert to JAX arrays
378- P = jnp.array(P)
379-
380358 # Generate random draws using JAX
381359 key = jax.random.PRNGKey(seed)
382- key, subkey1, subkey2 = jax.random.split(key, 3 )
360+ subkey1, subkey2 = jax.random.split(key)
383361 η_draws = jax.random.normal(subkey1, (shock_draw_size,))
384362 ζ_draws = jax.random.normal(subkey2, (shock_draw_size,))
385363 s_grid = jnp.linspace(0, grid_max, grid_size)
@@ -409,96 +387,83 @@ def Y(z, η, a_y, b_y):
409387Here's the Coleman-Reffett operator using JAX:
410388
411389``` {code-cell} ipython3
412- @jax.jit
413- def K(ae_vals, c_vals, ifp):
390+ def K(
391+ a_in: jnp.array, # a_in[i, z] is an asset grid
392+ c_in: jnp.array, # c_in[i, z] = consumption at a_in[i, z]
393+ ifp: IFP
394+ ):
414395 """
415396 The Coleman--Reffett operator for the income fluctuation problem,
416397 using the endogenous grid method with JAX.
417398
418- * ifp is an instance of IFP
419- * ae_vals[i, z] is an asset grid
420- * c_vals[i, z] is consumption at ae_vals[i, z]
421399 """
422400
423401 # Extract parameters from ifp
424- γ, β, P = ifp.γ, ifp.β, ifp.P
425- a_r, b_r, a_y, b_y = ifp.a_r, ifp.b_r, ifp.a_y, ifp.b_y
426- s_grid, η_draws, ζ_draws = ifp.s_grid, ifp.η_draws, ifp.ζ_draws
402+ γ, β, P, a_r, b_r, a_y, b_y, s_grid, η_draws, ζ_draws = ifp
427403 n = len(P)
428404
429- # Allocate memory
430- c_out = jnp.empty_like(c_vals)
431-
432- # Obtain c_i at each s_i, z, store in c_out[i, z], computing
433- # the expectation term by Monte Carlo
434405 def compute_expectation(s, z):
435- """Compute expectation for given s and z"""
436406 def inner_expectation(z_hat):
437- # Vectorize over shocks
438407 def compute_term(η, ζ):
439408 R_hat = R(z_hat, ζ, a_r, b_r)
440409 Y_hat = Y(z_hat, η, a_y, b_y)
441410 a_val = R_hat * s + Y_hat
442411 # Interpolate consumption
443- c_interp = jnp.interp(a_val, ae_vals[:, z_hat], c_vals[:, z_hat])
444- U = u_prime(c_interp, γ)
445- return R_hat * U
446-
412+ c_interp = jnp.interp(a_val, a_in[:, z_hat], c_in[:, z_hat])
413+ mu = u_prime(c_interp, γ)
414+ return R_hat * mu
447415 # Vectorize over all shock combinations
448416 η_grid, ζ_grid = jnp.meshgrid(η_draws, ζ_draws, indexing='ij')
449417 terms = vmap(vmap(compute_term))(η_grid, ζ_grid)
450418 return P[z, z_hat] * jnp.mean(terms)
451-
452419 # Sum over z_hat states
453420 Ez = jnp.sum(vmap(inner_expectation)(jnp.arange(n)))
454421 return u_prime_inv(β * Ez, γ)
455422
456423 # Vectorize over s_grid and z
457- c_out = vmap(vmap( compute_expectation, in_axes=(None, 0)),
458- in_axes=(0, None))(s_grid, jnp.arange(n ))
459-
424+ compute_exp_v1 = vmap(compute_expectation, in_axes=(None, 0))
425+ compute_exp_v2 = vmap(compute_exp_v1, in_axes=(0, None))
426+ c_out = compute_exp_v2(s_grid, jnp.arange(n))
460427 # Calculate endogenous asset grid
461- ae_out = s_grid[:, None] + c_out
462-
463- # Fixing a consumption-asset pair at (0, 0) improves interpolation
428+ a_out = s_grid[:, None] + c_out
429+ # Fix consumption-asset pair at (0, 0)
464430 c_out = c_out.at[0, :].set(0)
465- ae_out = ae_out .at[0, :].set(0)
431+ a_out = a_out .at[0, :].set(0)
466432
467- return ae_out , c_out
433+ return a_out , c_out
468434```
469435
470436The next function solves for an approximation of the optimal consumption policy
471437via time iteration using JAX:
472438
473439``` {code-cell} ipython3
474- def solve_model_time_iter(
475- model, # Class with model information
476- a_vec, # Initial condition for assets
477- σ_vec, # Initial condition for consumption
478- tol=1e-4,
479- max_iter=1000,
480- verbose=True,
481- print_skip=25
482- ):
483-
484- # Set up loop
485- i = 0
486- error = tol + 1
487-
488- while i < max_iter and error > tol:
489- a_new, σ_new = K(a_vec, σ_vec, model)
490- error = jnp.max(jnp.abs(σ_vec - σ_new))
440+ @jax.jit
441+ def solve_model(
442+ ifp: IFP,
443+ c_init: jnp.ndarray, # Initial guess of σ on grid endogenous grid
444+ a_init: jnp.ndarray, # Initial endogenous grid
445+ tol: float = 1e-5,
446+ max_iter: int = 1000
447+ ) -> jnp.ndarray:
448+ " Solve the model using time iteration with EGM. "
449+
450+ def condition(loop_state):
451+ c_in, a_in, i, error = loop_state
452+ return (error > tol) & (i < max_iter)
453+
454+ def body(loop_state):
455+ c_in, a_in, i, error = loop_state
456+ c_out, a_out = K(c_in, a_in, ifp)
457+ error = jnp.max(jnp.abs(c_out - c_in))
491458 i += 1
492- if verbose and i % print_skip == 0:
493- print(f"Error at iteration {i} is {error}.")
494- a_vec, σ_vec = a_new, σ_new
459+ return c_out, a_out, i, error
495460
496- if error > tol:
497- print("Failed to converge!" )
498- elif verbose:
499- print(f"\nConverged in {i} iterations.")
461+ i, error = 0, tol + 1
462+ initial_state = (c_init, a_init, i, error )
463+ final_loop_state = jax.lax.while_loop(condition, body, initial_state)
464+ c_out, a_out, i, error = final_loop_state
500465
501- return a_new, σ_new
466+ return c_out, a_out
502467```
503468
504469Now we can create an instance and solve the model using JAX:
@@ -522,10 +487,16 @@ a_init = σ_init.copy()
522487Let's generate an approximation solution with JAX:
523488
524489``` {code-cell} ipython3
525- a_star, σ_star = solve_model_time_iter (ifp, a_init, σ_init, print_skip=5 )
490+ a_star, σ_star = solve_model (ifp, a_init, σ_init)
526491```
527492
493+ Let's try it again with a timer.
528494
495+ ``` {code-cell} python3
496+ with qe.Timer(precision=8):
497+ a_star, σ_star = solve_model(ifp, a_init, σ_init)
498+ a_star.block_until_ready()
499+ ```
529500
530501## Simulation
531502
@@ -543,7 +514,6 @@ The function takes a solution pair `c_vec` and `a_vec`, understanding them
543514as representing an optimal policy associated with a given model ` ifp `
544515
545516``` {code-cell} ipython3
546- @jax.jit
547517def simulate_household(
548518 key, a_0, z_idx_0, c_vec, a_vec, ifp, T
549519 ):
@@ -593,6 +563,7 @@ def simulate_household(
593563Now we write a function to simulate many households in parallel.
594564
595565``` {code-cell} ipython3
566+ @partial(jax.jit, static_argnums=(3, 4, 5))
596567def compute_asset_stationary(
597568 c_vec, a_vec, ifp, num_households=50_000, T=500, seed=1234
598569 ):
@@ -623,7 +594,7 @@ def compute_asset_stationary(
623594 )
624595 assets = sim_all_households(keys, a_0_vector, z_idx_0_vector, c_vec, a_vec, ifp, T)
625596
626- return np .array(assets)
597+ return jnp .array(assets)
627598```
628599
629600We'll need some inequality measures for visualization, so let's define them first:
@@ -671,7 +642,7 @@ s_grid = ifp.s_grid
671642n_z = len(ifp.P)
672643a_init = s_grid[:, None] * jnp.ones(n_z)
673644c_init = a_init
674- a_vec, c_vec = solve_model_time_iter (ifp, a_init, c_init)
645+ a_vec, c_vec = solve_model (ifp, a_init, c_init)
675646assets = compute_asset_stationary(c_vec, a_vec, ifp, num_households=200_000)
676647
677648# Compute Gini coefficient for the plot
@@ -763,8 +734,8 @@ for a_r in a_r_vals:
763734 n_z_temp = len(ifp_temp.P)
764735 a_init_temp = s_grid_temp[:, None] * jnp.ones(n_z_temp)
765736 c_init_temp = a_init_temp
766- a_vec_temp, c_vec_temp = solve_model_time_iter (
767- ifp_temp, a_init_temp, c_init_temp, verbose=False
737+ a_vec_temp, c_vec_temp = solve_model (
738+ ifp_temp, a_init_temp, c_init_temp
768739 )
769740
770741 # Simulate households
@@ -840,8 +811,8 @@ for a_y in a_y_vals:
840811 n_z_temp = len(ifp_temp.P)
841812 a_init_temp = s_grid_temp[:, None] * jnp.ones(n_z_temp)
842813 c_init_temp = a_init_temp
843- a_vec_temp, c_vec_temp = solve_model_time_iter (
844- ifp_temp, a_init_temp, c_init_temp, verbose=False
814+ a_vec_temp, c_vec_temp = solve_model (
815+ ifp_temp, a_init_temp, c_init_temp
845816 )
846817
847818 # Simulate households
0 commit comments