From bb94956de539446e4018e991cd96ae6c12131c98 Mon Sep 17 00:00:00 2001 From: Einar Omang Date: Tue, 18 Nov 2025 14:26:39 +0100 Subject: [PATCH] Update secure channel params earlier Issue #175 is likely caused by a race condition, which happens because we don't update the secure channel until we've received the response over the oneshot channel from the transport. This fixes that by giving the transport access to the secure channel state directly, and calling `end_issue_or_renew_secure_channel` there instead of as part of sending the request. Nothing else has really changed, we still hold a lock while renewing the channel, this just makes it so that there is no way we can receive any more chunks before the secure channel has been updated. I noticed when doing this that we hold a write lock on the secure channel which we almost never need. I'd like us to fix that, at some point. --- async-opcua-client/src/transport/channel.rs | 27 +++++--- async-opcua-client/src/transport/connect.rs | 5 +- async-opcua-client/src/transport/core.rs | 27 +++++--- async-opcua-client/src/transport/state.rs | 75 +++++++++++---------- async-opcua-client/src/transport/tcp.rs | 31 ++++++--- 5 files changed, 97 insertions(+), 68 deletions(-) diff --git a/async-opcua-client/src/transport/channel.rs b/async-opcua-client/src/transport/channel.rs index c9f007eb..76ad05eb 100644 --- a/async-opcua-client/src/transport/channel.rs +++ b/async-opcua-client/src/transport/channel.rs @@ -1,6 +1,9 @@ use std::{str::FromStr, sync::Arc, time::Duration}; -use crate::{session::EndpointInfo, transport::core::TransportPollResult}; +use crate::{ + session::{process_unexpected_response, EndpointInfo}, + transport::core::TransportPollResult, +}; use arc_swap::{ArcSwap, ArcSwapOption}; use opcua_core::{ comms::secure_channel::{Role, SecureChannel}, @@ -37,7 +40,7 @@ pub struct AsyncSecureChannel { pub(crate) secure_channel: Arc>, certificate_store: Arc>, transport_config: TransportConfiguration, - state: SecureChannelState, + state: Arc, issue_channel_lock: tokio::sync::Mutex<()>, connector: Box, channel_lifetime: u32, @@ -147,7 +150,11 @@ impl AsyncSecureChannel { Self { transport_config, issue_channel_lock: tokio::sync::Mutex::new(()), - state: SecureChannelState::new(ignore_clock_skew, secure_channel.clone(), auth_token), + state: Arc::new(SecureChannelState::new( + ignore_clock_skew, + secure_channel.clone(), + auth_token, + )), endpoint_info, secure_channel, certificate_store, @@ -196,7 +203,9 @@ impl AsyncSecureChannel { let resp = request.send().await?; - self.state.end_issue_or_renew_secure_channel(resp)?; + if !matches!(resp, ResponseMessage::OpenSecureChannel(_)) { + return Err(process_unexpected_response(resp)); + } } drop(guard); @@ -258,7 +267,9 @@ impl AsyncSecureChannel { }; self.request_send.store(Some(Arc::new(send))); - self.state.end_issue_or_renew_secure_channel(resp)?; + if !matches!(resp, ResponseMessage::OpenSecureChannel(_)) { + return Err(process_unexpected_response(resp)); + } Ok(SecureChannelEventLoop { transport }) } @@ -305,11 +316,7 @@ impl AsyncSecureChannel { let (send, recv) = tokio::sync::mpsc::channel(MAX_INFLIGHT_MESSAGES); let transport = self .connector - .connect( - self.secure_channel.clone(), - recv, - self.transport_config.clone(), - ) + .connect(self.state.clone(), recv, self.transport_config.clone()) .await?; Ok((transport, send)) diff --git a/async-opcua-client/src/transport/connect.rs b/async-opcua-client/src/transport/connect.rs index a353fbf7..c2553357 100644 --- a/async-opcua-client/src/transport/connect.rs +++ b/async-opcua-client/src/transport/connect.rs @@ -1,9 +1,10 @@ use std::{future::Future, sync::Arc}; use async_trait::async_trait; -use opcua_core::{comms::secure_channel::SecureChannel, sync::RwLock}; use opcua_types::{EndpointDescription, Error, StatusCode}; +use crate::transport::state::SecureChannelState; + use super::{ tcp::{TcpTransport, TransportConfiguration}, OutgoingMessage, TcpConnector, TransportPollResult, @@ -23,7 +24,7 @@ pub trait Connector: Send + Sync { /// calling `run` on the returned transport in order to actually send and receive messages. async fn connect( &self, - channel: Arc>, + channel: Arc, outgoing_recv: tokio::sync::mpsc::Receiver, config: TransportConfiguration, ) -> Result; diff --git a/async-opcua-client/src/transport/core.rs b/async-opcua-client/src/transport/core.rs index 721f348b..4de57012 100644 --- a/async-opcua-client/src/transport/core.rs +++ b/async-opcua-client/src/transport/core.rs @@ -5,17 +5,18 @@ use std::time::Instant; use futures::future::Either; use opcua_core::comms::sequence_number::SequenceNumberHandle; use opcua_core::{trace_read_lock, trace_write_lock, RequestMessage, ResponseMessage}; -use parking_lot::RwLock; use tracing::{debug, error, trace, warn}; use opcua_core::comms::buffer::SendBuffer; use opcua_core::comms::message_chunk::MessageIsFinalType; use opcua_core::comms::{ chunker::Chunker, message_chunk::MessageChunk, message_chunk_info::ChunkInfo, - secure_channel::SecureChannel, tcp_codec::Message, + tcp_codec::Message, }; use opcua_types::{Error, StatusCode}; +use crate::transport::state::SecureChannelState; + #[derive(Debug)] struct MessageChunkWithChunkInfo { header: ChunkInfo, @@ -34,7 +35,7 @@ pub(super) struct TransportState { /// State of pending requests message_states: HashMap, /// Secure channel - pub(super) secure_channel: Arc>, + pub(super) channel_state: Arc, /// Max pending incoming messages max_chunk_count: usize, /// Last decoded sequence number @@ -69,17 +70,18 @@ pub struct OutgoingMessage { impl TransportState { pub(super) fn new( - secure_channel: Arc>, + channel_state: Arc, outgoing_recv: tokio::sync::mpsc::Receiver, max_chunk_count: usize, receive_buffer_size: usize, ) -> Self { - let legacy_sequence_numbers = secure_channel + let legacy_sequence_numbers = channel_state + .secure_channel() .read() .security_policy() .legacy_sequence_numbers(); Self { - secure_channel, + channel_state, outgoing_recv, message_states: HashMap::new(), sequence_numbers: SequenceNumberHandle::new(legacy_sequence_numbers), @@ -180,7 +182,9 @@ impl TransportState { } fn process_chunk(&mut self, chunk: MessageChunk) -> Result<(), StatusCode> { - let mut secure_channel = trace_write_lock!(self.secure_channel); + let mut secure_channel = trace_write_lock!(self.channel_state.secure_channel()); + // TODO: This is mut only because it _might_ be an open secure channel chunk. We should refactor + // this to avoid the write lock in most cases. let chunk = secure_channel.verify_and_remove_security(&chunk.data)?; let chunk_info = chunk.chunk_info(&secure_channel)?; @@ -239,6 +243,13 @@ impl TransportState { let in_chunks = Self::merge_chunks(message_state.chunks)?; let message = self.turn_received_chunks_into_message(&in_chunks)?; + // If the message is a response to opening a secure channel, we need to update encryption keys + // right now. If we wait, we risk new messages using the new encryption keys arriving before + // we've updated the secure channel. + if let ResponseMessage::OpenSecureChannel(msg) = &message { + self.channel_state.end_issue_or_renew_secure_channel(msg)?; + } + let _ = message_state.callback.send(Ok(message)); } } @@ -250,7 +261,7 @@ impl TransportState { chunks: &[MessageChunk], ) -> Result { // Validate that all chunks have incrementing sequence numbers and valid chunk types - let secure_channel = trace_read_lock!(self.secure_channel); + let secure_channel = trace_read_lock!(self.channel_state.secure_channel()); self.sequence_numbers.set(Chunker::validate_chunks( self.sequence_numbers.clone(), &secure_channel, diff --git a/async-opcua-client/src/transport/state.rs b/async-opcua-client/src/transport/state.rs index f742ae8b..43ab4556 100644 --- a/async-opcua-client/src/transport/state.rs +++ b/async-opcua-client/src/transport/state.rs @@ -6,7 +6,7 @@ use std::{ use tokio::sync::mpsc::error::SendTimeoutError; use tracing::{debug, trace}; -use crate::{session::process_unexpected_response, transport::OutgoingMessage}; +use crate::transport::OutgoingMessage; use arc_swap::ArcSwap; use opcua_core::{ comms::secure_channel::SecureChannel, handle::AtomicHandle, sync::RwLock, trace_write_lock, @@ -15,12 +15,13 @@ use opcua_core::{ use opcua_crypto::SecurityPolicy; use opcua_types::{ DateTime, DiagnosticBits, IntegerId, MessageSecurityMode, NodeId, OpenSecureChannelRequest, - RequestHeader, SecurityTokenRequestType, StatusCode, + OpenSecureChannelResponse, RequestHeader, SecurityTokenRequestType, StatusCode, }; pub(crate) type RequestSend = tokio::sync::mpsc::Sender; -pub(super) struct SecureChannelState { +/// The state of the secure channel used by the transport. +pub struct SecureChannelState { /// Time offset between the client and the server. client_offset: ArcSwap, /// Ignore clock skew between the client and the server. @@ -152,45 +153,41 @@ impl SecureChannelState { pub(super) fn end_issue_or_renew_secure_channel( &self, - response: ResponseMessage, + response: &OpenSecureChannelResponse, ) -> Result<(), StatusCode> { - if let ResponseMessage::OpenSecureChannel(response) = response { - // Extract the security token from the response. - let mut security_token = response.security_token.clone(); - - // When ignoring clock skew, we calculate the time offset between the client and the - // server and use that offset to compensate for the difference in time when setting - // the timestamps in the request headers and when decoding timestamps in messages - // received from the server. - if self.ignore_clock_skew && !response.response_header.timestamp.is_null() { - let offset = response.response_header.timestamp - DateTime::now(); - // Make sure to apply the offset to the security token in the current response. - security_token.created_at = security_token.created_at - offset; - // Update the client offset by adding the new offset. When the secure channel is - // renewed its already using the client offset calculated when issuing the secure - // channel and only needs to be updated to accommodate any additional clock skew. - self.set_client_offset(offset); - } + // Extract the security token from the response. + let mut security_token = response.security_token.clone(); + + // When ignoring clock skew, we calculate the time offset between the client and the + // server and use that offset to compensate for the difference in time when setting + // the timestamps in the request headers and when decoding timestamps in messages + // received from the server. + if self.ignore_clock_skew && !response.response_header.timestamp.is_null() { + let offset = response.response_header.timestamp - DateTime::now(); + // Make sure to apply the offset to the security token in the current response. + security_token.created_at = security_token.created_at - offset; + // Update the client offset by adding the new offset. When the secure channel is + // renewed its already using the client offset calculated when issuing the secure + // channel and only needs to be updated to accommodate any additional clock skew. + self.set_client_offset(offset); + } + + debug!("Setting transport's security token"); + { + let mut secure_channel = trace_write_lock!(self.secure_channel); + secure_channel.set_client_offset(**self.client_offset.load()); + secure_channel.set_security_token(security_token); - debug!("Setting transport's security token"); + if secure_channel.security_policy() != SecurityPolicy::None + && (secure_channel.security_mode() == MessageSecurityMode::Sign + || secure_channel.security_mode() == MessageSecurityMode::SignAndEncrypt) { - let mut secure_channel = trace_write_lock!(self.secure_channel); - secure_channel.set_client_offset(**self.client_offset.load()); - secure_channel.set_security_token(security_token); - - if secure_channel.security_policy() != SecurityPolicy::None - && (secure_channel.security_mode() == MessageSecurityMode::Sign - || secure_channel.security_mode() == MessageSecurityMode::SignAndEncrypt) - { - secure_channel.validate_secure_channel_nonce_length(&response.server_nonce)?; - secure_channel.set_remote_nonce_from_byte_string(&response.server_nonce)?; - secure_channel.derive_keys(); - } + secure_channel.validate_secure_channel_nonce_length(&response.server_nonce)?; + secure_channel.set_remote_nonce_from_byte_string(&response.server_nonce)?; + secure_channel.derive_keys(); } - Ok(()) - } else { - Err(process_unexpected_response(response)) } + Ok(()) } /// Construct a request header for the session. All requests after create session are expected @@ -213,4 +210,8 @@ impl SecureChannelState { pub(super) fn set_auth_token(&self, token: NodeId) { self.authentication_token.store(Arc::new(token)); } + + pub fn secure_channel(&self) -> &RwLock { + &self.secure_channel + } } diff --git a/async-opcua-client/src/transport/tcp.rs b/async-opcua-client/src/transport/tcp.rs index 6fe64edb..2f8111de 100644 --- a/async-opcua-client/src/transport/tcp.rs +++ b/async-opcua-client/src/transport/tcp.rs @@ -1,6 +1,8 @@ use std::net::SocketAddr; use std::sync::Arc; +use crate::transport::state::SecureChannelState; + use super::connect::{Connector, Transport}; use super::core::{OutgoingMessage, TransportPollResult, TransportState}; use async_trait::async_trait; @@ -180,15 +182,20 @@ impl TcpConnector { impl Connector for TcpConnector { async fn connect( &self, - channel: Arc>, + channel: Arc, outgoing_recv: tokio::sync::mpsc::Receiver, config: TransportConfiguration, ) -> Result { - let (framed_read, writer, ack, policy) = - match Self::connect_inner(&channel, &config, &self.endpoint_url).await { - Ok(k) => k, - Err(status) => return Err(status), - }; + let (framed_read, writer, ack, policy) = match Self::connect_inner( + channel.secure_channel(), + &config, + &self.endpoint_url, + ) + .await + { + Ok(k) => k, + Err(status) => return Err(status), + }; let mut buffer = SendBuffer::new( config.send_buffer_size, config.max_message_size, @@ -405,20 +412,22 @@ impl ReverseTcpConnector { impl Connector for ReverseTcpConnector { async fn connect( &self, - channel: Arc>, + channel: Arc, outgoing_recv: tokio::sync::mpsc::Receiver, config: TransportConfiguration, ) -> Result { let (framed_read, writer, ack, policy, endpoint_url) = match &self.listener { TcpConnectorReceiver::Listener(listener) => { - self.connect_inner(listener, &channel, &config).await? + self.connect_inner(listener, channel.secure_channel(), &config) + .await? } TcpConnectorReceiver::Address(addr) => { let listener = TcpListener::bind(addr).await.map_err(|err| { error!("Could not bind to address {}, {:?}", addr, err); StatusCode::BadCommunicationError })?; - self.connect_inner(&listener, &channel, &config).await? + self.connect_inner(&listener, channel.secure_channel(), &config) + .await? } }; @@ -487,7 +496,7 @@ impl TcpTransport { // If there's nothing in the send buffer, but there are chunks available, // write them to the send buffer before proceeding. if self.send_buffer.should_encode_chunks() { - let secure_channel = trace_read_lock!(self.state.secure_channel); + let secure_channel = trace_read_lock!(self.state.channel_state.secure_channel()); if let Err(e) = self.send_buffer.encode_next_chunk(&secure_channel) { return TransportPollResult::Closed(e); } @@ -525,7 +534,7 @@ impl TcpTransport { self.should_close = true; debug!("Writer is about to send a CloseSecureChannelRequest which means it should close in a moment"); } - let secure_channel = trace_read_lock!(self.state.secure_channel); + let secure_channel = trace_read_lock!(self.state.channel_state.secure_channel()); if let Err(e) = self.send_buffer.write(request_id, outgoing, &secure_channel) { drop(secure_channel); if let Some((request_id, request_handle)) = e.full_context() {