Skip to content

Commit 106ce02

Browse files
authored
Merge branch 'main' into main
2 parents 482cb81 + e67576d commit 106ce02

18 files changed

+386
-266
lines changed

README.md

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,15 @@ This repository contains the C++ implementation of the Tesseract quantum error c
170170
The following example demonstrates how to create and use the Tesseract decoder using the Python interface.
171171

172172
```python
173-
import tesseract_decoder.tesseract as tesseract
173+
from tesseract_decoder import tesseract
174174
import stim
175+
import numpy as np
176+
175177

176178
# 1. Define a detector error model (DEM)
177179
dem = stim.DetectorErrorModel("""
178-
error(0.1) D0 D1
179-
error(0.2) D1 D2 L0
180+
error(0.1) D0 D1 L0
181+
error(0.2) D1 D2 L1
180182
detector(0, 0, 0) D0
181183
detector(1, 0, 0) D1
182184
detector(2, 0, 0) D2
@@ -185,15 +187,24 @@ dem = stim.DetectorErrorModel("""
185187
# 2. Create the decoder configuration
186188
config = tesseract.TesseractConfig(dem=dem, det_beam=50)
187189

188-
# 3. Configure and create a decoder instance
189-
decoder = tesseract.TesseractDecoder(config)
190+
# 3. Create a decoder instance
191+
decoder = config.compile_decoder()
190192

191-
# 4. Simulate detection events and decode it
192-
detections = [1, 2]
193-
flipped_observables = decoder.decode(detections)
193+
# 4. Simulate detection events
194+
syndrome = [0, 1, 1]
194195

195-
print(f"Detections: {detections}")
196+
# 5a. Decode to observables
197+
flipped_observables = decoder.decode(syndrome)
196198
print(f"Flipped observables: {flipped_observables}")
199+
200+
# 5b. Alternatively, decode to errors
201+
decoder.decode_to_errors(np.where(syndrome)[0])
202+
predicted_errors = decoder.predicted_errors_buffer
203+
# Indices of predicted errors
204+
print(f"Predicted errors indices: {predicted_errors}")
205+
# Print properties of predicted errors
206+
for i in predicted_errors:
207+
print(f" {i}: {decoder.errors[i]}")
197208
```
198209

199210

docs/tutorial.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,7 @@
846846
" pqlimit=1000,\n",
847847
" # no_revisit_dets=False,\n",
848848
" # at_most_two_errors_per_detector = False,\n",
849-
" det_penalty = False,\n",
849+
" det_penalty = 0,\n",
850850
" # det_orders=tesseract_decoder.utils.build_det_orders(\n",
851851
" # dem, num_det_orders=2, det_order_bfs=True, seed=2384753),\n",
852852
")\n",
@@ -1613,4 +1613,4 @@
16131613
},
16141614
"nbformat": 4,
16151615
"nbformat_minor": 0
1616-
}
1616+
}

