Skip to content

Commit 9d94d6d

Browse files
committed
feat: add new socket api
1 parent b6b0922 commit 9d94d6d

File tree

3 files changed

+159
-1
lines changed

3 files changed

+159
-1
lines changed

crates/tlsn/src/prover.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
//! Prover.
22
33
mod client;
4+
mod conn;
5+
mod control;
46
mod error;
57
mod prove;
68
pub mod state;
79

10+
pub use conn::{ConnectionFuture, TlsConnection};
11+
pub use control::ProverControl;
812
pub use error::ProverError;
913
pub use tlsn_core::ProverOutput;
1014

@@ -21,7 +25,7 @@ use futures::{AsyncRead, AsyncWrite, FutureExt, TryFutureExt};
2125
use rustls_pki_types::CertificateDer;
2226
use serio::{SinkExt, stream::IoStreamExt};
2327
use std::{
24-
sync::Arc,
28+
sync::{Arc, Mutex},
2529
task::{Context, Poll},
2630
};
2731
use tls_client::{ClientConnection, ServerName as TlsServerName};
@@ -223,6 +227,35 @@ impl Prover<state::CommitAccepted> {
223227
};
224228
Ok(prover)
225229
}
230+
231+
/// Connects the prover and attaches a socket.
232+
///
233+
/// This is a convenience function which returns
234+
/// - [`TlsConnection`] for reading and writing traffic as well as other
235+
/// connection-specific settings.
236+
/// - [`ConnectionFuture`] which has to be polled for driving the
237+
/// connection forward.
238+
///
239+
/// # Arguments
240+
///
241+
/// * `config` - The TLS client configuration.
242+
#[instrument(parent = &self.span, level = "debug", skip_all, err)]
243+
pub async fn connect_with<S>(
244+
self,
245+
config: TlsClientConfig,
246+
socket: S,
247+
) -> Result<(TlsConnection, ConnectionFuture<S>), ProverError>
248+
where
249+
S: AsyncRead + AsyncWrite + Send,
250+
{
251+
let prover = self.connect(config).await?;
252+
let prover = Arc::new(Mutex::new(prover));
253+
254+
let conn = TlsConnection::new(prover.clone());
255+
let fut = ConnectionFuture::new(socket, prover.clone());
256+
257+
Ok((conn, fut))
258+
}
226259
}
227260

228261
impl Prover<state::Connected> {

crates/tlsn/src/prover/conn.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
use std::sync::{Arc, Mutex};
2+
3+
use crate::prover::{Prover, ProverError, control::ProverControl, state};
4+
use futures::{AsyncRead, AsyncWrite};
5+
6+
pub(crate) type ProverConnected = Arc<Mutex<Prover<state::Connected>>>;
7+
8+
/// A TLS connection to a server.
9+
///
10+
/// This type implements [`AsyncRead`] and [`AsyncWrite`] and can be used to
11+
/// communicate with a server using TLS.
12+
///
13+
/// # Note
14+
///
15+
/// This connection is closed on a best-effort basis if this is dropped. To
16+
/// ensure a clean close, you should call
17+
/// [`AsyncWriteExt::close`](futures::io::AsyncWriteExt::close) to close the
18+
/// connection.
19+
pub struct TlsConnection {
20+
prover: ProverConnected,
21+
}
22+
23+
impl TlsConnection {
24+
pub(crate) fn new(prover: ProverConnected) -> Self {
25+
Self { prover }
26+
}
27+
}
28+
29+
impl AsyncRead for TlsConnection {
30+
fn poll_read(
31+
self: std::pin::Pin<&mut Self>,
32+
cx: &mut std::task::Context<'_>,
33+
buf: &mut [u8],
34+
) -> std::task::Poll<std::io::Result<usize>> {
35+
todo!()
36+
}
37+
}
38+
39+
impl AsyncWrite for TlsConnection {
40+
fn poll_write(
41+
self: std::pin::Pin<&mut Self>,
42+
cx: &mut std::task::Context<'_>,
43+
buf: &[u8],
44+
) -> std::task::Poll<std::io::Result<usize>> {
45+
todo!()
46+
}
47+
48+
fn poll_flush(
49+
self: std::pin::Pin<&mut Self>,
50+
cx: &mut std::task::Context<'_>,
51+
) -> std::task::Poll<std::io::Result<()>> {
52+
todo!()
53+
}
54+
55+
fn poll_close(
56+
self: std::pin::Pin<&mut Self>,
57+
cx: &mut std::task::Context<'_>,
58+
) -> std::task::Poll<std::io::Result<()>> {
59+
todo!()
60+
}
61+
}
62+
63+
/// A future to drive the connection. Must be polled to make progress.
64+
pub struct ConnectionFuture<S> {
65+
socket: S,
66+
prover: ProverConnected,
67+
}
68+
69+
impl<S> ConnectionFuture<S> {
70+
pub(crate) fn new(socket: S, prover: ProverConnected) -> Self {
71+
Self { socket, prover }
72+
}
73+
74+
/// Returns a handle to control the prover.
75+
pub fn handle(&self) -> ProverControl {
76+
ProverControl {
77+
prover: self.prover.clone(),
78+
}
79+
}
80+
}
81+
82+
impl<S> Future for ConnectionFuture<S>
83+
where
84+
S: AsyncRead + AsyncWrite + Send,
85+
{
86+
type Output = Result<Prover<state::Committed>, ProverError>;
87+
88+
fn poll(
89+
self: std::pin::Pin<&mut Self>,
90+
cx: &mut std::task::Context<'_>,
91+
) -> std::task::Poll<Self::Output> {
92+
todo!()
93+
}
94+
}

crates/tlsn/src/prover/control.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use crate::prover::{ProverError, conn::ProverConnected};
2+
3+
/// A controller for the prover.
4+
#[derive(Clone)]
5+
pub struct ProverControl {
6+
pub(crate) prover: ProverConnected,
7+
}
8+
9+
impl ProverControl {
10+
/// Returns whether the prover is decrypting the server traffic.
11+
pub fn is_decrypting(&self) -> bool {
12+
let prover = self
13+
.prover
14+
.lock()
15+
.expect("should be able to acquire prover handle");
16+
prover.is_decrypting()
17+
}
18+
19+
/// Enables or disables the decryption of server traffic.
20+
///
21+
/// # Arguments
22+
///
23+
/// * `enable` - If decryption should be enabled or disabled.
24+
pub fn enable_decryption(&self, enable: bool) -> Result<(), ProverError> {
25+
let mut prover = self
26+
.prover
27+
.lock()
28+
.expect("should be able to acquire prover handle");
29+
prover.enable_decryption(enable)
30+
}
31+
}

0 commit comments

Comments
 (0)