@@ -10,7 +10,7 @@ def shared_test_compile_decoder(config_class, decoder_class):
1010 dem_string = "error(0.1) D0 D1 L0"
1111 dem = stim .DetectorErrorModel (dem_string )
1212 config = config_class (dem )
13-
13+ config . merge_errors = False
1414 decoder = config .compile_decoder ()
1515
1616 assert isinstance (decoder , decoder_class )
@@ -24,31 +24,41 @@ def shared_test_cost_from_errors(decoder_class, config_class):
2424 predicted error chain. The cost is calculated as the sum of the log-odds
2525 ratio for each error mechanism.
2626 """
27+
2728 dem_string = f'''
28- error(0.1) D0 D1 L0
29- error(0.2) D0 L1
30- error(0.05) D1
29+ error(0.1) D0 L0
30+ error(0.2) D1 L1
31+ error(0.3) D2
32+ error(0.4) D3 L0 L1
3133 '''
3234 dem = stim .DetectorErrorModel (dem_string )
3335 config = config_class (dem )
36+ config .merge_errors = False
3437 decoder = decoder_class (config )
35- if hasattr (decoder , 'init_ilp' ):
36- decoder .init_ilp ()
3738
38- # Case 1: Test with a single error.
39- # The cost should be ln((1 - 0.1) / 0.1) = ln(9)
40- cost1 = decoder .cost_from_errors ([0 ])
41- assert cost1 == pytest .approx (math .log (9 ))
39+ # Case 1: A single error mechanism that flips one observable (L0).
40+ errors1 = [0 ]
41+ decoder .predicted_errors_buffer = errors1
42+ cost1 = decoder .cost_from_errors (errors1 )
43+ assert cost1 == pytest .approx (math .log ((1 - 0.1 ) / 0.1 ))
4244
43- # Case 2: Test with multiple errors .
44- # The cost should be ln(9) + ln((1 - 0.2) / 0.2) = ln(9) + ln(4) +
45- # ln((1 - 0.05) / 0.05) = ln(9) + ln(4) + ln(19)
46- cost2 = decoder .cost_from_errors ([ 0 , 1 , 2 ] )
47- assert cost2 == pytest .approx (math .log (9 ) + math . log ( 4 ) + math . log ( 19 ))
45+ # Case 2: A single error mechanism that flips multiple observables (L0, L1) .
46+ errors2 = [ 3 ]
47+ decoder . predicted_errors_buffer = errors2
48+ cost2 = decoder .cost_from_errors (errors2 )
49+ assert cost2 == pytest .approx (math .log (( 1 - 0. 4 ) / 0.4 ))
4850
49- # Case 3: Test with no errors.
50- cost3 = decoder .cost_from_errors ([])
51- assert cost3 == pytest .approx (0 )
51+ # Case 3: Multiple error mechanisms whose effects cancel out (L0 from error 0, L0 from error 3).
52+ errors3 = [0 , 3 ]
53+ decoder .predicted_errors_buffer = errors3
54+ cost3 = decoder .cost_from_errors (errors3 )
55+ assert cost3 == pytest .approx (math .log ((1 - 0.1 ) / 0.1 ) + math .log ((1 - 0.4 ) / 0.4 ))
56+
57+ # Case 4: No errors.
58+ errors4 = []
59+ decoder .predicted_errors_buffer = errors4
60+ cost4 = decoder .cost_from_errors (errors4 )
61+ assert cost4 == pytest .approx (0 )
5262
5363def shared_test_get_observables_from_errors (decoder_class , config_class ):
5464 """
@@ -63,45 +73,41 @@ def shared_test_get_observables_from_errors(decoder_class, config_class):
6373 '''
6474 dem = stim .DetectorErrorModel (dem_string )
6575 config = config_class (dem )
76+ config .merge_errors = False
6677 decoder = decoder_class (config )
67-
6878 num_observables = dem .num_observables
6979
7080 # Case 1: A single error mechanism that flips one observable (L0).
7181 errors1 = [0 ]
72- expected1 = np .array ([True , False ], dtype = bool )
73- predicted1_list = decoder .get_observables_from_errors (errors1 )
74- assert isinstance (predicted1_list , list )
75- predicted1_arr = np .array (predicted1_list , dtype = bool )
76- assert predicted1_arr .shape [0 ] == num_observables
77- assert np .array_equal (predicted1_arr , expected1 )
82+ predicted_list1 = decoder .get_observables_from_errors (errors1 )
83+ assert isinstance (predicted_list1 , list )
84+ predicted_array1 = np .array (predicted_list1 , dtype = bool )
85+ assert predicted_array1 .shape [0 ] == num_observables
86+ assert np .array_equal (predicted_array1 , np .array ([True , False ], dtype = bool ))
7887
7988 # Case 2: A single error mechanism that flips multiple observables (L0, L1).
8089 errors2 = [3 ]
81- expected2 = np .array ([True , True ], dtype = bool )
82- predicted2_list = decoder .get_observables_from_errors (errors2 )
83- assert isinstance (predicted2_list , list )
84- predicted2_arr = np .array (predicted2_list , dtype = bool )
85- assert predicted2_arr .shape [0 ] == num_observables
86- assert np .array_equal (predicted2_arr , expected2 )
90+ predicted_list2 = decoder .get_observables_from_errors (errors2 )
91+ assert isinstance (predicted_list2 , list )
92+ predicted_array2 = np .array (predicted_list2 , dtype = bool )
93+ assert predicted_array2 .shape [0 ] == num_observables
94+ assert np .array_equal (predicted_array2 , np .array ([True , True ], dtype = bool ))
8795
8896 # Case 3: Multiple error mechanisms whose effects cancel out (L0 from error 0, L0 from error 3).
8997 errors3 = [0 , 3 ]
90- expected3 = np .array ([False , True ], dtype = bool )
91- predicted3_list = decoder .get_observables_from_errors (errors3 )
92- assert isinstance (predicted3_list , list )
93- predicted3_arr = np .array (predicted3_list , dtype = bool )
94- assert predicted3_arr .shape [0 ] == num_observables
95- assert np .array_equal (predicted3_arr , expected3 )
98+ predicted_list3 = decoder .get_observables_from_errors (errors3 )
99+ assert isinstance (predicted_list3 , list )
100+ predicted_array3 = np .array (predicted_list3 , dtype = bool )
101+ assert predicted_array3 .shape [0 ] == num_observables
102+ assert np .array_equal (predicted_array3 , np .array ([False , True ], dtype = bool ))
96103
97104 # Case 4: No errors.
98105 errors4 = []
99- expected4 = np .zeros (num_observables , dtype = bool )
100- predicted4_list = decoder .get_observables_from_errors (errors4 )
101- assert isinstance (predicted4_list , list )
102- predicted4_arr = np .array (predicted4_list , dtype = bool )
103- assert predicted4_arr .shape [0 ] == num_observables
104- assert np .array_equal (predicted4_arr , expected4 )
106+ predicted_list4 = decoder .get_observables_from_errors (errors4 )
107+ assert isinstance (predicted_list4 , list )
108+ predicted_array4 = np .array (predicted_list4 , dtype = bool )
109+ assert predicted_array4 .shape [0 ] == num_observables
110+ assert np .array_equal (predicted_array4 , np .zeros (num_observables , dtype = bool ))
105111
106112
107113def shared_test_decoder_predicts_various_observable_flips (decoder_class , config_class ):
@@ -118,8 +124,6 @@ def shared_test_decoder_predicts_various_observable_flips(decoder_class, config_
118124 config = config_class (dem )
119125 decoder = decoder_class (config )
120126
121- if hasattr (decoder , 'init_ilp' ):
122- decoder .init_ilp ()
123127 syndrome = np .zeros (dem .num_detectors , dtype = bool )
124128 syndrome [0 ] = True
125129 predicted_logical_flips_array = decoder .decode (syndrome )
@@ -142,8 +146,6 @@ def shared_test_decode_from_detection_events(decoder_class, config_class):
142146 dem = stim .DetectorErrorModel (dem_string )
143147 config = config_class (dem )
144148 decoder = decoder_class (config )
145- if hasattr (decoder , 'init_ilp' ):
146- decoder .init_ilp ()
147149
148150 # Case 1: Detections corresponding to the D0 D1 L0 error
149151 detections1 = [0 , 1 ]
@@ -177,8 +179,6 @@ def shared_test_decode(decoder_class, config_class):
177179 dem = stim .DetectorErrorModel (dem_string )
178180 config = config_class (dem )
179181 decoder = decoder_class (config )
180- if hasattr (decoder , 'init_ilp' ):
181- decoder .init_ilp ()
182182
183183 syndrome1 = np .array ([True , True ], dtype = bool )
184184 predicted1 = decoder .decode (syndrome1 )
@@ -205,8 +205,6 @@ def shared_test_decode_complex_dem(decoder_class, config_class):
205205 dem = stim .DetectorErrorModel (dem_string )
206206 config = config_class (dem )
207207 decoder = decoder_class (config )
208- if hasattr (decoder , 'init_ilp' ):
209- decoder .init_ilp ()
210208
211209 syndrome = np .array ([True , True , False ], dtype = bool )
212210 predicted = decoder .decode (syndrome )
@@ -225,8 +223,6 @@ def shared_test_decode_batch_with_invalid_dimensions(decoder_class, config_class
225223 dem = stim .DetectorErrorModel (dem_string )
226224 config = config_class (dem )
227225 decoder = decoder_class (config )
228- if hasattr (decoder , 'init_ilp' ):
229- decoder .init_ilp ()
230226 invalid_syndrome = np .array ([True , True ], dtype = bool )
231227 with pytest .raises (RuntimeError , match = "Input syndromes must be a 2D NumPy array." ):
232228 decoder .decode_batch (invalid_syndrome )
@@ -243,14 +239,12 @@ def shared_test_decode_batch(decoder_class, config_class):
243239 dem = stim .DetectorErrorModel (dem_string )
244240 config = config_class (dem )
245241 decoder = decoder_class (config )
246- if hasattr (decoder , 'init_ilp' ):
247- decoder .init_ilp ()
248-
249242 syndromes = np .array ([
250243 [True , True ],
251244 [True , False ],
252245 [False , True ],
253246 ], dtype = bool )
247+
254248 predictions = decoder .decode_batch (syndromes )
255249 assert isinstance (predictions , np .ndarray )
256250 assert predictions .dtype .type == np .bool_
@@ -276,9 +270,6 @@ def shared_test_decode_batch_with_complex_model(decoder_class, config_class):
276270 dem = stim .DetectorErrorModel (dem_string )
277271 config = config_class (dem )
278272 decoder = decoder_class (config )
279- if hasattr (decoder , 'init_ilp' ):
280- decoder .init_ilp ()
281-
282273 batch_syndromes = np .array ([
283274 [True , True , False , False ],
284275 [True , False , True , False ],
@@ -298,3 +289,70 @@ def shared_test_decode_batch_with_complex_model(decoder_class, config_class):
298289 [False , False , False ],
299290 ], dtype = bool )
300291 assert np .array_equal (predictions , expected_predictions )
292+
293+
294+
295+ def shared_test_merge_errors_affects_cost (decoder_class , config_class ):
296+ """
297+ Test that the error's cost changes based on the 'merge_errors' setting.
298+ """
299+ dem = stim .DetectorErrorModel (
300+ """
301+ error(0.1) D0
302+ error(0.01) D0
303+ """
304+ )
305+ detections = [0 ]
306+
307+ config_no_merge = config_class (dem , merge_errors = False )
308+ decoder_no_merge = decoder_class (config_no_merge )
309+ predicted_errors_no_merge = decoder_no_merge .decode_to_errors (detections )
310+ cost_no_merge = decoder_no_merge .cost_from_errors (decoder_no_merge .predicted_errors_buffer )
311+
312+ config_merge = config_class (dem , merge_errors = True )
313+ decoder_merge = decoder_class (config_merge )
314+ predicted_errors_merge = decoder_merge .decode_to_errors (detections )
315+ cost_merge = decoder_merge .cost_from_errors (decoder_merge .predicted_errors_buffer )
316+
317+ p_merged = 0.1 * (1 - 0.01 ) + 0.01 * (1 - 0.1 )
318+ expected_cost_no_merge = math .log ((1 - 0.1 ) / 0.1 )
319+ expected_cost_merge = math .log ((1 - p_merged ) / p_merged )
320+
321+ assert predicted_errors_no_merge == predicted_errors_merge
322+ assert cost_no_merge == pytest .approx (expected_cost_no_merge )
323+ assert cost_merge == pytest .approx (expected_cost_merge )
324+ assert cost_no_merge != cost_merge
325+
326+ def shared_test_decode_with_mismatched_syndrome_size (decoder_class , config_class ):
327+ """
328+ Tests that `decode` raises an error when the input syndrome's length does
329+ not match the number of detectors in the DEM.
330+ """
331+ dem_string = f'''
332+ error(0.1) D0 D1
333+ '''
334+ dem = stim .DetectorErrorModel (dem_string )
335+ config = config_class (dem )
336+ decoder = decoder_class (config )
337+
338+ # Syndrome has 1 detector, but DEM has 2
339+ invalid_syndrome = np .array ([True ], dtype = bool )
340+ with pytest .raises (ValueError , match = r"Syndrome array size \(1\) does not match the number of detectors in the decoder \(2\)\." ):
341+ decoder .decode (invalid_syndrome )
342+
343+ def shared_test_decode_batch_with_mismatched_syndrome_size (decoder_class , config_class ):
344+ """
345+ Tests that `decode_batch` raises an error when the input syndromes' width
346+ does not match the number of detectors in the DEM.
347+ """
348+ dem_string = f'''
349+ error(0.1) D0 D1
350+ '''
351+ dem = stim .DetectorErrorModel (dem_string )
352+ config = config_class (dem )
353+ decoder = decoder_class (config )
354+
355+ # Syndrome batch has 1 column, but DEM has 2
356+ invalid_syndromes = np .array ([[True ], [False ]], dtype = bool )
357+ with pytest .raises (ValueError , match = r"The number of detectors in the input array \(1\) does not match the number of detectors in the decoder \(2\)." ):
358+ decoder .decode_batch (invalid_syndromes )
0 commit comments