Skip to content

Commit f1b0431

Browse files
twieckiclaude
andcommitted
Migrate from PyMC v4.3.0 (Aesara) to PyMC v5+ (PyTensor)
This commit migrates the entire codebase from the Aesara backend (PyMC v4.3.0) to the PyTensor backend (PyMC v5+). Major changes: - Updated all imports from 'aesara' to 'pytensor' - Replaced 'at' (aesara.tensor) with 'pt' (pytensor.tensor) throughout - Updated aeppl imports to use pymc.logprob equivalents - Removed `return_inferencedata=True` from sampling calls (now default) - Updated pyproject.toml dependencies: - pymc: 4.3.0 → >=5.0.0 - Added pytensor>=2.8.0 - arviz: 0.13.0 → >=0.20.0 - numpyro: 0.10.0 → >=0.15.0 - Added pixi configuration for environment management - Updated Python requirement from >=3.8 to >=3.10 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 4c72d96 commit f1b0431

30 files changed

+9705
-264
lines changed

IMPLEMENTATION_PLAN.md

Lines changed: 609 additions & 0 deletions
Large diffs are not rendered by default.

MIGRATION_PLAN.md

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# Migration Plan: Porting homepy from PyMC3 (v4.3.0) to PyMC v5+
2+
3+
## Current State Analysis
4+
5+
**Codebase Structure:**
6+
- 33 Python files found with PyMC/Aesara usage
7+
- Current dependencies: `pymc==4.3.0`, `arviz==0.13.0`
8+
- Heavy use of Aesara tensor operations (`aesara.tensor as at`)
9+
- Custom Aesara graph manipulation in `aesaraf.py`
10+
11+
## Key PyMC3 → PyMC v5 Changes Identified
12+
13+
Based on the investigation, here are the **critical changes** from PyMC3/v4 to v5:
14+
15+
### 1. **Backend Migration: Aesara → PyTensor**
16+
- **Impact:** HIGH - affects all files
17+
- **Change:** Replace `aesara` with `pytensor`
18+
```python
19+
# OLD (PyMC 4.x):
20+
import aesara
21+
import aesara.tensor as at
22+
from aesara.tensor.var import TensorVariable
23+
24+
# NEW (PyMC 5.x):
25+
import pytensor
26+
import pytensor.tensor as pt
27+
from pytensor.tensor.var import TensorVariable
28+
```
29+
30+
### 2. **AePPL Module Integration**
31+
- **Impact:** HIGH - affects `aesaraf.py`
32+
- The `aeppl` package is now merged into PyMC's `logprob` submodule
33+
```python
34+
# OLD:
35+
from aeppl.utils import get_constant_value
36+
37+
# NEW:
38+
from pymc.logprob.utils import get_constant_value
39+
```
40+
41+
### 3. **Graph Rewriting API Changes**
42+
- **Impact:** HIGH - affects `aesaraf.py`
43+
```python
44+
# OLD:
45+
from aesara.graph.rewriting.basic import in2out, node_rewriter
46+
from aesara.graph.opt import local_optimizer as node_rewriter
47+
48+
# NEW:
49+
from pytensor.graph.rewriting.basic import in2out, node_rewriter
50+
```
51+
52+
### 4. **Random Variables and Shared Variables**
53+
- **Impact:** MEDIUM
54+
```python
55+
# OLD:
56+
from aesara.tensor.random.basic import NormalRV
57+
at.random.var.RandomStateSharedVariable
58+
59+
# NEW:
60+
from pytensor.tensor.random.basic import NormalRV
61+
pt.tensor.random.var.RandomStateSharedVariable
62+
```
63+
64+
### 5. **InferenceData Changes**
65+
- **Impact:** LOW
66+
- `pm.sample()` now returns `InferenceData` by default (no need for `return_inferencedata=True`)
67+
- Log likelihood NOT computed by default anymore
68+
69+
### 6. **Dependency Updates**
70+
- Update `arviz` to latest (currently at 0.13.0, latest is 0.20+)
71+
- Remove `aesara` dependencies
72+
- Add `pytensor` dependency
73+
74+
## Files Requiring Changes
75+
76+
### **High Priority (Core functionality):**
77+
78+
1. **`homepy/aesaraf.py`** - Custom Aesara graph manipulation
79+
- Line 20-27: Import statements
80+
- Line 36: `pymc.aesaraf.find_rng_nodes` usage
81+
- All graph rewriting logic
82+
83+
2. **`homepy/models/base.py`** - Main model class
84+
- Lines 26, 33-35: Aesara imports
85+
- Line 173: `aesara.function()` calls
86+
- Lines 982, 1029: Graph cloning operations
87+
88+
3. **`homepy/blocks/gp.py`** - Gaussian Process blocks
89+
- Line 32: `aesara.tensor as at`
90+
- Line 91: `.eval()` calls on covariance matrices
91+
92+
4. **`homepy/nested_hierarchy_rvs.py`** - Hierarchical RV creation
93+
- Lines 23: `aesara.tensor as at`
94+
- Line 102: `aesara.shared()` calls
95+
96+
### **Medium Priority (Block implementations):**
97+
98+
5-11. **All files in `homepy/blocks/`:**
99+
- `base.py`, `distributions.py`, `likelihoods.py`, `linear.py`, `means.py`
100+
- Import statements only
101+
102+
### **Low Priority (Tests):**
103+
104+
12-33. **All test files** - Update imports and assertions
105+
106+
## Detailed Migration Steps
107+
108+
### **Phase 1: Dependency Updates**
109+
110+
1. Update `pyproject.toml`:
111+
```toml
112+
dependencies = [
113+
"arviz>=0.20.0",
114+
"pymc>=5.0.0",
115+
"pytensor>=2.8.0", # New
116+
"numpyro>=0.15.0", # Update
117+
"numpy>=1.23.0",
118+
"pandas>=1.4.3",
119+
"xarray>=2022.3.0",
120+
# Remove: old pinned versions
121+
# Remove: cython (not needed for modern pymc)
122+
]
123+
```
124+
125+
2. Test environment setup with pixi
126+
127+
### **Phase 2: Global Find/Replace**
128+
129+
Safe mechanical replacements across all files:
130+
131+
```python
132+
# Imports
133+
"import aesara""import pytensor"
134+
"from aesara""from pytensor"
135+
"import aesara.tensor as at""import pytensor.tensor as pt"
136+
"as at""as pt"
137+
"at.""pt."
138+
139+
# Module paths
140+
"aesara.tensor""pytensor.tensor"
141+
"aesara.graph""pytensor.graph"
142+
"aesara.shared""pytensor.shared"
143+
"aesara.compile""pytensor.compile"
144+
```
145+
146+
### **Phase 3: AePPL Migration**
147+
148+
In `aesaraf.py`:
149+
```python
150+
# OLD:
151+
from aeppl.utils import get_constant_value
152+
153+
# NEW:
154+
from pymc.logprob.utils import get_constant_value
155+
```
156+
157+
### **Phase 4: Graph Rewriting Updates**
158+
159+
Update `aesaraf.py` node_rewriter compatibility:
160+
```python
161+
# Should work as-is after pytensor import change, but verify:
162+
from pytensor.graph.rewriting.basic import in2out, node_rewriter
163+
```
164+
165+
### **Phase 5: Sampling Updates**
166+
167+
In `models/base.py`:
168+
169+
```python
170+
# OLD (line 107):
171+
idata = pm.sample_prior_predictive(*args, return_inferencedata=True, **kwargs)
172+
173+
# NEW:
174+
idata = pm.sample_prior_predictive(*args, **kwargs)
175+
# return_inferencedata=True is now default
176+
177+
# OLD (line 116):
178+
idata = pm.sample(*args, return_inferencedata=True, **kwargs)
179+
180+
# NEW:
181+
idata = pm.sample(*args, **kwargs)
182+
```
183+
184+
For log likelihood (if needed for model comparison):
185+
```python
186+
idata = pm.sample(*args, idata_kwargs=dict(log_likelihood=True), **kwargs)
187+
```
188+
189+
### **Phase 6: Test and Validate**
190+
191+
1. Run unit tests: `pytest homepy/tests/`
192+
2. Run integration tests
193+
3. Validate GPU sampling still works
194+
4. Check model comparison functionality
195+
196+
## Risk Assessment
197+
198+
**Low Risk:**
199+
- Import replacements (mechanical)
200+
- Sampling API changes (backward compatible defaults)
201+
202+
**Medium Risk:**
203+
- Graph rewriting in `aesaraf.py` (may have subtle API changes)
204+
- Custom RV creation (API mostly stable)
205+
206+
**High Risk:**
207+
- Non-centered parameterization rewrites (core functionality)
208+
- JAX/GPU compilation (may need pytensor-specific flags)
209+
210+
## Testing Strategy
211+
212+
1. **Unit tests first**: Ensure basic building blocks work
213+
2. **Integration tests**: Test full model building and sampling
214+
3. **GPU tests**: Verify JAX backend still works
215+
4. **Regression tests**: Compare posteriors from v4 vs v5 on sample data
216+
217+
## Additional Considerations
218+
219+
1. **Version pinning**: Consider if you want to support both v4 and v5, or go straight to v5
220+
2. **ArviZ compatibility**: Current version 0.13.0 is old; update to 0.20+ for better v5 support
221+
3. **Documentation**: Update README and examples with new import statements
222+
4. **CI/CD**: Update GitHub Actions to use PyMC v5
223+
224+
## Estimated Effort
225+
226+
- **Mechanical changes**: 2-4 hours
227+
- **Testing and validation**: 4-8 hours
228+
- **Debugging edge cases**: 2-6 hours
229+
- **Total**: 1-2 days of focused work
230+
231+
## Implementation Checklist
232+
233+
- [ ] Phase 1: Update dependencies in pyproject.toml
234+
- [ ] Phase 2: Global find/replace (aesara → pytensor)
235+
- [ ] Phase 3: Update AePPL imports
236+
- [ ] Phase 4: Verify graph rewriting API
237+
- [ ] Phase 5: Update sampling calls
238+
- [ ] Phase 6: Run and fix unit tests
239+
- [ ] Phase 6: Run and fix integration tests
240+
- [ ] Phase 6: Test GPU sampling
241+
- [ ] Update documentation
242+
- [ ] Update CI/CD workflows

