Skip to content

Commit 72f5f46

Browse files
committed
fix(serverless): drain runners after url/headers change
1 parent 9145736 commit 72f5f46

File tree

5 files changed

+99
-36
lines changed

5 files changed

+99
-36
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/packages/pegboard/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ rivet-data.workspace = true
2222
rivet-error.workspace = true
2323
rivet-metrics.workspace = true
2424
rivet-runner-protocol.workspace = true
25+
rivet-runtime.workspace = true
2526
rivet-types.workspace = true
2627
rivet-util.workspace = true
2728
serde_bare.workspace = true

engine/packages/pegboard/src/workflows/serverless/connection.rs

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
use std::time::Instant;
2+
13
use anyhow::Context;
24
use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
35
use futures_util::{FutureExt, StreamExt};
46
use gas::prelude::*;
57
use reqwest::header::{HeaderName, HeaderValue};
68
use reqwest_eventsource as sse;
79
use rivet_runner_protocol as protocol;
10+
use rivet_runtime::TermSignal;
811
use rivet_types::runner_configs::RunnerConfigKind;
9-
use std::time::Instant;
1012
use tokio::time::Duration;
1113
use universalpubsub::PublishOpts;
1214
use vbare::OwnedVersionedData;
@@ -79,12 +81,12 @@ pub async fn pegboard_serverless_connection(ctx: &mut WorkflowCtx, input: &Input
7981

8082
let next = backoff.step().expect("should not have max retry");
8183
if let Some(_sig) = ctx
82-
.listen_with_timeout::<DrainSignal>(Instant::from(next) - Instant::now())
84+
.listen_with_timeout::<Drain>(Instant::from(next) - Instant::now())
8385
.await?
8486
{
8587
tracing::debug!("drain received during serverless connection backoff");
8688

87-
// Notify parent that drain started
89+
// Notify pool that drain started
8890
return Ok(Loop::Break(true));
8991
}
9092

@@ -172,8 +174,9 @@ async fn outbound_req_inner(
172174
});
173175
}
174176

177+
let mut term_signal = TermSignal::new().await;
175178
let mut drain_sub = ctx
176-
.subscribe::<DrainMessage>(("workflow_id", ctx.workflow_id()))
179+
.subscribe::<Drain>(("workflow_id", ctx.workflow_id()))
177180
.await?;
178181

179182
let (runner_config_res, namespace_res) = tokio::try_join!(
@@ -331,6 +334,7 @@ async fn outbound_req_inner(
331334
},
332335
_ = tokio::time::sleep(sleep_until_drain) => {}
333336
_ = drain_sub.next() => {}
337+
_ = term_signal.recv() => {}
334338
};
335339

336340
tracing::debug!(?runner_id, "connection reached lifespan, needs draining");
@@ -361,7 +365,8 @@ async fn outbound_req_inner(
361365
// After we tell the pool we're draining, any remaining failures
362366
// don't matter as the pool already stopped caring about us.
363367
if let Err(err) =
364-
finish_non_critical_draining(ctx, source, runner_id, runner_protocol_version).await
368+
finish_non_critical_draining(ctx, term_signal, source, runner_id, runner_protocol_version)
369+
.await
365370
{
366371
tracing::debug!(?err, "failed non critical draining phase");
367372
}
@@ -371,6 +376,9 @@ async fn outbound_req_inner(
371376
})
372377
}
373378

379+
/// Reads from the adjacent serverless runner wf which is keeping track of signals while this workflow runs
380+
/// outbound requests.
381+
#[tracing::instrument(skip_all)]
374382
async fn is_runner_draining(ctx: &ActivityCtx, runner_wf_id: Id) -> Result<bool> {
375383
let runner_wf = ctx
376384
.get_workflows(vec![runner_wf_id])
@@ -383,8 +391,10 @@ async fn is_runner_draining(ctx: &ActivityCtx, runner_wf_id: Id) -> Result<bool>
383391
Ok(state.is_draining)
384392
}
385393

