Skip to content

Commit 95cad50

Browse files
shoriweMarcoPolo
authored andcommitted
feat: relay: add option for custom filter function
1 parent 72894e3 commit 95cad50

File tree

3 files changed

+68
-21
lines changed

3 files changed

+68
-21
lines changed

p2p/protocol/circuitv2/relay/options.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
package relay
22

3+
import (
4+
"github.com/multiformats/go-multiaddr"
5+
)
6+
37
type Option func(*Relay) error
48

59
// WithResources is a Relay option that sets specific relay resources for the relay.
@@ -18,6 +22,18 @@ func WithLimit(limit *RelayLimit) Option {
1822
}
1923
}
2024

25+
// Reservation address function used to promote addresses to connected nodes
26+
type ReservationAddressFilterFunc func(addr multiaddr.Multiaddr) (include bool)
27+
28+
// Overrides the default reservation address filter.
29+
// This will permit the relay let the client know it have access to non public addresses too.
30+
func WithReservationAddressFilter(filter ReservationAddressFilterFunc) (option Option) {
31+
return func(r *Relay) (err error) {
32+
r.reservationAddrFilter = filter
33+
return nil
34+
}
35+
}
36+
2137
// WithInfiniteLimits is a Relay option that disables limits.
2238
func WithInfiniteLimits() Option {
2339
return func(r *Relay) error {

p2p/protocol/circuitv2/relay/relay.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ type Relay struct {
4646
ctx context.Context
4747
cancel func()
4848

49+
reservationAddrFilter ReservationAddressFilterFunc
50+
4951
host host.Host
5052
rc Resources
5153
acl ACLFilter
@@ -75,6 +77,8 @@ func New(h host.Host, opts ...Option) (*Relay, error) {
7577
acl: nil,
7678
rsvp: make(map[peer.ID]time.Time),
7779
conns: make(map[peer.ID]int),
80+
81+
reservationAddrFilter: manet.IsPublicAddr,
7882
}
7983

8084
for _, opt := range opts {
@@ -235,6 +239,7 @@ func (r *Relay) handleReserve(s network.Stream) pbv2.Status {
235239
// For example, the stream might be reset or the connection might be closed before the reservation is received.
236240
// In that case, the reservation will just be garbage collected later.
237241
rsvp := makeReservationMsg(
242+
r.reservationAddrFilter,
238243
r.host.Peerstore().PrivKey(r.host.ID()),
239244
r.host.ID(),
240245
r.host.Addrs(),
@@ -612,6 +617,7 @@ func (r *Relay) writeResponse(s network.Stream, status pbv2.Status, rsvp *pbv2.R
612617
}
613618

614619
func makeReservationMsg(
620+
reservationAddrFilter ReservationAddressFilterFunc,
615621
signingKey crypto.PrivKey,
616622
selfID peer.ID,
617623
selfAddrs []ma.Multiaddr,
@@ -630,7 +636,7 @@ func makeReservationMsg(
630636

631637
addrBytes := make([][]byte, 0, len(selfAddrs))
632638
for _, addr := range selfAddrs {
633-
if !manet.IsPublicAddr(addr) {
639+
if !reservationAddrFilter(addr) {
634640
continue
635641
}
636642

p2p/protocol/circuitv2/relay/relay_priv_test.go

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"github.com/stretchr/testify/require"
1111

1212
ma "github.com/multiformats/go-multiaddr"
13+
"github.com/multiformats/go-multiaddr/matest"
14+
manet "github.com/multiformats/go-multiaddr/net"
1315
)
1416

1517
func genKeyAndID(t *testing.T) (crypto.PrivKey, peer.ID) {
@@ -28,26 +30,49 @@ func TestMakeReservationWithP2PAddrs(t *testing.T) {
2830
_, otherID := genKeyAndID(t)
2931
_, reserverID := genKeyAndID(t)
3032

31-
addrs := []ma.Multiaddr{
32-
ma.StringCast("/ip4/1.2.3.4/tcp/1234"), // No p2p part
33-
ma.StringCast("/ip4/1.2.3.4/tcp/1235/p2p/" + selfID.String()), // Already has p2p part
34-
ma.StringCast("/ip4/1.2.3.4/tcp/1236/p2p/" + otherID.String()), // Some other peer (?? Not expected, but we could get anything in this func)
35-
}
36-
37-
rsvp := makeReservationMsg(selfKey, selfID, addrs, reserverID, time.Now().Add(time.Minute))
38-
require.NotNil(t, rsvp)
33+
tcs := []struct {
34+
name string
35+
filter func(ma.Multiaddr) bool
36+
input []ma.Multiaddr
37+
expected []ma.Multiaddr
38+
}{{
39+
name: "only public",
40+
filter: manet.IsPublicAddr,
41+
input: []ma.Multiaddr{
42+
ma.StringCast("/ip4/1.2.3.4/tcp/1234"), // No p2p part
43+
ma.StringCast("/ip4/1.2.3.4/tcp/1235/p2p/" + selfID.String()), // Already has p2p part
44+
ma.StringCast("/ip4/192.168.1.9/tcp/1235/p2p/" + selfID.String()), // Already has p2p part
45+
ma.StringCast("/ip4/1.2.3.4/tcp/1236/p2p/" + otherID.String()), // Some other peer (?? Not expected, but we could get anything in this func)
46+
},
47+
expected: []ma.Multiaddr{
48+
ma.StringCast("/ip4/1.2.3.4/tcp/1234/p2p/" + selfID.String()),
49+
ma.StringCast("/ip4/1.2.3.4/tcp/1235/p2p/" + selfID.String()),
50+
},
51+
}, {
52+
name: "only not public",
53+
filter: func(m ma.Multiaddr) bool { return !manet.IsPublicAddr(m) },
54+
input: []ma.Multiaddr{
55+
ma.StringCast("/ip4/1.2.3.4/tcp/1234"), // No p2p part
56+
ma.StringCast("/ip4/1.2.3.4/tcp/1235/p2p/" + selfID.String()), // Already has p2p part
57+
ma.StringCast("/ip4/192.168.1.9/tcp/1235/p2p/" + selfID.String()), // Already has p2p part
58+
ma.StringCast("/ip4/1.2.3.4/tcp/1236/p2p/" + otherID.String()), // Some other peer (?? Not expected, but we could get anything in this func)
59+
},
60+
expected: []ma.Multiaddr{
61+
ma.StringCast("/ip4/192.168.1.9/tcp/1235/p2p/" + selfID.String()),
62+
},
63+
}}
64+
for _, tc := range tcs {
65+
t.Run(tc.name, func(t *testing.T) {
66+
rsvp := makeReservationMsg(tc.filter, selfKey, selfID, tc.input, reserverID, time.Now().Add(time.Minute))
67+
require.NotNil(t, rsvp)
3968

40-
expectedAddrs := []string{
41-
"/ip4/1.2.3.4/tcp/1234/p2p/" + selfID.String(),
42-
"/ip4/1.2.3.4/tcp/1235/p2p/" + selfID.String(),
69+
addrsFromRsvp := make([]ma.Multiaddr, 0, len(rsvp.GetAddrs()))
70+
for _, addr := range rsvp.GetAddrs() {
71+
a, err := ma.NewMultiaddrBytes(addr)
72+
require.NoError(t, err)
73+
addrsFromRsvp = append(addrsFromRsvp, a)
74+
}
75+
matest.AssertEqualMultiaddrs(t, tc.expected, addrsFromRsvp)
76+
})
4377
}
44-
45-
addrsFromRsvp := make([]string, 0, len(rsvp.GetAddrs()))
46-
for _, addr := range rsvp.GetAddrs() {
47-
a, err := ma.NewMultiaddrBytes(addr)
48-
require.NoError(t, err)
49-
addrsFromRsvp = append(addrsFromRsvp, a.String())
50-
}
51-
52-
require.Equal(t, expectedAddrs, addrsFromRsvp)
5378
}

0 commit comments

Comments
 (0)