@@ -43,10 +43,17 @@ def __init__(self, model_name, model_kwargs, encode_kwargs, is_query=False):
4343 def prepare_kwargs (self ):
4444 ready = deepcopy (self .model_kwargs )
4545
46+ # 1) update model_kwargs
47+ ready .setdefault ("model_kwargs" , {}).setdefault ("trust_remote_code" , True )
48+
49+ # 2) update tokenizer_kwargs
4650 tok_kw = ready .setdefault ("tokenizer_kwargs" , {})
51+ tok_kw .setdefault ("trust_remote_code" , True )
4752 tok_kw .setdefault ("padding" , True )
4853 tok_kw .setdefault ("truncation" , True )
4954 tok_kw .setdefault ("return_token_type_ids" , False )
55+ tok_kw .setdefault ("use_fast" , True )
56+ tok_kw .setdefault ("max_length" , 512 )
5057
5158 return ready
5259
@@ -55,12 +62,9 @@ def prepare_encode_kwargs(self):
5562 self .encode_kwargs ['batch_size' ] = 1
5663 return self .encode_kwargs
5764
58- # VERIFIED CORRECT
5965 def create (self ):
6066 prepared_kwargs = self .prepare_kwargs ()
6167 prepared_encode_kwargs = self .prepare_encode_kwargs ()
62- logger .debug ("HF init kwargs=%s" , prepared_kwargs )
63- logger .debug ("encode_kwargs=%s" , prepared_encode_kwargs )
6468
6569 hf = HuggingFaceEmbeddings (
6670 model_name = self .model_name ,
@@ -78,122 +82,144 @@ def create(self):
7882
7983class SnowflakeEmbedding (BaseEmbeddingModel ):
8084 def prepare_kwargs (self ):
81- # 1) inherit the padded / truncated tokenizer- kwargs from the base class
85+ # 1) inherit all kwargs from the base class
8286 snow_kwargs = super ().prepare_kwargs ()
8387
84- # 2) If this is a “ large” Snowflake model, no extra tweaks are required
88+ # 2) update tokenizer_kwargs for large model
8589 if "large" in self .model_name .lower ():
86- logging .debug ("Model name contains 'large' – returning base kwargs unchanged" )
90+ tok_kw = snow_kwargs .setdefault ("tokenizer_kwargs" , {})
91+ tok_kw .update ({"max_length" : 8192 })
92+
8793 return snow_kwargs
8894
89- # 3) Decide whether xFormers memory-efficient attention is available
90- compute_device = snow_kwargs .get ("device" , "" )
91- is_cuda = compute_device .lower (). startswith ("cuda" )
95+ # 1) determine if xformers can be used
96+ compute_device = snow_kwargs .get ("device" , "" ). lower ()
97+ is_cuda = compute_device .startswith ("cuda" )
9298 use_xformers = is_cuda and supports_flash_attention ()
9399
94- # 4) Merge or create the config-level overrides
95- extra_cfg = {
100+ # 2) update tokenizer_kwargs for medium model
101+ tok_kw = snow_kwargs .setdefault ("tokenizer_kwargs" , {})
102+ tok_kw .update ({"max_length" : 8192 })
103+
104+ # 3) update config_kwargs for medium model
105+ snow_kwargs ["config_kwargs" ] = {
96106 "use_memory_efficient_attention" : use_xformers ,
97107 "unpad_inputs" : use_xformers ,
98108 "attn_implementation" : "eager" if use_xformers else "sdpa" ,
99109 }
100- snow_kwargs ["config_kwargs" ] = {
101- ** snow_kwargs .get ("config_kwargs" , {}), # keep any user-supplied keys
102- ** extra_cfg ,
103- }
104110
105- logging .debug ("Final Snowflake config_kwargs: %s" , snow_kwargs ["config_kwargs" ])
106111 return snow_kwargs
107112
108113
109114class StellaEmbedding (BaseEmbeddingModel ):
110115 def prepare_kwargs (self ):
111- # Start with the padded/truncated tokenizer-kwargs that the base class
112- # already adds (return_token_type_ids=False, etc.).
116+ # 1) inherit all kwargs from the base class
113117 stella_kwargs = super ().prepare_kwargs ()
114118
115- # Ensure the HF loader is allowed to execute remote code for Stella.
116- stella_kwargs .setdefault ("model_kwargs" , {})
117- stella_kwargs ["model_kwargs" ]["trust_remote_code" ] = True
119+ # 2) update tokenizer_kwargs
120+ tok_kw = stella_kwargs .setdefault ("tokenizer_kwargs" , {})
121+ tok_kw .update ({
122+ "max_length" : 512 ,
123+ })
118124
119125 return stella_kwargs
120126
121127 def prepare_encode_kwargs (self ):
122128 encode_kwargs = super ().prepare_encode_kwargs ()
129+ # 1) add the appropriate prompt_name if a query is being embedded
123130 if self .is_query :
124131 encode_kwargs ["prompt_name" ] = "s2p_query"
132+
125133 return encode_kwargs
126134
127135
128136class Stella400MEmbedding (BaseEmbeddingModel ):
129137 def prepare_kwargs (self ):
130- # 1) start from the padded/truncated tokenizer defaults
138+ # 1) inherit all kwargs from the base class
131139 stella_kwargs = super ().prepare_kwargs ()
132140
133- # 2) detect whether we can use xFormers kernels
141+ # 2) determine if xformers can be used
134142 compute_device = stella_kwargs .get ("device" , "" )
135- is_cuda = compute_device .lower ().startswith ("cuda" )
136- use_xformers = is_cuda and supports_flash_attention ()
137-
138- # 3) ensure the inner model_kwargs exists and allows remote code
139- stella_kwargs .setdefault ("model_kwargs" , {})
140- stella_kwargs ["model_kwargs" ]["trust_remote_code" ] = True
143+ is_cuda = compute_device .lower ().startswith ("cuda" )
144+ use_xformers = is_cuda and supports_flash_attention ()
141145
142- # 4) merge/ update tokenizer settings *without* losing existing keys
146+ # 3) update tokenizer_kwargs
143147 tok_kw = stella_kwargs .setdefault ("tokenizer_kwargs" , {})
144148 tok_kw .update ({
145- "max_length" : 8000 ,
146- "padding" : True ,
147- "truncation" : True ,
149+ "max_length" : 512 ,
148150 })
149151
150- # 5) add config-level overrides
152+ # 4) update config_kwargs
151153 stella_kwargs ["config_kwargs" ] = {
152154 "use_memory_efficient_attention" : use_xformers ,
153155 "unpad_inputs" : use_xformers ,
154- "attn_implementation" : "eager" , # always eager for 400 M
156+ "attn_implementation" : "eager" , # always " eager" even when not using xformers
155157 }
156158
157- logger .debug ("Stella400M kwargs → %s" , stella_kwargs )
158159 return stella_kwargs
159160
160161 def prepare_encode_kwargs (self ):
161162 encode_kwargs = super ().prepare_encode_kwargs ()
163+ # 1) add the appropriate prompt_name if a query is being embedded
162164 if self .is_query :
163165 encode_kwargs ["prompt_name" ] = "s2p_query"
166+
164167 return encode_kwargs
165168
166169
167170class AlibabaEmbedding (BaseEmbeddingModel ):
168171 def prepare_kwargs (self ):
169- logging .debug ("Starting AlibabaEmbedding prepare_kwargs." )
170- ali_kwargs = deepcopy (self .model_kwargs )
171- logging .debug (f"Original model_kwargs: { self .model_kwargs } " )
172+ # 1) inherit all kwargs from the base class
173+ ali_kwargs = super ().prepare_kwargs ()
172174
173- compute_device = self .model_kwargs .get ("device" , "" ).lower ()
174- is_cuda = compute_device == "cuda"
175+ # 2) determine if xformers can be used
176+ compute_device = ali_kwargs .get ("device" , "" ).lower ()
177+ is_cuda = compute_device .startswith ("cuda" )
175178 use_xformers = is_cuda and supports_flash_attention ()
176- logging .debug (f"Device: { compute_device } " )
177- logging .debug (f"is_cuda: { is_cuda } " )
178- logging .debug (f"use_xformers: { use_xformers } " )
179179
180- ali_kwargs ["tokenizer_kwargs" ] = {
180+ # 3) update tokenizer_kwargs
181+ tok_kw = ali_kwargs .setdefault ("tokenizer_kwargs" , {})
182+ tok_kw .update ({
181183 "max_length" : 8192 ,
182- "padding" : True ,
183- "truncation" : True
184- }
184+ })
185185
186+ # 4) update config_kwargs
186187 ali_kwargs ["config_kwargs" ] = {
187188 "use_memory_efficient_attention" : use_xformers ,
188189 "unpad_inputs" : use_xformers ,
189- "attn_implementation" : "eager" if use_xformers else "sdpa"
190+ "attn_implementation" : "eager" if use_xformers else "sdpa" ,
190191 }
191- logging .debug (f"Set 'config_kwargs': { ali_kwargs ['config_kwargs' ]} " )
192192
193- logging .debug (f"Final ali_kwargs: { ali_kwargs } " )
194193 return ali_kwargs
195194
196195
196+ class BgeCodeEmbedding (BaseEmbeddingModel ):
197+ DEFAULT_INSTRUCTION = ("Given a question in text, retrieve relevant code that is relevant." )
198+
199+ def prepare_kwargs (self ):
200+ # 1) inherit all kwargs from the base class
201+ bge_kwargs = super ().prepare_kwargs ()
202+
203+ # 2) update tokenizer_kwargs
204+ tok_kw = bge_kwargs .setdefault ("tokenizer_kwargs" , {})
205+ tok_kw .update ({
206+ "max_length" : 4096 ,
207+ })
208+
209+ return bge_kwargs
210+
211+ def prepare_encode_kwargs (self ):
212+ encode_kwargs = super ().prepare_encode_kwargs ()
213+
214+ # 1) add the custom prompt formatting if a query is being embedded
215+ if self .is_query :
216+ encode_kwargs ["prompt" ] = (
217+ f"<instruct>{ self .DEFAULT_INSTRUCTION } \n <query>"
218+ )
219+
220+ return encode_kwargs
221+
222+
197223def create_vector_db_in_process (database_name ):
198224 create_vector_db = CreateVectorDB (database_name = database_name )
199225 create_vector_db .run ()
@@ -249,7 +275,9 @@ def initialize_vector_model(self, embedding_model_name, config_data):
249275 }
250276 }
251277
252- encode_kwargs = {'normalize_embeddings' : True , 'batch_size' : 8 }
278+ # encode_kwargs = {'normalize_embeddings': True, 'batch_size': 8}
279+ # encode_kwargs = {'max_length': 512, 'batch_size': 8}
280+ encode_kwargs = {'batch_size' : 8 }
253281
254282 if compute_device .lower () == 'cpu' :
255283 encode_kwargs ['batch_size' ] = 2
@@ -265,7 +293,7 @@ def initialize_vector_model(self, embedding_model_name, config_data):
265293 'bge-small' : 12 ,
266294 'gte-base' : 14 ,
267295 'arctic-embed-m' : 14 ,
268- 'stella_en_400M_v5' : 12 ,
296+ 'stella_en_400M_v5' : 20 ,
269297 }
270298
271299 for key , value in batch_size_mapping .items ():
@@ -355,7 +383,7 @@ def create_database(self, texts, embeddings):
355383 with open (self .ROOT_DIRECTORY / "config.yaml" , 'r' , encoding = 'utf-8' ) as config_file :
356384 config_data = yaml .safe_load (config_file )
357385
358- # ── NEW EMBEDDING FLOW: pre‑compute vectors, then write DB ─────────
386+ # pre‑compute vectors, then write DB
359387 vectors = embeddings .embed_documents (all_texts )
360388 text_embed_pairs = [
361389 (txt , np .asarray (vec , dtype = np .float32 ))
@@ -580,7 +608,7 @@ def initialize_vector_model(self):
580608 "trust_remote_code" : True ,
581609 "model_kwargs" : {}
582610 }
583- encode_kwargs = {'normalize_embeddings' : True }
611+ # encode_kwargs = {'normalize_embeddings': True}
584612
585613 if "snowflake" in model_path .lower ():
586614 logger .debug ("Matched Snowflake condition" )
@@ -676,3 +704,100 @@ def cleanup(self):
676704
677705 gc .collect ()
678706 logging .debug (f"Cleanup completed for instance { self ._debug_id } " )
707+
708+ """
709+ ╔══════════════════════════════════════════════════════════════════════════╗
710+ ║ DEVELOPMENT NOTES – xFormers flags, attention-impl, and tokenization ║
711+ ╚══════════════════════════════════════════════════════════════════════════╝
712+
713+ ────────────────────────────────────────────────────────────────────────────
714+ 1. Which models can use xFormers memory-efficient attention?
715+ ────────────────────────────────────────────────────────────────────────────
716+ • Snowflake-GTE family (all sizes except the “-large” variants)
717+ • Alibaba-GTE family
718+ • Stella-400 M (v5)
719+
720+ Stella-1.5 B **cannot** use xFormers kernels at the time of writing.
721+
722+ ────────────────────────────────────────────────────────────────────────────
723+ 2. Snowflake-GTE & Alibaba-GTE (shared behaviour)
724+ ────────────────────────────────────────────────────────────────────────────
725+ ✔ Flags belong in ✧config_kwargs✧ (which LangChain forwards to AutoConfig):
726+
727+ {
728+ "config_kwargs": {
729+ "use_memory_efficient_attention": <bool>, # enable xFormers
730+ "unpad_inputs": <bool>, # strip padding tokens
731+ "attn_implementation": "eager" # MUST be "eager"
732+ }
733+ }
734+
735+ Implementation rules inside the GTE source:
736+
737+ • If use_memory_efficient_attention is **True**
738+ – xFormers must be importable, otherwise an assertion fires.
739+ – attn_implementation is automatically coerced to "eager"
740+ (the code does this for you, but supplying "eager" is clearer).
741+
742+ • If use_memory_efficient_attention is **False**
743+ – You may still set unpad_inputs=True. The model will unpad/re-pad
744+ tensors using pure-PyTorch helpers (slower but functional).
745+ – attn_implementation can be "sdpa" or "eager". Either works.
746+
747+ ────────────────────────────────────────────────────────────────────────────
748+ 3. Stella-400 M (v5)
749+ ────────────────────────────────────────────────────────────────────────────
750+ ✔ Same flag block, but with stricter rules:
751+
752+ {
753+ "config_kwargs": {
754+ "use_memory_efficient_attention": <bool>, # optional
755+ "unpad_inputs": <bool>, # should match the flag above
756+ "attn_implementation": "eager" # ALWAYS "eager"
757+ }
758+ }
759+
760+ • The 400 M code path **does not implement an SDPA class** yet, so
761+ "eager" is mandatory even when xFormers is disabled.
762+
763+ • If you set use_memory_efficient_attention=True while xFormers is
764+ missing, an assertion will raise at runtime.
765+
766+ ────────────────────────────────────────────────────────────────────────────
767+ 4. Flag placement summary
768+ ────────────────────────────────────────────────────────────────────────────
769+ outer model_kwargs (passed to SentenceTransformer)
770+ │
771+ ├── tokenizer_kwargs → forwarded to AutoTokenizer ← configure padding
772+ │ & truncation here
773+ │
774+ ├── model_kwargs → forwarded to AutoModel ← runtime knobs
775+ │ (dtype, quantisation, ...)
776+ │
777+ └── config_kwargs → forwarded to AutoConfig ← put the three
778+ (BEFORE weights load) xFormers flags here
779+ • use_memory_efficient_attention
780+ • unpad_inputs
781+ • attn_implementation
782+
783+ ────────────────────────────────────────────────────────────────────────────
784+ 5. Tokenization vs. encode_kwargs (common pit-fall)
785+ ────────────────────────────────────────────────────────────────────────────
786+ • SentenceTransformer.encode() *never* forwards encode_kwargs into the
787+ tokenizer. It tokenizes first, then passes encode_kwargs into the model’s
788+ forward() call.
789+
790+ • Therefore:
791+ – padding / truncation / max_length → tokenizer_kwargs
792+ – batch_size, convert_to_numpy, etc. → encode_kwargs
793+
794+ Putting padding flags in encode_kwargs is silently ignored.
795+
796+ ────────────────────────────────────────────────────────────────────────────
797+ 6. Runtime checklist
798+ ────────────────────────────────────────────────────────────────────────────
799+ □ GPU build? → supports_flash_attention() must confirm
800+ □ xFormers installed? → import xformers.ops succeeds
801+ □ Flags consistent? → unpad_inputs should mirror use_memory_efficient_attention
802+ □ attn_implementation → "eager" for 400 M; "eager"/"sdpa" for others
803+ """
0 commit comments