Skip to content

Commit e9a0774

Browse files
committed
Merge remote-tracking branch 'origin'
2 parents 41f7c3e + 18b12cd commit e9a0774

File tree

2 files changed

+112
-86
lines changed

2 files changed

+112
-86
lines changed

src/py/utils_test.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
"""
2828
)
2929

30+
_DETECTOR_ERROR_MODEL_10 = stim.DetectorErrorModel(
31+
"\n".join(f"error(0.1) D{i}" for i in range(10))
32+
)
33+
3034

3135
def test_module_has_global_constants():
3236
assert tesseract_decoder.utils.EPSILON <= 1e-7
@@ -44,7 +48,7 @@ def test_build_detector_graph():
4448
]
4549

4650

47-
def test_build_det_orders():
51+
def test_build_det_orders_bfs():
4852
assert tesseract_decoder.utils.build_det_orders(
4953
_DETECTOR_ERROR_MODEL, num_det_orders=1, seed=0
5054
) == [[0, 1]]
@@ -61,12 +65,14 @@ def test_build_det_orders_coordinate():
6165

6266
def test_build_det_orders_index():
6367
res = tesseract_decoder.utils.build_det_orders(
64-
_DETECTOR_ERROR_MODEL,
68+
_DETECTOR_ERROR_MODEL_10,
6569
num_det_orders=1,
6670
method=tesseract_decoder.utils.DetOrder.DetIndex,
6771
seed=0,
6872
)
69-
assert res == [[0, 1]] or res == [[1, 0]]
73+
expected_asc = list(range(10))
74+
expected_desc = list(range(9, -1, -1))
75+
assert res == [expected_asc] or res == [expected_desc]
7076

7177

7278
def test_get_errors_from_dem():

src/utils.cc

Lines changed: 103 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -82,104 +82,124 @@ std::vector<std::vector<size_t>> build_detector_graph(const stim::DetectorErrorM
8282
return neighbors;
8383
}
8484

85-
std::vector<std::vector<size_t>> build_det_orders(const stim::DetectorErrorModel& dem,
86-
size_t num_det_orders, DetOrder method,
87-
uint64_t seed) {
85+
static std::vector<std::vector<size_t>> build_det_orders_bfs(const stim::DetectorErrorModel& dem,
86+
size_t num_det_orders,
87+
std::mt19937_64& rng) {
8888
std::vector<std::vector<size_t>> det_orders(num_det_orders);
89-
std::mt19937_64 rng(seed);
90-
std::normal_distribution<double> dist(0, 1);
91-
92-
auto detector_coords = get_detector_coords(dem);
93-
94-
if (method == DetOrder::DetBFS) {
95-
auto graph = build_detector_graph(dem);
96-
std::uniform_int_distribution<size_t> dist_det(0, graph.size() - 1);
97-
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
98-
std::vector<size_t> perm;
99-
perm.reserve(graph.size());
100-
std::vector<char> visited(graph.size(), false);
101-
std::queue<size_t> q;
102-
size_t start = dist_det(rng);
103-
while (perm.size() < graph.size()) {
104-
if (!visited[start]) {
105-
visited[start] = true;
106-
q.push(start);
107-
perm.push_back(start);
108-
}
109-
while (!q.empty()) {
110-
size_t cur = q.front();
111-
q.pop();
112-
auto neigh = graph[cur];
113-
std::shuffle(neigh.begin(), neigh.end(), rng);
114-
for (size_t n : neigh) {
115-
if (!visited[n]) {
116-
visited[n] = true;
117-
q.push(n);
118-
perm.push_back(n);
119-
}
89+
auto graph = build_detector_graph(dem);
90+
std::uniform_int_distribution<size_t> dist_det(0, graph.size() - 1);
91+
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
92+
std::vector<size_t> perm;
93+
perm.reserve(graph.size());
94+
std::vector<char> visited(graph.size(), false);
95+
std::queue<size_t> q;
96+
size_t start = dist_det(rng);
97+
while (perm.size() < graph.size()) {
98+
if (!visited[start]) {
99+
visited[start] = true;
100+
q.push(start);
101+
perm.push_back(start);
102+
}
103+
while (!q.empty()) {
104+
size_t cur = q.front();
105+
q.pop();
106+
auto neigh = graph[cur];
107+
std::shuffle(neigh.begin(), neigh.end(), rng);
108+
for (size_t n : neigh) {
109+
if (!visited[n]) {
110+
visited[n] = true;
111+
q.push(n);
112+
perm.push_back(n);
120113
}
121114
}
122-
if (perm.size() < graph.size()) {
123-
do {
124-
start = dist_det(rng);
125-
} while (visited[start]);
126-
}
127115
}
128-
std::vector<size_t> inv_perm(graph.size());
129-
for (size_t i = 0; i < perm.size(); ++i) {
130-
inv_perm[perm[i]] = i;
116+
if (perm.size() < graph.size()) {
117+
do {
118+
start = dist_det(rng);
119+
} while (visited[start]);
131120
}
132-
det_orders[det_order] = inv_perm;
133121
}
134-
} else if (method == DetOrder::DetCoordinate) {
135-
std::vector<double> inner_products(dem.count_detectors());
136-
if (!detector_coords.size() || !detector_coords.at(0).size()) {
137-
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
138-
det_orders[det_order].resize(dem.count_detectors());
139-
std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0);
140-
}
141-
} else {
142-
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
143-
std::vector<double> orientation_vector;
144-
for (size_t i = 0; i < detector_coords.at(0).size(); ++i) {
145-
orientation_vector.push_back(dist(rng));
146-
}
122+
std::vector<size_t> inv_perm(graph.size());
123+
for (size_t i = 0; i < perm.size(); ++i) {
124+
inv_perm[perm[i]] = i;
125+
}
126+
det_orders[det_order] = inv_perm;
127+
}
128+
return det_orders;
129+
}
147130

148-
for (size_t i = 0; i < detector_coords.size(); ++i) {
149-
inner_products[i] = 0;
150-
for (size_t j = 0; j < orientation_vector.size(); ++j) {
151-
inner_products[i] += detector_coords[i][j] * orientation_vector[j];
152-
}
153-
}
154-
std::vector<size_t> perm(dem.count_detectors());
155-
std::iota(perm.begin(), perm.end(), 0);
156-
std::sort(perm.begin(), perm.end(), [&](const size_t& i, const size_t& j) {
157-
return inner_products[i] > inner_products[j];
158-
});
159-
std::vector<size_t> inv_perm(dem.count_detectors());
160-
for (size_t i = 0; i < perm.size(); ++i) {
161-
inv_perm[perm[i]] = i;
162-
}
163-
det_orders[det_order] = inv_perm;
131+
static std::vector<std::vector<size_t>> build_det_orders_coordinate(
132+
const stim::DetectorErrorModel& dem, size_t num_det_orders, std::mt19937_64& rng) {
133+
std::vector<std::vector<size_t>> det_orders(num_det_orders);
134+
auto detector_coords = get_detector_coords(dem);
135+
std::vector<double> inner_products(dem.count_detectors());
136+
std::normal_distribution<double> dist(0, 1);
137+
if (detector_coords.empty() || detector_coords.at(0).empty()) {
138+
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
139+
det_orders[det_order].resize(dem.count_detectors());
140+
std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0);
141+
}
142+
return det_orders;
143+
}
144+
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
145+
std::vector<double> orientation_vector;
146+
for (size_t i = 0; i < detector_coords.at(0).size(); ++i) {
147+
orientation_vector.push_back(dist(rng));
148+
}
149+
for (size_t i = 0; i < detector_coords.size(); ++i) {
150+
inner_products[i] = 0;
151+
for (size_t j = 0; j < orientation_vector.size(); ++j) {
152+
inner_products[i] += detector_coords[i][j] * orientation_vector[j];
164153
}
165154
}
166-
} else if (method == DetOrder::DetIndex) {
167-
std::uniform_int_distribution<int> dist_bool(0, 1);
168-
size_t n = dem.count_detectors();
169-
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
170-
det_orders[det_order].resize(n);
171-
if (dist_bool(rng)) {
172-
for (size_t i = 0; i < n; ++i) {
173-
det_orders[det_order][i] = n - 1 - i;
174-
}
175-
} else {
176-
std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0);
155+
std::vector<size_t> perm(dem.count_detectors());
156+
std::iota(perm.begin(), perm.end(), 0);
157+
std::sort(perm.begin(), perm.end(), [&](const size_t& i, const size_t& j) {
158+
return inner_products[i] > inner_products[j];
159+
});
160+
std::vector<size_t> inv_perm(dem.count_detectors());
161+
for (size_t i = 0; i < perm.size(); ++i) {
162+
inv_perm[perm[i]] = i;
163+
}
164+
det_orders[det_order] = inv_perm;
165+
}
166+
return det_orders;
167+
}
168+
169+
static std::vector<std::vector<size_t>> build_det_orders_index(const stim::DetectorErrorModel& dem,
170+
size_t num_det_orders,
171+
std::mt19937_64& rng) {
172+
std::vector<std::vector<size_t>> det_orders(num_det_orders);
173+
std::uniform_int_distribution<int> dist_bool(0, 1);
174+
size_t n = dem.count_detectors();
175+
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
176+
det_orders[det_order].resize(n);
177+
if (dist_bool(rng)) {
178+
for (size_t i = 0; i < n; ++i) {
179+
det_orders[det_order][i] = n - 1 - i;
177180
}
181+
} else {
182+
std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0);
178183
}
179184
}
180185
return det_orders;
181186
}
182187

188+
std::vector<std::vector<size_t>> build_det_orders(const stim::DetectorErrorModel& dem,
189+
size_t num_det_orders, DetOrder method,
190+
uint64_t seed) {
191+
std::mt19937_64 rng(seed);
192+
switch (method) {
193+
case DetOrder::DetBFS:
194+
return build_det_orders_bfs(dem, num_det_orders, rng);
195+
case DetOrder::DetCoordinate:
196+
return build_det_orders_coordinate(dem, num_det_orders, rng);
197+
case DetOrder::DetIndex:
198+
return build_det_orders_index(dem, num_det_orders, rng);
199+
}
200+
throw std::invalid_argument("Unknown det order method");
201+
}
202+
183203
bool sampling_from_dem(uint64_t seed, size_t num_shots, stim::DetectorErrorModel dem,
184204
std::vector<stim::SparseShot>& shots) {
185205
stim::DemSampler<stim::MAX_BITWORD_WIDTH> sampler(dem, std::mt19937_64{seed}, num_shots);

0 commit comments

Comments
 (0)