Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 28 additions & 25 deletions src/common.pybind.h
Original file line number Diff line number Diff line change
@@ -1,38 +1,41 @@
#ifndef TESSERACT_COMMON_PY_H
#define TESSERACT_COMMON_PY_H

#include <vector>

#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/operators.h>

#include <vector>

#include "common.h"

namespace py = pybind11;

void add_common_module(py::module &root)
{
auto m = root.def_submodule("common", "classes commonly used by the decoder");

// TODO: add as_dem_instruction_targets
py::class_<common::Symptom>(m, "Symptom")
.def(py::init<std::vector<int>, common::ObservablesMask>(), py::arg("detectors") = std::vector<int>(), py::arg("observables") = 0)
.def_readwrite("detectors", &common::Symptom::detectors)
.def_readwrite("observables", &common::Symptom::observables)
.def("__str__", &common::Symptom::str)
.def(py::self == py::self)
.def(py::self != py::self);

// TODO: add constructor with stim::DemInstruction.
py::class_<common::Error>(m, "Error")
.def_readwrite("likelihood_cost", &common::Error::likelihood_cost)
.def_readwrite("probability", &common::Error::probability)
.def_readwrite("symptom", &common::Error::symptom)
.def("__str__", &common::Error::str)
.def(py::init<>())
.def(py::init<double, std::vector<int> &, common::ObservablesMask, std::vector<bool> &>())
.def(py::init<double, double, std::vector<int> &, common::ObservablesMask, std::vector<bool> &>());
void add_common_module(py::module &root) {
auto m = root.def_submodule("common", "classes commonly used by the decoder");

// TODO: add as_dem_instruction_targets
py::class_<common::Symptom>(m, "Symptom")
.def(py::init<std::vector<int>, common::ObservablesMask>(),
py::arg("detectors") = std::vector<int>(),
py::arg("observables") = 0)
.def_readwrite("detectors", &common::Symptom::detectors)
.def_readwrite("observables", &common::Symptom::observables)
.def("__str__", &common::Symptom::str)
.def(py::self == py::self)
.def(py::self != py::self);

// TODO: add constructor with stim::DemInstruction.
py::class_<common::Error>(m, "Error")
.def_readwrite("likelihood_cost", &common::Error::likelihood_cost)
.def_readwrite("probability", &common::Error::probability)
.def_readwrite("symptom", &common::Error::symptom)
.def("__str__", &common::Error::str)
.def(py::init<>())
.def(py::init<double, std::vector<int> &, common::ObservablesMask,
std::vector<bool> &>())
.def(py::init<double, double, std::vector<int> &, common::ObservablesMask,
std::vector<bool> &>());
}

#endif
38 changes: 33 additions & 5 deletions src/tesseract.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ TesseractDecoder::TesseractDecoder(TesseractConfig config_) : config(config_) {
}
assert(this->config.det_orders.size());
errors = get_errors_from_dem(config.dem.flattened());
if (config.verbose) {
for (auto& error : errors) {
std::cout << error.str() << std::endl;
}
}
num_detectors = config.dem.count_detectors();
num_errors = config.dem.count_errors();
initialize_structures(config.dem.count_detectors());
Expand Down Expand Up @@ -87,21 +92,24 @@ void TesseractDecoder::initialize_structures(size_t num_detectors) {

struct VectorCharHash {
size_t operator()(const std::vector<char>& v) const {
size_t seed = v.size(); // Still good practice to incorporate vector size
size_t seed = v.size(); // Still good practice to incorporate vector size

// Iterate over char elements. Accessing 'b_val' is now a direct memory read.
// Iterate over char elements. Accessing 'b_val' is now a direct memory
// read.
for (char b_val : v) {
// The polynomial rolling hash with 31 (or another prime)
// 'b_val' is already a char (an 8-bit integer).
// static_cast<size_t>(b_val) ensures it's promoted to size_t before arithmetic.
// This cast is efficient (likely a simple register extension/move).
// static_cast<size_t>(b_val) ensures it's promoted to size_t before
// arithmetic. This cast is efficient (likely a simple register
// extension/move).
seed = seed * 31 + static_cast<size_t>(b_val);
}
return seed;
}
};

void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections) {
void TesseractDecoder::decode_to_errors(
const std::vector<uint64_t>& detections) {
std::vector<size_t> best_errors;
double best_cost = std::numeric_limits<double>::max();
assert(config.det_orders.size());
Expand Down Expand Up @@ -254,6 +262,18 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,

if (node.num_dets == 0) {
if (config.verbose) {
std::cout << "activated_errors = ";
for (size_t oei : node.errs) {
std::cout << oei << ", ";
}
std::cout << std::endl;
std::cout << "activated_dets = ";
for (size_t d = 0; d < num_detectors; ++d) {
if (node.dets[d]) {
std::cout << d << ", ";
}
}
std::cout << std::endl;
std::cout.precision(13);
std::cout << "Decoding complete. Cost: " << node.cost
<< " num_pq_pushed = " << num_pq_pushed << std::endl;
Expand All @@ -278,10 +298,18 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
std::cout << "num_dets = " << node.num_dets
<< " max_num_dets = " << max_num_dets << " cost = " << node.cost
<< std::endl;
std::cout << "activated_errors = ";
for (size_t oei : node.errs) {
std::cout << oei << ", ";
}
std::cout << std::endl;
std::cout << "activated_dets = ";
for (size_t d = 0; d < num_detectors; ++d) {
if (node.dets[d]) {
std::cout << d << ", ";
}
}
std::cout << std::endl;
}

if (node.num_dets < min_num_dets) {
Expand Down
7 changes: 2 additions & 5 deletions src/tesseract.pybind.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
#include <pybind11/pybind11.h>
#include "pybind11/detail/common.h"

#include "common.pybind.h"
#include "pybind11/detail/common.h"

PYBIND11_MODULE(tesseract_py, m)
{
add_common_module(m);
}
PYBIND11_MODULE(tesseract_py, m) { add_common_module(m); }
66 changes: 45 additions & 21 deletions src/tesseract_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,34 +179,58 @@ struct Args {

std::vector<std::vector<double>> detector_coords =
get_detector_coords(config.dem);
if (verbose) {
for (size_t d = 0; d < detector_coords.size(); ++d) {
std::cout << "Detector D" << d << " coordinate (";
size_t e = std::min(3ul, detector_coords[d].size());
for (size_t i = 0; i < e; ++i) {
std::cout << detector_coords[d][i];
if (i + 1 < e) std::cout << ", ";
}
std::cout << ")" << std::endl;
}
}

std::vector<double> inner_products(config.dem.count_detectors());

for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
// Sample a direction
std::vector<double> orientation_vector;
for (size_t i = 0; i < detector_coords.at(0).size(); ++i) {
orientation_vector.push_back(dist(rng));
if (!detector_coords.size() or !detector_coords.at(0).size()) {
// If there are no detector coordinates, just use the standard ordering
// of the indices.
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
config.det_orders.emplace_back();
std::iota(config.det_orders.back().begin(),
config.det_orders.front().end(), 0);
}
} else {
// Use the coordinates to order the detectors based on a random
// orientation
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
// Sample a direction
std::vector<double> orientation_vector;
for (size_t i = 0; i < detector_coords.at(0).size(); ++i) {
orientation_vector.push_back(dist(rng));
}

for (size_t i = 0; i < detector_coords.size(); ++i) {
inner_products[i] = 0;
for (size_t j = 0; j < orientation_vector.size(); ++j) {
inner_products[i] += detector_coords[i][j] * orientation_vector[j];
for (size_t i = 0; i < detector_coords.size(); ++i) {
inner_products[i] = 0;
for (size_t j = 0; j < orientation_vector.size(); ++j) {
inner_products[i] +=
detector_coords[i][j] * orientation_vector[j];
}
}
std::vector<size_t> perm(config.dem.count_detectors());
std::iota(perm.begin(), perm.end(), 0);
std::sort(perm.begin(), perm.end(),
[&](const size_t& i, const size_t& j) {
return inner_products[i] > inner_products[j];
});
// Invert the permutation
std::vector<size_t> inv_perm(config.dem.count_detectors());
for (size_t i = 0; i < perm.size(); ++i) {
inv_perm[perm[i]] = i;
}
config.det_orders[det_order] = inv_perm;
}
std::vector<size_t> perm(config.dem.count_detectors());
std::iota(perm.begin(), perm.end(), 0);
std::sort(perm.begin(), perm.end(),
[&](const size_t& i, const size_t& j) {
return inner_products[i] > inner_products[j];
});
// Invert the permutation
std::vector<size_t> inv_perm(config.dem.count_detectors());
for (size_t i = 0; i < perm.size(); ++i) {
inv_perm[perm[i]] = i;
}
config.det_orders[det_order] = inv_perm;
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ std::vector<std::vector<double>> get_detector_coords(
}
case stim::DemInstructionType::DEM_DETECTOR: {
std::vector<double> coord;
for (const stim::DemTarget& t : instruction.target_data) {
coord.push_back(t.val());
for (const double& t : instruction.arg_data) {
coord.push_back(t);
}
detector_coords.push_back(coord);
break;
Expand Down
Loading