Skip to content

Commit fb2f13b

Browse files
authored
Merge pull request #30 from noajshu/codex/refactor-det-order-creation-logic
Expose detector order creation
2 parents f57496c + 9776627 commit fb2f13b

File tree

5 files changed

+116
-89
lines changed

5 files changed

+116
-89
lines changed

src/py/utils_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ def test_build_detector_graph():
4444
]
4545

4646

47+
def test_build_det_orders():
48+
assert tesseract_decoder.utils.build_det_orders(
49+
_DETECTOR_ERROR_MODEL, num_det_orders=1, seed=0
50+
) == [[0, 1]]
51+
52+
53+
def test_build_det_orders_no_bfs():
54+
assert tesseract_decoder.utils.build_det_orders(
55+
_DETECTOR_ERROR_MODEL, num_det_orders=1, det_order_bfs=False, seed=0
56+
) == [[0, 1]]
57+
58+
4759
def test_get_errors_from_dem():
4860
expected = "Error{cost=1.945910, symptom=Symptom{D0 }}, Error{cost=0.510826, symptom=Symptom{D0 D1 }}, Error{cost=1.098612, symptom=Symptom{D1 }}"
4961
assert (

src/tesseract_main.cc

Lines changed: 3 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,8 @@ struct Args {
170170

171171
// Sample orientations of the error model to use for the det priority
172172
{
173-
config.det_orders.resize(num_det_orders);
174-
std::mt19937_64 rng(det_order_seed);
175-
std::normal_distribution<double> dist(/*mean=*/0, /*stddev=*/1);
176-
177-
std::vector<std::vector<double>> detector_coords = get_detector_coords(config.dem);
178173
if (verbose) {
174+
auto detector_coords = get_detector_coords(config.dem);
179175
for (size_t d = 0; d < detector_coords.size(); ++d) {
180176
std::cout << "Detector D" << d << " coordinate (";
181177
size_t e = std::min(3ul, detector_coords[d].size());
@@ -186,88 +182,8 @@ struct Args {
186182
std::cout << ")" << std::endl;
187183
}
188184
}
189-
190-
if (det_order_bfs) {
191-
auto graph = build_detector_graph(config.dem);
192-
std::uniform_int_distribution<size_t> dist_det(0, graph.size() - 1);
193-
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
194-
std::vector<size_t> perm;
195-
perm.reserve(graph.size());
196-
std::vector<char> visited(graph.size(), false);
197-
std::queue<size_t> q;
198-
size_t start = dist_det(rng);
199-
while (perm.size() < graph.size()) {
200-
if (!visited[start]) {
201-
visited[start] = true;
202-
q.push(start);
203-
perm.push_back(start);
204-
}
205-
while (!q.empty()) {
206-
size_t cur = q.front();
207-
q.pop();
208-
auto neigh = graph[cur];
209-
std::shuffle(neigh.begin(), neigh.end(), rng);
210-
for (size_t n : neigh) {
211-
if (!visited[n]) {
212-
visited[n] = true;
213-
q.push(n);
214-
perm.push_back(n);
215-
}
216-
}
217-
}
218-
if (perm.size() < graph.size()) {
219-
do {
220-
start = dist_det(rng);
221-
} while (visited[start]);
222-
}
223-
}
224-
std::vector<size_t> inv_perm(graph.size());
225-
for (size_t i = 0; i < perm.size(); ++i) {
226-
inv_perm[perm[i]] = i;
227-
}
228-
config.det_orders[det_order] = inv_perm;
229-
}
230-
} else {
231-
std::vector<double> inner_products(config.dem.count_detectors());
232-
233-
if (!detector_coords.size() || !detector_coords.at(0).size()) {
234-
// If there are no detector coordinates, just use the standard
235-
// ordering of the indices.
236-
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
237-
config.det_orders[det_order].resize(config.dem.count_detectors());
238-
std::iota(config.det_orders[det_order].begin(), config.det_orders[det_order].end(), 0);
239-
}
240-
241-
} else {
242-
// Use the coordinates to order the detectors based on a random
243-
// orientation
244-
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
245-
// Sample a direction
246-
std::vector<double> orientation_vector;
247-
for (size_t i = 0; i < detector_coords.at(0).size(); ++i) {
248-
orientation_vector.push_back(dist(rng));
249-
}
250-
251-
for (size_t i = 0; i < detector_coords.size(); ++i) {
252-
inner_products[i] = 0;
253-
for (size_t j = 0; j < orientation_vector.size(); ++j) {
254-
inner_products[i] += detector_coords[i][j] * orientation_vector[j];
255-
}
256-
}
257-
std::vector<size_t> perm(config.dem.count_detectors());
258-
std::iota(perm.begin(), perm.end(), 0);
259-
std::sort(perm.begin(), perm.end(), [&](const size_t& i, const size_t& j) {
260-
return inner_products[i] > inner_products[j];
261-
});
262-
// Invert the permutation
263-
std::vector<size_t> inv_perm(config.dem.count_detectors());
264-
for (size_t i = 0; i < perm.size(); ++i) {
265-
inv_perm[perm[i]] = i;
266-
}
267-
config.det_orders[det_order] = inv_perm;
268-
}
269-
}
270-
}
185+
config.det_orders =
186+
build_det_orders(config.dem, num_det_orders, det_order_bfs, det_order_seed);
271187
}
272188

