|
| 1 | +# Migration Plan: Porting homepy from PyMC3 (v4.3.0) to PyMC v5+ |
| 2 | + |
| 3 | +## Current State Analysis |
| 4 | + |
| 5 | +**Codebase Structure:** |
| 6 | +- 33 Python files found with PyMC/Aesara usage |
| 7 | +- Current dependencies: `pymc==4.3.0`, `arviz==0.13.0` |
| 8 | +- Heavy use of Aesara tensor operations (`aesara.tensor as at`) |
| 9 | +- Custom Aesara graph manipulation in `aesaraf.py` |
| 10 | + |
| 11 | +## Key PyMC3 → PyMC v5 Changes Identified |
| 12 | + |
| 13 | +Based on the investigation, here are the **critical changes** from PyMC3/v4 to v5: |
| 14 | + |
| 15 | +### 1. **Backend Migration: Aesara → PyTensor** |
| 16 | + - **Impact:** HIGH - affects all files |
| 17 | + - **Change:** Replace `aesara` with `pytensor` |
| 18 | + ```python |
| 19 | + # OLD (PyMC 4.x): |
| 20 | + import aesara |
| 21 | + import aesara.tensor as at |
| 22 | + from aesara.tensor.var import TensorVariable |
| 23 | + |
| 24 | + # NEW (PyMC 5.x): |
| 25 | + import pytensor |
| 26 | + import pytensor.tensor as pt |
| 27 | + from pytensor.tensor.var import TensorVariable |
| 28 | + ``` |
| 29 | + |
| 30 | +### 2. **AePPL Module Integration** |
| 31 | + - **Impact:** HIGH - affects `aesaraf.py` |
| 32 | + - The `aeppl` package is now merged into PyMC's `logprob` submodule |
| 33 | + ```python |
| 34 | + # OLD: |
| 35 | + from aeppl.utils import get_constant_value |
| 36 | + |
| 37 | + # NEW: |
| 38 | + from pymc.logprob.utils import get_constant_value |
| 39 | + ``` |
| 40 | + |
| 41 | +### 3. **Graph Rewriting API Changes** |
| 42 | + - **Impact:** HIGH - affects `aesaraf.py` |
| 43 | + ```python |
| 44 | + # OLD: |
| 45 | + from aesara.graph.rewriting.basic import in2out, node_rewriter |
| 46 | + from aesara.graph.opt import local_optimizer as node_rewriter |
| 47 | + |
| 48 | + # NEW: |
| 49 | + from pytensor.graph.rewriting.basic import in2out, node_rewriter |
| 50 | + ``` |
| 51 | + |
| 52 | +### 4. **Random Variables and Shared Variables** |
| 53 | + - **Impact:** MEDIUM |
| 54 | + ```python |
| 55 | + # OLD: |
| 56 | + from aesara.tensor.random.basic import NormalRV |
| 57 | + at.random.var.RandomStateSharedVariable |
| 58 | + |
| 59 | + # NEW: |
| 60 | + from pytensor.tensor.random.basic import NormalRV |
| 61 | + pt.tensor.random.var.RandomStateSharedVariable |
| 62 | + ``` |
| 63 | + |
| 64 | +### 5. **InferenceData Changes** |
| 65 | + - **Impact:** LOW |
| 66 | + - `pm.sample()` now returns `InferenceData` by default (no need for `return_inferencedata=True`) |
| 67 | + - Log likelihood NOT computed by default anymore |
| 68 | + |
| 69 | +### 6. **Dependency Updates** |
| 70 | + - Update `arviz` to latest (currently at 0.13.0, latest is 0.20+) |
| 71 | + - Remove `aesara` dependencies |
| 72 | + - Add `pytensor` dependency |
| 73 | + |
| 74 | +## Files Requiring Changes |
| 75 | + |
| 76 | +### **High Priority (Core functionality):** |
| 77 | + |
| 78 | +1. **`homepy/aesaraf.py`** - Custom Aesara graph manipulation |
| 79 | + - Line 20-27: Import statements |
| 80 | + - Line 36: `pymc.aesaraf.find_rng_nodes` usage |
| 81 | + - All graph rewriting logic |
| 82 | + |
| 83 | +2. **`homepy/models/base.py`** - Main model class |
| 84 | + - Lines 26, 33-35: Aesara imports |
| 85 | + - Line 173: `aesara.function()` calls |
| 86 | + - Lines 982, 1029: Graph cloning operations |
| 87 | + |
| 88 | +3. **`homepy/blocks/gp.py`** - Gaussian Process blocks |
| 89 | + - Line 32: `aesara.tensor as at` |
| 90 | + - Line 91: `.eval()` calls on covariance matrices |
| 91 | + |
| 92 | +4. **`homepy/nested_hierarchy_rvs.py`** - Hierarchical RV creation |
| 93 | + - Lines 23: `aesara.tensor as at` |
| 94 | + - Line 102: `aesara.shared()` calls |
| 95 | + |
| 96 | +### **Medium Priority (Block implementations):** |
| 97 | + |
| 98 | +5-11. **All files in `homepy/blocks/`:** |
| 99 | + - `base.py`, `distributions.py`, `likelihoods.py`, `linear.py`, `means.py` |
| 100 | + - Import statements only |
| 101 | + |
| 102 | +### **Low Priority (Tests):** |
| 103 | + |
| 104 | +12-33. **All test files** - Update imports and assertions |
| 105 | + |
| 106 | +## Detailed Migration Steps |
| 107 | + |
| 108 | +### **Phase 1: Dependency Updates** |
| 109 | + |
| 110 | +1. Update `pyproject.toml`: |
| 111 | +```toml |
| 112 | +dependencies = [ |
| 113 | + "arviz>=0.20.0", |
| 114 | + "pymc>=5.0.0", |
| 115 | + "pytensor>=2.8.0", # New |
| 116 | + "numpyro>=0.15.0", # Update |
| 117 | + "numpy>=1.23.0", |
| 118 | + "pandas>=1.4.3", |
| 119 | + "xarray>=2022.3.0", |
| 120 | + # Remove: old pinned versions |
| 121 | + # Remove: cython (not needed for modern pymc) |
| 122 | +] |
| 123 | +``` |
| 124 | + |
| 125 | +2. Test environment setup with pixi |
| 126 | + |
| 127 | +### **Phase 2: Global Find/Replace** |
| 128 | + |
| 129 | +Safe mechanical replacements across all files: |
| 130 | + |
| 131 | +```python |
| 132 | +# Imports |
| 133 | +"import aesara" → "import pytensor" |
| 134 | +"from aesara" → "from pytensor" |
| 135 | +"import aesara.tensor as at" → "import pytensor.tensor as pt" |
| 136 | +"as at" → "as pt" |
| 137 | +"at." → "pt." |
| 138 | + |
| 139 | +# Module paths |
| 140 | +"aesara.tensor" → "pytensor.tensor" |
| 141 | +"aesara.graph" → "pytensor.graph" |
| 142 | +"aesara.shared" → "pytensor.shared" |
| 143 | +"aesara.compile" → "pytensor.compile" |
| 144 | +``` |
| 145 | + |
| 146 | +### **Phase 3: AePPL Migration** |
| 147 | + |
| 148 | +In `aesaraf.py`: |
| 149 | +```python |
| 150 | +# OLD: |
| 151 | +from aeppl.utils import get_constant_value |
| 152 | + |
| 153 | +# NEW: |
| 154 | +from pymc.logprob.utils import get_constant_value |
| 155 | +``` |
| 156 | + |
| 157 | +### **Phase 4: Graph Rewriting Updates** |
| 158 | + |
| 159 | +Update `aesaraf.py` node_rewriter compatibility: |
| 160 | +```python |
| 161 | +# Should work as-is after pytensor import change, but verify: |
| 162 | +from pytensor.graph.rewriting.basic import in2out, node_rewriter |
| 163 | +``` |
| 164 | + |
| 165 | +### **Phase 5: Sampling Updates** |
| 166 | + |
| 167 | +In `models/base.py`: |
| 168 | + |
| 169 | +```python |
| 170 | +# OLD (line 107): |
| 171 | +idata = pm.sample_prior_predictive(*args, return_inferencedata=True, **kwargs) |
| 172 | + |
| 173 | +# NEW: |
| 174 | +idata = pm.sample_prior_predictive(*args, **kwargs) |
| 175 | +# return_inferencedata=True is now default |
| 176 | + |
| 177 | +# OLD (line 116): |
| 178 | +idata = pm.sample(*args, return_inferencedata=True, **kwargs) |
| 179 | + |
| 180 | +# NEW: |
| 181 | +idata = pm.sample(*args, **kwargs) |
| 182 | +``` |
| 183 | + |
| 184 | +For log likelihood (if needed for model comparison): |
| 185 | +```python |
| 186 | +idata = pm.sample(*args, idata_kwargs=dict(log_likelihood=True), **kwargs) |
| 187 | +``` |
| 188 | + |
| 189 | +### **Phase 6: Test and Validate** |
| 190 | + |
| 191 | +1. Run unit tests: `pytest homepy/tests/` |
| 192 | +2. Run integration tests |
| 193 | +3. Validate GPU sampling still works |
| 194 | +4. Check model comparison functionality |
| 195 | + |
| 196 | +## Risk Assessment |
| 197 | + |
| 198 | +**Low Risk:** |
| 199 | +- Import replacements (mechanical) |
| 200 | +- Sampling API changes (backward compatible defaults) |
| 201 | + |
| 202 | +**Medium Risk:** |
| 203 | +- Graph rewriting in `aesaraf.py` (may have subtle API changes) |
| 204 | +- Custom RV creation (API mostly stable) |
| 205 | + |
| 206 | +**High Risk:** |
| 207 | +- Non-centered parameterization rewrites (core functionality) |
| 208 | +- JAX/GPU compilation (may need pytensor-specific flags) |
| 209 | + |
| 210 | +## Testing Strategy |
| 211 | + |
| 212 | +1. **Unit tests first**: Ensure basic building blocks work |
| 213 | +2. **Integration tests**: Test full model building and sampling |
| 214 | +3. **GPU tests**: Verify JAX backend still works |
| 215 | +4. **Regression tests**: Compare posteriors from v4 vs v5 on sample data |
| 216 | + |
| 217 | +## Additional Considerations |
| 218 | + |
| 219 | +1. **Version pinning**: Consider if you want to support both v4 and v5, or go straight to v5 |
| 220 | +2. **ArviZ compatibility**: Current version 0.13.0 is old; update to 0.20+ for better v5 support |
| 221 | +3. **Documentation**: Update README and examples with new import statements |
| 222 | +4. **CI/CD**: Update GitHub Actions to use PyMC v5 |
| 223 | + |
| 224 | +## Estimated Effort |
| 225 | + |
| 226 | +- **Mechanical changes**: 2-4 hours |
| 227 | +- **Testing and validation**: 4-8 hours |
| 228 | +- **Debugging edge cases**: 2-6 hours |
| 229 | +- **Total**: 1-2 days of focused work |
| 230 | + |
| 231 | +## Implementation Checklist |
| 232 | + |
| 233 | +- [ ] Phase 1: Update dependencies in pyproject.toml |
| 234 | +- [ ] Phase 2: Global find/replace (aesara → pytensor) |
| 235 | +- [ ] Phase 3: Update AePPL imports |
| 236 | +- [ ] Phase 4: Verify graph rewriting API |
| 237 | +- [ ] Phase 5: Update sampling calls |
| 238 | +- [ ] Phase 6: Run and fix unit tests |
| 239 | +- [ ] Phase 6: Run and fix integration tests |
| 240 | +- [ ] Phase 6: Test GPU sampling |
| 241 | +- [ ] Update documentation |
| 242 | +- [ ] Update CI/CD workflows |
0 commit comments