Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/py/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,25 @@ def test_build_det_orders():
) == [[0, 1]]


def test_build_det_orders_no_bfs():
def test_build_det_orders_coordinate():
assert tesseract_decoder.utils.build_det_orders(
_DETECTOR_ERROR_MODEL, num_det_orders=1, det_order_bfs=False, seed=0
_DETECTOR_ERROR_MODEL,
num_det_orders=1,
method=tesseract_decoder.utils.DetOrder.DetCoordinate,
seed=0,
) == [[0, 1]]


def test_build_det_orders_index():
res = tesseract_decoder.utils.build_det_orders(
_DETECTOR_ERROR_MODEL,
num_det_orders=1,
method=tesseract_decoder.utils.DetOrder.DetIndex,
seed=0,
)
assert res == [[0, 1]] or res == [[1, 0]]


def test_get_errors_from_dem():
expected = "Error{cost=1.945910, symptom=Symptom{detectors=[0], observables=[]}}, Error{cost=0.510826, symptom=Symptom{detectors=[0 1], observables=[]}}, Error{cost=1.098612, symptom=Symptom{detectors=[1], observables=[]}}"
assert (
Expand Down
4 changes: 2 additions & 2 deletions src/tesseract.pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ TesseractConfig tesseract_config_maker_no_dem(
double det_penalty = 0.0, bool create_visualization = false) {
stim::DetectorErrorModel empty_dem;
if (det_orders.empty()) {
det_orders = build_det_orders(empty_dem, 20, /*det_order_bfs=*/true, 2384753);
det_orders = build_det_orders(empty_dem, 20, DetOrder::DetBFS, 2384753);
}
return TesseractConfig({empty_dem, det_beam, beam_climbing, no_revisit_dets,
at_most_two_errors_per_detector, verbose, merge_errors, pqlimit,
Expand All @@ -57,7 +57,7 @@ TesseractConfig tesseract_config_maker(
double det_penalty = 0.0, bool create_visualization = false) {
stim::DetectorErrorModel input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
if (det_orders.empty()) {
det_orders = build_det_orders(input_dem, 20, true, 2384753);
det_orders = build_det_orders(input_dem, 20, DetOrder::DetBFS, 2384753);
}
return TesseractConfig({input_dem, det_beam, beam_climbing, no_revisit_dets,
at_most_two_errors_per_detector, verbose, merge_errors, pqlimit,
Expand Down
44 changes: 27 additions & 17 deletions src/tesseract_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ struct Args {
// Manifold orientation options
uint64_t det_order_seed;
size_t num_det_orders = 10;
bool det_order_bfs = true;
bool det_order_bfs = false;
bool det_order_index = false;
bool det_order_coordinate = false;

// Sampling options
size_t sample_num_shots = 0;
Expand Down Expand Up @@ -88,6 +90,12 @@ struct Args {
throw std::invalid_argument("Must provide at least one of --circuit or --dem");
}

int det_order_flags = int(det_order_bfs) + int(det_order_index) + int(det_order_coordinate);
if (det_order_flags > 1) {
throw std::invalid_argument(
"Only one of --det-order-bfs, --det-order-index, or --det-order-coordinate may be set.");
}

int num_data_sources = int(sample_num_shots > 0) + int(!in_fname.empty());
if (num_data_sources != 1) {
throw std::invalid_argument("Requires exactly 1 source of shots.");
Expand Down Expand Up @@ -180,8 +188,13 @@ struct Args {
std::cout << ")" << std::endl;
}
}
config.det_orders =
build_det_orders(config.dem, num_det_orders, det_order_bfs, det_order_seed);
DetOrder order = DetOrder::DetBFS;
if (det_order_index) {
order = DetOrder::DetIndex;
} else if (det_order_coordinate) {
order = DetOrder::DetCoordinate;
}
config.det_orders = build_det_orders(config.dem, num_det_orders, order, det_order_seed);
}

if (sample_num_shots > 0) {
Expand Down Expand Up @@ -296,21 +309,18 @@ int main(int argc, char* argv[]) {
.metavar("N")
.default_value(size_t(1))
.store_into(args.num_det_orders);
program.add_argument("--no-det-order-bfs")
.help("Disable BFS-based detector ordering and use geometric orientation")
.default_value(true)
.implicit_value(false)
.store_into(args.det_order_bfs);
program.add_argument("--det-order-bfs")
.action([&](auto const&) {
std::cout << "BFS-based detector ordering is the default now; "
"--det-order-bfs is ignored."
<< std::endl;
})
.default_value(true)
.implicit_value(true)
.store_into(args.det_order_bfs)
.hidden();
.help("Use BFS-based detector ordering (default if no method specified)")
.flag()
.store_into(args.det_order_bfs);
program.add_argument("--det-order-index")
.help("Randomly choose increasing or decreasing detector index order")
.flag()
.store_into(args.det_order_index);
program.add_argument("--det-order-coordinate")
.help("Random geometric detector orientation ordering")
.flag()
.store_into(args.det_order_coordinate);
program.add_argument("--det-order-seed")
.help(
"Seed used when initializing the random detector traversal "
Expand Down
19 changes: 16 additions & 3 deletions src/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ std::vector<std::vector<size_t>> build_detector_graph(const stim::DetectorErrorM
}

std::vector<std::vector<size_t>> build_det_orders(const stim::DetectorErrorModel& dem,
size_t num_det_orders, bool det_order_bfs,
size_t num_det_orders, DetOrder method,
uint64_t seed) {
std::vector<std::vector<size_t>> det_orders(num_det_orders);
std::mt19937_64 rng(seed);
std::normal_distribution<double> dist(0, 1);

auto detector_coords = get_detector_coords(dem);

if (det_order_bfs) {
if (method == DetOrder::DetBFS) {
auto graph = build_detector_graph(dem);
std::uniform_int_distribution<size_t> dist_det(0, graph.size() - 1);
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
Expand Down Expand Up @@ -131,7 +131,7 @@ std::vector<std::vector<size_t>> build_det_orders(const stim::DetectorErrorModel
}
det_orders[det_order] = inv_perm;
}
} else {
} else if (method == DetOrder::DetCoordinate) {
std::vector<double> inner_products(dem.count_detectors());
if (!detector_coords.size() || !detector_coords.at(0).size()) {
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
Expand Down Expand Up @@ -163,6 +163,19 @@ std::vector<std::vector<size_t>> build_det_orders(const stim::DetectorErrorModel
det_orders[det_order] = inv_perm;
}
}
} else if (method == DetOrder::DetIndex) {
std::uniform_int_distribution<int> dist_bool(0, 1);
size_t n = dem.count_detectors();
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
det_orders[det_order].resize(n);
if (dist_bool(rng)) {
for (size_t i = 0; i < n; ++i) {
det_orders[det_order][i] = n - 1 - i;
}
} else {
std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0);
}
}
}
return det_orders;
}
Expand Down
5 changes: 4 additions & 1 deletion src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ std::vector<std::vector<double>> get_detector_coords(const stim::DetectorErrorMo
// in the model activates them both.
std::vector<std::vector<size_t>> build_detector_graph(const stim::DetectorErrorModel& dem);

enum class DetOrder { DetBFS, DetIndex, DetCoordinate };

std::vector<std::vector<size_t>> build_det_orders(const stim::DetectorErrorModel& dem,
size_t num_det_orders, bool det_order_bfs = true,
size_t num_det_orders,
DetOrder method = DetOrder::DetBFS,
uint64_t seed = 0);

const double INF = std::numeric_limits<double>::infinity();
Expand Down
24 changes: 16 additions & 8 deletions src/utils.pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ void add_utils_module(py::module& root) {
m.attr("INF") = INF;
m.doc() = "A representation of infinity for floating point numbers.";

py::enum_<DetOrder>(m, "DetOrder", "Detector ordering methods")
.value("DetBFS", DetOrder::DetBFS)
.value("DetIndex", DetOrder::DetIndex)
.value("DetCoordinate", DetOrder::DetCoordinate)
.export_values();

m.def(
"get_detector_coords",
[](py::object dem) {
Expand Down Expand Up @@ -79,11 +85,11 @@ void add_utils_module(py::module& root) {
)pbdoc");
m.def(
"build_det_orders",
[](py::object dem, size_t num_det_orders, bool det_order_bfs, uint64_t seed) {
[](py::object dem, size_t num_det_orders, DetOrder method, uint64_t seed) {
auto input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
return build_det_orders(input_dem, num_det_orders, det_order_bfs, seed);
return build_det_orders(input_dem, num_det_orders, method, seed);
},
py::arg("dem"), py::arg("num_det_orders"), py::arg("det_order_bfs") = true,
py::arg("dem"), py::arg("num_det_orders"), py::arg("method") = DetOrder::DetBFS,
py::arg("seed") = 0, R"pbdoc(
Generates various detector orderings for decoding.

Expand All @@ -93,17 +99,19 @@ void add_utils_module(py::module& root) {
The detector error model to generate orders for.
num_det_orders : int
The number of detector orderings to generate.
det_order_bfs : bool, default=True
If True, uses a Breadth-First Search (BFS) to generate
the orders. If False, uses a randomized ordering.
method : tesseract_decoder.utils.DetOrder, default=tesseract_decoder.utils.DetOrder.DetBFS
Strategy for ordering detectors. ``DetBFS`` performs a breadth-first
traversal, ``DetCoordinate`` uses randomized geometric orientations,
and ``DetIndex`` chooses either increasing or decreasing detector
index order at random.
seed : int, default=0
A seed for the random number generator.

Returns
-------
list[list[int]]
A list of detector orderings. Each inner list is a
permutation of the detector indices.
A list of detector orderings. Each inner list maps a detector index
to its position in the ordering.
)pbdoc");
m.def(
"get_errors_from_dem",
Expand Down
Loading