Skip to content

Commit 24d6886

Browse files
Thomas StrombergThomas Stromberg
authored andcommitted
reliability improvements
1 parent 896b678 commit 24d6886

File tree

3 files changed

+62
-33
lines changed

3 files changed

+62
-33
lines changed

pkg/client/client.go

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ type Config struct {
5757
OnConnect func()
5858
ServerURL string
5959
Token string
60+
TokenProvider func() (string, error) // Optional: dynamically provide fresh tokens for reconnection
6061
Organization string
6162
EventTypes []string
6263
PullRequests []string
@@ -98,8 +99,8 @@ func New(config Config) (*Client, error) {
9899
if config.Organization == "" && len(config.PullRequests) == 0 {
99100
return nil, errors.New("organization or pull requests required")
100101
}
101-
if config.Token == "" {
102-
return nil, errors.New("token is required")
102+
if config.Token == "" && config.TokenProvider == nil {
103+
return nil, errors.New("token or tokenProvider is required")
103104
}
104105

105106
// Set defaults
@@ -237,6 +238,17 @@ func (c *Client) Stop() {
237238
func (c *Client) connect(ctx context.Context) error {
238239
c.logger.Info("Establishing WebSocket connection")
239240

241+
// Get fresh token if TokenProvider is configured
242+
token := c.config.Token
243+
if c.config.TokenProvider != nil {
244+
t, err := c.config.TokenProvider()
245+
if err != nil {
246+
return fmt.Errorf("token provider: %w", err)
247+
}
248+
token = t
249+
c.logger.Debug("Using fresh token from TokenProvider")
250+
}
251+
240252
// Create WebSocket config with appropriate origin
241253
origin := "http://localhost/"
242254
if strings.HasPrefix(c.config.ServerURL, "wss://") {
@@ -249,7 +261,7 @@ func (c *Client) connect(ctx context.Context) error {
249261

250262
// Add Authorization header
251263
wsConfig.Header = make(map[string][]string)
252-
wsConfig.Header["Authorization"] = []string{fmt.Sprintf("Bearer %s", c.config.Token)}
264+
wsConfig.Header["Authorization"] = []string{fmt.Sprintf("Bearer %s", token)}
253265

254266
// Dial the server
255267
ws, err := websocket.DialConfig(wsConfig)
@@ -361,10 +373,10 @@ func (c *Client) connect(ctx context.Context) error {
361373
c.logger.Error("SUBSCRIPTION REJECTED BY SERVER!", "error_code", errorCode, "message", message)
362374
c.logger.Error(separatorLine)
363375

364-
// Return AuthenticationError for access denied errors to prevent retries
365-
if errorCode == "access_denied" {
376+
// Return AuthenticationError for authentication/authorization errors to prevent retries
377+
if errorCode == "access_denied" || errorCode == "authentication_failed" {
366378
return &AuthenticationError{
367-
message: fmt.Sprintf("Access denied: %s", message),
379+
message: fmt.Sprintf("Authentication/authorization failed: %s", message),
368380
}
369381
}
370382

@@ -537,11 +549,6 @@ func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error {
537549
return fmt.Errorf("read: %w", err)
538550
}
539551

540-
// Clear read timeout after successful read
541-
if err := ws.SetReadDeadline(time.Time{}); err != nil {
542-
c.logger.Warn("Failed to clear read timeout", "error", err)
543-
}
544-
545552
// Check message type
546553
responseType, ok := response[msgTypeField].(string)
547554
if !ok {

pkg/srv/client.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
type Client struct {
2121
conn *websocket.Conn
2222
send chan Event
23+
control chan map[string]any // Control messages (pongs, shutdown notices)
2324
hub *Hub
2425
done chan struct{}
2526
userOrgs map[string]bool
@@ -49,7 +50,8 @@ func NewClient(id string, sub Subscription, conn *websocket.Conn, hub *Hub, user
4950
ID: id,
5051
subscription: sub,
5152
conn: conn,
52-
send: make(chan Event, 100), // Increased buffer to reduce dropped messages
53+
send: make(chan Event, 100), // Increased buffer to reduce dropped messages
54+
control: make(chan map[string]any, 5), // Buffer for control messages (pongs, shutdown)
5355
hub: hub,
5456
done: make(chan struct{}),
5557
userOrgs: orgsMap,
@@ -98,6 +100,18 @@ func (c *Client) Run(ctx context.Context, pingInterval, writeTimeout time.Durati
98100
return
99101
}
100102

103+
case ctrl, ok := <-c.control:
104+
if !ok {
105+
log.Printf("client %s: control channel closed", c.ID)
106+
return
107+
}
108+
109+
// Send control message (pong, shutdown notice, etc.)
110+
if err := c.write(ctrl, writeTimeout); err != nil {
111+
log.Printf("client %s: control message send failed: %v", c.ID, err)
112+
return
113+
}
114+
101115
case event, ok := <-c.send:
102116
if !ok {
103117
log.Printf("client %s: send channel closed", c.ID)
@@ -133,5 +147,6 @@ func (c *Client) Close() {
133147
c.closeOnce.Do(func() {
134148
close(c.done)
135149
close(c.send)
150+
close(c.control)
136151
})
137152
}

pkg/srv/websocket.go

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ func sendErrorResponse(ws *websocket.Conn, errInfo errorInfo, ip string) error {
233233
return err
234234
}
235235

236+
// Allow time for client to receive error before connection closes.
237+
// Without this delay, TCP close can race with message delivery, causing clients to see EOF.
238+
time.Sleep(100 * time.Millisecond)
239+
236240
return nil
237241
}
238242

@@ -418,20 +422,20 @@ func (h *WebSocketHandler) validateAuth(ctx context.Context, ws *websocket.Conn,
418422
}
419423

420424
// closeWebSocket gracefully closes a WebSocket connection with cleanup.
421-
func closeWebSocket(ws *websocket.Conn) {
425+
// If client is provided, shutdown message is sent via control channel to avoid race.
426+
func closeWebSocket(ws *websocket.Conn, client *Client) {
422427
clientIP := security.ClientIP(ws.Request())
423428
log.Printf("WebSocket Handle() cleanup - closing connection for IP %s", clientIP)
424429

425-
// Send a final shutdown message to allow graceful client disconnect
426-
shutdownMsg := map[string]string{"type": "server_closing", "code": "1001"}
427-
if err := ws.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
428-
log.Printf("failed to set write deadline for shutdown message: %v", err)
429-
}
430-
if err := websocket.JSON.Send(ws, shutdownMsg); err != nil {
431-
// Expected during abrupt disconnection - don't log common cases
432-
if !strings.Contains(err.Error(), "use of closed network connection") &&
433-
!strings.Contains(err.Error(), "broken pipe") {
434-
log.Printf("failed to send shutdown message: %v", err)
430+
// Send shutdown message via control channel if client exists
431+
if client != nil {
432+
shutdownMsg := map[string]any{"type": "server_closing", "code": "1001"}
433+
select {
434+
case client.control <- shutdownMsg:
435+
// Give brief time for shutdown message to be sent
436+
time.Sleep(100 * time.Millisecond)
437+
case <-time.After(200 * time.Millisecond):
438+
log.Printf("Timeout sending shutdown message to client %s", client.ID)
435439
}
436440
}
437441

@@ -460,8 +464,11 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) {
460464
ctx, cancel := context.WithCancel(ws.Request().Context())
461465
defer cancel()
462466

463-
// Ensure WebSocket is properly closed
464-
defer closeWebSocket(ws)
467+
// Ensure WebSocket is properly closed (client will be set later if connection succeeds)
468+
var client *Client
469+
defer func() {
470+
closeWebSocket(ws, client)
471+
}()
465472

466473
// Get client IP
467474
ip := security.ClientIP(ws.Request())
@@ -645,7 +652,7 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) {
645652
}
646653
id[i] = charset[n.Int64()]
647654
}
648-
client := NewClient(
655+
client = NewClient(
649656
string(id),
650657
sub,
651658
ws,
@@ -765,17 +772,17 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) {
765772
// (which happens for ANY message including pong) keeps the connection alive
766773
continue
767774
case "ping":
768-
// Client sent us a ping, send pong back
775+
// Client sent us a ping, send pong back via control channel to avoid race
769776
pong := map[string]any{"type": "pong"}
770777
if seq, ok := msgMap["seq"]; ok {
771778
pong["seq"] = seq
772779
}
773-
if err := ws.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
774-
log.Printf("failed to set write deadline for pong to client %s: %v", client.ID, err)
775-
continue
776-
}
777-
if err := websocket.JSON.Send(ws, pong); err != nil {
778-
log.Printf("failed to send pong to client %s: %v", client.ID, err)
780+
// Non-blocking send to avoid deadlock if control channel is full
781+
select {
782+
case client.control <- pong:
783+
// Pong queued successfully
784+
default:
785+
log.Printf("WARNING: client %s control channel full, dropping pong", client.ID)
779786
}
780787
continue
781788
case "keepalive", "heartbeat":

0 commit comments

Comments
 (0)