Skip to content

Commit 3b186c2

Browse files
committed
Drop wire::write and replace encode_msg! macro
Now that we consistently use `wire::Message` everywhere, it's easier to simply use `Message::write`/`Type::write` instead of heaving yet another `wire::write` around. Here we drop `wire::write`, replace the `encode_msg` macro with a method that takes `wire::Message`, and convert a bunch of additional places to move semantics.
1 parent 2c7f1cf commit 3b186c2

File tree

3 files changed

+39
-72
lines changed

3 files changed

+39
-72
lines changed

lightning/src/ln/peer_channel_encryptor.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ use crate::prelude::*;
1212
use crate::ln::msgs;
1313
use crate::ln::msgs::LightningError;
1414
use crate::ln::wire;
15+
use crate::ln::wire::Type;
1516
use crate::sign::{NodeSigner, Recipient};
17+
use crate::util::ser::Writeable;
1618

1719
use bitcoin::hashes::sha256::Hash as Sha256;
1820
use bitcoin::hashes::{Hash, HashEngine};
@@ -570,7 +572,9 @@ impl PeerChannelEncryptor {
570572
// for the 2-byte message type prefix and its MAC.
571573
let mut res = VecWriter(Vec::with_capacity(MSG_BUF_ALLOC_SIZE));
572574
res.0.resize(16 + 2, 0);
573-
wire::write(&message, &mut res).expect("In-memory messages must never fail to serialize");
575+
576+
message.type_id().write(&mut res).expect("In-memory messages must never fail to serialize");
577+
message.write(&mut res).expect("In-memory messages must never fail to serialize");
574578

575579
self.encrypt_message_with_header_0s(&mut res.0);
576580
res.0

lightning/src/ln/peer_handler.rs

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,12 +1144,11 @@ impl From<LightningError> for MessageHandlingError {
11441144
}
11451145
}
11461146

