Skip to content

Commit 1aac68f

Browse files
authored
Merge pull request #30 from tstromberg/main
fix reservation TOCTOU
2 parents efb533e + 77b9cfe commit 1aac68f

File tree

11 files changed

+1075
-192
lines changed

11 files changed

+1075
-192
lines changed

cmd/server/main.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ const (
3232
minMaskHeaderLength = 20 // Minimum header length before we show full "[REDACTED]"
3333
)
3434

35-
// getEnvOrDefault returns the value of the environment variable or the default if not set
35+
// getEnvOrDefault returns the value of the environment variable or the default if not set.
3636
func getEnvOrDefault(key, defaultValue string) string {
3737
if value := os.Getenv(key); value != "" {
3838
return value
@@ -232,8 +232,9 @@ func main() {
232232
return
233233
}
234234

235-
// Check connection limit before upgrade
236-
if !connLimiter.CanAdd(ip) {
235+
// Reserve a connection slot before upgrade (prevents TOCTOU race condition)
236+
reservationToken := connLimiter.Reserve(ip)
237+
if reservationToken == "" {
237238
log.Printf("WebSocket 429: connection limit ip=%s", ip)
238239
w.WriteHeader(http.StatusTooManyRequests)
239240
if _, err := w.Write([]byte("429 Too Many Requests: Connection limit exceeded\n")); err != nil {
@@ -242,6 +243,9 @@ func main() {
242243
return
243244
}
244245

246+
// Set reservation token in request context so websocket handler can commit it
247+
r = r.WithContext(context.WithValue(r.Context(), "reservation_token", reservationToken))
248+
245249
// Log successful auth and proceed to upgrade
246250
log.Printf("WebSocket UPGRADE: ip=%s duration=%v", ip, time.Since(startTime))
247251

pkg/client/client.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,15 +287,18 @@ func (c *Client) connect(ctx context.Context) error {
287287
}
288288
return fmt.Errorf("dial: %w", err)
289289
}
290-
c.logger.Info("✓ WebSocket connection ESTABLISHED successfully!")
290+
c.logger.Info("========================================")
291+
c.logger.Info(fmt.Sprintf("✅ WebSocket ESTABLISHED: %s (org: %s)", c.config.ServerURL, c.config.Organization))
292+
c.logger.Info("========================================")
291293

292294
// Store connection
293295
c.mu.Lock()
294296
c.ws = ws
295297
c.mu.Unlock()
296298

297299
defer func() {
298-
c.logger.Debug("Closing WebSocket connection")
300+
c.logger.Info("========================================")
301+
c.logger.Info(fmt.Sprintf("❌ WebSocket CLOSING: %s (org: %s)", c.config.ServerURL, c.config.Organization))
299302
c.mu.Lock()
300303
c.ws = nil
301304
c.mu.Unlock()
@@ -304,6 +307,7 @@ func (c *Client) connect(ctx context.Context) error {
304307
} else {
305308
c.logger.Info("✓ WebSocket connection closed cleanly")
306309
}
310+
c.logger.Info("========================================")
307311
}()
308312

309313
// Build subscription

pkg/logger/logger.go

Lines changed: 110 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,92 @@
1-
// Package logger provides structured logging utilities with field support
2-
// for better debugging and monitoring of webhook sprinkler operations.
1+
// Package logger provides structured logging using slog with hostname tracking
2+
// and short source file paths for better debugging across multiple instances.
33
package logger
44

55
import (
6+
"context"
67
"fmt"
7-
"log"
8-
"sort"
9-
"strings"
8+
"io"
9+
"log/slog"
10+
"os"
11+
"path/filepath"
12+
"runtime"
13+
"time"
1014
)
1115

1216
// Fields represents structured log fields.
1317
type Fields map[string]any
1418

15-
// WithFieldsf adds structured context to log messages with printf-style formatting.
16-
func WithFieldsf(fields Fields, format string, args ...any) {
17-
// Sort keys for consistent output
18-
keys := make([]string, 0, len(fields))
19-
for k := range fields {
20-
keys = append(keys, k)
21-
}
22-
sort.Strings(keys)
19+
var (
20+
// defaultLogger is the global logger instance.
21+
defaultLogger *slog.Logger
22+
// hostname is cached on init for performance.
23+
hostname string
24+
)
2325

24-
var parts []string
25-
for _, k := range keys {
26-
parts = append(parts, fmt.Sprintf("%s=%v", k, fields[k]))
26+
func init() {
27+
var err error
28+
hostname, err = os.Hostname()
29+
if err != nil {
30+
hostname = "unknown"
2731
}
2832

29-
msg := fmt.Sprintf(format, args...)
30-
if len(parts) > 0 {
31-
log.Printf("%s [%s]", msg, strings.Join(parts, " "))
32-
} else {
33-
log.Print(msg)
33+
// Initialize with default text handler
34+
defaultLogger = New(os.Stderr)
35+
}
36+
37+
// New creates a new slog logger with hostname and short source paths.
38+
func New(w io.Writer) *slog.Logger {
39+
opts := &slog.HandlerOptions{
40+
AddSource: true,
41+
Level: slog.LevelInfo,
42+
ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr {
43+
// Shorten source file paths to just basename:line
44+
if a.Key == slog.SourceKey {
45+
if source, ok := a.Value.Any().(*slog.Source); ok {
46+
source.File = filepath.Base(source.File)
47+
// Remove function name to keep it concise
48+
source.Function = ""
49+
}
50+
}
51+
return a
52+
},
3453
}
54+
55+
handler := slog.NewTextHandler(w, opts)
56+
logger := slog.New(handler)
57+
58+
// Add hostname to all log messages
59+
return logger.With("instance", hostname)
60+
}
61+
62+
// SetDefault sets the default logger.
63+
func SetDefault(l *slog.Logger) {
64+
defaultLogger = l
65+
}
66+
67+
// SetLogger sets the default logger (alias for SetDefault).
68+
func SetLogger(l *slog.Logger) {
69+
defaultLogger = l
70+
}
71+
72+
// Default returns the default logger.
73+
func Default() *slog.Logger {
74+
return defaultLogger
75+
}
76+
77+
// Hostname returns the cached hostname.
78+
func Hostname() string {
79+
return hostname
3580
}
3681

3782
// Info logs an info message with optional fields.
3883
func Info(msg string, fields Fields) {
39-
WithFieldsf(fields, "%s", msg)
84+
defaultLogger.LogAttrs(context.Background(), slog.LevelInfo, msg, attrsFromFields(fields)...)
85+
}
86+
87+
// Warn logs a warning message with optional fields.
88+
func Warn(msg string, fields Fields) {
89+
defaultLogger.LogAttrs(context.Background(), slog.LevelWarn, msg, attrsFromFields(fields)...)
4090
}
4191

4292
// Error logs an error message with optional fields.
@@ -45,10 +95,44 @@ func Error(msg string, err error, fields Fields) {
4595
fields = Fields{}
4696
}
4797
fields["error"] = err.Error()
48-
WithFieldsf(fields, "ERROR: %s", msg)
98+
defaultLogger.LogAttrs(context.Background(), slog.LevelError, msg, attrsFromFields(fields)...)
4999
}
50100

51-
// Warn logs a warning message with optional fields.
52-
func Warn(msg string, fields Fields) {
53-
WithFieldsf(fields, "WARNING: %s", msg)
101+
// Debug logs a debug message with optional fields.
102+
func Debug(msg string, fields Fields) {
103+
defaultLogger.LogAttrs(context.Background(), slog.LevelDebug, msg, attrsFromFields(fields)...)
104+
}
105+
106+
// attrsFromFields converts Fields to slog.Attr slice.
107+
func attrsFromFields(fields Fields) []slog.Attr {
108+
if fields == nil {
109+
return nil
110+
}
111+
attrs := make([]slog.Attr, 0, len(fields))
112+
for k, v := range fields {
113+
attrs = append(attrs, slog.Any(k, v))
114+
}
115+
return attrs
116+
}
117+
118+
// LogAt logs a message at the specified level with source information.
119+
// This is useful when you want to override the default source location.
120+
func LogAt(level slog.Level, skip int, msg string, fields Fields) {
121+
var pcs [1]uintptr
122+
runtime.Callers(skip+2, pcs[:])
123+
r := slog.NewRecord(
124+
time.Now(),
125+
level,
126+
msg,
127+
pcs[0],
128+
)
129+
r.AddAttrs(attrsFromFields(fields)...)
130+
_ = defaultLogger.Handler().Handle(context.Background(), r) //nolint:errcheck // Best effort logging
131+
}
132+
133+
// WithFieldsf provides backward compatibility for tests.
134+
// Deprecated: Use Info/Warn/Error with Fields instead.
135+
func WithFieldsf(fields Fields, format string, args ...any) {
136+
msg := fmt.Sprintf(format, args...)
137+
Info(msg, fields)
54138
}

0 commit comments

Comments
 (0)