Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions async-opcua-client/src/transport/channel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::{str::FromStr, sync::Arc, time::Duration};

use crate::{session::EndpointInfo, transport::core::TransportPollResult};
use crate::{
session::{process_unexpected_response, EndpointInfo},
transport::core::TransportPollResult,
};
use arc_swap::{ArcSwap, ArcSwapOption};
use opcua_core::{
comms::secure_channel::{Role, SecureChannel},
Expand Down Expand Up @@ -37,7 +40,7 @@ pub struct AsyncSecureChannel {
pub(crate) secure_channel: Arc<RwLock<SecureChannel>>,
certificate_store: Arc<RwLock<CertificateStore>>,
transport_config: TransportConfiguration,
state: SecureChannelState,
state: Arc<SecureChannelState>,
issue_channel_lock: tokio::sync::Mutex<()>,
connector: Box<dyn Connector>,
channel_lifetime: u32,
Expand Down Expand Up @@ -147,7 +150,11 @@ impl AsyncSecureChannel {
Self {
transport_config,
issue_channel_lock: tokio::sync::Mutex::new(()),
state: SecureChannelState::new(ignore_clock_skew, secure_channel.clone(), auth_token),
state: Arc::new(SecureChannelState::new(
ignore_clock_skew,
secure_channel.clone(),
auth_token,
)),
endpoint_info,
secure_channel,
certificate_store,
Expand Down Expand Up @@ -196,7 +203,9 @@ impl AsyncSecureChannel {

let resp = request.send().await?;

self.state.end_issue_or_renew_secure_channel(resp)?;
if !matches!(resp, ResponseMessage::OpenSecureChannel(_)) {
return Err(process_unexpected_response(resp));
}
}

drop(guard);
Expand Down Expand Up @@ -258,7 +267,9 @@ impl AsyncSecureChannel {
};

self.request_send.store(Some(Arc::new(send)));
self.state.end_issue_or_renew_secure_channel(resp)?;
if !matches!(resp, ResponseMessage::OpenSecureChannel(_)) {
return Err(process_unexpected_response(resp));
}

Ok(SecureChannelEventLoop { transport })
}
Expand Down Expand Up @@ -305,11 +316,7 @@ impl AsyncSecureChannel {
let (send, recv) = tokio::sync::mpsc::channel(MAX_INFLIGHT_MESSAGES);
let transport = self
.connector
.connect(
self.secure_channel.clone(),
recv,
self.transport_config.clone(),
)
.connect(self.state.clone(), recv, self.transport_config.clone())
.await?;

Ok((transport, send))
Expand Down
5 changes: 3 additions & 2 deletions async-opcua-client/src/transport/connect.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::{future::Future, sync::Arc};

use async_trait::async_trait;
use opcua_core::{comms::secure_channel::SecureChannel, sync::RwLock};
use opcua_types::{EndpointDescription, Error, StatusCode};

use crate::transport::state::SecureChannelState;

use super::{
tcp::{TcpTransport, TransportConfiguration},
OutgoingMessage, TcpConnector, TransportPollResult,
Expand All @@ -23,7 +24,7 @@ pub trait Connector: Send + Sync {
/// calling `run` on the returned transport in order to actually send and receive messages.
async fn connect(
&self,
channel: Arc<RwLock<SecureChannel>>,
channel: Arc<SecureChannelState>,
outgoing_recv: tokio::sync::mpsc::Receiver<OutgoingMessage>,
config: TransportConfiguration,
) -> Result<TcpTransport, StatusCode>;
Expand Down
27 changes: 19 additions & 8 deletions async-opcua-client/src/transport/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@ use std::time::Instant;
use futures::future::Either;
use opcua_core::comms::sequence_number::SequenceNumberHandle;
use opcua_core::{trace_read_lock, trace_write_lock, RequestMessage, ResponseMessage};
use parking_lot::RwLock;
use tracing::{debug, error, trace, warn};

use opcua_core::comms::buffer::SendBuffer;
use opcua_core::comms::message_chunk::MessageIsFinalType;
use opcua_core::comms::{
chunker::Chunker, message_chunk::MessageChunk, message_chunk_info::ChunkInfo,
secure_channel::SecureChannel, tcp_codec::Message,
tcp_codec::Message,
};
use opcua_types::{Error, StatusCode};

use crate::transport::state::SecureChannelState;

#[derive(Debug)]
struct MessageChunkWithChunkInfo {
header: ChunkInfo,
Expand All @@ -34,7 +35,7 @@ pub(super) struct TransportState {
/// State of pending requests
message_states: HashMap<u32, MessageState>,
/// Secure channel
pub(super) secure_channel: Arc<RwLock<SecureChannel>>,
pub(super) channel_state: Arc<SecureChannelState>,
/// Max pending incoming messages
max_chunk_count: usize,
/// Last decoded sequence number
Expand Down Expand Up @@ -69,17 +70,18 @@ pub struct OutgoingMessage {

impl TransportState {
pub(super) fn new(
secure_channel: Arc<RwLock<SecureChannel>>,
channel_state: Arc<SecureChannelState>,
outgoing_recv: tokio::sync::mpsc::Receiver<OutgoingMessage>,
max_chunk_count: usize,
receive_buffer_size: usize,
) -> Self {
let legacy_sequence_numbers = secure_channel
let legacy_sequence_numbers = channel_state
.secure_channel()
.read()
.security_policy()
.legacy_sequence_numbers();
Self {
secure_channel,
channel_state,
outgoing_recv,
message_states: HashMap::new(),
sequence_numbers: SequenceNumberHandle::new(legacy_sequence_numbers),
Expand Down Expand Up @@ -180,7 +182,9 @@ impl TransportState {
}

fn process_chunk(&mut self, chunk: MessageChunk) -> Result<(), StatusCode> {
let mut secure_channel = trace_write_lock!(self.secure_channel);
let mut secure_channel = trace_write_lock!(self.channel_state.secure_channel());
// TODO: This is mut only because it _might_ be an open secure channel chunk. We should refactor
// this to avoid the write lock in most cases.
let chunk = secure_channel.verify_and_remove_security(&chunk.data)?;

let chunk_info = chunk.chunk_info(&secure_channel)?;
Expand Down Expand Up @@ -239,6 +243,13 @@ impl TransportState {
let in_chunks = Self::merge_chunks(message_state.chunks)?;
let message = self.turn_received_chunks_into_message(&in_chunks)?;

// If the message is a response to opening a secure channel, we need to update encryption keys
// right now. If we wait, we risk new messages using the new encryption keys arriving before
// we've updated the secure channel.
if let ResponseMessage::OpenSecureChannel(msg) = &message {
self.channel_state.end_issue_or_renew_secure_channel(msg)?;
}

let _ = message_state.callback.send(Ok(message));
}
}
Expand All @@ -250,7 +261,7 @@ impl TransportState {
chunks: &[MessageChunk],
) -> Result<ResponseMessage, Error> {
// Validate that all chunks have incrementing sequence numbers and valid chunk types
let secure_channel = trace_read_lock!(self.secure_channel);
let secure_channel = trace_read_lock!(self.channel_state.secure_channel());
self.sequence_numbers.set(Chunker::validate_chunks(
self.sequence_numbers.clone(),
&secure_channel,
Expand Down
75 changes: 38 additions & 37 deletions async-opcua-client/src/transport/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use tokio::sync::mpsc::error::SendTimeoutError;
use tracing::{debug, trace};

use crate::{session::process_unexpected_response, transport::OutgoingMessage};
use crate::transport::OutgoingMessage;
use arc_swap::ArcSwap;
use opcua_core::{
comms::secure_channel::SecureChannel, handle::AtomicHandle, sync::RwLock, trace_write_lock,
Expand All @@ -15,12 +15,13 @@ use opcua_core::{
use opcua_crypto::SecurityPolicy;
use opcua_types::{
DateTime, DiagnosticBits, IntegerId, MessageSecurityMode, NodeId, OpenSecureChannelRequest,
RequestHeader, SecurityTokenRequestType, StatusCode,
OpenSecureChannelResponse, RequestHeader, SecurityTokenRequestType, StatusCode,
};

pub(crate) type RequestSend = tokio::sync::mpsc::Sender<OutgoingMessage>;

pub(super) struct SecureChannelState {
/// The state of the secure channel used by the transport.
pub struct SecureChannelState {
/// Time offset between the client and the server.
client_offset: ArcSwap<chrono::Duration>,
/// Ignore clock skew between the client and the server.
Expand Down Expand Up @@ -152,45 +153,41 @@ impl SecureChannelState {

pub(super) fn end_issue_or_renew_secure_channel(
&self,
response: ResponseMessage,
response: &OpenSecureChannelResponse,
) -> Result<(), StatusCode> {
if let ResponseMessage::OpenSecureChannel(response) = response {
// Extract the security token from the response.
let mut security_token = response.security_token.clone();

// When ignoring clock skew, we calculate the time offset between the client and the
// server and use that offset to compensate for the difference in time when setting
// the timestamps in the request headers and when decoding timestamps in messages
// received from the server.
if self.ignore_clock_skew && !response.response_header.timestamp.is_null() {
let offset = response.response_header.timestamp - DateTime::now();
// Make sure to apply the offset to the security token in the current response.
security_token.created_at = security_token.created_at - offset;
// Update the client offset by adding the new offset. When the secure channel is
// renewed its already using the client offset calculated when issuing the secure
// channel and only needs to be updated to accommodate any additional clock skew.
self.set_client_offset(offset);
}
// Extract the security token from the response.
let mut security_token = response.security_token.clone();

// When ignoring clock skew, we calculate the time offset between the client and the
// server and use that offset to compensate for the difference in time when setting
// the timestamps in the request headers and when decoding timestamps in messages
// received from the server.
if self.ignore_clock_skew && !response.response_header.timestamp.is_null() {
let offset = response.response_header.timestamp - DateTime::now();
// Make sure to apply the offset to the security token in the current response.
security_token.created_at = security_token.created_at - offset;
// Update the client offset by adding the new offset. When the secure channel is
// renewed its already using the client offset calculated when issuing the secure
// channel and only needs to be updated to accommodate any additional clock skew.
self.set_client_offset(offset);
}

debug!("Setting transport's security token");
{
let mut secure_channel = trace_write_lock!(self.secure_channel);
secure_channel.set_client_offset(**self.client_offset.load());
secure_channel.set_security_token(security_token);

debug!("Setting transport's security token");
if secure_channel.security_policy() != SecurityPolicy::None
&& (secure_channel.security_mode() == MessageSecurityMode::Sign
|| secure_channel.security_mode() == MessageSecurityMode::SignAndEncrypt)
{
let mut secure_channel = trace_write_lock!(self.secure_channel);
secure_channel.set_client_offset(**self.client_offset.load());
secure_channel.set_security_token(security_token);

if secure_channel.security_policy() != SecurityPolicy::None
&& (secure_channel.security_mode() == MessageSecurityMode::Sign
|| secure_channel.security_mode() == MessageSecurityMode::SignAndEncrypt)
{
secure_channel.validate_secure_channel_nonce_length(&response.server_nonce)?;
secure_channel.set_remote_nonce_from_byte_string(&response.server_nonce)?;
secure_channel.derive_keys();
}
secure_channel.validate_secure_channel_nonce_length(&response.server_nonce)?;
secure_channel.set_remote_nonce_from_byte_string(&response.server_nonce)?;
secure_channel.derive_keys();
}
Ok(())
} else {
Err(process_unexpected_response(response))
}
Ok(())
}

/// Construct a request header for the session. All requests after create session are expected
Expand All @@ -213,4 +210,8 @@ impl SecureChannelState {
pub(super) fn set_auth_token(&self, token: NodeId) {
self.authentication_token.store(Arc::new(token));
}

pub fn secure_channel(&self) -> &RwLock<SecureChannel> {
&self.secure_channel
}
}
31 changes: 20 additions & 11 deletions async-opcua-client/src/transport/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::net::SocketAddr;
use std::sync::Arc;

use crate::transport::state::SecureChannelState;

use super::connect::{Connector, Transport};
use super::core::{OutgoingMessage, TransportPollResult, TransportState};
use async_trait::async_trait;
Expand Down Expand Up @@ -180,15 +182,20 @@ impl TcpConnector {
impl Connector for TcpConnector {
async fn connect(
&self,
channel: Arc<RwLock<SecureChannel>>,
channel: Arc<SecureChannelState>,
outgoing_recv: tokio::sync::mpsc::Receiver<OutgoingMessage>,
config: TransportConfiguration,
) -> Result<TcpTransport, StatusCode> {
let (framed_read, writer, ack, policy) =
match Self::connect_inner(&channel, &config, &self.endpoint_url).await {
Ok(k) => k,
Err(status) => return Err(status),
};
let (framed_read, writer, ack, policy) = match Self::connect_inner(
channel.secure_channel(),
&config,
&self.endpoint_url,
)
.await
{
Ok(k) => k,
Err(status) => return Err(status),
};
let mut buffer = SendBuffer::new(
config.send_buffer_size,
config.max_message_size,
Expand Down Expand Up @@ -405,20 +412,22 @@ impl ReverseTcpConnector {
impl Connector for ReverseTcpConnector {
async fn connect(
&self,
channel: Arc<RwLock<SecureChannel>>,
channel: Arc<SecureChannelState>,
outgoing_recv: tokio::sync::mpsc::Receiver<OutgoingMessage>,
config: TransportConfiguration,
) -> Result<TcpTransport, StatusCode> {
let (framed_read, writer, ack, policy, endpoint_url) = match &self.listener {
TcpConnectorReceiver::Listener(listener) => {
self.connect_inner(listener, &channel, &config).await?
self.connect_inner(listener, channel.secure_channel(), &config)
.await?
}
TcpConnectorReceiver::Address(addr) => {
let listener = TcpListener::bind(addr).await.map_err(|err| {
error!("Could not bind to address {}, {:?}", addr, err);
StatusCode::BadCommunicationError
})?;
self.connect_inner(&listener, &channel, &config).await?
self.connect_inner(&listener, channel.secure_channel(), &config)
.await?
}
};

Expand Down Expand Up @@ -487,7 +496,7 @@ impl TcpTransport {
// If there's nothing in the send buffer, but there are chunks available,
// write them to the send buffer before proceeding.
if self.send_buffer.should_encode_chunks() {
let secure_channel = trace_read_lock!(self.state.secure_channel);
let secure_channel = trace_read_lock!(self.state.channel_state.secure_channel());
if let Err(e) = self.send_buffer.encode_next_chunk(&secure_channel) {
return TransportPollResult::Closed(e);
}
Expand Down Expand Up @@ -525,7 +534,7 @@ impl TcpTransport {
self.should_close = true;
debug!("Writer is about to send a CloseSecureChannelRequest which means it should close in a moment");
}
let secure_channel = trace_read_lock!(self.state.secure_channel);
let secure_channel = trace_read_lock!(self.state.channel_state.secure_channel());
if let Err(e) = self.send_buffer.write(request_id, outgoing, &secure_channel) {
drop(secure_channel);
if let Some((request_id, request_handle)) = e.full_context() {
Expand Down