diff --git a/src/common.pybind.h b/src/common.pybind.h index 965b658..852c823 100644 --- a/src/common.pybind.h +++ b/src/common.pybind.h @@ -1,38 +1,41 @@ #ifndef TESSERACT_COMMON_PY_H #define TESSERACT_COMMON_PY_H -#include - +#include #include #include -#include + +#include #include "common.h" namespace py = pybind11; -void add_common_module(py::module &root) -{ - auto m = root.def_submodule("common", "classes commonly used by the decoder"); - - // TODO: add as_dem_instruction_targets - py::class_(m, "Symptom") - .def(py::init, common::ObservablesMask>(), py::arg("detectors") = std::vector(), py::arg("observables") = 0) - .def_readwrite("detectors", &common::Symptom::detectors) - .def_readwrite("observables", &common::Symptom::observables) - .def("__str__", &common::Symptom::str) - .def(py::self == py::self) - .def(py::self != py::self); - - // TODO: add constructor with stim::DemInstruction. - py::class_(m, "Error") - .def_readwrite("likelihood_cost", &common::Error::likelihood_cost) - .def_readwrite("probability", &common::Error::probability) - .def_readwrite("symptom", &common::Error::symptom) - .def("__str__", &common::Error::str) - .def(py::init<>()) - .def(py::init &, common::ObservablesMask, std::vector &>()) - .def(py::init &, common::ObservablesMask, std::vector &>()); +void add_common_module(py::module &root) { + auto m = root.def_submodule("common", "classes commonly used by the decoder"); + + // TODO: add as_dem_instruction_targets + py::class_(m, "Symptom") + .def(py::init, common::ObservablesMask>(), + py::arg("detectors") = std::vector(), + py::arg("observables") = 0) + .def_readwrite("detectors", &common::Symptom::detectors) + .def_readwrite("observables", &common::Symptom::observables) + .def("__str__", &common::Symptom::str) + .def(py::self == py::self) + .def(py::self != py::self); + + // TODO: add constructor with stim::DemInstruction. + py::class_(m, "Error") + .def_readwrite("likelihood_cost", &common::Error::likelihood_cost) + .def_readwrite("probability", &common::Error::probability) + .def_readwrite("symptom", &common::Error::symptom) + .def("__str__", &common::Error::str) + .def(py::init<>()) + .def(py::init &, common::ObservablesMask, + std::vector &>()) + .def(py::init &, common::ObservablesMask, + std::vector &>()); } #endif diff --git a/src/tesseract.cc b/src/tesseract.cc index 65d29b3..2c002b4 100644 --- a/src/tesseract.cc +++ b/src/tesseract.cc @@ -48,6 +48,11 @@ TesseractDecoder::TesseractDecoder(TesseractConfig config_) : config(config_) { } assert(this->config.det_orders.size()); errors = get_errors_from_dem(config.dem.flattened()); + if (config.verbose) { + for (auto& error : errors) { + std::cout << error.str() << std::endl; + } + } num_detectors = config.dem.count_detectors(); num_errors = config.dem.count_errors(); initialize_structures(config.dem.count_detectors()); @@ -87,21 +92,24 @@ void TesseractDecoder::initialize_structures(size_t num_detectors) { struct VectorCharHash { size_t operator()(const std::vector& v) const { - size_t seed = v.size(); // Still good practice to incorporate vector size + size_t seed = v.size(); // Still good practice to incorporate vector size - // Iterate over char elements. Accessing 'b_val' is now a direct memory read. + // Iterate over char elements. Accessing 'b_val' is now a direct memory + // read. for (char b_val : v) { // The polynomial rolling hash with 31 (or another prime) // 'b_val' is already a char (an 8-bit integer). - // static_cast(b_val) ensures it's promoted to size_t before arithmetic. - // This cast is efficient (likely a simple register extension/move). + // static_cast(b_val) ensures it's promoted to size_t before + // arithmetic. This cast is efficient (likely a simple register + // extension/move). seed = seed * 31 + static_cast(b_val); } return seed; } }; -void TesseractDecoder::decode_to_errors(const std::vector& detections) { +void TesseractDecoder::decode_to_errors( + const std::vector& detections) { std::vector best_errors; double best_cost = std::numeric_limits::max(); assert(config.det_orders.size()); @@ -254,6 +262,18 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, if (node.num_dets == 0) { if (config.verbose) { + std::cout << "activated_errors = "; + for (size_t oei : node.errs) { + std::cout << oei << ", "; + } + std::cout << std::endl; + std::cout << "activated_dets = "; + for (size_t d = 0; d < num_detectors; ++d) { + if (node.dets[d]) { + std::cout << d << ", "; + } + } + std::cout << std::endl; std::cout.precision(13); std::cout << "Decoding complete. Cost: " << node.cost << " num_pq_pushed = " << num_pq_pushed << std::endl; @@ -278,10 +298,18 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, std::cout << "num_dets = " << node.num_dets << " max_num_dets = " << max_num_dets << " cost = " << node.cost << std::endl; + std::cout << "activated_errors = "; for (size_t oei : node.errs) { std::cout << oei << ", "; } std::cout << std::endl; + std::cout << "activated_dets = "; + for (size_t d = 0; d < num_detectors; ++d) { + if (node.dets[d]) { + std::cout << d << ", "; + } + } + std::cout << std::endl; } if (node.num_dets < min_num_dets) { diff --git a/src/tesseract.pybind.cc b/src/tesseract.pybind.cc index 618a1cd..33dcc62 100644 --- a/src/tesseract.pybind.cc +++ b/src/tesseract.pybind.cc @@ -1,9 +1,6 @@ #include -#include "pybind11/detail/common.h" #include "common.pybind.h" +#include "pybind11/detail/common.h" -PYBIND11_MODULE(tesseract_py, m) -{ - add_common_module(m); -} +PYBIND11_MODULE(tesseract_py, m) { add_common_module(m); } diff --git a/src/tesseract_main.cc b/src/tesseract_main.cc index 0aff8bc..eb209f4 100644 --- a/src/tesseract_main.cc +++ b/src/tesseract_main.cc @@ -179,34 +179,58 @@ struct Args { std::vector> detector_coords = get_detector_coords(config.dem); + if (verbose) { + for (size_t d = 0; d < detector_coords.size(); ++d) { + std::cout << "Detector D" << d << " coordinate ("; + size_t e = std::min(3ul, detector_coords[d].size()); + for (size_t i = 0; i < e; ++i) { + std::cout << detector_coords[d][i]; + if (i + 1 < e) std::cout << ", "; + } + std::cout << ")" << std::endl; + } + } std::vector inner_products(config.dem.count_detectors()); - for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { - // Sample a direction - std::vector orientation_vector; - for (size_t i = 0; i < detector_coords.at(0).size(); ++i) { - orientation_vector.push_back(dist(rng)); + if (!detector_coords.size() or !detector_coords.at(0).size()) { + // If there are no detector coordinates, just use the standard ordering + // of the indices. + for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { + config.det_orders.emplace_back(); + std::iota(config.det_orders.back().begin(), + config.det_orders.front().end(), 0); } + } else { + // Use the coordinates to order the detectors based on a random + // orientation + for (size_t det_order = 0; det_order < num_det_orders; ++det_order) { + // Sample a direction + std::vector orientation_vector; + for (size_t i = 0; i < detector_coords.at(0).size(); ++i) { + orientation_vector.push_back(dist(rng)); + } - for (size_t i = 0; i < detector_coords.size(); ++i) { - inner_products[i] = 0; - for (size_t j = 0; j < orientation_vector.size(); ++j) { - inner_products[i] += detector_coords[i][j] * orientation_vector[j]; + for (size_t i = 0; i < detector_coords.size(); ++i) { + inner_products[i] = 0; + for (size_t j = 0; j < orientation_vector.size(); ++j) { + inner_products[i] += + detector_coords[i][j] * orientation_vector[j]; + } } + std::vector perm(config.dem.count_detectors()); + std::iota(perm.begin(), perm.end(), 0); + std::sort(perm.begin(), perm.end(), + [&](const size_t& i, const size_t& j) { + return inner_products[i] > inner_products[j]; + }); + // Invert the permutation + std::vector inv_perm(config.dem.count_detectors()); + for (size_t i = 0; i < perm.size(); ++i) { + inv_perm[perm[i]] = i; + } + config.det_orders[det_order] = inv_perm; } - std::vector perm(config.dem.count_detectors()); - std::iota(perm.begin(), perm.end(), 0); - std::sort(perm.begin(), perm.end(), - [&](const size_t& i, const size_t& j) { - return inner_products[i] > inner_products[j]; - }); - // Invert the permutation - std::vector inv_perm(config.dem.count_detectors()); - for (size_t i = 0; i < perm.size(); ++i) { - inv_perm[perm[i]] = i; - } - config.det_orders[det_order] = inv_perm; } } diff --git a/src/utils.cc b/src/utils.cc index 92ea1d5..10bca4d 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -38,8 +38,8 @@ std::vector> get_detector_coords( } case stim::DemInstructionType::DEM_DETECTOR: { std::vector coord; - for (const stim::DemTarget& t : instruction.target_data) { - coord.push_back(t.val()); + for (const double& t : instruction.arg_data) { + coord.push_back(t); } detector_coords.push_back(coord); break; diff --git a/viz/3dviewer.html b/viz/3dviewer.html new file mode 100644 index 0000000..39521b0 --- /dev/null +++ b/viz/3dviewer.html @@ -0,0 +1,319 @@ + + + + + Tesseract Decoder Viewer + + + +
+

