diff --git a/src/Enums/EdgeDirection.cs b/src/Enums/EdgeDirection.cs
new file mode 100644
index 000000000..9d9a2e53f
--- /dev/null
+++ b/src/Enums/EdgeDirection.cs
@@ -0,0 +1,36 @@
+namespace AiDotNet.Enums;
+
+///
+/// Specifies the direction of edges to retrieve when querying a knowledge graph.
+///
+///
+/// For Beginners: In a directed graph, edges have a direction - they go FROM one node TO another.
+///
+/// Think of it like Twitter follows:
+/// - If Alice follows Bob, the edge goes FROM Alice TO Bob
+/// - Alice has an OUTGOING edge (she's following someone)
+/// - Bob has an INCOMING edge (someone is following him)
+///
+/// When querying relationships:
+/// - Outgoing: "Who does this person follow?" (edges starting from this node)
+/// - Incoming: "Who follows this person?" (edges pointing to this node)
+/// - Both: "All connections" (both directions)
+///
+///
+public enum EdgeDirection
+{
+ ///
+ /// Retrieve only outgoing edges (edges where the specified node is the source).
+ ///
+ Outgoing,
+
+ ///
+ /// Retrieve only incoming edges (edges where the specified node is the target).
+ ///
+ Incoming,
+
+ ///
+ /// Retrieve both outgoing and incoming edges.
+ ///
+ Both
+}
diff --git a/src/Interfaces/IGraphStore.cs b/src/Interfaces/IGraphStore.cs
new file mode 100644
index 000000000..a5d033384
--- /dev/null
+++ b/src/Interfaces/IGraphStore.cs
@@ -0,0 +1,381 @@
+using System.Collections.Generic;
+using System.Threading.Tasks;
+using AiDotNet.RetrievalAugmentedGeneration.Graph;
+
+namespace AiDotNet.Interfaces;
+
+///
+/// Defines the contract for graph storage backends that manage nodes and edges.
+///
+///
+///
+/// A graph store provides persistent or in-memory storage for knowledge graphs,
+/// enabling efficient storage and retrieval of entities (nodes) and their relationships (edges).
+/// Implementations can range from simple in-memory dictionaries to distributed graph databases.
+///
+/// For Beginners: A graph store is like a filing system for connected information.
+///
+/// Think of it like organizing a network of friends:
+/// - Nodes are people (Alice, Bob, Charlie)
+/// - Edges are relationships (Alice KNOWS Bob, Bob WORKS_WITH Charlie)
+/// - The graph store remembers all these connections
+///
+/// Different implementations might:
+/// - MemoryGraphStore: Keep everything in RAM (fast but lost when app closes)
+/// - FileGraphStore: Save to disk (slower but survives restarts)
+/// - Neo4jGraphStore: Use a professional graph database (production-scale)
+///
+/// This interface lets you swap storage backends without changing your code!
+///
+///
+/// The numeric data type used for vector calculations (typically float or double).
+public interface IGraphStore
+{
+ ///
+ /// Gets the total number of nodes in the graph store.
+ ///
+ ///
+ /// For Beginners: This tells you how many entities (people, places, things)
+ /// are stored in the graph.
+ ///
+ ///
+ int NodeCount { get; }
+
+ ///
+ /// Gets the total number of edges in the graph store.
+ ///
+ ///
+ /// For Beginners: This tells you how many relationships/connections
+ /// exist between entities in the graph.
+ ///
+ ///
+ int EdgeCount { get; }
+
+ ///
+ /// Adds a node to the graph or updates it if it already exists.
+ ///
+ /// The node to add.
+ ///
+ ///
+ /// This method stores a node in the graph. If a node with the same ID already exists,
+ /// it will be updated with the new data. The node is automatically indexed by its label
+ /// for efficient label-based queries.
+ ///
+ /// For Beginners: This adds a new entity to the graph.
+ ///
+ /// Like adding a person to a social network:
+ /// - node.Id = "alice_001"
+ /// - node.Label = "PERSON"
+ /// - node.Properties = { "name": "Alice Smith", "age": 30 }
+ ///
+ /// If Alice already exists, her information gets updated.
+ ///
+ ///
+ void AddNode(GraphNode node);
+
+ ///
+ /// Adds an edge to the graph representing a relationship between two nodes.
+ ///
+ /// The edge to add.
+ ///
+ ///
+ /// This method creates a relationship between two existing nodes. Both the source
+ /// and target nodes must already exist in the graph, otherwise an exception is thrown.
+ /// The edge is indexed for efficient traversal from both directions.
+ ///
+ /// For Beginners: This adds a connection between two entities.
+ ///
+ /// Like saying "Alice knows Bob":
+ /// - edge.SourceId = "alice_001"
+ /// - edge.RelationType = "KNOWS"
+ /// - edge.TargetId = "bob_002"
+ /// - edge.Weight = 0.9 (how strong the relationship is)
+ ///
+ /// Both Alice and Bob must already be added as nodes first!
+ ///
+ ///
+ void AddEdge(GraphEdge edge);
+
+ ///
+ /// Retrieves a node by its unique identifier.
+ ///
+ /// The unique identifier of the node.
+ /// The node if found; otherwise, null.
+ ///
+ /// For Beginners: This gets a specific entity if you know its ID.
+ ///
+ /// Like asking: "Show me the person with ID 'alice_001'"
+ ///
+ ///
+ GraphNode? GetNode(string nodeId);
+
+ ///
+ /// Retrieves an edge by its unique identifier.
+ ///
+ /// The unique identifier of the edge.
+ /// The edge if found; otherwise, null.
+ ///
+ /// For Beginners: This gets a specific relationship if you know its ID.
+ ///
+ /// Edge IDs are usually auto-generated like: "alice_001_KNOWS_bob_002"
+ ///
+ ///
+ GraphEdge? GetEdge(string edgeId);
+
+ ///
+ /// Removes a node and all its connected edges from the graph.
+ ///
+ /// The unique identifier of the node to remove.
+ /// True if the node was found and removed; otherwise, false.
+ ///
+ ///
+ /// This method removes a node and automatically cleans up all edges connected to it
+ /// (both incoming and outgoing). This ensures the graph remains consistent.
+ ///
+ /// For Beginners: This deletes an entity and all its connections.
+ ///
+ /// Like removing Alice from the network:
+ /// - Alice's profile is deleted
+ /// - All "Alice KNOWS Bob" relationships are deleted
+ /// - All "Bob KNOWS Alice" relationships are deleted
+ ///
+ /// This keeps the graph clean - no broken connections!
+ ///
+ ///
+ bool RemoveNode(string nodeId);
+
+ ///
+ /// Removes an edge from the graph.
+ ///
+ /// The unique identifier of the edge to remove.
+ /// True if the edge was found and removed; otherwise, false.
+ ///
+ /// For Beginners: This deletes a specific relationship.
+ ///
+ /// Like saying "Alice no longer knows Bob" - removes just that connection,
+ /// but Alice and Bob still exist in the graph.
+ ///
+ ///
+ bool RemoveEdge(string edgeId);
+
+ ///
+ /// Gets all outgoing edges from a specific node.
+ ///
+ /// The source node ID.
+ /// Collection of outgoing edges from the node.
+ ///
+ ///
+ /// Outgoing edges represent relationships where this node is the source.
+ /// For example, if Alice KNOWS Bob, the "KNOWS" edge is outgoing from Alice.
+ ///
+ /// For Beginners: This finds all relationships going OUT from an entity.
+ ///
+ /// If you ask for Alice's outgoing edges, you get:
+ /// - Alice KNOWS Bob
+ /// - Alice WORKS_AT CompanyX
+ /// - Alice LIVES_IN Seattle
+ ///
+ /// These are things Alice does or has relationships with.
+ ///
+ ///
+ IEnumerable> GetOutgoingEdges(string nodeId);
+
+ ///
+ /// Gets all incoming edges to a specific node.
+ ///
+ /// The target node ID.
+ /// Collection of incoming edges to the node.
+ ///
+ ///
+ /// Incoming edges represent relationships where this node is the target.
+ /// For example, if Alice KNOWS Bob, the "KNOWS" edge is incoming to Bob.
+ ///
+ /// For Beginners: This finds all relationships coming IN to an entity.
+ ///
+ /// If you ask for Bob's incoming edges, you get:
+ /// - Alice KNOWS Bob
+ /// - Charlie WORKS_WITH Bob
+ /// - CompanyY EMPLOYS Bob
+ ///
+ /// These are relationships others have WITH Bob.
+ ///
+ ///
+ IEnumerable> GetIncomingEdges(string nodeId);
+
+ ///
+ /// Gets all nodes with a specific label.
+ ///
+ /// The node label to filter by (e.g., "PERSON", "COMPANY", "LOCATION").
+ /// Collection of nodes with the specified label.
+ ///
+ ///
+ /// Labels are used to categorize nodes by type. This enables efficient queries
+ /// like "find all PERSON nodes" or "find all COMPANY nodes".
+ ///
+ /// For Beginners: This finds all entities of a specific type.
+ ///
+ /// Like asking: "Show me all PERSON nodes"
+ /// Returns: Alice, Bob, Charlie (all people in the graph)
+ ///
+ /// Or: "Show me all COMPANY nodes"
+ /// Returns: Microsoft, Google, Amazon (all companies)
+ ///
+ /// Labels are like categories or tags for organizing your entities.
+ ///
+ ///
+ IEnumerable> GetNodesByLabel(string label);
+
+ ///
+ /// Gets all nodes currently stored in the graph.
+ ///
+ /// Collection of all nodes.
+ ///
+ ///
+ /// This method retrieves every node without any filtering.
+ /// Use with caution on large graphs as it may be memory-intensive.
+ ///
+ /// For Beginners: This gets every single entity in the graph.
+ ///
+ /// Like asking: "Show me everyone and everything in the network"
+ ///
+ /// Warning: If you have millions of entities, this could be slow and use lots of memory!
+ ///
+ ///
+ IEnumerable> GetAllNodes();
+
+ ///
+ /// Gets all edges currently stored in the graph.
+ ///
+ /// Collection of all edges.
+ ///
+ ///
+ /// This method retrieves every edge without any filtering.
+ /// Use with caution on large graphs as it may be memory-intensive.
+ ///
+ /// For Beginners: This gets every single relationship in the graph.
+ ///
+ /// Like asking: "Show me every connection between all entities"
+ ///
+ /// Warning: Large graphs can have millions of relationships!
+ ///
+ ///
+ IEnumerable> GetAllEdges();
+
+ ///
+ /// Removes all nodes and edges from the graph.
+ ///
+ ///
+ ///
+ /// This method clears the entire graph, removing all data. For persistent stores,
+ /// this may involve deleting files or database records. Use with extreme caution!
+ ///
+ /// For Beginners: This deletes EVERYTHING from the graph.
+ ///
+ /// Like wiping the entire social network clean - all people and all connections gone!
+ ///
+ /// ⚠️ WARNING: This cannot be undone! Make backups first!
+ ///
+ ///
+ void Clear();
+
+ // Async methods for I/O-intensive operations
+
+ ///
+ /// Asynchronously adds a node to the graph or updates it if it already exists.
+ ///
+ /// The node to add.
+ /// A task representing the asynchronous operation.
+ ///
+ ///
+ /// This is the async version of . Use this for file-based or
+ /// database-backed stores to avoid blocking the thread during I/O operations.
+ ///
+ /// For Beginners: This does the same as AddNode but doesn't block your app.
+ ///
+ /// When should you use async?
+ /// - FileGraphStore: Yes! (writes to disk)
+ /// - MemoryGraphStore: Optional (no I/O, but provided for consistency)
+ /// - Database stores: Definitely! (network I/O)
+ ///
+ /// Example:
+ /// ```csharp
+ /// await store.AddNodeAsync(node); // Non-blocking
+ /// ```
+ ///
+ ///
+ Task AddNodeAsync(GraphNode node);
+
+ ///
+ /// Asynchronously adds an edge to the graph.
+ ///
+ /// The edge to add.
+ /// A task representing the asynchronous operation.
+ Task AddEdgeAsync(GraphEdge edge);
+
+ ///
+ /// Asynchronously retrieves a node by its unique identifier.
+ ///
+ /// The unique identifier of the node.
+ /// A task that represents the asynchronous operation. The task result contains the node if found; otherwise, null.
+ Task?> GetNodeAsync(string nodeId);
+
+ ///
+ /// Asynchronously retrieves an edge by its unique identifier.
+ ///
+ /// The unique identifier of the edge.
+ /// A task that represents the asynchronous operation. The task result contains the edge if found; otherwise, null.
+ Task?> GetEdgeAsync(string edgeId);
+
+ ///
+ /// Asynchronously removes a node and all its connected edges from the graph.
+ ///
+ /// The unique identifier of the node to remove.
+ /// A task that represents the asynchronous operation. The task result is true if the node was found and removed; otherwise, false.
+ Task RemoveNodeAsync(string nodeId);
+
+ ///
+ /// Asynchronously removes an edge from the graph.
+ ///
+ /// The unique identifier of the edge to remove.
+ /// A task that represents the asynchronous operation. The task result is true if the edge was found and removed; otherwise, false.
+ Task RemoveEdgeAsync(string edgeId);
+
+ ///
+ /// Asynchronously gets all outgoing edges from a specific node.
+ ///
+ /// The source node ID.
+ /// A task that represents the asynchronous operation. The task result contains the collection of outgoing edges.
+ Task>> GetOutgoingEdgesAsync(string nodeId);
+
+ ///
+ /// Asynchronously gets all incoming edges to a specific node.
+ ///
+ /// The target node ID.
+ /// A task that represents the asynchronous operation. The task result contains the collection of incoming edges.
+ Task>> GetIncomingEdgesAsync(string nodeId);
+
+ ///
+ /// Asynchronously gets all nodes with a specific label.
+ ///
+ /// The node label to filter by.
+ /// A task that represents the asynchronous operation. The task result contains the collection of nodes with the specified label.
+ Task>> GetNodesByLabelAsync(string label);
+
+ ///
+ /// Asynchronously gets all nodes currently stored in the graph.
+ ///
+ /// A task that represents the asynchronous operation. The task result contains all nodes.
+ Task>> GetAllNodesAsync();
+
+ ///
+ /// Asynchronously gets all edges currently stored in the graph.
+ ///
+ /// A task that represents the asynchronous operation. The task result contains all edges.
+ Task>> GetAllEdgesAsync();
+
+ ///
+ /// Asynchronously removes all nodes and edges from the graph.
+ ///
+ /// A task representing the asynchronous operation.
+ Task ClearAsync();
+}
diff --git a/src/Interfaces/IPredictionModelBuilder.cs b/src/Interfaces/IPredictionModelBuilder.cs
index bd2de66e3..0de672424 100644
--- a/src/Interfaces/IPredictionModelBuilder.cs
+++ b/src/Interfaces/IPredictionModelBuilder.cs
@@ -2,6 +2,7 @@
using AiDotNet.DistributedTraining;
using AiDotNet.Enums;
using AiDotNet.Models;
+using AiDotNet.RetrievalAugmentedGeneration.Graph;
namespace AiDotNet.Interfaces;
@@ -342,31 +343,52 @@ public interface IPredictionModelBuilder
/// Configures the retrieval-augmented generation (RAG) components for use during model inference.
///
///
+ ///
/// RAG enhances text generation by retrieving relevant documents from a knowledge base
/// and using them as context for generating grounded, factual answers.
- ///
+ ///
+ ///
+ /// Graph RAG: When graphStore or knowledgeGraph is provided, enables knowledge graph-based
+ /// retrieval that finds related entities and their relationships, providing richer context than
+ /// vector similarity alone. If documentStore is also provided, hybrid retrieval combines both
+ /// vector search and graph traversal.
+ ///
+ ///
/// For Beginners: RAG is like giving your AI access to a library before answering questions.
/// Instead of relying only on what it learned during training, it can:
- /// 1. Search a document collection for relevant information
- /// 2. Read the relevant documents
- /// 3. Generate an answer based on those documents
- /// 4. Cite its sources
- ///
- /// This makes answers more accurate, up-to-date, and traceable to source materials.
- ///
- /// RAG operations (GenerateAnswer, RetrieveDocuments) are performed during inference via PredictionModelResult,
- /// not during model building.
+ ///
+ /// - Search a document collection for relevant information
+ /// - Read the relevant documents
+ /// - Generate an answer based on those documents
+ /// - Cite its sources
+ ///
+ ///
+ ///
+ /// Graph RAG Example: If you ask about "Paris", Graph RAG can find not just documents
+ /// mentioning Paris, but also related concepts like France, Eiffel Tower, and Seine River
+ /// by traversing the knowledge graph.
+ ///
+ ///
+ /// RAG operations (GenerateAnswer, RetrieveDocuments, GraphQuery, etc.) are performed during
+ /// inference via PredictionModelResult, not during model building.
+ ///
///
- /// Optional retriever for finding relevant documents. If not provided, RAG won't be available.
+ /// Optional retriever for finding relevant documents. If not provided, standard RAG won't be available.
/// Optional reranker for improving document ranking quality. Default provided if retriever is set.
/// Optional generator for producing grounded answers. Default provided if retriever is set.
/// Optional query processors for improving search quality.
+ /// Optional graph storage backend for Graph RAG (e.g., MemoryGraphStore, FileGraphStore).
+ /// Optional pre-configured knowledge graph. If null but graphStore is provided, a new one is created.
+ /// Optional document store for hybrid vector + graph retrieval.
/// The builder instance for method chaining.
IPredictionModelBuilder ConfigureRetrievalAugmentedGeneration(
IRetriever? retriever = null,
IReranker? reranker = null,
IGenerator? generator = null,
- IEnumerable? queryProcessors = null);
+ IEnumerable? queryProcessors = null,
+ IGraphStore? graphStore = null,
+ KnowledgeGraph? knowledgeGraph = null,
+ IDocumentStore? documentStore = null);
///
/// Configures AI agent assistance during model building and inference.
diff --git a/src/Models/Results/PredictionModelResult.cs b/src/Models/Results/PredictionModelResult.cs
index 8289e1d6f..e5e2133fb 100644
--- a/src/Models/Results/PredictionModelResult.cs
+++ b/src/Models/Results/PredictionModelResult.cs
@@ -2,6 +2,7 @@
global using Formatting = Newtonsoft.Json.Formatting;
using AiDotNet.Data.Abstractions;
using AiDotNet.Interfaces;
+using AiDotNet.RetrievalAugmentedGeneration.Graph;
using AiDotNet.Interpretability;
using AiDotNet.Serialization;
using AiDotNet.Agents;
@@ -13,6 +14,7 @@
using AiDotNet.Deployment.Mobile.CoreML;
using AiDotNet.Deployment.Mobile.TensorFlowLite;
using AiDotNet.Deployment.Runtime;
+using AiDotNet.Enums;
namespace AiDotNet.Models.Results;
@@ -209,6 +211,53 @@ public class PredictionModelResult : IFullModelQuery processors for preprocessing queries, or null if not configured.
internal IEnumerable? QueryProcessors { get; private set; }
+ ///
+ /// Gets or sets the knowledge graph for graph-enhanced retrieval.
+ ///
+ /// A knowledge graph containing entities and relationships, or null if Graph RAG is not configured.
+ ///
+ /// For Beginners: The knowledge graph stores entities (like people, places, concepts) and their
+ /// relationships. When you query the model, it can traverse these relationships to find related context
+ /// that pure vector similarity might miss.
+ ///
+ ///
+ /// This property is excluded from JSON serialization because it contains runtime infrastructure
+ /// (graph store, file handles) that should be reconfigured when the model is loaded.
+ ///
+ ///
+ [JsonIgnore]
+ internal KnowledgeGraph? KnowledgeGraph { get; private set; }
+
+ ///
+ /// Gets or sets the graph store backend for persistent graph storage.
+ ///
+ /// The graph storage backend, or null if Graph RAG is not configured.
+ ///
+ ///
+ /// This property is excluded from JSON serialization because it represents runtime storage
+ /// infrastructure (file handles, WAL) that must be reconfigured when the model is loaded.
+ ///
+ ///
+ [JsonIgnore]
+ internal IGraphStore? GraphStore { get; private set; }
+
+ ///
+ /// Gets or sets the hybrid graph retriever for combined vector + graph retrieval.
+ ///
+ /// A hybrid retriever combining vector similarity with graph traversal, or null if not configured.
+ ///
+ /// For Beginners: The hybrid retriever first finds similar documents using vector search,
+ /// then expands the context by traversing the knowledge graph to find related entities. This provides
+ /// richer context than pure vector search alone.
+ ///
+ ///
+ /// This property is excluded from JSON serialization because it contains references to
+ /// runtime infrastructure (knowledge graph, document store) that must be reconfigured when loaded.
+ ///
+ ///
+ [JsonIgnore]
+ internal HybridGraphRetriever? HybridGraphRetriever { get; private set; }
+
///
/// Gets or sets the meta-learner used for few-shot adaptation and fine-tuning.
///
@@ -427,6 +476,9 @@ public PredictionModelResult(IFullModel model,
/// Optional agent configuration used during model building.
/// Optional agent recommendations from model building.
/// Optional deployment configuration for export, caching, versioning, A/B testing, and telemetry.
+ /// Optional knowledge graph for graph-enhanced retrieval.
+ /// Optional graph store backend for persistent storage.
+ /// Optional hybrid retriever for combined vector + graph search.
public PredictionModelResult(OptimizationResult optimizationResult,
NormalizationInfo normalizationInfo,
IBiasDetector? biasDetector = null,
@@ -441,7 +493,10 @@ public PredictionModelResult(OptimizationResult optimization
AgentRecommendation? agentRecommendation = null,
DeploymentConfiguration? deploymentConfiguration = null,
Func[], Tensor[]>? jitCompiledFunction = null,
- AiDotNet.Configuration.InferenceOptimizationConfig? inferenceOptimizationConfig = null)
+ AiDotNet.Configuration.InferenceOptimizationConfig? inferenceOptimizationConfig = null,
+ KnowledgeGraph? knowledgeGraph = null,
+ IGraphStore? graphStore = null,
+ HybridGraphRetriever? hybridGraphRetriever = null)
{
Model = optimizationResult.BestSolution;
OptimizationResult = optimizationResult;
@@ -460,6 +515,9 @@ public PredictionModelResult(OptimizationResult optimization
DeploymentConfiguration = deploymentConfiguration;
JitCompiledFunction = jitCompiledFunction;
InferenceOptimizationConfig = inferenceOptimizationConfig;
+ KnowledgeGraph = knowledgeGraph;
+ GraphStore = graphStore;
+ HybridGraphRetriever = hybridGraphRetriever;
}
///
@@ -476,6 +534,9 @@ public PredictionModelResult(OptimizationResult optimization
/// Optional query processors for RAG query preprocessing.
/// Optional agent configuration for AI assistance during inference.
/// Optional deployment configuration for export, caching, versioning, A/B testing, and telemetry.
+ /// Optional knowledge graph for graph-enhanced retrieval.
+ /// Optional graph store backend for persistent storage.
+ /// Optional hybrid retriever for combined vector + graph search.
///
///
/// This constructor is used when a model has been trained using meta-learning (e.g., MAML, Reptile, SEAL).
@@ -513,7 +574,10 @@ public PredictionModelResult(
IGenerator? ragGenerator = null,
IEnumerable? queryProcessors = null,
AgentConfiguration? agentConfig = null,
- DeploymentConfiguration? deploymentConfiguration = null)
+ DeploymentConfiguration? deploymentConfiguration = null,
+ KnowledgeGraph? knowledgeGraph = null,
+ IGraphStore? graphStore = null,
+ HybridGraphRetriever? hybridGraphRetriever = null)
{
Model = metaLearner.BaseModel;
MetaLearner = metaLearner;
@@ -528,6 +592,9 @@ public PredictionModelResult(
QueryProcessors = queryProcessors;
AgentConfig = agentConfig;
DeploymentConfiguration = deploymentConfiguration;
+ KnowledgeGraph = knowledgeGraph;
+ GraphStore = graphStore;
+ HybridGraphRetriever = hybridGraphRetriever;
// Create placeholder OptimizationResult and NormalizationInfo for consistency
OptimizationResult = new OptimizationResult();
@@ -1046,7 +1113,7 @@ public IFullModel WithParameters(Vector parameters)
// Create new result with updated optimization result
// Preserve all configuration properties to ensure deployment behavior, model adaptation,
- // and training history are maintained across parameter updates
+ // training history, and Graph RAG configuration are maintained across parameter updates
return new PredictionModelResult(
updatedOptimizationResult,
NormalizationInfo,
@@ -1060,7 +1127,12 @@ public IFullModel WithParameters(Vector parameters)
crossValidationResult: CrossValidationResult,
agentConfig: AgentConfig,
agentRecommendation: AgentRecommendation,
- deploymentConfiguration: DeploymentConfiguration);
+ deploymentConfiguration: DeploymentConfiguration,
+ jitCompiledFunction: null, // JIT compilation is parameter-specific, don't copy
+ inferenceOptimizationConfig: InferenceOptimizationConfig,
+ knowledgeGraph: KnowledgeGraph,
+ graphStore: GraphStore,
+ hybridGraphRetriever: HybridGraphRetriever);
}
///
@@ -1147,7 +1219,7 @@ public IFullModel DeepCopy()
var clonedNormalizationInfo = NormalizationInfo.DeepCopy();
// Preserve all configuration properties to ensure deployment behavior, model adaptation,
- // and training history are maintained across deep copy
+ // training history, and Graph RAG configuration are maintained across deep copy
return new PredictionModelResult(
clonedOptimizationResult,
clonedNormalizationInfo,
@@ -1161,7 +1233,12 @@ public IFullModel DeepCopy()
crossValidationResult: CrossValidationResult,
agentConfig: AgentConfig,
agentRecommendation: AgentRecommendation,
- deploymentConfiguration: DeploymentConfiguration);
+ deploymentConfiguration: DeploymentConfiguration,
+ jitCompiledFunction: null, // JIT compilation is model-specific, don't copy
+ inferenceOptimizationConfig: InferenceOptimizationConfig,
+ knowledgeGraph: KnowledgeGraph,
+ graphStore: GraphStore,
+ hybridGraphRetriever: HybridGraphRetriever);
}
///
@@ -1585,6 +1662,208 @@ private string ProcessQueryWithProcessors(string query)
return processedQuery;
}
+ ///
+ /// Queries the knowledge graph to find related nodes by entity name or label.
+ ///
+ /// The search query (entity name or partial match).
+ /// Maximum number of results to return.
+ /// Collection of matching graph nodes.
+ /// Thrown when Graph RAG is not configured.
+ ///
+ /// For Beginners: This method searches the knowledge graph for entities matching your query.
+ /// Unlike vector search which finds similar documents, this finds entities by name or label.
+ ///
+ /// Example:
+ ///
+ /// var nodes = result.QueryKnowledgeGraph("Einstein", topK: 5);
+ /// foreach (var node in nodes)
+ /// {
+ /// Console.WriteLine($"{node.Label}: {node.Id}");
+ /// }
+ ///
+ ///
+ ///
+ public IEnumerable> QueryKnowledgeGraph(string query, int topK = 10)
+ {
+ if (KnowledgeGraph == null)
+ {
+ throw new InvalidOperationException(
+ "Knowledge graph not configured. Configure Graph RAG using PredictionModelBuilder.ConfigureRetrievalAugmentedGeneration() before building the model.");
+ }
+
+ if (string.IsNullOrWhiteSpace(query))
+ throw new ArgumentException("Query cannot be null or empty", nameof(query));
+
+ return KnowledgeGraph.FindRelatedNodes(query, topK);
+ }
+
+ ///
+ /// Retrieves results using hybrid vector + graph search for enhanced context retrieval.
+ ///
+ /// The query embedding vector.
+ /// Number of initial candidates from vector search.
+ /// How many hops to traverse in the graph (0 = no expansion).
+ /// Maximum total results to return.
+ /// List of retrieval results with scores and source information.
+ /// Thrown when hybrid retriever is not configured.
+ ///
+ /// For Beginners: This method combines the best of both worlds:
+ /// 1. First, it finds similar documents using vector similarity (like traditional RAG)
+ /// 2. Then, it expands the context by traversing the knowledge graph to find related entities
+ ///
+ /// For example, searching for "photosynthesis" might:
+ /// - Find documents about photosynthesis via vector search
+ /// - Then traverse the graph to also include chlorophyll, plants, carbon dioxide
+ ///
+ /// This provides richer, more complete context than vector search alone.
+ ///
+ ///
+ public List> HybridRetrieve(
+ Vector queryEmbedding,
+ int topK = 5,
+ int expansionDepth = 1,
+ int maxResults = 10)
+ {
+ if (HybridGraphRetriever == null)
+ {
+ throw new InvalidOperationException(
+ "Hybrid graph retriever not configured. Configure Graph RAG with a document store using PredictionModelBuilder.ConfigureRetrievalAugmentedGeneration() before building the model.");
+ }
+
+ if (queryEmbedding == null || queryEmbedding.Length == 0)
+ throw new ArgumentException("Query embedding cannot be null or empty", nameof(queryEmbedding));
+
+ return HybridGraphRetriever.Retrieve(queryEmbedding, topK, expansionDepth, maxResults);
+ }
+
+ ///
+ /// Traverses the knowledge graph starting from a node using breadth-first search.
+ ///
+ /// The ID of the starting node.
+ /// Maximum traversal depth.
+ /// Collection of nodes reachable from the starting node in BFS order.
+ /// Thrown when knowledge graph is not configured.
+ ///
+ /// For Beginners: This method explores the graph starting from a specific entity,
+ /// discovering all connected entities up to a specified depth.
+ ///
+ /// Example: Starting from "Paris", depth=2 might find:
+ /// - Depth 1: France, Eiffel Tower, Seine River
+ /// - Depth 2: Europe, Iron, Water
+ ///
+ /// This is useful for understanding the context around a specific entity.
+ ///
+ ///
+ public IEnumerable> TraverseGraph(string startNodeId, int maxDepth = 2)
+ {
+ if (KnowledgeGraph == null)
+ {
+ throw new InvalidOperationException(
+ "Knowledge graph not configured. Configure Graph RAG using PredictionModelBuilder.ConfigureRetrievalAugmentedGeneration() before building the model.");
+ }
+
+ if (string.IsNullOrWhiteSpace(startNodeId))
+ throw new ArgumentException("Start node ID cannot be null or empty", nameof(startNodeId));
+
+ return KnowledgeGraph.BreadthFirstTraversal(startNodeId, maxDepth);
+ }
+
+ ///
+ /// Finds the shortest path between two nodes in the knowledge graph.
+ ///
+ /// The ID of the starting node.
+ /// The ID of the target node.
+ /// List of node IDs representing the path, or empty list if no path exists.
+ /// Thrown when knowledge graph is not configured.
+ ///
+ /// For Beginners: This method finds how two entities are connected.
+ ///
+ /// Example: Finding the path between "Einstein" and "Princeton University" might return:
+ /// ["einstein", "worked_at_princeton", "princeton_university"]
+ ///
+ /// This is useful for understanding the relationships between concepts.
+ ///
+ ///
+ public List FindPathInGraph(string startNodeId, string endNodeId)
+ {
+ if (KnowledgeGraph == null)
+ {
+ throw new InvalidOperationException(
+ "Knowledge graph not configured. Configure Graph RAG using PredictionModelBuilder.ConfigureRetrievalAugmentedGeneration() before building the model.");
+ }
+
+ if (string.IsNullOrWhiteSpace(startNodeId))
+ throw new ArgumentException("Start node ID cannot be null or empty", nameof(startNodeId));
+ if (string.IsNullOrWhiteSpace(endNodeId))
+ throw new ArgumentException("End node ID cannot be null or empty", nameof(endNodeId));
+
+ return KnowledgeGraph.FindShortestPath(startNodeId, endNodeId);
+ }
+
+ ///
+ /// Gets all edges (relationships) connected to a node in the knowledge graph.
+ ///
+ /// The ID of the node to query.
+ /// The direction of edges to retrieve.
+ /// Collection of edges connected to the node.
+ /// Thrown when knowledge graph is not configured.
+ /// Thrown when nodeId is null or empty.
+ ///
+ /// For Beginners: This method finds all relationships connected to an entity.
+ ///
+ /// Example: Getting edges for "Einstein" might return:
+ /// - Outgoing: STUDIED→Physics, WORKED_AT→Princeton, BORN_IN→Germany
+ /// - Incoming: INFLUENCED_BY→Newton
+ ///
+ ///
+ public IEnumerable> GetNodeRelationships(string nodeId, EdgeDirection direction = EdgeDirection.Both)
+ {
+ if (KnowledgeGraph == null)
+ {
+ throw new InvalidOperationException(
+ "Knowledge graph not configured. Configure Graph RAG using PredictionModelBuilder.ConfigureRetrievalAugmentedGeneration() before building the model.");
+ }
+
+ if (string.IsNullOrWhiteSpace(nodeId))
+ throw new ArgumentException("Node ID cannot be null or empty", nameof(nodeId));
+
+ var result = new List>();
+
+ if (direction == EdgeDirection.Outgoing || direction == EdgeDirection.Both)
+ {
+ result.AddRange(KnowledgeGraph.GetOutgoingEdges(nodeId));
+ }
+
+ if (direction == EdgeDirection.Incoming || direction == EdgeDirection.Both)
+ {
+ result.AddRange(KnowledgeGraph.GetIncomingEdges(nodeId));
+ }
+
+ return result;
+ }
+
+ ///
+ /// Attaches Graph RAG components to a PredictionModelResult instance.
+ ///
+ /// The knowledge graph to attach.
+ /// The graph store backend to attach.
+ /// The hybrid retriever to attach.
+ ///
+ /// This method is internal and used by PredictionModelBuilder when loading/deserializing models.
+ /// Graph RAG components cannot be serialized (they contain file handles, WAL references, etc.),
+ /// so the builder automatically reattaches them when loading a model that was configured with Graph RAG.
+ /// Users should use PredictionModelBuilder.LoadModel() which handles this automatically.
+ ///
+ internal void AttachGraphComponents(
+ KnowledgeGraph? knowledgeGraph = null,
+ IGraphStore? graphStore = null,
+ HybridGraphRetriever? hybridGraphRetriever = null)
+ {
+ KnowledgeGraph = knowledgeGraph;
+ GraphStore = graphStore;
+ HybridGraphRetriever = hybridGraphRetriever;
+ }
+
///
/// Saves the prediction model result's current state to a stream.
///
diff --git a/src/PredictionModelBuilder.cs b/src/PredictionModelBuilder.cs
index 2edbf1c2a..83abd8b4e 100644
--- a/src/PredictionModelBuilder.cs
+++ b/src/PredictionModelBuilder.cs
@@ -17,6 +17,7 @@
global using AiDotNet.MixedPrecision;
global using AiDotNet.KnowledgeDistillation;
global using AiDotNet.Deployment.Configuration;
+global using AiDotNet.RetrievalAugmentedGeneration.Graph;
namespace AiDotNet;
@@ -54,6 +55,11 @@ public class PredictionModelBuilder : IPredictionModelBuilde
private IReranker? _ragReranker;
private IGenerator? _ragGenerator;
private IEnumerable? _queryProcessors;
+
+ // Graph RAG components for knowledge graph-enhanced retrieval
+ private KnowledgeGraph? _knowledgeGraph;
+ private IGraphStore? _graphStore;
+ private HybridGraphRetriever? _hybridGraphRetriever;
private IMetaLearner? _metaLearner;
private ICommunicationBackend? _distributedBackend;
private DistributedStrategy _distributedStrategy = DistributedStrategy.DDP;
@@ -565,7 +571,10 @@ public Task> BuildAsync()
ragGenerator: _ragGenerator,
queryProcessors: _queryProcessors,
agentConfig: _agentConfig,
- deploymentConfiguration: deploymentConfig);
+ deploymentConfiguration: deploymentConfig,
+ knowledgeGraph: _knowledgeGraph,
+ graphStore: _graphStore,
+ hybridGraphRetriever: _hybridGraphRetriever);
return Task.FromResult(result);
}
@@ -917,7 +926,10 @@ public async Task> BuildAsync(TInput x
agentRecommendation,
deploymentConfig,
jitCompiledFunction,
- _inferenceOptimizationConfig);
+ _inferenceOptimizationConfig,
+ _knowledgeGraph,
+ _graphStore,
+ _hybridGraphRetriever);
return finalResult;
}
@@ -1086,6 +1098,8 @@ public async Task> BuildAsync(int epis
_gpuAccelerationConfig);
// Return standard PredictionModelResult
+ // Note: This Build() overload doesn't perform JIT compilation (only the main Build() does),
+ // so jitCompiledFunction uses its default value of null
var result = new PredictionModelResult(
optimizationResult,
normInfo,
@@ -1099,7 +1113,11 @@ public async Task> BuildAsync(int epis
crossValidationResult: null,
_agentConfig,
agentRecommendation: null,
- deploymentConfig);
+ deploymentConfiguration: deploymentConfig,
+ inferenceOptimizationConfig: _inferenceOptimizationConfig,
+ knowledgeGraph: _knowledgeGraph,
+ graphStore: _graphStore,
+ hybridGraphRetriever: _hybridGraphRetriever);
return result;
}
@@ -1206,6 +1224,14 @@ public PredictionModelResult DeserializeModel(byte[] modelDa
var result = new PredictionModelResult();
result.Deserialize(modelData);
+ // Automatically reattach Graph RAG components if they were configured on this builder
+ // Graph RAG components cannot be serialized (file handles, WAL, etc.), so we reattach
+ // them from the builder's configuration to provide a seamless experience for users
+ if (_knowledgeGraph != null || _graphStore != null || _hybridGraphRetriever != null)
+ {
+ result.AttachGraphComponents(_knowledgeGraph, _graphStore, _hybridGraphRetriever);
+ }
+
return result;
}
@@ -1263,31 +1289,96 @@ public IPredictionModelBuilder ConfigureLoRA(ILoRAConfigurat
///
/// Configures the retrieval-augmented generation (RAG) components for use during model inference.
///
- /// Optional retriever for finding relevant documents. If not provided, RAG functionality won't be available.
+ /// Optional retriever for finding relevant documents. If not provided, standard RAG won't be available.
/// Optional reranker for improving document ranking quality. If not provided, a default reranker will be used if RAG is configured.
/// Optional generator for producing grounded answers. If not provided, a default generator will be used if RAG is configured.
/// Optional query processors for improving search quality.
+ /// Optional graph storage backend for Graph RAG (e.g., MemoryGraphStore, FileGraphStore).
+ /// Optional pre-configured knowledge graph. If null but graphStore is provided, a new one is created.
+ /// Optional document store for hybrid vector + graph retrieval.
/// This builder instance for method chaining.
///
+ ///
/// For Beginners: RAG combines retrieval and generation to create answers backed by real documents.
/// Configure it with:
- /// - A retriever (finds relevant documents from your collection) - required for RAG
- /// - A reranker (improves the ordering of retrieved documents) - optional, defaults provided
- /// - A generator (creates answers based on the documents) - optional, defaults provided
- /// - Optional query processors (improve search queries before retrieval)
- ///
+ ///
+ /// - A retriever (finds relevant documents from your collection) - required for standard RAG
+ /// - A reranker (improves the ordering of retrieved documents) - optional, defaults provided
+ /// - A generator (creates answers based on the documents) - optional, defaults provided
+ /// - Optional query processors (improve search queries before retrieval)
+ ///
+ ///
+ ///
+ /// Graph RAG: When graphStore or knowledgeGraph is provided, enables knowledge graph-based
+ /// retrieval that finds related entities and their relationships, providing richer context than
+ /// vector similarity alone. Traditional RAG finds similar documents using vectors. Graph RAG goes further by
+ /// also exploring relationships between entities. For example, if you ask about "Paris", it can find
+ /// not just documents mentioning Paris, but also related concepts like France, Eiffel Tower, and Seine River.
+ ///
+ ///
+ /// Hybrid Retrieval: When both knowledgeGraph and documentStore are provided, creates a
+ /// HybridGraphRetriever that combines vector search and graph traversal for optimal results.
+ ///
+ ///
+ /// Disabling RAG: Call with all parameters as null to disable RAG functionality completely.
+ ///
+ ///
/// RAG operations are performed during inference (after model training) via the PredictionModelResult.
+ ///
///
public IPredictionModelBuilder ConfigureRetrievalAugmentedGeneration(
IRetriever? retriever = null,
IReranker? reranker = null,
IGenerator? generator = null,
- IEnumerable? queryProcessors = null)
+ IEnumerable? queryProcessors = null,
+ IGraphStore? graphStore = null,
+ KnowledgeGraph? knowledgeGraph = null,
+ IDocumentStore? documentStore = null)
{
+ // Configure standard RAG components
_ragRetriever = retriever;
_ragReranker = reranker;
_ragGenerator = generator;
_queryProcessors = queryProcessors;
+
+ // Configure Graph RAG components
+ // If all Graph RAG parameters are null, clear Graph RAG fields
+ if (graphStore == null && knowledgeGraph == null && documentStore == null)
+ {
+ _graphStore = null;
+ _knowledgeGraph = null;
+ _hybridGraphRetriever = null;
+ return this;
+ }
+
+ _graphStore = graphStore;
+
+ // Use provided knowledge graph or create one from the store
+ if (knowledgeGraph != null)
+ {
+ _knowledgeGraph = knowledgeGraph;
+ }
+ else if (graphStore != null)
+ {
+ _knowledgeGraph = new KnowledgeGraph(graphStore);
+ }
+ else
+ {
+ // No knowledge graph source provided, clear the field
+ _knowledgeGraph = null;
+ }
+
+ // Create or clear hybrid retriever based on available components
+ if (_knowledgeGraph != null && documentStore != null)
+ {
+ _hybridGraphRetriever = new HybridGraphRetriever(_knowledgeGraph, documentStore);
+ }
+ else
+ {
+ // Clear hybrid retriever if dependencies are missing
+ _hybridGraphRetriever = null;
+ }
+
return this;
}
diff --git a/src/RetrievalAugmentedGeneration/Graph/BTreeIndex.cs b/src/RetrievalAugmentedGeneration/Graph/BTreeIndex.cs
new file mode 100644
index 000000000..739216123
--- /dev/null
+++ b/src/RetrievalAugmentedGeneration/Graph/BTreeIndex.cs
@@ -0,0 +1,319 @@
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Text;
+
+namespace AiDotNet.RetrievalAugmentedGeneration.Graph;
+
+///
+/// Simple file-based index for mapping string keys to file offsets.
+///
+///
+///
+/// This class provides a persistent index structure that maps string keys (e.g., node IDs)
+/// to byte offsets in data files. The index is stored on disk and reloaded on restart,
+/// enabling fast lookups without scanning entire data files.
+///
+/// For Beginners: Think of this like a book's index at the back.
+///
+/// Without an index:
+/// - To find "photosynthesis", you'd read every page from start to finish
+/// - Very slow for large books (or large data files)
+///
+/// With an index:
+/// - Look up "photosynthesis" → find it's on page 157
+/// - Jump directly to page 157
+/// - Much faster!
+///
+/// This class does the same for graph data:
+/// - Key: "node_alice_001"
+/// - Value: byte offset 45678 in nodes.dat file
+/// - We can jump directly to byte 45678 to read Alice's data
+///
+/// The index itself is stored in a file so it survives application restarts.
+///
+///
+/// Implementation Note: This is a simplified index using a sorted dictionary.
+/// For production systems with millions of entries, consider implementing a true
+/// B-Tree structure with splitting/merging nodes, or use an embedded database like
+/// SQLite or LevelDB.
+///
+///
+public class BTreeIndex : IDisposable
+{
+ private readonly string _indexFilePath;
+ private readonly SortedDictionary _index;
+ private bool _isDirty;
+ private bool _disposed;
+
+ ///
+ /// Gets the number of entries in the index.
+ ///
+ public int Count => _index.Count;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The path to the index file on disk.
+ public BTreeIndex(string indexFilePath)
+ {
+ _indexFilePath = indexFilePath ?? throw new ArgumentNullException(nameof(indexFilePath));
+ _index = new SortedDictionary();
+ _isDirty = false;
+
+ LoadFromDisk();
+ }
+
+ ///
+ /// Adds or updates a key-offset mapping in the index.
+ ///
+ /// The key to index (e.g., node ID).
+ /// The byte offset in the data file.
+ ///
+ /// For Beginners: This adds an entry to the index.
+ ///
+ /// Example:
+ /// - index.Add("alice", 1024)
+ /// - This means: "Alice's data starts at byte 1024 in the file"
+ ///
+ /// If "alice" already exists, it updates to the new offset.
+ /// The index is marked as "dirty" and will be saved to disk later.
+ ///
+ ///
+ public void Add(string key, long offset)
+ {
+ if (string.IsNullOrWhiteSpace(key))
+ throw new ArgumentException("Key cannot be null or whitespace", nameof(key));
+ if (offset < 0)
+ throw new ArgumentException("Offset cannot be negative", nameof(offset));
+
+ _index[key] = offset;
+ _isDirty = true;
+ }
+
+ ///
+ /// Retrieves the file offset associated with a key.
+ ///
+ /// The key to look up.
+ /// The byte offset if found; otherwise, -1.
+ ///
+ /// For Beginners: This looks up where data is stored.
+ ///
+ /// Example:
+ /// - var offset = index.Get("alice")
+ /// - Returns: 1024 (meaning Alice's data is at byte 1024)
+ /// - Or returns -1 if "alice" is not in the index
+ ///
+ ///
+ public long Get(string key)
+ {
+ if (string.IsNullOrWhiteSpace(key))
+ return -1;
+
+ return _index.TryGetValue(key, out var offset) ? offset : -1;
+ }
+
+ ///
+ /// Checks if a key exists in the index.
+ ///
+ /// The key to check.
+ /// True if the key exists; otherwise, false.
+ public bool Contains(string key)
+ {
+ return !string.IsNullOrWhiteSpace(key) && _index.ContainsKey(key);
+ }
+
+ ///
+ /// Removes a key from the index.
+ ///
+ /// The key to remove.
+ /// True if the key was found and removed; otherwise, false.
+ ///
+ /// For Beginners: This removes an entry from the index.
+ ///
+ /// Example:
+ /// - index.Remove("alice")
+ /// - Now we can no longer look up where Alice's data is
+ /// - The actual data file is NOT modified - just the index
+ ///
+ /// Note: For production systems, you'd typically mark entries as deleted
+ /// rather than removing them, to avoid fragmenting the data file.
+ ///
+ ///
+ public bool Remove(string key)
+ {
+ if (string.IsNullOrWhiteSpace(key))
+ return false;
+
+ var removed = _index.Remove(key);
+ if (removed)
+ _isDirty = true;
+
+ return removed;
+ }
+
+ ///
+ /// Gets all keys in the index.
+ ///
+ /// Collection of all keys.
+ public IEnumerable GetAllKeys()
+ {
+ return _index.Keys.ToList();
+ }
+
+ ///
+ /// Removes all entries from the index.
+ ///
+ public void Clear()
+ {
+ _index.Clear();
+ _isDirty = true;
+ }
+
+ ///
+ /// Saves the index to disk if it has been modified.
+ ///
+ ///
+ ///
+ /// This method writes the index to disk in a simple text format:
+ /// - Each line: "key|offset"
+ /// - Example: "alice|1024"
+ ///
+ /// The file is only written if changes have been made (_isDirty flag).
+ /// This is called automatically by Dispose() to ensure data isn't lost.
+ ///
+ /// For Beginners: This saves the index to a file.
+ ///
+ /// Think of it like saving your work:
+ /// - You've been adding entries to the index (in memory)
+ /// - Flush() writes everything to disk
+ /// - Now the index survives if the program crashes or restarts
+ ///
+ /// Format example (nodes_index.db):
+ /// ```
+ /// alice|1024
+ /// bob|2048
+ /// charlie|3072
+ /// ```
+ ///
+ ///
+ public void Flush()
+ {
+ if (!_isDirty)
+ return;
+
+ try
+ {
+ // Ensure directory exists
+ var directory = Path.GetDirectoryName(_indexFilePath);
+ if (!string.IsNullOrEmpty(directory) && !Directory.Exists(directory))
+ Directory.CreateDirectory(directory);
+
+ // Write index to temporary file first (atomic write)
+ var tempPath = _indexFilePath + ".tmp";
+ using (var writer = new StreamWriter(tempPath, false, Encoding.UTF8))
+ {
+ foreach (var kvp in _index)
+ {
+ writer.WriteLine($"{kvp.Key}|{kvp.Value}");
+ }
+ }
+
+ // Replace old index file with new one atomically
+ if (File.Exists(_indexFilePath))
+ {
+ // Use File.Replace for atomic replacement on Windows
+ var backupPath = _indexFilePath + ".bak";
+ File.Replace(tempPath, _indexFilePath, backupPath);
+ // Clean up backup file
+ if (File.Exists(backupPath))
+ File.Delete(backupPath);
+ }
+ else
+ {
+ File.Move(tempPath, _indexFilePath);
+ }
+
+ _isDirty = false;
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to flush index to disk: {_indexFilePath}", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException($"Failed to flush index to disk: {_indexFilePath}", ex);
+ }
+ }
+
+ ///
+ /// Loads the index from disk if it exists.
+ ///
+ private void LoadFromDisk()
+ {
+ if (!File.Exists(_indexFilePath))
+ return;
+
+ try
+ {
+ using var reader = new StreamReader(_indexFilePath, Encoding.UTF8);
+ string? line;
+ while ((line = reader.ReadLine()) != null)
+ {
+ var parts = line.Split('|');
+ if (parts.Length != 2)
+ continue;
+
+ var key = parts[0];
+ if (long.TryParse(parts[1], out var offset))
+ {
+ _index[key] = offset;
+ }
+ }
+
+ _isDirty = false;
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to load index from disk: {_indexFilePath}", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException($"Failed to load index from disk: {_indexFilePath}", ex);
+ }
+ }
+
+ ///
+ /// Disposes the index, ensuring all changes are saved to disk.
+ ///
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ ///
+ /// Releases resources used by the index.
+ ///
+ /// True if called from Dispose(), false if called from finalizer.
+ protected virtual void Dispose(bool disposing)
+ {
+ if (_disposed)
+ return;
+
+ try
+ {
+ if (disposing)
+ {
+ // Flush managed resources
+ Flush();
+ }
+ }
+ finally
+ {
+ // Ensure _disposed is set even if Flush throws
+ _disposed = true;
+ }
+ }
+}
diff --git a/src/RetrievalAugmentedGeneration/Graph/FileGraphStore.cs b/src/RetrievalAugmentedGeneration/Graph/FileGraphStore.cs
new file mode 100644
index 000000000..deb172b83
--- /dev/null
+++ b/src/RetrievalAugmentedGeneration/Graph/FileGraphStore.cs
@@ -0,0 +1,1011 @@
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using AiDotNet.Interfaces;
+using Newtonsoft.Json;
+
+namespace AiDotNet.RetrievalAugmentedGeneration.Graph;
+
+///
+/// File-based implementation of with persistent storage on disk.
+///
+/// The numeric type used for vector operations.
+///
+///
+/// This implementation provides persistent graph storage using files:
+/// - nodes.dat: Binary file containing serialized nodes
+/// - edges.dat: Binary file containing serialized edges
+/// - node_index.db: B-Tree index mapping node IDs to file offsets
+/// - edge_index.db: B-Tree index mapping edge IDs to file offsets
+///
+/// For Beginners: This stores your graph on disk so it survives restarts.
+///
+/// How it works:
+/// 1. When you add a node, it's written to nodes.dat
+/// 2. The position (offset) is recorded in node_index.db
+/// 3. To retrieve a node, we look up its offset and read from that position
+/// 4. Everything is saved to disk automatically
+///
+/// Pros:
+/// - 💾 Data persists across restarts
+/// - 🔄 Can handle graphs larger than RAM
+/// - 📁 Simple file-based storage (no database required)
+///
+/// Cons:
+/// - 🐌 Slower than in-memory (disk I/O overhead)
+/// - 🔒 Not suitable for concurrent access from multiple processes
+/// - 📦 No compression (files can be large)
+///
+/// Good for:
+/// - Applications that need to save graph state
+/// - Graphs up to a few million nodes
+/// - Single-process applications
+///
+/// Not good for:
+/// - Real-time systems requiring sub-millisecond latency
+/// - Multi-process concurrent access
+/// - Distributed systems (use Neo4j or similar instead)
+///
+///
+public class FileGraphStore : IGraphStore, IDisposable
+{
+ private readonly string _storageDirectory;
+ private readonly string _nodesFilePath;
+ private readonly string _edgesFilePath;
+ private readonly BTreeIndex _nodeIndex;
+ private readonly BTreeIndex _edgeIndex;
+ private readonly WriteAheadLog? _wal;
+
+ // In-memory caches for indices and metadata (thread-safe for concurrent async access)
+ private readonly ConcurrentDictionary> _outgoingEdges; // nodeId -> edge IDs
+ private readonly ConcurrentDictionary> _incomingEdges; // nodeId -> edge IDs
+ private readonly ConcurrentDictionary> _nodesByLabel; // label -> node IDs
+ private readonly object _cacheLock = new object(); // Lock for modifying HashSet contents
+
+ private readonly JsonSerializerSettings _jsonSettings;
+ private bool _disposed;
+
+ ///
+ public int NodeCount => _nodeIndex.Count;
+
+ ///
+ public int EdgeCount => _edgeIndex.Count;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The directory where graph data files will be stored.
+ /// Optional Write-Ahead Log for ACID transactions and crash recovery.
+ public FileGraphStore(string storageDirectory, WriteAheadLog? wal = null)
+ {
+ if (string.IsNullOrWhiteSpace(storageDirectory))
+ throw new ArgumentException("Storage directory cannot be null or whitespace", nameof(storageDirectory));
+
+ _storageDirectory = storageDirectory;
+ _nodesFilePath = Path.Combine(storageDirectory, "nodes.dat");
+ _edgesFilePath = Path.Combine(storageDirectory, "edges.dat");
+ _wal = wal;
+
+ // Create directory if it doesn't exist
+ if (!Directory.Exists(storageDirectory))
+ Directory.CreateDirectory(storageDirectory);
+
+ // Initialize indices
+ _nodeIndex = new BTreeIndex(Path.Combine(storageDirectory, "node_index.db"));
+ _edgeIndex = new BTreeIndex(Path.Combine(storageDirectory, "edge_index.db"));
+
+ // Initialize in-memory structures (thread-safe)
+ _outgoingEdges = new ConcurrentDictionary>();
+ _incomingEdges = new ConcurrentDictionary>();
+ _nodesByLabel = new ConcurrentDictionary>();
+
+ _jsonSettings = new JsonSerializerSettings
+ {
+ Formatting = Formatting.None
+ };
+
+ // Rebuild in-memory indices from persisted data
+ RebuildInMemoryIndices();
+ }
+
+ ///
+ public void AddNode(GraphNode node)
+ {
+ if (node == null)
+ throw new ArgumentNullException(nameof(node));
+
+ try
+ {
+ // Log to WAL first (durability)
+ _wal?.LogAddNode(node);
+
+ // Serialize node to JSON
+ var json = JsonConvert.SerializeObject(node, _jsonSettings);
+ var bytes = Encoding.UTF8.GetBytes(json);
+
+ // Get current file position (or reuse existing offset if updating)
+ long offset;
+ // For updates, we append to the end (old data becomes garbage)
+ // In production, you'd implement compaction to reclaim space
+ offset = new FileInfo(_nodesFilePath).Exists ? new FileInfo(_nodesFilePath).Length : 0;
+
+ // Write node data to file
+ using (var stream = new FileStream(_nodesFilePath, FileMode.Append, FileAccess.Write, FileShare.Read))
+ {
+ // Write length prefix (4 bytes)
+ var lengthBytes = BitConverter.GetBytes(bytes.Length);
+ stream.Write(lengthBytes, 0, 4);
+
+ // Write JSON data
+ stream.Write(bytes, 0, bytes.Length);
+ }
+
+ // Update index
+ _nodeIndex.Add(node.Id, offset);
+
+ // Update in-memory indices (thread-safe)
+ lock (_cacheLock)
+ {
+ var labelSet = _nodesByLabel.GetOrAdd(node.Label, _ => new HashSet());
+ labelSet.Add(node.Id);
+
+ _outgoingEdges.GetOrAdd(node.Id, _ => new HashSet());
+ _incomingEdges.GetOrAdd(node.Id, _ => new HashSet());
+ }
+
+ // Flush indices periodically (every 100 operations for performance)
+ if (_nodeIndex.Count % 100 == 0)
+ _nodeIndex.Flush();
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to add node '{node.Id}' to file store", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException($"Failed to add node '{node.Id}' to file store due to unauthorized access", ex);
+ }
+ catch (JsonSerializationException ex)
+ {
+ throw new IOException($"Failed to serialize node '{node.Id}' to JSON", ex);
+ }
+ }
+
+ ///
+ public void AddEdge(GraphEdge edge)
+ {
+ if (edge == null)
+ throw new ArgumentNullException(nameof(edge));
+ if (!_nodeIndex.Contains(edge.SourceId))
+ throw new InvalidOperationException($"Source node '{edge.SourceId}' does not exist");
+ if (!_nodeIndex.Contains(edge.TargetId))
+ throw new InvalidOperationException($"Target node '{edge.TargetId}' does not exist");
+
+ try
+ {
+ // Log to WAL first (durability)
+ _wal?.LogAddEdge(edge);
+
+ // Serialize edge to JSON
+ var json = JsonConvert.SerializeObject(edge, _jsonSettings);
+ var bytes = Encoding.UTF8.GetBytes(json);
+
+ // Get current file position
+ long offset = new FileInfo(_edgesFilePath).Exists ? new FileInfo(_edgesFilePath).Length : 0;
+
+ // Write edge data to file
+ using (var stream = new FileStream(_edgesFilePath, FileMode.Append, FileAccess.Write, FileShare.Read))
+ {
+ // Write length prefix (4 bytes)
+ var lengthBytes = BitConverter.GetBytes(bytes.Length);
+ stream.Write(lengthBytes, 0, 4);
+
+ // Write JSON data
+ stream.Write(bytes, 0, bytes.Length);
+ }
+
+ // Update index
+ _edgeIndex.Add(edge.Id, offset);
+
+ // Update in-memory edge indices (thread-safe)
+ lock (_cacheLock)
+ {
+ if (_outgoingEdges.TryGetValue(edge.SourceId, out var outgoingSet))
+ outgoingSet.Add(edge.Id);
+ if (_incomingEdges.TryGetValue(edge.TargetId, out var incomingSet))
+ incomingSet.Add(edge.Id);
+ }
+
+ // Flush indices periodically
+ if (_edgeIndex.Count % 100 == 0)
+ _edgeIndex.Flush();
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to add edge '{edge.Id}' to file store", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException($"Failed to add edge '{edge.Id}' to file store due to access permissions", ex);
+ }
+ catch (JsonSerializationException ex)
+ {
+ throw new IOException($"Failed to serialize edge '{edge.Id}' to JSON", ex);
+ }
+ }
+
+ ///
+ public GraphNode? GetNode(string nodeId)
+ {
+ if (string.IsNullOrWhiteSpace(nodeId))
+ return null;
+
+ var offset = _nodeIndex.Get(nodeId);
+ if (offset < 0)
+ return null;
+
+ try
+ {
+ using var stream = new FileStream(_nodesFilePath, FileMode.Open, FileAccess.Read, FileShare.Read);
+ stream.Seek(offset, SeekOrigin.Begin);
+
+ // Read length prefix - ensure all 4 bytes are read
+ var lengthBytes = new byte[4];
+ ReadExactly(stream, lengthBytes, 0, 4);
+ var length = BitConverter.ToInt32(lengthBytes, 0);
+
+ // Read JSON data - ensure all bytes are read
+ var jsonBytes = new byte[length];
+ ReadExactly(stream, jsonBytes, 0, length);
+ var json = Encoding.UTF8.GetString(jsonBytes);
+
+ // Deserialize
+ return JsonConvert.DeserializeObject>(json, _jsonSettings);
+ }
+ catch (EndOfStreamException ex)
+ {
+ throw new IOException($"Failed to read node '{nodeId}' from file store - data may be corrupted", ex);
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to read node '{nodeId}' from file store", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException($"Failed to read node '{nodeId}' from file store", ex);
+ }
+ catch (JsonSerializationException ex)
+ {
+ throw new IOException($"Failed to deserialize node '{nodeId}' from JSON", ex);
+ }
+ }
+
+ ///
+ public GraphEdge? GetEdge(string edgeId)
+ {
+ if (string.IsNullOrWhiteSpace(edgeId))
+ return null;
+
+ var offset = _edgeIndex.Get(edgeId);
+ if (offset < 0)
+ return null;
+
+ try
+ {
+ using var stream = new FileStream(_edgesFilePath, FileMode.Open, FileAccess.Read, FileShare.Read);
+ stream.Seek(offset, SeekOrigin.Begin);
+
+ // Read length prefix - ensure all 4 bytes are read
+ var lengthBytes = new byte[4];
+ ReadExactly(stream, lengthBytes, 0, 4);
+ var length = BitConverter.ToInt32(lengthBytes, 0);
+
+ // Read JSON data - ensure all bytes are read
+ var jsonBytes = new byte[length];
+ ReadExactly(stream, jsonBytes, 0, length);
+ var json = Encoding.UTF8.GetString(jsonBytes);
+
+ // Deserialize
+ return JsonConvert.DeserializeObject>(json, _jsonSettings);
+ }
+ catch (EndOfStreamException ex)
+ {
+ throw new IOException($"Failed to read edge '{edgeId}' from file store - data may be corrupted", ex);
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to read edge '{edgeId}' from file store", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException($"Failed to read edge '{edgeId}' from file store", ex);
+ }
+ catch (JsonSerializationException ex)
+ {
+ throw new IOException($"Failed to deserialize edge '{edgeId}' from JSON", ex);
+ }
+ }
+
+ ///
+ public bool RemoveNode(string nodeId)
+ {
+ if (string.IsNullOrWhiteSpace(nodeId) || !_nodeIndex.Contains(nodeId))
+ return false;
+
+ try
+ {
+ var node = GetNode(nodeId);
+ if (node == null)
+ return false;
+
+ // Log to WAL first (durability)
+ _wal?.LogRemoveNode(nodeId);
+
+ // Remove all outgoing edges (thread-safe)
+ if (_outgoingEdges.TryGetValue(nodeId, out var outgoing))
+ {
+ List edgesToRemove;
+ lock (_cacheLock)
+ {
+ edgesToRemove = outgoing.ToList();
+ }
+ foreach (var edgeId in edgesToRemove)
+ {
+ RemoveEdge(edgeId);
+ }
+ _outgoingEdges.TryRemove(nodeId, out _);
+ }
+
+ // Remove all incoming edges (thread-safe)
+ if (_incomingEdges.TryGetValue(nodeId, out var incoming))
+ {
+ List edgesToRemove;
+ lock (_cacheLock)
+ {
+ edgesToRemove = incoming.ToList();
+ }
+ foreach (var edgeId in edgesToRemove)
+ {
+ RemoveEdge(edgeId);
+ }
+ _incomingEdges.TryRemove(nodeId, out _);
+ }
+
+ // Remove from label index (thread-safe)
+ lock (_cacheLock)
+ {
+ if (_nodesByLabel.TryGetValue(node.Label, out var nodeIds))
+ {
+ nodeIds.Remove(nodeId);
+ if (nodeIds.Count == 0)
+ _nodesByLabel.TryRemove(node.Label, out _);
+ }
+ }
+
+ // Remove from node index (marks as deleted, actual data remains)
+ _nodeIndex.Remove(nodeId);
+ _nodeIndex.Flush();
+
+ return true;
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to remove node '{nodeId}' from file store", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException($"Failed to remove node '{nodeId}' from file store", ex);
+ }
+ }
+
+ ///
+ public bool RemoveEdge(string edgeId)
+ {
+ if (string.IsNullOrWhiteSpace(edgeId) || !_edgeIndex.Contains(edgeId))
+ return false;
+
+ try
+ {
+ var edge = GetEdge(edgeId);
+ if (edge == null)
+ return false;
+
+ // Log to WAL first (durability)
+ _wal?.LogRemoveEdge(edgeId);
+
+ // Remove from in-memory indices (thread-safe)
+ lock (_cacheLock)
+ {
+ if (_outgoingEdges.TryGetValue(edge.SourceId, out var outgoing))
+ outgoing.Remove(edgeId);
+ if (_incomingEdges.TryGetValue(edge.TargetId, out var incoming))
+ incoming.Remove(edgeId);
+ }
+
+ // Remove from edge index
+ _edgeIndex.Remove(edgeId);
+ _edgeIndex.Flush();
+
+ return true;
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to remove edge '{edgeId}' from file store", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException($"Failed to remove edge '{edgeId}' from file store", ex);
+ }
+ }
+
+ ///
+ public IEnumerable> GetOutgoingEdges(string nodeId)
+ {
+ if (!_outgoingEdges.TryGetValue(nodeId, out var edgeIds))
+ return Enumerable.Empty>();
+
+ // Take snapshot of edge IDs under lock to avoid concurrent modification
+ List snapshot;
+ lock (_cacheLock)
+ {
+ snapshot = edgeIds.ToList();
+ }
+ return snapshot.Select(id => GetEdge(id)).Where(e => e != null).Cast>();
+ }
+
+ ///
+ public IEnumerable> GetIncomingEdges(string nodeId)
+ {
+ if (!_incomingEdges.TryGetValue(nodeId, out var edgeIds))
+ return Enumerable.Empty>();
+
+ // Take snapshot of edge IDs under lock to avoid concurrent modification
+ List snapshot;
+ lock (_cacheLock)
+ {
+ snapshot = edgeIds.ToList();
+ }
+ return snapshot.Select(id => GetEdge(id)).Where(e => e != null).Cast>();
+ }
+
+ ///
+ public IEnumerable> GetNodesByLabel(string label)
+ {
+ if (!_nodesByLabel.TryGetValue(label, out var nodeIds))
+ return Enumerable.Empty>();
+
+ // Take snapshot of node IDs under lock to avoid concurrent modification
+ List snapshot;
+ lock (_cacheLock)
+ {
+ snapshot = nodeIds.ToList();
+ }
+ return snapshot.Select(id => GetNode(id)).Where(n => n != null).Cast>();
+ }
+
+ ///
+ public IEnumerable> GetAllNodes()
+ {
+ return _nodeIndex.GetAllKeys().Select(id => GetNode(id)).Where(n => n != null).Cast>();
+ }
+
+ ///
+ public IEnumerable> GetAllEdges()
+ {
+ return _edgeIndex.GetAllKeys().Select(id => GetEdge(id)).Where(e => e != null).Cast>();
+ }
+
+ ///
+ public void Clear()
+ {
+ try
+ {
+ // Clear in-memory structures
+ _outgoingEdges.Clear();
+ _incomingEdges.Clear();
+ _nodesByLabel.Clear();
+
+ // Clear indices
+ _nodeIndex.Clear();
+ _edgeIndex.Clear();
+ _nodeIndex.Flush();
+ _edgeIndex.Flush();
+
+ // Delete data files
+ if (File.Exists(_nodesFilePath))
+ File.Delete(_nodesFilePath);
+ if (File.Exists(_edgesFilePath))
+ File.Delete(_edgesFilePath);
+ }
+ catch (IOException ex)
+ {
+ throw new IOException("Failed to clear file store", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException("Failed to clear file store", ex);
+ }
+ }
+
+ ///
+ /// Rebuilds in-memory indices by scanning all nodes and edges.
+ ///
+ ///
+ /// This is called during initialization to restore the in-memory indices
+ /// from persisted data. It scans all nodes to rebuild label indices and
+ /// all edges to rebuild outgoing/incoming edge indices.
+ ///
+ private void RebuildInMemoryIndices()
+ {
+ try
+ {
+ // Rebuild node-related indices (thread-safe using GetOrAdd)
+ foreach (var node in _nodeIndex.GetAllKeys().Select(GetNode).OfType>())
+ {
+ // Rebuild label index
+ var labelSet = _nodesByLabel.GetOrAdd(node.Label, _ => new HashSet());
+ lock (_cacheLock)
+ {
+ labelSet.Add(node.Id);
+ }
+
+ // Initialize edge indices
+ _outgoingEdges.GetOrAdd(node.Id, _ => new HashSet());
+ _incomingEdges.GetOrAdd(node.Id, _ => new HashSet());
+ }
+
+ // Rebuild edge indices (thread-safe)
+ foreach (var edge in _edgeIndex.GetAllKeys().Select(GetEdge).OfType>())
+ {
+ lock (_cacheLock)
+ {
+ if (_outgoingEdges.TryGetValue(edge.SourceId, out var outgoingSet))
+ outgoingSet.Add(edge.Id);
+ if (_incomingEdges.TryGetValue(edge.TargetId, out var incomingSet))
+ incomingSet.Add(edge.Id);
+ }
+ }
+ }
+ catch (IOException ex)
+ {
+ throw new IOException("Failed to rebuild in-memory indices", ex);
+ }
+ catch (InvalidDataException ex)
+ {
+ throw new IOException("Failed to rebuild in-memory indices", ex);
+ }
+ }
+
+ // Async methods for non-blocking I/O operations
+
+ ///
+ public async Task AddNodeAsync(GraphNode node)
+ {
+ if (node == null)
+ throw new ArgumentNullException(nameof(node));
+
+ try
+ {
+ // Log to WAL first (durability)
+ _wal?.LogAddNode(node);
+
+ // Serialize node to JSON
+ var json = JsonConvert.SerializeObject(node, _jsonSettings);
+ var bytes = Encoding.UTF8.GetBytes(json);
+
+ // Get current file position
+ long offset = new FileInfo(_nodesFilePath).Exists ? new FileInfo(_nodesFilePath).Length : 0;
+
+ // Write node data to file asynchronously
+ using (var stream = new FileStream(_nodesFilePath, FileMode.Append, FileAccess.Write, FileShare.Read, 4096, useAsync: true))
+ {
+ // Write length prefix (4 bytes)
+ var lengthBytes = BitConverter.GetBytes(bytes.Length);
+ await stream.WriteAsync(lengthBytes, 0, 4);
+
+ // Write JSON data
+ await stream.WriteAsync(bytes, 0, bytes.Length);
+ }
+
+ // Update index
+ _nodeIndex.Add(node.Id, offset);
+
+ // Update in-memory indices (thread-safe)
+ lock (_cacheLock)
+ {
+ var labelSet = _nodesByLabel.GetOrAdd(node.Label, _ => new HashSet());
+ labelSet.Add(node.Id);
+
+ _outgoingEdges.GetOrAdd(node.Id, _ => new HashSet());
+ _incomingEdges.GetOrAdd(node.Id, _ => new HashSet());
+ }
+
+ // Flush indices periodically
+ if (_nodeIndex.Count % 100 == 0)
+ _nodeIndex.Flush();
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to add node '{node.Id}' to file store", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException($"Failed to add node '{node.Id}' to file store due to access permissions", ex);
+ }
+ catch (JsonSerializationException ex)
+ {
+ throw new IOException($"Failed to serialize node '{node.Id}' to JSON", ex);
+ }
+ }
+
+ ///
+ public async Task AddEdgeAsync(GraphEdge edge)
+ {
+ if (edge == null)
+ throw new ArgumentNullException(nameof(edge));
+ if (!_nodeIndex.Contains(edge.SourceId))
+ throw new InvalidOperationException($"Source node '{edge.SourceId}' does not exist");
+ if (!_nodeIndex.Contains(edge.TargetId))
+ throw new InvalidOperationException($"Target node '{edge.TargetId}' does not exist");
+
+ try
+ {
+ // Log to WAL first (durability)
+ _wal?.LogAddEdge(edge);
+
+ // Serialize edge to JSON
+ var json = JsonConvert.SerializeObject(edge, _jsonSettings);
+ var bytes = Encoding.UTF8.GetBytes(json);
+
+ // Get current file position
+ long offset = new FileInfo(_edgesFilePath).Exists ? new FileInfo(_edgesFilePath).Length : 0;
+
+ // Write edge data to file asynchronously
+ using (var stream = new FileStream(_edgesFilePath, FileMode.Append, FileAccess.Write, FileShare.Read, 4096, useAsync: true))
+ {
+ // Write length prefix (4 bytes)
+ var lengthBytes = BitConverter.GetBytes(bytes.Length);
+ await stream.WriteAsync(lengthBytes, 0, 4);
+
+ // Write JSON data
+ await stream.WriteAsync(bytes, 0, bytes.Length);
+ }
+
+ // Update index
+ _edgeIndex.Add(edge.Id, offset);
+
+ // Update in-memory edge indices (thread-safe)
+ lock (_cacheLock)
+ {
+ if (_outgoingEdges.TryGetValue(edge.SourceId, out var outgoingSet))
+ outgoingSet.Add(edge.Id);
+ if (_incomingEdges.TryGetValue(edge.TargetId, out var incomingSet))
+ incomingSet.Add(edge.Id);
+ }
+
+ // Flush indices periodically
+ if (_edgeIndex.Count % 100 == 0)
+ _edgeIndex.Flush();
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to add edge '{edge.Id}' to file store", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException($"Failed to add edge '{edge.Id}' to file store due to access permissions", ex);
+ }
+ }
+
+ ///
+ public async Task?> GetNodeAsync(string nodeId)
+ {
+ if (string.IsNullOrWhiteSpace(nodeId))
+ return null;
+
+ var offset = _nodeIndex.Get(nodeId);
+ if (offset < 0)
+ return null;
+
+ try
+ {
+ using var stream = new FileStream(_nodesFilePath, FileMode.Open, FileAccess.Read, FileShare.Read, 4096, useAsync: true);
+ stream.Seek(offset, SeekOrigin.Begin);
+
+ // Read length prefix - ensure all 4 bytes are read
+ var lengthBytes = new byte[4];
+ await ReadExactlyAsync(stream, lengthBytes, 0, 4);
+ var length = BitConverter.ToInt32(lengthBytes, 0);
+
+ // Read JSON data - ensure all bytes are read
+ var jsonBytes = new byte[length];
+ await ReadExactlyAsync(stream, jsonBytes, 0, length);
+ var json = Encoding.UTF8.GetString(jsonBytes);
+
+ // Deserialize
+ return JsonConvert.DeserializeObject>(json, _jsonSettings);
+ }
+ catch (EndOfStreamException ex)
+ {
+ throw new IOException($"Failed to read node '{nodeId}' from file store - data may be corrupted", ex);
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to read node '{nodeId}' from file store", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException($"Failed to read node '{nodeId}' from file store", ex);
+ }
+ catch (JsonSerializationException ex)
+ {
+ throw new IOException($"Failed to deserialize node '{nodeId}' from JSON", ex);
+ }
+ }
+
+ ///
+ public async Task?> GetEdgeAsync(string edgeId)
+ {
+ if (string.IsNullOrWhiteSpace(edgeId))
+ return null;
+
+ var offset = _edgeIndex.Get(edgeId);
+ if (offset < 0)
+ return null;
+
+ try
+ {
+ using var stream = new FileStream(_edgesFilePath, FileMode.Open, FileAccess.Read, FileShare.Read, 4096, useAsync: true);
+ stream.Seek(offset, SeekOrigin.Begin);
+
+ // Read length prefix - ensure all 4 bytes are read
+ var lengthBytes = new byte[4];
+ await ReadExactlyAsync(stream, lengthBytes, 0, 4);
+ var length = BitConverter.ToInt32(lengthBytes, 0);
+
+ // Read JSON data - ensure all bytes are read
+ var jsonBytes = new byte[length];
+ await ReadExactlyAsync(stream, jsonBytes, 0, length);
+ var json = Encoding.UTF8.GetString(jsonBytes);
+
+ // Deserialize
+ return JsonConvert.DeserializeObject>(json, _jsonSettings);
+ }
+ catch (EndOfStreamException ex)
+ {
+ throw new IOException($"Failed to read edge '{edgeId}' from file store - data may be corrupted", ex);
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to read edge '{edgeId}' from file store", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException($"Failed to read edge '{edgeId}' from file store", ex);
+ }
+ catch (JsonSerializationException ex)
+ {
+ throw new IOException($"Failed to deserialize edge '{edgeId}' from JSON", ex);
+ }
+ }
+
+ ///
+ public async Task RemoveNodeAsync(string nodeId)
+ {
+ if (string.IsNullOrWhiteSpace(nodeId) || !_nodeIndex.Contains(nodeId))
+ return false;
+
+ try
+ {
+ var node = await GetNodeAsync(nodeId);
+ if (node == null)
+ return false;
+
+ // Log to WAL first (durability)
+ _wal?.LogRemoveNode(nodeId);
+
+ // Remove all outgoing edges (thread-safe)
+ if (_outgoingEdges.TryGetValue(nodeId, out var outgoing))
+ {
+ List edgesToRemove;
+ lock (_cacheLock)
+ {
+ edgesToRemove = outgoing.ToList();
+ }
+ foreach (var edgeId in edgesToRemove)
+ {
+ await RemoveEdgeAsync(edgeId);
+ }
+ _outgoingEdges.TryRemove(nodeId, out _);
+ }
+
+ // Remove all incoming edges (thread-safe)
+ if (_incomingEdges.TryGetValue(nodeId, out var incoming))
+ {
+ List edgesToRemove;
+ lock (_cacheLock)
+ {
+ edgesToRemove = incoming.ToList();
+ }
+ foreach (var edgeId in edgesToRemove)
+ {
+ await RemoveEdgeAsync(edgeId);
+ }
+ _incomingEdges.TryRemove(nodeId, out _);
+ }
+
+ // Remove from label index (thread-safe)
+ lock (_cacheLock)
+ {
+ if (_nodesByLabel.TryGetValue(node.Label, out var nodeIds))
+ {
+ nodeIds.Remove(nodeId);
+ if (nodeIds.Count == 0)
+ _nodesByLabel.TryRemove(node.Label, out _);
+ }
+ }
+
+ // Remove from node index
+ _nodeIndex.Remove(nodeId);
+ _nodeIndex.Flush();
+
+ return true;
+ }
+ catch (IOException ex)
+ {
+ throw new IOException($"Failed to remove node '{nodeId}' from file store", ex);
+ }
+ catch (UnauthorizedAccessException ex)
+ {
+ throw new IOException($"Failed to remove node '{nodeId}' from file store", ex);
+ }
+ }
+
+ ///
+ public Task RemoveEdgeAsync(string edgeId)
+ {
+ return Task.FromResult(RemoveEdge(edgeId));
+ }
+
+ ///
+ public async Task>> GetOutgoingEdgesAsync(string nodeId)
+ {
+ if (!_outgoingEdges.TryGetValue(nodeId, out var edgeIds))
+ return Enumerable.Empty>();
+
+ // Take snapshot of edgeIds to avoid race condition during enumeration
+ List snapshot;
+ lock (_cacheLock)
+ {
+ snapshot = edgeIds.ToList();
+ }
+
+ var tasks = snapshot.Select(id => GetEdgeAsync(id));
+ var results = await Task.WhenAll(tasks);
+ return results.OfType>().ToList();
+ }
+
+ ///
+ public async Task>> GetIncomingEdgesAsync(string nodeId)
+ {
+ if (!_incomingEdges.TryGetValue(nodeId, out var edgeIds))
+ return Enumerable.Empty>();
+
+ // Take snapshot of edgeIds to avoid race condition during enumeration
+ List snapshot;
+ lock (_cacheLock)
+ {
+ snapshot = edgeIds.ToList();
+ }
+
+ var tasks = snapshot.Select(id => GetEdgeAsync(id));
+ var results = await Task.WhenAll(tasks);
+ return results.OfType>().ToList();
+ }
+
+ ///
+ public async Task>> GetNodesByLabelAsync(string label)
+ {
+ if (!_nodesByLabel.TryGetValue(label, out var nodeIds))
+ return Enumerable.Empty>();
+
+ // Take snapshot of nodeIds to avoid race condition during enumeration
+ List snapshot;
+ lock (_cacheLock)
+ {
+ snapshot = nodeIds.ToList();
+ }
+
+ var tasks = snapshot.Select(id => GetNodeAsync(id));
+ var results = await Task.WhenAll(tasks);
+ return results.OfType>().ToList();
+ }
+
+ ///
+ public async Task>> GetAllNodesAsync()
+ {
+ var tasks = _nodeIndex.GetAllKeys().Select(id => GetNodeAsync(id));
+ var results = await Task.WhenAll(tasks);
+ return results.OfType>().ToList();
+ }
+
+ ///
+ public async Task>> GetAllEdgesAsync()
+ {
+ var tasks = _edgeIndex.GetAllKeys().Select(id => GetEdgeAsync(id));
+ var results = await Task.WhenAll(tasks);
+ return results.OfType>().ToList();
+ }
+
+ ///
+ public Task ClearAsync()
+ {
+ Clear();
+ return Task.CompletedTask;
+ }
+
+ ///
+ /// Disposes the file graph store, ensuring all changes are flushed to disk.
+ ///
+ public void Dispose()
+ {
+ if (_disposed)
+ return;
+
+ try
+ {
+ _nodeIndex.Flush();
+ _edgeIndex.Flush();
+ _nodeIndex.Dispose();
+ _edgeIndex.Dispose();
+ }
+ finally
+ {
+ _disposed = true;
+ }
+ }
+
+ ///
+ /// Reads exactly the specified number of bytes from the stream.
+ ///
+ /// The stream to read from.
+ /// The buffer to read into.
+ /// The offset in the buffer to start writing at.
+ /// The number of bytes to read.
+ /// Thrown if the stream ends before all bytes are read.
+ private static void ReadExactly(Stream stream, byte[] buffer, int offset, int count)
+ {
+ int totalRead = 0;
+ while (totalRead < count)
+ {
+ int bytesRead = stream.Read(buffer, offset + totalRead, count - totalRead);
+ if (bytesRead == 0)
+ throw new EndOfStreamException($"Unexpected end of stream. Expected {count} bytes, got {totalRead}.");
+ totalRead += bytesRead;
+ }
+ }
+
+ ///
+ /// Asynchronously reads exactly the specified number of bytes from the stream.
+ ///
+ /// The stream to read from.
+ /// The buffer to read into.
+ /// The offset in the buffer to start writing at.
+ /// The number of bytes to read.
+ /// Thrown if the stream ends before all bytes are read.
+ private static async Task ReadExactlyAsync(Stream stream, byte[] buffer, int offset, int count)
+ {
+ int totalRead = 0;
+ while (totalRead < count)
+ {
+ int bytesRead = await stream.ReadAsync(buffer, offset + totalRead, count - totalRead);
+ if (bytesRead == 0)
+ throw new EndOfStreamException($"Unexpected end of stream. Expected {count} bytes, got {totalRead}.");
+ totalRead += bytesRead;
+ }
+ }
+}
diff --git a/src/RetrievalAugmentedGeneration/Graph/GraphAnalytics.cs b/src/RetrievalAugmentedGeneration/Graph/GraphAnalytics.cs
new file mode 100644
index 000000000..3a28df3d8
--- /dev/null
+++ b/src/RetrievalAugmentedGeneration/Graph/GraphAnalytics.cs
@@ -0,0 +1,740 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace AiDotNet.RetrievalAugmentedGeneration.Graph;
+
+///
+/// Provides graph analytics algorithms for analyzing knowledge graphs.
+///
+///
+///
+/// This class implements common graph algorithms used to analyze the structure
+/// and importance of nodes and edges in a knowledge graph.
+///
+/// For Beginners: Graph analytics help you understand your graph.
+///
+/// Think of a social network:
+/// - PageRank: Who are the most influential people?
+/// - Degree Centrality: Who has the most connections?
+/// - Closeness Centrality: Who can reach everyone quickly?
+/// - Betweenness Centrality: Who connects different groups?
+///
+/// These algorithms answer "who's important?" and "how are things connected?"
+///
+///
+public static class GraphAnalytics
+{
+ ///
+ /// Calculates PageRank scores for all nodes in the graph.
+ ///
+ /// The numeric type used for vector operations.
+ /// The knowledge graph to analyze.
+ /// The damping factor (default: 0.85). Must be between 0 and 1.
+ /// Maximum number of iterations (default: 100).
+ /// Convergence threshold (default: 0.0001).
+ /// Dictionary mapping node IDs to their PageRank scores.
+ ///
+ ///
+ /// PageRank is an algorithm used by Google to rank web pages. In a knowledge graph,
+ /// it identifies the most "important" or "central" nodes based on the structure
+ /// of relationships. Nodes pointed to by many important nodes get higher scores.
+ ///
+ /// For Beginners: PageRank finds the most important nodes.
+ ///
+ /// Imagine a citation network:
+ /// - Papers cited by many other papers get high PageRank
+ /// - Papers cited by highly-cited papers get even higher PageRank
+ /// - It's like asking: "Which papers are most influential?"
+ ///
+ /// The algorithm:
+ /// 1. Start: All nodes have equal rank
+ /// 2. Iterate: Each node shares its rank with nodes it points to
+ /// 3. Damping: 85% of rank flows through edges, 15% jumps randomly
+ /// 4. Repeat until scores stabilize
+ ///
+ /// Higher PageRank = More important/central in the graph
+ ///
+ ///
+ /// Thrown when graph is null.
+ /// Thrown when dampingFactor is not between 0 and 1.
+ public static Dictionary CalculatePageRank(
+ KnowledgeGraph graph,
+ double dampingFactor = 0.85,
+ int maxIterations = 100,
+ double convergenceThreshold = 0.0001)
+ {
+ if (graph == null)
+ throw new ArgumentNullException(nameof(graph));
+ if (dampingFactor < 0 || dampingFactor > 1)
+ throw new ArgumentOutOfRangeException(nameof(dampingFactor), "Damping factor must be between 0 and 1");
+
+ var nodes = graph.GetAllNodes().ToList();
+ if (nodes.Count == 0)
+ return new Dictionary();
+
+ var nodeCount = nodes.Count;
+ var ranks = new Dictionary();
+ var newRanks = new Dictionary();
+
+ // Initialize all nodes with equal rank
+ var initialRank = 1.0 / nodeCount;
+ foreach (var node in nodes)
+ {
+ ranks[node.Id] = initialRank;
+ newRanks[node.Id] = 0.0;
+ }
+
+ // Iterate until convergence or max iterations
+ for (int iteration = 0; iteration < maxIterations; iteration++)
+ {
+ // Calculate new ranks
+ foreach (var node in nodes)
+ {
+ var incomingEdges = graph.GetIncomingEdges(node.Id).ToList();
+ double rankSum = 0.0;
+
+ foreach (var edge in incomingEdges)
+ {
+ var sourceNode = graph.GetNode(edge.SourceId);
+ if (sourceNode != null)
+ {
+ var outgoingCount = graph.GetOutgoingEdges(edge.SourceId).Count();
+ if (outgoingCount > 0)
+ {
+ rankSum += ranks[edge.SourceId] / outgoingCount;
+ }
+ }
+ }
+
+ // PageRank formula: PR(A) = (1-d)/N + d * sum(PR(Ti)/C(Ti))
+ newRanks[node.Id] = (1 - dampingFactor) / nodeCount + dampingFactor * rankSum;
+ }
+
+ // Check for convergence
+ double maxChange = 0.0;
+ foreach (var node in nodes)
+ {
+ var change = Math.Abs(newRanks[node.Id] - ranks[node.Id]);
+ if (change > maxChange)
+ maxChange = change;
+ ranks[node.Id] = newRanks[node.Id];
+ }
+
+ if (maxChange < convergenceThreshold)
+ break;
+ }
+
+ return ranks;
+ }
+
+ ///
+ /// Calculates degree centrality for all nodes in the graph.
+ ///
+ /// The numeric type used for vector operations.
+ /// The knowledge graph to analyze.
+ /// Whether to normalize scores by the maximum possible degree (default: true).
+ /// Dictionary mapping node IDs to their degree centrality scores.
+ ///
+ ///
+ /// Degree centrality is the simplest centrality measure. It counts the number of
+ /// edges connected to each node. Nodes with more connections are considered more central.
+ ///
+ /// For Beginners: Degree centrality counts connections.
+ ///
+ /// In a social network:
+ /// - Person with 100 friends has higher degree centrality than person with 10 friends
+ /// - It's the simplest measure: "Who knows the most people?"
+ ///
+ /// Types:
+ /// - Out-degree: How many edges go OUT (how many people do you follow?)
+ /// - In-degree: How many edges come IN (how many followers do you have?)
+ /// - Total degree: Sum of both
+ ///
+ /// This implementation calculates total degree (in + out).
+ ///
+ /// Normalized: Divides by (N-1) where N is total nodes, giving a score from 0 to 1.
+ ///
+ ///
+ public static Dictionary CalculateDegreeCentrality(
+ KnowledgeGraph graph,
+ bool normalized = true)
+ {
+ if (graph == null)
+ throw new ArgumentNullException(nameof(graph));
+
+ var nodes = graph.GetAllNodes().ToList();
+ var centrality = new Dictionary();
+
+ if (nodes.Count == 0)
+ return centrality;
+
+ foreach (var node in nodes)
+ {
+ var outDegree = graph.GetOutgoingEdges(node.Id).Count();
+ var inDegree = graph.GetIncomingEdges(node.Id).Count();
+ var totalDegree = outDegree + inDegree;
+
+ // Normalize by the maximum possible degree (n-1) for undirected,
+ // or 2(n-1) for directed graphs
+ centrality[node.Id] = (normalized && nodes.Count > 1)
+ ? totalDegree / (2.0 * (nodes.Count - 1))
+ : totalDegree;
+ }
+
+ return centrality;
+ }
+
+ ///
+ /// Calculates closeness centrality for all nodes in the graph.
+ ///
+ /// The numeric type used for vector operations.
+ /// The knowledge graph to analyze.
+ /// Dictionary mapping node IDs to their closeness centrality scores.
+ ///
+ ///
+ /// Closeness centrality measures how close a node is to all other nodes in the graph.
+ /// It's calculated as the inverse of the average shortest path distance to all other nodes.
+ /// Nodes that can quickly reach all others have high closeness centrality.
+ ///
+ /// For Beginners: Closeness centrality measures "how close" you are to everyone.
+ ///
+ /// Think of an airport network:
+ /// - Hub airports (like Atlanta) can reach anywhere with few layovers → high closeness
+ /// - Remote airports need many connections → low closeness
+ ///
+ /// Algorithm:
+ /// 1. For each node, find shortest path to every other node
+ /// 2. Calculate average distance
+ /// 3. Closeness = 1 / average_distance
+ ///
+ /// Higher score = Can reach everyone more quickly
+ ///
+ /// Note: If nodes are unreachable, they're excluded from the calculation.
+ ///
+ ///
+ public static Dictionary CalculateClosenessCentrality(
+ KnowledgeGraph graph)
+ {
+ if (graph == null)
+ throw new ArgumentNullException(nameof(graph));
+
+ var nodes = graph.GetAllNodes().ToList();
+ var centrality = new Dictionary();
+
+ if (nodes.Count <= 1)
+ {
+ foreach (var node in nodes)
+ centrality[node.Id] = 0.0;
+ return centrality;
+ }
+
+ foreach (var node in nodes)
+ {
+ var distances = BreadthFirstSearchDistances(graph, node.Id);
+ var reachableNodes = distances.Values.Where(d => d > 0 && d < int.MaxValue).ToList();
+
+ if (reachableNodes.Count == 0)
+ {
+ centrality[node.Id] = 0.0;
+ }
+ else
+ {
+ var averageDistance = reachableNodes.Average();
+ centrality[node.Id] = (nodes.Count - 1) / (averageDistance * reachableNodes.Count);
+ }
+ }
+
+ return centrality;
+ }
+
+ ///
+ /// Calculates betweenness centrality for all nodes in the graph.
+ ///
+ /// The numeric type used for vector operations.
+ /// The knowledge graph to analyze.
+ /// Whether to normalize scores (default: true).
+ /// Dictionary mapping node IDs to their betweenness centrality scores.
+ ///
+ ///
+ /// Betweenness centrality measures how often a node appears on shortest paths between
+ /// other nodes. Nodes that act as "bridges" between different parts of the graph
+ /// have high betweenness centrality.
+ ///
+ /// For Beginners: Betweenness centrality finds "bridge" nodes.
+ ///
+ /// Think of a transportation network:
+ /// - A bridge connecting two cities has high betweenness
+ /// - Many shortest paths go through the bridge
+ /// - If you remove it, people must take longer routes
+ ///
+ /// In social networks:
+ /// - People connecting different friend groups have high betweenness
+ /// - They're "brokers" or "gatekeepers" of information
+ ///
+ /// Algorithm (simplified Brandes' algorithm):
+ /// 1. For all pairs of nodes, find shortest paths
+ /// 2. Count how many paths go through each node
+ /// 3. Higher count = Higher betweenness
+ ///
+ /// High betweenness = Important connector/bridge in the network
+ ///
+ ///
+ public static Dictionary CalculateBetweennessCentrality(
+ KnowledgeGraph graph,
+ bool normalized = true)
+ {
+ if (graph == null)
+ throw new ArgumentNullException(nameof(graph));
+
+ var nodes = graph.GetAllNodes().ToList();
+ var betweenness = new Dictionary();
+
+ // Initialize all to 0
+ foreach (var node in nodes)
+ betweenness[node.Id] = 0.0;
+
+ if (nodes.Count <= 2)
+ return betweenness;
+
+ // For each node as source
+ foreach (var source in nodes)
+ {
+ var stack = new Stack();
+ var paths = new Dictionary>();
+ var pathCounts = new Dictionary();
+ var distances = new Dictionary();
+ var dependencies = new Dictionary();
+
+ foreach (var node in nodes)
+ {
+ paths[node.Id] = new List();
+ pathCounts[node.Id] = 0;
+ distances[node.Id] = -1;
+ dependencies[node.Id] = 0.0;
+ }
+
+ pathCounts[source.Id] = 1;
+ distances[source.Id] = 0;
+
+ var queue = new Queue();
+ queue.Enqueue(source.Id);
+
+ // BFS
+ while (queue.Count > 0)
+ {
+ var v = queue.Dequeue();
+ stack.Push(v);
+
+ foreach (var edge in graph.GetOutgoingEdges(v))
+ {
+ var w = edge.TargetId;
+ // First time we see w?
+ if (distances[w] < 0)
+ {
+ queue.Enqueue(w);
+ distances[w] = distances[v] + 1;
+ }
+
+ // Shortest path to w via v?
+ if (distances[w] == distances[v] + 1)
+ {
+ pathCounts[w] += pathCounts[v];
+ paths[w].Add(v);
+ }
+ }
+ }
+
+ // Accumulate dependencies
+ while (stack.Count > 0)
+ {
+ var w = stack.Pop();
+ foreach (var v in paths[w])
+ {
+ dependencies[v] += (pathCounts[v] / (double)pathCounts[w]) * (1.0 + dependencies[w]);
+ }
+
+ if (w != source.Id)
+ betweenness[w] += dependencies[w];
+ }
+ }
+
+ // Normalize if requested
+ if (normalized && nodes.Count > 2)
+ {
+ var normFactor = (nodes.Count - 1) * (nodes.Count - 2);
+ foreach (var node in nodes)
+ {
+ betweenness[node.Id] /= normFactor;
+ }
+ }
+
+ return betweenness;
+ }
+
+ ///
+ /// Performs breadth-first search to calculate distances from a source node to all others.
+ ///
+ private static Dictionary BreadthFirstSearchDistances(
+ KnowledgeGraph graph,
+ string sourceId)
+ {
+ var distances = new Dictionary();
+ var nodes = graph.GetAllNodes();
+
+ foreach (var node in nodes)
+ distances[node.Id] = int.MaxValue;
+
+ distances[sourceId] = 0;
+ var queue = new Queue();
+ queue.Enqueue(sourceId);
+
+ while (queue.Count > 0)
+ {
+ var current = queue.Dequeue();
+ var currentDistance = distances[current];
+
+ foreach (var edge in graph.GetOutgoingEdges(current).Where(e => distances[e.TargetId] == int.MaxValue))
+ {
+ distances[edge.TargetId] = currentDistance + 1;
+ queue.Enqueue(edge.TargetId);
+ }
+ }
+
+ return distances;
+ }
+
+ ///
+ /// Identifies the top-k most central nodes based on a centrality measure.
+ ///
+ /// Dictionary of node IDs to centrality scores.
+ /// Number of top nodes to return.
+ /// List of (nodeId, score) tuples for the top-k nodes, ordered by descending score.
+ ///
+ /// For Beginners: This finds the "most important" nodes.
+ ///
+ /// After running PageRank or centrality calculations, use this to get:
+ /// - Top 10 most influential people (PageRank)
+ /// - Top 5 most connected nodes (Degree)
+ /// - Top 3 bridge nodes (Betweenness)
+ ///
+ /// Example:
+ /// ```csharp
+ /// var pageRank = GraphAnalytics.CalculatePageRank(graph);
+ /// var top10 = GraphAnalytics.GetTopKNodes(pageRank, 10);
+ /// // Returns the 10 most important nodes with their scores
+ /// ```
+ ///
+ ///
+ public static List<(string NodeId, double Score)> GetTopKNodes(
+ Dictionary centrality,
+ int k)
+ {
+ if (centrality == null)
+ throw new ArgumentNullException(nameof(centrality));
+
+ return centrality
+ .OrderByDescending(kvp => kvp.Value)
+ .Take(k)
+ .Select(kvp => (kvp.Key, kvp.Value))
+ .ToList();
+ }
+
+ ///
+ /// Finds all connected components in the graph.
+ ///
+ /// The numeric type used for vector operations.
+ /// The knowledge graph to analyze.
+ /// List of connected components, each containing a set of node IDs.
+ ///
+ ///
+ /// A connected component is a maximal subgraph where every node can reach every other node.
+ /// This helps identify isolated clusters or communities in the graph.
+ ///
+ /// For Beginners: Connected components find separate "islands" in your graph.
+ ///
+ /// Think of a social network:
+ /// - Component 1: Alice's friend group (all connected to each other)
+ /// - Component 2: Bob's friend group (completely separate from Alice's)
+ /// - Component 3: Isolated person Charlie (no connections)
+ ///
+ /// Uses:
+ /// - Find isolated communities
+ /// - Detect fragmented knowledge bases
+ /// - Identify separate discussion topics
+ /// - Check if graph is fully connected
+ ///
+ /// Algorithm (Union-Find/DFS):
+ /// 1. Start with first unvisited node
+ /// 2. Find all nodes reachable from it (BFS/DFS)
+ /// 3. That's one component
+ /// 4. Repeat for remaining unvisited nodes
+ ///
+ /// Result: List of components, each containing node IDs in that island.
+ ///
+ ///
+ public static List> FindConnectedComponents(KnowledgeGraph graph)
+ {
+ if (graph == null)
+ throw new ArgumentNullException(nameof(graph));
+
+ var nodes = graph.GetAllNodes().ToList();
+ var visited = new HashSet();
+ var components = new List>();
+
+ foreach (var node in nodes.Where(n => !visited.Contains(n.Id)))
+ {
+ var component = new HashSet();
+ var queue = new Queue();
+ queue.Enqueue(node.Id);
+ visited.Add(node.Id);
+
+ while (queue.Count > 0)
+ {
+ var current = queue.Dequeue();
+ component.Add(current);
+
+ // Check outgoing edges
+ foreach (var edge in graph.GetOutgoingEdges(current).Where(e => !visited.Contains(e.TargetId)))
+ {
+ visited.Add(edge.TargetId);
+ queue.Enqueue(edge.TargetId);
+ }
+
+ // Check incoming edges (for undirected behavior)
+ foreach (var edge in graph.GetIncomingEdges(current).Where(e => !visited.Contains(e.SourceId)))
+ {
+ visited.Add(edge.SourceId);
+ queue.Enqueue(edge.SourceId);
+ }
+ }
+
+ components.Add(component);
+ }
+
+ return components;
+ }
+
+ ///
+ /// Detects communities using Label Propagation algorithm.
+ ///
+ /// The numeric type used for vector operations.
+ /// The knowledge graph to analyze.
+ /// Maximum number of iterations (default: 100).
+ /// Dictionary mapping node IDs to their community labels.
+ ///
+ ///
+ /// Label Propagation is a fast community detection algorithm. Each node starts with
+ /// a unique label, then iteratively adopts the most common label among its neighbors.
+ /// Nodes in the same community will converge to the same label.
+ ///
+ /// For Beginners: Label Propagation finds groups of nodes that cluster together.
+ ///
+ /// Imagine a party where people wear colored hats:
+ /// 1. Everyone starts with a random color
+ /// 2. Every minute, you look at your friends' hats
+ /// 3. You change to the most popular color among your friends
+ /// 4. After a while, friend groups wear the same color
+ ///
+ /// In graphs:
+ /// - Start: Each node has unique label
+ /// - Iterate: Each node adopts most common neighbor label
+ /// - Converge: Nodes in same community have same label
+ ///
+ /// Why it works:
+ /// - Densely connected nodes influence each other
+ /// - They converge to the same label
+ /// - Weakly connected nodes drift apart
+ ///
+ /// Result: Community labels (numbers) for each node.
+ /// Nodes with the same label are in the same community.
+ ///
+ /// Fast: O(k * E) where k = iterations, E = edges
+ /// Great for: Large graphs, quick community detection
+ ///
+ ///
+ public static Dictionary DetectCommunitiesLabelPropagation(
+ KnowledgeGraph graph,
+ int maxIterations = 100)
+ {
+ if (graph == null)
+ throw new ArgumentNullException(nameof(graph));
+
+ var nodes = graph.GetAllNodes().ToList();
+ var labels = new Dictionary();
+ var random = new Random();
+
+ // Initialize each node with unique label
+ for (int i = 0; i < nodes.Count; i++)
+ {
+ labels[nodes[i].Id] = i;
+ }
+
+ // Iterate until convergence or max iterations
+ for (int iteration = 0; iteration < maxIterations; iteration++)
+ {
+ bool changed = false;
+
+ // Process nodes in random order
+ var shuffledNodes = nodes.OrderBy(n => random.Next()).ToList();
+
+ foreach (var node in shuffledNodes)
+ {
+ // Get labels of all neighbors
+ var neighborLabels = new List();
+
+ foreach (var edge in graph.GetOutgoingEdges(node.Id))
+ {
+ if (labels.TryGetValue(edge.TargetId, out var targetLabel))
+ neighborLabels.Add(targetLabel);
+ }
+
+ foreach (var edge in graph.GetIncomingEdges(node.Id))
+ {
+ if (labels.TryGetValue(edge.SourceId, out var sourceLabel))
+ neighborLabels.Add(sourceLabel);
+ }
+
+ if (neighborLabels.Count == 0)
+ continue;
+
+ // Find most common label
+ var labelCounts = neighborLabels
+ .GroupBy(l => l)
+ .Select(g => new { Label = g.Key, Count = g.Count() })
+ .OrderByDescending(x => x.Count)
+ .ThenBy(x => random.Next()) // Random tie-breaking
+ .First();
+
+ // Update label if different
+ if (labels[node.Id] != labelCounts.Label)
+ {
+ labels[node.Id] = labelCounts.Label;
+ changed = true;
+ }
+ }
+
+ // Converged if no changes
+ if (!changed)
+ break;
+ }
+
+ return labels;
+ }
+
+ ///
+ /// Calculates the clustering coefficient for each node.
+ ///
+ /// The numeric type used for vector operations.
+ /// The knowledge graph to analyze.
+ /// Dictionary mapping node IDs to their clustering coefficients (0 to 1).
+ ///
+ ///
+ /// The clustering coefficient measures how well a node's neighbors are connected to each other.
+ /// A high coefficient means the node is part of a tightly-knit cluster.
+ ///
+ /// For Beginners: Clustering coefficient measures how "clique-like" connections are.
+ ///
+ /// Think of your friend group:
+ /// - High clustering: Your friends all know each other (tight group)
+ /// - Low clustering: Your friends don't know each other (you're the hub)
+ ///
+ /// Formula:
+ /// - Count how many of your friends are friends with each other
+ /// - Divide by maximum possible friendships between them
+ /// - Result: 0 (no friends know each other) to 1 (everyone knows everyone)
+ ///
+ /// Example:
+ /// - You have 3 friends: Alice, Bob, Charlie
+ /// - Maximum connections between them: 3 (Alice-Bob, Bob-Charlie, Alice-Charlie)
+ /// - Actual connections: 2 (Alice-Bob, Bob-Charlie)
+ /// - Clustering coefficient: 2/3 = 0.67
+ ///
+ /// Uses:
+ /// - Identify tight communities
+ /// - Find nodes embedded in groups vs connectors between groups
+ /// - Measure graph's overall "cliquishness"
+ ///
+ /// High coefficient = Node in dense cluster
+ /// Low coefficient = Node connects different groups
+ ///
+ ///
+ public static Dictionary CalculateClusteringCoefficient(
+ KnowledgeGraph graph)
+ {
+ if (graph == null)
+ throw new ArgumentNullException(nameof(graph));
+
+ var nodes = graph.GetAllNodes().ToList();
+ var coefficients = new Dictionary();
+
+ foreach (var node in nodes)
+ {
+ var neighbors = new HashSet();
+
+ // Collect all neighbors (treat as undirected)
+ foreach (var edge in graph.GetOutgoingEdges(node.Id))
+ neighbors.Add(edge.TargetId);
+ foreach (var edge in graph.GetIncomingEdges(node.Id))
+ neighbors.Add(edge.SourceId);
+
+ if (neighbors.Count < 2)
+ {
+ coefficients[node.Id] = 0.0;
+ continue;
+ }
+
+ // Count connections between neighbors
+ int connectedPairs = 0;
+ var neighborList = neighbors.ToList();
+
+ for (int i = 0; i < neighborList.Count; i++)
+ {
+ for (int j = i + 1; j < neighborList.Count; j++)
+ {
+ var n1 = neighborList[i];
+ var n2 = neighborList[j];
+
+ // Check if n1 and n2 are connected
+ bool connected = graph.GetOutgoingEdges(n1).Any(e => e.TargetId == n2) ||
+ graph.GetOutgoingEdges(n2).Any(e => e.TargetId == n1);
+
+ if (connected)
+ connectedPairs++;
+ }
+ }
+
+ // Clustering coefficient = actual connections / possible connections
+ int possiblePairs = neighbors.Count * (neighbors.Count - 1) / 2;
+ coefficients[node.Id] = (double)connectedPairs / possiblePairs;
+ }
+
+ return coefficients;
+ }
+
+ ///
+ /// Calculates the average clustering coefficient for the entire graph.
+ ///
+ /// The numeric type used for vector operations.
+ /// The knowledge graph to analyze.
+ /// The average clustering coefficient (0 to 1).
+ ///
+ /// For Beginners: This measures how "clustered" the entire graph is.
+ ///
+ /// - Close to 1: Graph has many tight groups (like friend circles)
+ /// - Close to 0: Graph is sparse, few triangles (like a tree)
+ ///
+ /// Compare to random graphs:
+ /// - Random graph: Low clustering (~0.01 for large graphs)
+ /// - Real social networks: High clustering (~0.3-0.6)
+ /// - Small-world networks: High clustering + short paths
+ ///
+ /// This is one measure of graph structure used in network science.
+ ///
+ ///
+ public static double CalculateAverageClusteringCoefficient(KnowledgeGraph graph)
+ {
+ var coefficients = CalculateClusteringCoefficient(graph);
+ return coefficients.Count > 0 ? coefficients.Values.Average() : 0.0;
+ }
+}
diff --git a/src/RetrievalAugmentedGeneration/Graph/GraphEdge.cs b/src/RetrievalAugmentedGeneration/Graph/GraphEdge.cs
index 69c9ab743..af757cbd3 100644
--- a/src/RetrievalAugmentedGeneration/Graph/GraphEdge.cs
+++ b/src/RetrievalAugmentedGeneration/Graph/GraphEdge.cs
@@ -113,11 +113,50 @@ public void SetProperty(string key, object value)
///
/// The expected type of the property value.
/// The property key.
- /// The property value, or default if not found.
+ /// The property value, or default if not found or conversion fails.
+ ///
+ ///
+ /// This method handles JSON deserialization quirks where numeric types may differ
+ /// (e.g., int stored as long after JSON round-trip). It uses Convert.ChangeType
+ /// for IConvertible types to handle such conversions gracefully.
+ ///
+ ///
+ /// The method catches and handles the following exceptions during conversion:
+ /// - InvalidCastException: When the types are incompatible
+ /// - FormatException: When the string representation is invalid
+ /// - OverflowException: When the value is outside the target type's range
+ ///
+ ///
public TValue? GetProperty(string key)
{
- if (Properties.TryGetValue(key, out var value) && value is TValue typedValue)
+ if (!Properties.TryGetValue(key, out var value) || value == null)
+ return default;
+
+ // Direct type match
+ if (value is TValue typedValue)
return typedValue;
+
+ // Handle numeric type conversions (JSON deserializes integers as long)
+ if (value is IConvertible && typeof(TValue).IsValueType)
+ {
+ try
+ {
+ return (TValue)Convert.ChangeType(value, typeof(TValue));
+ }
+ catch (InvalidCastException)
+ {
+ return default;
+ }
+ catch (FormatException)
+ {
+ return default;
+ }
+ catch (OverflowException)
+ {
+ return default;
+ }
+ }
+
return default;
}
diff --git a/src/RetrievalAugmentedGeneration/Graph/GraphNode.cs b/src/RetrievalAugmentedGeneration/Graph/GraphNode.cs
index 1a5de25b5..dbc16828e 100644
--- a/src/RetrievalAugmentedGeneration/Graph/GraphNode.cs
+++ b/src/RetrievalAugmentedGeneration/Graph/GraphNode.cs
@@ -99,11 +99,50 @@ public void SetProperty(string key, object value)
///
/// The expected type of the property value.
/// The property key.
- /// The property value, or default if not found.
+ /// The property value, or default if not found or conversion fails.
+ ///
+ ///
+ /// This method handles JSON deserialization quirks where numeric types may differ
+ /// (e.g., int stored as long after JSON round-trip). It uses Convert.ChangeType
+ /// for IConvertible types to handle such conversions gracefully.
+ ///
+ ///
+ /// The method catches and handles the following exceptions during conversion:
+ /// - InvalidCastException: When the types are incompatible
+ /// - FormatException: When the string representation is invalid
+ /// - OverflowException: When the value is outside the target type's range
+ ///
+ ///
public TValue? GetProperty(string key)
{
- if (Properties.TryGetValue(key, out var value) && value is TValue typedValue)
+ if (!Properties.TryGetValue(key, out var value) || value == null)
+ return default;
+
+ // Direct type match
+ if (value is TValue typedValue)
return typedValue;
+
+ // Handle numeric type conversions (JSON deserializes integers as long)
+ if (value is IConvertible && typeof(TValue).IsValueType)
+ {
+ try
+ {
+ return (TValue)Convert.ChangeType(value, typeof(TValue));
+ }
+ catch (InvalidCastException)
+ {
+ return default;
+ }
+ catch (FormatException)
+ {
+ return default;
+ }
+ catch (OverflowException)
+ {
+ return default;
+ }
+ }
+
return default;
}
diff --git a/src/RetrievalAugmentedGeneration/Graph/GraphQueryMatcher.cs b/src/RetrievalAugmentedGeneration/Graph/GraphQueryMatcher.cs
new file mode 100644
index 000000000..5b8ff9294
--- /dev/null
+++ b/src/RetrievalAugmentedGeneration/Graph/GraphQueryMatcher.cs
@@ -0,0 +1,433 @@
+using System;
+using System.Collections.Generic;
+using System.Globalization;
+using System.Linq;
+using System.Text.RegularExpressions;
+
+namespace AiDotNet.RetrievalAugmentedGeneration.Graph;
+
+///
+/// Simple pattern matching for graph queries (inspired by Cypher/SPARQL but simplified).
+///
+/// The numeric type used for vector operations.
+///
+///
+/// Supports basic graph pattern matching queries like:
+/// - (Person)-[KNOWS]->(Person)
+/// - (Person {name: "Alice"})-[WORKS_AT]->(Company)
+/// - (a:Person)-[r:KNOWS]->(b:Person)
+///
+/// For Beginners: Pattern matching is like SQL for graphs.
+///
+/// SQL Example:
+/// ```sql
+/// SELECT * FROM persons WHERE name = 'Alice'
+/// ```
+///
+/// Graph Pattern Example:
+/// ```
+/// (Person {name: "Alice"})-[KNOWS]->(Person)
+/// ```
+/// Meaning: Find all people that Alice knows
+///
+/// Another Example:
+/// ```
+/// (Person)-[WORKS_AT]->(Company {name: "Google"})
+/// ```
+/// Meaning: Find all people who work at Google
+///
+/// This is much more natural for relationship-heavy data!
+///
+///
+public class GraphQueryMatcher
+{
+ private readonly KnowledgeGraph _graph;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The knowledge graph to query.
+ public GraphQueryMatcher(KnowledgeGraph graph)
+ {
+ _graph = graph ?? throw new ArgumentNullException(nameof(graph));
+ }
+
+ ///
+ /// Finds nodes matching a label and optional property filters.
+ ///
+ /// The node label to match.
+ /// Optional property filters.
+ /// List of matching nodes.
+ ///
+ /// FindNodes("Person", new Dictionary<string, object> { { "name", "Alice" } })
+ ///
+ public List> FindNodes(string label, Dictionary? properties = null)
+ {
+ if (string.IsNullOrWhiteSpace(label))
+ throw new ArgumentException("Label cannot be null or whitespace", nameof(label));
+
+ var nodes = _graph.GetNodesByLabel(label).ToList();
+
+ if (properties == null || properties.Count == 0)
+ return nodes;
+
+ // Filter by properties
+ return nodes.Where(node =>
+ {
+ foreach (var (key, value) in properties)
+ {
+ if (!node.Properties.TryGetValue(key, out var nodeValue))
+ return false;
+
+ // Simple equality check
+ if (!AreEqual(nodeValue, value))
+ return false;
+ }
+ return true;
+ }).ToList();
+ }
+
+ ///
+ /// Finds paths matching a pattern: (source label)-[relationship type]->(target label).
+ ///
+ /// The source node label.
+ /// The relationship type.
+ /// The target node label.
+ /// Optional source node property filters.
+ /// Optional target node property filters.
+ /// List of matching paths.
+ ///
+ /// FindPaths("Person", "KNOWS", "Person") // Find all KNOWS relationships
+ /// FindPaths("Person", "WORKS_AT", "Company",
+ /// new Dictionary<string, object> { { "name", "Alice" } },
+ /// new Dictionary<string, object> { { "name", "Google" } })
+ ///
+ public List> FindPaths(
+ string sourceLabel,
+ string relationshipType,
+ string targetLabel,
+ Dictionary? sourceProperties = null,
+ Dictionary? targetProperties = null)
+ {
+ if (string.IsNullOrWhiteSpace(sourceLabel))
+ throw new ArgumentException("Source label cannot be null or whitespace", nameof(sourceLabel));
+ if (string.IsNullOrWhiteSpace(relationshipType))
+ throw new ArgumentException("Relationship type cannot be null or whitespace", nameof(relationshipType));
+ if (string.IsNullOrWhiteSpace(targetLabel))
+ throw new ArgumentException("Target label cannot be null or whitespace", nameof(targetLabel));
+
+ var results = new List>();
+
+ // Find matching source nodes
+ var sourceNodes = FindNodes(sourceLabel, sourceProperties);
+
+ foreach (var sourceNode in sourceNodes)
+ {
+ // Get outgoing edges
+ var edges = _graph.GetOutgoingEdges(sourceNode.Id)
+ .Where(e => e.RelationType == relationshipType);
+
+ foreach (var edge in edges)
+ {
+ var targetNode = _graph.GetNode(edge.TargetId);
+ if (targetNode == null
+ || targetNode.Label != targetLabel
+ || (targetProperties != null && targetProperties.Count > 0 && !MatchesProperties(targetNode, targetProperties)))
+ {
+ continue;
+ }
+
+ // Found a match!
+ results.Add(new GraphPath
+ {
+ SourceNode = sourceNode,
+ Edge = edge,
+ TargetNode = targetNode
+ });
+ }
+ }
+
+ return results;
+ }
+
+ ///
+ /// Finds all paths of specified length from a source node.
+ ///
+ /// The source node ID.
+ /// The path length (number of hops).
+ /// Optional relationship type filter.
+ /// List of paths.
+ public List>> FindPathsOfLength(
+ string sourceId,
+ int pathLength,
+ string? relationshipType = null)
+ {
+ if (string.IsNullOrWhiteSpace(sourceId))
+ throw new ArgumentException("Source ID cannot be null or whitespace", nameof(sourceId));
+ if (pathLength <= 0)
+ throw new ArgumentOutOfRangeException(nameof(pathLength), "Path length must be positive");
+
+ var results = new List>>();
+ var currentPaths = new List>>();
+
+ var sourceNode = _graph.GetNode(sourceId);
+ if (sourceNode == null)
+ return results;
+
+ // Initialize with source node
+ currentPaths.Add(new List> { sourceNode });
+
+ // BFS expansion
+ for (int depth = 0; depth < pathLength; depth++)
+ {
+ var nextPaths = new List>>();
+
+ foreach (var path in currentPaths)
+ {
+ var lastNode = path[^1];
+ var edges = _graph.GetOutgoingEdges(lastNode.Id);
+
+ // Filter by relationship type if specified
+ if (!string.IsNullOrWhiteSpace(relationshipType))
+ {
+ edges = edges.Where(e => e.RelationType == relationshipType);
+ }
+
+ // Map edges to target nodes, filter nulls and cycles, then create new paths
+ var newPaths = edges
+ .Select(edge => _graph.GetNode(edge.TargetId))
+ .OfType>()
+ .Where(targetNode => !path.Any(n => n.Id == targetNode.Id))
+ .Select(targetNode => new List>(path) { targetNode });
+ nextPaths.AddRange(newPaths);
+ }
+
+ currentPaths = nextPaths;
+ }
+
+ return currentPaths;
+ }
+
+ ///
+ /// Finds all shortest paths between two nodes.
+ ///
+ /// The source node ID.
+ /// The target node ID.
+ /// Maximum depth to search (prevents infinite loops).
+ /// List of shortest paths.
+ public List>> FindShortestPaths(
+ string sourceId,
+ string targetId,
+ int maxDepth = 10)
+ {
+ if (string.IsNullOrWhiteSpace(sourceId))
+ throw new ArgumentException("Source ID cannot be null or whitespace", nameof(sourceId));
+ if (string.IsNullOrWhiteSpace(targetId))
+ throw new ArgumentException("Target ID cannot be null or whitespace", nameof(targetId));
+
+ var sourceNode = _graph.GetNode(sourceId);
+ var targetNode = _graph.GetNode(targetId);
+
+ if (sourceNode == null || targetNode == null)
+ return new List>>();
+
+ if (sourceId == targetId)
+ return new List>> { new List> { sourceNode } };
+
+ // BFS to find shortest paths
+ var queue = new Queue>>();
+ var visitedAtDepth = new Dictionary();
+ var results = new List>>();
+ var shortestLength = int.MaxValue;
+
+ queue.Enqueue(new List> { sourceNode });
+ visitedAtDepth[sourceId] = 0;
+
+ while (queue.Count > 0)
+ {
+ var path = queue.Dequeue();
+ var currentNode = path[^1];
+ var currentDepth = path.Count - 1;
+
+ // Check if we've exceeded max depth
+ if (path.Count > maxDepth)
+ continue;
+
+ // If we found longer paths than shortest, stop processing this path
+ if (path.Count > shortestLength)
+ continue;
+
+ // Get neighbors - map edges to target nodes and filter nulls
+ var neighbors = _graph.GetOutgoingEdges(currentNode.Id)
+ .Select(edge => _graph.GetNode(edge.TargetId))
+ .OfType>();
+
+ foreach (var neighbor in neighbors)
+ {
+ // Check if we found target
+ if (neighbor.Id == targetId)
+ {
+ var newPath = new List>(path) { neighbor };
+ results.Add(newPath);
+ shortestLength = Math.Min(shortestLength, newPath.Count);
+ continue;
+ }
+
+ // Avoid cycles
+ if (path.Any(n => n.Id == neighbor.Id))
+ continue;
+
+ // Continue exploring - allow visiting at same or earlier depth for shortest paths
+ var newDepth = currentDepth + 1;
+ if (!visitedAtDepth.TryGetValue(neighbor.Id, out var prevDepth) || newDepth <= prevDepth)
+ {
+ visitedAtDepth[neighbor.Id] = newDepth;
+ var newPath = new List>(path) { neighbor };
+ queue.Enqueue(newPath);
+ }
+ }
+ }
+
+ return results.Where(p => p.Count == shortestLength).ToList();
+ }
+
+ ///
+ /// Executes a simple pattern query string.
+ ///
+ /// The pattern string (simplified Cypher-like syntax).
+ /// List of matching paths.
+ ///
+ /// ExecutePattern("(Person)-[KNOWS]->(Person)")
+ /// ExecutePattern("(Person {name: \"Alice\"})-[WORKS_AT]->(Company)")
+ ///
+ public List> ExecutePattern(string pattern)
+ {
+ if (string.IsNullOrWhiteSpace(pattern))
+ throw new ArgumentException("Pattern cannot be null or whitespace", nameof(pattern));
+
+ // Simple regex-based pattern parser
+ // Pattern: (SourceLabel {prop: "value"})-[RELATIONSHIP]->(TargetLabel {prop: "value"})
+ var regex = new Regex(@"\((\w+)(?:\s*\{([^}]+)\})?\)-\[(\w+)\]->\((\w+)(?:\s*\{([^}]+)\})?\)");
+ var match = regex.Match(pattern);
+
+ if (!match.Success)
+ throw new ArgumentException($"Invalid pattern format: {pattern}. Expected format: (SourceLabel)-[RELATIONSHIP]->(TargetLabel)", nameof(pattern));
+
+ var sourceLabel = match.Groups[1].Value;
+ var sourcePropsStr = match.Groups[2].Value;
+ var relationshipType = match.Groups[3].Value;
+ var targetLabel = match.Groups[4].Value;
+ var targetPropsStr = match.Groups[5].Value;
+
+ var sourceProps = ParseProperties(sourcePropsStr);
+ var targetProps = ParseProperties(targetPropsStr);
+
+ return FindPaths(sourceLabel, relationshipType, targetLabel, sourceProps, targetProps);
+ }
+
+ ///
+ /// Parses property string into dictionary.
+ ///
+ /// String like: name: "Alice", age: 30
+ private Dictionary? ParseProperties(string propsString)
+ {
+ if (string.IsNullOrWhiteSpace(propsString))
+ return null;
+
+ var props = propsString.Split(',')
+ .Select(pair => pair.Split(':'))
+ .Where(parts => parts.Length == 2)
+ .Select(parts =>
+ {
+ var key = parts[0].Trim();
+ var valueStr = parts[1].Trim().Trim('"', '\'');
+ object value;
+ if (int.TryParse(valueStr, out var intValue))
+ value = intValue;
+ else if (double.TryParse(valueStr, NumberStyles.Float, CultureInfo.InvariantCulture, out var doubleValue))
+ value = doubleValue;
+ else
+ value = valueStr;
+ return new KeyValuePair(key, value);
+ })
+ .ToDictionary(kv => kv.Key, kv => kv.Value);
+
+ return props.Count > 0 ? props : null;
+ }
+
+ ///
+ /// Checks if a node matches property filters.
+ ///
+ private bool MatchesProperties(GraphNode node, Dictionary properties)
+ {
+ foreach (var kvp in properties)
+ {
+ if (!node.Properties.TryGetValue(kvp.Key, out var nodeValue))
+ return false;
+
+ if (!AreEqual(nodeValue, kvp.Value))
+ return false;
+ }
+ return true;
+ }
+
+ ///
+ /// Compares two objects for equality.
+ ///
+ private bool AreEqual(object obj1, object obj2)
+ {
+ if (obj1 == null && obj2 == null)
+ return true;
+ if (obj1 == null || obj2 == null)
+ return false;
+
+ // Handle numeric comparisons with tolerance for floating-point values
+ if (IsNumeric(obj1) && IsNumeric(obj2))
+ {
+ var d1 = Convert.ToDouble(obj1);
+ var d2 = Convert.ToDouble(obj2);
+ const double tolerance = 1e-10;
+ return Math.Abs(d1 - d2) < tolerance;
+ }
+
+ return obj1.Equals(obj2);
+ }
+
+ ///
+ /// Checks if an object is numeric.
+ ///
+ private bool IsNumeric(object obj)
+ {
+ return obj is int or long or float or double or decimal;
+ }
+}
+
+///
+/// Represents a path in the graph: source node -> edge -> target node.
+///
+/// The numeric type.
+public class GraphPath
+{
+ ///
+ /// Gets or sets the source node.
+ ///
+ public GraphNode SourceNode { get; set; } = null!;
+
+ ///
+ /// Gets or sets the edge connecting source to target.
+ ///
+ public GraphEdge Edge { get; set; } = null!;
+
+ ///
+ /// Gets or sets the target node.
+ ///
+ public GraphNode TargetNode { get; set; } = null!;
+
+ ///
+ /// Returns a string representation of the path.
+ ///
+ public override string ToString()
+ {
+ return $"({SourceNode.Label}:{SourceNode.Id})-[{Edge.RelationType}]->({TargetNode.Label}:{TargetNode.Id})";
+ }
+}
diff --git a/src/RetrievalAugmentedGeneration/Graph/GraphTransaction.cs b/src/RetrievalAugmentedGeneration/Graph/GraphTransaction.cs
new file mode 100644
index 000000000..8af33426c
--- /dev/null
+++ b/src/RetrievalAugmentedGeneration/Graph/GraphTransaction.cs
@@ -0,0 +1,446 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.IO;
+
+namespace AiDotNet.RetrievalAugmentedGeneration.Graph;
+
+///
+/// Transaction coordinator for managing transactions on graph stores with best-effort rollback.
+///
+/// The numeric type used for vector operations.
+///
+///
+/// This class provides transaction management with the following guarantees:
+/// - Atomicity (Best-Effort): If an operation fails during commit, compensating rollback
+/// is attempted in reverse order. However, if an undo operation fails, it is swallowed and
+/// rollback continues with remaining operations. Full atomicity is not guaranteed.
+/// - Consistency: Graph validation rules are enforced during operations.
+/// - Isolation: Single-threaded; no concurrent transaction support.
+/// - Durability: When a WAL is provided, operations are logged before execution.
+/// Without a WAL, durability is not guaranteed.
+///
+///
+/// Important: This is a lightweight transaction implementation suitable for single-process
+/// use cases. For full ACID compliance with crash recovery, ensure a WriteAheadLog is configured.
+///
+/// For Beginners: Transactions ensure your changes are safe.
+///
+/// Think of a bank transfer:
+/// - Debit $100 from Alice
+/// - Credit $100 to Bob
+///
+/// Without transactions:
+/// - If crash happens after debit but before credit, $100 disappears!
+///
+/// With transactions:
+/// - Begin transaction
+/// - Debit Alice
+/// - Credit Bob
+/// - Commit (both succeed) OR Rollback (both undone)
+/// - Money never disappears!
+///
+/// In graphs:
+/// ```csharp
+/// var txn = new GraphTransaction(store, wal);
+/// txn.Begin();
+/// try
+/// {
+/// txn.AddNode(node1);
+/// txn.AddEdge(edge1);
+/// txn.Commit(); // Both saved
+/// }
+/// catch
+/// {
+/// txn.Rollback(); // Both undone
+/// }
+/// ```
+///
+/// This ensures your graph is never in a broken state!
+///
+///
+public class GraphTransaction : IDisposable
+{
+ private readonly IGraphStore _store;
+ private readonly WriteAheadLog? _wal;
+ private readonly List> _operations;
+ private TransactionState _state;
+ private long _transactionId;
+ private bool _disposed;
+
+ ///
+ /// Gets the current state of the transaction.
+ ///
+ public TransactionState State => _state;
+
+ ///
+ /// Gets the transaction ID.
+ ///
+ public long TransactionId => _transactionId;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The graph store to operate on.
+ /// Optional Write-Ahead Log for durability.
+ public GraphTransaction(IGraphStore store, WriteAheadLog? wal = null)
+ {
+ _store = store ?? throw new ArgumentNullException(nameof(store));
+ _wal = wal;
+ _operations = new List>();
+ _state = TransactionState.NotStarted;
+ _transactionId = -1;
+ }
+
+ ///
+ /// Begins a new transaction.
+ ///
+ /// Thrown if transaction already started.
+ public void Begin()
+ {
+ if (_state != TransactionState.NotStarted)
+ throw new InvalidOperationException($"Transaction already in state: {_state}");
+
+ _state = TransactionState.Active;
+ _transactionId = DateTime.UtcNow.Ticks; // Simple ID generation
+ _operations.Clear();
+ }
+
+ ///
+ /// Adds a node within the transaction.
+ ///
+ /// The node to add.
+ public void AddNode(GraphNode node)
+ {
+ EnsureActive();
+
+ _operations.Add(new TransactionOperation
+ {
+ Type = GraphOperationType.AddNode,
+ Node = node
+ });
+ }
+
+ ///
+ /// Adds an edge within the transaction.
+ ///
+ /// The edge to add.
+ public void AddEdge(GraphEdge edge)
+ {
+ EnsureActive();
+
+ _operations.Add(new TransactionOperation
+ {
+ Type = GraphOperationType.AddEdge,
+ Edge = edge
+ });
+ }
+
+ ///
+ /// Removes a node within the transaction.
+ ///
+ /// The ID of the node to remove.
+ public void RemoveNode(string nodeId)
+ {
+ EnsureActive();
+
+ // Capture original node data for potential undo
+ var originalNode = _store.GetNode(nodeId);
+
+ _operations.Add(new TransactionOperation
+ {
+ Type = GraphOperationType.RemoveNode,
+ NodeId = nodeId,
+ Node = originalNode // Store for undo
+ });
+ }
+
+ ///
+ /// Removes an edge within the transaction.
+ ///
+ /// The ID of the edge to remove.
+ public void RemoveEdge(string edgeId)
+ {
+ EnsureActive();
+
+ // Capture original edge data for potential undo
+ var originalEdge = _store.GetEdge(edgeId);
+
+ _operations.Add(new TransactionOperation
+ {
+ Type = GraphOperationType.RemoveEdge,
+ EdgeId = edgeId,
+ Edge = originalEdge // Store for undo
+ });
+ }
+
+ ///
+ /// Commits the transaction, applying all operations with best-effort atomicity.
+ ///
+ /// Thrown if transaction not active.
+ ///
+ ///
+ /// If an operation fails mid-way, compensating rollback logic is executed
+ /// to undo already-applied operations in reverse order. However, this is
+ /// best-effort: if an undo operation throws an exception, it is
+ /// caught and swallowed, and rollback continues with remaining operations.
+ ///
+ ///
+ /// This means that after a failed commit, the graph may be left in a
+ /// partially modified state if undo operations also fail. For production
+ /// use cases requiring strict atomicity, consider using a database-backed
+ /// graph store with native transaction support.
+ ///
+ ///
+ public void Commit()
+ {
+ EnsureActive();
+
+ var appliedOperations = new List>();
+
+ try
+ {
+ // Log to WAL first (durability)
+ if (_wal != null)
+ {
+ foreach (var op in _operations)
+ {
+ LogOperation(op);
+ }
+ }
+
+ // Apply all operations, tracking which ones succeed
+ foreach (var op in _operations)
+ {
+ ApplyOperation(op);
+ appliedOperations.Add(op);
+ }
+
+ // Checkpoint if using WAL
+ _wal?.LogCheckpoint();
+
+ _state = TransactionState.Committed;
+ }
+ catch (Exception)
+ {
+ // Compensating rollback: undo applied operations in reverse order
+ for (int i = appliedOperations.Count - 1; i >= 0; i--)
+ {
+ try
+ {
+ UndoOperation(appliedOperations[i]);
+ }
+ catch
+ {
+ // Best-effort rollback - continue with remaining undo operations
+ }
+ }
+
+ _state = TransactionState.Failed;
+ throw;
+ }
+ }
+
+ ///
+ /// Rolls back the transaction, discarding all operations.
+ ///
+ public void Rollback()
+ {
+ if (_state != TransactionState.Active && _state != TransactionState.Failed)
+ throw new InvalidOperationException($"Cannot rollback transaction in state: {_state}");
+
+ // Simply discard operations (they were never applied)
+ _operations.Clear();
+ _state = TransactionState.RolledBack;
+ }
+
+ ///
+ /// Ensures the transaction is in active state.
+ ///
+ private void EnsureActive()
+ {
+ if (_state != TransactionState.Active)
+ throw new InvalidOperationException($"Transaction not active. Current state: {_state}");
+ }
+
+ ///
+ /// Logs an operation to the WAL.
+ ///
+ private void LogOperation(TransactionOperation op)
+ {
+ if (_wal == null)
+ return;
+
+ switch (op.Type)
+ {
+ case GraphOperationType.AddNode:
+ _wal.LogAddNode(op.Node!);
+ break;
+ case GraphOperationType.AddEdge:
+ _wal.LogAddEdge(op.Edge!);
+ break;
+ case GraphOperationType.RemoveNode:
+ _wal.LogRemoveNode(op.NodeId!);
+ break;
+ case GraphOperationType.RemoveEdge:
+ _wal.LogRemoveEdge(op.EdgeId!);
+ break;
+ }
+ }
+
+ ///
+ /// Applies an operation to the graph store.
+ ///
+ private void ApplyOperation(TransactionOperation op)
+ {
+ switch (op.Type)
+ {
+ case GraphOperationType.AddNode:
+ if (op.Node != null)
+ _store.AddNode(op.Node);
+ break;
+ case GraphOperationType.AddEdge:
+ if (op.Edge != null)
+ _store.AddEdge(op.Edge);
+ break;
+ case GraphOperationType.RemoveNode:
+ if (op.NodeId != null)
+ _store.RemoveNode(op.NodeId);
+ break;
+ case GraphOperationType.RemoveEdge:
+ if (op.EdgeId != null)
+ _store.RemoveEdge(op.EdgeId);
+ break;
+ }
+ }
+
+ ///
+ /// Undoes an already-applied operation (compensating action).
+ ///
+ ///
+ ///
+ /// This method performs the reverse of each operation type:
+ /// - AddNode → RemoveNode (using the stored node's ID)
+ /// - AddEdge → RemoveEdge (using the stored edge's ID)
+ /// - RemoveNode → AddNode (re-adds the captured original node)
+ /// - RemoveEdge → AddEdge (re-adds the captured original edge)
+ ///
+ ///
+ /// The original node/edge data is captured during the RecordRemoveNode/RecordRemoveEdge
+ /// operations, allowing complete restoration during undo.
+ ///
+ ///
+ private void UndoOperation(TransactionOperation op)
+ {
+ switch (op.Type)
+ {
+ case GraphOperationType.AddNode:
+ // Undo add by removing the node
+ if (op.Node != null)
+ _store.RemoveNode(op.Node.Id);
+ break;
+ case GraphOperationType.AddEdge:
+ // Undo add by removing the edge
+ if (op.Edge != null)
+ _store.RemoveEdge(op.Edge.Id);
+ break;
+ case GraphOperationType.RemoveNode:
+ // Undo remove by re-adding the node (if we have the original data)
+ if (op.Node != null)
+ _store.AddNode(op.Node);
+ break;
+ case GraphOperationType.RemoveEdge:
+ // Undo remove by re-adding the edge (if we have the original data)
+ if (op.Edge != null)
+ _store.AddEdge(op.Edge);
+ break;
+ }
+ }
+
+ ///
+ /// Disposes the transaction, rolling back if still active.
+ ///
+ public void Dispose()
+ {
+ if (_disposed)
+ return;
+
+ // Auto-rollback if still active
+ if (_state == TransactionState.Active)
+ {
+ try
+ {
+ Rollback();
+ }
+ catch (InvalidOperationException ex)
+ {
+ // Rollback errors during dispose are suppressed per IDisposable contract,
+ // but traced for diagnostics. Transaction may already be in invalid state.
+ Debug.WriteLine($"GraphTransaction.Dispose: Rollback failed with InvalidOperationException: {ex.Message}");
+ }
+ catch (IOException ex)
+ {
+ // I/O errors during dispose are suppressed per IDisposable contract,
+ // but traced for diagnostics. Underlying store may be unavailable.
+ Debug.WriteLine($"GraphTransaction.Dispose: Rollback failed with IOException: {ex.Message}");
+ }
+ }
+
+ _disposed = true;
+ }
+}
+
+///
+/// Represents a single operation within a transaction.
+///
+/// The numeric type.
+internal class TransactionOperation
+{
+ public GraphOperationType Type { get; set; }
+ public GraphNode? Node { get; set; }
+ public GraphEdge? Edge { get; set; }
+ public string? NodeId { get; set; }
+ public string? EdgeId { get; set; }
+}
+
+///
+/// Types of operations supported in graph transactions.
+///
+internal enum GraphOperationType
+{
+ AddNode,
+ AddEdge,
+ RemoveNode,
+ RemoveEdge
+}
+
+///
+/// Represents the state of a transaction.
+///
+public enum TransactionState
+{
+ ///
+ /// Transaction has not been started.
+ ///
+ NotStarted,
+
+ ///
+ /// Transaction is active and accepting operations.
+ ///
+ Active,
+
+ ///
+ /// Transaction has been committed successfully.
+ ///
+ Committed,
+
+ ///
+ /// Transaction has been rolled back.
+ ///
+ RolledBack,
+
+ ///
+ /// Transaction failed during commit.
+ ///
+ Failed
+}
diff --git a/src/RetrievalAugmentedGeneration/Graph/HybridGraphRetriever.cs b/src/RetrievalAugmentedGeneration/Graph/HybridGraphRetriever.cs
new file mode 100644
index 000000000..d7851c6b7
--- /dev/null
+++ b/src/RetrievalAugmentedGeneration/Graph/HybridGraphRetriever.cs
@@ -0,0 +1,383 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using AiDotNet.Helpers;
+using AiDotNet.Interfaces;
+using AiDotNet.LinearAlgebra;
+using AiDotNet.RetrievalAugmentedGeneration.Models;
+
+namespace AiDotNet.RetrievalAugmentedGeneration.Graph;
+
+///
+/// Hybrid retriever that combines vector similarity search with graph traversal for enhanced RAG.
+///
+/// The numeric type used for vector operations.
+///
+///
+/// This retriever uses a two-stage approach:
+/// 1. Vector similarity search to find initial candidate nodes
+/// 2. Graph traversal to expand context with related nodes
+///
+/// For Beginners: Traditional RAG uses only vector similarity:
+///
+/// Query: "What is photosynthesis?"
+/// Traditional RAG:
+/// - Find documents similar to the query
+/// - Return top-k matches
+/// - Misses related context!
+///
+/// Hybrid Graph RAG:
+/// - Find initial matches via vector similarity
+/// - Walk the graph to find related concepts
+/// - Example: photosynthesis → chlorophyll → plants → carbon dioxide
+/// - Provides richer, more complete context
+///
+/// Real-world analogy:
+/// - Traditional: Search "Paris" → get Paris documents
+/// - Hybrid: Search "Paris" → get Paris + France + Eiffel Tower + Seine River
+/// - Graph connections provide context vectors can't capture!
+///
+///
+public class HybridGraphRetriever
+{
+ private readonly KnowledgeGraph _graph;
+ private readonly IDocumentStore _documentStore;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The knowledge graph containing entity relationships.
+ /// The document store for similarity search.
+ public HybridGraphRetriever(
+ KnowledgeGraph graph,
+ IDocumentStore documentStore)
+ {
+ _graph = graph ?? throw new ArgumentNullException(nameof(graph));
+ _documentStore = documentStore ?? throw new ArgumentNullException(nameof(documentStore));
+ }
+
+ ///
+ /// Retrieves relevant nodes using hybrid vector + graph approach.
+ ///
+ /// The query embedding vector.
+ /// Number of initial candidates to retrieve via vector search.
+ /// How many hops to traverse in the graph (0 = no expansion).
+ /// Maximum total results to return after expansion.
+ /// List of retrieved nodes with their relevance scores.
+ public List> Retrieve(
+ Vector queryEmbedding,
+ int topK = 5,
+ int expansionDepth = 1,
+ int maxResults = 10)
+ {
+ if (queryEmbedding == null || queryEmbedding.Length == 0)
+ throw new ArgumentException("Query embedding cannot be null or empty", nameof(queryEmbedding));
+ if (topK <= 0)
+ throw new ArgumentOutOfRangeException(nameof(topK), "topK must be positive");
+ if (expansionDepth < 0)
+ throw new ArgumentOutOfRangeException(nameof(expansionDepth), "expansionDepth cannot be negative");
+
+ // Stage 1: Vector similarity search for initial candidates using document store
+ var initialCandidates = _documentStore.GetSimilar(queryEmbedding, topK).ToList();
+
+ if (expansionDepth == 0)
+ {
+ // No graph expansion - return vector results only
+ return initialCandidates
+ .Take(maxResults)
+ .Select(doc => new RetrievalResult
+ {
+ NodeId = doc.Id,
+ Score = doc.HasRelevanceScore ? Convert.ToDouble(doc.RelevanceScore) : 0.0,
+ Source = RetrievalSource.VectorSearch,
+ Embedding = doc.Embedding
+ })
+ .ToList();
+ }
+
+ // Stage 2: Graph expansion
+ var results = new Dictionary>();
+ var visited = new HashSet();
+
+ // Add initial candidates
+ foreach (var candidate in initialCandidates)
+ {
+ var result = new RetrievalResult
+ {
+ NodeId = candidate.Id,
+ Score = candidate.HasRelevanceScore ? Convert.ToDouble(candidate.RelevanceScore) : 0.0,
+ Source = RetrievalSource.VectorSearch,
+ Embedding = candidate.Embedding,
+ Depth = 0
+ };
+ results[candidate.Id] = result;
+ visited.Add(candidate.Id);
+ }
+
+ // BFS expansion from initial candidates
+ var queue = new Queue<(string nodeId, int depth)>();
+ foreach (var candidate in initialCandidates)
+ {
+ queue.Enqueue((candidate.Id, 0));
+ }
+
+ while (queue.Count > 0)
+ {
+ var (currentId, currentDepth) = queue.Dequeue();
+
+ if (currentDepth >= expansionDepth)
+ continue;
+
+ // Get neighbors from graph
+ var neighbors = GetNeighbors(currentId);
+
+ foreach (var neighborId in neighbors.Where(n => !visited.Contains(n)))
+ {
+ visited.Add(neighborId);
+
+ // Get neighbor's embedding from graph node
+ var neighborNode = _graph.GetNode(neighborId);
+ var neighborEmbedding = neighborNode?.Embedding;
+ double score = 0.0;
+
+ if (neighborEmbedding != null && neighborEmbedding.Length > 0)
+ {
+ // Calculate similarity to query using StatisticsHelper
+ score = CalculateSimilarity(queryEmbedding, neighborEmbedding);
+
+ // Apply depth penalty (closer nodes are more relevant)
+ var depthPenalty = Math.Pow(0.8, currentDepth + 1); // 0.8^depth
+ score *= depthPenalty;
+ }
+
+ var result = new RetrievalResult
+ {
+ NodeId = neighborId,
+ Score = score,
+ Source = RetrievalSource.GraphTraversal,
+ Embedding = neighborEmbedding,
+ Depth = currentDepth + 1,
+ ParentNodeId = currentId
+ };
+
+ // Only update if new score is higher to preserve best path
+ if (!results.TryGetValue(neighborId, out var existing) || result.Score > existing.Score)
+ {
+ results[neighborId] = result;
+ }
+
+ // Continue expanding
+ if (currentDepth + 1 < expansionDepth)
+ {
+ queue.Enqueue((neighborId, currentDepth + 1));
+ }
+ }
+ }
+
+ // Return top results sorted by score
+ return results.Values
+ .OrderByDescending(r => r.Score)
+ .Take(maxResults)
+ .ToList();
+ }
+
+ ///
+ /// Retrieves relevant nodes asynchronously using hybrid approach.
+ ///
+ public async Task>> RetrieveAsync(
+ Vector queryEmbedding,
+ int topK = 5,
+ int expansionDepth = 1,
+ int maxResults = 10)
+ {
+ // For now, just wrap the synchronous version
+ // In a real implementation, you'd use async vector DB operations
+ return await Task.Run(() => Retrieve(queryEmbedding, topK, expansionDepth, maxResults));
+ }
+
+ ///
+ /// Retrieves nodes with relationship-aware scoring.
+ ///
+ /// The query embedding vector.
+ /// Number of initial candidates.
+ /// Weights for different relationship types.
+ /// Maximum results to return.
+ /// List of retrieved nodes with relationship-aware scores.
+ public List> RetrieveWithRelationships(
+ Vector queryEmbedding,
+ int topK = 5,
+ Dictionary? relationshipWeights = null,
+ int maxResults = 10)
+ {
+ if (queryEmbedding == null || queryEmbedding.Length == 0)
+ throw new ArgumentException("Query embedding cannot be null or empty", nameof(queryEmbedding));
+
+ relationshipWeights ??= new Dictionary();
+
+ // Stage 1: Vector similarity search
+ var initialCandidates = _documentStore.GetSimilar(queryEmbedding, topK).ToList();
+ var results = new Dictionary>();
+
+ // Add initial candidates
+ foreach (var candidate in initialCandidates)
+ {
+ results[candidate.Id] = new RetrievalResult
+ {
+ NodeId = candidate.Id,
+ Score = candidate.HasRelevanceScore ? Convert.ToDouble(candidate.RelevanceScore) : 0.0,
+ Source = RetrievalSource.VectorSearch,
+ Embedding = candidate.Embedding,
+ Depth = 0
+ };
+ }
+
+ // Stage 2: Expand via relationships with weighted scoring
+ foreach (var candidate in initialCandidates)
+ {
+ var node = _graph.GetNode(candidate.Id);
+ if (node == null)
+ continue;
+
+ var outgoingEdges = _graph.GetOutgoingEdges(candidate.Id);
+
+ foreach (var edge in outgoingEdges.Where(e => !results.ContainsKey(e.TargetId)))
+ {
+ // Get relationship weight (default to 1.0)
+ var weight = relationshipWeights.TryGetValue(edge.RelationType, out var w) ? w : 1.0;
+
+ // Get target node's embedding from graph
+ var targetNode = _graph.GetNode(edge.TargetId);
+ var targetEmbedding = targetNode?.Embedding;
+ double score = 0.0;
+
+ if (targetEmbedding != null && targetEmbedding.Length > 0)
+ {
+ score = CalculateSimilarity(queryEmbedding, targetEmbedding);
+ score *= weight; // Apply relationship weight
+ score *= 0.8; // One-hop penalty
+ }
+
+ var result = new RetrievalResult
+ {
+ NodeId = edge.TargetId,
+ Score = score,
+ Source = RetrievalSource.GraphTraversal,
+ Embedding = targetEmbedding,
+ Depth = 1,
+ ParentNodeId = candidate.Id,
+ RelationType = edge.RelationType
+ };
+
+ // Only update if new score is higher to preserve best path
+ if (!results.TryGetValue(edge.TargetId, out var existing) || result.Score > existing.Score)
+ {
+ results[edge.TargetId] = result;
+ }
+ }
+ }
+
+ // Return top results
+ return results.Values
+ .OrderByDescending(r => r.Score)
+ .Take(maxResults)
+ .ToList();
+ }
+
+ ///
+ /// Gets all neighbors (both incoming and outgoing) of a node.
+ ///
+ private HashSet GetNeighbors(string nodeId)
+ {
+ var neighbors = new HashSet();
+
+ var outgoing = _graph.GetOutgoingEdges(nodeId);
+ foreach (var edge in outgoing)
+ {
+ neighbors.Add(edge.TargetId);
+ }
+
+ var incoming = _graph.GetIncomingEdges(nodeId);
+ foreach (var edge in incoming)
+ {
+ neighbors.Add(edge.SourceId);
+ }
+
+ return neighbors;
+ }
+
+ ///
+ /// Calculates similarity between two embeddings using cosine similarity.
+ ///
+ private double CalculateSimilarity(Vector embedding1, Vector embedding2)
+ {
+ if (embedding1.Length != embedding2.Length)
+ return 0.0;
+
+ var similarity = StatisticsHelper.CosineSimilarity(embedding1, embedding2);
+ return Convert.ToDouble(similarity);
+ }
+}
+
+///
+/// Represents a retrieval result from the hybrid retriever.
+///
+/// The numeric type.
+public class RetrievalResult
+{
+ ///
+ /// Gets or sets the node ID.
+ ///
+ public string NodeId { get; set; } = string.Empty;
+
+ ///
+ /// Gets or sets the relevance score (0-1, higher is more relevant).
+ ///
+ public double Score { get; set; }
+
+ ///
+ /// Gets or sets how this result was retrieved.
+ ///
+ public RetrievalSource Source { get; set; }
+
+ ///
+ /// Gets or sets the embedding vector.
+ ///
+ public Vector? Embedding { get; set; }
+
+ ///
+ /// Gets or sets the graph traversal depth (0 for initial candidates).
+ ///
+ public int Depth { get; set; }
+
+ ///
+ /// Gets or sets the parent node ID (for graph-traversed results).
+ ///
+ public string? ParentNodeId { get; set; }
+
+ ///
+ /// Gets or sets the relationship type (for graph-traversed results).
+ ///
+ public string? RelationType { get; set; }
+}
+
+///
+/// Indicates how a result was retrieved.
+///
+public enum RetrievalSource
+{
+ ///
+ /// Retrieved via vector similarity search.
+ ///
+ VectorSearch,
+
+ ///
+ /// Retrieved via graph traversal.
+ ///
+ GraphTraversal,
+
+ ///
+ /// Retrieved via both methods.
+ ///
+ Hybrid
+}
diff --git a/src/RetrievalAugmentedGeneration/Graph/KnowledgeGraph.cs b/src/RetrievalAugmentedGeneration/Graph/KnowledgeGraph.cs
index 9c9f1c3a7..f62ef742f 100644
--- a/src/RetrievalAugmentedGeneration/Graph/KnowledgeGraph.cs
+++ b/src/RetrievalAugmentedGeneration/Graph/KnowledgeGraph.cs
@@ -1,65 +1,72 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using AiDotNet.Interfaces;
namespace AiDotNet.RetrievalAugmentedGeneration.Graph;
///
-/// In-memory knowledge graph for storing and querying entity relationships.
+/// Knowledge graph for storing and querying entity relationships using a pluggable storage backend.
///
/// The numeric type used for vector operations.
///
///
/// A knowledge graph stores entities (nodes) and their relationships (edges) to enable structured information retrieval.
-/// This implementation uses efficient in-memory data structures optimized for graph traversal and querying.
+/// This implementation delegates storage operations to an implementation,
+/// allowing you to swap between in-memory, file-based, or database-backed storage.
///
/// For Beginners: A knowledge graph is like a map of how information connects together.
-///
+///
/// Imagine Wikipedia as a graph:
/// - Each article is a node (Albert Einstein, Physics, Germany, etc.)
/// - Links between articles are edges (Einstein STUDIED Physics, Einstein BORN_IN Germany)
/// - You can traverse the graph to find related information
-///
+///
/// This class lets you:
/// 1. Add entities and relationships
/// 2. Find connections between entities
/// 3. Traverse the graph to discover related information
/// 4. Query based on entity types or relationships
-///
+///
/// For example, to answer "Who worked at Princeton?":
/// 1. Find all edges with type "WORKED_AT"
/// 2. Filter for target = "Princeton University"
/// 3. Return the source entities (people who worked there)
+///
+/// Storage backends you can use:
+/// - MemoryGraphStore: Fast, in-memory (default)
+/// - FileGraphStore: Persistent, disk-based
+/// - Neo4jGraphStore: Professional graph database (future)
///
///
public class KnowledgeGraph
{
- private readonly Dictionary> _nodes;
- private readonly Dictionary> _edges;
- private readonly Dictionary> _outgoingEdges; // nodeId -> edge IDs going out
- private readonly Dictionary> _incomingEdges; // nodeId -> edge IDs coming in
- private readonly Dictionary> _nodesByLabel; // label -> node IDs
-
+ private readonly IGraphStore _store;
+
///
/// Gets the total number of nodes in the graph.
///
- public int NodeCount => _nodes.Count;
-
+ public int NodeCount => _store.NodeCount;
+
///
/// Gets the total number of edges in the graph.
///
- public int EdgeCount => _edges.Count;
-
+ public int EdgeCount => _store.EdgeCount;
+
+ ///
+ /// Initializes a new instance of the class with a custom graph store.
+ ///
+ /// The graph store implementation to use for storage.
+ public KnowledgeGraph(IGraphStore store)
+ {
+ _store = store ?? throw new ArgumentNullException(nameof(store));
+ }
+
///
- /// Initializes a new instance of the class.
+ /// Initializes a new instance of the class with default in-memory storage.
///
- public KnowledgeGraph()
+ public KnowledgeGraph() : this(new MemoryGraphStore())
{
- _nodes = new Dictionary>();
- _edges = new Dictionary>();
- _outgoingEdges = new Dictionary>();
- _incomingEdges = new Dictionary>();
- _nodesByLabel = new Dictionary>();
}
///
@@ -68,21 +75,9 @@ public KnowledgeGraph()
/// The node to add.
public void AddNode(GraphNode node)
{
- if (node == null)
- throw new ArgumentNullException(nameof(node));
-
- _nodes[node.Id] = node;
-
- if (!_nodesByLabel.ContainsKey(node.Label))
- _nodesByLabel[node.Label] = new HashSet();
- _nodesByLabel[node.Label].Add(node.Id);
-
- if (!_outgoingEdges.ContainsKey(node.Id))
- _outgoingEdges[node.Id] = new HashSet();
- if (!_incomingEdges.ContainsKey(node.Id))
- _incomingEdges[node.Id] = new HashSet();
+ _store.AddNode(node);
}
-
+
///
/// Adds an edge to the graph.
///
@@ -90,18 +85,9 @@ public void AddNode(GraphNode node)
/// Thrown when source or target nodes don't exist.
public void AddEdge(GraphEdge edge)
{
- if (edge == null)
- throw new ArgumentNullException(nameof(edge));
- if (!_nodes.ContainsKey(edge.SourceId))
- throw new InvalidOperationException($"Source node '{edge.SourceId}' does not exist");
- if (!_nodes.ContainsKey(edge.TargetId))
- throw new InvalidOperationException($"Target node '{edge.TargetId}' does not exist");
-
- _edges[edge.Id] = edge;
- _outgoingEdges[edge.SourceId].Add(edge.Id);
- _incomingEdges[edge.TargetId].Add(edge.Id);
+ _store.AddEdge(edge);
}
-
+
///
/// Gets a node by its ID.
///
@@ -109,9 +95,9 @@ public void AddEdge(GraphEdge edge)
/// The node, or null if not found.
public GraphNode? GetNode(string nodeId)
{
- return _nodes.TryGetValue(nodeId, out var node) ? node : null;
+ return _store.GetNode(nodeId);
}
-
+
///
/// Gets all nodes with a specific label.
///
@@ -119,12 +105,9 @@ public void AddEdge(GraphEdge edge)
/// Collection of nodes with the specified label.
public IEnumerable> GetNodesByLabel(string label)
{
- if (!_nodesByLabel.TryGetValue(label, out var nodeIds))
- return Enumerable.Empty>();
-
- return nodeIds.Select(id => _nodes[id]);
+ return _store.GetNodesByLabel(label);
}
-
+
///
/// Gets all outgoing edges from a node.
///
@@ -132,12 +115,9 @@ public IEnumerable> GetNodesByLabel(string label)
/// Collection of outgoing edges.
public IEnumerable> GetOutgoingEdges(string nodeId)
{
- if (!_outgoingEdges.TryGetValue(nodeId, out var edgeIds))
- return Enumerable.Empty>();
-
- return edgeIds.Select(id => _edges[id]);
+ return _store.GetOutgoingEdges(nodeId);
}
-
+
///
/// Gets all incoming edges to a node.
///
@@ -145,10 +125,7 @@ public IEnumerable> GetOutgoingEdges(string nodeId)
/// Collection of incoming edges.
public IEnumerable> GetIncomingEdges(string nodeId)
{
- if (!_incomingEdges.TryGetValue(nodeId, out var edgeIds))
- return Enumerable.Empty>();
-
- return edgeIds.Select(id => _edges[id]);
+ return _store.GetIncomingEdges(nodeId);
}
///
@@ -159,7 +136,9 @@ public IEnumerable> GetIncomingEdges(string nodeId)
public IEnumerable> GetNeighbors(string nodeId)
{
var edges = GetOutgoingEdges(nodeId);
- return edges.Select(e => _nodes[e.TargetId]);
+ return edges
+ .Select(e => _store.GetNode(e.TargetId))
+ .OfType>();
}
///
@@ -170,22 +149,28 @@ public IEnumerable> GetNeighbors(string nodeId)
/// Collection of nodes in BFS order.
public IEnumerable> BreadthFirstTraversal(string startNodeId, int maxDepth = int.MaxValue)
{
- if (!_nodes.ContainsKey(startNodeId))
+ if (_store.GetNode(startNodeId) == null)
yield break;
-
+
var visited = new HashSet();
var queue = new Queue<(string nodeId, int depth)>();
queue.Enqueue((startNodeId, 0));
visited.Add(startNodeId);
-
+
while (queue.Count > 0)
{
var (nodeId, depth) = queue.Dequeue();
- yield return _nodes[nodeId];
-
+ var node = _store.GetNode(nodeId);
+
+ // Skip if node was removed between queue add and dequeue
+ if (node == null)
+ continue;
+
+ yield return node;
+
if (depth >= maxDepth)
continue;
-
+
foreach (var edge in GetOutgoingEdges(nodeId))
{
if (!visited.Contains(edge.TargetId))
@@ -205,20 +190,20 @@ public IEnumerable> BreadthFirstTraversal(string startNodeId, int m
/// List of node IDs representing the path, or empty if no path exists.
public List FindShortestPath(string startNodeId, string endNodeId)
{
- if (!_nodes.ContainsKey(startNodeId) || !_nodes.ContainsKey(endNodeId))
+ if (_store.GetNode(startNodeId) == null || _store.GetNode(endNodeId) == null)
return new List();
-
+
var visited = new HashSet();
var parent = new Dictionary();
var queue = new Queue();
-
+
queue.Enqueue(startNodeId);
visited.Add(startNodeId);
-
+
while (queue.Count > 0)
{
var nodeId = queue.Dequeue();
-
+
if (nodeId == endNodeId)
{
// Reconstruct path
@@ -233,7 +218,7 @@ public List FindShortestPath(string startNodeId, string endNodeId)
path.Reverse();
return path;
}
-
+
foreach (var edge in GetOutgoingEdges(nodeId))
{
if (!visited.Contains(edge.TargetId))
@@ -244,7 +229,7 @@ public List FindShortestPath(string startNodeId, string endNodeId)
}
}
}
-
+
return new List(); // No path found
}
@@ -257,8 +242,8 @@ public List FindShortestPath(string startNodeId, string endNodeId)
public IEnumerable> FindRelatedNodes(string query, int topK = 10)
{
var queryLower = query.ToLowerInvariant();
-
- return _nodes.Values
+
+ return _store.GetAllNodes()
.Where(node =>
{
var name = node.GetProperty("name") ?? node.Id;
@@ -267,34 +252,30 @@ public IEnumerable> FindRelatedNodes(string query, int topK = 10)
})
.Take(topK);
}
-
+
///
/// Clears all nodes and edges from the graph.
///
public void Clear()
{
- _nodes.Clear();
- _edges.Clear();
- _outgoingEdges.Clear();
- _incomingEdges.Clear();
- _nodesByLabel.Clear();
+ _store.Clear();
}
-
+
///
/// Gets all nodes in the graph.
///
/// Collection of all nodes.
public IEnumerable> GetAllNodes()
{
- return _nodes.Values;
+ return _store.GetAllNodes();
}
-
+
///
/// Gets all edges in the graph.
///
/// Collection of all edges.
public IEnumerable> GetAllEdges()
{
- return _edges.Values;
+ return _store.GetAllEdges();
}
}
diff --git a/src/RetrievalAugmentedGeneration/Graph/MemoryGraphStore.cs b/src/RetrievalAugmentedGeneration/Graph/MemoryGraphStore.cs
new file mode 100644
index 000000000..76d226b5a
--- /dev/null
+++ b/src/RetrievalAugmentedGeneration/Graph/MemoryGraphStore.cs
@@ -0,0 +1,323 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using AiDotNet.Interfaces;
+
+namespace AiDotNet.RetrievalAugmentedGeneration.Graph;
+
+///
+/// In-memory implementation of using dictionaries for fast lookups.
+///
+/// The numeric type used for vector operations.
+///
+///
+/// This implementation provides high-performance graph storage entirely in RAM.
+/// All operations are O(1) or O(degree) complexity. Data is lost when the application stops.
+///
+///
+/// Thread Safety: This class is NOT thread-safe. Callers must ensure proper
+/// synchronization when accessing from multiple threads. For thread-safe operations,
+/// use external locking or consider using which provides
+/// thread-safe access via ConcurrentDictionary.
+///
+/// For Beginners: This stores your graph in the computer's memory (RAM).
+///
+/// Pros:
+/// - Very fast (everything in RAM)
+/// - Simple to use (no setup required)
+///
+/// Cons:
+/// - Data lost when app closes
+/// - Limited by available RAM
+/// - Not thread-safe (single-threaded use only)
+///
+/// Good for:
+/// - Development and testing
+/// - Small to medium graphs (<100K nodes)
+/// - Temporary graphs that don't need persistence
+///
+/// Not good for:
+/// - Production systems requiring persistence
+/// - Very large graphs (>1M nodes)
+/// - Multi-process or multi-threaded access to the same graph
+///
+/// For persistent storage, use FileGraphStore or Neo4jGraphStore instead.
+///
+///
+public class MemoryGraphStore : IGraphStore
+{
+ private readonly Dictionary> _nodes;
+ private readonly Dictionary> _edges;
+ private readonly Dictionary> _outgoingEdges; // nodeId -> edge IDs going out
+ private readonly Dictionary> _incomingEdges; // nodeId -> edge IDs coming in
+ private readonly Dictionary> _nodesByLabel; // label -> node IDs
+
+ ///
+ public int NodeCount => _nodes.Count;
+
+ ///
+ public int EdgeCount => _edges.Count;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public MemoryGraphStore()
+ {
+ _nodes = new Dictionary>();
+ _edges = new Dictionary>();
+ _outgoingEdges = new Dictionary>();
+ _incomingEdges = new Dictionary>();
+ _nodesByLabel = new Dictionary>();
+ }
+
+ ///
+ public void AddNode(GraphNode node)
+ {
+ if (node == null)
+ throw new ArgumentNullException(nameof(node));
+
+ // Remove old label index if node exists with different label
+ if (_nodes.TryGetValue(node.Id, out var existingNode) && existingNode.Label != node.Label)
+ {
+ if (_nodesByLabel.TryGetValue(existingNode.Label, out var oldLabelNodeIds))
+ {
+ oldLabelNodeIds.Remove(node.Id);
+ if (oldLabelNodeIds.Count == 0)
+ _nodesByLabel.Remove(existingNode.Label);
+ }
+ }
+
+ _nodes[node.Id] = node;
+
+ if (!_nodesByLabel.ContainsKey(node.Label))
+ _nodesByLabel[node.Label] = new HashSet();
+ _nodesByLabel[node.Label].Add(node.Id);
+
+ if (!_outgoingEdges.ContainsKey(node.Id))
+ _outgoingEdges[node.Id] = new HashSet();
+ if (!_incomingEdges.ContainsKey(node.Id))
+ _incomingEdges[node.Id] = new HashSet();
+ }
+
+ ///
+ public void AddEdge(GraphEdge edge)
+ {
+ if (edge == null)
+ throw new ArgumentNullException(nameof(edge));
+ if (!_nodes.ContainsKey(edge.SourceId))
+ throw new InvalidOperationException($"Source node '{edge.SourceId}' does not exist");
+ if (!_nodes.ContainsKey(edge.TargetId))
+ throw new InvalidOperationException($"Target node '{edge.TargetId}' does not exist");
+
+ _edges[edge.Id] = edge;
+ _outgoingEdges[edge.SourceId].Add(edge.Id);
+ _incomingEdges[edge.TargetId].Add(edge.Id);
+ }
+
+ ///
+ public GraphNode? GetNode(string nodeId)
+ {
+ return _nodes.TryGetValue(nodeId, out var node) ? node : null;
+ }
+
+ ///
+ public GraphEdge? GetEdge(string edgeId)
+ {
+ return _edges.TryGetValue(edgeId, out var edge) ? edge : null;
+ }
+
+ ///
+ public bool RemoveNode(string nodeId)
+ {
+ if (!_nodes.TryGetValue(nodeId, out var node))
+ return false;
+
+ // Remove all outgoing edges
+ if (_outgoingEdges.TryGetValue(nodeId, out var outgoing))
+ {
+ foreach (var edgeId in outgoing.ToList())
+ {
+ if (_edges.TryGetValue(edgeId, out var edge))
+ {
+ _edges.Remove(edgeId);
+ _incomingEdges[edge.TargetId].Remove(edgeId);
+ }
+ }
+ _outgoingEdges.Remove(nodeId);
+ }
+
+ // Remove all incoming edges
+ if (_incomingEdges.TryGetValue(nodeId, out var incoming))
+ {
+ foreach (var edgeId in incoming.ToList())
+ {
+ if (_edges.TryGetValue(edgeId, out var edge))
+ {
+ _edges.Remove(edgeId);
+ _outgoingEdges[edge.SourceId].Remove(edgeId);
+ }
+ }
+ _incomingEdges.Remove(nodeId);
+ }
+
+ // Remove from label index
+ if (_nodesByLabel.TryGetValue(node.Label, out var nodeIds))
+ {
+ nodeIds.Remove(nodeId);
+ if (nodeIds.Count == 0)
+ _nodesByLabel.Remove(node.Label);
+ }
+
+ // Remove the node itself
+ _nodes.Remove(nodeId);
+ return true;
+ }
+
+ ///
+ public bool RemoveEdge(string edgeId)
+ {
+ if (!_edges.TryGetValue(edgeId, out var edge))
+ return false;
+
+ _edges.Remove(edgeId);
+ _outgoingEdges[edge.SourceId].Remove(edgeId);
+ _incomingEdges[edge.TargetId].Remove(edgeId);
+ return true;
+ }
+
+ ///
+ public IEnumerable> GetOutgoingEdges(string nodeId)
+ {
+ if (!_outgoingEdges.TryGetValue(nodeId, out var edgeIds))
+ return Enumerable.Empty>();
+
+ // Use TryGetValue to safely handle edges that may have been removed
+ return edgeIds
+ .Select(id => _edges.TryGetValue(id, out var edge) ? edge : null)
+ .OfType>();
+ }
+
+ ///
+ public IEnumerable> GetIncomingEdges(string nodeId)
+ {
+ if (!_incomingEdges.TryGetValue(nodeId, out var edgeIds))
+ return Enumerable.Empty>();
+
+ // Use TryGetValue to safely handle edges that may have been removed
+ return edgeIds
+ .Select(id => _edges.TryGetValue(id, out var edge) ? edge : null)
+ .OfType>();
+ }
+
+ ///
+ public IEnumerable> GetNodesByLabel(string label)
+ {
+ if (!_nodesByLabel.TryGetValue(label, out var nodeIds))
+ return Enumerable.Empty>();
+
+ // Use TryGetValue to safely handle nodes that may have been removed
+ return nodeIds
+ .Select(id => _nodes.TryGetValue(id, out var node) ? node : null)
+ .OfType>();
+ }
+
+ ///
+ public IEnumerable> GetAllNodes()
+ {
+ return _nodes.Values;
+ }
+
+ ///
+ public IEnumerable> GetAllEdges()
+ {
+ return _edges.Values;
+ }
+
+ ///
+ public void Clear()
+ {
+ _nodes.Clear();
+ _edges.Clear();
+ _outgoingEdges.Clear();
+ _incomingEdges.Clear();
+ _nodesByLabel.Clear();
+ }
+
+ // Async methods (for MemoryGraphStore, these wrap synchronous operations)
+
+ ///
+ public Task AddNodeAsync(GraphNode node)
+ {
+ AddNode(node);
+ return Task.CompletedTask;
+ }
+
+ ///
+ public Task AddEdgeAsync(GraphEdge edge)
+ {
+ AddEdge(edge);
+ return Task.CompletedTask;
+ }
+
+ ///
+ public Task?> GetNodeAsync(string nodeId)
+ {
+ return Task.FromResult(GetNode(nodeId));
+ }
+
+ ///
+ public Task?> GetEdgeAsync(string edgeId)
+ {
+ return Task.FromResult(GetEdge(edgeId));
+ }
+
+ ///
+ public Task RemoveNodeAsync(string nodeId)
+ {
+ return Task.FromResult(RemoveNode(nodeId));
+ }
+
+ ///
+ public Task RemoveEdgeAsync(string edgeId)
+ {
+ return Task.FromResult(RemoveEdge(edgeId));
+ }
+
+ ///
+ public Task>> GetOutgoingEdgesAsync(string nodeId)
+ {
+ return Task.FromResult(GetOutgoingEdges(nodeId));
+ }
+
+ ///
+ public Task>> GetIncomingEdgesAsync(string nodeId)
+ {
+ return Task.FromResult(GetIncomingEdges(nodeId));
+ }
+
+ ///
+ public Task>> GetNodesByLabelAsync(string label)
+ {
+ return Task.FromResult(GetNodesByLabel(label));
+ }
+
+ ///
+ public Task>> GetAllNodesAsync()
+ {
+ return Task.FromResult(GetAllNodes());
+ }
+
+ ///
+ public Task>> GetAllEdgesAsync()
+ {
+ return Task.FromResult(GetAllEdges());
+ }
+
+ ///
+ public Task ClearAsync()
+ {
+ Clear();
+ return Task.CompletedTask;
+ }
+}
diff --git a/src/RetrievalAugmentedGeneration/Graph/WriteAheadLog.cs b/src/RetrievalAugmentedGeneration/Graph/WriteAheadLog.cs
new file mode 100644
index 000000000..0d9f278cc
--- /dev/null
+++ b/src/RetrievalAugmentedGeneration/Graph/WriteAheadLog.cs
@@ -0,0 +1,396 @@
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Text;
+using Newtonsoft.Json;
+
+namespace AiDotNet.RetrievalAugmentedGeneration.Graph;
+
+///
+/// Write-Ahead Log (WAL) for ensuring ACID properties and crash recovery.
+///
+///
+///
+/// A Write-Ahead Log records all changes before they're applied to the main data files.
+/// This ensures data integrity and enables recovery after crashes.
+///
+/// For Beginners: Think of WAL like a ship's log or diary.
+///
+/// Before making any change to your graph:
+/// 1. Write what you're about to do in the log (WAL)
+/// 2. Make sure the log is saved to disk
+/// 3. Then make the actual change
+///
+/// If the system crashes:
+/// - The log shows what was happening
+/// - You can replay the log to restore the graph
+/// - No data is lost!
+///
+/// This is how databases ensure "durability" - the D in ACID.
+///
+/// Real-world analogy:
+/// - Bank transaction: First log "transfer $100", then move the money
+/// - If crash happens after logging but before transfer, replay the log on restart
+/// - Money isn't lost!
+///
+///
+public class WriteAheadLog : IDisposable
+{
+ private readonly string _walFilePath;
+ private StreamWriter? _writer;
+ private long _currentTransactionId;
+ private readonly object _lock = new object();
+ private bool _disposed;
+
+ ///
+ /// Gets the current transaction ID.
+ ///
+ public long CurrentTransactionId => _currentTransactionId;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// Path to the WAL file.
+ public WriteAheadLog(string walFilePath)
+ {
+ _walFilePath = walFilePath ?? throw new ArgumentNullException(nameof(walFilePath));
+
+ // Ensure directory exists
+ var directory = Path.GetDirectoryName(walFilePath);
+ if (!string.IsNullOrEmpty(directory) && !Directory.Exists(directory))
+ Directory.CreateDirectory(directory);
+
+ // Restore transaction ID from existing log to prevent duplicates after restart
+ _currentTransactionId = RestoreLastTransactionId();
+
+ // Open WAL file for append
+ _writer = new StreamWriter(_walFilePath, append: true, Encoding.UTF8)
+ {
+ AutoFlush = true // Critical: flush immediately for durability
+ };
+ }
+
+ ///
+ /// Restores the last transaction ID from an existing WAL file.
+ ///
+ /// The maximum transaction ID found in the log, or 0 if no log exists.
+ private long RestoreLastTransactionId()
+ {
+ if (!File.Exists(_walFilePath))
+ return 0;
+
+ long maxId = 0;
+ try
+ {
+ foreach (var line in File.ReadLines(_walFilePath))
+ {
+ try
+ {
+ var entry = JsonConvert.DeserializeObject(line);
+ if (entry != null && entry.TransactionId > maxId)
+ maxId = entry.TransactionId;
+ }
+ catch (JsonException)
+ {
+ // Skip malformed lines
+ }
+ }
+ }
+ catch (IOException)
+ {
+ // If we can't read the file, start from 0
+ return 0;
+ }
+ return maxId;
+ }
+
+ ///
+ /// Logs a node addition operation.
+ ///
+ /// The numeric type.
+ /// The node being added.
+ /// The transaction ID for this operation.
+ public long LogAddNode(GraphNode node)
+ {
+ lock (_lock)
+ {
+ var txnId = ++_currentTransactionId;
+ var entry = new WALEntry
+ {
+ TransactionId = txnId,
+ Timestamp = DateTime.UtcNow,
+ OperationType = WALOperationType.AddNode,
+ NodeId = node.Id,
+ Data = JsonConvert.SerializeObject(node)
+ };
+
+ WriteEntry(entry);
+ return txnId;
+ }
+ }
+
+ ///
+ /// Logs an edge addition operation.
+ ///
+ /// The numeric type.
+ /// The edge being added.
+ /// The transaction ID for this operation.
+ public long LogAddEdge(GraphEdge edge)
+ {
+ lock (_lock)
+ {
+ var txnId = ++_currentTransactionId;
+ var entry = new WALEntry
+ {
+ TransactionId = txnId,
+ Timestamp = DateTime.UtcNow,
+ OperationType = WALOperationType.AddEdge,
+ EdgeId = edge.Id,
+ Data = JsonConvert.SerializeObject(edge)
+ };
+
+ WriteEntry(entry);
+ return txnId;
+ }
+ }
+
+ ///
+ /// Logs a node removal operation.
+ ///
+ /// The ID of the node being removed.
+ /// The transaction ID for this operation.
+ public long LogRemoveNode(string nodeId)
+ {
+ lock (_lock)
+ {
+ var txnId = ++_currentTransactionId;
+ var entry = new WALEntry
+ {
+ TransactionId = txnId,
+ Timestamp = DateTime.UtcNow,
+ OperationType = WALOperationType.RemoveNode,
+ NodeId = nodeId
+ };
+
+ WriteEntry(entry);
+ return txnId;
+ }
+ }
+
+ ///
+ /// Logs an edge removal operation.
+ ///
+ /// The ID of the edge being removed.
+ /// The transaction ID for this operation.
+ public long LogRemoveEdge(string edgeId)
+ {
+ lock (_lock)
+ {
+ var txnId = ++_currentTransactionId;
+ var entry = new WALEntry
+ {
+ TransactionId = txnId,
+ Timestamp = DateTime.UtcNow,
+ OperationType = WALOperationType.RemoveEdge,
+ EdgeId = edgeId
+ };
+
+ WriteEntry(entry);
+ return txnId;
+ }
+ }
+
+ ///
+ /// Logs a checkpoint (all data successfully persisted to disk).
+ ///
+ /// The transaction ID for this checkpoint.
+ public long LogCheckpoint()
+ {
+ lock (_lock)
+ {
+ var txnId = ++_currentTransactionId;
+ var entry = new WALEntry
+ {
+ TransactionId = txnId,
+ Timestamp = DateTime.UtcNow,
+ OperationType = WALOperationType.Checkpoint
+ };
+
+ WriteEntry(entry);
+ return txnId;
+ }
+ }
+
+ ///
+ /// Reads all WAL entries from the log file.
+ ///
+ /// List of WAL entries in order.
+ public List ReadLog()
+ {
+ var entries = new List();
+
+ if (!File.Exists(_walFilePath))
+ return entries;
+
+ lock (_lock)
+ {
+ // Flush writer to ensure all entries are on disk
+ _writer?.Flush();
+
+ // Use FileShare.ReadWrite to allow reading while writer is still open (Windows compatibility)
+ using var fileStream = new FileStream(_walFilePath, FileMode.Open, FileAccess.Read, FileShare.ReadWrite);
+ using var reader = new StreamReader(fileStream, Encoding.UTF8);
+ string? line;
+ while ((line = reader.ReadLine()) != null)
+ {
+ try
+ {
+ var entry = JsonConvert.DeserializeObject(line);
+ if (entry != null)
+ entries.Add(entry);
+ }
+ catch (JsonSerializationException)
+ {
+ // Skip corrupted entries
+ }
+ }
+ }
+
+ return entries;
+ }
+
+ ///