From ffa7fb266d5c6d429522ca16205cae9e5efb671f Mon Sep 17 00:00:00 2001 From: Thomas Stromberg Date: Thu, 4 Dec 2025 22:01:06 -0500 Subject: [PATCH] improve mock accuracy --- pkg/datastore/iterator.go | 4 +- pkg/datastore/query.go | 17 +- pkg/mock/mock.go | 612 ++++++++++++++++++++++++++++---------- pkg/mock/mock_test.go | 313 +++++++++++++++++++ 4 files changed, 789 insertions(+), 157 deletions(-) diff --git a/pkg/datastore/iterator.go b/pkg/datastore/iterator.go index 2e39d9f..a79424f 100644 --- a/pkg/datastore/iterator.go +++ b/pkg/datastore/iterator.go @@ -145,8 +145,10 @@ func (it *Iterator) fetch() error { it.index = 0 // Check if there are more results + // MORE_RESULTS_AFTER_LIMIT means we hit the query limit - don't auto-fetch more + // NOT_FINISHED and MORE_RESULTS_AFTER_CURSOR mean we should continue fetching moreResults := result.Batch.MoreResults - it.fetchNext = moreResults == "NOT_FINISHED" || moreResults == "MORE_RESULTS_AFTER_LIMIT" || moreResults == "MORE_RESULTS_AFTER_CURSOR" + it.fetchNext = moreResults == "NOT_FINISHED" || moreResults == "MORE_RESULTS_AFTER_CURSOR" if result.Batch.EndCursor != "" { it.cursor = Cursor(result.Batch.EndCursor) diff --git a/pkg/datastore/query.go b/pkg/datastore/query.go index ada56ad..6a0ef2a 100644 --- a/pkg/datastore/query.go +++ b/pkg/datastore/query.go @@ -384,7 +384,7 @@ func (c *Client) AllKeys(ctx context.Context, q *Query) ([]*Key, error) { } // GetAll retrieves all entities matching the query and stores them in dst. -// dst must be a pointer to a slice of structs. +// dst must be a pointer to a slice of structs, or nil for KeysOnly queries. // Returns the keys of the retrieved entities and any error. // This matches the API of cloud.google.com/go/datastore. func (c *Client) GetAll(ctx context.Context, query *Query, dst any) ([]*Key, error) { @@ -431,6 +431,21 @@ func (c *Client) GetAll(ctx context.Context, query *Query, dst any) ([]*Key, err return nil, fmt.Errorf("failed to parse response: %w", err) } + // For KeysOnly queries, dst can be nil - just return keys + if query.keysOnly && dst == nil { + keys := make([]*Key, 0, len(result.Batch.EntityResults)) + for _, er := range result.Batch.EntityResults { + key, err := keyFromJSON(er.Entity["key"]) + if err != nil { + c.logger.ErrorContext(ctx, "failed to parse key from response", "error", err) + return nil, err + } + keys = append(keys, key) + } + c.logger.DebugContext(ctx, "keys-only query completed successfully", "kind", query.kind, "keys_found", len(keys)) + return keys, nil + } + // Verify dst is a pointer to slice v := reflect.ValueOf(dst) if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Slice { diff --git a/pkg/mock/mock.go b/pkg/mock/mock.go index 8602af2..0ecb7b4 100644 --- a/pkg/mock/mock.go +++ b/pkg/mock/mock.go @@ -18,32 +18,56 @@ package mock import ( + "encoding/base64" "encoding/json" "fmt" "log" "net/http" "net/http/httptest" + "sort" "strconv" "sync" "testing" + "time" ) const metadataFlavor = "Google" +// Datastore limits (matching real Google Cloud Datastore). +const ( + maxMutationsPerCommit = 500 + maxEntitySizeBytes = 1048572 // 1 MiB - 4 bytes + maxKeySizeBytes = 6144 // 6 KiB + transactionTimeout = 270 // seconds (4.5 minutes, real is ~5 minutes) +) + // Store holds the in-memory entity storage. // //nolint:govet // Field alignment not optimized to maintain readability type Store struct { - mu sync.RWMutex - entities map[string]map[string]any - nextID int64 // Counter for allocating unique IDs + mu sync.RWMutex + entities map[string]map[string]any + transactions map[string]*transactionState + nextID int64 // Counter for allocating unique IDs + nextTxID int64 // Counter for transaction IDs +} + +// transactionState tracks the state of an active transaction. +// +//nolint:govet // Field order prioritizes logical grouping over memory optimization +type transactionState struct { + id string + createdAt time.Time + readKeys map[string]bool // Keys read during this transaction } // NewStore creates a new in-memory store. func NewStore() *Store { return &Store{ - entities: make(map[string]map[string]any), - nextID: 1000, // Start IDs at 1000 + entities: make(map[string]map[string]any), + transactions: make(map[string]*transactionState), + nextID: 1000, // Start IDs at 1000 + nextTxID: 1, } } @@ -112,7 +136,7 @@ func NewMockServers(t *testing.T) (metadataURL, apiURL string, cleanup func()) { } if r.URL.Path == "/projects/test-project:beginTransaction" { - handleBeginTransaction(w, r) + store.handleBeginTransaction(w, r) return } @@ -226,12 +250,13 @@ func (s *Store) handleLookup(w http.ResponseWriter, r *http.Request) { // handleCommit handles commit (put/delete) requests. // -//nolint:gocognit,maintidx // Complex logic required for handling multiple mutation types +//nolint:gocognit // Complex validation logic required to match real Datastore behavior func (s *Store) handleCommit(w http.ResponseWriter, r *http.Request) { var req struct { - Mode string `json:"mode"` - DatabaseID string `json:"databaseId"` - Mutations []map[string]any `json:"mutations"` + Mode string `json:"mode"` + DatabaseID string `json:"databaseId"` + Transaction string `json:"transaction"` + Mutations []map[string]any `json:"mutations"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -243,17 +268,22 @@ func (s *Store) handleCommit(w http.ResponseWriter, r *http.Request) { if req.DatabaseID != "" { routingHeader := r.Header.Get("X-Goog-Request-Params") if routingHeader == "" { - w.WriteHeader(http.StatusBadRequest) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "error": map[string]any{ - "code": 400, - "message": "Missing routing header for named database", - "status": "INVALID_ARGUMENT", - }, - }); err != nil { - log.Printf("failed to encode error response: %v", err) - } + s.writeError(w, http.StatusBadRequest, "INVALID_ARGUMENT", "Missing routing header for named database") + return + } + } + + // Validate mutation count limit + if len(req.Mutations) > maxMutationsPerCommit { + s.writeError(w, http.StatusBadRequest, "INVALID_ARGUMENT", + fmt.Sprintf("Too many mutations: %d exceeds limit of %d", len(req.Mutations), maxMutationsPerCommit)) + return + } + + // Validate transaction mode + if req.Mode == "TRANSACTIONAL" { + if req.Transaction == "" { + s.writeError(w, http.StatusBadRequest, "INVALID_ARGUMENT", "Transaction ID required for TRANSACTIONAL mode") return } } @@ -261,134 +291,138 @@ func (s *Store) handleCommit(w http.ResponseWriter, r *http.Request) { s.mu.Lock() defer s.mu.Unlock() + // Validate transaction if provided + if req.Transaction != "" { + txState, exists := s.transactions[req.Transaction] + if !exists { + s.writeErrorLocked(w, http.StatusBadRequest, "INVALID_ARGUMENT", "Invalid or expired transaction") + return + } + + // Check transaction timeout + if time.Since(txState.createdAt) > transactionTimeout*time.Second { + delete(s.transactions, req.Transaction) + s.writeErrorLocked(w, http.StatusBadRequest, "ABORTED", "Transaction has expired") + return + } + + // Remove transaction after commit (whether successful or not) + defer delete(s.transactions, req.Transaction) + } + var mutationResults []map[string]any for _, mutation := range req.Mutations { var resultKey map[string]any - // Handle insert + // Handle insert - fails if entity already exists (like real Datastore) if insert, ok := mutation["insert"].(map[string]any); ok { keyData, ok := insert["key"].(map[string]any) if !ok { continue } - path, ok := keyData["path"].([]any) - if !ok || len(path) == 0 { - continue + + // Validate entity size + if err := s.validateEntitySize(insert); err != nil { + s.writeErrorLocked(w, http.StatusBadRequest, "INVALID_ARGUMENT", err.Error()) + return } - pathElem, ok := path[0].(map[string]any) - if !ok { - continue + + // Validate key size + if err := s.validateKeySize(keyData); err != nil { + s.writeErrorLocked(w, http.StatusBadRequest, "INVALID_ARGUMENT", err.Error()) + return } - kind, ok := pathElem["kind"].(string) + + // Handle incomplete keys - allocate ID + keyStr, keyData, ok := s.resolveKey(keyData) if !ok { continue } - // Handle both name and ID keys - var keyStr string - if name, ok := pathElem["name"].(string); ok { - keyStr = kind + "/" + name - } else if id, ok := pathElem["id"].(string); ok { - keyStr = kind + "/" + id - } else { - continue + // Check if entity already exists - insert should fail + if _, exists := s.entities[keyStr]; exists { + s.writeErrorLocked(w, http.StatusConflict, "ALREADY_EXISTS", "Entity already exists") + return } + // Update the entity's key with potentially allocated ID + insert["key"] = keyData s.entities[keyStr] = insert resultKey = keyData } - // Handle update + // Handle update - fails if entity doesn't exist (like real Datastore) if update, ok := mutation["update"].(map[string]any); ok { keyData, ok := update["key"].(map[string]any) if !ok { continue } - path, ok := keyData["path"].([]any) - if !ok || len(path) == 0 { - continue + + // Validate entity size + if err := s.validateEntitySize(update); err != nil { + s.writeErrorLocked(w, http.StatusBadRequest, "INVALID_ARGUMENT", err.Error()) + return } - pathElem, ok := path[0].(map[string]any) - if !ok { - continue + + // Validate key size + if err := s.validateKeySize(keyData); err != nil { + s.writeErrorLocked(w, http.StatusBadRequest, "INVALID_ARGUMENT", err.Error()) + return } - kind, ok := pathElem["kind"].(string) + + keyStr, ok := s.extractKeyString(keyData) if !ok { continue } - // Handle both name and ID keys - var keyStr string - if name, ok := pathElem["name"].(string); ok { - keyStr = kind + "/" + name - } else if id, ok := pathElem["id"].(string); ok { - keyStr = kind + "/" + id - } else { - continue + // Check if entity exists - update should fail if not + if _, exists := s.entities[keyStr]; !exists { + s.writeErrorLocked(w, http.StatusNotFound, "NOT_FOUND", "No entity to update") + return } s.entities[keyStr] = update resultKey = keyData } - // Handle upsert + // Handle upsert - creates or updates (always succeeds) if upsert, ok := mutation["upsert"].(map[string]any); ok { keyData, ok := upsert["key"].(map[string]any) if !ok { continue } - path, ok := keyData["path"].([]any) - if !ok || len(path) == 0 { - continue - } - pathElem, ok := path[0].(map[string]any) - if !ok { - continue + + // Validate entity size + if err := s.validateEntitySize(upsert); err != nil { + s.writeErrorLocked(w, http.StatusBadRequest, "INVALID_ARGUMENT", err.Error()) + return } - kind, ok := pathElem["kind"].(string) - if !ok { - continue + + // Validate key size + if err := s.validateKeySize(keyData); err != nil { + s.writeErrorLocked(w, http.StatusBadRequest, "INVALID_ARGUMENT", err.Error()) + return } - // Handle both name and ID keys - var keyStr string - if name, ok := pathElem["name"].(string); ok { - keyStr = kind + "/" + name - } else if id, ok := pathElem["id"].(string); ok { - keyStr = kind + "/" + id - } else { + // Handle incomplete keys - allocate ID + keyStr, keyData, ok := s.resolveKey(keyData) + if !ok { continue } + // Update the entity's key with potentially allocated ID + upsert["key"] = keyData s.entities[keyStr] = upsert resultKey = keyData } // Handle delete if deleteKey, ok := mutation["delete"].(map[string]any); ok { - path, ok := deleteKey["path"].([]any) - if !ok || len(path) == 0 { - continue - } - pathElem, ok := path[0].(map[string]any) + keyStr, ok := s.extractKeyString(deleteKey) if !ok { continue } - kind, ok := pathElem["kind"].(string) - if !ok { - continue - } - - // Handle both name and ID keys - var keyStr string - if name, ok := pathElem["name"].(string); ok { - keyStr = kind + "/" + name - } else if id, ok := pathElem["id"].(string); ok { - keyStr = kind + "/" + id - } else { - continue - } delete(s.entities, keyStr) resultKey = deleteKey @@ -411,6 +445,127 @@ func (s *Store) handleCommit(w http.ResponseWriter, r *http.Request) { } } +// writeError writes an error response (must NOT hold lock). +// +//nolint:unparam // code parameter kept for consistency with writeErrorLocked +func (*Store) writeError(w http.ResponseWriter, code int, status, message string) { + w.WriteHeader(code) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "code": code, + "message": message, + "status": status, + }, + }); err != nil { + log.Printf("failed to encode error response: %v", err) + } +} + +// writeErrorLocked writes an error response (caller holds lock, but we don't release it). +func (*Store) writeErrorLocked(w http.ResponseWriter, statusCode int, status, message string) { + w.WriteHeader(statusCode) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "code": statusCode, + "message": message, + "status": status, + }, + }); err != nil { + log.Printf("failed to encode error response: %v", err) + } +} + +// validateEntitySize checks if an entity exceeds the size limit. +func (*Store) validateEntitySize(entity map[string]any) error { + data, err := json.Marshal(entity) + if err != nil { + return fmt.Errorf("failed to measure entity size: %w", err) + } + if len(data) > maxEntitySizeBytes { + return fmt.Errorf("entity size %d exceeds limit of %d bytes", len(data), maxEntitySizeBytes) + } + return nil +} + +// validateKeySize checks if a key exceeds the size limit. +func (*Store) validateKeySize(keyData map[string]any) error { + data, err := json.Marshal(keyData) + if err != nil { + return fmt.Errorf("failed to measure key size: %w", err) + } + if len(data) > maxKeySizeBytes { + return fmt.Errorf("key size %d exceeds limit of %d bytes", len(data), maxKeySizeBytes) + } + return nil +} + +// resolveKey handles incomplete keys by allocating an ID if needed. +// Returns the key string, updated key data, and success flag. +func (s *Store) resolveKey(keyData map[string]any) (keyStr string, updatedKey map[string]any, ok bool) { + path, ok := keyData["path"].([]any) + if !ok || len(path) == 0 { + return "", nil, false + } + pathElem, ok := path[0].(map[string]any) + if !ok { + return "", nil, false + } + kind, ok := pathElem["kind"].(string) + if !ok { + return "", nil, false + } + + // Handle both name and ID keys + if name, ok := pathElem["name"].(string); ok { + return kind + "/" + name, keyData, true + } + if id, ok := pathElem["id"].(string); ok { + return kind + "/" + id, keyData, true + } + + // Incomplete key - allocate an ID + s.nextID++ + allocatedID := strconv.FormatInt(s.nextID, 10) + pathElem["id"] = allocatedID + + return kind + "/" + allocatedID, keyData, true +} + +// extractKeyString extracts the key string from key data. +func (*Store) extractKeyString(keyData map[string]any) (string, bool) { + path, ok := keyData["path"].([]any) + if !ok || len(path) == 0 { + return "", false + } + pathElem, ok := path[0].(map[string]any) + if !ok { + return "", false + } + kind, ok := pathElem["kind"].(string) + if !ok { + return "", false + } + + // Handle both name and ID keys + if name, ok := pathElem["name"].(string); ok { + return kind + "/" + name, true + } + if id, ok := pathElem["id"].(string); ok { + return kind + "/" + id, true + } + return "", false +} + +// queryResult holds an entity with its key for sorting. +// +//nolint:govet // Field order prioritizes logical grouping over memory optimization +type queryResult struct { + keyStr string + entity map[string]any +} + // handleRunQuery handles query requests. func (s *Store) handleRunQuery(w http.ResponseWriter, r *http.Request) { var req struct { @@ -427,17 +582,7 @@ func (s *Store) handleRunQuery(w http.ResponseWriter, r *http.Request) { if req.DatabaseID != "" { routingHeader := r.Header.Get("X-Goog-Request-Params") if routingHeader == "" { - w.WriteHeader(http.StatusBadRequest) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "error": map[string]any{ - "code": 400, - "message": "Missing routing header for named database", - "status": "INVALID_ARGUMENT", - }, - }); err != nil { - log.Printf("failed to encode error response: %v", err) - } + s.writeError(w, http.StatusBadRequest, "INVALID_ARGUMENT", "Missing routing header for named database") return } } @@ -464,38 +609,24 @@ func (s *Store) handleRunQuery(w http.ResponseWriter, r *http.Request) { limit = int(l) } - // Check for startCursor - if present, we've already returned results - // For simplicity in the mock, return empty results when cursor is used - var startCursor string - if sc, ok := query["startCursor"].(string); ok { - startCursor = sc + var offset int + if o, ok := query["offset"].(float64); ok { + offset = int(o) + } + + // Parse cursor to get starting position + var startIdx int + if sc, ok := query["startCursor"].(string); ok && sc != "" { + startIdx = s.decodeCursor(sc) } // Find all entities of this kind s.mu.RLock() defer s.mu.RUnlock() - var results []any - - // If there's a start cursor, we simulate pagination by returning no more results - // This is a simplified mock behavior - a real implementation would track position - if startCursor != "" { - // Return empty results to indicate end of pagination - response := map[string]any{ - "batch": map[string]any{ - "entityResults": []any{}, - "moreResults": "NO_MORE_RESULTS", - }, - } - w.WriteHeader(http.StatusOK) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - log.Printf("failed to encode query response: %v", err) - } - return - } - - for _, entity := range s.entities { + // Collect all matching entities + var matches []queryResult + for keyStr, entity := range s.entities { keyData, ok := entity["key"].(map[string]any) if !ok { continue @@ -521,34 +652,58 @@ func (s *Store) handleRunQuery(w http.ResponseWriter, r *http.Request) { } } - results = append(results, map[string]any{ - "entity": entity, - }) - - if limit > 0 && len(results) >= limit { - break - } + matches = append(matches, queryResult{keyStr: keyStr, entity: entity}) } } - // Add cursor if there are more results (for pagination testing) + // Sort results deterministically by key string for consistent ordering + sort.Slice(matches, func(i, j int) bool { + return matches[i].keyStr < matches[j].keyStr + }) + + // Apply ordering from query if specified + if orders, ok := query["order"].([]any); ok && len(orders) > 0 { + s.applyOrdering(matches, orders) + } + + // Apply offset and cursor + skipCount := offset + startIdx + if skipCount > len(matches) { + skipCount = len(matches) + } + matches = matches[skipCount:] + + // Apply limit + totalMatches := len(matches) + if limit > 0 && len(matches) > limit { + matches = matches[:limit] + } + + // Build results + results := make([]any, 0, len(matches)) + for _, m := range matches { + results = append(results, map[string]any{ + "entity": m.entity, + }) + } + + // Generate cursor for pagination var endCursor string - if limit > 0 && len(results) == limit { - // Generate a simple cursor to indicate more results might exist - endCursor = fmt.Sprintf("cursor-after-%d", limit) + moreResults := "NO_MORE_RESULTS" + if limit > 0 && totalMatches > limit { + // Encode cursor as position in sorted results + endCursor = s.encodeCursor(skipCount + limit) + moreResults = "MORE_RESULTS_AFTER_LIMIT" } // Build response batch := map[string]any{ "entityResults": results, + "moreResults": moreResults, } - // Add cursor if available if endCursor != "" { batch["endCursor"] = endCursor - batch["moreResults"] = "MORE_RESULTS_AFTER_LIMIT" - } else { - batch["moreResults"] = "NO_MORE_RESULTS" } response := map[string]any{ @@ -562,8 +717,151 @@ func (s *Store) handleRunQuery(w http.ResponseWriter, r *http.Request) { } } +// encodeCursor creates a base64-encoded cursor from a position. +func (*Store) encodeCursor(pos int) string { + return base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("pos:%d", pos))) +} + +// decodeCursor extracts the position from a cursor string. +func (*Store) decodeCursor(cursor string) int { + data, err := base64.StdEncoding.DecodeString(cursor) + if err != nil { + return 0 + } + var pos int + if _, err := fmt.Sscanf(string(data), "pos:%d", &pos); err != nil { + return 0 + } + return pos +} + +// applyOrdering sorts query results based on order specifications. +func (*Store) applyOrdering(matches []queryResult, orders []any) { + sort.SliceStable(matches, func(i, j int) bool { + for _, orderAny := range orders { + order, ok := orderAny.(map[string]any) + if !ok { + continue + } + propMap, ok := order["property"].(map[string]any) + if !ok { + continue + } + propName, ok := propMap["name"].(string) + if !ok { + continue + } + direction, ok := order["direction"].(string) + descending := ok && direction == "DESCENDING" + + // Get property values from both entities + propsI, okI := matches[i].entity["properties"].(map[string]any) + propsJ, okJ := matches[j].entity["properties"].(map[string]any) + if !okI || !okJ { + continue + } + + valI := getPropertyValue(propsI, propName) + valJ := getPropertyValue(propsJ, propName) + + cmp := compareValues(valI, valJ) + if cmp != 0 { + if descending { + return cmp > 0 + } + return cmp < 0 + } + } + return false + }) +} + +// getPropertyValue extracts a comparable value from entity properties. +func getPropertyValue(props map[string]any, name string) any { + if props == nil { + return nil + } + prop, ok := props[name].(map[string]any) + if !ok { + return nil + } + if v, ok := prop["integerValue"].(string); ok { + var i int64 + if _, err := fmt.Sscanf(v, "%d", &i); err == nil { + return i + } + } + if v, ok := prop["stringValue"].(string); ok { + return v + } + if v, ok := prop["doubleValue"].(float64); ok { + return v + } + if v, ok := prop["booleanValue"].(bool); ok { + return v + } + return nil +} + +// compareValues compares two property values. +func compareValues(a, b any) int { + if a == nil && b == nil { + return 0 + } + if a == nil { + return -1 + } + if b == nil { + return 1 + } + + switch va := a.(type) { + case int64: + if vb, ok := b.(int64); ok { + if va < vb { + return -1 + } + if va > vb { + return 1 + } + return 0 + } + case string: + if vb, ok := b.(string); ok { + if va < vb { + return -1 + } + if va > vb { + return 1 + } + return 0 + } + case float64: + if vb, ok := b.(float64); ok { + if va < vb { + return -1 + } + if va > vb { + return 1 + } + return 0 + } + case bool: + if vb, ok := b.(bool); ok { + if !va && vb { + return -1 + } + if va && !vb { + return 1 + } + return 0 + } + } + return 0 +} + // handleBeginTransaction handles transaction begin requests. -func handleBeginTransaction(w http.ResponseWriter, r *http.Request) { +func (s *Store) handleBeginTransaction(w http.ResponseWriter, r *http.Request) { var req struct { DatabaseID string `json:"databaseId"` } @@ -577,25 +875,29 @@ func handleBeginTransaction(w http.ResponseWriter, r *http.Request) { if req.DatabaseID != "" { routingHeader := r.Header.Get("X-Goog-Request-Params") if routingHeader == "" { - w.WriteHeader(http.StatusBadRequest) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "error": map[string]any{ - "code": 400, - "message": "Missing routing header for named database", - "status": "INVALID_ARGUMENT", - }, - }); err != nil { - log.Printf("failed to encode error response: %v", err) - } + s.writeError(w, http.StatusBadRequest, "INVALID_ARGUMENT", "Missing routing header for named database") return } } + s.mu.Lock() + defer s.mu.Unlock() + + // Generate unique transaction ID + s.nextTxID++ + txID := fmt.Sprintf("tx-%d-%d", time.Now().UnixNano(), s.nextTxID) + + // Store transaction state + s.transactions[txID] = &transactionState{ + id: txID, + createdAt: time.Now(), + readKeys: make(map[string]bool), + } + w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": "test-transaction-id", + "transaction": txID, }); err != nil { log.Printf("failed to encode transaction response: %v", err) } diff --git a/pkg/mock/mock_test.go b/pkg/mock/mock_test.go index 984cfe4..d826bd0 100644 --- a/pkg/mock/mock_test.go +++ b/pkg/mock/mock_test.go @@ -2,6 +2,7 @@ package mock_test import ( "context" + "errors" "testing" "github.com/codeGROOVE-dev/ds9/pkg/datastore" @@ -442,3 +443,315 @@ func TestMockConcurrentQuery(t *testing.T) { <-done } } + +func TestMockInsertAlreadyExists(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type TestEntity struct { + Name string `datastore:"name"` + } + + key := datastore.NameKey("InsertTest", "test-key", nil) + entity := &TestEntity{Name: "first"} + + // First insert should succeed + mut := datastore.NewInsert(key, entity) + _, err := client.Mutate(ctx, mut) + if err != nil { + t.Fatalf("First insert failed: %v", err) + } + + // Second insert with same key should fail with ALREADY_EXISTS + entity2 := &TestEntity{Name: "second"} + mut2 := datastore.NewInsert(key, entity2) + _, err = client.Mutate(ctx, mut2) + if err == nil { + t.Error("Expected error for duplicate insert, got nil") + } + + // Verify original entity is unchanged + var retrieved TestEntity + if err := client.Get(ctx, key, &retrieved); err != nil { + t.Fatalf("Get failed: %v", err) + } + if retrieved.Name != "first" { + t.Errorf("Expected Name 'first', got %q", retrieved.Name) + } +} + +func TestMockUpdateNotFound(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type TestEntity struct { + Name string `datastore:"name"` + } + + // Try to update non-existent entity + key := datastore.NameKey("UpdateTest", "nonexistent", nil) + entity := &TestEntity{Name: "updated"} + mut := datastore.NewUpdate(key, entity) + _, err := client.Mutate(ctx, mut) + if err == nil { + t.Error("Expected error for update on non-existent entity, got nil") + } + + // Verify entity was not created + var retrieved TestEntity + err = client.Get(ctx, key, &retrieved) + if err == nil { + t.Error("Expected ErrNoSuchEntity, but entity was found") + } +} + +func TestMockUpdateExisting(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type TestEntity struct { + Name string `datastore:"name"` + } + + key := datastore.NameKey("UpdateTest2", "existing", nil) + entity := &TestEntity{Name: "original"} + + // Create entity first + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Update should succeed + updated := &TestEntity{Name: "updated"} + mut := datastore.NewUpdate(key, updated) + _, err := client.Mutate(ctx, mut) + if err != nil { + t.Fatalf("Update failed: %v", err) + } + + // Verify update + var retrieved TestEntity + if err := client.Get(ctx, key, &retrieved); err != nil { + t.Fatalf("Get failed: %v", err) + } + if retrieved.Name != "updated" { + t.Errorf("Expected Name 'updated', got %q", retrieved.Name) + } +} + +func TestMockUpsertAlwaysSucceeds(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type TestEntity struct { + Name string `datastore:"name"` + } + + key := datastore.NameKey("UpsertTest", "test-key", nil) + + // First upsert (creates) + entity1 := &TestEntity{Name: "first"} + mut1 := datastore.NewUpsert(key, entity1) + _, err := client.Mutate(ctx, mut1) + if err != nil { + t.Fatalf("First upsert failed: %v", err) + } + + // Second upsert (updates) + entity2 := &TestEntity{Name: "second"} + mut2 := datastore.NewUpsert(key, entity2) + _, err = client.Mutate(ctx, mut2) + if err != nil { + t.Fatalf("Second upsert failed: %v", err) + } + + // Verify update + var retrieved TestEntity + if err := client.Get(ctx, key, &retrieved); err != nil { + t.Fatalf("Get failed: %v", err) + } + if retrieved.Name != "second" { + t.Errorf("Expected Name 'second', got %q", retrieved.Name) + } +} + +func TestGetAllKeysOnlyNilDst(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type TestEntity struct { + Name string `datastore:"name"` + } + + // Create some entities + for i := range 5 { + key := datastore.NameKey("GetAllNilTest", string(rune('a'+i)), nil) + entity := &TestEntity{Name: "test"} + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // GetAll with KeysOnly and nil dst should work (like real Datastore) + query := datastore.NewQuery("GetAllNilTest").KeysOnly() + keys, err := client.GetAll(ctx, query, nil) + if err != nil { + t.Fatalf("GetAll with nil dst failed: %v", err) + } + + if len(keys) != 5 { + t.Errorf("Expected 5 keys, got %d", len(keys)) + } + + for _, key := range keys { + if key.Kind != "GetAllNilTest" { + t.Errorf("Expected kind 'GetAllNilTest', got %q", key.Kind) + } + } +} + +func TestMockQueryDeterministicOrder(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type TestEntity struct { + Name string `datastore:"name"` + } + + // Create entities with non-alphabetical insertion order + keys := []string{"zebra", "apple", "mango", "banana"} + for _, name := range keys { + key := datastore.NameKey("OrderTest", name, nil) + entity := &TestEntity{Name: name} + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query multiple times and verify order is deterministic + for run := range 3 { + query := datastore.NewQuery("OrderTest").KeysOnly() + resultKeys, err := client.AllKeys(ctx, query) + if err != nil { + t.Fatalf("AllKeys failed on run %d: %v", run, err) + } + + if len(resultKeys) != 4 { + t.Fatalf("Expected 4 keys, got %d", len(resultKeys)) + } + + // Results should be in alphabetical order by key name + expected := []string{"apple", "banana", "mango", "zebra"} + for i, key := range resultKeys { + if key.Name != expected[i] { + t.Errorf("Run %d: expected key %d to be %q, got %q", run, i, expected[i], key.Name) + } + } + } +} + +func TestMockPaginationWithCursor(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type TestEntity struct { + Count int64 `datastore:"count"` + } + + // Create 10 entities + for i := range 10 { + key := datastore.NameKey("PaginationTest", string(rune('a'+i)), nil) + entity := &TestEntity{Count: int64(i)} + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query with limit 3 and iterate through all pages + var allKeys []*datastore.Key + query := datastore.NewQuery("PaginationTest").Limit(3) + it := client.Run(ctx, query) + + for { + var entity TestEntity + key, err := it.Next(&entity) + if errors.Is(err, datastore.Done) { + break + } + if err != nil { + t.Fatalf("Next failed: %v", err) + } + allKeys = append(allKeys, key) + } + + // With limit=3, we should only get 3 results (not paginate automatically) + if len(allKeys) != 3 { + t.Errorf("Expected 3 keys with limit, got %d", len(allKeys)) + } + + // Verify cursor is available after iteration + cursor, err := it.Cursor() + if err != nil { + t.Logf("Cursor not available: %v (this is OK if at end of results)", err) + } else if cursor == "" { + t.Error("Expected non-empty cursor") + } +} + +func TestMockTransactionValidation(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type TestEntity struct { + Name string `datastore:"name"` + } + + // Test that transactions work correctly + key := datastore.NameKey("TxTest", "test-key", nil) + entity := &TestEntity{Name: "original"} + + // Put initial entity + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Initial Put failed: %v", err) + } + + // Run a transaction that reads and writes + _, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var retrieved TestEntity + if err := tx.Get(key, &retrieved); err != nil { + return err + } + + retrieved.Name = "modified" + _, err := tx.Put(key, &retrieved) + return err + }) + if err != nil { + t.Fatalf("Transaction failed: %v", err) + } + + // Verify the modification persisted + var result TestEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get after transaction failed: %v", err) + } + if result.Name != "modified" { + t.Errorf("Expected Name 'modified', got %q", result.Name) + } +}