diff --git a/async-opcua-client/src/transport/channel.rs b/async-opcua-client/src/transport/channel.rs index c9f007eb..76ad05eb 100644 --- a/async-opcua-client/src/transport/channel.rs +++ b/async-opcua-client/src/transport/channel.rs @@ -1,6 +1,9 @@ use std::{str::FromStr, sync::Arc, time::Duration}; -use crate::{session::EndpointInfo, transport::core::TransportPollResult}; +use crate::{ + session::{process_unexpected_response, EndpointInfo}, + transport::core::TransportPollResult, +}; use arc_swap::{ArcSwap, ArcSwapOption}; use opcua_core::{ comms::secure_channel::{Role, SecureChannel}, @@ -37,7 +40,7 @@ pub struct AsyncSecureChannel { pub(crate) secure_channel: Arc>, certificate_store: Arc>, transport_config: TransportConfiguration, - state: SecureChannelState, + state: Arc, issue_channel_lock: tokio::sync::Mutex<()>, connector: Box, channel_lifetime: u32, @@ -147,7 +150,11 @@ impl AsyncSecureChannel { Self { transport_config, issue_channel_lock: tokio::sync::Mutex::new(()), - state: SecureChannelState::new(ignore_clock_skew, secure_channel.clone(), auth_token), + state: Arc::new(SecureChannelState::new( + ignore_clock_skew, + secure_channel.clone(), + auth_token, + )), endpoint_info, secure_channel, certificate_store, @@ -196,7 +203,9 @@ impl AsyncSecureChannel { let resp = request.send().await?; - self.state.end_issue_or_renew_secure_channel(resp)?; + if !matches!(resp, ResponseMessage::OpenSecureChannel(_)) { + return Err(process_unexpected_response(resp)); + } } drop(guard); @@ -258,7 +267,9 @@ impl AsyncSecureChannel { }; self.request_send.store(Some(Arc::new(send))); - self.state.end_issue_or_renew_secure_channel(resp)?; + if !matches!(resp, ResponseMessage::OpenSecureChannel(_)) { + return Err(process_unexpected_response(resp)); + } Ok(SecureChannelEventLoop { transport }) } @@ -305,11 +316,7 @@ impl AsyncSecureChannel { let (send, recv) = tokio::sync::mpsc::channel(MAX_INFLIGHT_MESSAGES); let transport = self .connector - .connect( - self.secure_channel.clone(), - recv, - self.transport_config.clone(), - ) + .connect(self.state.clone(), recv, self.transport_config.clone()) .await?; Ok((transport, send)) diff --git a/async-opcua-client/src/transport/connect.rs b/async-opcua-client/src/transport/connect.rs index a353fbf7..c2553357 100644 --- a/async-opcua-client/src/transport/connect.rs +++ b/async-opcua-client/src/transport/connect.rs @@ -1,9 +1,10 @@ use std::{future::Future, sync::Arc}; use async_trait::async_trait; -use opcua_core::{comms::secure_channel::SecureChannel, sync::RwLock}; use opcua_types::{EndpointDescription, Error, StatusCode}; +use crate::transport::state::SecureChannelState; + use super::{ tcp::{TcpTransport, TransportConfiguration}, OutgoingMessage, TcpConnector, TransportPollResult, @@ -23,7 +24,7 @@ pub trait Connector: Send + Sync { /// calling `run` on the returned transport in order to actually send and receive messages. async fn connect( &self, - channel: Arc>, + channel: Arc, outgoing_recv: tokio::sync::mpsc::Receiver, config: TransportConfiguration, ) -> Result; diff --git a/async-opcua-client/src/transport/core.rs b/async-opcua-client/src/transport/core.rs index 721f348b..4de57012 100644 --- a/async-opcua-client/src/transport/core.rs +++ b/async-opcua-client/src/transport/core.rs @@ -5,17 +5,18 @@ use std::time::Instant; use futures::future::Either; use opcua_core::comms::sequence_number::SequenceNumberHandle; use opcua_core::{trace_read_lock, trace_write_lock, RequestMessage, ResponseMessage}; -use parking_lot::RwLock; use tracing::{debug, error, trace, warn}; use opcua_core::comms::buffer::SendBuffer; use opcua_core::comms::message_chunk::MessageIsFinalType; use opcua_core::comms::{ chunker::Chunker, message_chunk::MessageChunk, message_chunk_info::ChunkInfo, - secure_channel::SecureChannel, tcp_codec::Message, + tcp_codec::Message, }; use opcua_types::{Error, StatusCode}; +use crate::transport::state::SecureChannelState; + #[derive(Debug)] struct MessageChunkWithChunkInfo { header: ChunkInfo, @@ -34,7 +35,7 @@ pub(super) struct TransportState { /// State of pending requests message_states: HashMap, /// Secure channel - pub(super) secure_channel: Arc>, + pub(super) channel_state: Arc, /// Max pending incoming messages max_chunk_count: usize, /// Last decoded sequence number @@ -69,17 +70,18 @@ pub struct OutgoingMessage { impl TransportState { pub(super) fn new( - secure_channel: Arc>, + channel_state: Arc, outgoing_recv: tokio::sync::mpsc::Receiver, max_chunk_count: usize, receive_buffer_size: usize, ) -> Self { - let legacy_sequence_numbers = secure_channel + let legacy_sequence_numbers = channel_state + .secure_channel() .read() .security_policy() .legacy_sequence_numbers(); Self { - secure_channel, + channel_state, outgoing_recv, message_states: HashMap::new(), sequence_numbers: SequenceNumberHandle::new(legacy_sequence_numbers), @@ -180,7 +182,9 @@ impl TransportState { } fn process_chunk(&mut self, chunk: MessageChunk) -> Result<(), StatusCode> { - let mut secure_channel = trace_write_lock!(self.secure_channel); + let mut secure_channel = trace_write_lock!(self.channel_state.secure_channel()); + // TODO: This is mut only because it _might_ be an open secure channel chunk. We should refactor + // this to avoid the write lock in most cases. let chunk = secure_channel.verify_and_remove_security(&chunk.data)?; let chunk_info = chunk.chunk_info(&secure_channel)?; @@ -239,6 +243,13 @@ impl TransportState { let in_chunks = Self::merge_chunks(message_state.chunks)?; let message = self.turn_received_chunks_into_message(&in_chunks)?; + // If the message is a response to opening a secure channel, we need to update encryption keys + // right now. If we wait, we risk new messages using the new encryption keys arriving before + // we've updated the secure channel. + if let ResponseMessage::OpenSecureChannel(msg) = &message { + self.channel_state.end_issue_or_renew_secure_channel(msg)?; + } + let _ = message_state.callback.send(Ok(message)); } } @@ -250,7 +261,7 @@ impl TransportState { chunks: &[MessageChunk], ) -> Result { // Validate that all chunks have incrementing sequence numbers and valid chunk types - let secure_channel = trace_read_lock!(self.secure_channel); + let secure_channel = trace_read_lock!(self.channel_state.secure_channel()); self.sequence_numbers.set(Chunker::validate_chunks( self.sequence_numbers.clone(), &secure_channel, diff --git a/async-opcua-client/src/transport/state.rs b/async-opcua-client/src/transport/state.rs index f742ae8b..43ab4556 100644 --- a/async-opcua-client/src/transport/state.rs +++ b/async-opcua-client/src/transport/state.rs @@ -6,7 +6,7 @@ use std::{ use tokio::sync::mpsc::error::SendTimeoutError; use tracing::{debug, trace}; -use crate::{session::process_unexpected_response, transport::OutgoingMessage}; +use crate::transport::OutgoingMessage; use arc_swap::ArcSwap; use opcua_core::{ comms::secure_channel::SecureChannel, handle::AtomicHandle, sync::RwLock, trace_write_lock, @@ -15,12 +15,13 @@ use opcua_core::{ use opcua_crypto::SecurityPolicy; use opcua_types::{ DateTime, DiagnosticBits, IntegerId, MessageSecurityMode, NodeId, OpenSecureChannelRequest, - RequestHeader, SecurityTokenRequestType, StatusCode, + OpenSecureChannelResponse, RequestHeader, SecurityTokenRequestType, StatusCode, }; pub(crate) type RequestSend = tokio::sync::mpsc::Sender; -pub(super) struct SecureChannelState { +/// The state of the secure channel used by the transport. +pub struct SecureChannelState { /// Time offset between the client and the server. client_offset: ArcSwap, /// Ignore clock skew between the client and the server. @@ -152,45 +153,41 @@ impl SecureChannelState { pub(super) fn end_issue_or_renew_secure_channel( &self, - response: ResponseMessage, + response: &OpenSecureChannelResponse, ) -> Result<(), StatusCode> { - if let ResponseMessage::OpenSecureChannel(response) = response { - // Extract the security token from the response. - let mut security_token = response.security_token.clone(); - - // When ignoring clock skew, we calculate the time offset between the client and the - // server and use that offset to compensate for the difference in time when setting - // the timestamps in the request headers and when decoding timestamps in messages - // received from the server. - if self.ignore_clock_skew && !response.response_header.timestamp.is_null() { - let offset = response.response_header.timestamp - DateTime::now(); - // Make sure to apply the offset to the security token in the current response. - security_token.created_at = security_token.created_at - offset; - // Update the client offset by adding the new offset. When the secure channel is - // renewed its already using the client offset calculated when issuing the secure - // channel and only needs to be updated to accommodate any additional clock skew. - self.set_client_offset(offset); - } + // Extract the security token from the response. + let mut security_token = response.security_token.clone(); + + // When ignoring clock skew, we calculate the time offset between the client and the + // server and use that offset to compensate for the difference in time when setting + // the timestamps in the request headers and when decoding timestamps in messages + // received from the server. + if self.ignore_clock_skew && !response.response_header.timestamp.is_null() { + let offset = response.response_header.timestamp - DateTime::now(); + // Make sure to apply the offset to the security token in the current response. + security_token.created_at = security_token.created_at - offset; + // Update the client offset by adding the new offset. When the secure channel is + // renewed its already using the client offset calculated when issuing the secure + // channel and only needs to be updated to accommodate any additional clock skew. + self.set_client_offset(offset); + } + + debug!("Setting transport's security token"); + { + let mut secure_channel = trace_write_lock!(self.secure_channel); + secure_channel.set_client_offset(**self.client_offset.load()); + secure_channel.set_security_token(security_token); - debug!("Setting transport's security token"); + if secure_channel.security_policy() != SecurityPolicy::None + && (secure_channel.security_mode() == MessageSecurityMode::Sign + || secure_channel.security_mode() == MessageSecurityMode::SignAndEncrypt) { - let mut secure_channel = trace_write_lock!(self.secure_channel); - secure_channel.set_client_offset(**self.client_offset.load()); - secure_channel.set_security_token(security_token); - - if secure_channel.security_policy() != SecurityPolicy::None - && (secure_channel.security_mode() == MessageSecurityMode::Sign - || secure_channel.security_mode() == MessageSecurityMode::SignAndEncrypt) - { - secure_channel.validate_secure_channel_nonce_length(&response.server_nonce)?; - secure_channel.set_remote_nonce_from_byte_string(&response.server_nonce)?; - secure_channel.derive_keys(); - } + secure_channel.validate_secure_channel_nonce_length(&response.server_nonce)?; + secure_channel.set_remote_nonce_from_byte_string(&response.server_nonce)?; + secure_channel.derive_keys(); } - Ok(()) - } else { - Err(process_unexpected_response(response)) } + Ok(()) } /// Construct a request header for the session. All requests after create session are expected @@ -213,4 +210,8 @@ impl SecureChannelState { pub(super) fn set_auth_token(&self, token: NodeId) { self.authentication_token.store(Arc::new(token)); } + + pub fn secure_channel(&self) -> &RwLock { + &self.secure_channel + } } diff --git a/async-opcua-client/src/transport/tcp.rs b/async-opcua-client/src/transport/tcp.rs index 6fe64edb..2f8111de 100644 --- a/async-opcua-client/src/transport/tcp.rs +++ b/async-opcua-client/src/transport/tcp.rs @@ -1,6 +1,8 @@ use std::net::SocketAddr; use std::sync::Arc; +use crate::transport::state::SecureChannelState; + use super::connect::{Connector, Transport}; use super::core::{OutgoingMessage, TransportPollResult, TransportState}; use async_trait::async_trait; @@ -180,15 +182,20 @@ impl TcpConnector { impl Connector for TcpConnector { async fn connect( &self, - channel: Arc>, + channel: Arc, outgoing_recv: tokio::sync::mpsc::Receiver, config: TransportConfiguration, ) -> Result { - let (framed_read, writer, ack, policy) = - match Self::connect_inner(&channel, &config, &self.endpoint_url).await { - Ok(k) => k, - Err(status) => return Err(status), - }; + let (framed_read, writer, ack, policy) = match Self::connect_inner( + channel.secure_channel(), + &config, + &self.endpoint_url, + ) + .await + { + Ok(k) => k, + Err(status) => return Err(status), + }; let mut buffer = SendBuffer::new( config.send_buffer_size, config.max_message_size, @@ -405,20 +412,22 @@ impl ReverseTcpConnector { impl Connector for ReverseTcpConnector { async fn connect( &self, - channel: Arc>, + channel: Arc, outgoing_recv: tokio::sync::mpsc::Receiver, config: TransportConfiguration, ) -> Result { let (framed_read, writer, ack, policy, endpoint_url) = match &self.listener { TcpConnectorReceiver::Listener(listener) => { - self.connect_inner(listener, &channel, &config).await? + self.connect_inner(listener, channel.secure_channel(), &config) + .await? } TcpConnectorReceiver::Address(addr) => { let listener = TcpListener::bind(addr).await.map_err(|err| { error!("Could not bind to address {}, {:?}", addr, err); StatusCode::BadCommunicationError })?; - self.connect_inner(&listener, &channel, &config).await? + self.connect_inner(&listener, channel.secure_channel(), &config) + .await? } }; @@ -487,7 +496,7 @@ impl TcpTransport { // If there's nothing in the send buffer, but there are chunks available, // write them to the send buffer before proceeding. if self.send_buffer.should_encode_chunks() { - let secure_channel = trace_read_lock!(self.state.secure_channel); + let secure_channel = trace_read_lock!(self.state.channel_state.secure_channel()); if let Err(e) = self.send_buffer.encode_next_chunk(&secure_channel) { return TransportPollResult::Closed(e); } @@ -525,7 +534,7 @@ impl TcpTransport { self.should_close = true; debug!("Writer is about to send a CloseSecureChannelRequest which means it should close in a moment"); } - let secure_channel = trace_read_lock!(self.state.secure_channel); + let secure_channel = trace_read_lock!(self.state.channel_state.secure_channel()); if let Err(e) = self.send_buffer.write(request_id, outgoing, &secure_channel) { drop(secure_channel); if let Some((request_id, request_handle)) = e.full_context() {