@@ -19,7 +19,7 @@ use rustls::{
1919use crate :: error:: Error ;
2020use crate :: io:: ReadBuf ;
2121use crate :: net:: tls:: util:: StdSocket ;
22- use crate :: net:: tls:: TlsConfig ;
22+ use crate :: net:: tls:: { RawTlsConfig , TlsConfig } ;
2323use crate :: net:: Socket ;
2424
2525pub struct RustlsSocket < S : Socket > {
@@ -87,100 +87,136 @@ impl<S: Socket> Socket for RustlsSocket<S> {
8787 }
8888}
8989
90- pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
91- where
92- S : Socket ,
93- {
94- #[ cfg( all(
95- feature = "_tls-rustls-aws-lc-rs" ,
96- not( feature = "_tls-rustls-ring-webpki" ) ,
97- not( feature = "_tls-rustls-ring-native-roots" )
98- ) ) ]
99- let provider = Arc :: new ( rustls:: crypto:: aws_lc_rs:: default_provider ( ) ) ;
100- #[ cfg( any(
101- feature = "_tls-rustls-ring-webpki" ,
102- feature = "_tls-rustls-ring-native-roots"
103- ) ) ]
104- let provider = Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) ;
105-
106- // Unwrapping is safe here because we use a default provider.
107- let config = ClientConfig :: builder_with_provider ( provider. clone ( ) )
90+ impl TlsConfig < ' _ > {
91+ async fn rustls_config ( & self ) -> crate :: Result < ( rustls:: ClientConfig , & str ) , Error > {
92+ let RawTlsConfig {
93+ accept_invalid_certs,
94+ accept_invalid_hostnames,
95+ hostname,
96+ root_cert,
97+ client_cert,
98+ client_key,
99+ } = match self {
100+ TlsConfig :: RawTlsConfig ( raw) => raw,
101+ TlsConfig :: PrebuiltRustls { config, hostname } => {
102+ return Ok ( ( ( * config) . to_owned ( ) , hostname) ) ;
103+ }
104+ } ;
105+
106+ #[ cfg( all(
107+ feature = "_tls-rustls-aws-lc-rs" ,
108+ not( feature = "_tls-rustls-ring-webpki" ) ,
109+ not( feature = "_tls-rustls-ring-native-roots" )
110+ ) ) ]
111+ let config = ClientConfig :: builder_with_provider ( Arc :: new (
112+ rustls:: crypto:: aws_lc_rs:: default_provider ( ) ,
113+ ) )
108114 . with_safe_default_protocol_versions ( )
109115 . unwrap ( ) ;
116+ #[ cfg( any(
117+ feature = "_tls-rustls-ring-webpki" ,
118+ feature = "_tls-rustls-ring-native-roots"
119+ ) ) ]
120+ let config =
121+ ClientConfig :: builder_with_provider ( Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) )
122+ . with_safe_default_protocol_versions ( )
123+ . unwrap ( ) ;
124+ #[ cfg( all(
125+ not( feature = "_tls-rustls-ring-webpki" ) ,
126+ not( feature = "_tls-rustls-ring-native-roots" )
127+ ) ) ]
128+ let config = ClientConfig :: builder ( )
129+ . with_safe_default_protocol_versions ( )
130+ . unwrap ( ) ;
131+
132+ // authentication using user's key and its associated certificate
133+ let user_auth = match ( client_cert, client_key) {
134+ ( Some ( cert) , Some ( key) ) => {
135+ let cert_chain = certs_from_pem ( cert. data ( ) . await ?) ?;
136+ let key_der = private_key_from_pem ( key. data ( ) . await ?) ?;
137+ Some ( ( cert_chain, key_der) )
138+ }
139+ ( None , None ) => None ,
140+ ( _, _) => {
141+ return Err ( Error :: Configuration (
142+ "user auth key and certs must be given together" . into ( ) ,
143+ ) )
144+ }
145+ } ;
110146
111- // authentication using user's key and its associated certificate
112- let user_auth = match ( tls_config. client_cert_path , tls_config. client_key_path ) {
113- ( Some ( cert_path) , Some ( key_path) ) => {
114- let cert_chain = certs_from_pem ( cert_path. data ( ) . await ?) ?;
115- let key_der = private_key_from_pem ( key_path. data ( ) . await ?) ?;
116- Some ( ( cert_chain, key_der) )
117- }
118- ( None , None ) => None ,
119- ( _, _) => {
120- return Err ( Error :: Configuration (
121- "user auth key and certs must be given together" . into ( ) ,
122- ) )
123- }
124- } ;
147+ let provider = config. crypto_provider ( ) . clone ( ) ;
125148
126- let config = if tls_config. accept_invalid_certs {
127- if let Some ( user_auth) = user_auth {
128- config
129- . dangerous ( )
130- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
131- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
132- . map_err ( Error :: tls) ?
149+ let config = if * accept_invalid_certs {
150+ if let Some ( user_auth) = user_auth {
151+ config
152+ . dangerous ( )
153+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
154+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
155+ . map_err ( Error :: tls) ?
156+ } else {
157+ config
158+ . dangerous ( )
159+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
160+ . with_no_client_auth ( )
161+ }
133162 } else {
134- config
135- . dangerous ( )
136- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
137- . with_no_client_auth ( )
138- }
139- } else {
140- let mut cert_store = import_root_certs ( ) ;
163+ let mut cert_store = import_root_certs ( ) ;
141164
142- if let Some ( ca) = tls_config . root_cert_path {
143- let data = ca. data ( ) . await ?;
165+ if let Some ( ca) = root_cert {
166+ let data = ca. data ( ) . await ?;
144167
145- for result in CertificateDer :: pem_slice_iter ( & data) {
146- let Ok ( cert) = result else {
147- return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
148- } ;
168+ for result in CertificateDer :: pem_slice_iter ( & data) {
169+ let Ok ( cert) = result else {
170+ return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
171+ } ;
149172
150- cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
173+ cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
174+ }
151175 }
152- }
153-
154- if tls_config. accept_invalid_hostnames {
155- let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
156- . build ( )
157- . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
158176
159- if let Some ( user_auth) = user_auth {
177+ if * accept_invalid_hostnames {
178+ let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
179+ . build ( )
180+ . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
181+
182+ if let Some ( user_auth) = user_auth {
183+ config
184+ . dangerous ( )
185+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
186+ verifier,
187+ } ) )
188+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
189+ . map_err ( Error :: tls) ?
190+ } else {
191+ config
192+ . dangerous ( )
193+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
194+ verifier,
195+ } ) )
196+ . with_no_client_auth ( )
197+ }
198+ } else if let Some ( user_auth) = user_auth {
160199 config
161- . dangerous ( )
162- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
200+ . with_root_certificates ( cert_store)
163201 . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
164202 . map_err ( Error :: tls) ?
165203 } else {
166204 config
167- . dangerous ( )
168- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
205+ . with_root_certificates ( cert_store)
169206 . with_no_client_auth ( )
170207 }
171- } else if let Some ( user_auth) = user_auth {
172- config
173- . with_root_certificates ( cert_store)
174- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
175- . map_err ( Error :: tls) ?
176- } else {
177- config
178- . with_root_certificates ( cert_store)
179- . with_no_client_auth ( )
180- }
181- } ;
208+ } ;
209+
210+ Ok ( ( config, hostname) )
211+ }
212+ }
182213
183- let host = ServerName :: try_from ( tls_config. hostname . to_owned ( ) ) . map_err ( Error :: tls) ?;
214+ pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
215+ where
216+ S : Socket ,
217+ {
218+ let ( config, hostname) = tls_config. rustls_config ( ) . await ?;
219+ let host = ServerName :: try_from ( hostname. to_owned ( ) ) . map_err ( Error :: tls) ?;
184220
185221 let mut socket = RustlsSocket {
186222 inner : StdSocket :: new ( socket) ,
0 commit comments