Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions engine/packages/pegboard/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ rivet-data.workspace = true
rivet-error.workspace = true
rivet-metrics.workspace = true
rivet-runner-protocol.workspace = true
rivet-runtime.workspace = true
rivet-types.workspace = true
rivet-util.workspace = true
serde_bare.workspace = true
Expand Down
31 changes: 21 additions & 10 deletions engine/packages/pegboard/src/workflows/serverless/connection.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::time::Instant;

use anyhow::Context;
use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
use futures_util::{FutureExt, StreamExt};
use gas::prelude::*;
use reqwest::header::{HeaderName, HeaderValue};
use reqwest_eventsource as sse;
use rivet_runner_protocol as protocol;
use rivet_runtime::TermSignal;
use rivet_types::runner_configs::RunnerConfigKind;
use std::time::Instant;
use tokio::time::Duration;
use universalpubsub::PublishOpts;
use vbare::OwnedVersionedData;
Expand Down Expand Up @@ -79,12 +81,12 @@ pub async fn pegboard_serverless_connection(ctx: &mut WorkflowCtx, input: &Input

let next = backoff.step().expect("should not have max retry");
if let Some(_sig) = ctx
.listen_with_timeout::<DrainSignal>(Instant::from(next) - Instant::now())
.listen_with_timeout::<Drain>(Instant::from(next) - Instant::now())
.await?
{
tracing::debug!("drain received during serverless connection backoff");

// Notify parent that drain started
// Notify pool that drain started
return Ok(Loop::Break(true));
}

Expand Down Expand Up @@ -172,8 +174,9 @@ async fn outbound_req_inner(
});
}

let mut term_signal = TermSignal::new().await;
let mut drain_sub = ctx
.subscribe::<DrainMessage>(("workflow_id", ctx.workflow_id()))
.subscribe::<Drain>(("workflow_id", ctx.workflow_id()))
.await?;

let (runner_config_res, namespace_res) = tokio::try_join!(
Expand Down Expand Up @@ -331,6 +334,7 @@ async fn outbound_req_inner(
},
_ = tokio::time::sleep(sleep_until_drain) => {}
_ = drain_sub.next() => {}
_ = term_signal.recv() => {}
};

tracing::debug!(?runner_id, "connection reached lifespan, needs draining");
Expand Down Expand Up @@ -361,7 +365,8 @@ async fn outbound_req_inner(
// After we tell the pool we're draining, any remaining failures
// don't matter as the pool already stopped caring about us.
if let Err(err) =
finish_non_critical_draining(ctx, source, runner_id, runner_protocol_version).await
finish_non_critical_draining(ctx, term_signal, source, runner_id, runner_protocol_version)
.await
{
tracing::debug!(?err, "failed non critical draining phase");
}
Expand All @@ -371,6 +376,9 @@ async fn outbound_req_inner(
})
}

/// Reads from the adjacent serverless runner wf which is keeping track of signals while this workflow runs
/// outbound requests.
#[tracing::instrument(skip_all)]
async fn is_runner_draining(ctx: &ActivityCtx, runner_wf_id: Id) -> Result<bool> {
let runner_wf = ctx
.get_workflows(vec![runner_wf_id])
Expand All @@ -383,8 +391,10 @@ async fn is_runner_draining(ctx: &ActivityCtx, runner_wf_id: Id) -> Result<bool>
Ok(state.is_draining)
}

