diff --git a/src/Enums/EdgeDirection.cs b/src/Enums/EdgeDirection.cs new file mode 100644 index 000000000..9d9a2e53f --- /dev/null +++ b/src/Enums/EdgeDirection.cs @@ -0,0 +1,36 @@ +namespace AiDotNet.Enums; + +/// +/// Specifies the direction of edges to retrieve when querying a knowledge graph. +/// +/// +/// For Beginners: In a directed graph, edges have a direction - they go FROM one node TO another. +/// +/// Think of it like Twitter follows: +/// - If Alice follows Bob, the edge goes FROM Alice TO Bob +/// - Alice has an OUTGOING edge (she's following someone) +/// - Bob has an INCOMING edge (someone is following him) +/// +/// When querying relationships: +/// - Outgoing: "Who does this person follow?" (edges starting from this node) +/// - Incoming: "Who follows this person?" (edges pointing to this node) +/// - Both: "All connections" (both directions) +/// +/// +public enum EdgeDirection +{ + /// + /// Retrieve only outgoing edges (edges where the specified node is the source). + /// + Outgoing, + + /// + /// Retrieve only incoming edges (edges where the specified node is the target). + /// + Incoming, + + /// + /// Retrieve both outgoing and incoming edges. + /// + Both +} diff --git a/src/Interfaces/IGraphStore.cs b/src/Interfaces/IGraphStore.cs new file mode 100644 index 000000000..a5d033384 --- /dev/null +++ b/src/Interfaces/IGraphStore.cs @@ -0,0 +1,381 @@ +using System.Collections.Generic; +using System.Threading.Tasks; +using AiDotNet.RetrievalAugmentedGeneration.Graph; + +namespace AiDotNet.Interfaces; + +/// +/// Defines the contract for graph storage backends that manage nodes and edges. +/// +/// +/// +/// A graph store provides persistent or in-memory storage for knowledge graphs, +/// enabling efficient storage and retrieval of entities (nodes) and their relationships (edges). +/// Implementations can range from simple in-memory dictionaries to distributed graph databases. +/// +/// For Beginners: A graph store is like a filing system for connected information. +/// +/// Think of it like organizing a network of friends: +/// - Nodes are people (Alice, Bob, Charlie) +/// - Edges are relationships (Alice KNOWS Bob, Bob WORKS_WITH Charlie) +/// - The graph store remembers all these connections +/// +/// Different implementations might: +/// - MemoryGraphStore: Keep everything in RAM (fast but lost when app closes) +/// - FileGraphStore: Save to disk (slower but survives restarts) +/// - Neo4jGraphStore: Use a professional graph database (production-scale) +/// +/// This interface lets you swap storage backends without changing your code! +/// +/// +/// The numeric data type used for vector calculations (typically float or double). +public interface IGraphStore +{ + /// + /// Gets the total number of nodes in the graph store. + /// + /// + /// For Beginners: This tells you how many entities (people, places, things) + /// are stored in the graph. + /// + /// + int NodeCount { get; } + + /// + /// Gets the total number of edges in the graph store. + /// + /// + /// For Beginners: This tells you how many relationships/connections + /// exist between entities in the graph. + /// + /// + int EdgeCount { get; } + + /// + /// Adds a node to the graph or updates it if it already exists. + /// + /// The node to add. + /// + /// + /// This method stores a node in the graph. If a node with the same ID already exists, + /// it will be updated with the new data. The node is automatically indexed by its label + /// for efficient label-based queries. + /// + /// For Beginners: This adds a new entity to the graph. + /// + /// Like adding a person to a social network: + /// - node.Id = "alice_001" + /// - node.Label = "PERSON" + /// - node.Properties = { "name": "Alice Smith", "age": 30 } + /// + /// If Alice already exists, her information gets updated. + /// + /// + void AddNode(GraphNode node); + + /// + /// Adds an edge to the graph representing a relationship between two nodes. + /// + /// The edge to add. + /// + /// + /// This method creates a relationship between two existing nodes. Both the source + /// and target nodes must already exist in the graph, otherwise an exception is thrown. + /// The edge is indexed for efficient traversal from both directions. + /// + /// For Beginners: This adds a connection between two entities. + /// + /// Like saying "Alice knows Bob": + /// - edge.SourceId = "alice_001" + /// - edge.RelationType = "KNOWS" + /// - edge.TargetId = "bob_002" + /// - edge.Weight = 0.9 (how strong the relationship is) + /// + /// Both Alice and Bob must already be added as nodes first! + /// + /// + void AddEdge(GraphEdge edge); + + /// + /// Retrieves a node by its unique identifier. + /// + /// The unique identifier of the node. + /// The node if found; otherwise, null. + /// + /// For Beginners: This gets a specific entity if you know its ID. + /// + /// Like asking: "Show me the person with ID 'alice_001'" + /// + /// + GraphNode? GetNode(string nodeId); + + /// + /// Retrieves an edge by its unique identifier. + /// + /// The unique identifier of the edge. + /// The edge if found; otherwise, null. + /// + /// For Beginners: This gets a specific relationship if you know its ID. + /// + /// Edge IDs are usually auto-generated like: "alice_001_KNOWS_bob_002" + /// + /// + GraphEdge? GetEdge(string edgeId); + + /// + /// Removes a node and all its connected edges from the graph. + /// + /// The unique identifier of the node to remove. + /// True if the node was found and removed; otherwise, false. + /// + /// + /// This method removes a node and automatically cleans up all edges connected to it + /// (both incoming and outgoing). This ensures the graph remains consistent. + /// + /// For Beginners: This deletes an entity and all its connections. + /// + /// Like removing Alice from the network: + /// - Alice's profile is deleted + /// - All "Alice KNOWS Bob" relationships are deleted + /// - All "Bob KNOWS Alice" relationships are deleted + /// + /// This keeps the graph clean - no broken connections! + /// + /// + bool RemoveNode(string nodeId); + + /// + /// Removes an edge from the graph. + /// + /// The unique identifier of the edge to remove. + /// True if the edge was found and removed; otherwise, false. + /// + /// For Beginners: This deletes a specific relationship. + /// + /// Like saying "Alice no longer knows Bob" - removes just that connection, + /// but Alice and Bob still exist in the graph. + /// + /// + bool RemoveEdge(string edgeId); + + /// + /// Gets all outgoing edges from a specific node. + /// + /// The source node ID. + /// Collection of outgoing edges from the node. + /// + /// + /// Outgoing edges represent relationships where this node is the source. + /// For example, if Alice KNOWS Bob, the "KNOWS" edge is outgoing from Alice. + /// + /// For Beginners: This finds all relationships going OUT from an entity. + /// + /// If you ask for Alice's outgoing edges, you get: + /// - Alice KNOWS Bob + /// - Alice WORKS_AT CompanyX + /// - Alice LIVES_IN Seattle + /// + /// These are things Alice does or has relationships with. + /// + /// + IEnumerable> GetOutgoingEdges(string nodeId); + + /// + /// Gets all incoming edges to a specific node. + /// + /// The target node ID. + /// Collection of incoming edges to the node. + /// + /// + /// Incoming edges represent relationships where this node is the target. + /// For example, if Alice KNOWS Bob, the "KNOWS" edge is incoming to Bob. + /// + /// For Beginners: This finds all relationships coming IN to an entity. + /// + /// If you ask for Bob's incoming edges, you get: + /// - Alice KNOWS Bob + /// - Charlie WORKS_WITH Bob + /// - CompanyY EMPLOYS Bob + /// + /// These are relationships others have WITH Bob. + /// + /// + IEnumerable> GetIncomingEdges(string nodeId); + + /// + /// Gets all nodes with a specific label. + /// + /// The node label to filter by (e.g., "PERSON", "COMPANY", "LOCATION"). + /// Collection of nodes with the specified label. + /// + /// + /// Labels are used to categorize nodes by type. This enables efficient queries + /// like "find all PERSON nodes" or "find all COMPANY nodes". + /// + /// For Beginners: This finds all entities of a specific type. + /// + /// Like asking: "Show me all PERSON nodes" + /// Returns: Alice, Bob, Charlie (all people in the graph) + /// + /// Or: "Show me all COMPANY nodes" + /// Returns: Microsoft, Google, Amazon (all companies) + /// + /// Labels are like categories or tags for organizing your entities. + /// + /// + IEnumerable> GetNodesByLabel(string label); + + /// + /// Gets all nodes currently stored in the graph. + /// + /// Collection of all nodes. + /// + /// + /// This method retrieves every node without any filtering. + /// Use with caution on large graphs as it may be memory-intensive. + /// + /// For Beginners: This gets every single entity in the graph. + /// + /// Like asking: "Show me everyone and everything in the network" + /// + /// Warning: If you have millions of entities, this could be slow and use lots of memory! + /// + /// + IEnumerable> GetAllNodes(); + + /// + /// Gets all edges currently stored in the graph. + /// + /// Collection of all edges. + /// + /// + /// This method retrieves every edge without any filtering. + /// Use with caution on large graphs as it may be memory-intensive. + /// + /// For Beginners: This gets every single relationship in the graph. + /// + /// Like asking: "Show me every connection between all entities" + /// + /// Warning: Large graphs can have millions of relationships! + /// + /// + IEnumerable> GetAllEdges(); + + /// + /// Removes all nodes and edges from the graph. + /// + /// + /// + /// This method clears the entire graph, removing all data. For persistent stores, + /// this may involve deleting files or database records. Use with extreme caution! + /// + /// For Beginners: This deletes EVERYTHING from the graph. + /// + /// Like wiping the entire social network clean - all people and all connections gone! + /// + /// ⚠️ WARNING: This cannot be undone! Make backups first! + /// + /// + void Clear(); + + // Async methods for I/O-intensive operations + + /// + /// Asynchronously adds a node to the graph or updates it if it already exists. + /// + /// The node to add. + /// A task representing the asynchronous operation. + /// + /// + /// This is the async version of . Use this for file-based or + /// database-backed stores to avoid blocking the thread during I/O operations. + /// + /// For Beginners: This does the same as AddNode but doesn't block your app. + /// + /// When should you use async? + /// - FileGraphStore: Yes! (writes to disk) + /// - MemoryGraphStore: Optional (no I/O, but provided for consistency) + /// - Database stores: Definitely! (network I/O) + /// + /// Example: + /// ```csharp + /// await store.AddNodeAsync(node); // Non-blocking + /// ``` + /// + /// + Task AddNodeAsync(GraphNode node); + + /// + /// Asynchronously adds an edge to the graph. + /// + /// The edge to add. + /// A task representing the asynchronous operation. + Task AddEdgeAsync(GraphEdge edge); + + /// + /// Asynchronously retrieves a node by its unique identifier. + /// + /// The unique identifier of the node. + /// A task that represents the asynchronous operation. The task result contains the node if found; otherwise, null. + Task?> GetNodeAsync(string nodeId); + + /// + /// Asynchronously retrieves an edge by its unique identifier. + /// + /// The unique identifier of the edge. + /// A task that represents the asynchronous operation. The task result contains the edge if found; otherwise, null. + Task?> GetEdgeAsync(string edgeId); + + /// + /// Asynchronously removes a node and all its connected edges from the graph. + /// + /// The unique identifier of the node to remove. + /// A task that represents the asynchronous operation. The task result is true if the node was found and removed; otherwise, false. + Task RemoveNodeAsync(string nodeId); + + /// + /// Asynchronously removes an edge from the graph. + /// + /// The unique identifier of the edge to remove. + /// A task that represents the asynchronous operation. The task result is true if the edge was found and removed; otherwise, false. + Task RemoveEdgeAsync(string edgeId); + + /// + /// Asynchronously gets all outgoing edges from a specific node. + /// + /// The source node ID. + /// A task that represents the asynchronous operation. The task result contains the collection of outgoing edges. + Task>> GetOutgoingEdgesAsync(string nodeId); + + /// + /// Asynchronously gets all incoming edges to a specific node. + /// + /// The target node ID. + /// A task that represents the asynchronous operation. The task result contains the collection of incoming edges. + Task>> GetIncomingEdgesAsync(string nodeId); + + /// + /// Asynchronously gets all nodes with a specific label. + /// + /// The node label to filter by. + /// A task that represents the asynchronous operation. The task result contains the collection of nodes with the specified label. + Task>> GetNodesByLabelAsync(string label); + + /// + /// Asynchronously gets all nodes currently stored in the graph. + /// + /// A task that represents the asynchronous operation. The task result contains all nodes. + Task>> GetAllNodesAsync(); + + /// + /// Asynchronously gets all edges currently stored in the graph. + /// + /// A task that represents the asynchronous operation. The task result contains all edges. + Task>> GetAllEdgesAsync(); + + /// + /// Asynchronously removes all nodes and edges from the graph. + /// + /// A task representing the asynchronous operation. + Task ClearAsync(); +} diff --git a/src/Interfaces/IPredictionModelBuilder.cs b/src/Interfaces/IPredictionModelBuilder.cs index bd2de66e3..0de672424 100644 --- a/src/Interfaces/IPredictionModelBuilder.cs +++ b/src/Interfaces/IPredictionModelBuilder.cs @@ -2,6 +2,7 @@ using AiDotNet.DistributedTraining; using AiDotNet.Enums; using AiDotNet.Models; +using AiDotNet.RetrievalAugmentedGeneration.Graph; namespace AiDotNet.Interfaces; @@ -342,31 +343,52 @@ public interface IPredictionModelBuilder /// Configures the retrieval-augmented generation (RAG) components for use during model inference. /// /// + /// /// RAG enhances text generation by retrieving relevant documents from a knowledge base /// and using them as context for generating grounded, factual answers. - /// + /// + /// + /// Graph RAG: When graphStore or knowledgeGraph is provided, enables knowledge graph-based + /// retrieval that finds related entities and their relationships, providing richer context than + /// vector similarity alone. If documentStore is also provided, hybrid retrieval combines both + /// vector search and graph traversal. + /// + /// /// For Beginners: RAG is like giving your AI access to a library before answering questions. /// Instead of relying only on what it learned during training, it can: - /// 1. Search a document collection for relevant information - /// 2. Read the relevant documents - /// 3. Generate an answer based on those documents - /// 4. Cite its sources - /// - /// This makes answers more accurate, up-to-date, and traceable to source materials. - /// - /// RAG operations (GenerateAnswer, RetrieveDocuments) are performed during inference via PredictionModelResult, - /// not during model building. + /// + /// Search a document collection for relevant information + /// Read the relevant documents + /// Generate an answer based on those documents + /// Cite its sources + /// + /// + /// + /// Graph RAG Example: If you ask about "Paris", Graph RAG can find not just documents + /// mentioning Paris, but also related concepts like France, Eiffel Tower, and Seine River + /// by traversing the knowledge graph. + /// + /// + /// RAG operations (GenerateAnswer, RetrieveDocuments, GraphQuery, etc.) are performed during + /// inference via PredictionModelResult, not during model building. + /// /// - /// Optional retriever for finding relevant documents. If not provided, RAG won't be available. + /// Optional retriever for finding relevant documents. If not provided, standard RAG won't be available. /// Optional reranker for improving document ranking quality. Default provided if retriever is set. /// Optional generator for producing grounded answers. Default provided if retriever is set. /// Optional query processors for improving search quality. + /// Optional graph storage backend for Graph RAG (e.g., MemoryGraphStore, FileGraphStore). + /// Optional pre-configured knowledge graph. If null but graphStore is provided, a new one is created. + /// Optional document store for hybrid vector + graph retrieval. /// The builder instance for method chaining. IPredictionModelBuilder ConfigureRetrievalAugmentedGeneration( IRetriever? retriever = null, IReranker? reranker = null, IGenerator? generator = null, - IEnumerable? queryProcessors = null); + IEnumerable? queryProcessors = null, + IGraphStore? graphStore = null, + KnowledgeGraph? knowledgeGraph = null, + IDocumentStore? documentStore = null); /// /// Configures AI agent assistance during model building and inference. diff --git a/src/Models/Results/PredictionModelResult.cs b/src/Models/Results/PredictionModelResult.cs index 8289e1d6f..e5e2133fb 100644 --- a/src/Models/Results/PredictionModelResult.cs +++ b/src/Models/Results/PredictionModelResult.cs @@ -2,6 +2,7 @@ global using Formatting = Newtonsoft.Json.Formatting; using AiDotNet.Data.Abstractions; using AiDotNet.Interfaces; +using AiDotNet.RetrievalAugmentedGeneration.Graph; using AiDotNet.Interpretability; using AiDotNet.Serialization; using AiDotNet.Agents; @@ -13,6 +14,7 @@ using AiDotNet.Deployment.Mobile.CoreML; using AiDotNet.Deployment.Mobile.TensorFlowLite; using AiDotNet.Deployment.Runtime; +using AiDotNet.Enums; namespace AiDotNet.Models.Results; @@ -209,6 +211,53 @@ public class PredictionModelResult : IFullModelQuery processors for preprocessing queries, or null if not configured. internal IEnumerable? QueryProcessors { get; private set; } + /// + /// Gets or sets the knowledge graph for graph-enhanced retrieval. + /// + /// A knowledge graph containing entities and relationships, or null if Graph RAG is not configured. + /// + /// For Beginners: The knowledge graph stores entities (like people, places, concepts) and their + /// relationships. When you query the model, it can traverse these relationships to find related context + /// that pure vector similarity might miss. + /// + /// + /// This property is excluded from JSON serialization because it contains runtime infrastructure + /// (graph store, file handles) that should be reconfigured when the model is loaded. + /// + /// + [JsonIgnore] + internal KnowledgeGraph? KnowledgeGraph { get; private set; } + + /// + /// Gets or sets the graph store backend for persistent graph storage. + /// + /// The graph storage backend, or null if Graph RAG is not configured. + /// + /// + /// This property is excluded from JSON serialization because it represents runtime storage + /// infrastructure (file handles, WAL) that must be reconfigured when the model is loaded. + /// + /// + [JsonIgnore] + internal IGraphStore? GraphStore { get; private set; } + + /// + /// Gets or sets the hybrid graph retriever for combined vector + graph retrieval. + /// + /// A hybrid retriever combining vector similarity with graph traversal, or null if not configured. + /// + /// For Beginners: The hybrid retriever first finds similar documents using vector search, + /// then expands the context by traversing the knowledge graph to find related entities. This provides + /// richer context than pure vector search alone. + /// + /// + /// This property is excluded from JSON serialization because it contains references to + /// runtime infrastructure (knowledge graph, document store) that must be reconfigured when loaded. + /// + /// + [JsonIgnore] + internal HybridGraphRetriever? HybridGraphRetriever { get; private set; } + /// /// Gets or sets the meta-learner used for few-shot adaptation and fine-tuning. /// @@ -427,6 +476,9 @@ public PredictionModelResult(IFullModel model, /// Optional agent configuration used during model building. /// Optional agent recommendations from model building. /// Optional deployment configuration for export, caching, versioning, A/B testing, and telemetry. + /// Optional knowledge graph for graph-enhanced retrieval. + /// Optional graph store backend for persistent storage. + /// Optional hybrid retriever for combined vector + graph search. public PredictionModelResult(OptimizationResult optimizationResult, NormalizationInfo normalizationInfo, IBiasDetector? biasDetector = null, @@ -441,7 +493,10 @@ public PredictionModelResult(OptimizationResult optimization AgentRecommendation? agentRecommendation = null, DeploymentConfiguration? deploymentConfiguration = null, Func[], Tensor[]>? jitCompiledFunction = null, - AiDotNet.Configuration.InferenceOptimizationConfig? inferenceOptimizationConfig = null) + AiDotNet.Configuration.InferenceOptimizationConfig? inferenceOptimizationConfig = null, + KnowledgeGraph? knowledgeGraph = null, + IGraphStore? graphStore = null, + HybridGraphRetriever? hybridGraphRetriever = null) { Model = optimizationResult.BestSolution; OptimizationResult = optimizationResult; @@ -460,6 +515,9 @@ public PredictionModelResult(OptimizationResult optimization DeploymentConfiguration = deploymentConfiguration; JitCompiledFunction = jitCompiledFunction; InferenceOptimizationConfig = inferenceOptimizationConfig; + KnowledgeGraph = knowledgeGraph; + GraphStore = graphStore; + HybridGraphRetriever = hybridGraphRetriever; } /// @@ -476,6 +534,9 @@ public PredictionModelResult(OptimizationResult optimization /// Optional query processors for RAG query preprocessing. /// Optional agent configuration for AI assistance during inference. /// Optional deployment configuration for export, caching, versioning, A/B testing, and telemetry. + /// Optional knowledge graph for graph-enhanced retrieval. + /// Optional graph store backend for persistent storage. + /// Optional hybrid retriever for combined vector + graph search. /// /// /// This constructor is used when a model has been trained using meta-learning (e.g., MAML, Reptile, SEAL). @@ -513,7 +574,10 @@ public PredictionModelResult( IGenerator? ragGenerator = null, IEnumerable? queryProcessors = null, AgentConfiguration? agentConfig = null, - DeploymentConfiguration? deploymentConfiguration = null) + DeploymentConfiguration? deploymentConfiguration = null, + KnowledgeGraph? knowledgeGraph = null, + IGraphStore? graphStore = null, + HybridGraphRetriever? hybridGraphRetriever = null) { Model = metaLearner.BaseModel; MetaLearner = metaLearner; @@ -528,6 +592,9 @@ public PredictionModelResult( QueryProcessors = queryProcessors; AgentConfig = agentConfig; DeploymentConfiguration = deploymentConfiguration; + KnowledgeGraph = knowledgeGraph; + GraphStore = graphStore; + HybridGraphRetriever = hybridGraphRetriever; // Create placeholder OptimizationResult and NormalizationInfo for consistency OptimizationResult = new OptimizationResult(); @@ -1046,7 +1113,7 @@ public IFullModel WithParameters(Vector parameters) // Create new result with updated optimization result // Preserve all configuration properties to ensure deployment behavior, model adaptation, - // and training history are maintained across parameter updates + // training history, and Graph RAG configuration are maintained across parameter updates return new PredictionModelResult( updatedOptimizationResult, NormalizationInfo, @@ -1060,7 +1127,12 @@ public IFullModel WithParameters(Vector parameters) crossValidationResult: CrossValidationResult, agentConfig: AgentConfig, agentRecommendation: AgentRecommendation, - deploymentConfiguration: DeploymentConfiguration); + deploymentConfiguration: DeploymentConfiguration, + jitCompiledFunction: null, // JIT compilation is parameter-specific, don't copy + inferenceOptimizationConfig: InferenceOptimizationConfig, + knowledgeGraph: KnowledgeGraph, + graphStore: GraphStore, + hybridGraphRetriever: HybridGraphRetriever); } /// @@ -1147,7 +1219,7 @@ public IFullModel DeepCopy() var clonedNormalizationInfo = NormalizationInfo.DeepCopy(); // Preserve all configuration properties to ensure deployment behavior, model adaptation, - // and training history are maintained across deep copy + // training history, and Graph RAG configuration are maintained across deep copy return new PredictionModelResult( clonedOptimizationResult, clonedNormalizationInfo, @@ -1161,7 +1233,12 @@ public IFullModel DeepCopy() crossValidationResult: CrossValidationResult, agentConfig: AgentConfig, agentRecommendation: AgentRecommendation, - deploymentConfiguration: DeploymentConfiguration); + deploymentConfiguration: DeploymentConfiguration, + jitCompiledFunction: null, // JIT compilation is model-specific, don't copy + inferenceOptimizationConfig: InferenceOptimizationConfig, + knowledgeGraph: KnowledgeGraph, + graphStore: GraphStore, + hybridGraphRetriever: HybridGraphRetriever); } /// @@ -1585,6 +1662,208 @@ private string ProcessQueryWithProcessors(string query) return processedQuery; } + /// + /// Queries the knowledge graph to find related nodes by entity name or label. + /// + /// The search query (entity name or partial match). + /// Maximum number of results to return. + /// Collection of matching graph nodes. + /// Thrown when Graph RAG is not configured. + /// + /// For Beginners: This method searches the knowledge graph for entities matching your query. + /// Unlike vector search which finds similar documents, this finds entities by name or label. + /// + /// Example: + /// + /// var nodes = result.QueryKnowledgeGraph("Einstein", topK: 5); + /// foreach (var node in nodes) + /// { + /// Console.WriteLine($"{node.Label}: {node.Id}"); + /// } + /// + /// + /// + public IEnumerable> QueryKnowledgeGraph(string query, int topK = 10) + { + if (KnowledgeGraph == null) + { + throw new InvalidOperationException( + "Knowledge graph not configured. Configure Graph RAG using PredictionModelBuilder.ConfigureRetrievalAugmentedGeneration() before building the model."); + } + + if (string.IsNullOrWhiteSpace(query)) + throw new ArgumentException("Query cannot be null or empty", nameof(query)); + + return KnowledgeGraph.FindRelatedNodes(query, topK); + } + + /// + /// Retrieves results using hybrid vector + graph search for enhanced context retrieval. + /// + /// The query embedding vector. + /// Number of initial candidates from vector search. + /// How many hops to traverse in the graph (0 = no expansion). + /// Maximum total results to return. + /// List of retrieval results with scores and source information. + /// Thrown when hybrid retriever is not configured. + /// + /// For Beginners: This method combines the best of both worlds: + /// 1. First, it finds similar documents using vector similarity (like traditional RAG) + /// 2. Then, it expands the context by traversing the knowledge graph to find related entities + /// + /// For example, searching for "photosynthesis" might: + /// - Find documents about photosynthesis via vector search + /// - Then traverse the graph to also include chlorophyll, plants, carbon dioxide + /// + /// This provides richer, more complete context than vector search alone. + /// + /// + public List> HybridRetrieve( + Vector queryEmbedding, + int topK = 5, + int expansionDepth = 1, + int maxResults = 10) + { + if (HybridGraphRetriever == null) + { + throw new InvalidOperationException( + "Hybrid graph retriever not configured. Configure Graph RAG with a document store using PredictionModelBuilder.ConfigureRetrievalAugmentedGeneration() before building the model."); + } + + if (queryEmbedding == null || queryEmbedding.Length == 0) + throw new ArgumentException("Query embedding cannot be null or empty", nameof(queryEmbedding)); + + return HybridGraphRetriever.Retrieve(queryEmbedding, topK, expansionDepth, maxResults); + } + + /// + /// Traverses the knowledge graph starting from a node using breadth-first search. + /// + /// The ID of the starting node. + /// Maximum traversal depth. + /// Collection of nodes reachable from the starting node in BFS order. + /// Thrown when knowledge graph is not configured. + /// + /// For Beginners: This method explores the graph starting from a specific entity, + /// discovering all connected entities up to a specified depth. + /// + /// Example: Starting from "Paris", depth=2 might find: + /// - Depth 1: France, Eiffel Tower, Seine River + /// - Depth 2: Europe, Iron, Water + /// + /// This is useful for understanding the context around a specific entity. + /// + /// + public IEnumerable> TraverseGraph(string startNodeId, int maxDepth = 2) + { + if (KnowledgeGraph == null) + { + throw new InvalidOperationException( + "Knowledge graph not configured. Configure Graph RAG using PredictionModelBuilder.ConfigureRetrievalAugmentedGeneration() before building the model."); + } + + if (string.IsNullOrWhiteSpace(startNodeId)) + throw new ArgumentException("Start node ID cannot be null or empty", nameof(startNodeId)); + + return KnowledgeGraph.BreadthFirstTraversal(startNodeId, maxDepth); + } + + /// + /// Finds the shortest path between two nodes in the knowledge graph. + /// + /// The ID of the starting node. + /// The ID of the target node. + /// List of node IDs representing the path, or empty list if no path exists. + /// Thrown when knowledge graph is not configured. + /// + /// For Beginners: This method finds how two entities are connected. + /// + /// Example: Finding the path between "Einstein" and "Princeton University" might return: + /// ["einstein", "worked_at_princeton", "princeton_university"] + /// + /// This is useful for understanding the relationships between concepts. + /// + /// + public List FindPathInGraph(string startNodeId, string endNodeId) + { + if (KnowledgeGraph == null) + { + throw new InvalidOperationException( + "Knowledge graph not configured. Configure Graph RAG using PredictionModelBuilder.ConfigureRetrievalAugmentedGeneration() before building the model."); + } + + if (string.IsNullOrWhiteSpace(startNodeId)) + throw new ArgumentException("Start node ID cannot be null or empty", nameof(startNodeId)); + if (string.IsNullOrWhiteSpace(endNodeId)) + throw new ArgumentException("End node ID cannot be null or empty", nameof(endNodeId)); + + return KnowledgeGraph.FindShortestPath(startNodeId, endNodeId); + } + + /// + /// Gets all edges (relationships) connected to a node in the knowledge graph. + /// + /// The ID of the node to query. + /// The direction of edges to retrieve. + /// Collection of edges connected to the node. + /// Thrown when knowledge graph is not configured. + /// Thrown when nodeId is null or empty. + /// + /// For Beginners: This method finds all relationships connected to an entity. + /// + /// Example: Getting edges for "Einstein" might return: + /// - Outgoing: STUDIED→Physics, WORKED_AT→Princeton, BORN_IN→Germany + /// - Incoming: INFLUENCED_BY→Newton + /// + /// + public IEnumerable> GetNodeRelationships(string nodeId, EdgeDirection direction = EdgeDirection.Both) + { + if (KnowledgeGraph == null) + { + throw new InvalidOperationException( + "Knowledge graph not configured. Configure Graph RAG using PredictionModelBuilder.ConfigureRetrievalAugmentedGeneration() before building the model."); + } + + if (string.IsNullOrWhiteSpace(nodeId)) + throw new ArgumentException("Node ID cannot be null or empty", nameof(nodeId)); + + var result = new List>(); + + if (direction == EdgeDirection.Outgoing || direction == EdgeDirection.Both) + { + result.AddRange(KnowledgeGraph.GetOutgoingEdges(nodeId)); + } + + if (direction == EdgeDirection.Incoming || direction == EdgeDirection.Both) + { + result.AddRange(KnowledgeGraph.GetIncomingEdges(nodeId)); + } + + return result; + } + + /// + /// Attaches Graph RAG components to a PredictionModelResult instance. + /// + /// The knowledge graph to attach. + /// The graph store backend to attach. + /// The hybrid retriever to attach. + /// + /// This method is internal and used by PredictionModelBuilder when loading/deserializing models. + /// Graph RAG components cannot be serialized (they contain file handles, WAL references, etc.), + /// so the builder automatically reattaches them when loading a model that was configured with Graph RAG. + /// Users should use PredictionModelBuilder.LoadModel() which handles this automatically. + /// + internal void AttachGraphComponents( + KnowledgeGraph? knowledgeGraph = null, + IGraphStore? graphStore = null, + HybridGraphRetriever? hybridGraphRetriever = null) + { + KnowledgeGraph = knowledgeGraph; + GraphStore = graphStore; + HybridGraphRetriever = hybridGraphRetriever; + } + /// /// Saves the prediction model result's current state to a stream. /// diff --git a/src/PredictionModelBuilder.cs b/src/PredictionModelBuilder.cs index 2edbf1c2a..83abd8b4e 100644 --- a/src/PredictionModelBuilder.cs +++ b/src/PredictionModelBuilder.cs @@ -17,6 +17,7 @@ global using AiDotNet.MixedPrecision; global using AiDotNet.KnowledgeDistillation; global using AiDotNet.Deployment.Configuration; +global using AiDotNet.RetrievalAugmentedGeneration.Graph; namespace AiDotNet; @@ -54,6 +55,11 @@ public class PredictionModelBuilder : IPredictionModelBuilde private IReranker? _ragReranker; private IGenerator? _ragGenerator; private IEnumerable? _queryProcessors; + + // Graph RAG components for knowledge graph-enhanced retrieval + private KnowledgeGraph? _knowledgeGraph; + private IGraphStore? _graphStore; + private HybridGraphRetriever? _hybridGraphRetriever; private IMetaLearner? _metaLearner; private ICommunicationBackend? _distributedBackend; private DistributedStrategy _distributedStrategy = DistributedStrategy.DDP; @@ -565,7 +571,10 @@ public Task> BuildAsync() ragGenerator: _ragGenerator, queryProcessors: _queryProcessors, agentConfig: _agentConfig, - deploymentConfiguration: deploymentConfig); + deploymentConfiguration: deploymentConfig, + knowledgeGraph: _knowledgeGraph, + graphStore: _graphStore, + hybridGraphRetriever: _hybridGraphRetriever); return Task.FromResult(result); } @@ -917,7 +926,10 @@ public async Task> BuildAsync(TInput x agentRecommendation, deploymentConfig, jitCompiledFunction, - _inferenceOptimizationConfig); + _inferenceOptimizationConfig, + _knowledgeGraph, + _graphStore, + _hybridGraphRetriever); return finalResult; } @@ -1086,6 +1098,8 @@ public async Task> BuildAsync(int epis _gpuAccelerationConfig); // Return standard PredictionModelResult + // Note: This Build() overload doesn't perform JIT compilation (only the main Build() does), + // so jitCompiledFunction uses its default value of null var result = new PredictionModelResult( optimizationResult, normInfo, @@ -1099,7 +1113,11 @@ public async Task> BuildAsync(int epis crossValidationResult: null, _agentConfig, agentRecommendation: null, - deploymentConfig); + deploymentConfiguration: deploymentConfig, + inferenceOptimizationConfig: _inferenceOptimizationConfig, + knowledgeGraph: _knowledgeGraph, + graphStore: _graphStore, + hybridGraphRetriever: _hybridGraphRetriever); return result; } @@ -1206,6 +1224,14 @@ public PredictionModelResult DeserializeModel(byte[] modelDa var result = new PredictionModelResult(); result.Deserialize(modelData); + // Automatically reattach Graph RAG components if they were configured on this builder + // Graph RAG components cannot be serialized (file handles, WAL, etc.), so we reattach + // them from the builder's configuration to provide a seamless experience for users + if (_knowledgeGraph != null || _graphStore != null || _hybridGraphRetriever != null) + { + result.AttachGraphComponents(_knowledgeGraph, _graphStore, _hybridGraphRetriever); + } + return result; } @@ -1263,31 +1289,96 @@ public IPredictionModelBuilder ConfigureLoRA(ILoRAConfigurat /// /// Configures the retrieval-augmented generation (RAG) components for use during model inference. /// - /// Optional retriever for finding relevant documents. If not provided, RAG functionality won't be available. + /// Optional retriever for finding relevant documents. If not provided, standard RAG won't be available. /// Optional reranker for improving document ranking quality. If not provided, a default reranker will be used if RAG is configured. /// Optional generator for producing grounded answers. If not provided, a default generator will be used if RAG is configured. /// Optional query processors for improving search quality. + /// Optional graph storage backend for Graph RAG (e.g., MemoryGraphStore, FileGraphStore). + /// Optional pre-configured knowledge graph. If null but graphStore is provided, a new one is created. + /// Optional document store for hybrid vector + graph retrieval. /// This builder instance for method chaining. /// + /// /// For Beginners: RAG combines retrieval and generation to create answers backed by real documents. /// Configure it with: - /// - A retriever (finds relevant documents from your collection) - required for RAG - /// - A reranker (improves the ordering of retrieved documents) - optional, defaults provided - /// - A generator (creates answers based on the documents) - optional, defaults provided - /// - Optional query processors (improve search queries before retrieval) - /// + /// + /// A retriever (finds relevant documents from your collection) - required for standard RAG + /// A reranker (improves the ordering of retrieved documents) - optional, defaults provided + /// A generator (creates answers based on the documents) - optional, defaults provided + /// Optional query processors (improve search queries before retrieval) + /// + /// + /// + /// Graph RAG: When graphStore or knowledgeGraph is provided, enables knowledge graph-based + /// retrieval that finds related entities and their relationships, providing richer context than + /// vector similarity alone. Traditional RAG finds similar documents using vectors. Graph RAG goes further by + /// also exploring relationships between entities. For example, if you ask about "Paris", it can find + /// not just documents mentioning Paris, but also related concepts like France, Eiffel Tower, and Seine River. + /// + /// + /// Hybrid Retrieval: When both knowledgeGraph and documentStore are provided, creates a + /// HybridGraphRetriever that combines vector search and graph traversal for optimal results. + /// + /// + /// Disabling RAG: Call with all parameters as null to disable RAG functionality completely. + /// + /// /// RAG operations are performed during inference (after model training) via the PredictionModelResult. + /// /// public IPredictionModelBuilder ConfigureRetrievalAugmentedGeneration( IRetriever? retriever = null, IReranker? reranker = null, IGenerator? generator = null, - IEnumerable? queryProcessors = null) + IEnumerable? queryProcessors = null, + IGraphStore? graphStore = null, + KnowledgeGraph? knowledgeGraph = null, + IDocumentStore? documentStore = null) { + // Configure standard RAG components _ragRetriever = retriever; _ragReranker = reranker; _ragGenerator = generator; _queryProcessors = queryProcessors; + + // Configure Graph RAG components + // If all Graph RAG parameters are null, clear Graph RAG fields + if (graphStore == null && knowledgeGraph == null && documentStore == null) + { + _graphStore = null; + _knowledgeGraph = null; + _hybridGraphRetriever = null; + return this; + } + + _graphStore = graphStore; + + // Use provided knowledge graph or create one from the store + if (knowledgeGraph != null) + { + _knowledgeGraph = knowledgeGraph; + } + else if (graphStore != null) + { + _knowledgeGraph = new KnowledgeGraph(graphStore); + } + else + { + // No knowledge graph source provided, clear the field + _knowledgeGraph = null; + } + + // Create or clear hybrid retriever based on available components + if (_knowledgeGraph != null && documentStore != null) + { + _hybridGraphRetriever = new HybridGraphRetriever(_knowledgeGraph, documentStore); + } + else + { + // Clear hybrid retriever if dependencies are missing + _hybridGraphRetriever = null; + } + return this; } diff --git a/src/RetrievalAugmentedGeneration/Graph/BTreeIndex.cs b/src/RetrievalAugmentedGeneration/Graph/BTreeIndex.cs new file mode 100644 index 000000000..739216123 --- /dev/null +++ b/src/RetrievalAugmentedGeneration/Graph/BTreeIndex.cs @@ -0,0 +1,319 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; + +namespace AiDotNet.RetrievalAugmentedGeneration.Graph; + +/// +/// Simple file-based index for mapping string keys to file offsets. +/// +/// +/// +/// This class provides a persistent index structure that maps string keys (e.g., node IDs) +/// to byte offsets in data files. The index is stored on disk and reloaded on restart, +/// enabling fast lookups without scanning entire data files. +/// +/// For Beginners: Think of this like a book's index at the back. +/// +/// Without an index: +/// - To find "photosynthesis", you'd read every page from start to finish +/// - Very slow for large books (or large data files) +/// +/// With an index: +/// - Look up "photosynthesis" → find it's on page 157 +/// - Jump directly to page 157 +/// - Much faster! +/// +/// This class does the same for graph data: +/// - Key: "node_alice_001" +/// - Value: byte offset 45678 in nodes.dat file +/// - We can jump directly to byte 45678 to read Alice's data +/// +/// The index itself is stored in a file so it survives application restarts. +/// +/// +/// Implementation Note: This is a simplified index using a sorted dictionary. +/// For production systems with millions of entries, consider implementing a true +/// B-Tree structure with splitting/merging nodes, or use an embedded database like +/// SQLite or LevelDB. +/// +/// +public class BTreeIndex : IDisposable +{ + private readonly string _indexFilePath; + private readonly SortedDictionary _index; + private bool _isDirty; + private bool _disposed; + + /// + /// Gets the number of entries in the index. + /// + public int Count => _index.Count; + + /// + /// Initializes a new instance of the class. + /// + /// The path to the index file on disk. + public BTreeIndex(string indexFilePath) + { + _indexFilePath = indexFilePath ?? throw new ArgumentNullException(nameof(indexFilePath)); + _index = new SortedDictionary(); + _isDirty = false; + + LoadFromDisk(); + } + + /// + /// Adds or updates a key-offset mapping in the index. + /// + /// The key to index (e.g., node ID). + /// The byte offset in the data file. + /// + /// For Beginners: This adds an entry to the index. + /// + /// Example: + /// - index.Add("alice", 1024) + /// - This means: "Alice's data starts at byte 1024 in the file" + /// + /// If "alice" already exists, it updates to the new offset. + /// The index is marked as "dirty" and will be saved to disk later. + /// + /// + public void Add(string key, long offset) + { + if (string.IsNullOrWhiteSpace(key)) + throw new ArgumentException("Key cannot be null or whitespace", nameof(key)); + if (offset < 0) + throw new ArgumentException("Offset cannot be negative", nameof(offset)); + + _index[key] = offset; + _isDirty = true; + } + + /// + /// Retrieves the file offset associated with a key. + /// + /// The key to look up. + /// The byte offset if found; otherwise, -1. + /// + /// For Beginners: This looks up where data is stored. + /// + /// Example: + /// - var offset = index.Get("alice") + /// - Returns: 1024 (meaning Alice's data is at byte 1024) + /// - Or returns -1 if "alice" is not in the index + /// + /// + public long Get(string key) + { + if (string.IsNullOrWhiteSpace(key)) + return -1; + + return _index.TryGetValue(key, out var offset) ? offset : -1; + } + + /// + /// Checks if a key exists in the index. + /// + /// The key to check. + /// True if the key exists; otherwise, false. + public bool Contains(string key) + { + return !string.IsNullOrWhiteSpace(key) && _index.ContainsKey(key); + } + + /// + /// Removes a key from the index. + /// + /// The key to remove. + /// True if the key was found and removed; otherwise, false. + /// + /// For Beginners: This removes an entry from the index. + /// + /// Example: + /// - index.Remove("alice") + /// - Now we can no longer look up where Alice's data is + /// - The actual data file is NOT modified - just the index + /// + /// Note: For production systems, you'd typically mark entries as deleted + /// rather than removing them, to avoid fragmenting the data file. + /// + /// + public bool Remove(string key) + { + if (string.IsNullOrWhiteSpace(key)) + return false; + + var removed = _index.Remove(key); + if (removed) + _isDirty = true; + + return removed; + } + + /// + /// Gets all keys in the index. + /// + /// Collection of all keys. + public IEnumerable GetAllKeys() + { + return _index.Keys.ToList(); + } + + /// + /// Removes all entries from the index. + /// + public void Clear() + { + _index.Clear(); + _isDirty = true; + } + + /// + /// Saves the index to disk if it has been modified. + /// + /// + /// + /// This method writes the index to disk in a simple text format: + /// - Each line: "key|offset" + /// - Example: "alice|1024" + /// + /// The file is only written if changes have been made (_isDirty flag). + /// This is called automatically by Dispose() to ensure data isn't lost. + /// + /// For Beginners: This saves the index to a file. + /// + /// Think of it like saving your work: + /// - You've been adding entries to the index (in memory) + /// - Flush() writes everything to disk + /// - Now the index survives if the program crashes or restarts + /// + /// Format example (nodes_index.db): + /// ``` + /// alice|1024 + /// bob|2048 + /// charlie|3072 + /// ``` + /// + /// + public void Flush() + { + if (!_isDirty) + return; + + try + { + // Ensure directory exists + var directory = Path.GetDirectoryName(_indexFilePath); + if (!string.IsNullOrEmpty(directory) && !Directory.Exists(directory)) + Directory.CreateDirectory(directory); + + // Write index to temporary file first (atomic write) + var tempPath = _indexFilePath + ".tmp"; + using (var writer = new StreamWriter(tempPath, false, Encoding.UTF8)) + { + foreach (var kvp in _index) + { + writer.WriteLine($"{kvp.Key}|{kvp.Value}"); + } + } + + // Replace old index file with new one atomically + if (File.Exists(_indexFilePath)) + { + // Use File.Replace for atomic replacement on Windows + var backupPath = _indexFilePath + ".bak"; + File.Replace(tempPath, _indexFilePath, backupPath); + // Clean up backup file + if (File.Exists(backupPath)) + File.Delete(backupPath); + } + else + { + File.Move(tempPath, _indexFilePath); + } + + _isDirty = false; + } + catch (IOException ex) + { + throw new IOException($"Failed to flush index to disk: {_indexFilePath}", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException($"Failed to flush index to disk: {_indexFilePath}", ex); + } + } + + /// + /// Loads the index from disk if it exists. + /// + private void LoadFromDisk() + { + if (!File.Exists(_indexFilePath)) + return; + + try + { + using var reader = new StreamReader(_indexFilePath, Encoding.UTF8); + string? line; + while ((line = reader.ReadLine()) != null) + { + var parts = line.Split('|'); + if (parts.Length != 2) + continue; + + var key = parts[0]; + if (long.TryParse(parts[1], out var offset)) + { + _index[key] = offset; + } + } + + _isDirty = false; + } + catch (IOException ex) + { + throw new IOException($"Failed to load index from disk: {_indexFilePath}", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException($"Failed to load index from disk: {_indexFilePath}", ex); + } + } + + /// + /// Disposes the index, ensuring all changes are saved to disk. + /// + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Releases resources used by the index. + /// + /// True if called from Dispose(), false if called from finalizer. + protected virtual void Dispose(bool disposing) + { + if (_disposed) + return; + + try + { + if (disposing) + { + // Flush managed resources + Flush(); + } + } + finally + { + // Ensure _disposed is set even if Flush throws + _disposed = true; + } + } +} diff --git a/src/RetrievalAugmentedGeneration/Graph/FileGraphStore.cs b/src/RetrievalAugmentedGeneration/Graph/FileGraphStore.cs new file mode 100644 index 000000000..deb172b83 --- /dev/null +++ b/src/RetrievalAugmentedGeneration/Graph/FileGraphStore.cs @@ -0,0 +1,1011 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using AiDotNet.Interfaces; +using Newtonsoft.Json; + +namespace AiDotNet.RetrievalAugmentedGeneration.Graph; + +/// +/// File-based implementation of with persistent storage on disk. +/// +/// The numeric type used for vector operations. +/// +/// +/// This implementation provides persistent graph storage using files: +/// - nodes.dat: Binary file containing serialized nodes +/// - edges.dat: Binary file containing serialized edges +/// - node_index.db: B-Tree index mapping node IDs to file offsets +/// - edge_index.db: B-Tree index mapping edge IDs to file offsets +/// +/// For Beginners: This stores your graph on disk so it survives restarts. +/// +/// How it works: +/// 1. When you add a node, it's written to nodes.dat +/// 2. The position (offset) is recorded in node_index.db +/// 3. To retrieve a node, we look up its offset and read from that position +/// 4. Everything is saved to disk automatically +/// +/// Pros: +/// - 💾 Data persists across restarts +/// - 🔄 Can handle graphs larger than RAM +/// - 📁 Simple file-based storage (no database required) +/// +/// Cons: +/// - 🐌 Slower than in-memory (disk I/O overhead) +/// - 🔒 Not suitable for concurrent access from multiple processes +/// - 📦 No compression (files can be large) +/// +/// Good for: +/// - Applications that need to save graph state +/// - Graphs up to a few million nodes +/// - Single-process applications +/// +/// Not good for: +/// - Real-time systems requiring sub-millisecond latency +/// - Multi-process concurrent access +/// - Distributed systems (use Neo4j or similar instead) +/// +/// +public class FileGraphStore : IGraphStore, IDisposable +{ + private readonly string _storageDirectory; + private readonly string _nodesFilePath; + private readonly string _edgesFilePath; + private readonly BTreeIndex _nodeIndex; + private readonly BTreeIndex _edgeIndex; + private readonly WriteAheadLog? _wal; + + // In-memory caches for indices and metadata (thread-safe for concurrent async access) + private readonly ConcurrentDictionary> _outgoingEdges; // nodeId -> edge IDs + private readonly ConcurrentDictionary> _incomingEdges; // nodeId -> edge IDs + private readonly ConcurrentDictionary> _nodesByLabel; // label -> node IDs + private readonly object _cacheLock = new object(); // Lock for modifying HashSet contents + + private readonly JsonSerializerSettings _jsonSettings; + private bool _disposed; + + /// + public int NodeCount => _nodeIndex.Count; + + /// + public int EdgeCount => _edgeIndex.Count; + + /// + /// Initializes a new instance of the class. + /// + /// The directory where graph data files will be stored. + /// Optional Write-Ahead Log for ACID transactions and crash recovery. + public FileGraphStore(string storageDirectory, WriteAheadLog? wal = null) + { + if (string.IsNullOrWhiteSpace(storageDirectory)) + throw new ArgumentException("Storage directory cannot be null or whitespace", nameof(storageDirectory)); + + _storageDirectory = storageDirectory; + _nodesFilePath = Path.Combine(storageDirectory, "nodes.dat"); + _edgesFilePath = Path.Combine(storageDirectory, "edges.dat"); + _wal = wal; + + // Create directory if it doesn't exist + if (!Directory.Exists(storageDirectory)) + Directory.CreateDirectory(storageDirectory); + + // Initialize indices + _nodeIndex = new BTreeIndex(Path.Combine(storageDirectory, "node_index.db")); + _edgeIndex = new BTreeIndex(Path.Combine(storageDirectory, "edge_index.db")); + + // Initialize in-memory structures (thread-safe) + _outgoingEdges = new ConcurrentDictionary>(); + _incomingEdges = new ConcurrentDictionary>(); + _nodesByLabel = new ConcurrentDictionary>(); + + _jsonSettings = new JsonSerializerSettings + { + Formatting = Formatting.None + }; + + // Rebuild in-memory indices from persisted data + RebuildInMemoryIndices(); + } + + /// + public void AddNode(GraphNode node) + { + if (node == null) + throw new ArgumentNullException(nameof(node)); + + try + { + // Log to WAL first (durability) + _wal?.LogAddNode(node); + + // Serialize node to JSON + var json = JsonConvert.SerializeObject(node, _jsonSettings); + var bytes = Encoding.UTF8.GetBytes(json); + + // Get current file position (or reuse existing offset if updating) + long offset; + // For updates, we append to the end (old data becomes garbage) + // In production, you'd implement compaction to reclaim space + offset = new FileInfo(_nodesFilePath).Exists ? new FileInfo(_nodesFilePath).Length : 0; + + // Write node data to file + using (var stream = new FileStream(_nodesFilePath, FileMode.Append, FileAccess.Write, FileShare.Read)) + { + // Write length prefix (4 bytes) + var lengthBytes = BitConverter.GetBytes(bytes.Length); + stream.Write(lengthBytes, 0, 4); + + // Write JSON data + stream.Write(bytes, 0, bytes.Length); + } + + // Update index + _nodeIndex.Add(node.Id, offset); + + // Update in-memory indices (thread-safe) + lock (_cacheLock) + { + var labelSet = _nodesByLabel.GetOrAdd(node.Label, _ => new HashSet()); + labelSet.Add(node.Id); + + _outgoingEdges.GetOrAdd(node.Id, _ => new HashSet()); + _incomingEdges.GetOrAdd(node.Id, _ => new HashSet()); + } + + // Flush indices periodically (every 100 operations for performance) + if (_nodeIndex.Count % 100 == 0) + _nodeIndex.Flush(); + } + catch (IOException ex) + { + throw new IOException($"Failed to add node '{node.Id}' to file store", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException($"Failed to add node '{node.Id}' to file store due to unauthorized access", ex); + } + catch (JsonSerializationException ex) + { + throw new IOException($"Failed to serialize node '{node.Id}' to JSON", ex); + } + } + + /// + public void AddEdge(GraphEdge edge) + { + if (edge == null) + throw new ArgumentNullException(nameof(edge)); + if (!_nodeIndex.Contains(edge.SourceId)) + throw new InvalidOperationException($"Source node '{edge.SourceId}' does not exist"); + if (!_nodeIndex.Contains(edge.TargetId)) + throw new InvalidOperationException($"Target node '{edge.TargetId}' does not exist"); + + try + { + // Log to WAL first (durability) + _wal?.LogAddEdge(edge); + + // Serialize edge to JSON + var json = JsonConvert.SerializeObject(edge, _jsonSettings); + var bytes = Encoding.UTF8.GetBytes(json); + + // Get current file position + long offset = new FileInfo(_edgesFilePath).Exists ? new FileInfo(_edgesFilePath).Length : 0; + + // Write edge data to file + using (var stream = new FileStream(_edgesFilePath, FileMode.Append, FileAccess.Write, FileShare.Read)) + { + // Write length prefix (4 bytes) + var lengthBytes = BitConverter.GetBytes(bytes.Length); + stream.Write(lengthBytes, 0, 4); + + // Write JSON data + stream.Write(bytes, 0, bytes.Length); + } + + // Update index + _edgeIndex.Add(edge.Id, offset); + + // Update in-memory edge indices (thread-safe) + lock (_cacheLock) + { + if (_outgoingEdges.TryGetValue(edge.SourceId, out var outgoingSet)) + outgoingSet.Add(edge.Id); + if (_incomingEdges.TryGetValue(edge.TargetId, out var incomingSet)) + incomingSet.Add(edge.Id); + } + + // Flush indices periodically + if (_edgeIndex.Count % 100 == 0) + _edgeIndex.Flush(); + } + catch (IOException ex) + { + throw new IOException($"Failed to add edge '{edge.Id}' to file store", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException($"Failed to add edge '{edge.Id}' to file store due to access permissions", ex); + } + catch (JsonSerializationException ex) + { + throw new IOException($"Failed to serialize edge '{edge.Id}' to JSON", ex); + } + } + + /// + public GraphNode? GetNode(string nodeId) + { + if (string.IsNullOrWhiteSpace(nodeId)) + return null; + + var offset = _nodeIndex.Get(nodeId); + if (offset < 0) + return null; + + try + { + using var stream = new FileStream(_nodesFilePath, FileMode.Open, FileAccess.Read, FileShare.Read); + stream.Seek(offset, SeekOrigin.Begin); + + // Read length prefix - ensure all 4 bytes are read + var lengthBytes = new byte[4]; + ReadExactly(stream, lengthBytes, 0, 4); + var length = BitConverter.ToInt32(lengthBytes, 0); + + // Read JSON data - ensure all bytes are read + var jsonBytes = new byte[length]; + ReadExactly(stream, jsonBytes, 0, length); + var json = Encoding.UTF8.GetString(jsonBytes); + + // Deserialize + return JsonConvert.DeserializeObject>(json, _jsonSettings); + } + catch (EndOfStreamException ex) + { + throw new IOException($"Failed to read node '{nodeId}' from file store - data may be corrupted", ex); + } + catch (IOException ex) + { + throw new IOException($"Failed to read node '{nodeId}' from file store", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException($"Failed to read node '{nodeId}' from file store", ex); + } + catch (JsonSerializationException ex) + { + throw new IOException($"Failed to deserialize node '{nodeId}' from JSON", ex); + } + } + + /// + public GraphEdge? GetEdge(string edgeId) + { + if (string.IsNullOrWhiteSpace(edgeId)) + return null; + + var offset = _edgeIndex.Get(edgeId); + if (offset < 0) + return null; + + try + { + using var stream = new FileStream(_edgesFilePath, FileMode.Open, FileAccess.Read, FileShare.Read); + stream.Seek(offset, SeekOrigin.Begin); + + // Read length prefix - ensure all 4 bytes are read + var lengthBytes = new byte[4]; + ReadExactly(stream, lengthBytes, 0, 4); + var length = BitConverter.ToInt32(lengthBytes, 0); + + // Read JSON data - ensure all bytes are read + var jsonBytes = new byte[length]; + ReadExactly(stream, jsonBytes, 0, length); + var json = Encoding.UTF8.GetString(jsonBytes); + + // Deserialize + return JsonConvert.DeserializeObject>(json, _jsonSettings); + } + catch (EndOfStreamException ex) + { + throw new IOException($"Failed to read edge '{edgeId}' from file store - data may be corrupted", ex); + } + catch (IOException ex) + { + throw new IOException($"Failed to read edge '{edgeId}' from file store", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException($"Failed to read edge '{edgeId}' from file store", ex); + } + catch (JsonSerializationException ex) + { + throw new IOException($"Failed to deserialize edge '{edgeId}' from JSON", ex); + } + } + + /// + public bool RemoveNode(string nodeId) + { + if (string.IsNullOrWhiteSpace(nodeId) || !_nodeIndex.Contains(nodeId)) + return false; + + try + { + var node = GetNode(nodeId); + if (node == null) + return false; + + // Log to WAL first (durability) + _wal?.LogRemoveNode(nodeId); + + // Remove all outgoing edges (thread-safe) + if (_outgoingEdges.TryGetValue(nodeId, out var outgoing)) + { + List edgesToRemove; + lock (_cacheLock) + { + edgesToRemove = outgoing.ToList(); + } + foreach (var edgeId in edgesToRemove) + { + RemoveEdge(edgeId); + } + _outgoingEdges.TryRemove(nodeId, out _); + } + + // Remove all incoming edges (thread-safe) + if (_incomingEdges.TryGetValue(nodeId, out var incoming)) + { + List edgesToRemove; + lock (_cacheLock) + { + edgesToRemove = incoming.ToList(); + } + foreach (var edgeId in edgesToRemove) + { + RemoveEdge(edgeId); + } + _incomingEdges.TryRemove(nodeId, out _); + } + + // Remove from label index (thread-safe) + lock (_cacheLock) + { + if (_nodesByLabel.TryGetValue(node.Label, out var nodeIds)) + { + nodeIds.Remove(nodeId); + if (nodeIds.Count == 0) + _nodesByLabel.TryRemove(node.Label, out _); + } + } + + // Remove from node index (marks as deleted, actual data remains) + _nodeIndex.Remove(nodeId); + _nodeIndex.Flush(); + + return true; + } + catch (IOException ex) + { + throw new IOException($"Failed to remove node '{nodeId}' from file store", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException($"Failed to remove node '{nodeId}' from file store", ex); + } + } + + /// + public bool RemoveEdge(string edgeId) + { + if (string.IsNullOrWhiteSpace(edgeId) || !_edgeIndex.Contains(edgeId)) + return false; + + try + { + var edge = GetEdge(edgeId); + if (edge == null) + return false; + + // Log to WAL first (durability) + _wal?.LogRemoveEdge(edgeId); + + // Remove from in-memory indices (thread-safe) + lock (_cacheLock) + { + if (_outgoingEdges.TryGetValue(edge.SourceId, out var outgoing)) + outgoing.Remove(edgeId); + if (_incomingEdges.TryGetValue(edge.TargetId, out var incoming)) + incoming.Remove(edgeId); + } + + // Remove from edge index + _edgeIndex.Remove(edgeId); + _edgeIndex.Flush(); + + return true; + } + catch (IOException ex) + { + throw new IOException($"Failed to remove edge '{edgeId}' from file store", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException($"Failed to remove edge '{edgeId}' from file store", ex); + } + } + + /// + public IEnumerable> GetOutgoingEdges(string nodeId) + { + if (!_outgoingEdges.TryGetValue(nodeId, out var edgeIds)) + return Enumerable.Empty>(); + + // Take snapshot of edge IDs under lock to avoid concurrent modification + List snapshot; + lock (_cacheLock) + { + snapshot = edgeIds.ToList(); + } + return snapshot.Select(id => GetEdge(id)).Where(e => e != null).Cast>(); + } + + /// + public IEnumerable> GetIncomingEdges(string nodeId) + { + if (!_incomingEdges.TryGetValue(nodeId, out var edgeIds)) + return Enumerable.Empty>(); + + // Take snapshot of edge IDs under lock to avoid concurrent modification + List snapshot; + lock (_cacheLock) + { + snapshot = edgeIds.ToList(); + } + return snapshot.Select(id => GetEdge(id)).Where(e => e != null).Cast>(); + } + + /// + public IEnumerable> GetNodesByLabel(string label) + { + if (!_nodesByLabel.TryGetValue(label, out var nodeIds)) + return Enumerable.Empty>(); + + // Take snapshot of node IDs under lock to avoid concurrent modification + List snapshot; + lock (_cacheLock) + { + snapshot = nodeIds.ToList(); + } + return snapshot.Select(id => GetNode(id)).Where(n => n != null).Cast>(); + } + + /// + public IEnumerable> GetAllNodes() + { + return _nodeIndex.GetAllKeys().Select(id => GetNode(id)).Where(n => n != null).Cast>(); + } + + /// + public IEnumerable> GetAllEdges() + { + return _edgeIndex.GetAllKeys().Select(id => GetEdge(id)).Where(e => e != null).Cast>(); + } + + /// + public void Clear() + { + try + { + // Clear in-memory structures + _outgoingEdges.Clear(); + _incomingEdges.Clear(); + _nodesByLabel.Clear(); + + // Clear indices + _nodeIndex.Clear(); + _edgeIndex.Clear(); + _nodeIndex.Flush(); + _edgeIndex.Flush(); + + // Delete data files + if (File.Exists(_nodesFilePath)) + File.Delete(_nodesFilePath); + if (File.Exists(_edgesFilePath)) + File.Delete(_edgesFilePath); + } + catch (IOException ex) + { + throw new IOException("Failed to clear file store", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException("Failed to clear file store", ex); + } + } + + /// + /// Rebuilds in-memory indices by scanning all nodes and edges. + /// + /// + /// This is called during initialization to restore the in-memory indices + /// from persisted data. It scans all nodes to rebuild label indices and + /// all edges to rebuild outgoing/incoming edge indices. + /// + private void RebuildInMemoryIndices() + { + try + { + // Rebuild node-related indices (thread-safe using GetOrAdd) + foreach (var node in _nodeIndex.GetAllKeys().Select(GetNode).OfType>()) + { + // Rebuild label index + var labelSet = _nodesByLabel.GetOrAdd(node.Label, _ => new HashSet()); + lock (_cacheLock) + { + labelSet.Add(node.Id); + } + + // Initialize edge indices + _outgoingEdges.GetOrAdd(node.Id, _ => new HashSet()); + _incomingEdges.GetOrAdd(node.Id, _ => new HashSet()); + } + + // Rebuild edge indices (thread-safe) + foreach (var edge in _edgeIndex.GetAllKeys().Select(GetEdge).OfType>()) + { + lock (_cacheLock) + { + if (_outgoingEdges.TryGetValue(edge.SourceId, out var outgoingSet)) + outgoingSet.Add(edge.Id); + if (_incomingEdges.TryGetValue(edge.TargetId, out var incomingSet)) + incomingSet.Add(edge.Id); + } + } + } + catch (IOException ex) + { + throw new IOException("Failed to rebuild in-memory indices", ex); + } + catch (InvalidDataException ex) + { + throw new IOException("Failed to rebuild in-memory indices", ex); + } + } + + // Async methods for non-blocking I/O operations + + /// + public async Task AddNodeAsync(GraphNode node) + { + if (node == null) + throw new ArgumentNullException(nameof(node)); + + try + { + // Log to WAL first (durability) + _wal?.LogAddNode(node); + + // Serialize node to JSON + var json = JsonConvert.SerializeObject(node, _jsonSettings); + var bytes = Encoding.UTF8.GetBytes(json); + + // Get current file position + long offset = new FileInfo(_nodesFilePath).Exists ? new FileInfo(_nodesFilePath).Length : 0; + + // Write node data to file asynchronously + using (var stream = new FileStream(_nodesFilePath, FileMode.Append, FileAccess.Write, FileShare.Read, 4096, useAsync: true)) + { + // Write length prefix (4 bytes) + var lengthBytes = BitConverter.GetBytes(bytes.Length); + await stream.WriteAsync(lengthBytes, 0, 4); + + // Write JSON data + await stream.WriteAsync(bytes, 0, bytes.Length); + } + + // Update index + _nodeIndex.Add(node.Id, offset); + + // Update in-memory indices (thread-safe) + lock (_cacheLock) + { + var labelSet = _nodesByLabel.GetOrAdd(node.Label, _ => new HashSet()); + labelSet.Add(node.Id); + + _outgoingEdges.GetOrAdd(node.Id, _ => new HashSet()); + _incomingEdges.GetOrAdd(node.Id, _ => new HashSet()); + } + + // Flush indices periodically + if (_nodeIndex.Count % 100 == 0) + _nodeIndex.Flush(); + } + catch (IOException ex) + { + throw new IOException($"Failed to add node '{node.Id}' to file store", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException($"Failed to add node '{node.Id}' to file store due to access permissions", ex); + } + catch (JsonSerializationException ex) + { + throw new IOException($"Failed to serialize node '{node.Id}' to JSON", ex); + } + } + + /// + public async Task AddEdgeAsync(GraphEdge edge) + { + if (edge == null) + throw new ArgumentNullException(nameof(edge)); + if (!_nodeIndex.Contains(edge.SourceId)) + throw new InvalidOperationException($"Source node '{edge.SourceId}' does not exist"); + if (!_nodeIndex.Contains(edge.TargetId)) + throw new InvalidOperationException($"Target node '{edge.TargetId}' does not exist"); + + try + { + // Log to WAL first (durability) + _wal?.LogAddEdge(edge); + + // Serialize edge to JSON + var json = JsonConvert.SerializeObject(edge, _jsonSettings); + var bytes = Encoding.UTF8.GetBytes(json); + + // Get current file position + long offset = new FileInfo(_edgesFilePath).Exists ? new FileInfo(_edgesFilePath).Length : 0; + + // Write edge data to file asynchronously + using (var stream = new FileStream(_edgesFilePath, FileMode.Append, FileAccess.Write, FileShare.Read, 4096, useAsync: true)) + { + // Write length prefix (4 bytes) + var lengthBytes = BitConverter.GetBytes(bytes.Length); + await stream.WriteAsync(lengthBytes, 0, 4); + + // Write JSON data + await stream.WriteAsync(bytes, 0, bytes.Length); + } + + // Update index + _edgeIndex.Add(edge.Id, offset); + + // Update in-memory edge indices (thread-safe) + lock (_cacheLock) + { + if (_outgoingEdges.TryGetValue(edge.SourceId, out var outgoingSet)) + outgoingSet.Add(edge.Id); + if (_incomingEdges.TryGetValue(edge.TargetId, out var incomingSet)) + incomingSet.Add(edge.Id); + } + + // Flush indices periodically + if (_edgeIndex.Count % 100 == 0) + _edgeIndex.Flush(); + } + catch (IOException ex) + { + throw new IOException($"Failed to add edge '{edge.Id}' to file store", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException($"Failed to add edge '{edge.Id}' to file store due to access permissions", ex); + } + } + + /// + public async Task?> GetNodeAsync(string nodeId) + { + if (string.IsNullOrWhiteSpace(nodeId)) + return null; + + var offset = _nodeIndex.Get(nodeId); + if (offset < 0) + return null; + + try + { + using var stream = new FileStream(_nodesFilePath, FileMode.Open, FileAccess.Read, FileShare.Read, 4096, useAsync: true); + stream.Seek(offset, SeekOrigin.Begin); + + // Read length prefix - ensure all 4 bytes are read + var lengthBytes = new byte[4]; + await ReadExactlyAsync(stream, lengthBytes, 0, 4); + var length = BitConverter.ToInt32(lengthBytes, 0); + + // Read JSON data - ensure all bytes are read + var jsonBytes = new byte[length]; + await ReadExactlyAsync(stream, jsonBytes, 0, length); + var json = Encoding.UTF8.GetString(jsonBytes); + + // Deserialize + return JsonConvert.DeserializeObject>(json, _jsonSettings); + } + catch (EndOfStreamException ex) + { + throw new IOException($"Failed to read node '{nodeId}' from file store - data may be corrupted", ex); + } + catch (IOException ex) + { + throw new IOException($"Failed to read node '{nodeId}' from file store", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException($"Failed to read node '{nodeId}' from file store", ex); + } + catch (JsonSerializationException ex) + { + throw new IOException($"Failed to deserialize node '{nodeId}' from JSON", ex); + } + } + + /// + public async Task?> GetEdgeAsync(string edgeId) + { + if (string.IsNullOrWhiteSpace(edgeId)) + return null; + + var offset = _edgeIndex.Get(edgeId); + if (offset < 0) + return null; + + try + { + using var stream = new FileStream(_edgesFilePath, FileMode.Open, FileAccess.Read, FileShare.Read, 4096, useAsync: true); + stream.Seek(offset, SeekOrigin.Begin); + + // Read length prefix - ensure all 4 bytes are read + var lengthBytes = new byte[4]; + await ReadExactlyAsync(stream, lengthBytes, 0, 4); + var length = BitConverter.ToInt32(lengthBytes, 0); + + // Read JSON data - ensure all bytes are read + var jsonBytes = new byte[length]; + await ReadExactlyAsync(stream, jsonBytes, 0, length); + var json = Encoding.UTF8.GetString(jsonBytes); + + // Deserialize + return JsonConvert.DeserializeObject>(json, _jsonSettings); + } + catch (EndOfStreamException ex) + { + throw new IOException($"Failed to read edge '{edgeId}' from file store - data may be corrupted", ex); + } + catch (IOException ex) + { + throw new IOException($"Failed to read edge '{edgeId}' from file store", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException($"Failed to read edge '{edgeId}' from file store", ex); + } + catch (JsonSerializationException ex) + { + throw new IOException($"Failed to deserialize edge '{edgeId}' from JSON", ex); + } + } + + /// + public async Task RemoveNodeAsync(string nodeId) + { + if (string.IsNullOrWhiteSpace(nodeId) || !_nodeIndex.Contains(nodeId)) + return false; + + try + { + var node = await GetNodeAsync(nodeId); + if (node == null) + return false; + + // Log to WAL first (durability) + _wal?.LogRemoveNode(nodeId); + + // Remove all outgoing edges (thread-safe) + if (_outgoingEdges.TryGetValue(nodeId, out var outgoing)) + { + List edgesToRemove; + lock (_cacheLock) + { + edgesToRemove = outgoing.ToList(); + } + foreach (var edgeId in edgesToRemove) + { + await RemoveEdgeAsync(edgeId); + } + _outgoingEdges.TryRemove(nodeId, out _); + } + + // Remove all incoming edges (thread-safe) + if (_incomingEdges.TryGetValue(nodeId, out var incoming)) + { + List edgesToRemove; + lock (_cacheLock) + { + edgesToRemove = incoming.ToList(); + } + foreach (var edgeId in edgesToRemove) + { + await RemoveEdgeAsync(edgeId); + } + _incomingEdges.TryRemove(nodeId, out _); + } + + // Remove from label index (thread-safe) + lock (_cacheLock) + { + if (_nodesByLabel.TryGetValue(node.Label, out var nodeIds)) + { + nodeIds.Remove(nodeId); + if (nodeIds.Count == 0) + _nodesByLabel.TryRemove(node.Label, out _); + } + } + + // Remove from node index + _nodeIndex.Remove(nodeId); + _nodeIndex.Flush(); + + return true; + } + catch (IOException ex) + { + throw new IOException($"Failed to remove node '{nodeId}' from file store", ex); + } + catch (UnauthorizedAccessException ex) + { + throw new IOException($"Failed to remove node '{nodeId}' from file store", ex); + } + } + + /// + public Task RemoveEdgeAsync(string edgeId) + { + return Task.FromResult(RemoveEdge(edgeId)); + } + + /// + public async Task>> GetOutgoingEdgesAsync(string nodeId) + { + if (!_outgoingEdges.TryGetValue(nodeId, out var edgeIds)) + return Enumerable.Empty>(); + + // Take snapshot of edgeIds to avoid race condition during enumeration + List snapshot; + lock (_cacheLock) + { + snapshot = edgeIds.ToList(); + } + + var tasks = snapshot.Select(id => GetEdgeAsync(id)); + var results = await Task.WhenAll(tasks); + return results.OfType>().ToList(); + } + + /// + public async Task>> GetIncomingEdgesAsync(string nodeId) + { + if (!_incomingEdges.TryGetValue(nodeId, out var edgeIds)) + return Enumerable.Empty>(); + + // Take snapshot of edgeIds to avoid race condition during enumeration + List snapshot; + lock (_cacheLock) + { + snapshot = edgeIds.ToList(); + } + + var tasks = snapshot.Select(id => GetEdgeAsync(id)); + var results = await Task.WhenAll(tasks); + return results.OfType>().ToList(); + } + + /// + public async Task>> GetNodesByLabelAsync(string label) + { + if (!_nodesByLabel.TryGetValue(label, out var nodeIds)) + return Enumerable.Empty>(); + + // Take snapshot of nodeIds to avoid race condition during enumeration + List snapshot; + lock (_cacheLock) + { + snapshot = nodeIds.ToList(); + } + + var tasks = snapshot.Select(id => GetNodeAsync(id)); + var results = await Task.WhenAll(tasks); + return results.OfType>().ToList(); + } + + /// + public async Task>> GetAllNodesAsync() + { + var tasks = _nodeIndex.GetAllKeys().Select(id => GetNodeAsync(id)); + var results = await Task.WhenAll(tasks); + return results.OfType>().ToList(); + } + + /// + public async Task>> GetAllEdgesAsync() + { + var tasks = _edgeIndex.GetAllKeys().Select(id => GetEdgeAsync(id)); + var results = await Task.WhenAll(tasks); + return results.OfType>().ToList(); + } + + /// + public Task ClearAsync() + { + Clear(); + return Task.CompletedTask; + } + + /// + /// Disposes the file graph store, ensuring all changes are flushed to disk. + /// + public void Dispose() + { + if (_disposed) + return; + + try + { + _nodeIndex.Flush(); + _edgeIndex.Flush(); + _nodeIndex.Dispose(); + _edgeIndex.Dispose(); + } + finally + { + _disposed = true; + } + } + + /// + /// Reads exactly the specified number of bytes from the stream. + /// + /// The stream to read from. + /// The buffer to read into. + /// The offset in the buffer to start writing at. + /// The number of bytes to read. + /// Thrown if the stream ends before all bytes are read. + private static void ReadExactly(Stream stream, byte[] buffer, int offset, int count) + { + int totalRead = 0; + while (totalRead < count) + { + int bytesRead = stream.Read(buffer, offset + totalRead, count - totalRead); + if (bytesRead == 0) + throw new EndOfStreamException($"Unexpected end of stream. Expected {count} bytes, got {totalRead}."); + totalRead += bytesRead; + } + } + + /// + /// Asynchronously reads exactly the specified number of bytes from the stream. + /// + /// The stream to read from. + /// The buffer to read into. + /// The offset in the buffer to start writing at. + /// The number of bytes to read. + /// Thrown if the stream ends before all bytes are read. + private static async Task ReadExactlyAsync(Stream stream, byte[] buffer, int offset, int count) + { + int totalRead = 0; + while (totalRead < count) + { + int bytesRead = await stream.ReadAsync(buffer, offset + totalRead, count - totalRead); + if (bytesRead == 0) + throw new EndOfStreamException($"Unexpected end of stream. Expected {count} bytes, got {totalRead}."); + totalRead += bytesRead; + } + } +} diff --git a/src/RetrievalAugmentedGeneration/Graph/GraphAnalytics.cs b/src/RetrievalAugmentedGeneration/Graph/GraphAnalytics.cs new file mode 100644 index 000000000..3a28df3d8 --- /dev/null +++ b/src/RetrievalAugmentedGeneration/Graph/GraphAnalytics.cs @@ -0,0 +1,740 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace AiDotNet.RetrievalAugmentedGeneration.Graph; + +/// +/// Provides graph analytics algorithms for analyzing knowledge graphs. +/// +/// +/// +/// This class implements common graph algorithms used to analyze the structure +/// and importance of nodes and edges in a knowledge graph. +/// +/// For Beginners: Graph analytics help you understand your graph. +/// +/// Think of a social network: +/// - PageRank: Who are the most influential people? +/// - Degree Centrality: Who has the most connections? +/// - Closeness Centrality: Who can reach everyone quickly? +/// - Betweenness Centrality: Who connects different groups? +/// +/// These algorithms answer "who's important?" and "how are things connected?" +/// +/// +public static class GraphAnalytics +{ + /// + /// Calculates PageRank scores for all nodes in the graph. + /// + /// The numeric type used for vector operations. + /// The knowledge graph to analyze. + /// The damping factor (default: 0.85). Must be between 0 and 1. + /// Maximum number of iterations (default: 100). + /// Convergence threshold (default: 0.0001). + /// Dictionary mapping node IDs to their PageRank scores. + /// + /// + /// PageRank is an algorithm used by Google to rank web pages. In a knowledge graph, + /// it identifies the most "important" or "central" nodes based on the structure + /// of relationships. Nodes pointed to by many important nodes get higher scores. + /// + /// For Beginners: PageRank finds the most important nodes. + /// + /// Imagine a citation network: + /// - Papers cited by many other papers get high PageRank + /// - Papers cited by highly-cited papers get even higher PageRank + /// - It's like asking: "Which papers are most influential?" + /// + /// The algorithm: + /// 1. Start: All nodes have equal rank + /// 2. Iterate: Each node shares its rank with nodes it points to + /// 3. Damping: 85% of rank flows through edges, 15% jumps randomly + /// 4. Repeat until scores stabilize + /// + /// Higher PageRank = More important/central in the graph + /// + /// + /// Thrown when graph is null. + /// Thrown when dampingFactor is not between 0 and 1. + public static Dictionary CalculatePageRank( + KnowledgeGraph graph, + double dampingFactor = 0.85, + int maxIterations = 100, + double convergenceThreshold = 0.0001) + { + if (graph == null) + throw new ArgumentNullException(nameof(graph)); + if (dampingFactor < 0 || dampingFactor > 1) + throw new ArgumentOutOfRangeException(nameof(dampingFactor), "Damping factor must be between 0 and 1"); + + var nodes = graph.GetAllNodes().ToList(); + if (nodes.Count == 0) + return new Dictionary(); + + var nodeCount = nodes.Count; + var ranks = new Dictionary(); + var newRanks = new Dictionary(); + + // Initialize all nodes with equal rank + var initialRank = 1.0 / nodeCount; + foreach (var node in nodes) + { + ranks[node.Id] = initialRank; + newRanks[node.Id] = 0.0; + } + + // Iterate until convergence or max iterations + for (int iteration = 0; iteration < maxIterations; iteration++) + { + // Calculate new ranks + foreach (var node in nodes) + { + var incomingEdges = graph.GetIncomingEdges(node.Id).ToList(); + double rankSum = 0.0; + + foreach (var edge in incomingEdges) + { + var sourceNode = graph.GetNode(edge.SourceId); + if (sourceNode != null) + { + var outgoingCount = graph.GetOutgoingEdges(edge.SourceId).Count(); + if (outgoingCount > 0) + { + rankSum += ranks[edge.SourceId] / outgoingCount; + } + } + } + + // PageRank formula: PR(A) = (1-d)/N + d * sum(PR(Ti)/C(Ti)) + newRanks[node.Id] = (1 - dampingFactor) / nodeCount + dampingFactor * rankSum; + } + + // Check for convergence + double maxChange = 0.0; + foreach (var node in nodes) + { + var change = Math.Abs(newRanks[node.Id] - ranks[node.Id]); + if (change > maxChange) + maxChange = change; + ranks[node.Id] = newRanks[node.Id]; + } + + if (maxChange < convergenceThreshold) + break; + } + + return ranks; + } + + /// + /// Calculates degree centrality for all nodes in the graph. + /// + /// The numeric type used for vector operations. + /// The knowledge graph to analyze. + /// Whether to normalize scores by the maximum possible degree (default: true). + /// Dictionary mapping node IDs to their degree centrality scores. + /// + /// + /// Degree centrality is the simplest centrality measure. It counts the number of + /// edges connected to each node. Nodes with more connections are considered more central. + /// + /// For Beginners: Degree centrality counts connections. + /// + /// In a social network: + /// - Person with 100 friends has higher degree centrality than person with 10 friends + /// - It's the simplest measure: "Who knows the most people?" + /// + /// Types: + /// - Out-degree: How many edges go OUT (how many people do you follow?) + /// - In-degree: How many edges come IN (how many followers do you have?) + /// - Total degree: Sum of both + /// + /// This implementation calculates total degree (in + out). + /// + /// Normalized: Divides by (N-1) where N is total nodes, giving a score from 0 to 1. + /// + /// + public static Dictionary CalculateDegreeCentrality( + KnowledgeGraph graph, + bool normalized = true) + { + if (graph == null) + throw new ArgumentNullException(nameof(graph)); + + var nodes = graph.GetAllNodes().ToList(); + var centrality = new Dictionary(); + + if (nodes.Count == 0) + return centrality; + + foreach (var node in nodes) + { + var outDegree = graph.GetOutgoingEdges(node.Id).Count(); + var inDegree = graph.GetIncomingEdges(node.Id).Count(); + var totalDegree = outDegree + inDegree; + + // Normalize by the maximum possible degree (n-1) for undirected, + // or 2(n-1) for directed graphs + centrality[node.Id] = (normalized && nodes.Count > 1) + ? totalDegree / (2.0 * (nodes.Count - 1)) + : totalDegree; + } + + return centrality; + } + + /// + /// Calculates closeness centrality for all nodes in the graph. + /// + /// The numeric type used for vector operations. + /// The knowledge graph to analyze. + /// Dictionary mapping node IDs to their closeness centrality scores. + /// + /// + /// Closeness centrality measures how close a node is to all other nodes in the graph. + /// It's calculated as the inverse of the average shortest path distance to all other nodes. + /// Nodes that can quickly reach all others have high closeness centrality. + /// + /// For Beginners: Closeness centrality measures "how close" you are to everyone. + /// + /// Think of an airport network: + /// - Hub airports (like Atlanta) can reach anywhere with few layovers → high closeness + /// - Remote airports need many connections → low closeness + /// + /// Algorithm: + /// 1. For each node, find shortest path to every other node + /// 2. Calculate average distance + /// 3. Closeness = 1 / average_distance + /// + /// Higher score = Can reach everyone more quickly + /// + /// Note: If nodes are unreachable, they're excluded from the calculation. + /// + /// + public static Dictionary CalculateClosenessCentrality( + KnowledgeGraph graph) + { + if (graph == null) + throw new ArgumentNullException(nameof(graph)); + + var nodes = graph.GetAllNodes().ToList(); + var centrality = new Dictionary(); + + if (nodes.Count <= 1) + { + foreach (var node in nodes) + centrality[node.Id] = 0.0; + return centrality; + } + + foreach (var node in nodes) + { + var distances = BreadthFirstSearchDistances(graph, node.Id); + var reachableNodes = distances.Values.Where(d => d > 0 && d < int.MaxValue).ToList(); + + if (reachableNodes.Count == 0) + { + centrality[node.Id] = 0.0; + } + else + { + var averageDistance = reachableNodes.Average(); + centrality[node.Id] = (nodes.Count - 1) / (averageDistance * reachableNodes.Count); + } + } + + return centrality; + } + + /// + /// Calculates betweenness centrality for all nodes in the graph. + /// + /// The numeric type used for vector operations. + /// The knowledge graph to analyze. + /// Whether to normalize scores (default: true). + /// Dictionary mapping node IDs to their betweenness centrality scores. + /// + /// + /// Betweenness centrality measures how often a node appears on shortest paths between + /// other nodes. Nodes that act as "bridges" between different parts of the graph + /// have high betweenness centrality. + /// + /// For Beginners: Betweenness centrality finds "bridge" nodes. + /// + /// Think of a transportation network: + /// - A bridge connecting two cities has high betweenness + /// - Many shortest paths go through the bridge + /// - If you remove it, people must take longer routes + /// + /// In social networks: + /// - People connecting different friend groups have high betweenness + /// - They're "brokers" or "gatekeepers" of information + /// + /// Algorithm (simplified Brandes' algorithm): + /// 1. For all pairs of nodes, find shortest paths + /// 2. Count how many paths go through each node + /// 3. Higher count = Higher betweenness + /// + /// High betweenness = Important connector/bridge in the network + /// + /// + public static Dictionary CalculateBetweennessCentrality( + KnowledgeGraph graph, + bool normalized = true) + { + if (graph == null) + throw new ArgumentNullException(nameof(graph)); + + var nodes = graph.GetAllNodes().ToList(); + var betweenness = new Dictionary(); + + // Initialize all to 0 + foreach (var node in nodes) + betweenness[node.Id] = 0.0; + + if (nodes.Count <= 2) + return betweenness; + + // For each node as source + foreach (var source in nodes) + { + var stack = new Stack(); + var paths = new Dictionary>(); + var pathCounts = new Dictionary(); + var distances = new Dictionary(); + var dependencies = new Dictionary(); + + foreach (var node in nodes) + { + paths[node.Id] = new List(); + pathCounts[node.Id] = 0; + distances[node.Id] = -1; + dependencies[node.Id] = 0.0; + } + + pathCounts[source.Id] = 1; + distances[source.Id] = 0; + + var queue = new Queue(); + queue.Enqueue(source.Id); + + // BFS + while (queue.Count > 0) + { + var v = queue.Dequeue(); + stack.Push(v); + + foreach (var edge in graph.GetOutgoingEdges(v)) + { + var w = edge.TargetId; + // First time we see w? + if (distances[w] < 0) + { + queue.Enqueue(w); + distances[w] = distances[v] + 1; + } + + // Shortest path to w via v? + if (distances[w] == distances[v] + 1) + { + pathCounts[w] += pathCounts[v]; + paths[w].Add(v); + } + } + } + + // Accumulate dependencies + while (stack.Count > 0) + { + var w = stack.Pop(); + foreach (var v in paths[w]) + { + dependencies[v] += (pathCounts[v] / (double)pathCounts[w]) * (1.0 + dependencies[w]); + } + + if (w != source.Id) + betweenness[w] += dependencies[w]; + } + } + + // Normalize if requested + if (normalized && nodes.Count > 2) + { + var normFactor = (nodes.Count - 1) * (nodes.Count - 2); + foreach (var node in nodes) + { + betweenness[node.Id] /= normFactor; + } + } + + return betweenness; + } + + /// + /// Performs breadth-first search to calculate distances from a source node to all others. + /// + private static Dictionary BreadthFirstSearchDistances( + KnowledgeGraph graph, + string sourceId) + { + var distances = new Dictionary(); + var nodes = graph.GetAllNodes(); + + foreach (var node in nodes) + distances[node.Id] = int.MaxValue; + + distances[sourceId] = 0; + var queue = new Queue(); + queue.Enqueue(sourceId); + + while (queue.Count > 0) + { + var current = queue.Dequeue(); + var currentDistance = distances[current]; + + foreach (var edge in graph.GetOutgoingEdges(current).Where(e => distances[e.TargetId] == int.MaxValue)) + { + distances[edge.TargetId] = currentDistance + 1; + queue.Enqueue(edge.TargetId); + } + } + + return distances; + } + + /// + /// Identifies the top-k most central nodes based on a centrality measure. + /// + /// Dictionary of node IDs to centrality scores. + /// Number of top nodes to return. + /// List of (nodeId, score) tuples for the top-k nodes, ordered by descending score. + /// + /// For Beginners: This finds the "most important" nodes. + /// + /// After running PageRank or centrality calculations, use this to get: + /// - Top 10 most influential people (PageRank) + /// - Top 5 most connected nodes (Degree) + /// - Top 3 bridge nodes (Betweenness) + /// + /// Example: + /// ```csharp + /// var pageRank = GraphAnalytics.CalculatePageRank(graph); + /// var top10 = GraphAnalytics.GetTopKNodes(pageRank, 10); + /// // Returns the 10 most important nodes with their scores + /// ``` + /// + /// + public static List<(string NodeId, double Score)> GetTopKNodes( + Dictionary centrality, + int k) + { + if (centrality == null) + throw new ArgumentNullException(nameof(centrality)); + + return centrality + .OrderByDescending(kvp => kvp.Value) + .Take(k) + .Select(kvp => (kvp.Key, kvp.Value)) + .ToList(); + } + + /// + /// Finds all connected components in the graph. + /// + /// The numeric type used for vector operations. + /// The knowledge graph to analyze. + /// List of connected components, each containing a set of node IDs. + /// + /// + /// A connected component is a maximal subgraph where every node can reach every other node. + /// This helps identify isolated clusters or communities in the graph. + /// + /// For Beginners: Connected components find separate "islands" in your graph. + /// + /// Think of a social network: + /// - Component 1: Alice's friend group (all connected to each other) + /// - Component 2: Bob's friend group (completely separate from Alice's) + /// - Component 3: Isolated person Charlie (no connections) + /// + /// Uses: + /// - Find isolated communities + /// - Detect fragmented knowledge bases + /// - Identify separate discussion topics + /// - Check if graph is fully connected + /// + /// Algorithm (Union-Find/DFS): + /// 1. Start with first unvisited node + /// 2. Find all nodes reachable from it (BFS/DFS) + /// 3. That's one component + /// 4. Repeat for remaining unvisited nodes + /// + /// Result: List of components, each containing node IDs in that island. + /// + /// + public static List> FindConnectedComponents(KnowledgeGraph graph) + { + if (graph == null) + throw new ArgumentNullException(nameof(graph)); + + var nodes = graph.GetAllNodes().ToList(); + var visited = new HashSet(); + var components = new List>(); + + foreach (var node in nodes.Where(n => !visited.Contains(n.Id))) + { + var component = new HashSet(); + var queue = new Queue(); + queue.Enqueue(node.Id); + visited.Add(node.Id); + + while (queue.Count > 0) + { + var current = queue.Dequeue(); + component.Add(current); + + // Check outgoing edges + foreach (var edge in graph.GetOutgoingEdges(current).Where(e => !visited.Contains(e.TargetId))) + { + visited.Add(edge.TargetId); + queue.Enqueue(edge.TargetId); + } + + // Check incoming edges (for undirected behavior) + foreach (var edge in graph.GetIncomingEdges(current).Where(e => !visited.Contains(e.SourceId))) + { + visited.Add(edge.SourceId); + queue.Enqueue(edge.SourceId); + } + } + + components.Add(component); + } + + return components; + } + + /// + /// Detects communities using Label Propagation algorithm. + /// + /// The numeric type used for vector operations. + /// The knowledge graph to analyze. + /// Maximum number of iterations (default: 100). + /// Dictionary mapping node IDs to their community labels. + /// + /// + /// Label Propagation is a fast community detection algorithm. Each node starts with + /// a unique label, then iteratively adopts the most common label among its neighbors. + /// Nodes in the same community will converge to the same label. + /// + /// For Beginners: Label Propagation finds groups of nodes that cluster together. + /// + /// Imagine a party where people wear colored hats: + /// 1. Everyone starts with a random color + /// 2. Every minute, you look at your friends' hats + /// 3. You change to the most popular color among your friends + /// 4. After a while, friend groups wear the same color + /// + /// In graphs: + /// - Start: Each node has unique label + /// - Iterate: Each node adopts most common neighbor label + /// - Converge: Nodes in same community have same label + /// + /// Why it works: + /// - Densely connected nodes influence each other + /// - They converge to the same label + /// - Weakly connected nodes drift apart + /// + /// Result: Community labels (numbers) for each node. + /// Nodes with the same label are in the same community. + /// + /// Fast: O(k * E) where k = iterations, E = edges + /// Great for: Large graphs, quick community detection + /// + /// + public static Dictionary DetectCommunitiesLabelPropagation( + KnowledgeGraph graph, + int maxIterations = 100) + { + if (graph == null) + throw new ArgumentNullException(nameof(graph)); + + var nodes = graph.GetAllNodes().ToList(); + var labels = new Dictionary(); + var random = new Random(); + + // Initialize each node with unique label + for (int i = 0; i < nodes.Count; i++) + { + labels[nodes[i].Id] = i; + } + + // Iterate until convergence or max iterations + for (int iteration = 0; iteration < maxIterations; iteration++) + { + bool changed = false; + + // Process nodes in random order + var shuffledNodes = nodes.OrderBy(n => random.Next()).ToList(); + + foreach (var node in shuffledNodes) + { + // Get labels of all neighbors + var neighborLabels = new List(); + + foreach (var edge in graph.GetOutgoingEdges(node.Id)) + { + if (labels.TryGetValue(edge.TargetId, out var targetLabel)) + neighborLabels.Add(targetLabel); + } + + foreach (var edge in graph.GetIncomingEdges(node.Id)) + { + if (labels.TryGetValue(edge.SourceId, out var sourceLabel)) + neighborLabels.Add(sourceLabel); + } + + if (neighborLabels.Count == 0) + continue; + + // Find most common label + var labelCounts = neighborLabels + .GroupBy(l => l) + .Select(g => new { Label = g.Key, Count = g.Count() }) + .OrderByDescending(x => x.Count) + .ThenBy(x => random.Next()) // Random tie-breaking + .First(); + + // Update label if different + if (labels[node.Id] != labelCounts.Label) + { + labels[node.Id] = labelCounts.Label; + changed = true; + } + } + + // Converged if no changes + if (!changed) + break; + } + + return labels; + } + + /// + /// Calculates the clustering coefficient for each node. + /// + /// The numeric type used for vector operations. + /// The knowledge graph to analyze. + /// Dictionary mapping node IDs to their clustering coefficients (0 to 1). + /// + /// + /// The clustering coefficient measures how well a node's neighbors are connected to each other. + /// A high coefficient means the node is part of a tightly-knit cluster. + /// + /// For Beginners: Clustering coefficient measures how "clique-like" connections are. + /// + /// Think of your friend group: + /// - High clustering: Your friends all know each other (tight group) + /// - Low clustering: Your friends don't know each other (you're the hub) + /// + /// Formula: + /// - Count how many of your friends are friends with each other + /// - Divide by maximum possible friendships between them + /// - Result: 0 (no friends know each other) to 1 (everyone knows everyone) + /// + /// Example: + /// - You have 3 friends: Alice, Bob, Charlie + /// - Maximum connections between them: 3 (Alice-Bob, Bob-Charlie, Alice-Charlie) + /// - Actual connections: 2 (Alice-Bob, Bob-Charlie) + /// - Clustering coefficient: 2/3 = 0.67 + /// + /// Uses: + /// - Identify tight communities + /// - Find nodes embedded in groups vs connectors between groups + /// - Measure graph's overall "cliquishness" + /// + /// High coefficient = Node in dense cluster + /// Low coefficient = Node connects different groups + /// + /// + public static Dictionary CalculateClusteringCoefficient( + KnowledgeGraph graph) + { + if (graph == null) + throw new ArgumentNullException(nameof(graph)); + + var nodes = graph.GetAllNodes().ToList(); + var coefficients = new Dictionary(); + + foreach (var node in nodes) + { + var neighbors = new HashSet(); + + // Collect all neighbors (treat as undirected) + foreach (var edge in graph.GetOutgoingEdges(node.Id)) + neighbors.Add(edge.TargetId); + foreach (var edge in graph.GetIncomingEdges(node.Id)) + neighbors.Add(edge.SourceId); + + if (neighbors.Count < 2) + { + coefficients[node.Id] = 0.0; + continue; + } + + // Count connections between neighbors + int connectedPairs = 0; + var neighborList = neighbors.ToList(); + + for (int i = 0; i < neighborList.Count; i++) + { + for (int j = i + 1; j < neighborList.Count; j++) + { + var n1 = neighborList[i]; + var n2 = neighborList[j]; + + // Check if n1 and n2 are connected + bool connected = graph.GetOutgoingEdges(n1).Any(e => e.TargetId == n2) || + graph.GetOutgoingEdges(n2).Any(e => e.TargetId == n1); + + if (connected) + connectedPairs++; + } + } + + // Clustering coefficient = actual connections / possible connections + int possiblePairs = neighbors.Count * (neighbors.Count - 1) / 2; + coefficients[node.Id] = (double)connectedPairs / possiblePairs; + } + + return coefficients; + } + + /// + /// Calculates the average clustering coefficient for the entire graph. + /// + /// The numeric type used for vector operations. + /// The knowledge graph to analyze. + /// The average clustering coefficient (0 to 1). + /// + /// For Beginners: This measures how "clustered" the entire graph is. + /// + /// - Close to 1: Graph has many tight groups (like friend circles) + /// - Close to 0: Graph is sparse, few triangles (like a tree) + /// + /// Compare to random graphs: + /// - Random graph: Low clustering (~0.01 for large graphs) + /// - Real social networks: High clustering (~0.3-0.6) + /// - Small-world networks: High clustering + short paths + /// + /// This is one measure of graph structure used in network science. + /// + /// + public static double CalculateAverageClusteringCoefficient(KnowledgeGraph graph) + { + var coefficients = CalculateClusteringCoefficient(graph); + return coefficients.Count > 0 ? coefficients.Values.Average() : 0.0; + } +} diff --git a/src/RetrievalAugmentedGeneration/Graph/GraphEdge.cs b/src/RetrievalAugmentedGeneration/Graph/GraphEdge.cs index 69c9ab743..af757cbd3 100644 --- a/src/RetrievalAugmentedGeneration/Graph/GraphEdge.cs +++ b/src/RetrievalAugmentedGeneration/Graph/GraphEdge.cs @@ -113,11 +113,50 @@ public void SetProperty(string key, object value) /// /// The expected type of the property value. /// The property key. - /// The property value, or default if not found. + /// The property value, or default if not found or conversion fails. + /// + /// + /// This method handles JSON deserialization quirks where numeric types may differ + /// (e.g., int stored as long after JSON round-trip). It uses Convert.ChangeType + /// for IConvertible types to handle such conversions gracefully. + /// + /// + /// The method catches and handles the following exceptions during conversion: + /// - InvalidCastException: When the types are incompatible + /// - FormatException: When the string representation is invalid + /// - OverflowException: When the value is outside the target type's range + /// + /// public TValue? GetProperty(string key) { - if (Properties.TryGetValue(key, out var value) && value is TValue typedValue) + if (!Properties.TryGetValue(key, out var value) || value == null) + return default; + + // Direct type match + if (value is TValue typedValue) return typedValue; + + // Handle numeric type conversions (JSON deserializes integers as long) + if (value is IConvertible && typeof(TValue).IsValueType) + { + try + { + return (TValue)Convert.ChangeType(value, typeof(TValue)); + } + catch (InvalidCastException) + { + return default; + } + catch (FormatException) + { + return default; + } + catch (OverflowException) + { + return default; + } + } + return default; } diff --git a/src/RetrievalAugmentedGeneration/Graph/GraphNode.cs b/src/RetrievalAugmentedGeneration/Graph/GraphNode.cs index 1a5de25b5..dbc16828e 100644 --- a/src/RetrievalAugmentedGeneration/Graph/GraphNode.cs +++ b/src/RetrievalAugmentedGeneration/Graph/GraphNode.cs @@ -99,11 +99,50 @@ public void SetProperty(string key, object value) /// /// The expected type of the property value. /// The property key. - /// The property value, or default if not found. + /// The property value, or default if not found or conversion fails. + /// + /// + /// This method handles JSON deserialization quirks where numeric types may differ + /// (e.g., int stored as long after JSON round-trip). It uses Convert.ChangeType + /// for IConvertible types to handle such conversions gracefully. + /// + /// + /// The method catches and handles the following exceptions during conversion: + /// - InvalidCastException: When the types are incompatible + /// - FormatException: When the string representation is invalid + /// - OverflowException: When the value is outside the target type's range + /// + /// public TValue? GetProperty(string key) { - if (Properties.TryGetValue(key, out var value) && value is TValue typedValue) + if (!Properties.TryGetValue(key, out var value) || value == null) + return default; + + // Direct type match + if (value is TValue typedValue) return typedValue; + + // Handle numeric type conversions (JSON deserializes integers as long) + if (value is IConvertible && typeof(TValue).IsValueType) + { + try + { + return (TValue)Convert.ChangeType(value, typeof(TValue)); + } + catch (InvalidCastException) + { + return default; + } + catch (FormatException) + { + return default; + } + catch (OverflowException) + { + return default; + } + } + return default; } diff --git a/src/RetrievalAugmentedGeneration/Graph/GraphQueryMatcher.cs b/src/RetrievalAugmentedGeneration/Graph/GraphQueryMatcher.cs new file mode 100644 index 000000000..5b8ff9294 --- /dev/null +++ b/src/RetrievalAugmentedGeneration/Graph/GraphQueryMatcher.cs @@ -0,0 +1,433 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Text.RegularExpressions; + +namespace AiDotNet.RetrievalAugmentedGeneration.Graph; + +/// +/// Simple pattern matching for graph queries (inspired by Cypher/SPARQL but simplified). +/// +/// The numeric type used for vector operations. +/// +/// +/// Supports basic graph pattern matching queries like: +/// - (Person)-[KNOWS]->(Person) +/// - (Person {name: "Alice"})-[WORKS_AT]->(Company) +/// - (a:Person)-[r:KNOWS]->(b:Person) +/// +/// For Beginners: Pattern matching is like SQL for graphs. +/// +/// SQL Example: +/// ```sql +/// SELECT * FROM persons WHERE name = 'Alice' +/// ``` +/// +/// Graph Pattern Example: +/// ``` +/// (Person {name: "Alice"})-[KNOWS]->(Person) +/// ``` +/// Meaning: Find all people that Alice knows +/// +/// Another Example: +/// ``` +/// (Person)-[WORKS_AT]->(Company {name: "Google"}) +/// ``` +/// Meaning: Find all people who work at Google +/// +/// This is much more natural for relationship-heavy data! +/// +/// +public class GraphQueryMatcher +{ + private readonly KnowledgeGraph _graph; + + /// + /// Initializes a new instance of the class. + /// + /// The knowledge graph to query. + public GraphQueryMatcher(KnowledgeGraph graph) + { + _graph = graph ?? throw new ArgumentNullException(nameof(graph)); + } + + /// + /// Finds nodes matching a label and optional property filters. + /// + /// The node label to match. + /// Optional property filters. + /// List of matching nodes. + /// + /// FindNodes("Person", new Dictionary<string, object> { { "name", "Alice" } }) + /// + public List> FindNodes(string label, Dictionary? properties = null) + { + if (string.IsNullOrWhiteSpace(label)) + throw new ArgumentException("Label cannot be null or whitespace", nameof(label)); + + var nodes = _graph.GetNodesByLabel(label).ToList(); + + if (properties == null || properties.Count == 0) + return nodes; + + // Filter by properties + return nodes.Where(node => + { + foreach (var (key, value) in properties) + { + if (!node.Properties.TryGetValue(key, out var nodeValue)) + return false; + + // Simple equality check + if (!AreEqual(nodeValue, value)) + return false; + } + return true; + }).ToList(); + } + + /// + /// Finds paths matching a pattern: (source label)-[relationship type]->(target label). + /// + /// The source node label. + /// The relationship type. + /// The target node label. + /// Optional source node property filters. + /// Optional target node property filters. + /// List of matching paths. + /// + /// FindPaths("Person", "KNOWS", "Person") // Find all KNOWS relationships + /// FindPaths("Person", "WORKS_AT", "Company", + /// new Dictionary<string, object> { { "name", "Alice" } }, + /// new Dictionary<string, object> { { "name", "Google" } }) + /// + public List> FindPaths( + string sourceLabel, + string relationshipType, + string targetLabel, + Dictionary? sourceProperties = null, + Dictionary? targetProperties = null) + { + if (string.IsNullOrWhiteSpace(sourceLabel)) + throw new ArgumentException("Source label cannot be null or whitespace", nameof(sourceLabel)); + if (string.IsNullOrWhiteSpace(relationshipType)) + throw new ArgumentException("Relationship type cannot be null or whitespace", nameof(relationshipType)); + if (string.IsNullOrWhiteSpace(targetLabel)) + throw new ArgumentException("Target label cannot be null or whitespace", nameof(targetLabel)); + + var results = new List>(); + + // Find matching source nodes + var sourceNodes = FindNodes(sourceLabel, sourceProperties); + + foreach (var sourceNode in sourceNodes) + { + // Get outgoing edges + var edges = _graph.GetOutgoingEdges(sourceNode.Id) + .Where(e => e.RelationType == relationshipType); + + foreach (var edge in edges) + { + var targetNode = _graph.GetNode(edge.TargetId); + if (targetNode == null + || targetNode.Label != targetLabel + || (targetProperties != null && targetProperties.Count > 0 && !MatchesProperties(targetNode, targetProperties))) + { + continue; + } + + // Found a match! + results.Add(new GraphPath + { + SourceNode = sourceNode, + Edge = edge, + TargetNode = targetNode + }); + } + } + + return results; + } + + /// + /// Finds all paths of specified length from a source node. + /// + /// The source node ID. + /// The path length (number of hops). + /// Optional relationship type filter. + /// List of paths. + public List>> FindPathsOfLength( + string sourceId, + int pathLength, + string? relationshipType = null) + { + if (string.IsNullOrWhiteSpace(sourceId)) + throw new ArgumentException("Source ID cannot be null or whitespace", nameof(sourceId)); + if (pathLength <= 0) + throw new ArgumentOutOfRangeException(nameof(pathLength), "Path length must be positive"); + + var results = new List>>(); + var currentPaths = new List>>(); + + var sourceNode = _graph.GetNode(sourceId); + if (sourceNode == null) + return results; + + // Initialize with source node + currentPaths.Add(new List> { sourceNode }); + + // BFS expansion + for (int depth = 0; depth < pathLength; depth++) + { + var nextPaths = new List>>(); + + foreach (var path in currentPaths) + { + var lastNode = path[^1]; + var edges = _graph.GetOutgoingEdges(lastNode.Id); + + // Filter by relationship type if specified + if (!string.IsNullOrWhiteSpace(relationshipType)) + { + edges = edges.Where(e => e.RelationType == relationshipType); + } + + // Map edges to target nodes, filter nulls and cycles, then create new paths + var newPaths = edges + .Select(edge => _graph.GetNode(edge.TargetId)) + .OfType>() + .Where(targetNode => !path.Any(n => n.Id == targetNode.Id)) + .Select(targetNode => new List>(path) { targetNode }); + nextPaths.AddRange(newPaths); + } + + currentPaths = nextPaths; + } + + return currentPaths; + } + + /// + /// Finds all shortest paths between two nodes. + /// + /// The source node ID. + /// The target node ID. + /// Maximum depth to search (prevents infinite loops). + /// List of shortest paths. + public List>> FindShortestPaths( + string sourceId, + string targetId, + int maxDepth = 10) + { + if (string.IsNullOrWhiteSpace(sourceId)) + throw new ArgumentException("Source ID cannot be null or whitespace", nameof(sourceId)); + if (string.IsNullOrWhiteSpace(targetId)) + throw new ArgumentException("Target ID cannot be null or whitespace", nameof(targetId)); + + var sourceNode = _graph.GetNode(sourceId); + var targetNode = _graph.GetNode(targetId); + + if (sourceNode == null || targetNode == null) + return new List>>(); + + if (sourceId == targetId) + return new List>> { new List> { sourceNode } }; + + // BFS to find shortest paths + var queue = new Queue>>(); + var visitedAtDepth = new Dictionary(); + var results = new List>>(); + var shortestLength = int.MaxValue; + + queue.Enqueue(new List> { sourceNode }); + visitedAtDepth[sourceId] = 0; + + while (queue.Count > 0) + { + var path = queue.Dequeue(); + var currentNode = path[^1]; + var currentDepth = path.Count - 1; + + // Check if we've exceeded max depth + if (path.Count > maxDepth) + continue; + + // If we found longer paths than shortest, stop processing this path + if (path.Count > shortestLength) + continue; + + // Get neighbors - map edges to target nodes and filter nulls + var neighbors = _graph.GetOutgoingEdges(currentNode.Id) + .Select(edge => _graph.GetNode(edge.TargetId)) + .OfType>(); + + foreach (var neighbor in neighbors) + { + // Check if we found target + if (neighbor.Id == targetId) + { + var newPath = new List>(path) { neighbor }; + results.Add(newPath); + shortestLength = Math.Min(shortestLength, newPath.Count); + continue; + } + + // Avoid cycles + if (path.Any(n => n.Id == neighbor.Id)) + continue; + + // Continue exploring - allow visiting at same or earlier depth for shortest paths + var newDepth = currentDepth + 1; + if (!visitedAtDepth.TryGetValue(neighbor.Id, out var prevDepth) || newDepth <= prevDepth) + { + visitedAtDepth[neighbor.Id] = newDepth; + var newPath = new List>(path) { neighbor }; + queue.Enqueue(newPath); + } + } + } + + return results.Where(p => p.Count == shortestLength).ToList(); + } + + /// + /// Executes a simple pattern query string. + /// + /// The pattern string (simplified Cypher-like syntax). + /// List of matching paths. + /// + /// ExecutePattern("(Person)-[KNOWS]->(Person)") + /// ExecutePattern("(Person {name: \"Alice\"})-[WORKS_AT]->(Company)") + /// + public List> ExecutePattern(string pattern) + { + if (string.IsNullOrWhiteSpace(pattern)) + throw new ArgumentException("Pattern cannot be null or whitespace", nameof(pattern)); + + // Simple regex-based pattern parser + // Pattern: (SourceLabel {prop: "value"})-[RELATIONSHIP]->(TargetLabel {prop: "value"}) + var regex = new Regex(@"\((\w+)(?:\s*\{([^}]+)\})?\)-\[(\w+)\]->\((\w+)(?:\s*\{([^}]+)\})?\)"); + var match = regex.Match(pattern); + + if (!match.Success) + throw new ArgumentException($"Invalid pattern format: {pattern}. Expected format: (SourceLabel)-[RELATIONSHIP]->(TargetLabel)", nameof(pattern)); + + var sourceLabel = match.Groups[1].Value; + var sourcePropsStr = match.Groups[2].Value; + var relationshipType = match.Groups[3].Value; + var targetLabel = match.Groups[4].Value; + var targetPropsStr = match.Groups[5].Value; + + var sourceProps = ParseProperties(sourcePropsStr); + var targetProps = ParseProperties(targetPropsStr); + + return FindPaths(sourceLabel, relationshipType, targetLabel, sourceProps, targetProps); + } + + /// + /// Parses property string into dictionary. + /// + /// String like: name: "Alice", age: 30 + private Dictionary? ParseProperties(string propsString) + { + if (string.IsNullOrWhiteSpace(propsString)) + return null; + + var props = propsString.Split(',') + .Select(pair => pair.Split(':')) + .Where(parts => parts.Length == 2) + .Select(parts => + { + var key = parts[0].Trim(); + var valueStr = parts[1].Trim().Trim('"', '\''); + object value; + if (int.TryParse(valueStr, out var intValue)) + value = intValue; + else if (double.TryParse(valueStr, NumberStyles.Float, CultureInfo.InvariantCulture, out var doubleValue)) + value = doubleValue; + else + value = valueStr; + return new KeyValuePair(key, value); + }) + .ToDictionary(kv => kv.Key, kv => kv.Value); + + return props.Count > 0 ? props : null; + } + + /// + /// Checks if a node matches property filters. + /// + private bool MatchesProperties(GraphNode node, Dictionary properties) + { + foreach (var kvp in properties) + { + if (!node.Properties.TryGetValue(kvp.Key, out var nodeValue)) + return false; + + if (!AreEqual(nodeValue, kvp.Value)) + return false; + } + return true; + } + + /// + /// Compares two objects for equality. + /// + private bool AreEqual(object obj1, object obj2) + { + if (obj1 == null && obj2 == null) + return true; + if (obj1 == null || obj2 == null) + return false; + + // Handle numeric comparisons with tolerance for floating-point values + if (IsNumeric(obj1) && IsNumeric(obj2)) + { + var d1 = Convert.ToDouble(obj1); + var d2 = Convert.ToDouble(obj2); + const double tolerance = 1e-10; + return Math.Abs(d1 - d2) < tolerance; + } + + return obj1.Equals(obj2); + } + + /// + /// Checks if an object is numeric. + /// + private bool IsNumeric(object obj) + { + return obj is int or long or float or double or decimal; + } +} + +/// +/// Represents a path in the graph: source node -> edge -> target node. +/// +/// The numeric type. +public class GraphPath +{ + /// + /// Gets or sets the source node. + /// + public GraphNode SourceNode { get; set; } = null!; + + /// + /// Gets or sets the edge connecting source to target. + /// + public GraphEdge Edge { get; set; } = null!; + + /// + /// Gets or sets the target node. + /// + public GraphNode TargetNode { get; set; } = null!; + + /// + /// Returns a string representation of the path. + /// + public override string ToString() + { + return $"({SourceNode.Label}:{SourceNode.Id})-[{Edge.RelationType}]->({TargetNode.Label}:{TargetNode.Id})"; + } +} diff --git a/src/RetrievalAugmentedGeneration/Graph/GraphTransaction.cs b/src/RetrievalAugmentedGeneration/Graph/GraphTransaction.cs new file mode 100644 index 000000000..8af33426c --- /dev/null +++ b/src/RetrievalAugmentedGeneration/Graph/GraphTransaction.cs @@ -0,0 +1,446 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; + +namespace AiDotNet.RetrievalAugmentedGeneration.Graph; + +/// +/// Transaction coordinator for managing transactions on graph stores with best-effort rollback. +/// +/// The numeric type used for vector operations. +/// +/// +/// This class provides transaction management with the following guarantees: +/// - Atomicity (Best-Effort): If an operation fails during commit, compensating rollback +/// is attempted in reverse order. However, if an undo operation fails, it is swallowed and +/// rollback continues with remaining operations. Full atomicity is not guaranteed. +/// - Consistency: Graph validation rules are enforced during operations. +/// - Isolation: Single-threaded; no concurrent transaction support. +/// - Durability: When a WAL is provided, operations are logged before execution. +/// Without a WAL, durability is not guaranteed. +/// +/// +/// Important: This is a lightweight transaction implementation suitable for single-process +/// use cases. For full ACID compliance with crash recovery, ensure a WriteAheadLog is configured. +/// +/// For Beginners: Transactions ensure your changes are safe. +/// +/// Think of a bank transfer: +/// - Debit $100 from Alice +/// - Credit $100 to Bob +/// +/// Without transactions: +/// - If crash happens after debit but before credit, $100 disappears! +/// +/// With transactions: +/// - Begin transaction +/// - Debit Alice +/// - Credit Bob +/// - Commit (both succeed) OR Rollback (both undone) +/// - Money never disappears! +/// +/// In graphs: +/// ```csharp +/// var txn = new GraphTransaction(store, wal); +/// txn.Begin(); +/// try +/// { +/// txn.AddNode(node1); +/// txn.AddEdge(edge1); +/// txn.Commit(); // Both saved +/// } +/// catch +/// { +/// txn.Rollback(); // Both undone +/// } +/// ``` +/// +/// This ensures your graph is never in a broken state! +/// +/// +public class GraphTransaction : IDisposable +{ + private readonly IGraphStore _store; + private readonly WriteAheadLog? _wal; + private readonly List> _operations; + private TransactionState _state; + private long _transactionId; + private bool _disposed; + + /// + /// Gets the current state of the transaction. + /// + public TransactionState State => _state; + + /// + /// Gets the transaction ID. + /// + public long TransactionId => _transactionId; + + /// + /// Initializes a new instance of the class. + /// + /// The graph store to operate on. + /// Optional Write-Ahead Log for durability. + public GraphTransaction(IGraphStore store, WriteAheadLog? wal = null) + { + _store = store ?? throw new ArgumentNullException(nameof(store)); + _wal = wal; + _operations = new List>(); + _state = TransactionState.NotStarted; + _transactionId = -1; + } + + /// + /// Begins a new transaction. + /// + /// Thrown if transaction already started. + public void Begin() + { + if (_state != TransactionState.NotStarted) + throw new InvalidOperationException($"Transaction already in state: {_state}"); + + _state = TransactionState.Active; + _transactionId = DateTime.UtcNow.Ticks; // Simple ID generation + _operations.Clear(); + } + + /// + /// Adds a node within the transaction. + /// + /// The node to add. + public void AddNode(GraphNode node) + { + EnsureActive(); + + _operations.Add(new TransactionOperation + { + Type = GraphOperationType.AddNode, + Node = node + }); + } + + /// + /// Adds an edge within the transaction. + /// + /// The edge to add. + public void AddEdge(GraphEdge edge) + { + EnsureActive(); + + _operations.Add(new TransactionOperation + { + Type = GraphOperationType.AddEdge, + Edge = edge + }); + } + + /// + /// Removes a node within the transaction. + /// + /// The ID of the node to remove. + public void RemoveNode(string nodeId) + { + EnsureActive(); + + // Capture original node data for potential undo + var originalNode = _store.GetNode(nodeId); + + _operations.Add(new TransactionOperation + { + Type = GraphOperationType.RemoveNode, + NodeId = nodeId, + Node = originalNode // Store for undo + }); + } + + /// + /// Removes an edge within the transaction. + /// + /// The ID of the edge to remove. + public void RemoveEdge(string edgeId) + { + EnsureActive(); + + // Capture original edge data for potential undo + var originalEdge = _store.GetEdge(edgeId); + + _operations.Add(new TransactionOperation + { + Type = GraphOperationType.RemoveEdge, + EdgeId = edgeId, + Edge = originalEdge // Store for undo + }); + } + + /// + /// Commits the transaction, applying all operations with best-effort atomicity. + /// + /// Thrown if transaction not active. + /// + /// + /// If an operation fails mid-way, compensating rollback logic is executed + /// to undo already-applied operations in reverse order. However, this is + /// best-effort: if an undo operation throws an exception, it is + /// caught and swallowed, and rollback continues with remaining operations. + /// + /// + /// This means that after a failed commit, the graph may be left in a + /// partially modified state if undo operations also fail. For production + /// use cases requiring strict atomicity, consider using a database-backed + /// graph store with native transaction support. + /// + /// + public void Commit() + { + EnsureActive(); + + var appliedOperations = new List>(); + + try + { + // Log to WAL first (durability) + if (_wal != null) + { + foreach (var op in _operations) + { + LogOperation(op); + } + } + + // Apply all operations, tracking which ones succeed + foreach (var op in _operations) + { + ApplyOperation(op); + appliedOperations.Add(op); + } + + // Checkpoint if using WAL + _wal?.LogCheckpoint(); + + _state = TransactionState.Committed; + } + catch (Exception) + { + // Compensating rollback: undo applied operations in reverse order + for (int i = appliedOperations.Count - 1; i >= 0; i--) + { + try + { + UndoOperation(appliedOperations[i]); + } + catch + { + // Best-effort rollback - continue with remaining undo operations + } + } + + _state = TransactionState.Failed; + throw; + } + } + + /// + /// Rolls back the transaction, discarding all operations. + /// + public void Rollback() + { + if (_state != TransactionState.Active && _state != TransactionState.Failed) + throw new InvalidOperationException($"Cannot rollback transaction in state: {_state}"); + + // Simply discard operations (they were never applied) + _operations.Clear(); + _state = TransactionState.RolledBack; + } + + /// + /// Ensures the transaction is in active state. + /// + private void EnsureActive() + { + if (_state != TransactionState.Active) + throw new InvalidOperationException($"Transaction not active. Current state: {_state}"); + } + + /// + /// Logs an operation to the WAL. + /// + private void LogOperation(TransactionOperation op) + { + if (_wal == null) + return; + + switch (op.Type) + { + case GraphOperationType.AddNode: + _wal.LogAddNode(op.Node!); + break; + case GraphOperationType.AddEdge: + _wal.LogAddEdge(op.Edge!); + break; + case GraphOperationType.RemoveNode: + _wal.LogRemoveNode(op.NodeId!); + break; + case GraphOperationType.RemoveEdge: + _wal.LogRemoveEdge(op.EdgeId!); + break; + } + } + + /// + /// Applies an operation to the graph store. + /// + private void ApplyOperation(TransactionOperation op) + { + switch (op.Type) + { + case GraphOperationType.AddNode: + if (op.Node != null) + _store.AddNode(op.Node); + break; + case GraphOperationType.AddEdge: + if (op.Edge != null) + _store.AddEdge(op.Edge); + break; + case GraphOperationType.RemoveNode: + if (op.NodeId != null) + _store.RemoveNode(op.NodeId); + break; + case GraphOperationType.RemoveEdge: + if (op.EdgeId != null) + _store.RemoveEdge(op.EdgeId); + break; + } + } + + /// + /// Undoes an already-applied operation (compensating action). + /// + /// + /// + /// This method performs the reverse of each operation type: + /// - AddNode → RemoveNode (using the stored node's ID) + /// - AddEdge → RemoveEdge (using the stored edge's ID) + /// - RemoveNode → AddNode (re-adds the captured original node) + /// - RemoveEdge → AddEdge (re-adds the captured original edge) + /// + /// + /// The original node/edge data is captured during the RecordRemoveNode/RecordRemoveEdge + /// operations, allowing complete restoration during undo. + /// + /// + private void UndoOperation(TransactionOperation op) + { + switch (op.Type) + { + case GraphOperationType.AddNode: + // Undo add by removing the node + if (op.Node != null) + _store.RemoveNode(op.Node.Id); + break; + case GraphOperationType.AddEdge: + // Undo add by removing the edge + if (op.Edge != null) + _store.RemoveEdge(op.Edge.Id); + break; + case GraphOperationType.RemoveNode: + // Undo remove by re-adding the node (if we have the original data) + if (op.Node != null) + _store.AddNode(op.Node); + break; + case GraphOperationType.RemoveEdge: + // Undo remove by re-adding the edge (if we have the original data) + if (op.Edge != null) + _store.AddEdge(op.Edge); + break; + } + } + + /// + /// Disposes the transaction, rolling back if still active. + /// + public void Dispose() + { + if (_disposed) + return; + + // Auto-rollback if still active + if (_state == TransactionState.Active) + { + try + { + Rollback(); + } + catch (InvalidOperationException ex) + { + // Rollback errors during dispose are suppressed per IDisposable contract, + // but traced for diagnostics. Transaction may already be in invalid state. + Debug.WriteLine($"GraphTransaction.Dispose: Rollback failed with InvalidOperationException: {ex.Message}"); + } + catch (IOException ex) + { + // I/O errors during dispose are suppressed per IDisposable contract, + // but traced for diagnostics. Underlying store may be unavailable. + Debug.WriteLine($"GraphTransaction.Dispose: Rollback failed with IOException: {ex.Message}"); + } + } + + _disposed = true; + } +} + +/// +/// Represents a single operation within a transaction. +/// +/// The numeric type. +internal class TransactionOperation +{ + public GraphOperationType Type { get; set; } + public GraphNode? Node { get; set; } + public GraphEdge? Edge { get; set; } + public string? NodeId { get; set; } + public string? EdgeId { get; set; } +} + +/// +/// Types of operations supported in graph transactions. +/// +internal enum GraphOperationType +{ + AddNode, + AddEdge, + RemoveNode, + RemoveEdge +} + +/// +/// Represents the state of a transaction. +/// +public enum TransactionState +{ + /// + /// Transaction has not been started. + /// + NotStarted, + + /// + /// Transaction is active and accepting operations. + /// + Active, + + /// + /// Transaction has been committed successfully. + /// + Committed, + + /// + /// Transaction has been rolled back. + /// + RolledBack, + + /// + /// Transaction failed during commit. + /// + Failed +} diff --git a/src/RetrievalAugmentedGeneration/Graph/HybridGraphRetriever.cs b/src/RetrievalAugmentedGeneration/Graph/HybridGraphRetriever.cs new file mode 100644 index 000000000..d7851c6b7 --- /dev/null +++ b/src/RetrievalAugmentedGeneration/Graph/HybridGraphRetriever.cs @@ -0,0 +1,383 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using AiDotNet.Helpers; +using AiDotNet.Interfaces; +using AiDotNet.LinearAlgebra; +using AiDotNet.RetrievalAugmentedGeneration.Models; + +namespace AiDotNet.RetrievalAugmentedGeneration.Graph; + +/// +/// Hybrid retriever that combines vector similarity search with graph traversal for enhanced RAG. +/// +/// The numeric type used for vector operations. +/// +/// +/// This retriever uses a two-stage approach: +/// 1. Vector similarity search to find initial candidate nodes +/// 2. Graph traversal to expand context with related nodes +/// +/// For Beginners: Traditional RAG uses only vector similarity: +/// +/// Query: "What is photosynthesis?" +/// Traditional RAG: +/// - Find documents similar to the query +/// - Return top-k matches +/// - Misses related context! +/// +/// Hybrid Graph RAG: +/// - Find initial matches via vector similarity +/// - Walk the graph to find related concepts +/// - Example: photosynthesis → chlorophyll → plants → carbon dioxide +/// - Provides richer, more complete context +/// +/// Real-world analogy: +/// - Traditional: Search "Paris" → get Paris documents +/// - Hybrid: Search "Paris" → get Paris + France + Eiffel Tower + Seine River +/// - Graph connections provide context vectors can't capture! +/// +/// +public class HybridGraphRetriever +{ + private readonly KnowledgeGraph _graph; + private readonly IDocumentStore _documentStore; + + /// + /// Initializes a new instance of the class. + /// + /// The knowledge graph containing entity relationships. + /// The document store for similarity search. + public HybridGraphRetriever( + KnowledgeGraph graph, + IDocumentStore documentStore) + { + _graph = graph ?? throw new ArgumentNullException(nameof(graph)); + _documentStore = documentStore ?? throw new ArgumentNullException(nameof(documentStore)); + } + + /// + /// Retrieves relevant nodes using hybrid vector + graph approach. + /// + /// The query embedding vector. + /// Number of initial candidates to retrieve via vector search. + /// How many hops to traverse in the graph (0 = no expansion). + /// Maximum total results to return after expansion. + /// List of retrieved nodes with their relevance scores. + public List> Retrieve( + Vector queryEmbedding, + int topK = 5, + int expansionDepth = 1, + int maxResults = 10) + { + if (queryEmbedding == null || queryEmbedding.Length == 0) + throw new ArgumentException("Query embedding cannot be null or empty", nameof(queryEmbedding)); + if (topK <= 0) + throw new ArgumentOutOfRangeException(nameof(topK), "topK must be positive"); + if (expansionDepth < 0) + throw new ArgumentOutOfRangeException(nameof(expansionDepth), "expansionDepth cannot be negative"); + + // Stage 1: Vector similarity search for initial candidates using document store + var initialCandidates = _documentStore.GetSimilar(queryEmbedding, topK).ToList(); + + if (expansionDepth == 0) + { + // No graph expansion - return vector results only + return initialCandidates + .Take(maxResults) + .Select(doc => new RetrievalResult + { + NodeId = doc.Id, + Score = doc.HasRelevanceScore ? Convert.ToDouble(doc.RelevanceScore) : 0.0, + Source = RetrievalSource.VectorSearch, + Embedding = doc.Embedding + }) + .ToList(); + } + + // Stage 2: Graph expansion + var results = new Dictionary>(); + var visited = new HashSet(); + + // Add initial candidates + foreach (var candidate in initialCandidates) + { + var result = new RetrievalResult + { + NodeId = candidate.Id, + Score = candidate.HasRelevanceScore ? Convert.ToDouble(candidate.RelevanceScore) : 0.0, + Source = RetrievalSource.VectorSearch, + Embedding = candidate.Embedding, + Depth = 0 + }; + results[candidate.Id] = result; + visited.Add(candidate.Id); + } + + // BFS expansion from initial candidates + var queue = new Queue<(string nodeId, int depth)>(); + foreach (var candidate in initialCandidates) + { + queue.Enqueue((candidate.Id, 0)); + } + + while (queue.Count > 0) + { + var (currentId, currentDepth) = queue.Dequeue(); + + if (currentDepth >= expansionDepth) + continue; + + // Get neighbors from graph + var neighbors = GetNeighbors(currentId); + + foreach (var neighborId in neighbors.Where(n => !visited.Contains(n))) + { + visited.Add(neighborId); + + // Get neighbor's embedding from graph node + var neighborNode = _graph.GetNode(neighborId); + var neighborEmbedding = neighborNode?.Embedding; + double score = 0.0; + + if (neighborEmbedding != null && neighborEmbedding.Length > 0) + { + // Calculate similarity to query using StatisticsHelper + score = CalculateSimilarity(queryEmbedding, neighborEmbedding); + + // Apply depth penalty (closer nodes are more relevant) + var depthPenalty = Math.Pow(0.8, currentDepth + 1); // 0.8^depth + score *= depthPenalty; + } + + var result = new RetrievalResult + { + NodeId = neighborId, + Score = score, + Source = RetrievalSource.GraphTraversal, + Embedding = neighborEmbedding, + Depth = currentDepth + 1, + ParentNodeId = currentId + }; + + // Only update if new score is higher to preserve best path + if (!results.TryGetValue(neighborId, out var existing) || result.Score > existing.Score) + { + results[neighborId] = result; + } + + // Continue expanding + if (currentDepth + 1 < expansionDepth) + { + queue.Enqueue((neighborId, currentDepth + 1)); + } + } + } + + // Return top results sorted by score + return results.Values + .OrderByDescending(r => r.Score) + .Take(maxResults) + .ToList(); + } + + /// + /// Retrieves relevant nodes asynchronously using hybrid approach. + /// + public async Task>> RetrieveAsync( + Vector queryEmbedding, + int topK = 5, + int expansionDepth = 1, + int maxResults = 10) + { + // For now, just wrap the synchronous version + // In a real implementation, you'd use async vector DB operations + return await Task.Run(() => Retrieve(queryEmbedding, topK, expansionDepth, maxResults)); + } + + /// + /// Retrieves nodes with relationship-aware scoring. + /// + /// The query embedding vector. + /// Number of initial candidates. + /// Weights for different relationship types. + /// Maximum results to return. + /// List of retrieved nodes with relationship-aware scores. + public List> RetrieveWithRelationships( + Vector queryEmbedding, + int topK = 5, + Dictionary? relationshipWeights = null, + int maxResults = 10) + { + if (queryEmbedding == null || queryEmbedding.Length == 0) + throw new ArgumentException("Query embedding cannot be null or empty", nameof(queryEmbedding)); + + relationshipWeights ??= new Dictionary(); + + // Stage 1: Vector similarity search + var initialCandidates = _documentStore.GetSimilar(queryEmbedding, topK).ToList(); + var results = new Dictionary>(); + + // Add initial candidates + foreach (var candidate in initialCandidates) + { + results[candidate.Id] = new RetrievalResult + { + NodeId = candidate.Id, + Score = candidate.HasRelevanceScore ? Convert.ToDouble(candidate.RelevanceScore) : 0.0, + Source = RetrievalSource.VectorSearch, + Embedding = candidate.Embedding, + Depth = 0 + }; + } + + // Stage 2: Expand via relationships with weighted scoring + foreach (var candidate in initialCandidates) + { + var node = _graph.GetNode(candidate.Id); + if (node == null) + continue; + + var outgoingEdges = _graph.GetOutgoingEdges(candidate.Id); + + foreach (var edge in outgoingEdges.Where(e => !results.ContainsKey(e.TargetId))) + { + // Get relationship weight (default to 1.0) + var weight = relationshipWeights.TryGetValue(edge.RelationType, out var w) ? w : 1.0; + + // Get target node's embedding from graph + var targetNode = _graph.GetNode(edge.TargetId); + var targetEmbedding = targetNode?.Embedding; + double score = 0.0; + + if (targetEmbedding != null && targetEmbedding.Length > 0) + { + score = CalculateSimilarity(queryEmbedding, targetEmbedding); + score *= weight; // Apply relationship weight + score *= 0.8; // One-hop penalty + } + + var result = new RetrievalResult + { + NodeId = edge.TargetId, + Score = score, + Source = RetrievalSource.GraphTraversal, + Embedding = targetEmbedding, + Depth = 1, + ParentNodeId = candidate.Id, + RelationType = edge.RelationType + }; + + // Only update if new score is higher to preserve best path + if (!results.TryGetValue(edge.TargetId, out var existing) || result.Score > existing.Score) + { + results[edge.TargetId] = result; + } + } + } + + // Return top results + return results.Values + .OrderByDescending(r => r.Score) + .Take(maxResults) + .ToList(); + } + + /// + /// Gets all neighbors (both incoming and outgoing) of a node. + /// + private HashSet GetNeighbors(string nodeId) + { + var neighbors = new HashSet(); + + var outgoing = _graph.GetOutgoingEdges(nodeId); + foreach (var edge in outgoing) + { + neighbors.Add(edge.TargetId); + } + + var incoming = _graph.GetIncomingEdges(nodeId); + foreach (var edge in incoming) + { + neighbors.Add(edge.SourceId); + } + + return neighbors; + } + + /// + /// Calculates similarity between two embeddings using cosine similarity. + /// + private double CalculateSimilarity(Vector embedding1, Vector embedding2) + { + if (embedding1.Length != embedding2.Length) + return 0.0; + + var similarity = StatisticsHelper.CosineSimilarity(embedding1, embedding2); + return Convert.ToDouble(similarity); + } +} + +/// +/// Represents a retrieval result from the hybrid retriever. +/// +/// The numeric type. +public class RetrievalResult +{ + /// + /// Gets or sets the node ID. + /// + public string NodeId { get; set; } = string.Empty; + + /// + /// Gets or sets the relevance score (0-1, higher is more relevant). + /// + public double Score { get; set; } + + /// + /// Gets or sets how this result was retrieved. + /// + public RetrievalSource Source { get; set; } + + /// + /// Gets or sets the embedding vector. + /// + public Vector? Embedding { get; set; } + + /// + /// Gets or sets the graph traversal depth (0 for initial candidates). + /// + public int Depth { get; set; } + + /// + /// Gets or sets the parent node ID (for graph-traversed results). + /// + public string? ParentNodeId { get; set; } + + /// + /// Gets or sets the relationship type (for graph-traversed results). + /// + public string? RelationType { get; set; } +} + +/// +/// Indicates how a result was retrieved. +/// +public enum RetrievalSource +{ + /// + /// Retrieved via vector similarity search. + /// + VectorSearch, + + /// + /// Retrieved via graph traversal. + /// + GraphTraversal, + + /// + /// Retrieved via both methods. + /// + Hybrid +} diff --git a/src/RetrievalAugmentedGeneration/Graph/KnowledgeGraph.cs b/src/RetrievalAugmentedGeneration/Graph/KnowledgeGraph.cs index 9c9f1c3a7..f62ef742f 100644 --- a/src/RetrievalAugmentedGeneration/Graph/KnowledgeGraph.cs +++ b/src/RetrievalAugmentedGeneration/Graph/KnowledgeGraph.cs @@ -1,65 +1,72 @@ using System; using System.Collections.Generic; using System.Linq; +using AiDotNet.Interfaces; namespace AiDotNet.RetrievalAugmentedGeneration.Graph; /// -/// In-memory knowledge graph for storing and querying entity relationships. +/// Knowledge graph for storing and querying entity relationships using a pluggable storage backend. /// /// The numeric type used for vector operations. /// /// /// A knowledge graph stores entities (nodes) and their relationships (edges) to enable structured information retrieval. -/// This implementation uses efficient in-memory data structures optimized for graph traversal and querying. +/// This implementation delegates storage operations to an implementation, +/// allowing you to swap between in-memory, file-based, or database-backed storage. /// /// For Beginners: A knowledge graph is like a map of how information connects together. -/// +/// /// Imagine Wikipedia as a graph: /// - Each article is a node (Albert Einstein, Physics, Germany, etc.) /// - Links between articles are edges (Einstein STUDIED Physics, Einstein BORN_IN Germany) /// - You can traverse the graph to find related information -/// +/// /// This class lets you: /// 1. Add entities and relationships /// 2. Find connections between entities /// 3. Traverse the graph to discover related information /// 4. Query based on entity types or relationships -/// +/// /// For example, to answer "Who worked at Princeton?": /// 1. Find all edges with type "WORKED_AT" /// 2. Filter for target = "Princeton University" /// 3. Return the source entities (people who worked there) +/// +/// Storage backends you can use: +/// - MemoryGraphStore: Fast, in-memory (default) +/// - FileGraphStore: Persistent, disk-based +/// - Neo4jGraphStore: Professional graph database (future) /// /// public class KnowledgeGraph { - private readonly Dictionary> _nodes; - private readonly Dictionary> _edges; - private readonly Dictionary> _outgoingEdges; // nodeId -> edge IDs going out - private readonly Dictionary> _incomingEdges; // nodeId -> edge IDs coming in - private readonly Dictionary> _nodesByLabel; // label -> node IDs - + private readonly IGraphStore _store; + /// /// Gets the total number of nodes in the graph. /// - public int NodeCount => _nodes.Count; - + public int NodeCount => _store.NodeCount; + /// /// Gets the total number of edges in the graph. /// - public int EdgeCount => _edges.Count; - + public int EdgeCount => _store.EdgeCount; + + /// + /// Initializes a new instance of the class with a custom graph store. + /// + /// The graph store implementation to use for storage. + public KnowledgeGraph(IGraphStore store) + { + _store = store ?? throw new ArgumentNullException(nameof(store)); + } + /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class with default in-memory storage. /// - public KnowledgeGraph() + public KnowledgeGraph() : this(new MemoryGraphStore()) { - _nodes = new Dictionary>(); - _edges = new Dictionary>(); - _outgoingEdges = new Dictionary>(); - _incomingEdges = new Dictionary>(); - _nodesByLabel = new Dictionary>(); } /// @@ -68,21 +75,9 @@ public KnowledgeGraph() /// The node to add. public void AddNode(GraphNode node) { - if (node == null) - throw new ArgumentNullException(nameof(node)); - - _nodes[node.Id] = node; - - if (!_nodesByLabel.ContainsKey(node.Label)) - _nodesByLabel[node.Label] = new HashSet(); - _nodesByLabel[node.Label].Add(node.Id); - - if (!_outgoingEdges.ContainsKey(node.Id)) - _outgoingEdges[node.Id] = new HashSet(); - if (!_incomingEdges.ContainsKey(node.Id)) - _incomingEdges[node.Id] = new HashSet(); + _store.AddNode(node); } - + /// /// Adds an edge to the graph. /// @@ -90,18 +85,9 @@ public void AddNode(GraphNode node) /// Thrown when source or target nodes don't exist. public void AddEdge(GraphEdge edge) { - if (edge == null) - throw new ArgumentNullException(nameof(edge)); - if (!_nodes.ContainsKey(edge.SourceId)) - throw new InvalidOperationException($"Source node '{edge.SourceId}' does not exist"); - if (!_nodes.ContainsKey(edge.TargetId)) - throw new InvalidOperationException($"Target node '{edge.TargetId}' does not exist"); - - _edges[edge.Id] = edge; - _outgoingEdges[edge.SourceId].Add(edge.Id); - _incomingEdges[edge.TargetId].Add(edge.Id); + _store.AddEdge(edge); } - + /// /// Gets a node by its ID. /// @@ -109,9 +95,9 @@ public void AddEdge(GraphEdge edge) /// The node, or null if not found. public GraphNode? GetNode(string nodeId) { - return _nodes.TryGetValue(nodeId, out var node) ? node : null; + return _store.GetNode(nodeId); } - + /// /// Gets all nodes with a specific label. /// @@ -119,12 +105,9 @@ public void AddEdge(GraphEdge edge) /// Collection of nodes with the specified label. public IEnumerable> GetNodesByLabel(string label) { - if (!_nodesByLabel.TryGetValue(label, out var nodeIds)) - return Enumerable.Empty>(); - - return nodeIds.Select(id => _nodes[id]); + return _store.GetNodesByLabel(label); } - + /// /// Gets all outgoing edges from a node. /// @@ -132,12 +115,9 @@ public IEnumerable> GetNodesByLabel(string label) /// Collection of outgoing edges. public IEnumerable> GetOutgoingEdges(string nodeId) { - if (!_outgoingEdges.TryGetValue(nodeId, out var edgeIds)) - return Enumerable.Empty>(); - - return edgeIds.Select(id => _edges[id]); + return _store.GetOutgoingEdges(nodeId); } - + /// /// Gets all incoming edges to a node. /// @@ -145,10 +125,7 @@ public IEnumerable> GetOutgoingEdges(string nodeId) /// Collection of incoming edges. public IEnumerable> GetIncomingEdges(string nodeId) { - if (!_incomingEdges.TryGetValue(nodeId, out var edgeIds)) - return Enumerable.Empty>(); - - return edgeIds.Select(id => _edges[id]); + return _store.GetIncomingEdges(nodeId); } /// @@ -159,7 +136,9 @@ public IEnumerable> GetIncomingEdges(string nodeId) public IEnumerable> GetNeighbors(string nodeId) { var edges = GetOutgoingEdges(nodeId); - return edges.Select(e => _nodes[e.TargetId]); + return edges + .Select(e => _store.GetNode(e.TargetId)) + .OfType>(); } /// @@ -170,22 +149,28 @@ public IEnumerable> GetNeighbors(string nodeId) /// Collection of nodes in BFS order. public IEnumerable> BreadthFirstTraversal(string startNodeId, int maxDepth = int.MaxValue) { - if (!_nodes.ContainsKey(startNodeId)) + if (_store.GetNode(startNodeId) == null) yield break; - + var visited = new HashSet(); var queue = new Queue<(string nodeId, int depth)>(); queue.Enqueue((startNodeId, 0)); visited.Add(startNodeId); - + while (queue.Count > 0) { var (nodeId, depth) = queue.Dequeue(); - yield return _nodes[nodeId]; - + var node = _store.GetNode(nodeId); + + // Skip if node was removed between queue add and dequeue + if (node == null) + continue; + + yield return node; + if (depth >= maxDepth) continue; - + foreach (var edge in GetOutgoingEdges(nodeId)) { if (!visited.Contains(edge.TargetId)) @@ -205,20 +190,20 @@ public IEnumerable> BreadthFirstTraversal(string startNodeId, int m /// List of node IDs representing the path, or empty if no path exists. public List FindShortestPath(string startNodeId, string endNodeId) { - if (!_nodes.ContainsKey(startNodeId) || !_nodes.ContainsKey(endNodeId)) + if (_store.GetNode(startNodeId) == null || _store.GetNode(endNodeId) == null) return new List(); - + var visited = new HashSet(); var parent = new Dictionary(); var queue = new Queue(); - + queue.Enqueue(startNodeId); visited.Add(startNodeId); - + while (queue.Count > 0) { var nodeId = queue.Dequeue(); - + if (nodeId == endNodeId) { // Reconstruct path @@ -233,7 +218,7 @@ public List FindShortestPath(string startNodeId, string endNodeId) path.Reverse(); return path; } - + foreach (var edge in GetOutgoingEdges(nodeId)) { if (!visited.Contains(edge.TargetId)) @@ -244,7 +229,7 @@ public List FindShortestPath(string startNodeId, string endNodeId) } } } - + return new List(); // No path found } @@ -257,8 +242,8 @@ public List FindShortestPath(string startNodeId, string endNodeId) public IEnumerable> FindRelatedNodes(string query, int topK = 10) { var queryLower = query.ToLowerInvariant(); - - return _nodes.Values + + return _store.GetAllNodes() .Where(node => { var name = node.GetProperty("name") ?? node.Id; @@ -267,34 +252,30 @@ public IEnumerable> FindRelatedNodes(string query, int topK = 10) }) .Take(topK); } - + /// /// Clears all nodes and edges from the graph. /// public void Clear() { - _nodes.Clear(); - _edges.Clear(); - _outgoingEdges.Clear(); - _incomingEdges.Clear(); - _nodesByLabel.Clear(); + _store.Clear(); } - + /// /// Gets all nodes in the graph. /// /// Collection of all nodes. public IEnumerable> GetAllNodes() { - return _nodes.Values; + return _store.GetAllNodes(); } - + /// /// Gets all edges in the graph. /// /// Collection of all edges. public IEnumerable> GetAllEdges() { - return _edges.Values; + return _store.GetAllEdges(); } } diff --git a/src/RetrievalAugmentedGeneration/Graph/MemoryGraphStore.cs b/src/RetrievalAugmentedGeneration/Graph/MemoryGraphStore.cs new file mode 100644 index 000000000..76d226b5a --- /dev/null +++ b/src/RetrievalAugmentedGeneration/Graph/MemoryGraphStore.cs @@ -0,0 +1,323 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using AiDotNet.Interfaces; + +namespace AiDotNet.RetrievalAugmentedGeneration.Graph; + +/// +/// In-memory implementation of using dictionaries for fast lookups. +/// +/// The numeric type used for vector operations. +/// +/// +/// This implementation provides high-performance graph storage entirely in RAM. +/// All operations are O(1) or O(degree) complexity. Data is lost when the application stops. +/// +/// +/// Thread Safety: This class is NOT thread-safe. Callers must ensure proper +/// synchronization when accessing from multiple threads. For thread-safe operations, +/// use external locking or consider using which provides +/// thread-safe access via ConcurrentDictionary. +/// +/// For Beginners: This stores your graph in the computer's memory (RAM). +/// +/// Pros: +/// - Very fast (everything in RAM) +/// - Simple to use (no setup required) +/// +/// Cons: +/// - Data lost when app closes +/// - Limited by available RAM +/// - Not thread-safe (single-threaded use only) +/// +/// Good for: +/// - Development and testing +/// - Small to medium graphs (<100K nodes) +/// - Temporary graphs that don't need persistence +/// +/// Not good for: +/// - Production systems requiring persistence +/// - Very large graphs (>1M nodes) +/// - Multi-process or multi-threaded access to the same graph +/// +/// For persistent storage, use FileGraphStore or Neo4jGraphStore instead. +/// +/// +public class MemoryGraphStore : IGraphStore +{ + private readonly Dictionary> _nodes; + private readonly Dictionary> _edges; + private readonly Dictionary> _outgoingEdges; // nodeId -> edge IDs going out + private readonly Dictionary> _incomingEdges; // nodeId -> edge IDs coming in + private readonly Dictionary> _nodesByLabel; // label -> node IDs + + /// + public int NodeCount => _nodes.Count; + + /// + public int EdgeCount => _edges.Count; + + /// + /// Initializes a new instance of the class. + /// + public MemoryGraphStore() + { + _nodes = new Dictionary>(); + _edges = new Dictionary>(); + _outgoingEdges = new Dictionary>(); + _incomingEdges = new Dictionary>(); + _nodesByLabel = new Dictionary>(); + } + + /// + public void AddNode(GraphNode node) + { + if (node == null) + throw new ArgumentNullException(nameof(node)); + + // Remove old label index if node exists with different label + if (_nodes.TryGetValue(node.Id, out var existingNode) && existingNode.Label != node.Label) + { + if (_nodesByLabel.TryGetValue(existingNode.Label, out var oldLabelNodeIds)) + { + oldLabelNodeIds.Remove(node.Id); + if (oldLabelNodeIds.Count == 0) + _nodesByLabel.Remove(existingNode.Label); + } + } + + _nodes[node.Id] = node; + + if (!_nodesByLabel.ContainsKey(node.Label)) + _nodesByLabel[node.Label] = new HashSet(); + _nodesByLabel[node.Label].Add(node.Id); + + if (!_outgoingEdges.ContainsKey(node.Id)) + _outgoingEdges[node.Id] = new HashSet(); + if (!_incomingEdges.ContainsKey(node.Id)) + _incomingEdges[node.Id] = new HashSet(); + } + + /// + public void AddEdge(GraphEdge edge) + { + if (edge == null) + throw new ArgumentNullException(nameof(edge)); + if (!_nodes.ContainsKey(edge.SourceId)) + throw new InvalidOperationException($"Source node '{edge.SourceId}' does not exist"); + if (!_nodes.ContainsKey(edge.TargetId)) + throw new InvalidOperationException($"Target node '{edge.TargetId}' does not exist"); + + _edges[edge.Id] = edge; + _outgoingEdges[edge.SourceId].Add(edge.Id); + _incomingEdges[edge.TargetId].Add(edge.Id); + } + + /// + public GraphNode? GetNode(string nodeId) + { + return _nodes.TryGetValue(nodeId, out var node) ? node : null; + } + + /// + public GraphEdge? GetEdge(string edgeId) + { + return _edges.TryGetValue(edgeId, out var edge) ? edge : null; + } + + /// + public bool RemoveNode(string nodeId) + { + if (!_nodes.TryGetValue(nodeId, out var node)) + return false; + + // Remove all outgoing edges + if (_outgoingEdges.TryGetValue(nodeId, out var outgoing)) + { + foreach (var edgeId in outgoing.ToList()) + { + if (_edges.TryGetValue(edgeId, out var edge)) + { + _edges.Remove(edgeId); + _incomingEdges[edge.TargetId].Remove(edgeId); + } + } + _outgoingEdges.Remove(nodeId); + } + + // Remove all incoming edges + if (_incomingEdges.TryGetValue(nodeId, out var incoming)) + { + foreach (var edgeId in incoming.ToList()) + { + if (_edges.TryGetValue(edgeId, out var edge)) + { + _edges.Remove(edgeId); + _outgoingEdges[edge.SourceId].Remove(edgeId); + } + } + _incomingEdges.Remove(nodeId); + } + + // Remove from label index + if (_nodesByLabel.TryGetValue(node.Label, out var nodeIds)) + { + nodeIds.Remove(nodeId); + if (nodeIds.Count == 0) + _nodesByLabel.Remove(node.Label); + } + + // Remove the node itself + _nodes.Remove(nodeId); + return true; + } + + /// + public bool RemoveEdge(string edgeId) + { + if (!_edges.TryGetValue(edgeId, out var edge)) + return false; + + _edges.Remove(edgeId); + _outgoingEdges[edge.SourceId].Remove(edgeId); + _incomingEdges[edge.TargetId].Remove(edgeId); + return true; + } + + /// + public IEnumerable> GetOutgoingEdges(string nodeId) + { + if (!_outgoingEdges.TryGetValue(nodeId, out var edgeIds)) + return Enumerable.Empty>(); + + // Use TryGetValue to safely handle edges that may have been removed + return edgeIds + .Select(id => _edges.TryGetValue(id, out var edge) ? edge : null) + .OfType>(); + } + + /// + public IEnumerable> GetIncomingEdges(string nodeId) + { + if (!_incomingEdges.TryGetValue(nodeId, out var edgeIds)) + return Enumerable.Empty>(); + + // Use TryGetValue to safely handle edges that may have been removed + return edgeIds + .Select(id => _edges.TryGetValue(id, out var edge) ? edge : null) + .OfType>(); + } + + /// + public IEnumerable> GetNodesByLabel(string label) + { + if (!_nodesByLabel.TryGetValue(label, out var nodeIds)) + return Enumerable.Empty>(); + + // Use TryGetValue to safely handle nodes that may have been removed + return nodeIds + .Select(id => _nodes.TryGetValue(id, out var node) ? node : null) + .OfType>(); + } + + /// + public IEnumerable> GetAllNodes() + { + return _nodes.Values; + } + + /// + public IEnumerable> GetAllEdges() + { + return _edges.Values; + } + + /// + public void Clear() + { + _nodes.Clear(); + _edges.Clear(); + _outgoingEdges.Clear(); + _incomingEdges.Clear(); + _nodesByLabel.Clear(); + } + + // Async methods (for MemoryGraphStore, these wrap synchronous operations) + + /// + public Task AddNodeAsync(GraphNode node) + { + AddNode(node); + return Task.CompletedTask; + } + + /// + public Task AddEdgeAsync(GraphEdge edge) + { + AddEdge(edge); + return Task.CompletedTask; + } + + /// + public Task?> GetNodeAsync(string nodeId) + { + return Task.FromResult(GetNode(nodeId)); + } + + /// + public Task?> GetEdgeAsync(string edgeId) + { + return Task.FromResult(GetEdge(edgeId)); + } + + /// + public Task RemoveNodeAsync(string nodeId) + { + return Task.FromResult(RemoveNode(nodeId)); + } + + /// + public Task RemoveEdgeAsync(string edgeId) + { + return Task.FromResult(RemoveEdge(edgeId)); + } + + /// + public Task>> GetOutgoingEdgesAsync(string nodeId) + { + return Task.FromResult(GetOutgoingEdges(nodeId)); + } + + /// + public Task>> GetIncomingEdgesAsync(string nodeId) + { + return Task.FromResult(GetIncomingEdges(nodeId)); + } + + /// + public Task>> GetNodesByLabelAsync(string label) + { + return Task.FromResult(GetNodesByLabel(label)); + } + + /// + public Task>> GetAllNodesAsync() + { + return Task.FromResult(GetAllNodes()); + } + + /// + public Task>> GetAllEdgesAsync() + { + return Task.FromResult(GetAllEdges()); + } + + /// + public Task ClearAsync() + { + Clear(); + return Task.CompletedTask; + } +} diff --git a/src/RetrievalAugmentedGeneration/Graph/WriteAheadLog.cs b/src/RetrievalAugmentedGeneration/Graph/WriteAheadLog.cs new file mode 100644 index 000000000..0d9f278cc --- /dev/null +++ b/src/RetrievalAugmentedGeneration/Graph/WriteAheadLog.cs @@ -0,0 +1,396 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using Newtonsoft.Json; + +namespace AiDotNet.RetrievalAugmentedGeneration.Graph; + +/// +/// Write-Ahead Log (WAL) for ensuring ACID properties and crash recovery. +/// +/// +/// +/// A Write-Ahead Log records all changes before they're applied to the main data files. +/// This ensures data integrity and enables recovery after crashes. +/// +/// For Beginners: Think of WAL like a ship's log or diary. +/// +/// Before making any change to your graph: +/// 1. Write what you're about to do in the log (WAL) +/// 2. Make sure the log is saved to disk +/// 3. Then make the actual change +/// +/// If the system crashes: +/// - The log shows what was happening +/// - You can replay the log to restore the graph +/// - No data is lost! +/// +/// This is how databases ensure "durability" - the D in ACID. +/// +/// Real-world analogy: +/// - Bank transaction: First log "transfer $100", then move the money +/// - If crash happens after logging but before transfer, replay the log on restart +/// - Money isn't lost! +/// +/// +public class WriteAheadLog : IDisposable +{ + private readonly string _walFilePath; + private StreamWriter? _writer; + private long _currentTransactionId; + private readonly object _lock = new object(); + private bool _disposed; + + /// + /// Gets the current transaction ID. + /// + public long CurrentTransactionId => _currentTransactionId; + + /// + /// Initializes a new instance of the class. + /// + /// Path to the WAL file. + public WriteAheadLog(string walFilePath) + { + _walFilePath = walFilePath ?? throw new ArgumentNullException(nameof(walFilePath)); + + // Ensure directory exists + var directory = Path.GetDirectoryName(walFilePath); + if (!string.IsNullOrEmpty(directory) && !Directory.Exists(directory)) + Directory.CreateDirectory(directory); + + // Restore transaction ID from existing log to prevent duplicates after restart + _currentTransactionId = RestoreLastTransactionId(); + + // Open WAL file for append + _writer = new StreamWriter(_walFilePath, append: true, Encoding.UTF8) + { + AutoFlush = true // Critical: flush immediately for durability + }; + } + + /// + /// Restores the last transaction ID from an existing WAL file. + /// + /// The maximum transaction ID found in the log, or 0 if no log exists. + private long RestoreLastTransactionId() + { + if (!File.Exists(_walFilePath)) + return 0; + + long maxId = 0; + try + { + foreach (var line in File.ReadLines(_walFilePath)) + { + try + { + var entry = JsonConvert.DeserializeObject(line); + if (entry != null && entry.TransactionId > maxId) + maxId = entry.TransactionId; + } + catch (JsonException) + { + // Skip malformed lines + } + } + } + catch (IOException) + { + // If we can't read the file, start from 0 + return 0; + } + return maxId; + } + + /// + /// Logs a node addition operation. + /// + /// The numeric type. + /// The node being added. + /// The transaction ID for this operation. + public long LogAddNode(GraphNode node) + { + lock (_lock) + { + var txnId = ++_currentTransactionId; + var entry = new WALEntry + { + TransactionId = txnId, + Timestamp = DateTime.UtcNow, + OperationType = WALOperationType.AddNode, + NodeId = node.Id, + Data = JsonConvert.SerializeObject(node) + }; + + WriteEntry(entry); + return txnId; + } + } + + /// + /// Logs an edge addition operation. + /// + /// The numeric type. + /// The edge being added. + /// The transaction ID for this operation. + public long LogAddEdge(GraphEdge edge) + { + lock (_lock) + { + var txnId = ++_currentTransactionId; + var entry = new WALEntry + { + TransactionId = txnId, + Timestamp = DateTime.UtcNow, + OperationType = WALOperationType.AddEdge, + EdgeId = edge.Id, + Data = JsonConvert.SerializeObject(edge) + }; + + WriteEntry(entry); + return txnId; + } + } + + /// + /// Logs a node removal operation. + /// + /// The ID of the node being removed. + /// The transaction ID for this operation. + public long LogRemoveNode(string nodeId) + { + lock (_lock) + { + var txnId = ++_currentTransactionId; + var entry = new WALEntry + { + TransactionId = txnId, + Timestamp = DateTime.UtcNow, + OperationType = WALOperationType.RemoveNode, + NodeId = nodeId + }; + + WriteEntry(entry); + return txnId; + } + } + + /// + /// Logs an edge removal operation. + /// + /// The ID of the edge being removed. + /// The transaction ID for this operation. + public long LogRemoveEdge(string edgeId) + { + lock (_lock) + { + var txnId = ++_currentTransactionId; + var entry = new WALEntry + { + TransactionId = txnId, + Timestamp = DateTime.UtcNow, + OperationType = WALOperationType.RemoveEdge, + EdgeId = edgeId + }; + + WriteEntry(entry); + return txnId; + } + } + + /// + /// Logs a checkpoint (all data successfully persisted to disk). + /// + /// The transaction ID for this checkpoint. + public long LogCheckpoint() + { + lock (_lock) + { + var txnId = ++_currentTransactionId; + var entry = new WALEntry + { + TransactionId = txnId, + Timestamp = DateTime.UtcNow, + OperationType = WALOperationType.Checkpoint + }; + + WriteEntry(entry); + return txnId; + } + } + + /// + /// Reads all WAL entries from the log file. + /// + /// List of WAL entries in order. + public List ReadLog() + { + var entries = new List(); + + if (!File.Exists(_walFilePath)) + return entries; + + lock (_lock) + { + // Flush writer to ensure all entries are on disk + _writer?.Flush(); + + // Use FileShare.ReadWrite to allow reading while writer is still open (Windows compatibility) + using var fileStream = new FileStream(_walFilePath, FileMode.Open, FileAccess.Read, FileShare.ReadWrite); + using var reader = new StreamReader(fileStream, Encoding.UTF8); + string? line; + while ((line = reader.ReadLine()) != null) + { + try + { + var entry = JsonConvert.DeserializeObject(line); + if (entry != null) + entries.Add(entry); + } + catch (JsonSerializationException) + { + // Skip corrupted entries + } + } + } + + return entries; + } + + /// + /// Truncates the WAL after a successful checkpoint. + /// + /// + /// This removes old entries that have been successfully applied, + /// keeping the WAL file from growing indefinitely. + /// + public void Truncate() + { + lock (_lock) + { + _writer?.Close(); + _writer?.Dispose(); + + if (File.Exists(_walFilePath)) + File.Delete(_walFilePath); + + _writer = new StreamWriter(_walFilePath, append: true, Encoding.UTF8) + { + AutoFlush = true + }; + + _currentTransactionId = 0; + } + } + + /// + /// Writes a WAL entry to the log file. + /// + private void WriteEntry(WALEntry entry) + { + var json = JsonConvert.SerializeObject(entry); + _writer?.WriteLine(json); + // AutoFlush ensures it's written to disk immediately + } + + /// + /// Disposes the WAL, ensuring all entries are flushed. + /// + public void Dispose() + { + if (_disposed) + return; + + lock (_lock) + { + _writer?.Flush(); + _writer?.Close(); + _writer?.Dispose(); + _disposed = true; + } + } +} + +/// +/// Represents a single entry in the Write-Ahead Log. +/// +public class WALEntry +{ + /// + /// Gets or sets the transaction ID. + /// + public long TransactionId { get; set; } + + /// + /// Gets or sets the timestamp of the operation. + /// + public DateTime Timestamp { get; set; } + + /// + /// Gets or sets the type of operation. + /// + public WALOperationType OperationType { get; set; } + + /// + /// Gets or sets the node ID (for node operations). + /// + public string? NodeId { get; set; } + + /// + /// Gets or sets the edge ID (for edge operations). + /// + public string? EdgeId { get; set; } + + /// + /// Gets or sets the serialized data for the operation. + /// + public string? Data { get; set; } +} + +/// +/// Types of operations that can be logged in the WAL. +/// +public enum WALOperationType +{ + /// + /// Add a node to the graph. + /// + AddNode, + + /// + /// Add an edge to the graph. + /// + AddEdge, + + /// + /// Remove a node from the graph. + /// + RemoveNode, + + /// + /// Remove an edge from the graph. + /// + RemoveEdge, + + /// + /// Checkpoint - all operations up to this point are persisted. + /// + Checkpoint, + + /// + /// Begin a transaction. + /// + BeginTransaction, + + /// + /// Commit a transaction. + /// + CommitTransaction, + + /// + /// Rollback a transaction. + /// + RollbackTransaction +} diff --git a/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/BTreeIndexTests.cs b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/BTreeIndexTests.cs new file mode 100644 index 000000000..df2d7643d --- /dev/null +++ b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/BTreeIndexTests.cs @@ -0,0 +1,527 @@ +using System; +using System.IO; +using System.Linq; +using AiDotNet.RetrievalAugmentedGeneration.Graph; +using Xunit; + +namespace AiDotNetTests.UnitTests.RetrievalAugmentedGeneration +{ + public class BTreeIndexTests : IDisposable + { + private readonly string _testDirectory; + + public BTreeIndexTests() + { + _testDirectory = Path.Combine(Path.GetTempPath(), "btree_tests_" + Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(_testDirectory); + } + + public void Dispose() + { + if (Directory.Exists(_testDirectory)) + Directory.Delete(_testDirectory, true); + } + + private string GetTestIndexPath() + { + return Path.Combine(_testDirectory, $"test_index_{Guid.NewGuid():N}.db"); + } + + #region Constructor Tests + + [Fact] + public void Constructor_WithValidPath_CreatesIndex() + { + // Arrange & Act + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + + // Assert + Assert.Equal(0, index.Count); + } + + [Fact] + public void Constructor_WithNullPath_ThrowsArgumentNullException() + { + // Act & Assert + Assert.Throws(() => new BTreeIndex(null!)); + } + + [Fact] + public void Constructor_WithNonexistentDirectory_CreatesDirectory() + { + // Arrange + var nestedPath = Path.Combine(_testDirectory, "nested", "path", "index.db"); + + // Act + using var index = new BTreeIndex(nestedPath); + index.Add("key1", 100); + index.Flush(); + + // Assert + Assert.True(File.Exists(nestedPath)); + } + + #endregion + + #region Add Tests + + [Fact] + public void Add_WithValidKeyAndOffset_IncreasesCount() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + + // Act + index.Add("key1", 1024); + + // Assert + Assert.Equal(1, index.Count); + } + + [Fact] + public void Add_WithNullKey_ThrowsArgumentException() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + + // Act & Assert + Assert.Throws(() => index.Add(null!, 100)); + } + + [Fact] + public void Add_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + + // Act & Assert + Assert.Throws(() => index.Add("", 100)); + } + + [Fact] + public void Add_WithWhitespaceKey_ThrowsArgumentException() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + + // Act & Assert + Assert.Throws(() => index.Add(" ", 100)); + } + + [Fact] + public void Add_WithNegativeOffset_ThrowsArgumentException() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + + // Act & Assert + Assert.Throws(() => index.Add("key1", -1)); + } + + [Fact] + public void Add_WithDuplicateKey_UpdatesOffset() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + index.Add("key1", 1024); + + // Act + index.Add("key1", 2048); + + // Assert + Assert.Equal(1, index.Count); + Assert.Equal(2048, index.Get("key1")); + } + + #endregion + + #region Get Tests + + [Fact] + public void Get_WithExistingKey_ReturnsCorrectOffset() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + index.Add("key1", 1024); + index.Add("key2", 2048); + + // Act + var offset1 = index.Get("key1"); + var offset2 = index.Get("key2"); + + // Assert + Assert.Equal(1024, offset1); + Assert.Equal(2048, offset2); + } + + [Fact] + public void Get_WithNonexistentKey_ReturnsNegativeOne() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + + // Act + var offset = index.Get("nonexistent"); + + // Assert + Assert.Equal(-1, offset); + } + + [Fact] + public void Get_WithNullKey_ReturnsNegativeOne() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + + // Act + var offset = index.Get(null!); + + // Assert + Assert.Equal(-1, offset); + } + + #endregion + + #region Contains Tests + + [Fact] + public void Contains_WithExistingKey_ReturnsTrue() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + index.Add("key1", 1024); + + // Act + var contains = index.Contains("key1"); + + // Assert + Assert.True(contains); + } + + [Fact] + public void Contains_WithNonexistentKey_ReturnsFalse() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + + // Act + var contains = index.Contains("nonexistent"); + + // Assert + Assert.False(contains); + } + + [Fact] + public void Contains_WithNullKey_ReturnsFalse() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + + // Act + var contains = index.Contains(null!); + + // Assert + Assert.False(contains); + } + + #endregion + + #region Remove Tests + + [Fact] + public void Remove_WithExistingKey_RemovesKeyAndReturnsTrue() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + index.Add("key1", 1024); + + // Act + var result = index.Remove("key1"); + + // Assert + Assert.True(result); + Assert.Equal(0, index.Count); + Assert.False(index.Contains("key1")); + } + + [Fact] + public void Remove_WithNonexistentKey_ReturnsFalse() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + + // Act + var result = index.Remove("nonexistent"); + + // Assert + Assert.False(result); + } + + [Fact] + public void Remove_WithNullKey_ReturnsFalse() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + + // Act + var result = index.Remove(null!); + + // Assert + Assert.False(result); + } + + #endregion + + #region GetAllKeys Tests + + [Fact] + public void GetAllKeys_WithMultipleKeys_ReturnsAllKeys() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + index.Add("key1", 1024); + index.Add("key2", 2048); + index.Add("key3", 3072); + + // Act + var keys = index.GetAllKeys().ToList(); + + // Assert + Assert.Equal(3, keys.Count); + Assert.Contains("key1", keys); + Assert.Contains("key2", keys); + Assert.Contains("key3", keys); + } + + [Fact] + public void GetAllKeys_WithEmptyIndex_ReturnsEmpty() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + + // Act + var keys = index.GetAllKeys(); + + // Assert + Assert.Empty(keys); + } + + #endregion + + #region Clear Tests + + [Fact] + public void Clear_RemovesAllEntries() + { + // Arrange + var indexPath = GetTestIndexPath(); + using var index = new BTreeIndex(indexPath); + index.Add("key1", 1024); + index.Add("key2", 2048); + + // Act + index.Clear(); + + // Assert + Assert.Equal(0, index.Count); + Assert.Empty(index.GetAllKeys()); + } + + #endregion + + #region Flush and Persistence Tests + + [Fact] + public void Flush_SavesIndexToDisk() + { + // Arrange + var indexPath = GetTestIndexPath(); + using (var index = new BTreeIndex(indexPath)) + { + index.Add("key1", 1024); + index.Add("key2", 2048); + + // Act + index.Flush(); + } + + // Assert + Assert.True(File.Exists(indexPath)); + } + + [Fact] + public void Constructor_WithExistingIndexFile_LoadsData() + { + // Arrange + var indexPath = GetTestIndexPath(); + + // Create and populate index + using (var index = new BTreeIndex(indexPath)) + { + index.Add("key1", 1024); + index.Add("key2", 2048); + index.Add("key3", 3072); + index.Flush(); + } + + // Act - Create new index from same file + using var loadedIndex = new BTreeIndex(indexPath); + + // Assert + Assert.Equal(3, loadedIndex.Count); + Assert.Equal(1024, loadedIndex.Get("key1")); + Assert.Equal(2048, loadedIndex.Get("key2")); + Assert.Equal(3072, loadedIndex.Get("key3")); + } + + [Fact] + public void Dispose_FlushesDataToDisk() + { + // Arrange + var indexPath = GetTestIndexPath(); + + // Create index and add data without explicit flush + using (var index = new BTreeIndex(indexPath)) + { + index.Add("key1", 1024); + index.Add("key2", 2048); + // Dispose will be called here + } + + // Act - Load from disk + using var loadedIndex = new BTreeIndex(indexPath); + + // Assert + Assert.Equal(2, loadedIndex.Count); + Assert.Equal(1024, loadedIndex.Get("key1")); + Assert.Equal(2048, loadedIndex.Get("key2")); + } + + [Fact] + public void Flush_WithNoChanges_DoesNotModifyFileContent() + { + // Arrange + var indexPath = GetTestIndexPath(); + byte[] initialContent; + + // Create index with some data and flush to establish baseline + using (var index = new BTreeIndex(indexPath)) + { + index.Add("key1", 1024); + index.Flush(); + } + + // Read the initial file content + initialContent = File.ReadAllBytes(indexPath); + + // Act - Open existing index and flush without changes + using (var index = new BTreeIndex(indexPath)) + { + // No changes made, just flush + index.Flush(); + } + + // Assert - File content should be identical + var currentContent = File.ReadAllBytes(indexPath); + Assert.Equal(initialContent, currentContent); + } + + #endregion + + #region Integration Tests + + [Fact] + public void ComplexScenario_WithMultipleOperations_MaintainsConsistency() + { + // Arrange + var indexPath = GetTestIndexPath(); + + using (var index = new BTreeIndex(indexPath)) + { + // Add entries + index.Add("alice", 1000); + index.Add("bob", 2000); + index.Add("charlie", 3000); + Assert.Equal(3, index.Count); + + // Update an entry + index.Add("bob", 2500); + Assert.Equal(3, index.Count); + Assert.Equal(2500, index.Get("bob")); + + // Remove an entry + index.Remove("charlie"); + Assert.Equal(2, index.Count); + + // Add new entry + index.Add("diana", 4000); + Assert.Equal(3, index.Count); + + index.Flush(); + } + + // Reload and verify + using (var loadedIndex = new BTreeIndex(indexPath)) + { + Assert.Equal(3, loadedIndex.Count); + Assert.Equal(1000, loadedIndex.Get("alice")); + Assert.Equal(2500, loadedIndex.Get("bob")); + Assert.Equal(-1, loadedIndex.Get("charlie")); + Assert.Equal(4000, loadedIndex.Get("diana")); + } + } + + [Fact] + public void LargeIndex_WithThousandsOfEntries_PerformsCorrectly() + { + // Arrange + var indexPath = GetTestIndexPath(); + const int entryCount = 10000; + + using (var index = new BTreeIndex(indexPath)) + { + // Add many entries + for (int i = 0; i < entryCount; i++) + { + index.Add($"key_{i:D6}", i * 1024L); + } + + // Act + index.Flush(); + + // Assert - Verify random samples + Assert.Equal(entryCount, index.Count); + Assert.Equal(0, index.Get("key_000000")); + Assert.Equal(5000 * 1024L, index.Get("key_005000")); + Assert.Equal(9999 * 1024L, index.Get("key_009999")); + } + + // Reload and verify + using (var loadedIndex = new BTreeIndex(indexPath)) + { + Assert.Equal(entryCount, loadedIndex.Count); + Assert.Equal(1234 * 1024L, loadedIndex.Get("key_001234")); + } + } + + #endregion + } +} diff --git a/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/FileGraphStoreTests.cs b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/FileGraphStoreTests.cs new file mode 100644 index 000000000..feb4e238d --- /dev/null +++ b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/FileGraphStoreTests.cs @@ -0,0 +1,621 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using AiDotNet.RetrievalAugmentedGeneration.Graph; +using Xunit; + +namespace AiDotNetTests.UnitTests.RetrievalAugmentedGeneration +{ + public class FileGraphStoreTests : IDisposable + { + private readonly string _testDirectory; + + public FileGraphStoreTests() + { + _testDirectory = Path.Combine(Path.GetTempPath(), "filegraph_tests_" + Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(_testDirectory); + } + + public void Dispose() + { + if (Directory.Exists(_testDirectory)) + Directory.Delete(_testDirectory, true); + } + + private string GetTestStoragePath() + { + return Path.Combine(_testDirectory, Guid.NewGuid().ToString("N")); + } + + private GraphNode CreateTestNode(string id, string label, Dictionary? properties = null) + { + var node = new GraphNode(id, label); + if (properties != null) + { + foreach (var kvp in properties) + { + node.SetProperty(kvp.Key, kvp.Value); + } + } + return node; + } + + private GraphEdge CreateTestEdge(string sourceId, string relationType, string targetId, double weight = 1.0) + { + return new GraphEdge(sourceId, targetId, relationType, weight); + } + + #region Constructor Tests + + [Fact] + public void Constructor_WithValidPath_CreatesStoreAndDirectory() + { + // Arrange + var storagePath = GetTestStoragePath(); + + // Act + using var store = new FileGraphStore(storagePath); + + // Assert + Assert.True(Directory.Exists(storagePath)); + Assert.Equal(0, store.NodeCount); + Assert.Equal(0, store.EdgeCount); + } + + [Fact] + public void Constructor_WithNullPath_ThrowsArgumentException() + { + // Act & Assert + Assert.Throws(() => new FileGraphStore(null!)); + } + + [Fact] + public void Constructor_WithEmptyPath_ThrowsArgumentException() + { + // Act & Assert + Assert.Throws(() => new FileGraphStore("")); + } + + [Fact] + public void Constructor_CreatesRequiredFiles() + { + // Arrange + var storagePath = GetTestStoragePath(); + + // Act - use explicit using block to ensure dispose before assertions + using (var store = new FileGraphStore(storagePath)) + { + var node = CreateTestNode("node1", "PERSON"); + store.AddNode(node); + } // Dispose called here - flushes to disk + + // Assert + Assert.True(File.Exists(Path.Combine(storagePath, "node_index.db"))); + Assert.True(File.Exists(Path.Combine(storagePath, "nodes.dat"))); + } + + #endregion + + #region AddNode Tests + + [Fact] + public void AddNode_WithValidNode_IncreasesNodeCount() + { + // Arrange + var storagePath = GetTestStoragePath(); + using var store = new FileGraphStore(storagePath); + var node = CreateTestNode("node1", "PERSON"); + + // Act + store.AddNode(node); + + // Assert + Assert.Equal(1, store.NodeCount); + } + + [Fact] + public void AddNode_WithNullNode_ThrowsArgumentNullException() + { + // Arrange + var storagePath = GetTestStoragePath(); + using var store = new FileGraphStore(storagePath); + + // Act & Assert + Assert.Throws(() => store.AddNode(null!)); + } + + [Fact] + public void AddNode_WithProperties_PersistsProperties() + { + // Arrange + var storagePath = GetTestStoragePath(); + using var store = new FileGraphStore(storagePath); + var properties = new Dictionary + { + { "name", "Alice" }, + { "age", 30 }, + { "active", true } + }; + var node = CreateTestNode("node1", "PERSON", properties); + + // Act + store.AddNode(node); + var retrieved = store.GetNode("node1"); + + // Assert + Assert.NotNull(retrieved); + Assert.Equal("Alice", retrieved.GetProperty("name")); + Assert.Equal(30, retrieved.GetProperty("age")); + Assert.True(retrieved.GetProperty("active")); + } + + #endregion + + #region AddEdge Tests + + [Fact] + public void AddEdge_WithValidEdge_IncreasesEdgeCount() + { + // Arrange + var storagePath = GetTestStoragePath(); + using var store = new FileGraphStore(storagePath); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "COMPANY"); + store.AddNode(node1); + store.AddNode(node2); + var edge = CreateTestEdge("node1", "WORKS_AT", "node2"); + + // Act + store.AddEdge(edge); + + // Assert + Assert.Equal(1, store.EdgeCount); + } + + [Fact] + public void AddEdge_WithNullEdge_ThrowsArgumentNullException() + { + // Arrange + var storagePath = GetTestStoragePath(); + using var store = new FileGraphStore(storagePath); + + // Act & Assert + Assert.Throws(() => store.AddEdge(null!)); + } + + [Fact] + public void AddEdge_WithNonexistentSourceNode_ThrowsInvalidOperationException() + { + // Arrange + var storagePath = GetTestStoragePath(); + using var store = new FileGraphStore(storagePath); + var node = CreateTestNode("node1", "PERSON"); + store.AddNode(node); + var edge = CreateTestEdge("nonexistent", "KNOWS", "node1"); + + // Act & Assert + var exception = Assert.Throws(() => store.AddEdge(edge)); + Assert.Contains("Source node 'nonexistent' does not exist", exception.Message); + } + + #endregion + + #region GetNode Tests + + [Fact] + public void GetNode_WithExistingId_ReturnsNode() + { + // Arrange + var storagePath = GetTestStoragePath(); + using var store = new FileGraphStore(storagePath); + var node = CreateTestNode("node1", "PERSON"); + store.AddNode(node); + + // Act + var retrieved = store.GetNode("node1"); + + // Assert + Assert.NotNull(retrieved); + Assert.Equal("node1", retrieved.Id); + Assert.Equal("PERSON", retrieved.Label); + } + + [Fact] + public void GetNode_WithNonexistentId_ReturnsNull() + { + // Arrange + var storagePath = GetTestStoragePath(); + using var store = new FileGraphStore(storagePath); + + // Act + var retrieved = store.GetNode("nonexistent"); + + // Assert + Assert.Null(retrieved); + } + + #endregion + + #region Persistence Tests + + [Fact] + public void Persistence_NodesAndEdges_SurviveRestart() + { + // Arrange + var storagePath = GetTestStoragePath(); + + // Create and populate graph + using (var store = new FileGraphStore(storagePath)) + { + var alice = CreateTestNode("alice", "PERSON", new Dictionary { { "name", "Alice" } }); + var bob = CreateTestNode("bob", "PERSON", new Dictionary { { "name", "Bob" } }); + var acme = CreateTestNode("acme", "COMPANY", new Dictionary { { "name", "Acme Corp" } }); + + store.AddNode(alice); + store.AddNode(bob); + store.AddNode(acme); + + var edge1 = CreateTestEdge("alice", "KNOWS", "bob", 0.9); + var edge2 = CreateTestEdge("alice", "WORKS_AT", "acme", 1.0); + store.AddEdge(edge1); + store.AddEdge(edge2); + + // Dispose to flush + } + + // Act - Reload from disk + using (var reloadedStore = new FileGraphStore(storagePath)) + { + // Assert - Verify nodes + Assert.Equal(3, reloadedStore.NodeCount); + Assert.Equal(2, reloadedStore.EdgeCount); + + var alice = reloadedStore.GetNode("alice"); + Assert.NotNull(alice); + Assert.Equal("Alice", alice.GetProperty("name")); + + var bob = reloadedStore.GetNode("bob"); + Assert.NotNull(bob); + Assert.Equal("Bob", bob.GetProperty("name")); + + // Verify edges + var aliceOutgoing = reloadedStore.GetOutgoingEdges("alice").ToList(); + Assert.Equal(2, aliceOutgoing.Count); + Assert.Contains(aliceOutgoing, e => e.TargetId == "bob"); + Assert.Contains(aliceOutgoing, e => e.TargetId == "acme"); + } + } + + [Fact] + public void Persistence_LabelIndices_RebuildCorrectly() + { + // Arrange + var storagePath = GetTestStoragePath(); + + using (var store = new FileGraphStore(storagePath)) + { + store.AddNode(CreateTestNode("person1", "PERSON")); + store.AddNode(CreateTestNode("person2", "PERSON")); + store.AddNode(CreateTestNode("company1", "COMPANY")); + } + + // Act - Reload + using (var reloadedStore = new FileGraphStore(storagePath)) + { + // Assert + var persons = reloadedStore.GetNodesByLabel("PERSON").ToList(); + var companies = reloadedStore.GetNodesByLabel("COMPANY").ToList(); + + Assert.Equal(2, persons.Count); + Assert.Single(companies); + } + } + + [Fact] + public void Persistence_EdgeIndices_RebuildCorrectly() + { + // Arrange + var storagePath = GetTestStoragePath(); + + using (var store = new FileGraphStore(storagePath)) + { + store.AddNode(CreateTestNode("node1", "PERSON")); + store.AddNode(CreateTestNode("node2", "PERSON")); + store.AddNode(CreateTestNode("node3", "COMPANY")); + + store.AddEdge(CreateTestEdge("node1", "KNOWS", "node2")); + store.AddEdge(CreateTestEdge("node1", "WORKS_AT", "node3")); + store.AddEdge(CreateTestEdge("node2", "WORKS_AT", "node3")); + } + + // Act - Reload + using (var reloadedStore = new FileGraphStore(storagePath)) + { + // Assert - Check outgoing edges + var node1Outgoing = reloadedStore.GetOutgoingEdges("node1").ToList(); + Assert.Equal(2, node1Outgoing.Count); + + // Assert - Check incoming edges + var node3Incoming = reloadedStore.GetIncomingEdges("node3").ToList(); + Assert.Equal(2, node3Incoming.Count); + } + } + + #endregion + + #region RemoveNode Tests + + [Fact] + public void RemoveNode_WithExistingNode_RemovesNodeAndReturnsTrue() + { + // Arrange + var storagePath = GetTestStoragePath(); + using var store = new FileGraphStore(storagePath); + var node = CreateTestNode("node1", "PERSON"); + store.AddNode(node); + + // Act + var result = store.RemoveNode("node1"); + + // Assert + Assert.True(result); + Assert.Equal(0, store.NodeCount); + Assert.Null(store.GetNode("node1")); + } + + [Fact] + public void RemoveNode_RemovesAllConnectedEdges() + { + // Arrange + var storagePath = GetTestStoragePath(); + using var store = new FileGraphStore(storagePath); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "PERSON"); + var node3 = CreateTestNode("node3", "COMPANY"); + store.AddNode(node1); + store.AddNode(node2); + store.AddNode(node3); + + var edge1 = CreateTestEdge("node1", "KNOWS", "node2"); + var edge2 = CreateTestEdge("node2", "WORKS_AT", "node3"); + var edge3 = CreateTestEdge("node1", "WORKS_AT", "node3"); + store.AddEdge(edge1); + store.AddEdge(edge2); + store.AddEdge(edge3); + + // Act + store.RemoveNode("node1"); + + // Assert + Assert.Equal(2, store.NodeCount); + Assert.Equal(1, store.EdgeCount); + Assert.Null(store.GetEdge(edge1.Id)); + Assert.Null(store.GetEdge(edge3.Id)); + Assert.NotNull(store.GetEdge(edge2.Id)); + } + + [Fact] + public void RemoveNode_Persists_AfterReload() + { + // Arrange + var storagePath = GetTestStoragePath(); + + using (var store = new FileGraphStore(storagePath)) + { + store.AddNode(CreateTestNode("node1", "PERSON")); + store.AddNode(CreateTestNode("node2", "PERSON")); + store.RemoveNode("node1"); + } + + // Act - Reload + using (var reloadedStore = new FileGraphStore(storagePath)) + { + // Assert + Assert.Equal(1, reloadedStore.NodeCount); + Assert.Null(reloadedStore.GetNode("node1")); + Assert.NotNull(reloadedStore.GetNode("node2")); + } + } + + #endregion + + #region Clear Tests + + [Fact] + public void Clear_RemovesAllNodesAndEdges() + { + // Arrange + var storagePath = GetTestStoragePath(); + using var store = new FileGraphStore(storagePath); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "COMPANY"); + store.AddNode(node1); + store.AddNode(node2); + var edge = CreateTestEdge("node1", "WORKS_AT", "node2"); + store.AddEdge(edge); + + // Act + store.Clear(); + + // Assert + Assert.Equal(0, store.NodeCount); + Assert.Equal(0, store.EdgeCount); + Assert.Empty(store.GetAllNodes()); + Assert.Empty(store.GetAllEdges()); + } + + [Fact] + public void Clear_DeletesDataFiles() + { + // Arrange + var storagePath = GetTestStoragePath(); + using var store = new FileGraphStore(storagePath); + store.AddNode(CreateTestNode("node1", "PERSON")); + store.Dispose(); + + var nodesFile = Path.Combine(storagePath, "nodes.dat"); + Assert.True(File.Exists(nodesFile)); + + // Act + using var store2 = new FileGraphStore(storagePath); + store2.Clear(); + + // Assert + Assert.False(File.Exists(nodesFile)); + } + + #endregion + + #region Integration Tests + + [Fact] + public void ComplexGraph_WithMultipleOperations_MaintainsConsistency() + { + // Arrange + var storagePath = GetTestStoragePath(); + + using (var store = new FileGraphStore(storagePath)) + { + // Create a small social network + var alice = CreateTestNode("alice", "PERSON", new Dictionary { { "name", "Alice" } }); + var bob = CreateTestNode("bob", "PERSON", new Dictionary { { "name", "Bob" } }); + var charlie = CreateTestNode("charlie", "PERSON", new Dictionary { { "name", "Charlie" } }); + var acme = CreateTestNode("acme", "COMPANY", new Dictionary { { "name", "Acme Corp" } }); + + store.AddNode(alice); + store.AddNode(bob); + store.AddNode(charlie); + store.AddNode(acme); + + var edge1 = CreateTestEdge("alice", "KNOWS", "bob", 0.9); + var edge2 = CreateTestEdge("bob", "KNOWS", "charlie", 0.8); + var edge3 = CreateTestEdge("alice", "WORKS_AT", "acme", 1.0); + var edge4 = CreateTestEdge("bob", "WORKS_AT", "acme", 1.0); + + store.AddEdge(edge1); + store.AddEdge(edge2); + store.AddEdge(edge3); + store.AddEdge(edge4); + + // Verify initial state + Assert.Equal(4, store.NodeCount); + Assert.Equal(4, store.EdgeCount); + } + + // Reload and modify + using (var store = new FileGraphStore(storagePath)) + { + // Remove Bob - this removes his 3 edges: edge1 (alice->bob), edge2 (bob->charlie), edge4 (bob->acme) + // Only edge3 (alice->acme) remains + store.RemoveNode("bob"); + + Assert.Equal(3, store.NodeCount); + Assert.Equal(1, store.EdgeCount); + } + + // Reload again and verify persistence + using (var store = new FileGraphStore(storagePath)) + { + Assert.Equal(3, store.NodeCount); + Assert.Equal(1, store.EdgeCount); + Assert.Null(store.GetNode("bob")); + Assert.Single(store.GetIncomingEdges("acme")); + } + } + + [Fact] + public void LargeGraph_WithHundredsOfNodes_PerformsCorrectly() + { + // Arrange + var storagePath = GetTestStoragePath(); + const int nodeCount = 500; + + using (var store = new FileGraphStore(storagePath)) + { + // Add many nodes + for (int i = 0; i < nodeCount; i++) + { + var node = CreateTestNode($"node_{i:D4}", "PERSON", new Dictionary + { + { "name", $"Person {i}" }, + { "index", i } + }); + store.AddNode(node); + } + + // Add some edges + for (int i = 0; i < nodeCount - 1; i += 2) + { + store.AddEdge(CreateTestEdge($"node_{i:D4}", "KNOWS", $"node_{(i+1):D4}")); + } + + Assert.Equal(nodeCount, store.NodeCount); + Assert.Equal(nodeCount / 2, store.EdgeCount); + } + + // Reload and verify + using (var reloadedStore = new FileGraphStore(storagePath)) + { + Assert.Equal(nodeCount, reloadedStore.NodeCount); + Assert.Equal(nodeCount / 2, reloadedStore.EdgeCount); + + // Verify random samples + var node100 = reloadedStore.GetNode("node_0100"); + Assert.NotNull(node100); + Assert.Equal("Person 100", node100.GetProperty("name")); + + var node250 = reloadedStore.GetNode("node_0250"); + Assert.NotNull(node250); + Assert.Equal(250, node250.GetProperty("index")); + } + } + + #endregion + + #region KnowledgeGraph Integration Test + + [Fact] + public void KnowledgeGraph_WithFileGraphStore_PersistsCorrectly() + { + // Arrange + var storagePath = GetTestStoragePath(); + + // Create graph with file storage + using (var fileStore = new FileGraphStore(storagePath)) + { + var graph = new KnowledgeGraph(fileStore); + var alice = CreateTestNode("alice", "PERSON", new Dictionary { { "name", "Alice" } }); + var bob = CreateTestNode("bob", "PERSON", new Dictionary { { "name", "Bob" } }); + + graph.AddNode(alice); + graph.AddNode(bob); + graph.AddEdge(CreateTestEdge("alice", "KNOWS", "bob")); + + Assert.Equal(2, graph.NodeCount); + Assert.Equal(1, graph.EdgeCount); + } + + // Reload with new KnowledgeGraph instance + using (var fileStore = new FileGraphStore(storagePath)) + { + var graph = new KnowledgeGraph(fileStore); + // Assert - Data persisted + Assert.Equal(2, graph.NodeCount); + Assert.Equal(1, graph.EdgeCount); + + var alice = graph.GetNode("alice"); + Assert.NotNull(alice); + Assert.Equal("Alice", alice.GetProperty("name")); + + // Test graph algorithms still work + var path = graph.FindShortestPath("alice", "bob"); + Assert.Equal(2, path.Count); + Assert.Equal("alice", path[0]); + Assert.Equal("bob", path[1]); + } + } + + #endregion + } +} diff --git a/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/GraphQueryMatcherTests.cs b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/GraphQueryMatcherTests.cs new file mode 100644 index 000000000..8a065d313 --- /dev/null +++ b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/GraphQueryMatcherTests.cs @@ -0,0 +1,541 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using AiDotNet.RetrievalAugmentedGeneration.Graph; +using Xunit; + +namespace AiDotNetTests.UnitTests.RetrievalAugmentedGeneration +{ + /// + /// Unit tests for GraphQueryMatcher. + /// + /// + /// Note: xUnit creates a fresh instance of this test class for each [Fact] test method, + /// so the constructor runs for each test, ensuring a fresh _graph and _matcher for every test. + /// This provides test isolation without needing IClassFixture or explicit per-test setup. + /// + public class GraphQueryMatcherTests + { + private readonly KnowledgeGraph _graph; + private readonly GraphQueryMatcher _matcher; + + public GraphQueryMatcherTests() + { + _graph = new KnowledgeGraph(); + _matcher = new GraphQueryMatcher(_graph); + + SetupTestData(); + } + + private GraphNode CreateNode(string id, string label, Dictionary? properties = null) + { + var node = new GraphNode(id, label); + if (properties != null) + { + foreach (var kvp in properties) + { + node.SetProperty(kvp.Key, kvp.Value); + } + } + return node; + } + + private GraphEdge CreateEdge(string sourceId, string relationType, string targetId, double weight = 1.0) + { + return new GraphEdge(sourceId, targetId, relationType, weight); + } + + private void SetupTestData() + { + // Create people + _graph.AddNode(CreateNode("alice", "Person", new Dictionary + { + { "name", "Alice" }, + { "age", 30 } + })); + + _graph.AddNode(CreateNode("bob", "Person", new Dictionary + { + { "name", "Bob" }, + { "age", 35 } + })); + + _graph.AddNode(CreateNode("charlie", "Person", new Dictionary + { + { "name", "Charlie" }, + { "age", 28 } + })); + + // Create companies + _graph.AddNode(CreateNode("google", "Company", new Dictionary + { + { "name", "Google" }, + { "industry", "Tech" } + })); + + _graph.AddNode(CreateNode("microsoft", "Company", new Dictionary + { + { "name", "Microsoft" }, + { "industry", "Tech" } + })); + + // Create relationships + _graph.AddEdge(CreateEdge("alice", "KNOWS", "bob")); + _graph.AddEdge(CreateEdge("bob", "KNOWS", "charlie")); + _graph.AddEdge(CreateEdge("alice", "WORKS_AT", "google")); + _graph.AddEdge(CreateEdge("bob", "WORKS_AT", "microsoft")); + _graph.AddEdge(CreateEdge("charlie", "WORKS_AT", "google")); + } + + #region FindNodes Tests + + [Fact] + public void FindNodes_ByLabel_ReturnsAllMatchingNodes() + { + // Act + var people = _matcher.FindNodes("Person"); + + // Assert + Assert.Equal(3, people.Count); + Assert.All(people, p => Assert.Equal("Person", p.Label)); + } + + [Fact] + public void FindNodes_ByLabelAndProperty_ReturnsFilteredNodes() + { + // Arrange + var props = new Dictionary { { "name", "Alice" } }; + + // Act + var results = _matcher.FindNodes("Person", props); + + // Assert + Assert.Single(results); + Assert.Equal("alice", results[0].Id); + } + + [Fact] + public void FindNodes_MultipleProperties_FiltersCorrectly() + { + // Arrange + var props = new Dictionary + { + { "name", "Alice" }, + { "age", 30 } + }; + + // Act + var results = _matcher.FindNodes("Person", props); + + // Assert + Assert.Single(results); + Assert.Equal("alice", results[0].Id); + } + + [Fact] + public void FindNodes_NoMatches_ReturnsEmptyList() + { + // Arrange + var props = new Dictionary { { "name", "NonExistent" } }; + + // Act + var results = _matcher.FindNodes("Person", props); + + // Assert + Assert.Empty(results); + } + + [Fact] + public void FindNodes_InvalidLabel_ThrowsException() + { + // Act & Assert + Assert.Throws(() => _matcher.FindNodes(null!)); + Assert.Throws(() => _matcher.FindNodes("")); + Assert.Throws(() => _matcher.FindNodes(" ")); + } + + #endregion + + #region FindPaths Tests + + [Fact] + public void FindPaths_SimplePattern_ReturnsMatchingPaths() + { + // Act + var paths = _matcher.FindPaths("Person", "KNOWS", "Person"); + + // Assert + Assert.Equal(2, paths.Count); // Alice->Bob and Bob->Charlie + Assert.All(paths, p => + { + Assert.Equal("Person", p.SourceNode.Label); + Assert.Equal("KNOWS", p.Edge.RelationType); + Assert.Equal("Person", p.TargetNode.Label); + }); + } + + [Fact] + public void FindPaths_WithSourceFilter_ReturnsFilteredPaths() + { + // Arrange + var sourceProps = new Dictionary { { "name", "Alice" } }; + + // Act + var paths = _matcher.FindPaths("Person", "KNOWS", "Person", sourceProps); + + // Assert + Assert.Single(paths); // Only Alice->Bob + Assert.Equal("alice", paths[0].SourceNode.Id); + Assert.Equal("bob", paths[0].TargetNode.Id); + } + + [Fact] + public void FindPaths_WithTargetFilter_ReturnsFilteredPaths() + { + // Arrange + var targetProps = new Dictionary { { "name", "Charlie" } }; + + // Act + var paths = _matcher.FindPaths("Person", "KNOWS", "Person", null, targetProps); + + // Assert + Assert.Single(paths); // Only Bob->Charlie + Assert.Equal("bob", paths[0].SourceNode.Id); + Assert.Equal("charlie", paths[0].TargetNode.Id); + } + + [Fact] + public void FindPaths_DifferentLabels_WorksCorrectly() + { + // Act + var paths = _matcher.FindPaths("Person", "WORKS_AT", "Company"); + + // Assert + Assert.Equal(3, paths.Count); // Alice->Google, Bob->Microsoft, Charlie->Google + Assert.All(paths, p => + { + Assert.Equal("Person", p.SourceNode.Label); + Assert.Equal("WORKS_AT", p.Edge.RelationType); + Assert.Equal("Company", p.TargetNode.Label); + }); + } + + [Fact] + public void FindPaths_WithBothFilters_ReturnsSpecificPath() + { + // Arrange + var sourceProps = new Dictionary { { "name", "Alice" } }; + var targetProps = new Dictionary { { "name", "Google" } }; + + // Act + var paths = _matcher.FindPaths("Person", "WORKS_AT", "Company", sourceProps, targetProps); + + // Assert + Assert.Single(paths); + Assert.Equal("alice", paths[0].SourceNode.Id); + Assert.Equal("google", paths[0].TargetNode.Id); + } + + [Fact] + public void FindPaths_InvalidArguments_ThrowsException() + { + // Act & Assert + Assert.Throws(() => + _matcher.FindPaths(null!, "KNOWS", "Person")); + + Assert.Throws(() => + _matcher.FindPaths("Person", null!, "Person")); + + Assert.Throws(() => + _matcher.FindPaths("Person", "KNOWS", null!)); + } + + #endregion + + #region FindPathsOfLength Tests + + [Fact] + public void FindPathsOfLength_Length1_ReturnsDirectNeighbors() + { + // Act + var paths = _matcher.FindPathsOfLength("alice", 1); + + // Assert + Assert.Equal(2, paths.Count); // Alice->Bob and Alice->Google + Assert.All(paths, p => + { + Assert.Equal(2, p.Count); // Source + Target + Assert.Equal("alice", p[0].Id); + }); + } + + [Fact] + public void FindPathsOfLength_Length2_ReturnsDistantNodes() + { + // Act + var paths = _matcher.FindPathsOfLength("alice", 2); + + // Assert + Assert.NotEmpty(paths); + Assert.All(paths, p => Assert.Equal(3, p.Count)); // Source + Intermediate + Target + } + + [Fact] + public void FindPathsOfLength_WithRelationshipFilter_FiltersCorrectly() + { + // Act + var paths = _matcher.FindPathsOfLength("alice", 1, "KNOWS"); + + // Assert + Assert.Single(paths); // Only Alice->Bob via KNOWS + Assert.Equal("bob", paths[0][1].Id); + } + + [Fact] + public void FindPathsOfLength_AvoidsCycles() + { + // Act - This would create a cycle if not handled + var paths = _matcher.FindPathsOfLength("alice", 5); + + // Assert - Should not contain cycles (same node appearing twice) + Assert.All(paths, p => + { + var nodeIds = p.Select(n => n.Id).ToList(); + Assert.Equal(nodeIds.Count, nodeIds.Distinct().Count()); + }); + } + + [Fact] + public void FindPathsOfLength_InvalidArguments_ThrowsException() + { + // Act & Assert + Assert.Throws(() => + _matcher.FindPathsOfLength(null!, 1)); + + Assert.Throws(() => + _matcher.FindPathsOfLength("alice", 0)); + + Assert.Throws(() => + _matcher.FindPathsOfLength("alice", -1)); + } + + #endregion + + #region FindShortestPaths Tests + + [Fact] + public void FindShortestPaths_DirectConnection_ReturnsShortestPath() + { + // Act + var paths = _matcher.FindShortestPaths("alice", "bob"); + + // Assert + Assert.Single(paths); + Assert.Equal(2, paths[0].Count); // Alice -> Bob + Assert.Equal("alice", paths[0][0].Id); + Assert.Equal("bob", paths[0][1].Id); + } + + [Fact] + public void FindShortestPaths_IndirectConnection_FindsPath() + { + // Act + var paths = _matcher.FindShortestPaths("alice", "charlie"); + + // Assert + Assert.NotEmpty(paths); + var shortestPath = paths[0]; + Assert.Equal("alice", shortestPath[0].Id); + Assert.Equal("charlie", shortestPath[^1].Id); + } + + [Fact] + public void FindShortestPaths_SameNode_ReturnsSingleNodePath() + { + // Act + var paths = _matcher.FindShortestPaths("alice", "alice"); + + // Assert + Assert.Single(paths); + Assert.Single(paths[0]); + Assert.Equal("alice", paths[0][0].Id); + } + + [Fact] + public void FindShortestPaths_NoConnection_ReturnsEmpty() + { + // Arrange - Add isolated node + _graph.AddNode(CreateNode("isolated", "Person", new Dictionary { { "name", "Isolated" } })); + + // Act + var paths = _matcher.FindShortestPaths("alice", "isolated"); + + // Assert + Assert.Empty(paths); + } + + [Fact] + public void FindShortestPaths_InvalidArguments_ThrowsException() + { + // Act & Assert + Assert.Throws(() => + _matcher.FindShortestPaths(null!, "bob")); + + Assert.Throws(() => + _matcher.FindShortestPaths("alice", null!)); + } + + [Fact] + public void FindShortestPaths_RespectsMaxDepth() + { + // Act - Set maxDepth to 1, so can't reach Charlie from Alice + var paths = _matcher.FindShortestPaths("alice", "charlie", maxDepth: 1); + + // Assert - Charlie is 2 hops away, so shouldn't be found + Assert.Empty(paths); + } + + #endregion + + #region ExecutePattern Tests + + [Fact] + public void ExecutePattern_SimplePattern_ReturnsMatches() + { + // Act + var paths = _matcher.ExecutePattern("(Person)-[KNOWS]->(Person)"); + + // Assert + Assert.Equal(2, paths.Count); + } + + [Fact] + public void ExecutePattern_WithSourceProperty_FiltersCorrectly() + { + // Act + var paths = _matcher.ExecutePattern("(Person {name: \"Alice\"})-[KNOWS]->(Person)"); + + // Assert + Assert.Single(paths); + Assert.Equal("alice", paths[0].SourceNode.Id); + } + + [Fact] + public void ExecutePattern_WithTargetProperty_FiltersCorrectly() + { + // Act + var paths = _matcher.ExecutePattern("(Person)-[WORKS_AT]->(Company {name: \"Google\"})"); + + // Assert + Assert.Equal(2, paths.Count); // Alice and Charlie work at Google + Assert.All(paths, p => Assert.Equal("google", p.TargetNode.Id)); + } + + [Fact] + public void ExecutePattern_WithBothProperties_FindsSpecificPath() + { + // Act + var paths = _matcher.ExecutePattern("(Person {name: \"Alice\"})-[WORKS_AT]->(Company {name: \"Google\"})"); + + // Assert + Assert.Single(paths); + Assert.Equal("alice", paths[0].SourceNode.Id); + Assert.Equal("google", paths[0].TargetNode.Id); + } + + [Fact] + public void ExecutePattern_InvalidFormat_ThrowsException() + { + // Act & Assert + Assert.Throws(() => + _matcher.ExecutePattern("invalid pattern")); + + Assert.Throws(() => + _matcher.ExecutePattern("(Person)")); + + Assert.Throws(() => + _matcher.ExecutePattern(null!)); + } + + #endregion + + #region GraphPath ToString Tests + + [Fact] + public void GraphPath_ToString_FormatsCorrectly() + { + // Arrange + var paths = _matcher.FindPaths("Person", "KNOWS", "Person"); + + // Act + var pathString = paths[0].ToString(); + + // Assert + Assert.Contains("Person:", pathString); + Assert.Contains("-[KNOWS]->", pathString); + } + + #endregion + + #region Complex Scenario Tests + + [Fact] + public void FindPaths_ComplexQuery_HandlesMultipleCriteria() + { + // Arrange - Find people older than 30 who work at tech companies + var sourceProps = new Dictionary { { "age", 35 } }; + var targetProps = new Dictionary { { "industry", "Tech" } }; + + // Act + var paths = _matcher.FindPaths("Person", "WORKS_AT", "Company", sourceProps, targetProps); + + // Assert + Assert.Single(paths); // Only Bob (age 35) works at Microsoft (Tech) + Assert.Equal("bob", paths[0].SourceNode.Id); + } + + [Fact] + public void FindPathsOfLength_ComplexGraph_FindsAllPaths() + { + // Arrange - Add more connections to create multiple paths + _graph.AddEdge(CreateEdge("alice", "KNOWS", "charlie")); + + // Act - Find 1-hop paths from Alice + var paths = _matcher.FindPathsOfLength("alice", 1); + + // Assert - Should now have 3 paths (Bob, Charlie, Google) + Assert.Equal(3, paths.Count); + } + + #endregion + + #region Numeric Property Tests + + [Fact] + public void FindNodes_NumericProperty_ComparesCorrectly() + { + // Arrange + var props = new Dictionary { { "age", 30 } }; + + // Act + var results = _matcher.FindNodes("Person", props); + + // Assert + Assert.Single(results); + Assert.Equal("alice", results[0].Id); + } + + [Fact] + public void ExecutePattern_NumericProperty_ParsesCorrectly() + { + // Act - Pattern with numeric property + var paths = _matcher.ExecutePattern("(Person {age: 30})-[KNOWS]->(Person)"); + + // Assert + Assert.Single(paths); + Assert.Equal("alice", paths[0].SourceNode.Id); + } + + #endregion + } +} diff --git a/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/GraphStoreAsyncTests.cs b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/GraphStoreAsyncTests.cs new file mode 100644 index 000000000..aabc831d2 --- /dev/null +++ b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/GraphStoreAsyncTests.cs @@ -0,0 +1,295 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading.Tasks; +using AiDotNet.RetrievalAugmentedGeneration.Graph; +using Xunit; + +namespace AiDotNetTests.UnitTests.RetrievalAugmentedGeneration +{ + public class GraphStoreAsyncTests : IDisposable + { + private readonly string _testDirectory; + + public GraphStoreAsyncTests() + { + _testDirectory = Path.Combine(Path.GetTempPath(), "async_tests_" + Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(_testDirectory); + } + + public void Dispose() + { + if (Directory.Exists(_testDirectory)) + Directory.Delete(_testDirectory, true); + } + + private GraphNode CreateTestNode(string id, string label) + { + var node = new GraphNode(id, label); + node.SetProperty("name", $"Node {id}"); + return node; + } + + private GraphEdge CreateTestEdge(string sourceId, string relationType, string targetId) + { + return new GraphEdge(sourceId, targetId, relationType, 1.0); + } + + #region MemoryGraphStore Async Tests + + [Fact] + public async Task MemoryStore_AddNodeAsync_AddsNode() + { + // Arrange + var store = new MemoryGraphStore(); + var node = CreateTestNode("node1", "PERSON"); + + // Act + await store.AddNodeAsync(node); + + // Assert + Assert.Equal(1, store.NodeCount); + var retrieved = await store.GetNodeAsync("node1"); + Assert.NotNull(retrieved); + Assert.Equal("node1", retrieved.Id); + } + + [Fact] + public async Task MemoryStore_AddEdgeAsync_AddsEdge() + { + // Arrange + var store = new MemoryGraphStore(); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "COMPANY"); + await store.AddNodeAsync(node1); + await store.AddNodeAsync(node2); + var edge = CreateTestEdge("node1", "WORKS_AT", "node2"); + + // Act + await store.AddEdgeAsync(edge); + + // Assert + Assert.Equal(1, store.EdgeCount); + var retrieved = await store.GetEdgeAsync(edge.Id); + Assert.NotNull(retrieved); + } + + [Fact] + public async Task MemoryStore_GetNodesByLabelAsync_ReturnsCorrectNodes() + { + // Arrange + var store = new MemoryGraphStore(); + await store.AddNodeAsync(CreateTestNode("person1", "PERSON")); + await store.AddNodeAsync(CreateTestNode("person2", "PERSON")); + await store.AddNodeAsync(CreateTestNode("company1", "COMPANY")); + + // Act + var persons = await store.GetNodesByLabelAsync("PERSON"); + + // Assert + Assert.Equal(2, System.Linq.Enumerable.Count(persons)); + } + + [Fact] + public async Task MemoryStore_RemoveNodeAsync_RemovesNodeAndEdges() + { + // Arrange + var store = new MemoryGraphStore(); + await store.AddNodeAsync(CreateTestNode("node1", "PERSON")); + await store.AddNodeAsync(CreateTestNode("node2", "PERSON")); + await store.AddEdgeAsync(CreateTestEdge("node1", "KNOWS", "node2")); + + // Act + var result = await store.RemoveNodeAsync("node1"); + + // Assert + Assert.True(result); + Assert.Equal(1, store.NodeCount); + Assert.Equal(0, store.EdgeCount); + } + + #endregion + + #region FileGraphStore Async Tests + + [Fact] + public async Task FileStore_AddNodeAsync_PersistsNode() + { + // Arrange + var storagePath = Path.Combine(_testDirectory, Guid.NewGuid().ToString("N")); + using var store = new FileGraphStore(storagePath); + var node = CreateTestNode("node1", "PERSON"); + + // Act + await store.AddNodeAsync(node); + + // Assert + Assert.Equal(1, store.NodeCount); + var retrieved = await store.GetNodeAsync("node1"); + Assert.NotNull(retrieved); + Assert.Equal("node1", retrieved.Id); + Assert.Equal("Node node1", retrieved.GetProperty("name")); + } + + [Fact] + public async Task FileStore_AddEdgeAsync_PersistsEdge() + { + // Arrange + var storagePath = Path.Combine(_testDirectory, Guid.NewGuid().ToString("N")); + using var store = new FileGraphStore(storagePath); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "COMPANY"); + await store.AddNodeAsync(node1); + await store.AddNodeAsync(node2); + var edge = CreateTestEdge("node1", "WORKS_AT", "node2"); + + // Act + await store.AddEdgeAsync(edge); + + // Assert + Assert.Equal(1, store.EdgeCount); + var retrieved = await store.GetEdgeAsync(edge.Id); + Assert.NotNull(retrieved); + Assert.Equal("node1", retrieved.SourceId); + Assert.Equal("node2", retrieved.TargetId); + } + + [Fact] + public async Task FileStore_AsyncOperations_SurviveRestart() + { + // Arrange + var storagePath = Path.Combine(_testDirectory, Guid.NewGuid().ToString("N")); + + // Create and populate + using (var store = new FileGraphStore(storagePath)) + { + await store.AddNodeAsync(CreateTestNode("alice", "PERSON")); + await store.AddNodeAsync(CreateTestNode("bob", "PERSON")); + await store.AddEdgeAsync(CreateTestEdge("alice", "KNOWS", "bob")); + } + + // Act - Reload + using (var reloadedStore = new FileGraphStore(storagePath)) + { + // Assert + Assert.Equal(2, reloadedStore.NodeCount); + Assert.Equal(1, reloadedStore.EdgeCount); + + var alice = await reloadedStore.GetNodeAsync("alice"); + Assert.NotNull(alice); + Assert.Equal("Node alice", alice.GetProperty("name")); + + var outgoing = await reloadedStore.GetOutgoingEdgesAsync("alice"); + Assert.Single(outgoing); + } + } + + [Fact] + public async Task FileStore_GetAllNodesAsync_ReturnsAllNodes() + { + // Arrange + var storagePath = Path.Combine(_testDirectory, Guid.NewGuid().ToString("N")); + using var store = new FileGraphStore(storagePath); + await store.AddNodeAsync(CreateTestNode("node1", "PERSON")); + await store.AddNodeAsync(CreateTestNode("node2", "COMPANY")); + await store.AddNodeAsync(CreateTestNode("node3", "LOCATION")); + + // Act + var allNodes = await store.GetAllNodesAsync(); + + // Assert + Assert.Equal(3, System.Linq.Enumerable.Count(allNodes)); + } + + [Fact] + public async Task FileStore_GetAllEdgesAsync_ReturnsAllEdges() + { + // Arrange + var storagePath = Path.Combine(_testDirectory, Guid.NewGuid().ToString("N")); + using var store = new FileGraphStore(storagePath); + await store.AddNodeAsync(CreateTestNode("node1", "PERSON")); + await store.AddNodeAsync(CreateTestNode("node2", "PERSON")); + await store.AddNodeAsync(CreateTestNode("node3", "COMPANY")); + await store.AddEdgeAsync(CreateTestEdge("node1", "KNOWS", "node2")); + await store.AddEdgeAsync(CreateTestEdge("node1", "WORKS_AT", "node3")); + + // Act + var allEdges = await store.GetAllEdgesAsync(); + + // Assert + Assert.Equal(2, System.Linq.Enumerable.Count(allEdges)); + } + + [Fact] + public async Task FileStore_ClearAsync_RemovesAllData() + { + // Arrange + var storagePath = Path.Combine(_testDirectory, Guid.NewGuid().ToString("N")); + using var store = new FileGraphStore(storagePath); + await store.AddNodeAsync(CreateTestNode("node1", "PERSON")); + await store.AddNodeAsync(CreateTestNode("node2", "COMPANY")); + + // Act + await store.ClearAsync(); + + // Assert + Assert.Equal(0, store.NodeCount); + Assert.Equal(0, store.EdgeCount); + } + + [Fact] + public async Task FileStore_ConcurrentReads_WorkCorrectly() + { + // Arrange + var storagePath = Path.Combine(_testDirectory, Guid.NewGuid().ToString("N")); + using var store = new FileGraphStore(storagePath); + + // Add multiple nodes + for (int i = 0; i < 10; i++) + { + await store.AddNodeAsync(CreateTestNode($"node{i}", "PERSON")); + } + + // Act - Read concurrently + var tasks = new List?>>(); + for (int i = 0; i < 10; i++) + { + var nodeId = $"node{i}"; + tasks.Add(store.GetNodeAsync(nodeId)); + } + + var results = await Task.WhenAll(tasks); + + // Assert + Assert.Equal(10, results.Length); + Assert.All(results, node => Assert.NotNull(node)); + } + + #endregion + + #region Performance Comparison Tests + + [Fact] + public async Task PerformanceComparison_BulkInsert_AsyncVsSync() + { + // This test demonstrates that async operations can handle bulk inserts efficiently + var storagePath = Path.Combine(_testDirectory, Guid.NewGuid().ToString("N")); + using var store = new FileGraphStore(storagePath); + + const int nodeCount = 100; + + // Act - Async bulk insert + var asyncTasks = new List(); + for (int i = 0; i < nodeCount; i++) + { + asyncTasks.Add(store.AddNodeAsync(CreateTestNode($"node{i}", "PERSON"))); + } + await Task.WhenAll(asyncTasks); + + // Assert + Assert.Equal(nodeCount, store.NodeCount); + } + + #endregion + } +} diff --git a/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/GraphTransactionTests.cs b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/GraphTransactionTests.cs new file mode 100644 index 000000000..b6a9d0214 --- /dev/null +++ b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/GraphTransactionTests.cs @@ -0,0 +1,670 @@ +using System; +using System.Collections.Generic; +using System.IO; +using AiDotNet.RetrievalAugmentedGeneration.Graph; +using Xunit; + +namespace AiDotNetTests.UnitTests.RetrievalAugmentedGeneration +{ + public class GraphTransactionTests : IDisposable + { + private readonly string _testDirectory; + + public GraphTransactionTests() + { + _testDirectory = Path.Combine(Path.GetTempPath(), "txn_tests_" + Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(_testDirectory); + } + + public void Dispose() + { + if (Directory.Exists(_testDirectory)) + Directory.Delete(_testDirectory, true); + } + + private GraphNode CreateTestNode(string id, string label) + { + var node = new GraphNode(id, label); + node.SetProperty("name", $"Node {id}"); + return node; + } + + private GraphEdge CreateTestEdge(string sourceId, string relationType, string targetId) + { + return new GraphEdge(sourceId, targetId, relationType, 1.0); + } + + #region Basic Transaction Tests + + [Fact] + public void Transaction_Begin_SetsStateToActive() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + + // Act + txn.Begin(); + + // Assert + Assert.Equal(TransactionState.Active, txn.State); + Assert.NotEqual(-1, txn.TransactionId); + } + + [Fact] + public void Transaction_BeginTwice_ThrowsException() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + txn.Begin(); + + // Act & Assert + Assert.Throws(() => txn.Begin()); + } + + [Fact] + public void Transaction_AddNodeBeforeBegin_ThrowsException() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + var node = CreateTestNode("node1", "PERSON"); + + // Act & Assert + Assert.Throws(() => txn.AddNode(node)); + } + + #endregion + + #region Commit Tests + + [Fact] + public void Transaction_Commit_AppliesAllOperations() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + var node1 = CreateTestNode("alice", "PERSON"); + var node2 = CreateTestNode("bob", "PERSON"); + var edge = CreateTestEdge("alice", "KNOWS", "bob"); + + // Act + txn.Begin(); + txn.AddNode(node1); + txn.AddNode(node2); + txn.AddEdge(edge); + txn.Commit(); + + // Assert + Assert.Equal(TransactionState.Committed, txn.State); + Assert.Equal(2, store.NodeCount); + Assert.Equal(1, store.EdgeCount); + Assert.NotNull(store.GetNode("alice")); + Assert.NotNull(store.GetNode("bob")); + } + + [Fact] + public void Transaction_CommitWithRemoveOperations_WorksCorrectly() + { + // Arrange + var store = new MemoryGraphStore(); + store.AddNode(CreateTestNode("alice", "PERSON")); + store.AddNode(CreateTestNode("bob", "PERSON")); + var edge = CreateTestEdge("alice", "KNOWS", "bob"); + store.AddEdge(edge); + + var txn = new GraphTransaction(store); + + // Act + txn.Begin(); + txn.RemoveEdge(edge.Id); + txn.RemoveNode("bob"); + txn.Commit(); + + // Assert + Assert.Equal(TransactionState.Committed, txn.State); + Assert.Equal(1, store.NodeCount); + Assert.Equal(0, store.EdgeCount); + Assert.NotNull(store.GetNode("alice")); + Assert.Null(store.GetNode("bob")); + } + + [Fact] + public void Transaction_CommitBeforeBegin_ThrowsException() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + + // Act & Assert + Assert.Throws(() => txn.Commit()); + } + + [Fact] + public void Transaction_CommitAfterCommit_ThrowsException() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + txn.Begin(); + txn.AddNode(CreateTestNode("node1", "PERSON")); + txn.Commit(); + + // Act & Assert + Assert.Throws(() => txn.Commit()); + } + + #endregion + + #region Rollback Tests + + [Fact] + public void Transaction_Rollback_DiscardsAllOperations() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + var node1 = CreateTestNode("alice", "PERSON"); + var node2 = CreateTestNode("bob", "PERSON"); + + // Act + txn.Begin(); + txn.AddNode(node1); + txn.AddNode(node2); + txn.Rollback(); + + // Assert + Assert.Equal(TransactionState.RolledBack, txn.State); + Assert.Equal(0, store.NodeCount); + Assert.Null(store.GetNode("alice")); + Assert.Null(store.GetNode("bob")); + } + + [Fact] + public void Transaction_RollbackBeforeBegin_ThrowsException() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + + // Act & Assert + Assert.Throws(() => txn.Rollback()); + } + + [Fact] + public void Transaction_RollbackAfterCommit_ThrowsException() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + txn.Begin(); + txn.AddNode(CreateTestNode("node1", "PERSON")); + txn.Commit(); + + // Act & Assert + Assert.Throws(() => txn.Rollback()); + } + + #endregion + + #region WAL Integration Tests + + [Fact] + public void Transaction_WithWAL_LogsOperationsBeforeCommit() + { + // Arrange + var walPath = Path.Combine(_testDirectory, "test.wal"); + var wal = new WriteAheadLog(walPath); + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store, wal); + var node = CreateTestNode("alice", "PERSON"); + + // Act + txn.Begin(); + txn.AddNode(node); + txn.Commit(); + + // Assert + var entries = wal.ReadLog(); + Assert.NotEmpty(entries); + Assert.Contains(entries, e => e.OperationType == WALOperationType.AddNode && e.NodeId == "alice"); + Assert.Contains(entries, e => e.OperationType == WALOperationType.Checkpoint); + + wal.Dispose(); + } + + [Fact] + public void Transaction_WithWAL_RollbackDoesNotLogCheckpoint() + { + // Arrange + var walPath = Path.Combine(_testDirectory, "test_rollback.wal"); + var wal = new WriteAheadLog(walPath); + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store, wal); + var node = CreateTestNode("alice", "PERSON"); + + // Act + txn.Begin(); + txn.AddNode(node); + txn.Rollback(); + + // Assert + var entries = wal.ReadLog(); + // No checkpoint should be logged on rollback + Assert.DoesNotContain(entries, e => e.OperationType == WALOperationType.Checkpoint); + + wal.Dispose(); + } + + [Fact] + public void Transaction_WithWAL_SupportsMultipleOperations() + { + // Arrange + var walPath = Path.Combine(_testDirectory, "test_multi.wal"); + var wal = new WriteAheadLog(walPath); + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store, wal); + + // Act + txn.Begin(); + txn.AddNode(CreateTestNode("alice", "PERSON")); + txn.AddNode(CreateTestNode("bob", "PERSON")); + txn.AddEdge(CreateTestEdge("alice", "KNOWS", "bob")); + txn.Commit(); + + // Assert + var entries = wal.ReadLog(); + Assert.Equal(4, entries.Count); // 2 AddNode + 1 AddEdge + 1 Checkpoint + Assert.Equal(WALOperationType.AddNode, entries[0].OperationType); + Assert.Equal(WALOperationType.AddNode, entries[1].OperationType); + Assert.Equal(WALOperationType.AddEdge, entries[2].OperationType); + Assert.Equal(WALOperationType.Checkpoint, entries[3].OperationType); + + wal.Dispose(); + } + + #endregion + + #region FileGraphStore Integration Tests + + [Fact] + public void Transaction_WithFileStore_CommitPersistsData() + { + // Arrange + var storagePath = Path.Combine(_testDirectory, Guid.NewGuid().ToString("N")); + using var store = new FileGraphStore(storagePath); + var txn = new GraphTransaction(store); + + // Act + txn.Begin(); + txn.AddNode(CreateTestNode("alice", "PERSON")); + txn.AddNode(CreateTestNode("bob", "PERSON")); + txn.Commit(); + + // Assert + Assert.Equal(2, store.NodeCount); + } + + [Fact] + public void Transaction_WithFileStoreAndWAL_FullACIDSupport() + { + // Arrange + var storagePath = Path.Combine(_testDirectory, Guid.NewGuid().ToString("N")); + var walPath = Path.Combine(_testDirectory, "full_acid.wal"); + var wal = new WriteAheadLog(walPath); + using var store = new FileGraphStore(storagePath, wal); + var txn = new GraphTransaction(store, wal); + + // Act + txn.Begin(); + txn.AddNode(CreateTestNode("alice", "PERSON")); + txn.AddNode(CreateTestNode("bob", "PERSON")); + txn.AddEdge(CreateTestEdge("alice", "KNOWS", "bob")); + txn.Commit(); + + // Assert - Check store + Assert.Equal(2, store.NodeCount); + Assert.Equal(1, store.EdgeCount); + + // Assert - Check WAL + var entries = wal.ReadLog(); + Assert.Equal(4, entries.Count); // 2 AddNode + 1 AddEdge + 1 Checkpoint + + wal.Dispose(); + } + + [Fact] + public void Transaction_WithFileStoreAndWAL_RollbackPreventsPersistence() + { + // Arrange + var storagePath = Path.Combine(_testDirectory, Guid.NewGuid().ToString("N")); + var walPath = Path.Combine(_testDirectory, "rollback_test.wal"); + var wal = new WriteAheadLog(walPath); + using var store = new FileGraphStore(storagePath, wal); + var txn = new GraphTransaction(store, wal); + + // Act + txn.Begin(); + txn.AddNode(CreateTestNode("alice", "PERSON")); + txn.AddNode(CreateTestNode("bob", "PERSON")); + txn.Rollback(); + + // Assert + Assert.Equal(0, store.NodeCount); + + wal.Dispose(); + } + + #endregion + + #region Dispose Tests + + [Fact] + public void Transaction_Dispose_AutoRollbacksActiveTransaction() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + var node = CreateTestNode("alice", "PERSON"); + + // Act + txn.Begin(); + txn.AddNode(node); + txn.Dispose(); // Should auto-rollback + + // Assert + Assert.Equal(0, store.NodeCount); + } + + [Fact] + public void Transaction_Dispose_DoesNotAffectCommittedTransaction() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + var node = CreateTestNode("alice", "PERSON"); + + // Act + txn.Begin(); + txn.AddNode(node); + txn.Commit(); + txn.Dispose(); + + // Assert + Assert.Equal(1, store.NodeCount); + } + + [Fact] + public void Transaction_UsingStatement_AutoRollbacksOnException() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act & Assert + try + { + using var txn = new GraphTransaction(store); + txn.Begin(); + txn.AddNode(CreateTestNode("alice", "PERSON")); + throw new Exception("Simulated error"); + } + catch (Exception) + { + // Swallow exception - expected in this test + } + + // Transaction should have been rolled back + Assert.Equal(0, store.NodeCount); + } + + #endregion + + #region Error Handling Tests + + [Fact] + public void Transaction_Commit_FailsIfStoreThrows() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + + // Add invalid edge (nodes don't exist) + var edge = CreateTestEdge("nonexistent1", "KNOWS", "nonexistent2"); + + // Act & Assert + txn.Begin(); + txn.AddEdge(edge); + Assert.Throws(() => txn.Commit()); + Assert.Equal(TransactionState.Failed, txn.State); + } + + [Fact] + public void Transaction_FailedTransaction_CanBeRolledBack() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + var edge = CreateTestEdge("nonexistent1", "KNOWS", "nonexistent2"); + + // Act + txn.Begin(); + txn.AddNode(CreateTestNode("alice", "PERSON")); // This is valid + txn.AddEdge(edge); // This will fail on commit + + try + { + txn.Commit(); + } + catch (InvalidOperationException) + { + // Expected + } + + // Transaction is now in Failed state + Assert.Equal(TransactionState.Failed, txn.State); + + // Should be able to rollback + txn.Rollback(); + Assert.Equal(TransactionState.RolledBack, txn.State); + } + + #endregion + + #region ACID Property Tests + + [Fact] + public void Transaction_Atomicity_AllOrNothing() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + + // Act - Transaction with error in the middle + txn.Begin(); + txn.AddNode(CreateTestNode("alice", "PERSON")); + txn.AddNode(CreateTestNode("bob", "PERSON")); + txn.AddEdge(CreateTestEdge("alice", "KNOWS", "charlie")); // charlie doesn't exist - will fail + + try + { + txn.Commit(); + } + catch (InvalidOperationException) + { + // Expected failure + } + + // Assert - Nothing should be committed (Atomicity) + // Because the commit failed, the entire transaction should be reverted + Assert.Equal(TransactionState.Failed, txn.State); + + // Verify store state was rolled back - no nodes should remain + // The compensating rollback should have removed alice and bob + Assert.Equal(0, store.NodeCount); + Assert.Equal(0, store.EdgeCount); + Assert.Null(store.GetNode("alice")); + Assert.Null(store.GetNode("bob")); + } + + [Fact] + public void Transaction_Consistency_GraphRemainsValid() + { + // Arrange + var store = new MemoryGraphStore(); + var txn1 = new GraphTransaction(store); + var txn2 = new GraphTransaction(store); + + // Act - First transaction succeeds + txn1.Begin(); + txn1.AddNode(CreateTestNode("alice", "PERSON")); + txn1.AddNode(CreateTestNode("bob", "PERSON")); + txn1.Commit(); + + // Second transaction uses the nodes from first + txn2.Begin(); + txn2.AddEdge(CreateTestEdge("alice", "KNOWS", "bob")); + txn2.Commit(); + + // Assert - Graph is in valid state + Assert.Equal(2, store.NodeCount); + Assert.Equal(1, store.EdgeCount); + var aliceEdges = store.GetOutgoingEdges("alice"); + Assert.Single(aliceEdges); + } + + [Fact] + public void Transaction_Durability_WithWALSurvivesCrash() + { + // Arrange + var storagePath = Path.Combine(_testDirectory, Guid.NewGuid().ToString("N")); + var walPath = Path.Combine(_testDirectory, "durability_test.wal"); + + // First session - write data + using (var wal = new WriteAheadLog(walPath)) + using (var store = new FileGraphStore(storagePath, wal)) + { + var txn = new GraphTransaction(store, wal); + txn.Begin(); + txn.AddNode(CreateTestNode("alice", "PERSON")); + txn.AddNode(CreateTestNode("bob", "PERSON")); + txn.Commit(); + } + // Simulate crash - dispose everything + + // Second session - verify data survived + using (var wal = new WriteAheadLog(walPath)) + using (var store = new FileGraphStore(storagePath, wal)) + { + // Assert - Data should still be there (Durability) + Assert.Equal(2, store.NodeCount); + Assert.NotNull(store.GetNode("alice")); + Assert.NotNull(store.GetNode("bob")); + + // WAL should show checkpoint + var entries = wal.ReadLog(); + Assert.Contains(entries, e => e.OperationType == WALOperationType.Checkpoint); + } + } + + #endregion + + #region Complex Scenario Tests + + [Fact] + public void Transaction_MultipleSequentialTransactions_WorkCorrectly() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act - Multiple transactions + using (var txn1 = new GraphTransaction(store)) + { + txn1.Begin(); + txn1.AddNode(CreateTestNode("alice", "PERSON")); + txn1.Commit(); + } + + using (var txn2 = new GraphTransaction(store)) + { + txn2.Begin(); + txn2.AddNode(CreateTestNode("bob", "PERSON")); + txn2.Commit(); + } + + using (var txn3 = new GraphTransaction(store)) + { + txn3.Begin(); + txn3.AddEdge(CreateTestEdge("alice", "KNOWS", "bob")); + txn3.Commit(); + } + + // Assert + Assert.Equal(2, store.NodeCount); + Assert.Equal(1, store.EdgeCount); + } + + [Fact] + public void Transaction_MixedSuccessAndRollback_MaintainsConsistency() + { + // Arrange + var store = new MemoryGraphStore(); + + // Transaction 1 - Success + using (var txn1 = new GraphTransaction(store)) + { + txn1.Begin(); + txn1.AddNode(CreateTestNode("alice", "PERSON")); + txn1.Commit(); + } + + // Transaction 2 - Rollback + using (var txn2 = new GraphTransaction(store)) + { + txn2.Begin(); + txn2.AddNode(CreateTestNode("bob", "PERSON")); + txn2.Rollback(); + } + + // Transaction 3 - Success + using (var txn3 = new GraphTransaction(store)) + { + txn3.Begin(); + txn3.AddNode(CreateTestNode("charlie", "PERSON")); + txn3.Commit(); + } + + // Assert + Assert.Equal(2, store.NodeCount); // Only alice and charlie + Assert.NotNull(store.GetNode("alice")); + Assert.Null(store.GetNode("bob")); // Rolled back + Assert.NotNull(store.GetNode("charlie")); + } + + [Fact] + public void Transaction_LargeTransaction_HandlesMultipleOperations() + { + // Arrange + var store = new MemoryGraphStore(); + var txn = new GraphTransaction(store); + + // Act - Add 100 nodes and 99 edges in a single transaction + txn.Begin(); + for (int i = 0; i < 100; i++) + { + txn.AddNode(CreateTestNode($"node{i}", "PERSON")); + } + for (int i = 0; i < 99; i++) + { + txn.AddEdge(CreateTestEdge($"node{i}", "NEXT", $"node{i + 1}")); + } + txn.Commit(); + + // Assert + Assert.Equal(100, store.NodeCount); + Assert.Equal(99, store.EdgeCount); + } + + #endregion + } +} diff --git a/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/HybridGraphRetrieverTests.cs b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/HybridGraphRetrieverTests.cs new file mode 100644 index 000000000..3968d0cfb --- /dev/null +++ b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/HybridGraphRetrieverTests.cs @@ -0,0 +1,398 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using AiDotNet.Interfaces; +using AiDotNet.LinearAlgebra; +using AiDotNet.RetrievalAugmentedGeneration.DocumentStores; +using AiDotNet.RetrievalAugmentedGeneration.Graph; +using AiDotNet.RetrievalAugmentedGeneration.Models; +using Xunit; + +namespace AiDotNetTests.UnitTests.RetrievalAugmentedGeneration +{ + public class HybridGraphRetrieverTests + { + private readonly KnowledgeGraph _graph; + private readonly InMemoryDocumentStore _documentStore; + private readonly HybridGraphRetriever _retriever; + + public HybridGraphRetrieverTests() + { + // Create knowledge graph + _graph = new KnowledgeGraph(); + + // Create document store with 3-dimensional embeddings + _documentStore = new InMemoryDocumentStore(3); + + // Create retriever (no similarity metric needed - it uses StatisticsHelper internally) + _retriever = new HybridGraphRetriever(_graph, _documentStore); + + // Setup test data + SetupTestData(); + } + + private void SetupTestData() + { + // Create nodes with embeddings + var aliceEmbedding = new Vector(new double[] { 1.0, 0.0, 0.0 }); + var alice = new GraphNode("alice", "Person") + { + Properties = new Dictionary { { "name", "Alice" } }, + Embedding = aliceEmbedding + }; + + var bobEmbedding = new Vector(new double[] { 0.8, 0.2, 0.0 }); + var bob = new GraphNode("bob", "Person") + { + Properties = new Dictionary { { "name", "Bob" } }, + Embedding = bobEmbedding + }; + + var charlieEmbedding = new Vector(new double[] { 0.5, 0.5, 0.0 }); + var charlie = new GraphNode("charlie", "Person") + { + Properties = new Dictionary { { "name", "Charlie" } }, + Embedding = charlieEmbedding + }; + + var davidEmbedding = new Vector(new double[] { 0.2, 0.8, 0.0 }); + var david = new GraphNode("david", "Person") + { + Properties = new Dictionary { { "name", "David" } }, + Embedding = davidEmbedding + }; + + // Add nodes to graph + _graph.AddNode(alice); + _graph.AddNode(bob); + _graph.AddNode(charlie); + _graph.AddNode(david); + + // Create edges (social network) + _graph.AddEdge(new GraphEdge("alice", "bob", "KNOWS") { Weight = 1.0 }); + _graph.AddEdge(new GraphEdge("bob", "charlie", "KNOWS") { Weight = 1.0 }); + _graph.AddEdge(new GraphEdge("charlie", "david", "KNOWS") { Weight = 1.0 }); + + // Add documents to document store + _documentStore.Add(new VectorDocument( + new Document("alice", "Alice content"), + aliceEmbedding)); + + _documentStore.Add(new VectorDocument( + new Document("bob", "Bob content"), + bobEmbedding)); + + _documentStore.Add(new VectorDocument( + new Document("charlie", "Charlie content"), + charlieEmbedding)); + + _documentStore.Add(new VectorDocument( + new Document("david", "David content"), + davidEmbedding)); + } + + #region Basic Retrieval Tests + + [Fact] + public void Retrieve_WithoutExpansion_ReturnsOnlyVectorResults() + { + // Arrange + var query = new Vector(new double[] { 1.0, 0.0, 0.0 }); // Similar to Alice + + // Act + var results = _retriever.Retrieve(query, topK: 2, expansionDepth: 0, maxResults: 10); + + // Assert + Assert.Equal(2, results.Count); + Assert.All(results, r => Assert.Equal(RetrievalSource.VectorSearch, r.Source)); + Assert.Equal(0, results[0].Depth); + Assert.Equal("alice", results[0].NodeId); // Most similar + } + + [Fact] + public void Retrieve_WithExpansion_IncludesGraphNeighbors() + { + // Arrange + var query = new Vector(new double[] { 1.0, 0.0, 0.0 }); // Similar to Alice + + // Act + var results = _retriever.Retrieve(query, topK: 1, expansionDepth: 1, maxResults: 10); + + // Assert + Assert.True(results.Count > 1); // Should include Alice + neighbors + Assert.Contains(results, r => r.NodeId == "alice" && r.Source == RetrievalSource.VectorSearch); + Assert.Contains(results, r => r.NodeId == "bob" && r.Source == RetrievalSource.GraphTraversal); + } + + [Fact] + public void Retrieve_WithDepth2_ReachesDistantNodes() + { + // Arrange + var query = new Vector(new double[] { 1.0, 0.0, 0.0 }); // Similar to Alice + + // Act + var results = _retriever.Retrieve(query, topK: 1, expansionDepth: 2, maxResults: 10); + + // Assert + // Should reach: Alice (0-hop) -> Bob (1-hop) -> Charlie (2-hop) + Assert.Contains(results, r => r.NodeId == "charlie"); + var charlie = results.First(r => r.NodeId == "charlie"); + Assert.Equal(2, charlie.Depth); + Assert.Equal(RetrievalSource.GraphTraversal, charlie.Source); + } + + #endregion + + #region Depth Penalty Tests + + [Fact] + public void Retrieve_AppliesDepthPenalty_CloserNodesRankHigher() + { + // Arrange + var query = new Vector(new double[] { 1.0, 0.0, 0.0 }); + + // Act + var results = _retriever.Retrieve(query, topK: 1, expansionDepth: 2, maxResults: 10); + + // Assert - Closer nodes should have higher scores due to depth penalty + var bob = results.FirstOrDefault(r => r.NodeId == "bob"); + var charlie = results.FirstOrDefault(r => r.NodeId == "charlie"); + + if (bob != null && charlie != null) + { + // Bob (1-hop) should score higher than Charlie (2-hop) due to depth penalty + // even if their raw vector similarities are similar + Assert.True(bob.Depth < charlie.Depth); + } + } + + #endregion + + #region Relationship-Aware Retrieval Tests + + [Fact] + public void RetrieveWithRelationships_UsesRelationshipWeights() + { + // Arrange + var query = new Vector(new double[] { 1.0, 0.0, 0.0 }); + var weights = new Dictionary + { + { "KNOWS", 1.5 } // Boost KNOWS relationships + }; + + // Act + var results = _retriever.RetrieveWithRelationships(query, topK: 1, weights, maxResults: 10); + + // Assert + Assert.NotEmpty(results); + var traversedResults = results.Where(r => r.Source == RetrievalSource.GraphTraversal).ToList(); + Assert.NotEmpty(traversedResults); + Assert.All(traversedResults, r => Assert.Equal("KNOWS", r.RelationType)); + } + + [Fact] + public void RetrieveWithRelationships_IncludesRelationshipInfo() + { + // Arrange + var query = new Vector(new double[] { 1.0, 0.0, 0.0 }); + + // Act + var results = _retriever.RetrieveWithRelationships(query, topK: 1, maxResults: 10); + + // Assert + var bobResult = results.FirstOrDefault(r => r.NodeId == "bob"); + Assert.NotNull(bobResult); + Assert.Equal("KNOWS", bobResult.RelationType); + Assert.Equal("alice", bobResult.ParentNodeId); + } + + #endregion + + #region MaxResults Tests + + [Fact] + public void Retrieve_RespectsMaxResults() + { + // Arrange + var query = new Vector(new double[] { 1.0, 0.0, 0.0 }); + + // Act + var results = _retriever.Retrieve(query, topK: 2, expansionDepth: 2, maxResults: 2); + + // Assert + Assert.Equal(2, results.Count); + } + + [Fact] + public void Retrieve_SortsByScore() + { + // Arrange + var query = new Vector(new double[] { 1.0, 0.0, 0.0 }); + + // Act + var results = _retriever.Retrieve(query, topK: 2, expansionDepth: 1, maxResults: 10); + + // Assert - Results should be sorted by score descending + for (int i = 0; i < results.Count - 1; i++) + { + Assert.True(results[i].Score >= results[i + 1].Score); + } + } + + #endregion + + #region Error Handling Tests + + [Fact] + public void Retrieve_NullEmbedding_ThrowsException() + { + // Act & Assert + Assert.Throws(() => + _retriever.Retrieve(null!, topK: 5)); + } + + [Fact] + public void Retrieve_EmptyEmbedding_ThrowsException() + { + // Act & Assert + Assert.Throws(() => + _retriever.Retrieve(new Vector(Array.Empty()), topK: 5)); + } + + [Fact] + public void Retrieve_InvalidTopK_ThrowsException() + { + // Arrange + var query = new Vector(new double[] { 1.0, 0.0, 0.0 }); + + // Act & Assert + Assert.Throws(() => + _retriever.Retrieve(query, topK: 0)); + + Assert.Throws(() => + _retriever.Retrieve(query, topK: -1)); + } + + [Fact] + public void Retrieve_NegativeExpansionDepth_ThrowsException() + { + // Arrange + var query = new Vector(new double[] { 1.0, 0.0, 0.0 }); + + // Act & Assert + Assert.Throws(() => + _retriever.Retrieve(query, topK: 5, expansionDepth: -1)); + } + + #endregion + + #region Result Properties Tests + + [Fact] + public void Retrieve_PopulatesResultProperties() + { + // Arrange + var query = new Vector(new double[] { 1.0, 0.0, 0.0 }); + + // Act + var results = _retriever.Retrieve(query, topK: 1, expansionDepth: 1, maxResults: 10); + + // Assert + Assert.All(results, r => + { + Assert.NotNull(r.NodeId); + Assert.True(r.Score >= 0.0); + Assert.True(r.Depth >= 0); + + if (r.Source == RetrievalSource.VectorSearch) + { + Assert.Equal(0, r.Depth); + Assert.Null(r.ParentNodeId); + } + else if (r.Source == RetrievalSource.GraphTraversal) + { + Assert.True(r.Depth > 0); + Assert.NotNull(r.ParentNodeId); + } + }); + } + + [Fact] + public void Retrieve_IncludesEmbeddings() + { + // Arrange + var query = new Vector(new double[] { 1.0, 0.0, 0.0 }); + + // Act + var results = _retriever.Retrieve(query, topK: 1, expansionDepth: 0, maxResults: 10); + + // Assert - embeddings may or may not be populated depending on the source + Assert.NotEmpty(results); + } + + #endregion + + #region Complex Scenario Tests + + [Fact] + public void Retrieve_ComplexGraph_ProducesCoherentResults() + { + // Arrange - Add more complex graph structure + var graph = new KnowledgeGraph(); + var documentStore = new InMemoryDocumentStore(3); + + // Create a small community + for (int i = 0; i < 5; i++) + { + var embedding = new Vector(new double[] { i * 0.2, 1 - i * 0.2, 0.0 }); + var node = new GraphNode($"user{i}", "Person") + { + Properties = new Dictionary { { "name", $"User{i}" } }, + Embedding = embedding + }; + graph.AddNode(node); + documentStore.Add(new VectorDocument( + new Document($"user{i}", $"User{i} content"), + embedding)); + } + + // Create connections + for (int i = 0; i < 4; i++) + { + graph.AddEdge(new GraphEdge($"user{i}", $"user{i + 1}", "FRIENDS_WITH") { Weight = 1.0 }); + } + + var retriever = new HybridGraphRetriever(graph, documentStore); + var query = new Vector(new double[] { 0.0, 1.0, 0.0 }); // Similar to user0 + + // Act + var results = retriever.Retrieve(query, topK: 1, expansionDepth: 2, maxResults: 5); + + // Assert + Assert.NotEmpty(results); + Assert.True(results.Count <= 5); + Assert.Contains(results, r => r.Source == RetrievalSource.VectorSearch); + } + + #endregion + + #region Async Tests + + [Fact] + public async Task RetrieveAsync_WorksCorrectly() + { + // Arrange + var query = new Vector(new double[] { 1.0, 0.0, 0.0 }); + + // Act + var results = await _retriever.RetrieveAsync(query, topK: 2, expansionDepth: 1, maxResults: 10); + + // Assert + Assert.NotEmpty(results); + Assert.Contains(results, r => r.Source == RetrievalSource.VectorSearch); + } + + #endregion + } +} diff --git a/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/MemoryGraphStoreTests.cs b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/MemoryGraphStoreTests.cs new file mode 100644 index 000000000..9f803f7e2 --- /dev/null +++ b/tests/AiDotNet.Tests/UnitTests/RetrievalAugmentedGeneration/MemoryGraphStoreTests.cs @@ -0,0 +1,740 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using AiDotNet.Interfaces; +using AiDotNet.LinearAlgebra; +using AiDotNet.RetrievalAugmentedGeneration.Graph; +using Xunit; + +namespace AiDotNetTests.UnitTests.RetrievalAugmentedGeneration +{ + public class MemoryGraphStoreTests + { + private GraphNode CreateTestNode(string id, string label, Dictionary? properties = null) + { + var node = new GraphNode(id, label); + if (properties != null) + { + foreach (var kvp in properties) + { + node.SetProperty(kvp.Key, kvp.Value); + } + } + return node; + } + + private GraphEdge CreateTestEdge(string sourceId, string relationType, string targetId, double weight = 1.0) + { + return new GraphEdge(sourceId, targetId, relationType, weight); + } + + #region Constructor Tests + + [Fact] + public void Constructor_InitializesEmptyStore() + { + // Arrange & Act + var store = new MemoryGraphStore(); + + // Assert + Assert.Equal(0, store.NodeCount); + Assert.Equal(0, store.EdgeCount); + } + + #endregion + + #region AddNode Tests + + [Fact] + public void AddNode_WithValidNode_IncreasesNodeCount() + { + // Arrange + var store = new MemoryGraphStore(); + var node = CreateTestNode("node1", "PERSON"); + + // Act + store.AddNode(node); + + // Assert + Assert.Equal(1, store.NodeCount); + } + + [Fact] + public void AddNode_WithNullNode_ThrowsArgumentNullException() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act & Assert + Assert.Throws(() => store.AddNode(null!)); + } + + [Fact] + public void AddNode_WithDuplicateId_UpdatesExistingNode() + { + // Arrange + var store = new MemoryGraphStore(); + var node1 = CreateTestNode("node1", "PERSON", new Dictionary { { "name", "Alice" } }); + var node2 = CreateTestNode("node1", "PERSON", new Dictionary { { "name", "Alice Updated" } }); + + // Act + store.AddNode(node1); + store.AddNode(node2); + + // Assert + Assert.Equal(1, store.NodeCount); + var retrieved = store.GetNode("node1"); + Assert.NotNull(retrieved); + Assert.Equal("Alice Updated", retrieved.GetProperty("name")); + } + + [Fact] + public void AddNode_WithMultipleLabels_IndexesCorrectly() + { + // Arrange + var store = new MemoryGraphStore(); + var person1 = CreateTestNode("person1", "PERSON"); + var person2 = CreateTestNode("person2", "PERSON"); + var company = CreateTestNode("company1", "COMPANY"); + + // Act + store.AddNode(person1); + store.AddNode(person2); + store.AddNode(company); + + // Assert + Assert.Equal(3, store.NodeCount); + Assert.Equal(2, store.GetNodesByLabel("PERSON").Count()); + Assert.Single(store.GetNodesByLabel("COMPANY")); + } + + #endregion + + #region AddEdge Tests + + [Fact] + public void AddEdge_WithValidEdge_IncreasesEdgeCount() + { + // Arrange + var store = new MemoryGraphStore(); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "COMPANY"); + store.AddNode(node1); + store.AddNode(node2); + var edge = CreateTestEdge("node1", "WORKS_AT", "node2"); + + // Act + store.AddEdge(edge); + + // Assert + Assert.Equal(1, store.EdgeCount); + } + + [Fact] + public void AddEdge_WithNullEdge_ThrowsArgumentNullException() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act & Assert + Assert.Throws(() => store.AddEdge(null!)); + } + + [Fact] + public void AddEdge_WithNonexistentSourceNode_ThrowsInvalidOperationException() + { + // Arrange + var store = new MemoryGraphStore(); + var node = CreateTestNode("node1", "PERSON"); + store.AddNode(node); + var edge = CreateTestEdge("nonexistent", "KNOWS", "node1"); + + // Act & Assert + var exception = Assert.Throws(() => store.AddEdge(edge)); + Assert.Contains("Source node 'nonexistent' does not exist", exception.Message); + } + + [Fact] + public void AddEdge_WithNonexistentTargetNode_ThrowsInvalidOperationException() + { + // Arrange + var store = new MemoryGraphStore(); + var node = CreateTestNode("node1", "PERSON"); + store.AddNode(node); + var edge = CreateTestEdge("node1", "KNOWS", "nonexistent"); + + // Act & Assert + var exception = Assert.Throws(() => store.AddEdge(edge)); + Assert.Contains("Target node 'nonexistent' does not exist", exception.Message); + } + + #endregion + + #region GetNode Tests + + [Fact] + public void GetNode_WithExistingId_ReturnsNode() + { + // Arrange + var store = new MemoryGraphStore(); + var node = CreateTestNode("node1", "PERSON"); + store.AddNode(node); + + // Act + var retrieved = store.GetNode("node1"); + + // Assert + Assert.NotNull(retrieved); + Assert.Equal("node1", retrieved.Id); + Assert.Equal("PERSON", retrieved.Label); + } + + [Fact] + public void GetNode_WithNonexistentId_ReturnsNull() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act + var retrieved = store.GetNode("nonexistent"); + + // Assert + Assert.Null(retrieved); + } + + #endregion + + #region GetEdge Tests + + [Fact] + public void GetEdge_WithExistingId_ReturnsEdge() + { + // Arrange + var store = new MemoryGraphStore(); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "COMPANY"); + store.AddNode(node1); + store.AddNode(node2); + var edge = CreateTestEdge("node1", "WORKS_AT", "node2"); + store.AddEdge(edge); + + // Act + var retrieved = store.GetEdge(edge.Id); + + // Assert + Assert.NotNull(retrieved); + Assert.Equal(edge.Id, retrieved.Id); + Assert.Equal("node1", retrieved.SourceId); + Assert.Equal("node2", retrieved.TargetId); + } + + [Fact] + public void GetEdge_WithNonexistentId_ReturnsNull() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act + var retrieved = store.GetEdge("nonexistent"); + + // Assert + Assert.Null(retrieved); + } + + #endregion + + #region RemoveNode Tests + + [Fact] + public void RemoveNode_WithExistingNode_RemovesNodeAndReturnsTrue() + { + // Arrange + var store = new MemoryGraphStore(); + var node = CreateTestNode("node1", "PERSON"); + store.AddNode(node); + + // Act + var result = store.RemoveNode("node1"); + + // Assert + Assert.True(result); + Assert.Equal(0, store.NodeCount); + Assert.Null(store.GetNode("node1")); + } + + [Fact] + public void RemoveNode_WithNonexistentNode_ReturnsFalse() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act + var result = store.RemoveNode("nonexistent"); + + // Assert + Assert.False(result); + } + + [Fact] + public void RemoveNode_RemovesAllConnectedEdges() + { + // Arrange + var store = new MemoryGraphStore(); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "PERSON"); + var node3 = CreateTestNode("node3", "COMPANY"); + store.AddNode(node1); + store.AddNode(node2); + store.AddNode(node3); + + var edge1 = CreateTestEdge("node1", "KNOWS", "node2"); + var edge2 = CreateTestEdge("node2", "WORKS_AT", "node3"); + var edge3 = CreateTestEdge("node1", "WORKS_AT", "node3"); + store.AddEdge(edge1); + store.AddEdge(edge2); + store.AddEdge(edge3); + + // Act + store.RemoveNode("node1"); + + // Assert + Assert.Equal(2, store.NodeCount); + Assert.Equal(1, store.EdgeCount); // Only edge2 remains + Assert.Null(store.GetEdge(edge1.Id)); + Assert.Null(store.GetEdge(edge3.Id)); + Assert.NotNull(store.GetEdge(edge2.Id)); + } + + [Fact] + public void RemoveNode_RemovesFromLabelIndex() + { + // Arrange + var store = new MemoryGraphStore(); + var person1 = CreateTestNode("person1", "PERSON"); + var person2 = CreateTestNode("person2", "PERSON"); + store.AddNode(person1); + store.AddNode(person2); + + // Act + store.RemoveNode("person1"); + + // Assert + var personsRemaining = store.GetNodesByLabel("PERSON").ToList(); + Assert.Single(personsRemaining); + Assert.Equal("person2", personsRemaining[0].Id); + } + + #endregion + + #region RemoveEdge Tests + + [Fact] + public void RemoveEdge_WithExistingEdge_RemovesEdgeAndReturnsTrue() + { + // Arrange + var store = new MemoryGraphStore(); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "COMPANY"); + store.AddNode(node1); + store.AddNode(node2); + var edge = CreateTestEdge("node1", "WORKS_AT", "node2"); + store.AddEdge(edge); + + // Act + var result = store.RemoveEdge(edge.Id); + + // Assert + Assert.True(result); + Assert.Equal(0, store.EdgeCount); + Assert.Null(store.GetEdge(edge.Id)); + } + + [Fact] + public void RemoveEdge_WithNonexistentEdge_ReturnsFalse() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act + var result = store.RemoveEdge("nonexistent"); + + // Assert + Assert.False(result); + } + + [Fact] + public void RemoveEdge_DoesNotRemoveNodes() + { + // Arrange + var store = new MemoryGraphStore(); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "COMPANY"); + store.AddNode(node1); + store.AddNode(node2); + var edge = CreateTestEdge("node1", "WORKS_AT", "node2"); + store.AddEdge(edge); + + // Act + store.RemoveEdge(edge.Id); + + // Assert + Assert.Equal(2, store.NodeCount); + Assert.NotNull(store.GetNode("node1")); + Assert.NotNull(store.GetNode("node2")); + } + + #endregion + + #region GetOutgoingEdges Tests + + [Fact] + public void GetOutgoingEdges_WithExistingEdges_ReturnsCorrectEdges() + { + // Arrange + var store = new MemoryGraphStore(); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "PERSON"); + var node3 = CreateTestNode("node3", "COMPANY"); + store.AddNode(node1); + store.AddNode(node2); + store.AddNode(node3); + + var edge1 = CreateTestEdge("node1", "KNOWS", "node2"); + var edge2 = CreateTestEdge("node1", "WORKS_AT", "node3"); + var edge3 = CreateTestEdge("node2", "WORKS_AT", "node3"); + store.AddEdge(edge1); + store.AddEdge(edge2); + store.AddEdge(edge3); + + // Act + var outgoing = store.GetOutgoingEdges("node1").ToList(); + + // Assert + Assert.Equal(2, outgoing.Count); + Assert.Contains(outgoing, e => e.Id == edge1.Id); + Assert.Contains(outgoing, e => e.Id == edge2.Id); + } + + [Fact] + public void GetOutgoingEdges_WithNoEdges_ReturnsEmpty() + { + // Arrange + var store = new MemoryGraphStore(); + var node = CreateTestNode("node1", "PERSON"); + store.AddNode(node); + + // Act + var outgoing = store.GetOutgoingEdges("node1"); + + // Assert + Assert.Empty(outgoing); + } + + [Fact] + public void GetOutgoingEdges_WithNonexistentNode_ReturnsEmpty() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act + var outgoing = store.GetOutgoingEdges("nonexistent"); + + // Assert + Assert.Empty(outgoing); + } + + #endregion + + #region GetIncomingEdges Tests + + [Fact] + public void GetIncomingEdges_WithExistingEdges_ReturnsCorrectEdges() + { + // Arrange + var store = new MemoryGraphStore(); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "PERSON"); + var node3 = CreateTestNode("node3", "COMPANY"); + store.AddNode(node1); + store.AddNode(node2); + store.AddNode(node3); + + var edge1 = CreateTestEdge("node1", "WORKS_AT", "node3"); + var edge2 = CreateTestEdge("node2", "WORKS_AT", "node3"); + var edge3 = CreateTestEdge("node1", "KNOWS", "node2"); + store.AddEdge(edge1); + store.AddEdge(edge2); + store.AddEdge(edge3); + + // Act + var incoming = store.GetIncomingEdges("node3").ToList(); + + // Assert + Assert.Equal(2, incoming.Count); + Assert.Contains(incoming, e => e.Id == edge1.Id); + Assert.Contains(incoming, e => e.Id == edge2.Id); + } + + [Fact] + public void GetIncomingEdges_WithNoEdges_ReturnsEmpty() + { + // Arrange + var store = new MemoryGraphStore(); + var node = CreateTestNode("node1", "PERSON"); + store.AddNode(node); + + // Act + var incoming = store.GetIncomingEdges("node1"); + + // Assert + Assert.Empty(incoming); + } + + [Fact] + public void GetIncomingEdges_WithNonexistentNode_ReturnsEmpty() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act + var incoming = store.GetIncomingEdges("nonexistent"); + + // Assert + Assert.Empty(incoming); + } + + #endregion + + #region GetNodesByLabel Tests + + [Fact] + public void GetNodesByLabel_WithMatchingNodes_ReturnsCorrectNodes() + { + // Arrange + var store = new MemoryGraphStore(); + var person1 = CreateTestNode("person1", "PERSON"); + var person2 = CreateTestNode("person2", "PERSON"); + var company = CreateTestNode("company1", "COMPANY"); + store.AddNode(person1); + store.AddNode(person2); + store.AddNode(company); + + // Act + var persons = store.GetNodesByLabel("PERSON").ToList(); + + // Assert + Assert.Equal(2, persons.Count); + Assert.Contains(persons, n => n.Id == "person1"); + Assert.Contains(persons, n => n.Id == "person2"); + } + + [Fact] + public void GetNodesByLabel_WithNoMatches_ReturnsEmpty() + { + // Arrange + var store = new MemoryGraphStore(); + var person = CreateTestNode("person1", "PERSON"); + store.AddNode(person); + + // Act + var companies = store.GetNodesByLabel("COMPANY"); + + // Assert + Assert.Empty(companies); + } + + [Fact] + public void GetNodesByLabel_WithNonexistentLabel_ReturnsEmpty() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act + var nodes = store.GetNodesByLabel("NONEXISTENT"); + + // Assert + Assert.Empty(nodes); + } + + #endregion + + #region GetAllNodes Tests + + [Fact] + public void GetAllNodes_WithMultipleNodes_ReturnsAllNodes() + { + // Arrange + var store = new MemoryGraphStore(); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "COMPANY"); + var node3 = CreateTestNode("node3", "LOCATION"); + store.AddNode(node1); + store.AddNode(node2); + store.AddNode(node3); + + // Act + var allNodes = store.GetAllNodes().ToList(); + + // Assert + Assert.Equal(3, allNodes.Count); + Assert.Contains(allNodes, n => n.Id == "node1"); + Assert.Contains(allNodes, n => n.Id == "node2"); + Assert.Contains(allNodes, n => n.Id == "node3"); + } + + [Fact] + public void GetAllNodes_WithEmptyStore_ReturnsEmpty() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act + var allNodes = store.GetAllNodes(); + + // Assert + Assert.Empty(allNodes); + } + + #endregion + + #region GetAllEdges Tests + + [Fact] + public void GetAllEdges_WithMultipleEdges_ReturnsAllEdges() + { + // Arrange + var store = new MemoryGraphStore(); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "PERSON"); + var node3 = CreateTestNode("node3", "COMPANY"); + store.AddNode(node1); + store.AddNode(node2); + store.AddNode(node3); + + var edge1 = CreateTestEdge("node1", "KNOWS", "node2"); + var edge2 = CreateTestEdge("node1", "WORKS_AT", "node3"); + var edge3 = CreateTestEdge("node2", "WORKS_AT", "node3"); + store.AddEdge(edge1); + store.AddEdge(edge2); + store.AddEdge(edge3); + + // Act + var allEdges = store.GetAllEdges().ToList(); + + // Assert + Assert.Equal(3, allEdges.Count); + Assert.Contains(allEdges, e => e.Id == edge1.Id); + Assert.Contains(allEdges, e => e.Id == edge2.Id); + Assert.Contains(allEdges, e => e.Id == edge3.Id); + } + + [Fact] + public void GetAllEdges_WithEmptyStore_ReturnsEmpty() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act + var allEdges = store.GetAllEdges(); + + // Assert + Assert.Empty(allEdges); + } + + #endregion + + #region Clear Tests + + [Fact] + public void Clear_RemovesAllNodesAndEdges() + { + // Arrange + var store = new MemoryGraphStore(); + var node1 = CreateTestNode("node1", "PERSON"); + var node2 = CreateTestNode("node2", "COMPANY"); + store.AddNode(node1); + store.AddNode(node2); + var edge = CreateTestEdge("node1", "WORKS_AT", "node2"); + store.AddEdge(edge); + + // Act + store.Clear(); + + // Assert + Assert.Equal(0, store.NodeCount); + Assert.Equal(0, store.EdgeCount); + Assert.Empty(store.GetAllNodes()); + Assert.Empty(store.GetAllEdges()); + } + + [Fact] + public void Clear_OnEmptyStore_DoesNotThrow() + { + // Arrange + var store = new MemoryGraphStore(); + + // Act & Assert + store.Clear(); // Should not throw + Assert.Equal(0, store.NodeCount); + Assert.Equal(0, store.EdgeCount); + } + + #endregion + + #region Integration Tests + + [Fact] + public void ComplexGraph_WithMultipleOperations_MaintainsConsistency() + { + // Arrange + var store = new MemoryGraphStore(); + + // Create a small social network + var alice = CreateTestNode("alice", "PERSON", new Dictionary { { "name", "Alice" } }); + var bob = CreateTestNode("bob", "PERSON", new Dictionary { { "name", "Bob" } }); + var charlie = CreateTestNode("charlie", "PERSON", new Dictionary { { "name", "Charlie" } }); + var acme = CreateTestNode("acme", "COMPANY", new Dictionary { { "name", "Acme Corp" } }); + + store.AddNode(alice); + store.AddNode(bob); + store.AddNode(charlie); + store.AddNode(acme); + + var edge1 = CreateTestEdge("alice", "KNOWS", "bob", 0.9); + var edge2 = CreateTestEdge("bob", "KNOWS", "charlie", 0.8); + var edge3 = CreateTestEdge("alice", "WORKS_AT", "acme", 1.0); + var edge4 = CreateTestEdge("bob", "WORKS_AT", "acme", 1.0); + + store.AddEdge(edge1); + store.AddEdge(edge2); + store.AddEdge(edge3); + store.AddEdge(edge4); + + // Act & Assert - Verify initial state + Assert.Equal(4, store.NodeCount); + Assert.Equal(4, store.EdgeCount); + Assert.Equal(3, store.GetNodesByLabel("PERSON").Count()); + Assert.Single(store.GetNodesByLabel("COMPANY")); + + // Act - Remove Bob + store.RemoveNode("bob"); + + // Assert - Bob and his edges are gone + // Bob has 3 edges: edge1 (incoming from alice), edge2 (outgoing to charlie), edge4 (outgoing to acme) + // Only edge3 (alice -> acme) remains, so EdgeCount = 1 + Assert.Equal(3, store.NodeCount); + Assert.Equal(1, store.EdgeCount); + Assert.Null(store.GetNode("bob")); + Assert.Null(store.GetEdge(edge1.Id)); + Assert.Null(store.GetEdge(edge2.Id)); + Assert.Null(store.GetEdge(edge4.Id)); + + // Assert - Alice's edge to Acme still exists + Assert.NotNull(store.GetEdge(edge3.Id)); + Assert.Single(store.GetOutgoingEdges("alice")); + Assert.Single(store.GetIncomingEdges("acme")); + } + + #endregion + } +}