Skip to content

Commit 7f3187f

Browse files
committed
switching to index-based operations inside the search engine & added small optimization.
(the comments haven't been updated yet.)
1 parent dcad24d commit 7f3187f

File tree

1 file changed

+110
-73
lines changed

1 file changed

+110
-73
lines changed

src/search.cpp

Lines changed: 110 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@
3434
* - Space Complexity: O(1) (No additional memory allocated).
3535
*/
3636

37-
static inline int baseHeuristicFunc(int current_node, int goal_node)
37+
static inline int baseHeuristicFunc(size_t current_idx, size_t goal_idx, const graph& gdata)
3838
{
39-
int diff = current_node - goal_node;
40-
return (diff < 0) ? -diff : diff;
39+
int c_node = gdata.index_to_node[current_idx];
40+
int g_node = gdata.index_to_node[goal_idx];
41+
int diff = c_node - g_node;
42+
return diff < 0 ? -diff : diff;
4143
}
4244

4345
// ✅ function + comment verified.
@@ -71,27 +73,16 @@ static inline int baseHeuristicFunc(int current_node, int goal_node)
7173
* - Time Complexity: O(L) (L is the number of landmarks used; loop unrolling helps reduce overhead but remains linear).
7274
* - Space Complexity: O(1) (No caching is used).
7375
*/
74-
static inline int altHeuristicFunc(int current_node, int goal_node, const graph& gdata)
76+
static inline int altHeuristicFunc(size_t current_idx, size_t goal_idx, const graph& gdata)
7577
{
76-
if (current_node == goal_node) {
78+
if (current_idx == goal_idx) {
7779
return 0;
7880
}
79-
80-
auto it_cur = gdata.node_to_index.find(current_node);
81-
auto it_goal = gdata.node_to_index.find(goal_node);
82-
if (it_cur == gdata.node_to_index.end() || it_goal == gdata.node_to_index.end()) {
83-
return 0;
84-
}
85-
86-
size_t idx_c = it_cur->second;
87-
size_t idx_g = it_goal->second;
88-
89-
const auto& dm_c = gdata.dist_landmark[idx_c];
90-
const auto& dm_g = gdata.dist_landmark[idx_g];
81+
const auto &dm_c = gdata.dist_landmark[current_idx];
82+
const auto &dm_g = gdata.dist_landmark[goal_idx];
9183

9284
int best_val = 0;
9385
int count_l = (int)dm_c.size();
94-
9586
const int* dm_c_ptr = dm_c.data();
9687
const int* dm_g_ptr = dm_g.data();
9788

@@ -118,10 +109,6 @@ static inline int altHeuristicFunc(int current_node, int goal_node, const graph&
118109
}
119110

120111
#undef ALT_HEURISTIC_STEP
121-
122-
if (best_val < 0) {
123-
best_val = 0;
124-
}
125112
return best_val;
126113
}
127114

@@ -152,10 +139,10 @@ static inline int altHeuristicFunc(int current_node, int goal_node, const graph&
152139
* - Time Complexity: O(L) (if `altHeuristicFunc` is used), O(1) (for `baseHeuristicFunc`).
153140
* - Space Complexity: O(1) (No extra allocations performed).
154141
*/
155-
static inline int computeHeuristic(int current_node, int goal_node, const graph& gdata, const config& conf)
142+
static inline int computeHeuristic(size_t current_idx, size_t goal_idx, const graph& gdata, const config& conf)
156143
{
157-
return conf.use_alt ? altHeuristicFunc(current_node, goal_node, gdata)
158-
: baseHeuristicFunc(current_node, goal_node);
144+
return conf.use_alt ? altHeuristicFunc(current_idx, goal_idx, gdata)
145+
: baseHeuristicFunc(current_idx, goal_idx, gdata);
159146
}
160147

161148
// ✅ function + comment verified.
@@ -296,7 +283,7 @@ path_result findShortestPath(const graph& gdata, search_buffers& buffers, const
296283
PROFILE_STOP("buffers_update");
297284

298285
PROFILE_START("priority_queue_setup");
299-
using pq_item = std::pair<int, int>;
286+
using pq_item = std::pair<int, size_t>;
300287
auto cmp = [](const pq_item& a, const pq_item& b) {
301288
return a.first > b.first;
302289
};
@@ -305,8 +292,8 @@ path_result findShortestPath(const graph& gdata, search_buffers& buffers, const
305292
PROFILE_STOP("priority_queue_setup");
306293

307294
PROFILE_START("heuristic_computation");
308-
int h_start = computeHeuristic(start_node, end_node, gdata, conf);
309-
int h_end = computeHeuristic(end_node, start_node, gdata, conf);
295+
int h_start = computeHeuristic(start_idx, end_idx, gdata, conf);
296+
int h_end = computeHeuristic(end_idx, start_idx, gdata, conf);
310297
setHForward(buffers, start_idx, h_start);
311298
setHBackward(buffers, end_idx, h_end);
312299
PROFILE_STOP("heuristic_computation");
@@ -329,8 +316,7 @@ path_result findShortestPath(const graph& gdata, search_buffers& buffers, const
329316
if (hv >= 0) {
330317
return hv;
331318
}
332-
int real_node = gdata.index_to_node[idx];
333-
hv = computeHeuristic(real_node, end_node, gdata, conf);
319+
hv = computeHeuristic(idx, end_idx, gdata, conf);
334320
setHForward(buffers, idx, hv);
335321
return hv;
336322
};
@@ -342,31 +328,37 @@ path_result findShortestPath(const graph& gdata, search_buffers& buffers, const
342328
if (hv >= 0) {
343329
return hv;
344330
}
345-
int real_node = gdata.index_to_node[idx];
346-
hv = computeHeuristic(real_node, start_node, gdata, conf);
331+
hv = computeHeuristic(idx, start_idx, gdata, conf);
347332
setHBackward(buffers, idx, hv);
348333
return hv;
349334
};
350335
PROFILE_STOP("lambda_getBackwardH_setup");
351336

352-
auto expandForward = [&](int cur_idx, int cur_f) {
337+
auto expandForward = [&](int cur_idx_int, int cur_f) {
353338
PROFILE_START("expandForward_total");
339+
size_t cur_idx = (size_t)cur_idx_int;
340+
341+
int best_dist_snapshot = best_distance.load(std::memory_order_relaxed);
342+
int local_best_dist = best_dist_snapshot;
343+
int local_best_node = best_meet_node.load(std::memory_order_relaxed);
344+
354345
int cur_g = getDistFromStart(buffers, cur_idx);
355346
int h_val = getForwardH(cur_idx);
356347
if (cur_g + (int)(weight * h_val) < cur_f) {
357348
PROFILE_STOP("expandForward_total");
358349
return;
359350
}
360-
PROFILE_START("expandForward_neighbors");
361351

352+
PROFILE_START("expandForward_neighbors");
362353
size_t start_edge = gdata.offsets[cur_idx];
363-
size_t end_edge = ((static_cast<size_t>(cur_idx) + 1) < gdata.offsets.size()) ? gdata.offsets[static_cast<size_t>(cur_idx) + 1] : gdata.edges.size();
354+
size_t end_edge = ((cur_idx + 1) < gdata.offsets.size()) ? gdata.offsets[cur_idx + 1] : gdata.edges.size();
355+
356+
std::vector<pq_item> local_queue_insert;
357+
local_queue_insert.reserve(end_edge - start_edge);
364358

365-
int cnt = 0;
366-
int best_dist_snapshot = best_distance.load(std::memory_order_relaxed);
367359
for (size_t i = start_edge; i < end_edge; i++) {
368360
PROFILE_START("expandForward_iteration");
369-
if (++cnt % 4 == 0 && search_done.load(std::memory_order_relaxed)) {
361+
if (search_done.load(std::memory_order_relaxed)) {
370362
PROFILE_STOP("expandForward_iteration");
371363
break;
372364
}
@@ -375,95 +367,139 @@ path_result findShortestPath(const graph& gdata, search_buffers& buffers, const
375367
int cost = edge.weight;
376368
int new_g = cur_g + cost;
377369

378-
if (getDistFromEnd(buffers, nbr_idx) >= 0 && (new_g + getDistFromEnd(buffers, nbr_idx)) >= best_dist_snapshot) {
379-
PROFILE_STOP("expandForward_iteration");
380-
continue;
370+
if (getDistFromEnd(buffers, nbr_idx) >= 0) {
371+
int possible_dist = new_g + getDistFromEnd(buffers, nbr_idx);
372+
if (possible_dist >= local_best_dist) {
373+
PROFILE_STOP("expandForward_iteration");
374+
continue;
375+
}
381376
}
382-
377+
383378
int existing = getDistFromStart(buffers, nbr_idx);
384379
if (existing < 0 || new_g < existing) {
385380
setDistFromStart(buffers, nbr_idx, new_g);
386-
setParentForward(buffers, nbr_idx, { cur_idx, cost });
381+
setParentForward(buffers, nbr_idx, { (int)cur_idx, cost });
387382
int h_nbr = getForwardH(nbr_idx);
388383
int f_cost = new_g + (int)(weight * h_nbr);
389-
{
390-
std::lock_guard<std::mutex> lk(forward_mutex);
391-
forward_queue.push({ f_cost, nbr_idx });
392-
}
384+
local_queue_insert.push_back({ f_cost, nbr_idx });
393385
}
394-
if (getDistFromEnd(buffers, nbr_idx) >= 0) {
395-
int total_cost = new_g + getDistFromEnd(buffers, nbr_idx);
396-
if (total_cost < best_dist_snapshot) {
397-
best_distance.store(total_cost, std::memory_order_relaxed);
398-
best_meet_node.store(nbr_idx, std::memory_order_relaxed);
386+
387+
int dist_from_end = getDistFromEnd(buffers, nbr_idx);
388+
if (dist_from_end >= 0) {
389+
int total_cost = new_g + dist_from_end;
390+
if (total_cost < local_best_dist) {
391+
local_best_dist = total_cost;
392+
local_best_node = nbr_idx;
399393
}
400394
}
401395
PROFILE_STOP("expandForward_iteration");
402396
}
397+
398+
if (!local_queue_insert.empty()) {
399+
std::lock_guard<std::mutex> lk(forward_mutex);
400+
for (auto &ins : local_queue_insert) {
401+
forward_queue.push(ins);
402+
}
403+
}
404+
405+
if (local_best_dist < best_dist_snapshot) {
406+
int global_best = best_distance.load(std::memory_order_relaxed);
407+
if (local_best_dist < global_best) {
408+
best_distance.store(local_best_dist, std::memory_order_relaxed);
409+
best_meet_node.store(local_best_node, std::memory_order_relaxed);
410+
}
411+
}
412+
403413
PROFILE_STOP("expandForward_neighbors");
404414
PROFILE_STOP("expandForward_total");
405415
};
406416

407-
auto expandBackward = [&](int cur_idx, int cur_f) {
417+
auto expandBackward = [&](int cur_idx_int, int cur_f) {
408418
PROFILE_START("expandBackward_total");
419+
size_t cur_idx = (size_t)cur_idx_int;
420+
421+
int best_dist_snapshot = best_distance.load(std::memory_order_relaxed);
422+
int local_best_dist = best_dist_snapshot;
423+
int local_best_node = best_meet_node.load(std::memory_order_relaxed);
424+
409425
int cur_g = getDistFromEnd(buffers, cur_idx);
410426
int h_val = getBackwardH(cur_idx);
411427
if (cur_g + (int)(weight * h_val) < cur_f) {
412428
PROFILE_STOP("expandBackward_total");
413429
return;
414430
}
415-
PROFILE_START("expandBackward_neighbors");
416431

432+
PROFILE_START("expandBackward_neighbors");
417433
size_t start_edge = gdata.offsets[cur_idx];
418-
size_t end_edge = ((static_cast<size_t>(cur_idx) + 1) < gdata.offsets.size()) ? gdata.offsets[static_cast<size_t>(cur_idx) + 1] : gdata.edges.size();
419-
420-
int cnt = 0;
421-
int best_dist_snapshot = best_distance.load(std::memory_order_relaxed);
434+
size_t end_edge = ((cur_idx + 1) < gdata.offsets.size()) ? gdata.offsets[cur_idx + 1] : gdata.edges.size();
435+
436+
std::vector<pq_item> local_queue_insert;
437+
local_queue_insert.reserve(end_edge - start_edge);
438+
422439
for (size_t i = start_edge; i < end_edge; i++) {
423440
PROFILE_START("expandBackward_iteration");
424-
if (++cnt % 4 == 0 && search_done.load(std::memory_order_relaxed)) {
441+
if (search_done.load(std::memory_order_relaxed)) {
425442
PROFILE_STOP("expandBackward_iteration");
426443
break;
427444
}
428445
const auto& edge = gdata.edges[i];
429446
int nbr_idx = edge.target;
430447
int cost = edge.weight;
431448
int new_g = cur_g + cost;
432-
if (getDistFromStart(buffers, nbr_idx) >= 0 &&
433-
(new_g + getDistFromStart(buffers, nbr_idx)) >= best_dist_snapshot) {
434-
PROFILE_STOP("expandBackward_iteration");
435-
continue;
449+
450+
if (getDistFromStart(buffers, nbr_idx) >= 0) {
451+
int possible_dist = new_g + getDistFromStart(buffers, nbr_idx);
452+
if (possible_dist >= local_best_dist) {
453+
PROFILE_STOP("expandBackward_iteration");
454+
continue;
455+
}
436456
}
457+
437458
int existing = getDistFromEnd(buffers, nbr_idx);
438459
if (existing < 0 || new_g < existing) {
439460
setDistFromEnd(buffers, nbr_idx, new_g);
440-
setParentBackward(buffers, nbr_idx, { cur_idx, cost });
441-
int h_nbr = getBackwardH(nbr_idx);
461+
setParentBackward(buffers, nbr_idx, { (int)cur_idx, cost });
462+
int h_nbr = getBackwardH(nbr_idx);
442463
int f_cost = new_g + (int)(weight * h_nbr);
443-
{
444-
std::lock_guard<std::mutex> lk(backward_mutex);
445-
backward_queue.push({ f_cost, nbr_idx });
446-
}
464+
local_queue_insert.push_back({ f_cost, nbr_idx });
447465
}
466+
448467
if (getDistFromStart(buffers, nbr_idx) >= 0) {
449468
int total_cost = new_g + getDistFromStart(buffers, nbr_idx);
450-
if (total_cost < best_dist_snapshot) {
451-
best_distance.store(total_cost, std::memory_order_relaxed);
452-
best_meet_node.store(nbr_idx, std::memory_order_relaxed);
469+
if (total_cost < local_best_dist) {
470+
local_best_dist = total_cost;
471+
local_best_node = nbr_idx;
453472
}
454473
}
455474
PROFILE_STOP("expandBackward_iteration");
456475
}
476+
477+
if (!local_queue_insert.empty()) {
478+
std::lock_guard<std::mutex> lk(backward_mutex);
479+
for (auto &ins : local_queue_insert) {
480+
backward_queue.push(ins);
481+
}
482+
}
483+
484+
if (local_best_dist < best_dist_snapshot) {
485+
int global_best = best_distance.load(std::memory_order_relaxed);
486+
if (local_best_dist < global_best) {
487+
best_distance.store(local_best_dist, std::memory_order_relaxed);
488+
best_meet_node.store(local_best_node, std::memory_order_relaxed);
489+
}
490+
}
491+
457492
PROFILE_STOP("expandBackward_neighbors");
458493
PROFILE_STOP("expandBackward_total");
459494
};
460-
495+
461496
auto forwardThreadFunc = [&]() {
462497
PROFILE_START("forwardThread_total");
463498
while (!search_done.load(std::memory_order_relaxed)) {
464499
PROFILE_START("forwardThread_iteration");
465500
int cur_f, cur_idx;
466501
{
502+
std::lock_guard<std::mutex> lk(forward_mutex);
467503
if (forward_queue.empty()) {
468504
search_done.store(true, std::memory_order_relaxed);
469505
PROFILE_STOP("forwardThread_iteration");
@@ -492,6 +528,7 @@ path_result findShortestPath(const graph& gdata, search_buffers& buffers, const
492528
PROFILE_START("backwardThread_iteration");
493529
int cur_f, cur_idx;
494530
{
531+
std::lock_guard<std::mutex> lk(backward_mutex);
495532
if (backward_queue.empty()) {
496533
search_done.store(true, std::memory_order_relaxed);
497534
PROFILE_STOP("backwardThread_iteration");

0 commit comments

Comments
 (0)