@@ -30,6 +30,10 @@ pub(crate) struct MpcTlsClient {
3030}
3131
3232enum State {
33+ Start {
34+ mpc : Pin < MpcFuture > ,
35+ inner : Box < InnerState > ,
36+ } ,
3337 Active {
3438 mpc : Pin < MpcFuture > ,
3539 inner : Box < InnerState > ,
@@ -79,7 +83,7 @@ impl MpcTlsClient {
7983
8084 Self {
8185 decrypt,
82- state : State :: Active {
86+ state : State :: Start {
8387 mpc : Box :: into_pin ( mpc) ,
8488 inner : Box :: new ( inner) ,
8589 } ,
@@ -239,6 +243,13 @@ impl TlsClient for MpcTlsClient {
239243
240244 fn poll ( & mut self , cx : & mut std:: task:: Context ) -> Poll < Result < TlsOutput , Self :: Error > > {
241245 match std:: mem:: replace ( & mut self . state , State :: Error ) {
246+ State :: Start { mpc, inner } => {
247+ self . state = State :: Busy {
248+ mpc,
249+ fut : Box :: pin ( inner. start ( ) ) ,
250+ } ;
251+ self . poll ( cx)
252+ }
242253 State :: Active { mpc, inner } => {
243254 trace ! ( "inner client is active" ) ;
244255
@@ -381,6 +392,18 @@ struct InnerState {
381392}
382393
383394impl InnerState {
395+ #[ instrument( parent = & self . span, level = "debug" , skip_all, err) ]
396+ async fn start ( mut self : Box < Self > ) -> Result < Box < Self > , ProverError > {
397+ self . tls . start ( ) . await ?;
398+ Ok ( self )
399+ }
400+
401+ #[ instrument( parent = & self . span, level = "trace" , skip_all, err) ]
402+ async fn run ( mut self : Box < Self > ) -> Result < Box < Self > , ProverError > {
403+ self . tls . process_new_packets ( ) . await ?;
404+ Ok ( self )
405+ }
406+
384407 #[ instrument( parent = & self . span, level = "debug" , skip_all, err) ]
385408 async fn set_decrypt ( self : Box < Self > , enable : bool ) -> Result < Box < Self > , ProverError > {
386409 self . mpc_ctrl . enable_decryption ( enable) . await ?;
@@ -413,12 +436,6 @@ impl InnerState {
413436 self . run ( ) . await
414437 }
415438
416- #[ instrument( parent = & self . span, level = "trace" , skip_all, err) ]
417- async fn run ( mut self : Box < Self > ) -> Result < Box < Self > , ProverError > {
418- self . tls . process_new_packets ( ) . await ?;
419- Ok ( self )
420- }
421-
422439 #[ instrument( parent = & self . span, level = "debug" , skip_all, err) ]
423440 async fn finalize (
424441 self ,
0 commit comments