@@ -214,12 +214,26 @@ def prepare_encode_kwargs(self):
214214 # 1) add the custom prompt formatting if a query is being embedded
215215 if self .is_query :
216216 encode_kwargs ["prompt" ] = (
217- f"<instruct>{ self .DEFAULT_INSTRUCTION } \n <query>"
217+ f"<instruct>{ self .DEFAULT_INSTRUCTION } <query>"
218218 )
219219
220220 return encode_kwargs
221221
222222
223+ class InflyEmbedding (BaseEmbeddingModel ):
224+ def prepare_kwargs (self ):
225+ # 1) inherit all kwargs from the base class
226+ infly_kwargs = super ().prepare_kwargs ()
227+
228+ # 2) update tokenizer_kwargs
229+ tok_kw = infly_kwargs .setdefault ("tokenizer_kwargs" , {})
230+ tok_kw .update ({
231+ "max_length" : 8192 ,
232+ })
233+
234+ return infly_kwargs
235+
236+
223237def create_vector_db_in_process (database_name ):
224238 create_vector_db = CreateVectorDB (database_name = database_name )
225239 create_vector_db .run ()
@@ -270,6 +284,7 @@ def initialize_vector_model(self, embedding_model_name, config_data):
270284 model_kwargs = {
271285 "device" : compute_device ,
272286 "trust_remote_code" : True ,
287+ "similarity_fn_name" : "euclidean" , # (str, optional); "cosine" (default), "dot", "euclidean", "manhattan"
273288 "model_kwargs" : {
274289 "torch_dtype" : torch_dtype if torch_dtype is not None else None
275290 }
@@ -294,6 +309,10 @@ def initialize_vector_model(self, embedding_model_name, config_data):
294309 'gte-base' : 14 ,
295310 'arctic-embed-m' : 14 ,
296311 'stella_en_400M_v5' : 20 ,
312+ 'bge-code' : 2 ,
313+ 'infly-retriever-v1-1.5b' : 4 ,
314+ 'infly-retriever-v1-7b' : 2 ,
315+ 'stella_en_1.5b_v5' : 4 ,
297316 }
298317
299318 for key , value in batch_size_mapping .items ():
@@ -311,13 +330,19 @@ def initialize_vector_model(self, embedding_model_name, config_data):
311330 model = SnowflakeEmbedding (embedding_model_name , model_kwargs , encode_kwargs ).create ()
312331 elif "alibaba" in embedding_model_name .lower ():
313332 logger .debug ("Matched Alibaba condition" )
314- model = AlibabaEmbedding (embedding_model_name , model_kwargs , encode_kwargs ).create ()
333+ model = InflyEmbedding (embedding_model_name , model_kwargs , encode_kwargs ).create ()
315334 elif "400m" in embedding_model_name .lower ():
316335 logger .debug ("Matched Stella 400m condition" )
317336 model = Stella400MEmbedding (embedding_model_name , model_kwargs , encode_kwargs ).create ()
318- elif "1.5b " in embedding_model_name .lower ():
337+ elif "stella_en_1.5b_v5 " in embedding_model_name .lower ():
319338 logger .debug ("Matched Stella 1.5B condition" )
320339 model = StellaEmbedding (embedding_model_name , model_kwargs , encode_kwargs ).create ()
340+ elif "bge-code" in embedding_model_name .lower ():
341+ logger .debug ("Matches bge-code condition" )
342+ model = BgeCodeEmbedding (embedding_model_name , model_kwargs , encode_kwargs ).create ()
343+ elif "infly" in embedding_model_name .lower ():
344+ logger .debug ("Matches infly condition" )
345+ model = InflyEmbedding (embedding_model_name , model_kwargs , encode_kwargs ).create ()
321346 else :
322347 logger .debug ("No conditions matched - using base model" )
323348 model = BaseEmbeddingModel (embedding_model_name , model_kwargs , encode_kwargs ).create ()
@@ -359,7 +384,7 @@ def create_database(self, texts, embeddings):
359384 tiledb_id = str (random .randint (0 , MAX_UINT64 - 1 ))
360385
361386 text_str = str (doc .page_content or "" ).strip ()
362- if not text_str : # silently drop zero-length chunks
387+ if not text_str : # silently drop zero-length chunks
363388 continue
364389 all_texts .append (text_str )
365390
@@ -383,7 +408,7 @@ def create_database(self, texts, embeddings):
383408 with open (self .ROOT_DIRECTORY / "config.yaml" , 'r' , encoding = 'utf-8' ) as config_file :
384409 config_data = yaml .safe_load (config_file )
385410
386- # pre‑compute vectors, then write DB
411+ # precompute vectors, then write DB
387412 vectors = embeddings .embed_documents (all_texts )
388413 text_embed_pairs = [
389414 (txt , np .asarray (vec , dtype = np .float32 ))
@@ -470,7 +495,6 @@ def create_metadata_db(self, documents, hash_id_mappings):
470495 finally :
471496 conn .close ()
472497
473-
474498 def load_audio_documents (self , source_dir : Path = None ) -> list :
475499 if source_dir is None :
476500 source_dir = self .SOURCE_DIRECTORY
@@ -598,39 +622,50 @@ def load_configuration(self):
598622 raise
599623
600624 @torch .inference_mode ()
601- def initialize_vector_model (self ):
602- model_path = self .config ['created_databases' ][self .selected_database ]['model' ]
625+ def initialize_vector_model (self ):
626+ model_path = self .config ['created_databases' ][self .selected_database ]['model' ]
603627 self .model_name = os .path .basename (model_path )
604- compute_device = self .config ['Compute_Device' ]['database_query' ]
628+ compute_device = self .config ['Compute_Device' ]['database_query' ]
605629
630+ # ── outer kwargs passed to SentenceTransformer ──────────────
606631 model_kwargs = {
607- "device" : compute_device ,
632+ "device" : compute_device ,
608633 "trust_remote_code" : True ,
609- "model_kwargs" : {}
634+ "similarity_fn_name" : "euclidean" , # (str, optional); "cosine" (default), "dot", "euclidean", "manhattan"
635+ "model_kwargs" : {
636+ "trust_remote_code" : True ,
637+ },
638+ "tokenizer_kwargs" : {
639+ "use_fast" : True ,
640+ "trust_remote_code" : True ,
641+ },
610642 }
611- # encode_kwargs = {'normalize_embeddings': True}
612643
613- if "snowflake" in model_path .lower ():
614- logger .debug ("Matched Snowflake condition" )
644+ encode_kwargs = {"batch_size" : 1 }
645+
646+ mp_lower = model_path .lower ()
647+ if "snowflake" in mp_lower :
615648 embeddings = SnowflakeEmbedding (model_path , model_kwargs , encode_kwargs , is_query = True ).create ()
616- elif "alibaba" in model_path .lower ():
617- logger .debug ("Matched Alibaba condition" )
618- embeddings = AlibabaEmbedding (model_path , model_kwargs , encode_kwargs , is_query = True ).create ()
619- elif "400m" in model_path .lower ():
620- logger .debug ("Matched Stella 400m condition" )
649+ elif "alibaba" in mp_lower :
650+ embeddings = InflyEmbedding (model_path , model_kwargs , encode_kwargs , is_query = True ).create ()
651+ elif "400m" in mp_lower :
621652 embeddings = Stella400MEmbedding (model_path , model_kwargs , encode_kwargs , is_query = True ).create ()
622- elif "1.5b" in model_path .lower ():
623- logger .debug ("Matched Stella 1.5B condition" )
653+ elif "stella_en_1.5b_v5" in mp_lower :
624654 embeddings = StellaEmbedding (model_path , model_kwargs , encode_kwargs , is_query = True ).create ()
655+ elif "infly" in mp_lower :
656+ embeddings = InflyEmbedding (model_path , model_kwargs , encode_kwargs , is_query = True ).create ()
657+ elif "bge-code" in mp_lower :
658+ embeddings = BgeCodeEmbedding (model_path , model_kwargs , encode_kwargs , is_query = True ).create ()
625659 else :
626- if "bge" in model_path . lower () :
627- logger . debug ( "Matched BGE condition - setting prompt in encode_kwargs" )
628- encode_kwargs [ "prompt" ] = "Represent this sentence for searching relevant passages: "
629- logger . debug ( "No specific condition matched - using base model" )
660+ if "bge" in mp_lower :
661+ encode_kwargs [ " prompt" ] = (
662+ "Represent this sentence for searching relevant passages: "
663+ )
630664 embeddings = BaseEmbeddingModel (model_path , model_kwargs , encode_kwargs , is_query = True ).create ()
631665
632666 return embeddings
633667
668+
634669 def initialize_database (self ):
635670 persist_directory = Path (__file__ ).resolve ().parent / "Vector_DB" / self .selected_database
636671
0 commit comments