homepy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import re
1919
import warnings
2020

21-
# Suppress JAX omnistaging from aesara
21+
# Suppress JAX omnistaging from pytensor
2222
warnings.filterwarnings(
2323
"ignore",
2424
category=UserWarning,

homepy/aesaraf.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,21 @@
1717

1818
from typing import List, Union
1919

20-
import aesara
21-
import aesara.tensor as at
20+
import pytensor
21+
import pytensor.tensor as pt
2222
import numpy as np
2323

24-
from aeppl.utils import get_constant_value
25-
from aesara.graph import FunctionGraph
26-
from aesara.graph.basic import clone_get_equiv
27-
from aesara.tensor.random.basic import NormalRV
24+
from pymc.logprob.utils import get_constant_value
25+
from pytensor.graph import FunctionGraph
26+
from pytensor.graph.basic import clone_get_equiv
27+
from pytensor.tensor.random.basic import NormalRV
2828

2929
try:
3030
# Support aesara versions >= 2.8.0
31-
from aesara.graph.rewriting.basic import in2out, node_rewriter
31+
from pytensor.graph.rewriting.basic import in2out, node_rewriter
3232
except ImportError:
33-
from aesara.graph.opt import in2out
34-
from aesara.graph.opt import local_optimizer as node_rewriter
33+
from pytensor.graph.opt import in2out
34+
from pytensor.graph.opt import local_optimizer as node_rewriter
3535

3636
from pymc.aesaraf import find_rng_nodes
3737

@@ -58,7 +58,7 @@ def make_normal_not_centered(fgraph, node):
5858

5959
name = getattr(node.outputs[1], "name", None)
6060
if not loc_is_zero and name in resampled_vars_mapping:
61-
raw = at.random.normal(0, 1, rng=rng, size=size, dtype=dtype)
61+
raw = pt.random.normal(0, 1, rng=rng, size=size, dtype=dtype)
6262
raw.name = name + "_raw"
6363
og_index = resampled_vars_mapping[name]
6464
resampled_vars[og_index[0]][og_index[1]] = raw
@@ -95,11 +95,11 @@ def clone_replace_rv_consistent(outputs, free_RVs, replace):
9595
new_rng_nodes: List[Union[np.random.RandomState, np.random.Generator]] = []
9696
for rng_node in rng_nodes:
9797
rng_cls: type
98-
if isinstance(rng_node, at.random.var.RandomStateSharedVariable):
98+
if isinstance(rng_node, pt.random.var.RandomStateSharedVariable):
9999
rng_cls = np.random.RandomState
100100
else:
101101
rng_cls = np.random.Generator
102-
new_rng_nodes.append(aesara.shared(rng_cls(np.random.PCG64())))
102+
new_rng_nodes.append(pytensor.shared(rng_cls(np.random.PCG64())))
103103
orig_replace = {clone_map[rv]: rv for rv in free_RVs if rv in clone_map}
104104
orig_replace.update(dict(zip(rng_nodes, new_rng_nodes)))
105105
# replace_var can only be constant values or shared, not graph that depend on nodes

homepy/blocks/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from typing import Dict, Iterable, List, Optional
2121

22-
from aesara.tensor import TensorVariable
22+
from pytensor.tensor import TensorVariable
2323
from arviz import InferenceData
2424
from pandas import DataFrame
2525

0 commit comments

Comments
 (0)