+

+ + +

+ + +
+ + +
+ +
+ +
+
+ +
+
All Detectors
+
Activated Detectors
+
Activated Errors
+
+
+ + + + + diff --git a/viz/to_json.py b/viz/to_json.py new file mode 100644 index 0000000..567b431 --- /dev/null +++ b/viz/to_json.py @@ -0,0 +1,103 @@ +import json +import argparse +import re +import numpy as np + +def parse_implicit_list(line, prefix): + if not line.startswith(prefix): + raise ValueError(f"Expected line to start with '{prefix}', got: {line}") + list_part = line[len(prefix):].strip().rstrip(',') + if not list_part: + return [] + return [int(x.strip()) for x in list_part.split(',') if x.strip()] + +def parse_logfile(filepath): + detector_coords = {} + error_to_detectors = [] + frames = [] + + with open(filepath, 'r') as f: + lines = f.readlines() + + i = 0 + while i < len(lines): + line = lines[i].strip() + + if not any(line.startswith(s) for s in ['Error', 'Detector', 'activated_errors', 'activated_dets']): + continue + + if line.startswith("Detector D"): + match = re.match(r'Detector D(\d+) coordinate \(([-\d.]+), ([-\d.]+), ([-\d.]+)\)', line) + if match: + idx = int(match.group(1)) + coord = tuple(float(match.group(j)) for j in range(2, 5)) + detector_coords[idx] = coord + + elif line.startswith("Error{"): + match = re.search(r'Symptom\{([^\}]+)\}', line) + if match: + dets = match.group(1).split() + det_indices = [int(d[1:]) for d in dets if d.startswith('D')] + error_to_detectors.append(det_indices) + + elif line.startswith("activated_errors"): + try: + error_line = lines[i].strip() + det_line = lines[i + 1].strip() + + activated_errors = parse_implicit_list(error_line, "activated_errors =") + activated_dets = parse_implicit_list(det_line, "activated_dets =") + + frame = { + "activated": activated_dets, + "activated_errors": activated_errors + } + frames.append(frame) + i += 1 + except Exception as e: + print(f"\n⚠️ Error parsing frame at lines {i}-{i+1}: {e}") + print(f" {lines[i].strip()}") + print(f" {lines[i+1].strip() if i+1 < len(lines) else ''}") + i += 1 + + if not detector_coords: + raise RuntimeError("No detectors parsed!") + + coords_array = np.array(list(detector_coords.values())) + mean_coord = coords_array.mean(axis=0) + for k in detector_coords: + detector_coords[k] = (np.array(detector_coords[k]) - mean_coord).tolist() + + error_coords = {} + for i, det_list in enumerate(error_to_detectors): + try: + pts = np.array([detector_coords[d] for d in det_list if d in detector_coords]) + if len(pts) > 0: + error_coords[i] = pts.mean(axis=0).tolist() + except KeyError as e: + print(f"⚠️ Skipping error {i}: unknown detector {e}") + + error_to_detectors_dict = {str(i): dets for i, dets in enumerate(error_to_detectors)} + + return { + "detectorCoords": {str(k): v for k, v in detector_coords.items()}, + "errorCoords": {str(k): v for k, v in error_coords.items()}, + "errorToDetectors": error_to_detectors_dict, + "frames": frames + } + +def main(): + parser = argparse.ArgumentParser(description="Convert a tesseract decoder logfile to a 3D visualization JSON.") + parser.add_argument("logfile", help="Path to the logfile.txt") + parser.add_argument("-o", "--output", default="tesseract_visualization.json", help="Output JSON filename") + args = parser.parse_args() + + data = parse_logfile(args.logfile) + + with open(args.output, 'w') as f: + json.dump(data, f, indent=2) + + print(f"✅ JSON written to {args.output} with {len(data['frames'])} frames and {len(data['errorCoords'])} error coords.") + +if __name__ == "__main__": + main()