Skip to content

Commit 15906a9

Browse files
committed
Backport Connections method to GetChannel
Signed-off-by: Lorenzo <lorenzo.donini90@gmail.com>
1 parent 1d8ffca commit 15906a9

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

ws/server.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ type Server interface {
117117
// Addr gives the address on which the server is listening, useful if, for
118118
// example, the port is system-defined (set to 0).
119119
Addr() *net.TCPAddr
120+
// GetChannel retrieves an active Channel connection by its unique identifier.
121+
// If a connection with the given ID exists, it returns the corresponding webSocket instance.
122+
// If no connection is found with the specified ID, it returns nil and a false flag.
123+
GetChannel(websocketId string) (Channel, bool)
120124
}
121125

122126
// Default implementation of a Websocket server.
@@ -304,6 +308,13 @@ func (s *server) StopConnection(id string, closeError websocket.CloseError) erro
304308
return w.Close(closeError)
305309
}
306310

311+
func (s *server) GetChannel(websocketId string) (Channel, bool) {
312+
s.connMutex.RLock()
313+
defer s.connMutex.RUnlock()
314+
c, ok := s.connections[websocketId]
315+
return c, ok
316+
}
317+
307318
func (s *server) stopConnections() {
308319
s.connMutex.RLock()
309320
defer s.connMutex.RUnlock()

ws/websocket_test.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ func (s *WebSocketSuite) TestServerStopConnection() {
486486
Code: websocket.CloseGoingAway,
487487
Text: "CloseClientConnection",
488488
}
489+
wsID := "testws"
489490
s.server = newWebsocketServer(s.T(), nil)
490491
s.server.SetNewClientHandler(func(ws Channel) {
491492
triggerC <- struct{}{}
@@ -508,14 +509,25 @@ func (s *WebSocketSuite) TestServerStopConnection() {
508509
// Start server
509510
go s.server.Start(serverPort, serverPath)
510511
time.Sleep(100 * time.Millisecond)
512+
var c Channel
513+
var ok bool
514+
c, ok = s.server.GetChannel(wsID)
515+
s.False(ok)
516+
s.Nil(c)
511517
// Connect client
512518
host := fmt.Sprintf("localhost:%v", serverPort)
513519
u := url.URL{Scheme: "ws", Host: host, Path: testPath}
514520
err := s.client.Start(u.String())
515521
s.NoError(err)
516522
// Wait for client to connect
517-
_, ok := <-triggerC
523+
_, ok = <-triggerC
524+
s.True(ok)
525+
// Verify channel
526+
c, ok = s.server.GetChannel(wsID)
518527
s.True(ok)
528+
s.NotNil(c)
529+
s.Equal(wsID, c.ID())
530+
s.True(c.IsConnected())
519531
// Close connection and wait for client to be closed
520532
err = s.server.StopConnection(path.Base(testPath), closeError)
521533
s.NoError(err)

0 commit comments

Comments
 (0)