From 78e5ec8dc7df71226193adc513fe3bdaf605931b Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Sun, 10 Aug 2025 22:46:23 -0400 Subject: [PATCH 01/11] Add visualization library to CMake build --- CMakeLists.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bc3111bd..33420c7c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,10 +73,15 @@ target_include_directories(utils PUBLIC ${TESSERACT_SRC_DIR}) target_compile_options(utils PRIVATE ${OPT_COPTS}) target_link_libraries(utils PUBLIC common libstim Threads::Threads) +add_library(visualization ${TESSERACT_SRC_DIR}/visualization.cc ${TESSERACT_SRC_DIR}/visualization.h) +target_include_directories(visualization PUBLIC ${TESSERACT_SRC_DIR}) +target_compile_options(visualization PRIVATE ${OPT_COPTS}) +target_link_libraries(visualization PUBLIC common boost_headers) + add_library(tesseract_lib ${TESSERACT_SRC_DIR}/tesseract.cc ${TESSERACT_SRC_DIR}/tesseract.h) target_include_directories(tesseract_lib PUBLIC ${TESSERACT_SRC_DIR}) target_compile_options(tesseract_lib PRIVATE ${OPT_COPTS}) -target_link_libraries(tesseract_lib PUBLIC utils boost_headers) +target_link_libraries(tesseract_lib PUBLIC utils boost_headers visualization) add_library(simplex ${TESSERACT_SRC_DIR}/simplex.cc ${TESSERACT_SRC_DIR}/simplex.h) target_include_directories(simplex PUBLIC ${TESSERACT_SRC_DIR}) From 5c211bc872714ed03b2333c16d6c49a153c39d24 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Wed, 13 Aug 2025 11:27:13 -0400 Subject: [PATCH 02/11] Fix CMake Python module placement and add agent instructions --- .gitignore | 3 +++ AGENTS.md | 5 +++++ CMakeLists.txt | 7 +++++++ 3 files changed, 15 insertions(+) create mode 100644 AGENTS.md diff --git a/.gitignore b/.gitignore index 5b5d20b4..66d8f700 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,6 @@ eclipse-*bin/ /.sass-cache # User-specific .bazelrc user.bazelrc + +# Ignore python extension module produced by CMake. +src/tesseract_decoder*.so diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..ea68e20e --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,5 @@ +# Agent Instructions + +- Use the **CMake** build system when interacting with this repository. Humans use Bazel. +- A bug in some LLM coding environments makes Bazel difficult to use, so agents should rely on CMake. +- Keep both the CMake and Bazel builds working at all times. diff --git a/CMakeLists.txt b/CMakeLists.txt index 33420c7c..10e5ea69 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,4 +103,11 @@ pybind11_add_module(tesseract_decoder MODULE ${TESSERACT_SRC_DIR}/tesseract.pybi target_compile_options(tesseract_decoder PRIVATE ${OPT_COPTS}) target_include_directories(tesseract_decoder PRIVATE ${TESSERACT_SRC_DIR}) target_link_libraries(tesseract_decoder PRIVATE common utils simplex tesseract_lib) +set_target_properties(tesseract_decoder PROPERTIES + LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/src + LIBRARY_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/src + LIBRARY_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/src + LIBRARY_OUTPUT_DIRECTORY_MINSIZEREL ${PROJECT_SOURCE_DIR}/src + LIBRARY_OUTPUT_DIRECTORY_RELWITHDEBINFO ${PROJECT_SOURCE_DIR}/src +) From 7c10268eef32c22ce9189b195b84e83dc87e69ef Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Wed, 20 Aug 2025 15:09:33 -0700 Subject: [PATCH 03/11] Allow decode_to_errors to accept bitstring --- README.md | 6 +-- src/py/README.md | 25 +++++------ src/py/shared_decoding_tests.py | 6 +-- src/py/simplex_test.py | 3 +- src/py/tesseract_test.py | 7 +++- src/simplex.pybind.h | 34 ++++++++++++--- src/tesseract.pybind.h | 73 ++++++++++++++++++++++++++------- 7 files changed, 113 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 30230f7f..2e19e436 100644 --- a/README.md +++ b/README.md @@ -190,15 +190,15 @@ config = tesseract.TesseractConfig(dem=dem, det_beam=50) # 3. Create a decoder instance decoder = config.compile_decoder() -# 4. Simulate detection events -syndrome = [0, 1, 1] +# 4. Simulate detector outcomes +syndrome = np.array([0, 1, 1], dtype=bool) # 5a. Decode to observables flipped_observables = decoder.decode(syndrome) print(f"Flipped observables: {flipped_observables}") # 5b. Alternatively, decode to errors -decoder.decode_to_errors(np.where(syndrome)[0]) +decoder.decode_to_errors(syndrome) predicted_errors = decoder.predicted_errors_buffer # Indices of predicted errors print(f"Predicted errors indices: {predicted_errors}") diff --git a/src/py/README.md b/src/py/README.md index 79c0c396..625e41bd 100644 --- a/src/py/README.md +++ b/src/py/README.md @@ -64,28 +64,28 @@ print(f"Custom configuration detection penalty: {config2.det_beam}") #### Class `tesseract.TesseractDecoder` This is the main class that implements the Tesseract decoding logic. * `TesseractDecoder(config: tesseract.TesseractConfig)` -* `decode_to_errors(detections: list[int])` -* `decode_to_errors(detections: list[int], det_order: int, det_beam: int)` +* `decode_to_errors(syndrome: np.ndarray)` +* `decode_to_errors(syndrome: np.ndarray, det_order: int, det_beam: int)` * `get_observables_from_errors(predicted_errors: list[int]) -> list[bool]` * `cost_from_errors(predicted_errors: list[int]) -> float` -* `decode(detections: list[int]) -> list[bool]` +* `decode(syndrome: np.ndarray) -> np.ndarray` Explanation of each method: -#### `decode_to_errors(detections: list[int])` +#### `decode_to_errors(syndrome: np.ndarray)` Decodes a single measurement shot to predict a list of errors. -* **Parameters:** `detections` is a list of integers that represent the indices of the detectors that have fired in a single shot. +* **Parameters:** `syndrome` is a 1D NumPy array of booleans representing the detector outcomes for a single shot. * **Returns:** A list of integers, where each integer is the index of a predicted error. -#### `decode_to_errors(detections: list[int], det_order: int, det_beam: int)` +#### `decode_to_errors(syndrome: np.ndarray, det_order: int, det_beam: int)` An overloaded version of the `decode_to_errors` method that allows for a different decoding strategy. * **Parameters:** - * `detections` is a list of integers representing the indices of the fired detectors. + * `syndrome` is a 1D NumPy array of booleans representing the detector outcomes for a single shot. * `det_order` is an integer that specifies a different ordering of detectors to use for the decoding. @@ -219,10 +219,10 @@ print(f"Configuration verbose enabled: {config.verbose}") This is the main class for performing decoding using the Simplex algorithm. * `SimplexDecoder(config: simplex.SimplexConfig)` * `init_ilp()` -* `decode_to_errors(detections: list[int])` +* `decode_to_errors(syndrome: np.ndarray)` * `get_observables_from_errors(predicted_errors: list[int]) -> list[bool]` * `cost_from_errors(predicted_errors: list[int]) -> float` -* `decode(detections: list[int]) -> list[bool]` +* `decode(syndrome: np.ndarray) -> np.ndarray` **Example Usage**: @@ -230,6 +230,7 @@ This is the main class for performing decoding using the Simplex algorithm. import tesseract_decoder.simplex as simplex import stim import tesseract_decoder.common as common +import numpy as np # Create a DEM and a configuration dem = stim.DetectorErrorModel(""" @@ -245,9 +246,9 @@ decoder = simplex.SimplexDecoder(config) decoder.init_ilp() # Decode a shot where detector D1 fired -detections = [1] -flipped_observables = decoder.decode(detections) -print(f"Flipped observables for detections {detections}: {flipped_observables}") +syndrome = np.array([0, 1], dtype=bool) +flipped_observables = decoder.decode(syndrome) +print(f"Flipped observables for syndrome {syndrome.tolist()}: {flipped_observables}") # Access predicted errors predicted_error_indices = decoder.predicted_errors_buffer diff --git a/src/py/shared_decoding_tests.py b/src/py/shared_decoding_tests.py index 4500b141..82259d4b 100644 --- a/src/py/shared_decoding_tests.py +++ b/src/py/shared_decoding_tests.py @@ -302,16 +302,16 @@ def shared_test_merge_errors_affects_cost(decoder_class, config_class): error(0.01) D0 """ ) - detections = [0] + syndrome = np.array([True], dtype=bool) config_no_merge = config_class(dem, merge_errors=False) decoder_no_merge = decoder_class(config_no_merge) - predicted_errors_no_merge = decoder_no_merge.decode_to_errors(detections) + predicted_errors_no_merge = decoder_no_merge.decode_to_errors(syndrome) cost_no_merge = decoder_no_merge.cost_from_errors(decoder_no_merge.predicted_errors_buffer) config_merge = config_class(dem, merge_errors=True) decoder_merge = decoder_class(config_merge) - predicted_errors_merge = decoder_merge.decode_to_errors(detections) + predicted_errors_merge = decoder_merge.decode_to_errors(syndrome) cost_merge = decoder_merge.cost_from_errors(decoder_merge.predicted_errors_buffer) p_merged = 0.1 * (1 - 0.01) + 0.01 * (1 - 0.1) diff --git a/src/py/simplex_test.py b/src/py/simplex_test.py index 3a228d9c..752f9e8f 100644 --- a/src/py/simplex_test.py +++ b/src/py/simplex_test.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +import numpy as np import stim from src import tesseract_decoder @@ -56,7 +57,7 @@ def test_create_simplex_decoder(): decoder = tesseract_decoder.simplex.SimplexDecoder( tesseract_decoder.simplex.SimplexConfig(_DETECTOR_ERROR_MODEL, window_length=5) ) - decoder.decode_to_errors([1]) + decoder.decode_to_errors(np.array([False, True], dtype=bool)) assert decoder.get_observables_from_errors([1]) == [] assert decoder.cost_from_errors([2]) == pytest.approx(1.0986123) diff --git a/src/py/tesseract_test.py b/src/py/tesseract_test.py index 5df3e329..b7b21835 100644 --- a/src/py/tesseract_test.py +++ b/src/py/tesseract_test.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +import numpy as np import stim from src import tesseract_decoder @@ -60,8 +61,10 @@ def test_create_tesseract_config(): def test_create_tesseract_decoder(): config = tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL) decoder = tesseract_decoder.tesseract.TesseractDecoder(config) - decoder.decode_to_errors([0]) - decoder.decode_to_errors(detections=[0], det_order=0, det_beam=0) + decoder.decode_to_errors(np.array([True, False], dtype=bool)) + decoder.decode_to_errors( + syndrome=np.array([True, False], dtype=bool), det_order=0, det_beam=0 + ) assert decoder.get_observables_from_errors([1]) == [] assert decoder.cost_from_errors([1]) == pytest.approx(0.5108256237659907) diff --git a/src/simplex.pybind.h b/src/simplex.pybind.h index 79c8d59d..9c0e6b96 100644 --- a/src/simplex.pybind.h +++ b/src/simplex.pybind.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "common.h" #include "simplex.h" @@ -140,20 +141,43 @@ void add_simplex_module(py::module& root) { This method must be called before decoding. )pbdoc") - .def("decode_to_errors", &SimplexDecoder::decode_to_errors, py::arg("detections"), - py::call_guard(), R"pbdoc( + .def( + "decode_to_errors", + [](SimplexDecoder& self, const py::array_t& syndrome) { + if ((size_t)syndrome.size() != self.num_detectors) { + std::ostringstream msg; + msg << "Syndrome array size (" << syndrome.size() + << ") does not match the number of detectors in the decoder (" + << self.num_detectors << ")."; + throw std::invalid_argument(msg.str()); + } + + std::vector detections; + auto syndrome_unchecked = syndrome.unchecked<1>(); + for (size_t i = 0; i < (size_t)syndrome_unchecked.size(); ++i) { + if (syndrome_unchecked(i)) { + detections.push_back(i); + } + } + self.decode_to_errors(detections); + return self.predicted_errors_buffer; + }, + py::arg("syndrome"), + py::call_guard(), + R"pbdoc( Decodes a single shot to a list of error indices. Parameters ---------- - detections : list[int] - A list of indices of the detectors that have fired. + syndrome : np.ndarray + A 1D NumPy array of booleans representing the detector outcomes for a single shot. + The length of the array should match the number of detectors in the DEM. Returns ------- list[int] A list of predicted error indices. - )pbdoc") + )pbdoc") .def( "get_observables_from_errors", [](SimplexDecoder& self, const std::vector& predicted_errors) { diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index 267aa115..9c09e108 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "stim_utils.pybind.h" #include "tesseract.h" @@ -170,33 +171,75 @@ void add_tesseract_module(py::module& root) { config : TesseractConfig The configuration object for the decoder. )pbdoc") - .def("decode_to_errors", - py::overload_cast&>(&TesseractDecoder::decode_to_errors), - py::arg("detections"), - py::call_guard(), R"pbdoc( + .def( + "decode_to_errors", + [](TesseractDecoder& self, const py::array_t& syndrome) { + if ((size_t)syndrome.size() != self.num_detectors) { + std::ostringstream msg; + msg << "Syndrome array size (" << syndrome.size() + << ") does not match the number of detectors in the decoder (" + << self.num_detectors << ")."; + throw std::invalid_argument(msg.str()); + } + + std::vector detections; + auto syndrome_unchecked = syndrome.unchecked<1>(); + for (size_t i = 0; i < (size_t)syndrome_unchecked.size(); ++i) { + if (syndrome_unchecked(i)) { + detections.push_back(i); + } + } + self.decode_to_errors(detections); + return self.predicted_errors_buffer; + }, + py::arg("syndrome"), + py::call_guard(), + R"pbdoc( Decodes a single shot to a list of error indices. Parameters ---------- - detections : list[int] - A list of indices of the detectors that have fired. + syndrome : np.ndarray + A 1D NumPy array of booleans representing the detector outcomes for a single shot. + The length of the array should match the number of detectors in the DEM. Returns ------- list[int] A list of predicted error indices. - )pbdoc") - .def("decode_to_errors", - py::overload_cast&, size_t, size_t>( - &TesseractDecoder::decode_to_errors), - py::arg("detections"), py::arg("det_order"), py::arg("det_beam"), - py::call_guard(), R"pbdoc( + )pbdoc") + .def( + "decode_to_errors", + [](TesseractDecoder& self, const py::array_t& syndrome, size_t det_order, + size_t det_beam) { + if ((size_t)syndrome.size() != self.num_detectors) { + std::ostringstream msg; + msg << "Syndrome array size (" << syndrome.size() + << ") does not match the number of detectors in the decoder (" + << self.num_detectors << ")."; + throw std::invalid_argument(msg.str()); + } + + std::vector detections; + auto syndrome_unchecked = syndrome.unchecked<1>(); + for (size_t i = 0; i < (size_t)syndrome_unchecked.size(); ++i) { + if (syndrome_unchecked(i)) { + detections.push_back(i); + } + } + self.decode_to_errors(detections, det_order, det_beam); + return self.predicted_errors_buffer; + }, + py::arg("syndrome"), py::arg("det_order"), py::arg("det_beam"), + py::call_guard(), + R"pbdoc( Decodes a single shot using a specific detector ordering and beam size. Parameters ---------- - detections : list[int] - A list of indices of the detectors that have fired. + syndrome : np.ndarray + A 1D NumPy array of booleans representing the detector outcomes for a single shot. + The length of the array should match the number of detectors in the DEM. det_order : int The index of the detector ordering to use. det_beam : int @@ -206,7 +249,7 @@ void add_tesseract_module(py::module& root) { ------- list[int] A list of predicted error indices. - )pbdoc") + )pbdoc") .def( "get_observables_from_errors", [](TesseractDecoder& self, const std::vector& predicted_errors) { From c34dd80aec9f031bec0dcc47e6486b481f62d7b2 Mon Sep 17 00:00:00 2001 From: noajshu Date: Wed, 20 Aug 2025 22:18:00 +0000 Subject: [PATCH 04/11] clang-format --- src/simplex.pybind.h | 1 + src/tesseract.pybind.h | 1 + 2 files changed, 2 insertions(+) diff --git a/src/simplex.pybind.h b/src/simplex.pybind.h index 9c0e6b96..78f9a9e5 100644 --- a/src/simplex.pybind.h +++ b/src/simplex.pybind.h @@ -20,6 +20,7 @@ #include #include #include + #include #include "common.h" diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index 37c00ca8..32aa84d9 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -20,6 +20,7 @@ #include #include #include + #include #include "stim_utils.pybind.h" From 24c0d691dcd56e7e4aeea2341aac51e4fddbcfaa Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Wed, 20 Aug 2025 16:00:57 -0700 Subject: [PATCH 05/11] Update src/tesseract.pybind.h Co-authored-by: Noureldin --- src/tesseract.pybind.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index 32aa84d9..efdd17d7 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -253,7 +253,10 @@ void add_tesseract_module(py::module& root) { msg << "Syndrome array size (" << syndrome.size() << ") does not match the number of detectors in the decoder (" << self.num_detectors << ")."; - throw std::invalid_argument(msg.str()); + std::string msg = "Syndrome array size (" + std:to_string(syndrome.size()) + + ") does not match the number of detectors in the decoder (" + + std::to_string(self.num_detectors) + ")." + throw std::invalid_argument(msg); } std::vector detections; From 96849b67693c5beceb913b8f7f521bafdce1dc01 Mon Sep 17 00:00:00 2001 From: noajshu Date: Wed, 20 Aug 2025 23:38:00 +0000 Subject: [PATCH 06/11] remove stringstream --- src/simplex.pybind.h | 9 ++++----- src/tesseract.pybind.h | 9 +++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/simplex.pybind.h b/src/simplex.pybind.h index 78f9a9e5..7297fab8 100644 --- a/src/simplex.pybind.h +++ b/src/simplex.pybind.h @@ -253,11 +253,10 @@ void add_simplex_module(py::module& root) { "decode", [](SimplexDecoder& self, const py::array_t& syndrome) { if ((size_t)syndrome.size() != self.num_detectors) { - std::ostringstream msg; - msg << "Syndrome array size (" << syndrome.size() - << ") does not match the number of detectors in the decoder (" - << self.num_detectors << ")."; - throw std::invalid_argument(msg.str()); + std::string msg = "Syndrome array size (" + std::to_string(syndrome.size()) + + ") does not match the number of detectors in the decoder (" + + std::to_string(self.num_detectors) + ")."; + throw std::invalid_argument(msg); } std::vector detections; diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index efdd17d7..1d27adf2 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -253,10 +253,11 @@ void add_tesseract_module(py::module& root) { msg << "Syndrome array size (" << syndrome.size() << ") does not match the number of detectors in the decoder (" << self.num_detectors << ")."; - std::string msg = "Syndrome array size (" + std:to_string(syndrome.size()) - + ") does not match the number of detectors in the decoder (" - + std::to_string(self.num_detectors) + ")." - throw std::invalid_argument(msg); + std::string msg = "Syndrome array size (" + + std + : to_string(syndrome.size()) + + ") does not match the number of detectors in the decoder (" + + std::to_string(self.num_detectors) + ")." throw std::invalid_argument(msg); } std::vector detections; From 3d486b0975025c822e3db1309d1c181b36faf4cf Mon Sep 17 00:00:00 2001 From: noajshu Date: Wed, 20 Aug 2025 23:51:20 +0000 Subject: [PATCH 07/11] more fixes --- src/simplex.pybind.h | 21 ++++++-------- src/tesseract.pybind.h | 43 +++++++++++----------------- src/tesseract_sinter_compat.pybind.h | 4 +-- 3 files changed, 28 insertions(+), 40 deletions(-) diff --git a/src/simplex.pybind.h b/src/simplex.pybind.h index 7297fab8..9d00d3c6 100644 --- a/src/simplex.pybind.h +++ b/src/simplex.pybind.h @@ -21,8 +21,6 @@ #include #include -#include - #include "common.h" #include "simplex.h" #include "stim_utils.pybind.h" @@ -146,11 +144,10 @@ void add_simplex_module(py::module& root) { "decode_to_errors", [](SimplexDecoder& self, const py::array_t& syndrome) { if ((size_t)syndrome.size() != self.num_detectors) { - std::ostringstream msg; - msg << "Syndrome array size (" << syndrome.size() - << ") does not match the number of detectors in the decoder (" - << self.num_detectors << ")."; - throw std::invalid_argument(msg.str()); + std::string msg = "Syndrome array size (" + std::to_string(syndrome.size()) + + ") does not match the number of detectors in the decoder (" + + std::to_string(self.num_detectors) + ")."; + throw std::invalid_argument(msg); } std::vector detections; @@ -311,11 +308,11 @@ void add_simplex_module(py::module& root) { size_t num_detectors = syndromes_unchecked.shape(1); if (num_detectors != self.num_detectors) { - std::ostringstream msg; - msg << "The number of detectors in the input array (" << num_detectors - << ") does not match the number of detectors in the decoder (" - << self.num_detectors << ")."; - throw std::invalid_argument(msg.str()); + std::string msg = "The number of detectors in the input array (" + + std::to_string(num_detectors) + + ") does not match the number of detectors in the decoder (" + + std::to_string(self.num_detectors) + ")."; + throw std::invalid_argument(msg); } // Allocate the result array. diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index 1d27adf2..b94c27f5 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -21,8 +21,6 @@ #include #include -#include - #include "stim_utils.pybind.h" #include "tesseract.h" @@ -249,15 +247,10 @@ void add_tesseract_module(py::module& root) { "decode_to_errors", [](TesseractDecoder& self, const py::array_t& syndrome) { if ((size_t)syndrome.size() != self.num_detectors) { - std::ostringstream msg; - msg << "Syndrome array size (" << syndrome.size() - << ") does not match the number of detectors in the decoder (" - << self.num_detectors << ")."; - std::string msg = "Syndrome array size (" + - std - : to_string(syndrome.size()) + - ") does not match the number of detectors in the decoder (" + - std::to_string(self.num_detectors) + ")." throw std::invalid_argument(msg); + std::string msg = "Syndrome array size (" + std::to_string(syndrome.size()) + + ") does not match the number of detectors in the decoder (" + + std::to_string(self.num_detectors) + ")."; + throw std::invalid_argument(msg); } std::vector detections; @@ -291,11 +284,10 @@ void add_tesseract_module(py::module& root) { [](TesseractDecoder& self, const py::array_t& syndrome, size_t det_order, size_t det_beam) { if ((size_t)syndrome.size() != self.num_detectors) { - std::ostringstream msg; - msg << "Syndrome array size (" << syndrome.size() - << ") does not match the number of detectors in the decoder (" - << self.num_detectors << ")."; - throw std::invalid_argument(msg.str()); + std::string msg = "Syndrome array size (" + std::to_string(syndrome.size()) + + ") does not match the number of detectors in the decoder (" + + std::to_string(self.num_detectors) + ")."; + throw std::invalid_argument(msg); } std::vector detections; @@ -403,11 +395,10 @@ void add_tesseract_module(py::module& root) { "decode", [](TesseractDecoder& self, const py::array_t& syndrome) { if ((size_t)syndrome.size() != self.num_detectors) { - std::ostringstream msg; - msg << "Syndrome array size (" << syndrome.size() - << ") does not match the number of detectors in the decoder (" - << self.num_detectors << ")."; - throw std::invalid_argument(msg.str()); + std::string msg = "Syndrome array size (" + std::to_string(syndrome.size()) + + ") does not match the number of detectors in the decoder (" + + std::to_string(self.num_detectors) + ")."; + throw std::invalid_argument(msg); } std::vector detections; @@ -461,11 +452,11 @@ void add_tesseract_module(py::module& root) { size_t num_detectors = syndromes_unchecked.shape(1); if (num_detectors != self.num_detectors) { - std::ostringstream msg; - msg << "The number of detectors in the input array (" << num_detectors - << ") does not match the number of detectors in the decoder (" - << self.num_detectors << ")."; - throw std::invalid_argument(msg.str()); + std::string msg = "The number of detectors in the input array (" + + std::to_string(num_detectors) + + ") does not match the number of detectors in the decoder (" + + std::to_string(self.num_detectors) + ")."; + throw std::invalid_argument(msg); } // Allocate the result array. diff --git a/src/tesseract_sinter_compat.pybind.h b/src/tesseract_sinter_compat.pybind.h index 3d15e6b2..623253d9 100644 --- a/src/tesseract_sinter_compat.pybind.h +++ b/src/tesseract_sinter_compat.pybind.h @@ -83,7 +83,7 @@ struct TesseractSinterCompiledDecoder { // Store predictions into the output buffer uint8_t* single_result_buffer = result_buffer + shot * num_observable_bytes; std::fill(single_result_buffer, single_result_buffer + num_observable_bytes, 0); - for (int obs_index : predictions) { + for (size_t obs_index : predictions) { if (obs_index >= 0 && obs_index < num_observables) { single_result_buffer[obs_index / 8] ^= (1 << (obs_index % 8)); } @@ -191,7 +191,7 @@ struct TesseractSinterDecoder { // Pack the predictions back into a bit-packed format. std::fill(single_result_data.begin(), single_result_data.end(), 0); - for (int obs_index : predictions) { + for (size_t obs_index : predictions) { if (obs_index >= 0 && obs_index < num_obs) { single_result_data[obs_index / 8] ^= (1 << (obs_index % 8)); } From 61867e29d82be2417f69635c26feda3d27800a2e Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Fri, 29 Aug 2025 13:47:19 -0700 Subject: [PATCH 08/11] Handle explicit DetIndex case --- src/py/utils_test.py | 17 ++++++++++++++-- src/tesseract.pybind.h | 4 ++-- src/tesseract_main.cc | 44 ++++++++++++++++++++++++++---------------- src/utils.cc | 19 +++++++++++++++--- src/utils.h | 5 ++++- src/utils.pybind.h | 24 +++++++++++++++-------- 6 files changed, 80 insertions(+), 33 deletions(-) diff --git a/src/py/utils_test.py b/src/py/utils_test.py index b5a07c90..304ddf3f 100644 --- a/src/py/utils_test.py +++ b/src/py/utils_test.py @@ -50,12 +50,25 @@ def test_build_det_orders(): ) == [[0, 1]] -def test_build_det_orders_no_bfs(): +def test_build_det_orders_coordinate(): assert tesseract_decoder.utils.build_det_orders( - _DETECTOR_ERROR_MODEL, num_det_orders=1, det_order_bfs=False, seed=0 + _DETECTOR_ERROR_MODEL, + num_det_orders=1, + method=tesseract_decoder.utils.DetOrder.DetCoordinate, + seed=0, ) == [[0, 1]] +def test_build_det_orders_index(): + res = tesseract_decoder.utils.build_det_orders( + _DETECTOR_ERROR_MODEL, + num_det_orders=1, + method=tesseract_decoder.utils.DetOrder.DetIndex, + seed=0, + ) + assert res == [[0, 1]] or res == [[1, 0]] + + def test_get_errors_from_dem(): expected = "Error{cost=1.945910, symptom=Symptom{detectors=[0], observables=[]}}, Error{cost=0.510826, symptom=Symptom{detectors=[0 1], observables=[]}}, Error{cost=1.098612, symptom=Symptom{detectors=[1], observables=[]}}" assert ( diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index 6c145856..65111f0b 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -41,7 +41,7 @@ TesseractConfig tesseract_config_maker_no_dem( double det_penalty = 0.0, bool create_visualization = false) { stim::DetectorErrorModel empty_dem; if (det_orders.empty()) { - det_orders = build_det_orders(empty_dem, 20, /*det_order_bfs=*/true, 2384753); + det_orders = build_det_orders(empty_dem, 20, DetOrder::DetBFS, 2384753); } return TesseractConfig({empty_dem, det_beam, beam_climbing, no_revisit_dets, at_most_two_errors_per_detector, verbose, merge_errors, pqlimit, @@ -57,7 +57,7 @@ TesseractConfig tesseract_config_maker( double det_penalty = 0.0, bool create_visualization = false) { stim::DetectorErrorModel input_dem = parse_py_object(dem); if (det_orders.empty()) { - det_orders = build_det_orders(input_dem, 20, true, 2384753); + det_orders = build_det_orders(input_dem, 20, DetOrder::DetBFS, 2384753); } return TesseractConfig({input_dem, det_beam, beam_climbing, no_revisit_dets, at_most_two_errors_per_detector, verbose, merge_errors, pqlimit, diff --git a/src/tesseract_main.cc b/src/tesseract_main.cc index 8853f5ae..4785570e 100644 --- a/src/tesseract_main.cc +++ b/src/tesseract_main.cc @@ -34,7 +34,9 @@ struct Args { // Manifold orientation options uint64_t det_order_seed; size_t num_det_orders = 10; - bool det_order_bfs = true; + bool det_order_bfs = false; + bool det_order_index = false; + bool det_order_coordinate = false; // Sampling options size_t sample_num_shots = 0; @@ -88,6 +90,12 @@ struct Args { throw std::invalid_argument("Must provide at least one of --circuit or --dem"); } + int det_order_flags = int(det_order_bfs) + int(det_order_index) + int(det_order_coordinate); + if (det_order_flags > 1) { + throw std::invalid_argument( + "Only one of --det-order-bfs, --det-order-index, or --det-order-coordinate may be set."); + } + int num_data_sources = int(sample_num_shots > 0) + int(!in_fname.empty()); if (num_data_sources != 1) { throw std::invalid_argument("Requires exactly 1 source of shots."); @@ -180,8 +188,13 @@ struct Args { std::cout << ")" << std::endl; } } - config.det_orders = - build_det_orders(config.dem, num_det_orders, det_order_bfs, det_order_seed); + DetOrder order = DetOrder::DetBFS; + if (det_order_index) { + order = DetOrder::DetIndex; + } else if (det_order_coordinate) { + order = DetOrder::DetCoordinate; + } + config.det_orders = build_det_orders(config.dem, num_det_orders, order, det_order_seed); } if (sample_num_shots > 0) { @@ -296,21 +309,18 @@ int main(int argc, char* argv[]) { .metavar("N") .default_value(size_t(1)) .store_into(args.num_det_orders); - program.add_argument("--no-det-order-bfs") - .help("Disable BFS-based detector ordering and use geometric orientation") - .default_value(true) - .implicit_value(false) - .store_into(args.det_order_bfs); program.add_argument("--det-order-bfs") - .action([&](auto const&) { - std::cout << "BFS-based detector ordering is the default now; " - "--det-order-bfs is ignored." - << std::endl; - }) - .default_value(true) - .implicit_value(true) - .store_into(args.det_order_bfs) - .hidden(); + .help("Use BFS-based detector ordering (default if no method specified)") + .flag() + .store_into(args.det_order_bfs); + program.add_argument("--det-order-index") + .help("Randomly choose increasing or decreasing detector index order") + .flag() + .store_into(args.det_order_index); + program.add_argument("--det-order-coordinate") + .help("Random geometric detector orientation ordering") + .flag() + .store_into(args.det_order_coordinate); program.add_argument("--det-order-seed") .help( "Seed used when initializing the random detector traversal " diff --git a/src/utils.cc b/src/utils.cc index 1c678635..261a6aa1 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -83,7 +83,7 @@ std::vector> build_detector_graph(const stim::DetectorErrorM } std::vector> build_det_orders(const stim::DetectorErrorModel& dem, - size_t num_det_orders, bool det_order_bfs, + size_t num_det_orders, DetOrder method, uint64_t seed) { std::vector> det_orders(num_det_orders); std::mt19937_64 rng(seed); @@ -91,7 +91,7 @@ std::vector> build_det_orders(const stim::DetectorErrorModel auto detector_coords = get_detector_coords(dem); - if (det_order_bfs) { + if (method == DetOrder::DetBFS) { auto graph = build_detector_graph(dem); std::uniform_int_distribution dist_det(0, graph.size() - 1); for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { @@ -131,7 +131,7 @@ std::vector> build_det_orders(const stim::DetectorErrorModel } det_orders[det_order] = inv_perm; } - } else { + } else if (method == DetOrder::DetCoordinate) { std::vector inner_products(dem.count_detectors()); if (!detector_coords.size() || !detector_coords.at(0).size()) { for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { @@ -163,6 +163,19 @@ std::vector> build_det_orders(const stim::DetectorErrorModel det_orders[det_order] = inv_perm; } } + } else if (method == DetOrder::DetIndex) { + std::uniform_int_distribution dist_bool(0, 1); + size_t n = dem.count_detectors(); + for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { + det_orders[det_order].resize(n); + if (dist_bool(rng)) { + for (size_t i = 0; i < n; ++i) { + det_orders[det_order][i] = n - 1 - i; + } + } else { + std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0); + } + } } return det_orders; } diff --git a/src/utils.h b/src/utils.h index b66537cd..73d7817e 100644 --- a/src/utils.h +++ b/src/utils.h @@ -34,8 +34,11 @@ std::vector> get_detector_coords(const stim::DetectorErrorMo // in the model activates them both. std::vector> build_detector_graph(const stim::DetectorErrorModel& dem); +enum class DetOrder { DetBFS, DetIndex, DetCoordinate }; + std::vector> build_det_orders(const stim::DetectorErrorModel& dem, - size_t num_det_orders, bool det_order_bfs = true, + size_t num_det_orders, + DetOrder method = DetOrder::DetBFS, uint64_t seed = 0); const double INF = std::numeric_limits::infinity(); diff --git a/src/utils.pybind.h b/src/utils.pybind.h index 2fa23d79..92ba6680 100644 --- a/src/utils.pybind.h +++ b/src/utils.pybind.h @@ -32,6 +32,12 @@ void add_utils_module(py::module& root) { m.attr("INF") = INF; m.doc() = "A representation of infinity for floating point numbers."; + py::enum_(m, "DetOrder", "Detector ordering methods") + .value("DetBFS", DetOrder::DetBFS) + .value("DetIndex", DetOrder::DetIndex) + .value("DetCoordinate", DetOrder::DetCoordinate) + .export_values(); + m.def( "get_detector_coords", [](py::object dem) { @@ -79,11 +85,11 @@ void add_utils_module(py::module& root) { )pbdoc"); m.def( "build_det_orders", - [](py::object dem, size_t num_det_orders, bool det_order_bfs, uint64_t seed) { + [](py::object dem, size_t num_det_orders, DetOrder method, uint64_t seed) { auto input_dem = parse_py_object(dem); - return build_det_orders(input_dem, num_det_orders, det_order_bfs, seed); + return build_det_orders(input_dem, num_det_orders, method, seed); }, - py::arg("dem"), py::arg("num_det_orders"), py::arg("det_order_bfs") = true, + py::arg("dem"), py::arg("num_det_orders"), py::arg("method") = DetOrder::DetBFS, py::arg("seed") = 0, R"pbdoc( Generates various detector orderings for decoding. @@ -93,17 +99,19 @@ void add_utils_module(py::module& root) { The detector error model to generate orders for. num_det_orders : int The number of detector orderings to generate. - det_order_bfs : bool, default=True - If True, uses a Breadth-First Search (BFS) to generate - the orders. If False, uses a randomized ordering. + method : tesseract_decoder.utils.DetOrder, default=tesseract_decoder.utils.DetOrder.DetBFS + Strategy for ordering detectors. ``DetBFS`` performs a breadth-first + traversal, ``DetCoordinate`` uses randomized geometric orientations, + and ``DetIndex`` chooses either increasing or decreasing detector + index order at random. seed : int, default=0 A seed for the random number generator. Returns ------- list[list[int]] - A list of detector orderings. Each inner list is a - permutation of the detector indices. + A list of detector orderings. Each inner list maps a detector index + to its position in the ordering. )pbdoc"); m.def( "get_errors_from_dem", From 454521fd3f52f7d815a38e427cfddfc5b1396e4b Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Fri, 29 Aug 2025 15:56:49 -0700 Subject: [PATCH 09/11] expand det index test --- src/py/utils_test.py | 12 ++- src/utils.cc | 186 ++++++++++++++++++++++++------------------- 2 files changed, 112 insertions(+), 86 deletions(-) diff --git a/src/py/utils_test.py b/src/py/utils_test.py index 304ddf3f..3cbec85d 100644 --- a/src/py/utils_test.py +++ b/src/py/utils_test.py @@ -27,6 +27,10 @@ """ ) +_DETECTOR_ERROR_MODEL_10 = stim.DetectorErrorModel( + "\n".join(f"error(0.1) D{i}" for i in range(10)) +) + def test_module_has_global_constants(): assert tesseract_decoder.utils.EPSILON <= 1e-7 @@ -44,7 +48,7 @@ def test_build_detector_graph(): ] -def test_build_det_orders(): +def test_build_det_orders_bfs(): assert tesseract_decoder.utils.build_det_orders( _DETECTOR_ERROR_MODEL, num_det_orders=1, seed=0 ) == [[0, 1]] @@ -61,12 +65,14 @@ def test_build_det_orders_coordinate(): def test_build_det_orders_index(): res = tesseract_decoder.utils.build_det_orders( - _DETECTOR_ERROR_MODEL, + _DETECTOR_ERROR_MODEL_10, num_det_orders=1, method=tesseract_decoder.utils.DetOrder.DetIndex, seed=0, ) - assert res == [[0, 1]] or res == [[1, 0]] + expected_asc = list(range(10)) + expected_desc = list(range(9, -1, -1)) + assert res == [expected_asc] or res == [expected_desc] def test_get_errors_from_dem(): diff --git a/src/utils.cc b/src/utils.cc index 261a6aa1..4100a945 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -82,104 +82,124 @@ std::vector> build_detector_graph(const stim::DetectorErrorM return neighbors; } -std::vector> build_det_orders(const stim::DetectorErrorModel& dem, - size_t num_det_orders, DetOrder method, - uint64_t seed) { +static std::vector> build_det_orders_bfs(const stim::DetectorErrorModel& dem, + size_t num_det_orders, + std::mt19937_64& rng) { std::vector> det_orders(num_det_orders); - std::mt19937_64 rng(seed); - std::normal_distribution dist(0, 1); - - auto detector_coords = get_detector_coords(dem); - - if (method == DetOrder::DetBFS) { - auto graph = build_detector_graph(dem); - std::uniform_int_distribution dist_det(0, graph.size() - 1); - for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { - std::vector perm; - perm.reserve(graph.size()); - std::vector visited(graph.size(), false); - std::queue q; - size_t start = dist_det(rng); - while (perm.size() < graph.size()) { - if (!visited[start]) { - visited[start] = true; - q.push(start); - perm.push_back(start); - } - while (!q.empty()) { - size_t cur = q.front(); - q.pop(); - auto neigh = graph[cur]; - std::shuffle(neigh.begin(), neigh.end(), rng); - for (size_t n : neigh) { - if (!visited[n]) { - visited[n] = true; - q.push(n); - perm.push_back(n); - } + auto graph = build_detector_graph(dem); + std::uniform_int_distribution dist_det(0, graph.size() - 1); + for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { + std::vector perm; + perm.reserve(graph.size()); + std::vector visited(graph.size(), false); + std::queue q; + size_t start = dist_det(rng); + while (perm.size() < graph.size()) { + if (!visited[start]) { + visited[start] = true; + q.push(start); + perm.push_back(start); + } + while (!q.empty()) { + size_t cur = q.front(); + q.pop(); + auto neigh = graph[cur]; + std::shuffle(neigh.begin(), neigh.end(), rng); + for (size_t n : neigh) { + if (!visited[n]) { + visited[n] = true; + q.push(n); + perm.push_back(n); } } - if (perm.size() < graph.size()) { - do { - start = dist_det(rng); - } while (visited[start]); - } } - std::vector inv_perm(graph.size()); - for (size_t i = 0; i < perm.size(); ++i) { - inv_perm[perm[i]] = i; + if (perm.size() < graph.size()) { + do { + start = dist_det(rng); + } while (visited[start]); } - det_orders[det_order] = inv_perm; } - } else if (method == DetOrder::DetCoordinate) { - std::vector inner_products(dem.count_detectors()); - if (!detector_coords.size() || !detector_coords.at(0).size()) { - for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { - det_orders[det_order].resize(dem.count_detectors()); - std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0); - } - } else { - for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { - std::vector orientation_vector; - for (size_t i = 0; i < detector_coords.at(0).size(); ++i) { - orientation_vector.push_back(dist(rng)); - } + std::vector inv_perm(graph.size()); + for (size_t i = 0; i < perm.size(); ++i) { + inv_perm[perm[i]] = i; + } + det_orders[det_order] = inv_perm; + } + return det_orders; +} - for (size_t i = 0; i < detector_coords.size(); ++i) { - inner_products[i] = 0; - for (size_t j = 0; j < orientation_vector.size(); ++j) { - inner_products[i] += detector_coords[i][j] * orientation_vector[j]; - } - } - std::vector perm(dem.count_detectors()); - std::iota(perm.begin(), perm.end(), 0); - std::sort(perm.begin(), perm.end(), [&](const size_t& i, const size_t& j) { - return inner_products[i] > inner_products[j]; - }); - std::vector inv_perm(dem.count_detectors()); - for (size_t i = 0; i < perm.size(); ++i) { - inv_perm[perm[i]] = i; - } - det_orders[det_order] = inv_perm; +static std::vector> build_det_orders_coordinate( + const stim::DetectorErrorModel& dem, size_t num_det_orders, std::mt19937_64& rng) { + std::vector> det_orders(num_det_orders); + auto detector_coords = get_detector_coords(dem); + std::vector inner_products(dem.count_detectors()); + std::normal_distribution dist(0, 1); + if (detector_coords.empty() || detector_coords.at(0).empty()) { + for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { + det_orders[det_order].resize(dem.count_detectors()); + std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0); + } + return det_orders; + } + for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { + std::vector orientation_vector; + for (size_t i = 0; i < detector_coords.at(0).size(); ++i) { + orientation_vector.push_back(dist(rng)); + } + for (size_t i = 0; i < detector_coords.size(); ++i) { + inner_products[i] = 0; + for (size_t j = 0; j < orientation_vector.size(); ++j) { + inner_products[i] += detector_coords[i][j] * orientation_vector[j]; } } - } else if (method == DetOrder::DetIndex) { - std::uniform_int_distribution dist_bool(0, 1); - size_t n = dem.count_detectors(); - for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { - det_orders[det_order].resize(n); - if (dist_bool(rng)) { - for (size_t i = 0; i < n; ++i) { - det_orders[det_order][i] = n - 1 - i; - } - } else { - std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0); + std::vector perm(dem.count_detectors()); + std::iota(perm.begin(), perm.end(), 0); + std::sort(perm.begin(), perm.end(), [&](const size_t& i, const size_t& j) { + return inner_products[i] > inner_products[j]; + }); + std::vector inv_perm(dem.count_detectors()); + for (size_t i = 0; i < perm.size(); ++i) { + inv_perm[perm[i]] = i; + } + det_orders[det_order] = inv_perm; + } + return det_orders; +} + +static std::vector> build_det_orders_index(const stim::DetectorErrorModel& dem, + size_t num_det_orders, + std::mt19937_64& rng) { + std::vector> det_orders(num_det_orders); + std::uniform_int_distribution dist_bool(0, 1); + size_t n = dem.count_detectors(); + for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { + det_orders[det_order].resize(n); + if (dist_bool(rng)) { + for (size_t i = 0; i < n; ++i) { + det_orders[det_order][i] = n - 1 - i; } + } else { + std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0); } } return det_orders; } +std::vector> build_det_orders(const stim::DetectorErrorModel& dem, + size_t num_det_orders, DetOrder method, + uint64_t seed) { + std::mt19937_64 rng(seed); + switch (method) { + case DetOrder::DetBFS: + return build_det_orders_bfs(dem, num_det_orders, rng); + case DetOrder::DetCoordinate: + return build_det_orders_coordinate(dem, num_det_orders, rng); + case DetOrder::DetIndex: + return build_det_orders_index(dem, num_det_orders, rng); + } + throw std::invalid_argument("Unknown det order method"); +} + bool sampling_from_dem(uint64_t seed, size_t num_shots, stim::DetectorErrorModel dem, std::vector& shots) { stim::DemSampler sampler(dem, std::mt19937_64{seed}, num_shots); From 41f7c3e7189259383d8fed839dc90d67a02528e8 Mon Sep 17 00:00:00 2001 From: noajshu Date: Fri, 29 Aug 2025 23:33:56 +0000 Subject: [PATCH 10/11] update beam climbing for when det orders > beam+1 --- src/tesseract.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/tesseract.cc b/src/tesseract.cc index 97c8f4d8..ca441331 100644 --- a/src/tesseract.cc +++ b/src/tesseract.cc @@ -201,8 +201,10 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections) } if (config.beam_climbing) { - for (int beam = config.det_beam; beam >= 0; --beam) { - size_t detector_order = beam % config.det_orders.size(); + int beam = 0; + int detector_order = 0; + for (int trial = 0; trial < std::max(config.det_beam + 1, int(config.det_orders.size())); + ++trial) { decode_to_errors(detections, detector_order, beam); double local_cost = cost_from_errors(predicted_errors_buffer); if (!low_confidence_flag && local_cost < best_cost) { @@ -215,6 +217,10 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections) << " and obs_mask " << get_flipped_observables(predicted_errors_buffer) << ". Best cost so far: " << best_cost << std::endl; } + beam += 1; + detector_order += 1; + beam %= (config.det_beam + 1); + detector_order %= config.det_orders.size(); } } else { for (size_t detector_order = 0; detector_order < config.det_orders.size(); ++detector_order) { From b3e06d315791e08ab71d2ec686f8f16038785830 Mon Sep 17 00:00:00 2001 From: noajshu Date: Sun, 31 Aug 2025 01:31:40 +0000 Subject: [PATCH 11/11] refactor(python): Remove Node struct from pybind API The Node struct is an internal implementation detail of the C++ tesseract library and should not be exposed in the Python API. This commit removes the pybind bindings for the Node struct and also removes the corresponding test case from the Python tests. --- src/py/tesseract_test.py | 4 +--- src/tesseract.pybind.h | 26 +------------------------- 2 files changed, 2 insertions(+), 28 deletions(-) diff --git a/src/py/tesseract_test.py b/src/py/tesseract_test.py index 69b72e26..70d3b6ba 100644 --- a/src/py/tesseract_test.py +++ b/src/py/tesseract_test.py @@ -42,9 +42,7 @@ ) -def test_create_node(): - node = tesseract_decoder.tesseract.Node(errors=[1, 0]) - assert node.errors == [1, 0] + def test_create_tesseract_config(): diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index 72f92e5a..df17a0c6 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -199,31 +199,7 @@ void add_tesseract_module(py::module& root) { `TesseractConfig` object. )pbdoc"); - py::class_(m, "Node", R"pbdoc( - A class representing a node in the Tesseract search graph. - - This is used internally by the decoder to track decoding progress. - )pbdoc") - .def(py::init>(), py::arg("cost") = 0.0, - py::arg("num_dets") = 0, py::arg("errors") = std::vector(), R"pbdoc( - The constructor for the `Node` class. - - Parameters - ---------- - cost : float, default=0.0 - The cost of the path to this node. - num_dets : int, default=0 - The number of detectors this search node has. - errors : list[int], default=empty - The list of error indices this search node has. - )pbdoc") - .def_readwrite("cost", &Node::cost, "The cost of the node.") - .def_readwrite("num_dets", &Node::num_dets, "The number of detectors this search node has.") - .def_readwrite("errors", &Node::errors, "The list of error indices this search node has.") - .def(py::self > py::self, - "Comparison operator for nodes based on cost. This is necessary to prioritize " - "lower-cost nodes during the search.") - .def("__str__", &Node::str); + py::class_(m, "TesseractDecoder", R"pbdoc( A class that implements the Tesseract decoding algorithm.