From a4a05b0ab3494d6924a231f3f19c3a55caf1790c Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 30 Jul 2025 21:06:54 +0200 Subject: [PATCH 01/61] chore: add ProvidesData + tests --- .../graphql_datasource_federation_test.go | 89 +++++++++++++++++++ v2/pkg/engine/resolve/fetch.go | 1 + 2 files changed, 90 insertions(+) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go index 7fa1a7f3e..057144c31 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go @@ -1565,6 +1565,61 @@ func TestGraphQLDataSourceFederation(t *testing.T) { FieldName: "user", }, }, + ProvidesData: &resolve.Object{ + Fields: []*resolve.Field{ + { + Name: []byte("user"), + Value: &resolve.Object{ + Path: []string{"user"}, + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("account"), + Value: &resolve.Object{ + Path: []string{"account"}, + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("__typename"), + Value: &resolve.String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &resolve.Scalar{ + Path: []string{"id"}, + }, + }, + { + Name: []byte("info"), + Value: &resolve.Object{ + Path: []string{"info"}, + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("a"), + Value: &resolve.Scalar{ + Path: []string{"a"}, + }, + }, + { + Name: []byte("b"), + Value: &resolve.Scalar{ + Path: []string{"b"}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, }, }), resolve.SingleWithPath(&resolve.SingleFetch{ @@ -1587,6 +1642,40 @@ func TestGraphQLDataSourceFederation(t *testing.T) { }, }, OperationType: ast.OperationTypeQuery, + ProvidesData: &resolve.Object{ + Fields: []*resolve.Field{ + { + Name: []byte("__typename"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("name"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.Scalar{ + Path: []string{"name"}, + }, + }, + { + Name: []byte("shippingInfo"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.Object{ + Path: []string{"shippingInfo"}, + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("zip"), + Value: &resolve.Scalar{ + Path: []string{"zip"}, + }, + }, + }, + }, + }, + }, + }, }, DataSourceIdentifier: []byte("graphql_datasource.Source"), FetchConfiguration: resolve.FetchConfiguration{ diff --git a/v2/pkg/engine/resolve/fetch.go b/v2/pkg/engine/resolve/fetch.go index 2bf7b8a3f..1122a44ba 100644 --- a/v2/pkg/engine/resolve/fetch.go +++ b/v2/pkg/engine/resolve/fetch.go @@ -376,6 +376,7 @@ type FetchInfo struct { RootFields []GraphCoordinate OperationType ast.OperationType QueryPlan *QueryPlan + ProvidesData *Object } type GraphCoordinate struct { From 499737b93c70dbae7bc3595a5459695316e2121d Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 31 Jul 2025 09:43:38 +0200 Subject: [PATCH 02/61] feat: implement ProvidesData on fetch --- .../graphql_datasource_federation_test.go | 7 +- v2/pkg/engine/plan/visitor.go | 300 ++++++++++++++++++ 2 files changed, 303 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go index 057144c31..b953ce192 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go @@ -1581,7 +1581,7 @@ func TestGraphQLDataSourceFederation(t *testing.T) { Fields: []*resolve.Field{ { Name: []byte("__typename"), - Value: &resolve.String{ + Value: &resolve.Scalar{ Path: []string{"__typename"}, }, }, @@ -1645,9 +1645,8 @@ func TestGraphQLDataSourceFederation(t *testing.T) { ProvidesData: &resolve.Object{ Fields: []*resolve.Field{ { - Name: []byte("__typename"), - OnTypeNames: [][]byte{[]byte("Account")}, - Value: &resolve.String{ + Name: []byte("__typename"), + Value: &resolve.Scalar{ Path: []string{"__typename"}, }, }, diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 239e1666e..e878370b8 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -61,6 +61,18 @@ type Visitor struct { fieldPlanners map[int][]int // fieldEnclosingTypeNames stores the enclosing type names for each field ref fieldEnclosingTypeNames map[int]string + // plannerObjects stores the root object for each planner's ProvidesData + // map plannerID -> root object + plannerObjects map[int]*resolve.Object + // plannerCurrentFields stores the current field stack for each planner + // map plannerID -> field stack + plannerCurrentFields map[int][]objectFields + // plannerResponsePaths stores the response paths relative to each planner's root + // map plannerID -> response path stack + plannerResponsePaths map[int][]string + // plannerEntityBoundaryPaths stores the entity boundary paths for each planner + // map plannerID -> entity boundary path + plannerEntityBoundaryPaths map[int]string } type indirectInterfaceField struct { @@ -340,6 +352,12 @@ func (v *Visitor) EnterField(ref int) { if !v.Config.DisableIncludeFieldDependencies { v.fieldEnclosingTypeNames[ref] = strings.Clone(v.Walker.EnclosingTypeDefinition.NameString(v.Definition)) } + + // Track field for each planner that should handle it + for plannerID := range v.planners { + v.trackFieldForPlanner(plannerID, ref) + } + // check if we have to skip the field in the response // it means it was requested by the planner not the user if v.skipField(ref) { @@ -610,6 +628,11 @@ func (v *Visitor) addInterfaceObjectNameToTypeNames(fieldRef int, typeName []byt func (v *Visitor) LeaveField(ref int) { v.debugOnLeaveNode(ast.NodeKindField, ref) + // Pop fields for each planner that tracked this field + for plannerID := range v.planners { + v.popFieldsForPlanner(plannerID, ref) + } + if v.skipField(ref) { // we should also check skips on field leave // cause on nested keys we could mistakenly remove wrong object @@ -997,6 +1020,9 @@ func (v *Visitor) EnterOperationDefinition(ref int) { } } + // Initialize per-planner structures for ProvidesData tracking + v.initializePlannerStructures() + if operationKind == ast.OperationTypeSubscription { v.subscription = &resolve.GraphQLSubscription{ Response: v.response, @@ -1061,6 +1087,9 @@ func (v *Visitor) EnterDocument(operation, definition *ast.Document) { v.plannerFields = map[int][]int{} v.fieldPlanners = map[int][]int{} v.fieldEnclosingTypeNames = map[int]string{} + v.plannerObjects = map[int]*resolve.Object{} + v.plannerCurrentFields = map[int][]objectFields{} + v.plannerResponsePaths = map[int][]string{} } func (v *Visitor) LeaveDocument(_, _ *ast.Document) { @@ -1115,6 +1144,272 @@ func (v *Visitor) pathDeepness(path string) int { return strings.Count(path, ".") } +func (v *Visitor) initializePlannerStructures() { + // Initialize root objects and field stacks for each potential planner + // We'll populate these as we traverse fields + if v.planners == nil { + return + } + + for i := range v.planners { + v.plannerObjects[i] = &resolve.Object{ + Fields: []*resolve.Field{}, + } + v.plannerCurrentFields[i] = []objectFields{{ + fields: &v.plannerObjects[i].Fields, + popOnField: -1, + }} + v.plannerResponsePaths[i] = []string{} + } + v.plannerEntityBoundaryPaths = map[int]string{} +} + +func (v *Visitor) trackFieldForPlanner(plannerID int, fieldRef int) { + // Safety checks + if v.planners == nil || plannerID >= len(v.planners) { + return + } + if v.plannerObjects == nil || v.plannerCurrentFields == nil { + return + } + + // Check if this planner should handle this field + if !v.shouldPlannerHandleField(plannerID, fieldRef) { + return + } + + // Get field information + fieldName := v.Operation.FieldNameBytes(fieldRef) + fieldAliasOrName := v.Operation.FieldAliasOrNameString(fieldRef) + + // For nested entity fetches, check if this field represents the entity boundary + // If so, we should skip adding this field to ProvidesData and instead add its children + if v.isEntityBoundaryField(plannerID, fieldRef) { + // Add a __typename field to the current object for entity boundary + v.addTypenameFieldForPlanner(plannerID) + return + } + + // Get the field definition + fieldDefinition, ok := v.Walker.FieldDefinition(fieldRef) + if !ok { + return + } + fieldType := v.Definition.FieldDefinitionType(fieldDefinition) + + // Create a simple field value for tracking purposes + fieldValue := v.createFieldValueForPlanner(fieldRef, fieldType, []string{fieldAliasOrName}) + + onTypeNames := v.resolveEntityOnTypeNames(plannerID, fieldRef, fieldName) + + // Create the field + field := &resolve.Field{ + Name: fieldName, + Value: fieldValue, + OnTypeNames: onTypeNames, + } + + // Add the field to the current object for this planner + if len(v.plannerCurrentFields[plannerID]) > 0 { + currentFields := v.plannerCurrentFields[plannerID][len(v.plannerCurrentFields[plannerID])-1] + *currentFields.fields = append(*currentFields.fields, field) + } + + // If the field value is an object, push it onto the stack for this planner + if obj, ok := fieldValue.(*resolve.Object); ok { + v.Walker.DefferOnEnterField(func() { + v.plannerCurrentFields[plannerID] = append(v.plannerCurrentFields[plannerID], objectFields{ + popOnField: fieldRef, + fields: &obj.Fields, + }) + }) + } +} + +func (v *Visitor) resolveEntityOnTypeNames(plannerID, fieldRef int, fieldName ast.ByteSlice) (onTypeNames [][]byte) { + // If this is an entity root field, return the enclosing type name + if v.isEntityRootField(plannerID, fieldRef) { + enclosingTypeName := v.Walker.EnclosingTypeDefinition.NameBytes(v.Definition) + if enclosingTypeName != nil { + return [][]byte{enclosingTypeName} + } + } + + // Otherwise, use the regular resolution logic + onTypeNames = v.resolveOnTypeNames(fieldRef, fieldName) + return onTypeNames +} + +// createFieldValueForPlanner creates a simplified field value for planner tracking +// without relying on the full visitor state like resolveFieldValue does +func (v *Visitor) createFieldValueForPlanner(fieldRef, typeRef int, path []string) resolve.Node { + ofType := v.Definition.Types[typeRef].OfType + + switch v.Definition.Types[typeRef].TypeKind { + case ast.TypeKindNonNull: + node := v.createFieldValueForPlanner(fieldRef, ofType, path) + // Set nullable to false for the returned node + switch n := node.(type) { + case *resolve.Scalar: + n.Nullable = false + case *resolve.Object: + n.Nullable = false + case *resolve.Array: + n.Nullable = false + } + return node + case ast.TypeKindList: + listItem := v.createFieldValueForPlanner(fieldRef, ofType, nil) + return &resolve.Array{ + Nullable: true, + Path: path, + Item: listItem, + } + case ast.TypeKindNamed: + typeName := v.Definition.ResolveTypeNameString(typeRef) + typeDefinitionNode, ok := v.Definition.Index.FirstNodeByNameStr(typeName) + if !ok { + return &resolve.Null{} + } + switch typeDefinitionNode.Kind { + case ast.NodeKindScalarTypeDefinition, ast.NodeKindEnumTypeDefinition: + return &resolve.Scalar{ + Nullable: true, + Path: path, + } + case ast.NodeKindObjectTypeDefinition, ast.NodeKindInterfaceTypeDefinition, ast.NodeKindUnionTypeDefinition: + // For object types, create a new object that will be populated by child fields + obj := &resolve.Object{ + Nullable: true, + Path: path, + Fields: []*resolve.Field{}, + } + return obj + default: + return &resolve.Null{} + } + default: + return &resolve.Null{} + } +} + +// isEntityBoundaryField checks if this field represents the entity boundary for a nested entity fetch +// For nested entity fetches, the field at the response path boundary should be skipped in ProvidesData +func (v *Visitor) isEntityBoundaryField(plannerID int, fieldRef int) bool { + config := v.planners[plannerID] + fetchConfig := config.ObjectFetchConfiguration() + if fetchConfig == nil || fetchConfig.fetchItem == nil { + return false + } + + // Check if this is a nested fetch (has "." in response path) + responsePath := "query." + fetchConfig.fetchItem.ResponsePath + if !strings.Contains(responsePath, ".") { + return false // Root fetch, no boundary field to skip + } + + // For nested fetches, check if this field is at the entity boundary + currentPath := v.Walker.Path.DotDelimitedString() + fieldName := v.Operation.FieldAliasOrNameString(fieldRef) + fullFieldPath := currentPath + "." + fieldName + + // If this field path matches the response path, it's the entity boundary + if fullFieldPath == responsePath { + // Store the entity boundary path for this planner + v.plannerEntityBoundaryPaths[plannerID] = fullFieldPath + return true + } + return false +} + +// isEntityRootField checks if this field is at the root of an entity +// This means it has one additional path element compared to the stored entity boundary path +func (v *Visitor) isEntityRootField(plannerID int, fieldRef int) bool { + // Check if we have a stored entity boundary path for this planner + boundaryPath, hasBoundary := v.plannerEntityBoundaryPaths[plannerID] + if !hasBoundary { + return false + } + + // Get the current field path + currentPath := v.Walker.Path.DotDelimitedString() + fieldName := v.Operation.FieldAliasOrNameString(fieldRef) + fullFieldPath := currentPath + "." + fieldName + + // Check if this field is a direct child of the entity boundary + // It should start with the boundary path and have exactly one more segment + if !strings.HasPrefix(fullFieldPath, boundaryPath+".") { + return false + } + + // Remove the boundary path prefix and check if there's exactly one segment left + remainingPath := strings.TrimPrefix(fullFieldPath, boundaryPath+".") + // If there are no more dots, this is a root field of the entity + return !strings.Contains(remainingPath, ".") +} + +// addTypenameFieldForPlanner adds a __typename field to the current object for entity boundary fields +func (v *Visitor) addTypenameFieldForPlanner(plannerID int) { + + // Create a __typename field + typenameField := &resolve.Field{ + Name: []byte("__typename"), + Value: &resolve.Scalar{ + Path: []string{"__typename"}, + }, + } + + // Add the __typename field to the current object for this planner + if len(v.plannerCurrentFields[plannerID]) > 0 { + currentFields := v.plannerCurrentFields[plannerID][len(v.plannerCurrentFields[plannerID])-1] + *currentFields.fields = append(*currentFields.fields, typenameField) + } +} + +func (v *Visitor) shouldPlannerHandleField(plannerID int, fieldRef int) bool { + // Safety checks + if v.planners == nil || plannerID >= len(v.planners) { + return false + } + + // Use the same logic as AllowVisitor to check if a planner handles a field + path := v.Walker.Path.DotDelimitedString() + if v.Walker.CurrentKind == ast.NodeKindField { + path = path + "." + v.Operation.FieldAliasOrNameString(fieldRef) + } + + config := v.planners[plannerID] + if !config.HasPath(path) { + return false + } + + enclosingTypeName := v.Walker.EnclosingTypeDefinition.NameString(v.Definition) + + allow := config.HasPathWithFieldRef(fieldRef) || config.HasParent(path) + if !allow { + return false + } + + shouldWalkFieldsOnPath := config.ShouldWalkFieldsOnPath(path, enclosingTypeName) || + config.ShouldWalkFieldsOnPath(path, "") + + return shouldWalkFieldsOnPath +} + +func (v *Visitor) popFieldsForPlanner(plannerID int, fieldRef int) { + // Safety checks + if v.plannerCurrentFields == nil || plannerID >= len(v.plannerCurrentFields) { + return + } + + if len(v.plannerCurrentFields[plannerID]) > 0 { + last := len(v.plannerCurrentFields[plannerID]) - 1 + if v.plannerCurrentFields[plannerID][last].popOnField == fieldRef { + v.plannerCurrentFields[plannerID] = v.plannerCurrentFields[plannerID][:last] + } + } +} + func (v *Visitor) resolveInputTemplates(config *objectFetchConfiguration, input *string, variables *resolve.Variables) { *input = templateRegex.ReplaceAllStringFunc(*input, func(s string) string { selectors := selectorRegex.FindStringSubmatch(s) @@ -1324,6 +1619,11 @@ func (v *Visitor) configureFetch(internal *objectFetchConfiguration, external re OperationType: internal.operationType, QueryPlan: external.QueryPlan, } + + // Set ProvidesData from the planner's object structure + if providesData, ok := v.plannerObjects[internal.fetchID]; ok { + singleFetch.Info.ProvidesData = providesData + } } if !v.Config.DisableIncludeFieldDependencies { From 571cfe0bc516d0f4a758b9dad7308e12b566a8fc Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 31 Jul 2025 10:26:06 +0200 Subject: [PATCH 03/61] chore: improve __typename handling --- .../graphql_datasource_federation_test.go | 30 +++++++++++++++---- v2/pkg/engine/plan/visitor.go | 21 ++++++++++++- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go index b953ce192..9924eb9ec 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go @@ -1534,9 +1534,10 @@ func TestGraphQLDataSourceFederation(t *testing.T) { query CompositeKeys { user { account { + __typename name shippingInfo { - zip + z: zip } } } @@ -1665,9 +1666,9 @@ func TestGraphQLDataSourceFederation(t *testing.T) { Nullable: true, Fields: []*resolve.Field{ { - Name: []byte("zip"), + Name: []byte("z"), Value: &resolve.Scalar{ - Path: []string{"zip"}, + Path: []string{"z"}, }, }, }, @@ -1678,7 +1679,7 @@ func TestGraphQLDataSourceFederation(t *testing.T) { }, DataSourceIdentifier: []byte("graphql_datasource.Source"), FetchConfiguration: resolve.FetchConfiguration{ - Input: `{"method":"POST","url":"http://account.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Account {__typename name shippingInfo {zip}}}}","variables":{"representations":[$$0$$]}}}`, + Input: `{"method":"POST","url":"http://account.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Account {__typename name shippingInfo {z: zip}}}}","variables":{"representations":[$$0$$]}}}`, DataSource: &Source{}, SetTemplateOutputToNullOnVariableNull: true, RequiresEntityFetch: true, @@ -1778,6 +1779,23 @@ func TestGraphQLDataSourceFederation(t *testing.T) { TypeName: "Account", SourceName: "user.service", Fields: []*resolve.Field{ + { + Name: []byte("__typename"), + Info: &resolve.FieldInfo{ + Name: "__typename", + NamedType: "String", + ParentTypeNames: []string{"Account"}, + Source: resolve.TypeFieldSource{ + IDs: []string{"user.service"}, + Names: []string{"user.service"}, + }, + ExactParentTypeName: "Account", + }, + Value: &resolve.String{ + Path: []string{"__typename"}, + IsTypeName: true, + }, + }, { Name: []byte("name"), Info: &resolve.FieldInfo{ @@ -1817,7 +1835,7 @@ func TestGraphQLDataSourceFederation(t *testing.T) { SourceName: "account.service", Fields: []*resolve.Field{ { - Name: []byte("zip"), + Name: []byte("z"), Info: &resolve.FieldInfo{ Name: "zip", NamedType: "String", @@ -1829,7 +1847,7 @@ func TestGraphQLDataSourceFederation(t *testing.T) { ExactParentTypeName: "ShippingInfo", }, Value: &resolve.String{ - Path: []string{"zip"}, + Path: []string{"z"}, }, }, }, diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index e878370b8..5357450b8 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -1190,6 +1190,25 @@ func (v *Visitor) trackFieldForPlanner(plannerID int, fieldRef int) { return } + // Check if this is a __typename field and if we already have one with the same name and path + if bytes.Equal(fieldName, literal.TYPENAME) && len(v.plannerCurrentFields[plannerID]) > 0 { + currentFields := v.plannerCurrentFields[plannerID][len(v.plannerCurrentFields[plannerID])-1] + + // Check if we already have a __typename field with the same name and path + for _, existingField := range *currentFields.fields { + if bytes.Equal(existingField.Name, []byte(fieldAliasOrName)) { + // For __typename fields, the path is [fieldAliasOrName] + // Check if the existing field has the same path + if existingValue, ok := existingField.Value.(*resolve.Scalar); ok { + if len(existingValue.Path) > 0 && existingValue.Path[0] == fieldAliasOrName { + // We already have this __typename field with the same name and path, skip it + return + } + } + } + } + } + // Get the field definition fieldDefinition, ok := v.Walker.FieldDefinition(fieldRef) if !ok { @@ -1204,7 +1223,7 @@ func (v *Visitor) trackFieldForPlanner(plannerID int, fieldRef int) { // Create the field field := &resolve.Field{ - Name: fieldName, + Name: []byte(fieldAliasOrName), Value: fieldValue, OnTypeNames: onTypeNames, } From 6322c7b5773b4e27e49f2d3f1261fdfa24481a25 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 31 Jul 2025 11:59:10 +0200 Subject: [PATCH 04/61] chore: correctly handle objects in lists --- .../graphql_datasource_test.go | 94 +++++++++++++++++++ v2/pkg/engine/plan/visitor.go | 25 +++-- 2 files changed, 112 insertions(+), 7 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 0ce09599d..92e99a0e8 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -422,6 +422,100 @@ func TestGraphQLDataSource(t *testing.T) { FieldName: "nestedStringList", }, }, + ProvidesData: &resolve.Object{ + Nullable: false, + Path: []string{}, + Fields: []*resolve.Field{ + { + Name: []byte("droid"), + Value: &resolve.Object{ + Nullable: true, + Path: []string{"droid"}, + Fields: []*resolve.Field{ + { + Name: []byte("name"), + Value: &resolve.Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("aliased"), + Value: &resolve.Scalar{ + Path: []string{"aliased"}, + Nullable: false, + }, + }, + { + Name: []byte("friends"), + Value: &resolve.Array{ + Path: []string{"friends"}, + Nullable: true, + Item: &resolve.Object{ + Nullable: true, + Path: []string{}, + Fields: []*resolve.Field{ + { + Name: []byte("name"), + Value: &resolve.Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + { + Name: []byte("primaryFunction"), + Value: &resolve.Scalar{ + Path: []string{"primaryFunction"}, + Nullable: false, + }, + }, + }, + }, + }, + { + Name: []byte("hero"), + Value: &resolve.Object{ + Nullable: true, + Path: []string{"hero"}, + Fields: []*resolve.Field{ + { + Name: []byte("name"), + Value: &resolve.Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + { + Name: []byte("stringList"), + Value: &resolve.Array{ + Path: []string{"stringList"}, + Nullable: true, + Item: &resolve.Scalar{ + Path: []string{}, + Nullable: true, + }, + }, + }, + { + Name: []byte("nestedStringList"), + Value: &resolve.Array{ + Path: []string{"nestedStringList"}, + Nullable: true, + Item: &resolve.Scalar{ + Path: []string{}, + Nullable: true, + }, + }, + }, + }, + }, }, })), Info: &resolve.GraphQLResponseInfo{ diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 5357450b8..5702a9a3e 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -1234,14 +1234,25 @@ func (v *Visitor) trackFieldForPlanner(plannerID int, fieldRef int) { *currentFields.fields = append(*currentFields.fields, field) } - // If the field value is an object, push it onto the stack for this planner - if obj, ok := fieldValue.(*resolve.Object); ok { - v.Walker.DefferOnEnterField(func() { - v.plannerCurrentFields[plannerID] = append(v.plannerCurrentFields[plannerID], objectFields{ - popOnField: fieldRef, - fields: &obj.Fields, + for { + // for loop to unwrap array item + switch node := fieldValue.(type) { + case *resolve.Array: + // unwrap and check type again + fieldValue = node.Item + case *resolve.Object: + // if the field value is an object, add it to the current fields stack + v.Walker.DefferOnEnterField(func() { + v.plannerCurrentFields[plannerID] = append(v.plannerCurrentFields[plannerID], objectFields{ + popOnField: fieldRef, + fields: &node.Fields, + }) }) - }) + return + default: + // field value is a scalar or null, we don't add it to the stack + return + } } } From 8bcdfa34bf671b703c655b54f88e6c6fe8c2b311 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 31 Jul 2025 15:44:36 +0200 Subject: [PATCH 05/61] chore: implement loader.canSkipFetch --- v2/pkg/engine/resolve/loader.go | 131 +++ .../engine/resolve/loader_skip_fetch_test.go | 906 ++++++++++++++++++ 2 files changed, 1037 insertions(+) create mode 100644 v2/pkg/engine/resolve/loader_skip_fetch_test.go diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index f15a0a858..5abe2e4bc 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1732,3 +1732,134 @@ func (l *Loader) compactJSON(data []byte) ([]byte, error) { astjson.DeduplicateObjectKeysRecursively(v) return v.MarshalTo(nil), nil } + +func (l *Loader) canSkipFetch(info *FetchInfo, items []*astjson.Value) ([]*astjson.Value, bool) { + if info == nil || info.OperationType != ast.OperationTypeQuery { + return items, false + } + if len(items) == 1 && items[0].Type() == astjson.TypeNull { + return items, true + } + + // If ProvidesData is nil, we cannot validate the data - do not skip fetch + if info.ProvidesData == nil { + return items, false + } + + // Check each item and remove those that have sufficient data + remaining := make([]*astjson.Value, 0, len(items)) + for _, item := range items { + if !l.validateItemHasRequiredData(item, info.ProvidesData) { + remaining = append(remaining, item) + } + } + + // Return the remaining items and whether fetch can be skipped + return remaining, len(remaining) == 0 +} + +// validateItemHasRequiredData checks if the given item contains all required data +// as specified by the provided Object schema +func (l *Loader) validateItemHasRequiredData(item *astjson.Value, obj *Object) bool { + if obj == nil { + return true + } + + // Validate each field in the object + for _, field := range obj.Fields { + if !l.validateFieldData(item, field) { + return false + } + } + + return true +} + +// validateFieldData validates a single field against the item data +func (l *Loader) validateFieldData(item *astjson.Value, field *Field) bool { + fieldValue := item.Get(unsafebytes.BytesToString(field.Name)) + + // Check if field exists + if fieldValue == nil { + // Field is missing - this fails validation regardless of nullability + // Even nullable fields must be present (can be null, but not missing) + return false + } + + // Validate the field value against its specification + return l.validateNodeValue(fieldValue, field.Value) +} + +// validateScalarData validates scalar field data +func (l *Loader) validateScalarData(value *astjson.Value, scalar *Scalar) bool { + if value.Type() == astjson.TypeNull { + // Null is only allowed if the scalar is nullable + return scalar.Nullable + } + + // Any non-null value is acceptable for a scalar + return true +} + +// validateObjectData validates object field data +func (l *Loader) validateObjectData(value *astjson.Value, obj *Object) bool { + if value.Type() == astjson.TypeNull { + // Null is only allowed if the object is nullable + return obj.Nullable + } + + if value.Type() != astjson.TypeObject { + // Must be an object (or null if nullable) + return false + } + + // Recursively validate the object's fields + return l.validateItemHasRequiredData(value, obj) +} + +// validateArrayData validates array field data +func (l *Loader) validateArrayData(value *astjson.Value, arr *Array) bool { + if value.Type() == astjson.TypeNull { + // Null is only allowed if the array is nullable + return arr.Nullable + } + + if value.Type() != astjson.TypeArray { + // Must be an array (or null if nullable) + return false + } + + // If there's no item specification, we just validate the array exists + if arr.Item == nil { + return true + } + + // Validate each item in the array + arrayItems, err := value.Array() + if err != nil { + return false + } + + for _, item := range arrayItems { + if !l.validateNodeValue(item, arr.Item) { + return false + } + } + + return true +} + +// validateNodeValue validates a value against a Node specification +func (l *Loader) validateNodeValue(value *astjson.Value, nodeSpec Node) bool { + switch v := nodeSpec.(type) { + case *Scalar: + return l.validateScalarData(value, v) + case *Object: + return l.validateObjectData(value, v) + case *Array: + return l.validateArrayData(value, v) + default: + // Unknown type - assume invalid + return false + } +} diff --git a/v2/pkg/engine/resolve/loader_skip_fetch_test.go b/v2/pkg/engine/resolve/loader_skip_fetch_test.go new file mode 100644 index 000000000..0d9a5c649 --- /dev/null +++ b/v2/pkg/engine/resolve/loader_skip_fetch_test.go @@ -0,0 +1,906 @@ +package resolve + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/wundergraph/astjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +func TestLoader_canSkipFetch(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + info *FetchInfo + items []*astjson.Value + wantResult bool + wantRemaining int // -1 means check for empty, otherwise check exact count + checkFn func(t *testing.T, remaining []*astjson.Value) // optional custom validation + }{ + { + name: "single item with Query operation", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"id": "123"}`)), + }, + wantResult: true, + wantRemaining: -1, // empty + }, + { + name: "single item with Mutation operation", + info: &FetchInfo{ + OperationType: ast.OperationTypeMutation, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"id": "123"}`)), + }, + wantResult: false, + wantRemaining: 1, + }, + { + name: "single item with null type", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{Fields: []*Field{}}, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`null`)), + }, + wantResult: true, + wantRemaining: 1, // null item remains + }, + { + name: "single item with all required data", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123", "name": "John"}}`)), + }, + wantResult: true, + wantRemaining: -1, // empty + }, + { + name: "single item missing required field", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123"}}`)), // missing "name" + }, + wantResult: false, + wantRemaining: 1, + }, + { + name: "single item missing nullable field", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("email"), + Value: &Scalar{ + Path: []string{"email"}, + Nullable: true, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123"}}`)), // missing nullable "email" + }, + wantResult: false, + wantRemaining: 1, + }, + { + name: "single item with null value on required path", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": null}}`)), // null value on required field + }, + wantResult: false, + wantRemaining: 1, + }, + { + name: "single item with null value on nullable path", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("email"), + Value: &Scalar{ + Path: []string{"email"}, + Nullable: true, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123", "email": null}}`)), // null value on nullable field + }, + wantResult: true, + wantRemaining: -1, // empty + }, + { + name: "multiple items all can be skipped", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"id": "123"}`)), + astjson.MustParseBytes([]byte(`{"id": "456"}`)), + astjson.MustParseBytes([]byte(`{"id": "789"}`)), + }, + wantResult: true, + wantRemaining: -1, // empty + }, + { + name: "multiple items some can be skipped", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123", "name": "John"}}`)), // complete + astjson.MustParseBytes([]byte(`{"user": {"id": "456"}}`)), // missing name + astjson.MustParseBytes([]byte(`{"user": {"id": "789", "name": "Alice"}}`)), // complete + }, + wantResult: false, + wantRemaining: 1, + checkFn: func(t *testing.T, remaining []*astjson.Value) { + // Check that the remaining item is the incomplete one + user := remaining[0].Get("user") + assert.Equal(t, "456", string(user.Get("id").GetStringBytes())) + }, + }, + { + name: "multiple items none can be skipped", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123"}}`)), // missing name + astjson.MustParseBytes([]byte(`{"user": {"id": "456"}}`)), // missing name + astjson.MustParseBytes([]byte(`{"user": {"id": "789"}}`)), // missing name + }, + wantResult: false, + wantRemaining: 3, + }, + { + name: "nullable array that is null", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("tags"), + Value: &Array{ + Path: []string{"tags"}, + Nullable: true, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123", "tags": null}}`)), + }, + wantResult: true, + wantRemaining: -1, // empty + }, + { + name: "nullable array that is empty", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("tags"), + Value: &Array{ + Path: []string{"tags"}, + Nullable: true, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123", "tags": []}}`)), + }, + wantResult: true, + wantRemaining: -1, // empty + }, + { + name: "deeply nested structure", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: true, + Fields: []*Field{ + { + Name: []byte("account"), + Value: &Object{ + Path: []string{"account"}, + Nullable: true, + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &Scalar{ + Path: []string{"__typename"}, + Nullable: false, + }, + }, + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("info"), + Value: &Object{ + Path: []string{"info"}, + Nullable: true, + Fields: []*Field{ + { + Name: []byte("a"), + Value: &Scalar{ + Path: []string{"a"}, + Nullable: false, + }, + }, + { + Name: []byte("b"), + Value: &Scalar{ + Path: []string{"b"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{ + "user": { + "account": { + "__typename": "Account", + "id": "123", + "info": { + "a": "valueA", + "b": "valueB" + } + } + } + }`)), + }, + wantResult: true, + wantRemaining: -1, // empty + }, + { + name: "nil info", + info: nil, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"id": "123"}`)), + }, + wantResult: false, + wantRemaining: 1, + }, + { + name: "nil ProvidesData", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: nil, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"id": "123"}`)), + }, + wantResult: false, + wantRemaining: 1, + }, + { + name: "array with scalar items - valid", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("tags"), + Value: &Array{ + Path: []string{"tags"}, + Nullable: false, + Item: &Scalar{ + Path: []string{}, + Nullable: false, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"tags": ["tag1", "tag2", "tag3"]}`)), + }, + wantResult: true, + wantRemaining: -1, // empty + }, + { + name: "array with scalar items - invalid (null item in non-nullable array)", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("tags"), + Value: &Array{ + Path: []string{"tags"}, + Nullable: false, + Item: &Scalar{ + Path: []string{}, + Nullable: false, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"tags": ["tag1", null, "tag3"]}`)), // null item in non-nullable array + }, + wantResult: false, + wantRemaining: 1, + }, + { + name: "array with scalar items - valid (null item in nullable array)", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("tags"), + Value: &Array{ + Path: []string{"tags"}, + Nullable: false, + Item: &Scalar{ + Path: []string{}, + Nullable: true, // nullable scalar items + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"tags": ["tag1", null, "tag3"]}`)), // null item in nullable array + }, + wantResult: true, + wantRemaining: -1, // empty + }, + { + name: "array with object items - valid", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("users"), + Value: &Array{ + Path: []string{"users"}, + Nullable: false, + Item: &Object{ + Path: []string{}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"users": [{"id": "1", "name": "John"}, {"id": "2", "name": "Jane"}]}`)), + }, + wantResult: true, + wantRemaining: -1, // empty + }, + { + name: "array with object items - invalid (missing required field)", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("users"), + Value: &Array{ + Path: []string{"users"}, + Nullable: false, + Item: &Object{ + Path: []string{}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"users": [{"id": "1", "name": "John"}, {"id": "2"}]}`)), // missing "name" field + }, + wantResult: false, + wantRemaining: 1, + }, + { + name: "nested arrays - valid", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("matrix"), + Value: &Array{ + Path: []string{"matrix"}, + Nullable: false, + Item: &Array{ + Path: []string{}, + Nullable: false, + Item: &Scalar{ + Path: []string{}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"matrix": [["a", "b"], ["c", "d"], ["e", "f"]]}`)), + }, + wantResult: true, + wantRemaining: -1, // empty + }, + { + name: "nested arrays - invalid (null in inner non-nullable array)", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("matrix"), + Value: &Array{ + Path: []string{"matrix"}, + Nullable: false, + Item: &Array{ + Path: []string{}, + Nullable: false, + Item: &Scalar{ + Path: []string{}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"matrix": [["a", "b"], ["c", null], ["e", "f"]]}`)), // null in inner array + }, + wantResult: false, + wantRemaining: 1, + }, + { + name: "array of objects with nested arrays - complex valid case", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("groups"), + Value: &Array{ + Path: []string{"groups"}, + Nullable: false, + Item: &Object{ + Path: []string{}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("members"), + Value: &Array{ + Path: []string{"members"}, + Nullable: false, + Item: &Object{ + Path: []string{}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"groups": [{"name": "admins", "members": [{"id": "1"}, {"id": "2"}]}, {"name": "users", "members": [{"id": "3"}]}]}`)), + }, + wantResult: true, + wantRemaining: -1, // empty + }, + { + name: "array of objects with nested arrays - complex invalid case", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("groups"), + Value: &Array{ + Path: []string{"groups"}, + Nullable: false, + Item: &Object{ + Path: []string{}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("members"), + Value: &Array{ + Path: []string{"members"}, + Nullable: false, + Item: &Object{ + Path: []string{}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"groups": [{"name": "admins", "members": [{"id": "1"}, {}]}, {"name": "users", "members": [{"id": "3"}]}]}`)), // missing id in one member + }, + wantResult: false, + wantRemaining: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + loader := &Loader{} + + // Make a copy of items to avoid mutation affecting the test data + itemsCopy := make([]*astjson.Value, len(tt.items)) + copy(itemsCopy, tt.items) + + remaining, result := loader.canSkipFetch(tt.info, itemsCopy) + + assert.Equal(t, tt.wantResult, result, "result mismatch") + + if tt.wantRemaining == -1 { + assert.Empty(t, remaining, "expected empty remaining items") + } else { + assert.Len(t, remaining, tt.wantRemaining, "remaining items count mismatch") + } + + if tt.checkFn != nil { + tt.checkFn(t, remaining) + } + }) + } +} From 101e813980b57f16a387941389bbc417caeee0eb Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 5 Aug 2025 23:45:38 +0200 Subject: [PATCH 06/61] chore: add cache config to resolve & gateway --- .../engine/federation_integration_test.go | 84 ++++++++++++------- .../federationtesting/gateway/gateway.go | 4 + execution/federationtesting/gateway/main.go | 4 +- .../create_concrete_single_fetch_types.go | 2 + v2/pkg/engine/resolve/resolve.go | 3 + 5 files changed, 66 insertions(+), 31 deletions(-) diff --git a/execution/engine/federation_integration_test.go b/execution/engine/federation_integration_test.go index 1b867bb37..4b0f702a1 100644 --- a/execution/engine/federation_integration_test.go +++ b/execution/engine/federation_integration_test.go @@ -18,13 +18,37 @@ import ( "github.com/sebdah/goldie/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/execution/federationtesting" "github.com/wundergraph/graphql-go-tools/execution/federationtesting/gateway" products "github.com/wundergraph/graphql-go-tools/execution/federationtesting/products/graph" ) -func addGateway(enableART bool) func(setup *federationtesting.FederationSetup) *httptest.Server { +type gatewayOptions struct { + enableART bool + withLoaderCache map[string]resolve.LoaderCache +} + +func withEnableART(enableART bool) func(*gatewayOptions) { + return func(opts *gatewayOptions) { + opts.enableART = enableART + } +} + +func withLoaderCache(loaderCache map[string]resolve.LoaderCache) func(*gatewayOptions) { + return func(opts *gatewayOptions) { + opts.withLoaderCache = loaderCache + } +} + +type gatewayOptionsToFunc func(opts *gatewayOptions) + +func addGateway(options ...gatewayOptionsToFunc) func(setup *federationtesting.FederationSetup) *httptest.Server { + opts := &gatewayOptions{} + for _, option := range options { + option(opts) + } return func(setup *federationtesting.FederationSetup) *httptest.Server { httpClient := http.DefaultClient @@ -34,7 +58,7 @@ func addGateway(enableART bool) func(setup *federationtesting.FederationSetup) * {Name: "reviews", URL: setup.ReviewsUpstreamServer.URL}, }, httpClient) - gtw := gateway.Handler(abstractlogger.NoopLogger, poller, httpClient, enableART) + gtw := gateway.Handler(abstractlogger.NoopLogger, poller, httpClient, opts.enableART, opts.withLoaderCache) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -52,7 +76,7 @@ func TestFederationIntegrationTestWithArt(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - setup := federationtesting.NewFederationSetup(addGateway(true)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(true))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -82,7 +106,7 @@ func TestFederationIntegrationTestWithArt(t *testing.T) { func TestFederationIntegrationTest(t *testing.T) { t.Run("single upstream query operation", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -92,7 +116,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("query spans multiple federated servers", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -102,7 +126,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("mutation operation with variables", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -116,7 +140,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("union query", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -126,7 +150,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("interface query", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -136,7 +160,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("subscription query through WebSocket transport", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -155,7 +179,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Multiple queries and nested fragments", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -205,7 +229,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Multiple queries with __typename", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -237,7 +261,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Query that returns union", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -316,7 +340,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Object response type with interface and object fragment", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -335,7 +359,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Interface response type with object fragment", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -355,7 +379,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("recursive fragment", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -365,7 +389,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("empty fragment", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -375,7 +399,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("empty fragment variant", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -385,7 +409,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Union response type with interface fragments", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -429,7 +453,7 @@ func TestFederationIntegrationTest(t *testing.T) { // Duplicated properties (and therefore invalid JSON) are usually removed during normalization processes. // It is not yet decided whether this should be addressed before these normalization processes. t.Run("Complex nesting", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -441,7 +465,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("More complex nesting", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -453,7 +477,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Multiple nested interfaces", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -465,7 +489,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Multiple nested unions", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -477,7 +501,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("More complex nesting typename variant", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -489,7 +513,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Abstract object", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -501,7 +525,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Abstract object non shared", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -513,7 +537,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Abstract object nested", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -525,7 +549,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Abstract object nested reverse", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -537,7 +561,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Abstract object mixed", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -549,7 +573,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Abstract interface field", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -561,7 +585,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Merged fields are still resolved", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) diff --git a/execution/federationtesting/gateway/gateway.go b/execution/federationtesting/gateway/gateway.go index ffc62eb7d..728736a36 100644 --- a/execution/federationtesting/gateway/gateway.go +++ b/execution/federationtesting/gateway/gateway.go @@ -34,11 +34,13 @@ func NewGateway( gqlHandlerFactory HandlerFactory, httpClient *http.Client, logger log.Logger, + loaderCaches map[string]resolve.LoaderCache, ) *Gateway { return &Gateway{ gqlHandlerFactory: gqlHandlerFactory, httpClient: httpClient, logger: logger, + loaderCaches: loaderCaches, mu: &sync.Mutex{}, readyCh: make(chan struct{}), @@ -50,6 +52,7 @@ type Gateway struct { gqlHandlerFactory HandlerFactory httpClient *http.Client logger log.Logger + loaderCaches map[string]resolve.LoaderCache gqlHandler http.Handler mu *sync.Mutex @@ -82,6 +85,7 @@ func (g *Gateway) UpdateDataSources(subgraphsConfigs []engine.SubgraphConfigurat executionEngine, err := engine.NewExecutionEngine(ctx, g.logger, engineConfig, resolve.ResolverOptions{ MaxConcurrency: 1024, + Caches: g.loaderCaches, }) if err != nil { g.logger.Error("create engine: %v", log.Error(err)) diff --git a/execution/federationtesting/gateway/main.go b/execution/federationtesting/gateway/main.go index 61f97b0a1..39da34d0f 100644 --- a/execution/federationtesting/gateway/main.go +++ b/execution/federationtesting/gateway/main.go @@ -10,6 +10,7 @@ import ( "github.com/wundergraph/graphql-go-tools/execution/engine" http2 "github.com/wundergraph/graphql-go-tools/execution/federationtesting/gateway/http" "github.com/wundergraph/graphql-go-tools/execution/graphql" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) func NewDatasource(serviceConfig []ServiceConfig, httpClient *http.Client) *DatasourcePollerPoller { @@ -24,6 +25,7 @@ func Handler( datasourcePoller *DatasourcePollerPoller, httpClient *http.Client, enableART bool, + loaderCaches map[string]resolve.LoaderCache, ) *Gateway { upgrader := &ws.DefaultHTTPUpgrader upgrader.Header = http.Header{} @@ -35,7 +37,7 @@ func Handler( return http2.NewGraphqlHTTPHandler(schema, engine, upgrader, logger, enableART) } - gateway := NewGateway(gqlHandlerFactory, httpClient, logger) + gateway := NewGateway(gqlHandlerFactory, httpClient, logger, loaderCaches) datasourceWatcher.Register(gateway) diff --git a/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go b/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go index c8a0a17e1..2adfd6e38 100644 --- a/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go +++ b/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go @@ -107,6 +107,7 @@ func (d *createConcreteSingleFetchTypes) createEntityBatchFetch(fetch *resolve.S }, DataSource: fetch.DataSource, PostProcessing: fetch.PostProcessing, + Caching: fetch.Caching, } } @@ -141,5 +142,6 @@ func (d *createConcreteSingleFetchTypes) createEntityFetch(fetch *resolve.Single }, DataSource: fetch.DataSource, PostProcessing: fetch.PostProcessing, + Caching: fetch.Caching, } } diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 0d0bed5f5..204b0b734 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -149,6 +149,8 @@ type ResolverOptions struct { MaxSubscriptionFetchTimeout time.Duration // ApolloRouterCompatibilitySubrequestHTTPError is a compatibility flag for Apollo Router, it is used to handle HTTP errors in subrequests differently ApolloRouterCompatibilitySubrequestHTTPError bool + + Caches map[string]LoaderCache } // New returns a new Resolver, ctx.Done() is used to cancel all active subscriptions & streams @@ -231,6 +233,7 @@ func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{ allowedSubgraphErrorFields: allowedErrorFields, allowAllErrorExtensionFields: options.AllowAllErrorExtensionFields, apolloRouterCompatibilitySubrequestHTTPError: options.ApolloRouterCompatibilitySubrequestHTTPError, + caches: options.Caches, }, } } From b57a86e03cbf511bf71fa6986039b6450bb4fcca Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 5 Aug 2025 23:45:55 +0200 Subject: [PATCH 07/61] chore: add federation cache test --- execution/engine/federation_caching_test.go | 297 ++++++++++++++++++++ 1 file changed, 297 insertions(+) create mode 100644 execution/engine/federation_caching_test.go diff --git a/execution/engine/federation_caching_test.go b/execution/engine/federation_caching_test.go new file mode 100644 index 000000000..3eb61f977 --- /dev/null +++ b/execution/engine/federation_caching_test.go @@ -0,0 +1,297 @@ +package engine_test + +import ( + "context" + "fmt" + "net/http" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/execution/federationtesting" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +func TestFederationCaching(t *testing.T) { + t.Run("query spans multiple federated servers", func(t *testing.T) { + defaultCache := NewFakeLoaderCache() + caches := map[string]resolve.LoaderCache{ + "default": defaultCache, + } + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false), withLoaderCache(caches))) + t.Cleanup(setup.Close) + gqlClient := NewGraphqlClient(http.DefaultClient) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + resp := gqlClient.Query(ctx, setup.GatewayServer.URL, testQueryPath("queries/multiple_upstream.query"), nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) + resp = gqlClient.Query(ctx, setup.GatewayServer.URL, testQueryPath("queries/multiple_upstream.query"), nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) + defaultCache.mu.Lock() + defer defaultCache.mu.Unlock() + _, ok := defaultCache.storage[`{"__typename":"Product","upc":"top-1"}`] + assert.True(t, ok) + _, ok = defaultCache.storage[`{"__typename":"Product","upc":"top-2"}`] + assert.True(t, ok) + }) +} + +type cacheEntry struct { + data []byte + expiresAt *time.Time +} + +type FakeLoaderCache struct { + mu sync.RWMutex + storage map[string]cacheEntry +} + +func NewFakeLoaderCache() *FakeLoaderCache { + return &FakeLoaderCache{ + storage: make(map[string]cacheEntry), + } +} + +func (f *FakeLoaderCache) cleanupExpired() { + now := time.Now() + for key, entry := range f.storage { + if entry.expiresAt != nil && now.After(*entry.expiresAt) { + delete(f.storage, key) + } + } +} + +func (f *FakeLoaderCache) Get(ctx context.Context, keys []string) ([][]byte, error) { + f.mu.Lock() + defer f.mu.Unlock() + + // Clean up expired entries before executing command + f.cleanupExpired() + + result := make([][]byte, len(keys)) + for i, key := range keys { + if entry, exists := f.storage[key]; exists { + // Make a copy of the data to prevent external modifications + dataCopy := make([]byte, len(entry.data)) + copy(dataCopy, entry.data) + result[i] = dataCopy + } else { + result[i] = nil + } + } + return result, nil +} + +func (f *FakeLoaderCache) Set(ctx context.Context, keys []string, items [][]byte, ttl time.Duration) error { + if len(keys) != len(items) { + return nil // Silently ignore mismatched lengths like Redis would + } + + f.mu.Lock() + defer f.mu.Unlock() + + // Clean up expired entries before executing command + f.cleanupExpired() + + for i, key := range keys { + entry := cacheEntry{ + // Make a copy of the data to prevent external modifications + data: make([]byte, len(items[i])), + } + copy(entry.data, items[i]) + + // If ttl is 0, store without expiration + if ttl > 0 { + expiresAt := time.Now().Add(ttl) + entry.expiresAt = &expiresAt + } + + f.storage[key] = entry + } + return nil +} + +func (f *FakeLoaderCache) Delete(ctx context.Context, keys []string) error { + f.mu.Lock() + defer f.mu.Unlock() + + // Clean up expired entries before executing command + f.cleanupExpired() + + for _, key := range keys { + delete(f.storage, key) + } + return nil +} + +// TestFakeLoaderCache tests the cache implementation itself +func TestFakeLoaderCache(t *testing.T) { + ctx := context.Background() + cache := NewFakeLoaderCache() + + t.Run("SetAndGet", func(t *testing.T) { + // Test basic set and get + keys := []string{"key1", "key2", "key3"} + items := [][]byte{[]byte("value1"), []byte("value2"), []byte("value3")} + + err := cache.Set(ctx, keys, items, 0) // No TTL + require.NoError(t, err) + + // Get all keys + result, err := cache.Get(ctx, keys) + require.NoError(t, err) + require.Len(t, result, 3) + assert.Equal(t, "value1", string(result[0])) + assert.Equal(t, "value2", string(result[1])) + assert.Equal(t, "value3", string(result[2])) + + // Get partial keys + result, err = cache.Get(ctx, []string{"key2", "key4", "key1"}) + require.NoError(t, err) + require.Len(t, result, 3) + assert.Equal(t, "value2", string(result[0])) + assert.Nil(t, result[1]) // key4 doesn't exist + assert.Equal(t, "value1", string(result[2])) + }) + + t.Run("Delete", func(t *testing.T) { + // Set some keys + keys := []string{"del1", "del2", "del3"} + items := [][]byte{[]byte("v1"), []byte("v2"), []byte("v3")} + err := cache.Set(ctx, keys, items, 0) + require.NoError(t, err) + + // Delete some keys + err = cache.Delete(ctx, []string{"del1", "del3"}) + require.NoError(t, err) + + // Check remaining keys + result, err := cache.Get(ctx, keys) + require.NoError(t, err) + assert.Nil(t, result[0]) // del1 was deleted + assert.Equal(t, "v2", string(result[1])) // del2 still exists + assert.Nil(t, result[2]) // del3 was deleted + }) + + t.Run("TTL", func(t *testing.T) { + // Set with 50ms TTL + keys := []string{"ttl1", "ttl2"} + items := [][]byte{[]byte("expire1"), []byte("expire2")} + err := cache.Set(ctx, keys, items, 50*time.Millisecond) + require.NoError(t, err) + + // Immediately get - should exist + result, err := cache.Get(ctx, keys) + require.NoError(t, err) + assert.Equal(t, "expire1", string(result[0])) + assert.Equal(t, "expire2", string(result[1])) + + // Wait for expiration + time.Sleep(60 * time.Millisecond) + + // Get again - should be nil + result, err = cache.Get(ctx, keys) + require.NoError(t, err) + assert.Nil(t, result[0]) + assert.Nil(t, result[1]) + }) + + t.Run("MixedTTL", func(t *testing.T) { + // Set some with TTL, some without + err := cache.Set(ctx, []string{"perm1"}, [][]byte{[]byte("permanent")}, 0) + require.NoError(t, err) + + err = cache.Set(ctx, []string{"temp1"}, [][]byte{[]byte("temporary")}, 50*time.Millisecond) + require.NoError(t, err) + + // Wait for temporary to expire + time.Sleep(60 * time.Millisecond) + + // Check both + result, err := cache.Get(ctx, []string{"perm1", "temp1"}) + require.NoError(t, err) + assert.Equal(t, "permanent", string(result[0])) // Still exists + assert.Nil(t, result[1]) // Expired + }) + + t.Run("ThreadSafety", func(t *testing.T) { + // Test concurrent access + done := make(chan bool) + + // Writer goroutine + go func() { + for i := 0; i < 100; i++ { + key := fmt.Sprintf("concurrent_%d", i) + value := fmt.Sprintf("value_%d", i) + err := cache.Set(ctx, []string{key}, [][]byte{[]byte(value)}, 0) + assert.NoError(t, err) + } + done <- true + }() + + // Reader goroutine + go func() { + for i := 0; i < 100; i++ { + key := fmt.Sprintf("concurrent_%d", i%50) + _, err := cache.Get(ctx, []string{key}) + assert.NoError(t, err) + } + done <- true + }() + + // Deleter goroutine + go func() { + for i := 0; i < 50; i++ { + key := fmt.Sprintf("concurrent_%d", i*2) + err := cache.Delete(ctx, []string{key}) + assert.NoError(t, err) + } + done <- true + }() + + // Wait for all goroutines + <-done + <-done + <-done + }) + + t.Run("ResultLengthMatchesKeysLength", func(t *testing.T) { + // Test that result length always matches input keys length + + // Set some data + err := cache.Set(ctx, []string{"exist1", "exist3"}, [][]byte{[]byte("data1"), []byte("data3")}, 0) + require.NoError(t, err) + + // Request mix of existing and non-existing keys + keys := []string{"exist1", "missing1", "exist3", "missing2", "missing3"} + result, err := cache.Get(ctx, keys) + require.NoError(t, err) + + // Verify length matches exactly + assert.Len(t, result, len(keys), "Result length must match keys length") + assert.Len(t, result, 5, "Should return exactly 5 results") + + // Verify correct values + assert.Equal(t, "data1", string(result[0])) // exist1 + assert.Nil(t, result[1]) // missing1 + assert.Equal(t, "data3", string(result[2])) // exist3 + assert.Nil(t, result[3]) // missing2 + assert.Nil(t, result[4]) // missing3 + + // Test with all missing keys + allMissingKeys := []string{"missing4", "missing5", "missing6"} + result, err = cache.Get(ctx, allMissingKeys) + require.NoError(t, err) + assert.Len(t, result, 3, "Should return 3 results for 3 keys") + assert.Nil(t, result[0]) + assert.Nil(t, result[1]) + assert.Nil(t, result[2]) + + // Test with empty keys + result, err = cache.Get(ctx, []string{}) + require.NoError(t, err) + assert.Len(t, result, 0, "Should return empty slice for empty keys") + }) +} From 057e5571c33cc30c957c966e5947ab89bf5595a3 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 5 Aug 2025 23:46:36 +0200 Subject: [PATCH 08/61] chore: add cache config to fetch & loader --- .../graphql_datasource/graphql_datasource.go | 9 +- v2/pkg/engine/plan/visitor.go | 18 +++ v2/pkg/engine/resolve/fetch.go | 18 +++ v2/pkg/engine/resolve/loader.go | 153 +++++++++++++++++- 4 files changed, 189 insertions(+), 9 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 6f35dbee3..d357fcf95 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -843,12 +843,13 @@ func (p *Planner[T]) addRepresentationsVariable() { return } - variable, _ := p.variables.AddVariable(p.buildRepresentationsVariable()) + representationsVariable := resolve.NewResolvableObjectVariable(p.buildRepresentationsVariable()) + variable, _ := p.variables.AddVariable(representationsVariable) p.upstreamVariables, _ = sjson.SetRawBytes(p.upstreamVariables, "representations", []byte(fmt.Sprintf("[%s]", variable))) } -func (p *Planner[T]) buildRepresentationsVariable() resolve.Variable { +func (p *Planner[T]) buildRepresentationsVariable() *resolve.Object { objects := make([]*resolve.Object, 0, len(p.dataSourcePlannerConfig.RequiredFields)) for _, cfg := range p.dataSourcePlannerConfig.RequiredFields { node, err := buildRepresentationVariableNode(p.visitor.Definition, cfg, p.dataSourceConfig.FederationConfiguration()) @@ -860,9 +861,7 @@ func (p *Planner[T]) buildRepresentationsVariable() resolve.Variable { objects = append(objects, node) } - return resolve.NewResolvableObjectVariable( - mergeRepresentationVariableNodes(objects), - ) + return mergeRepresentationVariableNodes(objects) } func (p *Planner[T]) addRepresentationsQuery() { diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 7b6889d18..295b09bea 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -7,6 +7,7 @@ import ( "regexp" "slices" "strings" + "time" "github.com/wundergraph/astjson" @@ -1633,6 +1634,23 @@ func (v *Visitor) configureFetch(internal *objectFetchConfiguration, external re dataSourceType := reflect.TypeOf(external.DataSource).String() dataSourceType = strings.TrimPrefix(dataSourceType, "*") + cacheKeyTemplate := &resolve.InputTemplate{ + SetTemplateOutputToNullOnVariableNull: false, + Segments: make([]resolve.TemplateSegment, len(external.Variables)), + } + + for i, variable := range external.Variables { + segment := variable.TemplateSegment() + cacheKeyTemplate.Segments[i] = segment + } + + external.Caching = resolve.FetchCacheConfiguration{ + Enabled: true, + CacheName: "default", + TTL: time.Second * time.Duration(30), + CacheKeyTemplate: cacheKeyTemplate, + } + singleFetch := &resolve.SingleFetch{ FetchConfiguration: external, FetchDependencies: resolve.FetchDependencies{ diff --git a/v2/pkg/engine/resolve/fetch.go b/v2/pkg/engine/resolve/fetch.go index 448b9fac5..fc4f84a10 100644 --- a/v2/pkg/engine/resolve/fetch.go +++ b/v2/pkg/engine/resolve/fetch.go @@ -4,6 +4,7 @@ import ( "encoding/json" "slices" "strings" + "time" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) @@ -168,6 +169,7 @@ type BatchEntityFetch struct { Trace *DataSourceLoadTrace Info *FetchInfo CoordinateDependencies []FetchDependency + Caching FetchCacheConfiguration } func (b *BatchEntityFetch) Dependencies() *FetchDependencies { @@ -215,6 +217,7 @@ type EntityFetch struct { DataSourceIdentifier []byte Trace *DataSourceLoadTrace Info *FetchInfo + Caching FetchCacheConfiguration } func (e *EntityFetch) Dependencies() *FetchDependencies { @@ -325,6 +328,8 @@ type FetchConfiguration struct { // OperationName is non-empty when the operation name is propagated the downstream subgraph fetch. OperationName string + + Caching FetchCacheConfiguration } func (fc *FetchConfiguration) Equals(other *FetchConfiguration) bool { @@ -360,6 +365,19 @@ func (fc *FetchConfiguration) Equals(other *FetchConfiguration) bool { return true } +type FetchCacheConfiguration struct { + // Enabled indicates if caching is enabled for this fetch + Enabled bool + // CacheName is the name of the cache to use for this fetch + CacheName string + // TTL is the time to live which will be set for new cache entries + TTL time.Duration + // CacheKeyTemplate can be used to render a cache key for the fetch. + // In case of a root fetch, the variables will be one or more field arguments + // For entity fetches, the variables will be a single Object Variable with @key and @requires fields + CacheKeyTemplate *InputTemplate +} + // FetchDependency explains how a GraphCoordinate depends on other GraphCoordinates from other fetches type FetchDependency struct { // Coordinate is the type+field which depends on one or more FetchDependencyOrigin diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 5abe2e4bc..3efe4d6e6 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -114,6 +114,13 @@ type result struct { loaderHookContext context.Context httpResponseContext *httpclient.ResponseContext + + cache LoaderCache + cacheMustBeUpdated bool + cacheKeys []string + cacheItems []*astjson.Value + cacheTTL time.Duration + cacheSkippedFetch bool } func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { @@ -135,6 +142,8 @@ type Loader struct { ctx *Context info *GraphQLResponseInfo + caches map[string]LoaderCache + propagateSubgraphErrors bool propagateSubgraphStatusCodes bool subgraphErrorPropagationMode SubgraphErrorPropagationMode @@ -251,9 +260,15 @@ func (l *Loader) resolveSingle(item *FetchItem) error { res := &result{ out: &bytes.Buffer{}, } - err := l.loadSingleFetch(l.ctx.ctx, f, item, items, res) + skip, err := l.tryCacheLoadFetch(l.ctx.ctx, f.Info, f.Caching, items, res) if err != nil { - return err + return errors.WithStack(err) + } + if !skip { + err = l.loadSingleFetch(l.ctx.ctx, f, item, items, res) + if err != nil { + return err + } } err = l.mergeResult(item, res, items) if l.ctx.LoaderHooks != nil { @@ -265,10 +280,16 @@ func (l *Loader) resolveSingle(item *FetchItem) error { res := &result{ out: &bytes.Buffer{}, } - err := l.loadBatchEntityFetch(l.ctx.ctx, item, f, items, res) + skip, err := l.tryCacheLoadFetch(l.ctx.ctx, f.Info, f.Caching, items, res) if err != nil { return errors.WithStack(err) } + if !skip { + err = l.loadBatchEntityFetch(l.ctx.ctx, item, f, items, res) + if err != nil { + return errors.WithStack(err) + } + } err = l.mergeResult(item, res, items) if l.ctx.LoaderHooks != nil { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) @@ -278,10 +299,16 @@ func (l *Loader) resolveSingle(item *FetchItem) error { res := &result{ out: &bytes.Buffer{}, } - err := l.loadEntityFetch(l.ctx.ctx, item, f, items, res) + skip, err := l.tryCacheLoadFetch(l.ctx.ctx, f.Info, f.Caching, items, res) if err != nil { return errors.WithStack(err) } + if !skip { + err = l.loadEntityFetch(l.ctx.ctx, item, f, items, res) + if err != nil { + return errors.WithStack(err) + } + } err = l.mergeResult(item, res, items) if l.ctx.LoaderHooks != nil { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) @@ -415,10 +442,81 @@ func (l *Loader) itemsData(items []*astjson.Value) *astjson.Value { return arr } +type LoaderCache interface { + Get(ctx context.Context, keys []string) ([][]byte, error) + Set(ctx context.Context, keys []string, items [][]byte, ttl time.Duration) error + Delete(ctx context.Context, keys []string) error +} + +func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg FetchCacheConfiguration, inputItems []*astjson.Value, res *result) (skipFetch bool, err error) { + if !cfg.Enabled { + return false, nil + } + if cfg.CacheKeyTemplate == nil { + return false, nil + } + if l.caches == nil { + return false, nil + } + res.cache = l.caches[cfg.CacheName] + if res.cache == nil { + return false, nil + } + res.cacheKeys = make([]string, 0, len(inputItems)) + buf := &bytes.Buffer{} + for _, item := range inputItems { + err = cfg.CacheKeyTemplate.Render(l.ctx, item, buf) + if err != nil { + return false, err + } + if buf.Len() == 0 { + // If the cache key is empty, we skip the cache + continue + } + res.cacheKeys = append(res.cacheKeys, buf.String()) + buf.Reset() + } + if len(res.cacheKeys) == 0 { + // If no cache keys were generated, we skip the cache + return false, nil + } + cachedItems, err := res.cache.Get(ctx, res.cacheKeys) + if err != nil { + return false, err + } + res.cacheItems = make([]*astjson.Value, len(cachedItems)) + for i := range cachedItems { + if cachedItems[i] == nil { + res.cacheItems[i] = astjson.NullValue + continue + } + res.cacheItems[i], err = astjson.ParseBytesWithoutCache(cachedItems[i]) + if err != nil { + return false, errors.WithStack(err) + } + } + missing, canSkip := l.canSkipFetch(info, res.cacheItems) + if canSkip { + res.cacheSkippedFetch = true + return true, nil + } + res.cacheMustBeUpdated = true + res.cacheTTL = cfg.TTL + _ = missing + return false, nil +} + func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchItem, items []*astjson.Value, res *result) error { switch f := fetch.(type) { case *SingleFetch: res.out = &bytes.Buffer{} + skip, err := l.tryCacheLoadFetch(ctx, f.Info, f.Caching, items, res) + if err != nil { + return errors.WithStack(err) + } + if skip { + return nil + } return l.loadSingleFetch(ctx, f, fetchItem, items, res) case *ParallelListItemFetch: results := make([]*result, len(items)) @@ -451,9 +549,23 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchIte return nil case *EntityFetch: res.out = &bytes.Buffer{} + skip, err := l.tryCacheLoadFetch(ctx, f.Info, f.Caching, items, res) + if err != nil { + return errors.WithStack(err) + } + if skip { + return nil + } return l.loadEntityFetch(ctx, fetchItem, f, items, res) case *BatchEntityFetch: res.out = &bytes.Buffer{} + skip, err := l.tryCacheLoadFetch(ctx, f.Info, f.Caching, items, res) + if err != nil { + return errors.WithStack(err) + } + if skip { + return nil + } return l.loadBatchEntityFetch(ctx, fetchItem, f, items, res) } return nil @@ -513,12 +625,24 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson } return nil } + if res.cacheSkippedFetch { + for i, item := range res.cacheItems { + _, _, err := astjson.MergeValues(items[i], item) + if err != nil { + return l.renderErrorsFailedToFetch(fetchItem, res, "invalid cache item") + } + } + return nil + } if res.fetchSkipped { return nil } if res.out.Len() == 0 { return l.renderErrorsFailedToFetch(fetchItem, res, emptyGraphQLResponse) } + if res.cacheMustBeUpdated { + defer l.updateCache(res, items) + } value, err := astjson.ParseBytesWithoutCache(res.out.Bytes()) if err != nil { // Fall back to status code if parsing fails and non-2XX @@ -658,6 +782,27 @@ func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem, out *bytes.Buffe return nil } +func (l *Loader) updateCache(res *result, items []*astjson.Value) { + if res.cache == nil || len(res.cacheKeys) == 0 || len(res.cacheItems) == 0 { + return + } + var ( + keys []string + cacheItems [][]byte + ) + for i, item := range res.cacheItems { + if item != nil && item.Type() == astjson.TypeNull && items[i] != nil && items[i].Type() != astjson.TypeNull { + keys = append(keys, res.cacheKeys[i]) + value := items[i].MarshalTo(nil) + cacheItems = append(cacheItems, value) + } + } + err := res.cache.Set(context.Background(), keys, cacheItems, res.cacheTTL) + if err != nil { + panic(err) + } +} + func (l *Loader) appendSubgraphError(res *result, fetchItem *FetchItem, value *astjson.Value, values []*astjson.Value) error { // print them into the buffer to be able to parse them errorsJSON := value.MarshalTo(nil) From b585cd4bb4184dd9ed72299fe9f70358e044266c Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 17 Sep 2025 12:13:56 +0200 Subject: [PATCH 09/61] chore: fix tests --- .../graphql_datasource_federation_test.go | 165 +++++++++++++++++- .../graphql_datasource_test.go | 2 +- .../datasourcetesting/datasourcetesting.go | 28 +++ v2/pkg/engine/plan/configuration.go | 5 + v2/pkg/engine/plan/planner_test.go | 13 +- v2/pkg/engine/plan/visitor.go | 47 ++--- 6 files changed, 235 insertions(+), 25 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go index f4a41bedc..3968c9db9 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go @@ -2,6 +2,7 @@ package graphql_datasource import ( "testing" + "time" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" . "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasourcetesting" @@ -1557,6 +1558,14 @@ func TestGraphQLDataSourceFederation(t *testing.T) { Input: `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {__typename id info {a b}}}}"}}`, DataSource: &Source{}, PostProcessing: DefaultPostProcessingConfiguration, + Caching: resolve.FetchCacheConfiguration{ + Enabled: true, + CacheName: "default", + TTL: time.Second * 30, + CacheKeyTemplate: &resolve.InputTemplate{ + Segments: []resolve.TemplateSegment{}, + }, + }, }, Info: &resolve.FetchInfo{ DataSourceID: "user.service", @@ -1644,6 +1653,106 @@ func TestGraphQLDataSourceFederation(t *testing.T) { HasAuthorizationRule: true, }, }, + CoordinateDependencies: []resolve.FetchDependency{ + { + Coordinate: resolve.GraphCoordinate{ + TypeName: "Account", + FieldName: "name", + }, + IsUserRequested: true, + DependsOn: []resolve.FetchDependencyOrigin{ + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Account", + FieldName: "id", + }, + IsKey: true, + IsRequires: false, + }, + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Account", + FieldName: "info", + }, + IsKey: true, + IsRequires: false, + }, + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Info", + FieldName: "a", + }, + IsKey: true, + IsRequires: false, + }, + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Info", + FieldName: "b", + }, + IsKey: true, + IsRequires: false, + }, + }, + }, + { + Coordinate: resolve.GraphCoordinate{ + TypeName: "Account", + FieldName: "shippingInfo", + }, + IsUserRequested: true, + DependsOn: []resolve.FetchDependencyOrigin{ + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Account", + FieldName: "id", + }, + IsKey: true, + IsRequires: false, + }, + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Account", + FieldName: "info", + }, + IsKey: true, + IsRequires: false, + }, + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Info", + FieldName: "a", + }, + IsKey: true, + IsRequires: false, + }, + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Info", + FieldName: "b", + }, + IsKey: true, + IsRequires: false, + }, + }, + }, + }, OperationType: ast.OperationTypeQuery, ProvidesData: &resolve.Object{ Fields: []*resolve.Field{ @@ -1731,6 +1840,60 @@ func TestGraphQLDataSourceFederation(t *testing.T) { }, }, PostProcessing: SingleEntityPostProcessingConfiguration, + Caching: resolve.FetchCacheConfiguration{ + Enabled: true, + CacheName: "default", + TTL: time.Second * 30, + CacheKeyTemplate: &resolve.InputTemplate{ + Segments: []resolve.TemplateSegment{ + { + SegmentType: resolve.VariableSegmentType, + VariableKind: resolve.ResolvableObjectVariableKind, + Renderer: resolve.NewGraphQLVariableResolveRenderer(&resolve.Object{ + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("__typename"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.Scalar{ + Path: []string{"id"}, + }, + }, + { + Name: []byte("info"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.Object{ + Path: []string{"info"}, + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("a"), + Value: &resolve.Scalar{ + Path: []string{"a"}, + }, + }, + { + Name: []byte("b"), + Value: &resolve.Scalar{ + Path: []string{"b"}, + }, + }, + }, + }, + }, + }, + }), + }, + }, + }, + }, }, }, "user.account", resolve.ObjectPath("user"), resolve.ObjectPath("account")), ), @@ -1865,7 +2028,7 @@ func TestGraphQLDataSourceFederation(t *testing.T) { }, }, }, - planConfiguration, WithFieldInfo(), WithDefaultPostProcessor())) + planConfiguration, WithFieldInfo(), WithDefaultPostProcessor(), WithFieldDependencies(), WithEntityCaching(), WithFetchProvidesData())) }) t.Run("composite keys variant", func(t *testing.T) { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index ff02fb4b3..5ad4dbf87 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -783,7 +783,7 @@ func TestGraphQLDataSource(t *testing.T) { }, }, DisableResolveFieldPositions: true, - }, WithFieldInfo(), WithDefaultPostProcessor())) + }, WithFieldInfo(), WithDefaultPostProcessor(), WithFetchProvidesData())) t.Run("selections on interface type", RunTest(interfaceSelectionSchema, ` query MyQuery { diff --git a/v2/pkg/engine/datasourcetesting/datasourcetesting.go b/v2/pkg/engine/datasourcetesting/datasourcetesting.go index 598766644..66495b615 100644 --- a/v2/pkg/engine/datasourcetesting/datasourcetesting.go +++ b/v2/pkg/engine/datasourcetesting/datasourcetesting.go @@ -34,6 +34,8 @@ type testOptions struct { withPrintPlan bool withFieldDependencies bool withFetchReasons bool + withEntityCaching bool + withFetchProvidesData bool } func WithPostProcessors(postProcessors ...*postprocess.Processor) func(*testOptions) { @@ -84,6 +86,22 @@ func WithFetchReasons() func(*testOptions) { } } +func WithEntityCaching() func(*testOptions) { + return func(o *testOptions) { + o.withFieldInfo = true + o.withFieldDependencies = true + o.withEntityCaching = true + } +} + +func WithFetchProvidesData() func(*testOptions) { + return func(o *testOptions) { + o.withFieldInfo = true + o.withFieldDependencies = true + o.withFetchProvidesData = true + } +} + func RunWithPermutations(t *testing.T, definition, operation, operationName string, expectedPlan plan.Plan, config plan.Configuration, options ...func(*testOptions)) { t.Helper() @@ -143,6 +161,8 @@ func RunTestWithVariables(definition, operation, operationName, variables string // by default, we don't want to have field info in the tests because it's too verbose config.DisableIncludeInfo = true config.DisableIncludeFieldDependencies = true + config.DisableEntityCaching = true + config.DisableFetchProvidesData = true opts := &testOptions{} for _, o := range options { @@ -161,6 +181,14 @@ func RunTestWithVariables(definition, operation, operationName, variables string config.BuildFetchReasons = true } + if opts.withEntityCaching { + config.DisableEntityCaching = false + } + + if opts.withFetchProvidesData { + config.DisableFetchProvidesData = false + } + if opts.skipReason != "" { t.Skip(opts.skipReason) } diff --git a/v2/pkg/engine/plan/configuration.go b/v2/pkg/engine/plan/configuration.go index 215bbbcbd..9d64934ee 100644 --- a/v2/pkg/engine/plan/configuration.go +++ b/v2/pkg/engine/plan/configuration.go @@ -39,6 +39,11 @@ type Configuration struct { // It may be enabled by some other components of the engine. // It requires DisableIncludeInfo and DisableIncludeFieldDependencies set to false. BuildFetchReasons bool + + // DisableEntityCaching disables planning of entity caching behavior or generating relevant metadata + DisableEntityCaching bool + // DisableFetchProvidesData disables planning of meta information about which fields are provided by a fetch + DisableFetchProvidesData bool } type DebugConfiguration struct { diff --git a/v2/pkg/engine/plan/planner_test.go b/v2/pkg/engine/plan/planner_test.go index 270140381..be00907d8 100644 --- a/v2/pkg/engine/plan/planner_test.go +++ b/v2/pkg/engine/plan/planner_test.go @@ -172,6 +172,7 @@ func TestPlanner_Plan(t *testing.T) { }, Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, DataSources: []DataSource{testDefinitionDSConfiguration}, })) @@ -226,6 +227,7 @@ func TestPlanner_Plan(t *testing.T) { }, Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, DataSources: []DataSource{testDefinitionDSConfiguration}, })) @@ -292,6 +294,7 @@ func TestPlanner_Plan(t *testing.T) { }, Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, DataSources: []DataSource{testDefinitionDSConfiguration}, })) @@ -363,6 +366,7 @@ func TestPlanner_Plan(t *testing.T) { }, Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, DataSources: []DataSource{testDefinitionDSConfiguration}, })) @@ -425,14 +429,16 @@ func TestPlanner_Plan(t *testing.T) { }, Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, DataSources: []DataSource{testDefinitionDSConfiguration}, })) }) t.Run("operation selection", func(t *testing.T) { cfg := Configuration{ - DataSources: []DataSource{testDefinitionDSConfiguration}, - DisableIncludeInfo: true, + DataSources: []DataSource{testDefinitionDSConfiguration}, + DisableIncludeInfo: true, + DisableEntityCaching: true, } t.Run("should successfully plan a single named query by providing an operation name", test(testDefinition, ` @@ -585,6 +591,7 @@ func TestPlanner_Plan(t *testing.T) { Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, Fields: FieldConfigurations{ FieldConfiguration{ TypeName: "Character", @@ -644,6 +651,7 @@ func TestPlanner_Plan(t *testing.T) { Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, Fields: FieldConfigurations{ FieldConfiguration{ TypeName: "Character", @@ -703,6 +711,7 @@ func TestPlanner_Plan(t *testing.T) { Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, DataSources: []DataSource{dsConfig}, }, )) diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index a3dff3c3d..1b0646290 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -1642,21 +1642,27 @@ func (v *Visitor) configureFetch(internal *objectFetchConfiguration, external re dataSourceType := reflect.TypeOf(external.DataSource).String() dataSourceType = strings.TrimPrefix(dataSourceType, "*") - cacheKeyTemplate := &resolve.InputTemplate{ - SetTemplateOutputToNullOnVariableNull: false, - Segments: make([]resolve.TemplateSegment, len(external.Variables)), - } + if !v.Config.DisableEntityCaching { + cacheKeyTemplate := &resolve.InputTemplate{ + SetTemplateOutputToNullOnVariableNull: false, + Segments: make([]resolve.TemplateSegment, len(external.Variables)), + } - for i, variable := range external.Variables { - segment := variable.TemplateSegment() - cacheKeyTemplate.Segments[i] = segment - } + for i, variable := range external.Variables { + segment := variable.TemplateSegment() + cacheKeyTemplate.Segments[i] = segment + } - external.Caching = resolve.FetchCacheConfiguration{ - Enabled: true, - CacheName: "default", - TTL: time.Second * time.Duration(30), - CacheKeyTemplate: cacheKeyTemplate, + external.Caching = resolve.FetchCacheConfiguration{ + Enabled: true, + CacheName: "default", + TTL: time.Second * time.Duration(30), + CacheKeyTemplate: cacheKeyTemplate, + } + } else { + external.Caching = resolve.FetchCacheConfiguration{ + Enabled: false, + } } singleFetch := &resolve.SingleFetch{ @@ -1678,12 +1684,16 @@ func (v *Visitor) configureFetch(internal *objectFetchConfiguration, external re OperationType: internal.operationType, QueryPlan: external.QueryPlan, } - + if !v.Config.DisableFetchProvidesData { + // Set ProvidesData from the planner's object structure + if providesData, ok := v.plannerObjects[internal.fetchID]; ok { + singleFetch.Info.ProvidesData = providesData + } + } + singleFetch.Info.CoordinateDependencies = v.resolveFetchDependencies(internal.fetchID) if v.Config.DisableIncludeFieldDependencies { return singleFetch } - singleFetch.Info.CoordinateDependencies = v.resolveFetchDependencies(internal.fetchID) - if !v.Config.BuildFetchReasons { return singleFetch } @@ -1700,11 +1710,6 @@ func (v *Visitor) configureFetch(internal *objectFetchConfiguration, external re if _, ok := lookup[field]; ok { propagated = append(propagated, fr) } - - // Set ProvidesData from the planner's object structure - if providesData, ok := v.plannerObjects[internal.fetchID]; ok { - singleFetch.Info.ProvidesData = providesData - } } singleFetch.Info.PropagatedFetchReasons = propagated return singleFetch From f7d2b9499b7c97e367e0a317d5104ecec0982b4f Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 26 Sep 2025 15:26:32 +0200 Subject: [PATCH 10/61] chore: fix tests --- .../graphql_datasource_test.go | 25 ++- v2/pkg/engine/resolve/caching.go | 82 ++++++++++ v2/pkg/engine/resolve/caching_test.go | 153 ++++++++++++++++++ v2/pkg/engine/resolve/fetch.go | 4 +- v2/pkg/engine/resolve/loader.go | 2 +- v2/pkg/engine/resolve/variables.go | 1 - v2/pkg/engine/resolve/variables_renderer.go | 76 +++++++++ 7 files changed, 338 insertions(+), 5 deletions(-) create mode 100644 v2/pkg/engine/resolve/caching.go create mode 100644 v2/pkg/engine/resolve/caching_test.go diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 5ad4dbf87..e93a0cf42 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -398,6 +398,29 @@ func TestGraphQLDataSource(t *testing.T) { }, ), PostProcessing: DefaultPostProcessingConfiguration, + Caching: resolve.FetchCacheConfiguration{ + Enabled: true, + CacheName: "default", + TTL: 30 * time.Second, + CacheKeyTemplate: &resolve.RootQueryCacheKeyTemplate{ + Fields: []resolve.CacheKeyQueryRootField{ + { + Name: "droid", + Args: []resolve.CacheKeyQueryRootFieldArgument{ + { + Name: "id", + Variables: resolve.NewVariables( + &resolve.ContextVariable{ + Path: []string{"id"}, + Renderer: resolve.NewJSONVariableRenderer(), + }, + ), + }, + }, + }, + }, + }, + }, }, Info: &resolve.FetchInfo{ OperationType: ast.OperationTypeQuery, @@ -783,7 +806,7 @@ func TestGraphQLDataSource(t *testing.T) { }, }, DisableResolveFieldPositions: true, - }, WithFieldInfo(), WithDefaultPostProcessor(), WithFetchProvidesData())) + }, WithFieldInfo(), WithDefaultPostProcessor(), WithFetchProvidesData(), WithEntityCaching())) t.Run("selections on interface type", RunTest(interfaceSelectionSchema, ` query MyQuery { diff --git a/v2/pkg/engine/resolve/caching.go b/v2/pkg/engine/resolve/caching.go new file mode 100644 index 000000000..10593982a --- /dev/null +++ b/v2/pkg/engine/resolve/caching.go @@ -0,0 +1,82 @@ +package resolve + +import ( + "bytes" + + "github.com/wundergraph/astjson" +) + +type CacheKeyTemplate interface { + RenderCacheKey(ctx *Context, data *astjson.Value, out *bytes.Buffer) error +} + +type RootQueryCacheKeyTemplate struct { + Fields []CacheKeyQueryRootField +} + +type CacheKeyQueryRootField struct { + Name string + Args []CacheKeyQueryRootFieldArgument +} + +type CacheKeyQueryRootFieldArgument struct { + Name string + Variables InputTemplate +} + +func (r *RootQueryCacheKeyTemplate) RenderCacheKey(ctx *Context, data *astjson.Value, out *bytes.Buffer) error { + _, err := out.WriteString("Query") + if err != nil { + return err + } + + // Process each field + for _, field := range r.Fields { + _, err = out.WriteString("::") + if err != nil { + return err + } + + // Add field name + _, err = out.WriteString(field.Name) + if err != nil { + return err + } + + // Process each argument + for _, arg := range field.Args { + // Add argument separator ":" + _, err = out.WriteString(":") + if err != nil { + return err + } + + // Add argument name + _, err = out.WriteString(arg.Name) + if err != nil { + return err + } + + // Add argument separator ":" + _, err = out.WriteString(":") + if err != nil { + return err + } + + err = arg.Variables.Render(ctx, data, out) + if err != nil { + return err + } + } + } + + return nil +} + +type EntityQueryCacheKeyTemplate struct { + Keys *ResolvableObjectVariable +} + +func (e *EntityQueryCacheKeyTemplate) RenderCacheKey(ctx *Context, data *astjson.Value, out *bytes.Buffer) error { + return e.Keys.Renderer.RenderVariable(ctx.ctx, data, out) +} diff --git a/v2/pkg/engine/resolve/caching_test.go b/v2/pkg/engine/resolve/caching_test.go new file mode 100644 index 000000000..a515e598e --- /dev/null +++ b/v2/pkg/engine/resolve/caching_test.go @@ -0,0 +1,153 @@ +package resolve + +import ( + "bytes" + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/wundergraph/astjson" +) + +func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { + t.Run("single field single argument", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + Fields: []CacheKeyQueryRootField{ + { + Name: "droid", + Args: []CacheKeyQueryRootFieldArgument{ + { + Name: "id", + Variables: InputTemplate{ + SetTemplateOutputToNullOnVariableNull: true, + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"id":1}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + out := &bytes.Buffer{} + err := tmpl.RenderCacheKey(ctx, data, out) + assert.NoError(t, err) + assert.Equal(t, `Query::droid:id:1`, out.String()) + }) + + t.Run("single field multiple arguments", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + Fields: []CacheKeyQueryRootField{ + { + Name: "search", + Args: []CacheKeyQueryRootFieldArgument{ + { + Name: "term", + Variables: InputTemplate{ + SetTemplateOutputToNullOnVariableNull: true, + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"term"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + { + Name: "max", + Variables: InputTemplate{ + SetTemplateOutputToNullOnVariableNull: true, + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"max"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"term":"C3PO","max":10}`), + ctx: context.Background(), + } + out := &bytes.Buffer{} + data := astjson.MustParse(`{}`) + err := tmpl.RenderCacheKey(ctx, data, out) + assert.NoError(t, err) + assert.Equal(t, `Query::search:term:C3PO:max:10`, out.String()) + }) + + t.Run("multiple fields single argument each", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + Fields: []CacheKeyQueryRootField{ + { + Name: "droid", + Args: []CacheKeyQueryRootFieldArgument{ + { + Name: "id", + Variables: InputTemplate{ + SetTemplateOutputToNullOnVariableNull: true, + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + }, + { + Name: "user", + Args: []CacheKeyQueryRootFieldArgument{ + { + Name: "name", + Variables: InputTemplate{ + SetTemplateOutputToNullOnVariableNull: true, + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"name"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"id":1,"name":"john"}`), + ctx: context.Background(), + } + out := &bytes.Buffer{} + data := astjson.MustParse(`{}`) + err := tmpl.RenderCacheKey(ctx, data, out) + assert.NoError(t, err) + assert.Equal(t, `Query::droid:id:1::user:name:john`, out.String()) + }) +} diff --git a/v2/pkg/engine/resolve/fetch.go b/v2/pkg/engine/resolve/fetch.go index 1eab8961a..1c714d238 100644 --- a/v2/pkg/engine/resolve/fetch.go +++ b/v2/pkg/engine/resolve/fetch.go @@ -354,7 +354,7 @@ type FetchCacheConfiguration struct { // CacheKeyTemplate can be used to render a cache key for the fetch. // In case of a root fetch, the variables will be one or more field arguments // For entity fetches, the variables will be a single Object Variable with @key and @requires fields - CacheKeyTemplate *InputTemplate + CacheKeyTemplate CacheKeyTemplate } // FetchDependency explains how a GraphCoordinate depends on other GraphCoordinates from other fetches @@ -418,7 +418,7 @@ type FetchInfo struct { // with the request to the subgraph as part of the "fetch_reason" extension. // Specifically, it is created only for fields stored in the DataSource.RequireFetchReasons(). PropagatedFetchReasons []FetchReason - ProvidesData *Object + ProvidesData *Object } type GraphCoordinate struct { diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 030b58e4f..8972224ed 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -492,7 +492,7 @@ func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg Fet res.cacheKeys = make([]string, 0, len(inputItems)) buf := &bytes.Buffer{} for _, item := range inputItems { - err = cfg.CacheKeyTemplate.Render(l.ctx, item, buf) + err = cfg.CacheKeyTemplate.RenderCacheKey(l.ctx, item, buf) if err != nil { return false, err } diff --git a/v2/pkg/engine/resolve/variables.go b/v2/pkg/engine/resolve/variables.go index afc00459a..3f54993d9 100644 --- a/v2/pkg/engine/resolve/variables.go +++ b/v2/pkg/engine/resolve/variables.go @@ -11,7 +11,6 @@ const ( ObjectVariableKind HeaderVariableKind ResolvableObjectVariableKind - ListVariableKind ) const ( diff --git a/v2/pkg/engine/resolve/variables_renderer.go b/v2/pkg/engine/resolve/variables_renderer.go index 4cbb471f8..8ae58c655 100644 --- a/v2/pkg/engine/resolve/variables_renderer.go +++ b/v2/pkg/engine/resolve/variables_renderer.go @@ -277,6 +277,82 @@ func (g *GraphQLVariableRenderer) renderGraphQLValue(data *astjson.Value, out io return } +func NewCacheKeyVariableRenderer() *CacheKeyVariableRenderer { + return &CacheKeyVariableRenderer{} +} + +type CacheKeyVariableRenderer struct { +} + +func (g *CacheKeyVariableRenderer) GetKind() string { + return "cacheKey" +} + +// add renderer that renders both variable name and variable value +// before rendering, evaluate if the value contains null values +// if an object contains only null values, set the object to null +// do this recursively until reaching the root of the object + +func (g *CacheKeyVariableRenderer) RenderVariable(ctx context.Context, data *astjson.Value, out io.Writer) error { + return g.renderGraphQLValue(data, out) +} + +func (g *CacheKeyVariableRenderer) renderGraphQLValue(data *astjson.Value, out io.Writer) (err error) { + if data == nil { + _, _ = out.Write(literal.NULL) + return + } + switch data.Type() { + case astjson.TypeString: + b := data.GetStringBytes() + _, _ = out.Write(b) + case astjson.TypeObject: + _, _ = out.Write(literal.LBRACE) + o := data.GetObject() + first := true + o.Visit(func(k []byte, v *astjson.Value) { + if err != nil { + return + } + if !first { + _, _ = out.Write(literal.COMMA) + } else { + first = false + } + _, _ = out.Write(k) + _, _ = out.Write(literal.COLON) + err = g.renderGraphQLValue(v, out) + }) + if err != nil { + return err + } + _, _ = out.Write(literal.RBRACE) + case astjson.TypeNull: + _, _ = out.Write(literal.NULL) + case astjson.TypeTrue: + _, _ = out.Write(literal.TRUE) + case astjson.TypeFalse: + _, _ = out.Write(literal.FALSE) + case astjson.TypeArray: + _, _ = out.Write(literal.LBRACK) + arr := data.GetArray() + for i, value := range arr { + if i > 0 { + _, _ = out.Write(literal.COMMA) + } + err = g.renderGraphQLValue(value, out) + if err != nil { + return err + } + } + _, _ = out.Write(literal.RBRACK) + case astjson.TypeNumber: + b := data.MarshalTo(nil) + _, _ = out.Write(b) + } + return +} + func NewCSVVariableRenderer(arrayValueType JsonRootType) *CSVVariableRenderer { return &CSVVariableRenderer{ Kind: VariableRendererKindCsv, From 72ca42a7e6ba7b3f281575d441a71c229a28ab8e Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 15 Oct 2025 07:02:39 +0200 Subject: [PATCH 11/61] feat: add astjson & ArenaResolveGraphQLResponse --- go.work.sum | 5 +- v2/go.mod | 8 +- v2/go.sum | 6 +- .../astnormalization/uploads/upload_finder.go | 2 +- .../grpc_datasource/grpc_datasource.go | 6 +- .../grpc_datasource/grpc_datasource_test.go | 5 +- .../grpc_datasource/json_builder.go | 174 +++++++++--------- v2/pkg/engine/resolve/context.go | 2 +- v2/pkg/engine/resolve/loader.go | 97 +++++----- v2/pkg/engine/resolve/loader_test.go | 18 +- v2/pkg/engine/resolve/resolvable.go | 37 ++-- .../resolvable_custom_field_renderer_test.go | 4 +- v2/pkg/engine/resolve/resolvable_test.go | 52 +++--- v2/pkg/engine/resolve/resolve.go | 34 +++- v2/pkg/engine/resolve/tainted_objects_test.go | 8 +- v2/pkg/engine/resolve/variables_renderer.go | 2 +- v2/pkg/fastjsonext/fastjsonext.go | 37 ++-- v2/pkg/fastjsonext/fastjsonext_test.go | 10 +- .../variablesvalidation.go | 2 +- 19 files changed, 273 insertions(+), 236 deletions(-) diff --git a/go.work.sum b/go.work.sum index 5f48a89a0..9e675e2c3 100644 --- a/go.work.sum +++ b/go.work.sum @@ -247,6 +247,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/sjson v1.0.4 h1:UcdIRXff12Lpnu3OLtZvnc03g4vH2suXDXhBwBqmzYg= github.com/tidwall/sjson v1.0.4/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= @@ -268,6 +270,8 @@ github.com/wundergraph/astjson v0.0.0-20241210135722-15ca0ac078f8/go.mod h1:eOTL github.com/wundergraph/cosmo/composition-go v0.0.0-20240404083832-79d2290084c6/go.mod h1:Ib+rknmwn4oZFN9SQ4VMP3uF/C/tEINEug5iPQxfrPc= github.com/wundergraph/cosmo/composition-go v0.0.0-20240729154441-b20b00e892c6/go.mod h1:WbKC2jd0g6BFsMpNDRVSoQyZ0QB6sWqpRfe0/1pTah4= github.com/wundergraph/cosmo/router v0.0.0-20240404083832-79d2290084c6/go.mod h1:LS+5qlr4fQVEW7JMXXI1sz7CH5cdnqx3BNc10p+UbW4= +github.com/wundergraph/go-arena v0.0.0-20251008210416-55cb97e6f68f h1:5snewyMaIpajTu4wj22L/DgrGimICqXtUVjkZInBH3Y= +github.com/wundergraph/go-arena v0.0.0-20251008210416-55cb97e6f68f/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= github.com/xdg/scram v1.0.3 h1:nTadYh2Fs4BK2xdldEa2g5bbaZp0/+1nJMMPtPxS/to= github.com/xdg/scram v1.0.3/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= github.com/xdg/stringprep v1.0.3 h1:cmL5Enob4W83ti/ZHuZLuKD/xqJfus4fVPwE+/BDm+4= @@ -438,7 +442,6 @@ google.golang.org/genproto/googleapis/api v0.0.0-20240102182953-50ed04b92917 h1: google.golang.org/genproto/googleapis/api v0.0.0-20240102182953-50ed04b92917/go.mod h1:CmlNWB9lSezaYELKS5Ym1r44VrrbPUa7JTvw+6MbpJ0= google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1:GVIKPyP/kLIyVOgOnTwFOrvQaQUzOzGMCxgFUOEmm24= google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw= -google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe00waw= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= diff --git a/v2/go.mod b/v2/go.mod index 8ff4759fb..50365c0a9 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -24,11 +24,12 @@ require ( github.com/r3labs/sse/v2 v2.8.1 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/sebdah/goldie/v2 v2.7.1 - github.com/stretchr/testify v1.10.0 + github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/vektah/gqlparser/v2 v2.5.14 github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 + github.com/wundergraph/go-arena v0.0.1 go.uber.org/atomic v1.11.0 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.26.0 @@ -70,3 +71,8 @@ require ( gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace ( + github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 => ../../wundergraph-projects/astjson + github.com/wundergraph/go-arena v0.0.1 => ../../wundergraph-projects/go-arena +) diff --git a/v2/go.sum b/v2/go.sum index a98384ae8..f2c6a7e00 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -115,8 +115,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -129,8 +129,6 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/vektah/gqlparser/v2 v2.5.14 h1:dzLq75BJe03jjQm6n56PdH1oweB8ana42wj7E4jRy70= github.com/vektah/gqlparser/v2 v2.5.14/go.mod h1:WQQjFc+I1YIzoPvZBhUQX7waZgg3pMLi0r8KymvAE2w= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= diff --git a/v2/pkg/astnormalization/uploads/upload_finder.go b/v2/pkg/astnormalization/uploads/upload_finder.go index b69a8bef2..0fd2d44c1 100644 --- a/v2/pkg/astnormalization/uploads/upload_finder.go +++ b/v2/pkg/astnormalization/uploads/upload_finder.go @@ -74,7 +74,7 @@ func (v *UploadFinder) FindUploads(operation, definition *ast.Document, variable variables = []byte("{}") } - v.variables, err = astjson.ParseBytesWithoutCache(variables) + v.variables, err = astjson.ParseBytes(variables) if err != nil { return nil, err } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 4d9babc60..78cdce9f7 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -101,7 +101,6 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) // make gRPC calls for index, invocation := range invocations { errGrp.Go(func() error { - a := astjson.Arena{} // Invoke the gRPC method - this will populate invocation.Output methodName := fmt.Sprintf("/%s/%s", invocation.ServiceName, invocation.MethodName) @@ -113,7 +112,7 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) mu.Lock() defer mu.Unlock() - response, err := builder.marshalResponseJSON(&a, &invocation.Call.Response, invocation.Output) + response, err := builder.marshalResponseJSON(&invocation.Call.Response, invocation.Output) if err != nil { return err } @@ -135,8 +134,7 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) return nil } - a := astjson.Arena{} - root := a.NewObject() + root := astjson.ObjectValue(builder.jsonArena) for _, response := range responses { root, err = builder.mergeValues(root, response) if err != nil { diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 3ae711d51..f7340cec8 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -19,8 +19,6 @@ import ( protoref "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/dynamicpb" - "github.com/wundergraph/astjson" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" @@ -499,9 +497,8 @@ func TestMarshalResponseJSON(t *testing.T) { responseMessage := dynamicpb.NewMessage(responseMessageDesc) responseMessage.Mutable(responseMessageDesc.Fields().ByName("result")).List().Append(protoref.ValueOfMessage(productMessage)) - arena := astjson.Arena{} jsonBuilder := newJSONBuilder(nil, gjson.Result{}) - responseJSON, err := jsonBuilder.marshalResponseJSON(&arena, &response, responseMessage) + responseJSON, err := jsonBuilder.marshalResponseJSON(&response, responseMessage) require.NoError(t, err) require.Equal(t, `{"_entities":[{"__typename":"Product","id":"123","name_different":"test","price_different":123.45}]}`, responseJSON.String()) } diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go index 7c1fc81d7..8fe71a321 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -11,6 +11,7 @@ import ( protoref "google.golang.org/protobuf/reflect/protoreflect" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" ) // Standard GraphQL response paths @@ -104,6 +105,7 @@ type jsonBuilder struct { mapping *GRPCMapping // Mapping configuration for GraphQL to gRPC translation variables gjson.Result // GraphQL variables containing entity representations indexMap indexMap // Entity index mapping for federation ordering + jsonArena arena.Arena } // newJSONBuilder creates a new JSON builder instance with the provided mapping @@ -114,6 +116,7 @@ func newJSONBuilder(mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { mapping: mapping, variables: variables, indexMap: createRepresentationIndexMap(variables), + jsonArena: arena.NewMonotonicArena(), } } @@ -160,7 +163,7 @@ func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*a if len(j.indexMap) == 0 { // No federation index map available - use simple merge // This path is taken for non-federated queries - root, _, err := astjson.MergeValues(left, right) + root, _, err := astjson.MergeValues(j.jsonArena, left, right) if err != nil { return nil, err } @@ -186,11 +189,10 @@ func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*a // This function ensures that entities are placed in the correct positions in the final response // array based on their original representation order, which is critical for GraphQL federation. func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) (*astjson.Value, error) { - root := astjson.Arena{} // Create the response structure with _entities array - entities := root.NewObject() - entities.Set(entityPath, root.NewArray()) + entities := astjson.ObjectValue(j.jsonArena) + entities.Set(j.jsonArena, entityPath, astjson.ArrayValue(j.jsonArena)) arr := entities.Get(entityPath) // Extract entity arrays from both responses @@ -206,12 +208,12 @@ func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) ( // Merge left entities using index mapping to preserve order for index, lr := range leftRepresentations { - arr.SetArrayItem(j.indexMap.getResultIndex(lr, index), lr) + arr.SetArrayItem(j.jsonArena, j.indexMap.getResultIndex(lr, index), lr) } // Merge right entities using index mapping to preserve order for index, rr := range rightRepresentations { - arr.SetArrayItem(j.indexMap.getResultIndex(rr, index), rr) + arr.SetArrayItem(j.jsonArena, j.indexMap.getResultIndex(rr, index), rr) } return entities, nil @@ -220,12 +222,12 @@ func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) ( // marshalResponseJSON converts a protobuf message into a GraphQL-compatible JSON response. // This is the core marshaling function that handles all the complex type conversions, // including oneOf types, nested messages, lists, and scalar values. -func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMessage, data protoref.Message) (*astjson.Value, error) { +func (j *jsonBuilder) marshalResponseJSON(message *RPCMessage, data protoref.Message) (*astjson.Value, error) { if message == nil { - return arena.NewNull(), nil + return astjson.NullValue, nil } - root := arena.NewObject() + root := astjson.ObjectValue(j.jsonArena) // Handle protobuf oneOf types - these represent GraphQL union/interface types if message.IsOneOf() { @@ -259,14 +261,14 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess if field.StaticValue != "" { if len(message.MemberTypes) == 0 { // Simple static value - use as-is - root.Set(field.AliasOrPath(), arena.NewString(field.StaticValue)) + root.Set(j.jsonArena, field.AliasOrPath(), astjson.StringValue(j.jsonArena, field.StaticValue)) continue } // Type-specific static value - match against member types for _, memberTypes := range message.MemberTypes { if memberTypes == string(data.Type().Descriptor().Name()) { - root.Set(field.AliasOrPath(), arena.NewString(memberTypes)) + root.Set(j.jsonArena, field.AliasOrPath(), astjson.StringValue(j.jsonArena, memberTypes)) break } } @@ -284,8 +286,8 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess // Handle list fields (repeated in protobuf) if fd.IsList() { list := data.Get(fd).List() - arr := arena.NewArray() - root.Set(field.AliasOrPath(), arr) + arr := astjson.ArrayValue(j.jsonArena) + root.Set(j.jsonArena, field.AliasOrPath(), arr) if !list.IsValid() { // Invalid list - leave as empty array @@ -298,15 +300,15 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess case protoref.MessageKind: // List of messages - recursively marshal each message message := list.Get(i).Message() - value, err := j.marshalResponseJSON(arena, field.Message, message) + value, err := j.marshalResponseJSON(field.Message, message) if err != nil { return nil, err } - arr.SetArrayItem(i, value) + arr.SetArrayItem(j.jsonArena, i, value) default: // List of scalar values - convert directly - j.setArrayItem(i, arena, arr, list.Get(i), fd) + j.setArrayItem(i, arr, list.Get(i), fd) } } @@ -318,24 +320,24 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess msg := data.Get(fd).Message() if !msg.IsValid() { // Invalid message - set to null - root.Set(field.AliasOrPath(), arena.NewNull()) + root.Set(j.jsonArena, field.AliasOrPath(), astjson.NullValue) continue } // Handle special list wrapper types for complex nested lists if field.IsListType { - arr, err := j.flattenListStructure(arena, field.ListMetadata, msg, field.Message) + arr, err := j.flattenListStructure(field.ListMetadata, msg, field.Message) if err != nil { return nil, fmt.Errorf("unable to flatten list structure for field %q: %w", field.AliasOrPath(), err) } - root.Set(field.AliasOrPath(), arr) + root.Set(j.jsonArena, field.AliasOrPath(), arr) continue } // Handle optional scalar wrapper types (e.g., google.protobuf.StringValue) if field.IsOptionalScalar() { - err := j.resolveOptionalField(arena, root, field.AliasOrPath(), msg) + err := j.resolveOptionalField(root, field.AliasOrPath(), msg) if err != nil { return nil, err } @@ -344,27 +346,27 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess } // Regular nested message - recursively marshal - value, err := j.marshalResponseJSON(arena, field.Message, msg) + value, err := j.marshalResponseJSON(field.Message, msg) if err != nil { return nil, err } if field.JSONPath == "" { // Field should be merged into parent object (flattened) - root, _, err = astjson.MergeValues(root, value) + root, _, err = astjson.MergeValues(j.jsonArena, root, value) if err != nil { return nil, err } } else { // Field should be nested under its own key - root.Set(field.AliasOrPath(), value) + root.Set(j.jsonArena, field.AliasOrPath(), value) } continue } // Handle scalar fields (string, int, bool, etc.) - j.setJSONValue(arena, root, field.AliasOrPath(), data, fd) + j.setJSONValue(root, field.AliasOrPath(), data, fd) } return root, nil @@ -374,34 +376,34 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess // messages to support nullable and multi-dimensional lists. This is necessary because // protobuf doesn't directly support nullable list items or complex nesting scenarios // that GraphQL allows. -func (j *jsonBuilder) flattenListStructure(arena *astjson.Arena, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { +func (j *jsonBuilder) flattenListStructure(md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { if md == nil { - return arena.NewNull(), errors.New("list metadata not found") + return astjson.NullValue, errors.New("list metadata not found") } // Validate metadata consistency if len(md.LevelInfo) < md.NestingLevel { - return arena.NewNull(), errors.New("nesting level data does not match the number of levels in the list metadata") + return astjson.NullValue, errors.New("nesting level data does not match the number of levels in the list metadata") } // Handle null data with proper nullability checking if !data.IsValid() { if md.LevelInfo[0].Optional { - return arena.NewNull(), nil + return astjson.NullValue, nil } - return arena.NewNull(), errors.New("cannot add null item to response for non nullable list") + return astjson.NullValue, errors.New("cannot add null item to response for non nullable list") } // Start recursive traversal of the nested list structure - root := arena.NewArray() - return j.traverseList(0, arena, root, md, data, message) + root := astjson.ArrayValue(j.jsonArena) + return j.traverseList(0, root, md, data, message) } // traverseList recursively traverses nested list wrapper structures to extract the actual // list data. This handles multi-dimensional lists like [[String]] or [[[User]]] by // unwrapping the protobuf message wrappers at each level. -func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *astjson.Value, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { +func (j *jsonBuilder) traverseList(level int, current *astjson.Value, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { if level > md.NestingLevel { return current, nil } @@ -409,11 +411,11 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast // List wrappers always use field number 1 in the generated protobuf fd := data.Descriptor().Fields().ByNumber(1) if fd == nil { - return arena.NewNull(), fmt.Errorf("field with number %d not found in message %q", 1, data.Descriptor().Name()) + return astjson.NullValue, fmt.Errorf("field with number %d not found in message %q", 1, data.Descriptor().Name()) } if fd.Kind() != protoref.MessageKind { - return arena.NewNull(), fmt.Errorf("field %q is not a message", fd.Name()) + return astjson.NullValue, fmt.Errorf("field %q is not a message", fd.Name()) } // Get the wrapper message containing the list @@ -421,16 +423,16 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast if !msg.IsValid() { // Handle null wrapper based on nullability rules if md.LevelInfo[level].Optional { - return arena.NewNull(), nil + return astjson.NullValue, nil } - return arena.NewArray(), fmt.Errorf("cannot add null item to response for non nullable list") + return astjson.ArrayValue(j.jsonArena), fmt.Errorf("cannot add null item to response for non nullable list") } // The actual list is always at field number 1 in the wrapper fd = msg.Descriptor().Fields().ByNumber(1) if !fd.IsList() { - return arena.NewNull(), fmt.Errorf("field %q is not a list", fd.Name()) + return astjson.NullValue, fmt.Errorf("field %q is not a list", fd.Name()) } // Handle intermediate nesting levels (not the final level) @@ -438,13 +440,13 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast list := msg.Get(fd).List() for i := 0; i < list.Len(); i++ { // Create nested array for next level - next := arena.NewArray() - val, err := j.traverseList(level+1, arena, next, md, list.Get(i).Message(), message) + next := astjson.ArrayValue(j.jsonArena) + val, err := j.traverseList(level+1, next, md, list.Get(i).Message(), message) if err != nil { return nil, err } - current.SetArrayItem(i, val) + current.SetArrayItem(j.jsonArena, i, val) } return current, nil @@ -455,22 +457,22 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast if !list.IsValid() { // Invalid list at final level - return empty array // Nullability is checked at the wrapper level, not the list level - return arena.NewArray(), nil + return astjson.ArrayValue(j.jsonArena), nil } // Process each item in the final list for i := 0; i < list.Len(); i++ { if message != nil { // List of complex objects - recursively marshal each item - val, err := j.marshalResponseJSON(arena, message, list.Get(i).Message()) + val, err := j.marshalResponseJSON(message, list.Get(i).Message()) if err != nil { return nil, err } - current.SetArrayItem(i, val) + current.SetArrayItem(j.jsonArena, i, val) } else { // List of scalar values - convert directly - j.setArrayItem(i, arena, current, list.Get(i), fd) + j.setArrayItem(i, current, list.Get(i), fd) } } @@ -480,7 +482,7 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast // resolveOptionalField extracts the value from optional scalar wrapper types like // google.protobuf.StringValue, google.protobuf.Int32Value, etc. These wrappers // are used to represent nullable scalar values in protobuf. -func (j *jsonBuilder) resolveOptionalField(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message) error { +func (j *jsonBuilder) resolveOptionalField(root *astjson.Value, name string, data protoref.Message) error { // Optional scalar wrappers always have a "value" field fd := data.Descriptor().Fields().ByName(protoref.Name("value")) if fd == nil { @@ -488,16 +490,16 @@ func (j *jsonBuilder) resolveOptionalField(arena *astjson.Arena, root *astjson.V } // Extract and set the wrapped value - j.setJSONValue(arena, root, name, data, fd) + j.setJSONValue(root, name, data, fd) return nil } // setJSONValue converts a protobuf field value to the appropriate JSON representation // and sets it on the provided JSON object. This handles all protobuf scalar types // and enum values with proper GraphQL mapping. -func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) { +func (j *jsonBuilder) setJSONValue(root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) { if !data.IsValid() { - root.Set(name, arena.NewNull()) + root.Set(j.jsonArena, name, astjson.NullValue) return } @@ -505,27 +507,27 @@ func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, na case protoref.BoolKind: boolValue := data.Get(fd).Bool() if boolValue { - root.Set(name, arena.NewTrue()) + root.Set(j.jsonArena, name, astjson.TrueValue(j.jsonArena)) } else { - root.Set(name, arena.NewFalse()) + root.Set(j.jsonArena, name, astjson.FalseValue(j.jsonArena)) } case protoref.StringKind: - root.Set(name, arena.NewString(data.Get(fd).String())) + root.Set(j.jsonArena, name, astjson.StringValue(j.jsonArena, data.Get(fd).String())) case protoref.Int32Kind: - root.Set(name, arena.NewNumberInt(int(data.Get(fd).Int()))) + root.Set(j.jsonArena, name, astjson.IntValue(j.jsonArena, int(data.Get(fd).Int()))) case protoref.Int64Kind: - root.Set(name, arena.NewNumberString(strconv.FormatInt(data.Get(fd).Int(), 10))) + root.Set(j.jsonArena, name, astjson.NumberValue(j.jsonArena, strconv.FormatInt(data.Get(fd).Int(), 10))) case protoref.Uint32Kind, protoref.Uint64Kind: - root.Set(name, arena.NewNumberString(strconv.FormatUint(data.Get(fd).Uint(), 10))) + root.Set(j.jsonArena, name, astjson.NumberValue(j.jsonArena, strconv.FormatUint(data.Get(fd).Uint(), 10))) case protoref.FloatKind, protoref.DoubleKind: - root.Set(name, arena.NewNumberFloat64(data.Get(fd).Float())) + root.Set(j.jsonArena, name, astjson.FloatValue(j.jsonArena, data.Get(fd).Float())) case protoref.BytesKind: - root.Set(name, arena.NewStringBytes(data.Get(fd).Bytes())) + root.Set(j.jsonArena, name, astjson.StringValueBytes(j.jsonArena, data.Get(fd).Bytes())) case protoref.EnumKind: enumDesc := fd.Enum() enumValueDesc := enumDesc.Values().ByNumber(data.Get(fd).Enum()) if enumValueDesc == nil { - root.Set(name, arena.NewNull()) + root.Set(j.jsonArena, name, astjson.NullValue) return } @@ -533,20 +535,20 @@ func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, na graphqlValue, ok := j.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) if !ok { // No mapping found - set to null - root.Set(name, arena.NewNull()) + root.Set(j.jsonArena, name, astjson.NullValue) return } - root.Set(name, arena.NewString(graphqlValue)) + root.Set(j.jsonArena, name, astjson.StringValue(j.jsonArena, graphqlValue)) } } // setArrayItem converts a protobuf list item value to JSON and sets it at the specified // array index. This is similar to setJSONValue but operates on array elements rather // than object properties, and works with protobuf Value types rather than Message types. -func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) { +func (j *jsonBuilder) setArrayItem(index int, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) { if !data.IsValid() { - array.SetArrayItem(index, arena.NewNull()) + array.SetArrayItem(j.jsonArena, index, astjson.NullValue) return } @@ -554,27 +556,27 @@ func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjs case protoref.BoolKind: boolValue := data.Bool() if boolValue { - array.SetArrayItem(index, arena.NewTrue()) + array.SetArrayItem(j.jsonArena, index, astjson.TrueValue(j.jsonArena)) } else { - array.SetArrayItem(index, arena.NewFalse()) + array.SetArrayItem(j.jsonArena, index, astjson.FalseValue(j.jsonArena)) } case protoref.StringKind: - array.SetArrayItem(index, arena.NewString(data.String())) + array.SetArrayItem(j.jsonArena, index, astjson.StringValue(j.jsonArena, data.String())) case protoref.Int32Kind: - array.SetArrayItem(index, arena.NewNumberInt(int(data.Int()))) + array.SetArrayItem(j.jsonArena, index, astjson.IntValue(j.jsonArena, int(data.Int()))) case protoref.Int64Kind: - array.SetArrayItem(index, arena.NewNumberString(strconv.FormatInt(data.Int(), 10))) + array.SetArrayItem(j.jsonArena, index, astjson.NumberValue(j.jsonArena, strconv.FormatInt(data.Int(), 10))) case protoref.Uint32Kind, protoref.Uint64Kind: - array.SetArrayItem(index, arena.NewNumberString(strconv.FormatUint(data.Uint(), 10))) + array.SetArrayItem(j.jsonArena, index, astjson.NumberValue(j.jsonArena, strconv.FormatUint(data.Uint(), 10))) case protoref.FloatKind, protoref.DoubleKind: - array.SetArrayItem(index, arena.NewNumberFloat64(data.Float())) + array.SetArrayItem(j.jsonArena, index, astjson.FloatValue(j.jsonArena, data.Float())) case protoref.BytesKind: - array.SetArrayItem(index, arena.NewStringBytes(data.Bytes())) + array.SetArrayItem(j.jsonArena, index, astjson.StringValueBytes(j.jsonArena, data.Bytes())) case protoref.EnumKind: enumDesc := fd.Enum() enumValueDesc := enumDesc.Values().ByNumber(data.Enum()) if enumValueDesc == nil { - array.SetArrayItem(index, arena.NewNull()) + array.SetArrayItem(j.jsonArena, index, astjson.NullValue) return } @@ -582,20 +584,19 @@ func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjs graphqlValue, ok := j.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) if !ok { // No mapping found - use null - array.SetArrayItem(index, arena.NewNull()) + array.SetArrayItem(j.jsonArena, index, astjson.NullValue) return } - array.SetArrayItem(index, arena.NewString(graphqlValue)) + array.SetArrayItem(j.jsonArena, index, astjson.StringValue(j.jsonArena, graphqlValue)) } } // toDataObject wraps a response value in the standard GraphQL data envelope. // This creates the top-level structure { "data": ... } that GraphQL clients expect. func (j *jsonBuilder) toDataObject(root *astjson.Value) *astjson.Value { - a := astjson.Arena{} - data := a.NewObject() - data.Set(dataPath, root) + data := astjson.ObjectValue(j.jsonArena) + data.Set(j.jsonArena, dataPath, root) return data } @@ -603,30 +604,27 @@ func (j *jsonBuilder) toDataObject(root *astjson.Value) *astjson.Value { // This includes the error message and gRPC status code information in the extensions // field, following GraphQL error specification standards. func (j *jsonBuilder) writeErrorBytes(err error) []byte { - a := astjson.Arena{} - defer a.Reset() - // Create standard GraphQL error structure - errorRoot := a.NewObject() - errorArray := a.NewArray() - errorRoot.Set(errorsPath, errorArray) + errorRoot := astjson.ObjectValue(j.jsonArena) + errorArray := astjson.ArrayValue(j.jsonArena) + errorRoot.Set(j.jsonArena, errorsPath, errorArray) // Create individual error object - errorItem := a.NewObject() - errorItem.Set("message", a.NewString(err.Error())) + errorItem := astjson.ObjectValue(j.jsonArena) + errorItem.Set(j.jsonArena, "message", astjson.StringValue(j.jsonArena, err.Error())) // Add gRPC status code information to extensions - extensions := a.NewObject() + extensions := astjson.ObjectValue(j.jsonArena) if st, ok := status.FromError(err); ok { // gRPC error - include the specific status code - extensions.Set("code", a.NewString(st.Code().String())) + extensions.Set(j.jsonArena, "code", astjson.StringValue(j.jsonArena, st.Code().String())) } else { // Generic error - default to INTERNAL status - extensions.Set("code", a.NewString(codes.Internal.String())) + extensions.Set(j.jsonArena, "code", astjson.StringValue(j.jsonArena, codes.Internal.String())) } - errorItem.Set("extensions", extensions) - errorArray.SetArrayItem(0, errorItem) + errorItem.Set(j.jsonArena, "extensions", extensions) + errorArray.SetArrayItem(j.jsonArena, 0, errorItem) return errorRoot.MarshalTo(nil) } diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index 65d2d6b90..e9958d24e 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -146,7 +146,7 @@ func (c *Context) appendSubgraphErrors(errs ...error) { } type Request struct { - ID string + ID uint64 Header http.Header } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 7a14d61dc..ad4e78e47 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" @@ -180,6 +181,8 @@ type Loader struct { validateRequiredExternalFields bool taintedObjs taintedObjects + + jsonArena arena.Arena } func (l *Loader) Free() { @@ -431,7 +434,7 @@ func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjso return selected } -func itemsData(items []*astjson.Value) *astjson.Value { +func itemsData(a arena.Arena, items []*astjson.Value) *astjson.Value { if len(items) == 0 { return astjson.NullValue } @@ -442,7 +445,7 @@ func itemsData(items []*astjson.Value) *astjson.Value { // however, itemsData can be called concurrently, so this might result in a race arr := astjson.MustParseBytes([]byte(`[]`)) for i, item := range items { - arr.SetArrayItem(i, item) + arr.SetArrayItem(a, i, item) } return arr } @@ -552,7 +555,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson return l.renderErrorsFailedToFetch(fetchItem, res, emptyGraphQLResponse) } - response, err := astjson.ParseBytesWithoutCache(res.out.Bytes()) + response, err := astjson.ParseBytesWithArena(l.jsonArena, res.out.Bytes()) if err != nil { // Fall back to status code if parsing fails and non-2XX if (res.statusCode > 0 && res.statusCode < 200) || res.statusCode >= 300 { @@ -633,7 +636,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson return nil } if len(items) == 1 && res.batchStats == nil { - items[0], _, err = astjson.MergeValuesWithPath(items[0], responseData, res.postProcessing.MergePath...) + items[0], _, err = astjson.MergeValuesWithPath(l.jsonArena, items[0], responseData, res.postProcessing.MergePath...) if err != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, @@ -662,7 +665,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson if idx == -1 { continue } - items[i], _, err = astjson.MergeValuesWithPath(items[i], batch[idx], res.postProcessing.MergePath...) + items[i], _, err = astjson.MergeValuesWithPath(l.jsonArena, items[i], batch[idx], res.postProcessing.MergePath...) if err != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, @@ -683,7 +686,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson } for i := range items { - items[i], _, err = astjson.MergeValuesWithPath(items[i], batch[i], res.postProcessing.MergePath...) + items[i], _, err = astjson.MergeValuesWithPath(l.jsonArena, items[i], batch[i], res.postProcessing.MergePath...) if err != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, @@ -749,7 +752,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V values := value.GetArray() l.optionallyOmitErrorLocations(values) if l.rewriteSubgraphErrorPaths { - rewriteErrorPaths(fetchItem, values) + rewriteErrorPaths(l.jsonArena, fetchItem, values) } l.optionallyEnsureExtensionErrorCode(values) @@ -792,7 +795,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V } // Wrap mode (default) - errorObject, err := astjson.ParseWithoutCache(l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, failedToFetchNoReason)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, failedToFetchNoReason)) if err != nil { return err } @@ -861,17 +864,17 @@ func (l *Loader) optionallyEnsureExtensionErrorCode(values []*astjson.Value) { switch extensions.Type() { case astjson.TypeObject: if !extensions.Exists("code") { - extensions.Set("code", l.resolvable.astjsonArena.NewString(l.defaultErrorExtensionCode)) + extensions.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) } case astjson.TypeNull: - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("code", l.resolvable.astjsonArena.NewString(l.defaultErrorExtensionCode)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } else { - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("code", l.resolvable.astjsonArena.NewString(l.defaultErrorExtensionCode)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } } @@ -888,16 +891,16 @@ func (l *Loader) optionallyAttachServiceNameToErrorExtension(values []*astjson.V extensions := value.Get("extensions") switch extensions.Type() { case astjson.TypeObject: - extensions.Set("serviceName", l.resolvable.astjsonArena.NewString(serviceName)) + extensions.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) case astjson.TypeNull: - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("serviceName", l.resolvable.astjsonArena.NewString(serviceName)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } else { - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("serviceName", l.resolvable.astjsonArena.NewString(serviceName)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } } @@ -951,7 +954,7 @@ func (l *Loader) optionallyOmitErrorLocations(values []*astjson.Value) { // - Drops the numeric index immediately following "_entities". // - Converts all subsequent numeric segments to strings (e.g., 1 -> "1"). // - Skips non-string/non-number segments. -func rewriteErrorPaths(fetchItem *FetchItem, values []*astjson.Value) { +func rewriteErrorPaths(a arena.Arena, fetchItem *FetchItem, values []*astjson.Value) { pathPrefix := make([]string, len(fetchItem.ResponsePathElements)) copy(pathPrefix, fetchItem.ResponsePathElements) // remove the trailing @ in case we're in an array as it looks weird in the path @@ -993,11 +996,11 @@ func rewriteErrorPaths(fetchItem *FetchItem, values []*astjson.Value) { } } newPathJSON, _ := json.Marshal(newPath) - pathBytes, err := astjson.ParseBytesWithoutCache(newPathJSON) + pathBytes, err := astjson.ParseBytesWithArena(a, newPathJSON) if err != nil { continue } - value.Set("path", pathBytes) + value.Set(a, "path", pathBytes) break } } @@ -1018,17 +1021,17 @@ func (l *Loader) setSubgraphStatusCode(values []*astjson.Value, statusCode int) if extensions.Type() != astjson.TypeObject { continue } - v, err := astjson.ParseWithoutCache(strconv.Itoa(statusCode)) + v, err := astjson.ParseWithArena(l.jsonArena, strconv.Itoa(statusCode)) if err != nil { continue } - extensions.Set("statusCode", v) + extensions.Set(l.jsonArena, "statusCode", v) } else { - v, err := astjson.ParseWithoutCache(`{"statusCode":` + strconv.Itoa(statusCode) + `}`) + v, err := astjson.ParseWithArena(l.jsonArena, `{"statusCode":`+strconv.Itoa(statusCode)+`}`) if err != nil { continue } - value.Set("extensions", v) + value.Set(l.jsonArena, "extensions", v) } } } @@ -1065,7 +1068,7 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { } } }`, res.ds.Name, http.StatusText(res.statusCode), res.statusCode) - apolloRouterStatusError, err := astjson.ParseWithoutCache(apolloRouterStatusErrorJSON) + apolloRouterStatusError, err := astjson.ParseWithArena(l.jsonArena, apolloRouterStatusErrorJSON) if err != nil { return err } @@ -1078,7 +1081,7 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error { path := l.renderAtPathErrorPart(fetchItem.ResponsePath) msg := fmt.Sprintf(`{"message":"Failed to obtain field dependencies from Subgraph '%s'%s."}`, res.ds.Name, path) - errorObject, err := astjson.ParseWithoutCache(msg) + errorObject, err := astjson.ParseWithArena(l.jsonArena, msg) if err != nil { return err } @@ -1089,7 +1092,7 @@ func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error func (l *Loader) renderErrorsFailedToFetch(fetchItem *FetchItem, res *result, reason string) error { l.ctx.appendSubgraphErrors(res.err, NewSubgraphError(res.ds, fetchItem.ResponsePath, reason, res.statusCode)) - errorObject, err := astjson.ParseWithoutCache(l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, reason)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, reason)) if err != nil { return err } @@ -1106,7 +1109,7 @@ func (l *Loader) renderErrorsStatusFallback(fetchItem *FetchItem, res *result, s l.ctx.appendSubgraphErrors(res.err, NewSubgraphError(res.ds, fetchItem.ResponsePath, reason, res.statusCode)) - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"%s"}`, reason)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"%s"}`, reason)) if err != nil { return err } @@ -1140,13 +1143,13 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re if res.ds.Name == "" { for _, reason := range res.authorizationRejectedReasons { if reason == "" { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s.",%s}`, pathPart, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s.",%s}`, pathPart, extensionErrorCode)) if err != nil { continue } astjson.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s, Reason: %s.",%s}`, pathPart, reason, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s, Reason: %s.",%s}`, pathPart, reason, extensionErrorCode)) if err != nil { continue } @@ -1156,13 +1159,13 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re } else { for _, reason := range res.authorizationRejectedReasons { if reason == "" { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s.",%s}`, res.ds.Name, pathPart, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s.",%s}`, res.ds.Name, pathPart, extensionErrorCode)) if err != nil { continue } astjson.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s, Reason: %s.",%s}`, res.ds.Name, pathPart, reason, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s, Reason: %s.",%s}`, res.ds.Name, pathPart, reason, extensionErrorCode)) if err != nil { continue } @@ -1182,35 +1185,35 @@ func (l *Loader) renderRateLimitRejectedErrors(fetchItem *FetchItem, res *result ) if res.ds.Name == "" { if res.rateLimitRejectedReason == "" { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart)) if err != nil { return err } } else { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason)) if err != nil { return err } } } else { if res.rateLimitRejectedReason == "" { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart)) if err != nil { return err } } else { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason)) if err != nil { return err } } } if l.ctx.RateLimitOptions.ErrorExtensionCode.Enabled { - extension, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code)) + extension, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code)) if err != nil { return err } - errorObject, _, err = astjson.MergeValuesWithPath(errorObject, extension, "extensions") + errorObject, _, err = astjson.MergeValuesWithPath(l.jsonArena, errorObject, extension, "extensions") if err != nil { return err } @@ -1287,7 +1290,7 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI res.init(fetch.PostProcessing, fetch.Info) buf := &bytes.Buffer{} - inputData := itemsData(items) + inputData := itemsData(l.jsonArena, items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && inputData != nil { @@ -1353,7 +1356,7 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc res.init(fetch.PostProcessing, fetch.Info) buf := acquireEntityFetchBuffer() defer releaseEntityFetchBuffer(buf) - input := itemsData(items) + input := itemsData(l.jsonArena, items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && input != nil { @@ -1465,7 +1468,7 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && len(items) != 0 { - data := itemsData(items) + data := itemsData(l.jsonArena, items) if data != nil { fetch.Trace.RawInputData, _ = l.compactJSON(data.MarshalTo(nil)) } @@ -1840,7 +1843,7 @@ func (l *Loader) compactJSON(data []byte) ([]byte, error) { return nil, err } out := dst.Bytes() - v, err := astjson.ParseBytesWithoutCache(out) + v, err := astjson.ParseBytesWithArena(l.jsonArena, out) if err != nil { return nil, err } diff --git a/v2/pkg/engine/resolve/loader_test.go b/v2/pkg/engine/resolve/loader_test.go index 4ed83d444..01c5ef5dc 100644 --- a/v2/pkg/engine/resolve/loader_test.go +++ b/v2/pkg/engine/resolve/loader_test.go @@ -287,7 +287,7 @@ func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -376,7 +376,7 @@ func TestLoader_MergeErrorDifferingTypes(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -467,7 +467,7 @@ func TestLoader_MergeErrorDifferingArrayLength(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -749,7 +749,7 @@ func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { ctx: context.Background(), Extensions: []byte(`{"foo":"bar"}`), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -1024,7 +1024,7 @@ func BenchmarkLoader_LoadGraphQLResponseData(b *testing.B) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` b.SetBytes(int64(len(expected))) @@ -1125,7 +1125,7 @@ func TestLoader_RedactHeaders(t *testing.T) { Enable: true, }, } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) @@ -1421,7 +1421,7 @@ func TestLoader_InvalidBatchItemCount(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -1521,13 +1521,13 @@ func TestRewriteErrorPaths(t *testing.T) { for i, inputError := range tc.inputErrors { // Create a copy by marshaling and parsing again data := inputError.MarshalTo(nil) - value, err := astjson.ParseBytesWithoutCache(data) + value, err := astjson.ParseBytesWithArena(nil, data) assert.NoError(t, err, "Failed to copy input error") values[i] = value } // Call the function under test - rewriteErrorPaths(fetchItem, values) + rewriteErrorPaths(nil, fetchItem, values) // Compare the results assert.Equal(t, len(tc.expectedErrors), len(values), diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index 5219c910d..5aceb2110 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -11,6 +11,7 @@ import ( "github.com/cespare/xxhash/v2" "github.com/pkg/errors" "github.com/tidwall/gjson" + "github.com/wundergraph/go-arena" "github.com/wundergraph/astjson" @@ -31,7 +32,7 @@ type Resolvable struct { valueCompletion *astjson.Value skipAddingNullErrors bool - astjsonArena *astjson.Arena + astjsonArena arena.Arena parsers []*astjson.Parser print bool @@ -67,13 +68,13 @@ type ResolvableOptions struct { ApolloCompatibilityReplaceInvalidVarError bool } -func NewResolvable(options ResolvableOptions) *Resolvable { +func NewResolvable(a arena.Arena, options ResolvableOptions) *Resolvable { return &Resolvable{ options: options, xxh: xxhash.New(), authorizationAllow: make(map[uint64]struct{}), authorizationDeny: make(map[uint64]string), - astjsonArena: &astjson.Arena{}, + astjsonArena: a, } } @@ -95,7 +96,7 @@ func (r *Resolvable) Reset() { r.operationType = ast.OperationTypeUnknown r.renameTypeNames = r.renameTypeNames[:0] r.authorizationError = nil - r.astjsonArena.Reset() + r.astjsonArena = nil r.xxh.Reset() for k := range r.authorizationAllow { delete(r.authorizationAllow, k) @@ -109,14 +110,14 @@ func (r *Resolvable) Init(ctx *Context, initialData []byte, operationType ast.Op r.ctx = ctx r.operationType = operationType r.renameTypeNames = ctx.RenameTypeNames - r.data = r.astjsonArena.NewObject() - r.errors = r.astjsonArena.NewArray() + r.data = astjson.ObjectValue(r.astjsonArena) + r.errors = astjson.ArrayValue(r.astjsonArena) if initialData != nil { - initialValue, err := astjson.ParseBytesWithoutCache(initialData) + initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) if err != nil { return err } - r.data, _, err = astjson.MergeValues(r.data, initialValue) + r.data, _, err = astjson.MergeValues(r.astjsonArena, r.data, initialValue) if err != nil { return err } @@ -129,19 +130,19 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc r.operationType = ast.OperationTypeSubscription r.renameTypeNames = ctx.RenameTypeNames if initialData != nil { - initialValue, err := astjson.ParseBytesWithoutCache(initialData) + initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) if err != nil { return err } if postProcessing.SelectResponseDataPath == nil { - r.data, _, err = astjson.MergeValuesWithPath(r.data, initialValue, postProcessing.MergePath...) + r.data, _, err = astjson.MergeValuesWithPath(r.astjsonArena, r.data, initialValue, postProcessing.MergePath...) if err != nil { return err } } else { selectedInitialValue := initialValue.Get(postProcessing.SelectResponseDataPath...) if selectedInitialValue != nil { - r.data, _, err = astjson.MergeValuesWithPath(r.data, selectedInitialValue, postProcessing.MergePath...) + r.data, _, err = astjson.MergeValuesWithPath(r.astjsonArena, r.data, selectedInitialValue, postProcessing.MergePath...) if err != nil { return err } @@ -155,10 +156,10 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc } } if r.data == nil { - r.data = r.astjsonArena.NewObject() + r.data = astjson.ObjectValue(r.astjsonArena) } if r.errors == nil { - r.errors = r.astjsonArena.NewArray() + r.errors = astjson.ArrayValue(r.astjsonArena) } return } @@ -168,7 +169,7 @@ func (r *Resolvable) ResolveNode(node Node, data *astjson.Value, out io.Writer) r.print = false r.printErr = nil r.authorizationError = nil - r.errors = r.astjsonArena.NewArray() + r.errors = astjson.ArrayValue(r.astjsonArena) hasErrors := r.walkNode(node, data) if hasErrors { @@ -464,7 +465,7 @@ func (r *Resolvable) renderScalarFieldValue(value *astjson.Value, nullable bool) // renderScalarFieldString - is used when value require some pre-processing, e.g. unescaping or custom rendering func (r *Resolvable) renderScalarFieldBytes(data []byte, nullable bool) { - value, err := astjson.ParseBytesWithoutCache(data) + value, err := astjson.ParseBytesWithArena(r.astjsonArena, data) if err != nil { r.printErr = err return @@ -853,7 +854,7 @@ func (r *Resolvable) walkArray(arr *Array, value *astjson.Value) bool { r.popArrayPathElement() if err { if arr.Item.NodeKind() == NodeKindObject && arr.Item.NodeNullable() { - value.SetArrayItem(i, astjson.NullValue) + value.SetArrayItem(r.astjsonArena, i, astjson.NullValue) continue } if arr.Nullable { @@ -1287,14 +1288,14 @@ func (r *Resolvable) addErrorWithCodeAndPath(message, code string, fieldPath []s func (r *Resolvable) addValueCompletion(message, code string) { if r.valueCompletion == nil { - r.valueCompletion = r.astjsonArena.NewArray() + r.valueCompletion = astjson.ArrayValue(r.astjsonArena) } fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.valueCompletion, message, code, r.path) } func (r *Resolvable) addValueCompletionWithPath(message, code string, fieldPath []string) { if r.valueCompletion == nil { - r.valueCompletion = r.astjsonArena.NewArray() + r.valueCompletion = astjson.ArrayValue(r.astjsonArena) } r.pushNodePathElement(fieldPath) fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.valueCompletion, message, code, r.path) diff --git a/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go b/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go index 843c6e696..0dbb0394b 100644 --- a/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go +++ b/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go @@ -440,7 +440,7 @@ func TestResolvable_CustomFieldRenderer(t *testing.T) { t.Parallel() // Setup - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{} var input []byte @@ -543,7 +543,7 @@ func TestResolvable_CustomFieldRenderer(t *testing.T) { t.Parallel() input := []byte(tc.input) - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{} err := res.Init(ctx, input, ast.OperationTypeQuery) assert.NoError(t, err) diff --git a/v2/pkg/engine/resolve/resolvable_test.go b/v2/pkg/engine/resolve/resolvable_test.go index 4b92f8591..aea4e78ef 100644 --- a/v2/pkg/engine/resolve/resolvable_test.go +++ b/v2/pkg/engine/resolve/resolvable_test.go @@ -12,7 +12,7 @@ import ( func TestResolvable_Resolve(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -84,7 +84,7 @@ func TestResolvable_Resolve(t *testing.T) { func TestResolvable_ResolveWithTypeMismatch(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":true}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -157,7 +157,7 @@ func TestResolvable_ResolveWithTypeMismatch(t *testing.T) { func TestResolvable_ResolveWithErrorBubbleUp(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -231,7 +231,7 @@ func TestResolvable_ResolveWithErrorBubbleUp(t *testing.T) { func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { t.Run("Non-nullable root field", func(t *testing.T) { topProducts := `{"topProducts":null}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -258,7 +258,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-Nullable root field and nested field", func(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -333,7 +333,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Nullable root field and non-Nullable nested field", func(t *testing.T) { topProducts := `{"topProduct":{"name":null}}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -370,7 +370,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-Nullable sibling field", func(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","reviews":[{"author":{"__typename":"User","name":"Bob"},"body":null}]}]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -439,7 +439,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-nullable array and array item", func(t *testing.T) { topProducts := `{"topProducts":[null]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -469,7 +469,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Nullable array and non-nullable array item", func(t *testing.T) { topProducts := `{"topProducts":[null]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -500,7 +500,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-Nullable array, array item, and array item field", func(t *testing.T) { topProducts := `{"topProducts":[{"author":{"name":"Name"}},{"author":null}]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -549,7 +549,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { func TestResolvable_ResolveWithErrorBubbleUpUntilData(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -622,7 +622,7 @@ func TestResolvable_ResolveWithErrorBubbleUpUntilData(t *testing.T) { func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Invalid enum value", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -653,7 +653,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Inaccessible enum value", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -686,7 +686,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Invalid enum value with value completion Apollo compatibility flag", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -719,7 +719,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Inaccessible enum value with value completion Apollo compatibility flag", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -755,7 +755,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { func BenchmarkResolvable_Resolve(b *testing.B) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -838,7 +838,7 @@ func BenchmarkResolvable_Resolve(b *testing.B) { func BenchmarkResolvable_ResolveWithErrorBubbleUp(b *testing.B) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -923,7 +923,7 @@ func BenchmarkResolvable_ResolveWithErrorBubbleUp(b *testing.B) { } func TestResolvable_WithTracingNotStarted(t *testing.T) { - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) // Do not start a trace with SetTraceStart(), but request it to be output ctx := NewContext(context.Background()) ctx.TracingOptions.Enable = true @@ -950,7 +950,7 @@ func TestResolvable_WithTracingNotStarted(t *testing.T) { func TestResolveFloat(t *testing.T) { t.Run("default behaviour", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := NewContext(context.Background()) err := res.Init(ctx, []byte(`{"f":1.0}`), ast.OperationTypeQuery) assert.NoError(t, err) @@ -972,7 +972,7 @@ func TestResolveFloat(t *testing.T) { assert.Equal(t, `{"data":{"f":1.0}}`, out.String()) }) t.Run("invalid float", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := NewContext(context.Background()) err := res.Init(ctx, []byte(`{"f":false}`), ast.OperationTypeQuery) assert.NoError(t, err) @@ -994,7 +994,7 @@ func TestResolveFloat(t *testing.T) { assert.Equal(t, `{"errors":[{"message":"Float cannot represent non-float value: \"false\"","path":["f"]}],"data":null}`, out.String()) }) t.Run("truncate float", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityTruncateFloatValues: true, }) ctx := NewContext(context.Background()) @@ -1018,7 +1018,7 @@ func TestResolveFloat(t *testing.T) { assert.Equal(t, `{"data":{"f":1}}`, out.String()) }) t.Run("truncate float with decimal place", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityTruncateFloatValues: true, }) ctx := NewContext(context.Background()) @@ -1045,7 +1045,7 @@ func TestResolveFloat(t *testing.T) { func TestResolvable_ValueCompletion(t *testing.T) { t.Run("nested object", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := NewContext(context.Background()) @@ -1143,7 +1143,7 @@ func TestResolvable_ValueCompletion(t *testing.T) { }`) t.Run("nullable", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := NewContext(context.Background()) @@ -1241,7 +1241,7 @@ func TestResolvable_ValueCompletion(t *testing.T) { }) t.Run("mixed nullability", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := NewContext(context.Background()) @@ -1342,7 +1342,7 @@ func TestResolvable_ValueCompletion(t *testing.T) { func TestResolvable_WithTracing(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) background := SetTraceStart(context.Background(), true) ctx := NewContext(background) ctx.TracingOptions.Enable = true diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 14d8ad4b5..92501bd2e 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -235,7 +235,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}) *tools { return &tools{ - resolvable: NewResolvable(options.ResolvableOptions), + resolvable: NewResolvable(nil, options.ResolvableOptions), loader: &Loader{ propagateSubgraphErrors: options.PropagateSubgraphErrors, propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, @@ -291,6 +291,38 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons return resp, err } +func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLResponse, writer io.Writer) (*GraphQLResolveInfo, error) { + resp := &GraphQLResolveInfo{} + + start := time.Now() + <-r.maxConcurrency + resp.ResolveAcquireWaitTime = time.Since(start) + defer func() { + r.maxConcurrency <- struct{}{} + }() + + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + + err := t.resolvable.Init(ctx, nil, response.Info.OperationType) + if err != nil { + return nil, err + } + + if !ctx.ExecutionOptions.SkipLoader { + err = t.loader.LoadGraphQLResponseData(ctx, response, t.resolvable) + if err != nil { + return nil, err + } + } + + err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, writer) + if err != nil { + return nil, err + } + + return resp, err +} + type trigger struct { id uint64 cancel context.CancelFunc diff --git a/v2/pkg/engine/resolve/tainted_objects_test.go b/v2/pkg/engine/resolve/tainted_objects_test.go index 0eeb34440..b8205dc72 100644 --- a/v2/pkg/engine/resolve/tainted_objects_test.go +++ b/v2/pkg/engine/resolve/tainted_objects_test.go @@ -70,7 +70,7 @@ func TestSelectObjectAndIndex(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - response, err := astjson.ParseBytesWithoutCache([]byte(tt.responseJSON)) + response, err := astjson.ParseBytes([]byte(tt.responseJSON)) assert.NoError(t, err, "Failed to parse response JSON") // Convert path elements to astjson.Value slice @@ -94,7 +94,7 @@ func TestSelectObjectAndIndex(t *testing.T) { assert.Nil(t, entity, "Expected nil entity") } else { assert.NotNil(t, entity, "Expected non-nil entity") - expectedEntity, err := astjson.ParseBytesWithoutCache([]byte(tt.expectedEntity)) + expectedEntity, err := astjson.ParseBytes([]byte(tt.expectedEntity)) assert.NoError(t, err, "Failed to parse expected entity JSON") // Compare JSON representations @@ -320,10 +320,10 @@ func TestGetTaintedIndices(t *testing.T) { } mockFetch := &mockFetchWithInfo{info: fetchInfo} - response, err := astjson.ParseBytesWithoutCache([]byte(tt.responseJSON)) + response, err := astjson.ParseBytes([]byte(tt.responseJSON)) assert.NoError(t, err, "Failed to parse response JSON") - errors, err := astjson.ParseBytesWithoutCache([]byte(tt.errorsJSON)) + errors, err := astjson.ParseBytes([]byte(tt.errorsJSON)) assert.NoError(t, err, "Failed to parse errors JSON") indices := getTaintedIndices(mockFetch, response, errors) diff --git a/v2/pkg/engine/resolve/variables_renderer.go b/v2/pkg/engine/resolve/variables_renderer.go index 4cbb471f8..0fa1d3ee1 100644 --- a/v2/pkg/engine/resolve/variables_renderer.go +++ b/v2/pkg/engine/resolve/variables_renderer.go @@ -350,7 +350,7 @@ var ( func (g *GraphQLVariableResolveRenderer) getResolvable() *Resolvable { v := _graphQLVariableResolveRendererPool.Get() if v == nil { - return NewResolvable(ResolvableOptions{}) + return NewResolvable(nil, ResolvableOptions{}) } return v.(*Resolvable) } diff --git a/v2/pkg/fastjsonext/fastjsonext.go b/v2/pkg/fastjsonext/fastjsonext.go index 0480fcbd4..4929e8a96 100644 --- a/v2/pkg/fastjsonext/fastjsonext.go +++ b/v2/pkg/fastjsonext/fastjsonext.go @@ -2,27 +2,28 @@ package fastjsonext import ( "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" ) -func AppendErrorToArray(arena *astjson.Arena, v *astjson.Value, msg string, path []PathElement) { +func AppendErrorToArray(a arena.Arena, v *astjson.Value, msg string, path []PathElement) { if v.Type() != astjson.TypeArray { return } - errorObject := CreateErrorObjectWithPath(arena, msg, path) + errorObject := CreateErrorObjectWithPath(a, msg, path) items, _ := v.Array() - v.SetArrayItem(len(items), errorObject) + v.SetArrayItem(a, len(items), errorObject) } -func AppendErrorWithExtensionsCodeToArray(arena *astjson.Arena, v *astjson.Value, msg, code string, path []PathElement) { +func AppendErrorWithExtensionsCodeToArray(a arena.Arena, v *astjson.Value, msg, code string, path []PathElement) { if v.Type() != astjson.TypeArray { return } - errorObject := CreateErrorObjectWithPath(arena, msg, path) - extensions := arena.NewObject() - extensions.Set("code", arena.NewString(code)) - errorObject.Set("extensions", extensions) + errorObject := CreateErrorObjectWithPath(a, msg, path) + extensions := astjson.ObjectValue(a) + extensions.Set(a, "code", astjson.StringValue(a, code)) + errorObject.Set(a, "extensions", extensions) items, _ := v.Array() - v.SetArrayItem(len(items), errorObject) + v.SetArrayItem(a, len(items), errorObject) } type PathElement struct { @@ -30,29 +31,29 @@ type PathElement struct { Idx int } -func CreateErrorObjectWithPath(arena *astjson.Arena, message string, path []PathElement) *astjson.Value { - errorObject := arena.NewObject() - errorObject.Set("message", arena.NewString(message)) +func CreateErrorObjectWithPath(a arena.Arena, message string, path []PathElement) *astjson.Value { + errorObject := astjson.ObjectValue(a) + errorObject.Set(a, "message", astjson.StringValue(a, message)) if len(path) == 0 { return errorObject } - errorPath := arena.NewArray() + errorPath := astjson.ArrayValue(a) for i := range path { if path[i].Name != "" { - errorPath.SetArrayItem(i, arena.NewString(path[i].Name)) + errorPath.SetArrayItem(a, i, astjson.StringValue(a, path[i].Name)) } else { - errorPath.SetArrayItem(i, arena.NewNumberInt(path[i].Idx)) + errorPath.SetArrayItem(a, i, astjson.IntValue(a, path[i].Idx)) } } - errorObject.Set("path", errorPath) + errorObject.Set(a, "path", errorPath) return errorObject } func PrintGraphQLResponse(data, errors *astjson.Value) string { out := astjson.MustParse(`{}`) if astjson.ValueIsNonNull(errors) { - out.Set("errors", errors) + out.Set(nil, "errors", errors) } - out.Set("data", data) + out.Set(nil, "data", data) return string(out.MarshalTo(nil)) } diff --git a/v2/pkg/fastjsonext/fastjsonext_test.go b/v2/pkg/fastjsonext/fastjsonext_test.go index af4271630..e48a2ad1c 100644 --- a/v2/pkg/fastjsonext/fastjsonext_test.go +++ b/v2/pkg/fastjsonext/fastjsonext_test.go @@ -21,28 +21,28 @@ func TestGetArray(t *testing.T) { func TestAppendErrorWithMessage(t *testing.T) { a := astjson.MustParse(`[]`) - AppendErrorToArray(&astjson.Arena{}, a, "error", nil) + AppendErrorToArray(nil, a, "error", nil) out := a.MarshalTo(nil) require.Equal(t, `[{"message":"error"}]`, string(out)) - AppendErrorToArray(&astjson.Arena{}, a, "error2", []PathElement{{Name: "a"}}) + AppendErrorToArray(nil, a, "error2", []PathElement{{Name: "a"}}) out = a.MarshalTo(nil) require.Equal(t, `[{"message":"error"},{"message":"error2","path":["a"]}]`, string(out)) } func TestCreateErrorObjectWithPath(t *testing.T) { - v := CreateErrorObjectWithPath(&astjson.Arena{}, "my error message", []PathElement{ + v := CreateErrorObjectWithPath(nil, "my error message", []PathElement{ {Name: "a"}, }) out := v.MarshalTo(nil) require.Equal(t, `{"message":"my error message","path":["a"]}`, string(out)) - v = CreateErrorObjectWithPath(&astjson.Arena{}, "my error message", []PathElement{ + v = CreateErrorObjectWithPath(nil, "my error message", []PathElement{ {Name: "a"}, {Idx: 1}, {Name: "b"}, }) out = v.MarshalTo(nil) require.Equal(t, `{"message":"my error message","path":["a",1,"b"]}`, string(out)) - v = CreateErrorObjectWithPath(&astjson.Arena{}, "my error message", []PathElement{ + v = CreateErrorObjectWithPath(nil, "my error message", []PathElement{ {Name: "a"}, {Name: "b"}, }) diff --git a/v2/pkg/variablesvalidation/variablesvalidation.go b/v2/pkg/variablesvalidation/variablesvalidation.go index 70bb6033b..6953a5970 100644 --- a/v2/pkg/variablesvalidation/variablesvalidation.go +++ b/v2/pkg/variablesvalidation/variablesvalidation.go @@ -98,7 +98,7 @@ func (v *VariablesValidator) ValidateWithRemap(operation, definition *ast.Docume func (v *VariablesValidator) Validate(operation, definition *ast.Document, variables []byte) error { v.visitor.definition = definition v.visitor.operation = operation - v.visitor.variables, v.visitor.err = astjson.ParseBytesWithoutCache(variables) + v.visitor.variables, v.visitor.err = astjson.ParseBytes(variables) if v.visitor.err != nil { return v.visitor.err } From 20bf416b618279626dd4c1bb4f60c9c808c473e6 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 15 Oct 2025 16:03:32 +0200 Subject: [PATCH 12/61] chore: refactor & simplify DataSource interface --- .../graphql_datasource/graphql_datasource.go | 8 +- .../graphql_datasource_test.go | 24 +- .../grpc_datasource/grpc_datasource.go | 26 +- .../grpc_datasource/grpc_datasource_test.go | 72 +- .../datasource/httpclient/httpclient_test.go | 11 +- .../datasource/httpclient/nethttpclient.go | 27 +- .../fixtures/schema_introspection.golden | 2 +- ...on_with_custom_root_operation_types.golden | 2 +- .../fixtures/type_introspection.golden | 2 +- .../introspection_datasource/source.go | 22 +- .../introspection_datasource/source_test.go | 13 +- .../pubsub_datasource/pubsub_kafka.go | 16 +- .../pubsub_datasource/pubsub_nats.go | 30 +- .../staticdatasource/static_datasource.go | 8 +- v2/pkg/engine/plan/planner_test.go | 9 +- v2/pkg/engine/resolve/authorization_test.go | 49 +- v2/pkg/engine/resolve/datasource.go | 5 +- v2/pkg/engine/resolve/loader.go | 59 +- v2/pkg/engine/resolve/loader_hooks_test.go | 114 +- v2/pkg/engine/resolve/loader_test.go | 26 +- v2/pkg/engine/resolve/resolve.go | 6 + .../engine/resolve/resolve_federation_test.go | 225 ++- v2/pkg/engine/resolve/resolve_mock_test.go | 27 +- v2/pkg/engine/resolve/resolve_test.go | 1417 ++++++++++++----- 24 files changed, 1412 insertions(+), 788 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 3acdd0760..6f301d52d 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -1907,14 +1907,14 @@ func (s *Source) replaceEmptyObject(variables []byte) ([]byte, bool) { return variables, false } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { input = s.compactAndUnNullVariables(input) - return httpclient.DoMultipartForm(s.httpClient, ctx, input, files, out) + return httpclient.DoMultipartForm(s.httpClient, ctx, input, files) } -func (s *Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (s *Source) Load(ctx context.Context, input []byte) (data []byte, err error) { input = s.compactAndUnNullVariables(input) - return httpclient.Do(s.httpClient, ctx, input, out) + return httpclient.Do(s.httpClient, ctx, input) } type GraphQLSubscriptionClient interface { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index f7031fc3a..75a23f5ed 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -8693,10 +8693,9 @@ func TestSource_Load(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) - - require.NoError(t, src.Load(context.Background(), input, buf)) - assert.Equal(t, `{"variables":{"a":null,"b":"b","c":{}}}`, buf.String()) + data, err := src.Load(context.Background(), input) + require.NoError(t, err) + assert.Equal(t, `{"variables":{"a":null,"b":"b","c":{}}}`, string(data)) }) }) t.Run("remove undefined variables", func(t *testing.T) { @@ -8709,7 +8708,6 @@ func TestSource_Load(t *testing.T) { var input []byte input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) undefinedVariables := []string{"a", "c"} ctx := context.Background() @@ -8717,8 +8715,9 @@ func TestSource_Load(t *testing.T) { input, err = httpclient.SetUndefinedVariables(input, undefinedVariables) assert.NoError(t, err) - require.NoError(t, src.Load(ctx, input, buf)) - assert.Equal(t, `{"variables":{"b":null}}`, buf.String()) + data, err := src.Load(ctx, input) + require.NoError(t, err) + assert.Equal(t, `{"variables":{"b":null}}`, string(data)) }) }) } @@ -8800,10 +8799,10 @@ func TestLoadFiles(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputBodyWithPath(input, query, "query") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) ctx := context.Background() - require.NoError(t, src.LoadWithFiles(ctx, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}, buf)) + _, err = src.LoadWithFiles(ctx, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}) + require.NoError(t, err) }) t.Run("multiple files", func(t *testing.T) { @@ -8844,7 +8843,6 @@ func TestLoadFiles(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputBodyWithPath(input, query, "query") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) dir := t.TempDir() f1, err := os.CreateTemp(dir, file1Name) @@ -8858,11 +8856,11 @@ func TestLoadFiles(t *testing.T) { assert.NoError(t, err) ctx := context.Background() - require.NoError(t, src.LoadWithFiles(ctx, input, + _, err = src.LoadWithFiles(ctx, input, []*httpclient.FileUpload{ httpclient.NewFileUpload(f1.Name(), file1Name, "variables.files.0"), - httpclient.NewFileUpload(f2.Name(), file2Name, "variables.files.1")}, - buf)) + httpclient.NewFileUpload(f2.Name(), file2Name, "variables.files.1")}) + require.NoError(t, err) }) } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 78cdce9f7..58729e33c 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -7,7 +7,6 @@ package grpcdatasource import ( - "bytes" "context" "fmt" "sync" @@ -73,25 +72,24 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D } // Load implements resolve.DataSource interface. -// It processes the input JSON data to make gRPC calls and writes -// the response to the output buffer. +// It processes the input JSON data to make gRPC calls and returns +// the response data. // // The input is expected to contain the necessary information to make // a gRPC call, including service name, method name, and request data. -func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (d *DataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { // get variables from input variables := gjson.Parse(string(input)).Get("body.variables") builder := newJSONBuilder(d.mapping, variables) if d.disabled { - out.Write(builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used"))) - return nil + return builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used")), nil } // get invocations from plan invocations, err := d.rc.Compile(d.plan, variables) if err != nil { - return err + return nil, err } responses := make([]*astjson.Value, len(invocations)) @@ -130,23 +128,19 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) } if err := errGrp.Wait(); err != nil { - out.Write(builder.writeErrorBytes(err)) - return nil + return builder.writeErrorBytes(err), nil } root := astjson.ObjectValue(builder.jsonArena) for _, response := range responses { root, err = builder.mergeValues(root, response) if err != nil { - out.Write(builder.writeErrorBytes(err)) - return err + return builder.writeErrorBytes(err), err } } - data := builder.toDataObject(root) - out.Write(data.MarshalTo(nil)) - - return nil + dataObj := builder.toDataObject(root) + return dataObj.MarshalTo(nil), nil } // LoadWithFiles implements resolve.DataSource interface. @@ -156,6 +150,6 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) // might not be applicable for most gRPC use cases. // // Currently unimplemented. -func (d *DataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (d *DataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("unimplemented") } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index f7340cec8..2a18e2f17 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -1,7 +1,6 @@ package grpcdatasource import ( - "bytes" "context" "encoding/json" "fmt" @@ -147,12 +146,10 @@ func Test_DataSource_Load(t *testing.T) { require.NoError(t, err) - output := new(bytes.Buffer) - - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","variables":`+variables+`}`), output) + output, err := ds.Load(context.Background(), []byte(`{"query":"`+query+`","variables":`+variables+`}`)) require.NoError(t, err) - fmt.Println(output.String()) + fmt.Println(string(output)) } // Test_DataSource_Load_WithMockService tests the datasource.Load method with an actual gRPC server @@ -220,12 +217,11 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { require.NoError(t, err) // 3. Execute the query through our datasource - output := new(bytes.Buffer) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), output) + output, err := ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(t, err) // Print the response for debugging - // fmt.Println(output.String()) + // fmt.Println(string(output)) type response struct { Data struct { @@ -238,7 +234,7 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { var resp response - bytes := output.Bytes() + bytes := output fmt.Println(string(bytes)) err = json.Unmarshal(bytes, &resp) @@ -310,12 +306,10 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { require.NoError(t, err) // 3. Execute the query through our datasource - output := new(bytes.Buffer) - // Format the input with query and variables inputJSON := fmt.Sprintf(`{"query":%q,"body":%s}`, query, variables) - err = ds.Load(context.Background(), []byte(inputJSON), output) + output, err := ds.Load(context.Background(), []byte(inputJSON)) require.NoError(t, err) // Set up the correct response structure based on your GraphQL schema @@ -332,7 +326,7 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { } var resp response - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") // Check if there are any errors in the response @@ -407,11 +401,10 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { require.NoError(t, err) // 4. Execute the query - output := new(bytes.Buffer) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), output) + output, err := ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(t, err, "Load should not return an error even when the gRPC call fails") - responseJson := output.String() + responseJson := string(output) // 5. Verify the response format according to GraphQL specification // The response should have an "errors" array with the error message @@ -425,7 +418,7 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { } `json:"errors"` } - err = json.Unmarshal(output.Bytes(), &response) + err = json.Unmarshal(output, &response) require.NoError(t, err, "Failed to parse response JSON") // Verify there's at least one error @@ -733,9 +726,8 @@ func Test_DataSource_Load_WithAnimalInterface(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -746,7 +738,7 @@ func Test_DataSource_Load_WithAnimalInterface(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -1004,9 +996,8 @@ func Test_Datasource_Load_WithUnionTypes(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -1017,7 +1008,7 @@ func Test_Datasource_Load_WithUnionTypes(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -1141,9 +1132,8 @@ func Test_DataSource_Load_WithCategoryQueries(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -1154,7 +1144,7 @@ func Test_DataSource_Load_WithCategoryQueries(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -1222,9 +1212,8 @@ func Test_DataSource_Load_WithTotalCalculation(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, query, variables) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -1246,7 +1235,7 @@ func Test_DataSource_Load_WithTotalCalculation(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") @@ -1313,9 +1302,8 @@ func Test_DataSource_Load_WithTypename(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":{}}`, query) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -1332,7 +1320,7 @@ func Test_DataSource_Load_WithTypename(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") @@ -1783,9 +1771,8 @@ func Test_DataSource_Load_WithAliases(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -1796,7 +1783,7 @@ func Test_DataSource_Load_WithAliases(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -2162,9 +2149,8 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -2175,7 +2161,7 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -3464,9 +3450,8 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -3477,7 +3462,7 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -3617,15 +3602,14 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response var resp graphqlResponse - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") tc.validate(t, resp.Data) diff --git a/v2/pkg/engine/datasource/httpclient/httpclient_test.go b/v2/pkg/engine/datasource/httpclient/httpclient_test.go index 223e5d833..cbef2d1f7 100644 --- a/v2/pkg/engine/datasource/httpclient/httpclient_test.go +++ b/v2/pkg/engine/datasource/httpclient/httpclient_test.go @@ -1,7 +1,6 @@ package httpclient import ( - "bytes" "compress/gzip" "context" "io" @@ -80,10 +79,9 @@ func TestHttpClientDo(t *testing.T) { runTest := func(ctx context.Context, input []byte, expectedOutput string) func(t *testing.T) { return func(t *testing.T) { - out := &bytes.Buffer{} - err := Do(http.DefaultClient, ctx, input, out) + output, err := Do(http.DefaultClient, ctx, input) assert.NoError(t, err) - assert.Equal(t, expectedOutput, out.String()) + assert.Equal(t, expectedOutput, string(output)) } } @@ -211,9 +209,8 @@ func TestHttpClientDo(t *testing.T) { input = SetInputURL(input, []byte(server.URL)) input, err := sjson.SetBytes(input, TRACE, true) assert.NoError(t, err) - out := &bytes.Buffer{} - err = Do(http.DefaultClient, context.Background(), input, out) + output, err := Do(http.DefaultClient, context.Background(), input) assert.NoError(t, err) - assert.Contains(t, out.String(), `"Authorization":["****"]`) + assert.Contains(t, string(output), `"Authorization":["****"]`) }) } diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 4e8ca9b31..0eb4360fa 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -254,21 +254,27 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head return err } -func Do(client *http.Client, ctx context.Context, requestInput []byte, out *bytes.Buffer) (err error) { +func Do(client *http.Client, ctx context.Context, requestInput []byte) (data []byte, err error) { url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) h := pool.Hash64.Get() _, _ = h.Write(body) bodyHash := h.Sum64() pool.Hash64.Put(h) ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - return makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, out, ContentTypeJSON) + + var buf bytes.Buffer + err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, &buf, ContentTypeJSON) + if err != nil { + return nil, err + } + return buf.Bytes(), nil } func DoMultipartForm( - client *http.Client, ctx context.Context, requestInput []byte, files []*FileUpload, out *bytes.Buffer, -) (err error) { + client *http.Client, ctx context.Context, requestInput []byte, files []*FileUpload, +) (data []byte, err error) { if len(files) == 0 { - return errors.New("no files provided") + return nil, errors.New("no files provided") } url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) @@ -300,7 +306,7 @@ func DoMultipartForm( temporaryFile, err := os.Open(file.Path()) tempFiles = append(tempFiles, temporaryFile) if err != nil { - return err + return nil, err } formValues[key] = bufio.NewReader(temporaryFile) } @@ -309,7 +315,7 @@ func DoMultipartForm( multipartBody, contentType, err := multipartBytes(formValues, files) if err != nil { - return err + return nil, err } defer func() { @@ -327,7 +333,12 @@ func DoMultipartForm( bodyHash := h.Sum64() ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - return makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, out, contentType) + var buf bytes.Buffer + err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, &buf, contentType) + if err != nil { + return nil, err + } + return buf.Bytes(), nil } func multipartBytes(values map[string]io.Reader, files []*FileUpload) (*io.PipeReader, string, error) { diff --git a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden index 0064f2d6b..43d477605 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden +++ b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden @@ -353,4 +353,4 @@ } ], "__typename": "__Schema" -} +} \ No newline at end of file diff --git a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden index 0e8d299c2..240e7f0c3 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden +++ b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden @@ -501,4 +501,4 @@ } ], "__typename": "__Schema" -} +} \ No newline at end of file diff --git a/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden b/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden index 41827c0f6..16017d131 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden +++ b/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden @@ -56,4 +56,4 @@ "interfaces": [], "possibleTypes": [], "__typename": "__Type" -} +} \ No newline at end of file diff --git a/v2/pkg/engine/datasource/introspection_datasource/source.go b/v2/pkg/engine/datasource/introspection_datasource/source.go index b9a06489d..a55549ace 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source.go @@ -1,7 +1,6 @@ package introspection_datasource import ( - "bytes" "context" "encoding/json" "errors" @@ -19,21 +18,21 @@ type Source struct { introspectionData *introspection.Data } -func (s *Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (s *Source) Load(ctx context.Context, input []byte) (data []byte, err error) { var req introspectionInput if err := json.Unmarshal(input, &req); err != nil { - return err + return nil, err } if req.RequestType == TypeRequestType { - return s.singleType(out, req.TypeName) + return s.singleTypeBytes(req.TypeName) } - return json.NewEncoder(out).Encode(s.introspectionData.Schema) + return json.Marshal(s.introspectionData.Schema) } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { - return errors.New("introspection data source does not support file uploads") +func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { + return nil, errors.New("introspection data source does not support file uploads") } func (s *Source) typeInfo(typeName *string) *introspection.FullType { @@ -57,3 +56,12 @@ func (s *Source) singleType(w io.Writer, typeName *string) error { return json.NewEncoder(w).Encode(typeInfo) } + +func (s *Source) singleTypeBytes(typeName *string) ([]byte, error) { + typeInfo := s.typeInfo(typeName) + if typeInfo == nil { + return null, nil + } + + return json.Marshal(typeInfo) +} diff --git a/v2/pkg/engine/datasource/introspection_datasource/source_test.go b/v2/pkg/engine/datasource/introspection_datasource/source_test.go index bb4a91143..7c331b7d1 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source_test.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source_test.go @@ -27,13 +27,18 @@ func TestSource_Load(t *testing.T) { gen.Generate(&def, &report, &data) require.False(t, report.HasErrors()) - buf := &bytes.Buffer{} source := &Source{introspectionData: &data} - require.NoError(t, source.Load(context.Background(), []byte(input), buf)) + responseData, err := source.Load(context.Background(), []byte(input)) + require.NoError(t, err) actualResponse := &bytes.Buffer{} - require.NoError(t, json.Indent(actualResponse, buf.Bytes(), "", " ")) - goldie.Assert(t, fixtureName, actualResponse.Bytes()) + require.NoError(t, json.Indent(actualResponse, responseData, "", " ")) + // Trim the trailing newline that json.Indent adds + responseBytes := actualResponse.Bytes() + if len(responseBytes) > 0 && responseBytes[len(responseBytes)-1] == '\n' { + responseBytes = responseBytes[:len(responseBytes)-1] + } + goldie.Assert(t, fixtureName, responseBytes) } } diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go index cc562b803..7f1a6226b 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go @@ -1,10 +1,8 @@ package pubsub_datasource import ( - "bytes" "context" "encoding/json" - "io" "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" @@ -68,21 +66,19 @@ type KafkaPublishDataSource struct { pubSub KafkaPubSub } -func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { +func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { var publishConfiguration KafkaPublishEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) + err = json.Unmarshal(input, &publishConfiguration) if err != nil { - return err + return nil, err } if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + return []byte(`{"success": false}`), err } - _, err = io.WriteString(out, `{"success": true}`) - return err + return []byte(`{"success": true}`), nil } -func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go index 31cb6d415..e5d3bec0f 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go @@ -77,23 +77,21 @@ type NatsPublishDataSource struct { pubSub NatsPubSub } -func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { +func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { var publishConfiguration NatsPublishAndRequestEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) + err = json.Unmarshal(input, &publishConfiguration) if err != nil { - return err + return nil, err } if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + return []byte(`{"success": false}`), err } - _, err = io.WriteString(out, `{"success": true}`) - return err + return []byte(`{"success": true}`), nil } -func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { +func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } @@ -101,16 +99,22 @@ type NatsRequestDataSource struct { pubSub NatsPubSub } -func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { +func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { var subscriptionConfiguration NatsPublishAndRequestEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) + err = json.Unmarshal(input, &subscriptionConfiguration) if err != nil { - return err + return nil, err + } + + var buf bytes.Buffer + err = s.pubSub.Request(ctx, subscriptionConfiguration, &buf) + if err != nil { + return nil, err } - return s.pubSub.Request(ctx, subscriptionConfiguration, out) + return buf.Bytes(), nil } -func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { +func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go index e9074635c..626a1d9f9 100644 --- a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go +++ b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go @@ -1,7 +1,6 @@ package staticdatasource import ( - "bytes" "context" "github.com/jensneuse/abstractlogger" @@ -71,11 +70,10 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration { type Source struct{} -func (Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { - _, err = out.Write(input) - return +func (Source) Load(ctx context.Context, input []byte) (data []byte, err error) { + return input, nil } -func (Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/plan/planner_test.go b/v2/pkg/engine/plan/planner_test.go index 270140381..658ff3fc7 100644 --- a/v2/pkg/engine/plan/planner_test.go +++ b/v2/pkg/engine/plan/planner_test.go @@ -1,7 +1,6 @@ package plan import ( - "bytes" "context" "encoding/json" "fmt" @@ -1075,10 +1074,10 @@ type FakeDataSource struct { source *StatefulSource } -func (f *FakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { - return +func (f *FakeDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { + return nil, nil } -func (f *FakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { - return +func (f *FakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { + return nil, nil } diff --git a/v2/pkg/engine/resolve/authorization_test.go b/v2/pkg/engine/resolve/authorization_test.go index 263724a77..ea83c7725 100644 --- a/v2/pkg/engine/resolve/authorization_test.go +++ b/v2/pkg/engine/resolve/authorization_test.go @@ -1,7 +1,6 @@ package resolve import ( - "bytes" "context" "encoding/json" "errors" @@ -510,38 +509,32 @@ func TestAuthorization(t *testing.T) { func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller) *GraphQLResponse { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }).AnyTimes() reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }).AnyTimes() productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }).AnyTimes() return &GraphQLResponse{ @@ -821,38 +814,32 @@ func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller func generateTestFederationGraphQLResponseWithoutAuthorizationRules(t *testing.T, ctrl *gomock.Controller) *GraphQLResponse { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }).AnyTimes() reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }).AnyTimes() productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }).AnyTimes() return &GraphQLResponse{ diff --git a/v2/pkg/engine/resolve/datasource.go b/v2/pkg/engine/resolve/datasource.go index c679d7693..8063541f6 100644 --- a/v2/pkg/engine/resolve/datasource.go +++ b/v2/pkg/engine/resolve/datasource.go @@ -1,7 +1,6 @@ package resolve import ( - "bytes" "context" "github.com/cespare/xxhash/v2" @@ -10,8 +9,8 @@ import ( ) type DataSource interface { - Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) - LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) + Load(ctx context.Context, input []byte) (data []byte, err error) + LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) } type SubscriptionDataSource interface { diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index ad4e78e47..1bab9779b 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -57,11 +57,7 @@ type ResponseInfo struct { // ResponseHeaders contains a clone of the headers of the response from the subgraph. ResponseHeaders http.Header // This should be private as we do not want user's to access the raw responseBody directly - responseBody *bytes.Buffer -} - -func (ri *ResponseInfo) GetResponseBody() string { - return ri.responseBody.String() + responseBody []byte } func newResponseInfo(res *result, subgraphError error) *ResponseInfo { @@ -119,7 +115,6 @@ func (b *batchStats) getUniqueIndexes() int { type result struct { postProcessing PostProcessingConfiguration - out *bytes.Buffer batchStats batchStats fetchSkipped bool nestedMergeItems []*result @@ -139,6 +134,7 @@ type result struct { loaderHookContext context.Context httpResponseContext *httpclient.ResponseContext + out []byte } func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { @@ -283,9 +279,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { switch f := item.Fetch.(type) { case *SingleFetch: - res := &result{ - out: &bytes.Buffer{}, - } + res := &result{} err := l.loadSingleFetch(l.ctx.ctx, f, item, items, res) if err != nil { return err @@ -297,9 +291,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { return err case *BatchEntityFetch: - res := &result{ - out: &bytes.Buffer{}, - } + res := &result{} err := l.loadBatchEntityFetch(l.ctx.ctx, item, f, items, res) if err != nil { return errors.WithStack(err) @@ -310,9 +302,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { } return err case *EntityFetch: - res := &result{ - out: &bytes.Buffer{}, - } + res := &result{} err := l.loadEntityFetch(l.ctx.ctx, item, f, items, res) if err != nil { return errors.WithStack(err) @@ -330,9 +320,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { g, ctx := errgroup.WithContext(l.ctx.ctx) for i := range items { i := i - results[i] = &result{ - out: &bytes.Buffer{}, - } + results[i] = &result{} if l.ctx.TracingOptions.Enable { f.Traces[i] = new(SingleFetch) *f.Traces[i] = *f.Fetch @@ -453,7 +441,6 @@ func itemsData(a arena.Arena, items []*astjson.Value) *astjson.Value { func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchItem, items []*astjson.Value, res *result) error { switch f := fetch.(type) { case *SingleFetch: - res.out = &bytes.Buffer{} return l.loadSingleFetch(ctx, f, fetchItem, items, res) case *ParallelListItemFetch: results := make([]*result, len(items)) @@ -463,9 +450,7 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchIte g, ctx := errgroup.WithContext(l.ctx.ctx) for i := range items { i := i - results[i] = &result{ - out: &bytes.Buffer{}, - } + results[i] = &result{} if l.ctx.TracingOptions.Enable { f.Traces[i] = new(SingleFetch) *f.Traces[i] = *f.Fetch @@ -485,10 +470,8 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchIte res.nestedMergeItems = results return nil case *EntityFetch: - res.out = &bytes.Buffer{} return l.loadEntityFetch(ctx, fetchItem, f, items, res) case *BatchEntityFetch: - res.out = &bytes.Buffer{} return l.loadBatchEntityFetch(ctx, fetchItem, f, items, res) } return nil @@ -551,11 +534,12 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson if res.fetchSkipped { return nil } - if res.out.Len() == 0 { + if len(res.out) == 0 { return l.renderErrorsFailedToFetch(fetchItem, res, emptyGraphQLResponse) } - - response, err := astjson.ParseBytesWithArena(l.jsonArena, res.out.Bytes()) + slice := arena.AllocateSlice[byte](l.jsonArena, len(res.out), len(res.out)) + copy(slice, res.out) + response, err := astjson.ParseBytesWithArena(l.jsonArena, slice) if err != nil { // Fall back to status code if parsing fails and non-2XX if (res.statusCode > 0 && res.statusCode < 200) || res.statusCode >= 300 { @@ -706,7 +690,8 @@ var ( errorsInvalidInputFooter = []byte(`]}]}`) ) -func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem, out *bytes.Buffer) error { +func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem) []byte { + out := &bytes.Buffer{} elements := fetchItem.ResponsePathElements if len(elements) > 0 && elements[len(elements)-1] == "@" { elements = elements[:len(elements)-1] @@ -724,7 +709,7 @@ func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem, out *bytes.Buffe _, _ = out.Write(quote) } _, _ = out.Write(errorsInvalidInputFooter) - return nil + return out.Bytes() } func (l *Loader) appendSubgraphError(res *result, fetchItem *FetchItem, value *astjson.Value, values []*astjson.Value) error { @@ -1312,7 +1297,8 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI err := fetch.InputTemplate.Render(l.ctx, inputData, buf) if err != nil { - return l.renderErrorsInvalidInput(fetchItem, res.out) + res.out = l.renderErrorsInvalidInput(fetchItem) + return nil } fetchInput := buf.Bytes() allowed, err := l.validatePreFetch(fetchInput, fetch.Info, res) @@ -1648,9 +1634,14 @@ func (l *Loader) setTracingInput(fetchItem *FetchItem, input []byte, trace *Data func (l *Loader) loadByContext(ctx context.Context, source DataSource, input []byte, res *result) error { if l.ctx.Files != nil { - return source.LoadWithFiles(ctx, input, l.ctx.Files, res.out) + res.out, res.err = source.LoadWithFiles(ctx, input, l.ctx.Files) + } else { + res.out, res.err = source.Load(ctx, input) } - return source.Load(ctx, input, res.out) + if res.err != nil { + return errors.WithStack(res.err) + } + return nil } func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, source DataSource, input []byte, res *result, trace *DataSourceLoadTrace) { @@ -1813,8 +1804,8 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so trace.SingleFlightUsed = stats.SingleFlightUsed trace.SingleFlightSharedResponse = stats.SingleFlightSharedResponse } - if !l.ctx.TracingOptions.ExcludeOutput && res.out.Len() > 0 { - trace.Output, _ = l.compactJSON(res.out.Bytes()) + if !l.ctx.TracingOptions.ExcludeOutput && len(res.out) > 0 { + trace.Output, _ = l.compactJSON(res.out) if l.ctx.TracingOptions.EnablePredictableDebugTimings { trace.Output, _ = sjson.DeleteBytes(trace.Output, "extensions.trace.response.headers.Date") } diff --git a/v2/pkg/engine/resolve/loader_hooks_test.go b/v2/pkg/engine/resolve/loader_hooks_test.go index 4b7b3ea6c..d82857598 100644 --- a/v2/pkg/engine/resolve/loader_hooks_test.go +++ b/v2/pkg/engine/resolve/loader_hooks_test.go @@ -3,7 +3,6 @@ package resolve import ( "bytes" "context" - "io" "sync" "sync/atomic" "testing" @@ -50,11 +49,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("simple fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := Context{ ctx: context.Background(), @@ -124,11 +121,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := &Context{ ctx: context.Background(), @@ -192,11 +187,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("parallel fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := &Context{ ctx: context.Background(), @@ -257,11 +250,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("parallel list item fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := Context{ ctx: context.Background(), @@ -322,12 +313,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("fetch with subgraph error and custom extension code. No extension fields are propagated by default", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) resolveCtx := Context{ ctx: context.Background(), @@ -388,12 +376,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Propagate only extension code field from subgraph errors", testFnSubgraphErrorsWithExtensionFieldCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\",\"foo\":\"bar\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","foo":"bar"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -426,12 +411,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Propagate all extension fields from subgraph errors when allow all option is enabled", testFnSubgraphErrorsWithAllowAllExtensionFields(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\",\"foo\":\"bar\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","foo":"bar"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -464,12 +446,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName extension field", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -502,12 +481,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName when extensions is null", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("null")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("null")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":null},{"message":"errorMessage2","extensions":null}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -540,12 +516,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName when extensions is an empty object", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("null")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{}},{"message":"errorMessage2","extensions":null}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -578,12 +551,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when no code field was set", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -616,12 +586,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when extensions is null", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("null")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":null},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -654,12 +621,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when extensions is an empty object", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{}},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ diff --git a/v2/pkg/engine/resolve/loader_test.go b/v2/pkg/engine/resolve/loader_test.go index 01c5ef5dc..0fe38ddc7 100644 --- a/v2/pkg/engine/resolve/loader_test.go +++ b/v2/pkg/engine/resolve/loader_test.go @@ -19,19 +19,19 @@ func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctrl := gomock.NewController(t) productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + `{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"}]}`) + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`) response := &GraphQLResponse{ Fetches: Sequence( Single(&SingleFetch{ @@ -480,19 +480,19 @@ func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { ctrl := gomock.NewController(t) productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}","extensions":{"foo":"bar"}}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]},"extensions":{"foo":"bar"}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]},"extensions":{"foo":"bar"}}}`, - `{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + `{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]},"extensions":{"foo":"bar"}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"}]}`) + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`) response := &GraphQLResponse{ Fetches: Sequence( Single(&SingleFetch{ @@ -1054,7 +1054,7 @@ func TestLoader_RedactHeaders(t *testing.T) { productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","header":{"Authorization":"value"},"body":{"query":"query{topProducts{name __typename upc}}"},"__trace__":true}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) response := &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -1153,19 +1153,19 @@ func TestLoader_InvalidBatchItemCount(t *testing.T) { ctrl := gomock.NewController(t) productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"stock":8},{"stock":2}]}`) // 3 items expected, 2 returned + `{"data":{"_entities":[{"stock":8},{"stock":2}]}}`) // 3 items expected, 2 returned usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"},{"name":"user-3"}]}`) // 2 items expected, 3 returned + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"},{"name":"user-3"}]}}`) // 2 items expected, 3 returned response := &GraphQLResponse{ Fetches: Sequence( Single(&SingleFetch{ diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 92501bd2e..4a0075f6b 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "go.uber.org/atomic" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/xcontext" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) @@ -303,6 +304,11 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + jsonArena := arena.NewMonotonicArena() + defer jsonArena.Release() + t.loader.jsonArena = jsonArena + t.resolvable.astjsonArena = jsonArena + err := t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { return nil, err diff --git a/v2/pkg/engine/resolve/resolve_federation_test.go b/v2/pkg/engine/resolve/resolve_federation_test.go index 2547c6d10..64d969c6c 100644 --- a/v2/pkg/engine/resolve/resolve_federation_test.go +++ b/v2/pkg/engine/resolve/resolve_federation_test.go @@ -1,9 +1,7 @@ package resolve import ( - "bytes" "context" - "io" "testing" "github.com/golang/mock/gomock" @@ -21,18 +19,11 @@ func mockedDS(t TestingTB, ctrl *gomock.Controller, expectedInput, responseData t.Helper() service := NewMockDataSource(ctrl) service.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - expected := expectedInput - - require.Equal(t, expected, actual) - - pair := NewBufPair() - pair.Data.WriteString(responseData) - - return writeGraphqlResponse(pair, w, false) - }).AnyTimes() + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + require.Equal(t, expectedInput, string(input)) + return []byte(responseData), nil + }).Times(1) return service } @@ -48,7 +39,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { DataSource: mockedDS( t, ctrl, `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {__typename id info {a b}}}}"}}`, - `{"user":{"account":{"__typename":"Account","id":"1234","info":{"a":"foo","b":"bar"}}}}`, + `{"data":{"user":{"account":{"__typename":"Account","id":"1234","info":{"a":"foo","b":"bar"}}}}}`, ), Input: `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {__typename id info {a b}}}}"}}`, PostProcessing: PostProcessingConfiguration{ @@ -70,7 +61,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { DataSource: mockedDS( t, ctrl, expectedAccountsQuery, - `{"_entities":[{"__typename":"Account","name":"John Doe","shippingInfo":{"zip":"12345"}}]}`, + `{"data":{"_entities":[{"__typename":"Account","name":"John Doe","shippingInfo":{"zip":"12345"}}]}}`, ), Input: `{"method":"POST","url":"http://account.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Account {name shippingInfo {zip}}}}","variables":{"representations":$$0$$}}}`, PostProcessing: PostProcessingConfiguration{ @@ -182,38 +173,38 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("federation with shareable", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { firstService := NewMockDataSource(ctrl) firstService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://first.service","body":{"query":"{me {details {forename middlename} __typename id}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"me": {"__typename": "User", "id": "1234", "details": {"forename": "John", "middlename": "A"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"me": {"__typename": "User", "id": "1234", "details": {"forename": "John", "middlename": "A"}}}}`) + return pair.Data.Bytes(), nil }) secondService := NewMockDataSource(ctrl) secondService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://second.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {details {surname}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename": "User", "details": {"surname": "Smith"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities": [{"__typename": "User", "details": {"surname": "Smith"}}]}}`) + return pair.Data.Bytes(), nil }) thirdService := NewMockDataSource(ctrl) thirdService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://third.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {details {age}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename": "User", "details": {"age": 21}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities": [{"__typename": "User", "details": {"age": 21}}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -377,26 +368,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name infoOrAddress { ... on Info {id __typename} ... on Address {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Info"},{"id": 55,"__typename":"Address"}]}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Info"},{"id": 55,"__typename":"Address"}]}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age } ... on Address { line1 }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":55,"__typename":"Address"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"line1":"Munich","__typename":"Address"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"line1":"Munich","__typename":"Address"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -530,19 +521,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name infoOrAddress { ... on Info {id __typename} ... on Address {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Whatever"},{"id": 55,"__typename":"Whatever"}]}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Whatever"},{"id": 55,"__typename":"Whatever"}]}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -675,26 +666,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching on a field", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":12,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"},{"age":23,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"},{"age":23,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -819,26 +810,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with duplicates", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":11,"__typename":"Info"}},{"name":"Jane","info":{"id":11,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":11,"__typename":"Info"}},{"name":"Jane","info":{"id":11,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":77,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":77,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -960,26 +951,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with null entry", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":null},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":null},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"age":23,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"age":23,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -1105,19 +1096,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with all null entries", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":null},{"name":"John","info":null},{"name":"Jane","info":null}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":null},{"name":"John","info":null},{"name":"Jane","info":null}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1243,27 +1234,27 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with render error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() // render error - first item id is boolean - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":true,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":true,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":12,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -1390,26 +1381,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("all data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":{"id":11,"__typename":"Info"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":{"id":11,"__typename":"Info"}}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -1524,19 +1515,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("null info data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":null}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":null}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1652,19 +1643,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("wrong type data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":{"id":false,"__typename":"Info"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":{"id":false,"__typename":"Info"}}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1780,19 +1771,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("not matching type data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":{"id":1,"__typename":"Whatever"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":{"id":1,"__typename":"Whatever"}}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1912,19 +1903,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { user := mockedDS(t, ctrl, `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {address {__typename id line1 line2}}}}"}}`, - `{"user":{"account":{"address":{"__typename":"Address","id":"address-1","line1":"line1","line2":"line2"}}}}`) + `{"data":{"user":{"account":{"address":{"__typename":"Address","id":"address-1","line1":"line1","line2":"line2"}}}}}`) addressEnricher := mockedDS(t, ctrl, `{"method":"POST","url":"http://address-enricher.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Address {country city}}}","variables":{"representations":[{"__typename":"Address","id":"address-1"}]}}}`, - `{"__typename":"Address","country":"country-1","city":"city-1"}`) + `{"data":{"__typename":"Address","country":"country-1","city":"city-1"}}`) address := mockedDS(t, ctrl, `{"method":"POST","url":"http://address.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Address {line3(test: "BOOM") zip}}}","variables":{"representations":[{"__typename":"Address","id":"address-1","country":"country-1","city":"city-1"}]}}}`, - `{"__typename": "Address", "line3": "line3-1", "zip": "zip-1"}`) + `{"data":{"__typename": "Address", "line3": "line3-1", "zip": "zip-1"}}`) account := mockedDS(t, ctrl, `{"method":"POST","url":"http://account.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Address {fullAddress}}}","variables":{"representations":[{"__typename":"Address","id":"address-1","line1":"line1","line2":"line2","line3":"line3-1","zip":"zip-1"}]}}}`, - `{"__typename":"Address","fullAddress":"line1 line2 line3-1 city-1 country-1 zip-1"}`) + `{"data":{"__typename":"Address","fullAddress":"line1 line2 line3-1 city-1 country-1 zip-1"}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2152,19 +2143,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + `{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"}]}`) + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2424,19 +2415,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"}]}}}`, - `{"_entities":[{"stock":8}]}`) + `{"data":{"_entities":[{"stock":8}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"}]}}}`, - `{"_entities":[{"name":"user-1"}]}`) + `{"data":{"_entities":[{"name":"user-1"}]}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2696,11 +2687,11 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { accountsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://accounts","body":{"query":"{accounts{__typename ... on User {__typename id} ... on Moderator {__typename moderatorID} ... on Admin {__typename adminID}}}"}}`, - `{"accounts":[{"__typename":"User","id":"3"},{"__typename":"Admin","adminID":"2"},{"__typename":"Moderator","moderatorID":"1"}]}`) + `{"data":{"accounts":[{"__typename":"User","id":"3"},{"__typename":"Admin","adminID":"2"},{"__typename":"Moderator","moderatorID":"1"}]}}`) namesService := mockedDS(t, ctrl, `{"method":"POST","url":"http://names","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name} ... on Moderator {subject} ... on Admin {type}}}","variables":{"representations":[{"__typename":"User","id":"3"},{"__typename":"Admin","adminID":"2"},{"__typename":"Moderator","moderatorID":"1"}]}}}`, - `{"_entities":[{"__typename":"User","name":"User"},{"__typename":"Admin","type":"super"},{"__typename":"Moderator","subject":"posts"}]}`) + `{"data":{"_entities":[{"__typename":"User","name":"User"},{"__typename":"Admin","type":"super"},{"__typename":"Moderator","subject":"posts"}]}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2836,11 +2827,11 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { accountsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://accounts","body":{"query":"{accounts {__typename ... on User {some {__typename id}} ... on Admin {some {__typename id}}}}"}}`, - `{"accounts":[{"__typename":"User","some":{"__typename":"User","id":"1"}},{"__typename":"Admin","some":{"__typename":"User","id":"2"}},{"__typename":"User","some":{"__typename":"User","id":"3"}}]}`) + `{"data":{"accounts":[{"__typename":"User","some":{"__typename":"User","id":"1"}},{"__typename":"Admin","some":{"__typename":"User","id":"2"}},{"__typename":"User","some":{"__typename":"User","id":"3"}}]}}`) namesService := mockedDS(t, ctrl, `{"method":"POST","url":"http://names","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {__typename title}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"3"}]}}}`, - `{"_entities":[{"__typename":"User","title":"User1"},{"__typename":"User","title":"User3"}]}`) + `{"data":{"_entities":[{"__typename":"User","title":"User1"},{"__typename":"User","title":"User3"}]}}`) return &GraphQLResponse{ Fetches: Sequence( diff --git a/v2/pkg/engine/resolve/resolve_mock_test.go b/v2/pkg/engine/resolve/resolve_mock_test.go index 3f72cc3d8..d493ff4bd 100644 --- a/v2/pkg/engine/resolve/resolve_mock_test.go +++ b/v2/pkg/engine/resolve/resolve_mock_test.go @@ -5,7 +5,6 @@ package resolve import ( - bytes "bytes" context "context" reflect "reflect" @@ -37,29 +36,31 @@ func (m *MockDataSource) EXPECT() *MockDataSourceMockRecorder { } // Load mocks base method. -func (m *MockDataSource) Load(arg0 context.Context, arg1 []byte, arg2 *bytes.Buffer) error { +func (m *MockDataSource) Load(arg0 context.Context, arg1 []byte) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Load", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "Load", arg0, arg1) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // Load indicates an expected call of Load. -func (mr *MockDataSourceMockRecorder) Load(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockDataSourceMockRecorder) Load(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockDataSource)(nil).Load), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockDataSource)(nil).Load), arg0, arg1) } // LoadWithFiles mocks base method. -func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 []byte, arg2 []*httpclient.FileUpload, arg3 *bytes.Buffer) error { +func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 []byte, arg2 []*httpclient.FileUpload) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadWithFiles", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "LoadWithFiles", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // LoadWithFiles indicates an expected call of LoadWithFiles. -func (mr *MockDataSourceMockRecorder) LoadWithFiles(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockDataSourceMockRecorder) LoadWithFiles(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWithFiles", reflect.TypeOf((*MockDataSource)(nil).LoadWithFiles), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWithFiles", reflect.TypeOf((*MockDataSource)(nil).LoadWithFiles), arg0, arg1, arg2) } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 8e15ff98a..d19156f36 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -32,7 +32,7 @@ type _fakeDataSource struct { artificialLatency time.Duration } -func (f *_fakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (f *_fakeDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -41,11 +41,10 @@ func (f *_fakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buf require.Equal(f.t, string(f.input), string(input), "input mismatch") } } - _, err = out.Write(f.data) - return + return f.data, nil } -func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -54,8 +53,7 @@ func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files require.Equal(f.t, string(f.input), string(input), "input mismatch") } } - _, err = out.Write(f.data) - return + return f.data, nil } func FakeDataSource(data string) *_fakeDataSource { @@ -351,12 +349,11 @@ func TestResolver_ResolveNode(t *testing.T) { t.Run("fetch with context variable resolver", testFn(true, func(t *testing.T, ctrl *gomock.Controller) (response *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), []byte(`{"id":1}`), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { - _, err = w.Write([]byte(`{"name":"Jens"}`)) - return + Load(gomock.Any(), []byte(`{"id":1}`)). + Do(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"name":"Jens"}`), nil }). - Return(nil) + Return([]byte(`{"name":"Jens"}`), nil) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ FetchConfiguration: FetchConfiguration{ @@ -1802,11 +1799,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error without datasource ID", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -1834,11 +1829,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error without datasource ID no subgraph error forwarding", testFnNoSubgraphErrorForwarding(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -1866,11 +1859,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -1902,11 +1893,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error in pass through Subgraph Error Mode", testFnSubgraphErrorsPassthrough(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -1938,10 +1927,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with pass through mode and omit custom fields", testFnSubgraphErrorsPassthroughAndOmitCustomFields(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) error { - _, err := w.Write([]byte(`{"errors":[{"message":"errorMessage","longMessage":"This is a long message","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":{"name":null}}`)) - return err + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","longMessage":"This is a long message","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":{"name":null}}`), nil }) return &GraphQLResponse{ Info: &GraphQLResponseInfo{ @@ -1976,9 +1964,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err (with DataSourceID)", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - return &net.AddrError{} + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return nil, &net.AddrError{} }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -2010,9 +1998,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err (no DataSourceID)", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - return &net.AddrError{} + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return nil, &net.AddrError{} }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -2040,9 +2028,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err and non-nullable root field", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - return &net.AddrError{} + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return nil, &net.AddrError{} }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -2218,14 +2206,10 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with two Errors", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage1"), nil, nil, nil) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) - }). - Return(nil) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage1"},{"message":"errorMessage2"}]}`), nil + }).Times(1) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ FetchConfiguration: FetchConfiguration{ @@ -2578,39 +2562,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("complex GraphQL Server plan", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { serviceOne := NewMockDataSource(ctrl) serviceOne.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.one","body":{"query":"query($firstArg: String, $thirdArg: Int){serviceOne(serviceOneArg: $firstArg){fieldOne} anotherServiceOne(anotherServiceOneArg: $thirdArg){fieldOne} reusingServiceOne(reusingServiceOneArg: $firstArg){fieldOne}}","variables":{"thirdArg":123,"firstArg":"firstArgValue"}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"serviceOne":{"fieldOne":"fieldOneValue"},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"serviceOne":{"fieldOne":"fieldOneValue"},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}}`), nil }) serviceTwo := NewMockDataSource(ctrl) serviceTwo.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.two","body":{"query":"query($secondArg: Boolean, $fourthArg: Float){serviceTwo(serviceTwoArg: $secondArg){fieldTwo} secondServiceTwo(secondServiceTwoArg: $fourthArg){fieldTwo}}","variables":{"fourthArg":12.34,"secondArg":true}}}` assert.Equal(t, expected, actual) - - pair := NewBufPair() - pair.Data.WriteString(`{"serviceTwo":{"fieldTwo":"fieldTwoValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"serviceTwo":{"fieldTwo":"fieldTwoValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"}}}`), nil }) nestedServiceOne := NewMockDataSource(ctrl) nestedServiceOne.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.one","body":{"query":"{serviceOne {fieldOne}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"serviceOne":{"fieldOne":"fieldOneValue"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"serviceOne":{"fieldOne":"fieldOneValue"}}}`), nil }) return &GraphQLResponse{ @@ -2821,52 +2798,42 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename":"User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename":"User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) - // {"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":["id":"1234","__typename":"User"]}}} expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }) var productServiceCallCount atomic.Int64 productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) productServiceCallCount.Add(1) switch actual { case `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"}]}}}`: - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"name": "Furby"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"name": "Furby"}]}}`), nil case `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-2","__typename":"Product"}]}}}`: - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"name": "Trilby"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"name": "Trilby"}]}}`), nil default: t.Fatalf("unexpected request: %s", actual) } - return - }). - Return(nil).Times(2) + return nil, nil + }).Times(2) return &GraphQLResponse{ Fetches: Sequence( @@ -3038,38 +3005,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with batch", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }) return &GraphQLResponse{ @@ -3241,38 +3202,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with merge paths", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }) return &GraphQLResponse{ @@ -3445,45 +3400,39 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with null response", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews": [ + return []byte(`{"data":{"_entities":[{"reviews": [ {"body": "foo","product": {"upc": "top-1","__typename": "Product"}}, {"body": "bar","product": {"upc": "top-2","__typename": "Product"}}, {"body": "baz","product": null}, {"body": "bat","product": {"upc": "top-4","__typename": "Product"}}, {"body": "bal","product": {"upc": "top-5","__typename": "Product"}}, {"body": "ban","product": {"upc": "top-6","__typename": "Product"}} -]}]}`) - return writeGraphqlResponse(pair, w, false) +]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"},{"upc":"top-4","__typename":"Product"},{"upc":"top-5","__typename":"Product"},{"upc":"top-6","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"name":"Trilby"},{"name":"Fedora"},{"name":"Boater"},{"name":"Top Hat"},{"name":"Bowler"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"name":"Trilby"},{"name":"Fedora"},{"name":"Boater"},{"name":"Top Hat"},{"name":"Bowler"}]}}`), nil }) return &GraphQLResponse{ @@ -3678,38 +3627,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me": {"id": "1234","username": "Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -3871,38 +3814,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me": {"id": "1234","username": "Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -4061,38 +3998,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with optional variable", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8080/query","body":{"query":"{me {id}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","__typename":"User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","__typename":"User"}}}`), nil }) employeeService := NewMockDataSource(ctrl) employeeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8081/query","body":{"query":"query($representations: [_Any!]!, $companyId: ID!){_entities(representations: $representations){... on User {employment(companyId: $companyId){id}}}}","variables":{"companyId":"abc123","representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"employment":{"id":"xyz987"}}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"employment":{"id":"xyz987"}}]}}`), nil }) timeService := NewMockDataSource(ctrl) timeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8082/query","body":{"query":"query($representations: [_Any!]!, $date: LocalTime){_entities(representations: $representations){... on Employee {times(date: $date){id employee {id} start end}}}}","variables":{"date":null,"representations":[{"id":"xyz987","__typename":"Employee"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"times":[{"id": "t1","employee":{"id":"xyz987"},"start":"2022-11-02T08:00:00","end":"2022-11-02T12:00:00"}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"times":[{"id": "t1","employee":{"id":"xyz987"},"start":"2022-11-02T08:00:00","end":"2022-11-02T12:00:00"}]}]}}`), nil }) return &GraphQLResponse{ @@ -4263,148 +4194,597 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }) } -func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { - options := apolloCompatibilityOptions{ - valueCompletion: true, - suppressFetchErrors: true, +// testFnArena is a helper function for testing ArenaResolveGraphQLResponse +func testFnArena(fn func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string)) func(t *testing.T) { + return func(t *testing.T) { + t.Helper() + + ctrl := gomock.NewController(t) + rCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := newResolver(rCtx) + node, ctx, expectedOutput := fn(t, ctrl) + + if node.Info == nil { + node.Info = &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + } + } + + if t.Skipped() { + return + } + + buf := &bytes.Buffer{} + _, err := r.ArenaResolveGraphQLResponse(&ctx, node, buf) + assert.NoError(t, err) + assert.Equal(t, expectedOutput, buf.String()) + ctrl.Finish() } - t.Run("simple fetch with fetch error suppression - empty response", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - mockDataSource := NewMockDataSource(ctrl) - mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - _, _ = w.Write([]byte("{}")) - return - }) +} + +func TestResolver_ArenaResolveGraphQLResponse(t *testing.T) { + + t.Run("empty graphql response", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ - Fetches: SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), - SegmentType: StaticSegmentType, + Data: &Object{ + Nullable: true, + }, + }, Context{ctx: context.Background()}, `{"data":{}}` + })) + + t.Run("simple data source", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"id":"1","name":"Jens","registered":true}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("registered"), + Value: &Boolean{ + Path: []string{"registered"}, + Nullable: false, + }, + }, + }, }, }, }, - FetchConfiguration: FetchConfiguration{ - DataSource: mockDataSource, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data"}, - SelectResponseErrorsPath: []string{"errors"}, - }, - }, - }, "query"), + }, + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":"1","name":"Jens","registered":true}}}` + })) + + t.Run("array of strings", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"strings": ["Alex", "true", "123"]}`)}, + }), Data: &Object{ Fields: []*Field{ { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, + Name: []byte("strings"), + Value: &Array{ + Path: []string{"strings"}, + Item: &String{ + Nullable: false, + }, }, }, }, }, - }, Context{ctx: context.Background()}, `{"data":null,"extensions":{"valueCompletion":[{"message":"Cannot return null for non-nullable field Query.name.","path":["name"],"extensions":{"code":"INVALID_GRAPHQL"}}]}}` - }, &options)) + }, Context{ctx: context.Background()}, `{"data":{"strings":["Alex","true","123"]}}` + })) - t.Run("simple fetch with fetch error suppression - response with error", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - mockDataSource := NewMockDataSource(ctrl) - mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - _, _ = w.Write([]byte(`{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}]}`)) - return - }) + t.Run("array of objects", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ - Fetches: SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), - SegmentType: StaticSegmentType, + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"friends":[{"id":1,"name":"Alex"},{"id":2,"name":"Patric"}]}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("friends"), + Value: &Array{ + Path: []string{"friends"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, }, }, }, - FetchConfiguration: FetchConfiguration{ - DataSource: mockDataSource, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data"}, - SelectResponseErrorsPath: []string{"errors"}, - }, - }, - }, "query"), + }, + }, Context{ctx: context.Background()}, `{"data":{"friends":[{"id":1,"name":"Alex"},{"id":2,"name":"Patric"}]}}` + })) + + t.Run("nested objects", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"id":"1","name":"Jens","pet":{"name":"Barky","kind":"Dog"}}`)}, + }), Data: &Object{ Fields: []*Field{ { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, + Name: []byte("user"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("pet"), + Value: &Object{ + Path: []string{"pet"}, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("kind"), + Value: &String{ + Path: []string{"kind"}, + Nullable: false, + }, + }, + }, + }, + }, + }, }, }, }, }, - }, Context{ctx: context.Background()}, `{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}],"data":null}` - }, &options)) - - t.Run("complex fetch with fetch error suppression", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - userService := NewMockDataSource(ctrl) - userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` - assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me": {"id": "1234","username": "Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) - }) - - reviewsService := NewMockDataSource(ctrl) - reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` - assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) - }) - - productService := NewMockDataSource(ctrl) - productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` - assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) - }) + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":"1","name":"Jens","pet":{"name":"Barky","kind":"Dog"}}}}` + })) + t.Run("scalar types", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ - Fetches: Sequence( - SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}`), - SegmentType: StaticSegmentType, - }, + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"int": 12345, "float": 3.5, "str":"value", "bool": true}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("int"), + Value: &Integer{ + Path: []string{"int"}, + Nullable: false, }, }, - FetchConfiguration: FetchConfiguration{ - DataSource: userService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data"}, + { + Name: []byte("float"), + Value: &Float{ + Path: []string{"float"}, + Nullable: false, }, }, - }, "query"), - SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { + { + Name: []byte("str"), + Value: &String{ + Path: []string{"str"}, + Nullable: false, + }, + }, + { + Name: []byte("bool"), + Value: &Boolean{ + Path: []string{"bool"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"int":12345,"float":3.5,"str":"value","bool":true}}` + })) + + t.Run("null field", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("foo"), + Value: &Null{}, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"foo":null}}` + })) + + t.Run("__typename field", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"id":1,"name":"Jannik","__typename":"User"}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + Nullable: false, + IsTypeName: true, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":1,"name":"Jannik","__typename":"User"}}}` + })) + + t.Run("multiple fetches", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"user1":{"id":1,"name":"User1"},"user2":{"id":2,"name":"User2"}}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user1"), + Value: &Object{ + Path: []string{"user1"}, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + { + Name: []byte("user2"), + Value: &Object{ + Path: []string{"user2"}, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"user1":{"id":1,"name":"User1"},"user2":{"id":2,"name":"User2"}}}` + })) + + t.Run("with variables", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), []byte(`{"id":1}`)). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"name":"Jens"}`), nil + }) + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: mockDataSource}, + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"id":`), + SegmentType: StaticSegmentType, + }, + { + Data: []byte(`{{.arguments.id}}`), + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"id"}, + Renderer: NewPlainVariableRenderer(), + }, + { + Data: []byte(`}`), + SegmentType: StaticSegmentType, + }, + }, + }, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background(), Variables: astjson.MustParseBytes([]byte(`{"id":1}`))}, `{"data":{"name":"Jens"}}` + })) + + t.Run("error handling", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return nil, errors.New("data source error") + }) + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: mockDataSource}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"errors":[{"message":"Failed to fetch from Subgraph."}],"data":null}` + })) + + t.Run("bigint handling", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"n": 12345, "ns_small": "12346", "ns_big": "1152921504606846976"}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("n"), + Value: &BigInt{ + Path: []string{"n"}, + Nullable: false, + }, + }, + { + Name: []byte("ns_small"), + Value: &BigInt{ + Path: []string{"ns_small"}, + Nullable: false, + }, + }, + { + Name: []byte("ns_big"), + Value: &BigInt{ + Path: []string{"ns_big"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"n":12345,"ns_small":"12346","ns_big":"1152921504606846976"}}` + })) + + t.Run("skip loader", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("static"), + Value: &Null{}, + }, + }, + }, + }, Context{ctx: context.Background(), ExecutionOptions: ExecutionOptions{SkipLoader: true}}, `{"data":null}` + })) +} + +func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { + options := apolloCompatibilityOptions{ + valueCompletion: true, + suppressFetchErrors: true, + } + t.Run("simple fetch with fetch error suppression - empty response", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte("{}"), nil + }) + return &GraphQLResponse{ + Fetches: SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: mockDataSource, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + SelectResponseErrorsPath: []string{"errors"}, + }, + }, + }, "query"), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":null,"extensions":{"valueCompletion":[{"message":"Cannot return null for non-nullable field Query.name.","path":["name"],"extensions":{"code":"INVALID_GRAPHQL"}}]}}` + }, &options)) + + t.Run("simple fetch with fetch error suppression - response with error", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}]}`), nil + }) + return &GraphQLResponse{ + Fetches: SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: mockDataSource, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + SelectResponseErrorsPath: []string{"errors"}, + }, + }, + }, "query"), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}],"data":null}` + }, &options)) + + t.Run("complex fetch with fetch error suppression", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + userService := NewMockDataSource(ctrl) + userService.EXPECT(). + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + actual := string(input) + expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` + assert.Equal(t, expected, actual) + return []byte(`{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}`), nil + }) + + reviewsService := NewMockDataSource(ctrl) + reviewsService.EXPECT(). + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + actual := string(input) + expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` + assert.Equal(t, expected, actual) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}}`), nil + }) + + productService := NewMockDataSource(ctrl) + productService.EXPECT(). + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + actual := string(input) + expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` + assert.Equal(t, expected, actual) + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil + }) + + return &GraphQLResponse{ + Fetches: Sequence( + SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: userService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + }, + }, + }, "query"), + SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { Data: []byte(`{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"`), SegmentType: StaticSegmentType, }, @@ -4566,14 +4946,12 @@ func TestResolver_WithHeader(t *testing.T) { ctrl := gomock.NewController(t) fakeService := NewMockDataSource(ctrl) fakeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) assert.Equal(t, "foo", actual) - _, err = w.Write([]byte(`{"bar":"baz"}`)) - return - }). - Return(nil) + return []byte(`{"bar":"baz"}`), nil + }) out := &bytes.Buffer{} res := &GraphQLResponse{ @@ -4639,14 +5017,12 @@ func TestResolver_WithVariableRemapping(t *testing.T) { ctrl := gomock.NewController(t) fakeService := NewMockDataSource(ctrl) fakeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) assert.Equal(t, tc.expectedOutput, actual) - _, err = w.Write([]byte(`{"bar":"baz"}`)) - return - }). - Return(nil) + return []byte(`{"bar":"baz"}`), nil + }) out := &bytes.Buffer{} res := &GraphQLResponse{ @@ -5909,50 +6285,353 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { Data: []byte(`{"method":"POST","url":"http://localhost:4000"}`), }, }, - }, - }, - Filter: &SubscriptionFilter{ - In: &SubscriptionFieldFilter{ - FieldPath: []string{"id"}, - Values: []InputTemplate{ - { + }, + }, + Filter: &SubscriptionFilter{ + In: &SubscriptionFieldFilter{ + FieldPath: []string{"id"}, + Values: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: StaticSegmentType, + Data: []byte(`x.`), + }, + { + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"a"}, + Renderer: NewPlainVariableRenderer(), + }, + { + SegmentType: StaticSegmentType, + Data: []byte(`.`), + }, + { + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"b"}, + Renderer: NewPlainVariableRenderer(), + }, + }, + }, + }, + }, + }, + Response: &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("oneUserByID"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + }, + }, + }, + }, + }, + }, + } + + out := &SubscriptionRecorder{ + buf: &bytes.Buffer{}, + messages: []string{}, + complete: atomic.Bool{}, + } + out.complete.Store(false) + + id := SubscriptionIdentifier{ + ConnectionID: 1, + SubscriptionID: 1, + } + + resolver := newResolver(c) + + ctx := &Context{ + ctx: context.Background(), + Variables: astjson.MustParseBytes([]byte(`{"a":[1,2],"b":[3,4]}`)), + } + + err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, out, id) + assert.NoError(t, err) + out.AwaitComplete(t, defaultTimeout) + assert.Equal(t, 4, len(out.Messages())) + assert.ElementsMatch(t, []string{ + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + }, out.Messages()) + }) +} + +func Benchmark_NestedBatching(b *testing.B) { + rCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := newResolver(rCtx) + + productsService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`), + []byte(`{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`)) + stockService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`), + []byte(`{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`)) + reviewsService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`), + []byte(`{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`)) + usersService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`), + []byte(`{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`)) + + plan := &GraphQLResponse{ + Fetches: Sequence( + SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: productsService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + }, + }, + }, ""), + Parallel( + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + }, + DataSource: reviewsService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "topProducts", ArrayPath("topProducts")), + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ Segments: []TemplateSegment{ { + Data: []byte(`{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[`), SegmentType: StaticSegmentType, - Data: []byte(`x.`), }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ { - SegmentType: VariableSegmentType, - VariableKind: ContextVariableKind, - VariableSourcePath: []string{"a"}, - Renderer: NewPlainVariableRenderer(), + Data: []byte(`,`), + SegmentType: StaticSegmentType, }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ { + Data: []byte(`]}}}`), SegmentType: StaticSegmentType, - Data: []byte(`.`), }, + }, + }, + }, + DataSource: stockService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "topProducts", ArrayPath("topProducts")), + ), + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ { - SegmentType: VariableSegmentType, - VariableKind: ContextVariableKind, - VariableSourcePath: []string{"b"}, - Renderer: NewPlainVariableRenderer(), + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + }, + }), }, }, }, }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, }, - }, - Response: &GraphQLResponse{ - Data: &Object{ - Fields: []*Field{ - { - Name: []byte("oneUserByID"), - Value: &Object{ - Fields: []*Field{ - { - Name: []byte("id"), - Value: &String{ - Path: []string{"id"}, + DataSource: usersService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "topProducts.@.reviews.@.author", ArrayPath("topProducts"), ArrayPath("reviews"), ObjectPath("author")), + ), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("topProducts"), + Value: &Array{ + Path: []string{"topProducts"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + { + Name: []byte("stock"), + Value: &Integer{ + Path: []string{"stock"}, + }, + }, + { + Name: []byte("reviews"), + Value: &Array{ + Path: []string{"reviews"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("body"), + Value: &String{ + Path: []string{"body"}, + }, + }, + { + Name: []byte("author"), + Value: &Object{ + Path: []string{"author"}, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + }, + }, + }, + }, }, }, }, @@ -5961,41 +6640,53 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { }, }, }, - } + }, + Info: &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + }, + } - out := &SubscriptionRecorder{ - buf: &bytes.Buffer{}, - messages: []string{}, - complete: atomic.Bool{}, - } - out.complete.Store(false) + expected := []byte(`{"data":{"topProducts":[{"name":"Table","stock":8,"reviews":[{"body":"Love Table!","author":{"name":"user-1"}},{"body":"Prefer other Table.","author":{"name":"user-2"}}]},{"name":"Couch","stock":2,"reviews":[{"body":"Couch Too expensive.","author":{"name":"user-1"}}]},{"name":"Chair","stock":5,"reviews":[{"body":"Chair Could be better.","author":{"name":"user-2"}}]}]}}`) - id := SubscriptionIdentifier{ - ConnectionID: 1, - SubscriptionID: 1, - } + pool := sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, 1024)) + }, + } - resolver := newResolver(c) + ctxPool := sync.Pool{ + New: func() interface{} { + return NewContext(context.Background()) + }, + } - ctx := &Context{ - ctx: context.Background(), - Variables: astjson.MustParseBytes([]byte(`{"a":[1,2],"b":[3,4]}`)), - } + b.ReportAllocs() + b.SetBytes(int64(len(expected))) + b.ResetTimer() - err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, out, id) - assert.NoError(t, err) - out.AwaitComplete(t, defaultTimeout) - assert.Equal(t, 4, len(out.Messages())) - assert.ElementsMatch(t, []string{ - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - }, out.Messages()) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + ctx := ctxPool.Get().(*Context) + buf := pool.Get().(*bytes.Buffer) + ctx.ctx = context.Background() + _, err := resolver.ResolveGraphQLResponse(ctx, plan, nil, buf) + if err != nil { + b.Fatal(err) + } + if !bytes.Equal(expected, buf.Bytes()) { + require.Equal(b, string(expected), buf.String()) + } + + buf.Reset() + pool.Put(buf) + + ctx.Free() + ctxPool.Put(ctx) + } }) } -func Benchmark_NestedBatching(b *testing.B) { +func Benchmark_NestedBatchingArena(b *testing.B) { rCtx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -6293,7 +6984,7 @@ func Benchmark_NestedBatching(b *testing.B) { ctx := ctxPool.Get().(*Context) buf := pool.Get().(*bytes.Buffer) ctx.ctx = context.Background() - _, err := resolver.ResolveGraphQLResponse(ctx, plan, nil, buf) + _, err := resolver.ArenaResolveGraphQLResponse(ctx, plan, buf) if err != nil { b.Fatal(err) } @@ -6310,7 +7001,7 @@ func Benchmark_NestedBatching(b *testing.B) { }) } -func Benchmark_NestedBatchingWithoutChecks(b *testing.B) { +func Benchmark_NoCheckNestedBatching(b *testing.B) { rCtx, cancel := context.WithCancel(context.Background()) defer cancel() From 3142c9011da0c0e587c6464fe8df2f7c13b620bf Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 15 Oct 2025 16:28:42 +0200 Subject: [PATCH 13/61] chore: implement weak arena pool --- v2/pkg/engine/resolve/loader.go | 4 +++ v2/pkg/engine/resolve/resolve.go | 60 +++++++++++++++++++++++++++++--- 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 1bab9779b..70626cbe4 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -60,6 +60,10 @@ type ResponseInfo struct { responseBody []byte } +func (r *ResponseInfo) GetResponseBody() string { + return string(r.responseBody) +} + func newResponseInfo(res *result, subgraphError error) *ResponseInfo { responseInfo := &ResponseInfo{ StatusCode: res.statusCode, diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 4a0075f6b..01417606f 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -7,7 +7,9 @@ import ( "context" "fmt" "io" + "sync" "time" + "weak" "github.com/buger/jsonparser" "github.com/pkg/errors" @@ -70,6 +72,14 @@ type Resolver struct { heartbeatInterval time.Duration // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration + + arenaPool []weak.Pointer[arenaPoolItem] + arenaSize map[uint64]int + arenaPoolMu sync.Mutex +} + +type arenaPoolItem struct { + jsonArena arena.Arena } func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) { @@ -229,6 +239,8 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { resolver.maxConcurrency <- struct{}{} } + resolver.arenaSize = make(map[uint64]int) + go resolver.processEvents() return resolver @@ -292,6 +304,46 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons return resp, err } +func (r *Resolver) acquireArena(id uint64) *arenaPoolItem { + r.arenaPoolMu.Lock() + defer r.arenaPoolMu.Unlock() + + for i := 0; i < len(r.arenaPool); i++ { + v := r.arenaPool[i].Value() + r.arenaPool = append(r.arenaPool[:i], r.arenaPool[i+1:]...) + if v == nil { + continue + } + return v + } + + size := arena.WithMinBufferSize(r.getArenaSize(id)) + + return &arenaPoolItem{ + jsonArena: arena.NewMonotonicArena(size), + } +} + +func (r *Resolver) getArenaSize(id uint64) int { + if size, ok := r.arenaSize[id]; ok { + return size + } + return 1024 * 1024 +} + +func (r *Resolver) releaseArena(id uint64, item *arenaPoolItem) { + peak := item.jsonArena.Peak() + item.jsonArena.Reset() + + r.arenaPoolMu.Lock() + defer r.arenaPoolMu.Unlock() + + r.arenaSize[id] = peak + + w := weak.Make(item) + r.arenaPool = append(r.arenaPool, w) +} + func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLResponse, writer io.Writer) (*GraphQLResolveInfo, error) { resp := &GraphQLResolveInfo{} @@ -304,10 +356,10 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) - jsonArena := arena.NewMonotonicArena() - defer jsonArena.Release() - t.loader.jsonArena = jsonArena - t.resolvable.astjsonArena = jsonArena + poolItem := r.acquireArena(ctx.Request.ID) + defer r.releaseArena(ctx.Request.ID, poolItem) + t.loader.jsonArena = poolItem.jsonArena + t.resolvable.astjsonArena = poolItem.jsonArena err := t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { From 1c9b87758cc4e212a8f66a3233cf24c218119911 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 15 Oct 2025 17:49:34 +0200 Subject: [PATCH 14/61] chore: default buffer size --- v2/pkg/engine/datasource/httpclient/nethttpclient.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 0eb4360fa..30b01f012 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -262,8 +262,9 @@ func Do(client *http.Client, ctx context.Context, requestInput []byte) (data []b pool.Hash64.Put(h) ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - var buf bytes.Buffer - err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, &buf, ContentTypeJSON) + buf := bytes.NewBuffer(make([]byte, 0, 1024*4)) + + err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, buf, ContentTypeJSON) if err != nil { return nil, err } @@ -333,8 +334,9 @@ func DoMultipartForm( bodyHash := h.Sum64() ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - var buf bytes.Buffer - err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, &buf, contentType) + buf := bytes.NewBuffer(make([]byte, 0, 1024*4)) + + err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, buf, contentType) if err != nil { return nil, err } From 112171e9515440da04ab8d06c7ff267c70aaa5ad Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 16 Oct 2025 23:37:08 +0200 Subject: [PATCH 15/61] chore: move single flight into loader --- .../datasource/httpclient/nethttpclient.go | 81 +++--------- v2/pkg/engine/resolve/loader.go | 119 ++++++++++++------ v2/pkg/engine/resolve/resolvable.go | 19 ++- v2/pkg/engine/resolve/resolve.go | 17 ++- v2/pkg/engine/resolve/singleflight.go | 86 +++++++++++++ 5 files changed, 214 insertions(+), 108 deletions(-) create mode 100644 v2/pkg/engine/resolve/singleflight.go diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 30b01f012..3fa74b949 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -20,7 +20,6 @@ import ( "github.com/buger/jsonparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" - "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) const ( @@ -130,21 +129,11 @@ func respBodyReader(res *http.Response) (io.Reader, error) { } } -type bodyHashContextKey struct{} - -func BodyHashFromContext(ctx context.Context) (uint64, bool) { - value := ctx.Value(bodyHashContextKey{}) - if value == nil { - return 0, false - } - return value.(uint64), true -} - -func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, out *bytes.Buffer, contentType string) (err error) { +func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, contentType string) ([]byte, error) { request, err := http.NewRequestWithContext(ctx, string(method), string(url), body) if err != nil { - return err + return nil, err } if headers != nil { @@ -161,7 +150,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head return err }) if err != nil { - return err + return nil, err } } @@ -190,7 +179,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head } }) if err != nil { - return err + return nil, err } request.URL.RawQuery = query.Encode() } @@ -204,7 +193,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head response, err := client.Do(request) if err != nil { - return err + return nil, err } defer response.Body.Close() @@ -212,23 +201,20 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head respReader, err := respBodyReader(response) if err != nil { - return err + return nil, err } - if !enableTrace { - if response.ContentLength > 0 { - out.Grow(int(response.ContentLength)) - } else { - out.Grow(1024 * 4) - } - _, err = out.ReadFrom(respReader) - return + out := bytes.NewBuffer(make([]byte, 0, 1024*4)) + _, err = out.ReadFrom(respReader) + if err != nil { + return nil, err } - data, err := io.ReadAll(respReader) - if err != nil { - return err + if !enableTrace { + return out.Bytes(), nil } + + data := out.Bytes() responseTrace := TraceHTTP{ Request: TraceHTTPRequest{ Method: request.Method, @@ -244,31 +230,18 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head } trace, err := json.Marshal(responseTrace) if err != nil { - return err + return nil, err } responseWithTraceExtension, err := jsonparser.Set(data, trace, "extensions", "trace") if err != nil { - return err + return nil, err } - _, err = out.Write(responseWithTraceExtension) - return err + return responseWithTraceExtension, nil } func Do(client *http.Client, ctx context.Context, requestInput []byte) (data []byte, err error) { url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) - h := pool.Hash64.Get() - _, _ = h.Write(body) - bodyHash := h.Sum64() - pool.Hash64.Put(h) - ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - - buf := bytes.NewBuffer(make([]byte, 0, 1024*4)) - - err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, buf, ContentTypeJSON) - if err != nil { - return nil, err - } - return buf.Bytes(), nil + return makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, ContentTypeJSON) } func DoMultipartForm( @@ -280,10 +253,6 @@ func DoMultipartForm( url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) - h := pool.Hash64.Get() - defer pool.Hash64.Put(h) - _, _ = h.Write(body) - formValues := map[string]io.Reader{ "operations": bytes.NewReader(body), } @@ -300,10 +269,9 @@ func DoMultipartForm( } hasWrittenFileName = true - fmt.Fprintf(fileMap, `"%d":["%s"]`, i, file.variablePath) + _, _ = fmt.Fprintf(fileMap, `"%d":["%s"]`, i, file.variablePath) key := fmt.Sprintf("%d", i) - _, _ = h.WriteString(file.Path()) temporaryFile, err := os.Open(file.Path()) tempFiles = append(tempFiles, temporaryFile) if err != nil { @@ -331,16 +299,7 @@ func DoMultipartForm( } }() - bodyHash := h.Sum64() - ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - - buf := bytes.NewBuffer(make([]byte, 0, 1024*4)) - - err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, buf, contentType) - if err != nil { - return nil, err - } - return buf.Bytes(), nil + return makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, contentType) } func multipartBytes(values map[string]io.Reader, files []*FileUpload) (*io.PipeReader, string, error) { diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 70626cbe4..703119053 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -139,6 +139,7 @@ type result struct { httpResponseContext *httpclient.ResponseContext out []byte + singleFlightStats *singleFlightStats } func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { @@ -183,6 +184,7 @@ type Loader struct { taintedObjs taintedObjects jsonArena arena.Arena + sf *SingleFlight } func (l *Loader) Free() { @@ -772,6 +774,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V } // If the error propagation mode is pass-through, we append the errors to the root array + l.resolvable.ensureErrorsInitialized() l.resolvable.errors.AppendArrayItems(value) return nil } @@ -808,6 +811,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V return err } + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil @@ -1062,6 +1066,7 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { return err } + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, apolloRouterStatusError) return nil @@ -1075,6 +1080,7 @@ func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1086,6 +1092,7 @@ func (l *Loader) renderErrorsFailedToFetch(fetchItem *FetchItem, res *result, re return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1104,7 +1111,7 @@ func (l *Loader) renderErrorsStatusFallback(fetchItem *FetchItem, res *result, s } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) - + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1129,6 +1136,7 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re } pathPart := l.renderAtPathErrorPart(fetchItem.ResponsePath) extensionErrorCode := fmt.Sprintf(`"extensions":{"code":"%s"}`, errorcodes.UnauthorizedFieldOrType) + l.resolvable.ensureErrorsInitialized() if res.ds.Name == "" { for _, reason := range res.authorizationRejectedReasons { if reason == "" { @@ -1207,6 +1215,7 @@ func (l *Loader) renderRateLimitRejectedErrors(fetchItem *FetchItem, res *result return err } } + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1598,29 +1607,8 @@ func redactHeaders(rawJSON json.RawMessage) (json.RawMessage, error) { return redactedJSON, nil } -type disallowSingleFlightContextKey struct{} - -func SingleFlightDisallowed(ctx context.Context) bool { - return ctx.Value(disallowSingleFlightContextKey{}) != nil -} - -type singleFlightStatsKey struct{} - -type SingleFlightStats struct { - SingleFlightUsed bool - SingleFlightSharedResponse bool -} - -func GetSingleFlightStats(ctx context.Context) *SingleFlightStats { - maybeStats := ctx.Value(singleFlightStatsKey{}) - if maybeStats == nil { - return nil - } - return maybeStats.(*SingleFlightStats) -} - -func setSingleFlightStats(ctx context.Context, stats *SingleFlightStats) context.Context { - return context.WithValue(ctx, singleFlightStatsKey{}, stats) +type singleFlightStats struct { + used, shared bool } func (l *Loader) setTracingInput(fetchItem *FetchItem, input []byte, trace *DataSourceLoadTrace) { @@ -1636,7 +1624,70 @@ func (l *Loader) setTracingInput(fetchItem *FetchItem, input []byte, trace *Data } } -func (l *Loader) loadByContext(ctx context.Context, source DataSource, input []byte, res *result) error { +type loaderContextKey string + +const ( + operationTypeContextKey loaderContextKey = "operationType" +) + +func GetOperationTypeFromContext(ctx context.Context) ast.OperationType { + if ctx == nil { + return ast.OperationTypeQuery + } + if v := ctx.Value(operationTypeContextKey); v != nil { + if opType, ok := v.(ast.OperationType); ok { + return opType + } + } + return ast.OperationTypeQuery +} + +func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem *FetchItem, input []byte, res *result) error { + + if l.info != nil { + ctx = context.WithValue(ctx, operationTypeContextKey, l.info.OperationType) + } + + if l.info == nil || l.info.OperationType == ast.OperationTypeMutation { + // Disable single flight for mutations + return l.loadByContextDirect(ctx, source, input, res) + } + + key, item, shared := l.sf.GetOrCreateItem(ctx, fetchItem, input) + if res.singleFlightStats != nil { + res.singleFlightStats.used = shared + res.singleFlightStats.shared = shared + } + + if shared { + select { + case <-item.loaded: + case <-ctx.Done(): + return ctx.Err() + } + + if item.err != nil { + return item.err + } + + res.out = item.response + return nil + } + + defer l.sf.Finish(key, item) + + // Perform the actual load + err := l.loadByContextDirect(ctx, source, input, res) + if err != nil { + item.err = err + return err + } + + item.response = res.out + return nil +} + +func (l *Loader) loadByContextDirect(ctx context.Context, source DataSource, input []byte, res *result) error { if l.ctx.Files != nil { res.out, res.err = source.LoadWithFiles(ctx, input, l.ctx.Files) } else { @@ -1674,7 +1725,7 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so } } if l.ctx.TracingOptions.Enable { - ctx = setSingleFlightStats(ctx, &SingleFlightStats{}) + res.singleFlightStats = &singleFlightStats{} trace.Path = fetchItem.ResponsePath if !l.ctx.TracingOptions.ExcludeInput { trace.Input = make([]byte, len(input)) @@ -1778,9 +1829,6 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so ctx = httptrace.WithClientTrace(ctx, clientTrace) } } - if l.info != nil && l.info.OperationType == ast.OperationTypeMutation { - ctx = context.WithValue(ctx, disallowSingleFlightContextKey{}, true) - } var responseContext *httpclient.ResponseContext ctx, responseContext = httpclient.InjectResponseContext(ctx) @@ -1789,24 +1837,23 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so // Prevent that the context is destroyed when the loader hook return an empty context if res.loaderHookContext != nil { - res.err = l.loadByContext(res.loaderHookContext, source, input, res) + res.err = l.loadByContext(res.loaderHookContext, source, fetchItem, input, res) } else { - res.err = l.loadByContext(ctx, source, input, res) + res.err = l.loadByContext(ctx, source, fetchItem, input, res) res.loaderHookContext = ctx // Set the context to the original context to ensure that OnFinished hook gets valid context } } else { - res.err = l.loadByContext(ctx, source, input, res) + res.err = l.loadByContext(ctx, source, fetchItem, input, res) } res.statusCode = responseContext.StatusCode res.httpResponseContext = responseContext if l.ctx.TracingOptions.Enable { - stats := GetSingleFlightStats(ctx) - if stats != nil { - trace.SingleFlightUsed = stats.SingleFlightUsed - trace.SingleFlightSharedResponse = stats.SingleFlightSharedResponse + if res.singleFlightStats != nil { + trace.SingleFlightUsed = res.singleFlightStats.used + trace.SingleFlightSharedResponse = res.singleFlightStats.shared } if !l.ctx.TracingOptions.ExcludeOutput && len(res.out) > 0 { trace.Output, _ = l.compactJSON(res.out) diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index 5aceb2110..21470f475 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -111,7 +111,7 @@ func (r *Resolvable) Init(ctx *Context, initialData []byte, operationType ast.Op r.operationType = operationType r.renameTypeNames = ctx.RenameTypeNames r.data = astjson.ObjectValue(r.astjsonArena) - r.errors = astjson.ArrayValue(r.astjsonArena) + r.errors = nil if initialData != nil { initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) if err != nil { @@ -129,6 +129,7 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc r.ctx = ctx r.operationType = ast.OperationTypeSubscription r.renameTypeNames = ctx.RenameTypeNames + r.errors = nil if initialData != nil { initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) if err != nil { @@ -158,9 +159,6 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc if r.data == nil { r.data = astjson.ObjectValue(r.astjsonArena) } - if r.errors == nil { - r.errors = astjson.ArrayValue(r.astjsonArena) - } return } @@ -169,7 +167,7 @@ func (r *Resolvable) ResolveNode(node Node, data *astjson.Value, out io.Writer) r.print = false r.printErr = nil r.authorizationError = nil - r.errors = astjson.ArrayValue(r.astjsonArena) + r.errors = nil hasErrors := r.walkNode(node, data) if hasErrors { @@ -235,6 +233,12 @@ func (r *Resolvable) Resolve(ctx context.Context, rootData *Object, fetchTree *F return r.printErr } +func (r *Resolvable) ensureErrorsInitialized() { + if r.errors == nil { + r.errors = astjson.ArrayValue(r.astjsonArena) + } +} + func (r *Resolvable) enclosingTypeName() string { if len(r.enclosingTypeNames) > 0 { return r.enclosingTypeNames[len(r.enclosingTypeNames)-1] @@ -761,6 +765,7 @@ func (r *Resolvable) addRejectFieldError(reason string, ds DataSourceInfo, field } r.ctx.appendSubgraphErrors(errors.New(errorMessage), NewSubgraphError(ds, fieldPath, reason, 0)) + r.ensureErrorsInitialized() fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.errors, errorMessage, errorcodes.UnauthorizedFieldOrType, r.path) r.popNodePathElement(nodePath) } @@ -1202,6 +1207,7 @@ func (r *Resolvable) addNonNullableFieldError(fieldPath []string, parent *astjso r.addValueCompletion(r.renderApolloCompatibleNonNullableErrorMessage(), errorcodes.InvalidGraphql) } else { errorMessage := fmt.Sprintf("Cannot return null for non-nullable field '%s'.", r.renderFieldPath()) + r.ensureErrorsInitialized() fastjsonext.AppendErrorToArray(r.astjsonArena, r.errors, errorMessage, r.path) } r.popNodePathElement(fieldPath) @@ -1272,16 +1278,19 @@ func (r *Resolvable) renderFieldCoordinates() string { func (r *Resolvable) addError(message string, fieldPath []string) { r.pushNodePathElement(fieldPath) + r.ensureErrorsInitialized() fastjsonext.AppendErrorToArray(r.astjsonArena, r.errors, message, r.path) r.popNodePathElement(fieldPath) } func (r *Resolvable) addErrorWithCode(message, code string) { + r.ensureErrorsInitialized() fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.errors, message, code, r.path) } func (r *Resolvable) addErrorWithCodeAndPath(message, code string, fieldPath []string) { r.pushNodePathElement(fieldPath) + r.ensureErrorsInitialized() fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.errors, message, code, r.path) r.popNodePathElement(fieldPath) } diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 01417606f..eef77b5b8 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -76,6 +76,9 @@ type Resolver struct { arenaPool []weak.Pointer[arenaPoolItem] arenaSize map[uint64]int arenaPoolMu sync.Mutex + + // Single flight cache for deduplicating requests across all loaders + sf *SingleFlight } type arenaPoolItem struct { @@ -233,6 +236,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { allowedErrorFields: allowedErrorFields, heartbeatInterval: options.SubscriptionHeartbeatInterval, maxSubscriptionFetchTimeout: options.MaxSubscriptionFetchTimeout, + sf: NewSingleFlight(), } resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) for i := 0; i < options.MaxConcurrency; i++ { @@ -246,7 +250,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { return resolver } -func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}) *tools { +func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SingleFlight) *tools { return &tools{ resolvable: NewResolvable(nil, options.ResolvableOptions), loader: &Loader{ @@ -264,6 +268,7 @@ func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{ apolloRouterCompatibilitySubrequestHTTPError: options.ApolloRouterCompatibilitySubrequestHTTPError, propagateFetchReasons: options.PropagateFetchReasons, validateRequiredExternalFields: options.ValidateRequiredExternalFields, + sf: sf, }, } } @@ -282,7 +287,7 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons r.maxConcurrency <- struct{}{} }() - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) err := t.resolvable.Init(ctx, data, response.Info.OperationType) if err != nil { @@ -354,7 +359,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe r.maxConcurrency <- struct{}{} }() - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) poolItem := r.acquireArena(ctx.Request.ID) defer r.releaseArena(ctx.Request.ID, poolItem) @@ -511,7 +516,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar input := make([]byte, len(sharedInput)) copy(input, sharedInput) - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) @@ -1104,7 +1109,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { @@ -1213,7 +1218,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { diff --git a/v2/pkg/engine/resolve/singleflight.go b/v2/pkg/engine/resolve/singleflight.go new file mode 100644 index 000000000..7843bafec --- /dev/null +++ b/v2/pkg/engine/resolve/singleflight.go @@ -0,0 +1,86 @@ +package resolve + +import ( + "context" + "sync" + + "github.com/cespare/xxhash/v2" +) + +type SingleFlightItem struct { + loaded chan struct{} + response []byte + err error +} + +type SingleFlight struct { + mu *sync.RWMutex + items map[uint64]*SingleFlightItem + xxPool *sync.Pool + cleanup chan func() +} + +func NewSingleFlight() *SingleFlight { + return &SingleFlight{ + items: make(map[uint64]*SingleFlightItem), + mu: new(sync.RWMutex), + xxPool: &sync.Pool{ + New: func() any { + return xxhash.New() + }, + }, + cleanup: make(chan func()), + } +} + +func (s *SingleFlight) GetOrCreateItem(ctx context.Context, fetchItem *FetchItem, input []byte) (key uint64, item *SingleFlightItem, shared bool) { + key = s.key(fetchItem, input) + + // First, try to get the item with a read lock + s.mu.RLock() + item, exists := s.items[key] + s.mu.RUnlock() + if exists { + return key, item, true + } + + // If not exists, acquire a write lock to create the item + s.mu.Lock() + // Double-check if the item was created while acquiring the write lock + item, exists = s.items[key] + if exists { + s.mu.Unlock() + return key, item, true + } + + // Create a new item + item = &SingleFlightItem{ + loaded: make(chan struct{}), + } + s.items[key] = item + s.mu.Unlock() + return key, item, false +} + +func (s *SingleFlight) key(fetchItem *FetchItem, input []byte) uint64 { + h := s.xxPool.Get().(*xxhash.Digest) + if fetchItem != nil && fetchItem.Fetch != nil { + info := fetchItem.Fetch.FetchInfo() + if info != nil { + _, _ = h.WriteString(info.DataSourceID) + _, _ = h.WriteString(":") + } + } + _, _ = h.Write(input) + key := h.Sum64() + h.Reset() + s.xxPool.Put(h) + return key +} + +func (s *SingleFlight) Finish(key uint64, item *SingleFlightItem) { + close(item.loaded) + s.mu.Lock() + delete(s.items, key) + s.mu.Unlock() +} From 7a777ea9f163206b40c9a85623931e0526a05b58 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 17 Oct 2025 00:09:00 +0200 Subject: [PATCH 16/61] chore: add http client buffer size hint --- .../datasource/httpclient/nethttpclient.go | 19 ++++- v2/pkg/engine/resolve/loader.go | 6 +- v2/pkg/engine/resolve/singleflight.go | 81 +++++++++++++++---- 3 files changed, 88 insertions(+), 18 deletions(-) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 3fa74b949..27b0434c1 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -129,6 +129,23 @@ func respBodyReader(res *http.Response) (io.Reader, error) { } } +type httpClientContext string + +const ( + sizeHintKey httpClientContext = "size-hint" +) + +func WithHTTPClientSizeHint(ctx context.Context, size int) context.Context { + return context.WithValue(ctx, sizeHintKey, size) +} + +func buffer(ctx context.Context) *bytes.Buffer { + if sizeHint, ok := ctx.Value(sizeHintKey).(int); ok && sizeHint > 0 { + return bytes.NewBuffer(make([]byte, 0, sizeHint)) + } + return bytes.NewBuffer(make([]byte, 0, 1024*4)) // default to 4KB +} + func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, contentType string) ([]byte, error) { request, err := http.NewRequestWithContext(ctx, string(method), string(url), body) @@ -204,7 +221,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head return nil, err } - out := bytes.NewBuffer(make([]byte, 0, 1024*4)) + out := buffer(ctx) _, err = out.ReadFrom(respReader) if err != nil { return nil, err diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 703119053..2c923b2c9 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1653,7 +1653,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem return l.loadByContextDirect(ctx, source, input, res) } - key, item, shared := l.sf.GetOrCreateItem(ctx, fetchItem, input) + sfKey, fetchKey, item, shared := l.sf.GetOrCreateItem(ctx, fetchItem, input) if res.singleFlightStats != nil { res.singleFlightStats.used = shared res.singleFlightStats.shared = shared @@ -1674,7 +1674,9 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem return nil } - defer l.sf.Finish(key, item) + ctx = httpclient.WithHTTPClientSizeHint(ctx, item.sizeHint) + + defer l.sf.Finish(sfKey, fetchKey, item) // Perform the actual load err := l.loadByContextDirect(ctx, source, input, res) diff --git a/v2/pkg/engine/resolve/singleflight.go b/v2/pkg/engine/resolve/singleflight.go index 7843bafec..e29853196 100644 --- a/v2/pkg/engine/resolve/singleflight.go +++ b/v2/pkg/engine/resolve/singleflight.go @@ -11,18 +11,26 @@ type SingleFlightItem struct { loaded chan struct{} response []byte err error + sizeHint int } type SingleFlight struct { mu *sync.RWMutex items map[uint64]*SingleFlightItem + sizes map[uint64]*fetchSize xxPool *sync.Pool cleanup chan func() } +type fetchSize struct { + count int + totalBytes int +} + func NewSingleFlight() *SingleFlight { return &SingleFlight{ items: make(map[uint64]*SingleFlightItem), + sizes: make(map[uint64]*fetchSize), mu: new(sync.RWMutex), xxPool: &sync.Pool{ New: func() any { @@ -33,37 +41,49 @@ func NewSingleFlight() *SingleFlight { } } -func (s *SingleFlight) GetOrCreateItem(ctx context.Context, fetchItem *FetchItem, input []byte) (key uint64, item *SingleFlightItem, shared bool) { - key = s.key(fetchItem, input) +func (s *SingleFlight) GetOrCreateItem(ctx context.Context, fetchItem *FetchItem, input []byte) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { + sfKey, fetchKey = s.keys(fetchItem, input) // First, try to get the item with a read lock s.mu.RLock() - item, exists := s.items[key] + item, exists := s.items[sfKey] s.mu.RUnlock() if exists { - return key, item, true + return sfKey, fetchKey, item, true } // If not exists, acquire a write lock to create the item s.mu.Lock() // Double-check if the item was created while acquiring the write lock - item, exists = s.items[key] + item, exists = s.items[sfKey] if exists { s.mu.Unlock() - return key, item, true + return sfKey, fetchKey, item, true } // Create a new item item = &SingleFlightItem{ loaded: make(chan struct{}), } - s.items[key] = item + if size, ok := s.sizes[fetchKey]; ok { + item.sizeHint = size.totalBytes / size.count + } + s.items[sfKey] = item s.mu.Unlock() - return key, item, false + return sfKey, fetchKey, item, false } -func (s *SingleFlight) key(fetchItem *FetchItem, input []byte) uint64 { +func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte) (sfKey, fetchKey uint64) { h := s.xxPool.Get().(*xxhash.Digest) + sfKey = s.sfKey(h, fetchItem, input) + h.Reset() + fetchKey = s.fetchKey(h, fetchItem) + h.Reset() + s.xxPool.Put(h) + return sfKey, fetchKey +} + +func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte) uint64 { if fetchItem != nil && fetchItem.Fetch != nil { info := fetchItem.Fetch.FetchInfo() if info != nil { @@ -72,15 +92,46 @@ func (s *SingleFlight) key(fetchItem *FetchItem, input []byte) uint64 { } } _, _ = h.Write(input) - key := h.Sum64() - h.Reset() - s.xxPool.Put(h) - return key + return h.Sum64() } -func (s *SingleFlight) Finish(key uint64, item *SingleFlightItem) { +func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { + if fetchItem == nil || fetchItem.Fetch == nil { + return 0 + } + info := fetchItem.Fetch.FetchInfo() + if info == nil { + return 0 + } + _, _ = h.WriteString(info.DataSourceID) + _, _ = h.WriteString("|") + for i := range info.RootFields { + if i != 0 { + _, _ = h.WriteString(",") + } + _, _ = h.WriteString(info.RootFields[i].TypeName) + _, _ = h.WriteString(".") + _, _ = h.WriteString(info.RootFields[i].FieldName) + } + return h.Sum64() +} + +func (s *SingleFlight) Finish(sfKey, fetchKey uint64, item *SingleFlightItem) { close(item.loaded) s.mu.Lock() - delete(s.items, key) + delete(s.items, sfKey) + if size, ok := s.sizes[fetchKey]; ok { + if size.count == 50 { + size.count = 1 + size.totalBytes = size.totalBytes / 50 + } + size.count++ + size.totalBytes += len(item.response) + } else { + s.sizes[fetchKey] = &fetchSize{ + count: 1, + totalBytes: len(item.response), + } + } s.mu.Unlock() } From c41b4b6300dc500f6fea2795d3b6056b5dfbe6a1 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 17 Oct 2025 00:20:53 +0200 Subject: [PATCH 17/61] chore: selectItems on arena --- v2/pkg/engine/resolve/loader.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 2c923b2c9..23da2cbe0 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -367,7 +367,7 @@ func (l *Loader) selectItemsForPath(path []FetchItemPathElement) []*astjson.Valu if len(items) == 0 { break } - items = selectItems(items, path[i]) + items = selectItems(l.jsonArena, items, path[i]) } return l.taintedObjs.filterOutTainted(items) } @@ -388,7 +388,7 @@ func isItemAllowedByTypename(obj *astjson.Value, typeNames []string) bool { return slices.Contains(typeNames, __typeNameStr) } -func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjson.Value { +func selectItems(a arena.Arena, items []*astjson.Value, element FetchItemPathElement) []*astjson.Value { if len(items) == 0 { return nil } @@ -410,7 +410,7 @@ func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjso } return []*astjson.Value{field} } - selected := make([]*astjson.Value, 0, len(items)) + selected := arena.AllocateSlice[*astjson.Value](a, 0, len(items)) for _, item := range items { if !isItemAllowedByTypename(item, element.TypeNames) { continue @@ -420,10 +420,10 @@ func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjso continue } if field.Type() == astjson.TypeArray { - selected = append(selected, field.GetArray()...) + selected = arena.SliceAppend(a, selected, field.GetArray()...) continue } - selected = append(selected, field) + selected = arena.SliceAppend(a, selected, field) } return selected } From 3e1454f355faf8a1b9f4060cf41f3bd5cafa4336 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 17 Oct 2025 12:40:07 +0200 Subject: [PATCH 18/61] chore: refactor arena pool into separate file --- v2/pkg/engine/resolve/arena.go | 78 +++++++++++++++++ v2/pkg/engine/resolve/inputtemplate.go | 25 ++++-- v2/pkg/engine/resolve/loader.go | 114 +++++++------------------ v2/pkg/engine/resolve/resolve.go | 62 ++------------ 4 files changed, 131 insertions(+), 148 deletions(-) create mode 100644 v2/pkg/engine/resolve/arena.go diff --git a/v2/pkg/engine/resolve/arena.go b/v2/pkg/engine/resolve/arena.go new file mode 100644 index 000000000..1bd5ee495 --- /dev/null +++ b/v2/pkg/engine/resolve/arena.go @@ -0,0 +1,78 @@ +package resolve + +import ( + "sync" + "weak" + + "github.com/wundergraph/go-arena" +) + +// ArenaPool provides a thread-safe pool of arena.Arena instances for memory-efficient allocations. +// It uses weak pointers to allow garbage collection of unused arenas while maintaining +// a pool of reusable arenas for high-frequency allocation patterns. +type ArenaPool struct { + pool []weak.Pointer[ArenaPoolItem] + sizes map[uint64]int + mu sync.Mutex +} + +// ArenaPoolItem wraps an arena.Arena for use in the pool +type ArenaPoolItem struct { + Arena arena.Arena +} + +// NewArenaPool creates a new ArenaPool instance +func NewArenaPool() *ArenaPool { + return &ArenaPool{ + sizes: make(map[uint64]int), + } +} + +// Acquire gets an arena from the pool or creates a new one if none are available. +// The id parameter is used to track arena sizes per use case for optimization. +func (p *ArenaPool) Acquire(id uint64) *ArenaPoolItem { + p.mu.Lock() + defer p.mu.Unlock() + + // Try to find an available arena in the pool + for i := 0; i < len(p.pool); i++ { + v := p.pool[i].Value() + p.pool = append(p.pool[:i], p.pool[i+1:]...) + if v == nil { + continue + } + return v + } + + // No arena available, create a new one + size := arena.WithMinBufferSize(p.getArenaSize(id)) + return &ArenaPoolItem{ + Arena: arena.NewMonotonicArena(size), + } +} + +// Release returns an arena to the pool for reuse. +// The peak memory usage is recorded to optimize future arena sizes for this use case. +func (p *ArenaPool) Release(id uint64, item *ArenaPoolItem) { + peak := item.Arena.Peak() + item.Arena.Reset() + + p.mu.Lock() + defer p.mu.Unlock() + + // Record the peak usage for this use case + p.sizes[id] = peak + + // Add the arena back to the pool using a weak pointer + w := weak.Make(item) + p.pool = append(p.pool, w) +} + +// getArenaSize returns the optimal arena size for a given use case ID. +// If no size is recorded, it defaults to 1MB. +func (p *ArenaPool) getArenaSize(id uint64) int { + if size, ok := p.sizes[id]; ok { + return size + } + return 1024 * 1024 // Default 1MB +} diff --git a/v2/pkg/engine/resolve/inputtemplate.go b/v2/pkg/engine/resolve/inputtemplate.go index 82825cac7..80db3cdd8 100644 --- a/v2/pkg/engine/resolve/inputtemplate.go +++ b/v2/pkg/engine/resolve/inputtemplate.go @@ -1,10 +1,10 @@ package resolve import ( - "bytes" "context" "errors" "fmt" + "io" "github.com/wundergraph/astjson" @@ -36,7 +36,7 @@ type InputTemplate struct { SetTemplateOutputToNullOnVariableNull bool } -func SetInputUndefinedVariables(preparedInput *bytes.Buffer, undefinedVariables []string) error { +func SetInputUndefinedVariables(preparedInput InputTemplateWriter, undefinedVariables []string) error { if len(undefinedVariables) > 0 { output, err := httpclient.SetUndefinedVariables(preparedInput.Bytes(), undefinedVariables) if err != nil { @@ -55,7 +55,14 @@ func SetInputUndefinedVariables(preparedInput *bytes.Buffer, undefinedVariables // to callers; renderSegments intercepts it and writes literal.NULL instead. var errSetTemplateOutputNull = errors.New("set to null") -func (i *InputTemplate) Render(ctx *Context, data *astjson.Value, preparedInput *bytes.Buffer) error { +type InputTemplateWriter interface { + io.Writer + io.StringWriter + Reset() + Bytes() []byte +} + +func (i *InputTemplate) Render(ctx *Context, data *astjson.Value, preparedInput InputTemplateWriter) error { var undefinedVariables []string if err := i.renderSegments(ctx, data, i.Segments, preparedInput, &undefinedVariables); err != nil { @@ -65,12 +72,12 @@ func (i *InputTemplate) Render(ctx *Context, data *astjson.Value, preparedInput return SetInputUndefinedVariables(preparedInput, undefinedVariables) } -func (i *InputTemplate) RenderAndCollectUndefinedVariables(ctx *Context, data *astjson.Value, preparedInput *bytes.Buffer, undefinedVariables *[]string) (err error) { +func (i *InputTemplate) RenderAndCollectUndefinedVariables(ctx *Context, data *astjson.Value, preparedInput InputTemplateWriter, undefinedVariables *[]string) (err error) { err = i.renderSegments(ctx, data, i.Segments, preparedInput, undefinedVariables) return } -func (i *InputTemplate) renderSegments(ctx *Context, data *astjson.Value, segments []TemplateSegment, preparedInput *bytes.Buffer, undefinedVariables *[]string) (err error) { +func (i *InputTemplate) renderSegments(ctx *Context, data *astjson.Value, segments []TemplateSegment, preparedInput InputTemplateWriter, undefinedVariables *[]string) (err error) { for _, segment := range segments { switch segment.SegmentType { case StaticSegmentType: @@ -107,7 +114,7 @@ func (i *InputTemplate) renderSegments(ctx *Context, data *astjson.Value, segmen return err } -func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables *astjson.Value, segment TemplateSegment, preparedInput *bytes.Buffer) error { +func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables *astjson.Value, segment TemplateSegment, preparedInput InputTemplateWriter) error { value := variables.Get(segment.VariableSourcePath...) if value == nil || value.Type() == astjson.TypeNull { if i.SetTemplateOutputToNullOnVariableNull { @@ -119,11 +126,11 @@ func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables *ast return segment.Renderer.RenderVariable(ctx, value, preparedInput) } -func (i *InputTemplate) renderResolvableObjectVariable(ctx context.Context, objectData *astjson.Value, segment TemplateSegment, preparedInput *bytes.Buffer) error { +func (i *InputTemplate) renderResolvableObjectVariable(ctx context.Context, objectData *astjson.Value, segment TemplateSegment, preparedInput InputTemplateWriter) error { return segment.Renderer.RenderVariable(ctx, objectData, preparedInput) } -func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegment, preparedInput *bytes.Buffer) (variableWasUndefined bool, err error) { +func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegment, preparedInput InputTemplateWriter) (variableWasUndefined bool, err error) { variableSourcePath := segment.VariableSourcePath if len(variableSourcePath) == 1 && ctx.RemapVariables != nil { nameToUse, hasMapping := ctx.RemapVariables[variableSourcePath[0]] @@ -142,7 +149,7 @@ func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegm return false, segment.Renderer.RenderVariable(ctx.Context(), value, preparedInput) } -func (i *InputTemplate) renderHeaderVariable(ctx *Context, path []string, preparedInput *bytes.Buffer) error { +func (i *InputTemplate) renderHeaderVariable(ctx *Context, path []string, preparedInput InputTemplateWriter) error { if len(path) != 1 { return errHeaderPathInvalid } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 23da2cbe0..71a3c5304 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -11,7 +11,6 @@ import ( "slices" "strconv" "strings" - "sync" "time" "github.com/buger/jsonparser" @@ -359,7 +358,9 @@ func (l *Loader) resolveSingle(item *FetchItem) error { } func (l *Loader) selectItemsForPath(path []FetchItemPathElement) []*astjson.Value { - items := []*astjson.Value{l.resolvable.data} + // Use arena allocation for the initial items slice + items := arena.AllocateSlice[*astjson.Value](l.jsonArena, 1, 1) + items[0] = l.resolvable.data if len(path) == 0 { return l.taintedObjs.filterOutTainted(items) } @@ -1286,7 +1287,7 @@ func (l *Loader) validatePreFetch(input []byte, info *FetchInfo, res *result) (a func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchItem *FetchItem, items []*astjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) - buf := &bytes.Buffer{} + buf := bytes.NewBuffer(nil) inputData := itemsData(l.jsonArena, items) if l.ctx.TracingOptions.Enable { @@ -1325,36 +1326,8 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI return nil } -var ( - entityFetchPool = sync.Pool{ - New: func() any { - return &entityFetchBuffer{ - item: &bytes.Buffer{}, - preparedInput: &bytes.Buffer{}, - } - }, - } -) - -type entityFetchBuffer struct { - item *bytes.Buffer - preparedInput *bytes.Buffer -} - -func acquireEntityFetchBuffer() *entityFetchBuffer { - return entityFetchPool.Get().(*entityFetchBuffer) -} - -func releaseEntityFetchBuffer(buf *entityFetchBuffer) { - buf.item.Reset() - buf.preparedInput.Reset() - entityFetchPool.Put(buf) -} - func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *EntityFetch, items []*astjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) - buf := acquireEntityFetchBuffer() - defer releaseEntityFetchBuffer(buf) input := itemsData(l.jsonArena, items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} @@ -1363,14 +1336,17 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc } } + preparedInput := bytes.NewBuffer(nil) + item := bytes.NewBuffer(nil) + var undefinedVariables []string - err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = fetch.Input.Item.Render(l.ctx, input, buf.item) + err = fetch.Input.Item.Render(l.ctx, input, item) if err != nil { if fetch.Input.SkipErrItem { // skip fetch on render item error @@ -1382,7 +1358,7 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc } return errors.WithStack(err) } - renderedItem := buf.item.Bytes() + renderedItem := item.Bytes() if bytes.Equal(renderedItem, null) { // skip fetch if item is null res.fetchSkipped = true @@ -1401,17 +1377,17 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc return nil } } - _, _ = buf.item.WriteTo(buf.preparedInput) - err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + _, _ = item.WriteTo(preparedInput) + err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = SetInputUndefinedVariables(buf.preparedInput, undefinedVariables) + err = SetInputUndefinedVariables(preparedInput, undefinedVariables) if err != nil { return errors.WithStack(err) } - fetchInput := buf.preparedInput.Bytes() + fetchInput := preparedInput.Bytes() if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchItem, fetchInput, fetch.Trace) @@ -1429,41 +1405,9 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc return nil } -var ( - batchEntityFetchPool = sync.Pool{} -) - -type batchEntityFetchBuffer struct { - preparedInput *bytes.Buffer - itemInput *bytes.Buffer - keyGen *xxhash.Digest -} - -func acquireBatchEntityFetchBuffer() *batchEntityFetchBuffer { - buf := batchEntityFetchPool.Get() - if buf == nil { - return &batchEntityFetchBuffer{ - preparedInput: &bytes.Buffer{}, - itemInput: &bytes.Buffer{}, - keyGen: xxhash.New(), - } - } - return buf.(*batchEntityFetchBuffer) -} - -func releaseBatchEntityFetchBuffer(buf *batchEntityFetchBuffer) { - buf.preparedInput.Reset() - buf.itemInput.Reset() - buf.keyGen.Reset() - batchEntityFetchPool.Put(buf) -} - func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *BatchEntityFetch, items []*astjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) - buf := acquireBatchEntityFetchBuffer() - defer releaseBatchEntityFetchBuffer(buf) - if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && len(items) != 0 { @@ -1474,9 +1418,13 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, } } + preparedInput := bytes.NewBuffer(make([]byte, 0, 64)) + itemInput := bytes.NewBuffer(make([]byte, 0, 32)) + keyGen := xxhash.New() + var undefinedVariables []string - err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } @@ -1488,8 +1436,8 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, WithNextItem: for i, item := range items { for j := range fetch.Input.Items { - buf.itemInput.Reset() - err = fetch.Input.Items[j].Render(l.ctx, item, buf.itemInput) + itemInput.Reset() + err = fetch.Input.Items[j].Render(l.ctx, item, itemInput) if err != nil { if fetch.Input.SkipErrItems { err = nil // nolint:ineffassign @@ -1501,18 +1449,18 @@ WithNextItem: } return errors.WithStack(err) } - if fetch.Input.SkipNullItems && buf.itemInput.Len() == 4 && bytes.Equal(buf.itemInput.Bytes(), null) { + if fetch.Input.SkipNullItems && itemInput.Len() == 4 && bytes.Equal(itemInput.Bytes(), null) { res.batchStats[i] = append(res.batchStats[i], -1) continue } - if fetch.Input.SkipEmptyObjectItems && buf.itemInput.Len() == 2 && bytes.Equal(buf.itemInput.Bytes(), emptyObject) { + if fetch.Input.SkipEmptyObjectItems && itemInput.Len() == 2 && bytes.Equal(itemInput.Bytes(), emptyObject) { res.batchStats[i] = append(res.batchStats[i], -1) continue } - buf.keyGen.Reset() - _, _ = buf.keyGen.Write(buf.itemInput.Bytes()) - itemHash := buf.keyGen.Sum64() + keyGen.Reset() + _, _ = keyGen.Write(itemInput.Bytes()) + itemHash := keyGen.Sum64() for k := range itemHashes { if itemHashes[k] == itemHash { res.batchStats[i] = append(res.batchStats[i], k) @@ -1521,12 +1469,12 @@ WithNextItem: } itemHashes = append(itemHashes, itemHash) if addSeparator { - err = fetch.Input.Separator.Render(l.ctx, nil, buf.preparedInput) + err = fetch.Input.Separator.Render(l.ctx, nil, preparedInput) if err != nil { return errors.WithStack(err) } } - _, _ = buf.itemInput.WriteTo(buf.preparedInput) + _, _ = itemInput.WriteTo(preparedInput) res.batchStats[i] = append(res.batchStats[i], batchItemIndex) batchItemIndex++ addSeparator = true @@ -1543,16 +1491,16 @@ WithNextItem: } } - err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = SetInputUndefinedVariables(buf.preparedInput, undefinedVariables) + err = SetInputUndefinedVariables(preparedInput, undefinedVariables) if err != nil { return errors.WithStack(err) } - fetchInput := buf.preparedInput.Bytes() + fetchInput := preparedInput.Bytes() if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchItem, fetchInput, fetch.Trace) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index eef77b5b8..ce09fe086 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -7,15 +7,12 @@ import ( "context" "fmt" "io" - "sync" "time" - "weak" "github.com/buger/jsonparser" "github.com/pkg/errors" "go.uber.org/atomic" - "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/xcontext" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) @@ -73,18 +70,12 @@ type Resolver struct { // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration - arenaPool []weak.Pointer[arenaPoolItem] - arenaSize map[uint64]int - arenaPoolMu sync.Mutex + arenaPool *ArenaPool // Single flight cache for deduplicating requests across all loaders sf *SingleFlight } -type arenaPoolItem struct { - jsonArena arena.Arena -} - func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) { r.asyncErrorWriter = w } @@ -236,6 +227,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { allowedErrorFields: allowedErrorFields, heartbeatInterval: options.SubscriptionHeartbeatInterval, maxSubscriptionFetchTimeout: options.MaxSubscriptionFetchTimeout, + arenaPool: NewArenaPool(), sf: NewSingleFlight(), } resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) @@ -243,8 +235,6 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { resolver.maxConcurrency <- struct{}{} } - resolver.arenaSize = make(map[uint64]int) - go resolver.processEvents() return resolver @@ -309,46 +299,6 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons return resp, err } -func (r *Resolver) acquireArena(id uint64) *arenaPoolItem { - r.arenaPoolMu.Lock() - defer r.arenaPoolMu.Unlock() - - for i := 0; i < len(r.arenaPool); i++ { - v := r.arenaPool[i].Value() - r.arenaPool = append(r.arenaPool[:i], r.arenaPool[i+1:]...) - if v == nil { - continue - } - return v - } - - size := arena.WithMinBufferSize(r.getArenaSize(id)) - - return &arenaPoolItem{ - jsonArena: arena.NewMonotonicArena(size), - } -} - -func (r *Resolver) getArenaSize(id uint64) int { - if size, ok := r.arenaSize[id]; ok { - return size - } - return 1024 * 1024 -} - -func (r *Resolver) releaseArena(id uint64, item *arenaPoolItem) { - peak := item.jsonArena.Peak() - item.jsonArena.Reset() - - r.arenaPoolMu.Lock() - defer r.arenaPoolMu.Unlock() - - r.arenaSize[id] = peak - - w := weak.Make(item) - r.arenaPool = append(r.arenaPool, w) -} - func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLResponse, writer io.Writer) (*GraphQLResolveInfo, error) { resp := &GraphQLResolveInfo{} @@ -361,10 +311,10 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) - poolItem := r.acquireArena(ctx.Request.ID) - defer r.releaseArena(ctx.Request.ID, poolItem) - t.loader.jsonArena = poolItem.jsonArena - t.resolvable.astjsonArena = poolItem.jsonArena + poolItem := r.arenaPool.Acquire(ctx.Request.ID) + defer r.arenaPool.Release(ctx.Request.ID, poolItem) + t.loader.jsonArena = poolItem.Arena + t.resolvable.astjsonArena = poolItem.Arena err := t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { From a41ec06ba8ca13b0121889ce35fa48d9381a8951 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sun, 19 Oct 2025 19:50:20 +0200 Subject: [PATCH 19/61] refactor: update buffer size in HTTP client and enhance arena pool size tracking --- .../datasource/httpclient/nethttpclient.go | 2 +- v2/pkg/engine/resolve/arena.go | 25 ++++++++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 27b0434c1..d6276c837 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -143,7 +143,7 @@ func buffer(ctx context.Context) *bytes.Buffer { if sizeHint, ok := ctx.Value(sizeHintKey).(int); ok && sizeHint > 0 { return bytes.NewBuffer(make([]byte, 0, sizeHint)) } - return bytes.NewBuffer(make([]byte, 0, 1024*4)) // default to 4KB + return bytes.NewBuffer(make([]byte, 0, 64)) } func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, contentType string) ([]byte, error) { diff --git a/v2/pkg/engine/resolve/arena.go b/v2/pkg/engine/resolve/arena.go index 1bd5ee495..0aae88974 100644 --- a/v2/pkg/engine/resolve/arena.go +++ b/v2/pkg/engine/resolve/arena.go @@ -12,10 +12,15 @@ import ( // a pool of reusable arenas for high-frequency allocation patterns. type ArenaPool struct { pool []weak.Pointer[ArenaPoolItem] - sizes map[uint64]int + sizes map[uint64]*arenaPoolItemSize mu sync.Mutex } +type arenaPoolItemSize struct { + count int + totalBytes int +} + // ArenaPoolItem wraps an arena.Arena for use in the pool type ArenaPoolItem struct { Arena arena.Arena @@ -24,7 +29,7 @@ type ArenaPoolItem struct { // NewArenaPool creates a new ArenaPool instance func NewArenaPool() *ArenaPool { return &ArenaPool{ - sizes: make(map[uint64]int), + sizes: make(map[uint64]*arenaPoolItemSize), } } @@ -61,7 +66,19 @@ func (p *ArenaPool) Release(id uint64, item *ArenaPoolItem) { defer p.mu.Unlock() // Record the peak usage for this use case - p.sizes[id] = peak + if size, ok := p.sizes[id]; ok { + if size.count == 50 { + size.count = 1 + size.totalBytes = size.totalBytes / 50 + } + size.count++ + size.totalBytes += peak + } else { + p.sizes[id] = &arenaPoolItemSize{ + count: 1, + totalBytes: peak, + } + } // Add the arena back to the pool using a weak pointer w := weak.Make(item) @@ -72,7 +89,7 @@ func (p *ArenaPool) Release(id uint64, item *ArenaPoolItem) { // If no size is recorded, it defaults to 1MB. func (p *ArenaPool) getArenaSize(id uint64) int { if size, ok := p.sizes[id]; ok { - return size + return size.totalBytes / size.count } return 1024 * 1024 // Default 1MB } From ced27f30f64b24b461e1c1e28b44878ea7c28723 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 20 Oct 2025 20:24:59 +0200 Subject: [PATCH 20/61] chore: add second arena for response buffer --- v2/pkg/engine/resolve/resolve.go | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index ce09fe086..90b534174 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -11,6 +11,7 @@ import ( "github.com/buger/jsonparser" "github.com/pkg/errors" + "github.com/wundergraph/go-arena" "go.uber.org/atomic" "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/xcontext" @@ -70,7 +71,8 @@ type Resolver struct { // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration - arenaPool *ArenaPool + resolveArenaPool *ArenaPool + responseBufferPool *ArenaPool // Single flight cache for deduplicating requests across all loaders sf *SingleFlight @@ -227,7 +229,8 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { allowedErrorFields: allowedErrorFields, heartbeatInterval: options.SubscriptionHeartbeatInterval, maxSubscriptionFetchTimeout: options.MaxSubscriptionFetchTimeout, - arenaPool: NewArenaPool(), + resolveArenaPool: NewArenaPool(), + responseBufferPool: NewArenaPool(), sf: NewSingleFlight(), } resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) @@ -311,28 +314,36 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) - poolItem := r.arenaPool.Acquire(ctx.Request.ID) - defer r.arenaPool.Release(ctx.Request.ID, poolItem) - t.loader.jsonArena = poolItem.Arena - t.resolvable.astjsonArena = poolItem.Arena + resolveArena := r.resolveArenaPool.Acquire(ctx.Request.ID) + t.loader.jsonArena = resolveArena.Arena + t.resolvable.astjsonArena = resolveArena.Arena err := t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { + r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) return nil, err } if !ctx.ExecutionOptions.SkipLoader { err = t.loader.LoadGraphQLResponseData(ctx, response, t.resolvable) if err != nil { + r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) return nil, err } } - err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, writer) + responseArena := r.responseBufferPool.Acquire(ctx.Request.ID) + buf := arena.NewArenaBuffer(responseArena.Arena) + err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, buf) if err != nil { + r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) + r.responseBufferPool.Release(ctx.Request.ID, responseArena) return nil, err } + r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) + _, err = writer.Write(buf.Bytes()) + r.responseBufferPool.Release(ctx.Request.ID, responseArena) return resp, err } From 67db907e1f1a9a94ff05b09af31d7ee6fb9fdcb2 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 24 Oct 2025 12:28:20 +0200 Subject: [PATCH 21/61] chore: add headers to DataSource args, add HeadersForSubgraphRequest to resolve Context --- .../graphql_datasource/graphql_datasource.go | 29 +-- .../graphql_datasource_test.go | 35 ++- .../graphql_subscription_client.go | 108 ---------- .../graphql_subscription_client_test.go | 202 ------------------ .../grpc_datasource/grpc_datasource.go | 5 +- .../grpc_datasource/grpc_datasource_test.go | 26 +-- .../datasource/httpclient/httpclient_test.go | 4 +- .../datasource/httpclient/nethttpclient.go | 18 +- .../introspection_datasource/source.go | 5 +- .../introspection_datasource/source_test.go | 2 +- .../pubsub_datasource_test.go | 8 + .../pubsub_datasource/pubsub_kafka.go | 31 +-- .../pubsub_datasource/pubsub_nats.go | 35 +-- .../staticdatasource/static_datasource.go | 5 +- v2/pkg/engine/plan/planner_test.go | 5 +- v2/pkg/engine/plan/visitor.go | 2 + v2/pkg/engine/resolve/authorization_test.go | 25 +-- v2/pkg/engine/resolve/context.go | 15 ++ v2/pkg/engine/resolve/datasource.go | 13 +- v2/pkg/engine/resolve/event_loop_test.go | 9 +- v2/pkg/engine/resolve/loader.go | 34 ++- v2/pkg/engine/resolve/loader_hooks_test.go | 53 ++--- v2/pkg/engine/resolve/loader_test.go | 4 +- v2/pkg/engine/resolve/resolve.go | 82 ++++--- .../engine/resolve/resolve_federation_test.go | 95 ++++---- v2/pkg/engine/resolve/resolve_mock_test.go | 17 +- v2/pkg/engine/resolve/resolve_test.go | 188 ++++++++-------- v2/pkg/engine/resolve/response.go | 13 +- v2/pkg/engine/resolve/singleflight.go | 13 +- 29 files changed, 382 insertions(+), 699 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 6f301d52d..457468184 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -14,7 +14,6 @@ import ( "unicode" "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" "github.com/jensneuse/abstractlogger" "github.com/pkg/errors" "github.com/tidwall/sjson" @@ -1907,20 +1906,19 @@ func (s *Source) replaceEmptyObject(variables []byte) ([]byte, bool) { return variables, false } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (s *Source) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { input = s.compactAndUnNullVariables(input) - return httpclient.DoMultipartForm(s.httpClient, ctx, input, files) + return httpclient.DoMultipartForm(s.httpClient, ctx, headers, input, files) } -func (s *Source) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (s *Source) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { input = s.compactAndUnNullVariables(input) - return httpclient.Do(s.httpClient, ctx, input) + return httpclient.Do(s.httpClient, ctx, headers, input) } type GraphQLSubscriptionClient interface { // Subscribe to the origin source. The implementation must not block the calling goroutine. Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error - UniqueRequestID(ctx *resolve.Context, options GraphQLSubscriptionOptions, hash *xxhash.Digest) (err error) SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error Unsubscribe(id uint64) } @@ -1956,12 +1954,13 @@ type SubscriptionSource struct { client GraphQLSubscriptionClient } -func (s *SubscriptionSource) AsyncStart(ctx *resolve.Context, id uint64, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *SubscriptionSource) AsyncStart(ctx *resolve.Context, id uint64, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var options GraphQLSubscriptionOptions err := json.Unmarshal(input, &options) if err != nil { return err } + options.Header = headers if options.Body.Query == "" { return resolve.ErrUnableToResolve } @@ -1975,12 +1974,13 @@ func (s *SubscriptionSource) AsyncStop(id uint64) { } // Start the subscription. The updater is called on new events. Start needs to be called in a separate goroutine. -func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *SubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var options GraphQLSubscriptionOptions err := json.Unmarshal(input, &options) if err != nil { return err } + options.Header = headers if options.Body.Query == "" { return resolve.ErrUnableToResolve } @@ -1990,16 +1990,3 @@ func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater r var ( dataSouceName = []byte("graphql") ) - -func (s *SubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = xxh.Write(dataSouceName) - if err != nil { - return err - } - var options GraphQLSubscriptionOptions - err = json.Unmarshal(input, &options) - if err != nil { - return err - } - return s.client.UniqueRequestID(ctx, options, xxh) -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 75a23f5ed..e064b607e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -16,7 +16,6 @@ import ( "testing" "time" - "github.com/cespare/xxhash/v2" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -4021,6 +4020,8 @@ func TestGraphQLDataSource(t *testing.T) { NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), }, PostProcessing: DefaultPostProcessingConfiguration, + SourceName: "ds-id", + SourceID: "ds-id", }, Response: &resolve.GraphQLResponse{ Fetches: resolve.Sequence(), @@ -4062,6 +4063,8 @@ func TestGraphQLDataSource(t *testing.T) { client: NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), }, PostProcessing: DefaultPostProcessingConfiguration, + SourceName: "ds-id", + SourceID: "ds-id", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -8258,10 +8261,6 @@ func (f *FailingSubscriptionClient) Subscribe(ctx *resolve.Context, options Grap return errSubscriptionClientFail } -func (f *FailingSubscriptionClient) UniqueRequestID(ctx *resolve.Context, options GraphQLSubscriptionOptions, hash *xxhash.Digest) (err error) { - return errSubscriptionClientFail -} - type testSubscriptionUpdater struct { updates []string done bool @@ -8375,13 +8374,13 @@ func TestSubscriptionSource_Start(t *testing.T) { t.Run("should return error when input is invalid", func(t *testing.T) { source := SubscriptionSource{client: &FailingSubscriptionClient{}} - err := source.Start(resolve.NewContext(context.Background()), []byte(`{"url": "", "body": "", "header": null}`), nil) + err := source.Start(resolve.NewContext(context.Background()), nil, []byte(`{"url": "", "body": "", "header": null}`), nil) assert.Error(t, err) }) t.Run("should return error when subscription client returns an error", func(t *testing.T) { source := SubscriptionSource{client: &FailingSubscriptionClient{}} - err := source.Start(resolve.NewContext(context.Background()), []byte(`{"url": "", "body": {}, "header": null}`), nil) + err := source.Start(resolve.NewContext(context.Background()), nil, []byte(`{"url": "", "body": {}, "header": null}`), nil) assert.Error(t, err) assert.Equal(t, resolve.ErrUnableToResolve, err) }) @@ -8394,7 +8393,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: "#test") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.ErrorIs(t, err, resolve.ErrUnableToResolve) }) @@ -8406,7 +8405,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) updater.AwaitUpdates(t, time.Second, 1) assert.Len(t, updater.updates, 1) @@ -8424,7 +8423,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(resolverLifecycle) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(resolve.NewContext(subscriptionLifecycle), chatSubscriptionOptions, updater) + err := source.Start(resolve.NewContext(subscriptionLifecycle), nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8447,7 +8446,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8511,7 +8510,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) updater.AwaitUpdates(t, time.Second, 1) @@ -8531,7 +8530,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { source := newSubscriptionSource(resolverLifecycle) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(resolve.NewContext(subscriptionLifecycle), chatSubscriptionOptions, updater) + err := source.Start(resolve.NewContext(subscriptionLifecycle), nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8555,7 +8554,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8693,7 +8692,7 @@ func TestSource_Load(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputURL(input, []byte(serverUrl)) - data, err := src.Load(context.Background(), input) + data, err := src.Load(context.Background(), nil, input) require.NoError(t, err) assert.Equal(t, `{"variables":{"a":null,"b":"b","c":{}}}`, string(data)) }) @@ -8715,7 +8714,7 @@ func TestSource_Load(t *testing.T) { input, err = httpclient.SetUndefinedVariables(input, undefinedVariables) assert.NoError(t, err) - data, err := src.Load(ctx, input) + data, err := src.Load(ctx, nil, input) require.NoError(t, err) assert.Equal(t, `{"variables":{"b":null}}`, string(data)) }) @@ -8801,7 +8800,7 @@ func TestLoadFiles(t *testing.T) { input = httpclient.SetInputURL(input, []byte(serverUrl)) ctx := context.Background() - _, err = src.LoadWithFiles(ctx, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}) + _, err = src.LoadWithFiles(ctx, nil, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}) require.NoError(t, err) }) @@ -8856,7 +8855,7 @@ func TestLoadFiles(t *testing.T) { assert.NoError(t, err) ctx := context.Background() - _, err = src.LoadWithFiles(ctx, input, + _, err = src.LoadWithFiles(ctx, nil, input, []*httpclient.FileUpload{ httpclient.NewFileUpload(f1.Name(), file1Name, "variables.files.0"), httpclient.NewFileUpload(f2.Name(), file2Name, "variables.files.1")}) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index c5a52a476..c8a08df03 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -9,13 +9,9 @@ import ( "errors" "fmt" "io" - "maps" "net" "net/http" "net/http/httptrace" - "net/textproto" - "slices" - "strconv" "strings" "sync" "syscall" @@ -295,27 +291,6 @@ func (c *subscriptionClient) Subscribe(ctx *resolve.Context, options GraphQLSubs return c.subscribeWS(ctx.Context(), c.engineCtx, options, updater) } -var ( - withSSE = []byte(`sse:true`) - withSSEMethodPost = []byte(`sse_method_post:true`) -) - -func (c *subscriptionClient) UniqueRequestID(ctx *resolve.Context, options GraphQLSubscriptionOptions, hash *xxhash.Digest) (err error) { - if options.UseSSE { - _, err = hash.Write(withSSE) - if err != nil { - return err - } - } - if options.SSEMethodPost { - _, err = hash.Write(withSSEMethodPost) - if err != nil { - return err - } - } - return c.requestHash(ctx, options, hash) -} - func (c *subscriptionClient) subscribeSSE(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { options.readTimeout = c.readTimeout if c.streamingClient == nil { @@ -409,89 +384,6 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont return nil } -// generateHandlerIDHash generates a Hash based on: URL and Headers to uniquely identify Upgrade Requests -func (c *subscriptionClient) requestHash(ctx *resolve.Context, options GraphQLSubscriptionOptions, xxh *xxhash.Digest) (err error) { - if _, err = xxh.WriteString(options.URL); err != nil { - return err - } - if err := options.Header.Write(xxh); err != nil { - return err - } - // Make sure any header that will be forwarded to the subgraph - // is hashed to create the handlerID, this way requests with - // different headers will use separate connections. - for _, headerName := range options.ForwardedClientHeaderNames { - if _, err = xxh.WriteString(headerName); err != nil { - return err - } - for _, val := range ctx.Request.Header[textproto.CanonicalMIMEHeaderKey(headerName)] { - if _, err = xxh.WriteString(val); err != nil { - return err - } - } - } - - // Sort header names for deterministic hashing since looping through maps - // results in a non-deterministic order of elements - headerKeys := slices.Sorted(maps.Keys(ctx.Request.Header)) - - for _, headerRegexp := range options.ForwardedClientHeaderRegularExpressions { - // Write header pattern - if _, err = xxh.WriteString(headerRegexp.Pattern.String()); err != nil { - return err - } - - // Write negate match - if _, err = xxh.WriteString(strconv.FormatBool(headerRegexp.NegateMatch)); err != nil { - return err - } - - for _, headerName := range headerKeys { - values := ctx.Request.Header[headerName] - result := headerRegexp.Pattern.MatchString(headerName) - if headerRegexp.NegateMatch { - result = !result - } - if result { - for _, val := range values { - if _, err = xxh.WriteString(val); err != nil { - return err - } - } - } - } - } - if len(ctx.InitialPayload) > 0 { - if _, err = xxh.Write(ctx.InitialPayload); err != nil { - return err - } - } - if options.Body.Extensions != nil { - if _, err = xxh.Write(options.Body.Extensions); err != nil { - return err - } - } - if options.Body.Query != "" { - _, err = xxh.WriteString(options.Body.Query) - if err != nil { - return err - } - } - if options.Body.Variables != nil { - _, err = xxh.Write(options.Body.Variables) - if err != nil { - return err - } - } - if options.Body.OperationName != "" { - _, err = xxh.WriteString(options.Body.OperationName) - if err != nil { - return err - } - } - return nil -} - type UpgradeRequestError struct { URL string StatusCode int diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 279c4bfe8..25eaa29f7 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "regexp" "runtime" "strings" "sync" @@ -15,7 +14,6 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" "github.com/coder/websocket" ll "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" @@ -2571,203 +2569,3 @@ func TestInvalidWebSocketAcceptKey(t *testing.T) { }) } } - -func TestRequestHash(t *testing.T) { - t.Parallel() - client := &subscriptionClient{} - - t.Run("basic request with URL and headers", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - Header: http.Header{ - "Authorization": []string{"Bearer token"}, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xacbca06c541c2a79), hash.Sum64()) - }) - - t.Run("request with forwarded client headers", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{ - "X-User-Id": []string{"123"}, - "X-Role": []string{"admin"}, - }, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderNames: []string{"X-User-Id", "X-Role"}, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xf428bef25952044c), hash.Sum64()) - }) - - t.Run("request with forwarded client header regex patterns", func(t *testing.T) { - t.Parallel() - - t.Run("with normal", func(t *testing.T) { - header := http.Header{ - "X-Custom-1": []string{"value1"}, - "X-There-2": []string{"value2"}, - "X-Alright-3": []string{"value3"}, - } - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: header, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderRegularExpressions: []RegularExpression{ - { - Pattern: regexp.MustCompile("^X-Custom-.*$"), - NegateMatch: false, - }, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xb1557904bfa9d86a), hash.Sum64()) - }) - - t.Run("with negative", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{ - "X-Custom-1": []string{"valueThere1"}, - "X-Custom-2": []string{"valueThere2"}, - }, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderRegularExpressions: []RegularExpression{ - { - Pattern: regexp.MustCompile("^X-Custom-2"), - NegateMatch: true, - }, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x5888642db454ccab), hash.Sum64()) - }) - - t.Run("with multiple tries to ensure the hash is idempotent", func(t *testing.T) { - for range 100 { - header := http.Header{ - "X-Custom-1": []string{"a1"}, - "X-There-2": []string{"a2"}, - "X-Custom-6": []string{"a3"}, - "X-Alright-3": []string{"a4"}, - "X-Custom-5": []string{"a5"}, - } - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: header, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderRegularExpressions: []RegularExpression{ - { - Pattern: regexp.MustCompile("^X-Custom-.*$"), - NegateMatch: false, - }, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x6c9c1099adab987d), hash.Sum64()) - } - }) - }) - - t.Run("request with initial payload", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - InitialPayload: []byte(`{"auth": "token"}`), - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x3c5af329478bfcce), hash.Sum64()) - - }) - - t.Run("request with body components", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - Body: GraphQLBody{ - Query: "query { hello }", - Variables: []byte(`{"var": "value"}`), - OperationName: "HelloQuery", - Extensions: []byte(`{"ext": "value"}`), - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xd8d5588c8a466cf2), hash.Sum64()) - }) - - t.Run("empty components", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x767db2231989769), hash.Sum64()) - }) - -} diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 58729e33c..1305fda5f 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -9,6 +9,7 @@ package grpcdatasource import ( "context" "fmt" + "net/http" "sync" "github.com/tidwall/gjson" @@ -77,7 +78,7 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D // // The input is expected to contain the necessary information to make // a gRPC call, including service name, method name, and request data. -func (d *DataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { // get variables from input variables := gjson.Parse(string(input)).Get("body.variables") builder := newJSONBuilder(d.mapping, variables) @@ -150,6 +151,6 @@ func (d *DataSource) Load(ctx context.Context, input []byte) (data []byte, err e // might not be applicable for most gRPC use cases. // // Currently unimplemented. -func (d *DataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (d *DataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("unimplemented") } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 2a18e2f17..348a502d7 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -146,7 +146,7 @@ func Test_DataSource_Load(t *testing.T) { require.NoError(t, err) - output, err := ds.Load(context.Background(), []byte(`{"query":"`+query+`","variables":`+variables+`}`)) + output, err := ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","variables":`+variables+`}`)) require.NoError(t, err) fmt.Println(string(output)) @@ -217,7 +217,7 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { require.NoError(t, err) // 3. Execute the query through our datasource - output, err := ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`)) + output, err := ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(t, err) // Print the response for debugging @@ -309,7 +309,7 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { // Format the input with query and variables inputJSON := fmt.Sprintf(`{"query":%q,"body":%s}`, query, variables) - output, err := ds.Load(context.Background(), []byte(inputJSON)) + output, err := ds.Load(context.Background(), nil, []byte(inputJSON)) require.NoError(t, err) // Set up the correct response structure based on your GraphQL schema @@ -401,7 +401,7 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { require.NoError(t, err) // 4. Execute the query - output, err := ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`)) + output, err := ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(t, err, "Load should not return an error even when the gRPC call fails") responseJson := string(output) @@ -727,7 +727,7 @@ func Test_DataSource_Load_WithAnimalInterface(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -997,7 +997,7 @@ func Test_Datasource_Load_WithUnionTypes(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1133,7 +1133,7 @@ func Test_DataSource_Load_WithCategoryQueries(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1213,7 +1213,7 @@ func Test_DataSource_Load_WithTotalCalculation(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, query, variables) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1303,7 +1303,7 @@ func Test_DataSource_Load_WithTypename(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":{}}`, query) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1772,7 +1772,7 @@ func Test_DataSource_Load_WithAliases(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -2150,7 +2150,7 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -3451,7 +3451,7 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -3603,7 +3603,7 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response diff --git a/v2/pkg/engine/datasource/httpclient/httpclient_test.go b/v2/pkg/engine/datasource/httpclient/httpclient_test.go index cbef2d1f7..98685cece 100644 --- a/v2/pkg/engine/datasource/httpclient/httpclient_test.go +++ b/v2/pkg/engine/datasource/httpclient/httpclient_test.go @@ -79,7 +79,7 @@ func TestHttpClientDo(t *testing.T) { runTest := func(ctx context.Context, input []byte, expectedOutput string) func(t *testing.T) { return func(t *testing.T) { - output, err := Do(http.DefaultClient, ctx, input) + output, err := Do(http.DefaultClient, ctx, nil, input) assert.NoError(t, err) assert.Equal(t, expectedOutput, string(output)) } @@ -209,7 +209,7 @@ func TestHttpClientDo(t *testing.T) { input = SetInputURL(input, []byte(server.URL)) input, err := sjson.SetBytes(input, TRACE, true) assert.NoError(t, err) - output, err := Do(http.DefaultClient, context.Background(), input) + output, err := Do(http.DefaultClient, context.Background(), nil, input) assert.NoError(t, err) assert.Contains(t, string(output), `"Authorization":["****"]`) }) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index d6276c837..4c4f2de3d 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -27,6 +27,7 @@ const ( AcceptEncodingHeader = "Accept-Encoding" AcceptHeader = "Accept" ContentTypeHeader = "Content-Type" + ContentLengthHeader = "Content-Length" EncodingGzip = "gzip" EncodingDeflate = "deflate" @@ -146,13 +147,17 @@ func buffer(ctx context.Context) *bytes.Buffer { return bytes.NewBuffer(make([]byte, 0, 64)) } -func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, contentType string) ([]byte, error) { +func makeHTTPRequest(client *http.Client, ctx context.Context, baseHeaders http.Header, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, contentType string, contentLength int) ([]byte, error) { request, err := http.NewRequestWithContext(ctx, string(method), string(url), body) if err != nil { return nil, err } + if baseHeaders != nil { + request.Header = baseHeaders + } + if headers != nil { err = jsonparser.ObjectEach(headers, func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error { _, err := jsonparser.ArrayEach(value, func(value []byte, dataType jsonparser.ValueType, offset int, err error) { @@ -205,6 +210,9 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head request.Header.Add(ContentTypeHeader, contentType) request.Header.Set(AcceptEncodingHeader, EncodingGzip) request.Header.Add(AcceptEncodingHeader, EncodingDeflate) + if contentLength > 0 { + request.Header.Set(ContentLengthHeader, fmt.Sprintf("%d", contentLength)) + } setRequest(ctx, request) @@ -256,13 +264,13 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head return responseWithTraceExtension, nil } -func Do(client *http.Client, ctx context.Context, requestInput []byte) (data []byte, err error) { +func Do(client *http.Client, ctx context.Context, baseHeaders http.Header, requestInput []byte) (data []byte, err error) { url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) - return makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, ContentTypeJSON) + return makeHTTPRequest(client, ctx, baseHeaders, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, ContentTypeJSON, len(body)) } func DoMultipartForm( - client *http.Client, ctx context.Context, requestInput []byte, files []*FileUpload, + client *http.Client, ctx context.Context, baseHeaders http.Header, requestInput []byte, files []*FileUpload, ) (data []byte, err error) { if len(files) == 0 { return nil, errors.New("no files provided") @@ -316,7 +324,7 @@ func DoMultipartForm( } }() - return makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, contentType) + return makeHTTPRequest(client, ctx, baseHeaders, url, method, headers, queryParams, multipartBody, enableTrace, contentType, 0) } func multipartBytes(values map[string]io.Reader, files []*FileUpload) (*io.PipeReader, string, error) { diff --git a/v2/pkg/engine/datasource/introspection_datasource/source.go b/v2/pkg/engine/datasource/introspection_datasource/source.go index a55549ace..67195e44a 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "io" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/introspection" @@ -18,7 +19,7 @@ type Source struct { introspectionData *introspection.Data } -func (s *Source) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (s *Source) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var req introspectionInput if err := json.Unmarshal(input, &req); err != nil { return nil, err @@ -31,7 +32,7 @@ func (s *Source) Load(ctx context.Context, input []byte) (data []byte, err error return json.Marshal(s.introspectionData.Schema) } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (s *Source) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { return nil, errors.New("introspection data source does not support file uploads") } diff --git a/v2/pkg/engine/datasource/introspection_datasource/source_test.go b/v2/pkg/engine/datasource/introspection_datasource/source_test.go index 7c331b7d1..9737a4ee9 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source_test.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source_test.go @@ -28,7 +28,7 @@ func TestSource_Load(t *testing.T) { require.False(t, report.HasErrors()) source := &Source{introspectionData: &data} - responseData, err := source.Load(context.Background(), []byte(input)) + responseData, err := source.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) actualResponse := &bytes.Buffer{} diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go index 28a37df33..2ea8114ad 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go @@ -424,6 +424,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"helloSubscription"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -487,6 +489,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"subscriptionWithMultipleSubjects"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -532,6 +536,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"subscriptionWithStaticValues"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -583,6 +589,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"subscriptionWithArgTemplateAndStaticValue"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go index 7f1a6226b..3f688b6b1 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go @@ -3,9 +3,7 @@ package pubsub_datasource import ( "context" "encoding/json" - - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -31,28 +29,7 @@ type KafkaSubscriptionSource struct { pubSub KafkaPubSub } -func (s *KafkaSubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - - val, _, _, err := jsonparser.Get(input, "topics") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } - - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } - - _, err = xxh.Write(val) - return err -} - -func (s *KafkaSubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *KafkaSubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var subscriptionConfiguration KafkaSubscriptionEventConfiguration err := json.Unmarshal(input, &subscriptionConfiguration) if err != nil { @@ -66,7 +43,7 @@ type KafkaPublishDataSource struct { pubSub KafkaPubSub } -func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (s *KafkaPublishDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var publishConfiguration KafkaPublishEventConfiguration err = json.Unmarshal(input, &publishConfiguration) if err != nil { @@ -79,6 +56,6 @@ func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte) (data [ return []byte(`{"success": true}`), nil } -func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go index e5d3bec0f..776b5deac 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go @@ -5,9 +5,7 @@ import ( "context" "encoding/json" "io" - - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -42,28 +40,7 @@ type NatsSubscriptionSource struct { pubSub NatsPubSub } -func (s *NatsSubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - - val, _, _, err := jsonparser.Get(input, "subjects") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } - - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } - - _, err = xxh.Write(val) - return err -} - -func (s *NatsSubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *NatsSubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var subscriptionConfiguration NatsSubscriptionEventConfiguration err := json.Unmarshal(input, &subscriptionConfiguration) if err != nil { @@ -77,7 +54,7 @@ type NatsPublishDataSource struct { pubSub NatsPubSub } -func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (s *NatsPublishDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var publishConfiguration NatsPublishAndRequestEventConfiguration err = json.Unmarshal(input, &publishConfiguration) if err != nil { @@ -91,7 +68,7 @@ func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte) (data [] return []byte(`{"success": true}`), nil } -func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } @@ -99,7 +76,7 @@ type NatsRequestDataSource struct { pubSub NatsPubSub } -func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (s *NatsRequestDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var subscriptionConfiguration NatsPublishAndRequestEventConfiguration err = json.Unmarshal(input, &subscriptionConfiguration) if err != nil { @@ -115,6 +92,6 @@ func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte) (data [] return buf.Bytes(), nil } -func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go index 626a1d9f9..3fb75c8b3 100644 --- a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go +++ b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go @@ -2,6 +2,7 @@ package staticdatasource import ( "context" + "net/http" "github.com/jensneuse/abstractlogger" @@ -70,10 +71,10 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration { type Source struct{} -func (Source) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (Source) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { return input, nil } -func (Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (Source) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/plan/planner_test.go b/v2/pkg/engine/plan/planner_test.go index 658ff3fc7..b952107f0 100644 --- a/v2/pkg/engine/plan/planner_test.go +++ b/v2/pkg/engine/plan/planner_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "reflect" "slices" "testing" @@ -1074,10 +1075,10 @@ type FakeDataSource struct { source *StatefulSource } -func (f *FakeDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (f *FakeDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { return nil, nil } -func (f *FakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (f *FakeDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { return nil, nil } diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index ef8d09475..72dbe719c 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -1290,6 +1290,8 @@ func (v *Visitor) configureSubscription(config *objectFetchConfiguration) { v.subscription.Trigger.QueryPlan = subscription.QueryPlan v.resolveInputTemplates(config, &subscription.Input, &v.subscription.Trigger.Variables) v.subscription.Trigger.Input = []byte(subscription.Input) + v.subscription.Trigger.SourceName = config.sourceName + v.subscription.Trigger.SourceID = config.sourceID v.subscription.Filter = config.filter } diff --git a/v2/pkg/engine/resolve/authorization_test.go b/v2/pkg/engine/resolve/authorization_test.go index ea83c7725..95051def7 100644 --- a/v2/pkg/engine/resolve/authorization_test.go +++ b/v2/pkg/engine/resolve/authorization_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "io" + "net/http" "sync/atomic" "testing" @@ -509,8 +510,8 @@ func TestAuthorization(t *testing.T) { func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller) *GraphQLResponse { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -519,8 +520,8 @@ func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) @@ -529,8 +530,8 @@ func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) @@ -814,8 +815,8 @@ func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller func generateTestFederationGraphQLResponseWithoutAuthorizationRules(t *testing.T, ctrl *gomock.Controller) *GraphQLResponse { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -824,8 +825,8 @@ func generateTestFederationGraphQLResponseWithoutAuthorizationRules(t *testing.T reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) @@ -834,8 +835,8 @@ func generateTestFederationGraphQLResponseWithoutAuthorizationRules(t *testing.T productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index e9958d24e..b0b82f578 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -32,12 +32,27 @@ type Context struct { fieldRenderer FieldValueRenderer subgraphErrors error + + SubgraphHeadersBuilder HeadersForSubgraphRequest +} + +type HeadersForSubgraphRequest interface { + HeadersForSubgraph(subgraphName string) (http.Header, uint64) +} + +func (c *Context) HeadersForSubgraphRequest(subgraphName string) (http.Header, uint64) { + if c.SubgraphHeadersBuilder == nil { + return nil, 0 + } + return c.SubgraphHeadersBuilder.HeadersForSubgraph(subgraphName) } type ExecutionOptions struct { SkipLoader bool IncludeQueryPlanInResponse bool SendHeartbeat bool + // DisableRequestDeduplication disables deduplication of requests to the same subgraph with the same input within a single operation execution. + DisableRequestDeduplication bool } type FieldValue struct { diff --git a/v2/pkg/engine/resolve/datasource.go b/v2/pkg/engine/resolve/datasource.go index 8063541f6..7855fa637 100644 --- a/v2/pkg/engine/resolve/datasource.go +++ b/v2/pkg/engine/resolve/datasource.go @@ -2,26 +2,23 @@ package resolve import ( "context" - - "github.com/cespare/xxhash/v2" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" ) type DataSource interface { - Load(ctx context.Context, input []byte) (data []byte, err error) - LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) + Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) + LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) } type SubscriptionDataSource interface { // Start is called when a new subscription is created. It establishes the connection to the data source. // The updater is used to send updates to the client. Deduplication of the request must be done before calling this method. - Start(ctx *Context, input []byte, updater SubscriptionUpdater) error - UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) + Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error } type AsyncSubscriptionDataSource interface { - AsyncStart(ctx *Context, id uint64, input []byte, updater SubscriptionUpdater) error + AsyncStart(ctx *Context, id uint64, headers http.Header, input []byte, updater SubscriptionUpdater) error AsyncStop(id uint64) - UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) } diff --git a/v2/pkg/engine/resolve/event_loop_test.go b/v2/pkg/engine/resolve/event_loop_test.go index 11389630a..ba8b7c8e2 100644 --- a/v2/pkg/engine/resolve/event_loop_test.go +++ b/v2/pkg/engine/resolve/event_loop_test.go @@ -3,12 +3,12 @@ package resolve import ( "context" "io" + "net/http" "sync" "sync/atomic" "testing" "time" - "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/require" ) @@ -71,12 +71,7 @@ type FakeSource struct { interval time.Duration } -func (f *FakeSource) UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = xxh.Write(input) - return err -} - -func (f *FakeSource) Start(ctx *Context, input []byte, updater SubscriptionUpdater) error { +func (f *FakeSource) Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error { go func() { for i, u := range f.updates { updater.Update([]byte(u)) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 71a3c5304..a429087d0 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -14,10 +14,10 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" "github.com/pkg/errors" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" "golang.org/x/sync/errgroup" "github.com/wundergraph/astjson" @@ -1420,7 +1420,8 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, preparedInput := bytes.NewBuffer(make([]byte, 0, 64)) itemInput := bytes.NewBuffer(make([]byte, 0, 32)) - keyGen := xxhash.New() + keyGen := pool.Hash64.Get() + defer pool.Hash64.Put(keyGen) var undefinedVariables []string @@ -1590,18 +1591,33 @@ func GetOperationTypeFromContext(ctx context.Context) ast.OperationType { return ast.OperationTypeQuery } +func (l *Loader) headersForSubgraphRequest(fetchItem *FetchItem) (http.Header, uint64) { + if fetchItem == nil || fetchItem.Fetch == nil { + return nil, 0 + } + info := fetchItem.Fetch.FetchInfo() + if info == nil { + return nil, 0 + } + return l.ctx.HeadersForSubgraphRequest(info.DataSourceName) +} + func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem *FetchItem, input []byte, res *result) error { if l.info != nil { ctx = context.WithValue(ctx, operationTypeContextKey, l.info.OperationType) } - if l.info == nil || l.info.OperationType == ast.OperationTypeMutation { + headers, extraKey := l.headersForSubgraphRequest(fetchItem) + + if l.info == nil || + l.info.OperationType == ast.OperationTypeMutation || + l.ctx.ExecutionOptions.DisableRequestDeduplication { // Disable single flight for mutations - return l.loadByContextDirect(ctx, source, input, res) + return l.loadByContextDirect(ctx, source, headers, input, res) } - sfKey, fetchKey, item, shared := l.sf.GetOrCreateItem(ctx, fetchItem, input) + sfKey, fetchKey, item, shared := l.sf.GetOrCreateItem(fetchItem, input, extraKey) if res.singleFlightStats != nil { res.singleFlightStats.used = shared res.singleFlightStats.shared = shared @@ -1627,7 +1643,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem defer l.sf.Finish(sfKey, fetchKey, item) // Perform the actual load - err := l.loadByContextDirect(ctx, source, input, res) + err := l.loadByContextDirect(ctx, source, headers, input, res) if err != nil { item.err = err return err @@ -1637,11 +1653,11 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem return nil } -func (l *Loader) loadByContextDirect(ctx context.Context, source DataSource, input []byte, res *result) error { +func (l *Loader) loadByContextDirect(ctx context.Context, source DataSource, headers http.Header, input []byte, res *result) error { if l.ctx.Files != nil { - res.out, res.err = source.LoadWithFiles(ctx, input, l.ctx.Files) + res.out, res.err = source.LoadWithFiles(ctx, headers, input, l.ctx.Files) } else { - res.out, res.err = source.Load(ctx, input) + res.out, res.err = source.Load(ctx, headers, input) } if res.err != nil { return errors.WithStack(res.err) diff --git a/v2/pkg/engine/resolve/loader_hooks_test.go b/v2/pkg/engine/resolve/loader_hooks_test.go index d82857598..ebe263dcd 100644 --- a/v2/pkg/engine/resolve/loader_hooks_test.go +++ b/v2/pkg/engine/resolve/loader_hooks_test.go @@ -3,6 +3,7 @@ package resolve import ( "bytes" "context" + "net/http" "sync" "sync/atomic" "testing" @@ -49,8 +50,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("simple fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := Context{ @@ -121,8 +122,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := &Context{ @@ -187,8 +188,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("parallel fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := &Context{ @@ -250,8 +251,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("parallel list item fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := Context{ @@ -313,8 +314,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("fetch with subgraph error and custom extension code. No extension fields are propagated by default", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) resolveCtx := Context{ @@ -376,8 +377,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Propagate only extension code field from subgraph errors", testFnSubgraphErrorsWithExtensionFieldCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","foo":"bar"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ @@ -411,8 +412,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Propagate all extension fields from subgraph errors when allow all option is enabled", testFnSubgraphErrorsWithAllowAllExtensionFields(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","foo":"bar"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ @@ -446,8 +447,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName extension field", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ @@ -481,8 +482,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName when extensions is null", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":null},{"message":"errorMessage2","extensions":null}]}`), nil }) return &GraphQLResponse{ @@ -516,8 +517,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName when extensions is an empty object", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{}},{"message":"errorMessage2","extensions":null}]}`), nil }) return &GraphQLResponse{ @@ -551,8 +552,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when no code field was set", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ @@ -586,8 +587,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when extensions is null", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":null},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ @@ -621,8 +622,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when extensions is an empty object", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{}},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ diff --git a/v2/pkg/engine/resolve/loader_test.go b/v2/pkg/engine/resolve/loader_test.go index 0fe38ddc7..d6c002393 100644 --- a/v2/pkg/engine/resolve/loader_test.go +++ b/v2/pkg/engine/resolve/loader_test.go @@ -296,7 +296,7 @@ func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctrl.Finish() out := fastjsonext.PrintGraphQLResponse(resolvable.data, resolvable.errors) assert.NoError(t, err) - expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` + expected := `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` assert.Equal(t, expected, out) } @@ -758,7 +758,7 @@ func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { ctrl.Finish() out := fastjsonext.PrintGraphQLResponse(resolvable.data, resolvable.errors) assert.NoError(t, err) - expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` + expected := `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` assert.Equal(t, expected, out) } diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 90b534174..107f0cb79 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "io" + "net/http" "time" "github.com/buger/jsonparser" @@ -707,14 +708,16 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) asyncDataSource = async } + headers, _ := r.triggerHeaders(add.ctx, add.sourceName) + go func() { if r.options.Debug { fmt.Printf("resolver:trigger:start:%d\n", triggerID) } if asyncDataSource != nil { - err = asyncDataSource.AsyncStart(cloneCtx, triggerID, add.input, updater) + err = asyncDataSource.AsyncStart(cloneCtx, triggerID, headers, add.input, updater) } else { - err = add.resolve.Trigger.Source.Start(cloneCtx, add.input, updater) + err = add.resolve.Trigger.Source.Start(cloneCtx, headers, add.input, updater) } if err != nil { if r.options.Debug { @@ -1057,6 +1060,13 @@ func (r *Resolver) AsyncUnsubscribeClient(connectionID int64) error { return nil } +func (r *Resolver) triggerHeaders(ctx *Context, sourceName string) (http.Header, uint64) { + if ctx.SubgraphHeadersBuilder != nil { + return ctx.SubgraphHeadersBuilder.HeadersForSubgraph(sourceName) + } + return nil, 0 +} + func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQLSubscription, writer SubscriptionResponseWriter) error { if subscription.Trigger.Source == nil { return errors.New("no data source found") @@ -1094,14 +1104,14 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ return nil } + _, headersHash := r.triggerHeaders(ctx, subscription.Trigger.SourceName) + xxh := pool.Hash64.Get() - defer pool.Hash64.Put(xxh) - err = subscription.Trigger.Source.UniqueRequestID(ctx, input, xxh) - if err != nil { - msg := []byte(`{"errors":[{"message":"unable to resolve"}]}`) - return writeFlushComplete(writer, msg) - } - uniqueID := xxh.Sum64() + _, _ = xxh.Write(input) + // the hash for subgraph headers is pre-computed + // we can just add it to the input hash to get a unique id + uniqueID := xxh.Sum64() + headersHash + pool.Hash64.Put(xxh) id := SubscriptionIdentifier{ ConnectionID: ConnectionIDs.Inc(), SubscriptionID: 0, @@ -1120,12 +1130,13 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ triggerID: uniqueID, kind: subscriptionEventKindAddSubscription, addSubscription: &addSubscription{ - ctx: ctx, - input: input, - resolve: subscription, - writer: writer, - id: id, - completed: completed, + ctx: ctx, + input: input, + resolve: subscription, + writer: writer, + id: id, + completed: completed, + sourceName: subscription.Trigger.SourceName, }, }: } @@ -1203,13 +1214,14 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G return nil } + _, headersHash := r.triggerHeaders(ctx, subscription.Trigger.SourceName) + xxh := pool.Hash64.Get() - defer pool.Hash64.Put(xxh) - err = subscription.Trigger.Source.UniqueRequestID(ctx, input, xxh) - if err != nil { - msg := []byte(`{"errors":[{"message":"unable to resolve"}]}`) - return writeFlushComplete(writer, msg) - } + _, _ = xxh.Write(input) + // the hash for subgraph headers is pre-computed + // we can just add it to the input hash to get a unique id + uniqueID := xxh.Sum64() + headersHash + pool.Hash64.Put(xxh) select { case <-r.ctx.Done(): @@ -1219,15 +1231,16 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // Stop resolving if the client is gone return ctx.ctx.Err() case r.events <- subscriptionEvent{ - triggerID: xxh.Sum64(), + triggerID: uniqueID, kind: subscriptionEventKindAddSubscription, addSubscription: &addSubscription{ - ctx: ctx, - input: input, - resolve: subscription, - writer: writer, - id: id, - completed: make(chan struct{}), + ctx: ctx, + input: input, + resolve: subscription, + writer: writer, + id: id, + completed: make(chan struct{}), + sourceName: subscription.Trigger.SourceName, }, }: } @@ -1335,12 +1348,13 @@ type subscriptionEvent struct { } type addSubscription struct { - ctx *Context - input []byte - resolve *GraphQLSubscription - writer SubscriptionResponseWriter - id SubscriptionIdentifier - completed chan struct{} + ctx *Context + input []byte + resolve *GraphQLSubscription + writer SubscriptionResponseWriter + id SubscriptionIdentifier + completed chan struct{} + sourceName string } type subscriptionEventKind int diff --git a/v2/pkg/engine/resolve/resolve_federation_test.go b/v2/pkg/engine/resolve/resolve_federation_test.go index 64d969c6c..1c32db689 100644 --- a/v2/pkg/engine/resolve/resolve_federation_test.go +++ b/v2/pkg/engine/resolve/resolve_federation_test.go @@ -2,6 +2,7 @@ package resolve import ( "context" + "net/http" "testing" "github.com/golang/mock/gomock" @@ -19,8 +20,8 @@ func mockedDS(t TestingTB, ctrl *gomock.Controller, expectedInput, responseData t.Helper() service := NewMockDataSource(ctrl) service.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { require.Equal(t, expectedInput, string(input)) return []byte(responseData), nil }).Times(1) @@ -173,8 +174,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("federation with shareable", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { firstService := NewMockDataSource(ctrl) firstService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://first.service","body":{"query":"{me {details {forename middlename} __typename id}}"}}` assert.Equal(t, expected, actual) @@ -185,8 +186,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { secondService := NewMockDataSource(ctrl) secondService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://second.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {details {surname}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) @@ -197,8 +198,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { thirdService := NewMockDataSource(ctrl) thirdService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://third.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {details {age}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) @@ -368,8 +369,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name infoOrAddress { ... on Info {id __typename} ... on Address {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -380,8 +381,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age } ... on Address { line1 }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":55,"__typename":"Address"}]}}}` assert.Equal(t, expected, actual) @@ -521,8 +522,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name infoOrAddress { ... on Info {id __typename} ... on Address {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -533,7 +534,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -666,8 +667,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching on a field", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -678,8 +679,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":12,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) @@ -810,8 +811,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with duplicates", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -822,8 +823,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) @@ -951,8 +952,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with null entry", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -963,8 +964,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) @@ -1096,8 +1097,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with all null entries", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -1108,7 +1109,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1234,8 +1235,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with render error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -1247,8 +1248,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":12,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) @@ -1381,8 +1382,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("all data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -1393,8 +1394,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) @@ -1515,8 +1516,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("null info data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -1527,7 +1528,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1643,8 +1644,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("wrong type data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -1655,7 +1656,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1771,8 +1772,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("not matching type data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -1783,7 +1784,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ diff --git a/v2/pkg/engine/resolve/resolve_mock_test.go b/v2/pkg/engine/resolve/resolve_mock_test.go index d493ff4bd..a64b7dd83 100644 --- a/v2/pkg/engine/resolve/resolve_mock_test.go +++ b/v2/pkg/engine/resolve/resolve_mock_test.go @@ -6,6 +6,7 @@ package resolve import ( context "context" + http "net/http" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -36,31 +37,31 @@ func (m *MockDataSource) EXPECT() *MockDataSourceMockRecorder { } // Load mocks base method. -func (m *MockDataSource) Load(arg0 context.Context, arg1 []byte) ([]byte, error) { +func (m *MockDataSource) Load(arg0 context.Context, arg1 http.Header, arg2 []byte) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Load", arg0, arg1) + ret := m.ctrl.Call(m, "Load", arg0, arg1, arg2) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // Load indicates an expected call of Load. -func (mr *MockDataSourceMockRecorder) Load(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockDataSourceMockRecorder) Load(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockDataSource)(nil).Load), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockDataSource)(nil).Load), arg0, arg1, arg2) } // LoadWithFiles mocks base method. -func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 []byte, arg2 []*httpclient.FileUpload) ([]byte, error) { +func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 http.Header, arg2 []byte, arg3 []*httpclient.FileUpload) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadWithFiles", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "LoadWithFiles", arg0, arg1, arg2, arg3) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // LoadWithFiles indicates an expected call of LoadWithFiles. -func (mr *MockDataSourceMockRecorder) LoadWithFiles(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockDataSourceMockRecorder) LoadWithFiles(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWithFiles", reflect.TypeOf((*MockDataSource)(nil).LoadWithFiles), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWithFiles", reflect.TypeOf((*MockDataSource)(nil).LoadWithFiles), arg0, arg1, arg2, arg3) } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index d19156f36..5c2ea4ed6 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -13,7 +13,6 @@ import ( "testing" "time" - "github.com/cespare/xxhash/v2" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -32,7 +31,7 @@ type _fakeDataSource struct { artificialLatency time.Duration } -func (f *_fakeDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (f *_fakeDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -44,7 +43,7 @@ func (f *_fakeDataSource) Load(ctx context.Context, input []byte) (data []byte, return f.data, nil } -func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -349,8 +348,8 @@ func TestResolver_ResolveNode(t *testing.T) { t.Run("fetch with context variable resolver", testFn(true, func(t *testing.T, ctrl *gomock.Controller) (response *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), []byte(`{"id":1}`)). - Do(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), []byte(`{"id":1}`)). + Do(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"name":"Jens"}`), nil }). Return([]byte(`{"name":"Jens"}`), nil) @@ -1799,8 +1798,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error without datasource ID", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -1829,8 +1828,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error without datasource ID no subgraph error forwarding", testFnNoSubgraphErrorForwarding(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -1859,8 +1858,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -1893,8 +1892,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error in pass through Subgraph Error Mode", testFnSubgraphErrorsPassthrough(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -1927,8 +1926,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with pass through mode and omit custom fields", testFnSubgraphErrorsPassthroughAndOmitCustomFields(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","longMessage":"This is a long message","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":{"name":null}}`), nil }) return &GraphQLResponse{ @@ -1964,8 +1963,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err (with DataSourceID)", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return nil, &net.AddrError{} }) return &GraphQLResponse{ @@ -1998,8 +1997,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err (no DataSourceID)", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return nil, &net.AddrError{} }) return &GraphQLResponse{ @@ -2028,8 +2027,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err and non-nullable root field", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return nil, &net.AddrError{} }) return &GraphQLResponse{ @@ -2206,8 +2205,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with two Errors", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage1"},{"message":"errorMessage2"}]}`), nil }).Times(1) return &GraphQLResponse{ @@ -2562,8 +2561,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("complex GraphQL Server plan", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { serviceOne := NewMockDataSource(ctrl) serviceOne.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.one","body":{"query":"query($firstArg: String, $thirdArg: Int){serviceOne(serviceOneArg: $firstArg){fieldOne} anotherServiceOne(anotherServiceOneArg: $thirdArg){fieldOne} reusingServiceOne(reusingServiceOneArg: $firstArg){fieldOne}}","variables":{"thirdArg":123,"firstArg":"firstArgValue"}}}` assert.Equal(t, expected, actual) @@ -2572,8 +2571,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { serviceTwo := NewMockDataSource(ctrl) serviceTwo.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.two","body":{"query":"query($secondArg: Boolean, $fourthArg: Float){serviceTwo(serviceTwoArg: $secondArg){fieldTwo} secondServiceTwo(secondServiceTwoArg: $fourthArg){fieldTwo}}","variables":{"fourthArg":12.34,"secondArg":true}}}` assert.Equal(t, expected, actual) @@ -2582,8 +2581,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { nestedServiceOne := NewMockDataSource(ctrl) nestedServiceOne.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.one","body":{"query":"{serviceOne {fieldOne}}"}}` assert.Equal(t, expected, actual) @@ -2798,8 +2797,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -2808,8 +2807,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) @@ -2820,8 +2819,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) productServiceCallCount.Add(1) switch actual { @@ -3005,8 +3004,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with batch", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -3015,8 +3014,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) @@ -3025,8 +3024,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) @@ -3202,8 +3201,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with merge paths", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -3212,8 +3211,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) @@ -3222,8 +3221,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) @@ -3400,8 +3399,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with null response", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -3410,8 +3409,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) @@ -3427,8 +3426,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"},{"upc":"top-4","__typename":"Product"},{"upc":"top-5","__typename":"Product"},{"upc":"top-6","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) @@ -3627,8 +3626,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -3637,8 +3636,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) @@ -3647,8 +3646,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) @@ -3814,8 +3813,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -3824,8 +3823,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) @@ -3834,8 +3833,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) @@ -3998,8 +3997,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with optional variable", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8080/query","body":{"query":"{me {id}}"}}` assert.Equal(t, expected, actual) @@ -4008,8 +4007,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { employeeService := NewMockDataSource(ctrl) employeeService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8081/query","body":{"query":"query($representations: [_Any!]!, $companyId: ID!){_entities(representations: $representations){... on User {employment(companyId: $companyId){id}}}}","variables":{"companyId":"abc123","representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) @@ -4018,8 +4017,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { timeService := NewMockDataSource(ctrl) timeService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8082/query","body":{"query":"query($representations: [_Any!]!, $date: LocalTime){_entities(representations: $representations){... on Employee {times(date: $date){id employee {id} start end}}}}","variables":{"date":null,"representations":[{"id":"xyz987","__typename":"Employee"}]}}}` assert.Equal(t, expected, actual) @@ -4538,8 +4537,8 @@ func TestResolver_ArenaResolveGraphQLResponse(t *testing.T) { t.Run("with variables", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), []byte(`{"id":1}`)). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), []byte(`{"id":1}`)). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"name":"Jens"}`), nil }) return &GraphQLResponse{ @@ -4582,8 +4581,8 @@ func TestResolver_ArenaResolveGraphQLResponse(t *testing.T) { t.Run("error handling", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return nil, errors.New("data source error") }) return &GraphQLResponse{ @@ -4659,8 +4658,8 @@ func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { t.Run("simple fetch with fetch error suppression - empty response", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte("{}"), nil }) return &GraphQLResponse{ @@ -4697,8 +4696,8 @@ func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { t.Run("simple fetch with fetch error suppression - response with error", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}]}`), nil }) return &GraphQLResponse{ @@ -4735,8 +4734,8 @@ func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { t.Run("complex fetch with fetch error suppression", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -4745,8 +4744,8 @@ func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) @@ -4755,8 +4754,8 @@ func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) @@ -4946,8 +4945,8 @@ func TestResolver_WithHeader(t *testing.T) { ctrl := gomock.NewController(t) fakeService := NewMockDataSource(ctrl) fakeService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) assert.Equal(t, "foo", actual) return []byte(`{"bar":"baz"}`), nil @@ -5017,8 +5016,8 @@ func TestResolver_WithVariableRemapping(t *testing.T) { ctrl := gomock.NewController(t) fakeService := NewMockDataSource(ctrl) fakeService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) assert.Equal(t, tc.expectedOutput, actual) return []byte(`{"bar":"baz"}`), nil @@ -5203,16 +5202,7 @@ func (f *_fakeStream) AwaitIsDone(t *testing.T, timeout time.Duration) { } } -func (f *_fakeStream) UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = fmt.Fprint(xxh, fakeStreamRequestId.Add(1)) - if err != nil { - return - } - _, err = xxh.Write(input) - return -} - -func (f *_fakeStream) Start(ctx *Context, input []byte, updater SubscriptionUpdater) error { +func (f *_fakeStream) Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error { if f.onStart != nil { f.onStart(input) } diff --git a/v2/pkg/engine/resolve/response.go b/v2/pkg/engine/resolve/response.go index b98f4c00f..c02d92f49 100644 --- a/v2/pkg/engine/resolve/response.go +++ b/v2/pkg/engine/resolve/response.go @@ -16,12 +16,13 @@ type GraphQLSubscription struct { } type GraphQLSubscriptionTrigger struct { - Input []byte - InputTemplate InputTemplate - Variables Variables - Source SubscriptionDataSource - PostProcessing PostProcessingConfiguration - QueryPlan *QueryPlan + Input []byte + InputTemplate InputTemplate + Variables Variables + Source SubscriptionDataSource + PostProcessing PostProcessingConfiguration + QueryPlan *QueryPlan + SourceName, SourceID string } // GraphQLResponse contains an ordered tree of fetches and the response shape. diff --git a/v2/pkg/engine/resolve/singleflight.go b/v2/pkg/engine/resolve/singleflight.go index e29853196..a17960249 100644 --- a/v2/pkg/engine/resolve/singleflight.go +++ b/v2/pkg/engine/resolve/singleflight.go @@ -1,7 +1,6 @@ package resolve import ( - "context" "sync" "github.com/cespare/xxhash/v2" @@ -41,8 +40,8 @@ func NewSingleFlight() *SingleFlight { } } -func (s *SingleFlight) GetOrCreateItem(ctx context.Context, fetchItem *FetchItem, input []byte) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { - sfKey, fetchKey = s.keys(fetchItem, input) +func (s *SingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { + sfKey, fetchKey = s.keys(fetchItem, input, extraKey) // First, try to get the item with a read lock s.mu.RLock() @@ -73,9 +72,9 @@ func (s *SingleFlight) GetOrCreateItem(ctx context.Context, fetchItem *FetchItem return sfKey, fetchKey, item, false } -func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte) (sfKey, fetchKey uint64) { +func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64) { h := s.xxPool.Get().(*xxhash.Digest) - sfKey = s.sfKey(h, fetchItem, input) + sfKey = s.sfKey(h, fetchItem, input, extraKey) h.Reset() fetchKey = s.fetchKey(h, fetchItem) h.Reset() @@ -83,7 +82,7 @@ func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte) (sfKey, fetchKey return sfKey, fetchKey } -func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte) uint64 { +func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { if fetchItem != nil && fetchItem.Fetch != nil { info := fetchItem.Fetch.FetchInfo() if info != nil { @@ -92,7 +91,7 @@ func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byt } } _, _ = h.Write(input) - return h.Sum64() + return h.Sum64() + extraKey } func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { From 26f22b33f89c94a6ec64682d870b45600bfae244 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 24 Oct 2025 18:56:18 +0200 Subject: [PATCH 22/61] chore: rename HeadersForSubgraphRequest to SubgraphHeadersBuilder --- v2/pkg/engine/resolve/context.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index b0b82f578..dd4f32e8c 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -33,10 +33,10 @@ type Context struct { subgraphErrors error - SubgraphHeadersBuilder HeadersForSubgraphRequest + SubgraphHeadersBuilder SubgraphHeadersBuilder } -type HeadersForSubgraphRequest interface { +type SubgraphHeadersBuilder interface { HeadersForSubgraph(subgraphName string) (http.Header, uint64) } From 4392770d09eb34630a0e10666d693fdfdd118780 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 24 Oct 2025 18:56:41 +0200 Subject: [PATCH 23/61] chore: fix bug --- v2/pkg/engine/resolve/loader.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index a429087d0..826976809 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1851,7 +1851,7 @@ func (l *Loader) compactJSON(data []byte) ([]byte, error) { return nil, err } out := dst.Bytes() - v, err := astjson.ParseBytesWithArena(l.jsonArena, out) + v, err := astjson.ParseBytes(out) if err != nil { return nil, err } From 94f3d27c578ebd85761b6f97c482675c4b778b96 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 13:30:24 +0200 Subject: [PATCH 24/61] chore: use are to execute subscription updates --- v2/pkg/engine/resolve/resolve.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 107f0cb79..5acfc6aad 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -480,7 +480,12 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) + resolveArena := r.resolveArenaPool.Acquire(resolveCtx.Request.ID) + t.loader.jsonArena = resolveArena.Arena + t.resolvable.astjsonArena = resolveArena.Arena + if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { + r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:init:failed:%d\n", sub.id.SubscriptionID) @@ -492,6 +497,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } if err := t.loader.LoadGraphQLResponseData(resolveCtx, sub.resolve.Response, t.resolvable); err != nil { + r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:load:failed:%d\n", sub.id.SubscriptionID) @@ -503,6 +509,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } if err := t.resolvable.Resolve(resolveCtx.ctx, sub.resolve.Response.Data, sub.resolve.Response.Fetches, sub.writer); err != nil { + r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:resolve:failed:%d\n", sub.id.SubscriptionID) @@ -513,6 +520,8 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar return } + r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) + if err := sub.writer.Flush(); err != nil { // If flush fails (e.g. client disconnected), remove the subscription. _ = r.AsyncUnsubscribeSubscription(sub.id) From e7407d1fd2a3023989eb572fe30ef9bad4d694d5 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 15:36:20 +0200 Subject: [PATCH 25/61] chore: merge main --- v2/go.sum | 2 -- 1 file changed, 2 deletions(-) diff --git a/v2/go.sum b/v2/go.sum index 690d15a88..5a7781e3a 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -134,8 +134,6 @@ github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= From 60b5c3b390d0af3e1e54a71fd669138744e72689 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 20:59:31 +0200 Subject: [PATCH 26/61] chore: update deps --- v2/go.mod | 9 ++------- v2/go.sum | 4 ++++ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/v2/go.mod b/v2/go.mod index 308eea534..43ada453b 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -28,8 +28,8 @@ require ( github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/vektah/gqlparser/v2 v2.5.30 - github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 - github.com/wundergraph/go-arena v0.0.1 + github.com/wundergraph/astjson v1.0.0 + github.com/wundergraph/go-arena v1.0.0 go.uber.org/atomic v1.11.0 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.26.0 @@ -79,8 +79,3 @@ require ( ) tool github.com/99designs/gqlgen - -replace ( - github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 => ../../wundergraph-projects/astjson - github.com/wundergraph/go-arena v0.0.1 => ../../wundergraph-projects/go-arena -) diff --git a/v2/go.sum b/v2/go.sum index 5a7781e3a..6d0fb3636 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -134,6 +134,10 @@ github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= +github.com/wundergraph/astjson v1.0.0 h1:rETLJuQkMWWW03HCF6WBttEBOu8gi5vznj5KEUPVV2Q= +github.com/wundergraph/astjson v1.0.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= +github.com/wundergraph/go-arena v1.0.0 h1:RVYWpDkJ1/6851BRHYehBeEcTLKmZygYIZsvBorcOjw= +github.com/wundergraph/go-arena v1.0.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= From 3fb0272893d828d5a574d43396e8278b250a5ef8 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 21:45:26 +0200 Subject: [PATCH 27/61] chore: add comments --- .../graphql_subscription_client_test.go | 2 +- .../datasource/httpclient/nethttpclient.go | 14 ++++++ v2/pkg/engine/resolve/context.go | 7 +++ v2/pkg/engine/resolve/loader.go | 49 ++++++++++++++++--- v2/pkg/engine/resolve/resolvable.go | 7 ++- v2/pkg/engine/resolve/resolve.go | 9 +++- 6 files changed, 78 insertions(+), 10 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 25eaa29f7..86dd57c03 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -2437,7 +2437,7 @@ func TestWebSocketUpgradeFailures(t *testing.T) { w.Header().Set(key, value) } w.WriteHeader(tc.statusCode) - fmt.Fprintf(w, `{"error": "WebSocket upgrade failed", "status": %d}`, tc.statusCode) + _, _ = fmt.Fprintf(w, `{"error": "WebSocket upgrade failed", "status": %d}`, tc.statusCode) })) defer server.Close() diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 4c4f2de3d..c4ce9915f 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -136,6 +136,9 @@ const ( sizeHintKey httpClientContext = "size-hint" ) +// WithHTTPClientSizeHint allows the engine to keep track of response sizes per subgraph fetch +// If a hint is supplied, we can create a buffer of size close to the required size +// This reduces allocations by reducing the buffer grow calls, which always copies the buffer func WithHTTPClientSizeHint(ctx context.Context, size int) context.Context { return context.WithValue(ctx, sizeHintKey, size) } @@ -144,6 +147,9 @@ func buffer(ctx context.Context) *bytes.Buffer { if sizeHint, ok := ctx.Value(sizeHintKey).(int); ok && sizeHint > 0 { return bytes.NewBuffer(make([]byte, 0, sizeHint)) } + // if we start with zero, doubling will take a while until we reach the required size + // if we start with a high number, e.g. 1024, we just increase the memory usage of the engine + // 64 seems to be a healthy middle ground return bytes.NewBuffer(make([]byte, 0, 64)) } @@ -211,6 +217,8 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, baseHeaders http. request.Header.Set(AcceptEncodingHeader, EncodingGzip) request.Header.Add(AcceptEncodingHeader, EncodingDeflate) if contentLength > 0 { + // always set the Content-Length Header so that chunking can be avoided + // and other parties can more efficiently parse request.Header.Set(ContentLengthHeader, fmt.Sprintf("%d", contentLength)) } @@ -229,6 +237,12 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, baseHeaders http. return nil, err } + // we intentionally don't use a pool of sorts here + // we're buffering the response and then later, in the engine, + // parse it into an JSON AST with the use of an arena, which is quite efficient + // Through trial and error it turned out that it's best to leave this buffer to the GC + // It'll know best the lifecycle of the buffer + // Using an arena here just increased overall memory usage out := buffer(ctx) _, err = out.ReadFrom(respReader) if err != nil { diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index dd4f32e8c..fdb2ebb58 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -36,10 +36,17 @@ type Context struct { SubgraphHeadersBuilder SubgraphHeadersBuilder } +// SubgraphHeadersBuilder allows the user of the engine to "define" the headers for a subgraph request +// Instead of going back and forth between engine & transport, +// you can simply define a function that returns headers for a Subgraph request +// In addition to just the header, the implementer can return a hash for the header which will be used by request deduplication type SubgraphHeadersBuilder interface { + // HeadersForSubgraph must return the headers and a hash for a Subgraph Request + // The hash will be used for request deduplication HeadersForSubgraph(subgraphName string) (http.Header, uint64) } +// HeadersForSubgraphRequest returns headers and a hash for a request that the engine will make to a subgraph func (c *Context) HeadersForSubgraphRequest(subgraphName string) (http.Header, uint64) { if c.SubgraphHeadersBuilder == nil { return nil, 0 diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 826976809..4b22dbbcf 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -137,8 +137,9 @@ type result struct { loaderHookContext context.Context httpResponseContext *httpclient.ResponseContext - out []byte - singleFlightStats *singleFlightStats + // out is the subgraph response body + out []byte + singleFlightStats *singleFlightStats } func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { @@ -182,6 +183,14 @@ type Loader struct { taintedObjs taintedObjects + // jsonArena is the arena to allocation json, supplied by the Resolver + // Disclaimer: this arena is NOT thread safe! + // Only use from main goroutine + // Don't Reset or Release, the Resolver handles this + // Disclaimer: When parsing json into the arena, the underlying bytes must also be allocated on the arena! + // This is very important to "tie" their lifecycles together + // If you're not doing this, you will see segfaults + // Example of correct usage in func "mergeResult" jsonArena arena.Arena sf *SingleFlight } @@ -773,9 +782,11 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V return err } } - - // If the error propagation mode is pass-through, we append the errors to the root array + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() + // If the error propagation mode is pass-through, we append the errors to the root array l.resolvable.errors.AppendArrayItems(value) return nil } @@ -811,7 +822,9 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V if err := l.addApolloRouterCompatibilityError(res); err != nil { return err } - + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) @@ -1066,7 +1079,9 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { if err != nil { return err } - + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, apolloRouterStatusError) @@ -1081,6 +1096,9 @@ func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil @@ -1093,6 +1111,9 @@ func (l *Loader) renderErrorsFailedToFetch(fetchItem *FetchItem, res *result, re return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil @@ -1112,6 +1133,9 @@ func (l *Loader) renderErrorsStatusFallback(fetchItem *FetchItem, res *result, s } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil @@ -1137,6 +1161,9 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re } pathPart := l.renderAtPathErrorPart(fetchItem.ResponsePath) extensionErrorCode := fmt.Sprintf(`"extensions":{"code":"%s"}`, errorcodes.UnauthorizedFieldOrType) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() if res.ds.Name == "" { for _, reason := range res.authorizationRejectedReasons { @@ -1216,6 +1243,9 @@ func (l *Loader) renderRateLimitRejectedErrors(fetchItem *FetchItem, res *result return err } } + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil @@ -1417,7 +1447,7 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, } } } - + // I tried using arena here but it only worsened the situation preparedInput := bytes.NewBuffer(make([]byte, 0, 64)) itemInput := bytes.NewBuffer(make([]byte, 0, 32)) keyGen := pool.Hash64.Get() @@ -1579,6 +1609,7 @@ const ( operationTypeContextKey loaderContextKey = "operationType" ) +// GetOperationTypeFromContext can be used, e.g. by the transport, to check if the operation is a Mutation func GetOperationTypeFromContext(ctx context.Context) ast.OperationType { if ctx == nil { return ast.OperationTypeQuery @@ -1638,6 +1669,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem return nil } + // helps the http client to create buffers at the right size ctx = httpclient.WithHTTPClientSizeHint(ctx, item.sizeHint) defer l.sf.Finish(sfKey, fetchKey, item) @@ -1851,6 +1883,9 @@ func (l *Loader) compactJSON(data []byte) ([]byte, error) { return nil, err } out := dst.Bytes() + // don't use arena here or segfault + // it's also not a hot path and not important to optimize + // arena requires the parsed content to be on the arena as well v, err := astjson.ParseBytes(out) if err != nil { return nil, err diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index 21470f475..cbd1df5ea 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -31,7 +31,8 @@ type Resolvable struct { errors *astjson.Value valueCompletion *astjson.Value skipAddingNullErrors bool - + // astjsonArena is the arena to handle json, supplied by Resolver + // not thread safe, but Resolvable is single threaded anyways astjsonArena arena.Arena parsers []*astjson.Parser @@ -111,6 +112,7 @@ func (r *Resolvable) Init(ctx *Context, initialData []byte, operationType ast.Op r.operationType = operationType r.renameTypeNames = ctx.RenameTypeNames r.data = astjson.ObjectValue(r.astjsonArena) + // don't init errors! It will heavily increase memory usage r.errors = nil if initialData != nil { initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) @@ -129,6 +131,7 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc r.ctx = ctx r.operationType = ast.OperationTypeSubscription r.renameTypeNames = ctx.RenameTypeNames + // don't init errors! It will heavily increase memory usage r.errors = nil if initialData != nil { initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) @@ -167,6 +170,7 @@ func (r *Resolvable) ResolveNode(node Node, data *astjson.Value, out io.Writer) r.print = false r.printErr = nil r.authorizationError = nil + // don't init errors! It will heavily increase memory usage r.errors = nil hasErrors := r.walkNode(node, data) @@ -233,6 +237,7 @@ func (r *Resolvable) Resolve(ctx context.Context, rootData *Object, fetchTree *F return r.printErr } +// ensureErrorsInitialized is used to lazily init r.errors if needed func (r *Resolvable) ensureErrorsInitialized() { if r.errors == nil { r.errors = astjson.ArrayValue(r.astjsonArena) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 5acfc6aad..39b1a3bea 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -72,7 +72,13 @@ type Resolver struct { // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration - resolveArenaPool *ArenaPool + // resolveArenaPool is the arena pool dedicated for Loader & Resolvable + // ArenaPool automatically adjusts arena buffer sizes per workload + // resolving & response buffering are very different tasks + // as such, it was best to have two arena pools in terms of memory usage + // A single pool for both was much less efficient + resolveArenaPool *ArenaPool + // responseBufferPool is the arena pool dedicated for response buffering before sending to the client responseBufferPool *ArenaPool // Single flight cache for deduplicating requests across all loaders @@ -246,6 +252,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SingleFlight) *tools { return &tools{ + // we set the arena manually resolvable: NewResolvable(nil, options.ResolvableOptions), loader: &Loader{ propagateSubgraphErrors: options.PropagateSubgraphErrors, From bb33b4b527ada0a4622d1aa1a63927448595d9e7 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 21:48:27 +0200 Subject: [PATCH 28/61] chore: set content length correctly --- v2/pkg/engine/datasource/httpclient/nethttpclient.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index c4ce9915f..c5f53c7e0 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -219,7 +219,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, baseHeaders http. if contentLength > 0 { // always set the Content-Length Header so that chunking can be avoided // and other parties can more efficiently parse - request.Header.Set(ContentLengthHeader, fmt.Sprintf("%d", contentLength)) + request.ContentLength = int64(contentLength) } setRequest(ctx, request) From bb31735c6849ba3340a0ded2440450d6d1ea84a4 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 22:12:22 +0200 Subject: [PATCH 29/61] chore: fix bench --- v2/pkg/engine/resolve/loader_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v2/pkg/engine/resolve/loader_test.go b/v2/pkg/engine/resolve/loader_test.go index d6c002393..f88d7227f 100644 --- a/v2/pkg/engine/resolve/loader_test.go +++ b/v2/pkg/engine/resolve/loader_test.go @@ -1026,7 +1026,7 @@ func BenchmarkLoader_LoadGraphQLResponseData(b *testing.B) { } resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} - expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` + expected := `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` b.SetBytes(int64(len(expected))) b.ReportAllocs() b.ResetTimer() From ce83a7b763be51b37445310919d2d7241b96fa7e Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 22:18:35 +0200 Subject: [PATCH 30/61] chore: fix lint --- v2/pkg/engine/resolve/inputtemplate.go | 12 +++++++++--- v2/pkg/engine/resolve/loader.go | 3 +-- v2/pkg/engine/resolve/resolvable.go | 3 +-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/v2/pkg/engine/resolve/inputtemplate.go b/v2/pkg/engine/resolve/inputtemplate.go index 80db3cdd8..0ad72ec94 100644 --- a/v2/pkg/engine/resolve/inputtemplate.go +++ b/v2/pkg/engine/resolve/inputtemplate.go @@ -158,14 +158,20 @@ func (i *InputTemplate) renderHeaderVariable(ctx *Context, path []string, prepar return nil } if len(value) == 1 { - preparedInput.WriteString(value[0]) + if _, err := preparedInput.WriteString(value[0]); err != nil { + return err + } return nil } for j := range value { if j != 0 { - _, _ = preparedInput.Write(literal.COMMA) + if _, err := preparedInput.Write(literal.COMMA); err != nil { + return err + } + } + if _, err := preparedInput.WriteString(value[j]); err != nil { + return err } - preparedInput.WriteString(value[j]) } return nil } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 4b22dbbcf..73cef311f 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -17,16 +17,15 @@ import ( "github.com/pkg/errors" "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" "golang.org/x/sync/errgroup" "github.com/wundergraph/astjson" "github.com/wundergraph/go-arena" - "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/errorcodes" "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafebytes" + "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) const ( diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index cbd1df5ea..5879396e7 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -11,10 +11,9 @@ import ( "github.com/cespare/xxhash/v2" "github.com/pkg/errors" "github.com/tidwall/gjson" - "github.com/wundergraph/go-arena" "github.com/wundergraph/astjson" - + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/errorcodes" "github.com/wundergraph/graphql-go-tools/v2/pkg/fastjsonext" From 5cfd72d8da0d3074ab4ffc2d139ec71d706da2bc Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 22:32:35 +0200 Subject: [PATCH 31/61] chore: fix lint --- v2/pkg/engine/datasource/httpclient/nethttpclient.go | 2 +- v2/pkg/engine/resolve/loader.go | 1 + v2/pkg/engine/resolve/resolvable.go | 1 + v2/pkg/engine/resolve/resolve.go | 3 ++- 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index c5f53c7e0..46af845e4 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -217,7 +217,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, baseHeaders http. request.Header.Set(AcceptEncodingHeader, EncodingGzip) request.Header.Add(AcceptEncodingHeader, EncodingDeflate) if contentLength > 0 { - // always set the Content-Length Header so that chunking can be avoided + // always set the ContentLength field so that chunking can be avoided // and other parties can more efficiently parse request.ContentLength = int64(contentLength) } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 73cef311f..340c41894 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -21,6 +21,7 @@ import ( "github.com/wundergraph/astjson" "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/errorcodes" diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index 5879396e7..226705a70 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -14,6 +14,7 @@ import ( "github.com/wundergraph/astjson" "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/errorcodes" "github.com/wundergraph/graphql-go-tools/v2/pkg/fastjsonext" diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 39b1a3bea..2f7bec660 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -12,9 +12,10 @@ import ( "github.com/buger/jsonparser" "github.com/pkg/errors" - "github.com/wundergraph/go-arena" "go.uber.org/atomic" + "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/xcontext" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) From 4d4b4c5f1679eed3e5761a596835d2409f1ace1f Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sun, 26 Oct 2025 08:55:27 +0100 Subject: [PATCH 32/61] chore: cleanup & comments --- v2/pkg/engine/resolve/resolve.go | 21 ++++++------ v2/pkg/engine/resolve/response.go | 15 +++++---- v2/pkg/engine/resolve/singleflight.go | 46 +++++++++++++++++++++------ 3 files changed, 54 insertions(+), 28 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 2f7bec660..3420e9327 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -251,10 +251,9 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { return resolver } -func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SingleFlight) *tools { +func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SingleFlight, a arena.Arena) *tools { return &tools{ - // we set the arena manually - resolvable: NewResolvable(nil, options.ResolvableOptions), + resolvable: NewResolvable(a, options.ResolvableOptions), loader: &Loader{ propagateSubgraphErrors: options.PropagateSubgraphErrors, propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, @@ -271,6 +270,7 @@ func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{ propagateFetchReasons: options.PropagateFetchReasons, validateRequiredExternalFields: options.ValidateRequiredExternalFields, sf: sf, + jsonArena: a, }, } } @@ -289,7 +289,7 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons r.maxConcurrency <- struct{}{} }() - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, nil) err := t.resolvable.Init(ctx, data, response.Info.OperationType) if err != nil { @@ -321,9 +321,9 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe r.maxConcurrency <- struct{}{} }() - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) - resolveArena := r.resolveArenaPool.Acquire(ctx.Request.ID) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, resolveArena.Arena) + t.loader.jsonArena = resolveArena.Arena t.resolvable.astjsonArena = resolveArena.Arena @@ -486,11 +486,8 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar input := make([]byte, len(sharedInput)) copy(input, sharedInput) - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) - resolveArena := r.resolveArenaPool.Acquire(resolveCtx.Request.ID) - t.loader.jsonArena = resolveArena.Arena - t.resolvable.astjsonArena = resolveArena.Arena + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, resolveArena.Arena) if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) @@ -1097,7 +1094,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, nil) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { @@ -1207,7 +1204,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, nil) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { diff --git a/v2/pkg/engine/resolve/response.go b/v2/pkg/engine/resolve/response.go index c02d92f49..1efe078cc 100644 --- a/v2/pkg/engine/resolve/response.go +++ b/v2/pkg/engine/resolve/response.go @@ -16,13 +16,14 @@ type GraphQLSubscription struct { } type GraphQLSubscriptionTrigger struct { - Input []byte - InputTemplate InputTemplate - Variables Variables - Source SubscriptionDataSource - PostProcessing PostProcessingConfiguration - QueryPlan *QueryPlan - SourceName, SourceID string + Input []byte + InputTemplate InputTemplate + Variables Variables + Source SubscriptionDataSource + PostProcessing PostProcessingConfiguration + QueryPlan *QueryPlan + SourceName string + SourceID string } // GraphQLResponse contains an ordered tree of fetches and the response shape. diff --git a/v2/pkg/engine/resolve/singleflight.go b/v2/pkg/engine/resolve/singleflight.go index a17960249..76121d98e 100644 --- a/v2/pkg/engine/resolve/singleflight.go +++ b/v2/pkg/engine/resolve/singleflight.go @@ -6,13 +6,6 @@ import ( "github.com/cespare/xxhash/v2" ) -type SingleFlightItem struct { - loaded chan struct{} - response []byte - err error - sizeHint int -} - type SingleFlight struct { mu *sync.RWMutex items map[uint64]*SingleFlightItem @@ -21,8 +14,26 @@ type SingleFlight struct { cleanup chan func() } +// SingleFlightItem is used to communicate between leader and followers +// If an Item for a key doesn't exist, the leader creates and followers can join +type SingleFlightItem struct { + // loaded will be closed by the leader to indicate to followers when the work is done + loaded chan struct{} + // response is the shared result, it must not be modified + response []byte + // err is non nil if the leader produced an error while doing the work + err error + // sizeHint keeps track of the last 50 responses per fetchKey to give an estimate on the size + // this gives a leader a hint on how much space it should pre-allocate for buffers when fetching + // this reduces memory usage + sizeHint int +} + +// fetchSize gives an estimate of required buffer size for a given fetchKey when dividing totalBytes / count type fetchSize struct { - count int + // count is the number of fetches tracked + count int + // totalBytes is the cumulative bytes across tracked fetches totalBytes int } @@ -40,6 +51,13 @@ func NewSingleFlight() *SingleFlight { } } +// GetOrCreateItem generates a single flight key (100% identical fetches) and a fetchKey (similar fetches, collisions possible but unproblematic) +// and return a SingleFlightItem as well as an indication if it's shared or not +// If shared == false, the caller is a leader +// If shared == true, the caller is a follower +// item.sizeHint can be used to create an optimal buffer for the fetch in case of a leader +// item.err must always be checked +// item.response must never be mutated func (s *SingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { sfKey, fetchKey = s.keys(fetchItem, input, extraKey) @@ -62,6 +80,7 @@ func (s *SingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extra // Create a new item item = &SingleFlightItem{ + // empty chan to indicate to all followers when we're done (close) loaded: make(chan struct{}), } if size, ok := s.sizes[fetchKey]; ok { @@ -82,6 +101,8 @@ func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte, extraKey uint64) return sfKey, fetchKey } +// sfKey returns a key that 100% uniquely identifies a fetch with no collision +// two sfKey are only the same when the fetches are 100% equal func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { if fetchItem != nil && fetchItem.Fetch != nil { info := fetchItem.Fetch.FetchInfo() @@ -91,9 +112,13 @@ func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byt } } _, _ = h.Write(input) - return h.Sum64() + extraKey + return h.Sum64() + extraKey // extraKey in this case is the pre-generated hash for the headers } +// fetchKey is a less robust key compared to sfKey +// the purpose is to create a key from the DataSourceID and root fields to have less cardinality +// the goal is to get an estimate buffer size for similar fetches +// there's no point in hashing headers or the body for this purpose func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { if fetchItem == nil || fetchItem.Fetch == nil { return 0 @@ -115,6 +140,9 @@ func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { return h.Sum64() } +// Finish is for the leader to mark the SingleFlightItem as "done" +// trigger all followers to look at the err & response of the item +// and to update the size estimates func (s *SingleFlight) Finish(sfKey, fetchKey uint64, item *SingleFlightItem) { close(item.loaded) s.mu.Lock() From 48de6512dede3af421afcf66cceb00ac30e74763 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sun, 26 Oct 2025 10:13:09 +0100 Subject: [PATCH 33/61] chore: refactor --- v2/pkg/engine/resolve/resolve.go | 62 ++++++++++++++++---------------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 3420e9327..b5e3ff14b 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -322,11 +322,9 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe }() resolveArena := r.resolveArenaPool.Acquire(ctx.Request.ID) + // we're intentionally not using defer Release to have more control over the timing (see below) t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, resolveArena.Arena) - t.loader.jsonArena = resolveArena.Arena - t.resolvable.astjsonArena = resolveArena.Arena - err := t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) @@ -341,6 +339,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe } } + // only when loading is done, acquire an arena for the response buffer responseArena := r.responseBufferPool.Acquire(ctx.Request.ID) buf := arena.NewArenaBuffer(responseArena.Arena) err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, buf) @@ -350,8 +349,16 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe return nil, err } + // first release resolverArena + // all data is resolved and written into the response arena r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) + // next we write back to the client + // this includes flushing and syscalls + // as such, it can take some time + // which is why we split the arenas and released the first one _, err = writer.Write(buf.Bytes()) + // all data is written to the client + // we're safe to release our buffer r.responseBufferPool.Release(ctx.Request.ID, responseArena) return resp, err } @@ -722,16 +729,14 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) asyncDataSource = async } - headers, _ := r.triggerHeaders(add.ctx, add.sourceName) - go func() { if r.options.Debug { fmt.Printf("resolver:trigger:start:%d\n", triggerID) } if asyncDataSource != nil { - err = asyncDataSource.AsyncStart(cloneCtx, triggerID, headers, add.input, updater) + err = asyncDataSource.AsyncStart(cloneCtx, triggerID, add.headers, add.input, updater) } else { - err = add.resolve.Trigger.Source.Start(cloneCtx, headers, add.input, updater) + err = add.resolve.Trigger.Source.Start(cloneCtx, add.headers, add.input, updater) } if err != nil { if r.options.Debug { @@ -1074,9 +1079,17 @@ func (r *Resolver) AsyncUnsubscribeClient(connectionID int64) error { return nil } -func (r *Resolver) triggerHeaders(ctx *Context, sourceName string) (http.Header, uint64) { +// prepareTrigger safely gets the headers for the trigger Subgraph and computes the hash across headers and input +// the generated has is the unique triggerID +// the headers must be forwarded to the DataSource to create the trigger +func (r *Resolver) prepareTrigger(ctx *Context, sourceName string, input []byte) (headers http.Header, triggerID uint64) { if ctx.SubgraphHeadersBuilder != nil { - return ctx.SubgraphHeadersBuilder.HeadersForSubgraph(sourceName) + header, headerHash := ctx.SubgraphHeadersBuilder.HeadersForSubgraph(sourceName) + keyGen := pool.Hash64.Get() + _, _ = keyGen.Write(input) + triggerID = keyGen.Sum64() + headerHash + pool.Hash64.Put(keyGen) + return header, triggerID } return nil, 0 } @@ -1118,20 +1131,13 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ return nil } - _, headersHash := r.triggerHeaders(ctx, subscription.Trigger.SourceName) - - xxh := pool.Hash64.Get() - _, _ = xxh.Write(input) - // the hash for subgraph headers is pre-computed - // we can just add it to the input hash to get a unique id - uniqueID := xxh.Sum64() + headersHash - pool.Hash64.Put(xxh) + headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) id := SubscriptionIdentifier{ ConnectionID: ConnectionIDs.Inc(), SubscriptionID: 0, } if r.options.Debug { - fmt.Printf("resolver:trigger:subscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID) + fmt.Printf("resolver:trigger:subscribe:sync:%d:%d\n", triggerID, id.SubscriptionID) } completed := make(chan struct{}) @@ -1141,7 +1147,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ // Stop processing if the resolver is shutting down return r.ctx.Err() case r.events <- subscriptionEvent{ - triggerID: uniqueID, + triggerID: triggerID, kind: subscriptionEventKindAddSubscription, addSubscription: &addSubscription{ ctx: ctx, @@ -1151,6 +1157,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ id: id, completed: completed, sourceName: subscription.Trigger.SourceName, + headers: headers, }, }: } @@ -1177,13 +1184,13 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ } if r.options.Debug { - fmt.Printf("resolver:trigger:unsubscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID) + fmt.Printf("resolver:trigger:unsubscribe:sync:%d:%d\n", triggerID, id.SubscriptionID) } // Remove the subscription when the client disconnects. r.events <- subscriptionEvent{ - triggerID: uniqueID, + triggerID: triggerID, kind: subscriptionEventKindRemoveSubscription, id: id, } @@ -1228,14 +1235,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G return nil } - _, headersHash := r.triggerHeaders(ctx, subscription.Trigger.SourceName) - - xxh := pool.Hash64.Get() - _, _ = xxh.Write(input) - // the hash for subgraph headers is pre-computed - // we can just add it to the input hash to get a unique id - uniqueID := xxh.Sum64() + headersHash - pool.Hash64.Put(xxh) + headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) select { case <-r.ctx.Done(): @@ -1245,7 +1245,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // Stop resolving if the client is gone return ctx.ctx.Err() case r.events <- subscriptionEvent{ - triggerID: uniqueID, + triggerID: triggerID, kind: subscriptionEventKindAddSubscription, addSubscription: &addSubscription{ ctx: ctx, @@ -1255,6 +1255,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G id: id, completed: make(chan struct{}), sourceName: subscription.Trigger.SourceName, + headers: headers, }, }: } @@ -1369,6 +1370,7 @@ type addSubscription struct { id SubscriptionIdentifier completed chan struct{} sourceName string + headers http.Header } type subscriptionEventKind int From 6653948325e9f4bf91994f0d743968709771ca24 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sun, 26 Oct 2025 11:56:58 +0100 Subject: [PATCH 34/61] chore: refactor & comments --- v2/pkg/engine/resolve/arena.go | 8 ++++++ v2/pkg/engine/resolve/context.go | 10 ++++++-- v2/pkg/engine/resolve/inputtemplate.go | 2 ++ v2/pkg/engine/resolve/loader.go | 34 ++++++++++++++++++-------- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/v2/pkg/engine/resolve/arena.go b/v2/pkg/engine/resolve/arena.go index 0aae88974..cca1f3312 100644 --- a/v2/pkg/engine/resolve/arena.go +++ b/v2/pkg/engine/resolve/arena.go @@ -10,12 +10,20 @@ import ( // ArenaPool provides a thread-safe pool of arena.Arena instances for memory-efficient allocations. // It uses weak pointers to allow garbage collection of unused arenas while maintaining // a pool of reusable arenas for high-frequency allocation patterns. +// +// by storing ArenaPoolItem as weak pointers, the GC can collect them at any time +// before using an ArenaPoolItem, we try to get a strong pointer while removing it from the pool +// once we call Release, we turn the item back to the pool and make it a weak pointer again +// this means that at any time, GC can claim back the memory if required, +// allowing GC to automatically manage an appropriate pool size depending on available memory and GC pressure type ArenaPool struct { + // pool is a slice of weak pointers to the struct holding the arena.Arena pool []weak.Pointer[ArenaPoolItem] sizes map[uint64]*arenaPoolItemSize mu sync.Mutex } +// arenaPoolItemSize is used to track the required memory across the last 50 arenas in the pool type arenaPoolItemSize struct { count int totalBytes int diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index fdb2ebb58..52f2eb3bb 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -55,9 +55,15 @@ func (c *Context) HeadersForSubgraphRequest(subgraphName string) (http.Header, u } type ExecutionOptions struct { - SkipLoader bool + // SkipLoader will, as the name indicates, skip loading data + // However, it does indeed resolve a response + // This can be useful, e.g. in combination with IncludeQueryPlanInResponse + // The purpose is to get a QueryPlan (even for Subscriptions) + SkipLoader bool + // IncludeQueryPlanInResponse generates a QueryPlan as part of the response in Resolvable IncludeQueryPlanInResponse bool - SendHeartbeat bool + // SendHeartbeat sends regular HeartBeats for Subscriptions + SendHeartbeat bool // DisableRequestDeduplication disables deduplication of requests to the same subgraph with the same input within a single operation execution. DisableRequestDeduplication bool } diff --git a/v2/pkg/engine/resolve/inputtemplate.go b/v2/pkg/engine/resolve/inputtemplate.go index 0ad72ec94..e0fc97aa6 100644 --- a/v2/pkg/engine/resolve/inputtemplate.go +++ b/v2/pkg/engine/resolve/inputtemplate.go @@ -55,6 +55,8 @@ func SetInputUndefinedVariables(preparedInput InputTemplateWriter, undefinedVari // to callers; renderSegments intercepts it and writes literal.NULL instead. var errSetTemplateOutputNull = errors.New("set to null") +// InputTemplateWriter is used to decouple Buffer implementations from InputTemplate +// This way, the implementation can easily be swapped, e.g. between bytes.Buffer and similar implementations type InputTemplateWriter interface { io.Writer io.StringWriter diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 340c41894..4b51df7e6 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -192,7 +192,9 @@ type Loader struct { // If you're not doing this, you will see segfaults // Example of correct usage in func "mergeResult" jsonArena arena.Arena - sf *SingleFlight + // sf is the SingleFlight object shared across all client requests + // it's thread safe and can be used to de-duplicate subgraph requests + sf *SingleFlight } func (l *Loader) Free() { @@ -302,7 +304,6 @@ func (l *Loader) resolveSingle(item *FetchItem) error { if l.ctx.LoaderHooks != nil { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) } - return err case *BatchEntityFetch: res := &result{} @@ -438,7 +439,7 @@ func selectItems(a arena.Arena, items []*astjson.Value, element FetchItemPathEle return selected } -func itemsData(a arena.Arena, items []*astjson.Value) *astjson.Value { +func (l *Loader) itemsData(items []*astjson.Value) *astjson.Value { if len(items) == 0 { return astjson.NullValue } @@ -449,7 +450,7 @@ func itemsData(a arena.Arena, items []*astjson.Value) *astjson.Value { // however, itemsData can be called concurrently, so this might result in a race arr := astjson.MustParseBytes([]byte(`[]`)) for i, item := range items { - arr.SetArrayItem(a, i, item) + arr.SetArrayItem(nil, i, item) } return arr } @@ -553,6 +554,9 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson if len(res.out) == 0 { return l.renderErrorsFailedToFetch(fetchItem, res, emptyGraphQLResponse) } + // before parsing bytes with an arena.Arena, it's important to first allocate the bytes ON the same arena.Arena + // this ties their lifecycles together + // if you don't do this, you'll get segfaults slice := arena.AllocateSlice[byte](l.jsonArena, len(res.out), len(res.out)) copy(slice, res.out) response, err := astjson.ParseBytesWithArena(l.jsonArena, slice) @@ -707,7 +711,7 @@ var ( ) func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem) []byte { - out := &bytes.Buffer{} + out := bytes.NewBuffer(nil) elements := fetchItem.ResponsePathElements if len(elements) > 0 && elements[len(elements)-1] == "@" { elements = elements[:len(elements)-1] @@ -1319,7 +1323,7 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI res.init(fetch.PostProcessing, fetch.Info) buf := bytes.NewBuffer(nil) - inputData := itemsData(l.jsonArena, items) + inputData := l.itemsData(items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && inputData != nil { @@ -1358,7 +1362,7 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *EntityFetch, items []*astjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) - input := itemsData(l.jsonArena, items) + input := l.itemsData(items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && input != nil { @@ -1441,17 +1445,22 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && len(items) != 0 { - data := itemsData(l.jsonArena, items) + data := l.itemsData(items) if data != nil { fetch.Trace.RawInputData, _ = l.compactJSON(data.MarshalTo(nil)) } } } - // I tried using arena here but it only worsened the situation + // I tried using arena here, but it only worsened the situation preparedInput := bytes.NewBuffer(make([]byte, 0, 64)) itemInput := bytes.NewBuffer(make([]byte, 0, 32)) keyGen := pool.Hash64.Get() - defer pool.Hash64.Put(keyGen) + defer func() { + if keyGen == nil { + return + } + pool.Hash64.Put(keyGen) + }() var undefinedVariables []string @@ -1512,6 +1521,11 @@ WithNextItem: } } + // not used anymore + pool.Hash64.Put(keyGen) + // setting to nil so that the defer func doesn't return it twice + keyGen = nil + if len(itemHashes) == 0 { // all items were skipped - discard fetch res.fetchSkipped = true From 6cbfed0eacdd78479892e925c8bddb8ed905ecce Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sun, 26 Oct 2025 11:57:13 +0100 Subject: [PATCH 35/61] chore: remove unused ParallelListItemFetch --- .../create_concrete_single_fetch_types.go | 8 - v2/pkg/engine/resolve/fetch.go | 32 --- v2/pkg/engine/resolve/fetchtree.go | 25 --- v2/pkg/engine/resolve/loader.go | 62 ------ v2/pkg/engine/resolve/loader_hooks_test.go | 63 ------ v2/pkg/engine/resolve/resolve_test.go | 208 ------------------ 6 files changed, 398 deletions(-) diff --git a/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go b/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go index f5d0b2ae2..44b3225fb 100644 --- a/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go +++ b/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go @@ -51,19 +51,11 @@ func (d *createConcreteSingleFetchTypes) traverseSingleFetch(fetch *resolve.Sing return d.createEntityBatchFetch(fetch) case fetch.RequiresEntityFetch: return d.createEntityFetch(fetch) - case fetch.RequiresParallelListItemFetch: - return d.createParallelListItemFetch(fetch) default: return fetch } } -func (d *createConcreteSingleFetchTypes) createParallelListItemFetch(fetch *resolve.SingleFetch) resolve.Fetch { - return &resolve.ParallelListItemFetch{ - Fetch: fetch, - } -} - func (d *createConcreteSingleFetchTypes) createEntityBatchFetch(fetch *resolve.SingleFetch) resolve.Fetch { representationsVariableIndex := -1 for i, segment := range fetch.InputTemplate.Segments { diff --git a/v2/pkg/engine/resolve/fetch.go b/v2/pkg/engine/resolve/fetch.go index deeea25a4..622e731c4 100644 --- a/v2/pkg/engine/resolve/fetch.go +++ b/v2/pkg/engine/resolve/fetch.go @@ -12,7 +12,6 @@ type FetchKind int const ( FetchKindSingle FetchKind = iota + 1 - FetchKindParallelListItem FetchKindEntity FetchKindEntityBatch ) @@ -227,27 +226,6 @@ func (*EntityFetch) FetchKind() FetchKind { return FetchKindEntity } -// The ParallelListItemFetch can be used to make nested parallel fetches within a list -// Usually, you want to batch fetches within a list, which is the default behavior of SingleFetch -// However, if the data source does not support batching, you can use this fetch to make parallel fetches within a list -type ParallelListItemFetch struct { - Fetch *SingleFetch - Traces []*SingleFetch - Trace *DataSourceLoadTrace -} - -func (p *ParallelListItemFetch) Dependencies() *FetchDependencies { - return &p.Fetch.FetchDependencies -} - -func (p *ParallelListItemFetch) FetchInfo() *FetchInfo { - return p.Fetch.Info -} - -func (*ParallelListItemFetch) FetchKind() FetchKind { - return FetchKindParallelListItem -} - type QueryPlan struct { DependsOnFields []Representation Query string @@ -272,12 +250,6 @@ type FetchConfiguration struct { Variables Variables DataSource DataSource - // RequiresParallelListItemFetch indicates that the single fetches should be executed without batching. - // If we have multiple fetches attached to the object, then after post-processing of a plan - // we will get ParallelListItemFetch instead of ParallelFetch. - // Happens only for objects under the array path and used only for the introspection. - RequiresParallelListItemFetch bool - // RequiresEntityFetch will be set to true if the fetch is an entity fetch on an object. // After post-processing, we will get EntityFetch. RequiresEntityFetch bool @@ -313,9 +285,6 @@ func (fc *FetchConfiguration) Equals(other *FetchConfiguration) bool { // Note: we do not compare datasources, as they will always be a different instance. - if fc.RequiresParallelListItemFetch != other.RequiresParallelListItemFetch { - return false - } if fc.RequiresEntityFetch != other.RequiresEntityFetch { return false } @@ -505,5 +474,4 @@ var ( _ Fetch = (*SingleFetch)(nil) _ Fetch = (*BatchEntityFetch)(nil) _ Fetch = (*EntityFetch)(nil) - _ Fetch = (*ParallelListItemFetch)(nil) ) diff --git a/v2/pkg/engine/resolve/fetchtree.go b/v2/pkg/engine/resolve/fetchtree.go index f4fd987ce..9bc38497c 100644 --- a/v2/pkg/engine/resolve/fetchtree.go +++ b/v2/pkg/engine/resolve/fetchtree.go @@ -130,17 +130,6 @@ func (n *FetchTreeNode) Trace() *FetchTreeTraceNode { Trace: f.Trace, Path: n.Item.ResponsePath, } - case *ParallelListItemFetch: - trace.Fetch = &FetchTraceNode{ - Kind: "ParallelList", - SourceID: f.Fetch.Info.DataSourceID, - SourceName: f.Fetch.Info.DataSourceName, - Traces: make([]*DataSourceLoadTrace, len(f.Traces)), - Path: n.Item.ResponsePath, - } - for i, t := range f.Traces { - trace.Fetch.Traces[i] = t.Trace - } default: } case FetchTreeNodeKindSequence, FetchTreeNodeKindParallel: @@ -253,20 +242,6 @@ func (n *FetchTreeNode) queryPlan() *FetchTreeQueryPlanNode { queryPlan.Fetch.Query = f.Info.QueryPlan.Query queryPlan.Fetch.Representations = f.Info.QueryPlan.DependsOnFields } - case *ParallelListItemFetch: - queryPlan.Fetch = &FetchTreeQueryPlan{ - Kind: "ParallelList", - FetchID: f.Fetch.FetchDependencies.FetchID, - DependsOnFetchIDs: f.Fetch.FetchDependencies.DependsOnFetchIDs, - SubgraphName: f.Fetch.Info.DataSourceName, - SubgraphID: f.Fetch.Info.DataSourceID, - Path: n.Item.ResponsePath, - } - - if f.Fetch.Info.QueryPlan != nil { - queryPlan.Fetch.Query = f.Fetch.Info.QueryPlan.Query - queryPlan.Fetch.Representations = f.Fetch.Info.QueryPlan.DependsOnFields - } default: } case FetchTreeNodeKindSequence, FetchTreeNodeKindParallel: diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 4b51df7e6..cff02a488 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -327,41 +327,6 @@ func (l *Loader) resolveSingle(item *FetchItem) error { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) } return err - case *ParallelListItemFetch: - results := make([]*result, len(items)) - if l.ctx.TracingOptions.Enable { - f.Traces = make([]*SingleFetch, len(items)) - } - g, ctx := errgroup.WithContext(l.ctx.ctx) - for i := range items { - i := i - results[i] = &result{} - if l.ctx.TracingOptions.Enable { - f.Traces[i] = new(SingleFetch) - *f.Traces[i] = *f.Fetch - g.Go(func() error { - return l.loadFetch(ctx, f.Traces[i], item, items[i:i+1], results[i]) - }) - continue - } - g.Go(func() error { - return l.loadFetch(ctx, f.Fetch, item, items[i:i+1], results[i]) - }) - } - err := g.Wait() - if err != nil { - return errors.WithStack(err) - } - for i := range results { - err = l.mergeResult(item, results[i], items[i:i+1]) - if l.ctx.LoaderHooks != nil { - l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].ds, newResponseInfo(results[i], l.ctx.subgraphErrors)) - } - if err != nil { - return errors.WithStack(err) - } - } - return nil default: return nil } @@ -459,33 +424,6 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchIte switch f := fetch.(type) { case *SingleFetch: return l.loadSingleFetch(ctx, f, fetchItem, items, res) - case *ParallelListItemFetch: - results := make([]*result, len(items)) - if l.ctx.TracingOptions.Enable { - f.Traces = make([]*SingleFetch, len(items)) - } - g, ctx := errgroup.WithContext(l.ctx.ctx) - for i := range items { - i := i - results[i] = &result{} - if l.ctx.TracingOptions.Enable { - f.Traces[i] = new(SingleFetch) - *f.Traces[i] = *f.Fetch - g.Go(func() error { - return l.loadFetch(ctx, f.Traces[i], fetchItem, items[i:i+1], results[i]) - }) - continue - } - g.Go(func() error { - return l.loadFetch(ctx, f.Fetch, fetchItem, items[i:i+1], results[i]) - }) - } - err := g.Wait() - if err != nil { - return errors.WithStack(err) - } - res.nestedMergeItems = results - return nil case *EntityFetch: return l.loadEntityFetch(ctx, fetchItem, f, items, res) case *BatchEntityFetch: diff --git a/v2/pkg/engine/resolve/loader_hooks_test.go b/v2/pkg/engine/resolve/loader_hooks_test.go index ebe263dcd..4a2ce9cb2 100644 --- a/v2/pkg/engine/resolve/loader_hooks_test.go +++ b/v2/pkg/engine/resolve/loader_hooks_test.go @@ -248,69 +248,6 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { } })) - t.Run("parallel list item fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { - mockDataSource := NewMockDataSource(ctrl) - mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { - return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil - }) - resolveCtx := Context{ - ctx: context.Background(), - LoaderHooks: NewTestLoaderHooks(), - } - return &GraphQLResponse{ - Info: &GraphQLResponseInfo{ - OperationType: ast.OperationTypeQuery, - }, - Fetches: SingleWithPath(&ParallelListItemFetch{ - Fetch: &SingleFetch{ - FetchConfiguration: FetchConfiguration{ - DataSource: mockDataSource, - PostProcessing: PostProcessingConfiguration{ - SelectResponseErrorsPath: []string{"errors"}, - }, - }, - Info: &FetchInfo{ - DataSourceID: "Users", - DataSourceName: "Users", - }, - }, - }, "query"), - Data: &Object{ - Nullable: false, - Fields: []*Field{ - { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, - Nullable: true, - }, - }, - }, - }, - }, &resolveCtx, `{"errors":[{"message":"Failed to fetch from Subgraph 'Users' at Path 'query'.","extensions":{"errors":[{"message":"errorMessage"}]}}],"data":{"name":null}}`, - func(t *testing.T) { - loaderHooks := resolveCtx.LoaderHooks.(*TestLoaderHooks) - - assert.Equal(t, int64(1), loaderHooks.preFetchCalls.Load()) - assert.Equal(t, int64(1), loaderHooks.postFetchCalls.Load()) - - var subgraphError *SubgraphError - assert.Len(t, loaderHooks.errors, 1) - assert.ErrorAs(t, loaderHooks.errors[0], &subgraphError) - assert.Equal(t, "Users", subgraphError.DataSourceInfo.Name) - assert.Equal(t, "query", subgraphError.Path) - assert.Equal(t, "", subgraphError.Reason) - assert.Equal(t, 0, subgraphError.ResponseCode) - assert.Len(t, subgraphError.DownstreamErrors, 1) - assert.Equal(t, "errorMessage", subgraphError.DownstreamErrors[0].Message) - assert.Nil(t, subgraphError.DownstreamErrors[0].Extensions) - - assert.NotNil(t, resolveCtx.SubgraphErrors()) - } - })) - t.Run("fetch with subgraph error and custom extension code. No extension fields are propagated by default", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 5c2ea4ed6..112776037 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -2793,214 +2793,6 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, Context{ctx: context.Background(), Variables: astjson.MustParseBytes([]byte(`{"firstArg":"firstArgValue","thirdArg":123,"secondArg": true, "fourthArg": 12.34}`))}, `{"data":{"serviceOne":{"fieldOne":"fieldOneValue"},"serviceTwo":{"fieldTwo":"fieldTwoValue","serviceOneResponse":{"fieldOne":"fieldOneValue"}},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}}` })) t.Run("federation", func(t *testing.T) { - t.Run("simple", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - - userService := NewMockDataSource(ctrl) - userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { - actual := string(input) - expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` - assert.Equal(t, expected, actual) - return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename":"User"}}}`), nil - }) - - reviewsService := NewMockDataSource(ctrl) - reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { - actual := string(input) - expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` - assert.Equal(t, expected, actual) - return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil - }) - - var productServiceCallCount atomic.Int64 - - productService := NewMockDataSource(ctrl) - productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { - actual := string(input) - productServiceCallCount.Add(1) - switch actual { - case `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"}]}}}`: - return []byte(`{"data":{"_entities":[{"name": "Furby"}]}}`), nil - case `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-2","__typename":"Product"}]}}}`: - return []byte(`{"data":{"_entities":[{"name": "Trilby"}]}}`), nil - default: - t.Fatalf("unexpected request: %s", actual) - } - return nil, nil - }).Times(2) - - return &GraphQLResponse{ - Fetches: Sequence( - SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - FetchConfiguration: FetchConfiguration{ - DataSource: userService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data"}, - }, - }, - }, "query"), - SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[`), - SegmentType: StaticSegmentType, - }, - { - SegmentType: VariableSegmentType, - VariableKind: ResolvableObjectVariableKind, - Renderer: NewGraphQLVariableResolveRenderer(&Object{ - Fields: []*Field{ - { - Name: []byte("id"), - Value: &String{ - Path: []string{"id"}, - }, - }, - { - Name: []byte("__typename"), - Value: &String{ - Path: []string{"__typename"}, - }, - }, - }, - }), - }, - { - Data: []byte(`]}}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - FetchConfiguration: FetchConfiguration{ - DataSource: reviewsService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "0"}, - }, - }, - }, "query.me", ObjectPath("me")), - SingleWithPath(&ParallelListItemFetch{ - Fetch: &SingleFetch{ - FetchConfiguration: FetchConfiguration{ - DataSource: productService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "0"}, - }, - }, - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[`), - SegmentType: StaticSegmentType, - }, - { - SegmentType: VariableSegmentType, - VariableKind: ResolvableObjectVariableKind, - Renderer: NewGraphQLVariableResolveRenderer(&Object{ - Fields: []*Field{ - { - Name: []byte("upc"), - Value: &String{ - Path: []string{"upc"}, - }, - }, - { - Name: []byte("__typename"), - Value: &String{ - Path: []string{"__typename"}, - }, - }, - }, - }), - }, - { - Data: []byte(`]}}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - }, - }, "query.me.reviews.@.product", ObjectPath("me"), ArrayPath("reviews"), ObjectPath("product")), - ), - Data: &Object{ - Fields: []*Field{ - { - Name: []byte("me"), - Value: &Object{ - Path: []string{"me"}, - Nullable: true, - Fields: []*Field{ - { - Name: []byte("id"), - Value: &String{ - Path: []string{"id"}, - }, - }, - { - Name: []byte("username"), - Value: &String{ - Path: []string{"username"}, - }, - }, - { - - Name: []byte("reviews"), - Value: &Array{ - Path: []string{"reviews"}, - Nullable: true, - Item: &Object{ - Nullable: true, - Fields: []*Field{ - { - Name: []byte("body"), - Value: &String{ - Path: []string{"body"}, - }, - }, - { - Name: []byte("product"), - Value: &Object{ - Path: []string{"product"}, - Fields: []*Field{ - { - Name: []byte("upc"), - Value: &String{ - Path: []string{"upc"}, - }, - }, - { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, Context{ctx: context.Background(), Variables: nil}, `{"data":{"me":{"id":"1234","username":"Me","reviews":[{"body":"A highly effective form of birth control.","product":{"upc":"top-1","name":"Furby"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","name":"Trilby"}}]}}}` - })) t.Run("federation with batch", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). From daa18e84c305e3d99d62749706034b27d38c1aad Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 07:56:36 +0100 Subject: [PATCH 36/61] chore: simplify batchStats logic --- v2/pkg/engine/resolve/loader.go | 106 +++++++++++++++----------------- 1 file changed, 50 insertions(+), 56 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index cff02a488..1f6337560 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -91,34 +91,32 @@ func newResponseInfo(res *result, subgraphError error) *ResponseInfo { return responseInfo } -// batchStats represents an index map for batched items. -// It is used to ensure that the correct json values will be merged with the correct items from the batch. +// batchStats represents per-unique-batch-item merge targets. +// Outer slice index corresponds to the unique representation index in the request batch, +// and the inner slice contains all target values that should be merged with the response at that index. // // Example: -// [[0],[1],[0],[1]] We originally have 4 items, but we have 2 unique indexes (0 and 1). -// This means we are deduplicating 2 items by merging them from their response entity indexes. -// 0 -> 0, 1 -> 1, 2 -> 0, 3 -> 1 -type batchStats [][]int - -// getUniqueIndexes returns the number of unique indexes in the batchStats. -// This is used to ensure that we can provide a valid error message in case of differing array lengths. -func (b *batchStats) getUniqueIndexes() int { - uniqueIndexes := make(map[int]struct{}) - for _, bi := range *b { - for _, index := range bi { - if index < 0 { - continue - } - uniqueIndexes[index] = struct{}{} - } - } +// For 4 original items that deduplicate to 2 unique representations, we might have: +// [ +// +// [item0, item2], // merge response[0] into item0 and item2 +// [item1, item3], // merge response[1] into item1 and item3 +// +// ] +type batchStats [][]*astjson.Value - return len(uniqueIndexes) +// expectedNumberOfBatchItems returns the number of unique indexes in the batchStats. +// With the new structure, this equals the outer slice length. +func (b *batchStats) expectedNumberOfBatchItems() int { + return len(*b) } type result struct { - postProcessing PostProcessingConfiguration - batchStats batchStats + postProcessing PostProcessingConfiguration + batchStats batchStats + // batchHashToIndex maps a request item hash to its unique batch index. + // Used during request construction and to avoid recomputing uniqueness. + batchHashToIndex map[uint64]int fetchSkipped bool nestedMergeItems []*result @@ -597,26 +595,24 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson } if res.batchStats != nil { - uniqueIndexes := res.batchStats.getUniqueIndexes() - if uniqueIndexes != len(batch) { - return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, uniqueIndexes, len(batch))) + expectedBatchItems := res.batchStats.expectedNumberOfBatchItems() + if expectedBatchItems != len(batch) { + return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, expectedBatchItems, len(batch))) } - for i, stats := range res.batchStats { - for _, idx := range stats { - if idx == -1 { - continue - } - items[i], _, err = astjson.MergeValuesWithPath(l.jsonArena, items[i], batch[idx], res.postProcessing.MergePath...) - if err != nil { + for batchIndex, targets := range res.batchStats { + src := batch[batchIndex] + for _, target := range targets { + _, _, mErr := astjson.MergeValuesWithPath(l.jsonArena, target, src, res.postProcessing.MergePath...) + if mErr != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, - Reason: err, + Reason: mErr, Path: fetchItem.ResponsePath, }) } - if slices.Contains(taintedIndices, idx) { - l.taintedObjs.add(items[i]) + if slices.Contains(taintedIndices, batchIndex) { + l.taintedObjs.add(target) } } } @@ -1406,8 +1402,8 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, if err != nil { return errors.WithStack(err) } - res.batchStats = make(batchStats, len(items)) - itemHashes := make([]uint64, 0, len(items)) + res.batchStats = make(batchStats, 0, len(items)) + res.batchHashToIndex = make(map[uint64]int, len(items)) batchItemIndex := 0 addSeparator := false @@ -1419,7 +1415,6 @@ WithNextItem: if err != nil { if fetch.Input.SkipErrItems { err = nil // nolint:ineffassign - res.batchStats[i] = append(res.batchStats[i], -1) continue } if l.ctx.TracingOptions.Enable { @@ -1428,34 +1423,33 @@ WithNextItem: return errors.WithStack(err) } if fetch.Input.SkipNullItems && itemInput.Len() == 4 && bytes.Equal(itemInput.Bytes(), null) { - res.batchStats[i] = append(res.batchStats[i], -1) continue } if fetch.Input.SkipEmptyObjectItems && itemInput.Len() == 2 && bytes.Equal(itemInput.Bytes(), emptyObject) { - res.batchStats[i] = append(res.batchStats[i], -1) continue } keyGen.Reset() _, _ = keyGen.Write(itemInput.Bytes()) itemHash := keyGen.Sum64() - for k := range itemHashes { - if itemHashes[k] == itemHash { - res.batchStats[i] = append(res.batchStats[i], k) - continue WithNextItem - } - } - itemHashes = append(itemHashes, itemHash) - if addSeparator { - err = fetch.Input.Separator.Render(l.ctx, nil, preparedInput) - if err != nil { - return errors.WithStack(err) + if existingIndex, ok := res.batchHashToIndex[itemHash]; ok { + res.batchStats[existingIndex] = append(res.batchStats[existingIndex], items[i]) + continue WithNextItem + } else { + if addSeparator { + err = fetch.Input.Separator.Render(l.ctx, nil, preparedInput) + if err != nil { + return errors.WithStack(err) + } } + _, _ = itemInput.WriteTo(preparedInput) + // new unique representation + res.batchHashToIndex[itemHash] = batchItemIndex + // create a new targets bucket for this unique index + res.batchStats = append(res.batchStats, []*astjson.Value{items[i]}) + batchItemIndex++ + addSeparator = true } - _, _ = itemInput.WriteTo(preparedInput) - res.batchStats[i] = append(res.batchStats[i], batchItemIndex) - batchItemIndex++ - addSeparator = true } } @@ -1464,7 +1458,7 @@ WithNextItem: // setting to nil so that the defer func doesn't return it twice keyGen = nil - if len(itemHashes) == 0 { + if len(res.batchStats) == 0 { // all items were skipped - discard fetch res.fetchSkipped = true if l.ctx.TracingOptions.Enable { From 2003186c30fa9680eb6900f6b3e6662146631149 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 08:24:33 +0100 Subject: [PATCH 37/61] chore: simplify --- v2/pkg/engine/resolve/loader.go | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 1f6337560..971dc4a16 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1389,12 +1389,6 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, preparedInput := bytes.NewBuffer(make([]byte, 0, 64)) itemInput := bytes.NewBuffer(make([]byte, 0, 32)) keyGen := pool.Hash64.Get() - defer func() { - if keyGen == nil { - return - } - pool.Hash64.Put(keyGen) - }() var undefinedVariables []string @@ -1420,6 +1414,7 @@ WithNextItem: if l.ctx.TracingOptions.Enable { fetch.Trace.LoadSkipped = true } + pool.Hash64.Put(keyGen) return errors.WithStack(err) } if fetch.Input.SkipNullItems && itemInput.Len() == 4 && bytes.Equal(itemInput.Bytes(), null) { @@ -1439,6 +1434,7 @@ WithNextItem: if addSeparator { err = fetch.Input.Separator.Render(l.ctx, nil, preparedInput) if err != nil { + pool.Hash64.Put(keyGen) return errors.WithStack(err) } } @@ -1453,10 +1449,7 @@ WithNextItem: } } - // not used anymore pool.Hash64.Put(keyGen) - // setting to nil so that the defer func doesn't return it twice - keyGen = nil if len(res.batchStats) == 0 { // all items were skipped - discard fetch From 0c0e1ce22ae21f98941d54d4d41deed7948cd3a4 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 11:37:24 +0100 Subject: [PATCH 38/61] chore: add tools pool for loadBatchEntityFetch --- v2/pkg/engine/resolve/loader.go | 137 +++++++++++++++++++++----------- 1 file changed, 91 insertions(+), 46 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 971dc4a16..a4893ef73 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -11,9 +11,11 @@ import ( "slices" "strconv" "strings" + "sync" "time" "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/pkg/errors" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -26,7 +28,6 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/errorcodes" "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafebytes" - "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) const ( @@ -91,32 +92,21 @@ func newResponseInfo(res *result, subgraphError error) *ResponseInfo { return responseInfo } -// batchStats represents per-unique-batch-item merge targets. -// Outer slice index corresponds to the unique representation index in the request batch, -// and the inner slice contains all target values that should be merged with the response at that index. -// -// Example: -// For 4 original items that deduplicate to 2 unique representations, we might have: -// [ -// -// [item0, item2], // merge response[0] into item0 and item2 -// [item1, item3], // merge response[1] into item1 and item3 -// -// ] -type batchStats [][]*astjson.Value - -// expectedNumberOfBatchItems returns the number of unique indexes in the batchStats. -// With the new structure, this equals the outer slice length. -func (b *batchStats) expectedNumberOfBatchItems() int { - return len(*b) -} - type result struct { postProcessing PostProcessingConfiguration - batchStats batchStats - // batchHashToIndex maps a request item hash to its unique batch index. - // Used during request construction and to avoid recomputing uniqueness. - batchHashToIndex map[uint64]int + // batchStats represents per-unique-batch-item merge targets. + // Outer slice index corresponds to the unique representation index in the request batch, + // and the inner slice contains all target values that should be merged with the response at that index. + // + // Example: + // For 4 original items that deduplicate to 2 unique representations, we might have: + // [ + // + // [item0, item2], // merge response[0] into item0 and item2 + // [item1, item3], // merge response[1] into item1 and item3 + // + // ] + batchStats [][]*astjson.Value fetchSkipped bool nestedMergeItems []*result @@ -138,6 +128,7 @@ type result struct { // out is the subgraph response body out []byte singleFlightStats *singleFlightStats + tools *batchEntityTools } func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { @@ -231,6 +222,12 @@ func (l *Loader) resolveParallel(nodes []*FetchTreeNode) error { return nil } results := make([]*result, len(nodes)) + defer func() { + for i := range results { + // no-op if tools == nil + batchEntityToolPool.Put(results[i].tools) + } + }() itemsItems := make([][]*astjson.Value, len(nodes)) g, ctx := errgroup.WithContext(l.ctx.ctx) for i := range nodes { @@ -305,6 +302,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { return err case *BatchEntityFetch: res := &result{} + defer batchEntityToolPool.Put(res.tools) err := l.loadBatchEntityFetch(l.ctx.ctx, item, f, items, res) if err != nil { return errors.WithStack(err) @@ -595,9 +593,8 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson } if res.batchStats != nil { - expectedBatchItems := res.batchStats.expectedNumberOfBatchItems() - if expectedBatchItems != len(batch) { - return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, expectedBatchItems, len(batch))) + if len(res.batchStats) != len(batch) { + return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, len(res.batchStats), len(batch))) } for batchIndex, targets := range res.batchStats { @@ -1373,6 +1370,48 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc return nil } +type batchEntityTools struct { + keyGen *xxhash.Digest + batchHashToIndex map[uint64]int + a arena.Arena +} + +func (b *batchEntityTools) reset() { + b.keyGen.Reset() + b.a.Reset() + for i := range b.batchHashToIndex { + delete(b.batchHashToIndex, i) + } +} + +type _batchEntityToolPool struct { + pool sync.Pool +} + +func (p *_batchEntityToolPool) Get(items int) *batchEntityTools { + item := p.pool.Get() + if item == nil { + return &batchEntityTools{ + keyGen: xxhash.New(), + batchHashToIndex: make(map[uint64]int, items), + a: arena.NewMonotonicArena(arena.WithMinBufferSize(1024)), + } + } + return item.(*batchEntityTools) +} + +func (p *_batchEntityToolPool) Put(item *batchEntityTools) { + if item == nil { + return + } + item.reset() + p.pool.Put(item) +} + +var ( + batchEntityToolPool = _batchEntityToolPool{} +) + func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *BatchEntityFetch, items []*astjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) @@ -1385,19 +1424,19 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, } } } - // I tried using arena here, but it only worsened the situation - preparedInput := bytes.NewBuffer(make([]byte, 0, 64)) - itemInput := bytes.NewBuffer(make([]byte, 0, 32)) - keyGen := pool.Hash64.Get() + res.tools = batchEntityToolPool.Get(len(items)) + preparedInput := arena.NewArenaBuffer(res.tools.a) + itemInput := arena.NewArenaBuffer(res.tools.a) + batchStats := arena.AllocateSlice[[]*astjson.Value](res.tools.a, 0, len(items)) + + // I tried using arena here, but it only worsened the situation var undefinedVariables []string err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - res.batchStats = make(batchStats, 0, len(items)) - res.batchHashToIndex = make(map[uint64]int, len(items)) batchItemIndex := 0 addSeparator := false @@ -1414,7 +1453,6 @@ WithNextItem: if l.ctx.TracingOptions.Enable { fetch.Trace.LoadSkipped = true } - pool.Hash64.Put(keyGen) return errors.WithStack(err) } if fetch.Input.SkipNullItems && itemInput.Len() == 4 && bytes.Equal(itemInput.Bytes(), null) { @@ -1424,34 +1462,31 @@ WithNextItem: continue } - keyGen.Reset() - _, _ = keyGen.Write(itemInput.Bytes()) - itemHash := keyGen.Sum64() - if existingIndex, ok := res.batchHashToIndex[itemHash]; ok { - res.batchStats[existingIndex] = append(res.batchStats[existingIndex], items[i]) + res.tools.keyGen.Reset() + _, _ = res.tools.keyGen.Write(itemInput.Bytes()) + itemHash := res.tools.keyGen.Sum64() + if existingIndex, ok := res.tools.batchHashToIndex[itemHash]; ok { + batchStats[existingIndex] = arena.SliceAppend(res.tools.a, batchStats[existingIndex], items[i]) continue WithNextItem } else { if addSeparator { err = fetch.Input.Separator.Render(l.ctx, nil, preparedInput) if err != nil { - pool.Hash64.Put(keyGen) return errors.WithStack(err) } } _, _ = itemInput.WriteTo(preparedInput) // new unique representation - res.batchHashToIndex[itemHash] = batchItemIndex + res.tools.batchHashToIndex[itemHash] = batchItemIndex // create a new targets bucket for this unique index - res.batchStats = append(res.batchStats, []*astjson.Value{items[i]}) + batchStats = arena.SliceAppend(res.tools.a, batchStats, []*astjson.Value{items[i]}) batchItemIndex++ addSeparator = true } } } - pool.Hash64.Put(keyGen) - - if len(res.batchStats) == 0 { + if len(batchStats) == 0 { // all items were skipped - discard fetch res.fetchSkipped = true if l.ctx.TracingOptions.Enable { @@ -1470,7 +1505,16 @@ WithNextItem: if err != nil { return errors.WithStack(err) } + fetchInput := preparedInput.Bytes() + // it's important to copy the *astjson.Value's off the arena to avoid memory corruption + res.batchStats = make([][]*astjson.Value, len(batchStats)) + for i := range batchStats { + res.batchStats[i] = make([]*astjson.Value, len(batchStats[i])) + copy(res.batchStats[i], batchStats[i]) + batchStats[i] = nil + } + batchStats = nil if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchItem, fetchInput, fetch.Trace) @@ -1484,6 +1528,7 @@ WithNextItem: if !allowed { return nil } + l.executeSourceLoad(ctx, fetchItem, fetch.DataSource, fetchInput, res, fetch.Trace) return nil } From 8e3d0df3ed11e4c8a2799f2c16c1759ba160fd0f Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 13:31:29 +0100 Subject: [PATCH 39/61] chore: improved cleanup --- v2/pkg/engine/resolve/loader.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index a4893ef73..e4bd36d81 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1429,6 +1429,16 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, preparedInput := arena.NewArenaBuffer(res.tools.a) itemInput := arena.NewArenaBuffer(res.tools.a) batchStats := arena.AllocateSlice[[]*astjson.Value](res.tools.a, 0, len(items)) + defer func() { + // we need to clear the batchStats slice to avoid memory corruption + // once the outer func returns, we must not keep pointers to items on the arena + for i := range batchStats { + // nolint:ineffassign + batchStats[i] = nil + } + // nolint:ineffassign + batchStats = nil + }() // I tried using arena here, but it only worsened the situation var undefinedVariables []string @@ -1512,9 +1522,7 @@ WithNextItem: for i := range batchStats { res.batchStats[i] = make([]*astjson.Value, len(batchStats[i])) copy(res.batchStats[i], batchStats[i]) - batchStats[i] = nil } - batchStats = nil if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchItem, fetchInput, fetch.Trace) From f3f2a8ef3dea9f5d59522ac3ec530bf82bac312e Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 19:28:46 +0100 Subject: [PATCH 40/61] chore: refactor, docs, inbound sf --- v2/pkg/engine/resolve/const.go | 2 + v2/pkg/engine/resolve/context.go | 3 + .../resolve/inbound_request_singleflight.go | 138 ++++++++++++++++++ v2/pkg/engine/resolve/loader.go | 4 +- v2/pkg/engine/resolve/resolve.go | 37 +++-- ...ht.go => subgraph_request_singleflight.go} | 92 ++++++++---- 6 files changed, 232 insertions(+), 44 deletions(-) create mode 100644 v2/pkg/engine/resolve/inbound_request_singleflight.go rename v2/pkg/engine/resolve/{singleflight.go => subgraph_request_singleflight.go} (61%) diff --git a/v2/pkg/engine/resolve/const.go b/v2/pkg/engine/resolve/const.go index 8a259494e..2958fe1f5 100644 --- a/v2/pkg/engine/resolve/const.go +++ b/v2/pkg/engine/resolve/const.go @@ -8,6 +8,8 @@ var ( lBrack = []byte("[") rBrack = []byte("]") comma = []byte(",") + pipe = []byte("|") + dot = []byte(".") colon = []byte(":") quote = []byte("\"") null = []byte("null") diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index 52f2eb3bb..d6a8657e4 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -16,6 +16,7 @@ import ( type Context struct { ctx context.Context Variables *astjson.Value + VariablesHash uint64 Files []*httpclient.FileUpload Request Request RenameTypeNames []RenameTypeName @@ -44,6 +45,8 @@ type SubgraphHeadersBuilder interface { // HeadersForSubgraph must return the headers and a hash for a Subgraph Request // The hash will be used for request deduplication HeadersForSubgraph(subgraphName string) (http.Header, uint64) + // HashAll must return the hash for all subgraph requests combined + HashAll() uint64 } // HeadersForSubgraphRequest returns headers and a hash for a request that the engine will make to a subgraph diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go new file mode 100644 index 000000000..995ee390c --- /dev/null +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -0,0 +1,138 @@ +package resolve + +import ( + "sync" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +// InboundRequestSingleFlight is a sharded goroutine safe single flight implementation to de-couple inbound requests +// It's taking into consideration the normalized operation hash, variables hash and headers hash +// making it robust against collisions +// for scalability, you can add more shards in case the mutexes are a bottleneck +type InboundRequestSingleFlight struct { + shards []requestShard +} + +type requestShard struct { + mu sync.Mutex + m map[uint64]*InflightRequest +} + +const defaultRequestSingleFlightShardCount = 4 + +// NewRequestSingleFlight creates a InboundRequestSingleFlight with the provided +// number of shards. If shardCount <= 0, the default of 4 is used. +func NewRequestSingleFlight(shardCount int) *InboundRequestSingleFlight { + if shardCount <= 0 { + shardCount = defaultRequestSingleFlightShardCount + } + r := &InboundRequestSingleFlight{ + shards: make([]requestShard, shardCount), + } + for i := range r.shards { + r.shards[i] = requestShard{ + m: make(map[uint64]*InflightRequest), + } + } + return r +} + +type InflightRequest struct { + Done chan struct{} + Data []byte + Err error + ID uint64 + HasFollowers bool +} + +// GetOrCreate creates a new InflightRequest or returns an existing (shared) one +// The first caller to create an InflightRequest for a given key is a leader, everyone else a follower +// GetOrCreate blocks until ctx.ctx.Done() returns or InflightRequest.Done is closed +// It returns an error if the leader returned an error +// It returns nil,nil if the inbound request is not eligible for request deduplication +func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQLResponse) (*InflightRequest, error) { + + if ctx.ExecutionOptions.DisableRequestDeduplication { + return nil, nil + } + + if response != nil && response.Info != nil && response.Info.OperationType == ast.OperationTypeMutation { + return nil, nil + } + + // ctx.Request.ID is the unique ID of the normalized GraphQL document +1 (offset) + key := ctx.Request.ID + 1 + // ctx.VariablesHash is the hash of the normalized variables from the client request + // this makes the key unique across different variables + key += ctx.VariablesHash + 1 + if ctx.SubgraphHeadersBuilder != nil { + // ctx.SubgraphHeadersBuilder.HashAll() returns the hash of all headers that will be forwarded to all subgraphs + // this makes the key unique across different client request headers, given that we forward them + // we pre-compute all headers that will be forwarded to each subgraph + // if we combine all the subgraph header hashes, the key will be stable across all headers + key += ctx.SubgraphHeadersBuilder.HashAll() + } + + shard := r.shardFor(key) + shard.mu.Lock() + req, shared := shard.m[key] + if shared { + req.HasFollowers = true + shard.mu.Unlock() + select { + case <-req.Done: + if req.Err != nil { + return nil, req.Err + } + return req, nil + case <-ctx.ctx.Done(): + return nil, ctx.ctx.Err() + } + } + + req = &InflightRequest{ + Done: make(chan struct{}), + ID: key, + } + + shard.m[key] = req + shard.mu.Unlock() + return req, nil +} + +func (r *InboundRequestSingleFlight) FinishOk(req *InflightRequest, data []byte) { + if req == nil { + return + } + shard := r.shardFor(req.ID) + shard.mu.Lock() + delete(shard.m, req.ID) + hasFollowers := req.HasFollowers + shard.mu.Unlock() + if hasFollowers { + // optimization to only copy when we actually have to + req.Data = make([]byte, len(data)) + copy(req.Data, data) + } + close(req.Done) +} + +func (r *InboundRequestSingleFlight) FinishErr(req *InflightRequest, err error) { + if req == nil { + return + } + shard := r.shardFor(req.ID) + shard.mu.Lock() + delete(shard.m, req.ID) + shard.mu.Unlock() + req.Err = err + close(req.Done) +} + +func (r *InboundRequestSingleFlight) shardFor(key uint64) *requestShard { + // Fast modulo using power-of-two shard count if desired in the future. + // For now, use standard modulo for clarity. + idx := int(key % uint64(len(r.shards))) + return &r.shards[idx] +} diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index e4bd36d81..63cda90b2 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -181,9 +181,9 @@ type Loader struct { // If you're not doing this, you will see segfaults // Example of correct usage in func "mergeResult" jsonArena arena.Arena - // sf is the SingleFlight object shared across all client requests + // sf is the SubgraphRequestSingleFlight object shared across all client requests // it's thread safe and can be used to de-duplicate subgraph requests - sf *SingleFlight + sf *SubgraphRequestSingleFlight } func (l *Loader) Free() { diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index b5e3ff14b..dc1f0ba85 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -82,8 +82,10 @@ type Resolver struct { // responseBufferPool is the arena pool dedicated for response buffering before sending to the client responseBufferPool *ArenaPool - // Single flight cache for deduplicating requests across all loaders - sf *SingleFlight + // subgraphRequestSingleFlight is used to de-duplicate subgraph requests + subgraphRequestSingleFlight *SubgraphRequestSingleFlight + // inboundRequestSingleFlight is used to de-duplicate subgraph requests + inboundRequestSingleFlight *InboundRequestSingleFlight } func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) { @@ -239,7 +241,8 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { maxSubscriptionFetchTimeout: options.MaxSubscriptionFetchTimeout, resolveArenaPool: NewArenaPool(), responseBufferPool: NewArenaPool(), - sf: NewSingleFlight(), + subgraphRequestSingleFlight: NewSingleFlight(8), + inboundRequestSingleFlight: NewRequestSingleFlight(8), } resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) for i := 0; i < options.MaxConcurrency; i++ { @@ -251,7 +254,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { return resolver } -func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SingleFlight, a arena.Arena) *tools { +func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SubgraphRequestSingleFlight, a arena.Arena) *tools { return &tools{ resolvable: NewResolvable(a, options.ResolvableOptions), loader: &Loader{ @@ -289,7 +292,7 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons r.maxConcurrency <- struct{}{} }() - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, nil) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, nil) err := t.resolvable.Init(ctx, data, response.Info.OperationType) if err != nil { @@ -314,6 +317,16 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLResponse, writer io.Writer) (*GraphQLResolveInfo, error) { resp := &GraphQLResolveInfo{} + inflight, err := r.inboundRequestSingleFlight.GetOrCreate(ctx, response) + if err != nil { + return nil, err + } + + if inflight != nil && inflight.Data != nil { // follower + _, err = writer.Write(inflight.Data) + return resp, err + } + start := time.Now() <-r.maxConcurrency resp.ResolveAcquireWaitTime = time.Since(start) @@ -323,10 +336,11 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe resolveArena := r.resolveArenaPool.Acquire(ctx.Request.ID) // we're intentionally not using defer Release to have more control over the timing (see below) - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, resolveArena.Arena) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, resolveArena.Arena) - err := t.resolvable.Init(ctx, nil, response.Info.OperationType) + err = t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { + r.inboundRequestSingleFlight.FinishErr(inflight, err) r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) return nil, err } @@ -334,6 +348,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe if !ctx.ExecutionOptions.SkipLoader { err = t.loader.LoadGraphQLResponseData(ctx, response, t.resolvable) if err != nil { + r.inboundRequestSingleFlight.FinishErr(inflight, err) r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) return nil, err } @@ -344,6 +359,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe buf := arena.NewArenaBuffer(responseArena.Arena) err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, buf) if err != nil { + r.inboundRequestSingleFlight.FinishErr(inflight, err) r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) r.responseBufferPool.Release(ctx.Request.ID, responseArena) return nil, err @@ -357,6 +373,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe // as such, it can take some time // which is why we split the arenas and released the first one _, err = writer.Write(buf.Bytes()) + r.inboundRequestSingleFlight.FinishOk(inflight, buf.Bytes()) // all data is written to the client // we're safe to release our buffer r.responseBufferPool.Release(ctx.Request.ID, responseArena) @@ -494,7 +511,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar copy(input, sharedInput) resolveArena := r.resolveArenaPool.Acquire(resolveCtx.Request.ID) - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, resolveArena.Arena) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, resolveArena.Arena) if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) @@ -1107,7 +1124,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, nil) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, nil) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { @@ -1211,7 +1228,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, nil) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, nil) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { diff --git a/v2/pkg/engine/resolve/singleflight.go b/v2/pkg/engine/resolve/subgraph_request_singleflight.go similarity index 61% rename from v2/pkg/engine/resolve/singleflight.go rename to v2/pkg/engine/resolve/subgraph_request_singleflight.go index 76121d98e..013d90677 100644 --- a/v2/pkg/engine/resolve/singleflight.go +++ b/v2/pkg/engine/resolve/subgraph_request_singleflight.go @@ -6,14 +6,23 @@ import ( "github.com/cespare/xxhash/v2" ) -type SingleFlight struct { - mu *sync.RWMutex - items map[uint64]*SingleFlightItem - sizes map[uint64]*fetchSize +// SubgraphRequestSingleFlight is a sharded, goroutine safe single flight implementation to de-duplicate subgraph requests +// It's hashing the input and adds the pre-computed subgraph headers hash to avoid collisions +// In addition to single flight, it provides size hints to create right-sized buffers for subgraph requests +type SubgraphRequestSingleFlight struct { + shards []singleFlightShard xxPool *sync.Pool cleanup chan func() } +type singleFlightShard struct { + mu sync.RWMutex + items map[uint64]*SingleFlightItem + sizes map[uint64]*fetchSize +} + +const defaultSingleFlightShardCount = 4 + // SingleFlightItem is used to communicate between leader and followers // If an Item for a key doesn't exist, the leader creates and followers can join type SingleFlightItem struct { @@ -37,11 +46,12 @@ type fetchSize struct { totalBytes int } -func NewSingleFlight() *SingleFlight { - return &SingleFlight{ - items: make(map[uint64]*SingleFlightItem), - sizes: make(map[uint64]*fetchSize), - mu: new(sync.RWMutex), +func NewSingleFlight(shardCount int) *SubgraphRequestSingleFlight { + if shardCount <= 0 { + shardCount = defaultSingleFlightShardCount + } + s := &SubgraphRequestSingleFlight{ + shards: make([]singleFlightShard, shardCount), xxPool: &sync.Pool{ New: func() any { return xxhash.New() @@ -49,6 +59,13 @@ func NewSingleFlight() *SingleFlight { }, cleanup: make(chan func()), } + for i := range s.shards { + s.shards[i] = singleFlightShard{ + items: make(map[uint64]*SingleFlightItem), + sizes: make(map[uint64]*fetchSize), + } + } + return s } // GetOrCreateItem generates a single flight key (100% identical fetches) and a fetchKey (similar fetches, collisions possible but unproblematic) @@ -58,23 +75,26 @@ func NewSingleFlight() *SingleFlight { // item.sizeHint can be used to create an optimal buffer for the fetch in case of a leader // item.err must always be checked // item.response must never be mutated -func (s *SingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { +func (s *SubgraphRequestSingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { sfKey, fetchKey = s.keys(fetchItem, input, extraKey) - // First, try to get the item with a read lock - s.mu.RLock() - item, exists := s.items[sfKey] - s.mu.RUnlock() + // Get shard based on sfKey for items + shard := s.shardFor(sfKey) + + // First, try to get the item with a read lock on its shard + shard.mu.RLock() + item, exists := shard.items[sfKey] + shard.mu.RUnlock() if exists { return sfKey, fetchKey, item, true } // If not exists, acquire a write lock to create the item - s.mu.Lock() + shard.mu.Lock() // Double-check if the item was created while acquiring the write lock - item, exists = s.items[sfKey] + item, exists = shard.items[sfKey] if exists { - s.mu.Unlock() + shard.mu.Unlock() return sfKey, fetchKey, item, true } @@ -83,15 +103,16 @@ func (s *SingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extra // empty chan to indicate to all followers when we're done (close) loaded: make(chan struct{}), } - if size, ok := s.sizes[fetchKey]; ok { + // Read size hint from the same shard (both items and sizes use the same shard now) + if size, ok := shard.sizes[fetchKey]; ok { item.sizeHint = size.totalBytes / size.count } - s.items[sfKey] = item - s.mu.Unlock() + shard.items[sfKey] = item + shard.mu.Unlock() return sfKey, fetchKey, item, false } -func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64) { +func (s *SubgraphRequestSingleFlight) keys(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64) { h := s.xxPool.Get().(*xxhash.Digest) sfKey = s.sfKey(h, fetchItem, input, extraKey) h.Reset() @@ -103,7 +124,7 @@ func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte, extraKey uint64) // sfKey returns a key that 100% uniquely identifies a fetch with no collision // two sfKey are only the same when the fetches are 100% equal -func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { +func (s *SubgraphRequestSingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { if fetchItem != nil && fetchItem.Fetch != nil { info := fetchItem.Fetch.FetchInfo() if info != nil { @@ -119,7 +140,7 @@ func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byt // the purpose is to create a key from the DataSourceID and root fields to have less cardinality // the goal is to get an estimate buffer size for similar fetches // there's no point in hashing headers or the body for this purpose -func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { +func (s *SubgraphRequestSingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { if fetchItem == nil || fetchItem.Fetch == nil { return 0 } @@ -128,13 +149,13 @@ func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { return 0 } _, _ = h.WriteString(info.DataSourceID) - _, _ = h.WriteString("|") + _, _ = h.Write(pipe) for i := range info.RootFields { if i != 0 { - _, _ = h.WriteString(",") + _, _ = h.Write(comma) } _, _ = h.WriteString(info.RootFields[i].TypeName) - _, _ = h.WriteString(".") + _, _ = h.Write(dot) _, _ = h.WriteString(info.RootFields[i].FieldName) } return h.Sum64() @@ -143,11 +164,13 @@ func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { // Finish is for the leader to mark the SingleFlightItem as "done" // trigger all followers to look at the err & response of the item // and to update the size estimates -func (s *SingleFlight) Finish(sfKey, fetchKey uint64, item *SingleFlightItem) { +func (s *SubgraphRequestSingleFlight) Finish(sfKey, fetchKey uint64, item *SingleFlightItem) { close(item.loaded) - s.mu.Lock() - delete(s.items, sfKey) - if size, ok := s.sizes[fetchKey]; ok { + // Update sizes in the same shard as the item (using sfKey to get the shard) + shard := s.shardFor(sfKey) + shard.mu.Lock() + delete(shard.items, sfKey) + if size, ok := shard.sizes[fetchKey]; ok { if size.count == 50 { size.count = 1 size.totalBytes = size.totalBytes / 50 @@ -155,10 +178,15 @@ func (s *SingleFlight) Finish(sfKey, fetchKey uint64, item *SingleFlightItem) { size.count++ size.totalBytes += len(item.response) } else { - s.sizes[fetchKey] = &fetchSize{ + shard.sizes[fetchKey] = &fetchSize{ count: 1, totalBytes: len(item.response), } } - s.mu.Unlock() + shard.mu.Unlock() +} + +func (s *SubgraphRequestSingleFlight) shardFor(key uint64) *singleFlightShard { + idx := int(key % uint64(len(s.shards))) + return &s.shards[idx] } From cd59d03f8ea2b60440b28011850a3f7997bc0b0f Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 20:23:51 +0100 Subject: [PATCH 41/61] chore: refactor --- .../resolve/inbound_request_singleflight.go | 20 +++++++++---------- v2/pkg/engine/resolve/loader.go | 2 +- v2/pkg/engine/resolve/resolve.go | 6 +++++- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index 995ee390c..1dbe8c9a7 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -1,8 +1,10 @@ package resolve import ( + "encoding/binary" "sync" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) @@ -61,18 +63,16 @@ func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQL return nil, nil } - // ctx.Request.ID is the unique ID of the normalized GraphQL document +1 (offset) - key := ctx.Request.ID + 1 - // ctx.VariablesHash is the hash of the normalized variables from the client request - // this makes the key unique across different variables - key += ctx.VariablesHash + 1 + // Derive a robust key from request ID, variables hash and (optional) headers hash + var b [24]byte + binary.LittleEndian.PutUint64(b[0:8], ctx.Request.ID) + binary.LittleEndian.PutUint64(b[8:16], ctx.VariablesHash) + hh := uint64(0) if ctx.SubgraphHeadersBuilder != nil { - // ctx.SubgraphHeadersBuilder.HashAll() returns the hash of all headers that will be forwarded to all subgraphs - // this makes the key unique across different client request headers, given that we forward them - // we pre-compute all headers that will be forwarded to each subgraph - // if we combine all the subgraph header hashes, the key will be stable across all headers - key += ctx.SubgraphHeadersBuilder.HashAll() + hh = ctx.SubgraphHeadersBuilder.HashAll() } + binary.LittleEndian.PutUint64(b[16:24], hh) + key := xxhash.Sum64(b[:]) shard := r.shardFor(key) shard.mu.Lock() diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 63cda90b2..88ef6fec2 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1642,7 +1642,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem sfKey, fetchKey, item, shared := l.sf.GetOrCreateItem(fetchItem, input, extraKey) if res.singleFlightStats != nil { - res.singleFlightStats.used = shared + res.singleFlightStats.used = true res.singleFlightStats.shared = shared } diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index dc1f0ba85..b93888a79 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -5,6 +5,7 @@ package resolve import ( "bytes" "context" + "encoding/binary" "fmt" "io" "net/http" @@ -1104,7 +1105,10 @@ func (r *Resolver) prepareTrigger(ctx *Context, sourceName string, input []byte) header, headerHash := ctx.SubgraphHeadersBuilder.HeadersForSubgraph(sourceName) keyGen := pool.Hash64.Get() _, _ = keyGen.Write(input) - triggerID = keyGen.Sum64() + headerHash + var b [8]byte + binary.LittleEndian.PutUint64(b[:], headerHash) + _, _ = keyGen.Write(b[:]) + triggerID = keyGen.Sum64() pool.Hash64.Put(keyGen) return header, triggerID } From c579f4898d41ff07ef19746e75dfed4d35d783df Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 20:24:32 +0100 Subject: [PATCH 42/61] chore: fmt --- v2/pkg/engine/resolve/inbound_request_singleflight.go | 1 + 1 file changed, 1 insertion(+) diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index 1dbe8c9a7..6db40dc70 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/cespare/xxhash/v2" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) From 319126c5c61ee5eb75571f9b6af64b52f9aed45a Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 20:35:11 +0100 Subject: [PATCH 43/61] chore: fix test --- .../engine/testdata/complex_nesting_query_with_art.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/execution/engine/testdata/complex_nesting_query_with_art.json b/execution/engine/testdata/complex_nesting_query_with_art.json index 69a208fe4..ec85c1e5c 100644 --- a/execution/engine/testdata/complex_nesting_query_with_art.json +++ b/execution/engine/testdata/complex_nesting_query_with_art.json @@ -170,7 +170,7 @@ "duration_since_start_pretty": "1ns", "duration_load_nanoseconds": 1, "duration_load_pretty": "1ns", - "single_flight_used": false, + "single_flight_used": true, "single_flight_shared_response": false, "load_skipped": false, "load_stats": { @@ -310,7 +310,7 @@ "duration_since_start_pretty": "1ns", "duration_load_nanoseconds": 1, "duration_load_pretty": "1ns", - "single_flight_used": false, + "single_flight_used": true, "single_flight_shared_response": false, "load_skipped": false, "load_stats": { @@ -496,7 +496,7 @@ "duration_since_start_pretty": "1ns", "duration_load_nanoseconds": 1, "duration_load_pretty": "1ns", - "single_flight_used": false, + "single_flight_used": true, "single_flight_shared_response": false, "load_skipped": false, "load_stats": { From 0bf8fb37ad1272532e1c81c9bf4ef0f6b75d7ddf Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 28 Oct 2025 08:54:54 +0100 Subject: [PATCH 44/61] chore: refactor --- v2/pkg/engine/resolve/context.go | 10 ++++++++-- .../resolve/inbound_request_singleflight.go | 7 +++---- v2/pkg/engine/resolve/loader.go | 17 ++++++++++++++--- v2/pkg/engine/resolve/response.go | 13 +++++++++++++ 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index d6a8657e4..5783b29a5 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -67,8 +67,14 @@ type ExecutionOptions struct { IncludeQueryPlanInResponse bool // SendHeartbeat sends regular HeartBeats for Subscriptions SendHeartbeat bool - // DisableRequestDeduplication disables deduplication of requests to the same subgraph with the same input within a single operation execution. - DisableRequestDeduplication bool + // DisableSubgraphRequestDeduplication disables deduplication of requests to the same subgraph with the same input within a single operation execution. + DisableSubgraphRequestDeduplication bool + // DisableInboundRequestDeduplication disables deduplication of inbound client requests + // The engine is hashing the normalized operation, variables, and forwarded headers to achieve robust deduplication + // By default, overhead is negligible and as such this should be false (not disabled) most of the time + // However, if you're benchmarking internals of the engine, it can be helpful to switch it off + // When disabled (set to true) the code becomes a no-op + DisableInboundRequestDeduplication bool } type FieldValue struct { diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index 6db40dc70..f5ad8eb4a 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -5,8 +5,6 @@ import ( "sync" "github.com/cespare/xxhash/v2" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) // InboundRequestSingleFlight is a sharded goroutine safe single flight implementation to de-couple inbound requests @@ -54,13 +52,14 @@ type InflightRequest struct { // GetOrCreate blocks until ctx.ctx.Done() returns or InflightRequest.Done is closed // It returns an error if the leader returned an error // It returns nil,nil if the inbound request is not eligible for request deduplication +// or if DisableSubgraphRequestDeduplication or DisableInboundRequestDeduplication is set to true on Context func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQLResponse) (*InflightRequest, error) { - if ctx.ExecutionOptions.DisableRequestDeduplication { + if ctx.ExecutionOptions.DisableSubgraphRequestDeduplication || ctx.ExecutionOptions.DisableInboundRequestDeduplication { return nil, nil } - if response != nil && response.Info != nil && response.Info.OperationType == ast.OperationTypeMutation { + if !response.SingleFlightAllowed() { return nil, nil } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 88ef6fec2..893b70638 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1625,6 +1625,19 @@ func (l *Loader) headersForSubgraphRequest(fetchItem *FetchItem) (http.Header, u return l.ctx.HeadersForSubgraphRequest(info.DataSourceName) } +func (l *Loader) singleFlightAllowed() bool { + if l.ctx.ExecutionOptions.DisableSubgraphRequestDeduplication { + return false + } + if l.info == nil { + return false + } + if l.info.OperationType == ast.OperationTypeQuery { + return true + } + return false +} + func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem *FetchItem, input []byte, res *result) error { if l.info != nil { @@ -1633,9 +1646,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem headers, extraKey := l.headersForSubgraphRequest(fetchItem) - if l.info == nil || - l.info.OperationType == ast.OperationTypeMutation || - l.ctx.ExecutionOptions.DisableRequestDeduplication { + if !l.singleFlightAllowed() { // Disable single flight for mutations return l.loadByContextDirect(ctx, source, headers, input, res) } diff --git a/v2/pkg/engine/resolve/response.go b/v2/pkg/engine/resolve/response.go index 1efe078cc..d8af8d017 100644 --- a/v2/pkg/engine/resolve/response.go +++ b/v2/pkg/engine/resolve/response.go @@ -43,6 +43,19 @@ type GraphQLResponse struct { DataSources []DataSourceInfo } +func (g *GraphQLResponse) SingleFlightAllowed() bool { + if g == nil { + return false + } + if g.Info == nil { + return false + } + if g.Info.OperationType == ast.OperationTypeQuery { + return true + } + return false +} + type GraphQLResponseInfo struct { OperationType ast.OperationType } From 1ae36b46599e570b4ef31eca674a64c28297040f Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 28 Oct 2025 09:33:24 +0100 Subject: [PATCH 45/61] chore: refactor --- v2/pkg/engine/resolve/inbound_request_singleflight.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index f5ad8eb4a..66505a36a 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -52,10 +52,10 @@ type InflightRequest struct { // GetOrCreate blocks until ctx.ctx.Done() returns or InflightRequest.Done is closed // It returns an error if the leader returned an error // It returns nil,nil if the inbound request is not eligible for request deduplication -// or if DisableSubgraphRequestDeduplication or DisableInboundRequestDeduplication is set to true on Context +// or if DisableInboundRequestDeduplication is set to true on Context func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQLResponse) (*InflightRequest, error) { - if ctx.ExecutionOptions.DisableSubgraphRequestDeduplication || ctx.ExecutionOptions.DisableInboundRequestDeduplication { + if ctx.ExecutionOptions.DisableInboundRequestDeduplication { return nil, nil } From 57e688cc32728979bf942354d2dace6178160763 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 28 Oct 2025 10:06:54 +0100 Subject: [PATCH 46/61] chore: allow single flight in loader for sub Queries, even if root operation type is Mutation or Subscription --- v2/pkg/engine/resolve/loader.go | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 893b70638..a33242bc1 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1625,14 +1625,25 @@ func (l *Loader) headersForSubgraphRequest(fetchItem *FetchItem) (http.Header, u return l.ctx.HeadersForSubgraphRequest(info.DataSourceName) } -func (l *Loader) singleFlightAllowed() bool { +// singleFlightAllowed returns true if the specific GraphQL Operation is a Query +// even if the root operation type is a Mutation or Subscription +// sub-operations can still be of type Query +// even in such cases we allow request de-duplication because such requests are idempotent +func (l *Loader) singleFlightAllowed(fetchItem *FetchItem) bool { if l.ctx.ExecutionOptions.DisableSubgraphRequestDeduplication { return false } - if l.info == nil { + if fetchItem == nil { return false } - if l.info.OperationType == ast.OperationTypeQuery { + if fetchItem.Fetch == nil { + return false + } + info := fetchItem.Fetch.FetchInfo() + if info == nil { + return false + } + if info.OperationType == ast.OperationTypeQuery { return true } return false @@ -1646,7 +1657,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem headers, extraKey := l.headersForSubgraphRequest(fetchItem) - if !l.singleFlightAllowed() { + if !l.singleFlightAllowed(fetchItem) { // Disable single flight for mutations return l.loadByContextDirect(ctx, source, headers, input, res) } From a5e62898bc920a3ac43fb2193149b7d4ec06bbb2 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 28 Oct 2025 16:44:39 +0100 Subject: [PATCH 47/61] chore: merge main --- v2/pkg/engine/resolve/loader.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 16166da84..ee99babd2 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -491,7 +491,7 @@ func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg Fet res.cacheItems[i] = astjson.NullValue continue } - res.cacheItems[i], err = astjson.ParseBytesWithoutCache(cachedItems[i]) + res.cacheItems[i], err = astjson.ParseBytes(cachedItems[i]) if err != nil { return false, errors.WithStack(err) } @@ -596,7 +596,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson } if res.cacheSkippedFetch { for i, item := range res.cacheItems { - _, _, err := astjson.MergeValues(items[i], item) + _, _, err := astjson.MergeValues(l.jsonArena, items[i], item) if err != nil { return l.renderErrorsFailedToFetch(fetchItem, res, "invalid cache item") } From 8f3e30f68444125efe5a83f08cd241f3037f4a11 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 29 Oct 2025 09:12:39 +0100 Subject: [PATCH 48/61] chore: improve arena pool & add tests --- v2/pkg/engine/resolve/arena.go | 16 +- v2/pkg/engine/resolve/arena_test.go | 257 ++++++++++++++++++++++++++++ 2 files changed, 267 insertions(+), 6 deletions(-) create mode 100644 v2/pkg/engine/resolve/arena_test.go diff --git a/v2/pkg/engine/resolve/arena.go b/v2/pkg/engine/resolve/arena.go index cca1f3312..98bd93087 100644 --- a/v2/pkg/engine/resolve/arena.go +++ b/v2/pkg/engine/resolve/arena.go @@ -48,13 +48,17 @@ func (p *ArenaPool) Acquire(id uint64) *ArenaPoolItem { defer p.mu.Unlock() // Try to find an available arena in the pool - for i := 0; i < len(p.pool); i++ { - v := p.pool[i].Value() - p.pool = append(p.pool[:i], p.pool[i+1:]...) - if v == nil { - continue + for len(p.pool) > 0 { + // Pop the last item + lastIdx := len(p.pool) - 1 + wp := p.pool[lastIdx] + p.pool = p.pool[:lastIdx] + + v := wp.Value() + if v != nil { + return v } - return v + // If weak pointer was nil (GC collected), continue to next item } // No arena available, create a new one diff --git a/v2/pkg/engine/resolve/arena_test.go b/v2/pkg/engine/resolve/arena_test.go new file mode 100644 index 000000000..a6bb0f557 --- /dev/null +++ b/v2/pkg/engine/resolve/arena_test.go @@ -0,0 +1,257 @@ +package resolve + +import ( + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/go-arena" +) + +func TestNewArenaPool(t *testing.T) { + pool := NewArenaPool() + + require.NotNil(t, pool, "NewArenaPool returned nil") + assert.Equal(t, 0, len(pool.pool), "expected empty pool") + assert.Equal(t, 0, len(pool.sizes), "expected empty sizes map") +} + +func TestArenaPool_Acquire_EmptyPool(t *testing.T) { + pool := NewArenaPool() + + item := pool.Acquire(1) + + require.NotNil(t, item, "Acquire returned nil") + assert.NotNil(t, item.Arena, "Arena is nil") + + // Verify we can use the arena + buf := arena.NewArenaBuffer(item.Arena) + buf.WriteString("test") + + assert.Equal(t, 0, len(pool.pool), "pool should still be empty") +} + +func TestArenaPool_ReleaseAndAcquire(t *testing.T) { + pool := NewArenaPool() + id := uint64(42) + + // Acquire first arena + item1 := pool.Acquire(id) + + // Use the arena + buf := arena.NewArenaBuffer(item1.Arena) + buf.WriteString("test data") + + // Release it + pool.Release(id, item1) + + // Pool should have one item + assert.Equal(t, 1, len(pool.pool), "expected pool to have 1 item") + + // Acquire from pool + item2 := pool.Acquire(id) + + require.NotNil(t, item2, "Acquire returned nil") + + // Pool should be empty again + assert.Equal(t, 0, len(pool.pool), "expected empty pool after acquire") + + // The acquired arena should be reset and usable + buf2 := arena.NewArenaBuffer(item2.Arena) + buf2.WriteString("new data") + + assert.Equal(t, "new data", buf2.String()) +} + +func TestArenaPool_Acquire_ProvesBugFix(t *testing.T) { + // This test specifically proves the bug fix works + // Creates multiple items, clears some references, then acquires + // to ensure all items are checked without skipping + pool := NewArenaPool() + id := uint64(800) + + numItems := 10 + items := make([]*ArenaPoolItem, numItems) + + // Acquire all items + for i := 0; i < numItems; i++ { + items[i] = pool.Acquire(id) + buf := arena.NewArenaBuffer(items[i].Arena) + _, err := buf.WriteString("item data") + assert.NoError(t, err) + } + + // Release all while keeping strong references + for i := 0; i < numItems; i++ { + pool.Release(id, items[i]) + } + + // Pool should have all items + assert.Equal(t, numItems, len(pool.pool), "expected items in pool") + + // Clear every other item to simulate partial GC + for i := 0; i < numItems; i += 2 { + items[i] = nil + } + + // Force GC + runtime.GC() + runtime.GC() + + // Acquire items - should process ALL items without skipping + processed := 0 + acquired := 0 + + for len(pool.pool) > 0 && processed < numItems*2 { + poolSizeBefore := len(pool.pool) + item := pool.Acquire(id) + poolSizeAfter := len(pool.pool) + processed++ + + assert.Less(t, poolSizeAfter, poolSizeBefore, "Pool size did not decrease - item not removed properly!") + + if item != nil { + acquired++ + } + } + + // Pool should be empty + assert.Equal(t, 0, len(pool.pool), "expected empty pool") +} + +func TestArenaPool_Release_PeakTracking(t *testing.T) { + pool := NewArenaPool() + id := uint64(200) + + // First arena + item1 := pool.Acquire(id) + buf1 := arena.NewArenaBuffer(item1.Arena) + _, err := buf1.WriteString("small") + assert.NoError(t, err) + + peak1 := item1.Arena.Peak() + assert.Equal(t, peak1, 5) + + pool.Release(id, item1) + + // Check that size was tracked + size, exists := pool.sizes[id] + require.True(t, exists, "size tracking not created") + assert.Equal(t, 1, size.count, "expected count 1") + + // Second arena + item2 := pool.Acquire(id) + buf2 := arena.NewArenaBuffer(item2.Arena) + _, err = buf2.WriteString("larger data") + assert.NoError(t, err) + + pool.Release(id, item2) + + // Check updated tracking + assert.Equal(t, 2, size.count, "expected count 2") +} + +func TestArenaPool_GetArenaSize(t *testing.T) { + pool := NewArenaPool() + + // Test default size for unknown ID + size1 := pool.getArenaSize(999) + expectedDefault := 1024 * 1024 + assert.Equal(t, expectedDefault, size1, "expected default size") + + // Test calculated size after usage + id := uint64(400) + item := pool.Acquire(id) + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("some data") + assert.NoError(t, err) + pool.Release(id, item) + + size2 := pool.getArenaSize(id) + assert.NotEqual(t, 0, size2, "expected non-zero size after usage") +} + +func TestArenaPool_MultipleItemsInPool(t *testing.T) { + pool := NewArenaPool() + id := uint64(500) + + // Acquire multiple distinct items + numItems := 3 + items := make([]*ArenaPoolItem, numItems) + + for i := 0; i < numItems; i++ { + items[i] = pool.Acquire(id) + buf := arena.NewArenaBuffer(items[i].Arena) + _, err := buf.WriteString("data") + assert.NoError(t, err) + } + + // Release all while keeping references + for i := 0; i < numItems; i++ { + pool.Release(id, items[i]) + } + + // Should have all items in pool + assert.Equal(t, numItems, len(pool.pool), "expected items in pool") + + // Acquire all back + acquired := 0 + for len(pool.pool) > 0 { + item := pool.Acquire(id) + if item != nil { + acquired++ + } + } + + assert.Equal(t, numItems, acquired, "expected to acquire all items") +} + +func TestArenaPool_Release_MovingWindow(t *testing.T) { + pool := NewArenaPool() + id := uint64(600) + + // Release exactly 50 items + for i := 0; i < 50; i++ { + item := pool.Acquire(id) + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("test data") + assert.NoError(t, err) + pool.Release(id, item) + } + + // After 50 releases, verify count and total + size := pool.sizes[id] + require.NotNil(t, size, "size tracking should exist") + assert.Equal(t, 50, size.count, "expected count to be 50") + + totalBytesAfter50 := size.totalBytes + + // Release one more item to trigger the window reset + item51 := pool.Acquire(id) + buf51 := arena.NewArenaBuffer(item51.Arena) + _, err := buf51.WriteString("test data") + assert.NoError(t, err) + peak51 := item51.Arena.Peak() + pool.Release(id, item51) + + // After 51st release, verify the window was reset + // count should be 2 (reset to 1, then incremented) + // totalBytes should be (totalBytesAfter50 / 50) + peak51 + assert.Equal(t, 2, size.count, "expected count to be 2 after window reset") + + expectedTotalBytes := (totalBytesAfter50 / 50) + peak51 + assert.Equal(t, expectedTotalBytes, size.totalBytes, "expected totalBytes to be divided by 50 and new peak added") + + // Verify we can continue releasing and counting works correctly + for i := 0; i < 10; i++ { + item := pool.Acquire(id) + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("more data") + assert.NoError(t, err) + pool.Release(id, item) + } + + // After 10 more releases, count should be 12 (2 + 10) + assert.Equal(t, 12, size.count, "expected count to continue incrementing after window reset") +} From 3df9e01d9dcf796ad6910a3266ad9a788a8d89a0 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 29 Oct 2025 09:13:04 +0100 Subject: [PATCH 49/61] chore: use arena in Walker --- v2/pkg/astvisitor/visitor.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/v2/pkg/astvisitor/visitor.go b/v2/pkg/astvisitor/visitor.go index a2cbb102d..bd48ad692 100644 --- a/v2/pkg/astvisitor/visitor.go +++ b/v2/pkg/astvisitor/visitor.go @@ -5,6 +5,7 @@ import ( "fmt" "sync" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" @@ -94,6 +95,8 @@ type Walker struct { deferred []func() OnExternalError func(err *operationreport.ExternalError) + + arena arena.Arena } // NewWalker returns a fully initialized Walker @@ -125,6 +128,9 @@ func WalkerFromPool() *Walker { } func (w *Walker) Release() { + if w.arena != nil { + w.arena.Reset() + } w.ResetVisitors() w.Report = nil w.document = nil @@ -1370,6 +1376,11 @@ func (w *Walker) Walk(document, definition *ast.Document, report *operationrepor } else { w.Report = report } + if w.arena == nil { + w.arena = arena.NewMonotonicArena(arena.WithMinBufferSize(64)) + } else { + w.arena.Reset() + } w.Ancestors = w.Ancestors[:0] w.Path = w.Path[:0] w.TypeDefinitions = w.TypeDefinitions[:0] @@ -1822,8 +1833,7 @@ func (w *Walker) walkSelectionSet(ref int, skipFor SkipVisitors) { RefsChanged: for { - refs := make([]int, 0, len(w.document.SelectionSets[ref].SelectionRefs)) - refs = append(refs, w.document.SelectionSets[ref].SelectionRefs...) + refs := arena.SliceAppend(w.arena, nil, w.document.SelectionSets[ref].SelectionRefs...) for i, j := range refs { From aa789e070ea383a384355228b2b22e5061451d50 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 29 Oct 2025 09:14:37 +0100 Subject: [PATCH 50/61] chore: fix lint --- v2/pkg/astvisitor/visitor.go | 1 + v2/pkg/engine/resolve/arena_test.go | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/v2/pkg/astvisitor/visitor.go b/v2/pkg/astvisitor/visitor.go index bd48ad692..86a29c0c7 100644 --- a/v2/pkg/astvisitor/visitor.go +++ b/v2/pkg/astvisitor/visitor.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" diff --git a/v2/pkg/engine/resolve/arena_test.go b/v2/pkg/engine/resolve/arena_test.go index a6bb0f557..20c1069b8 100644 --- a/v2/pkg/engine/resolve/arena_test.go +++ b/v2/pkg/engine/resolve/arena_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/wundergraph/go-arena" ) @@ -27,7 +28,8 @@ func TestArenaPool_Acquire_EmptyPool(t *testing.T) { // Verify we can use the arena buf := arena.NewArenaBuffer(item.Arena) - buf.WriteString("test") + _, err := buf.WriteString("test") + assert.NoError(t, err) assert.Equal(t, 0, len(pool.pool), "pool should still be empty") } @@ -41,7 +43,8 @@ func TestArenaPool_ReleaseAndAcquire(t *testing.T) { // Use the arena buf := arena.NewArenaBuffer(item1.Arena) - buf.WriteString("test data") + _, err := buf.WriteString("test data") + assert.NoError(t, err) // Release it pool.Release(id, item1) @@ -59,7 +62,8 @@ func TestArenaPool_ReleaseAndAcquire(t *testing.T) { // The acquired arena should be reset and usable buf2 := arena.NewArenaBuffer(item2.Arena) - buf2.WriteString("new data") + _, err = buf2.WriteString("new data") + assert.NoError(t, err) assert.Equal(t, "new data", buf2.String()) } From 6d40307333f6bc00743a8dcb333d46f229902845 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 29 Oct 2025 11:14:54 +0100 Subject: [PATCH 51/61] chore: fix test with cache key --- .../graphql_datasource/graphql_datasource.go | 10 +++ .../graphql_datasource_federation_test.go | 86 ++++++++----------- v2/pkg/engine/plan/visitor.go | 22 ++--- 3 files changed, 53 insertions(+), 65 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index ba65e1860..9ae12a0a1 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -83,6 +83,10 @@ type Planner[T Configuration] struct { // to the downstream subgraph fetch. propagatedOperationName string + // caching + + cacheKeyTemplate resolve.CacheKeyTemplate + // federation addedInlineFragments map[onTypeInlineFragment]struct{} @@ -385,6 +389,9 @@ func (p *Planner[T]) ConfigureFetch() resolve.FetchConfiguration { SetTemplateOutputToNullOnVariableNull: requiresEntityFetch || requiresEntityBatchFetch, QueryPlan: p.queryPlan, OperationName: p.propagatedOperationName, + Caching: resolve.FetchCacheConfiguration{ + CacheKeyTemplate: p.cacheKeyTemplate, + }, } } @@ -836,6 +843,9 @@ func (p *Planner[T]) addRepresentationsVariable() { } representationsVariable := resolve.NewResolvableObjectVariable(p.buildRepresentationsVariable()) + p.cacheKeyTemplate = &resolve.EntityQueryCacheKeyTemplate{ + Keys: representationsVariable, + } variable, _ := p.variables.AddVariable(representationsVariable) p.upstreamVariables, _ = sjson.SetRawBytes(p.upstreamVariables, "representations", []byte(fmt.Sprintf("[%s]", variable))) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go index 3943accab..45fe7205d 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go @@ -1558,14 +1558,6 @@ func TestGraphQLDataSourceFederation(t *testing.T) { Input: `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {__typename id info {a b}}}}"}}`, DataSource: &Source{}, PostProcessing: DefaultPostProcessingConfiguration, - Caching: resolve.FetchCacheConfiguration{ - Enabled: true, - CacheName: "default", - TTL: time.Second * 30, - CacheKeyTemplate: &resolve.InputTemplate{ - Segments: []resolve.TemplateSegment{}, - }, - }, }, Info: &resolve.FetchInfo{ DataSourceID: "user.service", @@ -1844,54 +1836,48 @@ func TestGraphQLDataSourceFederation(t *testing.T) { Enabled: true, CacheName: "default", TTL: time.Second * 30, - CacheKeyTemplate: &resolve.InputTemplate{ - Segments: []resolve.TemplateSegment{ - { - SegmentType: resolve.VariableSegmentType, - VariableKind: resolve.ResolvableObjectVariableKind, - Renderer: resolve.NewGraphQLVariableResolveRenderer(&resolve.Object{ - Nullable: true, - Fields: []*resolve.Field{ - { - Name: []byte("__typename"), - OnTypeNames: [][]byte{[]byte("Account")}, - Value: &resolve.String{ - Path: []string{"__typename"}, - }, - }, - { - Name: []byte("id"), - OnTypeNames: [][]byte{[]byte("Account")}, - Value: &resolve.Scalar{ - Path: []string{"id"}, + CacheKeyTemplate: &resolve.EntityQueryCacheKeyTemplate{ + Keys: resolve.NewResolvableObjectVariable(&resolve.Object{ + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("__typename"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.Scalar{ + Path: []string{"id"}, + }, + }, + { + Name: []byte("info"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.Object{ + Path: []string{"info"}, + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("a"), + Value: &resolve.Scalar{ + Path: []string{"a"}, + }, }, - }, - { - Name: []byte("info"), - OnTypeNames: [][]byte{[]byte("Account")}, - Value: &resolve.Object{ - Path: []string{"info"}, - Nullable: true, - Fields: []*resolve.Field{ - { - Name: []byte("a"), - Value: &resolve.Scalar{ - Path: []string{"a"}, - }, - }, - { - Name: []byte("b"), - Value: &resolve.Scalar{ - Path: []string{"b"}, - }, - }, + { + Name: []byte("b"), + Value: &resolve.Scalar{ + Path: []string{"b"}, }, }, }, }, - }), + }, }, - }, + }), }, }, }, diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 46bc10163..57b969909 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -1645,21 +1645,14 @@ func (v *Visitor) configureFetch(internal *objectFetchConfiguration, external re dataSourceType = strings.TrimPrefix(dataSourceType, "*") if !v.Config.DisableEntityCaching { - cacheKeyTemplate := &resolve.InputTemplate{ - SetTemplateOutputToNullOnVariableNull: false, - Segments: make([]resolve.TemplateSegment, len(external.Variables)), - } - - for i, variable := range external.Variables { - segment := variable.TemplateSegment() - cacheKeyTemplate.Segments[i] = segment - } - external.Caching = resolve.FetchCacheConfiguration{ - Enabled: true, - CacheName: "default", - TTL: time.Second * time.Duration(30), - CacheKeyTemplate: cacheKeyTemplate, + if external.RequiresEntityFetch || external.RequiresEntityBatchFetch { + external.Caching = resolve.FetchCacheConfiguration{ + Enabled: true, + CacheName: "default", + TTL: time.Second * time.Duration(30), + CacheKeyTemplate: external.Caching.CacheKeyTemplate, + } } } else { external.Caching = resolve.FetchCacheConfiguration{ @@ -1692,7 +1685,6 @@ func (v *Visitor) configureFetch(internal *objectFetchConfiguration, external re singleFetch.Info.ProvidesData = providesData } } - singleFetch.Info.CoordinateDependencies = v.resolveFetchDependencies(internal.fetchID) if v.Config.DisableIncludeFieldDependencies { return singleFetch } From 9d802ac9d86d3bd887b8d2c6453d92e6451350c8 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 30 Oct 2025 10:13:38 +0100 Subject: [PATCH 52/61] chore: implement multi cache keys --- .../graphql_datasource/graphql_datasource.go | 91 ++- .../graphql_datasource_federation_test.go | 16 + .../graphql_datasource_test.go | 37 +- v2/pkg/engine/plan/visitor.go | 14 +- v2/pkg/engine/resolve/caching.go | 223 ++++-- v2/pkg/engine/resolve/caching_test.go | 727 ++++++++++++++++-- v2/pkg/engine/resolve/loader.go | 20 +- 7 files changed, 973 insertions(+), 155 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 9ae12a0a1..7185e10d0 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -86,6 +86,7 @@ type Planner[T Configuration] struct { // caching cacheKeyTemplate resolve.CacheKeyTemplate + rootFields []resolve.QueryField // tracks root fields and their arguments for cache key generation // federation @@ -379,6 +380,17 @@ func (p *Planner[T]) ConfigureFetch() resolve.FetchConfiguration { } } + // Set cache key template for non-entity calls (root queries) + if !requiresEntityFetch && !requiresEntityBatchFetch { + if len(p.rootFields) > 0 { + rootFieldsCopy := make([]resolve.QueryField, len(p.rootFields)) + copy(rootFieldsCopy, p.rootFields) + p.cacheKeyTemplate = &resolve.RootQueryCacheKeyTemplate{ + RootFields: rootFieldsCopy, + } + } + } + return resolve.FetchConfiguration{ Input: string(input), DataSource: dataSource, @@ -722,6 +734,15 @@ func (p *Planner[T]) EnterField(ref int) { } } + // Track all root fields for cache key generation + if p.isRootField() { + coordinate := resolve.GraphCoordinate{ + TypeName: p.visitor.Walker.EnclosingTypeDefinition.NameString(p.visitor.Definition), + FieldName: fieldName, + } + p.trackCacheKeyCoordinate(coordinate) + } + // store root field name and ref if p.rootFieldName == "" { p.rootFieldName = fieldName @@ -737,6 +758,16 @@ func (p *Planner[T]) EnterField(ref int) { p.addFieldArguments(p.addField(ref), ref, fieldConfiguration) } +// isRootField returns false if an ancestor ast.Node is of kind field +func (p *Planner[T]) isRootField() bool { + for i := 0; i < len(p.visitor.Walker.Ancestors); i++ { + if p.visitor.Walker.Ancestors[i].Kind == ast.NodeKindField { + return false + } + } + return true +} + func (p *Planner[T]) addFieldArguments(upstreamFieldRef int, fieldRef int, fieldConfiguration *plan.FieldConfiguration) { if fieldConfiguration != nil { for i := range fieldConfiguration.Arguments { @@ -746,6 +777,44 @@ func (p *Planner[T]) addFieldArguments(upstreamFieldRef int, fieldRef int, field } } +// trackCacheKeyCoordinate ensures a root field is tracked for cache key generation, +// initializing an empty args slice if it doesn't exist yet +func (p *Planner[T]) trackCacheKeyCoordinate(coordinate resolve.GraphCoordinate) { + + // Check if the field is already tracked + for i := range p.rootFields { + if p.rootFields[i].Coordinate.TypeName == coordinate.TypeName && + p.rootFields[i].Coordinate.FieldName == coordinate.FieldName { + // Field already tracked + return + } + } + // Add the field to the slice + p.rootFields = append(p.rootFields, resolve.QueryField{ + Coordinate: coordinate, + }) +} + +// trackFieldWithArgument adds an argument (name + variable) to the field's tracking for cache key generation +func (p *Planner[T]) trackFieldWithArgument(coordinate resolve.GraphCoordinate, argName string, variable resolve.Variable) { + if coordinate.FieldName == "" { + return + } + // Ensure the field is tracked first + p.trackCacheKeyCoordinate(coordinate) + // Find the field and add the argument + for i := range p.rootFields { + if p.rootFields[i].Coordinate.TypeName == coordinate.TypeName && + p.rootFields[i].Coordinate.FieldName == coordinate.FieldName { + p.rootFields[i].Args = append(p.rootFields[i].Args, resolve.FieldArgument{ + Name: argName, + Variable: variable, + }) + return + } + } +} + func (p *Planner[T]) addCustomField(ref int) (upstreamFieldRef int) { fieldName, alias := p.handleFieldAlias(ref) fieldNode := p.upstreamOperation.AddField(ast.Field{ @@ -827,6 +896,12 @@ func (p *Planner[T]) EnterDocument(_, _ *ast.Document) { p.addDirectivesToVariableDefinitions = map[int][]int{} p.addedInlineFragments = map[onTypeInlineFragment]struct{}{} + + // reset root fields tracking for cache key generation + for i := 0; i < len(p.rootFields); i++ { + p.rootFields[i].Args = nil + } + p.rootFields = p.rootFields[:0] } func (p *Planner[T]) LeaveDocument(_, _ *ast.Document) { @@ -1099,7 +1174,7 @@ func (p *Planner[T]) configureArgument(upstreamFieldRef, downstreamFieldRef int, switch argumentConfiguration.SourceType { case plan.FieldArgumentSource: - p.configureFieldArgumentSource(upstreamFieldRef, downstreamFieldRef, argumentConfiguration) + p.configureFieldArgumentSource(upstreamFieldRef, downstreamFieldRef, fieldConfig, argumentConfiguration) case plan.ObjectFieldSource: p.configureObjectFieldSource(upstreamFieldRef, downstreamFieldRef, fieldConfig, argumentConfiguration) } @@ -1108,7 +1183,7 @@ func (p *Planner[T]) configureArgument(upstreamFieldRef, downstreamFieldRef int, } // configureFieldArgumentSource - creates variables for a plain argument types, in case object or list types goes deep and calls applyInlineFieldArgument -func (p *Planner[T]) configureFieldArgumentSource(upstreamFieldRef, downstreamFieldRef int, argumentConfiguration plan.ArgumentConfiguration) { +func (p *Planner[T]) configureFieldArgumentSource(upstreamFieldRef, downstreamFieldRef int, fieldConfig plan.FieldConfiguration, argumentConfiguration plan.ArgumentConfiguration) { fieldArgument, ok := p.visitor.Operation.FieldArgument(downstreamFieldRef, []byte(argumentConfiguration.Name)) if !ok { return @@ -1130,6 +1205,12 @@ func (p *Planner[T]) configureFieldArgumentSource(upstreamFieldRef, downstreamFi variableValueRef, argRef := p.upstreamOperation.AddVariableValueArgument([]byte(argumentConfiguration.Name), variableName) // add the argument to the field, but don't redefine it p.upstreamOperation.AddArgumentToField(upstreamFieldRef, argRef) + coordinate := resolve.GraphCoordinate{ + TypeName: fieldConfig.TypeName, + FieldName: fieldConfig.FieldName, + } + p.trackFieldWithArgument(coordinate, argumentConfiguration.Name, contextVariable) + if exists { // if the variable exists we don't have to put it onto the variables declaration again, skip return } @@ -1281,6 +1362,12 @@ func (p *Planner[T]) configureObjectFieldSource(upstreamFieldRef, downstreamFiel Renderer: resolve.NewJSONVariableRenderer(), } + coordinate := resolve.GraphCoordinate{ + TypeName: fieldConfiguration.TypeName, + FieldName: fieldConfiguration.FieldName, + } + p.trackFieldWithArgument(coordinate, argumentConfiguration.Name, variable) + objectVariableName, exists := p.variables.AddVariable(variable) if !exists { p.upstreamVariables, _ = sjson.SetRawBytes(p.upstreamVariables, string(variableName), []byte(objectVariableName)) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go index 45fe7205d..73990c556 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go @@ -1558,6 +1558,22 @@ func TestGraphQLDataSourceFederation(t *testing.T) { Input: `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {__typename id info {a b}}}}"}}`, DataSource: &Source{}, PostProcessing: DefaultPostProcessingConfiguration, + Caching: resolve.FetchCacheConfiguration{ + Enabled: true, + CacheName: "default", + TTL: 30 * time.Second, + CacheKeyTemplate: &resolve.RootQueryCacheKeyTemplate{ + RootFields: []resolve.QueryField{ + { + Coordinate: resolve.GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []resolve.FieldArgument{}, + }, + }, + }, + }, }, Info: &resolve.FetchInfo{ DataSourceID: "user.service", diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index fe0f8725a..cccc11f3c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -402,21 +402,40 @@ func TestGraphQLDataSource(t *testing.T) { CacheName: "default", TTL: 30 * time.Second, CacheKeyTemplate: &resolve.RootQueryCacheKeyTemplate{ - Fields: []resolve.CacheKeyQueryRootField{ + RootFields: []resolve.QueryField{ { - Name: "droid", - Args: []resolve.CacheKeyQueryRootFieldArgument{ + Coordinate: resolve.GraphCoordinate{ + TypeName: "Query", + FieldName: "droid", + }, + Args: []resolve.FieldArgument{ { Name: "id", - Variables: resolve.NewVariables( - &resolve.ContextVariable{ - Path: []string{"id"}, - Renderer: resolve.NewJSONVariableRenderer(), - }, - ), + Variable: &resolve.ContextVariable{ + Path: []string{"id"}, + Renderer: resolve.NewJSONVariableRenderer(), + }, }, }, }, + { + Coordinate: resolve.GraphCoordinate{ + TypeName: "Query", + FieldName: "hero", + }, + }, + { + Coordinate: resolve.GraphCoordinate{ + TypeName: "Query", + FieldName: "stringList", + }, + }, + { + Coordinate: resolve.GraphCoordinate{ + TypeName: "Query", + FieldName: "nestedStringList", + }, + }, }, }, }, diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 57b969909..71da3b87e 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -1645,14 +1645,12 @@ func (v *Visitor) configureFetch(internal *objectFetchConfiguration, external re dataSourceType = strings.TrimPrefix(dataSourceType, "*") if !v.Config.DisableEntityCaching { - - if external.RequiresEntityFetch || external.RequiresEntityBatchFetch { - external.Caching = resolve.FetchCacheConfiguration{ - Enabled: true, - CacheName: "default", - TTL: time.Second * time.Duration(30), - CacheKeyTemplate: external.Caching.CacheKeyTemplate, - } + external.Caching = resolve.FetchCacheConfiguration{ + Enabled: true, + CacheName: "default", + TTL: time.Second * time.Duration(30), + // templates come prepared from the DataSource + CacheKeyTemplate: external.Caching.CacheKeyTemplate, } } else { external.Caching = resolve.FetchCacheConfiguration{ diff --git a/v2/pkg/engine/resolve/caching.go b/v2/pkg/engine/resolve/caching.go index 10593982a..d2fe544b5 100644 --- a/v2/pkg/engine/resolve/caching.go +++ b/v2/pkg/engine/resolve/caching.go @@ -1,82 +1,203 @@ package resolve import ( - "bytes" - "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafebytes" ) type CacheKeyTemplate interface { - RenderCacheKey(ctx *Context, data *astjson.Value, out *bytes.Buffer) error + // RenderCacheKeys returns multiple cache keys (one per root field or entity) + // Generates keys for all items at once + RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value) ([]string, error) } type RootQueryCacheKeyTemplate struct { - Fields []CacheKeyQueryRootField + RootFields []QueryField } -type CacheKeyQueryRootField struct { - Name string - Args []CacheKeyQueryRootFieldArgument +type QueryField struct { + Coordinate GraphCoordinate + Args []FieldArgument } -type CacheKeyQueryRootFieldArgument struct { - Name string - Variables InputTemplate +type FieldArgument struct { + Name string + Variable Variable } -func (r *RootQueryCacheKeyTemplate) RenderCacheKey(ctx *Context, data *astjson.Value, out *bytes.Buffer) error { - _, err := out.WriteString("Query") - if err != nil { - return err +// RenderCacheKeys returns multiple cache keys, one per root field per item +func (r *RootQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value) ([]string, error) { + if len(r.RootFields) == 0 { + return nil, nil } - - // Process each field - for _, field := range r.Fields { - _, err = out.WriteString("::") - if err != nil { - return err - } - - // Add field name - _, err = out.WriteString(field.Name) - if err != nil { - return err + // Estimate capacity: each item can generate keys for all root fields + keys := arena.AllocateSlice[string](a, 0, len(r.RootFields)*len(items)) + jsonBytes := arena.AllocateSlice[byte](a, 0, 64) + var ( + key string + ) + for _, item := range items { + for _, field := range r.RootFields { + key, jsonBytes = r.renderField(a, ctx, item, jsonBytes, field) + keys = arena.SliceAppend(a, keys, key) } + } + return keys, nil +} - // Process each argument +// renderField renders a single field cache key as JSON +func (r *RootQueryCacheKeyTemplate) renderField(a arena.Arena, ctx *Context, item *astjson.Value, jsonBytes []byte, field QueryField) (string, []byte) { + // Build JSON object starting with __typename + keyObj := astjson.ObjectValue(a) + typeName := field.Coordinate.TypeName + keyObj.Set(a, "__typename", astjson.StringValue(a, typeName)) + keyObj.Set(a, "field", astjson.StringValue(a, field.Coordinate.FieldName)) + + // Build args object if there are any arguments + if len(field.Args) > 0 { + argsObj := astjson.ObjectValue(a) for _, arg := range field.Args { - // Add argument separator ":" - _, err = out.WriteString(":") - if err != nil { - return err - } - - // Add argument name - _, err = out.WriteString(arg.Name) - if err != nil { - return err - } - - // Add argument separator ":" - _, err = out.WriteString(":") - if err != nil { - return err - } - - err = arg.Variables.Render(ctx, data, out) - if err != nil { - return err + var argValue *astjson.Value + segment := arg.Variable.TemplateSegment() + if segment.Renderer != nil { + switch segment.VariableKind { + case ContextVariableKind: + // Extract value from context variables + variableSourcePath := segment.VariableSourcePath + if len(variableSourcePath) == 1 && ctx.RemapVariables != nil { + if nameToUse, hasMapping := ctx.RemapVariables[variableSourcePath[0]]; hasMapping && nameToUse != variableSourcePath[0] { + variableSourcePath = []string{nameToUse} + } + } + argValue = ctx.Variables.Get(variableSourcePath...) + if argValue == nil { + argValue = astjson.NullValue + } + case ObjectVariableKind: + // Use data parameter for object variables + if item != nil { + value := item.Get(segment.VariableSourcePath...) + if value == nil || value.Type() == astjson.TypeNull { + argValue = astjson.NullValue + } else { + // Values are already JSON-compatible astjson.Value + argValue = value + } + } else { + argValue = astjson.NullValue + } + default: + // For other variable kinds, use data parameter + if item != nil { + argValue = item + } else { + argValue = astjson.NullValue + } + } + } else { + argValue = astjson.NullValue } + argsObj.Set(a, arg.Name, argValue) } + keyObj.Set(a, "args", argsObj) } - return nil + // Marshal to JSON and write to output + jsonBytes = keyObj.MarshalTo(jsonBytes[:0]) + slice := arena.AllocateSlice[byte](a, len(jsonBytes), len(jsonBytes)) + copy(slice, jsonBytes) + return unsafebytes.BytesToString(slice), jsonBytes } type EntityQueryCacheKeyTemplate struct { Keys *ResolvableObjectVariable } -func (e *EntityQueryCacheKeyTemplate) RenderCacheKey(ctx *Context, data *astjson.Value, out *bytes.Buffer) error { - return e.Keys.Renderer.RenderVariable(ctx.ctx, data, out) +// RenderCacheKeys returns one cache key per item for entity queries with keys nested under "keys" +func (e *EntityQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value) ([]string, error) { + jsonBytes := arena.AllocateSlice[byte](a, 0, 64) + keys := arena.AllocateSlice[string](a, 0, len(items)) + + for _, item := range items { + if item == nil { + continue + } + + // Build JSON object starting with __typename + keyObj := astjson.ObjectValue(a) + + // Extract __typename from the data + typename := item.Get("__typename") + if typename == nil { + // Fallback if no __typename in data + keyObj.Set(a, "__typename", astjson.StringValue(a, "Entity")) + } else { + keyObj.Set(a, "__typename", typename) + } + + // Put entity keys under "keys" nested object + keysObj := astjson.ObjectValue(a) + + // Extract only the fields defined in the Keys template (not all fields from data) + if e.Keys != nil && e.Keys.Renderer != nil { + if obj, ok := e.Keys.Renderer.Node.(*Object); ok { + for _, field := range obj.Fields { + fieldName := unsafebytes.BytesToString(field.Name) + // Skip __typename as it's already handled separately + if fieldName == "__typename" { + continue + } + // Resolve field value based on its template definition + fieldValue := e.resolveFieldValue(a, field.Value, item) + if fieldValue != nil && fieldValue.Type() != astjson.TypeNull { + keysObj.Set(a, fieldName, fieldValue) + } + } + } + } + + keyObj.Set(a, "keys", keysObj) + + // Marshal to JSON and write to buffer + jsonBytes = keyObj.MarshalTo(jsonBytes[:0]) + slice := arena.AllocateSlice[byte](a, len(jsonBytes), len(jsonBytes)) + copy(slice, jsonBytes) + keys = arena.SliceAppend(a, keys, unsafebytes.BytesToString(slice)) + } + + return keys, nil +} + +// resolveFieldValue resolves a field value from data based on its template definition +func (e *EntityQueryCacheKeyTemplate) resolveFieldValue(a arena.Arena, valueNode Node, data *astjson.Value) *astjson.Value { + switch node := valueNode.(type) { + case *String: + // Extract string value from data using the path + return data.Get(node.Path...) + case *Object: + // For nested objects, recursively build the object using only template-defined fields + nestedObj := astjson.ObjectValue(a) + // Get the base object from data using the object's path + baseData := data.Get(node.Path...) + if baseData == nil || baseData.Type() == astjson.TypeNull { + return nil + } + // Recursively resolve each field in the nested object template + for _, field := range node.Fields { + fieldName := unsafebytes.BytesToString(field.Name) + // Skip __typename in nested objects + if fieldName == "__typename" { + continue + } + fieldValue := e.resolveFieldValue(a, field.Value, baseData) + if fieldValue != nil && fieldValue.Type() != astjson.TypeNull { + nestedObj.Set(a, fieldName, fieldValue) + } + } + return nestedObj + default: + // For other types not handled above, return nil + return nil + } } diff --git a/v2/pkg/engine/resolve/caching_test.go b/v2/pkg/engine/resolve/caching_test.go index a515e598e..011b2a7c0 100644 --- a/v2/pkg/engine/resolve/caching_test.go +++ b/v2/pkg/engine/resolve/caching_test.go @@ -1,33 +1,52 @@ package resolve import ( - "bytes" "context" "testing" "github.com/stretchr/testify/assert" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" ) func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { + t.Run("single field no arguments", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "users", + }, + Args: []FieldArgument{}, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + assert.NoError(t, err) + assert.Equal(t, []string{`{"__typename":"Query","field":"users"}`}, keys) + }) + t.Run("single field single argument", func(t *testing.T) { tmpl := &RootQueryCacheKeyTemplate{ - Fields: []CacheKeyQueryRootField{ + RootFields: []QueryField{ { - Name: "droid", - Args: []CacheKeyQueryRootFieldArgument{ + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "droid", + }, + Args: []FieldArgument{ { Name: "id", - Variables: InputTemplate{ - SetTemplateOutputToNullOnVariableNull: true, - Segments: []TemplateSegment{ - { - SegmentType: VariableSegmentType, - VariableKind: ContextVariableKind, - VariableSourcePath: []string{"id"}, - Renderer: NewCacheKeyVariableRenderer(), - }, - }, + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), }, }, }, @@ -40,44 +59,63 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - out := &bytes.Buffer{} - err := tmpl.RenderCacheKey(ctx, data, out) + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, `Query::droid:id:1`, out.String()) + assert.Equal(t, []string{`{"__typename":"Query","field":"droid","args":{"id":1}}`}, keys) + }) + + t.Run("single field single string argument", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "name", + Variable: &ContextVariable{ + Path: []string{"name"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"name":"john"}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + assert.NoError(t, err) + assert.Equal(t, []string{`{"__typename":"Query","field":"user","args":{"name":"john"}}`}, keys) }) t.Run("single field multiple arguments", func(t *testing.T) { tmpl := &RootQueryCacheKeyTemplate{ - Fields: []CacheKeyQueryRootField{ + RootFields: []QueryField{ { - Name: "search", - Args: []CacheKeyQueryRootFieldArgument{ + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "search", + }, + Args: []FieldArgument{ { Name: "term", - Variables: InputTemplate{ - SetTemplateOutputToNullOnVariableNull: true, - Segments: []TemplateSegment{ - { - SegmentType: VariableSegmentType, - VariableKind: ContextVariableKind, - VariableSourcePath: []string{"term"}, - Renderer: NewCacheKeyVariableRenderer(), - }, - }, + Variable: &ContextVariable{ + Path: []string{"term"}, + Renderer: NewCacheKeyVariableRenderer(), }, }, { Name: "max", - Variables: InputTemplate{ - SetTemplateOutputToNullOnVariableNull: true, - Segments: []TemplateSegment{ - { - SegmentType: VariableSegmentType, - VariableKind: ContextVariableKind, - VariableSourcePath: []string{"max"}, - Renderer: NewCacheKeyVariableRenderer(), - }, - }, + Variable: &ContextVariable{ + Path: []string{"max"}, + Renderer: NewCacheKeyVariableRenderer(), }, }, }, @@ -89,50 +127,79 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { Variables: astjson.MustParse(`{"term":"C3PO","max":10}`), ctx: context.Background(), } - out := &bytes.Buffer{} data := astjson.MustParse(`{}`) - err := tmpl.RenderCacheKey(ctx, data, out) + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, `Query::search:term:C3PO:max:10`, out.String()) + assert.Equal(t, []string{`{"__typename":"Query","field":"search","args":{"term":"C3PO","max":10}}`}, keys) + }) + + t.Run("single field multiple arguments with boolean", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "products", + }, + Args: []FieldArgument{ + { + Name: "includeDeleted", + Variable: &ContextVariable{ + Path: []string{"includeDeleted"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + { + Name: "limit", + Variable: &ContextVariable{ + Path: []string{"limit"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"includeDeleted":true,"limit":20}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + assert.NoError(t, err) + assert.Equal(t, []string{`{"__typename":"Query","field":"products","args":{"includeDeleted":true,"limit":20}}`}, keys) }) t.Run("multiple fields single argument each", func(t *testing.T) { tmpl := &RootQueryCacheKeyTemplate{ - Fields: []CacheKeyQueryRootField{ + RootFields: []QueryField{ { - Name: "droid", - Args: []CacheKeyQueryRootFieldArgument{ + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "droid", + }, + Args: []FieldArgument{ { Name: "id", - Variables: InputTemplate{ - SetTemplateOutputToNullOnVariableNull: true, - Segments: []TemplateSegment{ - { - SegmentType: VariableSegmentType, - VariableKind: ContextVariableKind, - VariableSourcePath: []string{"id"}, - Renderer: NewCacheKeyVariableRenderer(), - }, - }, + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), }, }, }, }, { - Name: "user", - Args: []CacheKeyQueryRootFieldArgument{ + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ { Name: "name", - Variables: InputTemplate{ - SetTemplateOutputToNullOnVariableNull: true, - Segments: []TemplateSegment{ - { - SegmentType: VariableSegmentType, - VariableKind: ContextVariableKind, - VariableSourcePath: []string{"name"}, - Renderer: NewCacheKeyVariableRenderer(), - }, - }, + Variable: &ContextVariable{ + Path: []string{"name"}, + Renderer: NewCacheKeyVariableRenderer(), }, }, }, @@ -144,10 +211,528 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { Variables: astjson.MustParse(`{"id":1,"name":"john"}`), ctx: context.Background(), } - out := &bytes.Buffer{} data := astjson.MustParse(`{}`) - err := tmpl.RenderCacheKey(ctx, data, out) + + // Test RenderCacheKeys returns multiple keys + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + assert.NoError(t, err) + assert.Equal(t, []string{`{"__typename":"Query","field":"droid","args":{"id":1}}`, `{"__typename":"Query","field":"user","args":{"name":"john"}}`}, keys) + }) + + t.Run("multiple fields with mixed arguments", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "product", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + { + Name: "includeReviews", + Variable: &ContextVariable{ + Path: []string{"includeReviews"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "hero", + }, + Args: []FieldArgument{}, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"id":"123","includeReviews":true}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + + // Test RenderCacheKeys returns multiple keys + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + assert.NoError(t, err) + assert.Equal(t, []string{`{"__typename":"Query","field":"product","args":{"id":"123","includeReviews":true}}`, `{"__typename":"Query","field":"hero"}`}, keys) + }) + + t.Run("field with object variable argument", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "search", + }, + Args: []FieldArgument{ + { + Name: "filter", + Variable: &ObjectVariable{ + Path: []string{"filter"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{"filter":{"category":"electronics","price":100}}`) + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + assert.NoError(t, err) + assert.Equal(t, []string{`{"__typename":"Query","field":"search","args":{"filter":{"category":"electronics","price":100}}}`}, keys) + }) + + t.Run("field with null argument", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"id":null}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + assert.NoError(t, err) + assert.Equal(t, []string{`{"__typename":"Query","field":"user","args":{"id":null}}`}, keys) + }) + + t.Run("field with missing argument", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + assert.NoError(t, err) + assert.Equal(t, []string{`{"__typename":"Query","field":"user","args":{"id":null}}`}, keys) + }) + + t.Run("field with array argument", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "products", + }, + Args: []FieldArgument{ + { + Name: "ids", + Variable: &ContextVariable{ + Path: []string{"ids"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"ids":[1,2,3]}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + assert.NoError(t, err) + assert.Equal(t, []string{`{"__typename":"Query","field":"products","args":{"ids":[1,2,3]}}`}, keys) + }) + + t.Run("non-Query type", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Subscription", + FieldName: "messageAdded", + }, + Args: []FieldArgument{ + { + Name: "roomId", + Variable: &ContextVariable{ + Path: []string{"roomId"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"roomId":"123"}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + assert.NoError(t, err) + assert.Equal(t, []string{`{"__typename":"Subscription","field":"messageAdded","args":{"roomId":"123"}}`}, keys) + }) + + t.Run("single field with arena", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "name", + Variable: &ContextVariable{ + Path: []string{"name"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ar := arena.NewMonotonicArena(arena.WithMinBufferSize(1024)) + ctx := &Context{ + Variables: astjson.MustParse(`{"name":"john"}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + keys, err := tmpl.RenderCacheKeys(ar, ctx, []*astjson.Value{data}) + assert.NoError(t, err) + assert.Equal(t, []string{`{"__typename":"Query","field":"user","args":{"name":"john"}}`}, keys) + }) +} + +func TestCachingRenderEntityQueryCacheKeyTemplate(t *testing.T) { + t.Run("single entity with typename and id", func(t *testing.T) { + tmpl := &EntityQueryCacheKeyTemplate{ + Keys: NewResolvableObjectVariable(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + }, + }), + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{"__typename":"Product","id":"123"}`) + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + assert.NoError(t, err) + assert.Equal(t, []string{`{"__typename":"Product","keys":{"id":"123"}}`}, keys) + }) + + t.Run("single entity with multiple keys", func(t *testing.T) { + tmpl := &EntityQueryCacheKeyTemplate{ + Keys: NewResolvableObjectVariable(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("sku"), + Value: &String{ + Path: []string{"sku"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{"__typename":"Product","sku":"ABC123","upc":"DEF456","name":"Trilby"}`) + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + assert.NoError(t, err) + assert.Equal(t, []string{`{"__typename":"Product","keys":{"sku":"ABC123","upc":"DEF456"}}`}, keys) + }) + + t.Run("entity with nested object key", func(t *testing.T) { + tmpl := &EntityQueryCacheKeyTemplate{ + Keys: NewResolvableObjectVariable(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("key"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &String{ + Path: []string{"key", "id"}, + }, + }, + { + Name: []byte("version"), + Value: &String{ + Path: []string{"key", "version"}, + }, + }, + }, + }, + }, + }, + }), + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{"__typename":"VersionedEntity","key":{"id":"123","version":"1"}}`) + keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, `Query::droid:id:1::user:name:john`, out.String()) + assert.Equal(t, []string{`{"__typename":"VersionedEntity","keys":{"key":{"id":"123","version":"1"}}}`}, keys) + }) +} + +func BenchmarkRenderCacheKeys(b *testing.B) { + a := arena.NewMonotonicArena(arena.WithMinBufferSize(1024)) + + ctxRootQuery := &Context{ + Variables: astjson.MustParse(`{"id":1,"name":"john","term":"C3PO","max":10}`), + ctx: context.Background(), + } + + ctxEntityQuery := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + + b.Run("RootQuery/SingleField", func(b *testing.B) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + data := astjson.MustParse(`{}`) + items := []*astjson.Value{data} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + a.Reset() + _, err := tmpl.RenderCacheKeys(a, ctxRootQuery, items) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run("RootQuery/MultipleFields", func(b *testing.B) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "droid", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "name", + Variable: &ContextVariable{ + Path: []string{"name"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "search", + }, + Args: []FieldArgument{ + { + Name: "term", + Variable: &ContextVariable{ + Path: []string{"term"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + { + Name: "max", + Variable: &ContextVariable{ + Path: []string{"max"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + data := astjson.MustParse(`{}`) + items := []*astjson.Value{data} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + a.Reset() + _, err := tmpl.RenderCacheKeys(a, ctxRootQuery, items) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run("EntityQuery", func(b *testing.B) { + tmpl := &EntityQueryCacheKeyTemplate{ + Keys: NewResolvableObjectVariable(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + { + Name: []byte("sku"), + Value: &String{ + Path: []string{"sku"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + } + + data1 := astjson.MustParse(`{"__typename":"Product","id":"123","sku":"ABC123","upc":"DEF456","name":"Trilby"}`) + data2 := astjson.MustParse(`{"__typename":"Product","id":"456","sku":"XYZ789","upc":"GHI012","name":"Fedora"}`) + data3 := astjson.MustParse(`{"__typename":"Product","id":"789","sku":"JKL345","upc":"MNO678","name":"Boater"}`) + items := []*astjson.Value{data1, data2, data3} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + a.Reset() + _, err := tmpl.RenderCacheKeys(a, ctxEntityQuery, items) + if err != nil { + b.Fatal(err) + } + } }) } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index ee99babd2..f2e52dcc0 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -463,24 +463,16 @@ func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg Fet if res.cache == nil { return false, nil } - res.cacheKeys = make([]string, 0, len(inputItems)) - buf := &bytes.Buffer{} - for _, item := range inputItems { - err = cfg.CacheKeyTemplate.RenderCacheKey(l.ctx, item, buf) - if err != nil { - return false, err - } - if buf.Len() == 0 { - // If the cache key is empty, we skip the cache - continue - } - res.cacheKeys = append(res.cacheKeys, buf.String()) - buf.Reset() + // Generate cache keys for all items at once + keys, err := cfg.CacheKeyTemplate.RenderCacheKeys(nil, l.ctx, inputItems) + if err != nil { + return false, err } - if len(res.cacheKeys) == 0 { + if len(keys) == 0 { // If no cache keys were generated, we skip the cache return false, nil } + res.cacheKeys = keys cachedItems, err := res.cache.Get(ctx, res.cacheKeys) if err != nil { return false, err From 8ec26700ede4ae155aa5770f1f12e454b91c9923 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 30 Oct 2025 13:19:03 +0100 Subject: [PATCH 53/61] chore: refactor cache keys --- v2/pkg/engine/resolve/caching.go | 59 ++++++-- v2/pkg/engine/resolve/caching_test.go | 132 +++++++++++++----- v2/pkg/engine/resolve/loader.go | 81 ++++------- .../engine/resolve/loader_skip_fetch_test.go | 25 +++- 4 files changed, 188 insertions(+), 109 deletions(-) diff --git a/v2/pkg/engine/resolve/caching.go b/v2/pkg/engine/resolve/caching.go index d2fe544b5..af74fc217 100644 --- a/v2/pkg/engine/resolve/caching.go +++ b/v2/pkg/engine/resolve/caching.go @@ -9,7 +9,18 @@ import ( type CacheKeyTemplate interface { // RenderCacheKeys returns multiple cache keys (one per root field or entity) // Generates keys for all items at once - RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value) ([]string, error) + RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value) ([]*CacheKey, error) +} + +type CacheKey struct { + Item *astjson.Value + FromCache *astjson.Value + Keys []KeyEntry +} + +type KeyEntry struct { + Name string + Path string } type RootQueryCacheKeyTemplate struct { @@ -26,24 +37,33 @@ type FieldArgument struct { Variable Variable } -// RenderCacheKeys returns multiple cache keys, one per root field per item -func (r *RootQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value) ([]string, error) { +// RenderCacheKeys returns multiple cache keys, one per item +// Each cache key contains one or more KeyEntry objects (one per root field) +func (r *RootQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value) ([]*CacheKey, error) { if len(r.RootFields) == 0 { return nil, nil } - // Estimate capacity: each item can generate keys for all root fields - keys := arena.AllocateSlice[string](a, 0, len(r.RootFields)*len(items)) + // Estimate capacity: one CacheKey per item + cacheKeys := arena.AllocateSlice[*CacheKey](a, 0, len(items)) jsonBytes := arena.AllocateSlice[byte](a, 0, 64) - var ( - key string - ) + for _, item := range items { + // Create KeyEntry for each root field + keyEntries := arena.AllocateSlice[KeyEntry](a, 0, len(r.RootFields)) for _, field := range r.RootFields { + var key string key, jsonBytes = r.renderField(a, ctx, item, jsonBytes, field) - keys = arena.SliceAppend(a, keys, key) + keyEntries = arena.SliceAppend(a, keyEntries, KeyEntry{ + Name: key, + Path: field.Coordinate.FieldName, + }) } + cacheKeys = arena.SliceAppend(a, cacheKeys, &CacheKey{ + Item: item, + Keys: keyEntries, + }) } - return keys, nil + return cacheKeys, nil } // renderField renders a single field cache key as JSON @@ -115,9 +135,9 @@ type EntityQueryCacheKeyTemplate struct { } // RenderCacheKeys returns one cache key per item for entity queries with keys nested under "keys" -func (e *EntityQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value) ([]string, error) { +func (e *EntityQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value) ([]*CacheKey, error) { jsonBytes := arena.AllocateSlice[byte](a, 0, 64) - keys := arena.AllocateSlice[string](a, 0, len(items)) + cacheKeys := arena.AllocateSlice[*CacheKey](a, 0, len(items)) for _, item := range items { if item == nil { @@ -163,10 +183,21 @@ func (e *EntityQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Contex jsonBytes = keyObj.MarshalTo(jsonBytes[:0]) slice := arena.AllocateSlice[byte](a, len(jsonBytes), len(jsonBytes)) copy(slice, jsonBytes) - keys = arena.SliceAppend(a, keys, unsafebytes.BytesToString(slice)) + + // Create KeyEntry with empty path for entity queries + keyEntries := arena.AllocateSlice[KeyEntry](a, 0, 1) + keyEntries = arena.SliceAppend(a, keyEntries, KeyEntry{ + Name: unsafebytes.BytesToString(slice), + Path: "", + }) + + cacheKeys = arena.SliceAppend(a, cacheKeys, &CacheKey{ + Item: item, + Keys: keyEntries, + }) } - return keys, nil + return cacheKeys, nil } // resolveFieldValue resolves a field value from data based on its template definition diff --git a/v2/pkg/engine/resolve/caching_test.go b/v2/pkg/engine/resolve/caching_test.go index 011b2a7c0..5e1d965e6 100644 --- a/v2/pkg/engine/resolve/caching_test.go +++ b/v2/pkg/engine/resolve/caching_test.go @@ -28,9 +28,13 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Query","field":"users"}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"Query","field":"users"}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "users", cacheKeys[0].Keys[0].Path) }) t.Run("single field single argument", func(t *testing.T) { @@ -59,9 +63,13 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Query","field":"droid","args":{"id":1}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"Query","field":"droid","args":{"id":1}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "droid", cacheKeys[0].Keys[0].Path) }) t.Run("single field single string argument", func(t *testing.T) { @@ -90,9 +98,13 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Query","field":"user","args":{"name":"john"}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"Query","field":"user","args":{"name":"john"}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "user", cacheKeys[0].Keys[0].Path) }) t.Run("single field multiple arguments", func(t *testing.T) { @@ -128,9 +140,13 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Query","field":"search","args":{"term":"C3PO","max":10}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"Query","field":"search","args":{"term":"C3PO","max":10}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "search", cacheKeys[0].Keys[0].Path) }) t.Run("single field multiple arguments with boolean", func(t *testing.T) { @@ -166,9 +182,13 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Query","field":"products","args":{"includeDeleted":true,"limit":20}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"Query","field":"products","args":{"includeDeleted":true,"limit":20}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "products", cacheKeys[0].Keys[0].Path) }) t.Run("multiple fields single argument each", func(t *testing.T) { @@ -214,9 +234,15 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{}`) // Test RenderCacheKeys returns multiple keys - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Query","field":"droid","args":{"id":1}}`, `{"__typename":"Query","field":"user","args":{"name":"john"}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 2) + assert.Equal(t, `{"__typename":"Query","field":"droid","args":{"id":1}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "droid", cacheKeys[0].Keys[0].Path) + assert.Equal(t, `{"__typename":"Query","field":"user","args":{"name":"john"}}`, cacheKeys[0].Keys[1].Name) + assert.Equal(t, "user", cacheKeys[0].Keys[1].Path) }) t.Run("multiple fields with mixed arguments", func(t *testing.T) { @@ -261,9 +287,15 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{}`) // Test RenderCacheKeys returns multiple keys - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Query","field":"product","args":{"id":"123","includeReviews":true}}`, `{"__typename":"Query","field":"hero"}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 2) + assert.Equal(t, `{"__typename":"Query","field":"product","args":{"id":"123","includeReviews":true}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "product", cacheKeys[0].Keys[0].Path) + assert.Equal(t, `{"__typename":"Query","field":"hero"}`, cacheKeys[0].Keys[1].Name) + assert.Equal(t, "hero", cacheKeys[0].Keys[1].Path) }) t.Run("field with object variable argument", func(t *testing.T) { @@ -292,9 +324,13 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{"filter":{"category":"electronics","price":100}}`) - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Query","field":"search","args":{"filter":{"category":"electronics","price":100}}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"Query","field":"search","args":{"filter":{"category":"electronics","price":100}}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "search", cacheKeys[0].Keys[0].Path) }) t.Run("field with null argument", func(t *testing.T) { @@ -323,9 +359,13 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Query","field":"user","args":{"id":null}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"Query","field":"user","args":{"id":null}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "user", cacheKeys[0].Keys[0].Path) }) t.Run("field with missing argument", func(t *testing.T) { @@ -354,9 +394,13 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Query","field":"user","args":{"id":null}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"Query","field":"user","args":{"id":null}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "user", cacheKeys[0].Keys[0].Path) }) t.Run("field with array argument", func(t *testing.T) { @@ -385,9 +429,13 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Query","field":"products","args":{"ids":[1,2,3]}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"Query","field":"products","args":{"ids":[1,2,3]}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "products", cacheKeys[0].Keys[0].Path) }) t.Run("non-Query type", func(t *testing.T) { @@ -416,9 +464,13 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Subscription","field":"messageAdded","args":{"roomId":"123"}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"Subscription","field":"messageAdded","args":{"roomId":"123"}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "messageAdded", cacheKeys[0].Keys[0].Path) }) t.Run("single field with arena", func(t *testing.T) { @@ -448,9 +500,13 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - keys, err := tmpl.RenderCacheKeys(ar, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(ar, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Query","field":"user","args":{"name":"john"}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"Query","field":"user","args":{"name":"john"}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "user", cacheKeys[0].Keys[0].Path) }) } @@ -480,9 +536,13 @@ func TestCachingRenderEntityQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{"__typename":"Product","id":"123"}`) - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Product","keys":{"id":"123"}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"Product","keys":{"id":"123"}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "", cacheKeys[0].Keys[0].Path) }) t.Run("single entity with multiple keys", func(t *testing.T) { @@ -516,9 +576,13 @@ func TestCachingRenderEntityQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{"__typename":"Product","sku":"ABC123","upc":"DEF456","name":"Trilby"}`) - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"Product","keys":{"sku":"ABC123","upc":"DEF456"}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"Product","keys":{"sku":"ABC123","upc":"DEF456"}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "", cacheKeys[0].Keys[0].Path) }) t.Run("entity with nested object key", func(t *testing.T) { @@ -559,9 +623,13 @@ func TestCachingRenderEntityQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{"__typename":"VersionedEntity","key":{"id":"123","version":"1"}}`) - keys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Equal(t, []string{`{"__typename":"VersionedEntity","keys":{"key":{"id":"123","version":"1"}}}`}, keys) + assert.Len(t, cacheKeys, 1) + assert.Equal(t, data, cacheKeys[0].Item) + assert.Len(t, cacheKeys[0].Keys, 1) + assert.Equal(t, `{"__typename":"VersionedEntity","keys":{"key":{"id":"123","version":"1"}}}`, cacheKeys[0].Keys[0].Name) + assert.Equal(t, "", cacheKeys[0].Keys[0].Path) }) } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index f2e52dcc0..40e1c2e25 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -132,8 +132,7 @@ type result struct { cache LoaderCache cacheMustBeUpdated bool - cacheKeys []string - cacheItems []*astjson.Value + cacheKeys []*CacheKey cacheTTL time.Duration cacheSkippedFetch bool } @@ -444,9 +443,9 @@ func (l *Loader) itemsData(items []*astjson.Value) *astjson.Value { } type LoaderCache interface { - Get(ctx context.Context, keys []string) ([][]byte, error) - Set(ctx context.Context, keys []string, items [][]byte, ttl time.Duration) error - Delete(ctx context.Context, keys []string) error + Get(ctx context.Context, keys []*CacheKey) error + Set(ctx context.Context, keys []*CacheKey, ttl time.Duration) error + Delete(ctx context.Context, keys []*CacheKey) error } func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg FetchCacheConfiguration, inputItems []*astjson.Value, res *result) (skipFetch bool, err error) { @@ -464,31 +463,19 @@ func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg Fet return false, nil } // Generate cache keys for all items at once - keys, err := cfg.CacheKeyTemplate.RenderCacheKeys(nil, l.ctx, inputItems) + res.cacheKeys, err = cfg.CacheKeyTemplate.RenderCacheKeys(nil, l.ctx, inputItems) if err != nil { return false, err } - if len(keys) == 0 { + if len(res.cacheKeys) == 0 { // If no cache keys were generated, we skip the cache return false, nil } - res.cacheKeys = keys - cachedItems, err := res.cache.Get(ctx, res.cacheKeys) + err = res.cache.Get(ctx, res.cacheKeys) if err != nil { return false, err } - res.cacheItems = make([]*astjson.Value, len(cachedItems)) - for i := range cachedItems { - if cachedItems[i] == nil { - res.cacheItems[i] = astjson.NullValue - continue - } - res.cacheItems[i], err = astjson.ParseBytes(cachedItems[i]) - if err != nil { - return false, errors.WithStack(err) - } - } - missing, canSkip := l.canSkipFetch(info, res.cacheItems) + missing, canSkip := l.canSkipFetch(info, res) if canSkip { res.cacheSkippedFetch = true return true, nil @@ -587,8 +574,8 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson return nil } if res.cacheSkippedFetch { - for i, item := range res.cacheItems { - _, _, err := astjson.MergeValues(l.jsonArena, items[i], item) + for i, key := range res.cacheKeys { + _, _, err := astjson.MergeValues(l.jsonArena, items[i], key.FromCache) if err != nil { return l.renderErrorsFailedToFetch(fetchItem, res, "invalid cache item") } @@ -602,7 +589,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson return l.renderErrorsFailedToFetch(fetchItem, res, emptyGraphQLResponse) } if res.cacheMustBeUpdated { - defer l.updateCache(res, items) + defer l.updateCache(res) } // before parsing bytes with an arena.Arena, it's important to first allocate the bytes ON the same arena.Arena // this ties their lifecycles together @@ -779,24 +766,13 @@ func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem) []byte { return out.Bytes() } -func (l *Loader) updateCache(res *result, items []*astjson.Value) { - if res.cache == nil || len(res.cacheKeys) == 0 || len(res.cacheItems) == 0 { +func (l *Loader) updateCache(res *result) { + if res.cache == nil || len(res.cacheKeys) == 0 { return } - var ( - keys []string - cacheItems [][]byte - ) - for i, item := range res.cacheItems { - if item != nil && item.Type() == astjson.TypeNull && items[i] != nil && items[i].Type() != astjson.TypeNull { - keys = append(keys, res.cacheKeys[i]) - value := items[i].MarshalTo(nil) - cacheItems = append(cacheItems, value) - } - } - err := res.cache.Set(context.Background(), keys, cacheItems, res.cacheTTL) + err := res.cache.Set(context.Background(), res.cacheKeys, res.cacheTTL) if err != nil { - panic(err) + fmt.Printf("error cache.Set: %s", err) } } @@ -2045,24 +2021,16 @@ func (l *Loader) compactJSON(data []byte) ([]byte, error) { return v.MarshalTo(nil), nil } -func (l *Loader) canSkipFetch(info *FetchInfo, items []*astjson.Value) ([]*astjson.Value, bool) { - if info == nil || info.OperationType != ast.OperationTypeQuery { - return items, false - } - if len(items) == 1 && items[0].Type() == astjson.TypeNull { - return items, true - } - - // If ProvidesData is nil, we cannot validate the data - do not skip fetch - if info.ProvidesData == nil { - return items, false +func (l *Loader) canSkipFetch(info *FetchInfo, res *result) ([]*CacheKey, bool) { + if info == nil || info.OperationType != ast.OperationTypeQuery || info.ProvidesData == nil { + return res.cacheKeys, false } // Check each item and remove those that have sufficient data - remaining := make([]*astjson.Value, 0, len(items)) - for _, item := range items { - if !l.validateItemHasRequiredData(item, info.ProvidesData) { - remaining = append(remaining, item) + remaining := make([]*CacheKey, 0, len(res.cacheKeys)) + for i, key := range res.cacheKeys { + if !l.validateItemHasRequiredData(key.Item, info.ProvidesData) { + remaining = append(remaining, res.cacheKeys[i]) } } @@ -2073,10 +2041,9 @@ func (l *Loader) canSkipFetch(info *FetchInfo, items []*astjson.Value) ([]*astjs // validateItemHasRequiredData checks if the given item contains all required data // as specified by the provided Object schema func (l *Loader) validateItemHasRequiredData(item *astjson.Value, obj *Object) bool { - if obj == nil { - return true + if item == nil { + return false } - // Validate each field in the object for _, field := range obj.Fields { if !l.validateFieldData(item, field) { diff --git a/v2/pkg/engine/resolve/loader_skip_fetch_test.go b/v2/pkg/engine/resolve/loader_skip_fetch_test.go index 0d9a5c649..0afa54931 100644 --- a/v2/pkg/engine/resolve/loader_skip_fetch_test.go +++ b/v2/pkg/engine/resolve/loader_skip_fetch_test.go @@ -16,8 +16,8 @@ func TestLoader_canSkipFetch(t *testing.T) { info *FetchInfo items []*astjson.Value wantResult bool - wantRemaining int // -1 means check for empty, otherwise check exact count - checkFn func(t *testing.T, remaining []*astjson.Value) // optional custom validation + wantRemaining int // -1 means check for empty, otherwise check exact count + checkFn func(t *testing.T, remaining []*CacheKey) // optional custom validation }{ { name: "single item with Query operation", @@ -73,7 +73,7 @@ func TestLoader_canSkipFetch(t *testing.T) { astjson.MustParseBytes([]byte(`null`)), }, wantResult: true, - wantRemaining: 1, // null item remains + wantRemaining: -1, // empty - can skip fetch since no fields required }, { name: "single item with all required data", @@ -321,9 +321,9 @@ func TestLoader_canSkipFetch(t *testing.T) { }, wantResult: false, wantRemaining: 1, - checkFn: func(t *testing.T, remaining []*astjson.Value) { + checkFn: func(t *testing.T, remaining []*CacheKey) { // Check that the remaining item is the incomplete one - user := remaining[0].Get("user") + user := remaining[0].Item.Get("user") assert.Equal(t, "456", string(user.Get("id").GetStringBytes())) }, }, @@ -888,7 +888,20 @@ func TestLoader_canSkipFetch(t *testing.T) { itemsCopy := make([]*astjson.Value, len(tt.items)) copy(itemsCopy, tt.items) - remaining, result := loader.canSkipFetch(tt.info, itemsCopy) + // Create cache keys with Item set to the corresponding test items + cacheKeys := make([]*CacheKey, len(itemsCopy)) + for i, item := range itemsCopy { + cacheKeys[i] = &CacheKey{ + Item: item, + } + } + + // Create a result struct for canSkipFetch + res := &result{ + cacheKeys: cacheKeys, + } + + remaining, result := loader.canSkipFetch(tt.info, res) assert.Equal(t, tt.wantResult, result, "result mismatch") From e15c01f419adcda7ac4feab342e3fe905395b369 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 30 Oct 2025 13:33:32 +0100 Subject: [PATCH 54/61] chore: refactor cache key tests --- v2/pkg/engine/resolve/caching_test.go | 284 ++++++++++++++++++-------- 1 file changed, 198 insertions(+), 86 deletions(-) diff --git a/v2/pkg/engine/resolve/caching_test.go b/v2/pkg/engine/resolve/caching_test.go index 5e1d965e6..d980f1078 100644 --- a/v2/pkg/engine/resolve/caching_test.go +++ b/v2/pkg/engine/resolve/caching_test.go @@ -30,11 +30,18 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{}`) cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"Query","field":"users"}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "users", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Query","field":"users"}`, + Path: "users", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("single field single argument", func(t *testing.T) { @@ -65,11 +72,18 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{}`) cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"Query","field":"droid","args":{"id":1}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "droid", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Query","field":"droid","args":{"id":1}}`, + Path: "droid", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("single field single string argument", func(t *testing.T) { @@ -100,11 +114,18 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{}`) cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"Query","field":"user","args":{"name":"john"}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "user", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Query","field":"user","args":{"name":"john"}}`, + Path: "user", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("single field multiple arguments", func(t *testing.T) { @@ -142,11 +163,18 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{}`) cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"Query","field":"search","args":{"term":"C3PO","max":10}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "search", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Query","field":"search","args":{"term":"C3PO","max":10}}`, + Path: "search", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("single field multiple arguments with boolean", func(t *testing.T) { @@ -184,11 +212,18 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{}`) cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"Query","field":"products","args":{"includeDeleted":true,"limit":20}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "products", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Query","field":"products","args":{"includeDeleted":true,"limit":20}}`, + Path: "products", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("multiple fields single argument each", func(t *testing.T) { @@ -233,16 +268,24 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { } data := astjson.MustParse(`{}`) - // Test RenderCacheKeys returns multiple keys cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 2) - assert.Equal(t, `{"__typename":"Query","field":"droid","args":{"id":1}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "droid", cacheKeys[0].Keys[0].Path) - assert.Equal(t, `{"__typename":"Query","field":"user","args":{"name":"john"}}`, cacheKeys[0].Keys[1].Name) - assert.Equal(t, "user", cacheKeys[0].Keys[1].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Query","field":"droid","args":{"id":1}}`, + Path: "droid", + }, + { + Name: `{"__typename":"Query","field":"user","args":{"name":"john"}}`, + Path: "user", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("multiple fields with mixed arguments", func(t *testing.T) { @@ -286,16 +329,24 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { } data := astjson.MustParse(`{}`) - // Test RenderCacheKeys returns multiple keys cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 2) - assert.Equal(t, `{"__typename":"Query","field":"product","args":{"id":"123","includeReviews":true}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "product", cacheKeys[0].Keys[0].Path) - assert.Equal(t, `{"__typename":"Query","field":"hero"}`, cacheKeys[0].Keys[1].Name) - assert.Equal(t, "hero", cacheKeys[0].Keys[1].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Query","field":"product","args":{"id":"123","includeReviews":true}}`, + Path: "product", + }, + { + Name: `{"__typename":"Query","field":"hero"}`, + Path: "hero", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("field with object variable argument", func(t *testing.T) { @@ -326,11 +377,18 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{"filter":{"category":"electronics","price":100}}`) cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"Query","field":"search","args":{"filter":{"category":"electronics","price":100}}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "search", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Query","field":"search","args":{"filter":{"category":"electronics","price":100}}}`, + Path: "search", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("field with null argument", func(t *testing.T) { @@ -361,11 +419,18 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{}`) cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"Query","field":"user","args":{"id":null}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "user", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Query","field":"user","args":{"id":null}}`, + Path: "user", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("field with missing argument", func(t *testing.T) { @@ -396,11 +461,18 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{}`) cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"Query","field":"user","args":{"id":null}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "user", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Query","field":"user","args":{"id":null}}`, + Path: "user", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("field with array argument", func(t *testing.T) { @@ -431,11 +503,18 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{}`) cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"Query","field":"products","args":{"ids":[1,2,3]}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "products", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Query","field":"products","args":{"ids":[1,2,3]}}`, + Path: "products", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("non-Query type", func(t *testing.T) { @@ -466,11 +545,18 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{}`) cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"Subscription","field":"messageAdded","args":{"roomId":"123"}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "messageAdded", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Subscription","field":"messageAdded","args":{"roomId":"123"}}`, + Path: "messageAdded", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("single field with arena", func(t *testing.T) { @@ -502,11 +588,18 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{}`) cacheKeys, err := tmpl.RenderCacheKeys(ar, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"Query","field":"user","args":{"name":"john"}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "user", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Query","field":"user","args":{"name":"john"}}`, + Path: "user", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) } @@ -538,11 +631,17 @@ func TestCachingRenderEntityQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{"__typename":"Product","id":"123"}`) cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"Product","keys":{"id":"123"}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Product","keys":{"id":"123"}}`, + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("single entity with multiple keys", func(t *testing.T) { @@ -578,11 +677,17 @@ func TestCachingRenderEntityQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{"__typename":"Product","sku":"ABC123","upc":"DEF456","name":"Trilby"}`) cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"Product","keys":{"sku":"ABC123","upc":"DEF456"}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"Product","keys":{"sku":"ABC123","upc":"DEF456"}}`, + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) t.Run("entity with nested object key", func(t *testing.T) { @@ -625,11 +730,18 @@ func TestCachingRenderEntityQueryCacheKeyTemplate(t *testing.T) { data := astjson.MustParse(`{"__typename":"VersionedEntity","key":{"id":"123","version":"1"}}`) cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) assert.NoError(t, err) - assert.Len(t, cacheKeys, 1) - assert.Equal(t, data, cacheKeys[0].Item) - assert.Len(t, cacheKeys[0].Keys, 1) - assert.Equal(t, `{"__typename":"VersionedEntity","keys":{"key":{"id":"123","version":"1"}}}`, cacheKeys[0].Keys[0].Name) - assert.Equal(t, "", cacheKeys[0].Keys[0].Path) + expected := []*CacheKey{ + { + Item: data, + Keys: []KeyEntry{ + { + Name: `{"__typename":"VersionedEntity","keys":{"key":{"id":"123","version":"1"}}}`, + Path: "", + }, + }, + }, + } + assert.Equal(t, expected, cacheKeys) }) } From 9a4ba5be6aade20053a9eba3f268128477ceb93a Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 30 Oct 2025 16:21:17 +0100 Subject: [PATCH 55/61] chore: refactor execution cache test for miss then hit --- execution/engine/federation_caching_test.go | 260 ++++++++++++++++---- v2/pkg/engine/resolve/loader.go | 162 +++++++++++- 2 files changed, 362 insertions(+), 60 deletions(-) diff --git a/execution/engine/federation_caching_test.go b/execution/engine/federation_caching_test.go index 3eb61f977..c01e2f363 100644 --- a/execution/engine/federation_caching_test.go +++ b/execution/engine/federation_caching_test.go @@ -15,7 +15,7 @@ import ( ) func TestFederationCaching(t *testing.T) { - t.Run("query spans multiple federated servers", func(t *testing.T) { + t.Run("two subgraphs - miss then hit", func(t *testing.T) { defaultCache := NewFakeLoaderCache() caches := map[string]resolve.LoaderCache{ "default": defaultCache, @@ -25,19 +25,101 @@ func TestFederationCaching(t *testing.T) { gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) + + // First query - should miss cache and then set + defaultCache.ClearLog() resp := gqlClient.Query(ctx, setup.GatewayServer.URL, testQueryPath("queries/multiple_upstream.query"), nil, t) assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) + + logAfterFirst := defaultCache.GetLog() + assert.Equal(t, 4, len(logAfterFirst)) + + wantLog := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{false}, + }, + { + Operation: "set", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + }, + { + Operation: "get", + Keys: []string{ + `{"__typename":"Product","keys":{"upc":"top-1"}}`, + `{"__typename":"Product","keys":{"upc":"top-2"}}`, + }, + Hits: []bool{false, false}, + }, + { + Operation: "set", + Keys: []string{ + `{"__typename":"Product","keys":{"upc":"top-1"}}`, + `{"__typename":"Product","keys":{"upc":"top-2"}}`, + }, + }, + } + assert.Equal(t, wantLog, logAfterFirst) + + // Second query - should hit cache and then set + defaultCache.ClearLog() resp = gqlClient.Query(ctx, setup.GatewayServer.URL, testQueryPath("queries/multiple_upstream.query"), nil, t) assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) - defaultCache.mu.Lock() - defer defaultCache.mu.Unlock() - _, ok := defaultCache.storage[`{"__typename":"Product","upc":"top-1"}`] - assert.True(t, ok) - _, ok = defaultCache.storage[`{"__typename":"Product","upc":"top-2"}`] - assert.True(t, ok) + + logAfterSecond := defaultCache.GetLog() + assert.Equal(t, 4, len(logAfterSecond)) + + wantLogSecond := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{true}, // Should be a hit now + }, + { + Operation: "set", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + }, + { + Operation: "get", + Keys: []string{ + `{"__typename":"Product","keys":{"upc":"top-1"}}`, + `{"__typename":"Product","keys":{"upc":"top-2"}}`, + }, + Hits: []bool{true, true}, // Should be hits now, no misses + }, + { + Operation: "set", + Keys: []string{ + `{"__typename":"Product","keys":{"upc":"top-1"}}`, + `{"__typename":"Product","keys":{"upc":"top-2"}}`, + }, + }, + } + assert.Equal(t, wantLogSecond, logAfterSecond) }) } +type CacheLogEntry struct { + Operation string // "get", "set", "delete" + Keys []string // Keys involved in the operation + Hits []bool // For Get: whether each key was a hit (true) or miss (false) +} + +// normalizeCacheLog creates a copy of log entries without timestamps for comparison +func normalizeCacheLog(log []CacheLogEntry) []CacheLogEntry { + normalized := make([]CacheLogEntry, len(log)) + for i, entry := range log { + normalized[i] = CacheLogEntry{ + Operation: entry.Operation, + Keys: entry.Keys, + Hits: entry.Hits, + // Timestamp is zero value for comparison + } + } + return normalized +} + type cacheEntry struct { data []byte expiresAt *time.Time @@ -46,11 +128,13 @@ type cacheEntry struct { type FakeLoaderCache struct { mu sync.RWMutex storage map[string]cacheEntry + log []CacheLogEntry } func NewFakeLoaderCache() *FakeLoaderCache { return &FakeLoaderCache{ storage: make(map[string]cacheEntry), + log: make([]CacheLogEntry, 0), } } @@ -63,30 +147,44 @@ func (f *FakeLoaderCache) cleanupExpired() { } } -func (f *FakeLoaderCache) Get(ctx context.Context, keys []string) ([][]byte, error) { +func (f *FakeLoaderCache) Get(ctx context.Context, keys []string) ([]*resolve.CacheEntry, error) { f.mu.Lock() defer f.mu.Unlock() // Clean up expired entries before executing command f.cleanupExpired() - result := make([][]byte, len(keys)) + hits := make([]bool, len(keys)) + result := make([]*resolve.CacheEntry, len(keys)) for i, key := range keys { if entry, exists := f.storage[key]; exists { // Make a copy of the data to prevent external modifications dataCopy := make([]byte, len(entry.data)) copy(dataCopy, entry.data) - result[i] = dataCopy + result[i] = &resolve.CacheEntry{ + Key: key, + Value: dataCopy, + } + hits[i] = true } else { result[i] = nil + hits[i] = false } } + + // Log the operation + f.log = append(f.log, CacheLogEntry{ + Operation: "get", + Keys: keys, + Hits: hits, + }) + return result, nil } -func (f *FakeLoaderCache) Set(ctx context.Context, keys []string, items [][]byte, ttl time.Duration) error { - if len(keys) != len(items) { - return nil // Silently ignore mismatched lengths like Redis would +func (f *FakeLoaderCache) Set(ctx context.Context, entries []*resolve.CacheEntry, ttl time.Duration) error { + if len(entries) == 0 { + return nil } f.mu.Lock() @@ -95,21 +193,34 @@ func (f *FakeLoaderCache) Set(ctx context.Context, keys []string, items [][]byte // Clean up expired entries before executing command f.cleanupExpired() - for i, key := range keys { - entry := cacheEntry{ + keys := make([]string, 0, len(entries)) + for _, entry := range entries { + if entry == nil { + continue + } + cacheEntry := cacheEntry{ // Make a copy of the data to prevent external modifications - data: make([]byte, len(items[i])), + data: make([]byte, len(entry.Value)), } - copy(entry.data, items[i]) + copy(cacheEntry.data, entry.Value) // If ttl is 0, store without expiration if ttl > 0 { expiresAt := time.Now().Add(ttl) - entry.expiresAt = &expiresAt + cacheEntry.expiresAt = &expiresAt } - f.storage[key] = entry + f.storage[entry.Key] = cacheEntry + keys = append(keys, entry.Key) } + + // Log the operation + f.log = append(f.log, CacheLogEntry{ + Operation: "set", + Keys: keys, + Hits: nil, // Set operations don't have hits/misses + }) + return nil } @@ -123,9 +234,33 @@ func (f *FakeLoaderCache) Delete(ctx context.Context, keys []string) error { for _, key := range keys { delete(f.storage, key) } + + // Log the operation + f.log = append(f.log, CacheLogEntry{ + Operation: "delete", + Keys: keys, + Hits: nil, // Delete operations don't have hits/misses + }) + return nil } +// GetLog returns a copy of the cache operation log +func (f *FakeLoaderCache) GetLog() []CacheLogEntry { + f.mu.RLock() + defer f.mu.RUnlock() + logCopy := make([]CacheLogEntry, len(f.log)) + copy(logCopy, f.log) + return logCopy +} + +// ClearLog clears the cache operation log +func (f *FakeLoaderCache) ClearLog() { + f.mu.Lock() + defer f.mu.Unlock() + f.log = make([]CacheLogEntry, 0) +} + // TestFakeLoaderCache tests the cache implementation itself func TestFakeLoaderCache(t *testing.T) { ctx := context.Background() @@ -134,33 +269,45 @@ func TestFakeLoaderCache(t *testing.T) { t.Run("SetAndGet", func(t *testing.T) { // Test basic set and get keys := []string{"key1", "key2", "key3"} - items := [][]byte{[]byte("value1"), []byte("value2"), []byte("value3")} + entries := []*resolve.CacheEntry{ + {Key: "key1", Value: []byte("value1")}, + {Key: "key2", Value: []byte("value2")}, + {Key: "key3", Value: []byte("value3")}, + } - err := cache.Set(ctx, keys, items, 0) // No TTL + err := cache.Set(ctx, entries, 0) // No TTL require.NoError(t, err) // Get all keys result, err := cache.Get(ctx, keys) require.NoError(t, err) require.Len(t, result, 3) - assert.Equal(t, "value1", string(result[0])) - assert.Equal(t, "value2", string(result[1])) - assert.Equal(t, "value3", string(result[2])) + assert.NotNil(t, result[0]) + assert.Equal(t, "value1", string(result[0].Value)) + assert.NotNil(t, result[1]) + assert.Equal(t, "value2", string(result[1].Value)) + assert.NotNil(t, result[2]) + assert.Equal(t, "value3", string(result[2].Value)) // Get partial keys result, err = cache.Get(ctx, []string{"key2", "key4", "key1"}) require.NoError(t, err) require.Len(t, result, 3) - assert.Equal(t, "value2", string(result[0])) + assert.NotNil(t, result[0]) + assert.Equal(t, "value2", string(result[0].Value)) assert.Nil(t, result[1]) // key4 doesn't exist - assert.Equal(t, "value1", string(result[2])) + assert.NotNil(t, result[2]) + assert.Equal(t, "value1", string(result[2].Value)) }) t.Run("Delete", func(t *testing.T) { // Set some keys - keys := []string{"del1", "del2", "del3"} - items := [][]byte{[]byte("v1"), []byte("v2"), []byte("v3")} - err := cache.Set(ctx, keys, items, 0) + entries := []*resolve.CacheEntry{ + {Key: "del1", Value: []byte("v1")}, + {Key: "del2", Value: []byte("v2")}, + {Key: "del3", Value: []byte("v3")}, + } + err := cache.Set(ctx, entries, 0) require.NoError(t, err) // Delete some keys @@ -168,31 +315,36 @@ func TestFakeLoaderCache(t *testing.T) { require.NoError(t, err) // Check remaining keys - result, err := cache.Get(ctx, keys) + result, err := cache.Get(ctx, []string{"del1", "del2", "del3"}) require.NoError(t, err) - assert.Nil(t, result[0]) // del1 was deleted - assert.Equal(t, "v2", string(result[1])) // del2 still exists - assert.Nil(t, result[2]) // del3 was deleted + assert.Nil(t, result[0]) // del1 was deleted + assert.NotNil(t, result[1]) // del2 still exists + assert.Equal(t, "v2", string(result[1].Value)) + assert.Nil(t, result[2]) // del3 was deleted }) t.Run("TTL", func(t *testing.T) { // Set with 50ms TTL - keys := []string{"ttl1", "ttl2"} - items := [][]byte{[]byte("expire1"), []byte("expire2")} - err := cache.Set(ctx, keys, items, 50*time.Millisecond) + entries := []*resolve.CacheEntry{ + {Key: "ttl1", Value: []byte("expire1")}, + {Key: "ttl2", Value: []byte("expire2")}, + } + err := cache.Set(ctx, entries, 50*time.Millisecond) require.NoError(t, err) // Immediately get - should exist - result, err := cache.Get(ctx, keys) + result, err := cache.Get(ctx, []string{"ttl1", "ttl2"}) require.NoError(t, err) - assert.Equal(t, "expire1", string(result[0])) - assert.Equal(t, "expire2", string(result[1])) + assert.NotNil(t, result[0]) + assert.Equal(t, "expire1", string(result[0].Value)) + assert.NotNil(t, result[1]) + assert.Equal(t, "expire2", string(result[1].Value)) // Wait for expiration time.Sleep(60 * time.Millisecond) // Get again - should be nil - result, err = cache.Get(ctx, keys) + result, err = cache.Get(ctx, []string{"ttl1", "ttl2"}) require.NoError(t, err) assert.Nil(t, result[0]) assert.Nil(t, result[1]) @@ -200,10 +352,10 @@ func TestFakeLoaderCache(t *testing.T) { t.Run("MixedTTL", func(t *testing.T) { // Set some with TTL, some without - err := cache.Set(ctx, []string{"perm1"}, [][]byte{[]byte("permanent")}, 0) + err := cache.Set(ctx, []*resolve.CacheEntry{{Key: "perm1", Value: []byte("permanent")}}, 0) require.NoError(t, err) - err = cache.Set(ctx, []string{"temp1"}, [][]byte{[]byte("temporary")}, 50*time.Millisecond) + err = cache.Set(ctx, []*resolve.CacheEntry{{Key: "temp1", Value: []byte("temporary")}}, 50*time.Millisecond) require.NoError(t, err) // Wait for temporary to expire @@ -212,8 +364,9 @@ func TestFakeLoaderCache(t *testing.T) { // Check both result, err := cache.Get(ctx, []string{"perm1", "temp1"}) require.NoError(t, err) - assert.Equal(t, "permanent", string(result[0])) // Still exists - assert.Nil(t, result[1]) // Expired + assert.NotNil(t, result[0]) + assert.Equal(t, "permanent", string(result[0].Value)) // Still exists + assert.Nil(t, result[1]) // Expired }) t.Run("ThreadSafety", func(t *testing.T) { @@ -225,7 +378,7 @@ func TestFakeLoaderCache(t *testing.T) { for i := 0; i < 100; i++ { key := fmt.Sprintf("concurrent_%d", i) value := fmt.Sprintf("value_%d", i) - err := cache.Set(ctx, []string{key}, [][]byte{[]byte(value)}, 0) + err := cache.Set(ctx, []*resolve.CacheEntry{{Key: key, Value: []byte(value)}}, 0) assert.NoError(t, err) } done <- true @@ -261,7 +414,10 @@ func TestFakeLoaderCache(t *testing.T) { // Test that result length always matches input keys length // Set some data - err := cache.Set(ctx, []string{"exist1", "exist3"}, [][]byte{[]byte("data1"), []byte("data3")}, 0) + err := cache.Set(ctx, []*resolve.CacheEntry{ + {Key: "exist1", Value: []byte("data1")}, + {Key: "exist3", Value: []byte("data3")}, + }, 0) require.NoError(t, err) // Request mix of existing and non-existing keys @@ -274,11 +430,13 @@ func TestFakeLoaderCache(t *testing.T) { assert.Len(t, result, 5, "Should return exactly 5 results") // Verify correct values - assert.Equal(t, "data1", string(result[0])) // exist1 - assert.Nil(t, result[1]) // missing1 - assert.Equal(t, "data3", string(result[2])) // exist3 - assert.Nil(t, result[3]) // missing2 - assert.Nil(t, result[4]) // missing3 + assert.NotNil(t, result[0]) + assert.Equal(t, "data1", string(result[0].Value)) // exist1 + assert.Nil(t, result[1]) // missing1 + assert.NotNil(t, result[2]) + assert.Equal(t, "data3", string(result[2].Value)) // exist3 + assert.Nil(t, result[3]) // missing2 + assert.Nil(t, result[4]) // missing3 // Test with all missing keys allMissingKeys := []string{"missing4", "missing5", "missing6"} diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 40e1c2e25..362eec858 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -135,6 +135,7 @@ type result struct { cacheKeys []*CacheKey cacheTTL time.Duration cacheSkippedFetch bool + cacheResponseData *astjson.Value // Response data to cache (set in mergeResult) } func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { @@ -442,10 +443,127 @@ func (l *Loader) itemsData(items []*astjson.Value) *astjson.Value { return arr } +type CacheEntry struct { + Key string + Value []byte +} + type LoaderCache interface { - Get(ctx context.Context, keys []*CacheKey) error - Set(ctx context.Context, keys []*CacheKey, ttl time.Duration) error - Delete(ctx context.Context, keys []*CacheKey) error + Get(ctx context.Context, keys []string) ([]*CacheEntry, error) + Set(ctx context.Context, entries []*CacheEntry, ttl time.Duration) error + Delete(ctx context.Context, keys []string) error +} + +// extractCacheKeysStrings extracts all unique cache key strings from CacheKeys +func extractCacheKeysStrings(cacheKeys []*CacheKey) []string { + if len(cacheKeys) == 0 { + return nil + } + keySet := make(map[string]struct{}) + for _, cacheKey := range cacheKeys { + for _, entry := range cacheKey.Keys { + keySet[entry.Name] = struct{}{} + } + } + keys := make([]string, 0, len(keySet)) + for key := range keySet { + keys = append(keys, key) + } + return keys +} + +// populateFromCache populates CacheKey.FromCache fields from cache entries +func populateFromCache(cacheKeys []*CacheKey, entries []*CacheEntry) error { + // Create a map of key -> value for quick lookup + entryMap := make(map[string][]byte) + for _, entry := range entries { + if entry != nil && entry.Value != nil { + entryMap[entry.Key] = entry.Value + } + } + + // For each CacheKey, find matching entries and populate FromCache + // Since multiple KeyEntries can map to the same value, we use the first match + for _, cacheKey := range cacheKeys { + if cacheKey.FromCache != nil { + // Already populated, skip + continue + } + for _, keyEntry := range cacheKey.Keys { + if cachedValue, found := entryMap[keyEntry.Name]; found { + // Parse the cached JSON value + // Note: We use nil arena here because this is temporary data + // The FromCache will be merged into items which are on the jsonArena + parsedValue, err := astjson.ParseBytes(cachedValue) + if err != nil { + return errors.WithStack(err) + } + cacheKey.FromCache = parsedValue + break // Use first match + } + } + } + return nil +} + +// cacheKeysToEntries converts CacheKeys to CacheEntries for storage +// For each CacheKey, creates entries for all its KeyEntries with the same value +func cacheKeysToEntries(cacheKeys []*CacheKey, responseData *astjson.Value, jsonArena arena.Arena) ([]*CacheEntry, error) { + if len(cacheKeys) == 0 { + return nil, nil + } + + entries := make([]*CacheEntry, 0) + + // Check if responseData is an array + responseArray := responseData.GetArray() + + if responseArray != nil && len(responseArray) > 1 { + // Multiple items: extract per-item data from batch response + if len(responseArray) != len(cacheKeys) { + return nil, errors.Errorf("cache key count (%d) doesn't match response array length (%d)", len(cacheKeys), len(responseArray)) + } + + // For each CacheKey, serialize its corresponding item and store under all its KeyEntries + for i, cacheKey := range cacheKeys { + itemData := responseArray[i] + itemBytes := itemData.MarshalTo(nil) + + for _, keyEntry := range cacheKey.Keys { + valueCopy := make([]byte, len(itemBytes)) + copy(valueCopy, itemBytes) + entries = append(entries, &CacheEntry{ + Key: keyEntry.Name, + Value: valueCopy, + }) + } + } + } else { + // Single item: store same value under all keys + // This handles both single object and single-item array cases + var dataToStore *astjson.Value + if responseArray != nil && len(responseArray) == 1 { + dataToStore = responseArray[0] + } else { + dataToStore = responseData + } + + dataBytes := dataToStore.MarshalTo(nil) + + // Store under all KeyEntries for all CacheKeys + for _, cacheKey := range cacheKeys { + for _, keyEntry := range cacheKey.Keys { + valueCopy := make([]byte, len(dataBytes)) + copy(valueCopy, dataBytes) + entries = append(entries, &CacheEntry{ + Key: keyEntry.Name, + Value: valueCopy, + }) + } + } + } + + return entries, nil } func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg FetchCacheConfiguration, inputItems []*astjson.Value, res *result) (skipFetch bool, err error) { @@ -471,7 +589,18 @@ func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg Fet // If no cache keys were generated, we skip the cache return false, nil } - err = res.cache.Get(ctx, res.cacheKeys) + // Extract all unique cache key strings + cacheKeyStrings := extractCacheKeysStrings(res.cacheKeys) + if len(cacheKeyStrings) == 0 { + return false, nil + } + // Get cache entries + cacheEntries, err := res.cache.Get(ctx, cacheKeyStrings) + if err != nil { + return false, err + } + // Populate FromCache fields in CacheKeys + err = populateFromCache(res.cacheKeys, cacheEntries) if err != nil { return false, err } @@ -588,9 +717,6 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson if len(res.out) == 0 { return l.renderErrorsFailedToFetch(fetchItem, res, emptyGraphQLResponse) } - if res.cacheMustBeUpdated { - defer l.updateCache(res) - } // before parsing bytes with an arena.Arena, it's important to first allocate the bytes ON the same arena.Arena // this ties their lifecycles together // if you don't do this, you'll get segfaults @@ -612,6 +738,12 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson responseData = response } + // Store responseData for caching if needed + if res.cacheMustBeUpdated { + res.cacheResponseData = responseData + defer l.updateCache(res) + } + hasErrors := false var taintedIndices []int @@ -767,10 +899,22 @@ func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem) []byte { } func (l *Loader) updateCache(res *result) { - if res.cache == nil || len(res.cacheKeys) == 0 { + if res.cache == nil || len(res.cacheKeys) == 0 || res.cacheResponseData == nil { return } - err := res.cache.Set(context.Background(), res.cacheKeys, res.cacheTTL) + + // Convert CacheKeys to CacheEntries + cacheEntries, err := cacheKeysToEntries(res.cacheKeys, res.cacheResponseData, l.jsonArena) + if err != nil { + fmt.Printf("error converting cache keys to entries: %s", err) + return + } + + if len(cacheEntries) == 0 { + return + } + + err = res.cache.Set(context.Background(), cacheEntries, res.cacheTTL) if err != nil { fmt.Printf("error cache.Set: %s", err) } From 7547964ff42596781059269b17b5b42d1a1db913 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 30 Oct 2025 16:53:12 +0100 Subject: [PATCH 56/61] chore: expand federation caching tests --- execution/engine/federation_caching_test.go | 386 +++++++++++++++++++- execution/engine/graphql_client_test.go | 16 + 2 files changed, 397 insertions(+), 5 deletions(-) diff --git a/execution/engine/federation_caching_test.go b/execution/engine/federation_caching_test.go index c01e2f363..07da97ea8 100644 --- a/execution/engine/federation_caching_test.go +++ b/execution/engine/federation_caching_test.go @@ -4,13 +4,20 @@ import ( "context" "fmt" "net/http" + "net/http/httptest" + "net/url" + "path" + "sort" + "strings" "sync" "testing" "time" + "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/wundergraph/graphql-go-tools/execution/federationtesting" + "github.com/wundergraph/graphql-go-tools/execution/federationtesting/gateway" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -20,15 +27,31 @@ func TestFederationCaching(t *testing.T) { caches := map[string]resolve.LoaderCache{ "default": defaultCache, } - setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false), withLoaderCache(caches))) + + // Create HTTP client with tracking + tracker := newSubgraphCallTracker(http.DefaultTransport) + trackingClient := &http.Client{ + Transport: tracker, + } + + setup := federationtesting.NewFederationSetup(addCachingGateway(withCachingEnableART(false), withCachingLoaderCache(caches), withHTTPClient(trackingClient))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) + // Extract hostnames for tracking (URL.Host includes host:port) + accountsURLParsed, _ := url.Parse(setup.AccountsUpstreamServer.URL) + productsURLParsed, _ := url.Parse(setup.ProductsUpstreamServer.URL) + reviewsURLParsed, _ := url.Parse(setup.ReviewsUpstreamServer.URL) + accountsHost := accountsURLParsed.Host + productsHost := productsURLParsed.Host + reviewsHost := reviewsURLParsed.Host + // First query - should miss cache and then set defaultCache.ClearLog() - resp := gqlClient.Query(ctx, setup.GatewayServer.URL, testQueryPath("queries/multiple_upstream.query"), nil, t) + tracker.Reset() + resp := gqlClient.Query(ctx, setup.GatewayServer.URL, cachingTestQueryPath("queries/multiple_upstream.query"), nil, t) assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) logAfterFirst := defaultCache.GetLog() @@ -60,11 +83,23 @@ func TestFederationCaching(t *testing.T) { }, }, } - assert.Equal(t, wantLog, logAfterFirst) + assert.Equal(t, sortCacheLogKeys(wantLog), sortCacheLogKeys(logAfterFirst)) + + // Verify subgraph calls for first query + // First query should call products (topProducts) and reviews (reviews) + // Accounts is not called directly because username is provided via reviews @provides + productsCallsFirst := tracker.GetCount(productsHost) + reviewsCallsFirst := tracker.GetCount(reviewsHost) + accountsCallsFirst := tracker.GetCount(accountsHost) + + assert.Equal(t, 1, productsCallsFirst, "First query should call products subgraph exactly once") + assert.Equal(t, 1, reviewsCallsFirst, "First query should call reviews subgraph exactly once") + assert.Equal(t, 0, accountsCallsFirst, "First query should not call accounts subgraph (username provided via reviews @provides)") // Second query - should hit cache and then set defaultCache.ClearLog() - resp = gqlClient.Query(ctx, setup.GatewayServer.URL, testQueryPath("queries/multiple_upstream.query"), nil, t) + tracker.Reset() + resp = gqlClient.Query(ctx, setup.GatewayServer.URL, cachingTestQueryPath("queries/multiple_upstream.query"), nil, t) assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) logAfterSecond := defaultCache.GetLog() @@ -96,8 +131,302 @@ func TestFederationCaching(t *testing.T) { }, }, } - assert.Equal(t, wantLogSecond, logAfterSecond) + assert.Equal(t, sortCacheLogKeys(wantLogSecond), sortCacheLogKeys(logAfterSecond)) + + // Verify subgraph calls for second query + productsCallsSecond := tracker.GetCount(productsHost) + reviewsCallsSecond := tracker.GetCount(reviewsHost) + accountsCallsSecond := tracker.GetCount(accountsHost) + + assert.Equal(t, 1, productsCallsSecond, "Second query should hit cache and not call products subgraph again") + assert.Equal(t, 1, reviewsCallsSecond, "Second query should hit cache and not call reviews subgraph again") + assert.Equal(t, 0, accountsCallsSecond, "accounts not involved") }) + + t.Run("two subgraphs - partial fields then full fields", func(t *testing.T) { + defaultCache := NewFakeLoaderCache() + caches := map[string]resolve.LoaderCache{ + "default": defaultCache, + } + + // Create HTTP client with tracking + tracker := newSubgraphCallTracker(http.DefaultTransport) + trackingClient := &http.Client{ + Transport: tracker, + } + + setup := federationtesting.NewFederationSetup(addCachingGateway(withCachingEnableART(false), withCachingLoaderCache(caches), withHTTPClient(trackingClient))) + t.Cleanup(setup.Close) + gqlClient := NewGraphqlClient(http.DefaultClient) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Extract hostnames for tracking (URL.Host includes host:port) + accountsURLParsed, _ := url.Parse(setup.AccountsUpstreamServer.URL) + productsURLParsed, _ := url.Parse(setup.ProductsUpstreamServer.URL) + reviewsURLParsed, _ := url.Parse(setup.ReviewsUpstreamServer.URL) + accountsHost := accountsURLParsed.Host + productsHost := productsURLParsed.Host + reviewsHost := reviewsURLParsed.Host + + // First query - only ask for name field (products subgraph only) + defaultCache.ClearLog() + tracker.Reset() + firstQuery := `query { + topProducts { + name + } + }` + resp := gqlClient.QueryString(ctx, setup.GatewayServer.URL, firstQuery, nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby"},{"name":"Fedora"}]}}`, string(resp)) + + logAfterFirst := defaultCache.GetLog() + assert.Equal(t, 2, len(logAfterFirst)) + + wantLogFirst := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{false}, + }, + { + Operation: "set", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + }, + } + assert.Equal(t, sortCacheLogKeys(wantLogFirst), sortCacheLogKeys(logAfterFirst)) + + // Verify first query calls products subgraph only + productsCallsFirst := tracker.GetCount(productsHost) + reviewsCallsFirst := tracker.GetCount(reviewsHost) + accountsCallsFirst := tracker.GetCount(accountsHost) + assert.Equal(t, 1, productsCallsFirst, "First query calls products subgraph once") + assert.Equal(t, 0, reviewsCallsFirst, "First query does not call reviews subgraph") + assert.Equal(t, 0, accountsCallsFirst, "First query does not call accounts subgraph") + + // Second query - ask for full fields including reviews (products + reviews + accounts) + defaultCache.ClearLog() + tracker.Reset() + secondQuery := `query { + topProducts { + name + reviews { + body + author { + username + } + } + } + }` + resp = gqlClient.QueryString(ctx, setup.GatewayServer.URL, secondQuery, nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) + + logAfterSecond := defaultCache.GetLog() + assert.Equal(t, 4, len(logAfterSecond)) + + wantLogSecond := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{true}, // Should be a hit from first query + }, + { + Operation: "set", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + }, + { + Operation: "get", + Keys: []string{ + `{"__typename":"Product","keys":{"upc":"top-1"}}`, + `{"__typename":"Product","keys":{"upc":"top-2"}}`, + }, + Hits: []bool{false, false}, // Miss because second query requests different fields (reviews) + }, + { + Operation: "set", + Keys: []string{ + `{"__typename":"Product","keys":{"upc":"top-1"}}`, + `{"__typename":"Product","keys":{"upc":"top-2"}}`, + }, + }, + } + assert.Equal(t, sortCacheLogKeys(wantLogSecond), sortCacheLogKeys(logAfterSecond)) + + // Verify second query: products name is cached, but reviews still need to be fetched + productsCallsSecond := tracker.GetCount(productsHost) + reviewsCallsSecond := tracker.GetCount(reviewsHost) + accountsCallsSecond := tracker.GetCount(accountsHost) + + assert.Equal(t, 1, productsCallsSecond, "Second query calls products subgraph once (for reviews data)") + assert.Equal(t, 1, reviewsCallsSecond, "Second query calls reviews subgraph once (reviews not cached)") + assert.Equal(t, 0, accountsCallsSecond, "Second query does not call accounts subgraph") + + // Third query - repeat the second query (full fields) + defaultCache.ClearLog() + tracker.Reset() + thirdQuery := `query { + topProducts { + name + reviews { + body + author { + username + } + } + } + }` + resp = gqlClient.QueryString(ctx, setup.GatewayServer.URL, thirdQuery, nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) + + logAfterThird := defaultCache.GetLog() + assert.Equal(t, 4, len(logAfterThird)) + + wantLogThird := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{true}, // Should be a hit from second query + }, + { + Operation: "set", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + }, + { + Operation: "get", + Keys: []string{ + `{"__typename":"Product","keys":{"upc":"top-1"}}`, + `{"__typename":"Product","keys":{"upc":"top-2"}}`, + }, + Hits: []bool{true, true}, // Should be hits from second query + }, + { + Operation: "set", + Keys: []string{ + `{"__typename":"Product","keys":{"upc":"top-1"}}`, + `{"__typename":"Product","keys":{"upc":"top-2"}}`, + }, + }, + } + assert.Equal(t, sortCacheLogKeys(wantLogThird), sortCacheLogKeys(logAfterThird)) + + // Verify third query: all data should be cached, no subgraph calls + productsCallsThird := tracker.GetCount(productsHost) + reviewsCallsThird := tracker.GetCount(reviewsHost) + accountsCallsThird := tracker.GetCount(accountsHost) + + // All cache entries show hits, so no subgraph calls should be made + assert.Equal(t, 0, productsCallsThird, "Third query does not call products subgraph (all cache hits)") + assert.Equal(t, 0, reviewsCallsThird, "Third query does not call reviews subgraph (all cache hits)") + assert.Equal(t, 0, accountsCallsThird, "Third query does not call accounts subgraph") + }) +} + +// subgraphCallTracker tracks HTTP requests made to subgraph servers +type subgraphCallTracker struct { + mu sync.RWMutex + counts map[string]int // Maps subgraph URL to call count + original http.RoundTripper +} + +func newSubgraphCallTracker(original http.RoundTripper) *subgraphCallTracker { + return &subgraphCallTracker{ + counts: make(map[string]int), + original: original, + } +} + +func (t *subgraphCallTracker) RoundTrip(req *http.Request) (*http.Response, error) { + t.mu.Lock() + host := req.URL.Host + t.counts[host]++ + t.mu.Unlock() + return t.original.RoundTrip(req) +} + +func (t *subgraphCallTracker) GetCount(url string) int { + t.mu.RLock() + defer t.mu.RUnlock() + return t.counts[url] +} + +func (t *subgraphCallTracker) Reset() { + t.mu.Lock() + defer t.mu.Unlock() + t.counts = make(map[string]int) +} + +func (t *subgraphCallTracker) GetCounts() map[string]int { + t.mu.RLock() + defer t.mu.RUnlock() + result := make(map[string]int) + for k, v := range t.counts { + result[k] = v + } + return result +} + +func (t *subgraphCallTracker) DebugPrint() string { + t.mu.RLock() + defer t.mu.RUnlock() + return fmt.Sprintf("%v", t.counts) +} + +// Helper functions for gateway setup with HTTP client support +type cachingGatewayOptions struct { + enableART bool + withLoaderCache map[string]resolve.LoaderCache + httpClient *http.Client +} + +func withCachingEnableART(enableART bool) func(*cachingGatewayOptions) { + return func(opts *cachingGatewayOptions) { + opts.enableART = enableART + } +} + +func withCachingLoaderCache(loaderCache map[string]resolve.LoaderCache) func(*cachingGatewayOptions) { + return func(opts *cachingGatewayOptions) { + opts.withLoaderCache = loaderCache + } +} + +func withHTTPClient(client *http.Client) func(*cachingGatewayOptions) { + return func(opts *cachingGatewayOptions) { + opts.httpClient = client + } +} + +type cachingGatewayOptionsToFunc func(opts *cachingGatewayOptions) + +func addCachingGateway(options ...cachingGatewayOptionsToFunc) func(setup *federationtesting.FederationSetup) *httptest.Server { + opts := &cachingGatewayOptions{} + for _, option := range options { + option(opts) + } + return func(setup *federationtesting.FederationSetup) *httptest.Server { + httpClient := opts.httpClient + if httpClient == nil { + httpClient = http.DefaultClient + } + + poller := gateway.NewDatasource([]gateway.ServiceConfig{ + {Name: "accounts", URL: setup.AccountsUpstreamServer.URL}, + {Name: "products", URL: setup.ProductsUpstreamServer.URL, WS: strings.ReplaceAll(setup.ProductsUpstreamServer.URL, "http:", "ws:")}, + {Name: "reviews", URL: setup.ReviewsUpstreamServer.URL}, + }, httpClient) + + gtw := gateway.Handler(abstractlogger.NoopLogger, poller, httpClient, opts.enableART, opts.withLoaderCache) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + poller.Run(ctx) + return httptest.NewServer(gtw) + } +} + +func cachingTestQueryPath(name string) string { + return path.Join("..", "federationtesting", "testdata", name) } type CacheLogEntry struct { @@ -120,6 +449,53 @@ func normalizeCacheLog(log []CacheLogEntry) []CacheLogEntry { return normalized } +// sortCacheLogKeys sorts the keys (and corresponding hits) in each cache log entry +// This makes comparisons order-independent when multiple keys are present +func sortCacheLogKeys(log []CacheLogEntry) []CacheLogEntry { + sorted := make([]CacheLogEntry, len(log)) + for i, entry := range log { + // Only sort if there are multiple keys + if len(entry.Keys) <= 1 { + sorted[i] = entry + continue + } + + // Create pairs of (key, hit) to sort together + pairs := make([]struct { + key string + hit bool + }, len(entry.Keys)) + for j := range entry.Keys { + pairs[j].key = entry.Keys[j] + if entry.Hits != nil && j < len(entry.Hits) { + pairs[j].hit = entry.Hits[j] + } + } + + // Sort pairs by key + sort.Slice(pairs, func(a, b int) bool { + return pairs[a].key < pairs[b].key + }) + + // Extract sorted keys and hits + sorted[i] = CacheLogEntry{ + Operation: entry.Operation, + Keys: make([]string, len(pairs)), + Hits: nil, + } + if entry.Hits != nil && len(entry.Hits) > 0 { + sorted[i].Hits = make([]bool, len(pairs)) + } + for j := range pairs { + sorted[i].Keys[j] = pairs[j].key + if sorted[i].Hits != nil { + sorted[i].Hits[j] = pairs[j].hit + } + } + } + return sorted +} + type cacheEntry struct { data []byte expiresAt *time.Time diff --git a/execution/engine/graphql_client_test.go b/execution/engine/graphql_client_test.go index 23ed0c6e3..40b0018ac 100644 --- a/execution/engine/graphql_client_test.go +++ b/execution/engine/graphql_client_test.go @@ -74,6 +74,22 @@ func (g *GraphqlClient) Query(ctx context.Context, addr, queryFilePath string, v return responseBodyBytes } +func (g *GraphqlClient) QueryString(ctx context.Context, addr, query string, variables queryVariables, t *testing.T) []byte { + reqBody := requestBody(t, query, variables) + req, err := http.NewRequest(http.MethodPost, addr, bytes.NewBuffer(reqBody)) + require.NoError(t, err) + req = req.WithContext(ctx) + resp, err := g.httpClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + responseBodyBytes, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Contains(t, resp.Header.Get("Content-Type"), "application/json") + + return responseBodyBytes +} + func (g *GraphqlClient) QueryStatusCode(ctx context.Context, addr, queryFilePath string, variables queryVariables, expectedStatusCode int, t *testing.T) []byte { reqBody := loadQuery(t, queryFilePath, variables) req, err := http.NewRequest(http.MethodPost, addr, bytes.NewBuffer(reqBody)) From 5ce59bae170b2325c6fb1870496a3bbd94784866 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 30 Oct 2025 17:18:40 +0100 Subject: [PATCH 57/61] chore: don't save to cache when we didn't fetch from origin --- execution/engine/federation_caching_test.go | 13 +----- v2/pkg/engine/resolve/loader.go | 52 ++++++++++++++++++++- 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/execution/engine/federation_caching_test.go b/execution/engine/federation_caching_test.go index 07da97ea8..e08f34a44 100644 --- a/execution/engine/federation_caching_test.go +++ b/execution/engine/federation_caching_test.go @@ -279,7 +279,7 @@ func TestFederationCaching(t *testing.T) { assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) logAfterThird := defaultCache.GetLog() - assert.Equal(t, 4, len(logAfterThird)) + assert.Equal(t, 2, len(logAfterThird)) wantLogThird := []CacheLogEntry{ { @@ -287,10 +287,6 @@ func TestFederationCaching(t *testing.T) { Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, Hits: []bool{true}, // Should be a hit from second query }, - { - Operation: "set", - Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, - }, { Operation: "get", Keys: []string{ @@ -299,13 +295,6 @@ func TestFederationCaching(t *testing.T) { }, Hits: []bool{true, true}, // Should be hits from second query }, - { - Operation: "set", - Keys: []string{ - `{"__typename":"Product","keys":{"upc":"top-1"}}`, - `{"__typename":"Product","keys":{"upc":"top-2"}}`, - }, - }, } assert.Equal(t, sortCacheLogKeys(wantLogThird), sortCacheLogKeys(logAfterThird)) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 362eec858..19ca10da5 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -604,6 +604,7 @@ func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg Fet if err != nil { return false, err } + res.cacheTTL = cfg.TTL missing, canSkip := l.canSkipFetch(info, res) if canSkip { res.cacheSkippedFetch = true @@ -703,12 +704,35 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson return nil } if res.cacheSkippedFetch { + // Merge cached data into items + mergedData := make([]*astjson.Value, len(res.cacheKeys)) for i, key := range res.cacheKeys { - _, _, err := astjson.MergeValues(l.jsonArena, items[i], key.FromCache) + // Merge cached data into item + merged, _, err := astjson.MergeValues(l.jsonArena, items[i], key.FromCache) if err != nil { return l.renderErrorsFailedToFetch(fetchItem, res, "invalid cache item") } + mergedData[i] = merged + } + + // Update cache with merged data to refresh TTL, even when skipping fetch + if res.cacheMustBeUpdated && len(mergedData) > 0 { + // Construct responseData from merged items for cache update + // For batch responses, create an array; for single items, use the first item + var responseData *astjson.Value + if len(mergedData) == 1 { + responseData = mergedData[0] + } else { + // Create array from merged items + responseData = astjson.ArrayValue(l.jsonArena) + for i, item := range mergedData { + responseData.SetArrayItem(l.jsonArena, i, item) + } + } + res.cacheResponseData = responseData + defer l.updateCache(res) } + return nil } if res.fetchSkipped { @@ -2173,7 +2197,31 @@ func (l *Loader) canSkipFetch(info *FetchInfo, res *result) ([]*CacheKey, bool) // Check each item and remove those that have sufficient data remaining := make([]*CacheKey, 0, len(res.cacheKeys)) for i, key := range res.cacheKeys { - if !l.validateItemHasRequiredData(key.Item, info.ProvidesData) { + // When we have cached data, we should check if merging Item + FromCache gives us all required fields + // Otherwise, check Item. + var dataToCheck *astjson.Value + if key.FromCache != nil { + // If we have cached data, merge it with Item to get the complete picture + if key.Item != nil { + // Create a temporary merged value to check + // Note: We use a temporary arena here since we're just checking, not storing + merged, _, err := astjson.MergeValues(nil, key.Item, key.FromCache) + if err == nil && merged != nil { + dataToCheck = merged + } else { + // Fallback to FromCache if merge fails + dataToCheck = key.FromCache + } + } else { + dataToCheck = key.FromCache + } + } else { + dataToCheck = key.Item + } + + hasRequiredData := l.validateItemHasRequiredData(dataToCheck, info.ProvidesData) + + if !hasRequiredData { remaining = append(remaining, res.cacheKeys[i]) } } From 69937611e714842cfb9ae0077c8919470c2f49b1 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 30 Oct 2025 17:21:04 +0100 Subject: [PATCH 58/61] chore: lint --- v2/pkg/engine/resolve/caching.go | 1 + v2/pkg/engine/resolve/caching_test.go | 1 + v2/pkg/engine/resolve/fetch.go | 1 + v2/pkg/engine/resolve/loader.go | 4 ++-- v2/pkg/engine/resolve/loader_skip_fetch_test.go | 2 ++ v2/pkg/engine/resolve/resolve.go | 2 +- 6 files changed, 8 insertions(+), 3 deletions(-) diff --git a/v2/pkg/engine/resolve/caching.go b/v2/pkg/engine/resolve/caching.go index af74fc217..fcfbc45eb 100644 --- a/v2/pkg/engine/resolve/caching.go +++ b/v2/pkg/engine/resolve/caching.go @@ -3,6 +3,7 @@ package resolve import ( "github.com/wundergraph/astjson" "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafebytes" ) diff --git a/v2/pkg/engine/resolve/caching_test.go b/v2/pkg/engine/resolve/caching_test.go index d980f1078..c09a0dcbf 100644 --- a/v2/pkg/engine/resolve/caching_test.go +++ b/v2/pkg/engine/resolve/caching_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/wundergraph/astjson" "github.com/wundergraph/go-arena" ) diff --git a/v2/pkg/engine/resolve/fetch.go b/v2/pkg/engine/resolve/fetch.go index dd5292e17..59c4c7c7a 100644 --- a/v2/pkg/engine/resolve/fetch.go +++ b/v2/pkg/engine/resolve/fetch.go @@ -201,6 +201,7 @@ func (*BatchEntityFetch) FetchKind() FetchKind { // representations variable will contain single item type EntityFetch struct { FetchDependencies + CoordinateDependencies []FetchDependency Input EntityInput DataSource DataSource diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 19ca10da5..853ffe871 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -518,7 +518,7 @@ func cacheKeysToEntries(cacheKeys []*CacheKey, responseData *astjson.Value, json // Check if responseData is an array responseArray := responseData.GetArray() - if responseArray != nil && len(responseArray) > 1 { + if len(responseArray) > 1 { // Multiple items: extract per-item data from batch response if len(responseArray) != len(cacheKeys) { return nil, errors.Errorf("cache key count (%d) doesn't match response array length (%d)", len(cacheKeys), len(responseArray)) @@ -542,7 +542,7 @@ func cacheKeysToEntries(cacheKeys []*CacheKey, responseData *astjson.Value, json // Single item: store same value under all keys // This handles both single object and single-item array cases var dataToStore *astjson.Value - if responseArray != nil && len(responseArray) == 1 { + if len(responseArray) == 1 { dataToStore = responseArray[0] } else { dataToStore = responseData diff --git a/v2/pkg/engine/resolve/loader_skip_fetch_test.go b/v2/pkg/engine/resolve/loader_skip_fetch_test.go index 0afa54931..31f41adb5 100644 --- a/v2/pkg/engine/resolve/loader_skip_fetch_test.go +++ b/v2/pkg/engine/resolve/loader_skip_fetch_test.go @@ -4,7 +4,9 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/wundergraph/astjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 971cc4e43..f2323ad18 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -277,7 +277,7 @@ func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{ validateRequiredExternalFields: options.ValidateRequiredExternalFields, sf: sf, jsonArena: a, - caches: options.Caches, + caches: options.Caches, }, } } From ca8a003503234084b0e40ee9dee5535b063971fc Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 31 Oct 2025 19:45:52 +0100 Subject: [PATCH 59/61] chore: refactor key handling --- execution/engine/execution_engine.go | 6 + execution/engine/federation_caching_test.go | 202 +++++++++-- .../engine/federation_integration_test.go | 2 +- .../federationtesting/gateway/http/handler.go | 24 +- .../federationtesting/gateway/http/http.go | 4 + execution/federationtesting/gateway/main.go | 3 +- .../graphql_datasource_federation_test.go | 14 +- .../graphql_datasource_test.go | 7 +- v2/pkg/engine/plan/visitor.go | 3 +- v2/pkg/engine/resolve/caching.go | 49 +-- v2/pkg/engine/resolve/caching_test.go | 305 +++++++++-------- v2/pkg/engine/resolve/fetch.go | 4 + v2/pkg/engine/resolve/loader.go | 320 +++++++----------- .../engine/resolve/loader_skip_fetch_test.go | 107 ++---- 14 files changed, 563 insertions(+), 487 deletions(-) diff --git a/execution/engine/execution_engine.go b/execution/engine/execution_engine.go index 51b274282..ff77b855d 100644 --- a/execution/engine/execution_engine.go +++ b/execution/engine/execution_engine.go @@ -109,6 +109,12 @@ func WithRequestTraceOptions(options resolve.TraceOptions) ExecutionOptions { } } +func WithSubgraphHeadersBuilder(builder resolve.SubgraphHeadersBuilder) ExecutionOptions { + return func(ctx *internalExecutionContext) { + ctx.resolveContext.SubgraphHeadersBuilder = builder + } +} + func NewExecutionEngine(ctx context.Context, logger abstractlogger.Logger, engineConfig Configuration, resolverOptions resolve.ResolverOptions) (*ExecutionEngine, error) { executionPlanCache, err := lru.New(1024) if err != nil { diff --git a/execution/engine/federation_caching_test.go b/execution/engine/federation_caching_test.go index e08f34a44..4d8508372 100644 --- a/execution/engine/federation_caching_test.go +++ b/execution/engine/federation_caching_test.go @@ -70,16 +70,16 @@ func TestFederationCaching(t *testing.T) { { Operation: "get", Keys: []string{ - `{"__typename":"Product","keys":{"upc":"top-1"}}`, - `{"__typename":"Product","keys":{"upc":"top-2"}}`, + `{"__typename":"Product","key":{"upc":"top-1"}}`, + `{"__typename":"Product","key":{"upc":"top-2"}}`, }, Hits: []bool{false, false}, }, { Operation: "set", Keys: []string{ - `{"__typename":"Product","keys":{"upc":"top-1"}}`, - `{"__typename":"Product","keys":{"upc":"top-2"}}`, + `{"__typename":"Product","key":{"upc":"top-1"}}`, + `{"__typename":"Product","key":{"upc":"top-2"}}`, }, }, } @@ -103,7 +103,7 @@ func TestFederationCaching(t *testing.T) { assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) logAfterSecond := defaultCache.GetLog() - assert.Equal(t, 4, len(logAfterSecond)) + assert.Equal(t, 2, len(logAfterSecond)) wantLogSecond := []CacheLogEntry{ { @@ -111,25 +111,14 @@ func TestFederationCaching(t *testing.T) { Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, Hits: []bool{true}, // Should be a hit now }, - { - Operation: "set", - Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, - }, { Operation: "get", Keys: []string{ - `{"__typename":"Product","keys":{"upc":"top-1"}}`, - `{"__typename":"Product","keys":{"upc":"top-2"}}`, + `{"__typename":"Product","key":{"upc":"top-1"}}`, + `{"__typename":"Product","key":{"upc":"top-2"}}`, }, Hits: []bool{true, true}, // Should be hits now, no misses }, - { - Operation: "set", - Keys: []string{ - `{"__typename":"Product","keys":{"upc":"top-1"}}`, - `{"__typename":"Product","keys":{"upc":"top-2"}}`, - }, - }, } assert.Equal(t, sortCacheLogKeys(wantLogSecond), sortCacheLogKeys(logAfterSecond)) @@ -138,8 +127,8 @@ func TestFederationCaching(t *testing.T) { reviewsCallsSecond := tracker.GetCount(reviewsHost) accountsCallsSecond := tracker.GetCount(accountsHost) - assert.Equal(t, 1, productsCallsSecond, "Second query should hit cache and not call products subgraph again") - assert.Equal(t, 1, reviewsCallsSecond, "Second query should hit cache and not call reviews subgraph again") + assert.Equal(t, 0, productsCallsSecond, "Second query should hit cache and not call products subgraph again") + assert.Equal(t, 0, reviewsCallsSecond, "Second query should hit cache and not call reviews subgraph again") assert.Equal(t, 0, accountsCallsSecond, "accounts not involved") }) @@ -237,16 +226,16 @@ func TestFederationCaching(t *testing.T) { { Operation: "get", Keys: []string{ - `{"__typename":"Product","keys":{"upc":"top-1"}}`, - `{"__typename":"Product","keys":{"upc":"top-2"}}`, + `{"__typename":"Product","key":{"upc":"top-1"}}`, + `{"__typename":"Product","key":{"upc":"top-2"}}`, }, Hits: []bool{false, false}, // Miss because second query requests different fields (reviews) }, { Operation: "set", Keys: []string{ - `{"__typename":"Product","keys":{"upc":"top-1"}}`, - `{"__typename":"Product","keys":{"upc":"top-2"}}`, + `{"__typename":"Product","key":{"upc":"top-1"}}`, + `{"__typename":"Product","key":{"upc":"top-2"}}`, }, }, } @@ -290,8 +279,8 @@ func TestFederationCaching(t *testing.T) { { Operation: "get", Keys: []string{ - `{"__typename":"Product","keys":{"upc":"top-1"}}`, - `{"__typename":"Product","keys":{"upc":"top-2"}}`, + `{"__typename":"Product","key":{"upc":"top-1"}}`, + `{"__typename":"Product","key":{"upc":"top-2"}}`, }, Hits: []bool{true, true}, // Should be hits from second query }, @@ -308,6 +297,128 @@ func TestFederationCaching(t *testing.T) { assert.Equal(t, 0, reviewsCallsThird, "Third query does not call reviews subgraph (all cache hits)") assert.Equal(t, 0, accountsCallsThird, "Third query does not call accounts subgraph") }) + + t.Run("two subgraphs - with subgraph header prefix", func(t *testing.T) { + defaultCache := NewFakeLoaderCache() + caches := map[string]resolve.LoaderCache{ + "default": defaultCache, + } + + // Create HTTP client with tracking + tracker := newSubgraphCallTracker(http.DefaultTransport) + trackingClient := &http.Client{ + Transport: tracker, + } + + // Create mock SubgraphHeadersBuilder that returns a fixed hash for each subgraph + mockHeadersBuilder := &mockSubgraphHeadersBuilder{ + hashes: map[string]uint64{ + "1": 11111, + "2": 22222, + "3": 33333, + }, + } + + setup := federationtesting.NewFederationSetup(addCachingGateway( + withCachingEnableART(false), + withCachingLoaderCache(caches), + withHTTPClient(trackingClient), + withSubgraphHeadersBuilder(mockHeadersBuilder), + )) + t.Cleanup(setup.Close) + gqlClient := NewGraphqlClient(http.DefaultClient) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Extract hostnames for tracking (URL.Host includes host:port) + accountsURLParsed, _ := url.Parse(setup.AccountsUpstreamServer.URL) + productsURLParsed, _ := url.Parse(setup.ProductsUpstreamServer.URL) + reviewsURLParsed, _ := url.Parse(setup.ReviewsUpstreamServer.URL) + accountsHost := accountsURLParsed.Host + productsHost := productsURLParsed.Host + reviewsHost := reviewsURLParsed.Host + + // First query - should miss cache and then set with prefixed keys + defaultCache.ClearLog() + tracker.Reset() + resp := gqlClient.Query(ctx, setup.GatewayServer.URL, cachingTestQueryPath("queries/multiple_upstream.query"), nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) + + logAfterFirst := defaultCache.GetLog() + assert.Equal(t, 4, len(logAfterFirst)) + + wantLog := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`11111:{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{false}, + }, + { + Operation: "set", + Keys: []string{`11111:{"__typename":"Query","field":"topProducts"}`}, + }, + { + Operation: "get", + Keys: []string{ + `22222:{"__typename":"Product","key":{"upc":"top-1"}}`, + `22222:{"__typename":"Product","key":{"upc":"top-2"}}`, + }, + Hits: []bool{false, false}, + }, + { + Operation: "set", + Keys: []string{ + `22222:{"__typename":"Product","key":{"upc":"top-1"}}`, + `22222:{"__typename":"Product","key":{"upc":"top-2"}}`, + }, + }, + } + assert.Equal(t, sortCacheLogKeys(wantLog), sortCacheLogKeys(logAfterFirst)) + + // Verify subgraph calls for first query + productsCallsFirst := tracker.GetCount(productsHost) + reviewsCallsFirst := tracker.GetCount(reviewsHost) + accountsCallsFirst := tracker.GetCount(accountsHost) + + assert.Equal(t, 1, productsCallsFirst, "First query should call products subgraph exactly once") + assert.Equal(t, 1, reviewsCallsFirst, "First query should call reviews subgraph exactly once") + assert.Equal(t, 0, accountsCallsFirst, "First query should not call accounts subgraph") + + // Second query - should hit cache with prefixed keys + defaultCache.ClearLog() + tracker.Reset() + resp = gqlClient.Query(ctx, setup.GatewayServer.URL, cachingTestQueryPath("queries/multiple_upstream.query"), nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) + + logAfterSecond := defaultCache.GetLog() + assert.Equal(t, 2, len(logAfterSecond)) + + wantLogSecond := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`11111:{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{true}, // Should be a hit now + }, + { + Operation: "get", + Keys: []string{ + `22222:{"__typename":"Product","key":{"upc":"top-1"}}`, + `22222:{"__typename":"Product","key":{"upc":"top-2"}}`, + }, + Hits: []bool{true, true}, // Should be hits now + }, + } + assert.Equal(t, sortCacheLogKeys(wantLogSecond), sortCacheLogKeys(logAfterSecond)) + + // Verify subgraph calls for second query + productsCallsSecond := tracker.GetCount(productsHost) + reviewsCallsSecond := tracker.GetCount(reviewsHost) + accountsCallsSecond := tracker.GetCount(accountsHost) + + assert.Equal(t, 0, productsCallsSecond, "Second query should hit cache and not call products subgraph again") + assert.Equal(t, 0, reviewsCallsSecond, "Second query should hit cache and not call reviews subgraph again") + assert.Equal(t, 0, accountsCallsSecond, "accounts not involved") + }) } // subgraphCallTracker tracks HTTP requests made to subgraph servers @@ -362,9 +473,10 @@ func (t *subgraphCallTracker) DebugPrint() string { // Helper functions for gateway setup with HTTP client support type cachingGatewayOptions struct { - enableART bool - withLoaderCache map[string]resolve.LoaderCache - httpClient *http.Client + enableART bool + withLoaderCache map[string]resolve.LoaderCache + httpClient *http.Client + subgraphHeadersBuilder resolve.SubgraphHeadersBuilder } func withCachingEnableART(enableART bool) func(*cachingGatewayOptions) { @@ -385,6 +497,12 @@ func withHTTPClient(client *http.Client) func(*cachingGatewayOptions) { } } +func withSubgraphHeadersBuilder(builder resolve.SubgraphHeadersBuilder) func(*cachingGatewayOptions) { + return func(opts *cachingGatewayOptions) { + opts.subgraphHeadersBuilder = builder + } +} + type cachingGatewayOptionsToFunc func(opts *cachingGatewayOptions) func addCachingGateway(options ...cachingGatewayOptionsToFunc) func(setup *federationtesting.FederationSetup) *httptest.Server { @@ -404,7 +522,7 @@ func addCachingGateway(options ...cachingGatewayOptionsToFunc) func(setup *feder {Name: "reviews", URL: setup.ReviewsUpstreamServer.URL}, }, httpClient) - gtw := gateway.Handler(abstractlogger.NoopLogger, poller, httpClient, opts.enableART, opts.withLoaderCache) + gtw := gateway.Handler(abstractlogger.NoopLogger, poller, httpClient, opts.enableART, opts.withLoaderCache, opts.subgraphHeadersBuilder) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -414,6 +532,30 @@ func addCachingGateway(options ...cachingGatewayOptionsToFunc) func(setup *feder } } +// mockSubgraphHeadersBuilder is a mock implementation of SubgraphHeadersBuilder +type mockSubgraphHeadersBuilder struct { + hashes map[string]uint64 +} + +func (m *mockSubgraphHeadersBuilder) HeadersForSubgraph(subgraphName string) (http.Header, uint64) { + hash := m.hashes[subgraphName] + if hash == 0 { + // Return default hash if not found - this helps debug what names are being requested + // Note: This will cause test failures if subgraph names don't match + return nil, 99999 + } + return nil, hash +} + +func (m *mockSubgraphHeadersBuilder) HashAll() uint64 { + // Return a simple hash of all subgraph hashes combined + var result uint64 + for _, hash := range m.hashes { + result ^= hash + } + return result +} + func cachingTestQueryPath(name string) string { return path.Join("..", "federationtesting", "testdata", name) } diff --git a/execution/engine/federation_integration_test.go b/execution/engine/federation_integration_test.go index 4b0f702a1..e93231f21 100644 --- a/execution/engine/federation_integration_test.go +++ b/execution/engine/federation_integration_test.go @@ -58,7 +58,7 @@ func addGateway(options ...gatewayOptionsToFunc) func(setup *federationtesting.F {Name: "reviews", URL: setup.ReviewsUpstreamServer.URL}, }, httpClient) - gtw := gateway.Handler(abstractlogger.NoopLogger, poller, httpClient, opts.enableART, opts.withLoaderCache) + gtw := gateway.Handler(abstractlogger.NoopLogger, poller, httpClient, opts.enableART, opts.withLoaderCache, nil) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() diff --git a/execution/federationtesting/gateway/http/handler.go b/execution/federationtesting/gateway/http/handler.go index e6d575cd7..2e8983395 100644 --- a/execution/federationtesting/gateway/http/handler.go +++ b/execution/federationtesting/gateway/http/handler.go @@ -5,6 +5,7 @@ import ( "github.com/gobwas/ws" log "github.com/jensneuse/abstractlogger" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/execution/engine" "github.com/wundergraph/graphql-go-tools/execution/graphql" @@ -20,22 +21,25 @@ func NewGraphqlHTTPHandler( upgrader *ws.HTTPUpgrader, logger log.Logger, enableART bool, + subgraphHeadersBuilder resolve.SubgraphHeadersBuilder, ) http.Handler { return &GraphQLHTTPRequestHandler{ - schema: schema, - engine: engine, - wsUpgrader: upgrader, - log: logger, - enableART: enableART, + schema: schema, + engine: engine, + wsUpgrader: upgrader, + log: logger, + enableART: enableART, + subgraphHeadersBuilder: subgraphHeadersBuilder, } } type GraphQLHTTPRequestHandler struct { - log log.Logger - wsUpgrader *ws.HTTPUpgrader - engine *engine.ExecutionEngine - schema *graphql.Schema - enableART bool + log log.Logger + wsUpgrader *ws.HTTPUpgrader + engine *engine.ExecutionEngine + schema *graphql.Schema + enableART bool + subgraphHeadersBuilder resolve.SubgraphHeadersBuilder } func (g *GraphQLHTTPRequestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { diff --git a/execution/federationtesting/gateway/http/http.go b/execution/federationtesting/gateway/http/http.go index 5a255e01c..0d0a50e3f 100644 --- a/execution/federationtesting/gateway/http/http.go +++ b/execution/federationtesting/gateway/http/http.go @@ -45,6 +45,10 @@ func (g *GraphQLHTTPRequestHandler) handleHTTP(w http.ResponseWriter, r *http.Re opts = append(opts, engine.WithRequestTraceOptions(tracingOpts)) } + if g.subgraphHeadersBuilder != nil { + opts = append(opts, engine.WithSubgraphHeadersBuilder(g.subgraphHeadersBuilder)) + } + buf := bytes.NewBuffer(make([]byte, 0, 4096)) resultWriter := graphql.NewEngineResultWriterFromBuffer(buf) if err = g.engine.Execute(r.Context(), &gqlRequest, &resultWriter, opts...); err != nil { diff --git a/execution/federationtesting/gateway/main.go b/execution/federationtesting/gateway/main.go index 39da34d0f..dddfb372c 100644 --- a/execution/federationtesting/gateway/main.go +++ b/execution/federationtesting/gateway/main.go @@ -26,6 +26,7 @@ func Handler( httpClient *http.Client, enableART bool, loaderCaches map[string]resolve.LoaderCache, + subgraphHeadersBuilder resolve.SubgraphHeadersBuilder, ) *Gateway { upgrader := &ws.DefaultHTTPUpgrader upgrader.Header = http.Header{} @@ -34,7 +35,7 @@ func Handler( datasourceWatcher := datasourcePoller var gqlHandlerFactory HandlerFactoryFn = func(schema *graphql.Schema, engine *engine.ExecutionEngine) http.Handler { - return http2.NewGraphqlHTTPHandler(schema, engine, upgrader, logger, enableART) + return http2.NewGraphqlHTTPHandler(schema, engine, upgrader, logger, enableART, subgraphHeadersBuilder) } gateway := NewGateway(gqlHandlerFactory, httpClient, logger, loaderCaches) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go index 73990c556..01c4af8f5 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go @@ -1559,9 +1559,10 @@ func TestGraphQLDataSourceFederation(t *testing.T) { DataSource: &Source{}, PostProcessing: DefaultPostProcessingConfiguration, Caching: resolve.FetchCacheConfiguration{ - Enabled: true, - CacheName: "default", - TTL: 30 * time.Second, + Enabled: true, + CacheName: "default", + TTL: 30 * time.Second, + IncludeSubgraphHeaderPrefix: true, CacheKeyTemplate: &resolve.RootQueryCacheKeyTemplate{ RootFields: []resolve.QueryField{ { @@ -1849,9 +1850,10 @@ func TestGraphQLDataSourceFederation(t *testing.T) { }, PostProcessing: SingleEntityPostProcessingConfiguration, Caching: resolve.FetchCacheConfiguration{ - Enabled: true, - CacheName: "default", - TTL: time.Second * 30, + Enabled: true, + CacheName: "default", + TTL: time.Second * 30, + IncludeSubgraphHeaderPrefix: true, CacheKeyTemplate: &resolve.EntityQueryCacheKeyTemplate{ Keys: resolve.NewResolvableObjectVariable(&resolve.Object{ Nullable: true, diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index cccc11f3c..6e8850ef5 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -398,9 +398,10 @@ func TestGraphQLDataSource(t *testing.T) { ), PostProcessing: DefaultPostProcessingConfiguration, Caching: resolve.FetchCacheConfiguration{ - Enabled: true, - CacheName: "default", - TTL: 30 * time.Second, + Enabled: true, + CacheName: "default", + TTL: 30 * time.Second, + IncludeSubgraphHeaderPrefix: true, CacheKeyTemplate: &resolve.RootQueryCacheKeyTemplate{ RootFields: []resolve.QueryField{ { diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index 71da3b87e..1f189558b 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -1650,7 +1650,8 @@ func (v *Visitor) configureFetch(internal *objectFetchConfiguration, external re CacheName: "default", TTL: time.Second * time.Duration(30), // templates come prepared from the DataSource - CacheKeyTemplate: external.Caching.CacheKeyTemplate, + CacheKeyTemplate: external.Caching.CacheKeyTemplate, + IncludeSubgraphHeaderPrefix: true, } } else { external.Caching = resolve.FetchCacheConfiguration{ diff --git a/v2/pkg/engine/resolve/caching.go b/v2/pkg/engine/resolve/caching.go index fcfbc45eb..10566075a 100644 --- a/v2/pkg/engine/resolve/caching.go +++ b/v2/pkg/engine/resolve/caching.go @@ -10,18 +10,13 @@ import ( type CacheKeyTemplate interface { // RenderCacheKeys returns multiple cache keys (one per root field or entity) // Generates keys for all items at once - RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value) ([]*CacheKey, error) + RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value, prefix string) ([]*CacheKey, error) } type CacheKey struct { Item *astjson.Value FromCache *astjson.Value - Keys []KeyEntry -} - -type KeyEntry struct { - Name string - Path string + Keys []string } type RootQueryCacheKeyTemplate struct { @@ -40,7 +35,7 @@ type FieldArgument struct { // RenderCacheKeys returns multiple cache keys, one per item // Each cache key contains one or more KeyEntry objects (one per root field) -func (r *RootQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value) ([]*CacheKey, error) { +func (r *RootQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value, prefix string) ([]*CacheKey, error) { if len(r.RootFields) == 0 { return nil, nil } @@ -50,14 +45,19 @@ func (r *RootQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Context, for _, item := range items { // Create KeyEntry for each root field - keyEntries := arena.AllocateSlice[KeyEntry](a, 0, len(r.RootFields)) + keyEntries := arena.AllocateSlice[string](a, 0, len(r.RootFields)) for _, field := range r.RootFields { var key string key, jsonBytes = r.renderField(a, ctx, item, jsonBytes, field) - keyEntries = arena.SliceAppend(a, keyEntries, KeyEntry{ - Name: key, - Path: field.Coordinate.FieldName, - }) + if prefix != "" { + l := len(prefix) + 1 + len(key) + tmp := arena.AllocateSlice[byte](a, 0, l) + tmp = arena.SliceAppend(a, tmp, unsafebytes.StringToBytes(prefix)...) + tmp = arena.SliceAppend(a, tmp, []byte(`:`)...) + tmp = arena.SliceAppend(a, tmp, unsafebytes.StringToBytes(key)...) + key = unsafebytes.BytesToString(tmp) + } + keyEntries = arena.SliceAppend(a, keyEntries, key) } cacheKeys = arena.SliceAppend(a, cacheKeys, &CacheKey{ Item: item, @@ -136,7 +136,7 @@ type EntityQueryCacheKeyTemplate struct { } // RenderCacheKeys returns one cache key per item for entity queries with keys nested under "keys" -func (e *EntityQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value) ([]*CacheKey, error) { +func (e *EntityQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value, prefix string) ([]*CacheKey, error) { jsonBytes := arena.AllocateSlice[byte](a, 0, 64) cacheKeys := arena.AllocateSlice[*CacheKey](a, 0, len(items)) @@ -178,19 +178,24 @@ func (e *EntityQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Contex } } - keyObj.Set(a, "keys", keysObj) + keyObj.Set(a, "key", keysObj) // Marshal to JSON and write to buffer jsonBytes = keyObj.MarshalTo(jsonBytes[:0]) - slice := arena.AllocateSlice[byte](a, len(jsonBytes), len(jsonBytes)) - copy(slice, jsonBytes) + l := len(jsonBytes) + if prefix != "" { + l += 1 + len(prefix) + } + slice := arena.AllocateSlice[byte](a, 0, l) + if prefix != "" { + slice = arena.SliceAppend(a, slice, unsafebytes.StringToBytes(prefix)...) + slice = arena.SliceAppend(a, slice, []byte(`:`)...) + } + slice = arena.SliceAppend(a, slice, jsonBytes...) // Create KeyEntry with empty path for entity queries - keyEntries := arena.AllocateSlice[KeyEntry](a, 0, 1) - keyEntries = arena.SliceAppend(a, keyEntries, KeyEntry{ - Name: unsafebytes.BytesToString(slice), - Path: "", - }) + keyEntries := arena.AllocateSlice[string](a, 0, 1) + keyEntries = arena.SliceAppend(a, keyEntries, unsafebytes.BytesToString(slice)) cacheKeys = arena.SliceAppend(a, cacheKeys, &CacheKey{ Item: item, diff --git a/v2/pkg/engine/resolve/caching_test.go b/v2/pkg/engine/resolve/caching_test.go index c09a0dcbf..f382f58f3 100644 --- a/v2/pkg/engine/resolve/caching_test.go +++ b/v2/pkg/engine/resolve/caching_test.go @@ -29,17 +29,12 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Query","field":"users"}`, - Path: "users", - }, - }, + Keys: []string{`{"__typename":"Query","field":"users"}`}, }, } assert.Equal(t, expected, cacheKeys) @@ -71,17 +66,12 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Query","field":"droid","args":{"id":1}}`, - Path: "droid", - }, - }, + Keys: []string{`{"__typename":"Query","field":"droid","args":{"id":1}}`}, }, } assert.Equal(t, expected, cacheKeys) @@ -113,17 +103,12 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Query","field":"user","args":{"name":"john"}}`, - Path: "user", - }, - }, + Keys: []string{`{"__typename":"Query","field":"user","args":{"name":"john"}}`}, }, } assert.Equal(t, expected, cacheKeys) @@ -162,17 +147,12 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Query","field":"search","args":{"term":"C3PO","max":10}}`, - Path: "search", - }, - }, + Keys: []string{`{"__typename":"Query","field":"search","args":{"term":"C3PO","max":10}}`}, }, } assert.Equal(t, expected, cacheKeys) @@ -211,17 +191,12 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Query","field":"products","args":{"includeDeleted":true,"limit":20}}`, - Path: "products", - }, - }, + Keys: []string{`{"__typename":"Query","field":"products","args":{"includeDeleted":true,"limit":20}}`}, }, } assert.Equal(t, expected, cacheKeys) @@ -269,20 +244,14 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { } data := astjson.MustParse(`{}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Query","field":"droid","args":{"id":1}}`, - Path: "droid", - }, - { - Name: `{"__typename":"Query","field":"user","args":{"name":"john"}}`, - Path: "user", - }, + Keys: []string{ + `{"__typename":"Query","field":"droid","args":{"id":1}}`, + `{"__typename":"Query","field":"user","args":{"name":"john"}}`, }, }, } @@ -330,20 +299,14 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { } data := astjson.MustParse(`{}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Query","field":"product","args":{"id":"123","includeReviews":true}}`, - Path: "product", - }, - { - Name: `{"__typename":"Query","field":"hero"}`, - Path: "hero", - }, + Keys: []string{ + `{"__typename":"Query","field":"product","args":{"id":"123","includeReviews":true}}`, + `{"__typename":"Query","field":"hero"}`, }, }, } @@ -376,17 +339,12 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{"filter":{"category":"electronics","price":100}}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Query","field":"search","args":{"filter":{"category":"electronics","price":100}}}`, - Path: "search", - }, - }, + Keys: []string{`{"__typename":"Query","field":"search","args":{"filter":{"category":"electronics","price":100}}}`}, }, } assert.Equal(t, expected, cacheKeys) @@ -418,17 +376,12 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Query","field":"user","args":{"id":null}}`, - Path: "user", - }, - }, + Keys: []string{`{"__typename":"Query","field":"user","args":{"id":null}}`}, }, } assert.Equal(t, expected, cacheKeys) @@ -460,17 +413,12 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Query","field":"user","args":{"id":null}}`, - Path: "user", - }, - }, + Keys: []string{`{"__typename":"Query","field":"user","args":{"id":null}}`}, }, } assert.Equal(t, expected, cacheKeys) @@ -502,17 +450,12 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Query","field":"products","args":{"ids":[1,2,3]}}`, - Path: "products", - }, - }, + Keys: []string{`{"__typename":"Query","field":"products","args":{"ids":[1,2,3]}}`}, }, } assert.Equal(t, expected, cacheKeys) @@ -544,17 +487,12 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Subscription","field":"messageAdded","args":{"roomId":"123"}}`, - Path: "messageAdded", - }, - }, + Keys: []string{`{"__typename":"Subscription","field":"messageAdded","args":{"roomId":"123"}}`}, }, } assert.Equal(t, expected, cacheKeys) @@ -587,19 +525,106 @@ func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{}`) - cacheKeys, err := tmpl.RenderCacheKeys(ar, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(ar, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Query","field":"user","args":{"name":"john"}}`, - Path: "user", + Keys: []string{`{"__typename":"Query","field":"user","args":{"name":"john"}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("single field with prefix", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"id":1}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "prefix") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`prefix:{"__typename":"Query","field":"user","args":{"id":1}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("multiple fields with prefix", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "droid", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "name", + Variable: &ContextVariable{ + Path: []string{"name"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, }, }, }, } + + ctx := &Context{ + Variables: astjson.MustParse(`{"id":1,"name":"john"}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "my-prefix") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{ + `my-prefix:{"__typename":"Query","field":"droid","args":{"id":1}}`, + `my-prefix:{"__typename":"Query","field":"user","args":{"name":"john"}}`, + }, + }, + } assert.Equal(t, expected, cacheKeys) }) } @@ -630,16 +655,12 @@ func TestCachingRenderEntityQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{"__typename":"Product","id":"123"}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Product","keys":{"id":"123"}}`, - }, - }, + Keys: []string{`{"__typename":"Product","key":{"id":"123"}}`}, }, } assert.Equal(t, expected, cacheKeys) @@ -676,22 +697,18 @@ func TestCachingRenderEntityQueryCacheKeyTemplate(t *testing.T) { ctx: context.Background(), } data := astjson.MustParse(`{"__typename":"Product","sku":"ABC123","upc":"DEF456","name":"Trilby"}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ - { - Name: `{"__typename":"Product","keys":{"sku":"ABC123","upc":"DEF456"}}`, - }, - }, + Keys: []string{`{"__typename":"Product","key":{"sku":"ABC123","upc":"DEF456"}}`}, }, } assert.Equal(t, expected, cacheKeys) }) - t.Run("entity with nested object key", func(t *testing.T) { + t.Run("single entity with prefix", func(t *testing.T) { tmpl := &EntityQueryCacheKeyTemplate{ Keys: NewResolvableObjectVariable(&Object{ Fields: []*Field{ @@ -702,22 +719,9 @@ func TestCachingRenderEntityQueryCacheKeyTemplate(t *testing.T) { }, }, { - Name: []byte("key"), - Value: &Object{ - Fields: []*Field{ - { - Name: []byte("id"), - Value: &String{ - Path: []string{"key", "id"}, - }, - }, - { - Name: []byte("version"), - Value: &String{ - Path: []string{"key", "version"}, - }, - }, - }, + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, }, }, }, @@ -728,18 +732,55 @@ func TestCachingRenderEntityQueryCacheKeyTemplate(t *testing.T) { Variables: astjson.MustParse(`{}`), ctx: context.Background(), } - data := astjson.MustParse(`{"__typename":"VersionedEntity","key":{"id":"123","version":"1"}}`) - cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}) + data := astjson.MustParse(`{"__typename":"Product","id":"123"}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "entity-prefix") assert.NoError(t, err) expected := []*CacheKey{ { Item: data, - Keys: []KeyEntry{ + Keys: []string{`entity-prefix:{"__typename":"Product","key":{"id":"123"}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("entity with multiple keys and prefix", func(t *testing.T) { + tmpl := &EntityQueryCacheKeyTemplate{ + Keys: NewResolvableObjectVariable(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("sku"), + Value: &String{ + Path: []string{"sku"}, + }, + }, { - Name: `{"__typename":"VersionedEntity","keys":{"key":{"id":"123","version":"1"}}}`, - Path: "", + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, }, }, + }), + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{"__typename":"Product","sku":"ABC123","upc":"DEF456","name":"Trilby"}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "cache") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`cache:{"__typename":"Product","key":{"sku":"ABC123","upc":"DEF456"}}`}, }, } assert.Equal(t, expected, cacheKeys) @@ -788,7 +829,7 @@ func BenchmarkRenderCacheKeys(b *testing.B) { for i := 0; i < b.N; i++ { a.Reset() - _, err := tmpl.RenderCacheKeys(a, ctxRootQuery, items) + _, err := tmpl.RenderCacheKeys(a, ctxRootQuery, items, "") if err != nil { b.Fatal(err) } @@ -861,7 +902,7 @@ func BenchmarkRenderCacheKeys(b *testing.B) { for i := 0; i < b.N; i++ { a.Reset() - _, err := tmpl.RenderCacheKeys(a, ctxRootQuery, items) + _, err := tmpl.RenderCacheKeys(a, ctxRootQuery, items, "") if err != nil { b.Fatal(err) } @@ -910,7 +951,7 @@ func BenchmarkRenderCacheKeys(b *testing.B) { for i := 0; i < b.N; i++ { a.Reset() - _, err := tmpl.RenderCacheKeys(a, ctxEntityQuery, items) + _, err := tmpl.RenderCacheKeys(a, ctxEntityQuery, items, "") if err != nil { b.Fatal(err) } diff --git a/v2/pkg/engine/resolve/fetch.go b/v2/pkg/engine/resolve/fetch.go index 59c4c7c7a..c6792ae68 100644 --- a/v2/pkg/engine/resolve/fetch.go +++ b/v2/pkg/engine/resolve/fetch.go @@ -325,6 +325,10 @@ type FetchCacheConfiguration struct { // In case of a root fetch, the variables will be one or more field arguments // For entity fetches, the variables will be a single Object Variable with @key and @requires fields CacheKeyTemplate CacheKeyTemplate + // IncludeSubgraphHeaderPrefix indicates if cache keys should be prefixed with the subgraph header hash. + // The prefix format is "id:cacheKey" where id is the hash from HeadersForSubgraph. + // Defaults to true. + IncludeSubgraphHeaderPrefix bool } // FetchDependency explains how a GraphCoordinate depends on other GraphCoordinates from other fetches diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 853ffe871..0d23b2e6e 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -133,19 +133,22 @@ type result struct { cache LoaderCache cacheMustBeUpdated bool cacheKeys []*CacheKey - cacheTTL time.Duration - cacheSkippedFetch bool - cacheResponseData *astjson.Value // Response data to cache (set in mergeResult) + cacheSkipFetch bool + cacheConfig FetchCacheConfiguration } -func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { - r.postProcessing = postProcessing +func (l *Loader) createOrInitResult(res *result, postProcessing PostProcessingConfiguration, info *FetchInfo) *result { + if res == nil { + res = &result{} + } + res.postProcessing = postProcessing if info != nil { - r.ds = DataSourceInfo{ + res.ds = DataSourceInfo{ ID: info.DataSourceID, Name: info.DataSourceName, } } + return res } func IsIntrospectionDataSource(dataSourceID string) bool { @@ -299,7 +302,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { switch f := item.Fetch.(type) { case *SingleFetch: - res := &result{} + res := l.createOrInitResult(nil, f.PostProcessing, f.Info) skip, err := l.tryCacheLoadFetch(l.ctx.ctx, f.Info, f.Caching, items, res) if err != nil { return errors.WithStack(err) @@ -316,7 +319,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { } return err case *BatchEntityFetch: - res := &result{} + res := l.createOrInitResult(nil, f.PostProcessing, f.Info) defer batchEntityToolPool.Put(res.tools) skip, err := l.tryCacheLoadFetch(l.ctx.ctx, f.Info, f.Caching, items, res) if err != nil { @@ -334,7 +337,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { } return err case *EntityFetch: - res := &result{} + res := l.createOrInitResult(nil, f.PostProcessing, f.Info) skip, err := l.tryCacheLoadFetch(l.ctx.ctx, f.Info, f.Caching, items, res) if err != nil { return errors.WithStack(err) @@ -455,51 +458,38 @@ type LoaderCache interface { } // extractCacheKeysStrings extracts all unique cache key strings from CacheKeys -func extractCacheKeysStrings(cacheKeys []*CacheKey) []string { +// If includePrefix is true and subgraphName is provided, keys are prefixed with the subgraph header hash. +func (l *Loader) extractCacheKeysStrings(a arena.Arena, cacheKeys []*CacheKey) []string { if len(cacheKeys) == 0 { return nil } - keySet := make(map[string]struct{}) - for _, cacheKey := range cacheKeys { - for _, entry := range cacheKey.Keys { - keySet[entry.Name] = struct{}{} + out := arena.AllocateSlice[string](a, 0, len(cacheKeys)) + for i := range cacheKeys { + for j := range cacheKeys[i].Keys { + l := len(cacheKeys[i].Keys[j]) + key := arena.AllocateSlice[byte](a, 0, l) + key = arena.SliceAppend(a, key, unsafebytes.StringToBytes(cacheKeys[i].Keys[j])...) + out = arena.SliceAppend(a, out, unsafebytes.BytesToString(key)) } } - keys := make([]string, 0, len(keySet)) - for key := range keySet { - keys = append(keys, key) - } - return keys + return out } // populateFromCache populates CacheKey.FromCache fields from cache entries -func populateFromCache(cacheKeys []*CacheKey, entries []*CacheEntry) error { - // Create a map of key -> value for quick lookup - entryMap := make(map[string][]byte) - for _, entry := range entries { - if entry != nil && entry.Value != nil { - entryMap[entry.Key] = entry.Value - } - } - - // For each CacheKey, find matching entries and populate FromCache - // Since multiple KeyEntries can map to the same value, we use the first match - for _, cacheKey := range cacheKeys { - if cacheKey.FromCache != nil { - // Already populated, skip +// If includePrefix is true and subgraphName is provided, keys are looked up with the subgraph header hash prefix. +func (l *Loader) populateFromCache(a arena.Arena, cacheKeys []*CacheKey, entries []*CacheEntry) (err error) { + for i := range entries { + if entries[i] == nil || entries[i].Value == nil { continue } - for _, keyEntry := range cacheKey.Keys { - if cachedValue, found := entryMap[keyEntry.Name]; found { - // Parse the cached JSON value - // Note: We use nil arena here because this is temporary data - // The FromCache will be merged into items which are on the jsonArena - parsedValue, err := astjson.ParseBytes(cachedValue) - if err != nil { - return errors.WithStack(err) + for j := range cacheKeys { + for k := range cacheKeys[j].Keys { + if cacheKeys[j].Keys[k] == entries[i].Key { + cacheKeys[j].FromCache, err = astjson.ParseBytesWithArena(a, entries[i].Value) + if err != nil { + return errors.WithStack(err) + } } - cacheKey.FromCache = parsedValue - break // Use first match } } } @@ -508,62 +498,25 @@ func populateFromCache(cacheKeys []*CacheKey, entries []*CacheEntry) error { // cacheKeysToEntries converts CacheKeys to CacheEntries for storage // For each CacheKey, creates entries for all its KeyEntries with the same value -func cacheKeysToEntries(cacheKeys []*CacheKey, responseData *astjson.Value, jsonArena arena.Arena) ([]*CacheEntry, error) { - if len(cacheKeys) == 0 { - return nil, nil - } - - entries := make([]*CacheEntry, 0) - - // Check if responseData is an array - responseArray := responseData.GetArray() - - if len(responseArray) > 1 { - // Multiple items: extract per-item data from batch response - if len(responseArray) != len(cacheKeys) { - return nil, errors.Errorf("cache key count (%d) doesn't match response array length (%d)", len(cacheKeys), len(responseArray)) - } - - // For each CacheKey, serialize its corresponding item and store under all its KeyEntries - for i, cacheKey := range cacheKeys { - itemData := responseArray[i] - itemBytes := itemData.MarshalTo(nil) - - for _, keyEntry := range cacheKey.Keys { - valueCopy := make([]byte, len(itemBytes)) - copy(valueCopy, itemBytes) - entries = append(entries, &CacheEntry{ - Key: keyEntry.Name, - Value: valueCopy, - }) +// If includePrefix is true and subgraphName is provided, keys are prefixed with the subgraph header hash. +func (l *Loader) cacheKeysToEntries(a arena.Arena, cacheKeys []*CacheKey) ([]*CacheEntry, error) { + out := arena.AllocateSlice[*CacheEntry](a, 0, len(cacheKeys)) + buf := arena.AllocateSlice[byte](a, 64, 64) + for i := range cacheKeys { + for j := range cacheKeys[i].Keys { + if cacheKeys[i].Item == nil { + continue } - } - } else { - // Single item: store same value under all keys - // This handles both single object and single-item array cases - var dataToStore *astjson.Value - if len(responseArray) == 1 { - dataToStore = responseArray[0] - } else { - dataToStore = responseData - } - - dataBytes := dataToStore.MarshalTo(nil) - - // Store under all KeyEntries for all CacheKeys - for _, cacheKey := range cacheKeys { - for _, keyEntry := range cacheKey.Keys { - valueCopy := make([]byte, len(dataBytes)) - copy(valueCopy, dataBytes) - entries = append(entries, &CacheEntry{ - Key: keyEntry.Name, - Value: valueCopy, - }) + buf = cacheKeys[i].Item.MarshalTo(buf[:0]) + entry := &CacheEntry{ + Key: cacheKeys[i].Keys[j], + Value: arena.AllocateSlice[byte](a, len(buf), len(buf)), } + copy(entry.Value, buf) + out = arena.SliceAppend(a, out, entry) } } - - return entries, nil + return out, nil } func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg FetchCacheConfiguration, inputItems []*astjson.Value, res *result) (skipFetch bool, err error) { @@ -576,12 +529,20 @@ func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg Fet if l.caches == nil { return false, nil } + res.cacheConfig = cfg res.cache = l.caches[cfg.CacheName] if res.cache == nil { return false, nil } + var prefix string + if cfg.IncludeSubgraphHeaderPrefix && l.ctx.SubgraphHeadersBuilder != nil { + _, headersHash := l.ctx.SubgraphHeadersBuilder.HeadersForSubgraph(info.DataSourceName) + var buf [20]byte + b := strconv.AppendUint(buf[:0], headersHash, 10) + prefix = string(b) + } // Generate cache keys for all items at once - res.cacheKeys, err = cfg.CacheKeyTemplate.RenderCacheKeys(nil, l.ctx, inputItems) + res.cacheKeys, err = cfg.CacheKeyTemplate.RenderCacheKeys(nil, l.ctx, inputItems, prefix) if err != nil { return false, err } @@ -589,8 +550,7 @@ func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg Fet // If no cache keys were generated, we skip the cache return false, nil } - // Extract all unique cache key strings - cacheKeyStrings := extractCacheKeysStrings(res.cacheKeys) + cacheKeyStrings := l.extractCacheKeysStrings(nil, res.cacheKeys) if len(cacheKeyStrings) == 0 { return false, nil } @@ -600,25 +560,23 @@ func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg Fet return false, err } // Populate FromCache fields in CacheKeys - err = populateFromCache(res.cacheKeys, cacheEntries) + err = l.populateFromCache(nil, res.cacheKeys, cacheEntries) if err != nil { return false, err } - res.cacheTTL = cfg.TTL - missing, canSkip := l.canSkipFetch(info, res) + canSkip := l.canSkipFetch(info, res) if canSkip { - res.cacheSkippedFetch = true + res.cacheSkipFetch = true return true, nil } res.cacheMustBeUpdated = true - res.cacheTTL = cfg.TTL - _ = missing return false, nil } func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchItem, items []*astjson.Value, res *result) error { switch f := fetch.(type) { case *SingleFetch: + res = l.createOrInitResult(res, f.PostProcessing, f.Info) skip, err := l.tryCacheLoadFetch(ctx, f.Info, f.Caching, items, res) if err != nil { return errors.WithStack(err) @@ -628,6 +586,7 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchIte } return l.loadSingleFetch(ctx, f, fetchItem, items, res) case *EntityFetch: + res = l.createOrInitResult(res, f.PostProcessing, f.Info) skip, err := l.tryCacheLoadFetch(ctx, f.Info, f.Caching, items, res) if err != nil { return errors.WithStack(err) @@ -637,6 +596,7 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchIte } return l.loadEntityFetch(ctx, fetchItem, f, items, res) case *BatchEntityFetch: + res = l.createOrInitResult(res, f.PostProcessing, f.Info) skip, err := l.tryCacheLoadFetch(ctx, f.Info, f.Caching, items, res) if err != nil { return errors.WithStack(err) @@ -675,64 +635,18 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson if res.err != nil { return l.renderErrorsFailedToFetch(fetchItem, res, failedToFetchNoReason) } - if res.authorizationRejected { - err := l.renderAuthorizationRejectedErrors(fetchItem, res) - if err != nil { - return err - } - trueValue := astjson.MustParse(`true`) - skipErrorsPath := make([]string, len(res.postProcessing.MergePath)+1) - copy(skipErrorsPath, res.postProcessing.MergePath) - skipErrorsPath[len(skipErrorsPath)-1] = "__skipErrors" - for _, item := range items { - astjson.SetValue(item, trueValue, skipErrorsPath...) - } - return nil - } - if res.rateLimitRejected { - err := l.renderRateLimitRejectedErrors(fetchItem, res) - if err != nil { - return err - } - trueValue := astjson.MustParse(`true`) - skipErrorsPath := make([]string, len(res.postProcessing.MergePath)+1) - copy(skipErrorsPath, res.postProcessing.MergePath) - skipErrorsPath[len(skipErrorsPath)-1] = "__skipErrors" - for _, item := range items { - astjson.SetValue(item, trueValue, skipErrorsPath...) - } - return nil + if rejected, err := l.evaluateRejected(fetchItem, res, items); err != nil || rejected { + return err } - if res.cacheSkippedFetch { + if res.cacheSkipFetch { // Merge cached data into items - mergedData := make([]*astjson.Value, len(res.cacheKeys)) - for i, key := range res.cacheKeys { + for _, key := range res.cacheKeys { // Merge cached data into item - merged, _, err := astjson.MergeValues(l.jsonArena, items[i], key.FromCache) + _, _, err := astjson.MergeValues(l.jsonArena, key.Item, key.FromCache) if err != nil { return l.renderErrorsFailedToFetch(fetchItem, res, "invalid cache item") } - mergedData[i] = merged } - - // Update cache with merged data to refresh TTL, even when skipping fetch - if res.cacheMustBeUpdated && len(mergedData) > 0 { - // Construct responseData from merged items for cache update - // For batch responses, create an array; for single items, use the first item - var responseData *astjson.Value - if len(mergedData) == 1 { - responseData = mergedData[0] - } else { - // Create array from merged items - responseData = astjson.ArrayValue(l.jsonArena) - for i, item := range mergedData { - responseData.SetArrayItem(l.jsonArena, i, item) - } - } - res.cacheResponseData = responseData - defer l.updateCache(res) - } - return nil } if res.fetchSkipped { @@ -762,12 +676,6 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson responseData = response } - // Store responseData for caching if needed - if res.cacheMustBeUpdated { - res.cacheResponseData = responseData - defer l.updateCache(res) - } - hasErrors := false var taintedIndices []int @@ -823,7 +731,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson // no data return nil } - + defer l.updateCache(res) if len(items) == 0 { // If the data is set, it must be an object according to GraphQL over HTTP spec if responseData.Type() != astjson.TypeObject { @@ -892,9 +800,40 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson l.taintedObjs.add(items[i]) } } + return nil } +func (l *Loader) evaluateRejected(fetchItem *FetchItem, res *result, items []*astjson.Value) (bool, error) { + if res.authorizationRejected { + err := l.renderAuthorizationRejectedErrors(fetchItem, res) + if err != nil { + return false, err + } + l.setSkipErrors(res, items) + return true, nil + } + if res.rateLimitRejected { + err := l.renderRateLimitRejectedErrors(fetchItem, res) + if err != nil { + return false, err + } + l.setSkipErrors(res, items) + return true, nil + } + return false, nil +} + +func (l *Loader) setSkipErrors(res *result, items []*astjson.Value) { + trueValue := astjson.TrueValue(l.jsonArena) + skipErrorsPath := make([]string, len(res.postProcessing.MergePath)+1) + copy(skipErrorsPath, res.postProcessing.MergePath) + skipErrorsPath[len(skipErrorsPath)-1] = "__skipErrors" + for _, item := range items { + astjson.SetValue(item, trueValue, skipErrorsPath...) + } +} + var ( errorsInvalidInputHeader = []byte(`{"errors":[{"message":"Failed to render Fetch Input","path":[`) errorsInvalidInputFooter = []byte(`]}]}`) @@ -923,12 +862,12 @@ func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem) []byte { } func (l *Loader) updateCache(res *result) { - if res.cache == nil || len(res.cacheKeys) == 0 || res.cacheResponseData == nil { + if res.cache == nil || len(res.cacheKeys) == 0 || !res.cacheMustBeUpdated { return } // Convert CacheKeys to CacheEntries - cacheEntries, err := cacheKeysToEntries(res.cacheKeys, res.cacheResponseData, l.jsonArena) + cacheEntries, err := l.cacheKeysToEntries(l.jsonArena, res.cacheKeys) if err != nil { fmt.Printf("error converting cache keys to entries: %s", err) return @@ -938,7 +877,7 @@ func (l *Loader) updateCache(res *result) { return } - err = res.cache.Set(context.Background(), cacheEntries, res.cacheTTL) + err = res.cache.Set(l.ctx.ctx, cacheEntries, res.cacheConfig.TTL) if err != nil { fmt.Printf("error cache.Set: %s", err) } @@ -1532,9 +1471,7 @@ func (l *Loader) validatePreFetch(input []byte, info *FetchInfo, res *result) (a } func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchItem *FetchItem, items []*astjson.Value, res *result) error { - res.init(fetch.PostProcessing, fetch.Info) buf := bytes.NewBuffer(nil) - inputData := l.itemsData(items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} @@ -1573,7 +1510,6 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI } func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *EntityFetch, items []*astjson.Value, res *result) error { - res.init(fetch.PostProcessing, fetch.Info) input := l.itemsData(items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} @@ -1694,8 +1630,6 @@ var ( ) func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *BatchEntityFetch, items []*astjson.Value, res *result) error { - res.init(fetch.PostProcessing, fetch.Info) - if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && len(items) != 0 { @@ -2189,45 +2123,19 @@ func (l *Loader) compactJSON(data []byte) ([]byte, error) { return v.MarshalTo(nil), nil } -func (l *Loader) canSkipFetch(info *FetchInfo, res *result) ([]*CacheKey, bool) { +// canSkipFetch returns true if the cache provided exactly the information required to satisfy the query plan +// the query planner generates info.ProvidesData which tells precisely which fields the fetch must load +// if a single value is missing, we will execute the fetch +func (l *Loader) canSkipFetch(info *FetchInfo, res *result) bool { if info == nil || info.OperationType != ast.OperationTypeQuery || info.ProvidesData == nil { - return res.cacheKeys, false - } - - // Check each item and remove those that have sufficient data - remaining := make([]*CacheKey, 0, len(res.cacheKeys)) - for i, key := range res.cacheKeys { - // When we have cached data, we should check if merging Item + FromCache gives us all required fields - // Otherwise, check Item. - var dataToCheck *astjson.Value - if key.FromCache != nil { - // If we have cached data, merge it with Item to get the complete picture - if key.Item != nil { - // Create a temporary merged value to check - // Note: We use a temporary arena here since we're just checking, not storing - merged, _, err := astjson.MergeValues(nil, key.Item, key.FromCache) - if err == nil && merged != nil { - dataToCheck = merged - } else { - // Fallback to FromCache if merge fails - dataToCheck = key.FromCache - } - } else { - dataToCheck = key.FromCache - } - } else { - dataToCheck = key.Item - } - - hasRequiredData := l.validateItemHasRequiredData(dataToCheck, info.ProvidesData) - - if !hasRequiredData { - remaining = append(remaining, res.cacheKeys[i]) + return false + } + for i := range res.cacheKeys { + if !l.validateItemHasRequiredData(res.cacheKeys[i].FromCache, info.ProvidesData) { + return false } } - - // Return the remaining items and whether fetch can be skipped - return remaining, len(remaining) == 0 + return true } // validateItemHasRequiredData checks if the given item contains all required data diff --git a/v2/pkg/engine/resolve/loader_skip_fetch_test.go b/v2/pkg/engine/resolve/loader_skip_fetch_test.go index 31f41adb5..aadac1584 100644 --- a/v2/pkg/engine/resolve/loader_skip_fetch_test.go +++ b/v2/pkg/engine/resolve/loader_skip_fetch_test.go @@ -14,12 +14,10 @@ func TestLoader_canSkipFetch(t *testing.T) { t.Parallel() tests := []struct { - name string - info *FetchInfo - items []*astjson.Value - wantResult bool - wantRemaining int // -1 means check for empty, otherwise check exact count - checkFn func(t *testing.T, remaining []*CacheKey) // optional custom validation + name string + info *FetchInfo + items []*astjson.Value + expectSkipFetch bool }{ { name: "single item with Query operation", @@ -40,8 +38,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"id": "123"}`)), }, - wantResult: true, - wantRemaining: -1, // empty + expectSkipFetch: true, }, { name: "single item with Mutation operation", @@ -62,8 +59,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"id": "123"}`)), }, - wantResult: false, - wantRemaining: 1, + expectSkipFetch: false, }, { name: "single item with null type", @@ -74,8 +70,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`null`)), }, - wantResult: true, - wantRemaining: -1, // empty - can skip fetch since no fields required + expectSkipFetch: true, }, { name: "single item with all required data", @@ -112,8 +107,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"user": {"id": "123", "name": "John"}}`)), }, - wantResult: true, - wantRemaining: -1, // empty + expectSkipFetch: true, }, { name: "single item missing required field", @@ -150,8 +144,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"user": {"id": "123"}}`)), // missing "name" }, - wantResult: false, - wantRemaining: 1, + expectSkipFetch: false, }, { name: "single item missing nullable field", @@ -188,8 +181,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"user": {"id": "123"}}`)), // missing nullable "email" }, - wantResult: false, - wantRemaining: 1, + expectSkipFetch: false, }, { name: "single item with null value on required path", @@ -219,8 +211,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"user": {"id": null}}`)), // null value on required field }, - wantResult: false, - wantRemaining: 1, + expectSkipFetch: false, }, { name: "single item with null value on nullable path", @@ -257,8 +248,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"user": {"id": "123", "email": null}}`)), // null value on nullable field }, - wantResult: true, - wantRemaining: -1, // empty + expectSkipFetch: true, }, { name: "multiple items all can be skipped", @@ -281,8 +271,7 @@ func TestLoader_canSkipFetch(t *testing.T) { astjson.MustParseBytes([]byte(`{"id": "456"}`)), astjson.MustParseBytes([]byte(`{"id": "789"}`)), }, - wantResult: true, - wantRemaining: -1, // empty + expectSkipFetch: true, }, { name: "multiple items some can be skipped", @@ -321,13 +310,7 @@ func TestLoader_canSkipFetch(t *testing.T) { astjson.MustParseBytes([]byte(`{"user": {"id": "456"}}`)), // missing name astjson.MustParseBytes([]byte(`{"user": {"id": "789", "name": "Alice"}}`)), // complete }, - wantResult: false, - wantRemaining: 1, - checkFn: func(t *testing.T, remaining []*CacheKey) { - // Check that the remaining item is the incomplete one - user := remaining[0].Item.Get("user") - assert.Equal(t, "456", string(user.Get("id").GetStringBytes())) - }, + expectSkipFetch: false, }, { name: "multiple items none can be skipped", @@ -366,8 +349,7 @@ func TestLoader_canSkipFetch(t *testing.T) { astjson.MustParseBytes([]byte(`{"user": {"id": "456"}}`)), // missing name astjson.MustParseBytes([]byte(`{"user": {"id": "789"}}`)), // missing name }, - wantResult: false, - wantRemaining: 3, + expectSkipFetch: false, }, { name: "nullable array that is null", @@ -404,8 +386,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"user": {"id": "123", "tags": null}}`)), }, - wantResult: true, - wantRemaining: -1, // empty + expectSkipFetch: true, }, { name: "nullable array that is empty", @@ -442,8 +423,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"user": {"id": "123", "tags": []}}`)), }, - wantResult: true, - wantRemaining: -1, // empty + expectSkipFetch: true, }, { name: "deeply nested structure", @@ -523,8 +503,7 @@ func TestLoader_canSkipFetch(t *testing.T) { } }`)), }, - wantResult: true, - wantRemaining: -1, // empty + expectSkipFetch: true, }, { name: "nil info", @@ -532,8 +511,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"id": "123"}`)), }, - wantResult: false, - wantRemaining: 1, + expectSkipFetch: false, }, { name: "nil ProvidesData", @@ -544,8 +522,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"id": "123"}`)), }, - wantResult: false, - wantRemaining: 1, + expectSkipFetch: false, }, { name: "array with scalar items - valid", @@ -570,8 +547,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"tags": ["tag1", "tag2", "tag3"]}`)), }, - wantResult: true, - wantRemaining: -1, // empty + expectSkipFetch: true, }, { name: "array with scalar items - invalid (null item in non-nullable array)", @@ -596,8 +572,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"tags": ["tag1", null, "tag3"]}`)), // null item in non-nullable array }, - wantResult: false, - wantRemaining: 1, + expectSkipFetch: false, }, { name: "array with scalar items - valid (null item in nullable array)", @@ -622,8 +597,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"tags": ["tag1", null, "tag3"]}`)), // null item in nullable array }, - wantResult: true, - wantRemaining: -1, // empty + expectSkipFetch: true, }, { name: "array with object items - valid", @@ -664,8 +638,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"users": [{"id": "1", "name": "John"}, {"id": "2", "name": "Jane"}]}`)), }, - wantResult: true, - wantRemaining: -1, // empty + expectSkipFetch: true, }, { name: "array with object items - invalid (missing required field)", @@ -706,8 +679,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"users": [{"id": "1", "name": "John"}, {"id": "2"}]}`)), // missing "name" field }, - wantResult: false, - wantRemaining: 1, + expectSkipFetch: false, }, { name: "nested arrays - valid", @@ -736,8 +708,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"matrix": [["a", "b"], ["c", "d"], ["e", "f"]]}`)), }, - wantResult: true, - wantRemaining: -1, // empty + expectSkipFetch: true, }, { name: "nested arrays - invalid (null in inner non-nullable array)", @@ -766,8 +737,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"matrix": [["a", "b"], ["c", null], ["e", "f"]]}`)), // null in inner array }, - wantResult: false, - wantRemaining: 1, + expectSkipFetch: false, }, { name: "array of objects with nested arrays - complex valid case", @@ -821,8 +791,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"groups": [{"name": "admins", "members": [{"id": "1"}, {"id": "2"}]}, {"name": "users", "members": [{"id": "3"}]}]}`)), }, - wantResult: true, - wantRemaining: -1, // empty + expectSkipFetch: true, }, { name: "array of objects with nested arrays - complex invalid case", @@ -876,8 +845,7 @@ func TestLoader_canSkipFetch(t *testing.T) { items: []*astjson.Value{ astjson.MustParseBytes([]byte(`{"groups": [{"name": "admins", "members": [{"id": "1"}, {}]}, {"name": "users", "members": [{"id": "3"}]}]}`)), // missing id in one member }, - wantResult: false, - wantRemaining: 1, + expectSkipFetch: false, }, } @@ -894,7 +862,7 @@ func TestLoader_canSkipFetch(t *testing.T) { cacheKeys := make([]*CacheKey, len(itemsCopy)) for i, item := range itemsCopy { cacheKeys[i] = &CacheKey{ - Item: item, + FromCache: item, } } @@ -903,19 +871,8 @@ func TestLoader_canSkipFetch(t *testing.T) { cacheKeys: cacheKeys, } - remaining, result := loader.canSkipFetch(tt.info, res) - - assert.Equal(t, tt.wantResult, result, "result mismatch") - - if tt.wantRemaining == -1 { - assert.Empty(t, remaining, "expected empty remaining items") - } else { - assert.Len(t, remaining, tt.wantRemaining, "remaining items count mismatch") - } - - if tt.checkFn != nil { - tt.checkFn(t, remaining) - } + canSkipFetch := loader.canSkipFetch(tt.info, res) + assert.Equal(t, tt.expectSkipFetch, canSkipFetch, "skip fetch") }) } } From d8f04cabe7e7be1d748c657e315ed9e38b15a52b Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sun, 2 Nov 2025 13:33:20 +0100 Subject: [PATCH 60/61] chore: refactor arena handling --- .../grpc_datasource/grpc_datasource.go | 51 ++++++++++++++----- .../grpc_datasource/grpc_datasource_test.go | 12 ++--- .../grpc_datasource/json_builder.go | 6 +-- v2/pkg/engine/resolve/arena.go | 47 +++++++++++++++-- v2/pkg/engine/resolve/arena_test.go | 18 +++---- v2/pkg/engine/resolve/resolve.go | 20 ++++---- 6 files changed, 106 insertions(+), 48 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index c9c37891f..6cbc4ca12 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -8,11 +8,11 @@ package grpcdatasource import ( "context" + "encoding/binary" "fmt" "net/http" - "sync" - "errors" + "github.com/cespare/xxhash/v2" "github.com/tidwall/gjson" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -46,6 +46,8 @@ type DataSource struct { mapping *GRPCMapping federationConfigs plan.FederationFieldConfigurations disabled bool + + pool *resolve.ArenaPool } type ProtoConfig struct { @@ -81,6 +83,7 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D mapping: config.Mapping, federationConfigs: config.FederationConfigs, disabled: config.Disabled, + pool: resolve.NewArenaPool(), }, nil } @@ -93,15 +96,23 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { // get variables from input variables := gjson.Parse(unsafebytes.BytesToString(input)).Get("body.variables") - builder := newJSONBuilder(d.mapping, variables) + + var ( + poolItems []*resolve.ArenaPoolItem + ) + defer func() { + d.pool.ReleaseMany(poolItems) + }() + + item := d.acquirePoolItem(input, 0) + poolItems = append(poolItems, item) + builder := newJSONBuilder(item.Arena, d.mapping, variables) if d.disabled { return builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used")), nil } - arena := astjson.Arena{} - defer arena.Reset() - root := arena.NewObject() + root := astjson.ObjectValue(nil) failed := false @@ -116,8 +127,10 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte // make gRPC calls for index, serviceCall := range serviceCalls { + item := d.acquirePoolItem(input, index) + poolItems = append(poolItems, item) + builder := newJSONBuilder(item.Arena, d.mapping, variables) errGrp.Go(func() error { - a := astjson.Arena{} // Invoke the gRPC method - this will populate serviceCall.Output err := d.cc.Invoke(errGrpCtx, serviceCall.MethodFullName(), serviceCall.Input, serviceCall.Output) @@ -125,7 +138,7 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte return err } - response, err := builder.marshalResponseJSON(&a, &serviceCall.RPC.Response, serviceCall.Output) + response, err := builder.marshalResponseJSON(&serviceCall.RPC.Response, serviceCall.Output) if err != nil { return err } @@ -150,7 +163,7 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte } if err := errGrp.Wait(); err != nil { - out.Write(builder.writeErrorBytes(err)) + data = builder.writeErrorBytes(err) failed = true return nil } @@ -163,19 +176,29 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte root, err = builder.mergeValues(root, result.response) } if err != nil { - out.Write(builder.writeErrorBytes(err)) + data = builder.writeErrorBytes(err) return err } } return nil }); err != nil || failed { - return err + return data, err } - data := builder.toDataObject(root) - out.Write(data.MarshalTo(nil)) - return nil + value := builder.toDataObject(root) + return value.MarshalTo(nil), err +} + +func (d *DataSource) acquirePoolItem(input []byte, index int) *resolve.ArenaPoolItem { + keyGen := xxhash.New() + _, _ = keyGen.Write(input) + var b [8]byte + binary.LittleEndian.PutUint64(b[:], uint64(index)) + _, _ = keyGen.Write(b[:]) + key := keyGen.Sum64() + item := d.pool.Acquire(key) + return item } // LoadWithFiles implements resolve.DataSource interface. diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 9b4d6be43..9a427809a 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -54,8 +54,7 @@ func Benchmark_DataSource_Load(b *testing.B) { b.ReportAllocs() b.ResetTimer() for b.Loop() { - output := new(bytes.Buffer) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), output) + _, err = ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(b, err) } } @@ -93,7 +92,7 @@ func Benchmark_DataSource_Load_WithFieldArguments(b *testing.B) { }) require.NoError(b, err) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), new(bytes.Buffer)) + _, err = ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(b, err) } } @@ -564,7 +563,7 @@ func TestMarshalResponseJSON(t *testing.T) { responseMessage := dynamicpb.NewMessage(responseMessageDesc) responseMessage.Mutable(responseMessageDesc.Fields().ByName("result")).List().Append(protoref.ValueOfMessage(productMessage)) - jsonBuilder := newJSONBuilder(nil, gjson.Result{}) + jsonBuilder := newJSONBuilder(nil, nil, gjson.Result{}) responseJSON, err := jsonBuilder.marshalResponseJSON(&response, responseMessage) require.NoError(t, err) require.Equal(t, `{"_entities":[{"__typename":"Product","id":"123","name_different":"test","price_different":123.45}]}`, responseJSON.String()) @@ -3723,15 +3722,14 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response var resp graphqlResponse - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") tc.validate(t, resp.Data) diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go index 7eb874514..0b2edc07c 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -114,12 +114,12 @@ type jsonBuilder struct { // newJSONBuilder creates a new JSON builder instance with the provided mapping // and variables. The builder automatically creates an index map for proper // federation entity ordering if representations are present in the variables. -func newJSONBuilder(mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { +func newJSONBuilder(a arena.Arena, mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { return &jsonBuilder{ mapping: mapping, variables: variables, indexMap: createRepresentationIndexMap(variables), - jsonArena: arena.NewMonotonicArena(), + jsonArena: a, } } @@ -259,7 +259,7 @@ func (j *jsonBuilder) mergeWithPath(base *astjson.Value, resolved *astjson.Value } for i := range responseValues { - responseValues[i].Set(elementName, resolvedValues[i].Get(elementName)) + responseValues[i].Set(j.jsonArena, elementName, resolvedValues[i].Get(elementName)) } return nil diff --git a/v2/pkg/engine/resolve/arena.go b/v2/pkg/engine/resolve/arena.go index 98bd93087..7909460b2 100644 --- a/v2/pkg/engine/resolve/arena.go +++ b/v2/pkg/engine/resolve/arena.go @@ -32,6 +32,7 @@ type arenaPoolItemSize struct { // ArenaPoolItem wraps an arena.Arena for use in the pool type ArenaPoolItem struct { Arena arena.Arena + Key uint64 } // NewArenaPool creates a new ArenaPool instance @@ -43,7 +44,7 @@ func NewArenaPool() *ArenaPool { // Acquire gets an arena from the pool or creates a new one if none are available. // The id parameter is used to track arena sizes per use case for optimization. -func (p *ArenaPool) Acquire(id uint64) *ArenaPoolItem { +func (p *ArenaPool) Acquire(key uint64) *ArenaPoolItem { p.mu.Lock() defer p.mu.Unlock() @@ -56,21 +57,23 @@ func (p *ArenaPool) Acquire(id uint64) *ArenaPoolItem { v := wp.Value() if v != nil { + v.Key = key return v } // If weak pointer was nil (GC collected), continue to next item } // No arena available, create a new one - size := arena.WithMinBufferSize(p.getArenaSize(id)) + size := arena.WithMinBufferSize(p.getArenaSize(key)) return &ArenaPoolItem{ Arena: arena.NewMonotonicArena(size), + Key: key, } } // Release returns an arena to the pool for reuse. // The peak memory usage is recorded to optimize future arena sizes for this use case. -func (p *ArenaPool) Release(id uint64, item *ArenaPoolItem) { +func (p *ArenaPool) Release(item *ArenaPoolItem) { peak := item.Arena.Peak() item.Arena.Reset() @@ -78,7 +81,7 @@ func (p *ArenaPool) Release(id uint64, item *ArenaPoolItem) { defer p.mu.Unlock() // Record the peak usage for this use case - if size, ok := p.sizes[id]; ok { + if size, ok := p.sizes[item.Key]; ok { if size.count == 50 { size.count = 1 size.totalBytes = size.totalBytes / 50 @@ -86,17 +89,51 @@ func (p *ArenaPool) Release(id uint64, item *ArenaPoolItem) { size.count++ size.totalBytes += peak } else { - p.sizes[id] = &arenaPoolItemSize{ + p.sizes[item.Key] = &arenaPoolItemSize{ count: 1, totalBytes: peak, } } + item.Key = 0 + // Add the arena back to the pool using a weak pointer w := weak.Make(item) p.pool = append(p.pool, w) } +func (p *ArenaPool) ReleaseMany(items []*ArenaPoolItem) { + p.mu.Lock() + defer p.mu.Unlock() + + for _, item := range items { + + peak := item.Arena.Peak() + item.Arena.Reset() + + // Record the peak usage for this use case + if size, ok := p.sizes[item.Key]; ok { + if size.count == 50 { + size.count = 1 + size.totalBytes = size.totalBytes / 50 + } + size.count++ + size.totalBytes += peak + } else { + p.sizes[item.Key] = &arenaPoolItemSize{ + count: 1, + totalBytes: peak, + } + } + + item.Key = 0 + + // Add the arena back to the pool using a weak pointer + w := weak.Make(item) + p.pool = append(p.pool, w) + } +} + // getArenaSize returns the optimal arena size for a given use case ID. // If no size is recorded, it defaults to 1MB. func (p *ArenaPool) getArenaSize(id uint64) int { diff --git a/v2/pkg/engine/resolve/arena_test.go b/v2/pkg/engine/resolve/arena_test.go index 20c1069b8..c884434f1 100644 --- a/v2/pkg/engine/resolve/arena_test.go +++ b/v2/pkg/engine/resolve/arena_test.go @@ -47,7 +47,7 @@ func TestArenaPool_ReleaseAndAcquire(t *testing.T) { assert.NoError(t, err) // Release it - pool.Release(id, item1) + pool.Release(item1) // Pool should have one item assert.Equal(t, 1, len(pool.pool), "expected pool to have 1 item") @@ -88,7 +88,7 @@ func TestArenaPool_Acquire_ProvesBugFix(t *testing.T) { // Release all while keeping strong references for i := 0; i < numItems; i++ { - pool.Release(id, items[i]) + pool.Release(items[i]) } // Pool should have all items @@ -137,7 +137,7 @@ func TestArenaPool_Release_PeakTracking(t *testing.T) { peak1 := item1.Arena.Peak() assert.Equal(t, peak1, 5) - pool.Release(id, item1) + pool.Release(item1) // Check that size was tracked size, exists := pool.sizes[id] @@ -150,7 +150,7 @@ func TestArenaPool_Release_PeakTracking(t *testing.T) { _, err = buf2.WriteString("larger data") assert.NoError(t, err) - pool.Release(id, item2) + pool.Release(item2) // Check updated tracking assert.Equal(t, 2, size.count, "expected count 2") @@ -170,7 +170,7 @@ func TestArenaPool_GetArenaSize(t *testing.T) { buf := arena.NewArenaBuffer(item.Arena) _, err := buf.WriteString("some data") assert.NoError(t, err) - pool.Release(id, item) + pool.Release(item) size2 := pool.getArenaSize(id) assert.NotEqual(t, 0, size2, "expected non-zero size after usage") @@ -193,7 +193,7 @@ func TestArenaPool_MultipleItemsInPool(t *testing.T) { // Release all while keeping references for i := 0; i < numItems; i++ { - pool.Release(id, items[i]) + pool.Release(items[i]) } // Should have all items in pool @@ -221,7 +221,7 @@ func TestArenaPool_Release_MovingWindow(t *testing.T) { buf := arena.NewArenaBuffer(item.Arena) _, err := buf.WriteString("test data") assert.NoError(t, err) - pool.Release(id, item) + pool.Release(item) } // After 50 releases, verify count and total @@ -237,7 +237,7 @@ func TestArenaPool_Release_MovingWindow(t *testing.T) { _, err := buf51.WriteString("test data") assert.NoError(t, err) peak51 := item51.Arena.Peak() - pool.Release(id, item51) + pool.Release(item51) // After 51st release, verify the window was reset // count should be 2 (reset to 1, then incremented) @@ -253,7 +253,7 @@ func TestArenaPool_Release_MovingWindow(t *testing.T) { buf := arena.NewArenaBuffer(item.Arena) _, err := buf.WriteString("more data") assert.NoError(t, err) - pool.Release(id, item) + pool.Release(item) } // After 10 more releases, count should be 12 (2 + 10) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index b93888a79..747ee02c4 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -342,7 +342,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe err = t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { r.inboundRequestSingleFlight.FinishErr(inflight, err) - r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) return nil, err } @@ -350,7 +350,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe err = t.loader.LoadGraphQLResponseData(ctx, response, t.resolvable) if err != nil { r.inboundRequestSingleFlight.FinishErr(inflight, err) - r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) return nil, err } } @@ -361,14 +361,14 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, buf) if err != nil { r.inboundRequestSingleFlight.FinishErr(inflight, err) - r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) - r.responseBufferPool.Release(ctx.Request.ID, responseArena) + r.resolveArenaPool.Release(resolveArena) + r.responseBufferPool.Release(responseArena) return nil, err } // first release resolverArena // all data is resolved and written into the response arena - r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) // next we write back to the client // this includes flushing and syscalls // as such, it can take some time @@ -377,7 +377,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe r.inboundRequestSingleFlight.FinishOk(inflight, buf.Bytes()) // all data is written to the client // we're safe to release our buffer - r.responseBufferPool.Release(ctx.Request.ID, responseArena) + r.responseBufferPool.Release(responseArena) return resp, err } @@ -515,7 +515,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, resolveArena.Arena) if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { - r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:init:failed:%d\n", sub.id.SubscriptionID) @@ -527,7 +527,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } if err := t.loader.LoadGraphQLResponseData(resolveCtx, sub.resolve.Response, t.resolvable); err != nil { - r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:load:failed:%d\n", sub.id.SubscriptionID) @@ -539,7 +539,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } if err := t.resolvable.Resolve(resolveCtx.ctx, sub.resolve.Response.Data, sub.resolve.Response.Fetches, sub.writer); err != nil { - r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:resolve:failed:%d\n", sub.id.SubscriptionID) @@ -550,7 +550,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar return } - r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) if err := sub.writer.Flush(); err != nil { // If flush fails (e.g. client disconnected), remove the subscription. From 648dd0213d6fc17e0a52fd99e446d4504b0e22b7 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 12 Nov 2025 12:10:11 +0100 Subject: [PATCH 61/61] chore: add resolve caching test --- v2/pkg/engine/resolve/resolve_caching_test.go | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 v2/pkg/engine/resolve/resolve_caching_test.go diff --git a/v2/pkg/engine/resolve/resolve_caching_test.go b/v2/pkg/engine/resolve/resolve_caching_test.go new file mode 100644 index 000000000..a0f9ae7be --- /dev/null +++ b/v2/pkg/engine/resolve/resolve_caching_test.go @@ -0,0 +1,146 @@ +package resolve + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" +) + +func TestResolveCaching(t *testing.T) { + t.Run("nested batching single root result", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + + listingRoot := mockedDS(t, ctrl, + `{"method":"POST","url":"http://listing","body":{"query":"query{listing{__typename id name}}"}}`, + `{"data":{"listing":{"__typename":"Listing","id":1,"name":"L1"}}}`) + + nested := mockedDS(t, ctrl, + `{"method":"POST","url":"http://nested","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Listing { nested { id price listing { __typename id }} }}}","variables":{"representations":[{"__typename":"Listing","id":1}]}}}`, + `{"data":{"_entities":[{"__typename":"Listing","nested":{"id":1.1,"price":123,"listing":{"__typename":"Listing","id":1}}}]}}`) + + return &GraphQLResponse{ + Fetches: Sequence( + SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://listing","body":{"query":"query{listing{__typename id name}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: listingRoot, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + }, + }, + }, "query"), + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://nested","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Listing { nested { id price listing { __typename id }} }}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + }, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + }, + DataSource: nested, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "query.listing", ObjectPath("listing")), + ), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("listing"), + Value: &Object{ + Path: []string{"listing"}, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("nested"), + Value: &Object{ + Path: []string{"nested"}, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Float{ + Path: []string{"id"}, + }, + }, + { + Name: []byte("price"), + Value: &Integer{ + Path: []string{"price"}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background(), Variables: nil}, `{"data":{"listing":{"id":1,"name":"L1","nested":{"id":1.1,"price":123}}}}` + })) +}