diff --git a/Cargo.lock b/Cargo.lock index 99ac1effbd..3a4ad02a28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7176,6 +7176,7 @@ name = "tlsn" version = "0.1.0-alpha.14-pre" dependencies = [ "aes 0.8.4", + "bytes", "ctr 0.9.2", "futures", "ghash 0.5.1", @@ -7195,6 +7196,7 @@ dependencies = [ "mpz-zk", "once_cell", "opaque-debug", + "pin-project-lite", "rand 0.9.2", "rangeset", "rstest", @@ -7212,7 +7214,6 @@ dependencies = [ "tlsn-server-fixture", "tlsn-server-fixture-certs", "tlsn-tls-client", - "tlsn-tls-client-async", "tlsn-tls-core", "tokio", "tokio-util", @@ -7528,7 +7529,6 @@ dependencies = [ "tlsn-key-exchange", "tlsn-tls-backend", "tlsn-tls-client", - "tlsn-tls-client-async", "tlsn-tls-core", "tokio", "tokio-util", @@ -7603,26 +7603,6 @@ dependencies = [ "webpki-roots 1.0.3", ] -[[package]] -name = "tlsn-tls-client-async" -version = "0.1.0-alpha.14-pre" -dependencies = [ - "bytes", - "futures", - "http-body-util", - "hyper", - "hyper-util", - "rstest", - "rustls-pki-types", - "rustls-webpki 0.103.7", - "thiserror 1.0.69", - "tls-server-fixture", - "tlsn-tls-client", - "tokio", - "tokio-util", - "tracing", -] - [[package]] name = "tlsn-tls-core" version = "0.1.0-alpha.14-pre" @@ -7669,7 +7649,6 @@ dependencies = [ "tlsn", "tlsn-core", "tlsn-server-fixture-certs", - "tlsn-tls-client-async", "tlsn-tls-core", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index 7681dd0119..cb07a4c1f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,6 @@ members = [ "crates/server-fixture/server", "crates/tls/backend", "crates/tls/client", - "crates/tls/client-async", "crates/tls/core", "crates/mpc-tls", "crates/tls/server-fixture", @@ -57,7 +56,6 @@ tlsn-server-fixture = { path = "crates/server-fixture/server" } tlsn-server-fixture-certs = { path = "crates/server-fixture/certs" } tlsn-tls-backend = { path = "crates/tls/backend" } tlsn-tls-client = { path = "crates/tls/client" } -tlsn-tls-client-async = { path = "crates/tls/client-async" } tlsn-tls-core = { path = "crates/tls/core" } tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" } tlsn-harness-core = { path = "crates/harness/core" } diff --git a/crates/examples/attestation/prove.rs b/crates/examples/attestation/prove.rs index 7b9805f46f..a3fd23f320 100644 --- a/crates/examples/attestation/prove.rs +++ b/crates/examples/attestation/prove.rs @@ -124,7 +124,7 @@ async fn prover( // Bind the prover to the server connection. let (tls_connection, prover_fut) = prover - .connect( + .connect_with( TlsClientConfig::builder() .server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?)) // Create a root certificate store with the server-fixture's self-signed diff --git a/crates/examples/interactive/interactive.rs b/crates/examples/interactive/interactive.rs index dc3aef504b..7dbafc6a73 100644 --- a/crates/examples/interactive/interactive.rs +++ b/crates/examples/interactive/interactive.rs @@ -102,7 +102,7 @@ async fn prover( // Bind the prover to the server connection. let (tls_connection, prover_fut) = prover - .connect( + .connect_with( TlsClientConfig::builder() .server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?)) // Create a root certificate store with the server-fixture's self-signed diff --git a/crates/examples/interactive_zk/prover.rs b/crates/examples/interactive_zk/prover.rs index 63497f810e..fd6280d631 100644 --- a/crates/examples/interactive_zk/prover.rs +++ b/crates/examples/interactive_zk/prover.rs @@ -89,7 +89,7 @@ pub async fn prover( // Bind the prover to the server connection. let (tls_connection, prover_fut) = prover - .connect( + .connect_with( TlsClientConfig::builder() .server_name(ServerName::Dns(SERVER_DOMAIN.try_into()?)) // Create a root certificate store with the server-fixture's self-signed diff --git a/crates/harness/executor/src/bench/prover.rs b/crates/harness/executor/src/bench/prover.rs index ece42b7822..587d13aebf 100644 --- a/crates/harness/executor/src/bench/prover.rs +++ b/crates/harness/executor/src/bench/prover.rs @@ -59,7 +59,7 @@ pub async fn bench_prover(provider: &IoProvider, config: &Bench) -> Result (MpcTlsLeader, MpcTlsFollower) { - let mut rng = StdRng::seed_from_u64(0); - - let (mut mt_a, mut mt_b) = test_mt_context(8); - - let ctx_a = futures::executor::block_on(mt_a.new_context()).unwrap(); - let ctx_b = futures::executor::block_on(mt_b.new_context()).unwrap(); - - let delta_a = Delta::new(Block::random(&mut rng)); - let delta_b = Delta::new(Block::random(&mut rng)); - - let (rcot_send_a, rcot_recv_b) = ideal_rcot(Block::random(&mut rng), delta_a.into_inner()); - let (rcot_send_b, rcot_recv_a) = ideal_rcot(Block::random(&mut rng), delta_b.into_inner()); - - let rcot_send_a = SharedRCOTSender::new(rcot_send_a); - let rcot_send_b = SharedRCOTSender::new(rcot_send_b); - let rcot_recv_a = SharedRCOTReceiver::new(rcot_recv_a); - let rcot_recv_b = SharedRCOTReceiver::new(rcot_recv_b); - - let mpc_a = Arc::new(Mutex::new(IdealVm::new())); - let mpc_b = Arc::new(Mutex::new(IdealVm::new())); - - let leader = MpcTlsLeader::new( - config.clone(), - ctx_a, - mpc_a, - (rcot_send_a.clone(), rcot_send_a.clone(), rcot_send_a), - rcot_recv_a, - ); - - let follower = MpcTlsFollower::new( - config, - ctx_b, - mpc_b, - rcot_send_b, - (rcot_recv_b.clone(), rcot_recv_b.clone(), rcot_recv_b), - ); - - (leader, follower) -} diff --git a/crates/tls/client-async/Cargo.toml b/crates/tls/client-async/Cargo.toml deleted file mode 100644 index bd4475cb6b..0000000000 --- a/crates/tls/client-async/Cargo.toml +++ /dev/null @@ -1,39 +0,0 @@ -[package] -name = "tlsn-tls-client-async" -authors = ["TLSNotary Team"] -description = "An async TLS client for TLSNotary" -keywords = ["tls", "mpc", "2pc", "client", "async"] -categories = ["cryptography"] -license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.14-pre" -edition = "2021" - -[lints] -workspace = true - -[lib] -name = "tls_client_async" - -[features] -default = ["tracing"] -tracing = ["dep:tracing"] - -[dependencies] -tlsn-tls-client = { workspace = true } - -bytes = { workspace = true } -futures = { workspace = true } -thiserror = { workspace = true } -tokio-util = { workspace = true, features = ["io", "compat"] } -tracing = { workspace = true, optional = true } - -[dev-dependencies] -tls-server-fixture = { workspace = true } - -http-body-util = { workspace = true } -hyper = { workspace = true, features = ["client", "http1"] } -hyper-util = { workspace = true, features = ["full"] } -rstest = { workspace = true } -tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] } -rustls-webpki = { workspace = true } -rustls-pki-types = { workspace = true } diff --git a/crates/tls/client-async/src/conn.rs b/crates/tls/client-async/src/conn.rs deleted file mode 100644 index aec79c92b6..0000000000 --- a/crates/tls/client-async/src/conn.rs +++ /dev/null @@ -1,89 +0,0 @@ -use bytes::Bytes; -use futures::{ - channel::mpsc::{Receiver, SendError, Sender}, - sink::SinkMapErr, - AsyncRead, AsyncWrite, SinkExt, -}; -use std::{ - io::{Error as IoError, ErrorKind as IoErrorKind}, - pin::Pin, - task::{Context, Poll}, -}; -use tokio_util::{ - compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}, - io::{CopyToBytes, SinkWriter, StreamReader}, -}; - -type CompatSinkWriter = - Compat, fn(SendError) -> IoError>>>>; - -/// A TLS connection to a server. -/// -/// This type implements `AsyncRead` and `AsyncWrite` and can be used to -/// communicate with a server using TLS. -/// -/// # Note -/// -/// This connection is closed on a best-effort basis if this is dropped. To -/// ensure a clean close, you should call -/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the -/// connection. -#[derive(Debug)] -pub struct TlsConnection { - /// The data to be transmitted to the server is sent to this sink. - tx_sender: CompatSinkWriter, - /// The data to be received from the server is received from this stream. - rx_receiver: Compat>, Bytes>>, -} - -impl TlsConnection { - /// Creates a new TLS connection. - pub(crate) fn new( - tx_sender: Sender, - rx_receiver: Receiver>, - ) -> Self { - fn convert_error(err: SendError) -> IoError { - if err.is_disconnected() { - IoErrorKind::BrokenPipe.into() - } else { - IoErrorKind::WouldBlock.into() - } - } - - Self { - tx_sender: SinkWriter::new(CopyToBytes::new( - tx_sender.sink_map_err(convert_error as fn(SendError) -> IoError), - )) - .compat_write(), - rx_receiver: StreamReader::new(rx_receiver).compat(), - } - } -} - -impl AsyncRead for TlsConnection { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - Pin::new(&mut self.rx_receiver).poll_read(cx, buf) - } -} - -impl AsyncWrite for TlsConnection { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.tx_sender).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.tx_sender).poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.tx_sender).poll_close(cx) - } -} diff --git a/crates/tls/client-async/src/lib.rs b/crates/tls/client-async/src/lib.rs deleted file mode 100644 index ca7ffaacc3..0000000000 --- a/crates/tls/client-async/src/lib.rs +++ /dev/null @@ -1,269 +0,0 @@ -//! Provides a TLS client which exposes an async socket. -//! -//! This library provides the [bind_client] function which attaches a TLS client -//! to a socket connection and then exposes a [TlsConnection] object, which -//! provides an async socket API for reading and writing cleartext. The TLS -//! client will then automatically encrypt and decrypt traffic and forward that -//! to the provided socket. - -#![deny(missing_docs, unreachable_pub, unused_must_use)] -#![deny(clippy::all)] -#![forbid(unsafe_code)] - -mod conn; - -use bytes::{Buf, Bytes}; -use futures::{ - channel::mpsc, future::Fuse, select_biased, stream::Next, AsyncRead, AsyncReadExt, AsyncWrite, - AsyncWriteExt, Future, FutureExt, SinkExt, StreamExt, -}; - -use std::{ - pin::Pin, - task::{Context, Poll}, -}; - -#[cfg(feature = "tracing")] -use tracing::{debug, debug_span, error, trace, warn, Instrument}; - -use tls_client::ClientConnection; - -pub use conn::TlsConnection; - -const RX_TLS_BUF_SIZE: usize = 1 << 13; // 8 KiB -const RX_BUF_SIZE: usize = 1 << 13; // 8 KiB - -/// An error that can occur during a TLS connection. -#[allow(missing_docs)] -#[derive(Debug, thiserror::Error)] -pub enum ConnectionError { - #[error(transparent)] - TlsError(#[from] tls_client::Error), - #[error(transparent)] - IOError(#[from] std::io::Error), -} - -/// Closed connection data. -#[derive(Debug)] -pub struct ClosedConnection { - /// The connection for the client - pub client: ClientConnection, - /// Sent plaintext bytes - pub sent: Vec, - /// Received plaintext bytes - pub recv: Vec, -} - -/// A future which runs the TLS connection to completion. -/// -/// This future must be polled in order for the connection to make progress. -#[must_use = "futures do nothing unless polled"] -pub struct ConnectionFuture { - fut: Pin> + Send>>, -} - -impl Future for ConnectionFuture { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.fut.poll_unpin(cx) - } -} - -/// Binds a client connection to the provided socket. -/// -/// Returns a connection handle and a future which runs the connection to -/// completion. -/// -/// # Errors -/// -/// Any connection errors that occur will be returned from the future, not -/// [`TlsConnection`]. -pub fn bind_client( - socket: T, - mut client: ClientConnection, -) -> (TlsConnection, ConnectionFuture) { - let (tx_sender, mut tx_receiver) = mpsc::channel(1 << 14); - let (mut rx_sender, rx_receiver) = mpsc::channel(1 << 14); - - let conn = TlsConnection::new(tx_sender, rx_receiver); - - let fut = async move { - client.start().await?; - let mut notify = client.get_notify().await?; - - let (mut server_rx, mut server_tx) = socket.split(); - - let mut rx_tls_buf = [0u8; RX_TLS_BUF_SIZE]; - let mut rx_buf = [0u8; RX_BUF_SIZE]; - - let mut handshake_done = false; - let mut client_closed = false; - let mut server_closed = false; - - let mut sent = Vec::with_capacity(1024); - let mut recv = Vec::with_capacity(1024); - - let mut rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse(); - // We don't start writing application data until the handshake is complete. - let mut tx_recv_fut: Fuse>> = Fuse::terminated(); - - // Runs both the tx and rx halves of the connection to completion. - // This loop does not terminate until the *SERVER* closes the connection and - // we've processed all received data. If an error occurs, the `TlsConnection` - // channels will be closed and the error will be returned from this future. - 'conn: loop { - // Write all pending TLS data to the server. - if client.wants_write() && !client_closed { - #[cfg(feature = "tracing")] - trace!("client wants to write"); - while client.wants_write() { - let _sent = client.write_tls_async(&mut server_tx).await?; - #[cfg(feature = "tracing")] - trace!("sent {} tls bytes to server", _sent); - } - server_tx.flush().await?; - } - - // Forward received plaintext to `TlsConnection`. - while !client.plaintext_is_empty() { - let read = client.read_plaintext(&mut rx_buf)?; - recv.extend(&rx_buf[..read]); - // Ignore if the receiver has hung up. - _ = rx_sender - .send(Ok(Bytes::copy_from_slice(&rx_buf[..read]))) - .await; - #[cfg(feature = "tracing")] - trace!("forwarded {} plaintext bytes to conn", read); - } - - if !client.is_handshaking() && !handshake_done { - #[cfg(feature = "tracing")] - debug!("handshake complete"); - handshake_done = true; - // Start reading application data that needs to be transmitted from the - // `TlsConnection`. - tx_recv_fut = tx_receiver.next().fuse(); - } - - if server_closed && client.plaintext_is_empty() && client.is_empty().await? { - break 'conn; - } - - select_biased! { - // Reads TLS data from the server and writes it into the client. - received = &mut rx_tls_fut => { - let received = received?; - #[cfg(feature = "tracing")] - trace!("received {} tls bytes from server", received); - - // Loop until we've processed all the data we received in this read. - // Note that we must make one iteration even if `received == 0`. - let mut processed = 0; - let mut reader = rx_tls_buf[..received].reader(); - loop { - processed += client.read_tls(&mut reader)?; - client.process_new_packets().await?; - - debug_assert!(processed <= received); - if processed >= received { - break; - } - } - - #[cfg(feature = "tracing")] - trace!("processed {} tls bytes from server", processed); - - // By convention if `AsyncRead::read` returns 0, it means EOF, i.e. the peer - // has closed the socket. - if received == 0 { - #[cfg(feature = "tracing")] - debug!("server closed connection"); - server_closed = true; - client.server_closed().await?; - // Do not read from the socket again. - rx_tls_fut = Fuse::terminated(); - } else { - // Reset the read future so next iteration we can read again. - rx_tls_fut = server_rx.read(&mut rx_tls_buf).fuse(); - } - } - // If we receive None from `TlsConnection`, it has closed, so we - // send a close_notify to the server. - data = &mut tx_recv_fut => { - if let Some(data) = data { - #[cfg(feature = "tracing")] - trace!("writing {} plaintext bytes to client", data.len()); - - sent.extend(&data); - client - .write_all_plaintext(&data)?; - client.process_new_packets().await?; - - tx_recv_fut = tx_receiver.next().fuse(); - } else { - if !server_closed { - if let Err(e) = send_close_notify(&mut client, &mut server_tx).await { - #[cfg(feature = "tracing")] - warn!("failed to send close_notify to server: {}", e); - } - } - - client_closed = true; - - tx_recv_fut = Fuse::terminated(); - } - } - // Waits for a notification from the backend that it is ready to decrypt data. - _ = &mut notify => { - #[cfg(feature = "tracing")] - trace!("backend is ready to decrypt"); - - client.process_new_packets().await?; - } - } - } - - #[cfg(feature = "tracing")] - debug!("client shutdown"); - - _ = server_tx.close().await; - tx_receiver.close(); - rx_sender.close_channel(); - - #[cfg(feature = "tracing")] - trace!( - "server close notify: {}, sent: {}, recv: {}", - client.received_close_notify(), - sent.len(), - recv.len() - ); - - Ok(ClosedConnection { client, sent, recv }) - }; - - #[cfg(feature = "tracing")] - let fut = fut.instrument(debug_span!("tls_connection")); - - let fut = ConnectionFuture { fut: Box::pin(fut) }; - - (conn, fut) -} - -async fn send_close_notify( - client: &mut ClientConnection, - server_tx: &mut (impl AsyncWrite + Unpin), -) -> Result<(), ConnectionError> { - #[cfg(feature = "tracing")] - trace!("sending close_notify to server"); - client.send_close_notify().await?; - client.process_new_packets().await?; - - // Flush all remaining plaintext - while client.wants_write() { - client.write_tls_async(server_tx).await?; - } - server_tx.flush().await?; - - Ok(()) -} diff --git a/crates/tls/client-async/tests/test.rs b/crates/tls/client-async/tests/test.rs deleted file mode 100644 index f588665dbc..0000000000 --- a/crates/tls/client-async/tests/test.rs +++ /dev/null @@ -1,438 +0,0 @@ -use std::{str, sync::Arc}; - -use core::future::Future; -use futures::{AsyncReadExt, AsyncWriteExt}; -use http_body_util::{BodyExt as _, Full}; -use hyper::{body::Bytes, Request, StatusCode}; -use hyper_util::rt::TokioIo; -use rstest::{fixture, rstest}; -use rustls_pki_types::CertificateDer; -use tls_client::{ClientConfig, ClientConnection, RustCryptoBackend, ServerName}; -use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection}; -use tls_server_fixture::{ - bind_test_server, bind_test_server_hyper, APP_RECORD_LENGTH, CA_CERT_DER, CLOSE_DELAY, - SERVER_DOMAIN, -}; -use tokio::task::JoinHandle; -use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; -use webpki::anchor_from_trusted_cert; - -const CA_CERT: CertificateDer = CertificateDer::from_slice(CA_CERT_DER); - -// An established client TLS connection -struct TlsFixture { - client_tls_conn: TlsConnection, - // a handle that must be `.await`ed to get the result of a TLS connection - closed_tls_task: JoinHandle>, -} - -// Sets up a TLS connection between client and server and sends a hello message -#[fixture] -async fn set_up_tls() -> TlsFixture { - let (client_socket, server_socket) = tokio::io::duplex(1 << 16); - - let _server_task = tokio::spawn(bind_test_server(server_socket.compat())); - - let mut root_store = tls_client::RootCertStore::empty(); - root_store - .roots - .push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()); - let config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_no_client_auth(); - let client = ClientConnection::new( - Arc::new(config), - Box::new(RustCryptoBackend::new()), - ServerName::try_from(SERVER_DOMAIN).unwrap(), - ) - .unwrap(); - - let (mut client_tls_conn, tls_fut) = bind_client(client_socket.compat(), client); - - let closed_tls_task = tokio::spawn(tls_fut); - - client_tls_conn - .write_all(&pad("expecting you to send back hello".to_string())) - .await - .unwrap(); - - // give the server some time to respond - std::thread::sleep(std::time::Duration::from_millis(10)); - - let mut plaintext = vec![0u8; 320]; - let n = client_tls_conn.read(&mut plaintext).await.unwrap(); - let s = str::from_utf8(&plaintext[0..n]).unwrap(); - - assert_eq!(s, "hello"); - - TlsFixture { - client_tls_conn, - closed_tls_task, - } -} - -// Expect the async tls client wrapped in `hyper::client` to make a successful -// request and receive the expected response -#[tokio::test] -async fn test_hyper_ok() { - let (client_socket, server_socket) = tokio::io::duplex(1 << 16); - - let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat())); - - let mut root_store = tls_client::RootCertStore::empty(); - root_store - .roots - .push(anchor_from_trusted_cert(&CA_CERT).unwrap().to_owned()); - let config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_no_client_auth(); - let client = ClientConnection::new( - Arc::new(config), - Box::new(RustCryptoBackend::new()), - ServerName::try_from(SERVER_DOMAIN).unwrap(), - ) - .unwrap(); - - let (conn, tls_fut) = bind_client(client_socket.compat(), client); - - let closed_tls_task = tokio::spawn(tls_fut); - - let (mut request_sender, connection) = - hyper::client::conn::http1::handshake(TokioIo::new(conn.compat())) - .await - .unwrap(); - - tokio::spawn(connection); - - let request = Request::builder() - .uri(format!("https://{SERVER_DOMAIN}/echo")) - .header("Host", SERVER_DOMAIN) - .header("Connection", "close") - .method("POST") - .body(Full::::new("hello".into())) - .unwrap(); - - let response = request_sender.send_request(request).await.unwrap(); - - assert!(response.status() == StatusCode::OK); - - // Process the response body - response.into_body().collect().await.unwrap().to_bytes(); - - let _ = server_task.await.unwrap(); - - let closed_conn = closed_tls_task.await.unwrap().unwrap(); - - assert!(closed_conn.client.received_close_notify()); -} - -// Expect a clean TLS connection closure when server responds to the client's -// close_notify but doesn't close the socket -#[rstest] -#[tokio::test] -async fn test_ok_server_no_socket_close(set_up_tls: impl Future) { - let TlsFixture { - mut client_tls_conn, - closed_tls_task, - } = set_up_tls.await; - - // instruct the server to send close_notify back to us after 10 ms - client_tls_conn - .write_all(&pad("send_close_notify".to_string())) - .await - .unwrap(); - client_tls_conn.flush().await.unwrap(); - - // closing `client_tls_conn` will cause close_notify to be sent by the client; - client_tls_conn.close().await.unwrap(); - - let closed_conn = closed_tls_task.await.unwrap().unwrap(); - - assert!(closed_conn.client.received_close_notify()); -} - -// Expect a clean TLS connection closure when server responds to the client's -// close_notify AND also closes the socket -#[rstest] -#[tokio::test] -async fn test_ok_server_socket_close(set_up_tls: impl Future) { - let TlsFixture { - mut client_tls_conn, - closed_tls_task, - } = set_up_tls.await; - - // instruct the server to send close_notify back to us AND close the socket - // after 10 ms - client_tls_conn - .write_all(&pad("send_close_notify_and_close_socket".to_string())) - .await - .unwrap(); - client_tls_conn.flush().await.unwrap(); - - // closing `client_tls_conn` will cause close_notify to be sent by the client; - client_tls_conn.close().await.unwrap(); - - let closed_conn = closed_tls_task.await.unwrap().unwrap(); - - assert!(closed_conn.client.received_close_notify()); -} - -// Expect a clean TLS connection closure when server sends close_notify first -// but doesn't close the socket -#[rstest] -#[tokio::test] -async fn test_ok_server_close_notify(set_up_tls: impl Future) { - let TlsFixture { - mut client_tls_conn, - closed_tls_task, - } = set_up_tls.await; - - // instruct the server to send close_notify back to us after 10 ms - client_tls_conn - .write_all(&pad("send_close_notify".to_string())) - .await - .unwrap(); - client_tls_conn.flush().await.unwrap(); - - // give enough time for server's close_notify to arrive - tokio::time::sleep(std::time::Duration::from_millis(20)).await; - - client_tls_conn.close().await.unwrap(); - - let closed_conn = closed_tls_task.await.unwrap().unwrap(); - - assert!(closed_conn.client.received_close_notify()); -} - -// Expect a clean TLS connection closure when server sends close_notify first -// AND also closes the socket -#[rstest] -#[tokio::test] -async fn test_ok_server_close_notify_and_socket_close( - set_up_tls: impl Future, -) { - let TlsFixture { - mut client_tls_conn, - closed_tls_task, - } = set_up_tls.await; - - // instruct the server to send close_notify back to us after 10 ms - client_tls_conn - .write_all(&pad("send_close_notify_and_close_socket".to_string())) - .await - .unwrap(); - client_tls_conn.flush().await.unwrap(); - - // give enough time for server's close_notify to arrive - tokio::time::sleep(std::time::Duration::from_millis(20)).await; - - client_tls_conn.close().await.unwrap(); - - let closed_conn = closed_tls_task.await.unwrap().unwrap(); - - assert!(closed_conn.client.received_close_notify()); -} - -// Expect to be able to read the data after server closes the socket abruptly -#[rstest] -#[tokio::test] -async fn test_ok_read_after_close(set_up_tls: impl Future) { - let TlsFixture { - mut client_tls_conn, - .. - } = set_up_tls.await; - - // instruct the server to send us a hello message - client_tls_conn - .write_all(&pad("send a hello message".to_string())) - .await - .unwrap(); - client_tls_conn.flush().await.unwrap(); - - // instruct the server to close the socket - client_tls_conn - .write_all(&pad("close_socket".to_string())) - .await - .unwrap(); - client_tls_conn.flush().await.unwrap(); - - // give enough time to close the socket - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - - // try to read some more data - let mut buf = vec![0u8; 10]; - let n = client_tls_conn.read(&mut buf).await.unwrap(); - - assert_eq!(std::str::from_utf8(&buf[0..n]).unwrap(), "hello"); -} - -// Expect there to be no error when server DOES NOT send close_notify but just -// closes the socket -#[rstest] -#[tokio::test] -async fn test_ok_server_no_close_notify(set_up_tls: impl Future) { - let TlsFixture { - mut client_tls_conn, - closed_tls_task, - } = set_up_tls.await; - - // instruct the server to close the socket - client_tls_conn - .write_all(&pad("close_socket".to_string())) - .await - .unwrap(); - client_tls_conn.flush().await.unwrap(); - - // give enough time to close the socket - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - - client_tls_conn.close().await.unwrap(); - - let closed_conn = closed_tls_task.await.unwrap().unwrap(); - - assert!(!closed_conn.client.received_close_notify()); -} - -// Expect to register a delay when the server delays closing the socket -#[rstest] -#[tokio::test] -async fn test_ok_delay_close(set_up_tls: impl Future) { - let TlsFixture { - mut client_tls_conn, - closed_tls_task, - } = set_up_tls.await; - - client_tls_conn - .write_all(&pad("must_delay_when_closing".to_string())) - .await - .unwrap(); - client_tls_conn.flush().await.unwrap(); - - // closing `client_tls_conn` will cause close_notify to be sent by the client - client_tls_conn.close().await.unwrap(); - - use std::time::Instant; - let now = Instant::now(); - // this will resolve when the server stops delaying closing the socket - let closed_conn = closed_tls_task.await.unwrap().unwrap(); - let elapsed = now.elapsed(); - - // the elapsed time must be roughly equal to the server's delay - // (give or take timing variations) - assert!(elapsed.as_millis() as u64 > CLOSE_DELAY - 50); - - assert!(!closed_conn.client.received_close_notify()); -} - -// Expect client to error when server sends a corrupted message -#[rstest] -#[tokio::test] -async fn test_err_corrupted(set_up_tls: impl Future) { - let TlsFixture { - mut client_tls_conn, - closed_tls_task, - } = set_up_tls.await; - - // instruct the server to send a corrupted message - client_tls_conn - .write_all(&pad("send_corrupted_message".to_string())) - .await - .unwrap(); - client_tls_conn.flush().await.unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - client_tls_conn.close().await.unwrap(); - - assert_eq!( - closed_tls_task.await.unwrap().err().unwrap().to_string(), - "received corrupt message" - ); -} - -// Expect client to error when server sends a TLS record with a bad MAC -#[rstest] -#[tokio::test] -async fn test_err_bad_mac(set_up_tls: impl Future) { - let TlsFixture { - mut client_tls_conn, - closed_tls_task, - } = set_up_tls.await; - - // instruct the server to send us a TLS record with a bad MAC - client_tls_conn - .write_all(&pad("send_record_with_bad_mac".to_string())) - .await - .unwrap(); - client_tls_conn.flush().await.unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - client_tls_conn.close().await.unwrap(); - - assert_eq!( - closed_tls_task.await.unwrap().err().unwrap().to_string(), - "backend error: Decryption error: \"aead::Error\"" - ); -} - -// Expect client to error when server sends a fatal alert -#[rstest] -#[tokio::test] -async fn test_err_alert(set_up_tls: impl Future) { - let TlsFixture { - mut client_tls_conn, - closed_tls_task, - } = set_up_tls.await; - - // instruct the server to send us a TLS record with a bad MAC - client_tls_conn - .write_all(&pad("send_alert".to_string())) - .await - .unwrap(); - client_tls_conn.flush().await.unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - client_tls_conn.close().await.unwrap(); - - assert_eq!( - closed_tls_task.await.unwrap().err().unwrap().to_string(), - "received fatal alert: BadRecordMac" - ); -} - -// Expect an error when trying to write data to a connection which server closed -// abruptly -#[rstest] -#[tokio::test] -async fn test_err_write_after_close(set_up_tls: impl Future) { - let TlsFixture { - mut client_tls_conn, - .. - } = set_up_tls.await; - - // instruct the server to close the socket - client_tls_conn - .write_all(&pad("close_socket".to_string())) - .await - .unwrap(); - client_tls_conn.flush().await.unwrap(); - - // give enough time to close the socket - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - - // try to send some more data - let res = client_tls_conn - .write_all(&pad("more data".to_string())) - .await; - - assert_eq!(res.err().unwrap().kind(), std::io::ErrorKind::BrokenPipe); -} - -// Converts a string into a slice zero-padded to APP_RECORD_LENGTH -fn pad(s: String) -> Vec { - assert!(s.len() <= APP_RECORD_LENGTH); - let mut buf = vec![0u8; APP_RECORD_LENGTH]; - buf[..s.len()].copy_from_slice(s.as_bytes()); - buf -} diff --git a/crates/tls/client/src/conn.rs b/crates/tls/client/src/conn.rs index 11854619fc..491ef2871f 100644 --- a/crates/tls/client/src/conn.rs +++ b/crates/tls/client/src/conn.rs @@ -227,6 +227,7 @@ impl ConnectionCommon { /// Signals that the server has closed the connection. pub async fn server_closed(&mut self) -> Result<(), Error> { + self.common_state.has_seen_eof = true; self.common_state.backend.server_closed().await?; Ok(()) } diff --git a/crates/tlsn/Cargo.toml b/crates/tlsn/Cargo.toml index 18d5926918..de9aed994d 100644 --- a/crates/tlsn/Cargo.toml +++ b/crates/tlsn/Cargo.toml @@ -21,7 +21,6 @@ tlsn-attestation = { workspace = true } tlsn-core = { workspace = true } tlsn-deap = { workspace = true } tlsn-tls-client = { workspace = true } -tlsn-tls-client-async = { workspace = true } tlsn-tls-core = { workspace = true } tlsn-mpc-tls = { workspace = true } tlsn-cipher = { workspace = true } @@ -44,6 +43,7 @@ mpz-zk = { workspace = true } mpz-ideal-vm = { workspace = true } aes = { workspace = true } +bytes = { workspace = true } ctr = { workspace = true } futures = { workspace = true } opaque-debug = { workspace = true } @@ -57,6 +57,7 @@ serde = { workspace = true, features = ["derive"] } ghash = { workspace = true } semver = { workspace = true, features = ["serde"] } once_cell = { workspace = true } +pin-project-lite = { workspace = true } rangeset = { workspace = true } webpki-roots = { workspace = true } diff --git a/crates/tlsn/src/prover.rs b/crates/tlsn/src/prover.rs index 574fcd77a3..a4a97335f1 100644 --- a/crates/tlsn/src/prover.rs +++ b/crates/tlsn/src/prover.rs @@ -1,10 +1,14 @@ //! Prover. mod client; +mod conn; +mod control; mod error; mod prove; pub mod state; +pub use conn::{ConnectionFuture, TlsConnection}; +pub use control::ProverControl; pub use error::ProverError; pub use tlsn_core::ProverOutput; @@ -21,7 +25,7 @@ use futures::{AsyncRead, AsyncWrite, FutureExt, TryFutureExt}; use rustls_pki_types::CertificateDer; use serio::{SinkExt, stream::IoStreamExt}; use std::{ - sync::Arc, + sync::{Arc, Mutex}, task::{Context, Poll}, }; use tls_client::{ClientConnection, ServerName as TlsServerName}; @@ -223,6 +227,42 @@ impl Prover { }; Ok(prover) } + + /// Connects the prover and attaches a socket. + /// + /// This is a convenience function which returns + /// - [`TlsConnection`] for reading and writing traffic. + /// - [`ConnectionFuture`] which has to be polled for driving the + /// connection forward. + /// + /// # Arguments + /// + /// * `config` - The TLS client configuration. + /// * `socket` - The socket for IO. + #[instrument(parent = &self.span, level = "debug", skip_all, err)] + pub async fn connect_with( + self, + config: TlsClientConfig, + socket: S, + ) -> Result<(TlsConnection, ConnectionFuture), ProverError> + where + S: AsyncRead + AsyncWrite + Send, + { + let prover = self.connect(config).await?; + + let prover = Arc::new(Mutex::new(prover)); + let conn_waker = Arc::new(Mutex::new(None)); + let fut_waker = Arc::new(Mutex::new(None)); + + let conn = TlsConnection::new( + Arc::downgrade(&prover), + conn_waker.clone(), + fut_waker.clone(), + ); + let fut = ConnectionFuture::new(socket, prover, conn_waker, fut_waker); + + Ok((conn, fut)) + } } impl Prover { @@ -317,6 +357,7 @@ impl Prover { match self.state.tls_client.poll(cx)? { Poll::Ready(output) => { + let _ = self.state.mux_fut.poll_unpin(cx)?; self.state.output = Some(output); Poll::Ready(Ok(())) } diff --git a/crates/tlsn/src/prover/client/mpc.rs b/crates/tlsn/src/prover/client/mpc.rs index cf65c06a33..6230fce787 100644 --- a/crates/tlsn/src/prover/client/mpc.rs +++ b/crates/tlsn/src/prover/client/mpc.rs @@ -12,21 +12,30 @@ use futures::{Future, FutureExt}; use mpc_tls::{LeaderCtrl, SessionKeys}; use mpz_common::Context; use mpz_vm_core::Execute; -use std::{pin::Pin, sync::Arc, task::Poll}; +use std::{collections::VecDeque, pin::Pin, sync::Arc, task::Poll}; use tls_client::ClientConnection; use tlsn_core::transcript::TlsTranscript; use tlsn_deap::Deap; use tokio::sync::Mutex; use tracing::{Span, debug, instrument, trace, warn}; -pub(crate) type MpcFuture = Box>>; +pub(crate) type MpcFuture = + Box> + Send>; type FinalizeFuture = - Box>>; + Box> + Send>; pub(crate) struct MpcTlsClient { state: State, decrypt: bool, + cmds: VecDeque, +} + +#[derive(Debug, Clone, Copy)] +pub(crate) enum Command { + ClientClose, + ServerClose, + Decrypt(bool), } enum State { @@ -40,20 +49,20 @@ enum State { }, Busy { mpc: Pin, - fut: Pin, ProverError>>>>, + fut: Pin, ProverError>> + Send>>, }, - ClientClose { + MpcStop { mpc: Pin, - fut: Pin, ProverError>>>>, + inner: Box, }, - ServerClose { + CloseBusy { mpc: Pin, - fut: Pin, ProverError>>>>, + fut: Pin, ProverError>> + Send>>, }, - Closing { + Finishing { ctx: Context, transcript: Box, - fut: Pin, ProverError>>>>, + fut: Pin, ProverError>> + Send>>, }, Finalizing { fut: Pin, @@ -78,7 +87,7 @@ impl MpcTlsClient { vm, keys, mpc_ctrl, - closed: false, + mpc_stopped: false, }; Self { @@ -87,11 +96,12 @@ impl MpcTlsClient { mpc: Box::into_pin(mpc), inner: Box::new(inner), }, + cmds: VecDeque::default(), } } fn inner_client_mut(&mut self) -> Option<&mut ClientConnection> { - if let State::Active { inner, .. } = &mut self.state { + if let State::Active { inner, .. } | State::MpcStop { inner, .. } = &mut self.state { Some(&mut inner.tls) } else { None @@ -99,7 +109,7 @@ impl MpcTlsClient { } fn inner_client(&self) -> Option<&ClientConnection> { - if let State::Active { inner, .. } = &self.state { + if let State::Active { inner, .. } | State::MpcStop { inner, .. } = &self.state { Some(&inner.tls) } else { None @@ -148,7 +158,7 @@ impl TlsClient for MpcTlsClient { fn wants_read(&self) -> bool { if let Some(client) = self.inner_client() { - !client.sendable_plaintext_is_full() + !client.plaintext_is_empty() } else { false } @@ -156,7 +166,7 @@ impl TlsClient for MpcTlsClient { fn wants_write(&self) -> bool { if let Some(client) = self.inner_client() { - !client.plaintext_is_empty() + !client.sendable_plaintext_is_full() } else { false } @@ -164,7 +174,7 @@ impl TlsClient for MpcTlsClient { fn read(&mut self, buf: &mut [u8]) -> Result { if let Some(client) = self.inner_client_mut() - && !client.sendable_plaintext_is_full() + && !client.plaintext_is_empty() { client.read_plaintext(buf).map_err(ProverError::from) } else { @@ -174,7 +184,7 @@ impl TlsClient for MpcTlsClient { fn write(&mut self, buf: &[u8]) -> Result { if let Some(client) = self.inner_client_mut() - && !client.plaintext_is_empty() + && !client.sendable_plaintext_is_full() { client.write_plaintext(buf).map_err(ProverError::from) } else { @@ -183,58 +193,18 @@ impl TlsClient for MpcTlsClient { } fn client_close(&mut self) -> Result<(), Self::Error> { - match std::mem::replace(&mut self.state, State::Error) { - State::Active { inner, mpc } => { - self.state = State::ClientClose { - mpc, - fut: Box::pin(inner.client_close()), - }; - Ok(()) - } - other => { - self.state = other; - Err(ProverError::state( - "unable to close connection, client is not in active state", - )) - } - } + self.cmds.push_back(Command::ClientClose); + Ok(()) } fn server_close(&mut self) -> Result<(), Self::Error> { - match std::mem::replace(&mut self.state, State::Error) { - State::Active { inner, mpc } => { - self.state = State::ServerClose { - mpc, - fut: Box::pin(inner.server_close()), - }; - Ok(()) - } - other => { - self.state = other; - Err(ProverError::state( - "unable to close connection, client is not in active state", - )) - } - } + self.cmds.push_back(Command::ServerClose); + Ok(()) } fn enable_decryption(&mut self, enable: bool) -> Result<(), Self::Error> { - match std::mem::replace(&mut self.state, State::Error) { - State::Active { inner, mpc } => { - self.decrypt = enable; - self.state = State::Busy { - mpc, - fut: Box::pin(inner.set_decrypt(enable)), - }; - Ok(()) - } - other => { - self.state = other; - Err(ProverError::state( - "unable to enable decryption, client is not in active state", - )) - } - } + self.cmds.push_back(Command::Decrypt(enable)); + Ok(()) } fn is_decrypting(&self) -> bool { @@ -244,6 +214,7 @@ impl TlsClient for MpcTlsClient { fn poll(&mut self, cx: &mut std::task::Context) -> Poll> { match std::mem::replace(&mut self.state, State::Error) { State::Start { mpc, inner } => { + trace!("inner client is starting"); self.state = State::Busy { mpc, fut: Box::pin(inner.start()), @@ -253,10 +224,36 @@ impl TlsClient for MpcTlsClient { State::Active { mpc, inner } => { trace!("inner client is active"); - self.state = State::Busy { - mpc, - fut: Box::pin(inner.run()), - }; + if !inner.tls.is_handshaking() + && let Some(cmd) = self.cmds.pop_front() + { + match cmd { + Command::ClientClose => { + self.state = State::Busy { + mpc, + fut: Box::pin(inner.client_close()), + }; + } + Command::ServerClose => { + self.state = State::CloseBusy { + mpc, + fut: Box::pin(inner.server_close()), + }; + } + Command::Decrypt(enable) => { + self.decrypt = enable; + self.state = State::Busy { + mpc, + fut: Box::pin(inner.set_decrypt(enable)), + }; + } + } + } else { + self.state = State::Busy { + mpc, + fut: Box::pin(inner.run()), + }; + } self.poll(cx) } State::Busy { mut mpc, mut fut } => { @@ -277,73 +274,60 @@ impl TlsClient for MpcTlsClient { } Poll::Pending } - State::ClientClose { mut mpc, mut fut } => { - debug!("attempting to close connection clientside"); - match (fut.poll_unpin(cx)?, mpc.poll_unpin(cx)?) { - (Poll::Ready(inner), Poll::Ready((ctx, transcript))) => { - self.state = State::Finalizing { - fut: Box::pin(inner.finalize(ctx, transcript)), - }; - } - (Poll::Ready(inner), Poll::Pending) => { - self.state = State::ClientClose { - mpc, - fut: Box::pin(inner.client_close()), - }; - } - (Poll::Pending, Poll::Ready((ctx, transcript))) => { - self.state = State::Closing { - ctx, - transcript: Box::new(transcript), - fut, - }; - } - (Poll::Pending, Poll::Pending) => self.state = State::ClientClose { mpc, fut }, - } + State::MpcStop { mpc, inner } => { + trace!("inner client is stopping mpc"); + self.state = State::CloseBusy { + mpc, + fut: Box::pin(inner.stop()), + }; self.poll(cx) } - State::ServerClose { mut mpc, mut fut } => { - debug!("attempting to close connection serverside"); + State::CloseBusy { mut mpc, mut fut } => { + trace!("inner client is busy closing"); match (fut.poll_unpin(cx)?, mpc.poll_unpin(cx)?) { (Poll::Ready(inner), Poll::Ready((ctx, transcript))) => { self.state = State::Finalizing { fut: Box::pin(inner.finalize(ctx, transcript)), }; + self.poll(cx) } (Poll::Ready(inner), Poll::Pending) => { - self.state = State::ServerClose { - mpc, - fut: Box::pin(inner.server_close()), - }; + self.state = State::MpcStop { mpc, inner }; + Poll::Pending } (Poll::Pending, Poll::Ready((ctx, transcript))) => { - self.state = State::Closing { + self.state = State::Finishing { ctx, transcript: Box::new(transcript), fut, }; + Poll::Pending + } + (Poll::Pending, Poll::Pending) => { + self.state = State::CloseBusy { mpc, fut }; + Poll::Pending } - (Poll::Pending, Poll::Pending) => self.state = State::ServerClose { mpc, fut }, } - self.poll(cx) } - State::Closing { + State::Finishing { ctx, transcript, mut fut, } => { + trace!("inner client is finishing"); if let Poll::Ready(inner) = fut.poll_unpin(cx)? { self.state = State::Finalizing { fut: Box::pin(inner.finalize(ctx, *transcript)), }; + self.poll(cx) } else { - self.state = State::Closing { + self.state = State::Finishing { ctx, transcript, fut, }; + Poll::Pending } - self.poll(cx) } State::Finalizing { mut fut } => match fut.poll_unpin(cx) { Poll::Ready(output) => { @@ -372,7 +356,7 @@ impl TlsClient for MpcTlsClient { } Poll::Pending => { self.state = State::Finalizing { fut }; - self.poll(cx) + Poll::Pending } }, State::Finished => Poll::Ready(Err(ProverError::state( @@ -391,7 +375,7 @@ struct InnerState { vm: Arc>>, keys: SessionKeys, mpc_ctrl: LeaderCtrl, - closed: bool, + mpc_stopped: bool, } impl InnerState { @@ -415,28 +399,32 @@ impl InnerState { #[instrument(parent = &self.span, level = "debug", skip_all, err)] async fn client_close(mut self: Box) -> Result, ProverError> { - if self.tls.plaintext_is_empty() && self.tls.is_empty().await? && !self.closed { - if let Err(e) = self.tls.send_close_notify().await { - warn!("failed to send close_notify to server: {}", e); - }; - - self.mpc_ctrl.stop().await?; - self.closed = true; - debug!("closed connection"); + debug!("sending close notify"); + if let Err(e) = self.tls.send_close_notify().await { + warn!("failed to send close_notify to server: {}", e); } - self.run().await + Ok(self) } #[instrument(parent = &self.span, level = "debug", skip_all, err)] async fn server_close(mut self: Box) -> Result, ProverError> { - if self.tls.plaintext_is_empty() && self.tls.is_empty().await? && !self.closed { - self.tls.server_closed().await?; + self.tls.process_new_packets().await?; + self.tls.server_closed().await?; + debug!("closed connection serverside"); + Ok(self) + } + + #[instrument(parent = &self.span, level = "debug", skip_all, err)] + async fn stop(mut self: Box) -> Result, ProverError> { + self.tls.process_new_packets().await?; + if !self.mpc_stopped && self.tls.plaintext_is_empty() && self.tls.is_empty().await? { self.mpc_ctrl.stop().await?; - self.closed = true; - debug!("closed connection"); + self.mpc_stopped = true; + debug!("stopped mpc"); } - self.run().await + + Ok(self) } #[instrument(parent = &self.span, level = "debug", skip_all, err)] diff --git a/crates/tlsn/src/prover/conn.rs b/crates/tlsn/src/prover/conn.rs new file mode 100644 index 0000000000..708c615d27 --- /dev/null +++ b/crates/tlsn/src/prover/conn.rs @@ -0,0 +1,295 @@ +use crate::prover::{ + Prover, ProverError, conn::buffer::SimpleBuffer, control::ProverControl, state, +}; +use futures::{AsyncRead, AsyncWrite}; +use std::{ + pin::Pin, + sync::{Arc, Mutex, MutexGuard, Weak}, + task::{Context, Poll, Waker}, +}; + +mod buffer; + +const BUF_CAP: usize = 8 * 1024; + +/// A TLS connection to a server. +/// +/// This type implements [`AsyncRead`] and [`AsyncWrite`] and can be used to +/// communicate with a server using TLS. +/// +/// # Note +/// +/// This connection is closed on a best-effort basis if this is dropped. To +/// ensure a clean close, you should call +/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the +/// connection. +pub struct TlsConnection { + prover: Weak>>, + conn_waker: Arc>>, + fut_waker: Arc>>, + closed: bool, +} + +impl TlsConnection { + pub(crate) fn new( + prover: Weak>>, + conn_waker: Arc>>, + fut_waker: Arc>>, + ) -> Self { + Self { + prover, + conn_waker, + fut_waker, + closed: false, + } + } + + fn conn_waker(&self) -> MutexGuard<'_, Option> { + self.conn_waker + .lock() + .expect("should be able to acquire lock for waker") + } + + fn fut_waker(&self) -> MutexGuard<'_, Option> { + self.fut_waker + .lock() + .expect("should be able to acquire lock for waker") + } +} + +impl Drop for TlsConnection { + fn drop(&mut self) { + if !self.closed + && let Some(prover) = self.prover.upgrade() + { + let mut prover = prover + .lock() + .expect("should be able to acquire lock for prover"); + prover + .client_close() + .expect("should be able to close connection clientside"); + + if let Some(waker) = self.fut_waker().as_ref() { + waker.wake_by_ref(); + } + self.closed = true; + } + } +} + +impl AsyncRead for TlsConnection { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + + if let Some(prover) = self.prover.upgrade() { + let mut prover = prover + .lock() + .expect("should be able to acquire lock for prover"); + + let read = prover.read(buf)?; + + if read != 0 { + if let Some(waker) = self.fut_waker().as_ref() { + waker.wake_by_ref(); + } + Poll::Ready(Ok(read)) + } else if self.closed { + Poll::Ready(Ok(0)) + } else { + *self.conn_waker() = Some(cx.waker().clone()); + Poll::Pending + } + } else { + Poll::Ready(Ok(0)) + } + } +} + +impl AsyncWrite for TlsConnection { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + + if let Some(prover) = self.prover.upgrade() { + let mut prover = prover + .lock() + .expect("should be able to acquire lock for prover"); + + let write = prover.write(buf)?; + if write != 0 { + if let Some(waker) = self.fut_waker().as_ref() { + waker.wake_by_ref(); + } + Poll::Ready(Ok(write)) + } else { + *self.conn_waker() = Some(cx.waker().clone()); + Poll::Pending + } + } else { + Poll::Ready(Ok(0)) + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if !self.closed + && let Some(prover) = self.prover.upgrade() + { + *self.conn_waker() = Some(cx.waker().clone()); + let mut prover = prover + .lock() + .expect("should be able to acquire lock for prover"); + prover.client_close()?; + + if let Some(waker) = self.fut_waker().as_ref() { + waker.wake_by_ref(); + } + + self.closed = true; + return Poll::Pending; + } + Poll::Ready(Ok(())) + } +} + +pin_project_lite::pin_project! { + /// A future to drive the connection. Must be polled to make progress. + pub struct ConnectionFuture { + #[pin] + socket: S, + prover: Option>>>, + conn_waker: Arc>>, + fut_waker: Arc>>, + read_buf: SimpleBuffer, + write_buf: SimpleBuffer, + } +} + +impl ConnectionFuture { + pub(crate) fn new( + socket: S, + prover: Arc>>, + conn_waker: Arc>>, + fut_waker: Arc>>, + ) -> Self { + Self { + socket, + prover: Some(prover), + conn_waker, + fut_waker, + read_buf: SimpleBuffer::default(), + write_buf: SimpleBuffer::default(), + } + } + + /// Returns a handle to control the prover. + pub fn handle(&self) -> Option { + if let Some(prover) = &self.prover { + let ctrl = ProverControl { + prover: Arc::downgrade(prover), + }; + Some(ctrl) + } else { + None + } + } +} + +impl Future for ConnectionFuture +where + S: AsyncRead + AsyncWrite + Send, +{ + type Output = Result, ProverError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + *self + .fut_waker + .lock() + .expect("should be able to acquire lock for waker") = Some(cx.waker().clone()); + + let mut this = self.project(); + let mut prover = this + .prover + .as_mut() + .expect("prover should be available") + .lock() + .expect("should be able to acquire lock for prover"); + + // read from socket into client + let mut tmp_read_buf = [0_u8; BUF_CAP]; + + if let Poll::Ready(read) = this.socket.as_mut().poll_read(cx, &mut tmp_read_buf)? { + if read > 0 { + this.read_buf.extend(&tmp_read_buf[..read]); + } else { + prover.server_close()?; + } + } + + if this.read_buf.len() > 0 { + let read = prover.read_tls(this.read_buf.inner())?; + this.read_buf.consume(read); + } + + // write from client into socket + let mut tmp_write_buf = [0_u8; BUF_CAP]; + let write = prover.write_tls(&mut tmp_write_buf)?; + + if write > 0 { + this.write_buf.extend(&tmp_write_buf[..write]); + } + + if this.write_buf.len() > 0 + && let Poll::Ready(write) = this + .socket + .as_mut() + .poll_write(cx, this.write_buf.inner())? + { + this.write_buf.consume(write); + let _ = this.socket.as_mut().poll_flush(cx)?; + } + + // poll prover + if let Poll::Ready(()) = prover.poll(cx)? { + std::mem::drop(prover); + + let mut prover = this.prover.take().expect("prover should be available"); + let prover = loop { + std::hint::spin_loop(); + + match Arc::try_unwrap(prover) { + Ok(prover) => break prover, + Err(arc_prover) => prover = arc_prover, + } + }; + + let prover = Mutex::into_inner(prover).expect("prover should be available"); + return Poll::Ready(prover.finish()); + } + + if let Some(waker) = this + .conn_waker + .lock() + .expect("should be able to acquire lock for waker") + .as_ref() + { + waker.wake_by_ref(); + } + + Poll::Pending + } +} diff --git a/crates/tlsn/src/prover/conn/buffer.rs b/crates/tlsn/src/prover/conn/buffer.rs new file mode 100644 index 0000000000..efa8fe22cb --- /dev/null +++ b/crates/tlsn/src/prover/conn/buffer.rs @@ -0,0 +1,47 @@ +//! Simple buffer implementation. + +use crate::prover::conn::BUF_CAP; +use bytes::{Buf, BufMut, BytesMut}; + +pub(crate) struct SimpleBuffer { + buf: BytesMut, +} + +impl Default for SimpleBuffer { + fn default() -> Self { + Self { + buf: BytesMut::with_capacity(BUF_CAP), + } + } +} + +impl SimpleBuffer { + /// Returns the underlying parts of the buffer which has not yet been + /// consumed. + pub(crate) fn inner(&self) -> &[u8] { + &self.buf + } + + /// Marks bytes as consumed. + /// + /// # Arguments + /// + /// * `n` - How many bytes to mark consumed. + pub(crate) fn consume(&mut self, n: usize) { + self.buf.advance(n); + } + + /// Appends bytes to the end of the buffer. + /// + /// # Arguments + /// + /// * `bytes` - The byte slice to append. + pub(crate) fn extend(&mut self, bytes: &[u8]) { + self.buf.put_slice(bytes); + } + + /// Returns the number of consumable bytes. + pub(crate) fn len(&self) -> usize { + self.buf.len() + } +} diff --git a/crates/tlsn/src/prover/control.rs b/crates/tlsn/src/prover/control.rs new file mode 100644 index 0000000000..f2f1fffdbb --- /dev/null +++ b/crates/tlsn/src/prover/control.rs @@ -0,0 +1,37 @@ +use std::sync::{Mutex, Weak}; + +use crate::prover::{Prover, ProverError, state}; + +/// A controller for the prover. +#[derive(Clone)] +pub struct ProverControl { + pub(crate) prover: Weak>>, +} + +impl ProverControl { + /// Returns whether the prover is decrypting the server traffic. + pub fn is_decrypting(&self) -> bool { + let Some(prover) = self.prover.upgrade() else { + return false; + }; + let prover = prover + .lock() + .expect("should be able to acquire lock for prover"); + prover.is_decrypting() + } + + /// Enables or disables the decryption of server traffic. + /// + /// # Arguments + /// + /// * `enable` - If decryption should be enabled or disabled. + pub fn enable_decryption(&self, enable: bool) -> Result<(), ProverError> { + let Some(prover) = self.prover.upgrade() else { + return Err(ProverError::state("prover not available anymore")); + }; + let mut prover = prover + .lock() + .expect("should be able to acquire lock for prover"); + prover.enable_decryption(enable) + } +} diff --git a/crates/tlsn/src/prover/error.rs b/crates/tlsn/src/prover/error.rs index 701e826583..d0d3007c8b 100644 --- a/crates/tlsn/src/prover/error.rs +++ b/crates/tlsn/src/prover/error.rs @@ -124,3 +124,9 @@ impl From for ProverError { Self::new(ErrorKind::Commit, e) } } + +impl From for std::io::Error { + fn from(value: ProverError) -> Self { + Self::other(value) + } +} diff --git a/crates/tlsn/src/prover/state.rs b/crates/tlsn/src/prover/state.rs index 83143aebc5..2a43dd8210 100644 --- a/crates/tlsn/src/prover/state.rs +++ b/crates/tlsn/src/prover/state.rs @@ -42,7 +42,7 @@ pub struct Connected { pub(crate) mux_ctrl: MuxControl, pub(crate) mux_fut: MuxFuture, pub(crate) server_name: ServerName, - pub(crate) tls_client: Box>, + pub(crate) tls_client: Box + Send>, pub(crate) output: Option, } diff --git a/crates/tlsn/tests/test.rs b/crates/tlsn/tests/test.rs index 27d597e016..6028d4600f 100644 --- a/crates/tlsn/tests/test.rs +++ b/crates/tlsn/tests/test.rs @@ -140,7 +140,7 @@ async fn prover( .unwrap(); let (mut tls_connection, prover_fut) = prover - .connect( + .connect_with( TlsClientConfig::builder() .server_name(ServerName::Dns(SERVER_DOMAIN.try_into().unwrap())) .root_store(RootCertStore { @@ -158,10 +158,10 @@ async fn prover( .write_all(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") .await .unwrap(); - tls_connection.close().await.unwrap(); let mut response = vec![0u8; 1024]; tls_connection.read_to_end(&mut response).await.unwrap(); + tls_connection.close().await.unwrap(); let _ = server_task.await.unwrap(); diff --git a/crates/wasm/Cargo.toml b/crates/wasm/Cargo.toml index 3393444c35..ace11d9c2a 100644 --- a/crates/wasm/Cargo.toml +++ b/crates/wasm/Cargo.toml @@ -23,7 +23,6 @@ no-bundler = ["web-spawn/no-bundler"] tlsn-core = { workspace = true } tlsn = { workspace = true, features = ["web", "mozilla-certs"] } tlsn-server-fixture-certs = { workspace = true } -tlsn-tls-client-async = { workspace = true } tlsn-tls-core = { workspace = true } bincode = { workspace = true } diff --git a/crates/wasm/src/prover/mod.rs b/crates/wasm/src/prover/mod.rs index d14276dc07..87a9ceffea 100644 --- a/crates/wasm/src/prover/mod.rs +++ b/crates/wasm/src/prover/mod.rs @@ -6,7 +6,6 @@ use enum_try_as_inner::EnumTryAsInner; use futures::TryFutureExt; use http_body_util::{BodyExt, Full}; use hyper::body::Bytes; -use tls_client_async::TlsConnection; use tlsn::{ config::{ prove::ProveConfig, @@ -14,7 +13,7 @@ use tlsn::{ tls_commit::{mpc::MpcTlsConfig, TlsCommitConfig}, }, connection::ServerName, - prover::{state, Prover}, + prover::{state, Prover, TlsConnection}, webpki::{CertificateDer, PrivateKeyDer, RootCertStore}, }; use tracing::info; @@ -148,7 +147,7 @@ impl JsProver { info!("connected to server"); - let (tls_conn, prover_fut) = prover.connect(config, server_conn.into_io()).await?; + let (tls_conn, prover_fut) = prover.connect_with(config, server_conn.into_io()).await?; info!("sending request");