11use std:: io:: { Error , ErrorKind , IoSlice , Result } ;
22use std:: pin:: Pin ;
33use std:: ptr;
4- use std:: sync:: Arc ;
54use std:: task:: { Context , Poll , RawWaker , RawWakerVTable , Waker } ;
65use std:: time:: Duration ;
76
87use bytes:: buf:: BufMut ;
98use ignore_result:: Ignore ;
10- use rustls:: pki_types:: ServerName ;
11- use rustls:: ClientConfig ;
129use tokio:: io:: { AsyncBufReadExt , AsyncRead , AsyncWrite , AsyncWriteExt , BufStream , ReadBuf } ;
1310use tokio:: net:: TcpStream ;
1411use tokio:: { select, time} ;
15- use tokio_rustls:: client:: TlsStream ;
16- use tokio_rustls:: TlsConnector ;
1712use tracing:: { debug, trace} ;
1813
14+ #[ cfg( feature = "tls" ) ]
15+ mod tls {
16+ pub use std:: sync:: Arc ;
17+
18+ pub use rustls:: pki_types:: ServerName ;
19+ pub use rustls:: ClientConfig ;
20+ pub use tokio_rustls:: client:: TlsStream ;
21+ pub use tokio_rustls:: TlsConnector ;
22+ }
23+ #[ cfg( feature = "tls" ) ]
24+ use tls:: * ;
25+
1926use crate :: deadline:: Deadline ;
2027use crate :: endpoint:: { EndpointRef , IterableEndpoints } ;
2128
2229const NOOP_VTABLE : RawWakerVTable =
2330 RawWakerVTable :: new ( |_| RawWaker :: new ( ptr:: null ( ) , & NOOP_VTABLE ) , |_| { } , |_| { } , |_| { } ) ;
2431const NOOP_WAKER : RawWaker = RawWaker :: new ( ptr:: null ( ) , & NOOP_VTABLE ) ;
2532
33+ #[ derive( Debug ) ]
2634pub enum Connection {
27- Tls ( TlsStream < TcpStream > ) ,
2835 Raw ( TcpStream ) ,
36+ #[ cfg( feature = "tls" ) ]
37+ Tls ( TlsStream < TcpStream > ) ,
2938}
3039
3140impl AsyncRead for Connection {
3241 fn poll_read ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < Result < ( ) > > {
3342 match self . get_mut ( ) {
3443 Self :: Raw ( stream) => Pin :: new ( stream) . poll_read ( cx, buf) ,
44+ #[ cfg( feature = "tls" ) ]
3545 Self :: Tls ( stream) => Pin :: new ( stream) . poll_read ( cx, buf) ,
3646 }
3747 }
@@ -41,20 +51,23 @@ impl AsyncWrite for Connection {
4151 fn poll_write ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize > > {
4252 match self . get_mut ( ) {
4353 Self :: Raw ( stream) => Pin :: new ( stream) . poll_write ( cx, buf) ,
54+ #[ cfg( feature = "tls" ) ]
4455 Self :: Tls ( stream) => Pin :: new ( stream) . poll_write ( cx, buf) ,
4556 }
4657 }
4758
4859 fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) > > {
4960 match self . get_mut ( ) {
5061 Self :: Raw ( stream) => Pin :: new ( stream) . poll_flush ( cx) ,
62+ #[ cfg( feature = "tls" ) ]
5163 Self :: Tls ( stream) => Pin :: new ( stream) . poll_flush ( cx) ,
5264 }
5365 }
5466
5567 fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) > > {
5668 match self . get_mut ( ) {
5769 Self :: Raw ( stream) => Pin :: new ( stream) . poll_shutdown ( cx) ,
70+ #[ cfg( feature = "tls" ) ]
5871 Self :: Tls ( stream) => Pin :: new ( stream) . poll_shutdown ( cx) ,
5972 }
6073 }
@@ -65,6 +78,7 @@ impl Connection {
6578 Self :: Raw ( stream)
6679 }
6780
81+ #[ cfg( feature = "tls" ) ]
6882 pub fn new_tls ( stream : TlsStream < TcpStream > ) -> Self {
6983 Self :: Tls ( stream)
7084 }
@@ -97,6 +111,7 @@ impl Connection {
97111 pub async fn readable ( & self ) -> Result < ( ) > {
98112 match self {
99113 Self :: Raw ( stream) => stream. readable ( ) . await ,
114+ #[ cfg( feature = "tls" ) ]
100115 Self :: Tls ( stream) => {
101116 let ( stream, session) = stream. get_ref ( ) ;
102117 if session. wants_read ( ) {
@@ -112,6 +127,7 @@ impl Connection {
112127 pub async fn writable ( & self ) -> Result < ( ) > {
113128 match self {
114129 Self :: Raw ( stream) => stream. writable ( ) . await ,
130+ #[ cfg( feature = "tls" ) ]
115131 Self :: Tls ( stream) => {
116132 let ( stream, _session) = stream. get_ref ( ) ;
117133 stream. writable ( ) . await
@@ -122,6 +138,7 @@ impl Connection {
122138 pub fn wants_write ( & self ) -> bool {
123139 match self {
124140 Self :: Raw ( _) => false ,
141+ #[ cfg( feature = "tls" ) ]
125142 Self :: Tls ( stream) => {
126143 let ( _stream, session) = stream. get_ref ( ) ;
127144 session. wants_write ( )
@@ -160,13 +177,33 @@ impl Connection {
160177
161178#[ derive( Clone ) ]
162179pub struct Connector {
163- tls : TlsConnector ,
180+ #[ cfg( feature = "tls" ) ]
181+ tls : Option < TlsConnector > ,
164182 timeout : Duration ,
165183}
166184
167185impl Connector {
168- pub fn new ( config : impl Into < Arc < ClientConfig > > ) -> Self {
169- Self { tls : TlsConnector :: from ( config. into ( ) ) , timeout : Duration :: from_secs ( 10 ) }
186+ #[ cfg( feature = "tls" ) ]
187+ #[ allow( dead_code) ]
188+ pub fn new ( ) -> Self {
189+ Self { tls : None , timeout : Duration :: from_secs ( 10 ) }
190+ }
191+
192+ #[ cfg( not( feature = "tls" ) ) ]
193+ pub fn new ( ) -> Self {
194+ Self { timeout : Duration :: from_secs ( 10 ) }
195+ }
196+
197+ #[ cfg( feature = "tls" ) ]
198+ pub fn with_tls ( config : ClientConfig ) -> Self {
199+ Self { tls : Some ( TlsConnector :: from ( Arc :: new ( config) ) ) , timeout : Duration :: from_secs ( 10 ) }
200+ }
201+
202+ #[ cfg( feature = "tls" ) ]
203+ async fn connect_tls ( & self , stream : TcpStream , host : & str ) -> Result < Connection > {
204+ let domain = ServerName :: try_from ( host) . unwrap ( ) . to_owned ( ) ;
205+ let stream = self . tls . as_ref ( ) . unwrap ( ) . connect ( domain, stream) . await ?;
206+ Ok ( Connection :: new_tls ( stream) )
170207 }
171208
172209 pub fn timeout ( & self ) -> Duration {
@@ -178,6 +215,14 @@ impl Connector {
178215 }
179216
180217 pub async fn connect ( & self , endpoint : EndpointRef < ' _ > , deadline : & mut Deadline ) -> Result < Connection > {
218+ if endpoint. tls {
219+ #[ cfg( feature = "tls" ) ]
220+ if self . tls . is_none ( ) {
221+ return Err ( Error :: new ( ErrorKind :: Unsupported , "tls not supported" ) ) ;
222+ }
223+ #[ cfg( not( feature = "tls" ) ) ]
224+ return Err ( Error :: new ( ErrorKind :: Unsupported , "tls not supported" ) ) ;
225+ }
181226 select ! {
182227 _ = unsafe { Pin :: new_unchecked( deadline) } => Err ( Error :: new( ErrorKind :: TimedOut , "deadline exceed" ) ) ,
183228 _ = time:: sleep( self . timeout) => Err ( Error :: new( ErrorKind :: TimedOut , format!( "connection timeout{:?} exceed" , self . timeout) ) ) ,
@@ -186,9 +231,10 @@ impl Connector {
186231 Err ( err) => Err ( err) ,
187232 Ok ( sock) => {
188233 let connection = if endpoint. tls {
189- let domain = ServerName :: try_from( endpoint. host) . unwrap( ) . to_owned( ) ;
190- let stream = self . tls. connect( domain, sock) . await ?;
191- Connection :: new_tls( stream)
234+ #[ cfg( not( feature = "tls" ) ) ]
235+ unreachable!( "tls not supported" ) ;
236+ #[ cfg( feature = "tls" ) ]
237+ self . connect_tls( sock, endpoint. host) . await ?
192238 } else {
193239 Connection :: new_raw( sock)
194240 } ;
@@ -231,3 +277,20 @@ impl Connector {
231277 None
232278 }
233279}
280+
281+ #[ cfg( test) ]
282+ mod tests {
283+ use std:: io:: ErrorKind ;
284+
285+ use super :: Connector ;
286+ use crate :: deadline:: Deadline ;
287+ use crate :: endpoint:: EndpointRef ;
288+
289+ #[ tokio:: test]
290+ async fn raw ( ) {
291+ let connector = Connector :: new ( ) ;
292+ let endpoint = EndpointRef :: new ( "host1" , 2181 , true ) ;
293+ let err = connector. connect ( endpoint, & mut Deadline :: never ( ) ) . await . unwrap_err ( ) ;
294+ assert_eq ! ( err. kind( ) , ErrorKind :: Unsupported ) ;
295+ }
296+ }
0 commit comments