Skip to content

Commit dd6256e

Browse files
fix(ers): Do not use auth header jwt in MultiStrategy ERS (#2862)
### Proposed Changes * the ERS should not use the auth JWT, it should us the entities or token provided in the request body * the V1 ERS needed to add a value to the context so it could be used by one of the providers ### Checklist - [ ] I have added or updated unit tests - [ ] I have added or updated integration tests (if appropriate) - [ ] I have added or updated documentation ### Testing Instructions
1 parent 16d0b1e commit dd6256e

File tree

5 files changed

+82
-40
lines changed

5 files changed

+82
-40
lines changed

service/entityresolution/multi-strategy/registration.go

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package multistrategy
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"log/slog"
78

@@ -14,6 +15,7 @@ import (
1415
"github.com/opentdf/platform/service/logger"
1516
"github.com/opentdf/platform/service/pkg/serviceregistry"
1617
"go.opentelemetry.io/otel/trace"
18+
"google.golang.org/protobuf/encoding/protojson"
1719
"google.golang.org/protobuf/types/known/structpb"
1820
)
1921

@@ -43,25 +45,42 @@ func (ers *ERS) ResolveEntities(
4345
ctx context.Context,
4446
req *connect.Request[entityresolution.ResolveEntitiesRequest],
4547
) (*connect.Response[entityresolution.ResolveEntitiesResponse], error) {
46-
// Extract JWT claims from context (this would be set by authentication middleware)
47-
jwtClaims, ok := ctx.Value(types.JWTClaimsContextKey).(types.JWTClaims)
48-
if !ok {
49-
ers.logger.Warn("no JWT claims found in context for multi-strategy ERS")
50-
jwtClaims = make(types.JWTClaims)
51-
}
52-
5348
payload := req.Msg.GetEntities()
5449
resolvedEntities := make([]*entityresolution.EntityRepresentation, 0, len(payload))
55-
5650
for _, entity := range payload {
5751
entityID := entity.GetId()
5852
if entityID == "" {
5953
ers.logger.Warn("empty entity ID in request")
6054
continue
6155
}
6256

57+
var claimsMap types.JWTClaims
58+
switch entity.GetEntityType().(type) {
59+
case *authorization.Entity_Claims:
60+
claims := entity.GetClaims()
61+
if claims != nil {
62+
// First unmarshal to structpb.Struct
63+
var claimsStruct structpb.Struct
64+
err := claims.UnmarshalTo(&claimsStruct)
65+
if err != nil {
66+
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("error unpacking anypb.Any to structpb.Struct: %w", err))
67+
}
68+
// Convert to map[string]interface{}
69+
claimsMap = claimsStruct.AsMap()
70+
}
71+
default:
72+
entityBytes, err := protojson.Marshal(entity)
73+
if err != nil {
74+
return nil, err
75+
}
76+
err = json.Unmarshal(entityBytes, &claimsMap)
77+
if err != nil {
78+
return nil, err
79+
}
80+
}
81+
6382
// Resolve entity using multi-strategy service
64-
result, err := ers.service.ResolveEntity(ctx, entityID, jwtClaims)
83+
result, err := ers.service.ResolveEntity(ctx, entityID, claimsMap)
6584
if err != nil {
6685
ers.logger.Error("failed to resolve entity",
6786
slog.String("entity_id", entityID),
@@ -212,8 +231,11 @@ func (ers *ERS) createEntityChainFromSingleToken(ctx context.Context, token *aut
212231
for _, strategy := range strategies {
213232
attemptedStrategies = append(attemptedStrategies, strategy.Name)
214233

234+
// Put JWT claims into context for providers to access
235+
ctxWithClaims := context.WithValue(ctx, types.JWTClaimsContextKey, jwtClaims)
236+
215237
// Resolve entity using this strategy
216-
entityResult, err := ers.service.ResolveEntity(ctx, token.GetId(), jwtClaims)
238+
entityResult, err := ers.service.ResolveEntity(ctxWithClaims, token.GetId(), jwtClaims)
217239
if err != nil {
218240
lastError = err
219241
ers.logger.WarnContext(ctx, "strategy failed for token",

service/entityresolution/multi-strategy/service.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,17 @@ func (s *Service) GetConfig() types.MultiStrategyConfig {
6060
}
6161

6262
// ResolveEntity resolves entity information using the configured strategies
63-
func (s *Service) ResolveEntity(ctx context.Context, entityID string, jwtClaims types.JWTClaims) (*types.EntityResult, error) {
63+
func (s *Service) ResolveEntity(ctx context.Context, entityID string, claimsMap types.JWTClaims) (*types.EntityResult, error) {
6464
// Get all matching strategies based on JWT claims
65-
strategies, err := s.strategyMatcher.SelectStrategies(ctx, jwtClaims)
65+
strategies, err := s.strategyMatcher.SelectStrategies(ctx, claimsMap)
6666
if err != nil {
6767
return nil, types.WrapMultiStrategyError(
6868
types.ErrorTypeStrategy,
6969
"failed to select strategies",
7070
err,
7171
map[string]interface{}{
7272
"entity_id": entityID,
73-
"jwt_claims": extractClaimNames(jwtClaims),
73+
"entity_map": extractClaimNames(claimsMap),
7474
},
7575
)
7676
}
@@ -88,7 +88,7 @@ func (s *Service) ResolveEntity(ctx context.Context, entityID string, jwtClaims
8888
for _, strategy := range strategies {
8989
attemptedStrategies = append(attemptedStrategies, strategy.Name)
9090

91-
result, err := s.executeStrategy(ctx, entityID, jwtClaims, strategy)
91+
result, err := s.executeStrategy(ctx, entityID, claimsMap, strategy)
9292
if err != nil {
9393
lastError = err
9494

@@ -130,7 +130,7 @@ func (s *Service) ResolveEntity(ctx context.Context, entityID string, jwtClaims
130130
"entity_id": entityID,
131131
"failure_strategy": failureStrategy,
132132
"attempted_strategies": attemptedStrategies,
133-
"jwt_claims": extractClaimNames(jwtClaims),
133+
"entity_map": extractClaimNames(claimsMap),
134134
},
135135
)
136136
}

service/entityresolution/multi-strategy/strategy_matcher.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,34 +31,34 @@ func (sm *StrategyMatcher) SelectStrategy(_ context.Context, claims types.JWTCla
3131

3232
return nil, types.NewStrategyError("no matching strategy found", map[string]interface{}{
3333
"available_strategies": len(sm.strategies),
34-
"jwt_claims": extractClaimNames(claims),
34+
"entity_map": extractClaimNames(claims),
3535
})
3636
}
3737

3838
// SelectStrategies returns all strategies that match the JWT claims in configuration order
39-
func (sm *StrategyMatcher) SelectStrategies(_ context.Context, claims types.JWTClaims) ([]*types.MappingStrategy, error) {
39+
func (sm *StrategyMatcher) SelectStrategies(_ context.Context, claimsMap types.JWTClaims) ([]*types.MappingStrategy, error) {
4040
var matchingStrategies []*types.MappingStrategy
4141

4242
for _, strategy := range sm.strategies {
43-
if sm.matchesConditions(claims, strategy.Conditions) {
43+
if sm.matchesConditions(claimsMap, strategy.Conditions) {
4444
matchingStrategies = append(matchingStrategies, &strategy)
4545
}
4646
}
4747

4848
if len(matchingStrategies) == 0 {
4949
return nil, types.NewStrategyError("no matching strategy found", map[string]interface{}{
5050
"available_strategies": len(sm.strategies),
51-
"jwt_claims": extractClaimNames(claims),
51+
"entity_map": extractClaimNames(claimsMap),
5252
})
5353
}
5454

5555
return matchingStrategies, nil
5656
}
5757

5858
// matchesConditions checks if JWT claims match strategy conditions
59-
func (sm *StrategyMatcher) matchesConditions(claims types.JWTClaims, conditions types.StrategyConditions) bool {
59+
func (sm *StrategyMatcher) matchesConditions(claimsMap types.JWTClaims, conditions types.StrategyConditions) bool {
6060
for _, claimCondition := range conditions.JWTClaims {
61-
if !sm.matchesClaimCondition(claims, claimCondition) {
61+
if !sm.matchesClaimCondition(claimsMap, claimCondition) {
6262
return false
6363
}
6464
}

service/entityresolution/multi-strategy/v2/registration.go

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,23 @@ package multistrategy
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"log/slog"
8+
"strconv"
79

810
"connectrpc.com/connect"
911
"github.com/go-viper/mapstructure/v2"
1012
"github.com/lestrrat-go/jwx/v2/jwt"
1113
"github.com/opentdf/platform/protocol/go/entity"
1214
ersV2 "github.com/opentdf/platform/protocol/go/entityresolution/v2"
15+
ent "github.com/opentdf/platform/service/entity"
1316
multistrategy "github.com/opentdf/platform/service/entityresolution/multi-strategy"
1417
"github.com/opentdf/platform/service/entityresolution/multi-strategy/types"
1518
"github.com/opentdf/platform/service/logger"
1619
"github.com/opentdf/platform/service/pkg/serviceregistry"
1720
"go.opentelemetry.io/otel/trace"
21+
"google.golang.org/protobuf/encoding/protojson"
1822
"google.golang.org/protobuf/types/known/structpb"
1923
)
2024

@@ -49,27 +53,43 @@ func (ers *ERSV2) ResolveEntities(
4953
ctx context.Context,
5054
req *connect.Request[ersV2.ResolveEntitiesRequest],
5155
) (*connect.Response[ersV2.ResolveEntitiesResponse], error) {
52-
// Extract JWT claims from context (this would be set by authentication middleware)
53-
jwtClaims, ok := ctx.Value(types.JWTClaimsContextKey).(types.JWTClaims)
54-
if !ok {
55-
ers.logger.Warn("no JWT claims found in context for multi-strategy ERS v2")
56-
// For ResolveEntities, we need JWT claims to be provided by middleware
57-
// This is different from CreateEntityChainsFromTokens which has the JWT token directly
58-
jwtClaims = make(types.JWTClaims)
59-
}
60-
6156
payload := req.Msg.GetEntities()
6257
resolvedEntities := make([]*ersV2.EntityRepresentation, 0, len(payload))
6358

64-
for _, entityV2 := range payload {
59+
for idx, entityV2 := range payload {
6560
entityID := entityV2.GetEphemeralId()
6661
if entityID == "" {
67-
ers.logger.Warn("empty entity ID in request")
68-
continue
62+
entityID = ent.EntityIDPrefix + strconv.Itoa(idx)
63+
ers.logger.Warn("empty entity ID in request; using generated ID", slog.String("entity_id", entityID))
64+
}
65+
66+
var claimsMap types.JWTClaims
67+
switch entityV2.GetEntityType().(type) {
68+
case *entity.Entity_Claims:
69+
claims := entityV2.GetClaims()
70+
if claims != nil {
71+
// First unmarshal to structpb.Struct
72+
var claimsStruct structpb.Struct
73+
err := claims.UnmarshalTo(&claimsStruct)
74+
if err != nil {
75+
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("error unpacking anypb.Any to structpb.Struct: %w", err))
76+
}
77+
// Convert to map[string]interface{}
78+
claimsMap = claimsStruct.AsMap()
79+
}
80+
default:
81+
entityBytes, err := protojson.Marshal(entityV2)
82+
if err != nil {
83+
return nil, err
84+
}
85+
err = json.Unmarshal(entityBytes, &claimsMap)
86+
if err != nil {
87+
return nil, err
88+
}
6989
}
7090

7191
// Resolve entity using multi-strategy service
72-
result, err := ers.service.ResolveEntity(ctx, entityID, jwtClaims)
92+
result, err := ers.service.ResolveEntity(ctx, entityID, claimsMap)
7393
if err != nil {
7494
ers.logger.Error("failed to resolve entity",
7595
slog.String("entity_id", entityID),
@@ -188,7 +208,7 @@ func (ers *ERSV2) createEntityChainFromSingleTokenV2(ctx context.Context, token
188208
err,
189209
map[string]interface{}{
190210
"token_id": token.GetEphemeralId(),
191-
"jwt_claims": extractClaimNames(jwtClaims),
211+
"entity_map": extractClaimNames(jwtClaims),
192212
},
193213
)
194214
}
@@ -198,7 +218,7 @@ func (ers *ERSV2) createEntityChainFromSingleTokenV2(ctx context.Context, token
198218
"no matching strategies found for JWT claims",
199219
map[string]interface{}{
200220
"token_id": token.GetEphemeralId(),
201-
"jwt_claims": extractClaimNames(jwtClaims),
221+
"entity_map": extractClaimNames(jwtClaims),
202222
},
203223
)
204224
}
@@ -276,7 +296,7 @@ func (ers *ERSV2) createEntityChainFromSingleTokenV2(ctx context.Context, token
276296
"token_id": token.GetEphemeralId(),
277297
"failure_strategy": failureStrategy,
278298
"attempted_strategies": attemptedStrategies,
279-
"jwt_claims": extractClaimNames(jwtClaims),
299+
"entity_map": extractClaimNames(jwtClaims),
280300
},
281301
)
282302
}
@@ -416,7 +436,7 @@ func getEntityValueV2(entityType interface{}) string {
416436
}
417437

418438
// RegisterMultiStrategyERSV2 registers the v2 multi-strategy ERS service
419-
func RegisterERSV2(config map[string]interface{}, logger *logger.Logger) (*ERSV2, serviceregistry.HandlerServer) {
439+
func RegisterMultiStrategyERSV2(config map[string]interface{}, logger *logger.Logger) (*ERSV2, serviceregistry.HandlerServer) {
420440
var multiStrategyConfig types.MultiStrategyConfig
421441

422442
if err := mapstructure.Decode(config, &multiStrategyConfig); err != nil {
@@ -433,7 +453,7 @@ func RegisterERSV2(config map[string]interface{}, logger *logger.Logger) (*ERSV2
433453
return ers, nil
434454
}
435455

436-
// extractClaimNames extracts the names of claims from JWTClaims for logging
456+
// extractClaimNames extracts the names of fields from JWTClaims for logging
437457
func extractClaimNames(claims types.JWTClaims) []string {
438458
names := make([]string, 0, len(claims))
439459
for name := range claims {

service/entityresolution/v2/entity_resolution.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func NewRegistration() *serviceregistry.Service[entityresolutionv2connect.Entity
7171
claimsSVC.Tracer = srp.Tracer
7272
return EntityResolution{EntityResolutionServiceHandler: claimsSVC}, claimsHandler
7373
case MultiStrategyMode:
74-
multiSVC, multiHandler := multistrategyv2.RegisterERSV2(srp.Config, srp.Logger)
74+
multiSVC, multiHandler := multistrategyv2.RegisterMultiStrategyERSV2(srp.Config, srp.Logger)
7575
multiSVC.Tracer = srp.Tracer
7676
return EntityResolution{EntityResolutionServiceHandler: multiSVC}, multiHandler
7777
default:

0 commit comments

Comments
 (0)