From 150ffe69726476efe8d5e8401329a6d7d55c9502 Mon Sep 17 00:00:00 2001 From: hitonanode <32937551+hitonanode@users.noreply.github.com> Date: Sun, 10 Aug 2025 22:59:17 +0900 Subject: [PATCH] minimum steiner tree --- graph/minimum_steiner_tree.hpp | 113 +++++++++++++++++++++++ graph/minimum_steiner_tree.md | 21 +++++ graph/test/minimum_steiner_tree.test.cpp | 28 ++++++ 3 files changed, 162 insertions(+) create mode 100644 graph/minimum_steiner_tree.hpp create mode 100644 graph/minimum_steiner_tree.md create mode 100644 graph/test/minimum_steiner_tree.test.cpp diff --git a/graph/minimum_steiner_tree.hpp b/graph/minimum_steiner_tree.hpp new file mode 100644 index 00000000..89cc285d --- /dev/null +++ b/graph/minimum_steiner_tree.hpp @@ -0,0 +1,113 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +// Minimum Steiner tree of undirected connected graph +// n vertices, m edges, k terminals +// Complexity: O(3^k n + 2^k m log m) +// Verify: https://judge.yosupo.jp/problem/minimum_steiner_tree +template +std::pair> +MinimumSteinerTree(int n, const std::vector> &edges, + const std::vector &terminals) { + + if (n <= 1 or terminals.size() <= 1) return {T{}, {}}; + assert(!edges.empty()); + + std::vector>> to(n); + for (int i = 0; i < (int)edges.size(); ++i) { + auto [u, v, w] = edges[i]; + assert(w >= 0); + to.at(u).emplace_back(v, i, w); + to.at(v).emplace_back(u, i, w); + } + const int k = terminals.size(); + + std::vector dp(n << k); + std::vector prv(n << k, -1); + + auto f = [&](int i, int s) -> int { + assert(0 <= s and s < (1 << k)); + return (i << k) + s; + }; + + for (int i = 0; i < n; ++i) prv.at(f(i, 0)) = f(i, 0); + + for (int j = 0; j < k; ++j) { + const int i = terminals.at(j); + prv.at(f(i, 1 << j)) = f(i, 0); + } + + for (int s = 0; s < (1 << k); ++s) { + for (int i = 0; i < n; ++i) { + for (int t = (s - 1) & s; t; t = (t - 1) & s) { + if (prv.at(f(i, t)) == -1) continue; + if (prv.at(f(i, s ^ t)) == -1) continue; + const T new_cost = dp.at(f(i, t)) + dp.at(f(i, s ^ t)); + if (new_cost < dp.at(f(i, s)) or prv.at(f(i, s)) == -1) { + dp.at(f(i, s)) = new_cost; + prv.at(f(i, s)) = f(i, t); + assert(s >= t); + } + } + } + + using P = std::pair; + std::priority_queue, std::greater<>> pq; + + for (int i = 0; i < n; ++i) { + if (prv.at(f(i, s)) != -1) pq.emplace(dp.at(f(i, s)), i); + } + + while (!pq.empty()) { + auto [cost, i] = pq.top(); + pq.pop(); + if (dp.at(f(i, s)) < cost) continue; + for (auto [j, edge_id, w] : to.at(i)) { + if (prv.at(f(j, s)) != -1 and dp.at(f(j, s)) <= cost + w) continue; + dp.at(f(j, s)) = cost + w; + prv.at(f(j, s)) = f(edge_id, s); + pq.emplace(dp.at(f(j, s)), j); + } + } + } + + T ans = dp.at(f(0, (1 << k) - 1)); + int argmin = 0; + for (int i = 1; i < n; ++i) { + if (dp.at(f(i, (1 << k) - 1)) < ans) { + ans = dp.at(f(i, (1 << k) - 1)); + argmin = i; + } + } + + std::vector used_edges; + + auto rec = [&](auto &&self, int cur) -> void { + const int mask = cur & ((1 << k) - 1); + if (!mask) return; + const int i = cur >> k; + const int prv_mask = prv.at(cur) & ((1 << k) - 1); + if (prv_mask == 0) return; + + if (mask == prv_mask) { + const int edge_id = prv.at(cur) >> k; + used_edges.emplace_back(edge_id); + const int nxt = i ^ std::get<0>(edges.at(edge_id)) ^ std::get<1>(edges.at(edge_id)); + self(self, f(nxt, mask)); + } else { + self(self, f(i, prv_mask)); + self(self, f(i, mask ^ prv_mask)); + } + }; + rec(rec, f(argmin, (1 << k) - 1)); + + std::sort(used_edges.begin(), used_edges.end()); + + return {ans, used_edges}; +} diff --git a/graph/minimum_steiner_tree.md b/graph/minimum_steiner_tree.md new file mode 100644 index 00000000..69a46fec --- /dev/null +++ b/graph/minimum_steiner_tree.md @@ -0,0 +1,21 @@ +--- +title: Minimum Steiner tree (最小シュタイナー木) +documentation_of: ./minimum_steiner_tree.hpp +--- + +各辺重みが非負の $n$ 頂点 $m$ 辺無向グラフとその $k$ 個の頂点からなるターミナル集合を入力として,最小シュタイナー木を $O(3^k n + 2^k m \log m)$ で求める. + +## 使用方法 + +```cpp +int N; +vector> edges; +vector terminals; + +const auto [total_cost, used_edge_ids] = MinimumSteinerTree(N, edges, terminals); +``` + +## 問題例 + +- [Library Checker - Minimum Steiner Tree](https://judge.yosupo.jp/problem/minimum_steiner_tree) +- [No.114 遠い未来 - yukicoder](https://yukicoder.me/problems/no/114) diff --git a/graph/test/minimum_steiner_tree.test.cpp b/graph/test/minimum_steiner_tree.test.cpp new file mode 100644 index 00000000..c46855d9 --- /dev/null +++ b/graph/test/minimum_steiner_tree.test.cpp @@ -0,0 +1,28 @@ +#define PROBLEM "https://judge.yosupo.jp/problem/minimum_steiner_tree" +#include "../minimum_steiner_tree.hpp" + +#include +#include +#include +using namespace std; + +int main() { + cin.tie(nullptr), ios::sync_with_stdio(false); + + int N, M; + cin >> N >> M; + + vector> edges(M); + for (auto &[u, v, w] : edges) cin >> u >> v >> w; + + int K; + cin >> K; + vector terminals(K); + for (auto &t : terminals) cin >> t; + + const auto [cost, used_edges] = MinimumSteinerTree(N, edges, terminals); + cout << cost << ' ' << used_edges.size() << '\n'; + for (int i = 0; i < (int)used_edges.size(); ++i) { + cout << used_edges.at(i) << (i + 1 < (int)used_edges.size() ? ' ' : '\n'); + } +}