394+
#[tracing::instrument(skip_all)]
386395
async fn finish_non_critical_draining(
387396
ctx: &ActivityCtx,
397+
mut term_signal: TermSignal,
388398
mut source: sse::EventSource,
389399
mut runner_id: Option<Id>,
390400
mut runner_protocol_version: Option<u16>,
@@ -437,6 +447,7 @@ async fn finish_non_critical_draining(
437447
_ = tokio::time::sleep(DRAIN_GRACE_PERIOD) => {
438448
tracing::debug!(?runner_id, "reached drain grace period before runner shut down")
439449
}
450+
_ = term_signal.recv() => {}
440451
}
441452

442453
// Close connection
@@ -452,6 +463,7 @@ async fn finish_non_critical_draining(
452463
Ok(())
453464
}
454465

466+
#[tracing::instrument(skip_all)]
455467
async fn drain_runner(ctx: &ActivityCtx, runner_id: Id) -> Result<()> {
456468
let res = ctx
457469
.signal(crate::workflows::runner::Stop {
@@ -492,6 +504,7 @@ async fn drain_runner(ctx: &ActivityCtx, runner_id: Id) -> Result<()> {
492504
/// Send a stop message to the client.
493505
///
494506
/// This will close the runner's WebSocket.
507+
#[tracing::instrument(skip_all)]
495508
async fn publish_to_client_stop(
496509
ctx: &ActivityCtx,
497510
runner_id: Id,
@@ -514,11 +527,9 @@ async fn publish_to_client_stop(
514527
Ok(())
515528
}
516529

517-
#[message("pegboard_serverless_connection_drain_msg")]
518-
pub struct DrainMessage {}
519-
520-
#[signal("pegboard_serverless_connection_drain_sig")]
521-
pub struct DrainSignal {}
530+
#[message("pegboard_serverless_connection_drain")]
531+
#[signal("pegboard_serverless_connection_drain")]
532+
pub struct Drain {}
522533

523534
fn reconnect_backoff(
524535
retry_count: usize,

engine/packages/pegboard/src/workflows/serverless/pool.rs

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::hash::{DefaultHasher, Hash, Hasher};
2+
13
use futures_util::FutureExt;
24
use gas::{db::WorkflowData, prelude::*};
35
use rivet_types::{keys, runner_configs::RunnerConfigKind};
@@ -12,54 +14,79 @@ pub struct Input {
1214

1315
#[derive(Debug, Serialize, Deserialize, Default)]
1416
struct LifecycleState {
15-
runners: Vec<Id>,
17+
runners: Vec<RunnerState>,
18+
}
19+
20+
#[derive(Debug, Serialize, Deserialize)]
21+
struct RunnerState {
22+
/// Serverless runner wf id, not normal runner wf id.
23+
runner_wf_id: Id,
24+
details_hash: u64,
1625
}
1726

1827
#[workflow]
1928
pub async fn pegboard_serverless_pool(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> {
2029
ctx.loope(LifecycleState::default(), |ctx, state| {
2130
let input = input.clone();
2231
async move {
23-
// 1. Remove completed connections
32+
// Get desired count -> drain and start counts
33+
let ReadDesiredOutput::Desired {
34+
desired_count,
35+
details_hash,
36+
} = ctx.activity(ReadDesiredInput {
37+
namespace_id: input.namespace_id,
38+
runner_name: input.runner_name.clone(),
39+
})
40+
.await?
41+
else {
42+
return Ok(Loop::Break(()));
43+
};
44+
2445
let completed_runners = ctx
2546
.activity(GetCompletedInput {
26-
runners: state.runners.clone(),
47+
runners: state.runners.iter().map(|r| r.runner_wf_id).collect(),
2748
})
2849
.await?;
2950

30-
state.runners.retain(|r| !completed_runners.contains(r));
31-
32-
// 2. Get desired count -> drain and start counts
33-
let ReadDesiredOutput::Desired(desired_count) = ctx
34-
.activity(ReadDesiredInput {
35-
namespace_id: input.namespace_id,
36-
runner_name: input.runner_name.clone(),
37-
})
38-
.await?
39-
else {
40-
return Ok(Loop::Break(()));
41-
};
51+
// Remove completed connections
52+
state
53+
.runners
54+
.retain(|r| !completed_runners.contains(&r.runner_wf_id));
55+
56+
// Remove runners that have an outdated hash. This is done outside of the below draining mechanism
57+
// because we drain specific runners, not just a number of runners
58+
let (new, outdated) = std::mem::take(&mut state.runners)
59+
.into_iter()
60+
.partition::<Vec<_>, _>(|r| r.details_hash == details_hash);
61+
state.runners = new;
62+
63+
for runner in outdated {
64+
ctx.signal(runner::Drain {})
65+
.to_workflow_id(runner.runner_wf_id)
66+
.send()
67+
.await?;
68+
}
4269

4370
let drain_count = state.runners.len().saturating_sub(desired_count);
4471
let start_count = desired_count.saturating_sub(state.runners.len());
4572

46-
// 3. Drain old runners
73+
// Drain unnecessary runners
4774
if drain_count != 0 {
4875
// TODO: Implement smart logic of draining runners with the lowest allocated actors
4976
let draining_runners = state.runners.iter().take(drain_count).collect::<Vec<_>>();
5077

51-
for wf_id in draining_runners {
78+
for runner in draining_runners {
5279
ctx.signal(runner::Drain {})
53-
.to_workflow_id(*wf_id)
80+
.to_workflow_id(runner.runner_wf_id)
5481
.send()
5582
.await?;
5683
}
5784
}
5885

59-
// 4. Dispatch new runner workflows
86+
// Dispatch new runner workflows
6087
if start_count != 0 {
6188
for _ in 0..start_count {
62-
let wf_id = ctx
89+
let runner_wf_id = ctx
6390
.workflow(runner::Input {
6491
pool_wf_id: ctx.workflow_id(),
6592
namespace_id: input.namespace_id,
@@ -70,14 +97,17 @@ pub async fn pegboard_serverless_pool(ctx: &mut WorkflowCtx, input: &Input) -> R
7097
.dispatch()
7198
.await?;
7299

73-
state.runners.push(wf_id);
100+
state.runners.push(RunnerState {
101+
runner_wf_id,
102+
details_hash,
103+
});
74104
}
75105
}
76106

77107
// Wait for Bump or runner update signals until we tick again
78108
match ctx.listen::<Main>().await? {
79109
Main::RunnerDrainStarted(sig) => {
80-
state.runners.retain(|wf_id| *wf_id != sig.runner_wf_id);
110+
state.runners.retain(|r| r.runner_wf_id != sig.runner_wf_id);
81111
}
82112
Main::Bump(_) => {}
83113
}
@@ -102,6 +132,7 @@ async fn get_completed(ctx: &ActivityCtx, input: &GetCompletedInput) -> Result<V
102132
.get_workflows(input.runners.clone())
103133
.await?
104134
.into_iter()
135+
// When a workflow has output, it means it has completed
105136
.filter(WorkflowData::has_output)
106137
.map(|wf| wf.workflow_id)
107138
.collect())
@@ -115,7 +146,10 @@ struct ReadDesiredInput {
115146

116147
#[derive(Debug, Serialize, Deserialize)]
117148
enum ReadDesiredOutput {
118-
Desired(usize),
149+
Desired {
150+
desired_count: usize,
151+
details_hash: u64,
152+
},
119153
Stop,
120154
}
121155

@@ -132,6 +166,9 @@ async fn read_desired(ctx: &ActivityCtx, input: &ReadDesiredInput) -> Result<Rea
132166
};
133167

134168
let RunnerConfigKind::Serverless {
169+
url,
170+
headers,
171+
135172
slots_per_runner,
136173
min_runners,
137174
max_runners,
@@ -177,7 +214,18 @@ async fn read_desired(ctx: &ActivityCtx, input: &ReadDesiredInput) -> Result<Rea
177214
.min(max_runners)
178215
.try_into()?;
179216

180-
Ok(ReadDesiredOutput::Desired(desired_count))
217+
// Compute consistent hash of serverless details
218+
let mut hasher = DefaultHasher::new();
219+
url.hash(&mut hasher);
220+
let mut sorted_headers = headers.iter().collect::<Vec<_>>();
221+
sorted_headers.sort();
222+
sorted_headers.hash(&mut hasher);
223+
let details_hash = hasher.finish();
224+
225+
Ok(ReadDesiredOutput::Desired {
226+
desired_count,
227+
details_hash,
228+
})
181229
}
182230

183231
#[signal("pegboard_serverless_bump")]

engine/packages/pegboard/src/workflows/serverless/runner.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ impl State {
2020
}
2121
}
2222

23+
/// Runs alongside the connection workflow to process signals. This is required because the connection
24+
/// workflow cannot listen for signals while in an activity.
2325
#[workflow]
2426
pub async fn pegboard_serverless_runner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> {
2527
ctx.activity(InitStateInput {}).await?;
@@ -38,12 +40,12 @@ pub async fn pegboard_serverless_runner(ctx: &mut WorkflowCtx, input: &Input) ->
3840

3941
ctx.activity(MarkAsDrainingInput {}).await?;
4042

41-
ctx.signal(connection::DrainSignal {})
43+
ctx.signal(connection::Drain {})
4244
.to_workflow_id(conn_wf_id)
4345
.send()
4446
.await?;
4547

46-
ctx.msg(connection::DrainMessage {})
48+
ctx.msg(connection::Drain {})
4749
.tag("workflow_id", conn_wf_id)
4850
.send()
4951
.await?;

0 commit comments

Comments
 (0)