Skip to content

Commit f4e6be2

Browse files
authored
Merge pull request #335 from gatewayd-io/shutdown-metrics-server-gracefully
Shutdown metrics server gracefully
2 parents cb97fbe + 2b6f503 commit f4e6be2

File tree

9 files changed

+78
-13
lines changed

9 files changed

+78
-13
lines changed

cmd/run.go

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cmd
33
import (
44
"context"
55
"crypto/tls"
6+
"errors"
67
"fmt"
78
"log"
89
"net/http"
@@ -54,6 +55,7 @@ var (
5455
globalConfigFile string
5556
conf *config.Config
5657
pluginRegistry *plugin.Registry
58+
metricsServer *http.Server
5759

5860
UsageReportURL = "localhost:59091"
5961

@@ -72,6 +74,7 @@ func StopGracefully(
7274
pluginTimeoutCtx context.Context,
7375
sig os.Signal,
7476
metricsMerger *metrics.Merger,
77+
metricsServer *http.Server,
7578
pluginRegistry *plugin.Registry,
7679
logger zerolog.Logger,
7780
servers map[string]*network.Server,
@@ -110,6 +113,16 @@ func StopGracefully(
110113
logger.Info().Msg("Stopped metrics merger")
111114
span.AddEvent("Stopped metrics merger")
112115
}
116+
if metricsServer != nil {
117+
//nolint:contextcheck
118+
if err := metricsServer.Shutdown(context.Background()); err != nil {
119+
logger.Error().Err(err).Msg("Failed to stop metrics server")
120+
span.RecordError(err)
121+
} else {
122+
logger.Info().Msg("Stopped metrics server")
123+
span.AddEvent("Stopped metrics server")
124+
}
125+
}
113126
for name, server := range servers {
114127
logger.Info().Str("name", name).Msg("Stopping server")
115128
server.Shutdown() //nolint:contextcheck
@@ -352,7 +365,14 @@ var runCmd = &cobra.Command{
352365
span.RecordError(err)
353366
sentry.CaptureException(err)
354367
}
355-
next.ServeHTTP(responseWriter, request)
368+
// The WriteHeader method intentionally does nothing, to prevent a bug
369+
// in the merging metrics that causes the headers to be written twice,
370+
// which results in an error: "http: superfluous response.WriteHeader call".
371+
next.ServeHTTP(
372+
&metrics.HeaderBypassResponseWriter{
373+
ResponseWriter: responseWriter,
374+
},
375+
request)
356376
}
357377
return http.HandlerFunc(handler)
358378
}
@@ -371,6 +391,7 @@ var runCmd = &cobra.Command{
371391
if conf.Plugin.EnableMetricsMerger && metricsMerger != nil {
372392
handler = mergedMetricsHandler(handler)
373393
}
394+
374395
// Check if the metrics server is already running before registering the handler.
375396
if _, err = http.Get(address); err != nil { //nolint:gosec
376397
http.Handle(metricsConfig.Path, gziphandler.GzipHandler(handler))
@@ -379,16 +400,21 @@ var runCmd = &cobra.Command{
379400
span.RecordError(err)
380401
}
381402

382-
//nolint:gosec
383-
if err = http.ListenAndServe(
384-
metricsConfig.Address, nil); err != nil {
403+
// Create a new metrics server.
404+
metricsServer = &http.Server{
405+
Addr: metricsConfig.Address,
406+
Handler: handler,
407+
ReadHeaderTimeout: metricsConfig.GetReadHeaderTimeout(),
408+
}
409+
410+
// Start the metrics server.
411+
if err = metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
385412
logger.Error().Err(err).Msg("Failed to start metrics server")
386413
span.RecordError(err)
387414
}
388415
}(conf.Global.Metrics[config.Default], logger)
389416

390417
// This is a notification hook, so we don't care about the result.
391-
// TODO: Use a context with a timeout
392418
if data, ok := conf.GlobalKoanf.Get("loggers").(map[string]interface{}); ok {
393419
_, err = pluginRegistry.Run(
394420
pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_LOGGER)
@@ -723,6 +749,7 @@ var runCmd = &cobra.Command{
723749
pluginTimeoutCtx,
724750
sig,
725751
metricsMerger,
752+
metricsServer,
726753
pluginRegistry,
727754
logger,
728755
servers,

cmd/run_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ func Test_runCmd(t *testing.T) {
3434
nil,
3535
nil,
3636
nil,
37+
nil,
3738
loggers[config.Default],
3839
servers,
3940
stopChan,
@@ -115,6 +116,7 @@ func Test_runCmdWithCachePlugin(t *testing.T) {
115116
nil,
116117
nil,
117118
nil,
119+
nil,
118120
loggers[config.Default],
119121
servers,
120122
stopChan,

config/config.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,10 @@ func (c *Config) LoadDefaults(ctx context.Context) {
101101
}
102102

103103
defaultMetric := Metrics{
104-
Enabled: true,
105-
Address: DefaultMetricsAddress,
106-
Path: DefaultMetricsPath,
104+
Enabled: true,
105+
Address: DefaultMetricsAddress,
106+
Path: DefaultMetricsPath,
107+
ReadHeaderTimeout: DefaultReadHeaderTimeout,
107108
}
108109

109110
defaultClient := Client{

config/constants.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,9 @@ const (
123123
ChecksumBufferSize = 65536
124124

125125
// Metrics constants.
126-
DefaultMetricsAddress = "localhost:9090"
127-
DefaultMetricsPath = "/metrics"
126+
DefaultMetricsAddress = "localhost:9090"
127+
DefaultMetricsPath = "/metrics"
128+
DefaultReadHeaderTimeout = 10 * time.Second
128129

129130
// Sentry constants.
130131
DefaultTraceSampleRate = 0.2

config/getters.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,10 @@ func GetDefaultConfigFilePath(filename string) string {
269269
// The fallback is the current directory.
270270
return filepath.Join("./", filename)
271271
}
272+
273+
func (m Metrics) GetReadHeaderTimeout() time.Duration {
274+
if m.ReadHeaderTimeout <= 0 {
275+
return DefaultReadHeaderTimeout
276+
}
277+
return m.ReadHeaderTimeout
278+
}

config/getters_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,9 @@ func TestGetPlugins(t *testing.T) {
128128
func TestGetDefaultConfigFilePath(t *testing.T) {
129129
assert.Equal(t, GlobalConfigFilename, GetDefaultConfigFilePath(GlobalConfigFilename))
130130
}
131+
132+
// TestGetReadTimeout tests the GetReadTimeout function.
133+
func TestGetReadHeaderTimeout(t *testing.T) {
134+
metrics := Metrics{}
135+
assert.Equal(t, DefaultReadHeaderTimeout, metrics.GetReadHeaderTimeout())
136+
}

config/types.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ type Logger struct {
6868
}
6969

7070
type Metrics struct {
71-
Enabled bool `json:"enabled"`
72-
Address string `json:"address"`
73-
Path string `json:"path"`
71+
Enabled bool `json:"enabled"`
72+
Address string `json:"address"`
73+
Path string `json:"path"`
74+
ReadHeaderTimeout time.Duration `json:"readHeaderTimeout" jsonschema:"oneof_type=string;integer"`
7475
}
7576

7677
type Pool struct {

gatewayd.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ metrics:
2424
enabled: True
2525
address: localhost:9090
2626
path: /metrics
27+
readHeaderTimeout: 10s # duration, prevents Slowloris attacks
2728

2829
clients:
2930
default:

metrics/utils.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package metrics
2+
3+
import "net/http"
4+
5+
// HeaderBypassResponseWriter implements the http.ResponseWriter interface
6+
// and allows us to bypass the response header when writing to the response.
7+
// This is useful for merging metrics from multiple sources.
8+
type HeaderBypassResponseWriter struct {
9+
http.ResponseWriter
10+
}
11+
12+
// WriteHeader intentionally does nothing, but is required to
13+
// implement the http.ResponseWriter.
14+
func (w *HeaderBypassResponseWriter) WriteHeader(int) {}
15+
16+
// Write writes the data to the response.
17+
func (w *HeaderBypassResponseWriter) Write(data []byte) (int, error) {
18+
return w.ResponseWriter.Write(data) //nolint:wrapcheck
19+
}

0 commit comments

Comments
 (0)