From 82393f9eab2de0b1b947fdba57b664d95f55b1dd Mon Sep 17 00:00:00 2001 From: Suphanat Chunhapanya Date: Mon, 17 Nov 2025 16:01:59 -0300 Subject: [PATCH 1/4] ms-select2: initial commit for multistream-select2 --- lazyClient.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/lazyClient.go b/lazyClient.go index 3ff48f9..9e9850d 100644 --- a/lazyClient.go +++ b/lazyClient.go @@ -17,6 +17,17 @@ func NewMSSelect[T StringLike](c io.ReadWriteCloser, proto T) LazyConn { } } +func NewMSSelect2[T StringLike](c io.ReadWriteCloser, proto T, peerProtos []T) LazyConn { + // TODO: put peerProtos into lazyClientConn so that it knows what protocols the other peer supports + return &lazyClientConn[T]{ + protos: []T{ProtocolID, proto}, + con: c, + + rhandshakeOnce: newOnce(), + whandshakeOnce: newOnce(), + } +} + // NewMultistream returns a multistream for the given protocol. This will not // perform any protocol selection. If you are using a MultistreamMuxer, use // NewMSSelect. From 33c7cf49a7e4f36479d441a45215115c6e0a9363 Mon Sep 17 00:00:00 2001 From: Suphanat Chunhapanya Date: Thu, 20 Nov 2025 18:01:12 -0300 Subject: [PATCH 2/4] ms-select2: abbreviation tree and NewMSSelect2 --- abbrevTree.go | 98 ++++++++++++++++++++++++++++ abbrevTree_test.go | 155 +++++++++++++++++++++++++++++++++++++++++++++ lazyClient.go | 11 +++- multistream.go | 3 + 4 files changed, 265 insertions(+), 2 deletions(-) create mode 100644 abbrevTree.go create mode 100644 abbrevTree_test.go diff --git a/abbrevTree.go b/abbrevTree.go new file mode 100644 index 0000000..31e0fd2 --- /dev/null +++ b/abbrevTree.go @@ -0,0 +1,98 @@ +package multistream + +import ( + "crypto/sha256" +) + +type nodeProtocol[T StringLike] struct { + protocolID T + tombstoneBit bool +} + +type abbrevTree[T StringLike] struct { + root *abbrevNode[T] +} + +type abbrevNode[T StringLike] struct { + p *nodeProtocol[T] + children [256]*abbrevNode[T] +} + +func (at *abbrevTree[T]) Abbreviate(pid T) []byte { + var result []byte + hash := sha256.Sum256([]byte(pid)) + + if at.root == nil { + return nil + } + + current := at.root + // go furthest in the tree + for _, b := range hash { + if current.children[b] != nil { + result = append(result, b) + current = current.children[b] + } + } + + if current.p != nil && current.p.protocolID == pid && !current.p.tombstoneBit { + return result + } + return nil +} + +func (at *abbrevTree[T]) AddProtocol(pid T) { + hash := sha256.Sum256([]byte(pid)) + + if at.root == nil { + at.root = &abbrevNode[T]{} + } + + current := at.root + for idx, b := range hash { + if current.children[b] == nil { + current.children[b] = &abbrevNode[T]{ + p: &nodeProtocol[T]{ + protocolID: pid, + tombstoneBit: false, + }, + } + return + } + current = current.children[b] + + if current.p != nil { + if current.p.protocolID == pid { + // Resurrect the protocol ID. + current.p.tombstoneBit = false + } else if !current.p.tombstoneBit { + // There is another protocol in this node, so we need to duplicate it down. + h := sha256.Sum256([]byte(current.p.protocolID)) + + if current.children[h[idx+1]] == nil { + // It should be fine to reference the same nodeProtocol instance. + current.children[h[idx+1]] = &abbrevNode[T]{p: current.p} + } + } + } + } +} + +func (at *abbrevTree[T]) RemoveProtocol(pid T) { + hash := sha256.Sum256([]byte(pid)) + + if at.root == nil { + return + } + current := at.root + for _, b := range hash { + if current.children[b] == nil { + break + } + current = current.children[b] + + if current.p.protocolID == pid { + current.p.tombstoneBit = true + } + } +} diff --git a/abbrevTree_test.go b/abbrevTree_test.go new file mode 100644 index 0000000..cd435e7 --- /dev/null +++ b/abbrevTree_test.go @@ -0,0 +1,155 @@ +package multistream + +import ( + "bytes" + "crypto/sha256" + "testing" +) + +func TestAbbrevTreeAddProtocol(t *testing.T) { + tree := &abbrevTree[string]{} + + proto1 := "protocol1" + hash1 := sha256.Sum256([]byte(proto1)) + proto2 := "protocol2" + hash2 := sha256.Sum256([]byte(proto2)) + proto3 := "protocol251" // this one has the same first byte as "protocol1" + hash3 := sha256.Sum256([]byte(proto3)) + + // make sure we don't make mistakes on the hashes + if hash1[0] == hash2[0] { + t.Fatal("the first bytes of hash1 and hash2 should be different") + } + if hash1[0] != hash3[0] { + t.Fatal("the first bytes of hash1 and hash3 should be the same") + } + if hash1[1] == hash3[1] { + t.Fatal("the second bytes of hash1 and hash3 should be different") + } + + // add only proto1 + tree.AddProtocol(proto1) + + if tree.root == nil { + t.Fatal("root should not be nil after adding protocol") + } + if tree.root.children[hash1[0]] == nil || tree.root.children[hash1[0]].p == nil { + t.Fatal("the protocol was not added") + } + if tree.root.children[hash1[0]].p.protocolID != proto1 { + t.Fatal("the protocol ID was wrong") + } + if tree.root.children[hash1[0]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0]}) { + t.Fatal("abbreviation of proto1 is incorrect") + } + + // also add proto2 + tree.AddProtocol(proto2) + + if tree.root.children[hash2[0]] == nil || tree.root.children[hash2[0]].p == nil { + t.Fatal("the protocol was not added") + } + if tree.root.children[hash2[0]].p.protocolID != proto2 { + t.Fatal("the protocol ID was wrong") + } + if tree.root.children[hash2[0]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto2 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto2), []byte{hash2[0]}) { + t.Fatal("abbreviation of proto2 is incorrect") + } + + // add proto3 which has the same first byte of the hash as proto1 + tree.AddProtocol(proto3) + + n1 := tree.root.children[hash1[0]] + // the node at the first level should still be proto1 + if n1.p.protocolID != proto1 { + t.Fatal("the node in the first level should not be modified") + } + // proto1 should be duplicated down + if n1.children[hash1[1]] == nil || n1.children[hash1[1]].p == nil { + t.Fatal("proto1 was not duplicated") + } + if n1.children[hash1[1]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0], hash1[1]}) { + t.Fatal("abbreviation of proto1 is incorrect") + } + // proto3 should be added in the second level + if n1.children[hash3[1]] == nil || n1.children[hash3[1]].p == nil { + t.Fatal("proto3 was not added") + } + if n1.children[hash3[1]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto3 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto3), []byte{hash3[0], hash3[1]}) { + t.Fatal("abbreviation of proto3 is incorrect") + } +} + +func TestAbbrevTreeRemoveProtocol(t *testing.T) { + tree := &abbrevTree[string]{} + + proto1 := "protocol1" + hash1 := sha256.Sum256([]byte(proto1)) + proto2 := "protocol2" + proto3 := "protocol251" // this one has the same first byte as "protocol1" + + tree.AddProtocol(proto1) + tree.AddProtocol(proto2) + tree.AddProtocol(proto3) + + // remove only proto1 + tree.RemoveProtocol(proto1) + + n1 := tree.root.children[hash1[0]] + if !n1.p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must be set") + } + if !n1.children[hash1[1]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must be set") + } + if tree.Abbreviate(proto1) != nil { + t.Fatal("abbreviation of proto1 should be nil") + } +} + +func TestAbbrevTreeResurrectProtocol(t *testing.T) { + tree := &abbrevTree[string]{} + + proto1 := "protocol1" + hash1 := sha256.Sum256([]byte(proto1)) + proto2 := "protocol2" + proto3 := "protocol251" // this one has the same first byte as "protocol1" + + tree.AddProtocol(proto1) + tree.AddProtocol(proto2) + tree.AddProtocol(proto3) + tree.RemoveProtocol(proto1) + tree.AddProtocol(proto1) + + n1 := tree.root.children[hash1[0]] + if n1.p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + if n1.children[hash1[1]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + + // There should be another leaf node added for proto1 + n2 := n1.children[hash1[1]] + if n2.children[hash1[2]] == nil || n2.children[hash1[2]].p == nil { + t.Fatal("proto1 was not added") + } + if n2.children[hash1[2]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0], hash1[1], hash1[2]}) { + t.Fatal("abbreviation of proto1 is incorrect") + } +} diff --git a/lazyClient.go b/lazyClient.go index 9e9850d..7690b0b 100644 --- a/lazyClient.go +++ b/lazyClient.go @@ -1,6 +1,7 @@ package multistream import ( + "encoding/hex" "fmt" "io" ) @@ -18,9 +19,15 @@ func NewMSSelect[T StringLike](c io.ReadWriteCloser, proto T) LazyConn { } func NewMSSelect2[T StringLike](c io.ReadWriteCloser, proto T, peerProtos []T) LazyConn { - // TODO: put peerProtos into lazyClientConn so that it knows what protocols the other peer supports + t := &abbrevTree[T]{} + for _, p := range peerProtos { + t.AddProtocol(p) + } + + // TODO: use a proper varint instead of a hex string later + abbrv := T(hex.EncodeToString(t.Abbreviate(proto))) return &lazyClientConn[T]{ - protos: []T{ProtocolID, proto}, + protos: []T{ProtocolID, abbrv}, con: c, rhandshakeOnce: newOnce(), diff --git a/multistream.go b/multistream.go index 17e1ef7..eb2d830 100644 --- a/multistream.go +++ b/multistream.go @@ -22,6 +22,9 @@ var ErrTooLarge = errors.New("incoming message was too large") // the multistream muxers on both sides of a channel can work with each other. const ProtocolID = "/multistream/1.0.0" +// Multistream-select version that protocol abbreviation is supported +const AbbrevSupportedMSSVersion = 2 + var writerPool = sync.Pool{ New: func() interface{} { return bufio.NewWriter(nil) From 6b42268f6db7a0bd9f09d93a8c1d039bfab30f2f Mon Sep 17 00:00:00 2001 From: Suphanat Chunhapanya Date: Thu, 20 Nov 2025 19:30:53 -0300 Subject: [PATCH 3/4] ms-select2: abbreviation on the server side --- abbrevTree.go | 19 +++++++++++++++++++ multistream.go | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/abbrevTree.go b/abbrevTree.go index 31e0fd2..d66f9e2 100644 --- a/abbrevTree.go +++ b/abbrevTree.go @@ -32,6 +32,8 @@ func (at *abbrevTree[T]) Abbreviate(pid T) []byte { if current.children[b] != nil { result = append(result, b) current = current.children[b] + } else { + break } } @@ -41,6 +43,23 @@ func (at *abbrevTree[T]) Abbreviate(pid T) []byte { return nil } +func (at *abbrevTree[T]) GetProtocolID(prefix []byte) (T, error) { + if at.root == nil { + return "", ErrUnknownPrefix + } + current := at.root + for _, b := range prefix { + if current.children[b] == nil { + return "", ErrUnknownPrefix + } + current = current.children[b] + } + if current.p == nil { + return "", ErrUnknownPrefix + } + return current.p.protocolID, nil +} + func (at *abbrevTree[T]) AddProtocol(pid T) { hash := sha256.Sum256([]byte(pid)) diff --git a/multistream.go b/multistream.go index eb2d830..2e1e604 100644 --- a/multistream.go +++ b/multistream.go @@ -5,6 +5,7 @@ package multistream import ( "bufio" + "encoding/hex" "errors" "fmt" "io" @@ -18,6 +19,9 @@ import ( // ErrTooLarge is an error to signal that an incoming message was too large var ErrTooLarge = errors.New("incoming message was too large") +// ErrUnknownPrefix is an error to signal that the protocol hash prefix is unknown +var ErrUnknownPrefix = errors.New("unknown protocol hash prefix") + // ProtocolID identifies the multistream protocol itself and makes sure // the multistream muxers on both sides of a channel can work with each other. const ProtocolID = "/multistream/1.0.0" @@ -55,6 +59,7 @@ type Handler[T StringLike] struct { type MultistreamMuxer[T StringLike] struct { handlerlock sync.RWMutex handlers []Handler[T] + abbrevTree abbrevTree[T] } // NewMultistreamMuxer creates a muxer. @@ -137,6 +142,7 @@ func (msm *MultistreamMuxer[T]) AddHandlerWithFunc(protocol T, match func(T) boo msm.handlerlock.Lock() defer msm.handlerlock.Unlock() + msm.abbrevTree.AddProtocol(protocol) msm.removeHandler(protocol) msm.handlers = append(msm.handlers, Handler[T]{ MatchFunc: match, @@ -150,6 +156,7 @@ func (msm *MultistreamMuxer[T]) RemoveHandler(protocol T) { msm.handlerlock.Lock() defer msm.handlerlock.Unlock() + msm.abbrevTree.RemoveProtocol(protocol) msm.removeHandler(protocol) } @@ -179,6 +186,24 @@ func (msm *MultistreamMuxer[T]) Protocols() []T { // fails because of a ProtocolID mismatch. var ErrIncorrectVersion = errors.New("client connected with incorrect version") +func (msm *MultistreamMuxer[T]) decodeProtocol(s T) (T, error) { + msm.handlerlock.RLock() + defer msm.handlerlock.RUnlock() + + bytes, err := hex.DecodeString(string(s)) + // TODO: decide whether to compare strings or use abbrevTree by looking at + // multistream version instead. + if err != nil { + return s, nil + } + + proto, err := msm.abbrevTree.GetProtocolID(bytes) + if err != nil { + return "", err + } + return proto, nil +} + func (msm *MultistreamMuxer[T]) findHandler(proto T) *Handler[T] { msm.handlerlock.RLock() defer msm.handlerlock.RUnlock() @@ -225,7 +250,12 @@ loop: return "", nil, err } - h := msm.findHandler(tok) + p, err := msm.decodeProtocol(tok) + if err != nil { + return "", nil, err + } + + h := msm.findHandler(p) if h == nil { if err := delimWriteBuffered(rwc, []byte("na")); err != nil { return "", nil, err @@ -239,7 +269,7 @@ loop: _ = delimWriteBuffered(rwc, []byte(tok)) // hand off processing to the sub-protocol handler - return tok, h.Handle, nil + return p, h.Handle, nil } } From e2fcc7710d49de1cb426c2bae7fc015437068eca Mon Sep 17 00:00:00 2001 From: Suphanat Chunhapanya Date: Sun, 30 Nov 2025 22:43:52 +0700 Subject: [PATCH 4/4] ms-select2: compress multiselect proto id --- lazyClient.go | 43 +++++++++++++++++++++----------- multistream.go | 67 +++++++++++++++++++++++++------------------------- 2 files changed, 62 insertions(+), 48 deletions(-) diff --git a/lazyClient.go b/lazyClient.go index 7690b0b..a6ec50d 100644 --- a/lazyClient.go +++ b/lazyClient.go @@ -1,7 +1,7 @@ package multistream import ( - "encoding/hex" + "bytes" "fmt" "io" ) @@ -10,7 +10,7 @@ import ( // protocol selection with a MultistreamMuxer. func NewMSSelect[T StringLike](c io.ReadWriteCloser, proto T) LazyConn { return &lazyClientConn[T]{ - protos: []T{ProtocolID, proto}, + protos: []protoInfo[T]{{ID: ProtocolID}, {ID: proto}}, con: c, rhandshakeOnce: newOnce(), @@ -24,11 +24,13 @@ func NewMSSelect2[T StringLike](c io.ReadWriteCloser, proto T, peerProtos []T) L t.AddProtocol(p) } - // TODO: use a proper varint instead of a hex string later - abbrv := T(hex.EncodeToString(t.Abbreviate(proto))) + abbrv := t.Abbreviate(proto) return &lazyClientConn[T]{ - protos: []T{ProtocolID, abbrv}, - con: c, + protos: []protoInfo[T]{ + {ID: ProtocolID, Abbrev: ProtocolAbbrev}, + {ID: proto, Abbrev: abbrv}, + }, + con: c, rhandshakeOnce: newOnce(), whandshakeOnce: newOnce(), @@ -40,7 +42,7 @@ func NewMSSelect2[T StringLike](c io.ReadWriteCloser, proto T, peerProtos []T) L // NewMSSelect. func NewMultistream[T StringLike](c io.ReadWriteCloser, proto T) LazyConn { return &lazyClientConn[T]{ - protos: []T{proto}, + protos: []protoInfo[T]{{ID: proto}}, con: c, rhandshakeOnce: newOnce(), @@ -76,6 +78,11 @@ func (o *once) Do(f func()) { f() } +type protoInfo[T StringLike] struct { + ID T + Abbrev []byte +} + // lazyClientConn is a ReadWriteCloser adapter that lazily negotiates a protocol // using multistream-select on first use. // @@ -92,7 +99,7 @@ type lazyClientConn[T StringLike] struct { werr error // The sequence of protocols to negotiate. - protos []T + protos []protoInfo[T] // The inner connection. con io.ReadWriteCloser @@ -122,18 +129,22 @@ func (l *lazyClientConn[T]) Read(b []byte) (int, error) { func (l *lazyClientConn[T]) doReadHandshake() { for _, proto := range l.protos { // read protocol - tok, err := ReadNextToken[T](l.con) + tok, err := ReadNextTokenBytes(l.con) if err != nil { l.rerr = err return } - if tok == "na" { - l.rerr = ErrNotSupported[T]{[]T{proto}} + if bytes.Equal(tok, []byte("na")) { + l.rerr = ErrNotSupported[T]{[]T{proto.ID}} return } - if tok != proto { - l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %s != %s )", tok, proto) + if proto.Abbrev != nil && !bytes.Equal(tok, proto.Abbrev) { + l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %x != %x )", tok, proto.Abbrev) + return + } + if proto.Abbrev == nil && T(tok) != proto.ID { + l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %s != %s )", T(tok), proto.ID) return } } @@ -149,7 +160,11 @@ func (l *lazyClientConn[T]) doWriteHandshakeWithData(extra []byte) int { defer putWriter(buf) for _, proto := range l.protos { - l.werr = delimWrite(buf, []byte(proto)) + if proto.Abbrev != nil { + l.werr = delimWrite(buf, proto.Abbrev) + } else { + l.werr = delimWrite(buf, []byte(proto.ID)) + } if l.werr != nil { return 0 } diff --git a/multistream.go b/multistream.go index 2e1e604..1c059d7 100644 --- a/multistream.go +++ b/multistream.go @@ -5,7 +5,7 @@ package multistream import ( "bufio" - "encoding/hex" + "bytes" "errors" "fmt" "io" @@ -26,6 +26,9 @@ var ErrUnknownPrefix = errors.New("unknown protocol hash prefix") // the multistream muxers on both sides of a channel can work with each other. const ProtocolID = "/multistream/1.0.0" +// ProtocolID identifies the multistream protocol abbreviation support +var ProtocolAbbrev = []byte{0xff, 0x11} + // Multistream-select version that protocol abbreviation is supported const AbbrevSupportedMSSVersion = 2 @@ -186,24 +189,6 @@ func (msm *MultistreamMuxer[T]) Protocols() []T { // fails because of a ProtocolID mismatch. var ErrIncorrectVersion = errors.New("client connected with incorrect version") -func (msm *MultistreamMuxer[T]) decodeProtocol(s T) (T, error) { - msm.handlerlock.RLock() - defer msm.handlerlock.RUnlock() - - bytes, err := hex.DecodeString(string(s)) - // TODO: decide whether to compare strings or use abbrevTree by looking at - // multistream version instead. - if err != nil { - return s, nil - } - - proto, err := msm.abbrevTree.GetProtocolID(bytes) - if err != nil { - return "", err - } - return proto, nil -} - func (msm *MultistreamMuxer[T]) findHandler(proto T) *Handler[T] { msm.handlerlock.RLock() defer msm.handlerlock.RUnlock() @@ -227,17 +212,21 @@ func (msm *MultistreamMuxer[T]) Negotiate(rwc io.ReadWriteCloser) (proto T, hand } }() - // Send the multistream protocol ID - // Ignore the error here. We want the handshake to finish, even if the - // other side has closed this rwc for writing. They may have sent us a - // message and closed. Future writers will get an error anyways. - _ = delimWriteBuffered(rwc, []byte(ProtocolID)) - line, err := ReadNextToken[T](rwc) + token, err := ReadNextTokenBytes(rwc) if err != nil { return "", nil, err } - - if line != ProtocolID { + supportAbbrev := false + // Send the multistream protocol ID or the mulstream protocol abbreviation + // Ignore the error here. We want the handshake to finish, even if the + // other side has closed this rwc for writing. They may have sent us a + // message and closed. Future writers will get an error anyways. + if bytes.Equal(token, ProtocolAbbrev) { + supportAbbrev = true + _ = delimWriteBuffered(rwc, ProtocolAbbrev) + } else if T(token) == ProtocolID { + _ = delimWriteBuffered(rwc, []byte(ProtocolID)) + } else { rwc.Close() return "", nil, ErrIncorrectVersion } @@ -245,17 +234,27 @@ func (msm *MultistreamMuxer[T]) Negotiate(rwc io.ReadWriteCloser) (proto T, hand loop: for { // Now read and respond to commands until they send a valid protocol id - tok, err := ReadNextToken[T](rwc) + var proto T + + tok, err := ReadNextTokenBytes(rwc) if err != nil { return "", nil, err } - p, err := msm.decodeProtocol(tok) - if err != nil { - return "", nil, err + if supportAbbrev { + // decode the protocol abbreviation using the abbreviation tree + msm.handlerlock.RLock() + proto, err = msm.abbrevTree.GetProtocolID(tok) + msm.handlerlock.RUnlock() + + if err != nil { + return "", nil, err + } + } else { + proto = T(tok) } - h := msm.findHandler(p) + h := msm.findHandler(proto) if h == nil { if err := delimWriteBuffered(rwc, []byte("na")); err != nil { return "", nil, err @@ -266,10 +265,10 @@ loop: // Ignore the error here. We want the handshake to finish, even if the // other side has closed this rwc for writing. They may have sent us a // message and closed. Future writers will get an error anyways. - _ = delimWriteBuffered(rwc, []byte(tok)) + _ = delimWriteBuffered(rwc, tok) // hand off processing to the sub-protocol handler - return p, h.Handle, nil + return proto, h.Handle, nil } }