diff --git a/Cargo.lock b/Cargo.lock index 22dfbac133..d3a3ef3bb8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4604,6 +4604,7 @@ dependencies = [ "rivet-logs", "rivet-metrics", "rivet-pools", + "rivet-runner-protocol", "rivet-runtime", "rustls 0.23.29", "rustls-pemfile 2.2.0", @@ -4642,7 +4643,6 @@ dependencies = [ "lazy_static", "moka", "once_cell", - "pegboard", "rand 0.8.5", "regex", "reqwest", @@ -4650,6 +4650,7 @@ dependencies = [ "rivet-config", "rivet-error", "rivet-metrics", + "rivet-runner-protocol", "rivet-runtime", "rivet-util", "rustls 0.23.29", @@ -4736,8 +4737,9 @@ name = "rivet-runner-protocol" version = "2.0.24-rc.1" dependencies = [ "anyhow", - "base64 0.22.1", "gasoline", + "hex", + "rand 0.8.5", "rivet-util", "serde", "serde_bare", diff --git a/engine/packages/guard-core/Cargo.toml b/engine/packages/guard-core/Cargo.toml index c5185cf29d..aa1382ecd4 100644 --- a/engine/packages/guard-core/Cargo.toml +++ b/engine/packages/guard-core/Cargo.toml @@ -25,17 +25,17 @@ hyper-util = { workspace = true, features = ["full"] } indoc.workspace = true lazy_static.workspace = true moka = { workspace = true, features = ["future"] } -pegboard.workspace = true rand.workspace = true regex.workspace = true rivet-api-builder.workspace = true rivet-config.workspace = true rivet-error.workspace = true rivet-metrics.workspace = true +rivet-runner-protocol.workspace = true rivet-runtime.workspace = true rivet-util.workspace = true -rustls.workspace = true rustls-pemfile.workspace = true +rustls.workspace = true serde_json.workspace = true serde.workspace = true tokio-rustls.workspace = true diff --git a/engine/packages/guard-core/src/custom_serve.rs b/engine/packages/guard-core/src/custom_serve.rs index 5bc1dcbcf1..99e2719dc4 100644 --- a/engine/packages/guard-core/src/custom_serve.rs +++ b/engine/packages/guard-core/src/custom_serve.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use bytes::Bytes; use http_body_util::Full; use hyper::{Request, Response}; -use pegboard::tunnel::id::RequestId; +use rivet_runner_protocol as protocol; use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame; use crate::WebSocketHandle; @@ -23,7 +23,7 @@ pub trait CustomServeTrait: Send + Sync { &self, req: Request>, request_context: &mut RequestContext, - request_id: RequestId, + request_id: protocol::RequestId, ) -> Result>; /// Handle a WebSocket connection after upgrade. Supports connection retries. @@ -34,7 +34,7 @@ pub trait CustomServeTrait: Send + Sync { _path: &str, _request_context: &mut RequestContext, // Identifies the websocket across retries. - _unique_request_id: RequestId, + _unique_request_id: protocol::RequestId, // True if this websocket is reconnecting after hibernation. _after_hibernation: bool, ) -> Result> { @@ -45,7 +45,7 @@ pub trait CustomServeTrait: Send + Sync { async fn handle_websocket_hibernation( &self, _websocket: WebSocketHandle, - _unique_request_id: RequestId, + _unique_request_id: protocol::RequestId, ) -> Result { bail!("service does not support websocket hibernation"); } diff --git a/engine/packages/guard-core/src/lib.rs b/engine/packages/guard-core/src/lib.rs index e2e03d8538..25f0d4a954 100644 --- a/engine/packages/guard-core/src/lib.rs +++ b/engine/packages/guard-core/src/lib.rs @@ -12,7 +12,6 @@ pub mod websocket_handle; pub use cert_resolver::CertResolverFn; pub use custom_serve::CustomServeTrait; -pub use pegboard::tunnel::id::{RequestId, generate_request_id}; pub use proxy_service::{ CacheKeyFn, MiddlewareFn, ProxyService, ProxyState, RouteTarget, RoutingFn, RoutingOutput, }; diff --git a/engine/packages/guard-core/src/proxy_service.rs b/engine/packages/guard-core/src/proxy_service.rs index 9c0f880620..ad2f8d84d1 100644 --- a/engine/packages/guard-core/src/proxy_service.rs +++ b/engine/packages/guard-core/src/proxy_service.rs @@ -14,7 +14,7 @@ use rivet_metrics::KeyValue; use rivet_util::Id; use serde_json; -use pegboard::tunnel::id::{RequestId, generate_request_id}; +use rivet_runner_protocol as protocol; use std::{ borrow::Cow, collections::{HashMap as StdHashMap, HashSet}, @@ -350,7 +350,7 @@ pub struct ProxyState { route_cache: RouteCache, rate_limiters: Cache<(Id, std::net::IpAddr), Arc>>, in_flight_counters: Cache<(Id, std::net::IpAddr), Arc>>, - inflight_requests: Arc>>, + in_flight_requests: Arc>>, port_type: PortType, clickhouse_inserter: Option, tasks: Arc, @@ -379,7 +379,7 @@ impl ProxyState { .max_capacity(10_000) .time_to_live(PROXY_STATE_CACHE_TTL) .build(), - inflight_requests: Arc::new(Mutex::new(HashSet::new())), + in_flight_requests: Arc::new(Mutex::new(HashSet::new())), port_type, clickhouse_inserter, tasks: TaskGroup::new(), @@ -603,7 +603,7 @@ impl ProxyState { ip_addr: std::net::IpAddr, actor_id: &Option, headers: &hyper::HeaderMap, - ) -> Result> { + ) -> Result> { // Check in-flight limit if actor_id is present if let Some(actor_id) = *actor_id { // Get actor-specific middleware config @@ -648,7 +648,7 @@ impl ProxyState { &self, ip_addr: std::net::IpAddr, actor_id: &Option, - request_id: RequestId, + request_id: protocol::RequestId, ) { // Release in-flight counter if actor_id is present if let Some(actor_id) = *actor_id { @@ -660,17 +660,17 @@ impl ProxyState { } // Release request ID - let mut requests = self.inflight_requests.lock().await; + let mut requests = self.in_flight_requests.lock().await; requests.remove(&request_id); } /// Generate a unique request ID that is not currently in flight - async fn generate_unique_request_id(&self) -> anyhow::Result { + async fn generate_unique_request_id(&self) -> Result { const MAX_TRIES: u32 = 100; - let mut requests = self.inflight_requests.lock().await; + let mut requests = self.in_flight_requests.lock().await; for attempt in 0..MAX_TRIES { - let request_id = generate_request_id(); + let request_id = protocol::util::generate_request_id(); // Check if this ID is already in use if !requests.contains(&request_id) { @@ -688,7 +688,7 @@ impl ProxyState { ); } - anyhow::bail!( + bail!( "failed to generate unique request id after {} attempts", MAX_TRIES ); @@ -2144,7 +2144,7 @@ impl ProxyService { .release_in_flight(client_ip, &actor_id, request_id) .await; - anyhow::Ok(()) + Ok(()) } .instrument(tracing::info_span!("handle_ws_task_custom_serve")), ); diff --git a/engine/packages/guard/Cargo.toml b/engine/packages/guard/Cargo.toml index 69692a15e4..f7d8307d43 100644 --- a/engine/packages/guard/Cargo.toml +++ b/engine/packages/guard/Cargo.toml @@ -37,6 +37,7 @@ rivet-guard-core.workspace = true rivet-logs.workspace = true rivet-metrics.workspace = true rivet-pools.workspace = true +rivet-runner-protocol.workspace = true rivet-runtime.workspace = true rustls-pemfile.workspace = true rustls.workspace = true diff --git a/engine/packages/guard/src/routing/api_public.rs b/engine/packages/guard/src/routing/api_public.rs index e7ac2342fa..b1f701c4f3 100644 --- a/engine/packages/guard/src/routing/api_public.rs +++ b/engine/packages/guard/src/routing/api_public.rs @@ -6,9 +6,9 @@ use bytes::Bytes; use gas::prelude::*; use http_body_util::{BodyExt, Full}; use hyper::{Request, Response}; -use pegboard::tunnel::id::RequestId; use rivet_guard_core::proxy_service::{ResponseBody, RoutingOutput}; use rivet_guard_core::{CustomServeTrait, request_context::RequestContext}; +use rivet_runner_protocol as protocol; use tower::Service; struct ApiPublicService { @@ -21,7 +21,7 @@ impl CustomServeTrait for ApiPublicService { &self, req: Request>, _request_context: &mut RequestContext, - _request_id: RequestId, + _request_id: protocol::RequestId, ) -> Result> { // Clone the router to get a mutable service let mut service = self.router.clone(); diff --git a/engine/packages/pegboard-gateway/src/keepalive_task.rs b/engine/packages/pegboard-gateway/src/keepalive_task.rs index 2172fbdcd1..1632cd68b0 100644 --- a/engine/packages/pegboard-gateway/src/keepalive_task.rs +++ b/engine/packages/pegboard-gateway/src/keepalive_task.rs @@ -1,8 +1,7 @@ use anyhow::Result; use gas::prelude::*; -use pegboard::tunnel::id as tunnel_id; -use pegboard::tunnel::id::{GatewayId, RequestId}; use rand::Rng; +use rivet_runner_protocol as protocol; use std::time::Duration; use tokio::sync::watch; @@ -17,8 +16,8 @@ pub async fn task( shared_state: SharedState, ctx: StandaloneCtx, actor_id: Id, - gateway_id: GatewayId, - request_id: RequestId, + gateway_id: protocol::GatewayId, + request_id: protocol::RequestId, mut keepalive_abort_rx: watch::Receiver<()>, ) -> Result { let mut ping_interval = tokio::time::interval(Duration::from_millis( @@ -44,8 +43,8 @@ pub async fn task( tracing::debug!( %actor_id, - gateway_id=%tunnel_id::gateway_id_to_string(&gateway_id), - request_id=%tunnel_id::request_id_to_string(&request_id), + gateway_id=%protocol::util::id_to_string(&gateway_id), + request_id=%protocol::util::id_to_string(&request_id), "updating hws keepalive" ); diff --git a/engine/packages/pegboard-gateway/src/lib.rs b/engine/packages/pegboard-gateway/src/lib.rs index 4af6ef1c6f..3a1a4476da 100644 --- a/engine/packages/pegboard-gateway/src/lib.rs +++ b/engine/packages/pegboard-gateway/src/lib.rs @@ -5,7 +5,6 @@ use futures_util::TryStreamExt; use gas::prelude::*; use http_body_util::{BodyExt, Full}; use hyper::{Request, Response, StatusCode}; -use pegboard::tunnel::id::{self as tunnel_id, RequestId}; use rivet_error::*; use rivet_guard_core::{ custom_serve::{CustomServeTrait, HibernationResult}, @@ -86,7 +85,7 @@ impl CustomServeTrait for PegboardGateway { &self, req: Request>, _request_context: &mut RequestContext, - request_id: RequestId, + request_id: protocol::RequestId, ) -> Result> { // Use the actor ID from the gateway instance let actor_id = self.actor_id.to_string(); @@ -213,7 +212,7 @@ impl CustomServeTrait for PegboardGateway { } } else { tracing::warn!( - request_id=%tunnel_id::request_id_to_string(&request_id), + request_id=%protocol::util::id_to_string(&request_id), "received no message response during request init", ); break; @@ -268,14 +267,14 @@ impl CustomServeTrait for PegboardGateway { Ok(response) } - #[tracing::instrument(skip_all, fields(actor_id=?self.actor_id, runner_id=?self.runner_id, request_id=%tunnel_id::request_id_to_string(&request_id)))] + #[tracing::instrument(skip_all, fields(actor_id=?self.actor_id, runner_id=?self.runner_id, request_id=%protocol::util::id_to_string(&request_id)))] async fn handle_websocket( &self, client_ws: WebSocketHandle, headers: &hyper::HeaderMap, _path: &str, _request_context: &mut RequestContext, - request_id: RequestId, + request_id: protocol::RequestId, after_hibernation: bool, ) -> Result> { // Use the actor ID from the gateway instance @@ -354,7 +353,7 @@ impl CustomServeTrait for PegboardGateway { } } else { tracing::warn!( - request_id=%tunnel_id::request_id_to_string(&request_id), + request_id=%protocol::util::id_to_string(&request_id), "received no message response during ws init", ); break; @@ -572,11 +571,11 @@ impl CustomServeTrait for PegboardGateway { } } - #[tracing::instrument(skip_all, fields(actor_id=?self.actor_id, request_id=%tunnel_id::request_id_to_string(&request_id)))] + #[tracing::instrument(skip_all, fields(actor_id=?self.actor_id, request_id=%protocol::util::id_to_string(&request_id)))] async fn handle_websocket_hibernation( &self, client_ws: WebSocketHandle, - request_id: RequestId, + request_id: protocol::RequestId, ) -> Result { // Immediately rewake if we have pending messages if self diff --git a/engine/packages/pegboard-gateway/src/shared_state.rs b/engine/packages/pegboard-gateway/src/shared_state.rs index 6c9d1bd395..cd5158b117 100644 --- a/engine/packages/pegboard-gateway/src/shared_state.rs +++ b/engine/packages/pegboard-gateway/src/shared_state.rs @@ -1,6 +1,5 @@ use anyhow::Result; use gas::prelude::*; -use pegboard::tunnel::id::{self as tunnel_id, GatewayId, RequestId}; use rivet_guard_core::errors::WebSocketServiceTimeout; use rivet_runner_protocol::{self as protocol, versioned, PROTOCOL_VERSION}; use scc::{hash_map::Entry, HashMap}; @@ -40,7 +39,7 @@ struct InFlightRequest { /// True once first message for this request has been sent (so runner learned reply_to). opened: bool, /// Message index counter for this request. - message_index: tunnel_id::MessageIndex, + message_index: protocol::MessageIndex, hibernation_state: Option, stopping: bool, last_pong: i64, @@ -56,14 +55,14 @@ struct HibernationState { pub struct PendingWebsocketMessage { payload: Vec, send_instant: Instant, - message_index: tunnel_id::MessageIndex, + message_index: protocol::MessageIndex, } pub struct SharedStateInner { ups: PubSub, - gateway_id: GatewayId, + gateway_id: protocol::GatewayId, receiver_subject: String, - in_flight_requests: HashMap, + in_flight_requests: HashMap, hibernation_timeout: i64, } @@ -72,7 +71,7 @@ pub struct SharedState(Arc); impl SharedState { pub fn new(config: &rivet_config::Config, ups: PubSub) -> Self { - let gateway_id = tunnel_id::generate_gateway_id(); + let gateway_id = protocol::util::generate_gateway_id(); let receiver_subject = pegboard::pubsub_subjects::GatewayReceiverSubject::new(gateway_id).to_string(); @@ -85,7 +84,7 @@ impl SharedState { })) } - pub fn gateway_id(&self) -> GatewayId { + pub fn gateway_id(&self) -> protocol::GatewayId { self.gateway_id } @@ -102,11 +101,11 @@ impl SharedState { Ok(()) } - #[tracing::instrument(skip_all, fields(%receiver_subject, request_id=%tunnel_id::request_id_to_string(&request_id)))] + #[tracing::instrument(skip_all, fields(%receiver_subject, request_id=%protocol::util::id_to_string(&request_id)))] pub async fn start_in_flight_request( &self, receiver_subject: String, - request_id: RequestId, + request_id: protocol::RequestId, ) -> InFlightRequestHandle { let (msg_tx, msg_rx) = mpsc::channel(128); let (drop_tx, drop_rx) = watch::channel(()); @@ -148,10 +147,10 @@ impl SharedState { } } - #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id)))] + #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id)))] pub async fn send_message( &self, - request_id: RequestId, + request_id: protocol::RequestId, message_kind: protocol::ToClientTunnelMessageKind, ) -> Result<()> { let mut req = self @@ -161,8 +160,11 @@ impl SharedState { .context("request not in flight")?; // Generate message ID - let message_id = - tunnel_id::build_message_id(self.gateway_id, request_id, req.message_index)?; + let message_id = protocol::MessageId { + gateway_id: self.gateway_id, + request_id, + message_index: req.message_index, + }; // Increment message index for next message let current_message_index = req.message_index; @@ -185,8 +187,6 @@ impl SharedState { message_kind, }; - tracing::debug!(?message_id, ?payload, "shared state send message"); - // Send message let message = protocol::ToRunner::ToClientTunnelMessage(payload); let message_serialized = versioned::ToRunner::wrap_latest(message) @@ -226,8 +226,8 @@ impl SharedState { Ok(()) } - #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id)))] - pub async fn send_and_check_ping(&self, request_id: RequestId) -> Result<()> { + #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id)))] + pub async fn send_and_check_ping(&self, request_id: protocol::RequestId) -> Result<()> { let req = self .in_flight_requests .get_async(&request_id) @@ -262,8 +262,8 @@ impl SharedState { Ok(()) } - #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id)))] - pub async fn keepalive_hws(&self, request_id: RequestId) -> Result<()> { + #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id)))] + pub async fn keepalive_hws(&self, request_id: protocol::RequestId) -> Result<()> { let mut req = self .in_flight_requests .get_async(&request_id) @@ -293,7 +293,7 @@ impl SharedState { self.in_flight_requests.get_async(&pong.request_id).await else { tracing::debug!( - request_id=%tunnel_id::request_id_to_string(&pong.request_id), + request_id=%protocol::util::id_to_string(&pong.request_id), "in flight has already been disconnected, dropping ping" ); continue; @@ -306,22 +306,17 @@ impl SharedState { metrics::TUNNEL_PING_DURATION.record(rtt as f64 * 0.001, &[]); } Ok(protocol::ToGateway::ToServerTunnelMessage(msg)) => { - // Parse message ID to extract components - let parts = match tunnel_id::parse_message_id(msg.message_id) { - Ok(p) => p, - Err(err) => { - tracing::error!(?err, message_id=?msg.message_id, "failed to parse message id"); - continue; - } - }; + let message_id = msg.message_id; - let Some(in_flight) = - self.in_flight_requests.get_async(&parts.request_id).await + let Some(in_flight) = self + .in_flight_requests + .get_async(&message_id.request_id) + .await else { tracing::warn!( - gateway_id=%tunnel_id::gateway_id_to_string(&parts.gateway_id), - request_id=%tunnel_id::request_id_to_string(&parts.request_id), - message_index=parts.message_index, + gateway_id=%protocol::util::id_to_string(&message_id.gateway_id), + request_id=%protocol::util::id_to_string(&message_id.request_id), + message_index=message_id.message_index, "in flight has already been disconnected, dropping message" ); continue; @@ -329,9 +324,9 @@ impl SharedState { // Send message to the request handler to emulate the real network action tracing::debug!( - gateway_id=%tunnel_id::gateway_id_to_string(&parts.gateway_id), - request_id=%tunnel_id::request_id_to_string(&parts.request_id), - message_index=parts.message_index, + gateway_id=%protocol::util::id_to_string(&message_id.gateway_id), + request_id=%protocol::util::id_to_string(&message_id.request_id), + message_index=message_id.message_index, "forwarding message to request handler" ); let _ = in_flight.msg_tx.send(msg.message_kind.clone()).await; @@ -343,8 +338,12 @@ impl SharedState { } } - #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id), %enable))] - pub async fn toggle_hibernation(&self, request_id: RequestId, enable: bool) -> Result<()> { + #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id), %enable))] + pub async fn toggle_hibernation( + &self, + request_id: protocol::RequestId, + enable: bool, + ) -> Result<()> { let mut req = self .in_flight_requests .get_async(&request_id) @@ -367,8 +366,11 @@ impl SharedState { Ok(()) } - #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id)))] - pub async fn resend_pending_websocket_messages(&self, request_id: RequestId) -> Result<()> { + #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id)))] + pub async fn resend_pending_websocket_messages( + &self, + request_id: protocol::RequestId, + ) -> Result<()> { let Some(mut req) = self.in_flight_requests.get_async(&request_id).await else { bail!("request not in flight"); }; @@ -380,8 +382,6 @@ impl SharedState { tracing::debug!(len=?hs.pending_ws_msgs.len(), "resending pending messages"); for pending_msg in &hs.pending_ws_msgs { - tracing::info!(?pending_msg.payload, ?pending_msg.message_index, "------2---------"); - self.ups .publish(&receiver_subject, &pending_msg.payload, PublishOpts::one()) .await?; @@ -392,8 +392,11 @@ impl SharedState { Ok(()) } - #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id)))] - pub async fn has_pending_websocket_messages(&self, request_id: RequestId) -> Result { + #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id)))] + pub async fn has_pending_websocket_messages( + &self, + request_id: protocol::RequestId, + ) -> Result { let Some(req) = self.in_flight_requests.get_async(&request_id).await else { bail!("request not in flight"); }; @@ -405,10 +408,10 @@ impl SharedState { } } - #[tracing::instrument(skip_all, fields(request_id=%tunnel_id::request_id_to_string(&request_id), %ack_index))] + #[tracing::instrument(skip_all, fields(request_id=%protocol::util::id_to_string(&request_id), %ack_index))] pub async fn ack_pending_websocket_messages( &self, - request_id: RequestId, + request_id: protocol::RequestId, ack_index: u16, ) -> Result<()> { let Some(mut req) = self.in_flight_requests.get_async(&request_id).await else { @@ -530,13 +533,13 @@ impl SharedState { if let Some(reason) = &reason { tracing::debug!( - request_id=%tunnel_id::request_id_to_string(request_id), + request_id=%protocol::util::id_to_string(request_id), ?reason, "gc stopping in flight request" ); if req.drop_tx.send(()).is_err() { - tracing::debug!(request_id=%tunnel_id::request_id_to_string(request_id), "failed to send timeout msg to tunnel"); + tracing::debug!(request_id=%protocol::util::id_to_string(request_id), "failed to send timeout msg to tunnel"); } // Mark req as stopping to skip this loop next time the gc is run @@ -554,7 +557,7 @@ impl SharedState { // When the websocket reconnects a new channel will be created if req.stopping && req.drop_tx.is_closed() { tracing::debug!( - request_id=%tunnel_id::request_id_to_string(request_id), + request_id=%protocol::util::id_to_string(request_id), "gc removing in flight request" ); diff --git a/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs b/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs index 571a2ea527..4d03b47287 100644 --- a/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs +++ b/engine/packages/pegboard-gateway/src/tunnel_to_ws_task.rs @@ -1,6 +1,5 @@ use anyhow::Result; use gas::prelude::*; -use pegboard::tunnel::id as tunnel_id; use rivet_guard_core::{ WebSocketHandle, errors::{WebSocketServiceHibernate, WebSocketServiceTimeout, WebSocketServiceUnavailable}, @@ -39,7 +38,7 @@ pub async fn task( } protocol::ToServerTunnelMessageKind::ToServerWebSocketMessageAck(ack) => { tracing::debug!( - request_id=%tunnel_id::request_id_to_string(&request_id), + request_id=%protocol::util::id_to_string(&request_id), ack_index=?ack.index, "received WebSocketMessageAck from runner" ); @@ -74,7 +73,7 @@ pub async fn task( } } _ = drop_rx.changed() => { - tracing::warn!("websocket message timeout"); + tracing::warn!("garbage collected"); return Err(WebSocketServiceTimeout.build()); } _ = tunnel_to_ws_abort_rx.changed() => { diff --git a/engine/packages/pegboard-runner/src/conn.rs b/engine/packages/pegboard-runner/src/conn.rs index ba5e3ebb24..1565ea0a8d 100644 --- a/engine/packages/pegboard-runner/src/conn.rs +++ b/engine/packages/pegboard-runner/src/conn.rs @@ -1,3 +1,8 @@ +use std::{ + sync::{Arc, atomic::AtomicU32}, + time::Duration, +}; + use anyhow::Context; use futures_util::StreamExt; use gas::prelude::Id; @@ -5,12 +10,7 @@ use gas::prelude::*; use hyper_tungstenite::tungstenite::Message; use pegboard::ops::runner::update_alloc_idx::{Action, RunnerEligibility}; use rivet_guard_core::WebSocketHandle; -use rivet_runner_protocol as protocol; -use rivet_runner_protocol::*; -use std::{ - sync::{Arc, atomic::AtomicU32}, - time::Duration, -}; +use rivet_runner_protocol::{self as protocol, versioned}; use vbare::OwnedVersionedData; use crate::{errors::WsError, utils::UrlData}; diff --git a/engine/packages/pegboard-runner/src/lib.rs b/engine/packages/pegboard-runner/src/lib.rs index 0632f217de..d21d65b31b 100644 --- a/engine/packages/pegboard-runner/src/lib.rs +++ b/engine/packages/pegboard-runner/src/lib.rs @@ -5,11 +5,11 @@ use gas::prelude::*; use http_body_util::Full; use hyper::{Response, StatusCode}; use pegboard::ops::runner::update_alloc_idx::Action; -use pegboard::tunnel::id::RequestId; use rivet_guard_core::{ WebSocketHandle, custom_serve::CustomServeTrait, proxy_service::ResponseBody, request_context::RequestContext, }; +use rivet_runner_protocol as protocol; use std::time::Duration; use tokio::sync::watch; use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame; @@ -48,7 +48,7 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { &self, _req: hyper::Request>, _request_context: &mut RequestContext, - _request_id: RequestId, + _request_id: protocol::RequestId, ) -> Result> { // Pegboard runner ws doesn't handle regular HTTP requests // Return a simple status response @@ -68,7 +68,7 @@ impl CustomServeTrait for PegboardRunnerWsCustomServe { _headers: &hyper::HeaderMap, path: &str, _request_context: &mut RequestContext, - _unique_request_id: pegboard::tunnel::id::RequestId, + _unique_request_id: protocol::RequestId, _after_hibernation: bool, ) -> Result> { // Get UPS diff --git a/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs b/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs index f37efc3474..2cb4495d83 100644 --- a/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs +++ b/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs @@ -5,7 +5,6 @@ use gas::prelude::*; use hyper_tungstenite::tungstenite::Message as WsMessage; use hyper_tungstenite::tungstenite::Message; use pegboard::pubsub_subjects::GatewayReceiverSubject; -use pegboard::tunnel::id as tunnel_id; use pegboard_actor_kv as kv; use rivet_guard_core::websocket_handle::WebSocketReceiver; use rivet_runner_protocol::{self as protocol, PROTOCOL_VERSION, versioned}; @@ -382,7 +381,7 @@ async fn handle_tunnel_message( if protocol::compat::version_needs_tunnel_ack(conn.protocol_version) { let ack_msg = versioned::ToClient::wrap_latest(protocol::ToClient::ToClientTunnelMessage( protocol::ToClientTunnelMessage { - message_id: msg.message_id, + message_id: msg.message_id.clone(), message_kind: protocol::ToClientTunnelMessageKind::DeprecatedTunnelAck, }, )); @@ -399,12 +398,8 @@ async fn handle_tunnel_message( .context("failed to send DeprecatedTunnelAck to runner")?; } - // Parse message ID to extract gateway_id - let parts = - tunnel_id::parse_message_id(msg.message_id).context("failed to parse message id")?; - // Publish message to UPS - let gateway_reply_to = GatewayReceiverSubject::new(parts.gateway_id).to_string(); + let gateway_reply_to = GatewayReceiverSubject::new(msg.message_id.gateway_id).to_string(); let msg_serialized = versioned::ToGateway::wrap_latest(protocol::ToGateway::ToServerTunnelMessage(msg)) .serialize_with_embedded_version(PROTOCOL_VERSION) diff --git a/engine/packages/pegboard-serverless/src/lib.rs b/engine/packages/pegboard-serverless/src/lib.rs index 5591254f2f..d351a13e2d 100644 --- a/engine/packages/pegboard-serverless/src/lib.rs +++ b/engine/packages/pegboard-serverless/src/lib.rs @@ -511,10 +511,9 @@ async fn publish_to_client_stop(ctx: &StandaloneCtx, runner_id: Id) -> Result<() let receiver_subject = pegboard::pubsub_subjects::RunnerReceiverSubject::new(runner_id).to_string(); - let message_serialized = rivet_runner_protocol::versioned::ToRunner::wrap_latest( - rivet_runner_protocol::ToRunner::ToClientClose, - ) - .serialize_with_embedded_version(rivet_runner_protocol::PROTOCOL_VERSION)?; + let message_serialized = + protocol::versioned::ToRunner::wrap_latest(protocol::ToRunner::ToClientClose) + .serialize_with_embedded_version(protocol::PROTOCOL_VERSION)?; ctx.ups()? .publish(&receiver_subject, &message_serialized, PublishOpts::one()) diff --git a/engine/packages/pegboard/src/keys/actor.rs b/engine/packages/pegboard/src/keys/actor.rs index 5348b9b191..fe1460080d 100644 --- a/engine/packages/pegboard/src/keys/actor.rs +++ b/engine/packages/pegboard/src/keys/actor.rs @@ -1,9 +1,8 @@ use anyhow::Result; use gas::prelude::*; +use rivet_runner_protocol as protocol; use universaldb::prelude::*; -use crate::tunnel::id::{GatewayId, RequestId}; - #[derive(Debug)] pub struct CreateTsKey { actor_id: Id, @@ -318,16 +317,16 @@ impl<'de> TupleUnpack<'de> for NamespaceIdKey { pub struct HibernatingRequestKey { actor_id: Id, last_ping_ts: i64, - pub gateway_id: GatewayId, - pub request_id: RequestId, + pub gateway_id: protocol::GatewayId, + pub request_id: protocol::RequestId, } impl HibernatingRequestKey { pub fn new( actor_id: Id, last_ping_ts: i64, - gateway_id: GatewayId, - request_id: RequestId, + gateway_id: protocol::GatewayId, + request_id: protocol::RequestId, ) -> Self { HibernatingRequestKey { actor_id, @@ -381,12 +380,12 @@ impl<'de> TupleUnpack<'de> for HibernatingRequestKey { let (input, (_, _, actor_id, last_ping_ts, gateway_id_bytes, request_id_bytes)) = <(usize, usize, Id, i64, Vec, Vec)>::unpack(input, tuple_depth)?; - let gateway_id: GatewayId = gateway_id_bytes + let gateway_id = gateway_id_bytes .as_slice() .try_into() .expect("invalid gateway_id length"); - let request_id: RequestId = request_id_bytes + let request_id = request_id_bytes .as_slice() .try_into() .expect("invalid request_id length"); diff --git a/engine/packages/pegboard/src/keys/hibernating_request.rs b/engine/packages/pegboard/src/keys/hibernating_request.rs index 7f91883f60..dfafe1bf97 100644 --- a/engine/packages/pegboard/src/keys/hibernating_request.rs +++ b/engine/packages/pegboard/src/keys/hibernating_request.rs @@ -1,16 +1,15 @@ use anyhow::Result; +use rivet_runner_protocol as protocol; use universaldb::prelude::*; -use crate::tunnel::id::{GatewayId, RequestId}; - #[derive(Debug)] pub struct LastPingTsKey { - gateway_id: GatewayId, - request_id: RequestId, + gateway_id: protocol::GatewayId, + request_id: protocol::RequestId, } impl LastPingTsKey { - pub fn new(gateway_id: GatewayId, request_id: RequestId) -> Self { + pub fn new(gateway_id: protocol::GatewayId, request_id: protocol::RequestId) -> Self { LastPingTsKey { gateway_id, request_id, @@ -53,12 +52,12 @@ impl<'de> TupleUnpack<'de> for LastPingTsKey { let (input, (_, _, gateway_id_bytes, request_id_bytes, _)) = <(usize, usize, Vec, Vec, usize)>::unpack(input, tuple_depth)?; - let gateway_id: GatewayId = gateway_id_bytes + let gateway_id = gateway_id_bytes .as_slice() .try_into() .expect("invalid gateway_id length"); - let request_id: RequestId = request_id_bytes + let request_id = request_id_bytes .as_slice() .try_into() .expect("invalid request_id length"); diff --git a/engine/packages/pegboard/src/lib.rs b/engine/packages/pegboard/src/lib.rs index 1574ff4503..a776a3d227 100644 --- a/engine/packages/pegboard/src/lib.rs +++ b/engine/packages/pegboard/src/lib.rs @@ -5,7 +5,6 @@ pub mod keys; mod metrics; pub mod ops; pub mod pubsub_subjects; -pub mod tunnel; mod utils; pub mod workflows; diff --git a/engine/packages/pegboard/src/ops/actor/hibernating_request/delete.rs b/engine/packages/pegboard/src/ops/actor/hibernating_request/delete.rs index 3af721741f..09f1a96bd1 100644 --- a/engine/packages/pegboard/src/ops/actor/hibernating_request/delete.rs +++ b/engine/packages/pegboard/src/ops/actor/hibernating_request/delete.rs @@ -1,14 +1,14 @@ use gas::prelude::*; +use rivet_runner_protocol as protocol; use universaldb::utils::IsolationLevel::*; use crate::keys; -use crate::tunnel::id::{GatewayId, RequestId}; #[derive(Debug, Default)] pub struct Input { pub actor_id: Id, - pub gateway_id: GatewayId, - pub request_id: RequestId, + pub gateway_id: protocol::GatewayId, + pub request_id: protocol::RequestId, } #[operation] diff --git a/engine/packages/pegboard/src/ops/actor/hibernating_request/list.rs b/engine/packages/pegboard/src/ops/actor/hibernating_request/list.rs index 875ae86ca3..0ff8f98da8 100644 --- a/engine/packages/pegboard/src/ops/actor/hibernating_request/list.rs +++ b/engine/packages/pegboard/src/ops/actor/hibernating_request/list.rs @@ -1,10 +1,10 @@ use futures_util::{StreamExt, TryStreamExt}; use gas::prelude::*; +use rivet_runner_protocol as protocol; use universaldb::options::StreamingMode; use universaldb::utils::IsolationLevel::*; use crate::keys; -use crate::tunnel::id::{GatewayId, RequestId}; #[derive(Debug, Default)] pub struct Input { @@ -13,8 +13,8 @@ pub struct Input { #[derive(Debug)] pub struct HibernatingRequestItem { - pub gateway_id: GatewayId, - pub request_id: RequestId, + pub gateway_id: protocol::GatewayId, + pub request_id: protocol::RequestId, } #[operation] diff --git a/engine/packages/pegboard/src/ops/actor/hibernating_request/upsert.rs b/engine/packages/pegboard/src/ops/actor/hibernating_request/upsert.rs index 26f8b8cb7d..6fb0981a41 100644 --- a/engine/packages/pegboard/src/ops/actor/hibernating_request/upsert.rs +++ b/engine/packages/pegboard/src/ops/actor/hibernating_request/upsert.rs @@ -1,14 +1,14 @@ use gas::prelude::*; +use rivet_runner_protocol as protocol; use universaldb::utils::IsolationLevel::*; use crate::keys; -use crate::tunnel::id::{GatewayId, RequestId}; #[derive(Debug, Default)] pub struct Input { pub actor_id: Id, - pub gateway_id: GatewayId, - pub request_id: RequestId, + pub gateway_id: protocol::GatewayId, + pub request_id: protocol::RequestId, } #[operation] diff --git a/engine/packages/pegboard/src/pubsub_subjects.rs b/engine/packages/pegboard/src/pubsub_subjects.rs index 27f40dba5b..79b3286b4c 100644 --- a/engine/packages/pegboard/src/pubsub_subjects.rs +++ b/engine/packages/pegboard/src/pubsub_subjects.rs @@ -1,6 +1,5 @@ use gas::prelude::*; - -use crate::tunnel::id as tunnel_id; +use rivet_runner_protocol as protocol; pub struct RunnerReceiverSubject { runner_id: Id, @@ -61,11 +60,11 @@ impl std::fmt::Display for RunnerEvictionByNameSubject { } pub struct GatewayReceiverSubject { - gateway_id: tunnel_id::GatewayId, + gateway_id: protocol::GatewayId, } impl GatewayReceiverSubject { - pub fn new(gateway_id: tunnel_id::GatewayId) -> Self { + pub fn new(gateway_id: protocol::GatewayId) -> Self { Self { gateway_id } } } @@ -75,7 +74,7 @@ impl std::fmt::Display for GatewayReceiverSubject { write!( f, "pegboard.gateway.{}", - tunnel_id::gateway_id_to_string(&self.gateway_id) + protocol::util::id_to_string(&self.gateway_id) ) } } diff --git a/engine/packages/pegboard/src/tunnel/id.rs b/engine/packages/pegboard/src/tunnel/id.rs deleted file mode 100644 index 67a1582e39..0000000000 --- a/engine/packages/pegboard/src/tunnel/id.rs +++ /dev/null @@ -1,86 +0,0 @@ -use anyhow::{Context, Result, ensure}; -use base64::{Engine, engine::general_purpose::STANDARD as BASE64}; -use rivet_runner_protocol as protocol; - -// Type aliases for the message ID components -pub type GatewayId = [u8; 4]; -pub type RequestId = [u8; 4]; -pub type MessageIndex = u16; -pub type MessageId = [u8; 10]; - -/// Generate a new 4-byte gateway ID from a random u32 -pub fn generate_gateway_id() -> GatewayId { - rand::random::().to_le_bytes() -} - -/// Build a MessageId from its components -pub fn build_message_id( - gateway_id: GatewayId, - request_id: RequestId, - message_index: MessageIndex, -) -> Result { - let parts = protocol::MessageIdParts { - gateway_id, - request_id, - message_index, - }; - - // Serialize directly to a fixed-size buffer on the stack - let mut message_id = [0u8; 10]; - let mut cursor = std::io::Cursor::new(&mut message_id[..]); - serde_bare::to_writer(&mut cursor, &parts).context("failed to serialize message id parts")?; - - // Verify we wrote exactly 10 bytes - let written = cursor.position() as usize; - ensure!( - written == 10, - "message id serialization produced wrong size: expected 10 bytes, got {}", - written - ); - - Ok(message_id) -} - -/// Parse a MessageId into its components -pub fn parse_message_id(message_id: MessageId) -> Result { - serde_bare::from_slice(&message_id).context("failed to deserialize message id") -} - -/// Convert a GatewayId to a base64 string -pub fn gateway_id_to_string(gateway_id: &GatewayId) -> String { - BASE64.encode(gateway_id) -} - -/// Parse a GatewayId from a base64 string -pub fn gateway_id_from_string(s: &str) -> Result { - let bytes = BASE64.decode(s).context("failed to decode base64")?; - let gateway_id: GatewayId = bytes.try_into().map_err(|v: Vec| { - anyhow::anyhow!( - "invalid gateway id length: expected 4 bytes, got {}", - v.len() - ) - })?; - Ok(gateway_id) -} - -/// Generate a new 4-byte request ID from a random u32 -pub fn generate_request_id() -> RequestId { - rand::random::().to_le_bytes() -} - -/// Convert a RequestId to a base64 string -pub fn request_id_to_string(request_id: &RequestId) -> String { - BASE64.encode(request_id) -} - -/// Parse a RequestId from a base64 string -pub fn request_id_from_string(s: &str) -> Result { - let bytes = BASE64.decode(s).context("failed to decode base64")?; - let request_id: RequestId = bytes.try_into().map_err(|v: Vec| { - anyhow::anyhow!( - "invalid request id length: expected 4 bytes, got {}", - v.len() - ) - })?; - Ok(request_id) -} diff --git a/engine/packages/pegboard/src/tunnel/mod.rs b/engine/packages/pegboard/src/tunnel/mod.rs deleted file mode 100644 index fd6bb6c432..0000000000 --- a/engine/packages/pegboard/src/tunnel/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod id; diff --git a/engine/sdks/rust/runner-protocol/Cargo.toml b/engine/sdks/rust/runner-protocol/Cargo.toml index 65d781284f..c137ca9833 100644 --- a/engine/sdks/rust/runner-protocol/Cargo.toml +++ b/engine/sdks/rust/runner-protocol/Cargo.toml @@ -7,8 +7,9 @@ edition.workspace = true [dependencies] anyhow.workspace = true -base64.workspace = true gas.workspace = true +hex.workspace = true +rand.workspace = true rivet-util.workspace = true serde_bare.workspace = true serde.workspace = true diff --git a/engine/sdks/rust/runner-protocol/src/lib.rs b/engine/sdks/rust/runner-protocol/src/lib.rs index 2c0151c639..370d763d36 100644 --- a/engine/sdks/rust/runner-protocol/src/lib.rs +++ b/engine/sdks/rust/runner-protocol/src/lib.rs @@ -1,5 +1,6 @@ pub mod compat; pub mod generated; +pub mod util; pub mod versioned; // Re-export latest diff --git a/engine/sdks/rust/runner-protocol/src/util.rs b/engine/sdks/rust/runner-protocol/src/util.rs new file mode 100644 index 0000000000..9a4034963e --- /dev/null +++ b/engine/sdks/rust/runner-protocol/src/util.rs @@ -0,0 +1,14 @@ +/// Generate a new 4-byte gateway ID from a random u32 +pub fn generate_gateway_id() -> crate::GatewayId { + rand::random::().to_le_bytes() +} + +/// Generate a new 4-byte request ID from a random u32 +pub fn generate_request_id() -> crate::RequestId { + rand::random::().to_le_bytes() +} + +/// Convert a GatewayId to a hex string +pub fn id_to_string(gateway_id: &crate::GatewayId) -> String { + hex::encode(gateway_id) +} diff --git a/engine/sdks/rust/runner-protocol/src/versioned.rs b/engine/sdks/rust/runner-protocol/src/versioned.rs index a832939f2f..2e4ac0f9c5 100644 --- a/engine/sdks/rust/runner-protocol/src/versioned.rs +++ b/engine/sdks/rust/runner-protocol/src/versioned.rs @@ -177,11 +177,17 @@ impl ToClient { // Extract v3 message_id from v2's message_id // v3: gateway_id (4) + request_id (4) + message_index (2) = 10 bytes // v2.message_id contains: entire v3 message_id (10 bytes) + padding (6 bytes) - let mut message_id = [0u8; 10]; - message_id.copy_from_slice(&msg.message_id[..10]); + let mut gateway_id = [0u8; 4]; + gateway_id.copy_from_slice(&msg.message_id[..4]); + let mut request_id = [0u8; 4]; + request_id.copy_from_slice(&msg.request_id[..4]); v3::ToClient::ToClientTunnelMessage(v3::ToClientTunnelMessage { - message_id, + message_id: v3::MessageId { + gateway_id, + request_id, + message_index: 0, + }, message_kind: convert_to_client_tunnel_message_kind_v2_to_v3( msg.message_kind, ), @@ -252,8 +258,10 @@ impl ToClient { // v2.message_id = entire v3 message_id (10 bytes) + padding (4 zeros) let mut request_id = [0u8; 16]; let mut message_id = [0u8; 16]; - request_id[..8].copy_from_slice(&msg.message_id[..8]); // gateway_id + request_id - message_id[..10].copy_from_slice(&msg.message_id); // entire v3 message_id + request_id[..4].copy_from_slice(&msg.message_id.gateway_id); + request_id[4..8].copy_from_slice(&msg.message_id.request_id); + message_id[..8].copy_from_slice(&request_id[0..8]); + request_id[8..10].copy_from_slice(&msg.message_id.message_index.to_le_bytes()); v2::ToClient::ToClientTunnelMessage(v2::ToClientTunnelMessage { request_id, @@ -505,11 +513,17 @@ impl ToServer { // Extract v3 message_id from v2's message_id // v3: gateway_id (4) + request_id (4) + message_index (2) = 10 bytes // v2.message_id contains: entire v3 message_id (10 bytes) + padding (6 bytes) - let mut message_id = [0u8; 10]; - message_id.copy_from_slice(&msg.message_id[..10]); + let mut gateway_id = [0u8; 4]; + gateway_id.copy_from_slice(&msg.message_id[..4]); + let mut request_id = [0u8; 4]; + request_id.copy_from_slice(&msg.request_id[..4]); v3::ToServer::ToServerTunnelMessage(v3::ToServerTunnelMessage { - message_id, + message_id: v3::MessageId { + gateway_id, + request_id, + message_index: 0, + }, message_kind: convert_to_server_tunnel_message_kind_v2_to_v3( msg.message_kind, ), @@ -577,8 +591,10 @@ impl ToServer { // v2.message_id = entire v3 message_id (10 bytes) + padding (4 zeros) let mut request_id = [0u8; 16]; let mut message_id = [0u8; 16]; - request_id[..8].copy_from_slice(&msg.message_id[..8]); // gateway_id + request_id - message_id[..10].copy_from_slice(&msg.message_id); // entire v3 message_id + request_id[..4].copy_from_slice(&msg.message_id.gateway_id); + request_id[4..8].copy_from_slice(&msg.message_id.request_id); + message_id[..8].copy_from_slice(&request_id[0..8]); + request_id[8..10].copy_from_slice(&msg.message_id.message_index.to_le_bytes()); v2::ToServer::ToServerTunnelMessage(v2::ToServerTunnelMessage { request_id, @@ -1262,7 +1278,7 @@ fn convert_to_client_tunnel_message_kind_v2_to_v3( fn convert_to_client_tunnel_message_kind_v3_to_v2( kind: v3::ToClientTunnelMessageKind, - message_id: &[u8; 10], + message_id: &v3::MessageId, ) -> Result { Ok(match kind { v3::ToClientTunnelMessageKind::ToClientRequestStart(req) => { @@ -1292,12 +1308,10 @@ fn convert_to_client_tunnel_message_kind_v3_to_v2( }) } v3::ToClientTunnelMessageKind::ToClientWebSocketMessage(msg) => { - // Extract message index from message_id (bytes 8-9, u16 little-endian per BARE spec) - let index = u16::from_le_bytes([message_id[8], message_id[9]]); v2::ToClientTunnelMessageKind::ToClientWebSocketMessage(v2::ToClientWebSocketMessage { data: msg.data, binary: msg.binary, - index, + index: message_id.message_index, }) } v3::ToClientTunnelMessageKind::ToClientWebSocketClose(close) => {