55
66import sqlalchemy
77from llama_index .core .bridge .pydantic import BaseModel , Field
8- from llama_index .core .vector_stores .types import VectorStoreQuery
98from sqlalchemy .sql .selectable import Select
109
1110from llama_index .vector_stores .postgres .base import (
@@ -36,7 +35,17 @@ def get_bm25_data_model(
3635 from pgvector .sqlalchemy import Vector , HALFVEC
3736 from sqlalchemy import Column
3837 from sqlalchemy .dialects .postgresql import BIGINT , JSON , JSONB , VARCHAR
39- from sqlalchemy import cast , column , String , Integer , Numeric , Float , Boolean , Date , DateTime
38+ from sqlalchemy import (
39+ cast ,
40+ column ,
41+ String ,
42+ Integer ,
43+ Numeric ,
44+ Float ,
45+ Boolean ,
46+ Date ,
47+ DateTime ,
48+ )
4049 from sqlalchemy .dialects .postgresql import DOUBLE_PRECISION , UUID
4150 from sqlalchemy .schema import Index
4251
@@ -54,7 +63,7 @@ def get_bm25_data_model(
5463 }
5564
5665 indexed_metadata_keys = indexed_metadata_keys or set ()
57-
66+
5867 for key , pg_type in indexed_metadata_keys :
5968 if pg_type not in pg_type_map :
6069 raise ValueError (
@@ -67,7 +76,9 @@ def get_bm25_data_model(
6776 indexname = f"{ index_name } _idx"
6877
6978 metadata_dtype = JSONB if use_jsonb else JSON
70- embedding_col = Column (HALFVEC (embed_dim )) if use_halfvec else Column (Vector (embed_dim ))
79+ embedding_col = (
80+ Column (HALFVEC (embed_dim )) if use_halfvec else Column (Vector (embed_dim ))
81+ )
7182
7283 metadata_indices = [
7384 Index (
@@ -107,7 +118,7 @@ class BM25AbstractData(base):
107118class ParadeDBVectorStore (PGVectorStore , BaseModel ):
108119 """
109120 ParadeDB Vector Store with BM25 search support.
110-
121+
111122 Inherits from PGVectorStore and adds BM25 full-text search capabilities
112123 using ParadeDB's pg_search extension.
113124
@@ -130,16 +141,19 @@ class ParadeDBVectorStore(PGVectorStore, BaseModel):
130141 use_halfvec=True
131142 )
132143 ```
144+
133145 """
134146
135147 connection_string : Optional [Union [str , sqlalchemy .engine .URL ]] = Field (default = None )
136- async_connection_string : Optional [Union [str , sqlalchemy .engine .URL ]] = Field (default = None )
148+ async_connection_string : Optional [Union [str , sqlalchemy .engine .URL ]] = Field (
149+ default = None
150+ )
137151 table_name : Optional [str ] = Field (default = None )
138152 schema_name : Optional [str ] = Field (default = "paradedb" )
139153 hybrid_search : bool = Field (default = False )
140154 text_search_config : str = Field (default = "english" )
141155 embed_dim : int = Field (default = 1536 )
142- cache_ok : bool = Field (default = False )
156+ cache_ok : bool = Field (default = False )
143157 perform_setup : bool = Field (default = True )
144158 debug : bool = Field (default = False )
145159 use_jsonb : bool = Field (default = False )
@@ -154,7 +168,7 @@ def __init__(
154168 table_name : Optional [str ] = None ,
155169 schema_name : Optional [str ] = None ,
156170 hybrid_search : bool = False ,
157- text_search_config : str = "english" ,
171+ text_search_config : str = "english" ,
158172 embed_dim : int = 1536 ,
159173 cache_ok : bool = False ,
160174 perform_setup : bool = True ,
@@ -176,7 +190,7 @@ def __init__(
176190 self ,
177191 connection_string = connection_string ,
178192 async_connection_string = async_connection_string ,
179- table_name = table_name ,
193+ table_name = table_name ,
180194 schema_name = schema_name or "paradedb" ,
181195 hybrid_search = hybrid_search ,
182196 text_search_config = text_search_config ,
@@ -187,14 +201,16 @@ def __init__(
187201 use_jsonb = use_jsonb ,
188202 hnsw_kwargs = hnsw_kwargs ,
189203 create_engine_kwargs = create_engine_kwargs ,
190- use_bm25 = use_bm25
204+ use_bm25 = use_bm25 ,
191205 )
192-
206+
193207 # Call parent constructor
194208 PGVectorStore .__init__ (
195209 self ,
196210 connection_string = str (connection_string ) if connection_string else None ,
197- async_connection_string = str (async_connection_string ) if async_connection_string else None ,
211+ async_connection_string = str (async_connection_string )
212+ if async_connection_string
213+ else None ,
198214 table_name = table_name ,
199215 schema_name = self .schema_name ,
200216 hybrid_search = hybrid_search ,
@@ -213,10 +229,11 @@ def __init__(
213229 indexed_metadata_keys = indexed_metadata_keys ,
214230 customize_query_fn = customize_query_fn ,
215231 )
216-
232+
217233 # Override table model if using BM25
218234 if self .use_bm25 :
219235 from sqlalchemy .orm import declarative_base
236+
220237 self ._base = declarative_base ()
221238 self ._table_class = get_bm25_data_model (
222239 self ._base ,
@@ -270,6 +287,7 @@ def from_params(
270287
271288 Returns:
272289 ParadeDBVectorStore: Instance of ParadeDBVectorStore.
290+
273291 """
274292 conn_str = (
275293 connection_string
@@ -301,7 +319,7 @@ def from_params(
301319 def _create_extension (self ) -> None :
302320 """Override to add pg_search extension for BM25."""
303321 super ()._create_extension ()
304-
322+
305323 if self .use_bm25 :
306324 with self ._session () as session , session .begin ():
307325 try :
@@ -337,7 +355,7 @@ def _initialize(self) -> None:
337355 """Override to add BM25 index creation."""
338356 if not self ._is_initialized :
339357 super ()._initialize ()
340-
358+
341359 if self .use_bm25 and self .perform_setup :
342360 try :
343361 self ._create_bm25_index ()
@@ -355,10 +373,12 @@ def _build_sparse_query(
355373 ) -> Any :
356374 """Override to use BM25 if enabled, otherwise use parent's ts_vector."""
357375 if not self .use_bm25 :
358- return super ()._build_sparse_query (query_str , limit , metadata_filters , ** kwargs )
359-
376+ return super ()._build_sparse_query (
377+ query_str , limit , metadata_filters , ** kwargs
378+ )
379+
360380 from sqlalchemy import text
361-
381+
362382 if query_str is None :
363383 raise ValueError ("query_str must be specified for a sparse vector query." )
364384
@@ -373,14 +393,12 @@ def _build_sparse_query(
373393 if metadata_filters :
374394 _logger .warning ("Metadata filters not fully implemented for BM25 raw SQL" )
375395
376- stmt = text (f"""
396+ return text (f"""
377397 { base_query }
378398 ORDER BY rank DESC
379399 LIMIT :limit
380400 """ ).bindparams (query = query_str_clean , limit = limit )
381401
382- return stmt
383-
384402 def _sparse_query_with_rank (
385403 self ,
386404 query_str : Optional [str ] = None ,
@@ -390,7 +408,7 @@ def _sparse_query_with_rank(
390408 """Override to handle BM25 results properly."""
391409 if not self .use_bm25 :
392410 return super ()._sparse_query_with_rank (query_str , limit , metadata_filters )
393-
411+
394412 stmt = self ._build_sparse_query (query_str , limit , metadata_filters )
395413 with self ._session () as session , session .begin ():
396414 res = session .execute (stmt )
@@ -417,8 +435,10 @@ async def _async_sparse_query_with_rank(
417435 ) -> List [DBEmbeddingRow ]:
418436 """Override to handle async BM25 results properly."""
419437 if not self .use_bm25 :
420- return await super ()._async_sparse_query_with_rank (query_str , limit , metadata_filters )
421-
438+ return await super ()._async_sparse_query_with_rank (
439+ query_str , limit , metadata_filters
440+ )
441+
422442 stmt = self ._build_sparse_query (query_str , limit , metadata_filters )
423443 async with self ._async_session () as session , session .begin ():
424444 res = await session .execute (stmt )
@@ -435,4 +455,4 @@ async def _async_sparse_query_with_rank(
435455 similarity = item .rank ,
436456 )
437457 for item in res .all ()
438- ]
458+ ]
0 commit comments