diff --git a/.env.example b/.env.example index 1b9013f..168bcc3 100644 --- a/.env.example +++ b/.env.example @@ -25,3 +25,7 @@ PINECONE_INDEX_NAME=... ## Mongo Atlas MONGODB_URI=... # Full connection string + +## pgvector +PGVECTOR_CONNECTION_STRING=postgresql+psycopg://username:password@localhost:5432/your_database_name +PGVECTOR_COLLECTION_NAME=langchain # Optional, defaults to 'langchain' diff --git a/README.md b/README.md index 568ac32..f27ac2d 100644 --- a/README.md +++ b/README.md @@ -168,6 +168,71 @@ PINECONE_API_KEY=your-api-key PINECONE_INDEX_NAME=your-index-name ``` +#### pgvector + +`pgvector` is an open-source PostgreSQL extension for vector similarity search. It allows you to store and query embeddings directly within your PostgreSQL database. + +##### Setup pgvector + +1. **Install PostgreSQL and pgvector Extension:** + + - Ensure you have PostgreSQL installed (version 14 or higher is recommended). + - Install the `pgvector` extension: + - For PostgreSQL 14+ on Ubuntu/Debian: + + ```bash + # Install PostgreSQL if not installed + sudo apt-get install postgresql postgresql-contrib + + # Install pgvector extension + sudo apt-get install postgresql-{postgresql-version}-pgvector + ``` + + - Alternatively, you can install from source. Follow the instructions in the [pgvector GitHub repository](https://github.com/pgvector/pgvector#installation). + +2. **Create a Database and Enable pgvector:** + + - Create a new PostgreSQL database or use an existing one. + + ```sql + -- In the psql shell + CREATE DATABASE your_database_name; + ``` + + - Connect to your database and enable the `pgvector` extension: + + ```sql + -- Connect to your database + \c your_database_name + + -- Enable the pgvector extension + CREATE EXTENSION vector; + ``` + +3. **Set Up Environment Variables:** + + - In your `.env` file, add the following variables: + + ``` + PGVECTOR_CONNECTION_STRING=postgresql+psycopg://username:password@localhost:5432/your_database_name + PGVECTOR_COLLECTION_NAME=langchain # You can change this to your preferred collection/table name + ``` + + - Replace `username`, `password`, `localhost`, `5432`, and `your_database_name` with your actual PostgreSQL credentials and connection details. + +4. **Configure the Retriever Provider:** + + - Update the `retriever_provider` in your configuration to use `pgvector`: + + ```yaml + retriever_provider: pgvector + ``` + +5. **Verify the Setup:** + + - Ensure your application can connect to the PostgreSQL database. + - Test indexing and retrieval to confirm that `pgvector` is functioning correctly. + ### Setup Model diff --git a/pyproject.toml b/pyproject.toml index 047745a..e07e2c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "msgspec>=0.18.6", "langchain-mongodb>=0.1.9", "langchain-cohere>=0.2.4", + "langchain-postgres>=0.0.12", ] [project.optional-dependencies] diff --git a/src/retrieval_graph/configuration.py b/src/retrieval_graph/configuration.py index 677f389..125be89 100644 --- a/src/retrieval_graph/configuration.py +++ b/src/retrieval_graph/configuration.py @@ -32,12 +32,13 @@ class IndexConfiguration: ) retriever_provider: Annotated[ - Literal["elastic", "elastic-local", "pinecone", "mongodb"], + Literal["elastic", "elastic-local", "pinecone", "mongodb", "pgvector"], {"__template_metadata__": {"kind": "retriever"}}, ] = field( default="elastic", metadata={ - "description": "The vector store provider to use for retrieval. Options are 'elastic', 'pinecone', or 'mongodb'." + "description": "The vector store provider to use for retrieval. Options are 'elastic', 'pinecone', 'mongodb', or 'pgvector'." + }, ) diff --git a/src/retrieval_graph/retrieval.py b/src/retrieval_graph/retrieval.py index ffff12c..fc2ba7a 100644 --- a/src/retrieval_graph/retrieval.py +++ b/src/retrieval_graph/retrieval.py @@ -104,6 +104,89 @@ def make_mongodb_retriever( yield vstore.as_retriever(search_kwargs=search_kwargs) +@contextmanager +def make_pgvector_retriever( + configuration: IndexConfiguration, embedding_model: Embeddings +) -> Generator[VectorStoreRetriever, None, None]: + """Configure this agent to connect to a pgvector index.""" + import json + from typing import Any, List, Tuple + from langchain_postgres.vectorstores import PGVector as OverPGVector + from langchain_core.documents import Document + + class PGVector(OverPGVector): + """ + A custom override of the PGVector class to handle metadata deserialization issues + when operating in async_mode. This class addresses a known issue where metadata, + stored as byte data, is not properly converted back into a dictionary format + during asynchronous operations. + + The override specifically ensures that all metadata, whether stored as bytes, + strings, or other unrecognized formats, is correctly processed into a dictionary + format suitable for use within the application. This is crucial for maintaining + consistency and usability of metadata across asynchronous database interactions. + + Issue Reference: + "Metadata field not properly deserialized when using async_mode=True with PGVector #124" + + Methods: + _results_to_docs_and_scores: Converts query results from PGVector into a list + of tuples, each containing a Document and its corresponding + score, while ensuring metadata is correctly deserialized. + """ + def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]: + """Return docs and scores from results.""" + docs = [] + for result in results: + # Access the metadata + metadata = result.EmbeddingStore.cmetadata + + # Process the metadata to ensure it's a dict + if not isinstance(metadata, dict): + if hasattr(metadata, 'buf'): + # For Fragment types (e.g., asyncpg.Record) + metadata_bytes = metadata.buf + metadata_str = metadata_bytes.decode('utf-8') + metadata = json.loads(metadata_str) + elif isinstance(metadata, str): + # If it's a JSON string + metadata = json.loads(metadata) + else: + # Handle other types if necessary + metadata = {} + + doc = Document( + id=str(result.EmbeddingStore.id), + page_content=result.EmbeddingStore.document, + metadata=metadata, + ) + score = result.distance if self.embeddings is not None else None + docs.append((doc, score)) + return docs + + connection_string = os.environ.get("PGVECTOR_CONNECTION_STRING") + if not connection_string: + raise ValueError("PGVECTOR_CONNECTION_STRING environment variable is not set.") + + collection_name = os.environ.get("PGVECTOR_COLLECTION_NAME", "langchain") + + # Initialize the PGVector vector store with async_mode=True + vstore = PGVector( + connection=connection_string, + collection_name=collection_name, + embeddings=embedding_model, + use_jsonb=True, + pre_delete_collection=False, # Set to True if you want to delete existing data + async_mode=True + ) + + search_kwargs = configuration.search_kwargs + user_id = configuration.user_id + metadata_filter = search_kwargs.setdefault("filter", {}) + metadata_filter["user_id"] = {"$eq": user_id} + yield vstore.as_retriever(search_kwargs=search_kwargs) + + @contextmanager def make_retriever( config: RunnableConfig, @@ -127,9 +210,13 @@ def make_retriever( with make_mongodb_retriever(configuration, embedding_model) as retriever: yield retriever + case "pgvector": + with make_pgvector_retriever(configuration, embedding_model) as retriever: + yield retriever + case _: raise ValueError( "Unrecognized retriever_provider in configuration. " f"Expected one of: {', '.join(Configuration.__annotations__['retriever_provider'].__args__)}\n" f"Got: {configuration.retriever_provider}" - ) + ) \ No newline at end of file