Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions lectures/ifp_egm_transient_shocks.md
Original file line number Diff line number Diff line change
Expand Up @@ -456,17 +456,13 @@ def K(
return jnp.exp(a_y * η + z * b_y)

def compute_c(i, j):
" Function to compute consumption for one (i, j) pair where i >= 1. "
" Compute c_ij when i >= 1 (interior choice). "

def compute_expectation_k(k):
"""
For each k, approximate the integral

∫ u'(σ(R s_i + y(z_k, η'), z_k)) φ(η') dη'
"""
def expected_mu(k):
" Approximate ∫ u'(σ(R s_i + y(z_k, η'), z_k)) φ(η') dη' "

def compute_mu_at_eta(η):
" For each η draw, compute u'(σ(R * s_i + y(z_k, η), z_k)) "
" Compute u'(σ(R * s_i + y(z_k, η), z_k)) "
next_a = R * s[i] + y(z_grid[k], η)
# Interpolate to get σ(R * s_i + y(z_k, η), z_k)
next_c = jnp.interp(next_a, a_in[:, k], c_in[:, k])
Expand All @@ -479,10 +475,9 @@ def K(
return jnp.mean(all_draws)

# Compute expectation: Σ_k [∫ u'(σ(...)) φ(η) dη] * Π[j, k]
expectations = jax.vmap(compute_expectation_k)(jnp.arange(n_z))
expectations = jax.vmap(expected_mu)(jnp.arange(n_z))
expectation = jnp.sum(expectations * Π[j, :])

# Invert to get consumption c_{ij} at (s_i, z_j)
# Invert to get consumption c_ij at (s_i, z_j)
return u_prime_inv(β * R * expectation)

# Set up index grids for vmap computation of all c_{ij}
Expand Down
6 changes: 4 additions & 2 deletions lectures/os_egm.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,12 @@ def K(
# Allocate memory for new consumption array
c_out = np.empty_like(s_grid)

# Solve for updated consumption value
for i, s in enumerate(s_grid):
# Approximate marginal utility ∫ u'(σ(f(s, α)z)) f'(s, α) z ϕ(z)dz
vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks
c_out[i] = u_prime_inv(β * np.mean(vals))
mu = np.mean(vals)
# Compute consumption
c_out[i] = u_prime_inv(β * mu)

# Determine corresponding endogenous grid
x_out = s_grid + c_out # x_i = s_i + c_i
Expand Down
60 changes: 33 additions & 27 deletions lectures/os_egm_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,16 @@ class Model(NamedTuple):
α: float # production function parameter


def create_model(β: float = 0.96,
μ: float = 0.0,
s: float = 0.1,
grid_max: float = 4.0,
grid_size: int = 120,
shock_size: int = 250,
seed: int = 1234,
α: float = 0.4) -> Model:
def create_model(
β: float = 0.96,
μ: float = 0.0,
s: float = 0.1,
grid_max: float = 4.0,
grid_size: int = 120,
shock_size: int = 250,
seed: int = 1234,
α: float = 0.4
) -> Model:
"""
Creates an instance of the optimal savings model.
"""
Expand All @@ -114,6 +116,17 @@ def create_model(β: float = 0.96,
return Model(β=β, μ=μ, s=s, s_grid=s_grid, shocks=shocks, α=α)
```


We define utility and production functions globally.

```{code-cell} python3
# Define utility and production functions with derivatives
u = lambda c: jnp.log(c)
u_prime = lambda c: 1 / c
u_prime_inv = lambda x: 1 / x
f = lambda k, α: k**α
f_prime = lambda k, α: α * k**(α - 1)
```
Here's the Coleman-Reffett operator using EGM.

The key JAX feature here is `vmap`, which vectorizes the computation over the grid points.
Expand All @@ -138,10 +151,13 @@ def K(

# Define function to compute consumption at a single grid point
def compute_c(s):
# Approximate marginal utility ∫ u'(σ(f(s, α)z)) f'(s, α) z ϕ(z)dz
vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks
return u_prime_inv(β * jnp.mean(vals))
mu = jnp.mean(vals)
# Calculate consumption
return u_prime_inv(β * mu)

# Vectorize over grid using vmap
# Vectorize and calculate on all exogenous grid points
compute_c_vectorized = jax.vmap(compute_c)
c_out = compute_c_vectorized(s_grid)

Expand All @@ -151,18 +167,6 @@ def K(
return c_out, x_out
```

We define utility and production functions globally.

Note that `f` and `f_prime` take `α` as an explicit argument, allowing them to work with JAX's functional programming model.

```{code-cell} python3
# Define utility and production functions with derivatives
u = lambda c: jnp.log(c)
u_prime = lambda c: 1 / c
u_prime_inv = lambda x: 1 / x
f = lambda k, α: k**α
f_prime = lambda k, α: α * k**(α - 1)
```

Now we create a model instance.

Expand All @@ -175,11 +179,13 @@ The solver uses JAX's `jax.lax.while_loop` for the iteration and is JIT-compiled

```{code-cell} python3
@jax.jit
def solve_model_time_iter(model: Model,
c_init: jnp.ndarray,
x_init: jnp.ndarray,
tol: float = 1e-5,
max_iter: int = 1000):
def solve_model_time_iter(
model: Model,
c_init: jnp.ndarray,
x_init: jnp.ndarray,
tol: float = 1e-5,
max_iter: int = 1000
):
"""
Solve the model using time iteration with EGM.
"""
Expand Down
31 changes: 15 additions & 16 deletions lectures/os_numerical.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ This is a form of **successive approximation**, and was discussed in our {doc}`l
The basic idea is:

1. Take an arbitrary initial guess of $v$.
1. Obtain an update $w$ defined by
1. Obtain an update $\hat v$ defined by

$$
w(x) = \max_{0\leq c \leq x} \{u(c) + \beta v(x-c)\}
\hat v(x) = \max_{0\leq c \leq x} \{u(c) + \beta v(x-c)\}
$$

1. Stop if $w$ is approximately equal to $v$, otherwise set
$v=w$ and go back to step 2.
1. Stop if $\hat v$ is approximately equal to $v$, otherwise set
$v=\hat v$ and go back to step 2.

Let's write this a bit more mathematically.

Expand All @@ -109,7 +109,7 @@ We introduce the **Bellman operator** $T$ that takes a function v as an
argument and returns a new function $Tv$ defined by

$$
Tv(x) = \max_{0 \leq c \leq x} \{u(c) + \beta v(x - c)\}
Tv(x) = \max_{0 \leq c \leq x} \{u(c) + \beta v(x - c)\}
$$

From $v$ we get $Tv$, and applying $T$ to this yields
Expand Down Expand Up @@ -206,13 +206,7 @@ Here's the CRRA utility function.

```{code-cell} python3
def u(c, γ):
"""
Utility function.
"""
if γ == 1:
return np.log(c)
else:
return (c ** (1 - γ)) / (1 - γ)
return (c ** (1 - γ)) / (1 - γ)
```

To work with the Bellman equation, let's write it as
Expand Down Expand Up @@ -240,8 +234,8 @@ def B(
Right hand side of the Bellman equation given x and c.

"""
# Unpack
β, γ, x_grid = model.β, model.γ, model.x_grid
# Unpack (simplify names)
β, γ, x_grid = model

# Convert array v into a function by linear interpolation
vf = lambda x: np.interp(x, x_grid, v)
Expand All @@ -250,7 +244,12 @@ def B(
return u(c, γ) + β * vf(x - c)
```

We now define the Bellman operation:
We now define the Bellman operator acting on grid points:

$$
Tv(x_i) = \max_{0 \leq c \leq x_i} B(x_i, c, v)
\qquad \text{for all } i
$$

```{code-cell} python3
def T(
Expand Down Expand Up @@ -280,7 +279,7 @@ model = create_cake_eating_model()
β, γ, x_grid = model
```

Now let's see the iteration of the value function in action.
Now let's see iteration of the value function in action.

We start from guess $v$ given by $v(x) = u(x)$ for every
$x$ grid point.
Expand Down
Loading
Loading