From d0e22fad4ab80138d97c34be8314a7f064f307fe Mon Sep 17 00:00:00 2001 From: Elias Rohrer Date: Mon, 1 Dec 2025 13:39:45 +0100 Subject: [PATCH 1/3] Consistently use `wire::Message` for encoding network messages Previously, `enqueue_message` took an `M: Type + Writeable` reference, which didn't make use of our `wire::Message` type, which turned out to be rather confusing. Here, we use `Message` consistently in `PeerManager`'s `enqueue_message`, but also in `encrypt_message`, etc. While at it we also switch to move semantics, which is a nice cleanup. --- lightning/src/ln/channelmanager.rs | 2 + lightning/src/ln/functional_test_utils.rs | 2 + lightning/src/ln/msgs.rs | 2 + lightning/src/ln/peer_channel_encryptor.rs | 4 +- lightning/src/ln/peer_handler.rs | 428 +++++++++++++-------- 5 files changed, 285 insertions(+), 153 deletions(-) diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index 399c51b9d9a..d52d8535114 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -13872,7 +13872,9 @@ where &MessageSendEvent::UpdateHTLCs { .. } => false, &MessageSendEvent::SendRevokeAndACK { .. } => false, &MessageSendEvent::SendClosingSigned { .. } => false, + #[cfg(simple_close)] &MessageSendEvent::SendClosingComplete { .. } => false, + #[cfg(simple_close)] &MessageSendEvent::SendClosingSig { .. } => false, &MessageSendEvent::SendShutdown { .. } => false, &MessageSendEvent::SendChannelReestablish { .. } => false, diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index e31630a4926..e4e9c583f32 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -1124,7 +1124,9 @@ pub fn remove_first_msg_event_to_node( MessageSendEvent::UpdateHTLCs { node_id, .. } => node_id == msg_node_id, MessageSendEvent::SendRevokeAndACK { node_id, .. } => node_id == msg_node_id, MessageSendEvent::SendClosingSigned { node_id, .. } => node_id == msg_node_id, + #[cfg(simple_close)] MessageSendEvent::SendClosingComplete { node_id, .. } => node_id == msg_node_id, + #[cfg(simple_close)] MessageSendEvent::SendClosingSig { node_id, .. } => node_id == msg_node_id, MessageSendEvent::SendShutdown { node_id, .. } => node_id == msg_node_id, MessageSendEvent::SendChannelReestablish { node_id, .. } => node_id == msg_node_id, diff --git a/lightning/src/ln/msgs.rs b/lightning/src/ln/msgs.rs index 8e230fab1d9..0484ebe7530 100644 --- a/lightning/src/ln/msgs.rs +++ b/lightning/src/ln/msgs.rs @@ -1857,6 +1857,7 @@ pub enum MessageSendEvent { msg: ClosingSigned, }, /// Used to indicate that a `closing_complete` message should be sent to the peer with the given `node_id`. + #[cfg(simple_close)] SendClosingComplete { /// The node_id of the node which should receive this message node_id: PublicKey, @@ -1864,6 +1865,7 @@ pub enum MessageSendEvent { msg: ClosingComplete, }, /// Used to indicate that a `closing_sig` message should be sent to the peer with the given `node_id`. + #[cfg(simple_close)] SendClosingSig { /// The node_id of the node which should receive this message node_id: PublicKey, diff --git a/lightning/src/ln/peer_channel_encryptor.rs b/lightning/src/ln/peer_channel_encryptor.rs index 09b970a9ab2..1d34d9a8674 100644 --- a/lightning/src/ln/peer_channel_encryptor.rs +++ b/lightning/src/ln/peer_channel_encryptor.rs @@ -565,12 +565,12 @@ impl PeerChannelEncryptor { /// Encrypts the given message, returning the encrypted version. /// panics if the length of `message`, once encoded, is greater than 65535 or if the Noise /// handshake has not finished. - pub fn encrypt_message(&mut self, message: &M) -> Vec { + pub fn encrypt_message(&mut self, message: wire::Message) -> Vec { // Allocate a buffer with 2KB, fitting most common messages. Reserve the first 16+2 bytes // for the 2-byte message type prefix and its MAC. let mut res = VecWriter(Vec::with_capacity(MSG_BUF_ALLOC_SIZE)); res.0.resize(16 + 2, 0); - wire::write(message, &mut res).expect("In-memory messages must never fail to serialize"); + wire::write(&message, &mut res).expect("In-memory messages must never fail to serialize"); self.encrypt_message_with_header_0s(&mut res.0); res.0 diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index c3b490ef31a..7967208cbe4 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -29,7 +29,7 @@ use crate::ln::peer_channel_encryptor::{ }; use crate::ln::types::ChannelId; use crate::ln::wire; -use crate::ln::wire::{Encode, Type}; +use crate::ln::wire::{Encode, Message, Type}; use crate::onion_message::async_payments::{ AsyncPaymentsMessageHandler, HeldHtlcAvailable, OfferPaths, OfferPathsRequest, ReleaseHeldHtlc, ServeStaticInvoice, StaticInvoicePersisted, @@ -53,12 +53,14 @@ use crate::util::ser::{VecWriter, Writeable, Writer}; #[allow(unused_imports)] use crate::prelude::*; +use super::wire::CustomMessageReader; use crate::io; use crate::sync::{FairRwLock, Mutex, MutexGuard}; use core::convert::Infallible; use core::ops::Deref; use core::sync::atomic::{AtomicBool, AtomicI32, AtomicU32, Ordering}; use core::{cmp, fmt, hash, mem}; + #[cfg(not(c_bindings))] use { crate::chain::chainmonitor::ChainMonitor, @@ -1121,7 +1123,7 @@ pub struct PeerManager< } enum LogicalMessage { - FromWire(wire::Message), + FromWire(Message), CommitmentSignedBatch(ChannelId, Vec), } @@ -1572,7 +1574,9 @@ where if let Some(next_onion_message) = handler.next_onion_message_for_peer(peer_node_id) { - self.enqueue_message(peer, &next_onion_message); + let msg: Message<::CustomMessage> = + Message::OnionMessage(next_onion_message); + self.enqueue_message(peer, msg); } } } @@ -1590,16 +1594,25 @@ where if let Some((announce, update_a_option, update_b_option)) = self.message_handler.route_handler.get_next_channel_announcement(c) { - self.enqueue_message(peer, &announce); + peer.sync_status = InitSyncTracker::ChannelsSyncing( + announce.contents.short_channel_id + 1, + ); + let msg: Message<::CustomMessage> = + Message::ChannelAnnouncement(announce); + self.enqueue_message(peer, msg); + if let Some(update_a) = update_a_option { - self.enqueue_message(peer, &update_a); + let msg: Message< + ::CustomMessage, + > = Message::ChannelUpdate(update_a); + self.enqueue_message(peer, msg); } if let Some(update_b) = update_b_option { - self.enqueue_message(peer, &update_b); + let msg: Message< + ::CustomMessage, + > = Message::ChannelUpdate(update_b); + self.enqueue_message(peer, msg); } - peer.sync_status = InitSyncTracker::ChannelsSyncing( - announce.contents.short_channel_id + 1, - ); } else { peer.sync_status = InitSyncTracker::ChannelsSyncing(0xffff_ffff_ffff_ffff); @@ -1608,8 +1621,10 @@ where InitSyncTracker::ChannelsSyncing(c) if c == 0xffff_ffff_ffff_ffff => { let handler = &self.message_handler.route_handler; if let Some(msg) = handler.get_next_node_announcement(None) { - self.enqueue_message(peer, &msg); peer.sync_status = InitSyncTracker::NodesSyncing(msg.contents.node_id); + let msg: Message<::CustomMessage> = + Message::NodeAnnouncement(msg); + self.enqueue_message(peer, msg); } else { peer.sync_status = InitSyncTracker::NoSyncRequested; } @@ -1618,8 +1633,10 @@ where InitSyncTracker::NodesSyncing(sync_node_id) => { let handler = &self.message_handler.route_handler; if let Some(msg) = handler.get_next_node_announcement(Some(&sync_node_id)) { - self.enqueue_message(peer, &msg); peer.sync_status = InitSyncTracker::NodesSyncing(msg.contents.node_id); + let msg: Message<::CustomMessage> = + Message::NodeAnnouncement(msg); + self.enqueue_message(peer, msg); } else { peer.sync_status = InitSyncTracker::NoSyncRequested; } @@ -1727,7 +1744,7 @@ where } /// Append a message to a peer's pending outbound/write buffer - fn enqueue_message(&self, peer: &mut Peer, message: &M) { + fn enqueue_message(&self, peer: &mut Peer, message: Message) { let their_node_id = peer.their_node_id.map(|p| p.0); if their_node_id.is_some() { let logger = WithContext::from(&self.logger, their_node_id, None, None); @@ -1792,12 +1809,14 @@ where }, msgs::ErrorAction::SendErrorMessage { msg } => { log_debug!(logger, "Error handling message{}; sending error message with: {}", OptionalFromDebugger(&peer_node_id), e.err); - self.enqueue_message($peer, &msg); + let msg: Message<::CustomMessage> = Message::Error(msg); + self.enqueue_message($peer, msg); continue; }, msgs::ErrorAction::SendWarningMessage { msg, log_level } => { log_given_level!(logger, log_level, "Error handling message{}; sending warning message with: {}", OptionalFromDebugger(&peer_node_id), e.err); - self.enqueue_message($peer, &msg); + let msg: Message<::CustomMessage> = Message::Warning(msg); + self.enqueue_message($peer, msg); continue; }, } @@ -1892,7 +1911,9 @@ where peer.their_socket_address.clone(), ), }; - self.enqueue_message(peer, &resp); + let msg: Message<::CustomMessage> = + Message::Init(resp); + self.enqueue_message(peer, msg); }, NextNoiseStep::ActThree => { let res = peer @@ -1912,7 +1933,9 @@ where peer.their_socket_address.clone(), ), }; - self.enqueue_message(peer, &resp); + let msg: Message<::CustomMessage> = + Message::Init(resp); + self.enqueue_message(peer, msg); }, NextNoiseStep::NoiseComplete => { if peer.pending_read_is_header { @@ -1973,7 +1996,8 @@ where let data = "Unsupported message compression: zlib" .to_owned(); let msg = msgs::WarningMessage { channel_id, data }; - self.enqueue_message(peer, &msg); + let msg: Message<::CustomMessage> = Message::Warning(msg); + self.enqueue_message(peer, msg); continue; }, (_, Some(ty)) if is_gossip_msg(ty) => { @@ -1984,7 +2008,8 @@ where ty ); let msg = msgs::WarningMessage { channel_id, data }; - self.enqueue_message(peer, &msg); + let msg: Message<::CustomMessage> = Message::Warning(msg); + self.enqueue_message(peer, msg); continue; }, (msgs::DecodeError::UnknownRequiredFeature, _) => { @@ -2060,9 +2085,7 @@ where /// Returns the message back if it needs to be broadcasted to all other peers. fn handle_message( &self, peer_mutex: &Mutex, peer_lock: MutexGuard, - message: wire::Message< - <::Target as wire::CustomMessageReader>::CustomMessage, - >, + message: Message<<::Target as wire::CustomMessageReader>::CustomMessage>, ) -> Result, MessageHandlingError> { let their_node_id = peer_lock .their_node_id @@ -2103,9 +2126,7 @@ where // allow it to be subsequently processed by `do_handle_message_without_peer_lock`. fn do_handle_message_holding_peer_lock<'a>( &self, mut peer_lock: MutexGuard, - message: wire::Message< - <::Target as wire::CustomMessageReader>::CustomMessage, - >, + message: Message<<::Target as wire::CustomMessageReader>::CustomMessage>, their_node_id: PublicKey, logger: &WithContext<'a, L>, ) -> Result< Option< @@ -2116,7 +2137,7 @@ where peer_lock.received_message_since_timer_tick = true; // Need an Init as first message - if let wire::Message::Init(msg) = message { + if let Message::Init(msg) = message { // Check if we have any compatible chains if the `networks` field is specified. if let Some(networks) = &msg.networks { let chan_handler = &self.message_handler.chan_handler; @@ -2225,7 +2246,7 @@ where // During splicing, commitment_signed messages need to be collected into a single batch // before they are handled. - if let wire::Message::StartBatch(msg) = message { + if let Message::StartBatch(msg) = message { if peer_lock.message_batch.is_some() { let error = format!( "Peer {} sent start_batch for channel {} before previous batch completed", @@ -2296,7 +2317,7 @@ where return Ok(None); } - if let wire::Message::CommitmentSigned(msg) = message { + if let Message::CommitmentSigned(msg) = message { if let Some(message_batch) = &mut peer_lock.message_batch { let MessageBatchImpl::CommitmentSigned(ref mut messages) = &mut message_batch.messages; @@ -2325,7 +2346,7 @@ where return Ok(None); } } else { - return Ok(Some(LogicalMessage::FromWire(wire::Message::CommitmentSigned(msg)))); + return Ok(Some(LogicalMessage::FromWire(Message::CommitmentSigned(msg)))); } } else if let Some(message_batch) = &peer_lock.message_batch { match message_batch.messages { @@ -2341,7 +2362,7 @@ where return Err(PeerHandleError {}.into()); } - if let wire::Message::GossipTimestampFilter(_msg) = message { + if let Message::GossipTimestampFilter(_msg) = message { // When supporting gossip messages, start initial gossip sync only after we receive // a GossipTimestampFilter if peer_lock.their_features.as_ref().unwrap().supports_gossip_queries() @@ -2373,7 +2394,7 @@ where return Ok(None); } - if let wire::Message::ChannelAnnouncement(ref _msg) = message { + if let Message::ChannelAnnouncement(ref _msg) = message { peer_lock.received_channel_announce_since_backlogged = true; } @@ -2385,9 +2406,7 @@ where // Returns the message back if it needs to be broadcasted to all other peers. fn do_handle_message_without_peer_lock<'a>( &self, peer_mutex: &Mutex, - message: wire::Message< - <::Target as wire::CustomMessageReader>::CustomMessage, - >, + message: Message<<::Target as wire::CustomMessageReader>::CustomMessage>, their_node_id: PublicKey, logger: &WithContext<'a, L>, ) -> Result, MessageHandlingError> { if is_gossip_msg(message.type_id()) { @@ -2400,13 +2419,13 @@ where match message { // Setup and Control messages: - wire::Message::Init(_) => { + Message::Init(_) => { // Handled above }, - wire::Message::GossipTimestampFilter(_) => { + Message::GossipTimestampFilter(_) => { // Handled above }, - wire::Message::Error(msg) => { + Message::Error(msg) => { log_debug!( logger, "Got Err message from {}: {}", @@ -2418,149 +2437,151 @@ where return Err(PeerHandleError {}.into()); } }, - wire::Message::Warning(msg) => { + Message::Warning(msg) => { log_debug!(logger, "Got warning message: {}", PrintableString(&msg.data)); }, - wire::Message::Ping(msg) => { + Message::Ping(msg) => { if msg.ponglen < 65532 { let resp = msgs::Pong { byteslen: msg.ponglen }; - self.enqueue_message(&mut *peer_mutex.lock().unwrap(), &resp); + let msg: Message<::CustomMessage> = + Message::Pong(resp); + self.enqueue_message(&mut *peer_mutex.lock().unwrap(), msg); } }, - wire::Message::Pong(_msg) => { + Message::Pong(_msg) => { let mut peer_lock = peer_mutex.lock().unwrap(); peer_lock.awaiting_pong_timer_tick_intervals = 0; peer_lock.msgs_sent_since_pong = 0; }, // Channel messages: - wire::Message::StartBatch(_msg) => { + Message::StartBatch(_msg) => { debug_assert!(false); }, - wire::Message::OpenChannel(msg) => { + Message::OpenChannel(msg) => { self.message_handler.chan_handler.handle_open_channel(their_node_id, &msg); }, - wire::Message::OpenChannelV2(_msg) => { + Message::OpenChannelV2(_msg) => { self.message_handler.chan_handler.handle_open_channel_v2(their_node_id, &_msg); }, - wire::Message::AcceptChannel(msg) => { + Message::AcceptChannel(msg) => { self.message_handler.chan_handler.handle_accept_channel(their_node_id, &msg); }, - wire::Message::AcceptChannelV2(msg) => { + Message::AcceptChannelV2(msg) => { self.message_handler.chan_handler.handle_accept_channel_v2(their_node_id, &msg); }, - wire::Message::FundingCreated(msg) => { + Message::FundingCreated(msg) => { self.message_handler.chan_handler.handle_funding_created(their_node_id, &msg); }, - wire::Message::FundingSigned(msg) => { + Message::FundingSigned(msg) => { self.message_handler.chan_handler.handle_funding_signed(their_node_id, &msg); }, - wire::Message::ChannelReady(msg) => { + Message::ChannelReady(msg) => { self.message_handler.chan_handler.handle_channel_ready(their_node_id, &msg); }, - wire::Message::PeerStorage(msg) => { + Message::PeerStorage(msg) => { self.message_handler.chan_handler.handle_peer_storage(their_node_id, msg); }, - wire::Message::PeerStorageRetrieval(msg) => { + Message::PeerStorageRetrieval(msg) => { self.message_handler.chan_handler.handle_peer_storage_retrieval(their_node_id, msg); }, // Quiescence messages: - wire::Message::Stfu(msg) => { + Message::Stfu(msg) => { self.message_handler.chan_handler.handle_stfu(their_node_id, &msg); }, // Splicing messages: - wire::Message::SpliceInit(msg) => { + Message::SpliceInit(msg) => { self.message_handler.chan_handler.handle_splice_init(their_node_id, &msg); }, - wire::Message::SpliceAck(msg) => { + Message::SpliceAck(msg) => { self.message_handler.chan_handler.handle_splice_ack(their_node_id, &msg); }, - wire::Message::SpliceLocked(msg) => { + Message::SpliceLocked(msg) => { self.message_handler.chan_handler.handle_splice_locked(their_node_id, &msg); }, // Interactive transaction construction messages: - wire::Message::TxAddInput(msg) => { + Message::TxAddInput(msg) => { self.message_handler.chan_handler.handle_tx_add_input(their_node_id, &msg); }, - wire::Message::TxAddOutput(msg) => { + Message::TxAddOutput(msg) => { self.message_handler.chan_handler.handle_tx_add_output(their_node_id, &msg); }, - wire::Message::TxRemoveInput(msg) => { + Message::TxRemoveInput(msg) => { self.message_handler.chan_handler.handle_tx_remove_input(their_node_id, &msg); }, - wire::Message::TxRemoveOutput(msg) => { + Message::TxRemoveOutput(msg) => { self.message_handler.chan_handler.handle_tx_remove_output(their_node_id, &msg); }, - wire::Message::TxComplete(msg) => { + Message::TxComplete(msg) => { self.message_handler.chan_handler.handle_tx_complete(their_node_id, &msg); }, - wire::Message::TxSignatures(msg) => { + Message::TxSignatures(msg) => { self.message_handler.chan_handler.handle_tx_signatures(their_node_id, &msg); }, - wire::Message::TxInitRbf(msg) => { + Message::TxInitRbf(msg) => { self.message_handler.chan_handler.handle_tx_init_rbf(their_node_id, &msg); }, - wire::Message::TxAckRbf(msg) => { + Message::TxAckRbf(msg) => { self.message_handler.chan_handler.handle_tx_ack_rbf(their_node_id, &msg); }, - wire::Message::TxAbort(msg) => { + Message::TxAbort(msg) => { self.message_handler.chan_handler.handle_tx_abort(their_node_id, &msg); }, - wire::Message::Shutdown(msg) => { + Message::Shutdown(msg) => { self.message_handler.chan_handler.handle_shutdown(their_node_id, &msg); }, - wire::Message::ClosingSigned(msg) => { + Message::ClosingSigned(msg) => { self.message_handler.chan_handler.handle_closing_signed(their_node_id, &msg); }, #[cfg(simple_close)] - wire::Message::ClosingComplete(msg) => { + Message::ClosingComplete(msg) => { self.message_handler.chan_handler.handle_closing_complete(their_node_id, msg); }, #[cfg(simple_close)] - wire::Message::ClosingSig(msg) => { + Message::ClosingSig(msg) => { self.message_handler.chan_handler.handle_closing_sig(their_node_id, msg); }, // Commitment messages: - wire::Message::UpdateAddHTLC(msg) => { + Message::UpdateAddHTLC(msg) => { self.message_handler.chan_handler.handle_update_add_htlc(their_node_id, &msg); }, - wire::Message::UpdateFulfillHTLC(msg) => { + Message::UpdateFulfillHTLC(msg) => { self.message_handler.chan_handler.handle_update_fulfill_htlc(their_node_id, msg); }, - wire::Message::UpdateFailHTLC(msg) => { + Message::UpdateFailHTLC(msg) => { self.message_handler.chan_handler.handle_update_fail_htlc(their_node_id, &msg); }, - wire::Message::UpdateFailMalformedHTLC(msg) => { + Message::UpdateFailMalformedHTLC(msg) => { let chan_handler = &self.message_handler.chan_handler; chan_handler.handle_update_fail_malformed_htlc(their_node_id, &msg); }, - wire::Message::CommitmentSigned(msg) => { + Message::CommitmentSigned(msg) => { self.message_handler.chan_handler.handle_commitment_signed(their_node_id, &msg); }, - wire::Message::RevokeAndACK(msg) => { + Message::RevokeAndACK(msg) => { self.message_handler.chan_handler.handle_revoke_and_ack(their_node_id, &msg); }, - wire::Message::UpdateFee(msg) => { + Message::UpdateFee(msg) => { self.message_handler.chan_handler.handle_update_fee(their_node_id, &msg); }, - wire::Message::ChannelReestablish(msg) => { + Message::ChannelReestablish(msg) => { self.message_handler.chan_handler.handle_channel_reestablish(their_node_id, &msg); }, // Routing messages: - wire::Message::AnnouncementSignatures(msg) => { + Message::AnnouncementSignatures(msg) => { let chan_handler = &self.message_handler.chan_handler; chan_handler.handle_announcement_signatures(their_node_id, &msg); }, - wire::Message::ChannelAnnouncement(msg) => { + Message::ChannelAnnouncement(msg) => { let route_handler = &self.message_handler.route_handler; if route_handler .handle_channel_announcement(Some(their_node_id), &msg) @@ -2570,7 +2591,7 @@ where } self.update_gossip_backlogged(); }, - wire::Message::NodeAnnouncement(msg) => { + Message::NodeAnnouncement(msg) => { let route_handler = &self.message_handler.route_handler; if route_handler .handle_node_announcement(Some(their_node_id), &msg) @@ -2580,7 +2601,7 @@ where } self.update_gossip_backlogged(); }, - wire::Message::ChannelUpdate(msg) => { + Message::ChannelUpdate(msg) => { let chan_handler = &self.message_handler.chan_handler; chan_handler.handle_channel_update(their_node_id, &msg); @@ -2594,31 +2615,31 @@ where } self.update_gossip_backlogged(); }, - wire::Message::QueryShortChannelIds(msg) => { + Message::QueryShortChannelIds(msg) => { let route_handler = &self.message_handler.route_handler; route_handler.handle_query_short_channel_ids(their_node_id, msg)?; }, - wire::Message::ReplyShortChannelIdsEnd(msg) => { + Message::ReplyShortChannelIdsEnd(msg) => { let route_handler = &self.message_handler.route_handler; route_handler.handle_reply_short_channel_ids_end(their_node_id, msg)?; }, - wire::Message::QueryChannelRange(msg) => { + Message::QueryChannelRange(msg) => { let route_handler = &self.message_handler.route_handler; route_handler.handle_query_channel_range(their_node_id, msg)?; }, - wire::Message::ReplyChannelRange(msg) => { + Message::ReplyChannelRange(msg) => { let route_handler = &self.message_handler.route_handler; route_handler.handle_reply_channel_range(their_node_id, msg)?; }, // Onion message: - wire::Message::OnionMessage(msg) => { + Message::OnionMessage(msg) => { let onion_message_handler = &self.message_handler.onion_message_handler; onion_message_handler.handle_onion_message(their_node_id, &msg); }, // Unknown messages: - wire::Message::Unknown(type_id) if message.is_even() => { + Message::Unknown(type_id) if message.is_even() => { log_debug!( logger, "Received unknown even message of type {}, disconnecting peer!", @@ -2626,10 +2647,10 @@ where ); return Err(PeerHandleError {}.into()); }, - wire::Message::Unknown(type_id) => { + Message::Unknown(type_id) => { log_trace!(logger, "Received unknown odd message of type {}, ignoring", type_id); }, - wire::Message::Custom(custom) => { + Message::Custom(custom) => { let custom_message_handler = &self.message_handler.custom_message_handler; custom_message_handler.handle_custom_message(custom, their_node_id)?; }, @@ -2858,68 +2879,86 @@ where // robustly gossip broadcast events even if a peer's message buffer is full. let mut handle_event = |event, from_chan_handler| { match event { - MessageSendEvent::SendPeerStorage { ref node_id, ref msg } => { + MessageSendEvent::SendPeerStorage { ref node_id, msg } => { log_debug!( WithContext::from(&self.logger, Some(*node_id), None, None), "Handling SendPeerStorage event in peer_handler for {}", node_id, ); + let msg: Message<::CustomMessage> = + Message::PeerStorage(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendPeerStorageRetrieval { ref node_id, ref msg } => { + MessageSendEvent::SendPeerStorageRetrieval { ref node_id, msg } => { log_debug!( WithContext::from(&self.logger, Some(*node_id), None, None), "Handling SendPeerStorageRetrieval event in peer_handler for {}", node_id, ); + let msg: Message<::CustomMessage> = + Message::PeerStorageRetrieval(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendAcceptChannel { ref node_id, ref msg } => { + MessageSendEvent::SendAcceptChannel { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.common_fields.temporary_channel_id), None), "Handling SendAcceptChannel event in peer_handler for node {} for channel {}", node_id, &msg.common_fields.temporary_channel_id); + let msg: Message<::CustomMessage> = + Message::AcceptChannel(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendAcceptChannelV2 { ref node_id, ref msg } => { + MessageSendEvent::SendAcceptChannelV2 { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.common_fields.temporary_channel_id), None), "Handling SendAcceptChannelV2 event in peer_handler for node {} for channel {}", node_id, &msg.common_fields.temporary_channel_id); + let msg: Message<::CustomMessage> = + Message::AcceptChannelV2(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendOpenChannel { ref node_id, ref msg } => { + MessageSendEvent::SendOpenChannel { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.common_fields.temporary_channel_id), None), "Handling SendOpenChannel event in peer_handler for node {} for channel {}", node_id, &msg.common_fields.temporary_channel_id); + let msg: Message<::CustomMessage> = + Message::OpenChannel(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendOpenChannelV2 { ref node_id, ref msg } => { + MessageSendEvent::SendOpenChannelV2 { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.common_fields.temporary_channel_id), None), "Handling SendOpenChannelV2 event in peer_handler for node {} for channel {}", node_id, &msg.common_fields.temporary_channel_id); + let msg: Message<::CustomMessage> = + Message::OpenChannelV2(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendFundingCreated { ref node_id, ref msg } => { + MessageSendEvent::SendFundingCreated { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.temporary_channel_id), None), "Handling SendFundingCreated event in peer_handler for node {} for channel {} (which becomes {})", node_id, &msg.temporary_channel_id, ChannelId::v1_from_funding_txid(msg.funding_txid.as_byte_array(), msg.funding_output_index)); // TODO: If the peer is gone we should generate a DiscardFunding event // indicating to the wallet that they should just throw away this funding transaction + let msg: Message<::CustomMessage> = + Message::FundingCreated(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendFundingSigned { ref node_id, ref msg } => { + MessageSendEvent::SendFundingSigned { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendFundingSigned event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::FundingSigned(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendChannelReady { ref node_id, ref msg } => { + MessageSendEvent::SendChannelReady { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendChannelReady event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::ChannelReady(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendStfu { ref node_id, ref msg } => { + MessageSendEvent::SendStfu { ref node_id, msg } => { let logger = WithContext::from( &self.logger, Some(*node_id), @@ -2929,9 +2968,11 @@ where log_debug!(logger, "Handling SendStfu event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::Stfu(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendSpliceInit { ref node_id, ref msg } => { + MessageSendEvent::SendSpliceInit { ref node_id, msg } => { let logger = WithContext::from( &self.logger, Some(*node_id), @@ -2941,9 +2982,11 @@ where log_debug!(logger, "Handling SendSpliceInit event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::SpliceInit(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendSpliceAck { ref node_id, ref msg } => { + MessageSendEvent::SendSpliceAck { ref node_id, msg } => { let logger = WithContext::from( &self.logger, Some(*node_id), @@ -2953,9 +2996,11 @@ where log_debug!(logger, "Handling SendSpliceAck event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::SpliceAck(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendSpliceLocked { ref node_id, ref msg } => { + MessageSendEvent::SendSpliceLocked { ref node_id, msg } => { let logger = WithContext::from( &self.logger, Some(*node_id), @@ -2965,66 +3010,88 @@ where log_debug!(logger, "Handling SendSpliceLocked event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::SpliceLocked(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendTxAddInput { ref node_id, ref msg } => { + MessageSendEvent::SendTxAddInput { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxAddInput event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::TxAddInput(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendTxAddOutput { ref node_id, ref msg } => { + MessageSendEvent::SendTxAddOutput { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxAddOutput event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::TxAddOutput(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendTxRemoveInput { ref node_id, ref msg } => { + MessageSendEvent::SendTxRemoveInput { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxRemoveInput event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::TxRemoveInput(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendTxRemoveOutput { ref node_id, ref msg } => { + MessageSendEvent::SendTxRemoveOutput { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxRemoveOutput event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::TxRemoveOutput(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendTxComplete { ref node_id, ref msg } => { + MessageSendEvent::SendTxComplete { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxComplete event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::TxComplete(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendTxSignatures { ref node_id, ref msg } => { + MessageSendEvent::SendTxSignatures { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxSignatures event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::TxSignatures(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendTxInitRbf { ref node_id, ref msg } => { + MessageSendEvent::SendTxInitRbf { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxInitRbf event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::TxInitRbf(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendTxAckRbf { ref node_id, ref msg } => { + MessageSendEvent::SendTxAckRbf { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxAckRbf event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::TxAckRbf(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendTxAbort { ref node_id, ref msg } => { + MessageSendEvent::SendTxAbort { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxAbort event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::TxAbort(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendAnnouncementSignatures { ref node_id, ref msg } => { + MessageSendEvent::SendAnnouncementSignatures { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendAnnouncementSignatures event in peer_handler for node {} for channel {})", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::AnnouncementSignatures(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::UpdateHTLCs { @@ -3032,12 +3099,12 @@ where ref channel_id, updates: msgs::CommitmentUpdate { - ref update_add_htlcs, - ref update_fulfill_htlcs, - ref update_fail_htlcs, - ref update_fail_malformed_htlcs, - ref update_fee, - ref commitment_signed, + update_add_htlcs, + update_fulfill_htlcs, + update_fail_htlcs, + update_fail_malformed_htlcs, + update_fee, + commitment_signed, }, } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(*channel_id), None), "Handling UpdateHTLCs event in peer_handler for node {} with {} adds, {} fulfills, {} fails, {} commits for channel {}", @@ -3049,18 +3116,33 @@ where channel_id); let mut peer = get_peer_for_forwarding!(node_id)?; for msg in update_fulfill_htlcs { + let msg: Message< + ::CustomMessage, + > = Message::UpdateFulfillHTLC(msg); self.enqueue_message(&mut *peer, msg); } for msg in update_fail_htlcs { + let msg: Message< + ::CustomMessage, + > = Message::UpdateFailHTLC(msg); self.enqueue_message(&mut *peer, msg); } for msg in update_fail_malformed_htlcs { + let msg: Message< + ::CustomMessage, + > = Message::UpdateFailMalformedHTLC(msg); self.enqueue_message(&mut *peer, msg); } for msg in update_add_htlcs { + let msg: Message< + ::CustomMessage, + > = Message::UpdateAddHTLC(msg); self.enqueue_message(&mut *peer, msg); } - if let &Some(ref msg) = update_fee { + if let Some(msg) = update_fee { + let msg: Message< + ::CustomMessage, + > = Message::UpdateFee(msg); self.enqueue_message(&mut *peer, msg); } if commitment_signed.len() > 1 { @@ -3069,37 +3151,53 @@ where batch_size: commitment_signed.len() as u16, message_type: Some(msgs::CommitmentSigned::TYPE), }; - self.enqueue_message(&mut *peer, &msg); + let msg: Message< + ::CustomMessage, + > = Message::StartBatch(msg); + self.enqueue_message(&mut *peer, msg); } for msg in commitment_signed { + let msg: Message< + ::CustomMessage, + > = Message::CommitmentSigned(msg); self.enqueue_message(&mut *peer, msg); } }, - MessageSendEvent::SendRevokeAndACK { ref node_id, ref msg } => { + MessageSendEvent::SendRevokeAndACK { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendRevokeAndACK event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::RevokeAndACK(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendClosingSigned { ref node_id, ref msg } => { + MessageSendEvent::SendClosingSigned { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendClosingSigned event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::ClosingSigned(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendClosingComplete { ref node_id, ref msg } => { + #[cfg(simple_close)] + MessageSendEvent::SendClosingComplete { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendClosingComplete event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::ClosingComplete(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendClosingSig { ref node_id, ref msg } => { + #[cfg(simple_close)] + MessageSendEvent::SendClosingSig { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendClosingSig event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::ClosingSig(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendShutdown { ref node_id, ref msg } => { + MessageSendEvent::SendShutdown { ref node_id, msg } => { log_debug!( WithContext::from( &self.logger, @@ -3109,23 +3207,32 @@ where ), "Handling Shutdown event in peer_handler", ); + let msg: Message<::CustomMessage> = + Message::Shutdown(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendChannelReestablish { ref node_id, ref msg } => { + MessageSendEvent::SendChannelReestablish { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendChannelReestablish event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); + let msg: Message<::CustomMessage> = + Message::ChannelReestablish(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendChannelAnnouncement { ref node_id, - ref msg, - ref update_msg, + msg, + update_msg, } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), None, None), "Handling SendChannelAnnouncement event in peer_handler for node {} for short channel id {}", node_id, msg.contents.short_channel_id); + let msg: Message<::CustomMessage> = + Message::ChannelAnnouncement(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); + let update_msg: Message< + ::CustomMessage, + > = Message::ChannelUpdate(update_msg); self.enqueue_message( &mut *get_peer_for_forwarding!(node_id)?, update_msg, @@ -3216,12 +3323,14 @@ where _ => {}, } }, - MessageSendEvent::SendChannelUpdate { ref node_id, ref msg } => { + MessageSendEvent::SendChannelUpdate { ref node_id, msg } => { log_trace!( WithContext::from(&self.logger, Some(*node_id), None, None), "Handling SendChannelUpdate event in peer_handler for channel {}", msg.contents.short_channel_id ); + let msg: Message<::CustomMessage> = + Message::ChannelUpdate(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::HandleError { node_id, action } => { @@ -3239,7 +3348,7 @@ where // about to disconnect the peer and do it after we finish // processing most messages. let msg = msg.map(|msg| { - wire::Message::<<::Target as wire::CustomMessageReader>::CustomMessage>::Error(msg) + Message::<<::Target as wire::CustomMessageReader>::CustomMessage>::Error(msg) }); peers_to_disconnect.insert(node_id, msg); }, @@ -3250,7 +3359,7 @@ where // about to disconnect the peer and do it after we finish // processing most messages. peers_to_disconnect - .insert(node_id, Some(wire::Message::Warning(msg))); + .insert(node_id, Some(Message::Warning(msg))); }, msgs::ErrorAction::IgnoreAndLog(level) => { log_given_level!( @@ -3266,22 +3375,25 @@ where "Received a HandleError event to be ignored", ); }, - msgs::ErrorAction::SendErrorMessage { ref msg } => { + msgs::ErrorAction::SendErrorMessage { msg } => { log_trace!(logger, "Handling SendErrorMessage HandleError event in peer_handler with message {}", msg.data); + let msg: Message< + ::CustomMessage, + > = Message::Error(msg); self.enqueue_message( &mut *get_peer_for_forwarding!(&node_id)?, msg, ); }, - msgs::ErrorAction::SendWarningMessage { - ref msg, - ref log_level, - } => { + msgs::ErrorAction::SendWarningMessage { msg, ref log_level } => { log_given_level!(logger, *log_level, "Handling SendWarningMessage HandleError event in peer_handler with message {}", msg.data); + let msg: Message< + ::CustomMessage, + > = Message::Warning(msg); self.enqueue_message( &mut *get_peer_for_forwarding!(&node_id)?, msg, @@ -3289,33 +3401,41 @@ where }, } }, - MessageSendEvent::SendChannelRangeQuery { ref node_id, ref msg } => { + MessageSendEvent::SendChannelRangeQuery { ref node_id, msg } => { log_gossip!(WithContext::from(&self.logger, Some(*node_id), None, None), "Handling SendChannelRangeQuery event in peer_handler with first_blocknum={}, number_of_blocks={}", msg.first_blocknum, msg.number_of_blocks); + let msg: Message<::CustomMessage> = + Message::QueryChannelRange(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendShortIdsQuery { ref node_id, ref msg } => { + MessageSendEvent::SendShortIdsQuery { ref node_id, msg } => { log_gossip!(WithContext::from(&self.logger, Some(*node_id), None, None), "Handling SendShortIdsQuery event in peer_handler with num_scids={}", msg.short_channel_ids.len()); + let msg: Message<::CustomMessage> = + Message::QueryShortChannelIds(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendReplyChannelRange { ref node_id, ref msg } => { + MessageSendEvent::SendReplyChannelRange { ref node_id, msg } => { log_gossip!(WithContext::from(&self.logger, Some(*node_id), None, None), "Handling SendReplyChannelRange event in peer_handler with num_scids={} first_blocknum={} number_of_blocks={}, sync_complete={}", msg.short_channel_ids.len(), msg.first_blocknum, msg.number_of_blocks, msg.sync_complete); + let msg: Message<::CustomMessage> = + Message::ReplyChannelRange(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, - MessageSendEvent::SendGossipTimestampFilter { ref node_id, ref msg } => { + MessageSendEvent::SendGossipTimestampFilter { ref node_id, msg } => { log_gossip!(WithContext::from(&self.logger, Some(*node_id), None, None), "Handling SendGossipTimestampFilter event in peer_handler with first_timestamp={}, timestamp_range={}", msg.first_timestamp, msg.timestamp_range); + let msg: Message<::CustomMessage> = + Message::GossipTimestampFilter(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, } @@ -3351,7 +3471,9 @@ where } else { continue; }; - self.enqueue_message(&mut peer, &msg); + let msg: Message<::CustomMessage> = + Message::Custom(msg); + self.enqueue_message(&mut peer, msg); } for (descriptor, peer_mutex) in peers.iter() { @@ -3381,7 +3503,7 @@ where if let Some(peer_mutex) = peers.remove(&descriptor) { let mut peer = peer_mutex.lock().unwrap(); if let Some(msg) = msg { - self.enqueue_message(&mut *peer, &msg); + self.enqueue_message(&mut *peer, msg); // This isn't guaranteed to work, but if there is enough free // room in the send buffer, put the error message there... self.do_attempt_write_data(&mut descriptor, &mut *peer, false); @@ -3506,7 +3628,9 @@ where if peer.awaiting_pong_timer_tick_intervals == 0 { peer.awaiting_pong_timer_tick_intervals = -1; let ping = msgs::Ping { ponglen: 0, byteslen: 64 }; - self.enqueue_message(peer, &ping); + let msg: Message<::CustomMessage> = + Message::Ping(ping); + self.enqueue_message(peer, msg); } } @@ -3577,7 +3701,9 @@ where peer.awaiting_pong_timer_tick_intervals = 1; let ping = msgs::Ping { ponglen: 0, byteslen: 64 }; - self.enqueue_message(&mut *peer, &ping); + let msg: Message<::CustomMessage> = + Message::Ping(ping); + self.enqueue_message(&mut *peer, msg); break; } self.do_attempt_write_data( @@ -4226,7 +4352,7 @@ mod tests { .push(MessageSendEvent::SendShutdown { node_id: their_id, msg: msg.clone() }); peers[0].message_handler.chan_handler = &a_chan_handler; - b_chan_handler.expect_receive_msg(wire::Message::Shutdown(msg)); + b_chan_handler.expect_receive_msg(Message::Shutdown(msg)); peers[1].message_handler.chan_handler = &b_chan_handler; peers[0].process_events(); @@ -4261,7 +4387,8 @@ mod tests { peers[0].read_event(&mut fd_dup, &act_three).unwrap(); let not_init_msg = msgs::Ping { ponglen: 4, byteslen: 0 }; - let msg_bytes = dup_encryptor.encrypt_message(¬_init_msg); + let msg: Message<()> = Message::Ping(not_init_msg); + let msg_bytes = dup_encryptor.encrypt_message(msg); assert!(peers[0].read_event(&mut fd_dup, &msg_bytes).is_err()); } @@ -4639,13 +4766,12 @@ mod tests { { let peers = peer_a.peers.read().unwrap(); let mut peer_b = peers.get(&fd_a).unwrap().lock().unwrap(); - peer_a.enqueue_message( - &mut peer_b, - &msgs::WarningMessage { - channel_id: ChannelId([0; 32]), - data: "no disconnect plz".to_string(), - }, - ); + let warning = msgs::WarningMessage { + channel_id: ChannelId([0; 32]), + data: "no disconnect plz".to_string(), + }; + let msg: Message<()> = Message::Warning(warning); + peer_a.enqueue_message(&mut peer_b, msg); } peer_a.process_events(); let msg = fd_a.outbound_data.lock().unwrap().split_off(0); From 2c7f1cfea145f1629a7d54cf9755755b04852092 Mon Sep 17 00:00:00 2001 From: Elias Rohrer Date: Fri, 5 Dec 2025 14:44:20 +0100 Subject: [PATCH 2/3] f Concretize `enqueue_message` --- lightning/src/ln/peer_handler.rs | 203 +++++++++++-------------------- 1 file changed, 71 insertions(+), 132 deletions(-) diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index 7967208cbe4..8a6c6a786b1 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -1574,8 +1574,7 @@ where if let Some(next_onion_message) = handler.next_onion_message_for_peer(peer_node_id) { - let msg: Message<::CustomMessage> = - Message::OnionMessage(next_onion_message); + let msg = Message::OnionMessage(next_onion_message); self.enqueue_message(peer, msg); } } @@ -1597,20 +1596,15 @@ where peer.sync_status = InitSyncTracker::ChannelsSyncing( announce.contents.short_channel_id + 1, ); - let msg: Message<::CustomMessage> = - Message::ChannelAnnouncement(announce); + let msg = Message::ChannelAnnouncement(announce); self.enqueue_message(peer, msg); if let Some(update_a) = update_a_option { - let msg: Message< - ::CustomMessage, - > = Message::ChannelUpdate(update_a); + let msg = Message::ChannelUpdate(update_a); self.enqueue_message(peer, msg); } if let Some(update_b) = update_b_option { - let msg: Message< - ::CustomMessage, - > = Message::ChannelUpdate(update_b); + let msg = Message::ChannelUpdate(update_b); self.enqueue_message(peer, msg); } } else { @@ -1622,8 +1616,7 @@ where let handler = &self.message_handler.route_handler; if let Some(msg) = handler.get_next_node_announcement(None) { peer.sync_status = InitSyncTracker::NodesSyncing(msg.contents.node_id); - let msg: Message<::CustomMessage> = - Message::NodeAnnouncement(msg); + let msg = Message::NodeAnnouncement(msg); self.enqueue_message(peer, msg); } else { peer.sync_status = InitSyncTracker::NoSyncRequested; @@ -1634,8 +1627,7 @@ where let handler = &self.message_handler.route_handler; if let Some(msg) = handler.get_next_node_announcement(Some(&sync_node_id)) { peer.sync_status = InitSyncTracker::NodesSyncing(msg.contents.node_id); - let msg: Message<::CustomMessage> = - Message::NodeAnnouncement(msg); + let msg = Message::NodeAnnouncement(msg); self.enqueue_message(peer, msg); } else { peer.sync_status = InitSyncTracker::NoSyncRequested; @@ -1744,7 +1736,10 @@ where } /// Append a message to a peer's pending outbound/write buffer - fn enqueue_message(&self, peer: &mut Peer, message: Message) { + fn enqueue_message( + &self, peer: &mut Peer, + message: Message<::CustomMessage>, + ) { let their_node_id = peer.their_node_id.map(|p| p.0); if their_node_id.is_some() { let logger = WithContext::from(&self.logger, their_node_id, None, None); @@ -1809,13 +1804,13 @@ where }, msgs::ErrorAction::SendErrorMessage { msg } => { log_debug!(logger, "Error handling message{}; sending error message with: {}", OptionalFromDebugger(&peer_node_id), e.err); - let msg: Message<::CustomMessage> = Message::Error(msg); + let msg = Message::Error(msg); self.enqueue_message($peer, msg); continue; }, msgs::ErrorAction::SendWarningMessage { msg, log_level } => { log_given_level!(logger, log_level, "Error handling message{}; sending warning message with: {}", OptionalFromDebugger(&peer_node_id), e.err); - let msg: Message<::CustomMessage> = Message::Warning(msg); + let msg = Message::Warning(msg); self.enqueue_message($peer, msg); continue; }, @@ -1911,8 +1906,7 @@ where peer.their_socket_address.clone(), ), }; - let msg: Message<::CustomMessage> = - Message::Init(resp); + let msg = Message::Init(resp); self.enqueue_message(peer, msg); }, NextNoiseStep::ActThree => { @@ -1933,8 +1927,7 @@ where peer.their_socket_address.clone(), ), }; - let msg: Message<::CustomMessage> = - Message::Init(resp); + let msg = Message::Init(resp); self.enqueue_message(peer, msg); }, NextNoiseStep::NoiseComplete => { @@ -1995,8 +1988,10 @@ where let channel_id = ChannelId::new_zero(); let data = "Unsupported message compression: zlib" .to_owned(); - let msg = msgs::WarningMessage { channel_id, data }; - let msg: Message<::CustomMessage> = Message::Warning(msg); + let msg = Message::Warning(msgs::WarningMessage { + channel_id, + data, + }); self.enqueue_message(peer, msg); continue; }, @@ -2007,8 +2002,10 @@ where "Unreadable/bogus gossip message of type {}", ty ); - let msg = msgs::WarningMessage { channel_id, data }; - let msg: Message<::CustomMessage> = Message::Warning(msg); + let msg = Message::Warning(msgs::WarningMessage { + channel_id, + data, + }); self.enqueue_message(peer, msg); continue; }, @@ -2444,8 +2441,7 @@ where Message::Ping(msg) => { if msg.ponglen < 65532 { let resp = msgs::Pong { byteslen: msg.ponglen }; - let msg: Message<::CustomMessage> = - Message::Pong(resp); + let msg = Message::Pong(resp); self.enqueue_message(&mut *peer_mutex.lock().unwrap(), msg); } }, @@ -2885,8 +2881,7 @@ where "Handling SendPeerStorage event in peer_handler for {}", node_id, ); - let msg: Message<::CustomMessage> = - Message::PeerStorage(msg); + let msg = Message::PeerStorage(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendPeerStorageRetrieval { ref node_id, msg } => { @@ -2895,40 +2890,35 @@ where "Handling SendPeerStorageRetrieval event in peer_handler for {}", node_id, ); - let msg: Message<::CustomMessage> = - Message::PeerStorageRetrieval(msg); + let msg = Message::PeerStorageRetrieval(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendAcceptChannel { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.common_fields.temporary_channel_id), None), "Handling SendAcceptChannel event in peer_handler for node {} for channel {}", node_id, &msg.common_fields.temporary_channel_id); - let msg: Message<::CustomMessage> = - Message::AcceptChannel(msg); + let msg = Message::AcceptChannel(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendAcceptChannelV2 { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.common_fields.temporary_channel_id), None), "Handling SendAcceptChannelV2 event in peer_handler for node {} for channel {}", node_id, &msg.common_fields.temporary_channel_id); - let msg: Message<::CustomMessage> = - Message::AcceptChannelV2(msg); + let msg = Message::AcceptChannelV2(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendOpenChannel { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.common_fields.temporary_channel_id), None), "Handling SendOpenChannel event in peer_handler for node {} for channel {}", node_id, &msg.common_fields.temporary_channel_id); - let msg: Message<::CustomMessage> = - Message::OpenChannel(msg); + let msg = Message::OpenChannel(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendOpenChannelV2 { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.common_fields.temporary_channel_id), None), "Handling SendOpenChannelV2 event in peer_handler for node {} for channel {}", node_id, &msg.common_fields.temporary_channel_id); - let msg: Message<::CustomMessage> = - Message::OpenChannelV2(msg); + let msg = Message::OpenChannelV2(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendFundingCreated { ref node_id, msg } => { @@ -2938,24 +2928,21 @@ where ChannelId::v1_from_funding_txid(msg.funding_txid.as_byte_array(), msg.funding_output_index)); // TODO: If the peer is gone we should generate a DiscardFunding event // indicating to the wallet that they should just throw away this funding transaction - let msg: Message<::CustomMessage> = - Message::FundingCreated(msg); + let msg = Message::FundingCreated(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendFundingSigned { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendFundingSigned event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::FundingSigned(msg); + let msg = Message::FundingSigned(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendChannelReady { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendChannelReady event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::ChannelReady(msg); + let msg = Message::ChannelReady(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendStfu { ref node_id, msg } => { @@ -2968,8 +2955,7 @@ where log_debug!(logger, "Handling SendStfu event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::Stfu(msg); + let msg = Message::Stfu(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendSpliceInit { ref node_id, msg } => { @@ -2982,8 +2968,7 @@ where log_debug!(logger, "Handling SendSpliceInit event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::SpliceInit(msg); + let msg = Message::SpliceInit(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendSpliceAck { ref node_id, msg } => { @@ -2996,8 +2981,7 @@ where log_debug!(logger, "Handling SendSpliceAck event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::SpliceAck(msg); + let msg = Message::SpliceAck(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendSpliceLocked { ref node_id, msg } => { @@ -3010,88 +2994,77 @@ where log_debug!(logger, "Handling SendSpliceLocked event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::SpliceLocked(msg); + let msg = Message::SpliceLocked(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendTxAddInput { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxAddInput event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::TxAddInput(msg); + let msg = Message::TxAddInput(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendTxAddOutput { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxAddOutput event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::TxAddOutput(msg); + let msg = Message::TxAddOutput(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendTxRemoveInput { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxRemoveInput event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::TxRemoveInput(msg); + let msg = Message::TxRemoveInput(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendTxRemoveOutput { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxRemoveOutput event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::TxRemoveOutput(msg); + let msg = Message::TxRemoveOutput(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendTxComplete { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxComplete event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::TxComplete(msg); + let msg = Message::TxComplete(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendTxSignatures { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxSignatures event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::TxSignatures(msg); + let msg = Message::TxSignatures(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendTxInitRbf { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxInitRbf event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::TxInitRbf(msg); + let msg = Message::TxInitRbf(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendTxAckRbf { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxAckRbf event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::TxAckRbf(msg); + let msg = Message::TxAckRbf(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendTxAbort { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxAbort event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::TxAbort(msg); + let msg = Message::TxAbort(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendAnnouncementSignatures { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendAnnouncementSignatures event in peer_handler for node {} for channel {})", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::AnnouncementSignatures(msg); + let msg = Message::AnnouncementSignatures(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::UpdateHTLCs { @@ -3116,33 +3089,23 @@ where channel_id); let mut peer = get_peer_for_forwarding!(node_id)?; for msg in update_fulfill_htlcs { - let msg: Message< - ::CustomMessage, - > = Message::UpdateFulfillHTLC(msg); + let msg = Message::UpdateFulfillHTLC(msg); self.enqueue_message(&mut *peer, msg); } for msg in update_fail_htlcs { - let msg: Message< - ::CustomMessage, - > = Message::UpdateFailHTLC(msg); + let msg = Message::UpdateFailHTLC(msg); self.enqueue_message(&mut *peer, msg); } for msg in update_fail_malformed_htlcs { - let msg: Message< - ::CustomMessage, - > = Message::UpdateFailMalformedHTLC(msg); + let msg = Message::UpdateFailMalformedHTLC(msg); self.enqueue_message(&mut *peer, msg); } for msg in update_add_htlcs { - let msg: Message< - ::CustomMessage, - > = Message::UpdateAddHTLC(msg); + let msg = Message::UpdateAddHTLC(msg); self.enqueue_message(&mut *peer, msg); } if let Some(msg) = update_fee { - let msg: Message< - ::CustomMessage, - > = Message::UpdateFee(msg); + let msg = Message::UpdateFee(msg); self.enqueue_message(&mut *peer, msg); } if commitment_signed.len() > 1 { @@ -3151,15 +3114,11 @@ where batch_size: commitment_signed.len() as u16, message_type: Some(msgs::CommitmentSigned::TYPE), }; - let msg: Message< - ::CustomMessage, - > = Message::StartBatch(msg); + let msg = Message::StartBatch(msg); self.enqueue_message(&mut *peer, msg); } for msg in commitment_signed { - let msg: Message< - ::CustomMessage, - > = Message::CommitmentSigned(msg); + let msg = Message::CommitmentSigned(msg); self.enqueue_message(&mut *peer, msg); } }, @@ -3167,16 +3126,14 @@ where log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendRevokeAndACK event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::RevokeAndACK(msg); + let msg = Message::RevokeAndACK(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendClosingSigned { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendClosingSigned event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::ClosingSigned(msg); + let msg = Message::ClosingSigned(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, #[cfg(simple_close)] @@ -3184,8 +3141,7 @@ where log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendClosingComplete event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::ClosingComplete(msg); + let msg = Message::ClosingComplete(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, #[cfg(simple_close)] @@ -3193,8 +3149,7 @@ where log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendClosingSig event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::ClosingSig(msg); + let msg = Message::ClosingSig(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendShutdown { ref node_id, msg } => { @@ -3207,16 +3162,14 @@ where ), "Handling Shutdown event in peer_handler", ); - let msg: Message<::CustomMessage> = - Message::Shutdown(msg); + let msg = Message::Shutdown(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendChannelReestablish { ref node_id, msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendChannelReestablish event in peer_handler for node {} for channel {}", node_id, &msg.channel_id); - let msg: Message<::CustomMessage> = - Message::ChannelReestablish(msg); + let msg = Message::ChannelReestablish(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendChannelAnnouncement { @@ -3227,12 +3180,9 @@ where log_debug!(WithContext::from(&self.logger, Some(*node_id), None, None), "Handling SendChannelAnnouncement event in peer_handler for node {} for short channel id {}", node_id, msg.contents.short_channel_id); - let msg: Message<::CustomMessage> = - Message::ChannelAnnouncement(msg); + let msg = Message::ChannelAnnouncement(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); - let update_msg: Message< - ::CustomMessage, - > = Message::ChannelUpdate(update_msg); + let update_msg = Message::ChannelUpdate(update_msg); self.enqueue_message( &mut *get_peer_for_forwarding!(node_id)?, update_msg, @@ -3329,8 +3279,7 @@ where "Handling SendChannelUpdate event in peer_handler for channel {}", msg.contents.short_channel_id ); - let msg: Message<::CustomMessage> = - Message::ChannelUpdate(msg); + let msg = Message::ChannelUpdate(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::HandleError { node_id, action } => { @@ -3379,9 +3328,7 @@ where log_trace!(logger, "Handling SendErrorMessage HandleError event in peer_handler with message {}", msg.data); - let msg: Message< - ::CustomMessage, - > = Message::Error(msg); + let msg = Message::Error(msg); self.enqueue_message( &mut *get_peer_for_forwarding!(&node_id)?, msg, @@ -3391,9 +3338,7 @@ where log_given_level!(logger, *log_level, "Handling SendWarningMessage HandleError event in peer_handler with message {}", msg.data); - let msg: Message< - ::CustomMessage, - > = Message::Warning(msg); + let msg = Message::Warning(msg); self.enqueue_message( &mut *get_peer_for_forwarding!(&node_id)?, msg, @@ -3406,16 +3351,14 @@ where msg.first_blocknum, msg.number_of_blocks); - let msg: Message<::CustomMessage> = - Message::QueryChannelRange(msg); + let msg = Message::QueryChannelRange(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendShortIdsQuery { ref node_id, msg } => { log_gossip!(WithContext::from(&self.logger, Some(*node_id), None, None), "Handling SendShortIdsQuery event in peer_handler with num_scids={}", msg.short_channel_ids.len()); - let msg: Message<::CustomMessage> = - Message::QueryShortChannelIds(msg); + let msg = Message::QueryShortChannelIds(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendReplyChannelRange { ref node_id, msg } => { @@ -3425,8 +3368,7 @@ where msg.first_blocknum, msg.number_of_blocks, msg.sync_complete); - let msg: Message<::CustomMessage> = - Message::ReplyChannelRange(msg); + let msg = Message::ReplyChannelRange(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendGossipTimestampFilter { ref node_id, msg } => { @@ -3434,8 +3376,7 @@ where msg.first_timestamp, msg.timestamp_range); - let msg: Message<::CustomMessage> = - Message::GossipTimestampFilter(msg); + let msg = Message::GossipTimestampFilter(msg); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, } @@ -3471,8 +3412,7 @@ where } else { continue; }; - let msg: Message<::CustomMessage> = - Message::Custom(msg); + let msg = Message::Custom(msg); self.enqueue_message(&mut peer, msg); } @@ -3701,8 +3641,7 @@ where peer.awaiting_pong_timer_tick_intervals = 1; let ping = msgs::Ping { ponglen: 0, byteslen: 64 }; - let msg: Message<::CustomMessage> = - Message::Ping(ping); + let msg = Message::Ping(ping); self.enqueue_message(&mut *peer, msg); break; } @@ -4770,7 +4709,7 @@ mod tests { channel_id: ChannelId([0; 32]), data: "no disconnect plz".to_string(), }; - let msg: Message<()> = Message::Warning(warning); + let msg = Message::Warning(warning); peer_a.enqueue_message(&mut peer_b, msg); } peer_a.process_events(); From 3b186c247a085559ce8fec90df66ef6cf1402c1d Mon Sep 17 00:00:00 2001 From: Elias Rohrer Date: Mon, 1 Dec 2025 14:08:17 +0100 Subject: [PATCH 3/3] Drop `wire::write` and replace `encode_msg!` macro Now that we consistently use `wire::Message` everywhere, it's easier to simply use `Message::write`/`Type::write` instead of heaving yet another `wire::write` around. Here we drop `wire::write`, replace the `encode_msg` macro with a method that takes `wire::Message`, and convert a bunch of additional places to move semantics. --- lightning/src/ln/peer_channel_encryptor.rs | 6 +- lightning/src/ln/peer_handler.rs | 64 ++++++++++++---------- lightning/src/ln/wire.rs | 41 -------------- 3 files changed, 39 insertions(+), 72 deletions(-) diff --git a/lightning/src/ln/peer_channel_encryptor.rs b/lightning/src/ln/peer_channel_encryptor.rs index 1d34d9a8674..894de045b14 100644 --- a/lightning/src/ln/peer_channel_encryptor.rs +++ b/lightning/src/ln/peer_channel_encryptor.rs @@ -12,7 +12,9 @@ use crate::prelude::*; use crate::ln::msgs; use crate::ln::msgs::LightningError; use crate::ln::wire; +use crate::ln::wire::Type; use crate::sign::{NodeSigner, Recipient}; +use crate::util::ser::Writeable; use bitcoin::hashes::sha256::Hash as Sha256; use bitcoin::hashes::{Hash, HashEngine}; @@ -570,7 +572,9 @@ impl PeerChannelEncryptor { // for the 2-byte message type prefix and its MAC. let mut res = VecWriter(Vec::with_capacity(MSG_BUF_ALLOC_SIZE)); res.0.resize(16 + 2, 0); - wire::write(&message, &mut res).expect("In-memory messages must never fail to serialize"); + + message.type_id().write(&mut res).expect("In-memory messages must never fail to serialize"); + message.write(&mut res).expect("In-memory messages must never fail to serialize"); self.encrypt_message_with_header_0s(&mut res.0); res.0 diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index 8a6c6a786b1..4d1dff9cd52 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -1144,12 +1144,11 @@ impl From for MessageHandlingError { } } -macro_rules! encode_msg { - ($msg: expr) => {{ - let mut buffer = VecWriter(Vec::with_capacity(MSG_BUF_ALLOC_SIZE)); - wire::write($msg, &mut buffer).unwrap(); - buffer.0 - }}; +fn encode_message(message: wire::Message) -> Vec { + let mut buffer = VecWriter(Vec::with_capacity(MSG_BUF_ALLOC_SIZE)); + message.type_id().write(&mut buffer).expect("In-memory messages must never fail to serialize"); + message.write(&mut buffer).expect("In-memory messages must never fail to serialize"); + buffer.0 } impl @@ -2068,7 +2067,7 @@ where for msg in msgs_to_forward.drain(..) { self.forward_broadcast_msg( &*peers, - &msg, + msg, peer_node_id.as_ref().map(|(pk, _)| pk), false, ); @@ -2661,22 +2660,25 @@ where /// unless `allow_large_buffer` is set, in which case the message will be treated as critical /// and delivered no matter the available buffer space. fn forward_broadcast_msg( - &self, peers: &HashMap>, msg: &BroadcastGossipMessage, + &self, peers: &HashMap>, msg: BroadcastGossipMessage, except_node: Option<&PublicKey>, allow_large_buffer: bool, ) { match msg { - BroadcastGossipMessage::ChannelAnnouncement(ref msg) => { + BroadcastGossipMessage::ChannelAnnouncement(msg) => { log_gossip!(self.logger, "Sending message to all peers except {:?} or the announced channel's counterparties: {:?}", except_node, msg); - let encoded_msg = encode_msg!(msg); let our_channel = self.our_node_id == msg.contents.node_id_1 || self.our_node_id == msg.contents.node_id_2; - + let scid = msg.contents.short_channel_id; + let node_id_1 = msg.contents.node_id_1; + let node_id_2 = msg.contents.node_id_2; + let msg: Message<::CustomMessage> = + Message::ChannelAnnouncement(msg); + let encoded_msg = encode_message(msg); for (_, peer_mutex) in peers.iter() { let mut peer = peer_mutex.lock().unwrap(); if !peer.handshake_complete() { continue; } - let scid = msg.contents.short_channel_id; if !our_channel && !peer.should_forward_channel_announcement(scid) { continue; } @@ -2693,9 +2695,7 @@ where continue; } if let Some((_, their_node_id)) = peer.their_node_id { - if their_node_id == msg.contents.node_id_1 - || their_node_id == msg.contents.node_id_2 - { + if their_node_id == node_id_1 || their_node_id == node_id_2 { continue; } } @@ -2708,23 +2708,25 @@ where peer.gossip_broadcast_buffer.push_back(encoded_message); } }, - BroadcastGossipMessage::NodeAnnouncement(ref msg) => { + BroadcastGossipMessage::NodeAnnouncement(msg) => { log_gossip!( self.logger, "Sending message to all peers except {:?} or the announced node: {:?}", except_node, msg ); - let encoded_msg = encode_msg!(msg); let our_announcement = self.our_node_id == msg.contents.node_id; + let msg_node_id = msg.contents.node_id; + let msg: Message<::CustomMessage> = + Message::NodeAnnouncement(msg); + let encoded_msg = encode_message(msg); for (_, peer_mutex) in peers.iter() { let mut peer = peer_mutex.lock().unwrap(); if !peer.handshake_complete() { continue; } - let node_id = msg.contents.node_id; - if !our_announcement && !peer.should_forward_node_announcement(node_id) { + if !our_announcement && !peer.should_forward_node_announcement(msg_node_id) { continue; } debug_assert!(peer.their_node_id.is_some()); @@ -2740,7 +2742,7 @@ where continue; } if let Some((_, their_node_id)) = peer.their_node_id { - if their_node_id == msg.contents.node_id { + if their_node_id == msg_node_id { continue; } } @@ -2760,15 +2762,16 @@ where except_node, msg ); - let encoded_msg = encode_msg!(msg); - let our_channel = self.our_node_id == *node_id_1 || self.our_node_id == *node_id_2; - + let our_channel = self.our_node_id == node_id_1 || self.our_node_id == node_id_2; + let scid = msg.contents.short_channel_id; + let msg: Message<::CustomMessage> = + Message::ChannelUpdate(msg); + let encoded_msg = encode_message(msg); for (_, peer_mutex) in peers.iter() { let mut peer = peer_mutex.lock().unwrap(); if !peer.handshake_complete() { continue; } - let scid = msg.contents.short_channel_id; if !our_channel && !peer.should_forward_channel_announcement(scid) { continue; } @@ -3201,7 +3204,7 @@ where let forward = BroadcastGossipMessage::ChannelAnnouncement(msg); self.forward_broadcast_msg( peers, - &forward, + forward, None, from_chan_handler, ); @@ -3222,7 +3225,7 @@ where }; self.forward_broadcast_msg( peers, - &forward, + forward, None, from_chan_handler, ); @@ -3246,7 +3249,7 @@ where }; self.forward_broadcast_msg( peers, - &forward, + forward, None, from_chan_handler, ); @@ -3265,7 +3268,7 @@ where let forward = BroadcastGossipMessage::NodeAnnouncement(msg); self.forward_broadcast_msg( peers, - &forward, + forward, None, from_chan_handler, ); @@ -3742,7 +3745,7 @@ where let _ = self.message_handler.route_handler.handle_node_announcement(None, &msg); self.forward_broadcast_msg( &*self.peers.read().unwrap(), - &BroadcastGossipMessage::NodeAnnouncement(msg), + BroadcastGossipMessage::NodeAnnouncement(msg), None, true, ); @@ -4557,7 +4560,8 @@ mod tests { assert_eq!(peer.gossip_broadcast_buffer.len(), 1); let pending_msg = &peer.gossip_broadcast_buffer[0]; - let expected = encode_msg!(&msg_100); + let msg: Message<()> = Message::ChannelUpdate(msg_100); + let expected = encode_message(msg); assert_eq!(expected, pending_msg.fetch_encoded_msg_with_type_pfx()); } } diff --git a/lightning/src/ln/wire.rs b/lightning/src/ln/wire.rs index bc1d83adb68..9065c49c676 100644 --- a/lightning/src/ln/wire.rs +++ b/lightning/src/ln/wire.rs @@ -425,19 +425,6 @@ where } } -/// Writes a message to the data buffer encoded as a 2-byte big-endian type and a variable-length -/// payload. -/// -/// # Errors -/// -/// Returns an I/O error if the write could not be completed. -pub(crate) fn write( - message: &M, buffer: &mut W, -) -> Result<(), io::Error> { - message.type_id().write(buffer)?; - message.write(buffer) -} - mod encode { /// Defines a constant type identifier for reading messages from the wire. pub trait Encode { @@ -737,34 +724,6 @@ mod tests { } } - #[test] - fn write_message_with_type() { - let message = msgs::Pong { byteslen: 2u16 }; - let mut buffer = Vec::new(); - assert!(write(&message, &mut buffer).is_ok()); - - let type_length = ::core::mem::size_of::(); - let (type_bytes, payload_bytes) = buffer.split_at(type_length); - assert_eq!(u16::from_be_bytes(type_bytes.try_into().unwrap()), msgs::Pong::TYPE); - assert_eq!(payload_bytes, &ENCODED_PONG[type_length..]); - } - - #[test] - fn read_message_encoded_with_write() { - let message = msgs::Pong { byteslen: 2u16 }; - let mut buffer = Vec::new(); - assert!(write(&message, &mut buffer).is_ok()); - - let decoded_message = read(&mut &buffer[..], &IgnoringMessageHandler {}).unwrap(); - match decoded_message { - Message::Pong(msgs::Pong { byteslen: 2u16 }) => (), - Message::Pong(msgs::Pong { byteslen }) => { - panic!("Expected byteslen {}; found: {}", message.byteslen, byteslen); - }, - _ => panic!("Expected pong message; found message type: {}", decoded_message.type_id()), - } - } - #[test] fn is_even_message_type() { let message = Message::<()>::Unknown(42);