diff --git a/src/py/utils_test.py b/src/py/utils_test.py index 304ddf3..3cbec85 100644 --- a/src/py/utils_test.py +++ b/src/py/utils_test.py @@ -27,6 +27,10 @@ """ ) +_DETECTOR_ERROR_MODEL_10 = stim.DetectorErrorModel( + "\n".join(f"error(0.1) D{i}" for i in range(10)) +) + def test_module_has_global_constants(): assert tesseract_decoder.utils.EPSILON <= 1e-7 @@ -44,7 +48,7 @@ def test_build_detector_graph(): ] -def test_build_det_orders(): +def test_build_det_orders_bfs(): assert tesseract_decoder.utils.build_det_orders( _DETECTOR_ERROR_MODEL, num_det_orders=1, seed=0 ) == [[0, 1]] @@ -61,12 +65,14 @@ def test_build_det_orders_coordinate(): def test_build_det_orders_index(): res = tesseract_decoder.utils.build_det_orders( - _DETECTOR_ERROR_MODEL, + _DETECTOR_ERROR_MODEL_10, num_det_orders=1, method=tesseract_decoder.utils.DetOrder.DetIndex, seed=0, ) - assert res == [[0, 1]] or res == [[1, 0]] + expected_asc = list(range(10)) + expected_desc = list(range(9, -1, -1)) + assert res == [expected_asc] or res == [expected_desc] def test_get_errors_from_dem(): diff --git a/src/tesseract.cc b/src/tesseract.cc index 97c8f4d..ca44133 100644 --- a/src/tesseract.cc +++ b/src/tesseract.cc @@ -201,8 +201,10 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections) } if (config.beam_climbing) { - for (int beam = config.det_beam; beam >= 0; --beam) { - size_t detector_order = beam % config.det_orders.size(); + int beam = 0; + int detector_order = 0; + for (int trial = 0; trial < std::max(config.det_beam + 1, int(config.det_orders.size())); + ++trial) { decode_to_errors(detections, detector_order, beam); double local_cost = cost_from_errors(predicted_errors_buffer); if (!low_confidence_flag && local_cost < best_cost) { @@ -215,6 +217,10 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections) << " and obs_mask " << get_flipped_observables(predicted_errors_buffer) << ". Best cost so far: " << best_cost << std::endl; } + beam += 1; + detector_order += 1; + beam %= (config.det_beam + 1); + detector_order %= config.det_orders.size(); } } else { for (size_t detector_order = 0; detector_order < config.det_orders.size(); ++detector_order) { diff --git a/src/utils.cc b/src/utils.cc index 261a6aa..4100a94 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -82,104 +82,124 @@ std::vector> build_detector_graph(const stim::DetectorErrorM return neighbors; } -std::vector> build_det_orders(const stim::DetectorErrorModel& dem, - size_t num_det_orders, DetOrder method, - uint64_t seed) { +static std::vector> build_det_orders_bfs(const stim::DetectorErrorModel& dem, + size_t num_det_orders, + std::mt19937_64& rng) { std::vector> det_orders(num_det_orders); - std::mt19937_64 rng(seed); - std::normal_distribution dist(0, 1); - - auto detector_coords = get_detector_coords(dem); - - if (method == DetOrder::DetBFS) { - auto graph = build_detector_graph(dem); - std::uniform_int_distribution dist_det(0, graph.size() - 1); - for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { - std::vector perm; - perm.reserve(graph.size()); - std::vector visited(graph.size(), false); - std::queue q; - size_t start = dist_det(rng); - while (perm.size() < graph.size()) { - if (!visited[start]) { - visited[start] = true; - q.push(start); - perm.push_back(start); - } - while (!q.empty()) { - size_t cur = q.front(); - q.pop(); - auto neigh = graph[cur]; - std::shuffle(neigh.begin(), neigh.end(), rng); - for (size_t n : neigh) { - if (!visited[n]) { - visited[n] = true; - q.push(n); - perm.push_back(n); - } + auto graph = build_detector_graph(dem); + std::uniform_int_distribution dist_det(0, graph.size() - 1); + for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { + std::vector perm; + perm.reserve(graph.size()); + std::vector visited(graph.size(), false); + std::queue q; + size_t start = dist_det(rng); + while (perm.size() < graph.size()) { + if (!visited[start]) { + visited[start] = true; + q.push(start); + perm.push_back(start); + } + while (!q.empty()) { + size_t cur = q.front(); + q.pop(); + auto neigh = graph[cur]; + std::shuffle(neigh.begin(), neigh.end(), rng); + for (size_t n : neigh) { + if (!visited[n]) { + visited[n] = true; + q.push(n); + perm.push_back(n); } } - if (perm.size() < graph.size()) { - do { - start = dist_det(rng); - } while (visited[start]); - } } - std::vector inv_perm(graph.size()); - for (size_t i = 0; i < perm.size(); ++i) { - inv_perm[perm[i]] = i; + if (perm.size() < graph.size()) { + do { + start = dist_det(rng); + } while (visited[start]); } - det_orders[det_order] = inv_perm; } - } else if (method == DetOrder::DetCoordinate) { - std::vector inner_products(dem.count_detectors()); - if (!detector_coords.size() || !detector_coords.at(0).size()) { - for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { - det_orders[det_order].resize(dem.count_detectors()); - std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0); - } - } else { - for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { - std::vector orientation_vector; - for (size_t i = 0; i < detector_coords.at(0).size(); ++i) { - orientation_vector.push_back(dist(rng)); - } + std::vector inv_perm(graph.size()); + for (size_t i = 0; i < perm.size(); ++i) { + inv_perm[perm[i]] = i; + } + det_orders[det_order] = inv_perm; + } + return det_orders; +} - for (size_t i = 0; i < detector_coords.size(); ++i) { - inner_products[i] = 0; - for (size_t j = 0; j < orientation_vector.size(); ++j) { - inner_products[i] += detector_coords[i][j] * orientation_vector[j]; - } - } - std::vector perm(dem.count_detectors()); - std::iota(perm.begin(), perm.end(), 0); - std::sort(perm.begin(), perm.end(), [&](const size_t& i, const size_t& j) { - return inner_products[i] > inner_products[j]; - }); - std::vector inv_perm(dem.count_detectors()); - for (size_t i = 0; i < perm.size(); ++i) { - inv_perm[perm[i]] = i; - } - det_orders[det_order] = inv_perm; +static std::vector> build_det_orders_coordinate( + const stim::DetectorErrorModel& dem, size_t num_det_orders, std::mt19937_64& rng) { + std::vector> det_orders(num_det_orders); + auto detector_coords = get_detector_coords(dem); + std::vector inner_products(dem.count_detectors()); + std::normal_distribution dist(0, 1); + if (detector_coords.empty() || detector_coords.at(0).empty()) { + for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { + det_orders[det_order].resize(dem.count_detectors()); + std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0); + } + return det_orders; + } + for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { + std::vector orientation_vector; + for (size_t i = 0; i < detector_coords.at(0).size(); ++i) { + orientation_vector.push_back(dist(rng)); + } + for (size_t i = 0; i < detector_coords.size(); ++i) { + inner_products[i] = 0; + for (size_t j = 0; j < orientation_vector.size(); ++j) { + inner_products[i] += detector_coords[i][j] * orientation_vector[j]; } } - } else if (method == DetOrder::DetIndex) { - std::uniform_int_distribution dist_bool(0, 1); - size_t n = dem.count_detectors(); - for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { - det_orders[det_order].resize(n); - if (dist_bool(rng)) { - for (size_t i = 0; i < n; ++i) { - det_orders[det_order][i] = n - 1 - i; - } - } else { - std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0); + std::vector perm(dem.count_detectors()); + std::iota(perm.begin(), perm.end(), 0); + std::sort(perm.begin(), perm.end(), [&](const size_t& i, const size_t& j) { + return inner_products[i] > inner_products[j]; + }); + std::vector inv_perm(dem.count_detectors()); + for (size_t i = 0; i < perm.size(); ++i) { + inv_perm[perm[i]] = i; + } + det_orders[det_order] = inv_perm; + } + return det_orders; +} + +static std::vector> build_det_orders_index(const stim::DetectorErrorModel& dem, + size_t num_det_orders, + std::mt19937_64& rng) { + std::vector> det_orders(num_det_orders); + std::uniform_int_distribution dist_bool(0, 1); + size_t n = dem.count_detectors(); + for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { + det_orders[det_order].resize(n); + if (dist_bool(rng)) { + for (size_t i = 0; i < n; ++i) { + det_orders[det_order][i] = n - 1 - i; } + } else { + std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0); } } return det_orders; } +std::vector> build_det_orders(const stim::DetectorErrorModel& dem, + size_t num_det_orders, DetOrder method, + uint64_t seed) { + std::mt19937_64 rng(seed); + switch (method) { + case DetOrder::DetBFS: + return build_det_orders_bfs(dem, num_det_orders, rng); + case DetOrder::DetCoordinate: + return build_det_orders_coordinate(dem, num_det_orders, rng); + case DetOrder::DetIndex: + return build_det_orders_index(dem, num_det_orders, rng); + } + throw std::invalid_argument("Unknown det order method"); +} + bool sampling_from_dem(uint64_t seed, size_t num_shots, stim::DetectorErrorModel dem, std::vector& shots) { stim::DemSampler sampler(dem, std::mt19937_64{seed}, num_shots);