diff --git a/execution/engine/execution_engine.go b/execution/engine/execution_engine.go index 51b274282..ff77b855d 100644 --- a/execution/engine/execution_engine.go +++ b/execution/engine/execution_engine.go @@ -109,6 +109,12 @@ func WithRequestTraceOptions(options resolve.TraceOptions) ExecutionOptions { } } +func WithSubgraphHeadersBuilder(builder resolve.SubgraphHeadersBuilder) ExecutionOptions { + return func(ctx *internalExecutionContext) { + ctx.resolveContext.SubgraphHeadersBuilder = builder + } +} + func NewExecutionEngine(ctx context.Context, logger abstractlogger.Logger, engineConfig Configuration, resolverOptions resolve.ResolverOptions) (*ExecutionEngine, error) { executionPlanCache, err := lru.New(1024) if err != nil { diff --git a/execution/engine/federation_caching_test.go b/execution/engine/federation_caching_test.go new file mode 100644 index 000000000..4d8508372 --- /dev/null +++ b/execution/engine/federation_caching_test.go @@ -0,0 +1,962 @@ +package engine_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "path" + "sort" + "strings" + "sync" + "testing" + "time" + + "github.com/jensneuse/abstractlogger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/execution/federationtesting" + "github.com/wundergraph/graphql-go-tools/execution/federationtesting/gateway" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +func TestFederationCaching(t *testing.T) { + t.Run("two subgraphs - miss then hit", func(t *testing.T) { + defaultCache := NewFakeLoaderCache() + caches := map[string]resolve.LoaderCache{ + "default": defaultCache, + } + + // Create HTTP client with tracking + tracker := newSubgraphCallTracker(http.DefaultTransport) + trackingClient := &http.Client{ + Transport: tracker, + } + + setup := federationtesting.NewFederationSetup(addCachingGateway(withCachingEnableART(false), withCachingLoaderCache(caches), withHTTPClient(trackingClient))) + t.Cleanup(setup.Close) + gqlClient := NewGraphqlClient(http.DefaultClient) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Extract hostnames for tracking (URL.Host includes host:port) + accountsURLParsed, _ := url.Parse(setup.AccountsUpstreamServer.URL) + productsURLParsed, _ := url.Parse(setup.ProductsUpstreamServer.URL) + reviewsURLParsed, _ := url.Parse(setup.ReviewsUpstreamServer.URL) + accountsHost := accountsURLParsed.Host + productsHost := productsURLParsed.Host + reviewsHost := reviewsURLParsed.Host + + // First query - should miss cache and then set + defaultCache.ClearLog() + tracker.Reset() + resp := gqlClient.Query(ctx, setup.GatewayServer.URL, cachingTestQueryPath("queries/multiple_upstream.query"), nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) + + logAfterFirst := defaultCache.GetLog() + assert.Equal(t, 4, len(logAfterFirst)) + + wantLog := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{false}, + }, + { + Operation: "set", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + }, + { + Operation: "get", + Keys: []string{ + `{"__typename":"Product","key":{"upc":"top-1"}}`, + `{"__typename":"Product","key":{"upc":"top-2"}}`, + }, + Hits: []bool{false, false}, + }, + { + Operation: "set", + Keys: []string{ + `{"__typename":"Product","key":{"upc":"top-1"}}`, + `{"__typename":"Product","key":{"upc":"top-2"}}`, + }, + }, + } + assert.Equal(t, sortCacheLogKeys(wantLog), sortCacheLogKeys(logAfterFirst)) + + // Verify subgraph calls for first query + // First query should call products (topProducts) and reviews (reviews) + // Accounts is not called directly because username is provided via reviews @provides + productsCallsFirst := tracker.GetCount(productsHost) + reviewsCallsFirst := tracker.GetCount(reviewsHost) + accountsCallsFirst := tracker.GetCount(accountsHost) + + assert.Equal(t, 1, productsCallsFirst, "First query should call products subgraph exactly once") + assert.Equal(t, 1, reviewsCallsFirst, "First query should call reviews subgraph exactly once") + assert.Equal(t, 0, accountsCallsFirst, "First query should not call accounts subgraph (username provided via reviews @provides)") + + // Second query - should hit cache and then set + defaultCache.ClearLog() + tracker.Reset() + resp = gqlClient.Query(ctx, setup.GatewayServer.URL, cachingTestQueryPath("queries/multiple_upstream.query"), nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) + + logAfterSecond := defaultCache.GetLog() + assert.Equal(t, 2, len(logAfterSecond)) + + wantLogSecond := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{true}, // Should be a hit now + }, + { + Operation: "get", + Keys: []string{ + `{"__typename":"Product","key":{"upc":"top-1"}}`, + `{"__typename":"Product","key":{"upc":"top-2"}}`, + }, + Hits: []bool{true, true}, // Should be hits now, no misses + }, + } + assert.Equal(t, sortCacheLogKeys(wantLogSecond), sortCacheLogKeys(logAfterSecond)) + + // Verify subgraph calls for second query + productsCallsSecond := tracker.GetCount(productsHost) + reviewsCallsSecond := tracker.GetCount(reviewsHost) + accountsCallsSecond := tracker.GetCount(accountsHost) + + assert.Equal(t, 0, productsCallsSecond, "Second query should hit cache and not call products subgraph again") + assert.Equal(t, 0, reviewsCallsSecond, "Second query should hit cache and not call reviews subgraph again") + assert.Equal(t, 0, accountsCallsSecond, "accounts not involved") + }) + + t.Run("two subgraphs - partial fields then full fields", func(t *testing.T) { + defaultCache := NewFakeLoaderCache() + caches := map[string]resolve.LoaderCache{ + "default": defaultCache, + } + + // Create HTTP client with tracking + tracker := newSubgraphCallTracker(http.DefaultTransport) + trackingClient := &http.Client{ + Transport: tracker, + } + + setup := federationtesting.NewFederationSetup(addCachingGateway(withCachingEnableART(false), withCachingLoaderCache(caches), withHTTPClient(trackingClient))) + t.Cleanup(setup.Close) + gqlClient := NewGraphqlClient(http.DefaultClient) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Extract hostnames for tracking (URL.Host includes host:port) + accountsURLParsed, _ := url.Parse(setup.AccountsUpstreamServer.URL) + productsURLParsed, _ := url.Parse(setup.ProductsUpstreamServer.URL) + reviewsURLParsed, _ := url.Parse(setup.ReviewsUpstreamServer.URL) + accountsHost := accountsURLParsed.Host + productsHost := productsURLParsed.Host + reviewsHost := reviewsURLParsed.Host + + // First query - only ask for name field (products subgraph only) + defaultCache.ClearLog() + tracker.Reset() + firstQuery := `query { + topProducts { + name + } + }` + resp := gqlClient.QueryString(ctx, setup.GatewayServer.URL, firstQuery, nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby"},{"name":"Fedora"}]}}`, string(resp)) + + logAfterFirst := defaultCache.GetLog() + assert.Equal(t, 2, len(logAfterFirst)) + + wantLogFirst := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{false}, + }, + { + Operation: "set", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + }, + } + assert.Equal(t, sortCacheLogKeys(wantLogFirst), sortCacheLogKeys(logAfterFirst)) + + // Verify first query calls products subgraph only + productsCallsFirst := tracker.GetCount(productsHost) + reviewsCallsFirst := tracker.GetCount(reviewsHost) + accountsCallsFirst := tracker.GetCount(accountsHost) + assert.Equal(t, 1, productsCallsFirst, "First query calls products subgraph once") + assert.Equal(t, 0, reviewsCallsFirst, "First query does not call reviews subgraph") + assert.Equal(t, 0, accountsCallsFirst, "First query does not call accounts subgraph") + + // Second query - ask for full fields including reviews (products + reviews + accounts) + defaultCache.ClearLog() + tracker.Reset() + secondQuery := `query { + topProducts { + name + reviews { + body + author { + username + } + } + } + }` + resp = gqlClient.QueryString(ctx, setup.GatewayServer.URL, secondQuery, nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) + + logAfterSecond := defaultCache.GetLog() + assert.Equal(t, 4, len(logAfterSecond)) + + wantLogSecond := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{true}, // Should be a hit from first query + }, + { + Operation: "set", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + }, + { + Operation: "get", + Keys: []string{ + `{"__typename":"Product","key":{"upc":"top-1"}}`, + `{"__typename":"Product","key":{"upc":"top-2"}}`, + }, + Hits: []bool{false, false}, // Miss because second query requests different fields (reviews) + }, + { + Operation: "set", + Keys: []string{ + `{"__typename":"Product","key":{"upc":"top-1"}}`, + `{"__typename":"Product","key":{"upc":"top-2"}}`, + }, + }, + } + assert.Equal(t, sortCacheLogKeys(wantLogSecond), sortCacheLogKeys(logAfterSecond)) + + // Verify second query: products name is cached, but reviews still need to be fetched + productsCallsSecond := tracker.GetCount(productsHost) + reviewsCallsSecond := tracker.GetCount(reviewsHost) + accountsCallsSecond := tracker.GetCount(accountsHost) + + assert.Equal(t, 1, productsCallsSecond, "Second query calls products subgraph once (for reviews data)") + assert.Equal(t, 1, reviewsCallsSecond, "Second query calls reviews subgraph once (reviews not cached)") + assert.Equal(t, 0, accountsCallsSecond, "Second query does not call accounts subgraph") + + // Third query - repeat the second query (full fields) + defaultCache.ClearLog() + tracker.Reset() + thirdQuery := `query { + topProducts { + name + reviews { + body + author { + username + } + } + } + }` + resp = gqlClient.QueryString(ctx, setup.GatewayServer.URL, thirdQuery, nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) + + logAfterThird := defaultCache.GetLog() + assert.Equal(t, 2, len(logAfterThird)) + + wantLogThird := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{true}, // Should be a hit from second query + }, + { + Operation: "get", + Keys: []string{ + `{"__typename":"Product","key":{"upc":"top-1"}}`, + `{"__typename":"Product","key":{"upc":"top-2"}}`, + }, + Hits: []bool{true, true}, // Should be hits from second query + }, + } + assert.Equal(t, sortCacheLogKeys(wantLogThird), sortCacheLogKeys(logAfterThird)) + + // Verify third query: all data should be cached, no subgraph calls + productsCallsThird := tracker.GetCount(productsHost) + reviewsCallsThird := tracker.GetCount(reviewsHost) + accountsCallsThird := tracker.GetCount(accountsHost) + + // All cache entries show hits, so no subgraph calls should be made + assert.Equal(t, 0, productsCallsThird, "Third query does not call products subgraph (all cache hits)") + assert.Equal(t, 0, reviewsCallsThird, "Third query does not call reviews subgraph (all cache hits)") + assert.Equal(t, 0, accountsCallsThird, "Third query does not call accounts subgraph") + }) + + t.Run("two subgraphs - with subgraph header prefix", func(t *testing.T) { + defaultCache := NewFakeLoaderCache() + caches := map[string]resolve.LoaderCache{ + "default": defaultCache, + } + + // Create HTTP client with tracking + tracker := newSubgraphCallTracker(http.DefaultTransport) + trackingClient := &http.Client{ + Transport: tracker, + } + + // Create mock SubgraphHeadersBuilder that returns a fixed hash for each subgraph + mockHeadersBuilder := &mockSubgraphHeadersBuilder{ + hashes: map[string]uint64{ + "1": 11111, + "2": 22222, + "3": 33333, + }, + } + + setup := federationtesting.NewFederationSetup(addCachingGateway( + withCachingEnableART(false), + withCachingLoaderCache(caches), + withHTTPClient(trackingClient), + withSubgraphHeadersBuilder(mockHeadersBuilder), + )) + t.Cleanup(setup.Close) + gqlClient := NewGraphqlClient(http.DefaultClient) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Extract hostnames for tracking (URL.Host includes host:port) + accountsURLParsed, _ := url.Parse(setup.AccountsUpstreamServer.URL) + productsURLParsed, _ := url.Parse(setup.ProductsUpstreamServer.URL) + reviewsURLParsed, _ := url.Parse(setup.ReviewsUpstreamServer.URL) + accountsHost := accountsURLParsed.Host + productsHost := productsURLParsed.Host + reviewsHost := reviewsURLParsed.Host + + // First query - should miss cache and then set with prefixed keys + defaultCache.ClearLog() + tracker.Reset() + resp := gqlClient.Query(ctx, setup.GatewayServer.URL, cachingTestQueryPath("queries/multiple_upstream.query"), nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) + + logAfterFirst := defaultCache.GetLog() + assert.Equal(t, 4, len(logAfterFirst)) + + wantLog := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`11111:{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{false}, + }, + { + Operation: "set", + Keys: []string{`11111:{"__typename":"Query","field":"topProducts"}`}, + }, + { + Operation: "get", + Keys: []string{ + `22222:{"__typename":"Product","key":{"upc":"top-1"}}`, + `22222:{"__typename":"Product","key":{"upc":"top-2"}}`, + }, + Hits: []bool{false, false}, + }, + { + Operation: "set", + Keys: []string{ + `22222:{"__typename":"Product","key":{"upc":"top-1"}}`, + `22222:{"__typename":"Product","key":{"upc":"top-2"}}`, + }, + }, + } + assert.Equal(t, sortCacheLogKeys(wantLog), sortCacheLogKeys(logAfterFirst)) + + // Verify subgraph calls for first query + productsCallsFirst := tracker.GetCount(productsHost) + reviewsCallsFirst := tracker.GetCount(reviewsHost) + accountsCallsFirst := tracker.GetCount(accountsHost) + + assert.Equal(t, 1, productsCallsFirst, "First query should call products subgraph exactly once") + assert.Equal(t, 1, reviewsCallsFirst, "First query should call reviews subgraph exactly once") + assert.Equal(t, 0, accountsCallsFirst, "First query should not call accounts subgraph") + + // Second query - should hit cache with prefixed keys + defaultCache.ClearLog() + tracker.Reset() + resp = gqlClient.Query(ctx, setup.GatewayServer.URL, cachingTestQueryPath("queries/multiple_upstream.query"), nil, t) + assert.Equal(t, `{"data":{"topProducts":[{"name":"Trilby","reviews":[{"body":"A highly effective form of birth control.","author":{"username":"Me"}}]},{"name":"Fedora","reviews":[{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","author":{"username":"Me"}}]}]}}`, string(resp)) + + logAfterSecond := defaultCache.GetLog() + assert.Equal(t, 2, len(logAfterSecond)) + + wantLogSecond := []CacheLogEntry{ + { + Operation: "get", + Keys: []string{`11111:{"__typename":"Query","field":"topProducts"}`}, + Hits: []bool{true}, // Should be a hit now + }, + { + Operation: "get", + Keys: []string{ + `22222:{"__typename":"Product","key":{"upc":"top-1"}}`, + `22222:{"__typename":"Product","key":{"upc":"top-2"}}`, + }, + Hits: []bool{true, true}, // Should be hits now + }, + } + assert.Equal(t, sortCacheLogKeys(wantLogSecond), sortCacheLogKeys(logAfterSecond)) + + // Verify subgraph calls for second query + productsCallsSecond := tracker.GetCount(productsHost) + reviewsCallsSecond := tracker.GetCount(reviewsHost) + accountsCallsSecond := tracker.GetCount(accountsHost) + + assert.Equal(t, 0, productsCallsSecond, "Second query should hit cache and not call products subgraph again") + assert.Equal(t, 0, reviewsCallsSecond, "Second query should hit cache and not call reviews subgraph again") + assert.Equal(t, 0, accountsCallsSecond, "accounts not involved") + }) +} + +// subgraphCallTracker tracks HTTP requests made to subgraph servers +type subgraphCallTracker struct { + mu sync.RWMutex + counts map[string]int // Maps subgraph URL to call count + original http.RoundTripper +} + +func newSubgraphCallTracker(original http.RoundTripper) *subgraphCallTracker { + return &subgraphCallTracker{ + counts: make(map[string]int), + original: original, + } +} + +func (t *subgraphCallTracker) RoundTrip(req *http.Request) (*http.Response, error) { + t.mu.Lock() + host := req.URL.Host + t.counts[host]++ + t.mu.Unlock() + return t.original.RoundTrip(req) +} + +func (t *subgraphCallTracker) GetCount(url string) int { + t.mu.RLock() + defer t.mu.RUnlock() + return t.counts[url] +} + +func (t *subgraphCallTracker) Reset() { + t.mu.Lock() + defer t.mu.Unlock() + t.counts = make(map[string]int) +} + +func (t *subgraphCallTracker) GetCounts() map[string]int { + t.mu.RLock() + defer t.mu.RUnlock() + result := make(map[string]int) + for k, v := range t.counts { + result[k] = v + } + return result +} + +func (t *subgraphCallTracker) DebugPrint() string { + t.mu.RLock() + defer t.mu.RUnlock() + return fmt.Sprintf("%v", t.counts) +} + +// Helper functions for gateway setup with HTTP client support +type cachingGatewayOptions struct { + enableART bool + withLoaderCache map[string]resolve.LoaderCache + httpClient *http.Client + subgraphHeadersBuilder resolve.SubgraphHeadersBuilder +} + +func withCachingEnableART(enableART bool) func(*cachingGatewayOptions) { + return func(opts *cachingGatewayOptions) { + opts.enableART = enableART + } +} + +func withCachingLoaderCache(loaderCache map[string]resolve.LoaderCache) func(*cachingGatewayOptions) { + return func(opts *cachingGatewayOptions) { + opts.withLoaderCache = loaderCache + } +} + +func withHTTPClient(client *http.Client) func(*cachingGatewayOptions) { + return func(opts *cachingGatewayOptions) { + opts.httpClient = client + } +} + +func withSubgraphHeadersBuilder(builder resolve.SubgraphHeadersBuilder) func(*cachingGatewayOptions) { + return func(opts *cachingGatewayOptions) { + opts.subgraphHeadersBuilder = builder + } +} + +type cachingGatewayOptionsToFunc func(opts *cachingGatewayOptions) + +func addCachingGateway(options ...cachingGatewayOptionsToFunc) func(setup *federationtesting.FederationSetup) *httptest.Server { + opts := &cachingGatewayOptions{} + for _, option := range options { + option(opts) + } + return func(setup *federationtesting.FederationSetup) *httptest.Server { + httpClient := opts.httpClient + if httpClient == nil { + httpClient = http.DefaultClient + } + + poller := gateway.NewDatasource([]gateway.ServiceConfig{ + {Name: "accounts", URL: setup.AccountsUpstreamServer.URL}, + {Name: "products", URL: setup.ProductsUpstreamServer.URL, WS: strings.ReplaceAll(setup.ProductsUpstreamServer.URL, "http:", "ws:")}, + {Name: "reviews", URL: setup.ReviewsUpstreamServer.URL}, + }, httpClient) + + gtw := gateway.Handler(abstractlogger.NoopLogger, poller, httpClient, opts.enableART, opts.withLoaderCache, opts.subgraphHeadersBuilder) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + poller.Run(ctx) + return httptest.NewServer(gtw) + } +} + +// mockSubgraphHeadersBuilder is a mock implementation of SubgraphHeadersBuilder +type mockSubgraphHeadersBuilder struct { + hashes map[string]uint64 +} + +func (m *mockSubgraphHeadersBuilder) HeadersForSubgraph(subgraphName string) (http.Header, uint64) { + hash := m.hashes[subgraphName] + if hash == 0 { + // Return default hash if not found - this helps debug what names are being requested + // Note: This will cause test failures if subgraph names don't match + return nil, 99999 + } + return nil, hash +} + +func (m *mockSubgraphHeadersBuilder) HashAll() uint64 { + // Return a simple hash of all subgraph hashes combined + var result uint64 + for _, hash := range m.hashes { + result ^= hash + } + return result +} + +func cachingTestQueryPath(name string) string { + return path.Join("..", "federationtesting", "testdata", name) +} + +type CacheLogEntry struct { + Operation string // "get", "set", "delete" + Keys []string // Keys involved in the operation + Hits []bool // For Get: whether each key was a hit (true) or miss (false) +} + +// normalizeCacheLog creates a copy of log entries without timestamps for comparison +func normalizeCacheLog(log []CacheLogEntry) []CacheLogEntry { + normalized := make([]CacheLogEntry, len(log)) + for i, entry := range log { + normalized[i] = CacheLogEntry{ + Operation: entry.Operation, + Keys: entry.Keys, + Hits: entry.Hits, + // Timestamp is zero value for comparison + } + } + return normalized +} + +// sortCacheLogKeys sorts the keys (and corresponding hits) in each cache log entry +// This makes comparisons order-independent when multiple keys are present +func sortCacheLogKeys(log []CacheLogEntry) []CacheLogEntry { + sorted := make([]CacheLogEntry, len(log)) + for i, entry := range log { + // Only sort if there are multiple keys + if len(entry.Keys) <= 1 { + sorted[i] = entry + continue + } + + // Create pairs of (key, hit) to sort together + pairs := make([]struct { + key string + hit bool + }, len(entry.Keys)) + for j := range entry.Keys { + pairs[j].key = entry.Keys[j] + if entry.Hits != nil && j < len(entry.Hits) { + pairs[j].hit = entry.Hits[j] + } + } + + // Sort pairs by key + sort.Slice(pairs, func(a, b int) bool { + return pairs[a].key < pairs[b].key + }) + + // Extract sorted keys and hits + sorted[i] = CacheLogEntry{ + Operation: entry.Operation, + Keys: make([]string, len(pairs)), + Hits: nil, + } + if entry.Hits != nil && len(entry.Hits) > 0 { + sorted[i].Hits = make([]bool, len(pairs)) + } + for j := range pairs { + sorted[i].Keys[j] = pairs[j].key + if sorted[i].Hits != nil { + sorted[i].Hits[j] = pairs[j].hit + } + } + } + return sorted +} + +type cacheEntry struct { + data []byte + expiresAt *time.Time +} + +type FakeLoaderCache struct { + mu sync.RWMutex + storage map[string]cacheEntry + log []CacheLogEntry +} + +func NewFakeLoaderCache() *FakeLoaderCache { + return &FakeLoaderCache{ + storage: make(map[string]cacheEntry), + log: make([]CacheLogEntry, 0), + } +} + +func (f *FakeLoaderCache) cleanupExpired() { + now := time.Now() + for key, entry := range f.storage { + if entry.expiresAt != nil && now.After(*entry.expiresAt) { + delete(f.storage, key) + } + } +} + +func (f *FakeLoaderCache) Get(ctx context.Context, keys []string) ([]*resolve.CacheEntry, error) { + f.mu.Lock() + defer f.mu.Unlock() + + // Clean up expired entries before executing command + f.cleanupExpired() + + hits := make([]bool, len(keys)) + result := make([]*resolve.CacheEntry, len(keys)) + for i, key := range keys { + if entry, exists := f.storage[key]; exists { + // Make a copy of the data to prevent external modifications + dataCopy := make([]byte, len(entry.data)) + copy(dataCopy, entry.data) + result[i] = &resolve.CacheEntry{ + Key: key, + Value: dataCopy, + } + hits[i] = true + } else { + result[i] = nil + hits[i] = false + } + } + + // Log the operation + f.log = append(f.log, CacheLogEntry{ + Operation: "get", + Keys: keys, + Hits: hits, + }) + + return result, nil +} + +func (f *FakeLoaderCache) Set(ctx context.Context, entries []*resolve.CacheEntry, ttl time.Duration) error { + if len(entries) == 0 { + return nil + } + + f.mu.Lock() + defer f.mu.Unlock() + + // Clean up expired entries before executing command + f.cleanupExpired() + + keys := make([]string, 0, len(entries)) + for _, entry := range entries { + if entry == nil { + continue + } + cacheEntry := cacheEntry{ + // Make a copy of the data to prevent external modifications + data: make([]byte, len(entry.Value)), + } + copy(cacheEntry.data, entry.Value) + + // If ttl is 0, store without expiration + if ttl > 0 { + expiresAt := time.Now().Add(ttl) + cacheEntry.expiresAt = &expiresAt + } + + f.storage[entry.Key] = cacheEntry + keys = append(keys, entry.Key) + } + + // Log the operation + f.log = append(f.log, CacheLogEntry{ + Operation: "set", + Keys: keys, + Hits: nil, // Set operations don't have hits/misses + }) + + return nil +} + +func (f *FakeLoaderCache) Delete(ctx context.Context, keys []string) error { + f.mu.Lock() + defer f.mu.Unlock() + + // Clean up expired entries before executing command + f.cleanupExpired() + + for _, key := range keys { + delete(f.storage, key) + } + + // Log the operation + f.log = append(f.log, CacheLogEntry{ + Operation: "delete", + Keys: keys, + Hits: nil, // Delete operations don't have hits/misses + }) + + return nil +} + +// GetLog returns a copy of the cache operation log +func (f *FakeLoaderCache) GetLog() []CacheLogEntry { + f.mu.RLock() + defer f.mu.RUnlock() + logCopy := make([]CacheLogEntry, len(f.log)) + copy(logCopy, f.log) + return logCopy +} + +// ClearLog clears the cache operation log +func (f *FakeLoaderCache) ClearLog() { + f.mu.Lock() + defer f.mu.Unlock() + f.log = make([]CacheLogEntry, 0) +} + +// TestFakeLoaderCache tests the cache implementation itself +func TestFakeLoaderCache(t *testing.T) { + ctx := context.Background() + cache := NewFakeLoaderCache() + + t.Run("SetAndGet", func(t *testing.T) { + // Test basic set and get + keys := []string{"key1", "key2", "key3"} + entries := []*resolve.CacheEntry{ + {Key: "key1", Value: []byte("value1")}, + {Key: "key2", Value: []byte("value2")}, + {Key: "key3", Value: []byte("value3")}, + } + + err := cache.Set(ctx, entries, 0) // No TTL + require.NoError(t, err) + + // Get all keys + result, err := cache.Get(ctx, keys) + require.NoError(t, err) + require.Len(t, result, 3) + assert.NotNil(t, result[0]) + assert.Equal(t, "value1", string(result[0].Value)) + assert.NotNil(t, result[1]) + assert.Equal(t, "value2", string(result[1].Value)) + assert.NotNil(t, result[2]) + assert.Equal(t, "value3", string(result[2].Value)) + + // Get partial keys + result, err = cache.Get(ctx, []string{"key2", "key4", "key1"}) + require.NoError(t, err) + require.Len(t, result, 3) + assert.NotNil(t, result[0]) + assert.Equal(t, "value2", string(result[0].Value)) + assert.Nil(t, result[1]) // key4 doesn't exist + assert.NotNil(t, result[2]) + assert.Equal(t, "value1", string(result[2].Value)) + }) + + t.Run("Delete", func(t *testing.T) { + // Set some keys + entries := []*resolve.CacheEntry{ + {Key: "del1", Value: []byte("v1")}, + {Key: "del2", Value: []byte("v2")}, + {Key: "del3", Value: []byte("v3")}, + } + err := cache.Set(ctx, entries, 0) + require.NoError(t, err) + + // Delete some keys + err = cache.Delete(ctx, []string{"del1", "del3"}) + require.NoError(t, err) + + // Check remaining keys + result, err := cache.Get(ctx, []string{"del1", "del2", "del3"}) + require.NoError(t, err) + assert.Nil(t, result[0]) // del1 was deleted + assert.NotNil(t, result[1]) // del2 still exists + assert.Equal(t, "v2", string(result[1].Value)) + assert.Nil(t, result[2]) // del3 was deleted + }) + + t.Run("TTL", func(t *testing.T) { + // Set with 50ms TTL + entries := []*resolve.CacheEntry{ + {Key: "ttl1", Value: []byte("expire1")}, + {Key: "ttl2", Value: []byte("expire2")}, + } + err := cache.Set(ctx, entries, 50*time.Millisecond) + require.NoError(t, err) + + // Immediately get - should exist + result, err := cache.Get(ctx, []string{"ttl1", "ttl2"}) + require.NoError(t, err) + assert.NotNil(t, result[0]) + assert.Equal(t, "expire1", string(result[0].Value)) + assert.NotNil(t, result[1]) + assert.Equal(t, "expire2", string(result[1].Value)) + + // Wait for expiration + time.Sleep(60 * time.Millisecond) + + // Get again - should be nil + result, err = cache.Get(ctx, []string{"ttl1", "ttl2"}) + require.NoError(t, err) + assert.Nil(t, result[0]) + assert.Nil(t, result[1]) + }) + + t.Run("MixedTTL", func(t *testing.T) { + // Set some with TTL, some without + err := cache.Set(ctx, []*resolve.CacheEntry{{Key: "perm1", Value: []byte("permanent")}}, 0) + require.NoError(t, err) + + err = cache.Set(ctx, []*resolve.CacheEntry{{Key: "temp1", Value: []byte("temporary")}}, 50*time.Millisecond) + require.NoError(t, err) + + // Wait for temporary to expire + time.Sleep(60 * time.Millisecond) + + // Check both + result, err := cache.Get(ctx, []string{"perm1", "temp1"}) + require.NoError(t, err) + assert.NotNil(t, result[0]) + assert.Equal(t, "permanent", string(result[0].Value)) // Still exists + assert.Nil(t, result[1]) // Expired + }) + + t.Run("ThreadSafety", func(t *testing.T) { + // Test concurrent access + done := make(chan bool) + + // Writer goroutine + go func() { + for i := 0; i < 100; i++ { + key := fmt.Sprintf("concurrent_%d", i) + value := fmt.Sprintf("value_%d", i) + err := cache.Set(ctx, []*resolve.CacheEntry{{Key: key, Value: []byte(value)}}, 0) + assert.NoError(t, err) + } + done <- true + }() + + // Reader goroutine + go func() { + for i := 0; i < 100; i++ { + key := fmt.Sprintf("concurrent_%d", i%50) + _, err := cache.Get(ctx, []string{key}) + assert.NoError(t, err) + } + done <- true + }() + + // Deleter goroutine + go func() { + for i := 0; i < 50; i++ { + key := fmt.Sprintf("concurrent_%d", i*2) + err := cache.Delete(ctx, []string{key}) + assert.NoError(t, err) + } + done <- true + }() + + // Wait for all goroutines + <-done + <-done + <-done + }) + + t.Run("ResultLengthMatchesKeysLength", func(t *testing.T) { + // Test that result length always matches input keys length + + // Set some data + err := cache.Set(ctx, []*resolve.CacheEntry{ + {Key: "exist1", Value: []byte("data1")}, + {Key: "exist3", Value: []byte("data3")}, + }, 0) + require.NoError(t, err) + + // Request mix of existing and non-existing keys + keys := []string{"exist1", "missing1", "exist3", "missing2", "missing3"} + result, err := cache.Get(ctx, keys) + require.NoError(t, err) + + // Verify length matches exactly + assert.Len(t, result, len(keys), "Result length must match keys length") + assert.Len(t, result, 5, "Should return exactly 5 results") + + // Verify correct values + assert.NotNil(t, result[0]) + assert.Equal(t, "data1", string(result[0].Value)) // exist1 + assert.Nil(t, result[1]) // missing1 + assert.NotNil(t, result[2]) + assert.Equal(t, "data3", string(result[2].Value)) // exist3 + assert.Nil(t, result[3]) // missing2 + assert.Nil(t, result[4]) // missing3 + + // Test with all missing keys + allMissingKeys := []string{"missing4", "missing5", "missing6"} + result, err = cache.Get(ctx, allMissingKeys) + require.NoError(t, err) + assert.Len(t, result, 3, "Should return 3 results for 3 keys") + assert.Nil(t, result[0]) + assert.Nil(t, result[1]) + assert.Nil(t, result[2]) + + // Test with empty keys + result, err = cache.Get(ctx, []string{}) + require.NoError(t, err) + assert.Len(t, result, 0, "Should return empty slice for empty keys") + }) +} diff --git a/execution/engine/federation_integration_test.go b/execution/engine/federation_integration_test.go index 1b867bb37..e93231f21 100644 --- a/execution/engine/federation_integration_test.go +++ b/execution/engine/federation_integration_test.go @@ -18,13 +18,37 @@ import ( "github.com/sebdah/goldie/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/execution/federationtesting" "github.com/wundergraph/graphql-go-tools/execution/federationtesting/gateway" products "github.com/wundergraph/graphql-go-tools/execution/federationtesting/products/graph" ) -func addGateway(enableART bool) func(setup *federationtesting.FederationSetup) *httptest.Server { +type gatewayOptions struct { + enableART bool + withLoaderCache map[string]resolve.LoaderCache +} + +func withEnableART(enableART bool) func(*gatewayOptions) { + return func(opts *gatewayOptions) { + opts.enableART = enableART + } +} + +func withLoaderCache(loaderCache map[string]resolve.LoaderCache) func(*gatewayOptions) { + return func(opts *gatewayOptions) { + opts.withLoaderCache = loaderCache + } +} + +type gatewayOptionsToFunc func(opts *gatewayOptions) + +func addGateway(options ...gatewayOptionsToFunc) func(setup *federationtesting.FederationSetup) *httptest.Server { + opts := &gatewayOptions{} + for _, option := range options { + option(opts) + } return func(setup *federationtesting.FederationSetup) *httptest.Server { httpClient := http.DefaultClient @@ -34,7 +58,7 @@ func addGateway(enableART bool) func(setup *federationtesting.FederationSetup) * {Name: "reviews", URL: setup.ReviewsUpstreamServer.URL}, }, httpClient) - gtw := gateway.Handler(abstractlogger.NoopLogger, poller, httpClient, enableART) + gtw := gateway.Handler(abstractlogger.NoopLogger, poller, httpClient, opts.enableART, opts.withLoaderCache, nil) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -52,7 +76,7 @@ func TestFederationIntegrationTestWithArt(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - setup := federationtesting.NewFederationSetup(addGateway(true)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(true))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -82,7 +106,7 @@ func TestFederationIntegrationTestWithArt(t *testing.T) { func TestFederationIntegrationTest(t *testing.T) { t.Run("single upstream query operation", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -92,7 +116,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("query spans multiple federated servers", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -102,7 +126,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("mutation operation with variables", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -116,7 +140,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("union query", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -126,7 +150,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("interface query", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -136,7 +160,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("subscription query through WebSocket transport", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -155,7 +179,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Multiple queries and nested fragments", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -205,7 +229,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Multiple queries with __typename", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -237,7 +261,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Query that returns union", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -316,7 +340,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Object response type with interface and object fragment", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -335,7 +359,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Interface response type with object fragment", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -355,7 +379,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("recursive fragment", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -365,7 +389,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("empty fragment", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -375,7 +399,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("empty fragment variant", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -385,7 +409,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Union response type with interface fragments", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) @@ -429,7 +453,7 @@ func TestFederationIntegrationTest(t *testing.T) { // Duplicated properties (and therefore invalid JSON) are usually removed during normalization processes. // It is not yet decided whether this should be addressed before these normalization processes. t.Run("Complex nesting", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -441,7 +465,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("More complex nesting", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -453,7 +477,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Multiple nested interfaces", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -465,7 +489,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Multiple nested unions", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -477,7 +501,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("More complex nesting typename variant", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -489,7 +513,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Abstract object", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -501,7 +525,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Abstract object non shared", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -513,7 +537,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Abstract object nested", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -525,7 +549,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Abstract object nested reverse", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -537,7 +561,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Abstract object mixed", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -549,7 +573,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Abstract interface field", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) defer setup.Close() gqlClient := NewGraphqlClient(http.DefaultClient) @@ -561,7 +585,7 @@ func TestFederationIntegrationTest(t *testing.T) { }) t.Run("Merged fields are still resolved", func(t *testing.T) { - setup := federationtesting.NewFederationSetup(addGateway(false)) + setup := federationtesting.NewFederationSetup(addGateway(withEnableART(false))) t.Cleanup(setup.Close) gqlClient := NewGraphqlClient(http.DefaultClient) ctx, cancel := context.WithCancel(context.Background()) diff --git a/execution/engine/graphql_client_test.go b/execution/engine/graphql_client_test.go index 23ed0c6e3..40b0018ac 100644 --- a/execution/engine/graphql_client_test.go +++ b/execution/engine/graphql_client_test.go @@ -74,6 +74,22 @@ func (g *GraphqlClient) Query(ctx context.Context, addr, queryFilePath string, v return responseBodyBytes } +func (g *GraphqlClient) QueryString(ctx context.Context, addr, query string, variables queryVariables, t *testing.T) []byte { + reqBody := requestBody(t, query, variables) + req, err := http.NewRequest(http.MethodPost, addr, bytes.NewBuffer(reqBody)) + require.NoError(t, err) + req = req.WithContext(ctx) + resp, err := g.httpClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + responseBodyBytes, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Contains(t, resp.Header.Get("Content-Type"), "application/json") + + return responseBodyBytes +} + func (g *GraphqlClient) QueryStatusCode(ctx context.Context, addr, queryFilePath string, variables queryVariables, expectedStatusCode int, t *testing.T) []byte { reqBody := loadQuery(t, queryFilePath, variables) req, err := http.NewRequest(http.MethodPost, addr, bytes.NewBuffer(reqBody)) diff --git a/execution/engine/testdata/complex_nesting_query_with_art.json b/execution/engine/testdata/complex_nesting_query_with_art.json index 69a208fe4..ec85c1e5c 100644 --- a/execution/engine/testdata/complex_nesting_query_with_art.json +++ b/execution/engine/testdata/complex_nesting_query_with_art.json @@ -170,7 +170,7 @@ "duration_since_start_pretty": "1ns", "duration_load_nanoseconds": 1, "duration_load_pretty": "1ns", - "single_flight_used": false, + "single_flight_used": true, "single_flight_shared_response": false, "load_skipped": false, "load_stats": { @@ -310,7 +310,7 @@ "duration_since_start_pretty": "1ns", "duration_load_nanoseconds": 1, "duration_load_pretty": "1ns", - "single_flight_used": false, + "single_flight_used": true, "single_flight_shared_response": false, "load_skipped": false, "load_stats": { @@ -496,7 +496,7 @@ "duration_since_start_pretty": "1ns", "duration_load_nanoseconds": 1, "duration_load_pretty": "1ns", - "single_flight_used": false, + "single_flight_used": true, "single_flight_shared_response": false, "load_skipped": false, "load_stats": { diff --git a/execution/federationtesting/gateway/gateway.go b/execution/federationtesting/gateway/gateway.go index ffc62eb7d..728736a36 100644 --- a/execution/federationtesting/gateway/gateway.go +++ b/execution/federationtesting/gateway/gateway.go @@ -34,11 +34,13 @@ func NewGateway( gqlHandlerFactory HandlerFactory, httpClient *http.Client, logger log.Logger, + loaderCaches map[string]resolve.LoaderCache, ) *Gateway { return &Gateway{ gqlHandlerFactory: gqlHandlerFactory, httpClient: httpClient, logger: logger, + loaderCaches: loaderCaches, mu: &sync.Mutex{}, readyCh: make(chan struct{}), @@ -50,6 +52,7 @@ type Gateway struct { gqlHandlerFactory HandlerFactory httpClient *http.Client logger log.Logger + loaderCaches map[string]resolve.LoaderCache gqlHandler http.Handler mu *sync.Mutex @@ -82,6 +85,7 @@ func (g *Gateway) UpdateDataSources(subgraphsConfigs []engine.SubgraphConfigurat executionEngine, err := engine.NewExecutionEngine(ctx, g.logger, engineConfig, resolve.ResolverOptions{ MaxConcurrency: 1024, + Caches: g.loaderCaches, }) if err != nil { g.logger.Error("create engine: %v", log.Error(err)) diff --git a/execution/federationtesting/gateway/http/handler.go b/execution/federationtesting/gateway/http/handler.go index e6d575cd7..2e8983395 100644 --- a/execution/federationtesting/gateway/http/handler.go +++ b/execution/federationtesting/gateway/http/handler.go @@ -5,6 +5,7 @@ import ( "github.com/gobwas/ws" log "github.com/jensneuse/abstractlogger" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/execution/engine" "github.com/wundergraph/graphql-go-tools/execution/graphql" @@ -20,22 +21,25 @@ func NewGraphqlHTTPHandler( upgrader *ws.HTTPUpgrader, logger log.Logger, enableART bool, + subgraphHeadersBuilder resolve.SubgraphHeadersBuilder, ) http.Handler { return &GraphQLHTTPRequestHandler{ - schema: schema, - engine: engine, - wsUpgrader: upgrader, - log: logger, - enableART: enableART, + schema: schema, + engine: engine, + wsUpgrader: upgrader, + log: logger, + enableART: enableART, + subgraphHeadersBuilder: subgraphHeadersBuilder, } } type GraphQLHTTPRequestHandler struct { - log log.Logger - wsUpgrader *ws.HTTPUpgrader - engine *engine.ExecutionEngine - schema *graphql.Schema - enableART bool + log log.Logger + wsUpgrader *ws.HTTPUpgrader + engine *engine.ExecutionEngine + schema *graphql.Schema + enableART bool + subgraphHeadersBuilder resolve.SubgraphHeadersBuilder } func (g *GraphQLHTTPRequestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { diff --git a/execution/federationtesting/gateway/http/http.go b/execution/federationtesting/gateway/http/http.go index 5a255e01c..0d0a50e3f 100644 --- a/execution/federationtesting/gateway/http/http.go +++ b/execution/federationtesting/gateway/http/http.go @@ -45,6 +45,10 @@ func (g *GraphQLHTTPRequestHandler) handleHTTP(w http.ResponseWriter, r *http.Re opts = append(opts, engine.WithRequestTraceOptions(tracingOpts)) } + if g.subgraphHeadersBuilder != nil { + opts = append(opts, engine.WithSubgraphHeadersBuilder(g.subgraphHeadersBuilder)) + } + buf := bytes.NewBuffer(make([]byte, 0, 4096)) resultWriter := graphql.NewEngineResultWriterFromBuffer(buf) if err = g.engine.Execute(r.Context(), &gqlRequest, &resultWriter, opts...); err != nil { diff --git a/execution/federationtesting/gateway/main.go b/execution/federationtesting/gateway/main.go index 61f97b0a1..dddfb372c 100644 --- a/execution/federationtesting/gateway/main.go +++ b/execution/federationtesting/gateway/main.go @@ -10,6 +10,7 @@ import ( "github.com/wundergraph/graphql-go-tools/execution/engine" http2 "github.com/wundergraph/graphql-go-tools/execution/federationtesting/gateway/http" "github.com/wundergraph/graphql-go-tools/execution/graphql" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) func NewDatasource(serviceConfig []ServiceConfig, httpClient *http.Client) *DatasourcePollerPoller { @@ -24,6 +25,8 @@ func Handler( datasourcePoller *DatasourcePollerPoller, httpClient *http.Client, enableART bool, + loaderCaches map[string]resolve.LoaderCache, + subgraphHeadersBuilder resolve.SubgraphHeadersBuilder, ) *Gateway { upgrader := &ws.DefaultHTTPUpgrader upgrader.Header = http.Header{} @@ -32,10 +35,10 @@ func Handler( datasourceWatcher := datasourcePoller var gqlHandlerFactory HandlerFactoryFn = func(schema *graphql.Schema, engine *engine.ExecutionEngine) http.Handler { - return http2.NewGraphqlHTTPHandler(schema, engine, upgrader, logger, enableART) + return http2.NewGraphqlHTTPHandler(schema, engine, upgrader, logger, enableART, subgraphHeadersBuilder) } - gateway := NewGateway(gqlHandlerFactory, httpClient, logger) + gateway := NewGateway(gqlHandlerFactory, httpClient, logger, loaderCaches) datasourceWatcher.Register(gateway) diff --git a/go.work.sum b/go.work.sum index d66f874bf..1aecd8d22 100644 --- a/go.work.sum +++ b/go.work.sum @@ -246,6 +246,8 @@ github.com/twmb/franz-go/pkg/kmsg v1.7.0/go.mod h1:se9Mjdt0Nwzc9lnjJ0HyDtLyBnaBD github.com/vbatts/tar-split v0.12.1 h1:CqKoORW7BUWBe7UL/iqTVvkTBOF8UvOMKOIZykxnnbo= github.com/vbatts/tar-split v0.12.1/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wundergraph/go-arena v0.0.0-20251008210416-55cb97e6f68f h1:5snewyMaIpajTu4wj22L/DgrGimICqXtUVjkZInBH3Y= +github.com/wundergraph/go-arena v0.0.0-20251008210416-55cb97e6f68f/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= diff --git a/v2/go.mod b/v2/go.mod index 83fbcc291..43ada453b 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -28,7 +28,8 @@ require ( github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/vektah/gqlparser/v2 v2.5.30 - github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 + github.com/wundergraph/astjson v1.0.0 + github.com/wundergraph/go-arena v1.0.0 go.uber.org/atomic v1.11.0 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.26.0 diff --git a/v2/go.sum b/v2/go.sum index 690d15a88..6d0fb3636 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -134,8 +134,10 @@ github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= +github.com/wundergraph/astjson v1.0.0 h1:rETLJuQkMWWW03HCF6WBttEBOu8gi5vznj5KEUPVV2Q= +github.com/wundergraph/astjson v1.0.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= +github.com/wundergraph/go-arena v1.0.0 h1:RVYWpDkJ1/6851BRHYehBeEcTLKmZygYIZsvBorcOjw= +github.com/wundergraph/go-arena v1.0.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/v2/pkg/astnormalization/uploads/upload_finder.go b/v2/pkg/astnormalization/uploads/upload_finder.go index b69a8bef2..0fd2d44c1 100644 --- a/v2/pkg/astnormalization/uploads/upload_finder.go +++ b/v2/pkg/astnormalization/uploads/upload_finder.go @@ -74,7 +74,7 @@ func (v *UploadFinder) FindUploads(operation, definition *ast.Document, variable variables = []byte("{}") } - v.variables, err = astjson.ParseBytesWithoutCache(variables) + v.variables, err = astjson.ParseBytes(variables) if err != nil { return nil, err } diff --git a/v2/pkg/astvisitor/visitor.go b/v2/pkg/astvisitor/visitor.go index 86f284c72..cae364bce 100644 --- a/v2/pkg/astvisitor/visitor.go +++ b/v2/pkg/astvisitor/visitor.go @@ -5,6 +5,8 @@ import ( "fmt" "sync" + "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" @@ -95,6 +97,8 @@ type Walker struct { deferred []func() OnExternalError func(err *operationreport.ExternalError) + + arena arena.Arena } func NewWalkerWithID(ancestorSize int, id string) Walker { @@ -140,6 +144,9 @@ func WalkerFromPool() *Walker { } func (w *Walker) Release() { + if w.arena != nil { + w.arena.Reset() + } w.ResetVisitors() w.Report = nil w.document = nil @@ -1391,6 +1398,11 @@ func (w *Walker) Walk(document, definition *ast.Document, report *operationrepor } else { w.Report = report } + if w.arena == nil { + w.arena = arena.NewMonotonicArena(arena.WithMinBufferSize(64)) + } else { + w.arena.Reset() + } w.Ancestors = w.Ancestors[:0] w.Path = w.Path[:0] w.TypeDefinitions = w.TypeDefinitions[:0] @@ -1843,8 +1855,7 @@ func (w *Walker) walkSelectionSet(ref int, skipFor SkipVisitors) { RefsChanged: for { - refs := make([]int, 0, len(w.document.SelectionSets[ref].SelectionRefs)) - refs = append(refs, w.document.SelectionSets[ref].SelectionRefs...) + refs := arena.SliceAppend(w.arena, nil, w.document.SelectionSets[ref].SelectionRefs...) for i, j := range refs { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 11c61d9e5..d917c7e8d 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -14,7 +14,6 @@ import ( "unicode" "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" "github.com/jensneuse/abstractlogger" "github.com/pkg/errors" "github.com/tidwall/sjson" @@ -84,6 +83,11 @@ type Planner[T Configuration] struct { // to the downstream subgraph fetch. propagatedOperationName string + // caching + + cacheKeyTemplate resolve.CacheKeyTemplate + rootFields []resolve.QueryField // tracks root fields and their arguments for cache key generation + // federation addedInlineFragments map[onTypeInlineFragment]struct{} @@ -376,6 +380,17 @@ func (p *Planner[T]) ConfigureFetch() resolve.FetchConfiguration { } } + // Set cache key template for non-entity calls (root queries) + if !requiresEntityFetch && !requiresEntityBatchFetch { + if len(p.rootFields) > 0 { + rootFieldsCopy := make([]resolve.QueryField, len(p.rootFields)) + copy(rootFieldsCopy, p.rootFields) + p.cacheKeyTemplate = &resolve.RootQueryCacheKeyTemplate{ + RootFields: rootFieldsCopy, + } + } + } + return resolve.FetchConfiguration{ Input: string(input), DataSource: dataSource, @@ -386,6 +401,9 @@ func (p *Planner[T]) ConfigureFetch() resolve.FetchConfiguration { SetTemplateOutputToNullOnVariableNull: requiresEntityFetch || requiresEntityBatchFetch, QueryPlan: p.queryPlan, OperationName: p.propagatedOperationName, + Caching: resolve.FetchCacheConfiguration{ + CacheKeyTemplate: p.cacheKeyTemplate, + }, } } @@ -716,6 +734,15 @@ func (p *Planner[T]) EnterField(ref int) { } } + // Track all root fields for cache key generation + if p.isRootField() { + coordinate := resolve.GraphCoordinate{ + TypeName: p.visitor.Walker.EnclosingTypeDefinition.NameString(p.visitor.Definition), + FieldName: fieldName, + } + p.trackCacheKeyCoordinate(coordinate) + } + // store root field name and ref if p.rootFieldName == "" { p.rootFieldName = fieldName @@ -731,6 +758,16 @@ func (p *Planner[T]) EnterField(ref int) { p.addFieldArguments(p.addField(ref), ref, fieldConfiguration) } +// isRootField returns false if an ancestor ast.Node is of kind field +func (p *Planner[T]) isRootField() bool { + for i := 0; i < len(p.visitor.Walker.Ancestors); i++ { + if p.visitor.Walker.Ancestors[i].Kind == ast.NodeKindField { + return false + } + } + return true +} + func (p *Planner[T]) addFieldArguments(upstreamFieldRef int, fieldRef int, fieldConfiguration *plan.FieldConfiguration) { if fieldConfiguration != nil { for i := range fieldConfiguration.Arguments { @@ -740,6 +777,44 @@ func (p *Planner[T]) addFieldArguments(upstreamFieldRef int, fieldRef int, field } } +// trackCacheKeyCoordinate ensures a root field is tracked for cache key generation, +// initializing an empty args slice if it doesn't exist yet +func (p *Planner[T]) trackCacheKeyCoordinate(coordinate resolve.GraphCoordinate) { + + // Check if the field is already tracked + for i := range p.rootFields { + if p.rootFields[i].Coordinate.TypeName == coordinate.TypeName && + p.rootFields[i].Coordinate.FieldName == coordinate.FieldName { + // Field already tracked + return + } + } + // Add the field to the slice + p.rootFields = append(p.rootFields, resolve.QueryField{ + Coordinate: coordinate, + }) +} + +// trackFieldWithArgument adds an argument (name + variable) to the field's tracking for cache key generation +func (p *Planner[T]) trackFieldWithArgument(coordinate resolve.GraphCoordinate, argName string, variable resolve.Variable) { + if coordinate.FieldName == "" { + return + } + // Ensure the field is tracked first + p.trackCacheKeyCoordinate(coordinate) + // Find the field and add the argument + for i := range p.rootFields { + if p.rootFields[i].Coordinate.TypeName == coordinate.TypeName && + p.rootFields[i].Coordinate.FieldName == coordinate.FieldName { + p.rootFields[i].Args = append(p.rootFields[i].Args, resolve.FieldArgument{ + Name: argName, + Variable: variable, + }) + return + } + } +} + func (p *Planner[T]) addCustomField(ref int) (upstreamFieldRef int) { fieldName, alias := p.handleFieldAlias(ref) fieldNode := p.upstreamOperation.AddField(ast.Field{ @@ -821,6 +896,12 @@ func (p *Planner[T]) EnterDocument(_, _ *ast.Document) { p.addDirectivesToVariableDefinitions = map[int][]int{} p.addedInlineFragments = map[onTypeInlineFragment]struct{}{} + + // reset root fields tracking for cache key generation + for i := 0; i < len(p.rootFields); i++ { + p.rootFields[i].Args = nil + } + p.rootFields = p.rootFields[:0] } func (p *Planner[T]) LeaveDocument(_, _ *ast.Document) { @@ -836,12 +917,16 @@ func (p *Planner[T]) addRepresentationsVariable() { return } - variable, _ := p.variables.AddVariable(p.buildRepresentationsVariable()) + representationsVariable := resolve.NewResolvableObjectVariable(p.buildRepresentationsVariable()) + p.cacheKeyTemplate = &resolve.EntityQueryCacheKeyTemplate{ + Keys: representationsVariable, + } + variable, _ := p.variables.AddVariable(representationsVariable) p.upstreamVariables, _ = sjson.SetRawBytes(p.upstreamVariables, "representations", []byte(fmt.Sprintf("[%s]", variable))) } -func (p *Planner[T]) buildRepresentationsVariable() resolve.Variable { +func (p *Planner[T]) buildRepresentationsVariable() *resolve.Object { objects := make([]*resolve.Object, 0, len(p.dataSourcePlannerConfig.RequiredFields)) for _, cfg := range p.dataSourcePlannerConfig.RequiredFields { node, err := buildRepresentationVariableNode(p.visitor.Definition, cfg, p.dataSourceConfig.FederationConfiguration()) @@ -853,9 +938,7 @@ func (p *Planner[T]) buildRepresentationsVariable() resolve.Variable { objects = append(objects, node) } - return resolve.NewResolvableObjectVariable( - mergeRepresentationVariableNodes(objects), - ) + return mergeRepresentationVariableNodes(objects) } func (p *Planner[T]) addRepresentationsQuery() { @@ -1091,7 +1174,7 @@ func (p *Planner[T]) configureArgument(upstreamFieldRef, downstreamFieldRef int, switch argumentConfiguration.SourceType { case plan.FieldArgumentSource: - p.configureFieldArgumentSource(upstreamFieldRef, downstreamFieldRef, argumentConfiguration) + p.configureFieldArgumentSource(upstreamFieldRef, downstreamFieldRef, fieldConfig, argumentConfiguration) case plan.ObjectFieldSource: p.configureObjectFieldSource(upstreamFieldRef, downstreamFieldRef, fieldConfig, argumentConfiguration) } @@ -1100,7 +1183,7 @@ func (p *Planner[T]) configureArgument(upstreamFieldRef, downstreamFieldRef int, } // configureFieldArgumentSource - creates variables for a plain argument types, in case object or list types goes deep and calls applyInlineFieldArgument -func (p *Planner[T]) configureFieldArgumentSource(upstreamFieldRef, downstreamFieldRef int, argumentConfiguration plan.ArgumentConfiguration) { +func (p *Planner[T]) configureFieldArgumentSource(upstreamFieldRef, downstreamFieldRef int, fieldConfig plan.FieldConfiguration, argumentConfiguration plan.ArgumentConfiguration) { fieldArgument, ok := p.visitor.Operation.FieldArgument(downstreamFieldRef, []byte(argumentConfiguration.Name)) if !ok { return @@ -1122,6 +1205,12 @@ func (p *Planner[T]) configureFieldArgumentSource(upstreamFieldRef, downstreamFi variableValueRef, argRef := p.upstreamOperation.AddVariableValueArgument([]byte(argumentConfiguration.Name), variableName) // add the argument to the field, but don't redefine it p.upstreamOperation.AddArgumentToField(upstreamFieldRef, argRef) + coordinate := resolve.GraphCoordinate{ + TypeName: fieldConfig.TypeName, + FieldName: fieldConfig.FieldName, + } + p.trackFieldWithArgument(coordinate, argumentConfiguration.Name, contextVariable) + if exists { // if the variable exists we don't have to put it onto the variables declaration again, skip return } @@ -1273,6 +1362,12 @@ func (p *Planner[T]) configureObjectFieldSource(upstreamFieldRef, downstreamFiel Renderer: resolve.NewJSONVariableRenderer(), } + coordinate := resolve.GraphCoordinate{ + TypeName: fieldConfiguration.TypeName, + FieldName: fieldConfiguration.FieldName, + } + p.trackFieldWithArgument(coordinate, argumentConfiguration.Name, variable) + objectVariableName, exists := p.variables.AddVariable(variable) if !exists { p.upstreamVariables, _ = sjson.SetRawBytes(p.upstreamVariables, string(variableName), []byte(objectVariableName)) @@ -1907,20 +2002,19 @@ func (s *Source) replaceEmptyObject(variables []byte) ([]byte, bool) { return variables, false } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (s *Source) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { input = s.compactAndUnNullVariables(input) - return httpclient.DoMultipartForm(s.httpClient, ctx, input, files, out) + return httpclient.DoMultipartForm(s.httpClient, ctx, headers, input, files) } -func (s *Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (s *Source) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { input = s.compactAndUnNullVariables(input) - return httpclient.Do(s.httpClient, ctx, input, out) + return httpclient.Do(s.httpClient, ctx, headers, input) } type GraphQLSubscriptionClient interface { // Subscribe to the origin source. The implementation must not block the calling goroutine. Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error - UniqueRequestID(ctx *resolve.Context, options GraphQLSubscriptionOptions, hash *xxhash.Digest) (err error) SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error Unsubscribe(id uint64) } @@ -1956,12 +2050,13 @@ type SubscriptionSource struct { client GraphQLSubscriptionClient } -func (s *SubscriptionSource) AsyncStart(ctx *resolve.Context, id uint64, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *SubscriptionSource) AsyncStart(ctx *resolve.Context, id uint64, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var options GraphQLSubscriptionOptions err := json.Unmarshal(input, &options) if err != nil { return err } + options.Header = headers if options.Body.Query == "" { return resolve.ErrUnableToResolve } @@ -1975,12 +2070,13 @@ func (s *SubscriptionSource) AsyncStop(id uint64) { } // Start the subscription. The updater is called on new events. Start needs to be called in a separate goroutine. -func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *SubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var options GraphQLSubscriptionOptions err := json.Unmarshal(input, &options) if err != nil { return err } + options.Header = headers if options.Body.Query == "" { return resolve.ErrUnableToResolve } @@ -1990,16 +2086,3 @@ func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater r var ( dataSouceName = []byte("graphql") ) - -func (s *SubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = xxh.Write(dataSouceName) - if err != nil { - return err - } - var options GraphQLSubscriptionOptions - err = json.Unmarshal(input, &options) - if err != nil { - return err - } - return s.client.UniqueRequestID(ctx, options, xxh) -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go index 4a84acf92..0926ae9a2 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_federation_test.go @@ -2,6 +2,7 @@ package graphql_datasource import ( "testing" + "time" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" . "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasourcetesting" @@ -1536,9 +1537,10 @@ func TestGraphQLDataSourceFederation(t *testing.T) { query CompositeKeys { user { account { + __typename name shippingInfo { - zip + z: zip } } } @@ -1556,6 +1558,23 @@ func TestGraphQLDataSourceFederation(t *testing.T) { Input: `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {__typename id info {a b}}}}"}}`, DataSource: &Source{}, PostProcessing: DefaultPostProcessingConfiguration, + Caching: resolve.FetchCacheConfiguration{ + Enabled: true, + CacheName: "default", + TTL: 30 * time.Second, + IncludeSubgraphHeaderPrefix: true, + CacheKeyTemplate: &resolve.RootQueryCacheKeyTemplate{ + RootFields: []resolve.QueryField{ + { + Coordinate: resolve.GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []resolve.FieldArgument{}, + }, + }, + }, + }, }, Info: &resolve.FetchInfo{ DataSourceID: "user.service", @@ -1567,6 +1586,61 @@ func TestGraphQLDataSourceFederation(t *testing.T) { FieldName: "user", }, }, + ProvidesData: &resolve.Object{ + Fields: []*resolve.Field{ + { + Name: []byte("user"), + Value: &resolve.Object{ + Path: []string{"user"}, + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("account"), + Value: &resolve.Object{ + Path: []string{"account"}, + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("__typename"), + Value: &resolve.Scalar{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &resolve.Scalar{ + Path: []string{"id"}, + }, + }, + { + Name: []byte("info"), + Value: &resolve.Object{ + Path: []string{"info"}, + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("a"), + Value: &resolve.Scalar{ + Path: []string{"a"}, + }, + }, + { + Name: []byte("b"), + Value: &resolve.Scalar{ + Path: []string{"b"}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, }, }), resolve.SingleWithPath(&resolve.SingleFetch{ @@ -1588,11 +1662,144 @@ func TestGraphQLDataSourceFederation(t *testing.T) { HasAuthorizationRule: true, }, }, + CoordinateDependencies: []resolve.FetchDependency{ + { + Coordinate: resolve.GraphCoordinate{ + TypeName: "Account", + FieldName: "name", + }, + IsUserRequested: true, + DependsOn: []resolve.FetchDependencyOrigin{ + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Account", + FieldName: "id", + }, + IsKey: true, + IsRequires: false, + }, + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Account", + FieldName: "info", + }, + IsKey: true, + IsRequires: false, + }, + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Info", + FieldName: "a", + }, + IsKey: true, + IsRequires: false, + }, + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Info", + FieldName: "b", + }, + IsKey: true, + IsRequires: false, + }, + }, + }, + { + Coordinate: resolve.GraphCoordinate{ + TypeName: "Account", + FieldName: "shippingInfo", + }, + IsUserRequested: true, + DependsOn: []resolve.FetchDependencyOrigin{ + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Account", + FieldName: "id", + }, + IsKey: true, + IsRequires: false, + }, + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Account", + FieldName: "info", + }, + IsKey: true, + IsRequires: false, + }, + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Info", + FieldName: "a", + }, + IsKey: true, + IsRequires: false, + }, + { + FetchID: 0, + Subgraph: "user.service", + Coordinate: resolve.GraphCoordinate{ + TypeName: "Info", + FieldName: "b", + }, + IsKey: true, + IsRequires: false, + }, + }, + }, + }, OperationType: ast.OperationTypeQuery, + ProvidesData: &resolve.Object{ + Fields: []*resolve.Field{ + { + Name: []byte("__typename"), + Value: &resolve.Scalar{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("name"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.Scalar{ + Path: []string{"name"}, + }, + }, + { + Name: []byte("shippingInfo"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.Object{ + Path: []string{"shippingInfo"}, + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("z"), + Value: &resolve.Scalar{ + Path: []string{"z"}, + }, + }, + }, + }, + }, + }, + }, }, DataSourceIdentifier: []byte("graphql_datasource.Source"), FetchConfiguration: resolve.FetchConfiguration{ - Input: `{"method":"POST","url":"http://account.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Account {__typename name shippingInfo {zip}}}}","variables":{"representations":[$$0$$]}}}`, + Input: `{"method":"POST","url":"http://account.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Account {__typename name shippingInfo {z: zip}}}}","variables":{"representations":[$$0$$]}}}`, DataSource: &Source{}, SetTemplateOutputToNullOnVariableNull: true, RequiresEntityFetch: true, @@ -1642,6 +1849,55 @@ func TestGraphQLDataSourceFederation(t *testing.T) { }, }, PostProcessing: SingleEntityPostProcessingConfiguration, + Caching: resolve.FetchCacheConfiguration{ + Enabled: true, + CacheName: "default", + TTL: time.Second * 30, + IncludeSubgraphHeaderPrefix: true, + CacheKeyTemplate: &resolve.EntityQueryCacheKeyTemplate{ + Keys: resolve.NewResolvableObjectVariable(&resolve.Object{ + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("__typename"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.Scalar{ + Path: []string{"id"}, + }, + }, + { + Name: []byte("info"), + OnTypeNames: [][]byte{[]byte("Account")}, + Value: &resolve.Object{ + Path: []string{"info"}, + Nullable: true, + Fields: []*resolve.Field{ + { + Name: []byte("a"), + Value: &resolve.Scalar{ + Path: []string{"a"}, + }, + }, + { + Name: []byte("b"), + Value: &resolve.Scalar{ + Path: []string{"b"}, + }, + }, + }, + }, + }, + }, + }), + }, + }, }, }, "user.account", resolve.ObjectPath("user"), resolve.ObjectPath("account")), ), @@ -1692,6 +1948,23 @@ func TestGraphQLDataSourceFederation(t *testing.T) { TypeName: "Account", SourceName: "user.service", Fields: []*resolve.Field{ + { + Name: []byte("__typename"), + Info: &resolve.FieldInfo{ + Name: "__typename", + NamedType: "String", + ParentTypeNames: []string{"Account"}, + Source: resolve.TypeFieldSource{ + IDs: []string{"user.service"}, + Names: []string{"user.service"}, + }, + ExactParentTypeName: "Account", + }, + Value: &resolve.String{ + Path: []string{"__typename"}, + IsTypeName: true, + }, + }, { Name: []byte("name"), Info: &resolve.FieldInfo{ @@ -1731,7 +2004,7 @@ func TestGraphQLDataSourceFederation(t *testing.T) { SourceName: "account.service", Fields: []*resolve.Field{ { - Name: []byte("zip"), + Name: []byte("z"), Info: &resolve.FieldInfo{ Name: "zip", NamedType: "String", @@ -1743,7 +2016,7 @@ func TestGraphQLDataSourceFederation(t *testing.T) { ExactParentTypeName: "ShippingInfo", }, Value: &resolve.String{ - Path: []string{"zip"}, + Path: []string{"z"}, }, }, }, @@ -1759,7 +2032,7 @@ func TestGraphQLDataSourceFederation(t *testing.T) { }, }, }, - planConfiguration, WithFieldInfo(), WithDefaultPostProcessor())) + planConfiguration, WithFieldInfo(), WithDefaultPostProcessor(), WithFieldDependencies(), WithEntityCaching(), WithFetchProvidesData())) }) t.Run("composite keys variant", func(t *testing.T) { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 7f41943c8..81c4266f7 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -16,7 +16,6 @@ import ( "testing" "time" - "github.com/cespare/xxhash/v2" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -394,6 +393,49 @@ func TestGraphQLDataSource(t *testing.T) { }, ), PostProcessing: DefaultPostProcessingConfiguration, + Caching: resolve.FetchCacheConfiguration{ + Enabled: true, + CacheName: "default", + TTL: 30 * time.Second, + IncludeSubgraphHeaderPrefix: true, + CacheKeyTemplate: &resolve.RootQueryCacheKeyTemplate{ + RootFields: []resolve.QueryField{ + { + Coordinate: resolve.GraphCoordinate{ + TypeName: "Query", + FieldName: "droid", + }, + Args: []resolve.FieldArgument{ + { + Name: "id", + Variable: &resolve.ContextVariable{ + Path: []string{"id"}, + Renderer: resolve.NewJSONVariableRenderer(), + }, + }, + }, + }, + { + Coordinate: resolve.GraphCoordinate{ + TypeName: "Query", + FieldName: "hero", + }, + }, + { + Coordinate: resolve.GraphCoordinate{ + TypeName: "Query", + FieldName: "stringList", + }, + }, + { + Coordinate: resolve.GraphCoordinate{ + TypeName: "Query", + FieldName: "nestedStringList", + }, + }, + }, + }, + }, }, Info: &resolve.FetchInfo{ OperationType: ast.OperationTypeQuery, @@ -417,6 +459,100 @@ func TestGraphQLDataSource(t *testing.T) { FieldName: "nestedStringList", }, }, + ProvidesData: &resolve.Object{ + Nullable: false, + Path: []string{}, + Fields: []*resolve.Field{ + { + Name: []byte("droid"), + Value: &resolve.Object{ + Nullable: true, + Path: []string{"droid"}, + Fields: []*resolve.Field{ + { + Name: []byte("name"), + Value: &resolve.Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("aliased"), + Value: &resolve.Scalar{ + Path: []string{"aliased"}, + Nullable: false, + }, + }, + { + Name: []byte("friends"), + Value: &resolve.Array{ + Path: []string{"friends"}, + Nullable: true, + Item: &resolve.Object{ + Nullable: true, + Path: []string{}, + Fields: []*resolve.Field{ + { + Name: []byte("name"), + Value: &resolve.Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + { + Name: []byte("primaryFunction"), + Value: &resolve.Scalar{ + Path: []string{"primaryFunction"}, + Nullable: false, + }, + }, + }, + }, + }, + { + Name: []byte("hero"), + Value: &resolve.Object{ + Nullable: true, + Path: []string{"hero"}, + Fields: []*resolve.Field{ + { + Name: []byte("name"), + Value: &resolve.Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + { + Name: []byte("stringList"), + Value: &resolve.Array{ + Path: []string{"stringList"}, + Nullable: true, + Item: &resolve.Scalar{ + Path: []string{}, + Nullable: true, + }, + }, + }, + { + Name: []byte("nestedStringList"), + Value: &resolve.Array{ + Path: []string{"nestedStringList"}, + Nullable: true, + Item: &resolve.Scalar{ + Path: []string{}, + Nullable: true, + }, + }, + }, + }, + }, }, })), Info: &resolve.GraphQLResponseInfo{ @@ -681,7 +817,7 @@ func TestGraphQLDataSource(t *testing.T) { }, }, DisableResolveFieldPositions: true, - }, WithFieldInfo(), WithDefaultPostProcessor())) + }, WithFieldInfo(), WithDefaultPostProcessor(), WithFetchProvidesData(), WithEntityCaching())) t.Run("selections on interface type", RunTest(interfaceSelectionSchema, ` query MyQuery { @@ -4009,6 +4145,8 @@ func TestGraphQLDataSource(t *testing.T) { NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), }, PostProcessing: DefaultPostProcessingConfiguration, + SourceName: "ds-id", + SourceID: "ds-id", }, Response: &resolve.GraphQLResponse{ Fetches: resolve.Sequence(), @@ -4050,6 +4188,8 @@ func TestGraphQLDataSource(t *testing.T) { client: NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), }, PostProcessing: DefaultPostProcessingConfiguration, + SourceName: "ds-id", + SourceID: "ds-id", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -8246,10 +8386,6 @@ func (f *FailingSubscriptionClient) Subscribe(ctx *resolve.Context, options Grap return errSubscriptionClientFail } -func (f *FailingSubscriptionClient) UniqueRequestID(ctx *resolve.Context, options GraphQLSubscriptionOptions, hash *xxhash.Digest) (err error) { - return errSubscriptionClientFail -} - type testSubscriptionUpdaterChan struct { updates chan string complete chan struct{} @@ -8441,13 +8577,13 @@ func TestSubscriptionSource_Start(t *testing.T) { t.Run("should return error when input is invalid", func(t *testing.T) { source := SubscriptionSource{client: &FailingSubscriptionClient{}} - err := source.Start(resolve.NewContext(context.Background()), []byte(`{"url": "", "body": "", "header": null}`), nil) + err := source.Start(resolve.NewContext(context.Background()), nil, []byte(`{"url": "", "body": "", "header": null}`), nil) assert.Error(t, err) }) t.Run("should return error when subscription client returns an error", func(t *testing.T) { source := SubscriptionSource{client: &FailingSubscriptionClient{}} - err := source.Start(resolve.NewContext(context.Background()), []byte(`{"url": "", "body": {}, "header": null}`), nil) + err := source.Start(resolve.NewContext(context.Background()), nil, []byte(`{"url": "", "body": {}, "header": null}`), nil) assert.Error(t, err) assert.Equal(t, resolve.ErrUnableToResolve, err) }) @@ -8460,7 +8596,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: "#test") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.ErrorIs(t, err, resolve.ErrUnableToResolve) }) @@ -8472,7 +8608,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) updater.AwaitUpdates(t, time.Second, 1) assert.Len(t, updater.updates, 1) @@ -8490,7 +8626,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(resolverLifecycle) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(resolve.NewContext(subscriptionLifecycle), chatSubscriptionOptions, updater) + err := source.Start(resolve.NewContext(subscriptionLifecycle), nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8513,7 +8649,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8577,7 +8713,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) updater.AwaitUpdates(t, time.Second, 1) @@ -8597,7 +8733,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { source := newSubscriptionSource(resolverLifecycle) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(resolve.NewContext(subscriptionLifecycle), chatSubscriptionOptions, updater) + err := source.Start(resolve.NewContext(subscriptionLifecycle), nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8621,7 +8757,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8759,10 +8895,9 @@ func TestSource_Load(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) - - require.NoError(t, src.Load(context.Background(), input, buf)) - assert.Equal(t, `{"variables":{"a":null,"b":"b","c":{}}}`, buf.String()) + data, err := src.Load(context.Background(), nil, input) + require.NoError(t, err) + assert.Equal(t, `{"variables":{"a":null,"b":"b","c":{}}}`, string(data)) }) }) t.Run("remove undefined variables", func(t *testing.T) { @@ -8775,7 +8910,6 @@ func TestSource_Load(t *testing.T) { var input []byte input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) undefinedVariables := []string{"a", "c"} ctx := context.Background() @@ -8783,8 +8917,9 @@ func TestSource_Load(t *testing.T) { input, err = httpclient.SetUndefinedVariables(input, undefinedVariables) assert.NoError(t, err) - require.NoError(t, src.Load(ctx, input, buf)) - assert.Equal(t, `{"variables":{"b":null}}`, buf.String()) + data, err := src.Load(ctx, nil, input) + require.NoError(t, err) + assert.Equal(t, `{"variables":{"b":null}}`, string(data)) }) }) } @@ -8866,10 +9001,10 @@ func TestLoadFiles(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputBodyWithPath(input, query, "query") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) ctx := context.Background() - require.NoError(t, src.LoadWithFiles(ctx, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}, buf)) + _, err = src.LoadWithFiles(ctx, nil, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}) + require.NoError(t, err) }) t.Run("multiple files", func(t *testing.T) { @@ -8910,7 +9045,6 @@ func TestLoadFiles(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputBodyWithPath(input, query, "query") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) dir := t.TempDir() f1, err := os.CreateTemp(dir, file1Name) @@ -8924,11 +9058,11 @@ func TestLoadFiles(t *testing.T) { assert.NoError(t, err) ctx := context.Background() - require.NoError(t, src.LoadWithFiles(ctx, input, + _, err = src.LoadWithFiles(ctx, nil, input, []*httpclient.FileUpload{ httpclient.NewFileUpload(f1.Name(), file1Name, "variables.files.0"), - httpclient.NewFileUpload(f2.Name(), file2Name, "variables.files.1")}, - buf)) + httpclient.NewFileUpload(f2.Name(), file2Name, "variables.files.1")}) + require.NoError(t, err) }) } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index c5a52a476..c8a08df03 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -9,13 +9,9 @@ import ( "errors" "fmt" "io" - "maps" "net" "net/http" "net/http/httptrace" - "net/textproto" - "slices" - "strconv" "strings" "sync" "syscall" @@ -295,27 +291,6 @@ func (c *subscriptionClient) Subscribe(ctx *resolve.Context, options GraphQLSubs return c.subscribeWS(ctx.Context(), c.engineCtx, options, updater) } -var ( - withSSE = []byte(`sse:true`) - withSSEMethodPost = []byte(`sse_method_post:true`) -) - -func (c *subscriptionClient) UniqueRequestID(ctx *resolve.Context, options GraphQLSubscriptionOptions, hash *xxhash.Digest) (err error) { - if options.UseSSE { - _, err = hash.Write(withSSE) - if err != nil { - return err - } - } - if options.SSEMethodPost { - _, err = hash.Write(withSSEMethodPost) - if err != nil { - return err - } - } - return c.requestHash(ctx, options, hash) -} - func (c *subscriptionClient) subscribeSSE(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { options.readTimeout = c.readTimeout if c.streamingClient == nil { @@ -409,89 +384,6 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont return nil } -// generateHandlerIDHash generates a Hash based on: URL and Headers to uniquely identify Upgrade Requests -func (c *subscriptionClient) requestHash(ctx *resolve.Context, options GraphQLSubscriptionOptions, xxh *xxhash.Digest) (err error) { - if _, err = xxh.WriteString(options.URL); err != nil { - return err - } - if err := options.Header.Write(xxh); err != nil { - return err - } - // Make sure any header that will be forwarded to the subgraph - // is hashed to create the handlerID, this way requests with - // different headers will use separate connections. - for _, headerName := range options.ForwardedClientHeaderNames { - if _, err = xxh.WriteString(headerName); err != nil { - return err - } - for _, val := range ctx.Request.Header[textproto.CanonicalMIMEHeaderKey(headerName)] { - if _, err = xxh.WriteString(val); err != nil { - return err - } - } - } - - // Sort header names for deterministic hashing since looping through maps - // results in a non-deterministic order of elements - headerKeys := slices.Sorted(maps.Keys(ctx.Request.Header)) - - for _, headerRegexp := range options.ForwardedClientHeaderRegularExpressions { - // Write header pattern - if _, err = xxh.WriteString(headerRegexp.Pattern.String()); err != nil { - return err - } - - // Write negate match - if _, err = xxh.WriteString(strconv.FormatBool(headerRegexp.NegateMatch)); err != nil { - return err - } - - for _, headerName := range headerKeys { - values := ctx.Request.Header[headerName] - result := headerRegexp.Pattern.MatchString(headerName) - if headerRegexp.NegateMatch { - result = !result - } - if result { - for _, val := range values { - if _, err = xxh.WriteString(val); err != nil { - return err - } - } - } - } - } - if len(ctx.InitialPayload) > 0 { - if _, err = xxh.Write(ctx.InitialPayload); err != nil { - return err - } - } - if options.Body.Extensions != nil { - if _, err = xxh.Write(options.Body.Extensions); err != nil { - return err - } - } - if options.Body.Query != "" { - _, err = xxh.WriteString(options.Body.Query) - if err != nil { - return err - } - } - if options.Body.Variables != nil { - _, err = xxh.Write(options.Body.Variables) - if err != nil { - return err - } - } - if options.Body.OperationName != "" { - _, err = xxh.WriteString(options.Body.OperationName) - if err != nil { - return err - } - } - return nil -} - type UpgradeRequestError struct { URL string StatusCode int diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 279c4bfe8..86dd57c03 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "regexp" "runtime" "strings" "sync" @@ -15,7 +14,6 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" "github.com/coder/websocket" ll "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" @@ -2439,7 +2437,7 @@ func TestWebSocketUpgradeFailures(t *testing.T) { w.Header().Set(key, value) } w.WriteHeader(tc.statusCode) - fmt.Fprintf(w, `{"error": "WebSocket upgrade failed", "status": %d}`, tc.statusCode) + _, _ = fmt.Fprintf(w, `{"error": "WebSocket upgrade failed", "status": %d}`, tc.statusCode) })) defer server.Close() @@ -2571,203 +2569,3 @@ func TestInvalidWebSocketAcceptKey(t *testing.T) { }) } } - -func TestRequestHash(t *testing.T) { - t.Parallel() - client := &subscriptionClient{} - - t.Run("basic request with URL and headers", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - Header: http.Header{ - "Authorization": []string{"Bearer token"}, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xacbca06c541c2a79), hash.Sum64()) - }) - - t.Run("request with forwarded client headers", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{ - "X-User-Id": []string{"123"}, - "X-Role": []string{"admin"}, - }, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderNames: []string{"X-User-Id", "X-Role"}, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xf428bef25952044c), hash.Sum64()) - }) - - t.Run("request with forwarded client header regex patterns", func(t *testing.T) { - t.Parallel() - - t.Run("with normal", func(t *testing.T) { - header := http.Header{ - "X-Custom-1": []string{"value1"}, - "X-There-2": []string{"value2"}, - "X-Alright-3": []string{"value3"}, - } - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: header, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderRegularExpressions: []RegularExpression{ - { - Pattern: regexp.MustCompile("^X-Custom-.*$"), - NegateMatch: false, - }, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xb1557904bfa9d86a), hash.Sum64()) - }) - - t.Run("with negative", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{ - "X-Custom-1": []string{"valueThere1"}, - "X-Custom-2": []string{"valueThere2"}, - }, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderRegularExpressions: []RegularExpression{ - { - Pattern: regexp.MustCompile("^X-Custom-2"), - NegateMatch: true, - }, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x5888642db454ccab), hash.Sum64()) - }) - - t.Run("with multiple tries to ensure the hash is idempotent", func(t *testing.T) { - for range 100 { - header := http.Header{ - "X-Custom-1": []string{"a1"}, - "X-There-2": []string{"a2"}, - "X-Custom-6": []string{"a3"}, - "X-Alright-3": []string{"a4"}, - "X-Custom-5": []string{"a5"}, - } - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: header, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderRegularExpressions: []RegularExpression{ - { - Pattern: regexp.MustCompile("^X-Custom-.*$"), - NegateMatch: false, - }, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x6c9c1099adab987d), hash.Sum64()) - } - }) - }) - - t.Run("request with initial payload", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - InitialPayload: []byte(`{"auth": "token"}`), - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x3c5af329478bfcce), hash.Sum64()) - - }) - - t.Run("request with body components", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - Body: GraphQLBody{ - Query: "query { hello }", - Variables: []byte(`{"var": "value"}`), - OperationName: "HelloQuery", - Extensions: []byte(`{"ext": "value"}`), - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xd8d5588c8a466cf2), hash.Sum64()) - }) - - t.Run("empty components", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x767db2231989769), hash.Sum64()) - }) - -} diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 8a196cbc6..6cbc4ca12 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -7,10 +7,12 @@ package grpcdatasource import ( - "bytes" "context" - "errors" + "encoding/binary" + "fmt" + "net/http" + "github.com/cespare/xxhash/v2" "github.com/tidwall/gjson" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -44,6 +46,8 @@ type DataSource struct { mapping *GRPCMapping federationConfigs plan.FederationFieldConfigurations disabled bool + + pool *resolve.ArenaPool } type ProtoConfig struct { @@ -79,28 +83,36 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D mapping: config.Mapping, federationConfigs: config.FederationConfigs, disabled: config.Disabled, + pool: resolve.NewArenaPool(), }, nil } // Load implements resolve.DataSource interface. -// It processes the input JSON data to make gRPC calls and writes -// the response to the output buffer. +// It processes the input JSON data to make gRPC calls and returns +// the response data. // // The input is expected to contain the necessary information to make // a gRPC call, including service name, method name, and request data. -func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { // get variables from input variables := gjson.Parse(unsafebytes.BytesToString(input)).Get("body.variables") - builder := newJSONBuilder(d.mapping, variables) + + var ( + poolItems []*resolve.ArenaPoolItem + ) + defer func() { + d.pool.ReleaseMany(poolItems) + }() + + item := d.acquirePoolItem(input, 0) + poolItems = append(poolItems, item) + builder := newJSONBuilder(item.Arena, d.mapping, variables) if d.disabled { - out.Write(builder.writeErrorBytes(errors.New("gRPC datasource needs to be enabled to be used"))) - return nil + return builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used")), nil } - arena := astjson.Arena{} - defer arena.Reset() - root := arena.NewObject() + root := astjson.ObjectValue(nil) failed := false @@ -115,8 +127,10 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) // make gRPC calls for index, serviceCall := range serviceCalls { + item := d.acquirePoolItem(input, index) + poolItems = append(poolItems, item) + builder := newJSONBuilder(item.Arena, d.mapping, variables) errGrp.Go(func() error { - a := astjson.Arena{} // Invoke the gRPC method - this will populate serviceCall.Output err := d.cc.Invoke(errGrpCtx, serviceCall.MethodFullName(), serviceCall.Input, serviceCall.Output) @@ -124,7 +138,7 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) return err } - response, err := builder.marshalResponseJSON(&a, &serviceCall.RPC.Response, serviceCall.Output) + response, err := builder.marshalResponseJSON(&serviceCall.RPC.Response, serviceCall.Output) if err != nil { return err } @@ -149,7 +163,7 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) } if err := errGrp.Wait(); err != nil { - out.Write(builder.writeErrorBytes(err)) + data = builder.writeErrorBytes(err) failed = true return nil } @@ -162,19 +176,29 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) root, err = builder.mergeValues(root, result.response) } if err != nil { - out.Write(builder.writeErrorBytes(err)) + data = builder.writeErrorBytes(err) return err } } return nil }); err != nil || failed { - return err + return data, err } - data := builder.toDataObject(root) - out.Write(data.MarshalTo(nil)) - return nil + value := builder.toDataObject(root) + return value.MarshalTo(nil), err +} + +func (d *DataSource) acquirePoolItem(input []byte, index int) *resolve.ArenaPoolItem { + keyGen := xxhash.New() + _, _ = keyGen.Write(input) + var b [8]byte + binary.LittleEndian.PutUint64(b[:], uint64(index)) + _, _ = keyGen.Write(b[:]) + key := keyGen.Sum64() + item := d.pool.Acquire(key) + return item } // LoadWithFiles implements resolve.DataSource interface. @@ -184,6 +208,6 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) // might not be applicable for most gRPC use cases. // // Currently unimplemented. -func (d *DataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (d *DataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("unimplemented") } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 8191b5b08..9a427809a 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -1,7 +1,6 @@ package grpcdatasource import ( - "bytes" "context" "encoding/json" "fmt" @@ -19,8 +18,6 @@ import ( protoref "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/dynamicpb" - "github.com/wundergraph/astjson" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" @@ -57,8 +54,7 @@ func Benchmark_DataSource_Load(b *testing.B) { b.ReportAllocs() b.ResetTimer() for b.Loop() { - output := new(bytes.Buffer) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), output) + _, err = ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(b, err) } } @@ -96,7 +92,7 @@ func Benchmark_DataSource_Load_WithFieldArguments(b *testing.B) { }) require.NoError(b, err) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), new(bytes.Buffer)) + _, err = ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(b, err) } } @@ -223,12 +219,10 @@ func Test_DataSource_Load(t *testing.T) { require.NoError(t, err) - output := new(bytes.Buffer) - - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","variables":`+variables+`}`), output) + output, err := ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","variables":`+variables+`}`)) require.NoError(t, err) - fmt.Println(output.String()) + fmt.Println(string(output)) } // Test_DataSource_Load_WithMockService tests the datasource.Load method with an actual gRPC server @@ -296,12 +290,11 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { require.NoError(t, err) // 3. Execute the query through our datasource - output := new(bytes.Buffer) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), output) + output, err := ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(t, err) // Print the response for debugging - // fmt.Println(output.String()) + // fmt.Println(string(output)) type response struct { Data struct { @@ -314,7 +307,7 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { var resp response - bytes := output.Bytes() + bytes := output fmt.Println(string(bytes)) err = json.Unmarshal(bytes, &resp) @@ -386,12 +379,10 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { require.NoError(t, err) // 3. Execute the query through our datasource - output := new(bytes.Buffer) - // Format the input with query and variables inputJSON := fmt.Sprintf(`{"query":%q,"body":%s}`, query, variables) - err = ds.Load(context.Background(), []byte(inputJSON), output) + output, err := ds.Load(context.Background(), nil, []byte(inputJSON)) require.NoError(t, err) // Set up the correct response structure based on your GraphQL schema @@ -408,7 +399,7 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { } var resp response - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") // Check if there are any errors in the response @@ -483,11 +474,10 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { require.NoError(t, err) // 4. Execute the query - output := new(bytes.Buffer) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), output) + output, err := ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(t, err, "Load should not return an error even when the gRPC call fails") - responseJson := output.String() + responseJson := string(output) // 5. Verify the response format according to GraphQL specification // The response should have an "errors" array with the error message @@ -501,7 +491,7 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { } `json:"errors"` } - err = json.Unmarshal(output.Bytes(), &response) + err = json.Unmarshal(output, &response) require.NoError(t, err, "Failed to parse response JSON") // Verify there's at least one error @@ -573,9 +563,8 @@ func TestMarshalResponseJSON(t *testing.T) { responseMessage := dynamicpb.NewMessage(responseMessageDesc) responseMessage.Mutable(responseMessageDesc.Fields().ByName("result")).List().Append(protoref.ValueOfMessage(productMessage)) - arena := astjson.Arena{} - jsonBuilder := newJSONBuilder(nil, gjson.Result{}) - responseJSON, err := jsonBuilder.marshalResponseJSON(&arena, &response, responseMessage) + jsonBuilder := newJSONBuilder(nil, nil, gjson.Result{}) + responseJSON, err := jsonBuilder.marshalResponseJSON(&response, responseMessage) require.NoError(t, err) require.Equal(t, `{"_entities":[{"__typename":"Product","id":"123","name_different":"test","price_different":123.45}]}`, responseJSON.String()) } @@ -810,9 +799,8 @@ func Test_DataSource_Load_WithAnimalInterface(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -823,7 +811,7 @@ func Test_DataSource_Load_WithAnimalInterface(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -1081,9 +1069,8 @@ func Test_Datasource_Load_WithUnionTypes(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1094,7 +1081,7 @@ func Test_Datasource_Load_WithUnionTypes(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -1218,9 +1205,8 @@ func Test_DataSource_Load_WithCategoryQueries(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1231,7 +1217,7 @@ func Test_DataSource_Load_WithCategoryQueries(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -1299,9 +1285,8 @@ func Test_DataSource_Load_WithTotalCalculation(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, query, variables) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1323,7 +1308,7 @@ func Test_DataSource_Load_WithTotalCalculation(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") @@ -1390,9 +1375,8 @@ func Test_DataSource_Load_WithTypename(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":{}}`, query) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1409,7 +1393,7 @@ func Test_DataSource_Load_WithTypename(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") @@ -1860,9 +1844,8 @@ func Test_DataSource_Load_WithAliases(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1873,7 +1856,7 @@ func Test_DataSource_Load_WithAliases(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -2239,9 +2222,8 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -2252,7 +2234,7 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -3541,9 +3523,8 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -3554,7 +3535,7 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -3741,15 +3722,14 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response var resp graphqlResponse - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") tc.validate(t, resp.Data) @@ -3928,15 +3908,14 @@ func Test_Datasource_Load_WithFieldResolvers(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response var resp graphqlResponse - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") tc.validate(t, resp.Data) diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go index 6e521b041..0b2edc07c 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -11,6 +11,7 @@ import ( protoref "google.golang.org/protobuf/reflect/protoreflect" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) @@ -107,16 +108,18 @@ type jsonBuilder struct { mapping *GRPCMapping // Mapping configuration for GraphQL to gRPC translation variables gjson.Result // GraphQL variables containing entity representations indexMap indexMap // Entity index mapping for federation ordering + jsonArena arena.Arena } // newJSONBuilder creates a new JSON builder instance with the provided mapping // and variables. The builder automatically creates an index map for proper // federation entity ordering if representations are present in the variables. -func newJSONBuilder(mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { +func newJSONBuilder(a arena.Arena, mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { return &jsonBuilder{ mapping: mapping, variables: variables, indexMap: createRepresentationIndexMap(variables), + jsonArena: a, } } @@ -163,7 +166,7 @@ func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*a if len(j.indexMap) == 0 { // No federation index map available - use simple merge // This path is taken for non-federated queries - root, _, err := astjson.MergeValues(left, right) + root, _, err := astjson.MergeValues(j.jsonArena, left, right) if err != nil { return nil, err } @@ -189,11 +192,10 @@ func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*a // This function ensures that entities are placed in the correct positions in the final response // array based on their original representation order, which is critical for GraphQL federation. func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) (*astjson.Value, error) { - arena := astjson.Arena{} // Create the response structure with _entities array - entities := arena.NewObject() - entities.Set(entityPath, arena.NewArray()) + entities := astjson.ObjectValue(j.jsonArena) + entities.Set(j.jsonArena, entityPath, astjson.ArrayValue(j.jsonArena)) arr := entities.Get(entityPath) // Extract entity arrays from both responses @@ -209,12 +211,12 @@ func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) ( // Merge left entities using index mapping to preserve order for index, lr := range leftRepresentations { - arr.SetArrayItem(j.indexMap.getResultIndex(lr, index), lr) + arr.SetArrayItem(j.jsonArena, j.indexMap.getResultIndex(lr, index), lr) } // Merge right entities using index mapping to preserve order for index, rr := range rightRepresentations { - arr.SetArrayItem(j.indexMap.getResultIndex(rr, index), rr) + arr.SetArrayItem(j.jsonArena, j.indexMap.getResultIndex(rr, index), rr) } return entities, nil @@ -257,7 +259,7 @@ func (j *jsonBuilder) mergeWithPath(base *astjson.Value, resolved *astjson.Value } for i := range responseValues { - responseValues[i].Set(elementName, resolvedValues[i].Get(elementName)) + responseValues[i].Set(j.jsonArena, elementName, resolvedValues[i].Get(elementName)) } return nil @@ -315,12 +317,12 @@ func (j *jsonBuilder) flattenList(items []*astjson.Value, path ast.Path) ([]*ast // marshalResponseJSON converts a protobuf message into a GraphQL-compatible JSON response. // This is the core marshaling function that handles all the complex type conversions, // including oneOf types, nested messages, lists, and scalar values. -func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMessage, data protoref.Message) (*astjson.Value, error) { +func (j *jsonBuilder) marshalResponseJSON(message *RPCMessage, data protoref.Message) (*astjson.Value, error) { if message == nil { - return arena.NewNull(), nil + return astjson.NullValue, nil } - root := arena.NewObject() + root := astjson.ObjectValue(j.jsonArena) // Handle protobuf oneOf types - these represent GraphQL union/interface types if message.IsOneOf() { @@ -354,14 +356,14 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess if field.StaticValue != "" { if len(message.MemberTypes) == 0 { // Simple static value - use as-is - root.Set(field.AliasOrPath(), arena.NewString(field.StaticValue)) + root.Set(j.jsonArena, field.AliasOrPath(), astjson.StringValue(j.jsonArena, field.StaticValue)) continue } // Type-specific static value - match against member types for _, memberTypes := range message.MemberTypes { if memberTypes == string(data.Type().Descriptor().Name()) { - root.Set(field.AliasOrPath(), arena.NewString(memberTypes)) + root.Set(j.jsonArena, field.AliasOrPath(), astjson.StringValue(j.jsonArena, memberTypes)) break } } @@ -379,8 +381,8 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess // Handle list fields (repeated in protobuf) if fd.IsList() { list := data.Get(fd).List() - arr := arena.NewArray() - root.Set(field.AliasOrPath(), arr) + arr := astjson.ArrayValue(j.jsonArena) + root.Set(j.jsonArena, field.AliasOrPath(), arr) if !list.IsValid() { // Invalid list - leave as empty array @@ -393,15 +395,15 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess case protoref.MessageKind: // List of messages - recursively marshal each message message := list.Get(i).Message() - value, err := j.marshalResponseJSON(arena, field.Message, message) + value, err := j.marshalResponseJSON(field.Message, message) if err != nil { return nil, err } - arr.SetArrayItem(i, value) + arr.SetArrayItem(j.jsonArena, i, value) default: // List of scalar values - convert directly - j.setArrayItem(i, arena, arr, list.Get(i), fd) + j.setArrayItem(i, arr, list.Get(i), fd) } } @@ -413,24 +415,24 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess msg := data.Get(fd).Message() if !msg.IsValid() { // Invalid message - set to null - root.Set(field.AliasOrPath(), arena.NewNull()) + root.Set(j.jsonArena, field.AliasOrPath(), astjson.NullValue) continue } // Handle special list wrapper types for complex nested lists if field.IsListType { - arr, err := j.flattenListStructure(arena, field.ListMetadata, msg, field.Message) + arr, err := j.flattenListStructure(field.ListMetadata, msg, field.Message) if err != nil { return nil, fmt.Errorf("unable to flatten list structure for field %q: %w", field.AliasOrPath(), err) } - root.Set(field.AliasOrPath(), arr) + root.Set(j.jsonArena, field.AliasOrPath(), arr) continue } // Handle optional scalar wrapper types (e.g., google.protobuf.StringValue) if field.IsOptionalScalar() { - err := j.resolveOptionalField(arena, root, field.AliasOrPath(), msg) + err := j.resolveOptionalField(root, field.AliasOrPath(), msg) if err != nil { return nil, err } @@ -439,27 +441,27 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess } // Regular nested message - recursively marshal - value, err := j.marshalResponseJSON(arena, field.Message, msg) + value, err := j.marshalResponseJSON(field.Message, msg) if err != nil { return nil, err } if field.JSONPath == "" { // Field should be merged into parent object (flattened) - root, _, err = astjson.MergeValues(root, value) + root, _, err = astjson.MergeValues(j.jsonArena, root, value) if err != nil { return nil, err } } else { // Field should be nested under its own key - root.Set(field.AliasOrPath(), value) + root.Set(j.jsonArena, field.AliasOrPath(), value) } continue } // Handle scalar fields (string, int, bool, etc.) - j.setJSONValue(arena, root, field.AliasOrPath(), data, fd) + j.setJSONValue(root, field.AliasOrPath(), data, fd) } return root, nil @@ -469,34 +471,34 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess // messages to support nullable and multi-dimensional lists. This is necessary because // protobuf doesn't directly support nullable list items or complex nesting scenarios // that GraphQL allows. -func (j *jsonBuilder) flattenListStructure(arena *astjson.Arena, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { +func (j *jsonBuilder) flattenListStructure(md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { if md == nil { - return arena.NewNull(), errors.New("list metadata not found") + return astjson.NullValue, errors.New("list metadata not found") } // Validate metadata consistency if len(md.LevelInfo) < md.NestingLevel { - return arena.NewNull(), errors.New("nesting level data does not match the number of levels in the list metadata") + return astjson.NullValue, errors.New("nesting level data does not match the number of levels in the list metadata") } // Handle null data with proper nullability checking if !data.IsValid() { if md.LevelInfo[0].Optional { - return arena.NewNull(), nil + return astjson.NullValue, nil } - return arena.NewNull(), errors.New("cannot add null item to response for non nullable list") + return astjson.NullValue, errors.New("cannot add null item to response for non nullable list") } // Start recursive traversal of the nested list structure - root := arena.NewArray() - return j.traverseList(0, arena, root, md, data, message) + root := astjson.ArrayValue(j.jsonArena) + return j.traverseList(0, root, md, data, message) } // traverseList recursively traverses nested list wrapper structures to extract the actual // list data. This handles multi-dimensional lists like [[String]] or [[[User]]] by // unwrapping the protobuf message wrappers at each level. -func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *astjson.Value, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { +func (j *jsonBuilder) traverseList(level int, current *astjson.Value, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { if level > md.NestingLevel { return current, nil } @@ -504,11 +506,11 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast // List wrappers always use field number 1 in the generated protobuf fd := data.Descriptor().Fields().ByNumber(1) if fd == nil { - return arena.NewNull(), fmt.Errorf("field with number %d not found in message %q", 1, data.Descriptor().Name()) + return astjson.NullValue, fmt.Errorf("field with number %d not found in message %q", 1, data.Descriptor().Name()) } if fd.Kind() != protoref.MessageKind { - return arena.NewNull(), fmt.Errorf("field %q is not a message", fd.Name()) + return astjson.NullValue, fmt.Errorf("field %q is not a message", fd.Name()) } // Get the wrapper message containing the list @@ -516,16 +518,16 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast if !msg.IsValid() { // Handle null wrapper based on nullability rules if md.LevelInfo[level].Optional { - return arena.NewNull(), nil + return astjson.NullValue, nil } - return arena.NewArray(), errors.New("cannot add null item to response for non nullable list") + return astjson.ArrayValue(j.jsonArena), errors.New("cannot add null item to response for non nullable list") } // The actual list is always at field number 1 in the wrapper fd = msg.Descriptor().Fields().ByNumber(1) if !fd.IsList() { - return arena.NewNull(), fmt.Errorf("field %q is not a list", fd.Name()) + return astjson.NullValue, fmt.Errorf("field %q is not a list", fd.Name()) } // Handle intermediate nesting levels (not the final level) @@ -533,13 +535,13 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast list := msg.Get(fd).List() for i := 0; i < list.Len(); i++ { // Create nested array for next level - next := arena.NewArray() - val, err := j.traverseList(level+1, arena, next, md, list.Get(i).Message(), message) + next := astjson.ArrayValue(j.jsonArena) + val, err := j.traverseList(level+1, next, md, list.Get(i).Message(), message) if err != nil { return nil, err } - current.SetArrayItem(i, val) + current.SetArrayItem(j.jsonArena, i, val) } return current, nil @@ -550,22 +552,22 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast if !list.IsValid() { // Invalid list at final level - return empty array // Nullability is checked at the wrapper level, not the list level - return arena.NewArray(), nil + return astjson.ArrayValue(j.jsonArena), nil } // Process each item in the final list for i := 0; i < list.Len(); i++ { if message != nil { // List of complex objects - recursively marshal each item - val, err := j.marshalResponseJSON(arena, message, list.Get(i).Message()) + val, err := j.marshalResponseJSON(message, list.Get(i).Message()) if err != nil { return nil, err } - current.SetArrayItem(i, val) + current.SetArrayItem(j.jsonArena, i, val) } else { // List of scalar values - convert directly - j.setArrayItem(i, arena, current, list.Get(i), fd) + j.setArrayItem(i, current, list.Get(i), fd) } } @@ -575,7 +577,7 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast // resolveOptionalField extracts the value from optional scalar wrapper types like // google.protobuf.StringValue, google.protobuf.Int32Value, etc. These wrappers // are used to represent nullable scalar values in protobuf. -func (j *jsonBuilder) resolveOptionalField(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message) error { +func (j *jsonBuilder) resolveOptionalField(root *astjson.Value, name string, data protoref.Message) error { // Optional scalar wrappers always have a "value" field fd := data.Descriptor().Fields().ByName(protoref.Name("value")) if fd == nil { @@ -583,16 +585,16 @@ func (j *jsonBuilder) resolveOptionalField(arena *astjson.Arena, root *astjson.V } // Extract and set the wrapped value - j.setJSONValue(arena, root, name, data, fd) + j.setJSONValue(root, name, data, fd) return nil } // setJSONValue converts a protobuf field value to the appropriate JSON representation // and sets it on the provided JSON object. This handles all protobuf scalar types // and enum values with proper GraphQL mapping. -func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) { +func (j *jsonBuilder) setJSONValue(root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) { if !data.IsValid() { - root.Set(name, arena.NewNull()) + root.Set(j.jsonArena, name, astjson.NullValue) return } @@ -600,27 +602,27 @@ func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, na case protoref.BoolKind: boolValue := data.Get(fd).Bool() if boolValue { - root.Set(name, arena.NewTrue()) + root.Set(j.jsonArena, name, astjson.TrueValue(j.jsonArena)) } else { - root.Set(name, arena.NewFalse()) + root.Set(j.jsonArena, name, astjson.FalseValue(j.jsonArena)) } case protoref.StringKind: - root.Set(name, arena.NewString(data.Get(fd).String())) + root.Set(j.jsonArena, name, astjson.StringValue(j.jsonArena, data.Get(fd).String())) case protoref.Int32Kind: - root.Set(name, arena.NewNumberInt(int(data.Get(fd).Int()))) + root.Set(j.jsonArena, name, astjson.IntValue(j.jsonArena, int(data.Get(fd).Int()))) case protoref.Int64Kind: - root.Set(name, arena.NewNumberString(strconv.FormatInt(data.Get(fd).Int(), 10))) + root.Set(j.jsonArena, name, astjson.NumberValue(j.jsonArena, strconv.FormatInt(data.Get(fd).Int(), 10))) case protoref.Uint32Kind, protoref.Uint64Kind: - root.Set(name, arena.NewNumberString(strconv.FormatUint(data.Get(fd).Uint(), 10))) + root.Set(j.jsonArena, name, astjson.NumberValue(j.jsonArena, strconv.FormatUint(data.Get(fd).Uint(), 10))) case protoref.FloatKind, protoref.DoubleKind: - root.Set(name, arena.NewNumberFloat64(data.Get(fd).Float())) + root.Set(j.jsonArena, name, astjson.FloatValue(j.jsonArena, data.Get(fd).Float())) case protoref.BytesKind: - root.Set(name, arena.NewStringBytes(data.Get(fd).Bytes())) + root.Set(j.jsonArena, name, astjson.StringValueBytes(j.jsonArena, data.Get(fd).Bytes())) case protoref.EnumKind: enumDesc := fd.Enum() enumValueDesc := enumDesc.Values().ByNumber(data.Get(fd).Enum()) if enumValueDesc == nil { - root.Set(name, arena.NewNull()) + root.Set(j.jsonArena, name, astjson.NullValue) return } @@ -628,20 +630,20 @@ func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, na graphqlValue, ok := j.mapping.FindEnumValueMapping(string(enumDesc.Name()), string(enumValueDesc.Name())) if !ok { // No mapping found - set to null - root.Set(name, arena.NewNull()) + root.Set(j.jsonArena, name, astjson.NullValue) return } - root.Set(name, arena.NewString(graphqlValue)) + root.Set(j.jsonArena, name, astjson.StringValue(j.jsonArena, graphqlValue)) } } // setArrayItem converts a protobuf list item value to JSON and sets it at the specified // array index. This is similar to setJSONValue but operates on array elements rather // than object properties, and works with protobuf Value types rather than Message types. -func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) { +func (j *jsonBuilder) setArrayItem(index int, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) { if !data.IsValid() { - array.SetArrayItem(index, arena.NewNull()) + array.SetArrayItem(j.jsonArena, index, astjson.NullValue) return } @@ -649,27 +651,27 @@ func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjs case protoref.BoolKind: boolValue := data.Bool() if boolValue { - array.SetArrayItem(index, arena.NewTrue()) + array.SetArrayItem(j.jsonArena, index, astjson.TrueValue(j.jsonArena)) } else { - array.SetArrayItem(index, arena.NewFalse()) + array.SetArrayItem(j.jsonArena, index, astjson.FalseValue(j.jsonArena)) } case protoref.StringKind: - array.SetArrayItem(index, arena.NewString(data.String())) + array.SetArrayItem(j.jsonArena, index, astjson.StringValue(j.jsonArena, data.String())) case protoref.Int32Kind: - array.SetArrayItem(index, arena.NewNumberInt(int(data.Int()))) + array.SetArrayItem(j.jsonArena, index, astjson.IntValue(j.jsonArena, int(data.Int()))) case protoref.Int64Kind: - array.SetArrayItem(index, arena.NewNumberString(strconv.FormatInt(data.Int(), 10))) + array.SetArrayItem(j.jsonArena, index, astjson.NumberValue(j.jsonArena, strconv.FormatInt(data.Int(), 10))) case protoref.Uint32Kind, protoref.Uint64Kind: - array.SetArrayItem(index, arena.NewNumberString(strconv.FormatUint(data.Uint(), 10))) + array.SetArrayItem(j.jsonArena, index, astjson.NumberValue(j.jsonArena, strconv.FormatUint(data.Uint(), 10))) case protoref.FloatKind, protoref.DoubleKind: - array.SetArrayItem(index, arena.NewNumberFloat64(data.Float())) + array.SetArrayItem(j.jsonArena, index, astjson.FloatValue(j.jsonArena, data.Float())) case protoref.BytesKind: - array.SetArrayItem(index, arena.NewStringBytes(data.Bytes())) + array.SetArrayItem(j.jsonArena, index, astjson.StringValueBytes(j.jsonArena, data.Bytes())) case protoref.EnumKind: enumDesc := fd.Enum() enumValueDesc := enumDesc.Values().ByNumber(data.Enum()) if enumValueDesc == nil { - array.SetArrayItem(index, arena.NewNull()) + array.SetArrayItem(j.jsonArena, index, astjson.NullValue) return } @@ -677,20 +679,19 @@ func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjs graphqlValue, ok := j.mapping.FindEnumValueMapping(string(enumDesc.Name()), string(enumValueDesc.Name())) if !ok { // No mapping found - use null - array.SetArrayItem(index, arena.NewNull()) + array.SetArrayItem(j.jsonArena, index, astjson.NullValue) return } - array.SetArrayItem(index, arena.NewString(graphqlValue)) + array.SetArrayItem(j.jsonArena, index, astjson.StringValue(j.jsonArena, graphqlValue)) } } // toDataObject wraps a response value in the standard GraphQL data envelope. // This creates the top-level structure { "data": ... } that GraphQL clients expect. func (j *jsonBuilder) toDataObject(root *astjson.Value) *astjson.Value { - a := astjson.Arena{} - data := a.NewObject() - data.Set(dataPath, root) + data := astjson.ObjectValue(j.jsonArena) + data.Set(j.jsonArena, dataPath, root) return data } @@ -698,30 +699,27 @@ func (j *jsonBuilder) toDataObject(root *astjson.Value) *astjson.Value { // This includes the error message and gRPC status code information in the extensions // field, following GraphQL error specification standards. func (j *jsonBuilder) writeErrorBytes(err error) []byte { - a := astjson.Arena{} - defer a.Reset() - // Create standard GraphQL error structure - errorRoot := a.NewObject() - errorArray := a.NewArray() - errorRoot.Set(errorsPath, errorArray) + errorRoot := astjson.ObjectValue(j.jsonArena) + errorArray := astjson.ArrayValue(j.jsonArena) + errorRoot.Set(j.jsonArena, errorsPath, errorArray) // Create individual error object - errorItem := a.NewObject() - errorItem.Set("message", a.NewString(err.Error())) + errorItem := astjson.ObjectValue(j.jsonArena) + errorItem.Set(j.jsonArena, "message", astjson.StringValue(j.jsonArena, err.Error())) // Add gRPC status code information to extensions - extensions := a.NewObject() + extensions := astjson.ObjectValue(j.jsonArena) if st, ok := status.FromError(err); ok { // gRPC error - include the specific status code - extensions.Set("code", a.NewString(st.Code().String())) + extensions.Set(j.jsonArena, "code", astjson.StringValue(j.jsonArena, st.Code().String())) } else { // Generic error - default to INTERNAL status - extensions.Set("code", a.NewString(codes.Internal.String())) + extensions.Set(j.jsonArena, "code", astjson.StringValue(j.jsonArena, codes.Internal.String())) } - errorItem.Set("extensions", extensions) - errorArray.SetArrayItem(0, errorItem) + errorItem.Set(j.jsonArena, "extensions", extensions) + errorArray.SetArrayItem(j.jsonArena, 0, errorItem) return errorRoot.MarshalTo(nil) } diff --git a/v2/pkg/engine/datasource/httpclient/httpclient_test.go b/v2/pkg/engine/datasource/httpclient/httpclient_test.go index 223e5d833..98685cece 100644 --- a/v2/pkg/engine/datasource/httpclient/httpclient_test.go +++ b/v2/pkg/engine/datasource/httpclient/httpclient_test.go @@ -1,7 +1,6 @@ package httpclient import ( - "bytes" "compress/gzip" "context" "io" @@ -80,10 +79,9 @@ func TestHttpClientDo(t *testing.T) { runTest := func(ctx context.Context, input []byte, expectedOutput string) func(t *testing.T) { return func(t *testing.T) { - out := &bytes.Buffer{} - err := Do(http.DefaultClient, ctx, input, out) + output, err := Do(http.DefaultClient, ctx, nil, input) assert.NoError(t, err) - assert.Equal(t, expectedOutput, out.String()) + assert.Equal(t, expectedOutput, string(output)) } } @@ -211,9 +209,8 @@ func TestHttpClientDo(t *testing.T) { input = SetInputURL(input, []byte(server.URL)) input, err := sjson.SetBytes(input, TRACE, true) assert.NoError(t, err) - out := &bytes.Buffer{} - err = Do(http.DefaultClient, context.Background(), input, out) + output, err := Do(http.DefaultClient, context.Background(), nil, input) assert.NoError(t, err) - assert.Contains(t, out.String(), `"Authorization":["****"]`) + assert.Contains(t, string(output), `"Authorization":["****"]`) }) } diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 4e8ca9b31..46af845e4 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -20,7 +20,6 @@ import ( "github.com/buger/jsonparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" - "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) const ( @@ -28,6 +27,7 @@ const ( AcceptEncodingHeader = "Accept-Encoding" AcceptHeader = "Accept" ContentTypeHeader = "Content-Type" + ContentLengthHeader = "Content-Length" EncodingGzip = "gzip" EncodingDeflate = "deflate" @@ -130,21 +130,38 @@ func respBodyReader(res *http.Response) (io.Reader, error) { } } -type bodyHashContextKey struct{} +type httpClientContext string -func BodyHashFromContext(ctx context.Context) (uint64, bool) { - value := ctx.Value(bodyHashContextKey{}) - if value == nil { - return 0, false +const ( + sizeHintKey httpClientContext = "size-hint" +) + +// WithHTTPClientSizeHint allows the engine to keep track of response sizes per subgraph fetch +// If a hint is supplied, we can create a buffer of size close to the required size +// This reduces allocations by reducing the buffer grow calls, which always copies the buffer +func WithHTTPClientSizeHint(ctx context.Context, size int) context.Context { + return context.WithValue(ctx, sizeHintKey, size) +} + +func buffer(ctx context.Context) *bytes.Buffer { + if sizeHint, ok := ctx.Value(sizeHintKey).(int); ok && sizeHint > 0 { + return bytes.NewBuffer(make([]byte, 0, sizeHint)) } - return value.(uint64), true + // if we start with zero, doubling will take a while until we reach the required size + // if we start with a high number, e.g. 1024, we just increase the memory usage of the engine + // 64 seems to be a healthy middle ground + return bytes.NewBuffer(make([]byte, 0, 64)) } -func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, out *bytes.Buffer, contentType string) (err error) { +func makeHTTPRequest(client *http.Client, ctx context.Context, baseHeaders http.Header, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, contentType string, contentLength int) ([]byte, error) { request, err := http.NewRequestWithContext(ctx, string(method), string(url), body) if err != nil { - return err + return nil, err + } + + if baseHeaders != nil { + request.Header = baseHeaders } if headers != nil { @@ -161,7 +178,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head return err }) if err != nil { - return err + return nil, err } } @@ -190,7 +207,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head } }) if err != nil { - return err + return nil, err } request.URL.RawQuery = query.Encode() } @@ -199,12 +216,17 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head request.Header.Add(ContentTypeHeader, contentType) request.Header.Set(AcceptEncodingHeader, EncodingGzip) request.Header.Add(AcceptEncodingHeader, EncodingDeflate) + if contentLength > 0 { + // always set the ContentLength field so that chunking can be avoided + // and other parties can more efficiently parse + request.ContentLength = int64(contentLength) + } setRequest(ctx, request) response, err := client.Do(request) if err != nil { - return err + return nil, err } defer response.Body.Close() @@ -212,23 +234,26 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head respReader, err := respBodyReader(response) if err != nil { - return err + return nil, err } - if !enableTrace { - if response.ContentLength > 0 { - out.Grow(int(response.ContentLength)) - } else { - out.Grow(1024 * 4) - } - _, err = out.ReadFrom(respReader) - return + // we intentionally don't use a pool of sorts here + // we're buffering the response and then later, in the engine, + // parse it into an JSON AST with the use of an arena, which is quite efficient + // Through trial and error it turned out that it's best to leave this buffer to the GC + // It'll know best the lifecycle of the buffer + // Using an arena here just increased overall memory usage + out := buffer(ctx) + _, err = out.ReadFrom(respReader) + if err != nil { + return nil, err } - data, err := io.ReadAll(respReader) - if err != nil { - return err + if !enableTrace { + return out.Bytes(), nil } + + data := out.Bytes() responseTrace := TraceHTTP{ Request: TraceHTTPRequest{ Method: request.Method, @@ -244,39 +269,29 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head } trace, err := json.Marshal(responseTrace) if err != nil { - return err + return nil, err } responseWithTraceExtension, err := jsonparser.Set(data, trace, "extensions", "trace") if err != nil { - return err + return nil, err } - _, err = out.Write(responseWithTraceExtension) - return err + return responseWithTraceExtension, nil } -func Do(client *http.Client, ctx context.Context, requestInput []byte, out *bytes.Buffer) (err error) { +func Do(client *http.Client, ctx context.Context, baseHeaders http.Header, requestInput []byte) (data []byte, err error) { url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) - h := pool.Hash64.Get() - _, _ = h.Write(body) - bodyHash := h.Sum64() - pool.Hash64.Put(h) - ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - return makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, out, ContentTypeJSON) + return makeHTTPRequest(client, ctx, baseHeaders, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, ContentTypeJSON, len(body)) } func DoMultipartForm( - client *http.Client, ctx context.Context, requestInput []byte, files []*FileUpload, out *bytes.Buffer, -) (err error) { + client *http.Client, ctx context.Context, baseHeaders http.Header, requestInput []byte, files []*FileUpload, +) (data []byte, err error) { if len(files) == 0 { - return errors.New("no files provided") + return nil, errors.New("no files provided") } url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) - h := pool.Hash64.Get() - defer pool.Hash64.Put(h) - _, _ = h.Write(body) - formValues := map[string]io.Reader{ "operations": bytes.NewReader(body), } @@ -293,14 +308,13 @@ func DoMultipartForm( } hasWrittenFileName = true - fmt.Fprintf(fileMap, `"%d":["%s"]`, i, file.variablePath) + _, _ = fmt.Fprintf(fileMap, `"%d":["%s"]`, i, file.variablePath) key := fmt.Sprintf("%d", i) - _, _ = h.WriteString(file.Path()) temporaryFile, err := os.Open(file.Path()) tempFiles = append(tempFiles, temporaryFile) if err != nil { - return err + return nil, err } formValues[key] = bufio.NewReader(temporaryFile) } @@ -309,7 +323,7 @@ func DoMultipartForm( multipartBody, contentType, err := multipartBytes(formValues, files) if err != nil { - return err + return nil, err } defer func() { @@ -324,10 +338,7 @@ func DoMultipartForm( } }() - bodyHash := h.Sum64() - ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - - return makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, out, contentType) + return makeHTTPRequest(client, ctx, baseHeaders, url, method, headers, queryParams, multipartBody, enableTrace, contentType, 0) } func multipartBytes(values map[string]io.Reader, files []*FileUpload) (*io.PipeReader, string, error) { diff --git a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden index f6ab07228..d6f62343c 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden +++ b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden @@ -363,4 +363,4 @@ } ], "__typename": "__Schema" -} +} \ No newline at end of file diff --git a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden index 6b73ac8dc..f56fee360 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden +++ b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden @@ -511,4 +511,4 @@ } ], "__typename": "__Schema" -} +} \ No newline at end of file diff --git a/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden b/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden index 41827c0f6..16017d131 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden +++ b/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden @@ -56,4 +56,4 @@ "interfaces": [], "possibleTypes": [], "__typename": "__Type" -} +} \ No newline at end of file diff --git a/v2/pkg/engine/datasource/introspection_datasource/source.go b/v2/pkg/engine/datasource/introspection_datasource/source.go index b9a06489d..67195e44a 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source.go @@ -1,11 +1,11 @@ package introspection_datasource import ( - "bytes" "context" "encoding/json" "errors" "io" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/introspection" @@ -19,21 +19,21 @@ type Source struct { introspectionData *introspection.Data } -func (s *Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (s *Source) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var req introspectionInput if err := json.Unmarshal(input, &req); err != nil { - return err + return nil, err } if req.RequestType == TypeRequestType { - return s.singleType(out, req.TypeName) + return s.singleTypeBytes(req.TypeName) } - return json.NewEncoder(out).Encode(s.introspectionData.Schema) + return json.Marshal(s.introspectionData.Schema) } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { - return errors.New("introspection data source does not support file uploads") +func (s *Source) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { + return nil, errors.New("introspection data source does not support file uploads") } func (s *Source) typeInfo(typeName *string) *introspection.FullType { @@ -57,3 +57,12 @@ func (s *Source) singleType(w io.Writer, typeName *string) error { return json.NewEncoder(w).Encode(typeInfo) } + +func (s *Source) singleTypeBytes(typeName *string) ([]byte, error) { + typeInfo := s.typeInfo(typeName) + if typeInfo == nil { + return null, nil + } + + return json.Marshal(typeInfo) +} diff --git a/v2/pkg/engine/datasource/introspection_datasource/source_test.go b/v2/pkg/engine/datasource/introspection_datasource/source_test.go index bb4a91143..9737a4ee9 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source_test.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source_test.go @@ -27,13 +27,18 @@ func TestSource_Load(t *testing.T) { gen.Generate(&def, &report, &data) require.False(t, report.HasErrors()) - buf := &bytes.Buffer{} source := &Source{introspectionData: &data} - require.NoError(t, source.Load(context.Background(), []byte(input), buf)) + responseData, err := source.Load(context.Background(), nil, []byte(input)) + require.NoError(t, err) actualResponse := &bytes.Buffer{} - require.NoError(t, json.Indent(actualResponse, buf.Bytes(), "", " ")) - goldie.Assert(t, fixtureName, actualResponse.Bytes()) + require.NoError(t, json.Indent(actualResponse, responseData, "", " ")) + // Trim the trailing newline that json.Indent adds + responseBytes := actualResponse.Bytes() + if len(responseBytes) > 0 && responseBytes[len(responseBytes)-1] == '\n' { + responseBytes = responseBytes[:len(responseBytes)-1] + } + goldie.Assert(t, fixtureName, responseBytes) } } diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go index 28a37df33..2ea8114ad 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go @@ -424,6 +424,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"helloSubscription"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -487,6 +489,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"subscriptionWithMultipleSubjects"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -532,6 +536,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"subscriptionWithStaticValues"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -583,6 +589,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"subscriptionWithArgTemplateAndStaticValue"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go index cc562b803..3f688b6b1 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go @@ -1,13 +1,9 @@ package pubsub_datasource import ( - "bytes" "context" "encoding/json" - "io" - - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -33,28 +29,7 @@ type KafkaSubscriptionSource struct { pubSub KafkaPubSub } -func (s *KafkaSubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - - val, _, _, err := jsonparser.Get(input, "topics") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } - - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } - - _, err = xxh.Write(val) - return err -} - -func (s *KafkaSubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *KafkaSubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var subscriptionConfiguration KafkaSubscriptionEventConfiguration err := json.Unmarshal(input, &subscriptionConfiguration) if err != nil { @@ -68,21 +43,19 @@ type KafkaPublishDataSource struct { pubSub KafkaPubSub } -func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { +func (s *KafkaPublishDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var publishConfiguration KafkaPublishEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) + err = json.Unmarshal(input, &publishConfiguration) if err != nil { - return err + return nil, err } if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + return []byte(`{"success": false}`), err } - _, err = io.WriteString(out, `{"success": true}`) - return err + return []byte(`{"success": true}`), nil } -func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go index 31cb6d415..776b5deac 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go @@ -5,9 +5,7 @@ import ( "context" "encoding/json" "io" - - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -42,28 +40,7 @@ type NatsSubscriptionSource struct { pubSub NatsPubSub } -func (s *NatsSubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - - val, _, _, err := jsonparser.Get(input, "subjects") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } - - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } - - _, err = xxh.Write(val) - return err -} - -func (s *NatsSubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *NatsSubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var subscriptionConfiguration NatsSubscriptionEventConfiguration err := json.Unmarshal(input, &subscriptionConfiguration) if err != nil { @@ -77,23 +54,21 @@ type NatsPublishDataSource struct { pubSub NatsPubSub } -func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { +func (s *NatsPublishDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var publishConfiguration NatsPublishAndRequestEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) + err = json.Unmarshal(input, &publishConfiguration) if err != nil { - return err + return nil, err } if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + return []byte(`{"success": false}`), err } - _, err = io.WriteString(out, `{"success": true}`) - return err + return []byte(`{"success": true}`), nil } -func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { +func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } @@ -101,16 +76,22 @@ type NatsRequestDataSource struct { pubSub NatsPubSub } -func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { +func (s *NatsRequestDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var subscriptionConfiguration NatsPublishAndRequestEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) + err = json.Unmarshal(input, &subscriptionConfiguration) if err != nil { - return err + return nil, err + } + + var buf bytes.Buffer + err = s.pubSub.Request(ctx, subscriptionConfiguration, &buf) + if err != nil { + return nil, err } - return s.pubSub.Request(ctx, subscriptionConfiguration, out) + return buf.Bytes(), nil } -func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { +func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go index e9074635c..3fb75c8b3 100644 --- a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go +++ b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go @@ -1,8 +1,8 @@ package staticdatasource import ( - "bytes" "context" + "net/http" "github.com/jensneuse/abstractlogger" @@ -71,11 +71,10 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration { type Source struct{} -func (Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { - _, err = out.Write(input) - return +func (Source) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { + return input, nil } -func (Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (Source) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasourcetesting/datasourcetesting.go b/v2/pkg/engine/datasourcetesting/datasourcetesting.go index 598766644..66495b615 100644 --- a/v2/pkg/engine/datasourcetesting/datasourcetesting.go +++ b/v2/pkg/engine/datasourcetesting/datasourcetesting.go @@ -34,6 +34,8 @@ type testOptions struct { withPrintPlan bool withFieldDependencies bool withFetchReasons bool + withEntityCaching bool + withFetchProvidesData bool } func WithPostProcessors(postProcessors ...*postprocess.Processor) func(*testOptions) { @@ -84,6 +86,22 @@ func WithFetchReasons() func(*testOptions) { } } +func WithEntityCaching() func(*testOptions) { + return func(o *testOptions) { + o.withFieldInfo = true + o.withFieldDependencies = true + o.withEntityCaching = true + } +} + +func WithFetchProvidesData() func(*testOptions) { + return func(o *testOptions) { + o.withFieldInfo = true + o.withFieldDependencies = true + o.withFetchProvidesData = true + } +} + func RunWithPermutations(t *testing.T, definition, operation, operationName string, expectedPlan plan.Plan, config plan.Configuration, options ...func(*testOptions)) { t.Helper() @@ -143,6 +161,8 @@ func RunTestWithVariables(definition, operation, operationName, variables string // by default, we don't want to have field info in the tests because it's too verbose config.DisableIncludeInfo = true config.DisableIncludeFieldDependencies = true + config.DisableEntityCaching = true + config.DisableFetchProvidesData = true opts := &testOptions{} for _, o := range options { @@ -161,6 +181,14 @@ func RunTestWithVariables(definition, operation, operationName, variables string config.BuildFetchReasons = true } + if opts.withEntityCaching { + config.DisableEntityCaching = false + } + + if opts.withFetchProvidesData { + config.DisableFetchProvidesData = false + } + if opts.skipReason != "" { t.Skip(opts.skipReason) } diff --git a/v2/pkg/engine/plan/configuration.go b/v2/pkg/engine/plan/configuration.go index dafcb021c..d7a7fee67 100644 --- a/v2/pkg/engine/plan/configuration.go +++ b/v2/pkg/engine/plan/configuration.go @@ -46,6 +46,11 @@ type Configuration struct { // entity. // This option requires BuildFetchReasons set to true. ValidateRequiredExternalFields bool + + // DisableEntityCaching disables planning of entity caching behavior or generating relevant metadata + DisableEntityCaching bool + // DisableFetchProvidesData disables planning of meta information about which fields are provided by a fetch + DisableFetchProvidesData bool } type DebugConfiguration struct { diff --git a/v2/pkg/engine/plan/planner_test.go b/v2/pkg/engine/plan/planner_test.go index 270140381..32d42d198 100644 --- a/v2/pkg/engine/plan/planner_test.go +++ b/v2/pkg/engine/plan/planner_test.go @@ -1,10 +1,10 @@ package plan import ( - "bytes" "context" "encoding/json" "fmt" + "net/http" "reflect" "slices" "testing" @@ -172,6 +172,7 @@ func TestPlanner_Plan(t *testing.T) { }, Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, DataSources: []DataSource{testDefinitionDSConfiguration}, })) @@ -226,6 +227,7 @@ func TestPlanner_Plan(t *testing.T) { }, Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, DataSources: []DataSource{testDefinitionDSConfiguration}, })) @@ -292,6 +294,7 @@ func TestPlanner_Plan(t *testing.T) { }, Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, DataSources: []DataSource{testDefinitionDSConfiguration}, })) @@ -363,6 +366,7 @@ func TestPlanner_Plan(t *testing.T) { }, Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, DataSources: []DataSource{testDefinitionDSConfiguration}, })) @@ -425,14 +429,16 @@ func TestPlanner_Plan(t *testing.T) { }, Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, DataSources: []DataSource{testDefinitionDSConfiguration}, })) }) t.Run("operation selection", func(t *testing.T) { cfg := Configuration{ - DataSources: []DataSource{testDefinitionDSConfiguration}, - DisableIncludeInfo: true, + DataSources: []DataSource{testDefinitionDSConfiguration}, + DisableIncludeInfo: true, + DisableEntityCaching: true, } t.Run("should successfully plan a single named query by providing an operation name", test(testDefinition, ` @@ -585,6 +591,7 @@ func TestPlanner_Plan(t *testing.T) { Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, Fields: FieldConfigurations{ FieldConfiguration{ TypeName: "Character", @@ -644,6 +651,7 @@ func TestPlanner_Plan(t *testing.T) { Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, Fields: FieldConfigurations{ FieldConfiguration{ TypeName: "Character", @@ -703,6 +711,7 @@ func TestPlanner_Plan(t *testing.T) { Configuration{ DisableResolveFieldPositions: true, DisableIncludeInfo: true, + DisableEntityCaching: true, DataSources: []DataSource{dsConfig}, }, )) @@ -1075,10 +1084,10 @@ type FakeDataSource struct { source *StatefulSource } -func (f *FakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { - return +func (f *FakeDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { + return nil, nil } -func (f *FakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { - return +func (f *FakeDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { + return nil, nil } diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index ebbd0d5c1..4d7441ab0 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -8,6 +8,7 @@ import ( "regexp" "slices" "strings" + "time" "github.com/wundergraph/astjson" @@ -64,6 +65,18 @@ type Visitor struct { // fieldEnclosingTypeNames maps fieldRef to the enclosing type name. fieldEnclosingTypeNames map[int]string + // plannerObjects stores the root object for each planner's ProvidesData + // map plannerID -> root object + plannerObjects map[int]*resolve.Object + // plannerCurrentFields stores the current field stack for each planner + // map plannerID -> field stack + plannerCurrentFields map[int][]objectFields + // plannerResponsePaths stores the response paths relative to each planner's root + // map plannerID -> response path stack + plannerResponsePaths map[int][]string + // plannerEntityBoundaryPaths stores the entity boundary paths for each planner + // map plannerID -> entity boundary path + plannerEntityBoundaryPaths map[int]string } type indirectInterfaceField struct { @@ -343,6 +356,12 @@ func (v *Visitor) EnterField(ref int) { if !v.Config.DisableIncludeFieldDependencies { v.fieldEnclosingTypeNames[ref] = strings.Clone(v.Walker.EnclosingTypeDefinition.NameString(v.Definition)) } + + // Track field for each planner that should handle it + for plannerID := range v.planners { + v.trackFieldForPlanner(plannerID, ref) + } + // check if we have to skip the field in the response // it means it was requested by the planner not the user if v.skipField(ref) { @@ -613,6 +632,11 @@ func (v *Visitor) addInterfaceObjectNameToTypeNames(fieldRef int, typeName []byt func (v *Visitor) LeaveField(ref int) { v.debugOnLeaveNode(ast.NodeKindField, ref) + // Pop fields for each planner that tracked this field + for plannerID := range v.planners { + v.popFieldsForPlanner(plannerID, ref) + } + if v.skipField(ref) { // we should also check skips on field leave // cause on nested keys we could mistakenly remove wrong object @@ -1004,6 +1028,9 @@ func (v *Visitor) EnterOperationDefinition(ref int) { } } + // Initialize per-planner structures for ProvidesData tracking + v.initializePlannerStructures() + if operationKind == ast.OperationTypeSubscription { v.subscription = &resolve.GraphQLSubscription{ Response: v.response, @@ -1070,6 +1097,9 @@ func (v *Visitor) EnterDocument(operation, definition *ast.Document) { v.plannerFields = map[int][]int{} v.fieldPlanners = map[int][]int{} v.fieldEnclosingTypeNames = map[int]string{} + v.plannerObjects = map[int]*resolve.Object{} + v.plannerCurrentFields = map[int][]objectFields{} + v.plannerResponsePaths = map[int][]string{} } func (v *Visitor) LeaveDocument(_, _ *ast.Document) { @@ -1124,6 +1154,302 @@ func (v *Visitor) pathDeepness(path string) int { return strings.Count(path, ".") } +func (v *Visitor) initializePlannerStructures() { + // Initialize root objects and field stacks for each potential planner + // We'll populate these as we traverse fields + if v.planners == nil { + return + } + + for i := range v.planners { + v.plannerObjects[i] = &resolve.Object{ + Fields: []*resolve.Field{}, + } + v.plannerCurrentFields[i] = []objectFields{{ + fields: &v.plannerObjects[i].Fields, + popOnField: -1, + }} + v.plannerResponsePaths[i] = []string{} + } + v.plannerEntityBoundaryPaths = map[int]string{} +} + +func (v *Visitor) trackFieldForPlanner(plannerID int, fieldRef int) { + // Safety checks + if v.planners == nil || plannerID >= len(v.planners) { + return + } + if v.plannerObjects == nil || v.plannerCurrentFields == nil { + return + } + + // Check if this planner should handle this field + if !v.shouldPlannerHandleField(plannerID, fieldRef) { + return + } + + // Get field information + fieldName := v.Operation.FieldNameBytes(fieldRef) + fieldAliasOrName := v.Operation.FieldAliasOrNameString(fieldRef) + + // For nested entity fetches, check if this field represents the entity boundary + // If so, we should skip adding this field to ProvidesData and instead add its children + if v.isEntityBoundaryField(plannerID, fieldRef) { + // Add a __typename field to the current object for entity boundary + v.addTypenameFieldForPlanner(plannerID) + return + } + + // Check if this is a __typename field and if we already have one with the same name and path + if bytes.Equal(fieldName, literal.TYPENAME) && len(v.plannerCurrentFields[plannerID]) > 0 { + currentFields := v.plannerCurrentFields[plannerID][len(v.plannerCurrentFields[plannerID])-1] + + // Check if we already have a __typename field with the same name and path + for _, existingField := range *currentFields.fields { + if bytes.Equal(existingField.Name, []byte(fieldAliasOrName)) { + // For __typename fields, the path is [fieldAliasOrName] + // Check if the existing field has the same path + if existingValue, ok := existingField.Value.(*resolve.Scalar); ok { + if len(existingValue.Path) > 0 && existingValue.Path[0] == fieldAliasOrName { + // We already have this __typename field with the same name and path, skip it + return + } + } + } + } + } + + // Get the field definition + fieldDefinition, ok := v.Walker.FieldDefinition(fieldRef) + if !ok { + return + } + fieldType := v.Definition.FieldDefinitionType(fieldDefinition) + + // Create a simple field value for tracking purposes + fieldValue := v.createFieldValueForPlanner(fieldRef, fieldType, []string{fieldAliasOrName}) + + onTypeNames := v.resolveEntityOnTypeNames(plannerID, fieldRef, fieldName) + + // Create the field + field := &resolve.Field{ + Name: []byte(fieldAliasOrName), + Value: fieldValue, + OnTypeNames: onTypeNames, + } + + // Add the field to the current object for this planner + if len(v.plannerCurrentFields[plannerID]) > 0 { + currentFields := v.plannerCurrentFields[plannerID][len(v.plannerCurrentFields[plannerID])-1] + *currentFields.fields = append(*currentFields.fields, field) + } + + for { + // for loop to unwrap array item + switch node := fieldValue.(type) { + case *resolve.Array: + // unwrap and check type again + fieldValue = node.Item + case *resolve.Object: + // if the field value is an object, add it to the current fields stack + v.Walker.DefferOnEnterField(func() { + v.plannerCurrentFields[plannerID] = append(v.plannerCurrentFields[plannerID], objectFields{ + popOnField: fieldRef, + fields: &node.Fields, + }) + }) + return + default: + // field value is a scalar or null, we don't add it to the stack + return + } + } +} + +func (v *Visitor) resolveEntityOnTypeNames(plannerID, fieldRef int, fieldName ast.ByteSlice) (onTypeNames [][]byte) { + // If this is an entity root field, return the enclosing type name + if v.isEntityRootField(plannerID, fieldRef) { + enclosingTypeName := v.Walker.EnclosingTypeDefinition.NameBytes(v.Definition) + if enclosingTypeName != nil { + return [][]byte{enclosingTypeName} + } + } + + // Otherwise, use the regular resolution logic + onTypeNames = v.resolveOnTypeNames(fieldRef, fieldName) + return onTypeNames +} + +// createFieldValueForPlanner creates a simplified field value for planner tracking +// without relying on the full visitor state like resolveFieldValue does +func (v *Visitor) createFieldValueForPlanner(fieldRef, typeRef int, path []string) resolve.Node { + ofType := v.Definition.Types[typeRef].OfType + + switch v.Definition.Types[typeRef].TypeKind { + case ast.TypeKindNonNull: + node := v.createFieldValueForPlanner(fieldRef, ofType, path) + // Set nullable to false for the returned node + switch n := node.(type) { + case *resolve.Scalar: + n.Nullable = false + case *resolve.Object: + n.Nullable = false + case *resolve.Array: + n.Nullable = false + } + return node + case ast.TypeKindList: + listItem := v.createFieldValueForPlanner(fieldRef, ofType, nil) + return &resolve.Array{ + Nullable: true, + Path: path, + Item: listItem, + } + case ast.TypeKindNamed: + typeName := v.Definition.ResolveTypeNameString(typeRef) + typeDefinitionNode, ok := v.Definition.Index.FirstNodeByNameStr(typeName) + if !ok { + return &resolve.Null{} + } + switch typeDefinitionNode.Kind { + case ast.NodeKindScalarTypeDefinition, ast.NodeKindEnumTypeDefinition: + return &resolve.Scalar{ + Nullable: true, + Path: path, + } + case ast.NodeKindObjectTypeDefinition, ast.NodeKindInterfaceTypeDefinition, ast.NodeKindUnionTypeDefinition: + // For object types, create a new object that will be populated by child fields + obj := &resolve.Object{ + Nullable: true, + Path: path, + Fields: []*resolve.Field{}, + } + return obj + default: + return &resolve.Null{} + } + default: + return &resolve.Null{} + } +} + +// isEntityBoundaryField checks if this field represents the entity boundary for a nested entity fetch +// For nested entity fetches, the field at the response path boundary should be skipped in ProvidesData +func (v *Visitor) isEntityBoundaryField(plannerID int, fieldRef int) bool { + config := v.planners[plannerID] + fetchConfig := config.ObjectFetchConfiguration() + if fetchConfig == nil || fetchConfig.fetchItem == nil { + return false + } + + // Check if this is a nested fetch (has "." in response path) + responsePath := "query." + fetchConfig.fetchItem.ResponsePath + if !strings.Contains(responsePath, ".") { + return false // Root fetch, no boundary field to skip + } + + // For nested fetches, check if this field is at the entity boundary + currentPath := v.Walker.Path.DotDelimitedString() + fieldName := v.Operation.FieldAliasOrNameString(fieldRef) + fullFieldPath := currentPath + "." + fieldName + + // If this field path matches the response path, it's the entity boundary + if fullFieldPath == responsePath { + // Store the entity boundary path for this planner + v.plannerEntityBoundaryPaths[plannerID] = fullFieldPath + return true + } + return false +} + +// isEntityRootField checks if this field is at the root of an entity +// This means it has one additional path element compared to the stored entity boundary path +func (v *Visitor) isEntityRootField(plannerID int, fieldRef int) bool { + // Check if we have a stored entity boundary path for this planner + boundaryPath, hasBoundary := v.plannerEntityBoundaryPaths[plannerID] + if !hasBoundary { + return false + } + + // Get the current field path + currentPath := v.Walker.Path.DotDelimitedString() + fieldName := v.Operation.FieldAliasOrNameString(fieldRef) + fullFieldPath := currentPath + "." + fieldName + + // Check if this field is a direct child of the entity boundary + // It should start with the boundary path and have exactly one more segment + if !strings.HasPrefix(fullFieldPath, boundaryPath+".") { + return false + } + + // Remove the boundary path prefix and check if there's exactly one segment left + remainingPath := strings.TrimPrefix(fullFieldPath, boundaryPath+".") + // If there are no more dots, this is a root field of the entity + return !strings.Contains(remainingPath, ".") +} + +// addTypenameFieldForPlanner adds a __typename field to the current object for entity boundary fields +func (v *Visitor) addTypenameFieldForPlanner(plannerID int) { + + // Create a __typename field + typenameField := &resolve.Field{ + Name: []byte("__typename"), + Value: &resolve.Scalar{ + Path: []string{"__typename"}, + }, + } + + // Add the __typename field to the current object for this planner + if len(v.plannerCurrentFields[plannerID]) > 0 { + currentFields := v.plannerCurrentFields[plannerID][len(v.plannerCurrentFields[plannerID])-1] + *currentFields.fields = append(*currentFields.fields, typenameField) + } +} + +func (v *Visitor) shouldPlannerHandleField(plannerID int, fieldRef int) bool { + // Safety checks + if v.planners == nil || plannerID >= len(v.planners) { + return false + } + + // Use the same logic as AllowVisitor to check if a planner handles a field + path := v.Walker.Path.DotDelimitedString() + if v.Walker.CurrentKind == ast.NodeKindField { + path = path + "." + v.Operation.FieldAliasOrNameString(fieldRef) + } + + config := v.planners[plannerID] + if !config.HasPath(path) { + return false + } + + enclosingTypeName := v.Walker.EnclosingTypeDefinition.NameString(v.Definition) + + allow := config.HasPathWithFieldRef(fieldRef) || config.HasParent(path) + if !allow { + return false + } + + shouldWalkFieldsOnPath := config.ShouldWalkFieldsOnPath(path, enclosingTypeName) || + config.ShouldWalkFieldsOnPath(path, "") + + return shouldWalkFieldsOnPath +} + +func (v *Visitor) popFieldsForPlanner(plannerID int, fieldRef int) { + // Safety checks + if v.plannerCurrentFields == nil || plannerID >= len(v.plannerCurrentFields) { + return + } + + if len(v.plannerCurrentFields[plannerID]) > 0 { + last := len(v.plannerCurrentFields[plannerID]) - 1 + if v.plannerCurrentFields[plannerID][last].popOnField == fieldRef { + v.plannerCurrentFields[plannerID] = v.plannerCurrentFields[plannerID][:last] + } + } +} + func (v *Visitor) resolveInputTemplates(config *objectFetchConfiguration, input *string, variables *resolve.Variables) { *input = templateRegex.ReplaceAllStringFunc(*input, func(s string) string { selectors := selectorRegex.FindStringSubmatch(s) @@ -1291,6 +1617,8 @@ func (v *Visitor) configureSubscription(config *objectFetchConfiguration) { v.subscription.Trigger.QueryPlan = subscription.QueryPlan v.resolveInputTemplates(config, &subscription.Input, &v.subscription.Trigger.Variables) v.subscription.Trigger.Input = []byte(subscription.Input) + v.subscription.Trigger.SourceName = config.sourceName + v.subscription.Trigger.SourceID = config.sourceID v.subscription.Filter = config.filter } @@ -1317,6 +1645,21 @@ func (v *Visitor) configureFetch(internal *objectFetchConfiguration, external re dataSourceType := reflect.TypeOf(external.DataSource).String() dataSourceType = strings.TrimPrefix(dataSourceType, "*") + if !v.Config.DisableEntityCaching { + external.Caching = resolve.FetchCacheConfiguration{ + Enabled: true, + CacheName: "default", + TTL: time.Second * time.Duration(30), + // templates come prepared from the DataSource + CacheKeyTemplate: external.Caching.CacheKeyTemplate, + IncludeSubgraphHeaderPrefix: true, + } + } else { + external.Caching = resolve.FetchCacheConfiguration{ + Enabled: false, + } + } + singleFetch := &resolve.SingleFetch{ FetchConfiguration: external, FetchDependencies: resolve.FetchDependencies{ @@ -1336,7 +1679,12 @@ func (v *Visitor) configureFetch(internal *objectFetchConfiguration, external re OperationType: internal.operationType, QueryPlan: external.QueryPlan, } - + if !v.Config.DisableFetchProvidesData { + // Set ProvidesData from the planner's object structure + if providesData, ok := v.plannerObjects[internal.fetchID]; ok { + singleFetch.Info.ProvidesData = providesData + } + } if v.Config.DisableIncludeFieldDependencies { return singleFetch } diff --git a/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go b/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go index f5d0b2ae2..1fdccad92 100644 --- a/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go +++ b/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go @@ -51,19 +51,11 @@ func (d *createConcreteSingleFetchTypes) traverseSingleFetch(fetch *resolve.Sing return d.createEntityBatchFetch(fetch) case fetch.RequiresEntityFetch: return d.createEntityFetch(fetch) - case fetch.RequiresParallelListItemFetch: - return d.createParallelListItemFetch(fetch) default: return fetch } } -func (d *createConcreteSingleFetchTypes) createParallelListItemFetch(fetch *resolve.SingleFetch) resolve.Fetch { - return &resolve.ParallelListItemFetch{ - Fetch: fetch, - } -} - func (d *createConcreteSingleFetchTypes) createEntityBatchFetch(fetch *resolve.SingleFetch) resolve.Fetch { representationsVariableIndex := -1 for i, segment := range fetch.InputTemplate.Segments { @@ -106,6 +98,7 @@ func (d *createConcreteSingleFetchTypes) createEntityBatchFetch(fetch *resolve.S }, DataSource: fetch.DataSource, PostProcessing: fetch.PostProcessing, + Caching: fetch.Caching, } } @@ -139,5 +132,6 @@ func (d *createConcreteSingleFetchTypes) createEntityFetch(fetch *resolve.Single }, DataSource: fetch.DataSource, PostProcessing: fetch.PostProcessing, + Caching: fetch.Caching, } } diff --git a/v2/pkg/engine/resolve/arena.go b/v2/pkg/engine/resolve/arena.go new file mode 100644 index 000000000..7909460b2 --- /dev/null +++ b/v2/pkg/engine/resolve/arena.go @@ -0,0 +1,144 @@ +package resolve + +import ( + "sync" + "weak" + + "github.com/wundergraph/go-arena" +) + +// ArenaPool provides a thread-safe pool of arena.Arena instances for memory-efficient allocations. +// It uses weak pointers to allow garbage collection of unused arenas while maintaining +// a pool of reusable arenas for high-frequency allocation patterns. +// +// by storing ArenaPoolItem as weak pointers, the GC can collect them at any time +// before using an ArenaPoolItem, we try to get a strong pointer while removing it from the pool +// once we call Release, we turn the item back to the pool and make it a weak pointer again +// this means that at any time, GC can claim back the memory if required, +// allowing GC to automatically manage an appropriate pool size depending on available memory and GC pressure +type ArenaPool struct { + // pool is a slice of weak pointers to the struct holding the arena.Arena + pool []weak.Pointer[ArenaPoolItem] + sizes map[uint64]*arenaPoolItemSize + mu sync.Mutex +} + +// arenaPoolItemSize is used to track the required memory across the last 50 arenas in the pool +type arenaPoolItemSize struct { + count int + totalBytes int +} + +// ArenaPoolItem wraps an arena.Arena for use in the pool +type ArenaPoolItem struct { + Arena arena.Arena + Key uint64 +} + +// NewArenaPool creates a new ArenaPool instance +func NewArenaPool() *ArenaPool { + return &ArenaPool{ + sizes: make(map[uint64]*arenaPoolItemSize), + } +} + +// Acquire gets an arena from the pool or creates a new one if none are available. +// The id parameter is used to track arena sizes per use case for optimization. +func (p *ArenaPool) Acquire(key uint64) *ArenaPoolItem { + p.mu.Lock() + defer p.mu.Unlock() + + // Try to find an available arena in the pool + for len(p.pool) > 0 { + // Pop the last item + lastIdx := len(p.pool) - 1 + wp := p.pool[lastIdx] + p.pool = p.pool[:lastIdx] + + v := wp.Value() + if v != nil { + v.Key = key + return v + } + // If weak pointer was nil (GC collected), continue to next item + } + + // No arena available, create a new one + size := arena.WithMinBufferSize(p.getArenaSize(key)) + return &ArenaPoolItem{ + Arena: arena.NewMonotonicArena(size), + Key: key, + } +} + +// Release returns an arena to the pool for reuse. +// The peak memory usage is recorded to optimize future arena sizes for this use case. +func (p *ArenaPool) Release(item *ArenaPoolItem) { + peak := item.Arena.Peak() + item.Arena.Reset() + + p.mu.Lock() + defer p.mu.Unlock() + + // Record the peak usage for this use case + if size, ok := p.sizes[item.Key]; ok { + if size.count == 50 { + size.count = 1 + size.totalBytes = size.totalBytes / 50 + } + size.count++ + size.totalBytes += peak + } else { + p.sizes[item.Key] = &arenaPoolItemSize{ + count: 1, + totalBytes: peak, + } + } + + item.Key = 0 + + // Add the arena back to the pool using a weak pointer + w := weak.Make(item) + p.pool = append(p.pool, w) +} + +func (p *ArenaPool) ReleaseMany(items []*ArenaPoolItem) { + p.mu.Lock() + defer p.mu.Unlock() + + for _, item := range items { + + peak := item.Arena.Peak() + item.Arena.Reset() + + // Record the peak usage for this use case + if size, ok := p.sizes[item.Key]; ok { + if size.count == 50 { + size.count = 1 + size.totalBytes = size.totalBytes / 50 + } + size.count++ + size.totalBytes += peak + } else { + p.sizes[item.Key] = &arenaPoolItemSize{ + count: 1, + totalBytes: peak, + } + } + + item.Key = 0 + + // Add the arena back to the pool using a weak pointer + w := weak.Make(item) + p.pool = append(p.pool, w) + } +} + +// getArenaSize returns the optimal arena size for a given use case ID. +// If no size is recorded, it defaults to 1MB. +func (p *ArenaPool) getArenaSize(id uint64) int { + if size, ok := p.sizes[id]; ok { + return size.totalBytes / size.count + } + return 1024 * 1024 // Default 1MB +} diff --git a/v2/pkg/engine/resolve/arena_test.go b/v2/pkg/engine/resolve/arena_test.go new file mode 100644 index 000000000..c884434f1 --- /dev/null +++ b/v2/pkg/engine/resolve/arena_test.go @@ -0,0 +1,261 @@ +package resolve + +import ( + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/go-arena" +) + +func TestNewArenaPool(t *testing.T) { + pool := NewArenaPool() + + require.NotNil(t, pool, "NewArenaPool returned nil") + assert.Equal(t, 0, len(pool.pool), "expected empty pool") + assert.Equal(t, 0, len(pool.sizes), "expected empty sizes map") +} + +func TestArenaPool_Acquire_EmptyPool(t *testing.T) { + pool := NewArenaPool() + + item := pool.Acquire(1) + + require.NotNil(t, item, "Acquire returned nil") + assert.NotNil(t, item.Arena, "Arena is nil") + + // Verify we can use the arena + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("test") + assert.NoError(t, err) + + assert.Equal(t, 0, len(pool.pool), "pool should still be empty") +} + +func TestArenaPool_ReleaseAndAcquire(t *testing.T) { + pool := NewArenaPool() + id := uint64(42) + + // Acquire first arena + item1 := pool.Acquire(id) + + // Use the arena + buf := arena.NewArenaBuffer(item1.Arena) + _, err := buf.WriteString("test data") + assert.NoError(t, err) + + // Release it + pool.Release(item1) + + // Pool should have one item + assert.Equal(t, 1, len(pool.pool), "expected pool to have 1 item") + + // Acquire from pool + item2 := pool.Acquire(id) + + require.NotNil(t, item2, "Acquire returned nil") + + // Pool should be empty again + assert.Equal(t, 0, len(pool.pool), "expected empty pool after acquire") + + // The acquired arena should be reset and usable + buf2 := arena.NewArenaBuffer(item2.Arena) + _, err = buf2.WriteString("new data") + assert.NoError(t, err) + + assert.Equal(t, "new data", buf2.String()) +} + +func TestArenaPool_Acquire_ProvesBugFix(t *testing.T) { + // This test specifically proves the bug fix works + // Creates multiple items, clears some references, then acquires + // to ensure all items are checked without skipping + pool := NewArenaPool() + id := uint64(800) + + numItems := 10 + items := make([]*ArenaPoolItem, numItems) + + // Acquire all items + for i := 0; i < numItems; i++ { + items[i] = pool.Acquire(id) + buf := arena.NewArenaBuffer(items[i].Arena) + _, err := buf.WriteString("item data") + assert.NoError(t, err) + } + + // Release all while keeping strong references + for i := 0; i < numItems; i++ { + pool.Release(items[i]) + } + + // Pool should have all items + assert.Equal(t, numItems, len(pool.pool), "expected items in pool") + + // Clear every other item to simulate partial GC + for i := 0; i < numItems; i += 2 { + items[i] = nil + } + + // Force GC + runtime.GC() + runtime.GC() + + // Acquire items - should process ALL items without skipping + processed := 0 + acquired := 0 + + for len(pool.pool) > 0 && processed < numItems*2 { + poolSizeBefore := len(pool.pool) + item := pool.Acquire(id) + poolSizeAfter := len(pool.pool) + processed++ + + assert.Less(t, poolSizeAfter, poolSizeBefore, "Pool size did not decrease - item not removed properly!") + + if item != nil { + acquired++ + } + } + + // Pool should be empty + assert.Equal(t, 0, len(pool.pool), "expected empty pool") +} + +func TestArenaPool_Release_PeakTracking(t *testing.T) { + pool := NewArenaPool() + id := uint64(200) + + // First arena + item1 := pool.Acquire(id) + buf1 := arena.NewArenaBuffer(item1.Arena) + _, err := buf1.WriteString("small") + assert.NoError(t, err) + + peak1 := item1.Arena.Peak() + assert.Equal(t, peak1, 5) + + pool.Release(item1) + + // Check that size was tracked + size, exists := pool.sizes[id] + require.True(t, exists, "size tracking not created") + assert.Equal(t, 1, size.count, "expected count 1") + + // Second arena + item2 := pool.Acquire(id) + buf2 := arena.NewArenaBuffer(item2.Arena) + _, err = buf2.WriteString("larger data") + assert.NoError(t, err) + + pool.Release(item2) + + // Check updated tracking + assert.Equal(t, 2, size.count, "expected count 2") +} + +func TestArenaPool_GetArenaSize(t *testing.T) { + pool := NewArenaPool() + + // Test default size for unknown ID + size1 := pool.getArenaSize(999) + expectedDefault := 1024 * 1024 + assert.Equal(t, expectedDefault, size1, "expected default size") + + // Test calculated size after usage + id := uint64(400) + item := pool.Acquire(id) + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("some data") + assert.NoError(t, err) + pool.Release(item) + + size2 := pool.getArenaSize(id) + assert.NotEqual(t, 0, size2, "expected non-zero size after usage") +} + +func TestArenaPool_MultipleItemsInPool(t *testing.T) { + pool := NewArenaPool() + id := uint64(500) + + // Acquire multiple distinct items + numItems := 3 + items := make([]*ArenaPoolItem, numItems) + + for i := 0; i < numItems; i++ { + items[i] = pool.Acquire(id) + buf := arena.NewArenaBuffer(items[i].Arena) + _, err := buf.WriteString("data") + assert.NoError(t, err) + } + + // Release all while keeping references + for i := 0; i < numItems; i++ { + pool.Release(items[i]) + } + + // Should have all items in pool + assert.Equal(t, numItems, len(pool.pool), "expected items in pool") + + // Acquire all back + acquired := 0 + for len(pool.pool) > 0 { + item := pool.Acquire(id) + if item != nil { + acquired++ + } + } + + assert.Equal(t, numItems, acquired, "expected to acquire all items") +} + +func TestArenaPool_Release_MovingWindow(t *testing.T) { + pool := NewArenaPool() + id := uint64(600) + + // Release exactly 50 items + for i := 0; i < 50; i++ { + item := pool.Acquire(id) + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("test data") + assert.NoError(t, err) + pool.Release(item) + } + + // After 50 releases, verify count and total + size := pool.sizes[id] + require.NotNil(t, size, "size tracking should exist") + assert.Equal(t, 50, size.count, "expected count to be 50") + + totalBytesAfter50 := size.totalBytes + + // Release one more item to trigger the window reset + item51 := pool.Acquire(id) + buf51 := arena.NewArenaBuffer(item51.Arena) + _, err := buf51.WriteString("test data") + assert.NoError(t, err) + peak51 := item51.Arena.Peak() + pool.Release(item51) + + // After 51st release, verify the window was reset + // count should be 2 (reset to 1, then incremented) + // totalBytes should be (totalBytesAfter50 / 50) + peak51 + assert.Equal(t, 2, size.count, "expected count to be 2 after window reset") + + expectedTotalBytes := (totalBytesAfter50 / 50) + peak51 + assert.Equal(t, expectedTotalBytes, size.totalBytes, "expected totalBytes to be divided by 50 and new peak added") + + // Verify we can continue releasing and counting works correctly + for i := 0; i < 10; i++ { + item := pool.Acquire(id) + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("more data") + assert.NoError(t, err) + pool.Release(item) + } + + // After 10 more releases, count should be 12 (2 + 10) + assert.Equal(t, 12, size.count, "expected count to continue incrementing after window reset") +} diff --git a/v2/pkg/engine/resolve/authorization_test.go b/v2/pkg/engine/resolve/authorization_test.go index 263724a77..95051def7 100644 --- a/v2/pkg/engine/resolve/authorization_test.go +++ b/v2/pkg/engine/resolve/authorization_test.go @@ -1,11 +1,11 @@ package resolve import ( - "bytes" "context" "encoding/json" "errors" "io" + "net/http" "sync/atomic" "testing" @@ -510,38 +510,32 @@ func TestAuthorization(t *testing.T) { func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller) *GraphQLResponse { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }).AnyTimes() reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }).AnyTimes() productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }).AnyTimes() return &GraphQLResponse{ @@ -821,38 +815,32 @@ func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller func generateTestFederationGraphQLResponseWithoutAuthorizationRules(t *testing.T, ctrl *gomock.Controller) *GraphQLResponse { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }).AnyTimes() reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }).AnyTimes() productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }).AnyTimes() return &GraphQLResponse{ diff --git a/v2/pkg/engine/resolve/caching.go b/v2/pkg/engine/resolve/caching.go new file mode 100644 index 000000000..10566075a --- /dev/null +++ b/v2/pkg/engine/resolve/caching.go @@ -0,0 +1,240 @@ +package resolve + +import ( + "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafebytes" +) + +type CacheKeyTemplate interface { + // RenderCacheKeys returns multiple cache keys (one per root field or entity) + // Generates keys for all items at once + RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value, prefix string) ([]*CacheKey, error) +} + +type CacheKey struct { + Item *astjson.Value + FromCache *astjson.Value + Keys []string +} + +type RootQueryCacheKeyTemplate struct { + RootFields []QueryField +} + +type QueryField struct { + Coordinate GraphCoordinate + Args []FieldArgument +} + +type FieldArgument struct { + Name string + Variable Variable +} + +// RenderCacheKeys returns multiple cache keys, one per item +// Each cache key contains one or more KeyEntry objects (one per root field) +func (r *RootQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value, prefix string) ([]*CacheKey, error) { + if len(r.RootFields) == 0 { + return nil, nil + } + // Estimate capacity: one CacheKey per item + cacheKeys := arena.AllocateSlice[*CacheKey](a, 0, len(items)) + jsonBytes := arena.AllocateSlice[byte](a, 0, 64) + + for _, item := range items { + // Create KeyEntry for each root field + keyEntries := arena.AllocateSlice[string](a, 0, len(r.RootFields)) + for _, field := range r.RootFields { + var key string + key, jsonBytes = r.renderField(a, ctx, item, jsonBytes, field) + if prefix != "" { + l := len(prefix) + 1 + len(key) + tmp := arena.AllocateSlice[byte](a, 0, l) + tmp = arena.SliceAppend(a, tmp, unsafebytes.StringToBytes(prefix)...) + tmp = arena.SliceAppend(a, tmp, []byte(`:`)...) + tmp = arena.SliceAppend(a, tmp, unsafebytes.StringToBytes(key)...) + key = unsafebytes.BytesToString(tmp) + } + keyEntries = arena.SliceAppend(a, keyEntries, key) + } + cacheKeys = arena.SliceAppend(a, cacheKeys, &CacheKey{ + Item: item, + Keys: keyEntries, + }) + } + return cacheKeys, nil +} + +// renderField renders a single field cache key as JSON +func (r *RootQueryCacheKeyTemplate) renderField(a arena.Arena, ctx *Context, item *astjson.Value, jsonBytes []byte, field QueryField) (string, []byte) { + // Build JSON object starting with __typename + keyObj := astjson.ObjectValue(a) + typeName := field.Coordinate.TypeName + keyObj.Set(a, "__typename", astjson.StringValue(a, typeName)) + keyObj.Set(a, "field", astjson.StringValue(a, field.Coordinate.FieldName)) + + // Build args object if there are any arguments + if len(field.Args) > 0 { + argsObj := astjson.ObjectValue(a) + for _, arg := range field.Args { + var argValue *astjson.Value + segment := arg.Variable.TemplateSegment() + if segment.Renderer != nil { + switch segment.VariableKind { + case ContextVariableKind: + // Extract value from context variables + variableSourcePath := segment.VariableSourcePath + if len(variableSourcePath) == 1 && ctx.RemapVariables != nil { + if nameToUse, hasMapping := ctx.RemapVariables[variableSourcePath[0]]; hasMapping && nameToUse != variableSourcePath[0] { + variableSourcePath = []string{nameToUse} + } + } + argValue = ctx.Variables.Get(variableSourcePath...) + if argValue == nil { + argValue = astjson.NullValue + } + case ObjectVariableKind: + // Use data parameter for object variables + if item != nil { + value := item.Get(segment.VariableSourcePath...) + if value == nil || value.Type() == astjson.TypeNull { + argValue = astjson.NullValue + } else { + // Values are already JSON-compatible astjson.Value + argValue = value + } + } else { + argValue = astjson.NullValue + } + default: + // For other variable kinds, use data parameter + if item != nil { + argValue = item + } else { + argValue = astjson.NullValue + } + } + } else { + argValue = astjson.NullValue + } + argsObj.Set(a, arg.Name, argValue) + } + keyObj.Set(a, "args", argsObj) + } + + // Marshal to JSON and write to output + jsonBytes = keyObj.MarshalTo(jsonBytes[:0]) + slice := arena.AllocateSlice[byte](a, len(jsonBytes), len(jsonBytes)) + copy(slice, jsonBytes) + return unsafebytes.BytesToString(slice), jsonBytes +} + +type EntityQueryCacheKeyTemplate struct { + Keys *ResolvableObjectVariable +} + +// RenderCacheKeys returns one cache key per item for entity queries with keys nested under "keys" +func (e *EntityQueryCacheKeyTemplate) RenderCacheKeys(a arena.Arena, ctx *Context, items []*astjson.Value, prefix string) ([]*CacheKey, error) { + jsonBytes := arena.AllocateSlice[byte](a, 0, 64) + cacheKeys := arena.AllocateSlice[*CacheKey](a, 0, len(items)) + + for _, item := range items { + if item == nil { + continue + } + + // Build JSON object starting with __typename + keyObj := astjson.ObjectValue(a) + + // Extract __typename from the data + typename := item.Get("__typename") + if typename == nil { + // Fallback if no __typename in data + keyObj.Set(a, "__typename", astjson.StringValue(a, "Entity")) + } else { + keyObj.Set(a, "__typename", typename) + } + + // Put entity keys under "keys" nested object + keysObj := astjson.ObjectValue(a) + + // Extract only the fields defined in the Keys template (not all fields from data) + if e.Keys != nil && e.Keys.Renderer != nil { + if obj, ok := e.Keys.Renderer.Node.(*Object); ok { + for _, field := range obj.Fields { + fieldName := unsafebytes.BytesToString(field.Name) + // Skip __typename as it's already handled separately + if fieldName == "__typename" { + continue + } + // Resolve field value based on its template definition + fieldValue := e.resolveFieldValue(a, field.Value, item) + if fieldValue != nil && fieldValue.Type() != astjson.TypeNull { + keysObj.Set(a, fieldName, fieldValue) + } + } + } + } + + keyObj.Set(a, "key", keysObj) + + // Marshal to JSON and write to buffer + jsonBytes = keyObj.MarshalTo(jsonBytes[:0]) + l := len(jsonBytes) + if prefix != "" { + l += 1 + len(prefix) + } + slice := arena.AllocateSlice[byte](a, 0, l) + if prefix != "" { + slice = arena.SliceAppend(a, slice, unsafebytes.StringToBytes(prefix)...) + slice = arena.SliceAppend(a, slice, []byte(`:`)...) + } + slice = arena.SliceAppend(a, slice, jsonBytes...) + + // Create KeyEntry with empty path for entity queries + keyEntries := arena.AllocateSlice[string](a, 0, 1) + keyEntries = arena.SliceAppend(a, keyEntries, unsafebytes.BytesToString(slice)) + + cacheKeys = arena.SliceAppend(a, cacheKeys, &CacheKey{ + Item: item, + Keys: keyEntries, + }) + } + + return cacheKeys, nil +} + +// resolveFieldValue resolves a field value from data based on its template definition +func (e *EntityQueryCacheKeyTemplate) resolveFieldValue(a arena.Arena, valueNode Node, data *astjson.Value) *astjson.Value { + switch node := valueNode.(type) { + case *String: + // Extract string value from data using the path + return data.Get(node.Path...) + case *Object: + // For nested objects, recursively build the object using only template-defined fields + nestedObj := astjson.ObjectValue(a) + // Get the base object from data using the object's path + baseData := data.Get(node.Path...) + if baseData == nil || baseData.Type() == astjson.TypeNull { + return nil + } + // Recursively resolve each field in the nested object template + for _, field := range node.Fields { + fieldName := unsafebytes.BytesToString(field.Name) + // Skip __typename in nested objects + if fieldName == "__typename" { + continue + } + fieldValue := e.resolveFieldValue(a, field.Value, baseData) + if fieldValue != nil && fieldValue.Type() != astjson.TypeNull { + nestedObj.Set(a, fieldName, fieldValue) + } + } + return nestedObj + default: + // For other types not handled above, return nil + return nil + } +} diff --git a/v2/pkg/engine/resolve/caching_test.go b/v2/pkg/engine/resolve/caching_test.go new file mode 100644 index 000000000..f382f58f3 --- /dev/null +++ b/v2/pkg/engine/resolve/caching_test.go @@ -0,0 +1,960 @@ +package resolve + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" +) + +func TestCachingRenderRootQueryCacheKeyTemplate(t *testing.T) { + t.Run("single field no arguments", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "users", + }, + Args: []FieldArgument{}, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`{"__typename":"Query","field":"users"}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("single field single argument", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "droid", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"id":1}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`{"__typename":"Query","field":"droid","args":{"id":1}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("single field single string argument", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "name", + Variable: &ContextVariable{ + Path: []string{"name"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"name":"john"}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`{"__typename":"Query","field":"user","args":{"name":"john"}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("single field multiple arguments", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "search", + }, + Args: []FieldArgument{ + { + Name: "term", + Variable: &ContextVariable{ + Path: []string{"term"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + { + Name: "max", + Variable: &ContextVariable{ + Path: []string{"max"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"term":"C3PO","max":10}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`{"__typename":"Query","field":"search","args":{"term":"C3PO","max":10}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("single field multiple arguments with boolean", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "products", + }, + Args: []FieldArgument{ + { + Name: "includeDeleted", + Variable: &ContextVariable{ + Path: []string{"includeDeleted"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + { + Name: "limit", + Variable: &ContextVariable{ + Path: []string{"limit"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"includeDeleted":true,"limit":20}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`{"__typename":"Query","field":"products","args":{"includeDeleted":true,"limit":20}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("multiple fields single argument each", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "droid", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "name", + Variable: &ContextVariable{ + Path: []string{"name"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"id":1,"name":"john"}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{ + `{"__typename":"Query","field":"droid","args":{"id":1}}`, + `{"__typename":"Query","field":"user","args":{"name":"john"}}`, + }, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("multiple fields with mixed arguments", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "product", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + { + Name: "includeReviews", + Variable: &ContextVariable{ + Path: []string{"includeReviews"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "hero", + }, + Args: []FieldArgument{}, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"id":"123","includeReviews":true}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{ + `{"__typename":"Query","field":"product","args":{"id":"123","includeReviews":true}}`, + `{"__typename":"Query","field":"hero"}`, + }, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("field with object variable argument", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "search", + }, + Args: []FieldArgument{ + { + Name: "filter", + Variable: &ObjectVariable{ + Path: []string{"filter"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{"filter":{"category":"electronics","price":100}}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`{"__typename":"Query","field":"search","args":{"filter":{"category":"electronics","price":100}}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("field with null argument", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"id":null}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`{"__typename":"Query","field":"user","args":{"id":null}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("field with missing argument", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`{"__typename":"Query","field":"user","args":{"id":null}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("field with array argument", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "products", + }, + Args: []FieldArgument{ + { + Name: "ids", + Variable: &ContextVariable{ + Path: []string{"ids"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"ids":[1,2,3]}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`{"__typename":"Query","field":"products","args":{"ids":[1,2,3]}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("non-Query type", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Subscription", + FieldName: "messageAdded", + }, + Args: []FieldArgument{ + { + Name: "roomId", + Variable: &ContextVariable{ + Path: []string{"roomId"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"roomId":"123"}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`{"__typename":"Subscription","field":"messageAdded","args":{"roomId":"123"}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("single field with arena", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "name", + Variable: &ContextVariable{ + Path: []string{"name"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ar := arena.NewMonotonicArena(arena.WithMinBufferSize(1024)) + ctx := &Context{ + Variables: astjson.MustParse(`{"name":"john"}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(ar, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`{"__typename":"Query","field":"user","args":{"name":"john"}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("single field with prefix", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"id":1}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "prefix") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`prefix:{"__typename":"Query","field":"user","args":{"id":1}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("multiple fields with prefix", func(t *testing.T) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "droid", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "name", + Variable: &ContextVariable{ + Path: []string{"name"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + ctx := &Context{ + Variables: astjson.MustParse(`{"id":1,"name":"john"}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "my-prefix") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{ + `my-prefix:{"__typename":"Query","field":"droid","args":{"id":1}}`, + `my-prefix:{"__typename":"Query","field":"user","args":{"name":"john"}}`, + }, + }, + } + assert.Equal(t, expected, cacheKeys) + }) +} + +func TestCachingRenderEntityQueryCacheKeyTemplate(t *testing.T) { + t.Run("single entity with typename and id", func(t *testing.T) { + tmpl := &EntityQueryCacheKeyTemplate{ + Keys: NewResolvableObjectVariable(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + }, + }), + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{"__typename":"Product","id":"123"}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`{"__typename":"Product","key":{"id":"123"}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("single entity with multiple keys", func(t *testing.T) { + tmpl := &EntityQueryCacheKeyTemplate{ + Keys: NewResolvableObjectVariable(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("sku"), + Value: &String{ + Path: []string{"sku"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{"__typename":"Product","sku":"ABC123","upc":"DEF456","name":"Trilby"}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`{"__typename":"Product","key":{"sku":"ABC123","upc":"DEF456"}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("single entity with prefix", func(t *testing.T) { + tmpl := &EntityQueryCacheKeyTemplate{ + Keys: NewResolvableObjectVariable(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + }, + }), + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{"__typename":"Product","id":"123"}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "entity-prefix") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`entity-prefix:{"__typename":"Product","key":{"id":"123"}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) + + t.Run("entity with multiple keys and prefix", func(t *testing.T) { + tmpl := &EntityQueryCacheKeyTemplate{ + Keys: NewResolvableObjectVariable(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("sku"), + Value: &String{ + Path: []string{"sku"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + } + + ctx := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + data := astjson.MustParse(`{"__typename":"Product","sku":"ABC123","upc":"DEF456","name":"Trilby"}`) + cacheKeys, err := tmpl.RenderCacheKeys(nil, ctx, []*astjson.Value{data}, "cache") + assert.NoError(t, err) + expected := []*CacheKey{ + { + Item: data, + Keys: []string{`cache:{"__typename":"Product","key":{"sku":"ABC123","upc":"DEF456"}}`}, + }, + } + assert.Equal(t, expected, cacheKeys) + }) +} + +func BenchmarkRenderCacheKeys(b *testing.B) { + a := arena.NewMonotonicArena(arena.WithMinBufferSize(1024)) + + ctxRootQuery := &Context{ + Variables: astjson.MustParse(`{"id":1,"name":"john","term":"C3PO","max":10}`), + ctx: context.Background(), + } + + ctxEntityQuery := &Context{ + Variables: astjson.MustParse(`{}`), + ctx: context.Background(), + } + + b.Run("RootQuery/SingleField", func(b *testing.B) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + data := astjson.MustParse(`{}`) + items := []*astjson.Value{data} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + a.Reset() + _, err := tmpl.RenderCacheKeys(a, ctxRootQuery, items, "") + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run("RootQuery/MultipleFields", func(b *testing.B) { + tmpl := &RootQueryCacheKeyTemplate{ + RootFields: []QueryField{ + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "droid", + }, + Args: []FieldArgument{ + { + Name: "id", + Variable: &ContextVariable{ + Path: []string{"id"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "user", + }, + Args: []FieldArgument{ + { + Name: "name", + Variable: &ContextVariable{ + Path: []string{"name"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + { + Coordinate: GraphCoordinate{ + TypeName: "Query", + FieldName: "search", + }, + Args: []FieldArgument{ + { + Name: "term", + Variable: &ContextVariable{ + Path: []string{"term"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + { + Name: "max", + Variable: &ContextVariable{ + Path: []string{"max"}, + Renderer: NewCacheKeyVariableRenderer(), + }, + }, + }, + }, + }, + } + + data := astjson.MustParse(`{}`) + items := []*astjson.Value{data} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + a.Reset() + _, err := tmpl.RenderCacheKeys(a, ctxRootQuery, items, "") + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run("EntityQuery", func(b *testing.B) { + tmpl := &EntityQueryCacheKeyTemplate{ + Keys: NewResolvableObjectVariable(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + { + Name: []byte("sku"), + Value: &String{ + Path: []string{"sku"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + } + + data1 := astjson.MustParse(`{"__typename":"Product","id":"123","sku":"ABC123","upc":"DEF456","name":"Trilby"}`) + data2 := astjson.MustParse(`{"__typename":"Product","id":"456","sku":"XYZ789","upc":"GHI012","name":"Fedora"}`) + data3 := astjson.MustParse(`{"__typename":"Product","id":"789","sku":"JKL345","upc":"MNO678","name":"Boater"}`) + items := []*astjson.Value{data1, data2, data3} + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + a.Reset() + _, err := tmpl.RenderCacheKeys(a, ctxEntityQuery, items, "") + if err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/v2/pkg/engine/resolve/const.go b/v2/pkg/engine/resolve/const.go index 8a259494e..2958fe1f5 100644 --- a/v2/pkg/engine/resolve/const.go +++ b/v2/pkg/engine/resolve/const.go @@ -8,6 +8,8 @@ var ( lBrack = []byte("[") rBrack = []byte("]") comma = []byte(",") + pipe = []byte("|") + dot = []byte(".") colon = []byte(":") quote = []byte("\"") null = []byte("null") diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index 65d2d6b90..5783b29a5 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -16,6 +16,7 @@ import ( type Context struct { ctx context.Context Variables *astjson.Value + VariablesHash uint64 Files []*httpclient.FileUpload Request Request RenameTypeNames []RenameTypeName @@ -32,12 +33,48 @@ type Context struct { fieldRenderer FieldValueRenderer subgraphErrors error + + SubgraphHeadersBuilder SubgraphHeadersBuilder +} + +// SubgraphHeadersBuilder allows the user of the engine to "define" the headers for a subgraph request +// Instead of going back and forth between engine & transport, +// you can simply define a function that returns headers for a Subgraph request +// In addition to just the header, the implementer can return a hash for the header which will be used by request deduplication +type SubgraphHeadersBuilder interface { + // HeadersForSubgraph must return the headers and a hash for a Subgraph Request + // The hash will be used for request deduplication + HeadersForSubgraph(subgraphName string) (http.Header, uint64) + // HashAll must return the hash for all subgraph requests combined + HashAll() uint64 +} + +// HeadersForSubgraphRequest returns headers and a hash for a request that the engine will make to a subgraph +func (c *Context) HeadersForSubgraphRequest(subgraphName string) (http.Header, uint64) { + if c.SubgraphHeadersBuilder == nil { + return nil, 0 + } + return c.SubgraphHeadersBuilder.HeadersForSubgraph(subgraphName) } type ExecutionOptions struct { - SkipLoader bool + // SkipLoader will, as the name indicates, skip loading data + // However, it does indeed resolve a response + // This can be useful, e.g. in combination with IncludeQueryPlanInResponse + // The purpose is to get a QueryPlan (even for Subscriptions) + SkipLoader bool + // IncludeQueryPlanInResponse generates a QueryPlan as part of the response in Resolvable IncludeQueryPlanInResponse bool - SendHeartbeat bool + // SendHeartbeat sends regular HeartBeats for Subscriptions + SendHeartbeat bool + // DisableSubgraphRequestDeduplication disables deduplication of requests to the same subgraph with the same input within a single operation execution. + DisableSubgraphRequestDeduplication bool + // DisableInboundRequestDeduplication disables deduplication of inbound client requests + // The engine is hashing the normalized operation, variables, and forwarded headers to achieve robust deduplication + // By default, overhead is negligible and as such this should be false (not disabled) most of the time + // However, if you're benchmarking internals of the engine, it can be helpful to switch it off + // When disabled (set to true) the code becomes a no-op + DisableInboundRequestDeduplication bool } type FieldValue struct { @@ -146,7 +183,7 @@ func (c *Context) appendSubgraphErrors(errs ...error) { } type Request struct { - ID string + ID uint64 Header http.Header } diff --git a/v2/pkg/engine/resolve/datasource.go b/v2/pkg/engine/resolve/datasource.go index c679d7693..7855fa637 100644 --- a/v2/pkg/engine/resolve/datasource.go +++ b/v2/pkg/engine/resolve/datasource.go @@ -1,28 +1,24 @@ package resolve import ( - "bytes" "context" - - "github.com/cespare/xxhash/v2" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" ) type DataSource interface { - Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) - LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) + Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) + LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) } type SubscriptionDataSource interface { // Start is called when a new subscription is created. It establishes the connection to the data source. // The updater is used to send updates to the client. Deduplication of the request must be done before calling this method. - Start(ctx *Context, input []byte, updater SubscriptionUpdater) error - UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) + Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error } type AsyncSubscriptionDataSource interface { - AsyncStart(ctx *Context, id uint64, input []byte, updater SubscriptionUpdater) error + AsyncStart(ctx *Context, id uint64, headers http.Header, input []byte, updater SubscriptionUpdater) error AsyncStop(id uint64) - UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) } diff --git a/v2/pkg/engine/resolve/event_loop_test.go b/v2/pkg/engine/resolve/event_loop_test.go index 11389630a..ba8b7c8e2 100644 --- a/v2/pkg/engine/resolve/event_loop_test.go +++ b/v2/pkg/engine/resolve/event_loop_test.go @@ -3,12 +3,12 @@ package resolve import ( "context" "io" + "net/http" "sync" "sync/atomic" "testing" "time" - "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/require" ) @@ -71,12 +71,7 @@ type FakeSource struct { interval time.Duration } -func (f *FakeSource) UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = xxh.Write(input) - return err -} - -func (f *FakeSource) Start(ctx *Context, input []byte, updater SubscriptionUpdater) error { +func (f *FakeSource) Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error { go func() { for i, u := range f.updates { updater.Update([]byte(u)) diff --git a/v2/pkg/engine/resolve/fetch.go b/v2/pkg/engine/resolve/fetch.go index deeea25a4..c6792ae68 100644 --- a/v2/pkg/engine/resolve/fetch.go +++ b/v2/pkg/engine/resolve/fetch.go @@ -4,6 +4,7 @@ import ( "encoding/json" "slices" "strings" + "time" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) @@ -12,7 +13,6 @@ type FetchKind int const ( FetchKindSingle FetchKind = iota + 1 - FetchKindParallelListItem FetchKindEntity FetchKindEntityBatch ) @@ -160,12 +160,14 @@ func (*SingleFetch) FetchKind() FetchKind { type BatchEntityFetch struct { FetchDependencies - Input BatchInput - DataSource DataSource - PostProcessing PostProcessingConfiguration - DataSourceIdentifier []byte - Trace *DataSourceLoadTrace - Info *FetchInfo + Input BatchInput + DataSource DataSource + PostProcessing PostProcessingConfiguration + DataSourceIdentifier []byte + Trace *DataSourceLoadTrace + Info *FetchInfo + CoordinateDependencies []FetchDependency + Caching FetchCacheConfiguration } func (b *BatchEntityFetch) Dependencies() *FetchDependencies { @@ -200,12 +202,14 @@ func (*BatchEntityFetch) FetchKind() FetchKind { type EntityFetch struct { FetchDependencies - Input EntityInput - DataSource DataSource - PostProcessing PostProcessingConfiguration - DataSourceIdentifier []byte - Trace *DataSourceLoadTrace - Info *FetchInfo + CoordinateDependencies []FetchDependency + Input EntityInput + DataSource DataSource + PostProcessing PostProcessingConfiguration + DataSourceIdentifier []byte + Trace *DataSourceLoadTrace + Info *FetchInfo + Caching FetchCacheConfiguration } func (e *EntityFetch) Dependencies() *FetchDependencies { @@ -227,27 +231,6 @@ func (*EntityFetch) FetchKind() FetchKind { return FetchKindEntity } -// The ParallelListItemFetch can be used to make nested parallel fetches within a list -// Usually, you want to batch fetches within a list, which is the default behavior of SingleFetch -// However, if the data source does not support batching, you can use this fetch to make parallel fetches within a list -type ParallelListItemFetch struct { - Fetch *SingleFetch - Traces []*SingleFetch - Trace *DataSourceLoadTrace -} - -func (p *ParallelListItemFetch) Dependencies() *FetchDependencies { - return &p.Fetch.FetchDependencies -} - -func (p *ParallelListItemFetch) FetchInfo() *FetchInfo { - return p.Fetch.Info -} - -func (*ParallelListItemFetch) FetchKind() FetchKind { - return FetchKindParallelListItem -} - type QueryPlan struct { DependsOnFields []Representation Query string @@ -272,12 +255,6 @@ type FetchConfiguration struct { Variables Variables DataSource DataSource - // RequiresParallelListItemFetch indicates that the single fetches should be executed without batching. - // If we have multiple fetches attached to the object, then after post-processing of a plan - // we will get ParallelListItemFetch instead of ParallelFetch. - // Happens only for objects under the array path and used only for the introspection. - RequiresParallelListItemFetch bool - // RequiresEntityFetch will be set to true if the fetch is an entity fetch on an object. // After post-processing, we will get EntityFetch. RequiresEntityFetch bool @@ -297,8 +274,16 @@ type FetchConfiguration struct { QueryPlan *QueryPlan + // CoordinateDependencies contain a list of GraphCoordinates (typeName+fieldName) + // and which fields from other fetches they depend on. + // This information is useful to understand why a fetch depends on other fetches, + // and how multiple dependencies lead to a chain of fetches + CoordinateDependencies []FetchDependency + // OperationName is non-empty when the operation name is propagated to the upstream subgraph fetch. OperationName string + + Caching FetchCacheConfiguration } func (fc *FetchConfiguration) Equals(other *FetchConfiguration) bool { @@ -313,9 +298,6 @@ func (fc *FetchConfiguration) Equals(other *FetchConfiguration) bool { // Note: we do not compare datasources, as they will always be a different instance. - if fc.RequiresParallelListItemFetch != other.RequiresParallelListItemFetch { - return false - } if fc.RequiresEntityFetch != other.RequiresEntityFetch { return false } @@ -332,6 +314,23 @@ func (fc *FetchConfiguration) Equals(other *FetchConfiguration) bool { return true } +type FetchCacheConfiguration struct { + // Enabled indicates if caching is enabled for this fetch + Enabled bool + // CacheName is the name of the cache to use for this fetch + CacheName string + // TTL is the time to live which will be set for new cache entries + TTL time.Duration + // CacheKeyTemplate can be used to render a cache key for the fetch. + // In case of a root fetch, the variables will be one or more field arguments + // For entity fetches, the variables will be a single Object Variable with @key and @requires fields + CacheKeyTemplate CacheKeyTemplate + // IncludeSubgraphHeaderPrefix indicates if cache keys should be prefixed with the subgraph header hash. + // The prefix format is "id:cacheKey" where id is the hash from HeadersForSubgraph. + // Defaults to true. + IncludeSubgraphHeaderPrefix bool +} + // FetchDependency explains how a GraphCoordinate depends on other GraphCoordinates from other fetches type FetchDependency struct { // Coordinate is the type+field which depends on one or more FetchDependencyOrigin @@ -394,6 +393,7 @@ type FetchInfo struct { // with the request to the subgraph as part of the "fetch_reason" extension. // Specifically, it is created only for fields stored in the DataSource.RequireFetchReasons(). PropagatedFetchReasons []FetchReason + ProvidesData *Object } type GraphCoordinate struct { @@ -505,5 +505,4 @@ var ( _ Fetch = (*SingleFetch)(nil) _ Fetch = (*BatchEntityFetch)(nil) _ Fetch = (*EntityFetch)(nil) - _ Fetch = (*ParallelListItemFetch)(nil) ) diff --git a/v2/pkg/engine/resolve/fetchtree.go b/v2/pkg/engine/resolve/fetchtree.go index f4fd987ce..9bc38497c 100644 --- a/v2/pkg/engine/resolve/fetchtree.go +++ b/v2/pkg/engine/resolve/fetchtree.go @@ -130,17 +130,6 @@ func (n *FetchTreeNode) Trace() *FetchTreeTraceNode { Trace: f.Trace, Path: n.Item.ResponsePath, } - case *ParallelListItemFetch: - trace.Fetch = &FetchTraceNode{ - Kind: "ParallelList", - SourceID: f.Fetch.Info.DataSourceID, - SourceName: f.Fetch.Info.DataSourceName, - Traces: make([]*DataSourceLoadTrace, len(f.Traces)), - Path: n.Item.ResponsePath, - } - for i, t := range f.Traces { - trace.Fetch.Traces[i] = t.Trace - } default: } case FetchTreeNodeKindSequence, FetchTreeNodeKindParallel: @@ -253,20 +242,6 @@ func (n *FetchTreeNode) queryPlan() *FetchTreeQueryPlanNode { queryPlan.Fetch.Query = f.Info.QueryPlan.Query queryPlan.Fetch.Representations = f.Info.QueryPlan.DependsOnFields } - case *ParallelListItemFetch: - queryPlan.Fetch = &FetchTreeQueryPlan{ - Kind: "ParallelList", - FetchID: f.Fetch.FetchDependencies.FetchID, - DependsOnFetchIDs: f.Fetch.FetchDependencies.DependsOnFetchIDs, - SubgraphName: f.Fetch.Info.DataSourceName, - SubgraphID: f.Fetch.Info.DataSourceID, - Path: n.Item.ResponsePath, - } - - if f.Fetch.Info.QueryPlan != nil { - queryPlan.Fetch.Query = f.Fetch.Info.QueryPlan.Query - queryPlan.Fetch.Representations = f.Fetch.Info.QueryPlan.DependsOnFields - } default: } case FetchTreeNodeKindSequence, FetchTreeNodeKindParallel: diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go new file mode 100644 index 000000000..66505a36a --- /dev/null +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -0,0 +1,138 @@ +package resolve + +import ( + "encoding/binary" + "sync" + + "github.com/cespare/xxhash/v2" +) + +// InboundRequestSingleFlight is a sharded goroutine safe single flight implementation to de-couple inbound requests +// It's taking into consideration the normalized operation hash, variables hash and headers hash +// making it robust against collisions +// for scalability, you can add more shards in case the mutexes are a bottleneck +type InboundRequestSingleFlight struct { + shards []requestShard +} + +type requestShard struct { + mu sync.Mutex + m map[uint64]*InflightRequest +} + +const defaultRequestSingleFlightShardCount = 4 + +// NewRequestSingleFlight creates a InboundRequestSingleFlight with the provided +// number of shards. If shardCount <= 0, the default of 4 is used. +func NewRequestSingleFlight(shardCount int) *InboundRequestSingleFlight { + if shardCount <= 0 { + shardCount = defaultRequestSingleFlightShardCount + } + r := &InboundRequestSingleFlight{ + shards: make([]requestShard, shardCount), + } + for i := range r.shards { + r.shards[i] = requestShard{ + m: make(map[uint64]*InflightRequest), + } + } + return r +} + +type InflightRequest struct { + Done chan struct{} + Data []byte + Err error + ID uint64 + HasFollowers bool +} + +// GetOrCreate creates a new InflightRequest or returns an existing (shared) one +// The first caller to create an InflightRequest for a given key is a leader, everyone else a follower +// GetOrCreate blocks until ctx.ctx.Done() returns or InflightRequest.Done is closed +// It returns an error if the leader returned an error +// It returns nil,nil if the inbound request is not eligible for request deduplication +// or if DisableInboundRequestDeduplication is set to true on Context +func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQLResponse) (*InflightRequest, error) { + + if ctx.ExecutionOptions.DisableInboundRequestDeduplication { + return nil, nil + } + + if !response.SingleFlightAllowed() { + return nil, nil + } + + // Derive a robust key from request ID, variables hash and (optional) headers hash + var b [24]byte + binary.LittleEndian.PutUint64(b[0:8], ctx.Request.ID) + binary.LittleEndian.PutUint64(b[8:16], ctx.VariablesHash) + hh := uint64(0) + if ctx.SubgraphHeadersBuilder != nil { + hh = ctx.SubgraphHeadersBuilder.HashAll() + } + binary.LittleEndian.PutUint64(b[16:24], hh) + key := xxhash.Sum64(b[:]) + + shard := r.shardFor(key) + shard.mu.Lock() + req, shared := shard.m[key] + if shared { + req.HasFollowers = true + shard.mu.Unlock() + select { + case <-req.Done: + if req.Err != nil { + return nil, req.Err + } + return req, nil + case <-ctx.ctx.Done(): + return nil, ctx.ctx.Err() + } + } + + req = &InflightRequest{ + Done: make(chan struct{}), + ID: key, + } + + shard.m[key] = req + shard.mu.Unlock() + return req, nil +} + +func (r *InboundRequestSingleFlight) FinishOk(req *InflightRequest, data []byte) { + if req == nil { + return + } + shard := r.shardFor(req.ID) + shard.mu.Lock() + delete(shard.m, req.ID) + hasFollowers := req.HasFollowers + shard.mu.Unlock() + if hasFollowers { + // optimization to only copy when we actually have to + req.Data = make([]byte, len(data)) + copy(req.Data, data) + } + close(req.Done) +} + +func (r *InboundRequestSingleFlight) FinishErr(req *InflightRequest, err error) { + if req == nil { + return + } + shard := r.shardFor(req.ID) + shard.mu.Lock() + delete(shard.m, req.ID) + shard.mu.Unlock() + req.Err = err + close(req.Done) +} + +func (r *InboundRequestSingleFlight) shardFor(key uint64) *requestShard { + // Fast modulo using power-of-two shard count if desired in the future. + // For now, use standard modulo for clarity. + idx := int(key % uint64(len(r.shards))) + return &r.shards[idx] +} diff --git a/v2/pkg/engine/resolve/inputtemplate.go b/v2/pkg/engine/resolve/inputtemplate.go index 82825cac7..e0fc97aa6 100644 --- a/v2/pkg/engine/resolve/inputtemplate.go +++ b/v2/pkg/engine/resolve/inputtemplate.go @@ -1,10 +1,10 @@ package resolve import ( - "bytes" "context" "errors" "fmt" + "io" "github.com/wundergraph/astjson" @@ -36,7 +36,7 @@ type InputTemplate struct { SetTemplateOutputToNullOnVariableNull bool } -func SetInputUndefinedVariables(preparedInput *bytes.Buffer, undefinedVariables []string) error { +func SetInputUndefinedVariables(preparedInput InputTemplateWriter, undefinedVariables []string) error { if len(undefinedVariables) > 0 { output, err := httpclient.SetUndefinedVariables(preparedInput.Bytes(), undefinedVariables) if err != nil { @@ -55,7 +55,16 @@ func SetInputUndefinedVariables(preparedInput *bytes.Buffer, undefinedVariables // to callers; renderSegments intercepts it and writes literal.NULL instead. var errSetTemplateOutputNull = errors.New("set to null") -func (i *InputTemplate) Render(ctx *Context, data *astjson.Value, preparedInput *bytes.Buffer) error { +// InputTemplateWriter is used to decouple Buffer implementations from InputTemplate +// This way, the implementation can easily be swapped, e.g. between bytes.Buffer and similar implementations +type InputTemplateWriter interface { + io.Writer + io.StringWriter + Reset() + Bytes() []byte +} + +func (i *InputTemplate) Render(ctx *Context, data *astjson.Value, preparedInput InputTemplateWriter) error { var undefinedVariables []string if err := i.renderSegments(ctx, data, i.Segments, preparedInput, &undefinedVariables); err != nil { @@ -65,12 +74,12 @@ func (i *InputTemplate) Render(ctx *Context, data *astjson.Value, preparedInput return SetInputUndefinedVariables(preparedInput, undefinedVariables) } -func (i *InputTemplate) RenderAndCollectUndefinedVariables(ctx *Context, data *astjson.Value, preparedInput *bytes.Buffer, undefinedVariables *[]string) (err error) { +func (i *InputTemplate) RenderAndCollectUndefinedVariables(ctx *Context, data *astjson.Value, preparedInput InputTemplateWriter, undefinedVariables *[]string) (err error) { err = i.renderSegments(ctx, data, i.Segments, preparedInput, undefinedVariables) return } -func (i *InputTemplate) renderSegments(ctx *Context, data *astjson.Value, segments []TemplateSegment, preparedInput *bytes.Buffer, undefinedVariables *[]string) (err error) { +func (i *InputTemplate) renderSegments(ctx *Context, data *astjson.Value, segments []TemplateSegment, preparedInput InputTemplateWriter, undefinedVariables *[]string) (err error) { for _, segment := range segments { switch segment.SegmentType { case StaticSegmentType: @@ -107,7 +116,7 @@ func (i *InputTemplate) renderSegments(ctx *Context, data *astjson.Value, segmen return err } -func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables *astjson.Value, segment TemplateSegment, preparedInput *bytes.Buffer) error { +func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables *astjson.Value, segment TemplateSegment, preparedInput InputTemplateWriter) error { value := variables.Get(segment.VariableSourcePath...) if value == nil || value.Type() == astjson.TypeNull { if i.SetTemplateOutputToNullOnVariableNull { @@ -119,11 +128,11 @@ func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables *ast return segment.Renderer.RenderVariable(ctx, value, preparedInput) } -func (i *InputTemplate) renderResolvableObjectVariable(ctx context.Context, objectData *astjson.Value, segment TemplateSegment, preparedInput *bytes.Buffer) error { +func (i *InputTemplate) renderResolvableObjectVariable(ctx context.Context, objectData *astjson.Value, segment TemplateSegment, preparedInput InputTemplateWriter) error { return segment.Renderer.RenderVariable(ctx, objectData, preparedInput) } -func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegment, preparedInput *bytes.Buffer) (variableWasUndefined bool, err error) { +func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegment, preparedInput InputTemplateWriter) (variableWasUndefined bool, err error) { variableSourcePath := segment.VariableSourcePath if len(variableSourcePath) == 1 && ctx.RemapVariables != nil { nameToUse, hasMapping := ctx.RemapVariables[variableSourcePath[0]] @@ -142,7 +151,7 @@ func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegm return false, segment.Renderer.RenderVariable(ctx.Context(), value, preparedInput) } -func (i *InputTemplate) renderHeaderVariable(ctx *Context, path []string, preparedInput *bytes.Buffer) error { +func (i *InputTemplate) renderHeaderVariable(ctx *Context, path []string, preparedInput InputTemplateWriter) error { if len(path) != 1 { return errHeaderPathInvalid } @@ -151,14 +160,20 @@ func (i *InputTemplate) renderHeaderVariable(ctx *Context, path []string, prepar return nil } if len(value) == 1 { - preparedInput.WriteString(value[0]) + if _, err := preparedInput.WriteString(value[0]); err != nil { + return err + } return nil } for j := range value { if j != 0 { - _, _ = preparedInput.Write(literal.COMMA) + if _, err := preparedInput.Write(literal.COMMA); err != nil { + return err + } + } + if _, err := preparedInput.WriteString(value[j]); err != nil { + return err } - preparedInput.WriteString(value[j]) } return nil } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 7a14d61dc..0d23b2e6e 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" @@ -56,11 +57,11 @@ type ResponseInfo struct { // ResponseHeaders contains a clone of the headers of the response from the subgraph. ResponseHeaders http.Header // This should be private as we do not want user's to access the raw responseBody directly - responseBody *bytes.Buffer + responseBody []byte } -func (ri *ResponseInfo) GetResponseBody() string { - return ri.responseBody.String() +func (r *ResponseInfo) GetResponseBody() string { + return string(r.responseBody) } func newResponseInfo(res *result, subgraphError error) *ResponseInfo { @@ -91,35 +92,21 @@ func newResponseInfo(res *result, subgraphError error) *ResponseInfo { return responseInfo } -// batchStats represents an index map for batched items. -// It is used to ensure that the correct json values will be merged with the correct items from the batch. -// -// Example: -// [[0],[1],[0],[1]] We originally have 4 items, but we have 2 unique indexes (0 and 1). -// This means we are deduplicating 2 items by merging them from their response entity indexes. -// 0 -> 0, 1 -> 1, 2 -> 0, 3 -> 1 -type batchStats [][]int - -// getUniqueIndexes returns the number of unique indexes in the batchStats. -// This is used to ensure that we can provide a valid error message in case of differing array lengths. -func (b *batchStats) getUniqueIndexes() int { - uniqueIndexes := make(map[int]struct{}) - for _, bi := range *b { - for _, index := range bi { - if index < 0 { - continue - } - uniqueIndexes[index] = struct{}{} - } - } - - return len(uniqueIndexes) -} - type result struct { - postProcessing PostProcessingConfiguration - out *bytes.Buffer - batchStats batchStats + postProcessing PostProcessingConfiguration + // batchStats represents per-unique-batch-item merge targets. + // Outer slice index corresponds to the unique representation index in the request batch, + // and the inner slice contains all target values that should be merged with the response at that index. + // + // Example: + // For 4 original items that deduplicate to 2 unique representations, we might have: + // [ + // + // [item0, item2], // merge response[0] into item0 and item2 + // [item1, item3], // merge response[1] into item1 and item3 + // + // ] + batchStats [][]*astjson.Value fetchSkipped bool nestedMergeItems []*result @@ -138,16 +125,30 @@ type result struct { loaderHookContext context.Context httpResponseContext *httpclient.ResponseContext + // out is the subgraph response body + out []byte + singleFlightStats *singleFlightStats + tools *batchEntityTools + + cache LoaderCache + cacheMustBeUpdated bool + cacheKeys []*CacheKey + cacheSkipFetch bool + cacheConfig FetchCacheConfiguration } -func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { - r.postProcessing = postProcessing +func (l *Loader) createOrInitResult(res *result, postProcessing PostProcessingConfiguration, info *FetchInfo) *result { + if res == nil { + res = &result{} + } + res.postProcessing = postProcessing if info != nil { - r.ds = DataSourceInfo{ + res.ds = DataSourceInfo{ ID: info.DataSourceID, Name: info.DataSourceName, } } + return res } func IsIntrospectionDataSource(dataSourceID string) bool { @@ -159,6 +160,8 @@ type Loader struct { ctx *Context info *GraphQLResponseInfo + caches map[string]LoaderCache + propagateSubgraphErrors bool propagateSubgraphStatusCodes bool subgraphErrorPropagationMode SubgraphErrorPropagationMode @@ -180,6 +183,19 @@ type Loader struct { validateRequiredExternalFields bool taintedObjs taintedObjects + + // jsonArena is the arena to allocation json, supplied by the Resolver + // Disclaimer: this arena is NOT thread safe! + // Only use from main goroutine + // Don't Reset or Release, the Resolver handles this + // Disclaimer: When parsing json into the arena, the underlying bytes must also be allocated on the arena! + // This is very important to "tie" their lifecycles together + // If you're not doing this, you will see segfaults + // Example of correct usage in func "mergeResult" + jsonArena arena.Arena + // sf is the SubgraphRequestSingleFlight object shared across all client requests + // it's thread safe and can be used to de-duplicate subgraph requests + sf *SubgraphRequestSingleFlight } func (l *Loader) Free() { @@ -218,6 +234,12 @@ func (l *Loader) resolveParallel(nodes []*FetchTreeNode) error { return nil } results := make([]*result, len(nodes)) + defer func() { + for i := range results { + // no-op if tools == nil + batchEntityToolPool.Put(results[i].tools) + } + }() itemsItems := make([][]*astjson.Value, len(nodes)) g, ctx := errgroup.WithContext(l.ctx.ctx) for i := range nodes { @@ -280,89 +302,66 @@ func (l *Loader) resolveSingle(item *FetchItem) error { switch f := item.Fetch.(type) { case *SingleFetch: - res := &result{ - out: &bytes.Buffer{}, - } - err := l.loadSingleFetch(l.ctx.ctx, f, item, items, res) + res := l.createOrInitResult(nil, f.PostProcessing, f.Info) + skip, err := l.tryCacheLoadFetch(l.ctx.ctx, f.Info, f.Caching, items, res) if err != nil { - return err + return errors.WithStack(err) + } + if !skip { + err = l.loadSingleFetch(l.ctx.ctx, f, item, items, res) + if err != nil { + return err + } } err = l.mergeResult(item, res, items) if l.ctx.LoaderHooks != nil { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) } - return err case *BatchEntityFetch: - res := &result{ - out: &bytes.Buffer{}, - } - err := l.loadBatchEntityFetch(l.ctx.ctx, item, f, items, res) + res := l.createOrInitResult(nil, f.PostProcessing, f.Info) + defer batchEntityToolPool.Put(res.tools) + skip, err := l.tryCacheLoadFetch(l.ctx.ctx, f.Info, f.Caching, items, res) if err != nil { return errors.WithStack(err) } + if !skip { + err = l.loadBatchEntityFetch(l.ctx.ctx, item, f, items, res) + if err != nil { + return errors.WithStack(err) + } + } err = l.mergeResult(item, res, items) if l.ctx.LoaderHooks != nil { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) } return err case *EntityFetch: - res := &result{ - out: &bytes.Buffer{}, - } - err := l.loadEntityFetch(l.ctx.ctx, item, f, items, res) + res := l.createOrInitResult(nil, f.PostProcessing, f.Info) + skip, err := l.tryCacheLoadFetch(l.ctx.ctx, f.Info, f.Caching, items, res) if err != nil { return errors.WithStack(err) } + if !skip { + err = l.loadEntityFetch(l.ctx.ctx, item, f, items, res) + if err != nil { + return errors.WithStack(err) + } + } err = l.mergeResult(item, res, items) if l.ctx.LoaderHooks != nil { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) } return err - case *ParallelListItemFetch: - results := make([]*result, len(items)) - if l.ctx.TracingOptions.Enable { - f.Traces = make([]*SingleFetch, len(items)) - } - g, ctx := errgroup.WithContext(l.ctx.ctx) - for i := range items { - i := i - results[i] = &result{ - out: &bytes.Buffer{}, - } - if l.ctx.TracingOptions.Enable { - f.Traces[i] = new(SingleFetch) - *f.Traces[i] = *f.Fetch - g.Go(func() error { - return l.loadFetch(ctx, f.Traces[i], item, items[i:i+1], results[i]) - }) - continue - } - g.Go(func() error { - return l.loadFetch(ctx, f.Fetch, item, items[i:i+1], results[i]) - }) - } - err := g.Wait() - if err != nil { - return errors.WithStack(err) - } - for i := range results { - err = l.mergeResult(item, results[i], items[i:i+1]) - if l.ctx.LoaderHooks != nil { - l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].ds, newResponseInfo(results[i], l.ctx.subgraphErrors)) - } - if err != nil { - return errors.WithStack(err) - } - } - return nil default: return nil } } func (l *Loader) selectItemsForPath(path []FetchItemPathElement) []*astjson.Value { - items := []*astjson.Value{l.resolvable.data} + // Use arena allocation for the initial items slice + items := arena.AllocateSlice[*astjson.Value](l.jsonArena, 1, 1) + items[0] = l.resolvable.data if len(path) == 0 { return l.taintedObjs.filterOutTainted(items) } @@ -370,7 +369,7 @@ func (l *Loader) selectItemsForPath(path []FetchItemPathElement) []*astjson.Valu if len(items) == 0 { break } - items = selectItems(items, path[i]) + items = selectItems(l.jsonArena, items, path[i]) } return l.taintedObjs.filterOutTainted(items) } @@ -391,7 +390,7 @@ func isItemAllowedByTypename(obj *astjson.Value, typeNames []string) bool { return slices.Contains(typeNames, __typeNameStr) } -func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjson.Value { +func selectItems(a arena.Arena, items []*astjson.Value, element FetchItemPathElement) []*astjson.Value { if len(items) == 0 { return nil } @@ -413,7 +412,7 @@ func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjso } return []*astjson.Value{field} } - selected := make([]*astjson.Value, 0, len(items)) + selected := arena.AllocateSlice[*astjson.Value](a, 0, len(items)) for _, item := range items { if !isItemAllowedByTypename(item, element.TypeNames) { continue @@ -423,15 +422,15 @@ func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjso continue } if field.Type() == astjson.TypeArray { - selected = append(selected, field.GetArray()...) + selected = arena.SliceAppend(a, selected, field.GetArray()...) continue } - selected = append(selected, field) + selected = arena.SliceAppend(a, selected, field) } return selected } -func itemsData(items []*astjson.Value) *astjson.Value { +func (l *Loader) itemsData(items []*astjson.Value) *astjson.Value { if len(items) == 0 { return astjson.NullValue } @@ -442,50 +441,169 @@ func itemsData(items []*astjson.Value) *astjson.Value { // however, itemsData can be called concurrently, so this might result in a race arr := astjson.MustParseBytes([]byte(`[]`)) for i, item := range items { - arr.SetArrayItem(i, item) + arr.SetArrayItem(nil, i, item) } return arr } -func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchItem, items []*astjson.Value, res *result) error { - switch f := fetch.(type) { - case *SingleFetch: - res.out = &bytes.Buffer{} - return l.loadSingleFetch(ctx, f, fetchItem, items, res) - case *ParallelListItemFetch: - results := make([]*result, len(items)) - if l.ctx.TracingOptions.Enable { - f.Traces = make([]*SingleFetch, len(items)) +type CacheEntry struct { + Key string + Value []byte +} + +type LoaderCache interface { + Get(ctx context.Context, keys []string) ([]*CacheEntry, error) + Set(ctx context.Context, entries []*CacheEntry, ttl time.Duration) error + Delete(ctx context.Context, keys []string) error +} + +// extractCacheKeysStrings extracts all unique cache key strings from CacheKeys +// If includePrefix is true and subgraphName is provided, keys are prefixed with the subgraph header hash. +func (l *Loader) extractCacheKeysStrings(a arena.Arena, cacheKeys []*CacheKey) []string { + if len(cacheKeys) == 0 { + return nil + } + out := arena.AllocateSlice[string](a, 0, len(cacheKeys)) + for i := range cacheKeys { + for j := range cacheKeys[i].Keys { + l := len(cacheKeys[i].Keys[j]) + key := arena.AllocateSlice[byte](a, 0, l) + key = arena.SliceAppend(a, key, unsafebytes.StringToBytes(cacheKeys[i].Keys[j])...) + out = arena.SliceAppend(a, out, unsafebytes.BytesToString(key)) + } + } + return out +} + +// populateFromCache populates CacheKey.FromCache fields from cache entries +// If includePrefix is true and subgraphName is provided, keys are looked up with the subgraph header hash prefix. +func (l *Loader) populateFromCache(a arena.Arena, cacheKeys []*CacheKey, entries []*CacheEntry) (err error) { + for i := range entries { + if entries[i] == nil || entries[i].Value == nil { + continue } - g, ctx := errgroup.WithContext(l.ctx.ctx) - for i := range items { - i := i - results[i] = &result{ - out: &bytes.Buffer{}, + for j := range cacheKeys { + for k := range cacheKeys[j].Keys { + if cacheKeys[j].Keys[k] == entries[i].Key { + cacheKeys[j].FromCache, err = astjson.ParseBytesWithArena(a, entries[i].Value) + if err != nil { + return errors.WithStack(err) + } + } } - if l.ctx.TracingOptions.Enable { - f.Traces[i] = new(SingleFetch) - *f.Traces[i] = *f.Fetch - g.Go(func() error { - return l.loadFetch(ctx, f.Traces[i], fetchItem, items[i:i+1], results[i]) - }) + } + } + return nil +} + +// cacheKeysToEntries converts CacheKeys to CacheEntries for storage +// For each CacheKey, creates entries for all its KeyEntries with the same value +// If includePrefix is true and subgraphName is provided, keys are prefixed with the subgraph header hash. +func (l *Loader) cacheKeysToEntries(a arena.Arena, cacheKeys []*CacheKey) ([]*CacheEntry, error) { + out := arena.AllocateSlice[*CacheEntry](a, 0, len(cacheKeys)) + buf := arena.AllocateSlice[byte](a, 64, 64) + for i := range cacheKeys { + for j := range cacheKeys[i].Keys { + if cacheKeys[i].Item == nil { continue } - g.Go(func() error { - return l.loadFetch(ctx, f.Fetch, fetchItem, items[i:i+1], results[i]) - }) + buf = cacheKeys[i].Item.MarshalTo(buf[:0]) + entry := &CacheEntry{ + Key: cacheKeys[i].Keys[j], + Value: arena.AllocateSlice[byte](a, len(buf), len(buf)), + } + copy(entry.Value, buf) + out = arena.SliceAppend(a, out, entry) } - err := g.Wait() + } + return out, nil +} + +func (l *Loader) tryCacheLoadFetch(ctx context.Context, info *FetchInfo, cfg FetchCacheConfiguration, inputItems []*astjson.Value, res *result) (skipFetch bool, err error) { + if !cfg.Enabled { + return false, nil + } + if cfg.CacheKeyTemplate == nil { + return false, nil + } + if l.caches == nil { + return false, nil + } + res.cacheConfig = cfg + res.cache = l.caches[cfg.CacheName] + if res.cache == nil { + return false, nil + } + var prefix string + if cfg.IncludeSubgraphHeaderPrefix && l.ctx.SubgraphHeadersBuilder != nil { + _, headersHash := l.ctx.SubgraphHeadersBuilder.HeadersForSubgraph(info.DataSourceName) + var buf [20]byte + b := strconv.AppendUint(buf[:0], headersHash, 10) + prefix = string(b) + } + // Generate cache keys for all items at once + res.cacheKeys, err = cfg.CacheKeyTemplate.RenderCacheKeys(nil, l.ctx, inputItems, prefix) + if err != nil { + return false, err + } + if len(res.cacheKeys) == 0 { + // If no cache keys were generated, we skip the cache + return false, nil + } + cacheKeyStrings := l.extractCacheKeysStrings(nil, res.cacheKeys) + if len(cacheKeyStrings) == 0 { + return false, nil + } + // Get cache entries + cacheEntries, err := res.cache.Get(ctx, cacheKeyStrings) + if err != nil { + return false, err + } + // Populate FromCache fields in CacheKeys + err = l.populateFromCache(nil, res.cacheKeys, cacheEntries) + if err != nil { + return false, err + } + canSkip := l.canSkipFetch(info, res) + if canSkip { + res.cacheSkipFetch = true + return true, nil + } + res.cacheMustBeUpdated = true + return false, nil +} + +func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchItem, items []*astjson.Value, res *result) error { + switch f := fetch.(type) { + case *SingleFetch: + res = l.createOrInitResult(res, f.PostProcessing, f.Info) + skip, err := l.tryCacheLoadFetch(ctx, f.Info, f.Caching, items, res) if err != nil { return errors.WithStack(err) } - res.nestedMergeItems = results - return nil + if skip { + return nil + } + return l.loadSingleFetch(ctx, f, fetchItem, items, res) case *EntityFetch: - res.out = &bytes.Buffer{} + res = l.createOrInitResult(res, f.PostProcessing, f.Info) + skip, err := l.tryCacheLoadFetch(ctx, f.Info, f.Caching, items, res) + if err != nil { + return errors.WithStack(err) + } + if skip { + return nil + } return l.loadEntityFetch(ctx, fetchItem, f, items, res) case *BatchEntityFetch: - res.out = &bytes.Buffer{} + res = l.createOrInitResult(res, f.PostProcessing, f.Info) + skip, err := l.tryCacheLoadFetch(ctx, f.Info, f.Caching, items, res) + if err != nil { + return errors.WithStack(err) + } + if skip { + return nil + } return l.loadBatchEntityFetch(ctx, fetchItem, f, items, res) } return nil @@ -517,42 +635,32 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson if res.err != nil { return l.renderErrorsFailedToFetch(fetchItem, res, failedToFetchNoReason) } - if res.authorizationRejected { - err := l.renderAuthorizationRejectedErrors(fetchItem, res) - if err != nil { - return err - } - trueValue := astjson.MustParse(`true`) - skipErrorsPath := make([]string, len(res.postProcessing.MergePath)+1) - copy(skipErrorsPath, res.postProcessing.MergePath) - skipErrorsPath[len(skipErrorsPath)-1] = "__skipErrors" - for _, item := range items { - astjson.SetValue(item, trueValue, skipErrorsPath...) - } - return nil + if rejected, err := l.evaluateRejected(fetchItem, res, items); err != nil || rejected { + return err } - if res.rateLimitRejected { - err := l.renderRateLimitRejectedErrors(fetchItem, res) - if err != nil { - return err - } - trueValue := astjson.MustParse(`true`) - skipErrorsPath := make([]string, len(res.postProcessing.MergePath)+1) - copy(skipErrorsPath, res.postProcessing.MergePath) - skipErrorsPath[len(skipErrorsPath)-1] = "__skipErrors" - for _, item := range items { - astjson.SetValue(item, trueValue, skipErrorsPath...) + if res.cacheSkipFetch { + // Merge cached data into items + for _, key := range res.cacheKeys { + // Merge cached data into item + _, _, err := astjson.MergeValues(l.jsonArena, key.Item, key.FromCache) + if err != nil { + return l.renderErrorsFailedToFetch(fetchItem, res, "invalid cache item") + } } return nil } if res.fetchSkipped { return nil } - if res.out.Len() == 0 { + if len(res.out) == 0 { return l.renderErrorsFailedToFetch(fetchItem, res, emptyGraphQLResponse) } - - response, err := astjson.ParseBytesWithoutCache(res.out.Bytes()) + // before parsing bytes with an arena.Arena, it's important to first allocate the bytes ON the same arena.Arena + // this ties their lifecycles together + // if you don't do this, you'll get segfaults + slice := arena.AllocateSlice[byte](l.jsonArena, len(res.out), len(res.out)) + copy(slice, res.out) + response, err := astjson.ParseBytesWithArena(l.jsonArena, slice) if err != nil { // Fall back to status code if parsing fails and non-2XX if (res.statusCode > 0 && res.statusCode < 200) || res.statusCode >= 300 { @@ -623,7 +731,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson // no data return nil } - + defer l.updateCache(res) if len(items) == 0 { // If the data is set, it must be an object according to GraphQL over HTTP spec if responseData.Type() != astjson.TypeObject { @@ -633,7 +741,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson return nil } if len(items) == 1 && res.batchStats == nil { - items[0], _, err = astjson.MergeValuesWithPath(items[0], responseData, res.postProcessing.MergePath...) + items[0], _, err = astjson.MergeValuesWithPath(l.jsonArena, items[0], responseData, res.postProcessing.MergePath...) if err != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, @@ -652,26 +760,23 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson } if res.batchStats != nil { - uniqueIndexes := res.batchStats.getUniqueIndexes() - if uniqueIndexes != len(batch) { - return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, uniqueIndexes, len(batch))) + if len(res.batchStats) != len(batch) { + return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, len(res.batchStats), len(batch))) } - for i, stats := range res.batchStats { - for _, idx := range stats { - if idx == -1 { - continue - } - items[i], _, err = astjson.MergeValuesWithPath(items[i], batch[idx], res.postProcessing.MergePath...) - if err != nil { + for batchIndex, targets := range res.batchStats { + src := batch[batchIndex] + for _, target := range targets { + _, _, mErr := astjson.MergeValuesWithPath(l.jsonArena, target, src, res.postProcessing.MergePath...) + if mErr != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, - Reason: err, + Reason: mErr, Path: fetchItem.ResponsePath, }) } - if slices.Contains(taintedIndices, idx) { - l.taintedObjs.add(items[i]) + if slices.Contains(taintedIndices, batchIndex) { + l.taintedObjs.add(target) } } } @@ -683,7 +788,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson } for i := range items { - items[i], _, err = astjson.MergeValuesWithPath(items[i], batch[i], res.postProcessing.MergePath...) + items[i], _, err = astjson.MergeValuesWithPath(l.jsonArena, items[i], batch[i], res.postProcessing.MergePath...) if err != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, @@ -695,15 +800,47 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson l.taintedObjs.add(items[i]) } } + return nil } +func (l *Loader) evaluateRejected(fetchItem *FetchItem, res *result, items []*astjson.Value) (bool, error) { + if res.authorizationRejected { + err := l.renderAuthorizationRejectedErrors(fetchItem, res) + if err != nil { + return false, err + } + l.setSkipErrors(res, items) + return true, nil + } + if res.rateLimitRejected { + err := l.renderRateLimitRejectedErrors(fetchItem, res) + if err != nil { + return false, err + } + l.setSkipErrors(res, items) + return true, nil + } + return false, nil +} + +func (l *Loader) setSkipErrors(res *result, items []*astjson.Value) { + trueValue := astjson.TrueValue(l.jsonArena) + skipErrorsPath := make([]string, len(res.postProcessing.MergePath)+1) + copy(skipErrorsPath, res.postProcessing.MergePath) + skipErrorsPath[len(skipErrorsPath)-1] = "__skipErrors" + for _, item := range items { + astjson.SetValue(item, trueValue, skipErrorsPath...) + } +} + var ( errorsInvalidInputHeader = []byte(`{"errors":[{"message":"Failed to render Fetch Input","path":[`) errorsInvalidInputFooter = []byte(`]}]}`) ) -func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem, out *bytes.Buffer) error { +func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem) []byte { + out := bytes.NewBuffer(nil) elements := fetchItem.ResponsePathElements if len(elements) > 0 && elements[len(elements)-1] == "@" { elements = elements[:len(elements)-1] @@ -721,7 +858,29 @@ func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem, out *bytes.Buffe _, _ = out.Write(quote) } _, _ = out.Write(errorsInvalidInputFooter) - return nil + return out.Bytes() +} + +func (l *Loader) updateCache(res *result) { + if res.cache == nil || len(res.cacheKeys) == 0 || !res.cacheMustBeUpdated { + return + } + + // Convert CacheKeys to CacheEntries + cacheEntries, err := l.cacheKeysToEntries(l.jsonArena, res.cacheKeys) + if err != nil { + fmt.Printf("error converting cache keys to entries: %s", err) + return + } + + if len(cacheEntries) == 0 { + return + } + + err = res.cache.Set(l.ctx.ctx, cacheEntries, res.cacheConfig.TTL) + if err != nil { + fmt.Printf("error cache.Set: %s", err) + } } func (l *Loader) appendSubgraphError(res *result, fetchItem *FetchItem, value *astjson.Value, values []*astjson.Value) error { @@ -749,7 +908,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V values := value.GetArray() l.optionallyOmitErrorLocations(values) if l.rewriteSubgraphErrorPaths { - rewriteErrorPaths(fetchItem, values) + rewriteErrorPaths(l.jsonArena, fetchItem, values) } l.optionallyEnsureExtensionErrorCode(values) @@ -778,7 +937,10 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V return err } } - + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() // If the error propagation mode is pass-through, we append the errors to the root array l.resolvable.errors.AppendArrayItems(value) return nil @@ -792,7 +954,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V } // Wrap mode (default) - errorObject, err := astjson.ParseWithoutCache(l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, failedToFetchNoReason)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, failedToFetchNoReason)) if err != nil { return err } @@ -815,7 +977,10 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V if err := l.addApolloRouterCompatibilityError(res); err != nil { return err } - + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil @@ -861,17 +1026,17 @@ func (l *Loader) optionallyEnsureExtensionErrorCode(values []*astjson.Value) { switch extensions.Type() { case astjson.TypeObject: if !extensions.Exists("code") { - extensions.Set("code", l.resolvable.astjsonArena.NewString(l.defaultErrorExtensionCode)) + extensions.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) } case astjson.TypeNull: - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("code", l.resolvable.astjsonArena.NewString(l.defaultErrorExtensionCode)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } else { - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("code", l.resolvable.astjsonArena.NewString(l.defaultErrorExtensionCode)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } } @@ -888,16 +1053,16 @@ func (l *Loader) optionallyAttachServiceNameToErrorExtension(values []*astjson.V extensions := value.Get("extensions") switch extensions.Type() { case astjson.TypeObject: - extensions.Set("serviceName", l.resolvable.astjsonArena.NewString(serviceName)) + extensions.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) case astjson.TypeNull: - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("serviceName", l.resolvable.astjsonArena.NewString(serviceName)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } else { - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("serviceName", l.resolvable.astjsonArena.NewString(serviceName)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } } @@ -951,7 +1116,7 @@ func (l *Loader) optionallyOmitErrorLocations(values []*astjson.Value) { // - Drops the numeric index immediately following "_entities". // - Converts all subsequent numeric segments to strings (e.g., 1 -> "1"). // - Skips non-string/non-number segments. -func rewriteErrorPaths(fetchItem *FetchItem, values []*astjson.Value) { +func rewriteErrorPaths(a arena.Arena, fetchItem *FetchItem, values []*astjson.Value) { pathPrefix := make([]string, len(fetchItem.ResponsePathElements)) copy(pathPrefix, fetchItem.ResponsePathElements) // remove the trailing @ in case we're in an array as it looks weird in the path @@ -993,11 +1158,11 @@ func rewriteErrorPaths(fetchItem *FetchItem, values []*astjson.Value) { } } newPathJSON, _ := json.Marshal(newPath) - pathBytes, err := astjson.ParseBytesWithoutCache(newPathJSON) + pathBytes, err := astjson.ParseBytesWithArena(a, newPathJSON) if err != nil { continue } - value.Set("path", pathBytes) + value.Set(a, "path", pathBytes) break } } @@ -1018,17 +1183,17 @@ func (l *Loader) setSubgraphStatusCode(values []*astjson.Value, statusCode int) if extensions.Type() != astjson.TypeObject { continue } - v, err := astjson.ParseWithoutCache(strconv.Itoa(statusCode)) + v, err := astjson.ParseWithArena(l.jsonArena, strconv.Itoa(statusCode)) if err != nil { continue } - extensions.Set("statusCode", v) + extensions.Set(l.jsonArena, "statusCode", v) } else { - v, err := astjson.ParseWithoutCache(`{"statusCode":` + strconv.Itoa(statusCode) + `}`) + v, err := astjson.ParseWithArena(l.jsonArena, `{"statusCode":`+strconv.Itoa(statusCode)+`}`) if err != nil { continue } - value.Set("extensions", v) + value.Set(l.jsonArena, "extensions", v) } } } @@ -1065,11 +1230,14 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { } } }`, res.ds.Name, http.StatusText(res.statusCode), res.statusCode) - apolloRouterStatusError, err := astjson.ParseWithoutCache(apolloRouterStatusErrorJSON) + apolloRouterStatusError, err := astjson.ParseWithArena(l.jsonArena, apolloRouterStatusErrorJSON) if err != nil { return err } - + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, apolloRouterStatusError) return nil @@ -1078,22 +1246,30 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error { path := l.renderAtPathErrorPart(fetchItem.ResponsePath) msg := fmt.Sprintf(`{"message":"Failed to obtain field dependencies from Subgraph '%s'%s."}`, res.ds.Name, path) - errorObject, err := astjson.ParseWithoutCache(msg) + errorObject, err := astjson.ParseWithArena(l.jsonArena, msg) if err != nil { return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } func (l *Loader) renderErrorsFailedToFetch(fetchItem *FetchItem, res *result, reason string) error { l.ctx.appendSubgraphErrors(res.err, NewSubgraphError(res.ds, fetchItem.ResponsePath, reason, res.statusCode)) - errorObject, err := astjson.ParseWithoutCache(l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, reason)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, reason)) if err != nil { return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1106,13 +1282,16 @@ func (l *Loader) renderErrorsStatusFallback(fetchItem *FetchItem, res *result, s l.ctx.appendSubgraphErrors(res.err, NewSubgraphError(res.ds, fetchItem.ResponsePath, reason, res.statusCode)) - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"%s"}`, reason)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"%s"}`, reason)) if err != nil { return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) - + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1137,16 +1316,20 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re } pathPart := l.renderAtPathErrorPart(fetchItem.ResponsePath) extensionErrorCode := fmt.Sprintf(`"extensions":{"code":"%s"}`, errorcodes.UnauthorizedFieldOrType) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() if res.ds.Name == "" { for _, reason := range res.authorizationRejectedReasons { if reason == "" { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s.",%s}`, pathPart, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s.",%s}`, pathPart, extensionErrorCode)) if err != nil { continue } astjson.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s, Reason: %s.",%s}`, pathPart, reason, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s, Reason: %s.",%s}`, pathPart, reason, extensionErrorCode)) if err != nil { continue } @@ -1156,13 +1339,13 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re } else { for _, reason := range res.authorizationRejectedReasons { if reason == "" { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s.",%s}`, res.ds.Name, pathPart, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s.",%s}`, res.ds.Name, pathPart, extensionErrorCode)) if err != nil { continue } astjson.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s, Reason: %s.",%s}`, res.ds.Name, pathPart, reason, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s, Reason: %s.",%s}`, res.ds.Name, pathPart, reason, extensionErrorCode)) if err != nil { continue } @@ -1182,39 +1365,43 @@ func (l *Loader) renderRateLimitRejectedErrors(fetchItem *FetchItem, res *result ) if res.ds.Name == "" { if res.rateLimitRejectedReason == "" { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart)) if err != nil { return err } } else { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason)) if err != nil { return err } } } else { if res.rateLimitRejectedReason == "" { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart)) if err != nil { return err } } else { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason)) if err != nil { return err } } } if l.ctx.RateLimitOptions.ErrorExtensionCode.Enabled { - extension, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code)) + extension, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code)) if err != nil { return err } - errorObject, _, err = astjson.MergeValuesWithPath(errorObject, extension, "extensions") + errorObject, _, err = astjson.MergeValuesWithPath(l.jsonArena, errorObject, extension, "extensions") if err != nil { return err } } + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1284,10 +1471,8 @@ func (l *Loader) validatePreFetch(input []byte, info *FetchInfo, res *result) (a } func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchItem *FetchItem, items []*astjson.Value, res *result) error { - res.init(fetch.PostProcessing, fetch.Info) - buf := &bytes.Buffer{} - - inputData := itemsData(items) + buf := bytes.NewBuffer(nil) + inputData := l.itemsData(items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && inputData != nil { @@ -1309,7 +1494,8 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI err := fetch.InputTemplate.Render(l.ctx, inputData, buf) if err != nil { - return l.renderErrorsInvalidInput(fetchItem, res.out) + res.out = l.renderErrorsInvalidInput(fetchItem) + return nil } fetchInput := buf.Bytes() allowed, err := l.validatePreFetch(fetchInput, fetch.Info, res) @@ -1323,37 +1509,8 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI return nil } -var ( - entityFetchPool = sync.Pool{ - New: func() any { - return &entityFetchBuffer{ - item: &bytes.Buffer{}, - preparedInput: &bytes.Buffer{}, - } - }, - } -) - -type entityFetchBuffer struct { - item *bytes.Buffer - preparedInput *bytes.Buffer -} - -func acquireEntityFetchBuffer() *entityFetchBuffer { - return entityFetchPool.Get().(*entityFetchBuffer) -} - -func releaseEntityFetchBuffer(buf *entityFetchBuffer) { - buf.item.Reset() - buf.preparedInput.Reset() - entityFetchPool.Put(buf) -} - func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *EntityFetch, items []*astjson.Value, res *result) error { - res.init(fetch.PostProcessing, fetch.Info) - buf := acquireEntityFetchBuffer() - defer releaseEntityFetchBuffer(buf) - input := itemsData(items) + input := l.itemsData(items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && input != nil { @@ -1361,14 +1518,17 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc } } + preparedInput := bytes.NewBuffer(nil) + item := bytes.NewBuffer(nil) + var undefinedVariables []string - err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = fetch.Input.Item.Render(l.ctx, input, buf.item) + err = fetch.Input.Item.Render(l.ctx, input, item) if err != nil { if fetch.Input.SkipErrItem { // skip fetch on render item error @@ -1380,7 +1540,7 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc } return errors.WithStack(err) } - renderedItem := buf.item.Bytes() + renderedItem := item.Bytes() if bytes.Equal(renderedItem, null) { // skip fetch if item is null res.fetchSkipped = true @@ -1399,17 +1559,17 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc return nil } } - _, _ = buf.item.WriteTo(buf.preparedInput) - err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + _, _ = item.WriteTo(preparedInput) + err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = SetInputUndefinedVariables(buf.preparedInput, undefinedVariables) + err = SetInputUndefinedVariables(preparedInput, undefinedVariables) if err != nil { return errors.WithStack(err) } - fetchInput := buf.preparedInput.Bytes() + fetchInput := preparedInput.Bytes() if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchItem, fetchInput, fetch.Trace) @@ -1427,71 +1587,92 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc return nil } -var ( - batchEntityFetchPool = sync.Pool{} -) +type batchEntityTools struct { + keyGen *xxhash.Digest + batchHashToIndex map[uint64]int + a arena.Arena +} + +func (b *batchEntityTools) reset() { + b.keyGen.Reset() + b.a.Reset() + for i := range b.batchHashToIndex { + delete(b.batchHashToIndex, i) + } +} -type batchEntityFetchBuffer struct { - preparedInput *bytes.Buffer - itemInput *bytes.Buffer - keyGen *xxhash.Digest +type _batchEntityToolPool struct { + pool sync.Pool } -func acquireBatchEntityFetchBuffer() *batchEntityFetchBuffer { - buf := batchEntityFetchPool.Get() - if buf == nil { - return &batchEntityFetchBuffer{ - preparedInput: &bytes.Buffer{}, - itemInput: &bytes.Buffer{}, - keyGen: xxhash.New(), +func (p *_batchEntityToolPool) Get(items int) *batchEntityTools { + item := p.pool.Get() + if item == nil { + return &batchEntityTools{ + keyGen: xxhash.New(), + batchHashToIndex: make(map[uint64]int, items), + a: arena.NewMonotonicArena(arena.WithMinBufferSize(1024)), } } - return buf.(*batchEntityFetchBuffer) + return item.(*batchEntityTools) } -func releaseBatchEntityFetchBuffer(buf *batchEntityFetchBuffer) { - buf.preparedInput.Reset() - buf.itemInput.Reset() - buf.keyGen.Reset() - batchEntityFetchPool.Put(buf) +func (p *_batchEntityToolPool) Put(item *batchEntityTools) { + if item == nil { + return + } + item.reset() + p.pool.Put(item) } -func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *BatchEntityFetch, items []*astjson.Value, res *result) error { - res.init(fetch.PostProcessing, fetch.Info) - - buf := acquireBatchEntityFetchBuffer() - defer releaseBatchEntityFetchBuffer(buf) +var ( + batchEntityToolPool = _batchEntityToolPool{} +) +func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *BatchEntityFetch, items []*astjson.Value, res *result) error { if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && len(items) != 0 { - data := itemsData(items) + data := l.itemsData(items) if data != nil { fetch.Trace.RawInputData, _ = l.compactJSON(data.MarshalTo(nil)) } } } + res.tools = batchEntityToolPool.Get(len(items)) + preparedInput := arena.NewArenaBuffer(res.tools.a) + itemInput := arena.NewArenaBuffer(res.tools.a) + batchStats := arena.AllocateSlice[[]*astjson.Value](res.tools.a, 0, len(items)) + defer func() { + // we need to clear the batchStats slice to avoid memory corruption + // once the outer func returns, we must not keep pointers to items on the arena + for i := range batchStats { + // nolint:ineffassign + batchStats[i] = nil + } + // nolint:ineffassign + batchStats = nil + }() + + // I tried using arena here, but it only worsened the situation var undefinedVariables []string - err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - res.batchStats = make(batchStats, len(items)) - itemHashes := make([]uint64, 0, len(items)) batchItemIndex := 0 addSeparator := false WithNextItem: for i, item := range items { for j := range fetch.Input.Items { - buf.itemInput.Reset() - err = fetch.Input.Items[j].Render(l.ctx, item, buf.itemInput) + itemInput.Reset() + err = fetch.Input.Items[j].Render(l.ctx, item, itemInput) if err != nil { if fetch.Input.SkipErrItems { err = nil // nolint:ineffassign - res.batchStats[i] = append(res.batchStats[i], -1) continue } if l.ctx.TracingOptions.Enable { @@ -1499,39 +1680,38 @@ WithNextItem: } return errors.WithStack(err) } - if fetch.Input.SkipNullItems && buf.itemInput.Len() == 4 && bytes.Equal(buf.itemInput.Bytes(), null) { - res.batchStats[i] = append(res.batchStats[i], -1) + if fetch.Input.SkipNullItems && itemInput.Len() == 4 && bytes.Equal(itemInput.Bytes(), null) { continue } - if fetch.Input.SkipEmptyObjectItems && buf.itemInput.Len() == 2 && bytes.Equal(buf.itemInput.Bytes(), emptyObject) { - res.batchStats[i] = append(res.batchStats[i], -1) + if fetch.Input.SkipEmptyObjectItems && itemInput.Len() == 2 && bytes.Equal(itemInput.Bytes(), emptyObject) { continue } - buf.keyGen.Reset() - _, _ = buf.keyGen.Write(buf.itemInput.Bytes()) - itemHash := buf.keyGen.Sum64() - for k := range itemHashes { - if itemHashes[k] == itemHash { - res.batchStats[i] = append(res.batchStats[i], k) - continue WithNextItem - } - } - itemHashes = append(itemHashes, itemHash) - if addSeparator { - err = fetch.Input.Separator.Render(l.ctx, nil, buf.preparedInput) - if err != nil { - return errors.WithStack(err) + res.tools.keyGen.Reset() + _, _ = res.tools.keyGen.Write(itemInput.Bytes()) + itemHash := res.tools.keyGen.Sum64() + if existingIndex, ok := res.tools.batchHashToIndex[itemHash]; ok { + batchStats[existingIndex] = arena.SliceAppend(res.tools.a, batchStats[existingIndex], items[i]) + continue WithNextItem + } else { + if addSeparator { + err = fetch.Input.Separator.Render(l.ctx, nil, preparedInput) + if err != nil { + return errors.WithStack(err) + } } + _, _ = itemInput.WriteTo(preparedInput) + // new unique representation + res.tools.batchHashToIndex[itemHash] = batchItemIndex + // create a new targets bucket for this unique index + batchStats = arena.SliceAppend(res.tools.a, batchStats, []*astjson.Value{items[i]}) + batchItemIndex++ + addSeparator = true } - _, _ = buf.itemInput.WriteTo(buf.preparedInput) - res.batchStats[i] = append(res.batchStats[i], batchItemIndex) - batchItemIndex++ - addSeparator = true } } - if len(itemHashes) == 0 { + if len(batchStats) == 0 { // all items were skipped - discard fetch res.fetchSkipped = true if l.ctx.TracingOptions.Enable { @@ -1541,16 +1721,23 @@ WithNextItem: } } - err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = SetInputUndefinedVariables(buf.preparedInput, undefinedVariables) + err = SetInputUndefinedVariables(preparedInput, undefinedVariables) if err != nil { return errors.WithStack(err) } - fetchInput := buf.preparedInput.Bytes() + + fetchInput := preparedInput.Bytes() + // it's important to copy the *astjson.Value's off the arena to avoid memory corruption + res.batchStats = make([][]*astjson.Value, len(batchStats)) + for i := range batchStats { + res.batchStats[i] = make([]*astjson.Value, len(batchStats[i])) + copy(res.batchStats[i], batchStats[i]) + } if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchItem, fetchInput, fetch.Trace) @@ -1564,6 +1751,7 @@ WithNextItem: if !allowed { return nil } + l.executeSourceLoad(ctx, fetchItem, fetch.DataSource, fetchInput, res, fetch.Trace) return nil } @@ -1605,29 +1793,8 @@ func redactHeaders(rawJSON json.RawMessage) (json.RawMessage, error) { return redactedJSON, nil } -type disallowSingleFlightContextKey struct{} - -func SingleFlightDisallowed(ctx context.Context) bool { - return ctx.Value(disallowSingleFlightContextKey{}) != nil -} - -type singleFlightStatsKey struct{} - -type SingleFlightStats struct { - SingleFlightUsed bool - SingleFlightSharedResponse bool -} - -func GetSingleFlightStats(ctx context.Context) *SingleFlightStats { - maybeStats := ctx.Value(singleFlightStatsKey{}) - if maybeStats == nil { - return nil - } - return maybeStats.(*SingleFlightStats) -} - -func setSingleFlightStats(ctx context.Context, stats *SingleFlightStats) context.Context { - return context.WithValue(ctx, singleFlightStatsKey{}, stats) +type singleFlightStats struct { + used, shared bool } func (l *Loader) setTracingInput(fetchItem *FetchItem, input []byte, trace *DataSourceLoadTrace) { @@ -1643,11 +1810,120 @@ func (l *Loader) setTracingInput(fetchItem *FetchItem, input []byte, trace *Data } } -func (l *Loader) loadByContext(ctx context.Context, source DataSource, input []byte, res *result) error { +type loaderContextKey string + +const ( + operationTypeContextKey loaderContextKey = "operationType" +) + +// GetOperationTypeFromContext can be used, e.g. by the transport, to check if the operation is a Mutation +func GetOperationTypeFromContext(ctx context.Context) ast.OperationType { + if ctx == nil { + return ast.OperationTypeQuery + } + if v := ctx.Value(operationTypeContextKey); v != nil { + if opType, ok := v.(ast.OperationType); ok { + return opType + } + } + return ast.OperationTypeQuery +} + +func (l *Loader) headersForSubgraphRequest(fetchItem *FetchItem) (http.Header, uint64) { + if fetchItem == nil || fetchItem.Fetch == nil { + return nil, 0 + } + info := fetchItem.Fetch.FetchInfo() + if info == nil { + return nil, 0 + } + return l.ctx.HeadersForSubgraphRequest(info.DataSourceName) +} + +// singleFlightAllowed returns true if the specific GraphQL Operation is a Query +// even if the root operation type is a Mutation or Subscription +// sub-operations can still be of type Query +// even in such cases we allow request de-duplication because such requests are idempotent +func (l *Loader) singleFlightAllowed(fetchItem *FetchItem) bool { + if l.ctx.ExecutionOptions.DisableSubgraphRequestDeduplication { + return false + } + if fetchItem == nil { + return false + } + if fetchItem.Fetch == nil { + return false + } + info := fetchItem.Fetch.FetchInfo() + if info == nil { + return false + } + if info.OperationType == ast.OperationTypeQuery { + return true + } + return false +} + +func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem *FetchItem, input []byte, res *result) error { + + if l.info != nil { + ctx = context.WithValue(ctx, operationTypeContextKey, l.info.OperationType) + } + + headers, extraKey := l.headersForSubgraphRequest(fetchItem) + + if !l.singleFlightAllowed(fetchItem) { + // Disable single flight for mutations + return l.loadByContextDirect(ctx, source, headers, input, res) + } + + sfKey, fetchKey, item, shared := l.sf.GetOrCreateItem(fetchItem, input, extraKey) + if res.singleFlightStats != nil { + res.singleFlightStats.used = true + res.singleFlightStats.shared = shared + } + + if shared { + select { + case <-item.loaded: + case <-ctx.Done(): + return ctx.Err() + } + + if item.err != nil { + return item.err + } + + res.out = item.response + return nil + } + + // helps the http client to create buffers at the right size + ctx = httpclient.WithHTTPClientSizeHint(ctx, item.sizeHint) + + defer l.sf.Finish(sfKey, fetchKey, item) + + // Perform the actual load + err := l.loadByContextDirect(ctx, source, headers, input, res) + if err != nil { + item.err = err + return err + } + + item.response = res.out + return nil +} + +func (l *Loader) loadByContextDirect(ctx context.Context, source DataSource, headers http.Header, input []byte, res *result) error { if l.ctx.Files != nil { - return source.LoadWithFiles(ctx, input, l.ctx.Files, res.out) + res.out, res.err = source.LoadWithFiles(ctx, headers, input, l.ctx.Files) + } else { + res.out, res.err = source.Load(ctx, headers, input) } - return source.Load(ctx, input, res.out) + if res.err != nil { + return errors.WithStack(res.err) + } + return nil } func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, source DataSource, input []byte, res *result, trace *DataSourceLoadTrace) { @@ -1676,7 +1952,7 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so } } if l.ctx.TracingOptions.Enable { - ctx = setSingleFlightStats(ctx, &SingleFlightStats{}) + res.singleFlightStats = &singleFlightStats{} trace.Path = fetchItem.ResponsePath if !l.ctx.TracingOptions.ExcludeInput { trace.Input = make([]byte, len(input)) @@ -1780,9 +2056,6 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so ctx = httptrace.WithClientTrace(ctx, clientTrace) } } - if l.info != nil && l.info.OperationType == ast.OperationTypeMutation { - ctx = context.WithValue(ctx, disallowSingleFlightContextKey{}, true) - } var responseContext *httpclient.ResponseContext ctx, responseContext = httpclient.InjectResponseContext(ctx) @@ -1791,27 +2064,26 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so // Prevent that the context is destroyed when the loader hook return an empty context if res.loaderHookContext != nil { - res.err = l.loadByContext(res.loaderHookContext, source, input, res) + res.err = l.loadByContext(res.loaderHookContext, source, fetchItem, input, res) } else { - res.err = l.loadByContext(ctx, source, input, res) + res.err = l.loadByContext(ctx, source, fetchItem, input, res) res.loaderHookContext = ctx // Set the context to the original context to ensure that OnFinished hook gets valid context } } else { - res.err = l.loadByContext(ctx, source, input, res) + res.err = l.loadByContext(ctx, source, fetchItem, input, res) } res.statusCode = responseContext.StatusCode res.httpResponseContext = responseContext if l.ctx.TracingOptions.Enable { - stats := GetSingleFlightStats(ctx) - if stats != nil { - trace.SingleFlightUsed = stats.SingleFlightUsed - trace.SingleFlightSharedResponse = stats.SingleFlightSharedResponse + if res.singleFlightStats != nil { + trace.SingleFlightUsed = res.singleFlightStats.used + trace.SingleFlightSharedResponse = res.singleFlightStats.shared } - if !l.ctx.TracingOptions.ExcludeOutput && res.out.Len() > 0 { - trace.Output, _ = l.compactJSON(res.out.Bytes()) + if !l.ctx.TracingOptions.ExcludeOutput && len(res.out) > 0 { + trace.Output, _ = l.compactJSON(res.out) if l.ctx.TracingOptions.EnablePredictableDebugTimings { trace.Output, _ = sjson.DeleteBytes(trace.Output, "extensions.trace.response.headers.Date") } @@ -1840,10 +2112,133 @@ func (l *Loader) compactJSON(data []byte) ([]byte, error) { return nil, err } out := dst.Bytes() - v, err := astjson.ParseBytesWithoutCache(out) + // don't use arena here or segfault + // it's also not a hot path and not important to optimize + // arena requires the parsed content to be on the arena as well + v, err := astjson.ParseBytes(out) if err != nil { return nil, err } astjson.DeduplicateObjectKeysRecursively(v) return v.MarshalTo(nil), nil } + +// canSkipFetch returns true if the cache provided exactly the information required to satisfy the query plan +// the query planner generates info.ProvidesData which tells precisely which fields the fetch must load +// if a single value is missing, we will execute the fetch +func (l *Loader) canSkipFetch(info *FetchInfo, res *result) bool { + if info == nil || info.OperationType != ast.OperationTypeQuery || info.ProvidesData == nil { + return false + } + for i := range res.cacheKeys { + if !l.validateItemHasRequiredData(res.cacheKeys[i].FromCache, info.ProvidesData) { + return false + } + } + return true +} + +// validateItemHasRequiredData checks if the given item contains all required data +// as specified by the provided Object schema +func (l *Loader) validateItemHasRequiredData(item *astjson.Value, obj *Object) bool { + if item == nil { + return false + } + // Validate each field in the object + for _, field := range obj.Fields { + if !l.validateFieldData(item, field) { + return false + } + } + + return true +} + +// validateFieldData validates a single field against the item data +func (l *Loader) validateFieldData(item *astjson.Value, field *Field) bool { + fieldValue := item.Get(unsafebytes.BytesToString(field.Name)) + + // Check if field exists + if fieldValue == nil { + // Field is missing - this fails validation regardless of nullability + // Even nullable fields must be present (can be null, but not missing) + return false + } + + // Validate the field value against its specification + return l.validateNodeValue(fieldValue, field.Value) +} + +// validateScalarData validates scalar field data +func (l *Loader) validateScalarData(value *astjson.Value, scalar *Scalar) bool { + if value.Type() == astjson.TypeNull { + // Null is only allowed if the scalar is nullable + return scalar.Nullable + } + + // Any non-null value is acceptable for a scalar + return true +} + +// validateObjectData validates object field data +func (l *Loader) validateObjectData(value *astjson.Value, obj *Object) bool { + if value.Type() == astjson.TypeNull { + // Null is only allowed if the object is nullable + return obj.Nullable + } + + if value.Type() != astjson.TypeObject { + // Must be an object (or null if nullable) + return false + } + + // Recursively validate the object's fields + return l.validateItemHasRequiredData(value, obj) +} + +// validateArrayData validates array field data +func (l *Loader) validateArrayData(value *astjson.Value, arr *Array) bool { + if value.Type() == astjson.TypeNull { + // Null is only allowed if the array is nullable + return arr.Nullable + } + + if value.Type() != astjson.TypeArray { + // Must be an array (or null if nullable) + return false + } + + // If there's no item specification, we just validate the array exists + if arr.Item == nil { + return true + } + + // Validate each item in the array + arrayItems, err := value.Array() + if err != nil { + return false + } + + for _, item := range arrayItems { + if !l.validateNodeValue(item, arr.Item) { + return false + } + } + + return true +} + +// validateNodeValue validates a value against a Node specification +func (l *Loader) validateNodeValue(value *astjson.Value, nodeSpec Node) bool { + switch v := nodeSpec.(type) { + case *Scalar: + return l.validateScalarData(value, v) + case *Object: + return l.validateObjectData(value, v) + case *Array: + return l.validateArrayData(value, v) + default: + // Unknown type - assume invalid + return false + } +} diff --git a/v2/pkg/engine/resolve/loader_hooks_test.go b/v2/pkg/engine/resolve/loader_hooks_test.go index 4b7b3ea6c..4a2ce9cb2 100644 --- a/v2/pkg/engine/resolve/loader_hooks_test.go +++ b/v2/pkg/engine/resolve/loader_hooks_test.go @@ -3,7 +3,7 @@ package resolve import ( "bytes" "context" - "io" + "net/http" "sync" "sync/atomic" "testing" @@ -50,11 +50,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("simple fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := Context{ ctx: context.Background(), @@ -124,11 +122,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := &Context{ ctx: context.Background(), @@ -192,11 +188,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("parallel fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := &Context{ ctx: context.Background(), @@ -254,80 +248,12 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { } })) - t.Run("parallel list item fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { - mockDataSource := NewMockDataSource(ctrl) - mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) - }) - resolveCtx := Context{ - ctx: context.Background(), - LoaderHooks: NewTestLoaderHooks(), - } - return &GraphQLResponse{ - Info: &GraphQLResponseInfo{ - OperationType: ast.OperationTypeQuery, - }, - Fetches: SingleWithPath(&ParallelListItemFetch{ - Fetch: &SingleFetch{ - FetchConfiguration: FetchConfiguration{ - DataSource: mockDataSource, - PostProcessing: PostProcessingConfiguration{ - SelectResponseErrorsPath: []string{"errors"}, - }, - }, - Info: &FetchInfo{ - DataSourceID: "Users", - DataSourceName: "Users", - }, - }, - }, "query"), - Data: &Object{ - Nullable: false, - Fields: []*Field{ - { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, - Nullable: true, - }, - }, - }, - }, - }, &resolveCtx, `{"errors":[{"message":"Failed to fetch from Subgraph 'Users' at Path 'query'.","extensions":{"errors":[{"message":"errorMessage"}]}}],"data":{"name":null}}`, - func(t *testing.T) { - loaderHooks := resolveCtx.LoaderHooks.(*TestLoaderHooks) - - assert.Equal(t, int64(1), loaderHooks.preFetchCalls.Load()) - assert.Equal(t, int64(1), loaderHooks.postFetchCalls.Load()) - - var subgraphError *SubgraphError - assert.Len(t, loaderHooks.errors, 1) - assert.ErrorAs(t, loaderHooks.errors[0], &subgraphError) - assert.Equal(t, "Users", subgraphError.DataSourceInfo.Name) - assert.Equal(t, "query", subgraphError.Path) - assert.Equal(t, "", subgraphError.Reason) - assert.Equal(t, 0, subgraphError.ResponseCode) - assert.Len(t, subgraphError.DownstreamErrors, 1) - assert.Equal(t, "errorMessage", subgraphError.DownstreamErrors[0].Message) - assert.Nil(t, subgraphError.DownstreamErrors[0].Extensions) - - assert.NotNil(t, resolveCtx.SubgraphErrors()) - } - })) - t.Run("fetch with subgraph error and custom extension code. No extension fields are propagated by default", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) resolveCtx := Context{ ctx: context.Background(), @@ -388,12 +314,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Propagate only extension code field from subgraph errors", testFnSubgraphErrorsWithExtensionFieldCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\",\"foo\":\"bar\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","foo":"bar"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -426,12 +349,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Propagate all extension fields from subgraph errors when allow all option is enabled", testFnSubgraphErrorsWithAllowAllExtensionFields(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\",\"foo\":\"bar\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","foo":"bar"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -464,12 +384,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName extension field", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -502,12 +419,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName when extensions is null", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("null")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("null")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":null},{"message":"errorMessage2","extensions":null}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -540,12 +454,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName when extensions is an empty object", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("null")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{}},{"message":"errorMessage2","extensions":null}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -578,12 +489,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when no code field was set", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -616,12 +524,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when extensions is null", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("null")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":null},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -654,12 +559,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when extensions is an empty object", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{}},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ diff --git a/v2/pkg/engine/resolve/loader_skip_fetch_test.go b/v2/pkg/engine/resolve/loader_skip_fetch_test.go new file mode 100644 index 000000000..aadac1584 --- /dev/null +++ b/v2/pkg/engine/resolve/loader_skip_fetch_test.go @@ -0,0 +1,878 @@ +package resolve + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/wundergraph/astjson" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +func TestLoader_canSkipFetch(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + info *FetchInfo + items []*astjson.Value + expectSkipFetch bool + }{ + { + name: "single item with Query operation", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"id": "123"}`)), + }, + expectSkipFetch: true, + }, + { + name: "single item with Mutation operation", + info: &FetchInfo{ + OperationType: ast.OperationTypeMutation, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"id": "123"}`)), + }, + expectSkipFetch: false, + }, + { + name: "single item with null type", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{Fields: []*Field{}}, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`null`)), + }, + expectSkipFetch: true, + }, + { + name: "single item with all required data", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123", "name": "John"}}`)), + }, + expectSkipFetch: true, + }, + { + name: "single item missing required field", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123"}}`)), // missing "name" + }, + expectSkipFetch: false, + }, + { + name: "single item missing nullable field", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("email"), + Value: &Scalar{ + Path: []string{"email"}, + Nullable: true, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123"}}`)), // missing nullable "email" + }, + expectSkipFetch: false, + }, + { + name: "single item with null value on required path", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": null}}`)), // null value on required field + }, + expectSkipFetch: false, + }, + { + name: "single item with null value on nullable path", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("email"), + Value: &Scalar{ + Path: []string{"email"}, + Nullable: true, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123", "email": null}}`)), // null value on nullable field + }, + expectSkipFetch: true, + }, + { + name: "multiple items all can be skipped", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"id": "123"}`)), + astjson.MustParseBytes([]byte(`{"id": "456"}`)), + astjson.MustParseBytes([]byte(`{"id": "789"}`)), + }, + expectSkipFetch: true, + }, + { + name: "multiple items some can be skipped", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123", "name": "John"}}`)), // complete + astjson.MustParseBytes([]byte(`{"user": {"id": "456"}}`)), // missing name + astjson.MustParseBytes([]byte(`{"user": {"id": "789", "name": "Alice"}}`)), // complete + }, + expectSkipFetch: false, + }, + { + name: "multiple items none can be skipped", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123"}}`)), // missing name + astjson.MustParseBytes([]byte(`{"user": {"id": "456"}}`)), // missing name + astjson.MustParseBytes([]byte(`{"user": {"id": "789"}}`)), // missing name + }, + expectSkipFetch: false, + }, + { + name: "nullable array that is null", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("tags"), + Value: &Array{ + Path: []string{"tags"}, + Nullable: true, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123", "tags": null}}`)), + }, + expectSkipFetch: true, + }, + { + name: "nullable array that is empty", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("tags"), + Value: &Array{ + Path: []string{"tags"}, + Nullable: true, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"user": {"id": "123", "tags": []}}`)), + }, + expectSkipFetch: true, + }, + { + name: "deeply nested structure", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Nullable: true, + Fields: []*Field{ + { + Name: []byte("account"), + Value: &Object{ + Path: []string{"account"}, + Nullable: true, + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &Scalar{ + Path: []string{"__typename"}, + Nullable: false, + }, + }, + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("info"), + Value: &Object{ + Path: []string{"info"}, + Nullable: true, + Fields: []*Field{ + { + Name: []byte("a"), + Value: &Scalar{ + Path: []string{"a"}, + Nullable: false, + }, + }, + { + Name: []byte("b"), + Value: &Scalar{ + Path: []string{"b"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{ + "user": { + "account": { + "__typename": "Account", + "id": "123", + "info": { + "a": "valueA", + "b": "valueB" + } + } + } + }`)), + }, + expectSkipFetch: true, + }, + { + name: "nil info", + info: nil, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"id": "123"}`)), + }, + expectSkipFetch: false, + }, + { + name: "nil ProvidesData", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: nil, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"id": "123"}`)), + }, + expectSkipFetch: false, + }, + { + name: "array with scalar items - valid", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("tags"), + Value: &Array{ + Path: []string{"tags"}, + Nullable: false, + Item: &Scalar{ + Path: []string{}, + Nullable: false, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"tags": ["tag1", "tag2", "tag3"]}`)), + }, + expectSkipFetch: true, + }, + { + name: "array with scalar items - invalid (null item in non-nullable array)", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("tags"), + Value: &Array{ + Path: []string{"tags"}, + Nullable: false, + Item: &Scalar{ + Path: []string{}, + Nullable: false, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"tags": ["tag1", null, "tag3"]}`)), // null item in non-nullable array + }, + expectSkipFetch: false, + }, + { + name: "array with scalar items - valid (null item in nullable array)", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("tags"), + Value: &Array{ + Path: []string{"tags"}, + Nullable: false, + Item: &Scalar{ + Path: []string{}, + Nullable: true, // nullable scalar items + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"tags": ["tag1", null, "tag3"]}`)), // null item in nullable array + }, + expectSkipFetch: true, + }, + { + name: "array with object items - valid", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("users"), + Value: &Array{ + Path: []string{"users"}, + Nullable: false, + Item: &Object{ + Path: []string{}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"users": [{"id": "1", "name": "John"}, {"id": "2", "name": "Jane"}]}`)), + }, + expectSkipFetch: true, + }, + { + name: "array with object items - invalid (missing required field)", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("users"), + Value: &Array{ + Path: []string{"users"}, + Nullable: false, + Item: &Object{ + Path: []string{}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"users": [{"id": "1", "name": "John"}, {"id": "2"}]}`)), // missing "name" field + }, + expectSkipFetch: false, + }, + { + name: "nested arrays - valid", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("matrix"), + Value: &Array{ + Path: []string{"matrix"}, + Nullable: false, + Item: &Array{ + Path: []string{}, + Nullable: false, + Item: &Scalar{ + Path: []string{}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"matrix": [["a", "b"], ["c", "d"], ["e", "f"]]}`)), + }, + expectSkipFetch: true, + }, + { + name: "nested arrays - invalid (null in inner non-nullable array)", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("matrix"), + Value: &Array{ + Path: []string{"matrix"}, + Nullable: false, + Item: &Array{ + Path: []string{}, + Nullable: false, + Item: &Scalar{ + Path: []string{}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"matrix": [["a", "b"], ["c", null], ["e", "f"]]}`)), // null in inner array + }, + expectSkipFetch: false, + }, + { + name: "array of objects with nested arrays - complex valid case", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("groups"), + Value: &Array{ + Path: []string{"groups"}, + Nullable: false, + Item: &Object{ + Path: []string{}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("members"), + Value: &Array{ + Path: []string{"members"}, + Nullable: false, + Item: &Object{ + Path: []string{}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"groups": [{"name": "admins", "members": [{"id": "1"}, {"id": "2"}]}, {"name": "users", "members": [{"id": "3"}]}]}`)), + }, + expectSkipFetch: true, + }, + { + name: "array of objects with nested arrays - complex invalid case", + info: &FetchInfo{ + OperationType: ast.OperationTypeQuery, + ProvidesData: &Object{ + Fields: []*Field{ + { + Name: []byte("groups"), + Value: &Array{ + Path: []string{"groups"}, + Nullable: false, + Item: &Object{ + Path: []string{}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &Scalar{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("members"), + Value: &Array{ + Path: []string{"members"}, + Nullable: false, + Item: &Object{ + Path: []string{}, + Nullable: false, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Scalar{ + Path: []string{"id"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + items: []*astjson.Value{ + astjson.MustParseBytes([]byte(`{"groups": [{"name": "admins", "members": [{"id": "1"}, {}]}, {"name": "users", "members": [{"id": "3"}]}]}`)), // missing id in one member + }, + expectSkipFetch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + loader := &Loader{} + + // Make a copy of items to avoid mutation affecting the test data + itemsCopy := make([]*astjson.Value, len(tt.items)) + copy(itemsCopy, tt.items) + + // Create cache keys with Item set to the corresponding test items + cacheKeys := make([]*CacheKey, len(itemsCopy)) + for i, item := range itemsCopy { + cacheKeys[i] = &CacheKey{ + FromCache: item, + } + } + + // Create a result struct for canSkipFetch + res := &result{ + cacheKeys: cacheKeys, + } + + canSkipFetch := loader.canSkipFetch(tt.info, res) + assert.Equal(t, tt.expectSkipFetch, canSkipFetch, "skip fetch") + }) + } +} diff --git a/v2/pkg/engine/resolve/loader_test.go b/v2/pkg/engine/resolve/loader_test.go index 4ed83d444..f88d7227f 100644 --- a/v2/pkg/engine/resolve/loader_test.go +++ b/v2/pkg/engine/resolve/loader_test.go @@ -19,19 +19,19 @@ func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctrl := gomock.NewController(t) productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + `{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"}]}`) + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`) response := &GraphQLResponse{ Fetches: Sequence( Single(&SingleFetch{ @@ -287,7 +287,7 @@ func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -296,7 +296,7 @@ func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctrl.Finish() out := fastjsonext.PrintGraphQLResponse(resolvable.data, resolvable.errors) assert.NoError(t, err) - expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` + expected := `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` assert.Equal(t, expected, out) } @@ -376,7 +376,7 @@ func TestLoader_MergeErrorDifferingTypes(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -467,7 +467,7 @@ func TestLoader_MergeErrorDifferingArrayLength(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -480,19 +480,19 @@ func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { ctrl := gomock.NewController(t) productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}","extensions":{"foo":"bar"}}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]},"extensions":{"foo":"bar"}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]},"extensions":{"foo":"bar"}}}`, - `{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + `{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]},"extensions":{"foo":"bar"}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"}]}`) + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`) response := &GraphQLResponse{ Fetches: Sequence( Single(&SingleFetch{ @@ -749,7 +749,7 @@ func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { ctx: context.Background(), Extensions: []byte(`{"foo":"bar"}`), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -758,7 +758,7 @@ func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { ctrl.Finish() out := fastjsonext.PrintGraphQLResponse(resolvable.data, resolvable.errors) assert.NoError(t, err) - expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` + expected := `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` assert.Equal(t, expected, out) } @@ -1024,9 +1024,9 @@ func BenchmarkLoader_LoadGraphQLResponseData(b *testing.B) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} - expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` + expected := `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` b.SetBytes(int64(len(expected))) b.ReportAllocs() b.ResetTimer() @@ -1054,7 +1054,7 @@ func TestLoader_RedactHeaders(t *testing.T) { productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","header":{"Authorization":"value"},"body":{"query":"query{topProducts{name __typename upc}}"},"__trace__":true}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) response := &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -1125,7 +1125,7 @@ func TestLoader_RedactHeaders(t *testing.T) { Enable: true, }, } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) @@ -1153,19 +1153,19 @@ func TestLoader_InvalidBatchItemCount(t *testing.T) { ctrl := gomock.NewController(t) productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"stock":8},{"stock":2}]}`) // 3 items expected, 2 returned + `{"data":{"_entities":[{"stock":8},{"stock":2}]}}`) // 3 items expected, 2 returned usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"},{"name":"user-3"}]}`) // 2 items expected, 3 returned + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"},{"name":"user-3"}]}}`) // 2 items expected, 3 returned response := &GraphQLResponse{ Fetches: Sequence( Single(&SingleFetch{ @@ -1421,7 +1421,7 @@ func TestLoader_InvalidBatchItemCount(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -1521,13 +1521,13 @@ func TestRewriteErrorPaths(t *testing.T) { for i, inputError := range tc.inputErrors { // Create a copy by marshaling and parsing again data := inputError.MarshalTo(nil) - value, err := astjson.ParseBytesWithoutCache(data) + value, err := astjson.ParseBytesWithArena(nil, data) assert.NoError(t, err, "Failed to copy input error") values[i] = value } // Call the function under test - rewriteErrorPaths(fetchItem, values) + rewriteErrorPaths(nil, fetchItem, values) // Compare the results assert.Equal(t, len(tc.expectedErrors), len(values), diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index 5219c910d..226705a70 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -13,6 +13,7 @@ import ( "github.com/tidwall/gjson" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/errorcodes" @@ -30,8 +31,9 @@ type Resolvable struct { errors *astjson.Value valueCompletion *astjson.Value skipAddingNullErrors bool - - astjsonArena *astjson.Arena + // astjsonArena is the arena to handle json, supplied by Resolver + // not thread safe, but Resolvable is single threaded anyways + astjsonArena arena.Arena parsers []*astjson.Parser print bool @@ -67,13 +69,13 @@ type ResolvableOptions struct { ApolloCompatibilityReplaceInvalidVarError bool } -func NewResolvable(options ResolvableOptions) *Resolvable { +func NewResolvable(a arena.Arena, options ResolvableOptions) *Resolvable { return &Resolvable{ options: options, xxh: xxhash.New(), authorizationAllow: make(map[uint64]struct{}), authorizationDeny: make(map[uint64]string), - astjsonArena: &astjson.Arena{}, + astjsonArena: a, } } @@ -95,7 +97,7 @@ func (r *Resolvable) Reset() { r.operationType = ast.OperationTypeUnknown r.renameTypeNames = r.renameTypeNames[:0] r.authorizationError = nil - r.astjsonArena.Reset() + r.astjsonArena = nil r.xxh.Reset() for k := range r.authorizationAllow { delete(r.authorizationAllow, k) @@ -109,14 +111,15 @@ func (r *Resolvable) Init(ctx *Context, initialData []byte, operationType ast.Op r.ctx = ctx r.operationType = operationType r.renameTypeNames = ctx.RenameTypeNames - r.data = r.astjsonArena.NewObject() - r.errors = r.astjsonArena.NewArray() + r.data = astjson.ObjectValue(r.astjsonArena) + // don't init errors! It will heavily increase memory usage + r.errors = nil if initialData != nil { - initialValue, err := astjson.ParseBytesWithoutCache(initialData) + initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) if err != nil { return err } - r.data, _, err = astjson.MergeValues(r.data, initialValue) + r.data, _, err = astjson.MergeValues(r.astjsonArena, r.data, initialValue) if err != nil { return err } @@ -128,20 +131,22 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc r.ctx = ctx r.operationType = ast.OperationTypeSubscription r.renameTypeNames = ctx.RenameTypeNames + // don't init errors! It will heavily increase memory usage + r.errors = nil if initialData != nil { - initialValue, err := astjson.ParseBytesWithoutCache(initialData) + initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) if err != nil { return err } if postProcessing.SelectResponseDataPath == nil { - r.data, _, err = astjson.MergeValuesWithPath(r.data, initialValue, postProcessing.MergePath...) + r.data, _, err = astjson.MergeValuesWithPath(r.astjsonArena, r.data, initialValue, postProcessing.MergePath...) if err != nil { return err } } else { selectedInitialValue := initialValue.Get(postProcessing.SelectResponseDataPath...) if selectedInitialValue != nil { - r.data, _, err = astjson.MergeValuesWithPath(r.data, selectedInitialValue, postProcessing.MergePath...) + r.data, _, err = astjson.MergeValuesWithPath(r.astjsonArena, r.data, selectedInitialValue, postProcessing.MergePath...) if err != nil { return err } @@ -155,10 +160,7 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc } } if r.data == nil { - r.data = r.astjsonArena.NewObject() - } - if r.errors == nil { - r.errors = r.astjsonArena.NewArray() + r.data = astjson.ObjectValue(r.astjsonArena) } return } @@ -168,7 +170,8 @@ func (r *Resolvable) ResolveNode(node Node, data *astjson.Value, out io.Writer) r.print = false r.printErr = nil r.authorizationError = nil - r.errors = r.astjsonArena.NewArray() + // don't init errors! It will heavily increase memory usage + r.errors = nil hasErrors := r.walkNode(node, data) if hasErrors { @@ -234,6 +237,13 @@ func (r *Resolvable) Resolve(ctx context.Context, rootData *Object, fetchTree *F return r.printErr } +// ensureErrorsInitialized is used to lazily init r.errors if needed +func (r *Resolvable) ensureErrorsInitialized() { + if r.errors == nil { + r.errors = astjson.ArrayValue(r.astjsonArena) + } +} + func (r *Resolvable) enclosingTypeName() string { if len(r.enclosingTypeNames) > 0 { return r.enclosingTypeNames[len(r.enclosingTypeNames)-1] @@ -464,7 +474,7 @@ func (r *Resolvable) renderScalarFieldValue(value *astjson.Value, nullable bool) // renderScalarFieldString - is used when value require some pre-processing, e.g. unescaping or custom rendering func (r *Resolvable) renderScalarFieldBytes(data []byte, nullable bool) { - value, err := astjson.ParseBytesWithoutCache(data) + value, err := astjson.ParseBytesWithArena(r.astjsonArena, data) if err != nil { r.printErr = err return @@ -760,6 +770,7 @@ func (r *Resolvable) addRejectFieldError(reason string, ds DataSourceInfo, field } r.ctx.appendSubgraphErrors(errors.New(errorMessage), NewSubgraphError(ds, fieldPath, reason, 0)) + r.ensureErrorsInitialized() fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.errors, errorMessage, errorcodes.UnauthorizedFieldOrType, r.path) r.popNodePathElement(nodePath) } @@ -853,7 +864,7 @@ func (r *Resolvable) walkArray(arr *Array, value *astjson.Value) bool { r.popArrayPathElement() if err { if arr.Item.NodeKind() == NodeKindObject && arr.Item.NodeNullable() { - value.SetArrayItem(i, astjson.NullValue) + value.SetArrayItem(r.astjsonArena, i, astjson.NullValue) continue } if arr.Nullable { @@ -1201,6 +1212,7 @@ func (r *Resolvable) addNonNullableFieldError(fieldPath []string, parent *astjso r.addValueCompletion(r.renderApolloCompatibleNonNullableErrorMessage(), errorcodes.InvalidGraphql) } else { errorMessage := fmt.Sprintf("Cannot return null for non-nullable field '%s'.", r.renderFieldPath()) + r.ensureErrorsInitialized() fastjsonext.AppendErrorToArray(r.astjsonArena, r.errors, errorMessage, r.path) } r.popNodePathElement(fieldPath) @@ -1271,30 +1283,33 @@ func (r *Resolvable) renderFieldCoordinates() string { func (r *Resolvable) addError(message string, fieldPath []string) { r.pushNodePathElement(fieldPath) + r.ensureErrorsInitialized() fastjsonext.AppendErrorToArray(r.astjsonArena, r.errors, message, r.path) r.popNodePathElement(fieldPath) } func (r *Resolvable) addErrorWithCode(message, code string) { + r.ensureErrorsInitialized() fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.errors, message, code, r.path) } func (r *Resolvable) addErrorWithCodeAndPath(message, code string, fieldPath []string) { r.pushNodePathElement(fieldPath) + r.ensureErrorsInitialized() fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.errors, message, code, r.path) r.popNodePathElement(fieldPath) } func (r *Resolvable) addValueCompletion(message, code string) { if r.valueCompletion == nil { - r.valueCompletion = r.astjsonArena.NewArray() + r.valueCompletion = astjson.ArrayValue(r.astjsonArena) } fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.valueCompletion, message, code, r.path) } func (r *Resolvable) addValueCompletionWithPath(message, code string, fieldPath []string) { if r.valueCompletion == nil { - r.valueCompletion = r.astjsonArena.NewArray() + r.valueCompletion = astjson.ArrayValue(r.astjsonArena) } r.pushNodePathElement(fieldPath) fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.valueCompletion, message, code, r.path) diff --git a/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go b/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go index 843c6e696..0dbb0394b 100644 --- a/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go +++ b/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go @@ -440,7 +440,7 @@ func TestResolvable_CustomFieldRenderer(t *testing.T) { t.Parallel() // Setup - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{} var input []byte @@ -543,7 +543,7 @@ func TestResolvable_CustomFieldRenderer(t *testing.T) { t.Parallel() input := []byte(tc.input) - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{} err := res.Init(ctx, input, ast.OperationTypeQuery) assert.NoError(t, err) diff --git a/v2/pkg/engine/resolve/resolvable_test.go b/v2/pkg/engine/resolve/resolvable_test.go index 4b92f8591..aea4e78ef 100644 --- a/v2/pkg/engine/resolve/resolvable_test.go +++ b/v2/pkg/engine/resolve/resolvable_test.go @@ -12,7 +12,7 @@ import ( func TestResolvable_Resolve(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -84,7 +84,7 @@ func TestResolvable_Resolve(t *testing.T) { func TestResolvable_ResolveWithTypeMismatch(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":true}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -157,7 +157,7 @@ func TestResolvable_ResolveWithTypeMismatch(t *testing.T) { func TestResolvable_ResolveWithErrorBubbleUp(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -231,7 +231,7 @@ func TestResolvable_ResolveWithErrorBubbleUp(t *testing.T) { func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { t.Run("Non-nullable root field", func(t *testing.T) { topProducts := `{"topProducts":null}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -258,7 +258,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-Nullable root field and nested field", func(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -333,7 +333,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Nullable root field and non-Nullable nested field", func(t *testing.T) { topProducts := `{"topProduct":{"name":null}}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -370,7 +370,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-Nullable sibling field", func(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","reviews":[{"author":{"__typename":"User","name":"Bob"},"body":null}]}]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -439,7 +439,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-nullable array and array item", func(t *testing.T) { topProducts := `{"topProducts":[null]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -469,7 +469,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Nullable array and non-nullable array item", func(t *testing.T) { topProducts := `{"topProducts":[null]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -500,7 +500,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-Nullable array, array item, and array item field", func(t *testing.T) { topProducts := `{"topProducts":[{"author":{"name":"Name"}},{"author":null}]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -549,7 +549,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { func TestResolvable_ResolveWithErrorBubbleUpUntilData(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -622,7 +622,7 @@ func TestResolvable_ResolveWithErrorBubbleUpUntilData(t *testing.T) { func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Invalid enum value", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -653,7 +653,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Inaccessible enum value", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -686,7 +686,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Invalid enum value with value completion Apollo compatibility flag", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -719,7 +719,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Inaccessible enum value with value completion Apollo compatibility flag", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -755,7 +755,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { func BenchmarkResolvable_Resolve(b *testing.B) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -838,7 +838,7 @@ func BenchmarkResolvable_Resolve(b *testing.B) { func BenchmarkResolvable_ResolveWithErrorBubbleUp(b *testing.B) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -923,7 +923,7 @@ func BenchmarkResolvable_ResolveWithErrorBubbleUp(b *testing.B) { } func TestResolvable_WithTracingNotStarted(t *testing.T) { - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) // Do not start a trace with SetTraceStart(), but request it to be output ctx := NewContext(context.Background()) ctx.TracingOptions.Enable = true @@ -950,7 +950,7 @@ func TestResolvable_WithTracingNotStarted(t *testing.T) { func TestResolveFloat(t *testing.T) { t.Run("default behaviour", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := NewContext(context.Background()) err := res.Init(ctx, []byte(`{"f":1.0}`), ast.OperationTypeQuery) assert.NoError(t, err) @@ -972,7 +972,7 @@ func TestResolveFloat(t *testing.T) { assert.Equal(t, `{"data":{"f":1.0}}`, out.String()) }) t.Run("invalid float", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := NewContext(context.Background()) err := res.Init(ctx, []byte(`{"f":false}`), ast.OperationTypeQuery) assert.NoError(t, err) @@ -994,7 +994,7 @@ func TestResolveFloat(t *testing.T) { assert.Equal(t, `{"errors":[{"message":"Float cannot represent non-float value: \"false\"","path":["f"]}],"data":null}`, out.String()) }) t.Run("truncate float", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityTruncateFloatValues: true, }) ctx := NewContext(context.Background()) @@ -1018,7 +1018,7 @@ func TestResolveFloat(t *testing.T) { assert.Equal(t, `{"data":{"f":1}}`, out.String()) }) t.Run("truncate float with decimal place", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityTruncateFloatValues: true, }) ctx := NewContext(context.Background()) @@ -1045,7 +1045,7 @@ func TestResolveFloat(t *testing.T) { func TestResolvable_ValueCompletion(t *testing.T) { t.Run("nested object", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := NewContext(context.Background()) @@ -1143,7 +1143,7 @@ func TestResolvable_ValueCompletion(t *testing.T) { }`) t.Run("nullable", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := NewContext(context.Background()) @@ -1241,7 +1241,7 @@ func TestResolvable_ValueCompletion(t *testing.T) { }) t.Run("mixed nullability", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := NewContext(context.Background()) @@ -1342,7 +1342,7 @@ func TestResolvable_ValueCompletion(t *testing.T) { func TestResolvable_WithTracing(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) background := SetTraceStart(context.Background(), true) ctx := NewContext(background) ctx.TracingOptions.Enable = true diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 14d8ad4b5..00bef245e 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -5,14 +5,18 @@ package resolve import ( "bytes" "context" + "encoding/binary" "fmt" "io" + "net/http" "time" "github.com/buger/jsonparser" "github.com/pkg/errors" "go.uber.org/atomic" + "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/xcontext" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) @@ -69,6 +73,20 @@ type Resolver struct { heartbeatInterval time.Duration // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration + + // resolveArenaPool is the arena pool dedicated for Loader & Resolvable + // ArenaPool automatically adjusts arena buffer sizes per workload + // resolving & response buffering are very different tasks + // as such, it was best to have two arena pools in terms of memory usage + // A single pool for both was much less efficient + resolveArenaPool *ArenaPool + // responseBufferPool is the arena pool dedicated for response buffering before sending to the client + responseBufferPool *ArenaPool + + // subgraphRequestSingleFlight is used to de-duplicate subgraph requests + subgraphRequestSingleFlight *SubgraphRequestSingleFlight + // inboundRequestSingleFlight is used to de-duplicate subgraph requests + inboundRequestSingleFlight *InboundRequestSingleFlight } func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) { @@ -167,6 +185,8 @@ type ResolverOptions struct { PropagateFetchReasons bool ValidateRequiredExternalFields bool + + Caches map[string]LoaderCache } // New returns a new Resolver. ctx.Done() is used to cancel all active subscriptions and streams. @@ -222,6 +242,10 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { allowedErrorFields: allowedErrorFields, heartbeatInterval: options.SubscriptionHeartbeatInterval, maxSubscriptionFetchTimeout: options.MaxSubscriptionFetchTimeout, + resolveArenaPool: NewArenaPool(), + responseBufferPool: NewArenaPool(), + subgraphRequestSingleFlight: NewSingleFlight(8), + inboundRequestSingleFlight: NewRequestSingleFlight(8), } resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) for i := 0; i < options.MaxConcurrency; i++ { @@ -233,9 +257,9 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { return resolver } -func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}) *tools { +func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SubgraphRequestSingleFlight, a arena.Arena) *tools { return &tools{ - resolvable: NewResolvable(options.ResolvableOptions), + resolvable: NewResolvable(a, options.ResolvableOptions), loader: &Loader{ propagateSubgraphErrors: options.PropagateSubgraphErrors, propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, @@ -251,6 +275,9 @@ func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{ apolloRouterCompatibilitySubrequestHTTPError: options.ApolloRouterCompatibilitySubrequestHTTPError, propagateFetchReasons: options.PropagateFetchReasons, validateRequiredExternalFields: options.ValidateRequiredExternalFields, + sf: sf, + jsonArena: a, + caches: options.Caches, }, } } @@ -269,7 +296,7 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons r.maxConcurrency <- struct{}{} }() - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, nil) err := t.resolvable.Init(ctx, data, response.Info.OperationType) if err != nil { @@ -291,6 +318,72 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons return resp, err } +func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLResponse, writer io.Writer) (*GraphQLResolveInfo, error) { + resp := &GraphQLResolveInfo{} + + inflight, err := r.inboundRequestSingleFlight.GetOrCreate(ctx, response) + if err != nil { + return nil, err + } + + if inflight != nil && inflight.Data != nil { // follower + _, err = writer.Write(inflight.Data) + return resp, err + } + + start := time.Now() + <-r.maxConcurrency + resp.ResolveAcquireWaitTime = time.Since(start) + defer func() { + r.maxConcurrency <- struct{}{} + }() + + resolveArena := r.resolveArenaPool.Acquire(ctx.Request.ID) + // we're intentionally not using defer Release to have more control over the timing (see below) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, resolveArena.Arena) + + err = t.resolvable.Init(ctx, nil, response.Info.OperationType) + if err != nil { + r.inboundRequestSingleFlight.FinishErr(inflight, err) + r.resolveArenaPool.Release(resolveArena) + return nil, err + } + + if !ctx.ExecutionOptions.SkipLoader { + err = t.loader.LoadGraphQLResponseData(ctx, response, t.resolvable) + if err != nil { + r.inboundRequestSingleFlight.FinishErr(inflight, err) + r.resolveArenaPool.Release(resolveArena) + return nil, err + } + } + + // only when loading is done, acquire an arena for the response buffer + responseArena := r.responseBufferPool.Acquire(ctx.Request.ID) + buf := arena.NewArenaBuffer(responseArena.Arena) + err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, buf) + if err != nil { + r.inboundRequestSingleFlight.FinishErr(inflight, err) + r.resolveArenaPool.Release(resolveArena) + r.responseBufferPool.Release(responseArena) + return nil, err + } + + // first release resolverArena + // all data is resolved and written into the response arena + r.resolveArenaPool.Release(resolveArena) + // next we write back to the client + // this includes flushing and syscalls + // as such, it can take some time + // which is why we split the arenas and released the first one + _, err = writer.Write(buf.Bytes()) + r.inboundRequestSingleFlight.FinishOk(inflight, buf.Bytes()) + // all data is written to the client + // we're safe to release our buffer + r.responseBufferPool.Release(responseArena) + return resp, err +} + type trigger struct { id uint64 cancel context.CancelFunc @@ -421,9 +514,11 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar input := make([]byte, len(sharedInput)) copy(input, sharedInput) - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + resolveArena := r.resolveArenaPool.Acquire(resolveCtx.Request.ID) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, resolveArena.Arena) if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { + r.resolveArenaPool.Release(resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:init:failed:%d\n", sub.id.SubscriptionID) @@ -435,6 +530,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } if err := t.loader.LoadGraphQLResponseData(resolveCtx, sub.resolve.Response, t.resolvable); err != nil { + r.resolveArenaPool.Release(resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:load:failed:%d\n", sub.id.SubscriptionID) @@ -446,6 +542,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } if err := t.resolvable.Resolve(resolveCtx.ctx, sub.resolve.Response.Data, sub.resolve.Response.Fetches, sub.writer); err != nil { + r.resolveArenaPool.Release(resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:resolve:failed:%d\n", sub.id.SubscriptionID) @@ -456,6 +553,8 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar return } + r.resolveArenaPool.Release(resolveArena) + if err := sub.writer.Flush(); err != nil { // If flush fails (e.g. client disconnected), remove the subscription. _ = r.AsyncUnsubscribeSubscription(sub.id) @@ -656,9 +755,9 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) fmt.Printf("resolver:trigger:start:%d\n", triggerID) } if asyncDataSource != nil { - err = asyncDataSource.AsyncStart(cloneCtx, triggerID, add.input, updater) + err = asyncDataSource.AsyncStart(cloneCtx, triggerID, add.headers, add.input, updater) } else { - err = add.resolve.Trigger.Source.Start(cloneCtx, add.input, updater) + err = add.resolve.Trigger.Source.Start(cloneCtx, add.headers, add.input, updater) } if err != nil { if r.options.Debug { @@ -1001,6 +1100,24 @@ func (r *Resolver) AsyncUnsubscribeClient(connectionID int64) error { return nil } +// prepareTrigger safely gets the headers for the trigger Subgraph and computes the hash across headers and input +// the generated has is the unique triggerID +// the headers must be forwarded to the DataSource to create the trigger +func (r *Resolver) prepareTrigger(ctx *Context, sourceName string, input []byte) (headers http.Header, triggerID uint64) { + if ctx.SubgraphHeadersBuilder != nil { + header, headerHash := ctx.SubgraphHeadersBuilder.HeadersForSubgraph(sourceName) + keyGen := pool.Hash64.Get() + _, _ = keyGen.Write(input) + var b [8]byte + binary.LittleEndian.PutUint64(b[:], headerHash) + _, _ = keyGen.Write(b[:]) + triggerID = keyGen.Sum64() + pool.Hash64.Put(keyGen) + return header, triggerID + } + return nil, 0 +} + func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQLSubscription, writer SubscriptionResponseWriter) error { if subscription.Trigger.Source == nil { return errors.New("no data source found") @@ -1014,7 +1131,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, nil) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { @@ -1038,20 +1155,13 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ return nil } - xxh := pool.Hash64.Get() - defer pool.Hash64.Put(xxh) - err = subscription.Trigger.Source.UniqueRequestID(ctx, input, xxh) - if err != nil { - msg := []byte(`{"errors":[{"message":"unable to resolve"}]}`) - return writeFlushComplete(writer, msg) - } - uniqueID := xxh.Sum64() + headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) id := SubscriptionIdentifier{ ConnectionID: ConnectionIDs.Inc(), SubscriptionID: 0, } if r.options.Debug { - fmt.Printf("resolver:trigger:subscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID) + fmt.Printf("resolver:trigger:subscribe:sync:%d:%d\n", triggerID, id.SubscriptionID) } completed := make(chan struct{}) @@ -1061,15 +1171,17 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ // Stop processing if the resolver is shutting down return r.ctx.Err() case r.events <- subscriptionEvent{ - triggerID: uniqueID, + triggerID: triggerID, kind: subscriptionEventKindAddSubscription, addSubscription: &addSubscription{ - ctx: ctx, - input: input, - resolve: subscription, - writer: writer, - id: id, - completed: completed, + ctx: ctx, + input: input, + resolve: subscription, + writer: writer, + id: id, + completed: completed, + sourceName: subscription.Trigger.SourceName, + headers: headers, }, }: } @@ -1096,13 +1208,13 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ } if r.options.Debug { - fmt.Printf("resolver:trigger:unsubscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID) + fmt.Printf("resolver:trigger:unsubscribe:sync:%d:%d\n", triggerID, id.SubscriptionID) } // Remove the subscription when the client disconnects. r.events <- subscriptionEvent{ - triggerID: uniqueID, + triggerID: triggerID, kind: subscriptionEventKindRemoveSubscription, id: id, } @@ -1123,7 +1235,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, nil) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { @@ -1147,13 +1259,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G return nil } - xxh := pool.Hash64.Get() - defer pool.Hash64.Put(xxh) - err = subscription.Trigger.Source.UniqueRequestID(ctx, input, xxh) - if err != nil { - msg := []byte(`{"errors":[{"message":"unable to resolve"}]}`) - return writeFlushComplete(writer, msg) - } + headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) select { case <-r.ctx.Done(): @@ -1163,15 +1269,17 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // Stop resolving if the client is gone return ctx.ctx.Err() case r.events <- subscriptionEvent{ - triggerID: xxh.Sum64(), + triggerID: triggerID, kind: subscriptionEventKindAddSubscription, addSubscription: &addSubscription{ - ctx: ctx, - input: input, - resolve: subscription, - writer: writer, - id: id, - completed: make(chan struct{}), + ctx: ctx, + input: input, + resolve: subscription, + writer: writer, + id: id, + completed: make(chan struct{}), + sourceName: subscription.Trigger.SourceName, + headers: headers, }, }: } @@ -1279,12 +1387,14 @@ type subscriptionEvent struct { } type addSubscription struct { - ctx *Context - input []byte - resolve *GraphQLSubscription - writer SubscriptionResponseWriter - id SubscriptionIdentifier - completed chan struct{} + ctx *Context + input []byte + resolve *GraphQLSubscription + writer SubscriptionResponseWriter + id SubscriptionIdentifier + completed chan struct{} + sourceName string + headers http.Header } type subscriptionEventKind int diff --git a/v2/pkg/engine/resolve/resolve_caching_test.go b/v2/pkg/engine/resolve/resolve_caching_test.go new file mode 100644 index 000000000..a0f9ae7be --- /dev/null +++ b/v2/pkg/engine/resolve/resolve_caching_test.go @@ -0,0 +1,146 @@ +package resolve + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" +) + +func TestResolveCaching(t *testing.T) { + t.Run("nested batching single root result", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + + listingRoot := mockedDS(t, ctrl, + `{"method":"POST","url":"http://listing","body":{"query":"query{listing{__typename id name}}"}}`, + `{"data":{"listing":{"__typename":"Listing","id":1,"name":"L1"}}}`) + + nested := mockedDS(t, ctrl, + `{"method":"POST","url":"http://nested","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Listing { nested { id price listing { __typename id }} }}}","variables":{"representations":[{"__typename":"Listing","id":1}]}}}`, + `{"data":{"_entities":[{"__typename":"Listing","nested":{"id":1.1,"price":123,"listing":{"__typename":"Listing","id":1}}}]}}`) + + return &GraphQLResponse{ + Fetches: Sequence( + SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://listing","body":{"query":"query{listing{__typename id name}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: listingRoot, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + }, + }, + }, "query"), + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://nested","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Listing { nested { id price listing { __typename id }} }}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + }, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + }, + DataSource: nested, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "query.listing", ObjectPath("listing")), + ), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("listing"), + Value: &Object{ + Path: []string{"listing"}, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("nested"), + Value: &Object{ + Path: []string{"nested"}, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Float{ + Path: []string{"id"}, + }, + }, + { + Name: []byte("price"), + Value: &Integer{ + Path: []string{"price"}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background(), Variables: nil}, `{"data":{"listing":{"id":1,"name":"L1","nested":{"id":1.1,"price":123}}}}` + })) +} diff --git a/v2/pkg/engine/resolve/resolve_federation_test.go b/v2/pkg/engine/resolve/resolve_federation_test.go index 2547c6d10..1c32db689 100644 --- a/v2/pkg/engine/resolve/resolve_federation_test.go +++ b/v2/pkg/engine/resolve/resolve_federation_test.go @@ -1,9 +1,8 @@ package resolve import ( - "bytes" "context" - "io" + "net/http" "testing" "github.com/golang/mock/gomock" @@ -21,18 +20,11 @@ func mockedDS(t TestingTB, ctrl *gomock.Controller, expectedInput, responseData t.Helper() service := NewMockDataSource(ctrl) service.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - expected := expectedInput - - require.Equal(t, expected, actual) - - pair := NewBufPair() - pair.Data.WriteString(responseData) - - return writeGraphqlResponse(pair, w, false) - }).AnyTimes() + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + require.Equal(t, expectedInput, string(input)) + return []byte(responseData), nil + }).Times(1) return service } @@ -48,7 +40,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { DataSource: mockedDS( t, ctrl, `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {__typename id info {a b}}}}"}}`, - `{"user":{"account":{"__typename":"Account","id":"1234","info":{"a":"foo","b":"bar"}}}}`, + `{"data":{"user":{"account":{"__typename":"Account","id":"1234","info":{"a":"foo","b":"bar"}}}}}`, ), Input: `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {__typename id info {a b}}}}"}}`, PostProcessing: PostProcessingConfiguration{ @@ -70,7 +62,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { DataSource: mockedDS( t, ctrl, expectedAccountsQuery, - `{"_entities":[{"__typename":"Account","name":"John Doe","shippingInfo":{"zip":"12345"}}]}`, + `{"data":{"_entities":[{"__typename":"Account","name":"John Doe","shippingInfo":{"zip":"12345"}}]}}`, ), Input: `{"method":"POST","url":"http://account.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Account {name shippingInfo {zip}}}}","variables":{"representations":$$0$$}}}`, PostProcessing: PostProcessingConfiguration{ @@ -182,38 +174,38 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("federation with shareable", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { firstService := NewMockDataSource(ctrl) firstService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://first.service","body":{"query":"{me {details {forename middlename} __typename id}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"me": {"__typename": "User", "id": "1234", "details": {"forename": "John", "middlename": "A"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"me": {"__typename": "User", "id": "1234", "details": {"forename": "John", "middlename": "A"}}}}`) + return pair.Data.Bytes(), nil }) secondService := NewMockDataSource(ctrl) secondService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://second.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {details {surname}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename": "User", "details": {"surname": "Smith"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities": [{"__typename": "User", "details": {"surname": "Smith"}}]}}`) + return pair.Data.Bytes(), nil }) thirdService := NewMockDataSource(ctrl) thirdService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://third.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {details {age}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename": "User", "details": {"age": 21}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities": [{"__typename": "User", "details": {"age": 21}}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -377,26 +369,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name infoOrAddress { ... on Info {id __typename} ... on Address {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Info"},{"id": 55,"__typename":"Address"}]}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Info"},{"id": 55,"__typename":"Address"}]}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age } ... on Address { line1 }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":55,"__typename":"Address"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"line1":"Munich","__typename":"Address"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"line1":"Munich","__typename":"Address"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -530,19 +522,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name infoOrAddress { ... on Info {id __typename} ... on Address {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Whatever"},{"id": 55,"__typename":"Whatever"}]}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Whatever"},{"id": 55,"__typename":"Whatever"}]}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -675,26 +667,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching on a field", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":12,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"},{"age":23,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"},{"age":23,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -819,26 +811,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with duplicates", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":11,"__typename":"Info"}},{"name":"Jane","info":{"id":11,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":11,"__typename":"Info"}},{"name":"Jane","info":{"id":11,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":77,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":77,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -960,26 +952,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with null entry", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":null},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":null},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"age":23,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"age":23,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -1105,19 +1097,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with all null entries", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":null},{"name":"John","info":null},{"name":"Jane","info":null}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":null},{"name":"John","info":null},{"name":"Jane","info":null}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1243,27 +1235,27 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with render error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() // render error - first item id is boolean - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":true,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":true,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":12,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -1390,26 +1382,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("all data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":{"id":11,"__typename":"Info"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":{"id":11,"__typename":"Info"}}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -1524,19 +1516,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("null info data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":null}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":null}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1652,19 +1644,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("wrong type data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":{"id":false,"__typename":"Info"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":{"id":false,"__typename":"Info"}}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1780,19 +1772,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("not matching type data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":{"id":1,"__typename":"Whatever"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":{"id":1,"__typename":"Whatever"}}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1912,19 +1904,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { user := mockedDS(t, ctrl, `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {address {__typename id line1 line2}}}}"}}`, - `{"user":{"account":{"address":{"__typename":"Address","id":"address-1","line1":"line1","line2":"line2"}}}}`) + `{"data":{"user":{"account":{"address":{"__typename":"Address","id":"address-1","line1":"line1","line2":"line2"}}}}}`) addressEnricher := mockedDS(t, ctrl, `{"method":"POST","url":"http://address-enricher.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Address {country city}}}","variables":{"representations":[{"__typename":"Address","id":"address-1"}]}}}`, - `{"__typename":"Address","country":"country-1","city":"city-1"}`) + `{"data":{"__typename":"Address","country":"country-1","city":"city-1"}}`) address := mockedDS(t, ctrl, `{"method":"POST","url":"http://address.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Address {line3(test: "BOOM") zip}}}","variables":{"representations":[{"__typename":"Address","id":"address-1","country":"country-1","city":"city-1"}]}}}`, - `{"__typename": "Address", "line3": "line3-1", "zip": "zip-1"}`) + `{"data":{"__typename": "Address", "line3": "line3-1", "zip": "zip-1"}}`) account := mockedDS(t, ctrl, `{"method":"POST","url":"http://account.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Address {fullAddress}}}","variables":{"representations":[{"__typename":"Address","id":"address-1","line1":"line1","line2":"line2","line3":"line3-1","zip":"zip-1"}]}}}`, - `{"__typename":"Address","fullAddress":"line1 line2 line3-1 city-1 country-1 zip-1"}`) + `{"data":{"__typename":"Address","fullAddress":"line1 line2 line3-1 city-1 country-1 zip-1"}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2152,19 +2144,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + `{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"}]}`) + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2424,19 +2416,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"}]}}}`, - `{"_entities":[{"stock":8}]}`) + `{"data":{"_entities":[{"stock":8}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"}]}}}`, - `{"_entities":[{"name":"user-1"}]}`) + `{"data":{"_entities":[{"name":"user-1"}]}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2696,11 +2688,11 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { accountsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://accounts","body":{"query":"{accounts{__typename ... on User {__typename id} ... on Moderator {__typename moderatorID} ... on Admin {__typename adminID}}}"}}`, - `{"accounts":[{"__typename":"User","id":"3"},{"__typename":"Admin","adminID":"2"},{"__typename":"Moderator","moderatorID":"1"}]}`) + `{"data":{"accounts":[{"__typename":"User","id":"3"},{"__typename":"Admin","adminID":"2"},{"__typename":"Moderator","moderatorID":"1"}]}}`) namesService := mockedDS(t, ctrl, `{"method":"POST","url":"http://names","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name} ... on Moderator {subject} ... on Admin {type}}}","variables":{"representations":[{"__typename":"User","id":"3"},{"__typename":"Admin","adminID":"2"},{"__typename":"Moderator","moderatorID":"1"}]}}}`, - `{"_entities":[{"__typename":"User","name":"User"},{"__typename":"Admin","type":"super"},{"__typename":"Moderator","subject":"posts"}]}`) + `{"data":{"_entities":[{"__typename":"User","name":"User"},{"__typename":"Admin","type":"super"},{"__typename":"Moderator","subject":"posts"}]}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2836,11 +2828,11 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { accountsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://accounts","body":{"query":"{accounts {__typename ... on User {some {__typename id}} ... on Admin {some {__typename id}}}}"}}`, - `{"accounts":[{"__typename":"User","some":{"__typename":"User","id":"1"}},{"__typename":"Admin","some":{"__typename":"User","id":"2"}},{"__typename":"User","some":{"__typename":"User","id":"3"}}]}`) + `{"data":{"accounts":[{"__typename":"User","some":{"__typename":"User","id":"1"}},{"__typename":"Admin","some":{"__typename":"User","id":"2"}},{"__typename":"User","some":{"__typename":"User","id":"3"}}]}}`) namesService := mockedDS(t, ctrl, `{"method":"POST","url":"http://names","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {__typename title}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"3"}]}}}`, - `{"_entities":[{"__typename":"User","title":"User1"},{"__typename":"User","title":"User3"}]}`) + `{"data":{"_entities":[{"__typename":"User","title":"User1"},{"__typename":"User","title":"User3"}]}}`) return &GraphQLResponse{ Fetches: Sequence( diff --git a/v2/pkg/engine/resolve/resolve_mock_test.go b/v2/pkg/engine/resolve/resolve_mock_test.go index 3f72cc3d8..a64b7dd83 100644 --- a/v2/pkg/engine/resolve/resolve_mock_test.go +++ b/v2/pkg/engine/resolve/resolve_mock_test.go @@ -5,8 +5,8 @@ package resolve import ( - bytes "bytes" context "context" + http "net/http" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -37,11 +37,12 @@ func (m *MockDataSource) EXPECT() *MockDataSourceMockRecorder { } // Load mocks base method. -func (m *MockDataSource) Load(arg0 context.Context, arg1 []byte, arg2 *bytes.Buffer) error { +func (m *MockDataSource) Load(arg0 context.Context, arg1 http.Header, arg2 []byte) ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Load", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // Load indicates an expected call of Load. @@ -51,11 +52,12 @@ func (mr *MockDataSourceMockRecorder) Load(arg0, arg1, arg2 interface{}) *gomock } // LoadWithFiles mocks base method. -func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 []byte, arg2 []*httpclient.FileUpload, arg3 *bytes.Buffer) error { +func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 http.Header, arg2 []byte, arg3 []*httpclient.FileUpload) ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LoadWithFiles", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // LoadWithFiles indicates an expected call of LoadWithFiles. diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 8e15ff98a..112776037 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -13,7 +13,6 @@ import ( "testing" "time" - "github.com/cespare/xxhash/v2" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -32,7 +31,7 @@ type _fakeDataSource struct { artificialLatency time.Duration } -func (f *_fakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (f *_fakeDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -41,11 +40,10 @@ func (f *_fakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buf require.Equal(f.t, string(f.input), string(input), "input mismatch") } } - _, err = out.Write(f.data) - return + return f.data, nil } -func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -54,8 +52,7 @@ func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files require.Equal(f.t, string(f.input), string(input), "input mismatch") } } - _, err = out.Write(f.data) - return + return f.data, nil } func FakeDataSource(data string) *_fakeDataSource { @@ -351,12 +348,11 @@ func TestResolver_ResolveNode(t *testing.T) { t.Run("fetch with context variable resolver", testFn(true, func(t *testing.T, ctrl *gomock.Controller) (response *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), []byte(`{"id":1}`), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { - _, err = w.Write([]byte(`{"name":"Jens"}`)) - return + Load(gomock.Any(), gomock.Any(), []byte(`{"id":1}`)). + Do(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"name":"Jens"}`), nil }). - Return(nil) + Return([]byte(`{"name":"Jens"}`), nil) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ FetchConfiguration: FetchConfiguration{ @@ -1802,11 +1798,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error without datasource ID", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -1834,11 +1828,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error without datasource ID no subgraph error forwarding", testFnNoSubgraphErrorForwarding(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -1866,11 +1858,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -1902,11 +1892,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error in pass through Subgraph Error Mode", testFnSubgraphErrorsPassthrough(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -1938,10 +1926,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with pass through mode and omit custom fields", testFnSubgraphErrorsPassthroughAndOmitCustomFields(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) error { - _, err := w.Write([]byte(`{"errors":[{"message":"errorMessage","longMessage":"This is a long message","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":{"name":null}}`)) - return err + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","longMessage":"This is a long message","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":{"name":null}}`), nil }) return &GraphQLResponse{ Info: &GraphQLResponseInfo{ @@ -1976,9 +1963,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err (with DataSourceID)", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - return &net.AddrError{} + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return nil, &net.AddrError{} }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -2010,9 +1997,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err (no DataSourceID)", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - return &net.AddrError{} + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return nil, &net.AddrError{} }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -2040,9 +2027,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err and non-nullable root field", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - return &net.AddrError{} + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return nil, &net.AddrError{} }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -2218,14 +2205,10 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with two Errors", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage1"), nil, nil, nil) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) - }). - Return(nil) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage1"},{"message":"errorMessage2"}]}`), nil + }).Times(1) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ FetchConfiguration: FetchConfiguration{ @@ -2578,39 +2561,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("complex GraphQL Server plan", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { serviceOne := NewMockDataSource(ctrl) serviceOne.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.one","body":{"query":"query($firstArg: String, $thirdArg: Int){serviceOne(serviceOneArg: $firstArg){fieldOne} anotherServiceOne(anotherServiceOneArg: $thirdArg){fieldOne} reusingServiceOne(reusingServiceOneArg: $firstArg){fieldOne}}","variables":{"thirdArg":123,"firstArg":"firstArgValue"}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"serviceOne":{"fieldOne":"fieldOneValue"},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"serviceOne":{"fieldOne":"fieldOneValue"},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}}`), nil }) serviceTwo := NewMockDataSource(ctrl) serviceTwo.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.two","body":{"query":"query($secondArg: Boolean, $fourthArg: Float){serviceTwo(serviceTwoArg: $secondArg){fieldTwo} secondServiceTwo(secondServiceTwoArg: $fourthArg){fieldTwo}}","variables":{"fourthArg":12.34,"secondArg":true}}}` assert.Equal(t, expected, actual) - - pair := NewBufPair() - pair.Data.WriteString(`{"serviceTwo":{"fieldTwo":"fieldTwoValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"serviceTwo":{"fieldTwo":"fieldTwoValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"}}}`), nil }) nestedServiceOne := NewMockDataSource(ctrl) nestedServiceOne.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.one","body":{"query":"{serviceOne {fieldOne}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"serviceOne":{"fieldOne":"fieldOneValue"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"serviceOne":{"fieldOne":"fieldOneValue"}}}`), nil }) return &GraphQLResponse{ @@ -2817,259 +2793,35 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, Context{ctx: context.Background(), Variables: astjson.MustParseBytes([]byte(`{"firstArg":"firstArgValue","thirdArg":123,"secondArg": true, "fourthArg": 12.34}`))}, `{"data":{"serviceOne":{"fieldOne":"fieldOneValue"},"serviceTwo":{"fieldTwo":"fieldTwoValue","serviceOneResponse":{"fieldOne":"fieldOneValue"}},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}}` })) t.Run("federation", func(t *testing.T) { - t.Run("simple", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - - userService := NewMockDataSource(ctrl) - userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` - assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename":"User"}}`) - return writeGraphqlResponse(pair, w, false) - }) - - reviewsService := NewMockDataSource(ctrl) - reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - // {"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":["id":"1234","__typename":"User"]}}} - expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` - assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) - }) - - var productServiceCallCount atomic.Int64 - - productService := NewMockDataSource(ctrl) - productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - productServiceCallCount.Add(1) - switch actual { - case `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"}]}}}`: - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"name": "Furby"}]}`) - return writeGraphqlResponse(pair, w, false) - case `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-2","__typename":"Product"}]}}}`: - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"name": "Trilby"}]}`) - return writeGraphqlResponse(pair, w, false) - default: - t.Fatalf("unexpected request: %s", actual) - } - return - }). - Return(nil).Times(2) - - return &GraphQLResponse{ - Fetches: Sequence( - SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - FetchConfiguration: FetchConfiguration{ - DataSource: userService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data"}, - }, - }, - }, "query"), - SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[`), - SegmentType: StaticSegmentType, - }, - { - SegmentType: VariableSegmentType, - VariableKind: ResolvableObjectVariableKind, - Renderer: NewGraphQLVariableResolveRenderer(&Object{ - Fields: []*Field{ - { - Name: []byte("id"), - Value: &String{ - Path: []string{"id"}, - }, - }, - { - Name: []byte("__typename"), - Value: &String{ - Path: []string{"__typename"}, - }, - }, - }, - }), - }, - { - Data: []byte(`]}}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - FetchConfiguration: FetchConfiguration{ - DataSource: reviewsService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "0"}, - }, - }, - }, "query.me", ObjectPath("me")), - SingleWithPath(&ParallelListItemFetch{ - Fetch: &SingleFetch{ - FetchConfiguration: FetchConfiguration{ - DataSource: productService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "0"}, - }, - }, - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[`), - SegmentType: StaticSegmentType, - }, - { - SegmentType: VariableSegmentType, - VariableKind: ResolvableObjectVariableKind, - Renderer: NewGraphQLVariableResolveRenderer(&Object{ - Fields: []*Field{ - { - Name: []byte("upc"), - Value: &String{ - Path: []string{"upc"}, - }, - }, - { - Name: []byte("__typename"), - Value: &String{ - Path: []string{"__typename"}, - }, - }, - }, - }), - }, - { - Data: []byte(`]}}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - }, - }, "query.me.reviews.@.product", ObjectPath("me"), ArrayPath("reviews"), ObjectPath("product")), - ), - Data: &Object{ - Fields: []*Field{ - { - Name: []byte("me"), - Value: &Object{ - Path: []string{"me"}, - Nullable: true, - Fields: []*Field{ - { - Name: []byte("id"), - Value: &String{ - Path: []string{"id"}, - }, - }, - { - Name: []byte("username"), - Value: &String{ - Path: []string{"username"}, - }, - }, - { - - Name: []byte("reviews"), - Value: &Array{ - Path: []string{"reviews"}, - Nullable: true, - Item: &Object{ - Nullable: true, - Fields: []*Field{ - { - Name: []byte("body"), - Value: &String{ - Path: []string{"body"}, - }, - }, - { - Name: []byte("product"), - Value: &Object{ - Path: []string{"product"}, - Fields: []*Field{ - { - Name: []byte("upc"), - Value: &String{ - Path: []string{"upc"}, - }, - }, - { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, Context{ctx: context.Background(), Variables: nil}, `{"data":{"me":{"id":"1234","username":"Me","reviews":[{"body":"A highly effective form of birth control.","product":{"upc":"top-1","name":"Furby"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","name":"Trilby"}}]}}}` - })) t.Run("federation with batch", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }) return &GraphQLResponse{ @@ -3241,38 +2993,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with merge paths", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }) return &GraphQLResponse{ @@ -3445,45 +3191,39 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with null response", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews": [ + return []byte(`{"data":{"_entities":[{"reviews": [ {"body": "foo","product": {"upc": "top-1","__typename": "Product"}}, {"body": "bar","product": {"upc": "top-2","__typename": "Product"}}, {"body": "baz","product": null}, {"body": "bat","product": {"upc": "top-4","__typename": "Product"}}, {"body": "bal","product": {"upc": "top-5","__typename": "Product"}}, {"body": "ban","product": {"upc": "top-6","__typename": "Product"}} -]}]}`) - return writeGraphqlResponse(pair, w, false) +]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"},{"upc":"top-4","__typename":"Product"},{"upc":"top-5","__typename":"Product"},{"upc":"top-6","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"name":"Trilby"},{"name":"Fedora"},{"name":"Boater"},{"name":"Top Hat"},{"name":"Bowler"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"name":"Trilby"},{"name":"Fedora"},{"name":"Boater"},{"name":"Top Hat"},{"name":"Bowler"}]}}`), nil }) return &GraphQLResponse{ @@ -3678,38 +3418,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me": {"id": "1234","username": "Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -3871,38 +3605,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me": {"id": "1234","username": "Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -4061,38 +3789,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with optional variable", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8080/query","body":{"query":"{me {id}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","__typename":"User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","__typename":"User"}}}`), nil }) employeeService := NewMockDataSource(ctrl) employeeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8081/query","body":{"query":"query($representations: [_Any!]!, $companyId: ID!){_entities(representations: $representations){... on User {employment(companyId: $companyId){id}}}}","variables":{"companyId":"abc123","representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"employment":{"id":"xyz987"}}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"employment":{"id":"xyz987"}}]}}`), nil }) timeService := NewMockDataSource(ctrl) timeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8082/query","body":{"query":"query($representations: [_Any!]!, $date: LocalTime){_entities(representations: $representations){... on Employee {times(date: $date){id employee {id} start end}}}}","variables":{"date":null,"representations":[{"id":"xyz987","__typename":"Employee"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"times":[{"id": "t1","employee":{"id":"xyz987"},"start":"2022-11-02T08:00:00","end":"2022-11-02T12:00:00"}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"times":[{"id": "t1","employee":{"id":"xyz987"},"start":"2022-11-02T08:00:00","end":"2022-11-02T12:00:00"}]}]}}`), nil }) return &GraphQLResponse{ @@ -4263,62 +3985,517 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }) } -func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { - options := apolloCompatibilityOptions{ - valueCompletion: true, - suppressFetchErrors: true, +// testFnArena is a helper function for testing ArenaResolveGraphQLResponse +func testFnArena(fn func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string)) func(t *testing.T) { + return func(t *testing.T) { + t.Helper() + + ctrl := gomock.NewController(t) + rCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := newResolver(rCtx) + node, ctx, expectedOutput := fn(t, ctrl) + + if node.Info == nil { + node.Info = &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + } + } + + if t.Skipped() { + return + } + + buf := &bytes.Buffer{} + _, err := r.ArenaResolveGraphQLResponse(&ctx, node, buf) + assert.NoError(t, err) + assert.Equal(t, expectedOutput, buf.String()) + ctrl.Finish() } - t.Run("simple fetch with fetch error suppression - empty response", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - mockDataSource := NewMockDataSource(ctrl) - mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - _, _ = w.Write([]byte("{}")) - return - }) +} + +func TestResolver_ArenaResolveGraphQLResponse(t *testing.T) { + + t.Run("empty graphql response", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ - Fetches: SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - FetchConfiguration: FetchConfiguration{ - DataSource: mockDataSource, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data"}, - SelectResponseErrorsPath: []string{"errors"}, - }, - }, - }, "query"), Data: &Object{ - Fields: []*Field{ - { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, - }, - }, - }, + Nullable: true, }, - }, Context{ctx: context.Background()}, `{"data":null,"extensions":{"valueCompletion":[{"message":"Cannot return null for non-nullable field Query.name.","path":["name"],"extensions":{"code":"INVALID_GRAPHQL"}}]}}` - }, &options)) + }, Context{ctx: context.Background()}, `{"data":{}}` + })) - t.Run("simple fetch with fetch error suppression - response with error", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - mockDataSource := NewMockDataSource(ctrl) - mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - _, _ = w.Write([]byte(`{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}]}`)) - return - }) + t.Run("simple data source", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ - Fetches: SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"id":"1","name":"Jens","registered":true}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("registered"), + Value: &Boolean{ + Path: []string{"registered"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":"1","name":"Jens","registered":true}}}` + })) + + t.Run("array of strings", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"strings": ["Alex", "true", "123"]}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("strings"), + Value: &Array{ + Path: []string{"strings"}, + Item: &String{ + Nullable: false, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"strings":["Alex","true","123"]}}` + })) + + t.Run("array of objects", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"friends":[{"id":1,"name":"Alex"},{"id":2,"name":"Patric"}]}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("friends"), + Value: &Array{ + Path: []string{"friends"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"friends":[{"id":1,"name":"Alex"},{"id":2,"name":"Patric"}]}}` + })) + + t.Run("nested objects", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"id":"1","name":"Jens","pet":{"name":"Barky","kind":"Dog"}}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("pet"), + Value: &Object{ + Path: []string{"pet"}, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("kind"), + Value: &String{ + Path: []string{"kind"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":"1","name":"Jens","pet":{"name":"Barky","kind":"Dog"}}}}` + })) + + t.Run("scalar types", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"int": 12345, "float": 3.5, "str":"value", "bool": true}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("int"), + Value: &Integer{ + Path: []string{"int"}, + Nullable: false, + }, + }, + { + Name: []byte("float"), + Value: &Float{ + Path: []string{"float"}, + Nullable: false, + }, + }, + { + Name: []byte("str"), + Value: &String{ + Path: []string{"str"}, + Nullable: false, + }, + }, + { + Name: []byte("bool"), + Value: &Boolean{ + Path: []string{"bool"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"int":12345,"float":3.5,"str":"value","bool":true}}` + })) + + t.Run("null field", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("foo"), + Value: &Null{}, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"foo":null}}` + })) + + t.Run("__typename field", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"id":1,"name":"Jannik","__typename":"User"}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + Nullable: false, + IsTypeName: true, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":1,"name":"Jannik","__typename":"User"}}}` + })) + + t.Run("multiple fetches", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"user1":{"id":1,"name":"User1"},"user2":{"id":2,"name":"User2"}}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user1"), + Value: &Object{ + Path: []string{"user1"}, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + { + Name: []byte("user2"), + Value: &Object{ + Path: []string{"user2"}, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"user1":{"id":1,"name":"User1"},"user2":{"id":2,"name":"User2"}}}` + })) + + t.Run("with variables", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any(), []byte(`{"id":1}`)). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"name":"Jens"}`), nil + }) + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: mockDataSource}, + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"id":`), + SegmentType: StaticSegmentType, + }, + { + Data: []byte(`{{.arguments.id}}`), + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"id"}, + Renderer: NewPlainVariableRenderer(), + }, + { + Data: []byte(`}`), + SegmentType: StaticSegmentType, + }, + }, + }, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background(), Variables: astjson.MustParseBytes([]byte(`{"id":1}`))}, `{"data":{"name":"Jens"}}` + })) + + t.Run("error handling", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return nil, errors.New("data source error") + }) + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: mockDataSource}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"errors":[{"message":"Failed to fetch from Subgraph."}],"data":null}` + })) + + t.Run("bigint handling", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"n": 12345, "ns_small": "12346", "ns_big": "1152921504606846976"}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("n"), + Value: &BigInt{ + Path: []string{"n"}, + Nullable: false, + }, + }, + { + Name: []byte("ns_small"), + Value: &BigInt{ + Path: []string{"ns_small"}, + Nullable: false, + }, + }, + { + Name: []byte("ns_big"), + Value: &BigInt{ + Path: []string{"ns_big"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"n":12345,"ns_small":"12346","ns_big":"1152921504606846976"}}` + })) + + t.Run("skip loader", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("static"), + Value: &Null{}, + }, + }, + }, + }, Context{ctx: context.Background(), ExecutionOptions: ExecutionOptions{SkipLoader: true}}, `{"data":null}` + })) +} + +func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { + options := apolloCompatibilityOptions{ + valueCompletion: true, + suppressFetchErrors: true, + } + t.Run("simple fetch with fetch error suppression - empty response", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte("{}"), nil + }) + return &GraphQLResponse{ + Fetches: SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: mockDataSource, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + SelectResponseErrorsPath: []string{"errors"}, + }, + }, + }, "query"), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":null,"extensions":{"valueCompletion":[{"message":"Cannot return null for non-nullable field Query.name.","path":["name"],"extensions":{"code":"INVALID_GRAPHQL"}}]}}` + }, &options)) + + t.Run("simple fetch with fetch error suppression - response with error", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}]}`), nil + }) + return &GraphQLResponse{ + Fetches: SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ { Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), SegmentType: StaticSegmentType, @@ -4349,38 +4526,32 @@ func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { t.Run("complex fetch with fetch error suppression", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me": {"id": "1234","username": "Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -4566,14 +4737,12 @@ func TestResolver_WithHeader(t *testing.T) { ctrl := gomock.NewController(t) fakeService := NewMockDataSource(ctrl) fakeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) assert.Equal(t, "foo", actual) - _, err = w.Write([]byte(`{"bar":"baz"}`)) - return - }). - Return(nil) + return []byte(`{"bar":"baz"}`), nil + }) out := &bytes.Buffer{} res := &GraphQLResponse{ @@ -4639,14 +4808,12 @@ func TestResolver_WithVariableRemapping(t *testing.T) { ctrl := gomock.NewController(t) fakeService := NewMockDataSource(ctrl) fakeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) assert.Equal(t, tc.expectedOutput, actual) - _, err = w.Write([]byte(`{"bar":"baz"}`)) - return - }). - Return(nil) + return []byte(`{"bar":"baz"}`), nil + }) out := &bytes.Buffer{} res := &GraphQLResponse{ @@ -4827,16 +4994,7 @@ func (f *_fakeStream) AwaitIsDone(t *testing.T, timeout time.Duration) { } } -func (f *_fakeStream) UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = fmt.Fprint(xxh, fakeStreamRequestId.Add(1)) - if err != nil { - return - } - _, err = xxh.Write(input) - return -} - -func (f *_fakeStream) Start(ctx *Context, input []byte, updater SubscriptionUpdater) error { +func (f *_fakeStream) Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error { if f.onStart != nil { f.onStart(input) } @@ -5909,50 +6067,353 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { Data: []byte(`{"method":"POST","url":"http://localhost:4000"}`), }, }, - }, - }, - Filter: &SubscriptionFilter{ - In: &SubscriptionFieldFilter{ - FieldPath: []string{"id"}, - Values: []InputTemplate{ - { + }, + }, + Filter: &SubscriptionFilter{ + In: &SubscriptionFieldFilter{ + FieldPath: []string{"id"}, + Values: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: StaticSegmentType, + Data: []byte(`x.`), + }, + { + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"a"}, + Renderer: NewPlainVariableRenderer(), + }, + { + SegmentType: StaticSegmentType, + Data: []byte(`.`), + }, + { + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"b"}, + Renderer: NewPlainVariableRenderer(), + }, + }, + }, + }, + }, + }, + Response: &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("oneUserByID"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + }, + }, + }, + }, + }, + }, + } + + out := &SubscriptionRecorder{ + buf: &bytes.Buffer{}, + messages: []string{}, + complete: atomic.Bool{}, + } + out.complete.Store(false) + + id := SubscriptionIdentifier{ + ConnectionID: 1, + SubscriptionID: 1, + } + + resolver := newResolver(c) + + ctx := &Context{ + ctx: context.Background(), + Variables: astjson.MustParseBytes([]byte(`{"a":[1,2],"b":[3,4]}`)), + } + + err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, out, id) + assert.NoError(t, err) + out.AwaitComplete(t, defaultTimeout) + assert.Equal(t, 4, len(out.Messages())) + assert.ElementsMatch(t, []string{ + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + }, out.Messages()) + }) +} + +func Benchmark_NestedBatching(b *testing.B) { + rCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := newResolver(rCtx) + + productsService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`), + []byte(`{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`)) + stockService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`), + []byte(`{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`)) + reviewsService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`), + []byte(`{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`)) + usersService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`), + []byte(`{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`)) + + plan := &GraphQLResponse{ + Fetches: Sequence( + SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: productsService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + }, + }, + }, ""), + Parallel( + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + }, + DataSource: reviewsService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "topProducts", ArrayPath("topProducts")), + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ Segments: []TemplateSegment{ { + Data: []byte(`{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[`), SegmentType: StaticSegmentType, - Data: []byte(`x.`), }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ { - SegmentType: VariableSegmentType, - VariableKind: ContextVariableKind, - VariableSourcePath: []string{"a"}, - Renderer: NewPlainVariableRenderer(), + Data: []byte(`,`), + SegmentType: StaticSegmentType, }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ { + Data: []byte(`]}}}`), SegmentType: StaticSegmentType, - Data: []byte(`.`), }, + }, + }, + }, + DataSource: stockService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "topProducts", ArrayPath("topProducts")), + ), + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ { - SegmentType: VariableSegmentType, - VariableKind: ContextVariableKind, - VariableSourcePath: []string{"b"}, - Renderer: NewPlainVariableRenderer(), + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + }, + }), }, }, }, }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, }, - }, - Response: &GraphQLResponse{ - Data: &Object{ - Fields: []*Field{ - { - Name: []byte("oneUserByID"), - Value: &Object{ - Fields: []*Field{ - { - Name: []byte("id"), - Value: &String{ - Path: []string{"id"}, + DataSource: usersService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "topProducts.@.reviews.@.author", ArrayPath("topProducts"), ArrayPath("reviews"), ObjectPath("author")), + ), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("topProducts"), + Value: &Array{ + Path: []string{"topProducts"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + { + Name: []byte("stock"), + Value: &Integer{ + Path: []string{"stock"}, + }, + }, + { + Name: []byte("reviews"), + Value: &Array{ + Path: []string{"reviews"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("body"), + Value: &String{ + Path: []string{"body"}, + }, + }, + { + Name: []byte("author"), + Value: &Object{ + Path: []string{"author"}, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + }, + }, + }, + }, }, }, }, @@ -5961,41 +6422,53 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { }, }, }, - } + }, + Info: &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + }, + } - out := &SubscriptionRecorder{ - buf: &bytes.Buffer{}, - messages: []string{}, - complete: atomic.Bool{}, - } - out.complete.Store(false) + expected := []byte(`{"data":{"topProducts":[{"name":"Table","stock":8,"reviews":[{"body":"Love Table!","author":{"name":"user-1"}},{"body":"Prefer other Table.","author":{"name":"user-2"}}]},{"name":"Couch","stock":2,"reviews":[{"body":"Couch Too expensive.","author":{"name":"user-1"}}]},{"name":"Chair","stock":5,"reviews":[{"body":"Chair Could be better.","author":{"name":"user-2"}}]}]}}`) - id := SubscriptionIdentifier{ - ConnectionID: 1, - SubscriptionID: 1, - } + pool := sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, 1024)) + }, + } - resolver := newResolver(c) + ctxPool := sync.Pool{ + New: func() interface{} { + return NewContext(context.Background()) + }, + } - ctx := &Context{ - ctx: context.Background(), - Variables: astjson.MustParseBytes([]byte(`{"a":[1,2],"b":[3,4]}`)), - } + b.ReportAllocs() + b.SetBytes(int64(len(expected))) + b.ResetTimer() - err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, out, id) - assert.NoError(t, err) - out.AwaitComplete(t, defaultTimeout) - assert.Equal(t, 4, len(out.Messages())) - assert.ElementsMatch(t, []string{ - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - }, out.Messages()) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + ctx := ctxPool.Get().(*Context) + buf := pool.Get().(*bytes.Buffer) + ctx.ctx = context.Background() + _, err := resolver.ResolveGraphQLResponse(ctx, plan, nil, buf) + if err != nil { + b.Fatal(err) + } + if !bytes.Equal(expected, buf.Bytes()) { + require.Equal(b, string(expected), buf.String()) + } + + buf.Reset() + pool.Put(buf) + + ctx.Free() + ctxPool.Put(ctx) + } }) } -func Benchmark_NestedBatching(b *testing.B) { +func Benchmark_NestedBatchingArena(b *testing.B) { rCtx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -6293,7 +6766,7 @@ func Benchmark_NestedBatching(b *testing.B) { ctx := ctxPool.Get().(*Context) buf := pool.Get().(*bytes.Buffer) ctx.ctx = context.Background() - _, err := resolver.ResolveGraphQLResponse(ctx, plan, nil, buf) + _, err := resolver.ArenaResolveGraphQLResponse(ctx, plan, buf) if err != nil { b.Fatal(err) } @@ -6310,7 +6783,7 @@ func Benchmark_NestedBatching(b *testing.B) { }) } -func Benchmark_NestedBatchingWithoutChecks(b *testing.B) { +func Benchmark_NoCheckNestedBatching(b *testing.B) { rCtx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/v2/pkg/engine/resolve/response.go b/v2/pkg/engine/resolve/response.go index b98f4c00f..d8af8d017 100644 --- a/v2/pkg/engine/resolve/response.go +++ b/v2/pkg/engine/resolve/response.go @@ -22,6 +22,8 @@ type GraphQLSubscriptionTrigger struct { Source SubscriptionDataSource PostProcessing PostProcessingConfiguration QueryPlan *QueryPlan + SourceName string + SourceID string } // GraphQLResponse contains an ordered tree of fetches and the response shape. @@ -41,6 +43,19 @@ type GraphQLResponse struct { DataSources []DataSourceInfo } +func (g *GraphQLResponse) SingleFlightAllowed() bool { + if g == nil { + return false + } + if g.Info == nil { + return false + } + if g.Info.OperationType == ast.OperationTypeQuery { + return true + } + return false +} + type GraphQLResponseInfo struct { OperationType ast.OperationType } diff --git a/v2/pkg/engine/resolve/subgraph_request_singleflight.go b/v2/pkg/engine/resolve/subgraph_request_singleflight.go new file mode 100644 index 000000000..013d90677 --- /dev/null +++ b/v2/pkg/engine/resolve/subgraph_request_singleflight.go @@ -0,0 +1,192 @@ +package resolve + +import ( + "sync" + + "github.com/cespare/xxhash/v2" +) + +// SubgraphRequestSingleFlight is a sharded, goroutine safe single flight implementation to de-duplicate subgraph requests +// It's hashing the input and adds the pre-computed subgraph headers hash to avoid collisions +// In addition to single flight, it provides size hints to create right-sized buffers for subgraph requests +type SubgraphRequestSingleFlight struct { + shards []singleFlightShard + xxPool *sync.Pool + cleanup chan func() +} + +type singleFlightShard struct { + mu sync.RWMutex + items map[uint64]*SingleFlightItem + sizes map[uint64]*fetchSize +} + +const defaultSingleFlightShardCount = 4 + +// SingleFlightItem is used to communicate between leader and followers +// If an Item for a key doesn't exist, the leader creates and followers can join +type SingleFlightItem struct { + // loaded will be closed by the leader to indicate to followers when the work is done + loaded chan struct{} + // response is the shared result, it must not be modified + response []byte + // err is non nil if the leader produced an error while doing the work + err error + // sizeHint keeps track of the last 50 responses per fetchKey to give an estimate on the size + // this gives a leader a hint on how much space it should pre-allocate for buffers when fetching + // this reduces memory usage + sizeHint int +} + +// fetchSize gives an estimate of required buffer size for a given fetchKey when dividing totalBytes / count +type fetchSize struct { + // count is the number of fetches tracked + count int + // totalBytes is the cumulative bytes across tracked fetches + totalBytes int +} + +func NewSingleFlight(shardCount int) *SubgraphRequestSingleFlight { + if shardCount <= 0 { + shardCount = defaultSingleFlightShardCount + } + s := &SubgraphRequestSingleFlight{ + shards: make([]singleFlightShard, shardCount), + xxPool: &sync.Pool{ + New: func() any { + return xxhash.New() + }, + }, + cleanup: make(chan func()), + } + for i := range s.shards { + s.shards[i] = singleFlightShard{ + items: make(map[uint64]*SingleFlightItem), + sizes: make(map[uint64]*fetchSize), + } + } + return s +} + +// GetOrCreateItem generates a single flight key (100% identical fetches) and a fetchKey (similar fetches, collisions possible but unproblematic) +// and return a SingleFlightItem as well as an indication if it's shared or not +// If shared == false, the caller is a leader +// If shared == true, the caller is a follower +// item.sizeHint can be used to create an optimal buffer for the fetch in case of a leader +// item.err must always be checked +// item.response must never be mutated +func (s *SubgraphRequestSingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { + sfKey, fetchKey = s.keys(fetchItem, input, extraKey) + + // Get shard based on sfKey for items + shard := s.shardFor(sfKey) + + // First, try to get the item with a read lock on its shard + shard.mu.RLock() + item, exists := shard.items[sfKey] + shard.mu.RUnlock() + if exists { + return sfKey, fetchKey, item, true + } + + // If not exists, acquire a write lock to create the item + shard.mu.Lock() + // Double-check if the item was created while acquiring the write lock + item, exists = shard.items[sfKey] + if exists { + shard.mu.Unlock() + return sfKey, fetchKey, item, true + } + + // Create a new item + item = &SingleFlightItem{ + // empty chan to indicate to all followers when we're done (close) + loaded: make(chan struct{}), + } + // Read size hint from the same shard (both items and sizes use the same shard now) + if size, ok := shard.sizes[fetchKey]; ok { + item.sizeHint = size.totalBytes / size.count + } + shard.items[sfKey] = item + shard.mu.Unlock() + return sfKey, fetchKey, item, false +} + +func (s *SubgraphRequestSingleFlight) keys(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64) { + h := s.xxPool.Get().(*xxhash.Digest) + sfKey = s.sfKey(h, fetchItem, input, extraKey) + h.Reset() + fetchKey = s.fetchKey(h, fetchItem) + h.Reset() + s.xxPool.Put(h) + return sfKey, fetchKey +} + +// sfKey returns a key that 100% uniquely identifies a fetch with no collision +// two sfKey are only the same when the fetches are 100% equal +func (s *SubgraphRequestSingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { + if fetchItem != nil && fetchItem.Fetch != nil { + info := fetchItem.Fetch.FetchInfo() + if info != nil { + _, _ = h.WriteString(info.DataSourceID) + _, _ = h.WriteString(":") + } + } + _, _ = h.Write(input) + return h.Sum64() + extraKey // extraKey in this case is the pre-generated hash for the headers +} + +// fetchKey is a less robust key compared to sfKey +// the purpose is to create a key from the DataSourceID and root fields to have less cardinality +// the goal is to get an estimate buffer size for similar fetches +// there's no point in hashing headers or the body for this purpose +func (s *SubgraphRequestSingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { + if fetchItem == nil || fetchItem.Fetch == nil { + return 0 + } + info := fetchItem.Fetch.FetchInfo() + if info == nil { + return 0 + } + _, _ = h.WriteString(info.DataSourceID) + _, _ = h.Write(pipe) + for i := range info.RootFields { + if i != 0 { + _, _ = h.Write(comma) + } + _, _ = h.WriteString(info.RootFields[i].TypeName) + _, _ = h.Write(dot) + _, _ = h.WriteString(info.RootFields[i].FieldName) + } + return h.Sum64() +} + +// Finish is for the leader to mark the SingleFlightItem as "done" +// trigger all followers to look at the err & response of the item +// and to update the size estimates +func (s *SubgraphRequestSingleFlight) Finish(sfKey, fetchKey uint64, item *SingleFlightItem) { + close(item.loaded) + // Update sizes in the same shard as the item (using sfKey to get the shard) + shard := s.shardFor(sfKey) + shard.mu.Lock() + delete(shard.items, sfKey) + if size, ok := shard.sizes[fetchKey]; ok { + if size.count == 50 { + size.count = 1 + size.totalBytes = size.totalBytes / 50 + } + size.count++ + size.totalBytes += len(item.response) + } else { + shard.sizes[fetchKey] = &fetchSize{ + count: 1, + totalBytes: len(item.response), + } + } + shard.mu.Unlock() +} + +func (s *SubgraphRequestSingleFlight) shardFor(key uint64) *singleFlightShard { + idx := int(key % uint64(len(s.shards))) + return &s.shards[idx] +} diff --git a/v2/pkg/engine/resolve/tainted_objects_test.go b/v2/pkg/engine/resolve/tainted_objects_test.go index 0eeb34440..b8205dc72 100644 --- a/v2/pkg/engine/resolve/tainted_objects_test.go +++ b/v2/pkg/engine/resolve/tainted_objects_test.go @@ -70,7 +70,7 @@ func TestSelectObjectAndIndex(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - response, err := astjson.ParseBytesWithoutCache([]byte(tt.responseJSON)) + response, err := astjson.ParseBytes([]byte(tt.responseJSON)) assert.NoError(t, err, "Failed to parse response JSON") // Convert path elements to astjson.Value slice @@ -94,7 +94,7 @@ func TestSelectObjectAndIndex(t *testing.T) { assert.Nil(t, entity, "Expected nil entity") } else { assert.NotNil(t, entity, "Expected non-nil entity") - expectedEntity, err := astjson.ParseBytesWithoutCache([]byte(tt.expectedEntity)) + expectedEntity, err := astjson.ParseBytes([]byte(tt.expectedEntity)) assert.NoError(t, err, "Failed to parse expected entity JSON") // Compare JSON representations @@ -320,10 +320,10 @@ func TestGetTaintedIndices(t *testing.T) { } mockFetch := &mockFetchWithInfo{info: fetchInfo} - response, err := astjson.ParseBytesWithoutCache([]byte(tt.responseJSON)) + response, err := astjson.ParseBytes([]byte(tt.responseJSON)) assert.NoError(t, err, "Failed to parse response JSON") - errors, err := astjson.ParseBytesWithoutCache([]byte(tt.errorsJSON)) + errors, err := astjson.ParseBytes([]byte(tt.errorsJSON)) assert.NoError(t, err, "Failed to parse errors JSON") indices := getTaintedIndices(mockFetch, response, errors) diff --git a/v2/pkg/engine/resolve/variables.go b/v2/pkg/engine/resolve/variables.go index afc00459a..3f54993d9 100644 --- a/v2/pkg/engine/resolve/variables.go +++ b/v2/pkg/engine/resolve/variables.go @@ -11,7 +11,6 @@ const ( ObjectVariableKind HeaderVariableKind ResolvableObjectVariableKind - ListVariableKind ) const ( diff --git a/v2/pkg/engine/resolve/variables_renderer.go b/v2/pkg/engine/resolve/variables_renderer.go index 4cbb471f8..572892557 100644 --- a/v2/pkg/engine/resolve/variables_renderer.go +++ b/v2/pkg/engine/resolve/variables_renderer.go @@ -277,6 +277,82 @@ func (g *GraphQLVariableRenderer) renderGraphQLValue(data *astjson.Value, out io return } +func NewCacheKeyVariableRenderer() *CacheKeyVariableRenderer { + return &CacheKeyVariableRenderer{} +} + +type CacheKeyVariableRenderer struct { +} + +func (g *CacheKeyVariableRenderer) GetKind() string { + return "cacheKey" +} + +// add renderer that renders both variable name and variable value +// before rendering, evaluate if the value contains null values +// if an object contains only null values, set the object to null +// do this recursively until reaching the root of the object + +func (g *CacheKeyVariableRenderer) RenderVariable(ctx context.Context, data *astjson.Value, out io.Writer) error { + return g.renderGraphQLValue(data, out) +} + +func (g *CacheKeyVariableRenderer) renderGraphQLValue(data *astjson.Value, out io.Writer) (err error) { + if data == nil { + _, _ = out.Write(literal.NULL) + return + } + switch data.Type() { + case astjson.TypeString: + b := data.GetStringBytes() + _, _ = out.Write(b) + case astjson.TypeObject: + _, _ = out.Write(literal.LBRACE) + o := data.GetObject() + first := true + o.Visit(func(k []byte, v *astjson.Value) { + if err != nil { + return + } + if !first { + _, _ = out.Write(literal.COMMA) + } else { + first = false + } + _, _ = out.Write(k) + _, _ = out.Write(literal.COLON) + err = g.renderGraphQLValue(v, out) + }) + if err != nil { + return err + } + _, _ = out.Write(literal.RBRACE) + case astjson.TypeNull: + _, _ = out.Write(literal.NULL) + case astjson.TypeTrue: + _, _ = out.Write(literal.TRUE) + case astjson.TypeFalse: + _, _ = out.Write(literal.FALSE) + case astjson.TypeArray: + _, _ = out.Write(literal.LBRACK) + arr := data.GetArray() + for i, value := range arr { + if i > 0 { + _, _ = out.Write(literal.COMMA) + } + err = g.renderGraphQLValue(value, out) + if err != nil { + return err + } + } + _, _ = out.Write(literal.RBRACK) + case astjson.TypeNumber: + b := data.MarshalTo(nil) + _, _ = out.Write(b) + } + return +} + func NewCSVVariableRenderer(arrayValueType JsonRootType) *CSVVariableRenderer { return &CSVVariableRenderer{ Kind: VariableRendererKindCsv, @@ -350,7 +426,7 @@ var ( func (g *GraphQLVariableResolveRenderer) getResolvable() *Resolvable { v := _graphQLVariableResolveRendererPool.Get() if v == nil { - return NewResolvable(ResolvableOptions{}) + return NewResolvable(nil, ResolvableOptions{}) } return v.(*Resolvable) } diff --git a/v2/pkg/fastjsonext/fastjsonext.go b/v2/pkg/fastjsonext/fastjsonext.go index 0480fcbd4..4929e8a96 100644 --- a/v2/pkg/fastjsonext/fastjsonext.go +++ b/v2/pkg/fastjsonext/fastjsonext.go @@ -2,27 +2,28 @@ package fastjsonext import ( "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" ) -func AppendErrorToArray(arena *astjson.Arena, v *astjson.Value, msg string, path []PathElement) { +func AppendErrorToArray(a arena.Arena, v *astjson.Value, msg string, path []PathElement) { if v.Type() != astjson.TypeArray { return } - errorObject := CreateErrorObjectWithPath(arena, msg, path) + errorObject := CreateErrorObjectWithPath(a, msg, path) items, _ := v.Array() - v.SetArrayItem(len(items), errorObject) + v.SetArrayItem(a, len(items), errorObject) } -func AppendErrorWithExtensionsCodeToArray(arena *astjson.Arena, v *astjson.Value, msg, code string, path []PathElement) { +func AppendErrorWithExtensionsCodeToArray(a arena.Arena, v *astjson.Value, msg, code string, path []PathElement) { if v.Type() != astjson.TypeArray { return } - errorObject := CreateErrorObjectWithPath(arena, msg, path) - extensions := arena.NewObject() - extensions.Set("code", arena.NewString(code)) - errorObject.Set("extensions", extensions) + errorObject := CreateErrorObjectWithPath(a, msg, path) + extensions := astjson.ObjectValue(a) + extensions.Set(a, "code", astjson.StringValue(a, code)) + errorObject.Set(a, "extensions", extensions) items, _ := v.Array() - v.SetArrayItem(len(items), errorObject) + v.SetArrayItem(a, len(items), errorObject) } type PathElement struct { @@ -30,29 +31,29 @@ type PathElement struct { Idx int } -func CreateErrorObjectWithPath(arena *astjson.Arena, message string, path []PathElement) *astjson.Value { - errorObject := arena.NewObject() - errorObject.Set("message", arena.NewString(message)) +func CreateErrorObjectWithPath(a arena.Arena, message string, path []PathElement) *astjson.Value { + errorObject := astjson.ObjectValue(a) + errorObject.Set(a, "message", astjson.StringValue(a, message)) if len(path) == 0 { return errorObject } - errorPath := arena.NewArray() + errorPath := astjson.ArrayValue(a) for i := range path { if path[i].Name != "" { - errorPath.SetArrayItem(i, arena.NewString(path[i].Name)) + errorPath.SetArrayItem(a, i, astjson.StringValue(a, path[i].Name)) } else { - errorPath.SetArrayItem(i, arena.NewNumberInt(path[i].Idx)) + errorPath.SetArrayItem(a, i, astjson.IntValue(a, path[i].Idx)) } } - errorObject.Set("path", errorPath) + errorObject.Set(a, "path", errorPath) return errorObject } func PrintGraphQLResponse(data, errors *astjson.Value) string { out := astjson.MustParse(`{}`) if astjson.ValueIsNonNull(errors) { - out.Set("errors", errors) + out.Set(nil, "errors", errors) } - out.Set("data", data) + out.Set(nil, "data", data) return string(out.MarshalTo(nil)) } diff --git a/v2/pkg/fastjsonext/fastjsonext_test.go b/v2/pkg/fastjsonext/fastjsonext_test.go index af4271630..e48a2ad1c 100644 --- a/v2/pkg/fastjsonext/fastjsonext_test.go +++ b/v2/pkg/fastjsonext/fastjsonext_test.go @@ -21,28 +21,28 @@ func TestGetArray(t *testing.T) { func TestAppendErrorWithMessage(t *testing.T) { a := astjson.MustParse(`[]`) - AppendErrorToArray(&astjson.Arena{}, a, "error", nil) + AppendErrorToArray(nil, a, "error", nil) out := a.MarshalTo(nil) require.Equal(t, `[{"message":"error"}]`, string(out)) - AppendErrorToArray(&astjson.Arena{}, a, "error2", []PathElement{{Name: "a"}}) + AppendErrorToArray(nil, a, "error2", []PathElement{{Name: "a"}}) out = a.MarshalTo(nil) require.Equal(t, `[{"message":"error"},{"message":"error2","path":["a"]}]`, string(out)) } func TestCreateErrorObjectWithPath(t *testing.T) { - v := CreateErrorObjectWithPath(&astjson.Arena{}, "my error message", []PathElement{ + v := CreateErrorObjectWithPath(nil, "my error message", []PathElement{ {Name: "a"}, }) out := v.MarshalTo(nil) require.Equal(t, `{"message":"my error message","path":["a"]}`, string(out)) - v = CreateErrorObjectWithPath(&astjson.Arena{}, "my error message", []PathElement{ + v = CreateErrorObjectWithPath(nil, "my error message", []PathElement{ {Name: "a"}, {Idx: 1}, {Name: "b"}, }) out = v.MarshalTo(nil) require.Equal(t, `{"message":"my error message","path":["a",1,"b"]}`, string(out)) - v = CreateErrorObjectWithPath(&astjson.Arena{}, "my error message", []PathElement{ + v = CreateErrorObjectWithPath(nil, "my error message", []PathElement{ {Name: "a"}, {Name: "b"}, }) diff --git a/v2/pkg/variablesvalidation/variablesvalidation.go b/v2/pkg/variablesvalidation/variablesvalidation.go index d9631739e..b1af4f40e 100644 --- a/v2/pkg/variablesvalidation/variablesvalidation.go +++ b/v2/pkg/variablesvalidation/variablesvalidation.go @@ -98,7 +98,7 @@ func (v *VariablesValidator) ValidateWithRemap(operation, definition *ast.Docume func (v *VariablesValidator) Validate(operation, definition *ast.Document, variables []byte) error { v.visitor.definition = definition v.visitor.operation = operation - v.visitor.variables, v.visitor.err = astjson.ParseBytesWithoutCache(variables) + v.visitor.variables, v.visitor.err = astjson.ParseBytes(variables) if v.visitor.err != nil { return v.visitor.err }