Skip to content

Commit 40ddfaa

Browse files
committed
fix: add mutex protection for concurrent map access in p2p client
1 parent afdb01d commit 40ddfaa

File tree

2 files changed

+86
-24
lines changed

2 files changed

+86
-24
lines changed

client.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"io"
1010
"log"
1111
"os"
12+
"sync"
1213
"time"
1314

1415
"github.com/libp2p/go-libp2p"
@@ -32,6 +33,7 @@ type Client struct {
3233
topics map[string]*pubsub.Topic
3334
subs map[string]*pubsub.Subscription
3435
msgChans map[string]chan Message
36+
mu sync.RWMutex
3537
peerTracker *peerTracker
3638
ctx context.Context
3739
cancel context.CancelFunc
@@ -179,11 +181,17 @@ func NewClient(config Config) (*Client, error) {
179181
// The returned channel will be closed when the client is closed.
180182
func (c *Client) Subscribe(topic string) <-chan Message {
181183
msgChan := make(chan Message, 100)
184+
185+
c.mu.Lock()
182186
c.msgChans[topic] = msgChan
187+
c.mu.Unlock()
183188

184189
go func() {
185190
// Join or get existing topic
191+
c.mu.Lock()
186192
t, ok := c.topics[topic]
193+
c.mu.Unlock()
194+
187195
if !ok {
188196
var err error
189197
t, err = c.pubsub.Join(topic)
@@ -192,7 +200,10 @@ func (c *Client) Subscribe(topic string) <-chan Message {
192200
close(msgChan)
193201
return
194202
}
203+
204+
c.mu.Lock()
195205
c.topics[topic] = t
206+
c.mu.Unlock()
196207
}
197208

198209
// Subscribe to topic
@@ -202,7 +213,10 @@ func (c *Client) Subscribe(topic string) <-chan Message {
202213
close(msgChan)
203214
return
204215
}
216+
217+
c.mu.Lock()
205218
c.subs[topic] = sub
219+
c.mu.Unlock()
206220

207221
// Set up peer connection notifications for this topic
208222
c.host.Network().Notify(&network.NotifyBundle{
@@ -245,14 +259,20 @@ func (c *Client) Subscribe(topic string) <-chan Message {
245259

246260
// Publish publishes a message to the specified topic.
247261
func (c *Client) Publish(ctx context.Context, topic string, data []byte) error {
262+
c.mu.RLock()
248263
t, ok := c.topics[topic]
264+
c.mu.RUnlock()
265+
249266
if !ok {
250267
var err error
251268
t, err = c.pubsub.Join(topic)
252269
if err != nil {
253270
return fmt.Errorf("failed to join topic: %w", err)
254271
}
272+
273+
c.mu.Lock()
255274
c.topics[topic] = t
275+
c.mu.Unlock()
256276
}
257277

258278
// Wrap data with metadata
@@ -305,6 +325,7 @@ func (c *Client) Close() error {
305325

306326
done := make(chan struct{})
307327
go func() {
328+
c.mu.Lock()
308329
// Close all message channels
309330
for _, ch := range c.msgChans {
310331
close(ch)
@@ -319,6 +340,7 @@ func (c *Client) Close() error {
319340
for _, topic := range c.topics {
320341
topic.Close()
321342
}
343+
c.mu.Unlock()
322344

323345
// Close services
324346
if c.mdnsService != nil {
@@ -355,7 +377,14 @@ func (c *Client) waitForDHTAndAdvertise(ctx context.Context, routingDiscovery *d
355377
case <-ticker.C:
356378
if len(c.dht.RoutingTable().ListPeers()) > 0 {
357379
// Advertise on DHT for all topics
380+
c.mu.RLock()
381+
topicsCopy := make([]string, 0, len(c.topics))
358382
for topic := range c.topics {
383+
topicsCopy = append(topicsCopy, topic)
384+
}
385+
c.mu.RUnlock()
386+
387+
for _, topic := range topicsCopy {
359388
_, err := routingDiscovery.Advertise(ctx, topic)
360389
if err != nil {
361390
c.logger.Warnf("Failed to advertise topic %s: %v", topic, err)
@@ -378,7 +407,14 @@ func (c *Client) discoverPeers(ctx context.Context, routingDiscovery *drouting.R
378407
case <-ctx.Done():
379408
return
380409
case <-ticker.C:
410+
c.mu.RLock()
411+
topicsCopy := make([]string, 0, len(c.topics))
381412
for topic := range c.topics {
413+
topicsCopy = append(topicsCopy, topic)
414+
}
415+
c.mu.RUnlock()
416+
417+
for _, topic := range topicsCopy {
382418
peerChan, err := routingDiscovery.FindPeers(ctx, topic)
383419
if err != nil {
384420
continue

example/main.go

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"os"
88
"os/signal"
9+
"strings"
910
"syscall"
1011
"time"
1112

@@ -14,11 +15,11 @@ import (
1415
p2p "github.com/ordishs/p2p_poc"
1516
)
1617

17-
const topicName = "broadcast_p2p_poc"
18-
1918
func main() {
2019
name := flag.String("name", "", "Your node name")
2120
privateKey := flag.String("key", "", "Private key hex (will generate if not provided)")
21+
topics := flag.String("topics", "broadcast_p2p_poc", "Comma-separated list of topics to subscribe to")
22+
noBroadcast := flag.Bool("no-broadcast", false, "Disable message broadcasting")
2223

2324
flag.Parse()
2425

@@ -71,37 +72,62 @@ func main() {
7172
}
7273
defer client.Close()
7374

74-
// Subscribe to topic
75-
msgChan := client.Subscribe(topicName)
75+
// Parse topics list
76+
topicList := strings.Split(*topics, ",")
77+
for i, t := range topicList {
78+
topicList[i] = strings.TrimSpace(t)
79+
}
80+
81+
logger.Infof("Subscribing to topics: %v", topicList)
82+
83+
// Subscribe to all topics and merge messages into single channel
84+
allMsgChan := make(chan p2p.Message, 100)
85+
86+
for _, topic := range topicList {
87+
msgChan := client.Subscribe(topic)
88+
logger.Infof("Subscribed to topic: %s", topic)
89+
90+
go func(ch <-chan p2p.Message) {
91+
for msg := range ch {
92+
allMsgChan <- msg
93+
}
94+
}(msgChan)
95+
96+
logger.Infof("Subscribed to topic: %s", topic)
97+
}
7698

7799
// Start message receiver
78100
go func() {
79-
for msg := range msgChan {
80-
fmt.Printf("[%-20s] %s: %s\n", msg.FromID, msg.From, string(msg.Data))
101+
for msg := range allMsgChan {
102+
fmt.Printf("[%-52s] %s: %s (topic: %s)\n", msg.FromID, msg.From, string(msg.Data), msg.Topic)
81103
}
82104
}()
83105

84-
// Start message broadcaster
85-
go func() {
86-
counter := 0
87-
ticker := time.NewTicker(1 * time.Second)
88-
defer ticker.Stop()
106+
// Start message broadcaster (publishes to all topics)
107+
if !*noBroadcast {
108+
go func() {
109+
counter := 0
110+
ticker := time.NewTicker(1 * time.Second)
111+
defer ticker.Stop()
89112

90-
for {
91-
select {
92-
case <-ctx.Done():
93-
return
94-
case <-ticker.C:
95-
counter++
96-
data := fmt.Sprintf("Message #%d", counter)
97-
98-
if err := client.Publish(topicName, []byte(data)); err != nil {
113+
for {
114+
select {
115+
case <-ctx.Done():
99116
return
117+
case <-ticker.C:
118+
counter++
119+
data := fmt.Sprintf("Message #%d", counter)
120+
121+
for _, topic := range topicList {
122+
if err := client.Publish(ctx, topic, []byte(data)); err != nil {
123+
return
124+
}
125+
}
126+
fmt.Printf("[%-52s] %s: %s\n", "local", *name, data)
100127
}
101-
fmt.Printf("[%-52s] %s: %s\n", "local", *name, data)
102128
}
103-
}
104-
}()
129+
}()
130+
}
105131

106132
// Periodically display peer information
107133
go func() {
@@ -117,7 +143,7 @@ func main() {
117143
if len(peers) > 0 {
118144
fmt.Printf("\n=== Connected Peers: %d ===\n", len(peers))
119145
for _, peer := range peers {
120-
fmt.Printf(" - %s [%s]\n", peer.Name, peer.ID[:16])
146+
fmt.Printf(" - %s [%s]\n", peer.Name, peer.ID)
121147
for _, addr := range peer.Addrs {
122148
fmt.Printf(" %s\n", addr)
123149
}

0 commit comments

Comments
 (0)