Skip to content

Commit b6b0922

Browse files
committed
add start state
1 parent 619ad9c commit b6b0922

File tree

1 file changed

+24
-7
lines changed
  • crates/tlsn/src/prover/client

1 file changed

+24
-7
lines changed

crates/tlsn/src/prover/client/mpc.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ pub(crate) struct MpcTlsClient {
3030
}
3131

3232
enum State {
33+
Start {
34+
mpc: Pin<MpcFuture>,
35+
inner: Box<InnerState>,
36+
},
3337
Active {
3438
mpc: Pin<MpcFuture>,
3539
inner: Box<InnerState>,
@@ -79,7 +83,7 @@ impl MpcTlsClient {
7983

8084
Self {
8185
decrypt,
82-
state: State::Active {
86+
state: State::Start {
8387
mpc: Box::into_pin(mpc),
8488
inner: Box::new(inner),
8589
},
@@ -239,6 +243,13 @@ impl TlsClient for MpcTlsClient {
239243

240244
fn poll(&mut self, cx: &mut std::task::Context) -> Poll<Result<TlsOutput, Self::Error>> {
241245
match std::mem::replace(&mut self.state, State::Error) {
246+
State::Start { mpc, inner } => {
247+
self.state = State::Busy {
248+
mpc,
249+
fut: Box::pin(inner.start()),
250+
};
251+
self.poll(cx)
252+
}
242253
State::Active { mpc, inner } => {
243254
trace!("inner client is active");
244255

@@ -381,6 +392,18 @@ struct InnerState {
381392
}
382393

383394
impl InnerState {
395+
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
396+
async fn start(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
397+
self.tls.start().await?;
398+
Ok(self)
399+
}
400+
401+
#[instrument(parent = &self.span, level = "trace", skip_all, err)]
402+
async fn run(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
403+
self.tls.process_new_packets().await?;
404+
Ok(self)
405+
}
406+
384407
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
385408
async fn set_decrypt(self: Box<Self>, enable: bool) -> Result<Box<Self>, ProverError> {
386409
self.mpc_ctrl.enable_decryption(enable).await?;
@@ -413,12 +436,6 @@ impl InnerState {
413436
self.run().await
414437
}
415438

416-
#[instrument(parent = &self.span, level = "trace", skip_all, err)]
417-
async fn run(mut self: Box<Self>) -> Result<Box<Self>, ProverError> {
418-
self.tls.process_new_packets().await?;
419-
Ok(self)
420-
}
421-
422439
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
423440
async fn finalize(
424441
self,

0 commit comments

Comments
 (0)