Skip to content

Commit fb513d2

Browse files
Thomas StrombergThomas Stromberg
authored andcommitted
fix panic
1 parent 77b9cfe commit fb513d2

File tree

6 files changed

+55
-61
lines changed

6 files changed

+55
-61
lines changed

cmd/server/main.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,10 @@ const (
3232
minMaskHeaderLength = 20 // Minimum header length before we show full "[REDACTED]"
3333
)
3434

35-
// getEnvOrDefault returns the value of the environment variable or the default if not set.
36-
func getEnvOrDefault(key, defaultValue string) string {
37-
if value := os.Getenv(key); value != "" {
38-
return value
39-
}
40-
return defaultValue
41-
}
35+
// contextKey is a custom type for context keys to avoid collisions.
36+
type contextKey string
37+
38+
const reservationTokenKey contextKey = "reservation_token"
4239

4340
var (
4441
webhookSecret = flag.String("webhook-secret", os.Getenv("GITHUB_WEBHOOK_SECRET"), "GitHub webhook secret for signature verification")
@@ -50,8 +47,12 @@ var (
5047
maxConnsPerIP = flag.Int("max-conns-per-ip", 10, "Maximum WebSocket connections per IP")
5148
maxConnsTotal = flag.Int("max-conns-total", 1000, "Maximum total WebSocket connections")
5249
rateLimit = flag.Int("rate-limit", 100, "Maximum requests per minute per IP")
53-
allowedEvents = flag.String("allowed-events", getEnvOrDefault("ALLOWED_WEBHOOK_EVENTS", "*"),
54-
"Comma-separated list of allowed webhook event types (use '*' for all, default: '*')")
50+
allowedEvents = flag.String("allowed-events", func() string {
51+
if value := os.Getenv("ALLOWED_WEBHOOK_EVENTS"); value != "" {
52+
return value
53+
}
54+
return "*"
55+
}(), "Comma-separated list of allowed webhook event types (use '*' for all, default: '*')")
5556
debugHeaders = flag.Bool("debug-headers", false, "Log request headers for debugging (security warning: may log sensitive data)")
5657
)
5758

@@ -244,7 +245,7 @@ func main() {
244245
}
245246

246247
// Set reservation token in request context so websocket handler can commit it
247-
r = r.WithContext(context.WithValue(r.Context(), "reservation_token", reservationToken))
248+
r = r.WithContext(context.WithValue(r.Context(), reservationTokenKey, reservationToken))
248249

249250
// Log successful auth and proceed to upgrade
250251
log.Printf("WebSocket UPGRADE: ip=%s duration=%v", ip, time.Since(startTime))

pkg/client/client.go

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ type Client struct {
8585
ws *websocket.Conn
8686
stopCh chan struct{}
8787
stoppedCh chan struct{}
88-
writeCh chan any // Channel for serializing all writes
88+
stopOnce sync.Once // Ensures Stop() is only executed once
89+
writeCh chan any // Channel for serializing all writes
8990
eventCount int
9091
retries int
9192
}
@@ -220,16 +221,27 @@ func (c *Client) Start(ctx context.Context) error {
220221
}
221222

222223
// Stop gracefully stops the client.
224+
// Safe to call multiple times - only the first call will take effect.
225+
// Also safe to call before Start() or if Start() was never called.
223226
func (c *Client) Stop() {
224-
close(c.stopCh)
225-
c.mu.Lock()
226-
if c.ws != nil {
227-
if closeErr := c.ws.Close(); closeErr != nil {
228-
c.logger.Error("Error closing websocket on shutdown", "error", closeErr)
227+
c.stopOnce.Do(func() {
228+
close(c.stopCh)
229+
c.mu.Lock()
230+
if c.ws != nil {
231+
if closeErr := c.ws.Close(); closeErr != nil {
232+
c.logger.Error("Error closing websocket on shutdown", "error", closeErr)
233+
}
229234
}
230-
}
231-
c.mu.Unlock()
232-
<-c.stoppedCh
235+
c.mu.Unlock()
236+
237+
// Wait for Start() to finish, but with timeout in case Start() was never called
238+
select {
239+
case <-c.stoppedCh:
240+
// Start() completed normally
241+
case <-time.After(100 * time.Millisecond):
242+
// Start() was never called or hasn't started yet - that's ok
243+
}
244+
})
233245
}
234246

235247
// connect establishes a WebSocket connection and handles events.

pkg/security/connlimiter.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ type connectionInfo struct {
2525

2626
// reservation represents a reserved connection slot.
2727
type reservation struct {
28-
ip string
2928
createdAt time.Time
29+
ip string
3030
}
3131

3232
// ConnectionLimiter tracks connections per IP and total.
@@ -185,7 +185,7 @@ func (cl *ConnectionLimiter) CancelReservation(token string) {
185185
}
186186

187187
// CanAdd checks if a connection can be added for the given IP without actually adding it.
188-
// DEPRECATED: This method has a TOCTOU race condition. Use Reserve() instead, which
188+
// Deprecated: This method has a TOCTOU race condition. Use Reserve() instead, which
189189
// atomically checks and reserves a slot, preventing the race.
190190
//
191191
// This method is kept for backward compatibility and testing only.

pkg/security/race_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ func TestConnectionLimiterTOCTOU_Documentation(t *testing.T) {
114114
ip := "192.168.1.1"
115115

116116
var wg sync.WaitGroup
117-
var canAddSuccess int32 // How many times CanAdd returned true
118-
var addSuccess int32 // How many times Add actually succeeded
119-
var addFailed int32 // How many times Add failed despite CanAdd=true
117+
var canAddSuccess int32 // How many times CanAdd returned true
118+
var addSuccess int32 // How many times Add actually succeeded
119+
var addFailed int32 // How many times Add failed despite CanAdd=true
120120

121121
// Launch many goroutines simultaneously trying to add connections
122122
// This simulates multiple HTTP handlers racing to add connections

pkg/srv/websocket.go

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,6 @@ func (h *WebSocketHandler) extractGitHubToken(ws *websocket.Conn, ip string) (st
150150
return githubToken, true
151151
}
152152

153-
// tokenDebugInfo extracts token prefix for debug logging.
154-
func tokenDebugInfo(token string) string {
155-
if len(token) >= tokenPrefixLength {
156-
return token[:tokenPrefixLength]
157-
}
158-
return ""
159-
}
160-
161153
// errorInfo holds error response details.
162154
type errorInfo struct {
163155
code string
@@ -281,7 +273,10 @@ func (*WebSocketHandler) handleAuthError(
281273
logContext string,
282274
) error {
283275
errInfo := determineErrorInfo(err, username, orgName, userOrgs)
284-
tokenPrefix := tokenDebugInfo(githubToken)
276+
tokenPrefix := ""
277+
if len(githubToken) >= tokenPrefixLength {
278+
tokenPrefix = githubToken[:tokenPrefixLength]
279+
}
285280

286281
logger.Error(logContext, err, logger.Fields{
287282
"ip": ip,
@@ -411,11 +406,6 @@ type wsCloser struct {
411406
mu sync.Mutex
412407
}
413408

414-
// newWSCloser creates a new WebSocket closer wrapper.
415-
func newWSCloser(ws *websocket.Conn) *wsCloser {
416-
return &wsCloser{ws: ws}
417-
}
418-
419409
// Close closes the WebSocket connection exactly once.
420410
func (wc *wsCloser) Close() error {
421411
var err error
@@ -491,7 +481,7 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) {
491481
log.Printf("WebSocket Handle() got IP: %s", ip)
492482

493483
// Wrap WebSocket with sync.Once closer to prevent double-close
494-
wc := newWSCloser(ws)
484+
wc := &wsCloser{ws: ws}
495485

496486
// Ensure WebSocket is properly closed (client will be set later if connection succeeds)
497487
var client *Client
@@ -508,7 +498,8 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) {
508498
})
509499

510500
// Get reservation token from context (set by main.go before upgrade)
511-
reservationToken, _ := ws.Request().Context().Value("reservation_token").(string)
501+
// Context key is a string type for package boundary crossing
502+
reservationToken, _ := ws.Request().Context().Value("reservation_token").(string) //nolint:errcheck // Type assertion intentionally unchecked - empty string is valid default
512503
if reservationToken == "" {
513504
// No reservation token - this should not happen in production
514505
// (main.go always sets it), but handle gracefully for tests

pkg/webhook/handler.go

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
182182
// GitHub webhooks can fire before the pull_requests array is populated
183183
commitSHA := extractCommitSHA(eventType, payload)
184184
// Extract repo URL as fallback for org-based matching
185-
repoURL := extractRepoURL(payload)
185+
repoURL := ""
186+
if repo, ok := payload["repository"].(map[string]any); ok {
187+
if htmlURL, ok := repo["html_url"].(string); ok {
188+
repoURL = htmlURL
189+
}
190+
}
186191

187192
// If we can't extract repo URL, drop the event
188193
if repoURL == "" {
@@ -299,11 +304,15 @@ func ExtractPRURL(eventType string, payload map[string]any) string {
299304
}
300305
}
301306
// Log when we can't extract PR URL from check event
307+
payloadKeys := make([]string, 0, len(payload))
308+
for k := range payload {
309+
payloadKeys = append(payloadKeys, k)
310+
}
302311
logger.Warn("no PR URL found in check event", logger.Fields{
303312
"event_type": eventType,
304313
"has_check_run": payload["check_run"] != nil,
305314
"has_check_suite": payload["check_suite"] != nil,
306-
"payload_keys": getPayloadKeys(payload),
315+
"payload_keys": payloadKeys,
307316
})
308317
default:
309318
// For other event types, no PR URL can be extracted
@@ -378,15 +387,6 @@ func extractPRFromCheckEvent(checkEvent map[string]any, payload map[string]any,
378387
return constructedURL
379388
}
380389

381-
// getPayloadKeys returns the keys from a payload map for logging.
382-
func getPayloadKeys(payload map[string]any) []string {
383-
keys := make([]string, 0, len(payload))
384-
for k := range payload {
385-
keys = append(keys, k)
386-
}
387-
return keys
388-
}
389-
390390
// getMapKeys returns the keys from a map for logging.
391391
func getMapKeys(m map[string]any) []string {
392392
keys := make([]string, 0, len(m))
@@ -417,13 +417,3 @@ func extractCommitSHA(eventType string, payload map[string]any) string {
417417
return ""
418418
}
419419

420-
// extractRepoURL extracts the repository HTML URL from the payload.
421-
// This is used as a fallback when PR URL cannot be extracted (e.g., check event race condition).
422-
func extractRepoURL(payload map[string]any) string {
423-
if repo, ok := payload["repository"].(map[string]any); ok {
424-
if htmlURL, ok := repo["html_url"].(string); ok {
425-
return htmlURL
426-
}
427-
}
428-
return ""
429-
}

0 commit comments

Comments
 (0)