src/common.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ common::Error::Error(const stim::DemInstruction& error) {
2929
throw std::invalid_argument(
3030
"Error must be loaded from an error dem instruction, but received: " + error.str());
3131
}
32-
assert(error.type == stim::DemInstructionType::DEM_ERROR);
3332
double probability = error.arg_data[0];
3433
if (probability < 0 || probability > 1) {
3534
throw std::invalid_argument("Probability must be between 0 and 1, but received: " +

src/common.test.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,16 @@ TEST(common, RemoveZeroProbabilityErrors) {
9898

9999
// Helper function to compare the two methods.
100100
void assert_merged_probabilities_are_equal(double p1, double p2) {
101-
// Method 1: Merge probabilities directly using the exclusive OR formula.
101+
// Merge probabilities using the exclusive OR formula.
102102
double merged_p_direct = p1 + p2 - 2 * p1 * p2;
103103

104-
// Method 2: Convert to likelihood costs, merge them, then convert back.
104+
// Convert to likelihood costs, merge, and then convert back.
105105
double cost1 = std::log(p1 / (1 - p1));
106106
double cost2 = std::log(p2 / (1 - p2));
107107
double merged_cost = common::merge_weights(cost1, cost2);
108108
double merged_p_via_costs = 1 / (1 + std::exp(merged_cost));
109109

110-
// The two methods should produce nearly identical results.
110+
// The two methods should produce nearly same results.
111111
ASSERT_NEAR(merged_p_direct, merged_p_via_costs, 1e-12);
112112
}
113113

@@ -131,7 +131,6 @@ TEST(CommonTest, merge_weights_is_equivalent_to_probability_xor) {
131131
// Helper function to create a simple DEM with two identical errors.
132132
stim::DetectorErrorModel create_dem_with_two_errors(double p1, double p2) {
133133
stim::DetectorErrorModel dem;
134-
135134
std::vector<stim::DemTarget> targets = {stim::DemTarget::relative_detector_id(0)};
136135

137136
dem.append_error_instruction(p1, targets, "");

src/py/common_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def test_as_dem_instruction_targets():
6666
def test_error_from_dem_instruction():
6767
di = stim.DemInstruction("error", [0.125], [stim.target_logical_observable_id(3)])
6868
error = tesseract_decoder.common.Error(di)
69-
7069
assert str(error) == "Error{cost=1.945910, symptom=Symptom{}}"
7170

7271
def test_error_get_set_probability():

src/py/shared_decoding_tests.py

Lines changed: 117 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def shared_test_compile_decoder(config_class, decoder_class):
1010
dem_string = "error(0.1) D0 D1 L0"
1111
dem = stim.DetectorErrorModel(dem_string)
1212
config = config_class(dem)
13-
13+
config.merge_errors = False
1414
decoder = config.compile_decoder()
1515

1616
assert isinstance(decoder, decoder_class)
@@ -24,31 +24,41 @@ def shared_test_cost_from_errors(decoder_class, config_class):
2424
predicted error chain. The cost is calculated as the sum of the log-odds
2525
ratio for each error mechanism.
2626
"""
27+
2728
dem_string = f'''
28-
error(0.1) D0 D1 L0
29-
error(0.2) D0 L1
30-
error(0.05) D1
29+
error(0.1) D0 L0
30+
error(0.2) D1 L1
31+
error(0.3) D2
32+
error(0.4) D3 L0 L1
3133
'''
3234
dem = stim.DetectorErrorModel(dem_string)
3335
config = config_class(dem)
36+
config.merge_errors = False
3437
decoder = decoder_class(config)
35-
if hasattr(decoder, 'init_ilp'):
36-
decoder.init_ilp()
3738

38-
# Case 1: Test with a single error.
39-
# The cost should be ln((1 - 0.1) / 0.1) = ln(9)
40-
cost1 = decoder.cost_from_errors([0])
41-
assert cost1 == pytest.approx(math.log(9))
39+
# Case 1: A single error mechanism that flips one observable (L0).
40+
errors1 = [0]
41+
decoder.predicted_errors_buffer = errors1
42+
cost1 = decoder.cost_from_errors(errors1)
43+
assert cost1 == pytest.approx(math.log((1 - 0.1) / 0.1))
4244

43-
# Case 2: Test with multiple errors.
44-
# The cost should be ln(9) + ln((1 - 0.2) / 0.2) = ln(9) + ln(4) +
45-
# ln((1 - 0.05) / 0.05) = ln(9) + ln(4) + ln(19)
46-
cost2 = decoder.cost_from_errors([0, 1, 2])
47-
assert cost2 == pytest.approx(math.log(9) + math.log(4) + math.log(19))
45+
# Case 2: A single error mechanism that flips multiple observables (L0, L1).
46+
errors2 = [3]
47+
decoder.predicted_errors_buffer = errors2
48+
cost2 = decoder.cost_from_errors(errors2)
49+
assert cost2 == pytest.approx(math.log((1 - 0.4) / 0.4))
4850

49-
# Case 3: Test with no errors.
50-
cost3 = decoder.cost_from_errors([])
51-
assert cost3 == pytest.approx(0)
51+
# Case 3: Multiple error mechanisms whose effects cancel out (L0 from error 0, L0 from error 3).
52+
errors3 = [0, 3]
53+
decoder.predicted_errors_buffer = errors3
54+
cost3 = decoder.cost_from_errors(errors3)
55+
assert cost3 == pytest.approx(math.log((1 - 0.1) / 0.1) + math.log((1 - 0.4) / 0.4))
56+
57+
# Case 4: No errors.
58+
errors4 = []
59+
decoder.predicted_errors_buffer = errors4
60+
cost4 = decoder.cost_from_errors(errors4)
61+
assert cost4 == pytest.approx(0)
5262

5363
def shared_test_get_observables_from_errors(decoder_class, config_class):
5464
"""
@@ -63,45 +73,41 @@ def shared_test_get_observables_from_errors(decoder_class, config_class):
6373
'''
6474
dem = stim.DetectorErrorModel(dem_string)
6575
config = config_class(dem)
76+
config.merge_errors = False
6677
decoder = decoder_class(config)
67-
6878
num_observables = dem.num_observables
6979

7080
# Case 1: A single error mechanism that flips one observable (L0).
7181
errors1 = [0]
72-
expected1 = np.array([True, False], dtype=bool)
73-
predicted1_list = decoder.get_observables_from_errors(errors1)
74-
assert isinstance(predicted1_list, list)
75-
predicted1_arr = np.array(predicted1_list, dtype=bool)
76-
assert predicted1_arr.shape[0] == num_observables
77-
assert np.array_equal(predicted1_arr, expected1)
82+
predicted_list1 = decoder.get_observables_from_errors(errors1)
83+
assert isinstance(predicted_list1, list)
84+
predicted_array1 = np.array(predicted_list1, dtype=bool)
85+
assert predicted_array1.shape[0] == num_observables
86+
assert np.array_equal(predicted_array1, np.array([True, False], dtype=bool))
7887

7988
# Case 2: A single error mechanism that flips multiple observables (L0, L1).
8089
errors2 = [3]
81-
expected2 = np.array([True, True], dtype=bool)
82-
predicted2_list = decoder.get_observables_from_errors(errors2)
83-
assert isinstance(predicted2_list, list)
84-
predicted2_arr = np.array(predicted2_list, dtype=bool)
85-
assert predicted2_arr.shape[0] == num_observables
86-
assert np.array_equal(predicted2_arr, expected2)
90+
predicted_list2 = decoder.get_observables_from_errors(errors2)
91+
assert isinstance(predicted_list2, list)
92+
predicted_array2 = np.array(predicted_list2, dtype=bool)
93+
assert predicted_array2.shape[0] == num_observables
94+
assert np.array_equal(predicted_array2, np.array([True, True], dtype=bool))
8795

8896
# Case 3: Multiple error mechanisms whose effects cancel out (L0 from error 0, L0 from error 3).
8997
errors3 = [0, 3]
90-
expected3 = np.array([False, True], dtype=bool)
91-
predicted3_list = decoder.get_observables_from_errors(errors3)
92-
assert isinstance(predicted3_list, list)
93-
predicted3_arr = np.array(predicted3_list, dtype=bool)
94-
assert predicted3_arr.shape[0] == num_observables
95-
assert np.array_equal(predicted3_arr, expected3)
98+
predicted_list3 = decoder.get_observables_from_errors(errors3)
99+
assert isinstance(predicted_list3, list)
100+
predicted_array3 = np.array(predicted_list3, dtype=bool)
101+
assert predicted_array3.shape[0] == num_observables
102+
assert np.array_equal(predicted_array3, np.array([False, True], dtype=bool))
96103

97104
# Case 4: No errors.
98105
errors4 = []
99-
expected4 = np.zeros(num_observables, dtype=bool)
100-
predicted4_list = decoder.get_observables_from_errors(errors4)
101-
assert isinstance(predicted4_list, list)
102-
predicted4_arr = np.array(predicted4_list, dtype=bool)
103-
assert predicted4_arr.shape[0] == num_observables
104-
assert np.array_equal(predicted4_arr, expected4)
106+
predicted_list4 = decoder.get_observables_from_errors(errors4)
107+
assert isinstance(predicted_list4, list)
108+
predicted_array4 = np.array(predicted_list4, dtype=bool)
109+
assert predicted_array4.shape[0] == num_observables
110+
assert np.array_equal(predicted_array4, np.zeros(num_observables, dtype=bool))
105111

106112

107113
def shared_test_decoder_predicts_various_observable_flips(decoder_class, config_class):
@@ -118,8 +124,6 @@ def shared_test_decoder_predicts_various_observable_flips(decoder_class, config_
118124
config = config_class(dem)
119125
decoder = decoder_class(config)
120126

121-
if hasattr(decoder, 'init_ilp'):
122-
decoder.init_ilp()
123127
syndrome = np.zeros(dem.num_detectors, dtype=bool)
124128
syndrome[0] = True
125129
predicted_logical_flips_array = decoder.decode(syndrome)
@@ -142,8 +146,6 @@ def shared_test_decode_from_detection_events(decoder_class, config_class):
142146
dem = stim.DetectorErrorModel(dem_string)
143147
config = config_class(dem)
144148
decoder = decoder_class(config)
145-
if hasattr(decoder, 'init_ilp'):
146-
decoder.init_ilp()
147149

148150
# Case 1: Detections corresponding to the D0 D1 L0 error
149151
detections1 = [0, 1]
@@ -177,8 +179,6 @@ def shared_test_decode(decoder_class, config_class):
177179
dem = stim.DetectorErrorModel(dem_string)
178180
config = config_class(dem)
179181
decoder = decoder_class(config)
180-
if hasattr(decoder, 'init_ilp'):
181-
decoder.init_ilp()
182182

183183
syndrome1 = np.array([True, True], dtype=bool)
184184
predicted1 = decoder.decode(syndrome1)
@@ -205,8 +205,6 @@ def shared_test_decode_complex_dem(decoder_class, config_class):
205205
dem = stim.DetectorErrorModel(dem_string)
206206
config = config_class(dem)
207207
decoder = decoder_class(config)
208-
if hasattr(decoder, 'init_ilp'):
209-
decoder.init_ilp()
210208

211209
syndrome = np.array([True, True, False], dtype=bool)
212210
predicted = decoder.decode(syndrome)
@@ -225,8 +223,6 @@ def shared_test_decode_batch_with_invalid_dimensions(decoder_class, config_class
225223
dem = stim.DetectorErrorModel(dem_string)
226224
config = config_class(dem)
227225
decoder = decoder_class(config)
228-
if hasattr(decoder, 'init_ilp'):
229-
decoder.init_ilp()
230226
invalid_syndrome = np.array([True, True], dtype=bool)
231227
with pytest.raises(RuntimeError, match="Input syndromes must be a 2D NumPy array."):
232228
decoder.decode_batch(invalid_syndrome)
@@ -243,14 +239,12 @@ def shared_test_decode_batch(decoder_class, config_class):
243239
dem = stim.DetectorErrorModel(dem_string)
244240
config = config_class(dem)
245241
decoder = decoder_class(config)
246-
if hasattr(decoder, 'init_ilp'):
247-
decoder.init_ilp()
248-
249242
syndromes = np.array([
250243
[True, True],
251244
[True, False],
252245
[False, True],
253246
], dtype=bool)
247+
254248
predictions = decoder.decode_batch(syndromes)
255249
assert isinstance(predictions, np.ndarray)
256250
assert predictions.dtype.type == np.bool_
@@ -276,9 +270,6 @@ def shared_test_decode_batch_with_complex_model(decoder_class, config_class):
276270
dem = stim.DetectorErrorModel(dem_string)
277271
config = config_class(dem)
278272
decoder = decoder_class(config)
279-
if hasattr(decoder, 'init_ilp'):
280-
decoder.init_ilp()
281-
282273
batch_syndromes = np.array([
283274
[True, True, False, False],
284275
[True, False, True, False],
@@ -298,3 +289,70 @@ def shared_test_decode_batch_with_complex_model(decoder_class, config_class):
298289
[False, False, False],
299290
], dtype=bool)
300291
assert np.array_equal(predictions, expected_predictions)
292+
293+
294+
295+
def shared_test_merge_errors_affects_cost(decoder_class, config_class):
296+
"""
297+
Test that the error's cost changes based on the 'merge_errors' setting.
298+
"""
299+
dem = stim.DetectorErrorModel(
300+
"""
301+
error(0.1) D0
302+
error(0.01) D0
303+
"""
304+
)
305+
detections = [0]
306+
307+
config_no_merge = config_class(dem, merge_errors=False)
308+
decoder_no_merge = decoder_class(config_no_merge)
309+
predicted_errors_no_merge = decoder_no_merge.decode_to_errors(detections)
310+
cost_no_merge = decoder_no_merge.cost_from_errors(decoder_no_merge.predicted_errors_buffer)
311+
312+
config_merge = config_class(dem, merge_errors=True)
313+
decoder_merge = decoder_class(config_merge)
314+
predicted_errors_merge = decoder_merge.decode_to_errors(detections)
315+
cost_merge = decoder_merge.cost_from_errors(decoder_merge.predicted_errors_buffer)
316+
317+
p_merged = 0.1 * (1 - 0.01) + 0.01 * (1 - 0.1)
318+
expected_cost_no_merge = math.log((1 - 0.1) / 0.1)
319+
expected_cost_merge = math.log((1 - p_merged) / p_merged)
320+
321+
assert predicted_errors_no_merge == predicted_errors_merge
322+
assert cost_no_merge == pytest.approx(expected_cost_no_merge)
323+
assert cost_merge == pytest.approx(expected_cost_merge)
324+
assert cost_no_merge != cost_merge
325+
326+
def shared_test_decode_with_mismatched_syndrome_size(decoder_class, config_class):
327+
"""
328+
Tests that `decode` raises an error when the input syndrome's length does
329+
not match the number of detectors in the DEM.
330+
"""
331+
dem_string = f'''
332+
error(0.1) D0 D1
333+
'''
334+
dem = stim.DetectorErrorModel(dem_string)
335+
config = config_class(dem)
336+
decoder = decoder_class(config)
337+
338+
# Syndrome has 1 detector, but DEM has 2
339+
invalid_syndrome = np.array([True], dtype=bool)
340+
with pytest.raises(ValueError, match=r"Syndrome array size \(1\) does not match the number of detectors in the decoder \(2\)\."):
341+
decoder.decode(invalid_syndrome)
342+
343+
def shared_test_decode_batch_with_mismatched_syndrome_size(decoder_class, config_class):
344+
"""
345+
Tests that `decode_batch` raises an error when the input syndromes' width
346+
does not match the number of detectors in the DEM.
347+
"""
348+
dem_string = f'''
349+
error(0.1) D0 D1
350+
'''
351+
dem = stim.DetectorErrorModel(dem_string)
352+
config = config_class(dem)
353+
decoder = decoder_class(config)
354+
355+
# Syndrome batch has 1 column, but DEM has 2
356+
invalid_syndromes = np.array([[True], [False]], dtype=bool)
357+
with pytest.raises(ValueError, match=r"The number of detectors in the input array \(1\) does not match the number of detectors in the decoder \(2\)."):
358+
decoder.decode_batch(invalid_syndromes)

0 commit comments

Comments
 (0)