Skip to content

Commit 922ed4a

Browse files
committed
fix: item read bug
1 parent 939fe9a commit 922ed4a

File tree

7 files changed

+1051
-4
lines changed

7 files changed

+1051
-4
lines changed

BUG_REPRODUCTION_PACKAGE.md

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# MLX Scalar Layout Bug - Complete Reproduction Package
2+
3+
## 📦 Package Contents
4+
5+
This directory contains everything needed to understand, reproduce, and report the MLX scalar layout bug discovered in EMLX.
6+
7+
### 📄 Documentation
8+
9+
1. **`BUG_REPRODUCTION_README.md`** (START HERE)
10+
- Complete overview of the bug
11+
- How to run reproduction scripts
12+
- Explanation of the bug mechanics
13+
- Next steps
14+
15+
2. **`MLX_BUG_REPORT.md`**
16+
- Formal technical bug report
17+
- Root cause analysis
18+
- Suggested fixes
19+
20+
3. **`MLX_GITHUB_ISSUE.md`**
21+
- Ready-to-post GitHub issue template
22+
- Concise reproduction case
23+
- Test case for MLX maintainers
24+
25+
4. **`BUG_FIX_SUMMARY.md`**
26+
- How the bug was discovered
27+
- Debug methodology
28+
- Workaround implementation in EMLX
29+
30+
### 🧪 Reproduction Scripts
31+
32+
**Main Reproductions:**
33+
- `test_mlx_scalar_bug_raw.exs` - **Demonstrates the bug** (bypasses workaround)
34+
- `test_mlx_scalar_bug.exs` - Shows the bug is fixed with workaround
35+
- `test_scheduler_proper_state.exs` - Real-world impact test
36+
37+
**Debug Tools:**
38+
- `debug_scheduler_divergence.exs` - Traces operations with Nx.Defn.Evaluator
39+
- `find_divergence.exs` - Finds first divergence point
40+
- `compare_debug_traces.exs` - Compares operation traces
41+
42+
### 🎯 Quick Start
43+
44+
```bash
45+
# 1. See the bug in action
46+
elixir test_mlx_scalar_bug_raw.exs
47+
48+
# 2. Verify the fix works
49+
elixir test_mlx_scalar_bug.exs
50+
51+
# 3. Test real-world impact
52+
elixir test_scheduler_proper_state.exs
53+
```
54+
55+
## 🐛 Bug Summary
56+
57+
**What**: MLX creates scalar tensors with invalid memory layout after `slice``squeeze`
58+
59+
**Why**: Squeeze doesn't materialize scalars, leaving them as views into source array
60+
61+
**Impact**: `item()` reads 8 bytes across two consecutive values instead of 4 bytes
62+
63+
**Example**:
64+
- Create array `[0, 1, ..., 951, 952, ...]`
65+
- Extract scalar at index 951
66+
- `item<int64>()` reads `[951, 952]` as single value
67+
- Returns `4,088,808,866,743` instead of `951`
68+
69+
## 📊 Test Results
70+
71+
### Raw Bug (test_mlx_scalar_bug_raw.exs)
72+
```
73+
Index 951: Expected 951, Got 4,088,808,866,743 ❌
74+
Index 998: Expected 998, Got 4,290,672,329,702 ❌
75+
8 out of 9 test cases fail
76+
```
77+
78+
### With Workaround (test_mlx_scalar_bug.exs)
79+
```
80+
All tests pass ✅
81+
Mean difference: 1.0e-8
82+
Std difference: 2.4e-7
83+
```
84+
85+
## 🔧 The Fix
86+
87+
### Current Workaround (in EMLX)
88+
89+
```elixir
90+
defp to_number(%T{} = t) do
91+
device_tuple = from_nx(t)
92+
93+
# Force materialization by adding 0
94+
scalar_zero = EMLX.scalar_tensor(0, EMLX.scalar_type(device_tuple), elem(device_tuple, 0))
95+
ref_fixed = EMLX.add(device_tuple, scalar_zero)
96+
97+
EMLX.item(ref_fixed)
98+
end
99+
```
100+
101+
**Location**: `lib/emlx/backend.ex:445-462`
102+
103+
### Needed in MLX
104+
105+
Either:
106+
1. Fix `squeeze()` to materialize scalar results
107+
2. Fix `item()` to handle views correctly
108+
3. Both (recommended)
109+
110+
## 📝 How It Was Found
111+
112+
1. **Observed**: Bumblebee Stable Diffusion had 0.5 std deviation error on EMLX
113+
2. **Traced**: Used `Nx.Defn.Evaluator` debug mode to log all 57 operations
114+
3. **Compared**: Found first divergence at operation #22 (slice)
115+
4. **Debugged**: Added instrumentation to `mlx_slice` function
116+
5. **Discovered**: `to_number()` returned garbage: `3,869,765,534,647` instead of `951`
117+
6. **Analyzed**: Examined memory layout, found repeating pattern
118+
7. **Fixed**: Implemented workaround in `to_number()`
119+
8. **Verified**: All tests pass, numerical differences eliminated
120+
121+
## 🚀 Next Steps
122+
123+
### For Reporting to MLX
124+
125+
1. Use `MLX_GITHUB_ISSUE.md` as the issue template
126+
2. Link to this reproduction package
127+
3. Include output from `test_mlx_scalar_bug_raw.exs`
128+
129+
### For EMLX Development
130+
131+
- ✅ Workaround implemented and working
132+
- ⏳ Monitor MLX for upstream fix
133+
- 🔄 Remove workaround once MLX is fixed
134+
- 📝 Add regression tests
135+
136+
## 📫 Files You Need
137+
138+
**To report the bug to MLX:**
139+
- `MLX_GITHUB_ISSUE.md` (copy/paste to GitHub)
140+
- Output from `test_mlx_scalar_bug_raw.exs`
141+
142+
**To understand the bug:**
143+
- `BUG_REPRODUCTION_README.md`
144+
- `MLX_BUG_REPORT.md`
145+
146+
**To verify in your environment:**
147+
- `test_mlx_scalar_bug_raw.exs`
148+
- `test_mlx_scalar_bug.exs`
149+
150+
## 🎓 Key Learnings
151+
152+
1. **Use debug tracing**: `Nx.Defn.Evaluator` with `debug_options` is invaluable
153+
2. **Compare operation-by-operation**: Find the exact divergence point
154+
3. **Examine memory**: Sometimes the issue is in the tensor layout, not the operation
155+
4. **Test at multiple levels**: From unit tests to integration tests
156+
5. **Document thoroughly**: Makes bug reports actionable
157+
158+
## ⚠️ Important Notes
159+
160+
- The bug is **deterministic** and **reproducible**
161+
- Affects **any integer type** (int8, int16, int32, int64)
162+
- Only manifests when extracting scalars from views/slices
163+
- Direct scalar creation works fine
164+
- Workaround has minimal performance impact
165+
166+
---
167+
168+
**Created**: October 23, 2025
169+
**Status**: Bug confirmed, workaround implemented, ready to report upstream
170+
**Impact**: Critical for numerical correctness
171+

0 commit comments

Comments
 (0)