273189
if (sample_num_shots > 0) {

src/utils.cc

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
#include <filesystem>
2020
#include <fstream>
2121
#include <iostream>
22+
#include <numeric>
23+
#include <queue>
2224
#include <random>
2325
#include <string>
2426

2527
#include "common.h"
2628
#include "stim.h"
2729

28-
std::vector<std::vector<double>> get_detector_coords(stim::DetectorErrorModel& dem) {
30+
std::vector<std::vector<double>> get_detector_coords(const stim::DetectorErrorModel& dem) {
2931
std::vector<std::vector<double>> detector_coords;
3032
for (const stim::DemInstruction& instruction : dem.flattened().instructions) {
3133
switch (instruction.type) {
@@ -79,6 +81,91 @@ std::vector<std::vector<size_t>> build_detector_graph(const stim::DetectorErrorM
7981
return neighbors;
8082
}
8183

84+
std::vector<std::vector<size_t>> build_det_orders(const stim::DetectorErrorModel& dem,
85+
size_t num_det_orders, bool det_order_bfs,
86+
uint64_t seed) {
87+
std::vector<std::vector<size_t>> det_orders(num_det_orders);
88+
std::mt19937_64 rng(seed);
89+
std::normal_distribution<double> dist(0, 1);
90+
91+
auto detector_coords = get_detector_coords(dem);
92+
93+
if (det_order_bfs) {
94+
auto graph = build_detector_graph(dem);
95+
std::uniform_int_distribution<size_t> dist_det(0, graph.size() - 1);
96+
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
97+
std::vector<size_t> perm;
98+
perm.reserve(graph.size());
99+
std::vector<char> visited(graph.size(), false);
100+
std::queue<size_t> q;
101+
size_t start = dist_det(rng);
102+
while (perm.size() < graph.size()) {
103+
if (!visited[start]) {
104+
visited[start] = true;
105+
q.push(start);
106+
perm.push_back(start);
107+
}
108+
while (!q.empty()) {
109+
size_t cur = q.front();
110+
q.pop();
111+
auto neigh = graph[cur];
112+
std::shuffle(neigh.begin(), neigh.end(), rng);
113+
for (size_t n : neigh) {
114+
if (!visited[n]) {
115+
visited[n] = true;
116+
q.push(n);
117+
perm.push_back(n);
118+
}
119+
}
120+
}
121+
if (perm.size() < graph.size()) {
122+
do {
123+
start = dist_det(rng);
124+
} while (visited[start]);
125+
}
126+
}
127+
std::vector<size_t> inv_perm(graph.size());
128+
for (size_t i = 0; i < perm.size(); ++i) {
129+
inv_perm[perm[i]] = i;
130+
}
131+
det_orders[det_order] = inv_perm;
132+
}
133+
} else {
134+
std::vector<double> inner_products(dem.count_detectors());
135+
if (!detector_coords.size() || !detector_coords.at(0).size()) {
136+
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
137+
det_orders[det_order].resize(dem.count_detectors());
138+
std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0);
139+
}
140+
} else {
141+
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
142+
std::vector<double> orientation_vector;
143+
for (size_t i = 0; i < detector_coords.at(0).size(); ++i) {
144+
orientation_vector.push_back(dist(rng));
145+
}
146+
147+
for (size_t i = 0; i < detector_coords.size(); ++i) {
148+
inner_products[i] = 0;
149+
for (size_t j = 0; j < orientation_vector.size(); ++j) {
150+
inner_products[i] += detector_coords[i][j] * orientation_vector[j];
151+
}
152+
}
153+
std::vector<size_t> perm(dem.count_detectors());
154+
std::iota(perm.begin(), perm.end(), 0);
155+
std::sort(perm.begin(), perm.end(), [&](const size_t& i, const size_t& j) {
156+
return inner_products[i] > inner_products[j];
157+
});
158+
std::vector<size_t> inv_perm(dem.count_detectors());
159+
for (size_t i = 0; i < perm.size(); ++i) {
160+
inv_perm[perm[i]] = i;
161+
}
162+
det_orders[det_order] = inv_perm;
163+
}
164+
}
165+
}
166+
return det_orders;
167+
}
168+
82169
bool sampling_from_dem(uint64_t seed, size_t num_shots, stim::DetectorErrorModel dem,
83170
std::vector<stim::SparseShot>& shots) {
84171
stim::DemSampler<stim::MAX_BITWORD_WIDTH> sampler(dem, std::mt19937_64{seed}, num_shots);

src/utils.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,16 @@
2828

2929
constexpr const double EPSILON = 1e-7;
3030

31-
std::vector<std::vector<double>> get_detector_coords(stim::DetectorErrorModel& dem);
31+
std::vector<std::vector<double>> get_detector_coords(const stim::DetectorErrorModel& dem);
3232

3333
// Builds an adjacency list graph where two detectors share an edge iff an error
3434
// in the model activates them both.
3535
std::vector<std::vector<size_t>> build_detector_graph(const stim::DetectorErrorModel& dem);
3636

37+
std::vector<std::vector<size_t>> build_det_orders(const stim::DetectorErrorModel& dem,
38+
size_t num_det_orders, bool det_order_bfs = true,
39+
uint64_t seed = 0);
40+
3741
const double INF = std::numeric_limits<double>::infinity();
3842

3943
bool sampling_from_dem(uint64_t seed, size_t num_shots, stim::DetectorErrorModel dem,

src/utils.pybind.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ void add_utils_module(py::module &root) {
4242
return build_detector_graph(input_dem);
4343
},
4444
py::arg("dem"));
45+
m.def(
46+
"build_det_orders",
47+
[](py::object dem, size_t num_det_orders, bool det_order_bfs, uint64_t seed) {
48+
auto input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
49+
return build_det_orders(input_dem, num_det_orders, det_order_bfs, seed);
50+
},
51+
py::arg("dem"), py::arg("num_det_orders"), py::arg("det_order_bfs") = true,
52+
py::arg("seed") = 0);
4553
m.def(
4654
"get_errors_from_dem",
4755
[](py::object dem) {

0 commit comments

Comments
 (0)