diff --git a/cmd/riverui/auth_middleware.go b/cmd/riverui/auth_middleware.go index dddf634c..f3b3f53b 100644 --- a/cmd/riverui/auth_middleware.go +++ b/cmd/riverui/auth_middleware.go @@ -7,13 +7,14 @@ import ( ) type authMiddleware struct { - username string - password string + password string + pathPrefix string // HTTP path prefix + username string } func (m *authMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { - if isReqAuthorized(req, m.username, m.password) { + if m.isReqAuthorized(req) { next.ServeHTTP(res, req) return } @@ -23,13 +24,13 @@ func (m *authMiddleware) Middleware(next http.Handler) http.Handler { }) } -func isReqAuthorized(req *http.Request, username, password string) bool { - reqUsername, reqPassword, ok := req.BasicAuth() - - isHealthCheck := strings.Contains(req.URL.Path, "/api/health-checks/") - isValidAuth := ok && - subtle.ConstantTimeCompare([]byte(reqUsername), []byte(username)) == 1 && - subtle.ConstantTimeCompare([]byte(reqPassword), []byte(password)) == 1 +func (m *authMiddleware) isReqAuthorized(req *http.Request) bool { + if strings.HasPrefix(req.URL.Path, m.pathPrefix+"/api/health-checks/") { + return true + } - return isHealthCheck || isValidAuth + reqUsername, reqPassword, ok := req.BasicAuth() + return ok && + subtle.ConstantTimeCompare([]byte(reqUsername), []byte(m.username)) == 1 && + subtle.ConstantTimeCompare([]byte(reqPassword), []byte(m.password)) == 1 } diff --git a/cmd/riverui/auth_middleware_test.go b/cmd/riverui/auth_middleware_test.go index 392b1ddb..b55e6211 100644 --- a/cmd/riverui/auth_middleware_test.go +++ b/cmd/riverui/auth_middleware_test.go @@ -1,79 +1,91 @@ package main import ( - "cmp" "context" "net/http" "net/http/httptest" - "os" "testing" "github.com/stretchr/testify/require" - "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/apiframe/apimiddleware" ) func TestAuthMiddleware(t *testing.T) { - var ( - ctx = context.Background() - databaseURL = cmp.Or(os.Getenv("TEST_DATABASE_URL"), "postgres://localhost/river_test") - basicAuthUser = "test_auth_user" + t.Parallel() + + const ( basicAuthPassword = "test_auth_pass" + basicAuthUsername = "test_auth_user" ) - t.Setenv("DEV", "true") - t.Setenv("DATABASE_URL", databaseURL) - t.Setenv("RIVER_BASIC_AUTH_USER", basicAuthUser) - t.Setenv("RIVER_BASIC_AUTH_PASS", basicAuthPassword) + ctx := context.Background() + + type testBundle struct { + handler http.Handler + } - setup := func(t *testing.T, prefix string) http.Handler { + setup := func(t *testing.T) (*authMiddleware, *testBundle) { t.Helper() - initRes, err := initServer(ctx, riversharedtest.Logger(t), prefix) - require.NoError(t, err) - t.Cleanup(initRes.dbPool.Close) - return initRes.httpServer.Handler + authMiddleware := &authMiddleware{username: basicAuthUsername, password: basicAuthPassword} + + return authMiddleware, &testBundle{ + handler: apimiddleware.NewMiddlewareStack( + authMiddleware, + ).Mount(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })), + } } - t.Run("Unauthorized", func(t *testing.T) { //nolint:paralleltest - handler := setup(t, "/") - req := httptest.NewRequest(http.MethodGet, "/api/jobs", nil) - recorder := httptest.NewRecorder() + t.Run("Unauthorized", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) - handler.ServeHTTP(recorder, req) + req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/api/jobs", nil) + recorder := httptest.NewRecorder() + bundle.handler.ServeHTTP(recorder, req) require.Equal(t, http.StatusUnauthorized, recorder.Code) }) t.Run("Authorized", func(t *testing.T) { //nolint:paralleltest - handler := setup(t, "/") - req := httptest.NewRequest(http.MethodGet, "/api/jobs", nil) - req.SetBasicAuth(basicAuthUser, basicAuthPassword) + t.Parallel() - recorder := httptest.NewRecorder() + _, bundle := setup(t) - handler.ServeHTTP(recorder, req) + req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/api/jobs", nil) + req.SetBasicAuth(basicAuthUsername, basicAuthPassword) + recorder := httptest.NewRecorder() + bundle.handler.ServeHTTP(recorder, req) require.Equal(t, http.StatusOK, recorder.Code) }) - t.Run("Healthcheck exemption", func(t *testing.T) { //nolint:paralleltest - handler := setup(t, "/") - req := httptest.NewRequest(http.MethodGet, "/api/health-checks/complete", nil) - recorder := httptest.NewRecorder() + t.Run("HealthCheckExemption", func(t *testing.T) { //nolint:paralleltest + t.Parallel() - handler.ServeHTTP(recorder, req) + _, bundle := setup(t) + req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/api/health-checks/complete", nil) + + recorder := httptest.NewRecorder() + bundle.handler.ServeHTTP(recorder, req) require.Equal(t, http.StatusOK, recorder.Code) }) - t.Run("Healthcheck exemption with prefix", func(t *testing.T) { //nolint:paralleltest - handler := setup(t, "/test-prefix") - req := httptest.NewRequest(http.MethodGet, "/test-prefix/api/health-checks/complete", nil) - recorder := httptest.NewRecorder() + t.Run("HealthCheckExemptionWithPrefix", func(t *testing.T) { //nolint:paralleltest + t.Parallel() + + middleware, bundle := setup(t) + middleware.pathPrefix = "/test-prefix" - handler.ServeHTTP(recorder, req) + req := httptest.NewRequestWithContext(ctx, http.MethodGet, "/test-prefix/api/health-checks/complete", nil) + recorder := httptest.NewRecorder() + bundle.handler.ServeHTTP(recorder, req) require.Equal(t, http.StatusOK, recorder.Code) }) } diff --git a/cmd/riverui/main.go b/cmd/riverui/main.go index 42b83d25..5df149dd 100644 --- a/cmd/riverui/main.go +++ b/cmd/riverui/main.go @@ -163,7 +163,7 @@ func initServer(ctx context.Context, logger *slog.Logger, pathPrefix string) (*i apimiddleware.MiddlewareFunc(logHandler), ) if basicAuthUsername != "" && basicAuthPassword != "" { - middlewareStack.Use(&authMiddleware{username: basicAuthUsername, password: basicAuthPassword}) + middlewareStack.Use(&authMiddleware{username: basicAuthUsername, password: basicAuthPassword, pathPrefix: pathPrefix}) } return &initServerResult{