Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 44 additions & 38 deletions access_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"context"
"errors"
"fmt"
"net/http"
"strings"

"crypto/rand"
"encoding/binary"

"github.com/go-chi/transport"
"github.com/goware/base64"
"github.com/jxskiss/base62"
)
Expand All @@ -22,16 +24,10 @@ var (
ErrInvalidKeyLength = errors.New("invalid access key length")
)

type AccessKey string

func (a AccessKey) String() string {
return string(a)
}

func (a AccessKey) GetProjectID() (projectID uint64, err error) {
func GetProjectIDFromAccessKey(accessKey string) (projectID uint64, err error) {
var errs []error
for _, e := range SupportedEncodings {
projectID, err := e.Decode(a)
projectID, err := e.Decode(accessKey)
if err != nil {
errs = append(errs, fmt.Errorf("decode v%d: %w", e.Version(), err))
continue
Expand All @@ -41,34 +37,44 @@ func (a AccessKey) GetProjectID() (projectID uint64, err error) {
return 0, errors.Join(errs...)
}

func (a AccessKey) GetPrefix() string {
parts := strings.Split(a.String(), Separator)
if len(parts) < 2 {
return ""
}
return strings.Join(parts[:len(parts)-1], Separator)
}

var ErrUnsupportedEncoding = errors.New("unsupported access key encoding")

func GenerateAccessKey(ctx context.Context, projectID uint64) (AccessKey, error) {
func GenerateAccessKey(ctx context.Context, projectID uint64) string {
version, ok := GetVersion(ctx)
if !ok {
return DefaultEncoding.Encode(ctx, projectID), nil
return DefaultEncoding.Encode(ctx, projectID)
}

for _, e := range SupportedEncodings {
if e.Version() == version {
return e.Encode(ctx, projectID), nil
return e.Encode(ctx, projectID)
}
}
return "", ErrUnsupportedEncoding
return ""
}

func GetAccessKeyPrefix(accessKey string) string {
parts := strings.Split(accessKey, Separator)
if len(parts) < 2 {
return ""
}
return strings.Join(parts[:len(parts)-1], Separator)
}

func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper {
return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) {
r := transport.CloneRequest(req)

if accessKey, ok := GetAccessKey(req.Context()); ok {
r.Header.Set(HeaderAccessKey, accessKey)
}

return next.RoundTrip(r)
})
}

type Encoding interface {
Version() byte
Encode(ctx context.Context, projectID uint64) AccessKey
Decode(accessKey AccessKey) (projectID uint64, err error)
Encode(ctx context.Context, projectID uint64) string
Decode(accessKey string) (projectID uint64, err error)
}

const (
Expand All @@ -83,15 +89,15 @@ type V0 struct{}

func (V0) Version() byte { return 0 }

func (V0) Encode(_ context.Context, projectID uint64) AccessKey {
func (V0) Encode(_ context.Context, projectID uint64) string {
buf := make([]byte, sizeV0)
binary.BigEndian.PutUint64(buf, projectID)
_, _ = rand.Read(buf[8:])
return AccessKey(base62.EncodeToString(buf))
return base62.EncodeToString(buf)
}

func (V0) Decode(accessKey AccessKey) (projectID uint64, err error) {
buf, err := base62.DecodeString(accessKey.String())
func (V0) Decode(accessKey string) (projectID uint64, err error) {
buf, err := base62.DecodeString(accessKey)
if err != nil {
return 0, fmt.Errorf("base62 decode: %w", err)
}
Expand All @@ -107,16 +113,16 @@ type V1 struct{}

func (V1) Version() byte { return 1 }

func (v V1) Encode(_ context.Context, projectID uint64) AccessKey {
func (v V1) Encode(_ context.Context, projectID uint64) string {
buf := make([]byte, sizeV1)
buf[0] = v.Version()
binary.BigEndian.PutUint64(buf[1:], projectID)
_, _ = rand.Read(buf[9:])
return AccessKey(base64.Base64UrlEncode(buf))
return base64.Base64UrlEncode(buf)
}

func (V1) Decode(accessKey AccessKey) (projectID uint64, err error) {
buf, err := base64.Base64UrlDecode(accessKey.String())
func (V1) Decode(accessKey string) (projectID uint64, err error) {
buf, err := base64.Base64UrlDecode(accessKey)
if err != nil {
return 0, fmt.Errorf("base64 decode: %w", err)
}
Expand All @@ -137,19 +143,19 @@ const (

func (V2) Version() byte { return 2 }

func (v V2) Encode(ctx context.Context, projectID uint64) AccessKey {
func (v V2) Encode(ctx context.Context, projectID uint64) string {
buf := make([]byte, sizeV2)
buf[0] = v.Version()
binary.BigEndian.PutUint64(buf[1:], projectID)
_, _ = rand.Read(buf[9:])
return AccessKey(getPrefix(ctx) + Separator + base64.Base64UrlEncode(buf))
return getPrefix(ctx) + Separator + base64.Base64UrlEncode(buf)
}

func (V2) Decode(accessKey AccessKey) (projectID uint64, err error) {
parts := strings.Split(accessKey.String(), Separator)
raw := parts[len(parts)-1]
func (V2) Decode(accessKey string) (projectID uint64, err error) {
parts := strings.Split(accessKey, Separator)
accessKey = parts[len(parts)-1]

buf, err := base64.Base64UrlDecode(raw)
buf, err := base64.Base64UrlDecode(accessKey)
if err != nil {
return 0, fmt.Errorf("base64 decode: %w", err)
}
Expand Down
33 changes: 14 additions & 19 deletions access_key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,55 +14,50 @@ func TestAccessKeyEncoding(t *testing.T) {
t.Run("v0", func(t *testing.T) {
ctx := authcontrol.WithVersion(context.Background(), 0)
projectID := uint64(12345)
accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID)
require.NoError(t, err)
accessKey := authcontrol.GenerateAccessKey(ctx, projectID)
t.Log("=> k", accessKey)

outID, err := accessKey.GetProjectID()
outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey)
require.NoError(t, err)
require.Equal(t, projectID, outID)
})

t.Run("v1", func(t *testing.T) {
ctx := authcontrol.WithVersion(context.Background(), 1)
projectID := uint64(12345)
accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID)
require.NoError(t, err)
accessKey := authcontrol.GenerateAccessKey(ctx, projectID)
t.Log("=> k", accessKey)
outID, err := accessKey.GetProjectID()
outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey)
require.NoError(t, err)
require.Equal(t, projectID, outID)
})
t.Run("v2", func(t *testing.T) {
ctx := authcontrol.WithVersion(context.Background(), 2)
projectID := uint64(12345)
accessKey, err := authcontrol.GenerateAccessKey(ctx, projectID)
require.NoError(t, err)
t.Log("=> k", accessKey, "| prefix =>", accessKey.GetPrefix())
outID, err := accessKey.GetProjectID()
accessKey := authcontrol.GenerateAccessKey(ctx, projectID)
t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey))
outID, err := authcontrol.GetProjectIDFromAccessKey(accessKey)
require.NoError(t, err)
require.Equal(t, projectID, outID)

ctx = authcontrol.WithPrefix(ctx, "newprefix:dev")

accessKey2, err := authcontrol.GenerateAccessKey(ctx, projectID)
require.NoError(t, err)
t.Log("=> k", accessKey2, "| prefix =>", accessKey2.GetPrefix())
outID, err = accessKey2.GetProjectID()
accessKey2 := authcontrol.GenerateAccessKey(ctx, projectID)
t.Log("=> k", accessKey2, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey2))
outID, err = authcontrol.GetProjectIDFromAccessKey(accessKey2)
require.NoError(t, err)
require.Equal(t, projectID, outID)
// retrocompatibility with the older prefix
outID, err = accessKey.GetProjectID()
outID, err = authcontrol.GetProjectIDFromAccessKey(accessKey)
require.NoError(t, err)
require.Equal(t, projectID, outID)
})
}

func TestDecode(t *testing.T) {
ctx := authcontrol.WithVersion(context.Background(), 2)
accessKey, err := authcontrol.GenerateAccessKey(ctx, 237)
require.NoError(t, err)
t.Log("=> k", accessKey, "| prefix =>", accessKey.GetPrefix())
accessKey := authcontrol.GenerateAccessKey(ctx, 237)
t.Log("=> k", accessKey, "| prefix =>", authcontrol.GetAccessKeyPrefix(accessKey))
}

func TestForwardAccessKeyTransport(t *testing.T) {
Expand All @@ -76,7 +71,7 @@ func TestForwardAccessKeyTransport(t *testing.T) {

// Create context with access key
accessKey := "test-access-key-123"
ctx := authcontrol.WithAccessKey(context.Background(), authcontrol.AccessKey(accessKey))
ctx := authcontrol.WithAccessKey(context.Background(), accessKey)

// Create HTTP client with ForwardAccessKeyTransport
client := &http.Client{
Expand Down
2 changes: 1 addition & 1 deletion cmd/access_key/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var decodeCmd = &cobra.Command{
if len(args) != 1 {
return fmt.Errorf("access key is required")
}
accessKey := authcontrol.AccessKey(args[0])
accessKey := args[0]
var (
projectID uint64
version byte
Expand Down
20 changes: 3 additions & 17 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (

"github.com/0xsequence/authcontrol/proto"
"github.com/go-chi/jwtauth/v5"
"github.com/go-chi/transport"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
)
Expand All @@ -23,10 +22,10 @@ const (
HeaderAccessKey = "X-Access-Key"
)

type AccessKeyFunc func(*http.Request) AccessKey
type AccessKeyFunc func(*http.Request) string

func AccessKeyFromHeader(r *http.Request) AccessKey {
return AccessKey(r.Header.Get(HeaderAccessKey))
func AccessKeyFromHeader(r *http.Request) string {
return r.Header.Get(HeaderAccessKey)
}

type ErrHandler func(r *http.Request, w http.ResponseWriter, err error)
Expand Down Expand Up @@ -199,16 +198,3 @@ func findProjectClaim(r *http.Request) (uint64, error) {
return 0, fmt.Errorf("invalid type: %T", val)
}
}

// ForwardAccessKeyTransport is a RoundTripper that forwards the access key from the request context to the request header.
func ForwardAccessKeyTransport(next http.RoundTripper) http.RoundTripper {
return transport.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) {
r := transport.CloneRequest(req)

if accessKey, ok := GetAccessKey(req.Context()); ok {
r.Header.Set(HeaderAccessKey, accessKey.String())
}

return next.RoundTrip(r)
})
}
4 changes: 2 additions & 2 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import (

const HeaderKey = "Test-Key"

func keyFunc(r *http.Request) authcontrol.AccessKey {
return authcontrol.AccessKey(r.Header.Get(HeaderKey))
func keyFunc(r *http.Request) string {
return r.Header.Get(HeaderKey)
}

type requestOption func(r *http.Request)
Expand Down
6 changes: 3 additions & 3 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ func GetService(ctx context.Context) (string, bool) {
// WithAccessKey adds the access key to the context.
//
// TODO: Deprecate this in favor of Session middleware with a JWT token.
func WithAccessKey(ctx context.Context, accessKey AccessKey) context.Context {
func WithAccessKey(ctx context.Context, accessKey string) context.Context {
return context.WithValue(ctx, ctxKeyAccessKey, accessKey)
}

// GetAccessKey returns the access key from the context.
func GetAccessKey(ctx context.Context) (AccessKey, bool) {
v, ok := ctx.Value(ctxKeyAccessKey).(AccessKey)
func GetAccessKey(ctx context.Context) (string, bool) {
v, ok := ctx.Value(ctxKeyAccessKey).(string)
return v, ok
}

Expand Down
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ require (
github.com/go-chi/metrics v0.1.0
github.com/go-chi/traceid v0.2.0
github.com/go-chi/transport v0.4.0
github.com/goware/base64 v0.1.0
github.com/jxskiss/base62 v1.1.0
github.com/lestrrat-go/jwx/v2 v2.1.3
github.com/spf13/cobra v1.9.1
github.com/stretchr/testify v1.10.0
)

Expand All @@ -23,7 +20,9 @@ require (
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect
github.com/goccy/go-json v0.10.3 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/goware/base64 v0.1.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jxskiss/base62 v1.1.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
Expand All @@ -38,6 +37,7 @@ require (
github.com/prometheus/procfs v0.15.1 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/spf13/cobra v1.9.1 // indirect
github.com/spf13/pflag v1.0.6 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/sync v0.10.0 // indirect
Expand Down
1 change: 0 additions & 1 deletion go.work.sum
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtX
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
Expand Down
4 changes: 2 additions & 2 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ func Session(cfg Options) func(next http.Handler) http.Handler {

ctx = WithAccessKey(ctx, accessKey)

projectID, _ = accessKey.GetProjectID()
projectID, _ = GetProjectIDFromAccessKey(accessKey)
ctx = withProjectID(ctx, projectID)
httplog.SetAttrs(ctx, slog.Uint64("projectId", projectID))
break
Expand Down Expand Up @@ -332,7 +332,7 @@ func PropagateAccessKey(headerContextFuncs ...func(context.Context, http.Header)

if accessKey, ok := GetAccessKey(ctx); ok {
h := http.Header{
HeaderAccessKey: []string{accessKey.String()},
HeaderAccessKey: []string{accessKey},
}
for _, fn := range headerContextFuncs {
ctx, _ = fn(ctx, h)
Expand Down
3 changes: 2 additions & 1 deletion middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ func TestCustomErrHandler(t *testing.T) {

r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

claims := map[string]any{"service": "client_service"}
var claims map[string]any
claims = map[string]any{"service": "client_service"}

// Valid Request
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
Expand Down