1147-
macro_rules! encode_msg {
1148-
($msg: expr) => {{
1149-
let mut buffer = VecWriter(Vec::with_capacity(MSG_BUF_ALLOC_SIZE));
1150-
wire::write($msg, &mut buffer).unwrap();
1151-
buffer.0
1152-
}};
1147+
fn encode_message<T: wire::Type>(message: wire::Message<T>) -> Vec<u8> {
1148+
let mut buffer = VecWriter(Vec::with_capacity(MSG_BUF_ALLOC_SIZE));
1149+
message.type_id().write(&mut buffer).expect("In-memory messages must never fail to serialize");
1150+
message.write(&mut buffer).expect("In-memory messages must never fail to serialize");
1151+
buffer.0
11531152
}
11541153

11551154
impl<Descriptor: SocketDescriptor, CM: Deref, OM: Deref, L: Deref, NS: Deref, SM: Deref>
@@ -2068,7 +2067,7 @@ where
20682067
for msg in msgs_to_forward.drain(..) {
20692068
self.forward_broadcast_msg(
20702069
&*peers,
2071-
&msg,
2070+
msg,
20722071
peer_node_id.as_ref().map(|(pk, _)| pk),
20732072
false,
20742073
);
@@ -2661,22 +2660,25 @@ where
26612660
/// unless `allow_large_buffer` is set, in which case the message will be treated as critical
26622661
/// and delivered no matter the available buffer space.
26632662
fn forward_broadcast_msg(
2664-
&self, peers: &HashMap<Descriptor, Mutex<Peer>>, msg: &BroadcastGossipMessage,
2663+
&self, peers: &HashMap<Descriptor, Mutex<Peer>>, msg: BroadcastGossipMessage,
26652664
except_node: Option<&PublicKey>, allow_large_buffer: bool,
26662665
) {
26672666
match msg {
2668-
BroadcastGossipMessage::ChannelAnnouncement(ref msg) => {
2667+
BroadcastGossipMessage::ChannelAnnouncement(msg) => {
26692668
log_gossip!(self.logger, "Sending message to all peers except {:?} or the announced channel's counterparties: {:?}", except_node, msg);
2670-
let encoded_msg = encode_msg!(msg);
26712669
let our_channel = self.our_node_id == msg.contents.node_id_1
26722670
|| self.our_node_id == msg.contents.node_id_2;
2673-
2671+
let scid = msg.contents.short_channel_id;
2672+
let node_id_1 = msg.contents.node_id_1;
2673+
let node_id_2 = msg.contents.node_id_2;
2674+
let msg: Message<<CMH::Target as CustomMessageReader>::CustomMessage> =
2675+
Message::ChannelAnnouncement(msg);
2676+
let encoded_msg = encode_message(msg);
26742677
for (_, peer_mutex) in peers.iter() {
26752678
let mut peer = peer_mutex.lock().unwrap();
26762679
if !peer.handshake_complete() {
26772680
continue;
26782681
}
2679-
let scid = msg.contents.short_channel_id;
26802682
if !our_channel && !peer.should_forward_channel_announcement(scid) {
26812683
continue;
26822684
}
@@ -2693,9 +2695,7 @@ where
26932695
continue;
26942696
}
26952697
if let Some((_, their_node_id)) = peer.their_node_id {
2696-
if their_node_id == msg.contents.node_id_1
2697-
|| their_node_id == msg.contents.node_id_2
2698-
{
2698+
if their_node_id == node_id_1 || their_node_id == node_id_2 {
26992699
continue;
27002700
}
27012701
}
@@ -2708,23 +2708,25 @@ where
27082708
peer.gossip_broadcast_buffer.push_back(encoded_message);
27092709
}
27102710
},
2711-
BroadcastGossipMessage::NodeAnnouncement(ref msg) => {
2711+
BroadcastGossipMessage::NodeAnnouncement(msg) => {
27122712
log_gossip!(
27132713
self.logger,
27142714
"Sending message to all peers except {:?} or the announced node: {:?}",
27152715
except_node,
27162716
msg
27172717
);
2718-
let encoded_msg = encode_msg!(msg);
27192718
let our_announcement = self.our_node_id == msg.contents.node_id;
2719+
let msg_node_id = msg.contents.node_id;
27202720

2721+
let msg: Message<<CMH::Target as CustomMessageReader>::CustomMessage> =
2722+
Message::NodeAnnouncement(msg);
2723+
let encoded_msg = encode_message(msg);
27212724
for (_, peer_mutex) in peers.iter() {
27222725
let mut peer = peer_mutex.lock().unwrap();
27232726
if !peer.handshake_complete() {
27242727
continue;
27252728
}
2726-
let node_id = msg.contents.node_id;
2727-
if !our_announcement && !peer.should_forward_node_announcement(node_id) {
2729+
if !our_announcement && !peer.should_forward_node_announcement(msg_node_id) {
27282730
continue;
27292731
}
27302732
debug_assert!(peer.their_node_id.is_some());
@@ -2740,7 +2742,7 @@ where
27402742
continue;
27412743
}
27422744
if let Some((_, their_node_id)) = peer.their_node_id {
2743-
if their_node_id == msg.contents.node_id {
2745+
if their_node_id == msg_node_id {
27442746
continue;
27452747
}
27462748
}
@@ -2760,15 +2762,16 @@ where
27602762
except_node,
27612763
msg
27622764
);
2763-
let encoded_msg = encode_msg!(msg);
2764-
let our_channel = self.our_node_id == *node_id_1 || self.our_node_id == *node_id_2;
2765-
2765+
let our_channel = self.our_node_id == node_id_1 || self.our_node_id == node_id_2;
2766+
let scid = msg.contents.short_channel_id;
2767+
let msg: Message<<CMH::Target as CustomMessageReader>::CustomMessage> =
2768+
Message::ChannelUpdate(msg);
2769+
let encoded_msg = encode_message(msg);
27662770
for (_, peer_mutex) in peers.iter() {
27672771
let mut peer = peer_mutex.lock().unwrap();
27682772
if !peer.handshake_complete() {
27692773
continue;
27702774
}
2771-
let scid = msg.contents.short_channel_id;
27722775
if !our_channel && !peer.should_forward_channel_announcement(scid) {
27732776
continue;
27742777
}
@@ -3201,7 +3204,7 @@ where
32013204
let forward = BroadcastGossipMessage::ChannelAnnouncement(msg);
32023205
self.forward_broadcast_msg(
32033206
peers,
3204-
&forward,
3207+
forward,
32053208
None,
32063209
from_chan_handler,
32073210
);
@@ -3222,7 +3225,7 @@ where
32223225
};
32233226
self.forward_broadcast_msg(
32243227
peers,
3225-
&forward,
3228+
forward,
32263229
None,
32273230
from_chan_handler,
32283231
);
@@ -3246,7 +3249,7 @@ where
32463249
};
32473250
self.forward_broadcast_msg(
32483251
peers,
3249-
&forward,
3252+
forward,
32503253
None,
32513254
from_chan_handler,
32523255
);
@@ -3265,7 +3268,7 @@ where
32653268
let forward = BroadcastGossipMessage::NodeAnnouncement(msg);
32663269
self.forward_broadcast_msg(
32673270
peers,
3268-
&forward,
3271+
forward,
32693272
None,
32703273
from_chan_handler,
32713274
);
@@ -3742,7 +3745,7 @@ where
37423745
let _ = self.message_handler.route_handler.handle_node_announcement(None, &msg);
37433746
self.forward_broadcast_msg(
37443747
&*self.peers.read().unwrap(),
3745-
&BroadcastGossipMessage::NodeAnnouncement(msg),
3748+
BroadcastGossipMessage::NodeAnnouncement(msg),
37463749
None,
37473750
true,
37483751
);
@@ -4557,7 +4560,8 @@ mod tests {
45574560
assert_eq!(peer.gossip_broadcast_buffer.len(), 1);
45584561

45594562
let pending_msg = &peer.gossip_broadcast_buffer[0];
4560-
let expected = encode_msg!(&msg_100);
4563+
let msg: Message<()> = Message::ChannelUpdate(msg_100);
4564+
let expected = encode_message(msg);
45614565
assert_eq!(expected, pending_msg.fetch_encoded_msg_with_type_pfx());
45624566
}
45634567
}

lightning/src/ln/wire.rs

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -425,19 +425,6 @@ where
425425
}
426426
}
427427

428-
/// Writes a message to the data buffer encoded as a 2-byte big-endian type and a variable-length
429-
/// payload.
430-
///
431-
/// # Errors
432-
///
433-
/// Returns an I/O error if the write could not be completed.
434-
pub(crate) fn write<M: Type + Writeable, W: Writer>(
435-
message: &M, buffer: &mut W,
436-
) -> Result<(), io::Error> {
437-
message.type_id().write(buffer)?;
438-
message.write(buffer)
439-
}
440-
441428
mod encode {
442429
/// Defines a constant type identifier for reading messages from the wire.
443430
pub trait Encode {
@@ -737,34 +724,6 @@ mod tests {
737724
}
738725
}
739726

740-
#[test]
741-
fn write_message_with_type() {
742-
let message = msgs::Pong { byteslen: 2u16 };
743-
let mut buffer = Vec::new();
744-
assert!(write(&message, &mut buffer).is_ok());
745-
746-
let type_length = ::core::mem::size_of::<u16>();
747-
let (type_bytes, payload_bytes) = buffer.split_at(type_length);
748-
assert_eq!(u16::from_be_bytes(type_bytes.try_into().unwrap()), msgs::Pong::TYPE);
749-
assert_eq!(payload_bytes, &ENCODED_PONG[type_length..]);
750-
}
751-
752-
#[test]
753-
fn read_message_encoded_with_write() {
754-
let message = msgs::Pong { byteslen: 2u16 };
755-
let mut buffer = Vec::new();
756-
assert!(write(&message, &mut buffer).is_ok());
757-
758-
let decoded_message = read(&mut &buffer[..], &IgnoringMessageHandler {}).unwrap();
759-
match decoded_message {
760-
Message::Pong(msgs::Pong { byteslen: 2u16 }) => (),
761-
Message::Pong(msgs::Pong { byteslen }) => {
762-
panic!("Expected byteslen {}; found: {}", message.byteslen, byteslen);
763-
},
764-
_ => panic!("Expected pong message; found message type: {}", decoded_message.type_id()),
765-
}
766-
}
767-
768727
#[test]
769728
fn is_even_message_type() {
770729
let message = Message::<()>::Unknown(42);

0 commit comments

Comments
 (0)