diff --git a/access_key.go b/access_key.go index 15cd072..99a23ec 100644 --- a/access_key.go +++ b/access_key.go @@ -4,11 +4,13 @@ import ( "context" "errors" "fmt" + "net/http" "strings" "crypto/rand" "encoding/binary" + "github.com/go-chi/transport" "github.com/goware/base64" "github.com/jxskiss/base62" ) @@ -22,16 +24,10 @@ var ( ErrInvalidKeyLength = errors.New("invalid access key length") ) -type AccessKey string - -func (a AccessKey) String() string { - return string(a) -} - -func (a AccessKey) GetProjectID() (projectID uint64, err error) { +func GetProjectIDFromAccessKey(accessKey string) (projectID uint64, err error) { var errs []error for _, e := range SupportedEncodings { - projectID, err := e.Decode(a) + projectID, err := e.Decode(accessKey) if err != nil { errs = append(errs, fmt.Errorf("decode v%d: %w", e.Version(), err)) continue @@ -41,34 +37,44 @@ func (a AccessKey) GetProjectID() (projectID uint64, err error) { return 0, errors.Join(errs...) } -func (a AccessKey) GetPrefix() string { - parts := strings.Split(a.String(), Separator) - if len(parts) < 2 { - return "" - } - return strings.Join(parts[:len(parts)-1], Separator) -} - -var ErrUnsupportedEncoding = errors.New("unsupported access key encoding") - -func GenerateAccessKey(ctx context.Context, projectID uint64) (AccessKey, error) { +func GenerateAccessKey(ctx context.Context, projectID uint64) string { version, ok := GetVersion(ctx) if !ok { - return DefaultEncoding.Encode(ctx, projectID), nil + return DefaultEncoding.Encode(ctx, projectID) } for _, e := range SupportedEncodings { if e.Version() == version { - return e.Encode(ctx, projectID), nil + return e.Encode(ctx, projectID) } } - return "", ErrUnsupportedEncoding + return "" +} + +func GetAccessKeyPrefix(accessKey string) string { + parts := strings.Split(accessKey, Separator) + if len(parts) < 2 { + return "" + } + return strings.Join(parts[:len(parts)-1], Separator) +} + +func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper { + return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) { + r := transport.CloneRequest(req) + + if accessKey, ok := GetAccessKey(req.Context()); ok { + r.Header.Set(HeaderAccessKey, accessKey) + } + + return next.RoundTrip(r) + }) } type Encoding interface { Version() byte - Encode(ctx context.Context, projectID uint64) AccessKey - Decode(accessKey AccessKey) (projectID uint64, err error) + Encode(ctx context.Context, projectID uint64) string + Decode(accessKey string) (projectID uint64, err error) } const ( @@ -83,15 +89,15 @@ type V0 struct{} func (V0) Version() byte { return 0 } -func (V0) Encode(_ context.Context, projectID uint64) AccessKey { +func (V0) Encode(_ context.Context, projectID uint64) string { buf := make([]byte, sizeV0) binary.BigEndian.PutUint64(buf, projectID) _, _ = rand.Read(buf[8:]) - return AccessKey(base62.EncodeToString(buf)) + return base62.EncodeToString(buf) } -func (V0) Decode(accessKey AccessKey) (projectID uint64, err error) { - buf, err := base62.DecodeString(accessKey.String()) +func (V0) Decode(accessKey string) (projectID uint64, err error) { + buf, err := base62.DecodeString(accessKey) if err != nil { return 0, fmt.Errorf("base62 decode: %w", err) } @@ -107,16 +113,16 @@ type V1 struct{} func (V1) Version() byte { return 1 } -func (v V1) Encode(_ context.Context, projectID uint64) AccessKey { +func (v V1) Encode(_ context.Context, projectID uint64) string { buf := make([]byte, sizeV1) buf[0] = v.Version() binary.BigEndian.PutUint64(buf[1:], projectID) _, _ = rand.Read(buf[9:]) - return AccessKey(base64.Base64UrlEncode(buf)) + return base64.Base64UrlEncode(buf) } -func (V1) Decode(accessKey AccessKey) (projectID uint64, err error) { - buf, err := base64.Base64UrlDecode(accessKey.String()) +func (V1) Decode(accessKey string) (projectID uint64, err error) { + buf, err := base64.Base64UrlDecode(accessKey) if err != nil { return 0, fmt.Errorf("base64 decode: %w", err) } @@ -137,19 +143,19 @@ const ( func (V2) Version() byte { return 2 } -func (v V2) Encode(ctx context.Context, projectID uint64) AccessKey { +func (v V2) Encode(ctx context.Context, projectID uint64) string { buf := make([]byte, sizeV2) buf[0] = v.Version() binary.BigEndian.PutUint64(buf[1:], projectID) _, _ = rand.Read(buf[9:]) - return AccessKey(getPrefix(ctx) + Separator + base64.Base64UrlEncode(buf)) + return getPrefix(ctx) + Separator + base64.Base64UrlEncode(buf) } -func (V2) Decode(accessKey AccessKey) (projectID uint64, err error) { - parts := strings.Split(accessKey.String(), Separator) - raw := parts[len(parts)-1] +func (V2) Decode(accessKey string) (projectID uint64, err error) { + parts := strings.Split(accessKey, Separator) + accessKey = parts[len(parts)-1] - buf, err := base64.Base64UrlDecode(raw) + buf, err := base64.Base64UrlDecode(accessKey) if err != nil { return 0, fmt.Errorf("base64 decode: %w", err) } diff --git a/access_key_test.go b/access_key_test.go index 2c27497..4f98249 100644 --- a/access_key_test.go +++ b/access_key_test.go @@ -14,11 +14,10 @@ func TestAccessKeyEncoding(t *testing.T) { t.Run("v0", func(t *testing.T) { ctx := authcontrol.WithVersion(context.Background(), 0) projectID := uint64(12345) - accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID) - require.NoError(t, err) + accessKey := authcontrol.GenerateAccessKey(ctx, projectID) t.Log("=> k", accessKey) - outID, err := accessKey.GetProjectID() + outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey) require.NoError(t, err) require.Equal(t, projectID, outID) }) @@ -26,33 +25,30 @@ func TestAccessKeyEncoding(t *testing.T) { t.Run("v1", func(t *testing.T) { ctx := authcontrol.WithVersion(context.Background(), 1) projectID := uint64(12345) - accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID) - require.NoError(t, err) + accessKey := authcontrol.GenerateAccessKey(ctx, projectID) t.Log("=> k", accessKey) - outID, err := accessKey.GetProjectID() + outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey) require.NoError(t, err) require.Equal(t, projectID, outID) }) t.Run("v2", func(t *testing.T) { ctx := authcontrol.WithVersion(context.Background(), 2) projectID := uint64(12345) - accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID) - require.NoError(t, err) - t.Log("=> k", accessKey, "| prefix =>", accessKey.GetPrefix()) - outID, err := accessKey.GetProjectID() + accessKey := authcontrol.GenerateAccessKey(ctx, projectID) + t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey)) + outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey) require.NoError(t, err) require.Equal(t, projectID, outID) ctx = authcontrol.WithPrefix(ctx, "newprefix:dev") - accessKey2, err := authcontrol.GenerateAccessKey(ctx, projectID) - require.NoError(t, err) - t.Log("=> k", accessKey2, "| prefix =>", accessKey2.GetPrefix()) - outID, err = accessKey2.GetProjectID() + accessKey2 := authcontrol.GenerateAccessKey(ctx, projectID) + t.Log("=> k", accessKey2, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey2)) + outID, err = authcontrol.GetProjectIDFromAccessKey(accessKey2) require.NoError(t, err) require.Equal(t, projectID, outID) // retrocompatibility with the older prefix - outID, err = accessKey.GetProjectID() + outID, err = authcontrol.GetProjectIDFromAccessKey(accessKey) require.NoError(t, err) require.Equal(t, projectID, outID) }) @@ -60,9 +56,8 @@ func TestAccessKeyEncoding(t *testing.T) { func TestDecode(t *testing.T) { ctx := authcontrol.WithVersion(context.Background(), 2) - accessKey, err := authcontrol.GenerateAccessKey(ctx, 237) - require.NoError(t, err) - t.Log("=> k", accessKey, "| prefix =>", accessKey.GetPrefix()) + accessKey := authcontrol.GenerateAccessKey(ctx, 237) + t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey)) } func TestForwardAccessKeyTransport(t *testing.T) { @@ -76,7 +71,7 @@ func TestForwardAccessKeyTransport(t *testing.T) { // Create context with access key accessKey := "test-access-key-123" - ctx := authcontrol.WithAccessKey(context.Background(), authcontrol.AccessKey(accessKey)) + ctx := authcontrol.WithAccessKey(context.Background(), accessKey) // Create HTTP client with ForwardAccessKeyTransport client := &http.Client{ diff --git a/cmd/access_key/main.go b/cmd/access_key/main.go index 9b1b125..152b1d5 100644 --- a/cmd/access_key/main.go +++ b/cmd/access_key/main.go @@ -28,7 +28,7 @@ var decodeCmd = &cobra.Command{ if len(args) != 1 { return fmt.Errorf("access key is required") } - accessKey := authcontrol.AccessKey(args[0]) + accessKey := args[0] var ( projectID uint64 version byte diff --git a/common.go b/common.go index 8d4cebe..1dc3b8a 100644 --- a/common.go +++ b/common.go @@ -14,7 +14,6 @@ import ( "github.com/0xsequence/authcontrol/proto" "github.com/go-chi/jwtauth/v5" - "github.com/go-chi/transport" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" ) @@ -23,10 +22,10 @@ const ( HeaderAccessKey = "X-Access-Key" ) -type AccessKeyFunc func(*http.Request) AccessKey +type AccessKeyFunc func(*http.Request) string -func AccessKeyFromHeader(r *http.Request) AccessKey { - return AccessKey(r.Header.Get(HeaderAccessKey)) +func AccessKeyFromHeader(r *http.Request) string { + return r.Header.Get(HeaderAccessKey) } type ErrHandler func(r *http.Request, w http.ResponseWriter, err error) @@ -199,16 +198,3 @@ func findProjectClaim(r *http.Request) (uint64, error) { return 0, fmt.Errorf("invalid type: %T", val) } } - -// ForwardAccessKeyTransport is a RoundTripper that forwards the access key from the request context to the request header. -func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper { - return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) { - r := transport.CloneRequest(req) - - if accessKey, ok := GetAccessKey(req.Context()); ok { - r.Header.Set(HeaderAccessKey, accessKey.String()) - } - - return next.RoundTrip(r) - }) -} diff --git a/common_test.go b/common_test.go index b456779..17a1e19 100644 --- a/common_test.go +++ b/common_test.go @@ -17,8 +17,8 @@ import ( const HeaderKey = "Test-Key" -func keyFunc(r *http.Request) authcontrol.AccessKey { - return authcontrol.AccessKey(r.Header.Get(HeaderKey)) +func keyFunc(r *http.Request) string { + return r.Header.Get(HeaderKey) } type requestOption func(r *http.Request) diff --git a/context.go b/context.go index da76542..b16a1e2 100644 --- a/context.go +++ b/context.go @@ -102,13 +102,13 @@ func GetService(ctx context.Context) (string, bool) { // WithAccessKey adds the access key to the context. // // TODO: Deprecate this in favor of Session middleware with a JWT token. -func WithAccessKey(ctx context.Context, accessKey AccessKey) context.Context { +func WithAccessKey(ctx context.Context, accessKey string) context.Context { return context.WithValue(ctx, ctxKeyAccessKey, accessKey) } // GetAccessKey returns the access key from the context. -func GetAccessKey(ctx context.Context) (AccessKey, bool) { - v, ok := ctx.Value(ctxKeyAccessKey).(AccessKey) +func GetAccessKey(ctx context.Context) (string, bool) { + v, ok := ctx.Value(ctxKeyAccessKey).(string) return v, ok } diff --git a/go.mod b/go.mod index 36c8dc2..5195ab9 100644 --- a/go.mod +++ b/go.mod @@ -9,10 +9,7 @@ require ( github.com/go-chi/metrics v0.1.0 github.com/go-chi/traceid v0.2.0 github.com/go-chi/transport v0.4.0 - github.com/goware/base64 v0.1.0 - github.com/jxskiss/base62 v1.1.0 github.com/lestrrat-go/jwx/v2 v2.1.3 - github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.10.0 ) @@ -23,7 +20,9 @@ require ( github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/goware/base64 v0.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jxskiss/base62 v1.1.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/lestrrat-go/blackmagic v1.0.2 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect @@ -38,6 +37,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/segmentio/asm v1.2.0 // indirect + github.com/spf13/cobra v1.9.1 // indirect github.com/spf13/pflag v1.0.6 // indirect golang.org/x/crypto v0.31.0 // indirect golang.org/x/sync v0.10.0 // indirect diff --git a/go.work.sum b/go.work.sum index 01f0bef..4c8099d 100644 --- a/go.work.sum +++ b/go.work.sum @@ -15,7 +15,6 @@ github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtX golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= -golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/middleware.go b/middleware.go index fbfebe0..338a3f6 100644 --- a/middleware.go +++ b/middleware.go @@ -251,7 +251,7 @@ func Session(cfg Options) func(next http.Handler) http.Handler { ctx = WithAccessKey(ctx, accessKey) - projectID, _ = accessKey.GetProjectID() + projectID, _ = GetProjectIDFromAccessKey(accessKey) ctx = withProjectID(ctx, projectID) httplog.SetAttrs(ctx, slog.Uint64("projectId", projectID)) break @@ -332,7 +332,7 @@ func PropagateAccessKey(headerContextFuncs ...func(context.Context, http.Header) if accessKey, ok := GetAccessKey(ctx); ok { h := http.Header{ - HeaderAccessKey: []string{accessKey.String()}, + HeaderAccessKey: []string{accessKey}, } for _, fn := range headerContextFuncs { ctx, _ = fn(ctx, h) diff --git a/middleware_test.go b/middleware_test.go index 2e73571..147a23c 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -325,7 +325,8 @@ func TestCustomErrHandler(t *testing.T) { r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - claims := map[string]any{"service": "client_service"} + var claims map[string]any + claims = map[string]any{"service": "client_service"} // Valid Request ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))