Skip to content

Commit 05bd941

Browse files
committed
wip: improve bandwidth limiting
1 parent 120263a commit 05bd941

File tree

17 files changed

+2301
-850
lines changed

17 files changed

+2301
-850
lines changed

Cargo.lock

Lines changed: 1285 additions & 672 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ version = "0.0.8"
1010

1111
[dependencies]
1212
anyhow = "1.0.86"
13-
askama = { version = "0.12.1", features = ["with-axum"] }
14-
askama_axum = "0.4.0"
13+
askama = { version = "0.12.1", features = ["with-axum"], optional = true }
14+
askama_axum = { version = "0.4.0", optional = true }
1515
atty = "0.2"
1616
aws-config = { version = "1.5.3", optional = true, features = [
1717
"behavior-version-latest",
@@ -20,17 +20,19 @@ aws-sdk-ec2 = { version = "1.55.0", optional = true }
2020
aws-sdk-cloudformation = { version = "1.0", optional = true }
2121
aws-sdk-route53 = { version = "1.0", optional = true }
2222
aws-sdk-ssm = { version = "1.0", optional = true }
23-
axum = "0.7.5"
23+
axum = { version = "0.7.5", optional = true }
2424
base64 = "0.22"
25-
chrono = "0.4"
25+
chrono = { version = "0.4", optional = true }
2626
clap = { version = "4.5.4", features = ["string", "derive", "env"] }
27-
rust-embed = { version = "8.4.0", features = ["axum", "debug-embed"] }
27+
rust-embed = { version = "8.4.0", features = ["axum", "debug-embed"], optional = true }
2828
futures = "0.3.30"
29-
indicatif = "0.17"
30-
mime_guess = "2.0"
31-
nix = { version = "0.29", features = ["net"] }
32-
rand = "0.8"
33-
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
29+
indicatif = "0.18.3"
30+
mime_guess = { version = "2.0", optional = true }
31+
nix = { version = "0.30.1", features = ["net"] }
32+
rand = "0.9.2"
33+
reqwest = { version = "0.12", default-features = false, features = [
34+
"rustls-tls",
35+
] }
3436
serde = { version = "1.0.203", features = ["derive"] }
3537
serde_json = "1.0.117"
3638
serde_yaml = "0.9.27"
@@ -39,12 +41,17 @@ tokio = { version = "1.38.0", features = ["full"] }
3941
tracing = "0.1.40"
4042
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
4143
url = "2.5"
44+
validator = { version = "0.18", features = ["derive"] }
4245

4346
[build-dependencies]
4447
anyhow = "1.0.86"
45-
reqwest = { version = "0.12.4", default-features = false, features = ["blocking", "rustls-tls"] }
48+
reqwest = { version = "0.12.4", default-features = false, features = [
49+
"blocking",
50+
"rustls-tls",
51+
] }
4652

4753
[features]
54+
default = ["aws", "dashboard"]
4855
cloudflare = []
4956
aws = [
5057
"dep:aws-config",
@@ -53,3 +60,27 @@ aws = [
5360
"dep:aws-sdk-route53",
5461
"dep:aws-sdk-ssm",
5562
]
63+
dashboard = [
64+
"dep:axum",
65+
"dep:askama",
66+
"dep:askama_axum",
67+
"dep:rust-embed",
68+
"dep:mime_guess",
69+
"dep:chrono",
70+
]
71+
72+
[[example]]
73+
name = "ui_aws"
74+
required-features = ["dashboard"]
75+
76+
[[example]]
77+
name = "ui_cloudflare"
78+
required-features = ["dashboard"]
79+
80+
[[example]]
81+
name = "ui_no_proxy"
82+
required-features = ["dashboard"]
83+
84+
[[example]]
85+
name = "ui_high_traffic"
86+
required-features = ["dashboard"]

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM rust:slim as build
1+
FROM rust:slim AS build
22
RUN rustup target add x86_64-unknown-linux-musl
33

44
RUN apt-get update && apt-get install -y musl-tools

examples/ui_aws.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ async fn index() -> Html<String> {
3030
tunnel_stats: stats,
3131
proxy_info,
3232
cloudfront_info: None,
33+
upload_limit: Some(100),
34+
download_limit: Some(100),
3335
};
3436

3537
Html(template.render().unwrap())

examples/ui_cloudflare.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ async fn index() -> Html<String> {
2424
let template = IndexTemplate {
2525
tunnel_stats: stats,
2626
proxy_info,
27+
cloudfront_info: None,
28+
upload_limit: None,
29+
download_limit: None,
2730
};
2831

2932
Html(template.render().unwrap())

examples/ui_high_traffic.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ async fn index() -> Html<String> {
3030
tunnel_stats: stats,
3131
proxy_info,
3232
cloudfront_info: None,
33+
upload_limit: Some(1000),
34+
download_limit: Some(1000),
3335
};
3436

3537
Html(template.render().unwrap())

examples/ui_no_proxy.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ async fn index() -> Html<String> {
1818
let template = IndexTemplate {
1919
tunnel_stats: stats,
2020
proxy_info: None,
21+
cloudfront_info: None,
22+
upload_limit: None,
23+
download_limit: None,
2124
};
2225

2326
Html(template.render().unwrap())

shell.nix

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{ pkgs ? import <nixpkgs> { } }:
1+
{ pkgs ? import (fetchTarball "https://github.com/NixOS/nixpkgs/archive/nixos-unstable.tar.gz") { } }:
22

33
with pkgs;
44

src/api.rs

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,27 @@ use tokio::sync::RwLock;
1414
#[folder = "assets/"]
1515
struct Assets;
1616

17-
#[derive(Clone, Default)]
17+
#[derive(Clone)]
1818
pub struct AppState {
1919
pub stats: Arc<RwLock<TunnelStats>>,
2020
pub proxy_info: Arc<RwLock<Option<ProxyInfo>>>,
2121
pub cloudfront_info: Arc<RwLock<Option<CloudFrontInfo>>>,
22+
pub tunnel: Arc<RwLock<Option<Arc<crate::wireguard::OriginTunnel>>>>,
23+
pub upload_limit: Option<u32>,
24+
pub download_limit: Option<u32>,
25+
}
26+
27+
impl Default for AppState {
28+
fn default() -> Self {
29+
Self {
30+
stats: Arc::new(RwLock::new(TunnelStats::default())),
31+
proxy_info: Arc::new(RwLock::new(None)),
32+
cloudfront_info: Arc::new(RwLock::new(None)),
33+
tunnel: Arc::new(RwLock::new(None)),
34+
upload_limit: None,
35+
download_limit: None,
36+
}
37+
}
2238
}
2339

2440
#[derive(Clone, Default)]
@@ -166,6 +182,8 @@ pub struct IndexTemplate {
166182
pub tunnel_stats: TunnelStats,
167183
pub proxy_info: Option<ProxyInfo>,
168184
pub cloudfront_info: Option<CloudFrontInfo>,
185+
pub upload_limit: Option<u32>,
186+
pub download_limit: Option<u32>,
169187
}
170188

171189
pub async fn assets(axum::extract::Path(file): axum::extract::Path<String>) -> Response {
@@ -188,12 +206,58 @@ pub async fn assets(axum::extract::Path(file): axum::extract::Path<String>) -> R
188206
pub fn router(state: AppState) -> Router {
189207
Router::new()
190208
.route("/", get(index))
209+
.route("/api/stats", get(stats_api))
191210
.route("/assets/*file", get(assets))
192211
.with_state(state)
193212
}
194213

214+
pub async fn stats_api(State(state): State<AppState>) -> impl IntoResponse {
215+
use axum::Json;
216+
use serde::Serialize;
217+
218+
#[derive(Serialize)]
219+
struct StatsResponse {
220+
bytes_sent: u64,
221+
bytes_received: u64,
222+
bytes_sent_formatted: String,
223+
bytes_received_formatted: String,
224+
packets_sent: u64,
225+
packets_received: u64,
226+
}
227+
228+
let mut stats = state.stats.read().await.clone();
229+
230+
// Update stats from iptables if tunnel is available
231+
if let Some(tunnel) = state.tunnel.read().await.as_ref() {
232+
if let Ok((bytes_sent, bytes_received)) = tunnel.get_traffic_stats().await {
233+
stats.bytes_sent = bytes_sent;
234+
stats.bytes_received = bytes_received;
235+
}
236+
}
237+
238+
stats.format_sizes();
239+
240+
Json(StatsResponse {
241+
bytes_sent: stats.bytes_sent,
242+
bytes_received: stats.bytes_received,
243+
bytes_sent_formatted: stats.bytes_sent_formatted,
244+
bytes_received_formatted: stats.bytes_received_formatted,
245+
packets_sent: stats.packets_sent,
246+
packets_received: stats.packets_received,
247+
})
248+
}
249+
195250
pub async fn index(State(state): State<AppState>) -> impl IntoResponse {
196251
let mut stats = state.stats.read().await.clone();
252+
253+
// Update stats from iptables if tunnel is available
254+
if let Some(tunnel) = state.tunnel.read().await.as_ref() {
255+
if let Ok((bytes_sent, bytes_received)) = tunnel.get_traffic_stats().await {
256+
stats.bytes_sent = bytes_sent;
257+
stats.bytes_received = bytes_received;
258+
}
259+
}
260+
197261
stats.format_sizes();
198262
let mut proxy_info = state.proxy_info.read().await.clone();
199263
let cloudfront_info = state.cloudfront_info.read().await.clone();
@@ -212,6 +276,8 @@ pub async fn index(State(state): State<AppState>) -> impl IntoResponse {
212276
tunnel_stats: stats,
213277
proxy_info,
214278
cloudfront_info,
279+
upload_limit: state.upload_limit,
280+
download_limit: state.download_limit,
215281
};
216282

217283
Html(template.render().unwrap())

src/aws/cloudformation.rs

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ pub struct CloudFormationTemplate {
55
pub stack_name: String,
66
pub region: String,
77
pub ingress_host: String,
8-
pub ingress_port: u16,
8+
pub ingress_port: u16, // Primary port (first ingress) for backwards compat
99
pub ingress_protocol: String,
10+
pub port_mappings: Vec<(u16, String)>, // All port mappings (port, protocol)
1011
pub origin_host: String,
1112
pub origin_port: u16,
1213
pub origin_ip: String,
@@ -334,15 +335,19 @@ impl CloudFormationTemplate {
334335
"CidrIp": format!("{}/32", self.origin_ip),
335336
"Description": "WireGuard from origin"
336337
}),
337-
json!({
338-
"IpProtocol": self.ingress_protocol.as_str(),
339-
"FromPort": self.ingress_port,
340-
"ToPort": self.ingress_port,
341-
"CidrIp": "0.0.0.0/0",
342-
"Description": "Ingress traffic"
343-
}),
344338
];
345339

340+
// Add rules for each port mapping
341+
for (port, protocol) in &self.port_mappings {
342+
rules.push(json!({
343+
"IpProtocol": protocol.to_lowercase(),
344+
"FromPort": port,
345+
"ToPort": port,
346+
"CidrIp": "0.0.0.0/0",
347+
"Description": format!("Ingress {} traffic on port {}", protocol.to_uppercase(), port)
348+
}));
349+
}
350+
346351
if self.debug {
347352
rules.push(json!({
348353
"IpProtocol": "tcp",
@@ -386,16 +391,32 @@ impl CloudFormationTemplate {
386391
.map(|s| format!("{}.0", s))
387392
.unwrap_or_else(|| "172.17.0.0".to_string());
388393

394+
// Generate Nix list expression for port mappings
395+
// Format: [ { port = 80; protocol = "tcp"; } { port = 443; protocol = "tcp"; } ]
396+
let port_mappings_nix = if self.port_mappings.is_empty() {
397+
"[ ]".to_string()
398+
} else {
399+
let mappings: Vec<String> = self.port_mappings
400+
.iter()
401+
.map(|(port, protocol)| {
402+
format!(
403+
"{{ port = {}; protocol = \"{}\"; }}",
404+
port,
405+
protocol.to_lowercase()
406+
)
407+
})
408+
.collect();
409+
format!("[\n {}\n ]", mappings.join("\n "))
410+
};
411+
389412
// Replace placeholders in the Nix template
390413
let nix_config = NIX_TEMPLATE
391414
.replace(
392415
"debug = false",
393416
&format!("debug = {}", if self.debug { "true" } else { "false" }),
394417
)
395418
.replace("{PROXY_WG_PRIVATE_KEY}", &self.proxy_wg_private_key)
396-
.replace("{PROTOCOL}", &self.ingress_protocol)
397-
.replace("{INGRESS_PORT}", &self.ingress_port.to_string())
398-
.replace("{ORIGIN_PORT}", &self.origin_port.to_string())
419+
.replace("{PORT_MAPPINGS}", &port_mappings_nix)
399420
.replace("{ORIGIN_WG_PUBLIC_KEY}", &self.origin_wg_public_key)
400421
.replace("{PRESHARED_KEY}", &self.preshared_key)
401422
.replace("{ORIGIN_IP}", &self.wg_origin_ip)
@@ -420,6 +441,7 @@ mod tests {
420441
ingress_host: "test.example.com".to_string(),
421442
ingress_port: 80,
422443
ingress_protocol: "tcp".to_string(),
444+
port_mappings: vec![(80, "tcp".to_string())],
423445
origin_host: "localhost".to_string(),
424446
origin_port: 8080,
425447
origin_ip: "1.2.3.4".to_string(),
@@ -445,6 +467,7 @@ mod tests {
445467
ingress_host: "test.example.com".to_string(),
446468
ingress_port: 80,
447469
ingress_protocol: "tcp".to_string(),
470+
port_mappings: vec![(80, "tcp".to_string())],
448471
origin_host: "localhost".to_string(),
449472
origin_port: 8080,
450473
origin_ip: "1.2.3.4".to_string(),
@@ -470,6 +493,7 @@ mod tests {
470493
ingress_host: "test.example.com".to_string(),
471494
ingress_port: 80,
472495
ingress_protocol: "tcp".to_string(),
496+
port_mappings: vec![(80, "tcp".to_string())],
473497
origin_host: "localhost".to_string(),
474498
origin_port: 8080,
475499
origin_ip: "1.2.3.4".to_string(),
@@ -486,10 +510,11 @@ mod tests {
486510

487511
let userdata = template.generate_userdata();
488512
let userdata_str = serde_json::to_string(&userdata).unwrap();
489-
// Check for NixOS configuration syntax and TCP protocol
513+
// Check for NixOS configuration syntax and port mappings
490514
assert!(userdata_str.contains("{ config, pkgs, lib, ... }:"));
491515
assert!(userdata_str.contains("debug = false"));
492-
assert!(userdata_str.contains("-p tcp --dport 80"));
516+
assert!(userdata_str.contains("port = 80"));
517+
assert!(userdata_str.contains("protocol = \\\"tcp\\\""));
493518
}
494519

495520
#[test]
@@ -500,6 +525,7 @@ mod tests {
500525
ingress_host: "test.example.com".to_string(),
501526
ingress_port: 53,
502527
ingress_protocol: "udp".to_string(),
528+
port_mappings: vec![(53, "udp".to_string())],
503529
origin_host: "localhost".to_string(),
504530
origin_port: 53,
505531
origin_ip: "1.2.3.4".to_string(),
@@ -516,10 +542,11 @@ mod tests {
516542

517543
let userdata = template.generate_userdata();
518544
let userdata_str = serde_json::to_string(&userdata).unwrap();
519-
// Check for NixOS configuration syntax and UDP protocol
545+
// Check for NixOS configuration syntax and port mappings
520546
assert!(userdata_str.contains("{ config, pkgs, lib, ... }:"));
521547
assert!(userdata_str.contains("debug = false"));
522-
assert!(userdata_str.contains("-p udp --dport 53"));
548+
assert!(userdata_str.contains("port = 53"));
549+
assert!(userdata_str.contains("protocol = \\\"udp\\\""));
523550
}
524551

525552
#[test]
@@ -530,6 +557,7 @@ mod tests {
530557
ingress_host: "test.example.com".to_string(),
531558
ingress_port: 80,
532559
ingress_protocol: "tcp".to_string(),
560+
port_mappings: vec![(80, "tcp".to_string())],
533561
origin_host: "localhost".to_string(),
534562
origin_port: 8080,
535563
origin_ip: "1.2.3.4".to_string(),

0 commit comments

Comments
 (0)