#[tracing::instrument(skip_all)]
async fn finish_non_critical_draining(
ctx: &ActivityCtx,
mut term_signal: TermSignal,
mut source: sse::EventSource,
mut runner_id: Option<Id>,
mut runner_protocol_version: Option<u16>,
Expand Down Expand Up @@ -437,6 +447,7 @@ async fn finish_non_critical_draining(
_ = tokio::time::sleep(DRAIN_GRACE_PERIOD) => {
tracing::debug!(?runner_id, "reached drain grace period before runner shut down")
}
_ = term_signal.recv() => {}
}

// Close connection
Expand All @@ -452,6 +463,7 @@ async fn finish_non_critical_draining(
Ok(())
}

#[tracing::instrument(skip_all)]
async fn drain_runner(ctx: &ActivityCtx, runner_id: Id) -> Result<()> {
let res = ctx
.signal(crate::workflows::runner::Stop {
Expand Down Expand Up @@ -492,6 +504,7 @@ async fn drain_runner(ctx: &ActivityCtx, runner_id: Id) -> Result<()> {
/// Send a stop message to the client.
///
/// This will close the runner's WebSocket.
#[tracing::instrument(skip_all)]
async fn publish_to_client_stop(
ctx: &ActivityCtx,
runner_id: Id,
Expand All @@ -514,11 +527,9 @@ async fn publish_to_client_stop(
Ok(())
}

#[message("pegboard_serverless_connection_drain_msg")]
pub struct DrainMessage {}

#[signal("pegboard_serverless_connection_drain_sig")]
pub struct DrainSignal {}
#[message("pegboard_serverless_connection_drain")]
#[signal("pegboard_serverless_connection_drain")]
pub struct Drain {}

fn reconnect_backoff(
retry_count: usize,
Expand Down
96 changes: 72 additions & 24 deletions engine/packages/pegboard/src/workflows/serverless/pool.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::hash::{DefaultHasher, Hash, Hasher};

use futures_util::FutureExt;
use gas::{db::WorkflowData, prelude::*};
use rivet_types::{keys, runner_configs::RunnerConfigKind};
Expand All @@ -12,54 +14,79 @@ pub struct Input {

#[derive(Debug, Serialize, Deserialize, Default)]
struct LifecycleState {
runners: Vec<Id>,
runners: Vec<RunnerState>,
}

#[derive(Debug, Serialize, Deserialize)]
struct RunnerState {
/// Serverless runner wf id, not normal runner wf id.
runner_wf_id: Id,
details_hash: u64,
}

#[workflow]
pub async fn pegboard_serverless_pool(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> {
ctx.loope(LifecycleState::default(), |ctx, state| {
let input = input.clone();
async move {
// 1. Remove completed connections
// Get desired count -> drain and start counts
let ReadDesiredOutput::Desired {
desired_count,
details_hash,
} = ctx.activity(ReadDesiredInput {
namespace_id: input.namespace_id,
runner_name: input.runner_name.clone(),
})
.await?
else {
return Ok(Loop::Break(()));
};

let completed_runners = ctx
.activity(GetCompletedInput {
runners: state.runners.clone(),
runners: state.runners.iter().map(|r| r.runner_wf_id).collect(),
})
.await?;

state.runners.retain(|r| !completed_runners.contains(r));

// 2. Get desired count -> drain and start counts
let ReadDesiredOutput::Desired(desired_count) = ctx
.activity(ReadDesiredInput {
namespace_id: input.namespace_id,
runner_name: input.runner_name.clone(),
})
.await?
else {
return Ok(Loop::Break(()));
};
// Remove completed connections
state
.runners
.retain(|r| !completed_runners.contains(&r.runner_wf_id));

// Remove runners that have an outdated hash. This is done outside of the below draining mechanism
// because we drain specific runners, not just a number of runners
let (new, outdated) = std::mem::take(&mut state.runners)
.into_iter()
.partition::<Vec<_>, _>(|r| r.details_hash == details_hash);
state.runners = new;

for runner in outdated {
ctx.signal(runner::Drain {})
.to_workflow_id(runner.runner_wf_id)
.send()
.await?;
}

let drain_count = state.runners.len().saturating_sub(desired_count);
let start_count = desired_count.saturating_sub(state.runners.len());

// 3. Drain old runners
// Drain unnecessary runners
if drain_count != 0 {
// TODO: Implement smart logic of draining runners with the lowest allocated actors
let draining_runners = state.runners.iter().take(drain_count).collect::<Vec<_>>();

for wf_id in draining_runners {
for runner in draining_runners {
ctx.signal(runner::Drain {})
.to_workflow_id(*wf_id)
.to_workflow_id(runner.runner_wf_id)
.send()
.await?;
}
}

// 4. Dispatch new runner workflows
// Dispatch new runner workflows
if start_count != 0 {
for _ in 0..start_count {
let wf_id = ctx
let runner_wf_id = ctx
.workflow(runner::Input {
pool_wf_id: ctx.workflow_id(),
namespace_id: input.namespace_id,
Expand All @@ -70,14 +97,17 @@ pub async fn pegboard_serverless_pool(ctx: &mut WorkflowCtx, input: &Input) -> R
.dispatch()
.await?;

state.runners.push(wf_id);
state.runners.push(RunnerState {
runner_wf_id,
details_hash,
});
}
}

// Wait for Bump or runner update signals until we tick again
match ctx.listen::<Main>().await? {
Main::RunnerDrainStarted(sig) => {
state.runners.retain(|wf_id| *wf_id != sig.runner_wf_id);
state.runners.retain(|r| r.runner_wf_id != sig.runner_wf_id);
}
Main::Bump(_) => {}
}
Expand All @@ -102,6 +132,7 @@ async fn get_completed(ctx: &ActivityCtx, input: &GetCompletedInput) -> Result<V
.get_workflows(input.runners.clone())
.await?
.into_iter()
// When a workflow has output, it means it has completed
.filter(WorkflowData::has_output)
.map(|wf| wf.workflow_id)
.collect())
Expand All @@ -115,7 +146,10 @@ struct ReadDesiredInput {

#[derive(Debug, Serialize, Deserialize)]
enum ReadDesiredOutput {
Desired(usize),
Desired {
desired_count: usize,
details_hash: u64,
},
Stop,
}

Expand All @@ -132,6 +166,9 @@ async fn read_desired(ctx: &ActivityCtx, input: &ReadDesiredInput) -> Result<Rea
};

let RunnerConfigKind::Serverless {
url,
headers,

slots_per_runner,
min_runners,
max_runners,
Expand Down Expand Up @@ -177,7 +214,18 @@ async fn read_desired(ctx: &ActivityCtx, input: &ReadDesiredInput) -> Result<Rea
.min(max_runners)
.try_into()?;

Ok(ReadDesiredOutput::Desired(desired_count))
// Compute consistent hash of serverless details
let mut hasher = DefaultHasher::new();
url.hash(&mut hasher);
let mut sorted_headers = headers.iter().collect::<Vec<_>>();
sorted_headers.sort();
sorted_headers.hash(&mut hasher);
let details_hash = hasher.finish();

Ok(ReadDesiredOutput::Desired {
desired_count,
details_hash,
})
}

#[signal("pegboard_serverless_bump")]
Expand Down
6 changes: 4 additions & 2 deletions engine/packages/pegboard/src/workflows/serverless/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ impl State {
}
}

/// Runs alongside the connection workflow to process signals. This is required because the connection
/// workflow cannot listen for signals while in an activity.
#[workflow]
pub async fn pegboard_serverless_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> {
ctx.activity(InitStateInput {}).await?;
Expand All @@ -38,12 +40,12 @@ pub async fn pegboard_serverless_runner(ctx: &mut WorkflowCtx, input: &Input) ->

ctx.activity(MarkAsDrainingInput {}).await?;

ctx.signal(connection::DrainSignal {})
ctx.signal(connection::Drain {})
.to_workflow_id(conn_wf_id)
.send()
.await?;

ctx.msg(connection::DrainMessage {})
ctx.msg(connection::Drain {})
.tag("workflow_id", conn_wf_id)
.send()
.await?;
Expand Down
Loading