From 72f5f460b9ade12ce4615edf7f4067378c779780 Mon Sep 17 00:00:00 2001 From: MasterPtato Date: Tue, 2 Dec 2025 10:53:47 -0800 Subject: [PATCH] fix(serverless): drain runners after url/headers change --- Cargo.lock | 1 + engine/packages/pegboard/Cargo.toml | 1 + .../src/workflows/serverless/connection.rs | 31 ++++-- .../pegboard/src/workflows/serverless/pool.rs | 96 ++++++++++++++----- .../src/workflows/serverless/runner.rs | 6 +- 5 files changed, 99 insertions(+), 36 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b63d0245b1..f4262236de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3422,6 +3422,7 @@ dependencies = [ "rivet-error", "rivet-metrics", "rivet-runner-protocol", + "rivet-runtime", "rivet-types", "rivet-util", "serde", diff --git a/engine/packages/pegboard/Cargo.toml b/engine/packages/pegboard/Cargo.toml index 21a6d8c2ea..44f012147b 100644 --- a/engine/packages/pegboard/Cargo.toml +++ b/engine/packages/pegboard/Cargo.toml @@ -22,6 +22,7 @@ rivet-data.workspace = true rivet-error.workspace = true rivet-metrics.workspace = true rivet-runner-protocol.workspace = true +rivet-runtime.workspace = true rivet-types.workspace = true rivet-util.workspace = true serde_bare.workspace = true diff --git a/engine/packages/pegboard/src/workflows/serverless/connection.rs b/engine/packages/pegboard/src/workflows/serverless/connection.rs index 6d3b575a6a..b027aed542 100644 --- a/engine/packages/pegboard/src/workflows/serverless/connection.rs +++ b/engine/packages/pegboard/src/workflows/serverless/connection.rs @@ -1,3 +1,5 @@ +use std::time::Instant; + use anyhow::Context; use base64::{Engine, engine::general_purpose::STANDARD as BASE64}; use futures_util::{FutureExt, StreamExt}; @@ -5,8 +7,8 @@ use gas::prelude::*; use reqwest::header::{HeaderName, HeaderValue}; use reqwest_eventsource as sse; use rivet_runner_protocol as protocol; +use rivet_runtime::TermSignal; use rivet_types::runner_configs::RunnerConfigKind; -use std::time::Instant; use tokio::time::Duration; use universalpubsub::PublishOpts; use vbare::OwnedVersionedData; @@ -79,12 +81,12 @@ pub async fn pegboard_serverless_connection(ctx: &mut WorkflowCtx, input: &Input let next = backoff.step().expect("should not have max retry"); if let Some(_sig) = ctx - .listen_with_timeout::(Instant::from(next) - Instant::now()) + .listen_with_timeout::(Instant::from(next) - Instant::now()) .await? { tracing::debug!("drain received during serverless connection backoff"); - // Notify parent that drain started + // Notify pool that drain started return Ok(Loop::Break(true)); } @@ -172,8 +174,9 @@ async fn outbound_req_inner( }); } + let mut term_signal = TermSignal::new().await; let mut drain_sub = ctx - .subscribe::(("workflow_id", ctx.workflow_id())) + .subscribe::(("workflow_id", ctx.workflow_id())) .await?; let (runner_config_res, namespace_res) = tokio::try_join!( @@ -331,6 +334,7 @@ async fn outbound_req_inner( }, _ = tokio::time::sleep(sleep_until_drain) => {} _ = drain_sub.next() => {} + _ = term_signal.recv() => {} }; tracing::debug!(?runner_id, "connection reached lifespan, needs draining"); @@ -361,7 +365,8 @@ async fn outbound_req_inner( // After we tell the pool we're draining, any remaining failures // don't matter as the pool already stopped caring about us. if let Err(err) = - finish_non_critical_draining(ctx, source, runner_id, runner_protocol_version).await + finish_non_critical_draining(ctx, term_signal, source, runner_id, runner_protocol_version) + .await { tracing::debug!(?err, "failed non critical draining phase"); } @@ -371,6 +376,9 @@ async fn outbound_req_inner( }) } +/// Reads from the adjacent serverless runner wf which is keeping track of signals while this workflow runs +/// outbound requests. +#[tracing::instrument(skip_all)] async fn is_runner_draining(ctx: &ActivityCtx, runner_wf_id: Id) -> Result { let runner_wf = ctx .get_workflows(vec![runner_wf_id]) @@ -383,8 +391,10 @@ async fn is_runner_draining(ctx: &ActivityCtx, runner_wf_id: Id) -> Result Ok(state.is_draining) } +#[tracing::instrument(skip_all)] async fn finish_non_critical_draining( ctx: &ActivityCtx, + mut term_signal: TermSignal, mut source: sse::EventSource, mut runner_id: Option, mut runner_protocol_version: Option, @@ -437,6 +447,7 @@ async fn finish_non_critical_draining( _ = tokio::time::sleep(DRAIN_GRACE_PERIOD) => { tracing::debug!(?runner_id, "reached drain grace period before runner shut down") } + _ = term_signal.recv() => {} } // Close connection @@ -452,6 +463,7 @@ async fn finish_non_critical_draining( Ok(()) } +#[tracing::instrument(skip_all)] async fn drain_runner(ctx: &ActivityCtx, runner_id: Id) -> Result<()> { let res = ctx .signal(crate::workflows::runner::Stop { @@ -492,6 +504,7 @@ async fn drain_runner(ctx: &ActivityCtx, runner_id: Id) -> Result<()> { /// Send a stop message to the client. /// /// This will close the runner's WebSocket. +#[tracing::instrument(skip_all)] async fn publish_to_client_stop( ctx: &ActivityCtx, runner_id: Id, @@ -514,11 +527,9 @@ async fn publish_to_client_stop( Ok(()) } -#[message("pegboard_serverless_connection_drain_msg")] -pub struct DrainMessage {} - -#[signal("pegboard_serverless_connection_drain_sig")] -pub struct DrainSignal {} +#[message("pegboard_serverless_connection_drain")] +#[signal("pegboard_serverless_connection_drain")] +pub struct Drain {} fn reconnect_backoff( retry_count: usize, diff --git a/engine/packages/pegboard/src/workflows/serverless/pool.rs b/engine/packages/pegboard/src/workflows/serverless/pool.rs index 1b3cb422ed..ec657c0744 100644 --- a/engine/packages/pegboard/src/workflows/serverless/pool.rs +++ b/engine/packages/pegboard/src/workflows/serverless/pool.rs @@ -1,3 +1,5 @@ +use std::hash::{DefaultHasher, Hash, Hasher}; + use futures_util::FutureExt; use gas::{db::WorkflowData, prelude::*}; use rivet_types::{keys, runner_configs::RunnerConfigKind}; @@ -12,7 +14,14 @@ pub struct Input { #[derive(Debug, Serialize, Deserialize, Default)] struct LifecycleState { - runners: Vec, + runners: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct RunnerState { + /// Serverless runner wf id, not normal runner wf id. + runner_wf_id: Id, + details_hash: u64, } #[workflow] @@ -20,46 +29,64 @@ pub async fn pegboard_serverless_pool(ctx: &mut WorkflowCtx, input: &Input) -> R ctx.loope(LifecycleState::default(), |ctx, state| { let input = input.clone(); async move { - // 1. Remove completed connections + // Get desired count -> drain and start counts + let ReadDesiredOutput::Desired { + desired_count, + details_hash, + } = ctx.activity(ReadDesiredInput { + namespace_id: input.namespace_id, + runner_name: input.runner_name.clone(), + }) + .await? + else { + return Ok(Loop::Break(())); + }; + let completed_runners = ctx .activity(GetCompletedInput { - runners: state.runners.clone(), + runners: state.runners.iter().map(|r| r.runner_wf_id).collect(), }) .await?; - state.runners.retain(|r| !completed_runners.contains(r)); - - // 2. Get desired count -> drain and start counts - let ReadDesiredOutput::Desired(desired_count) = ctx - .activity(ReadDesiredInput { - namespace_id: input.namespace_id, - runner_name: input.runner_name.clone(), - }) - .await? - else { - return Ok(Loop::Break(())); - }; + // Remove completed connections + state + .runners + .retain(|r| !completed_runners.contains(&r.runner_wf_id)); + + // Remove runners that have an outdated hash. This is done outside of the below draining mechanism + // because we drain specific runners, not just a number of runners + let (new, outdated) = std::mem::take(&mut state.runners) + .into_iter() + .partition::, _>(|r| r.details_hash == details_hash); + state.runners = new; + + for runner in outdated { + ctx.signal(runner::Drain {}) + .to_workflow_id(runner.runner_wf_id) + .send() + .await?; + } let drain_count = state.runners.len().saturating_sub(desired_count); let start_count = desired_count.saturating_sub(state.runners.len()); - // 3. Drain old runners + // Drain unnecessary runners if drain_count != 0 { // TODO: Implement smart logic of draining runners with the lowest allocated actors let draining_runners = state.runners.iter().take(drain_count).collect::>(); - for wf_id in draining_runners { + for runner in draining_runners { ctx.signal(runner::Drain {}) - .to_workflow_id(*wf_id) + .to_workflow_id(runner.runner_wf_id) .send() .await?; } } - // 4. Dispatch new runner workflows + // Dispatch new runner workflows if start_count != 0 { for _ in 0..start_count { - let wf_id = ctx + let runner_wf_id = ctx .workflow(runner::Input { pool_wf_id: ctx.workflow_id(), namespace_id: input.namespace_id, @@ -70,14 +97,17 @@ pub async fn pegboard_serverless_pool(ctx: &mut WorkflowCtx, input: &Input) -> R .dispatch() .await?; - state.runners.push(wf_id); + state.runners.push(RunnerState { + runner_wf_id, + details_hash, + }); } } // Wait for Bump or runner update signals until we tick again match ctx.listen::
().await? { Main::RunnerDrainStarted(sig) => { - state.runners.retain(|wf_id| *wf_id != sig.runner_wf_id); + state.runners.retain(|r| r.runner_wf_id != sig.runner_wf_id); } Main::Bump(_) => {} } @@ -102,6 +132,7 @@ async fn get_completed(ctx: &ActivityCtx, input: &GetCompletedInput) -> Result Result Result>(); + sorted_headers.sort(); + sorted_headers.hash(&mut hasher); + let details_hash = hasher.finish(); + + Ok(ReadDesiredOutput::Desired { + desired_count, + details_hash, + }) } #[signal("pegboard_serverless_bump")] diff --git a/engine/packages/pegboard/src/workflows/serverless/runner.rs b/engine/packages/pegboard/src/workflows/serverless/runner.rs index 6fd82161ce..55eab09f75 100644 --- a/engine/packages/pegboard/src/workflows/serverless/runner.rs +++ b/engine/packages/pegboard/src/workflows/serverless/runner.rs @@ -20,6 +20,8 @@ impl State { } } +/// Runs alongside the connection workflow to process signals. This is required because the connection +/// workflow cannot listen for signals while in an activity. #[workflow] pub async fn pegboard_serverless_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> { ctx.activity(InitStateInput {}).await?; @@ -38,12 +40,12 @@ pub async fn pegboard_serverless_runner(ctx: &mut WorkflowCtx, input: &Input) -> ctx.activity(MarkAsDrainingInput {}).await?; - ctx.signal(connection::DrainSignal {}) + ctx.signal(connection::Drain {}) .to_workflow_id(conn_wf_id) .send() .await?; - ctx.msg(connection::DrainMessage {}) + ctx.msg(connection::Drain {}) .tag("workflow_id", conn_wf_id) .send() .await?;