Skip to content

Commit 3f27182

Browse files
authored
v7.8.0
1 parent 1e04fb4 commit 3f27182

File tree

1 file changed

+180
-55
lines changed

1 file changed

+180
-55
lines changed

src/database_interactions.py

Lines changed: 180 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,17 @@ def __init__(self, model_name, model_kwargs, encode_kwargs, is_query=False):
4343
def prepare_kwargs(self):
4444
ready = deepcopy(self.model_kwargs)
4545

46+
# 1) update model_kwargs
47+
ready.setdefault("model_kwargs", {}).setdefault("trust_remote_code", True)
48+
49+
# 2) update tokenizer_kwargs
4650
tok_kw = ready.setdefault("tokenizer_kwargs", {})
51+
tok_kw.setdefault("trust_remote_code", True)
4752
tok_kw.setdefault("padding", True)
4853
tok_kw.setdefault("truncation", True)
4954
tok_kw.setdefault("return_token_type_ids", False)
55+
tok_kw.setdefault("use_fast", True)
56+
tok_kw.setdefault("max_length", 512)
5057

5158
return ready
5259

@@ -55,12 +62,9 @@ def prepare_encode_kwargs(self):
5562
self.encode_kwargs['batch_size'] = 1
5663
return self.encode_kwargs
5764

58-
# VERIFIED CORRECT
5965
def create(self):
6066
prepared_kwargs = self.prepare_kwargs()
6167
prepared_encode_kwargs = self.prepare_encode_kwargs()
62-
logger.debug("HF init kwargs=%s", prepared_kwargs)
63-
logger.debug("encode_kwargs=%s", prepared_encode_kwargs)
6468

6569
hf = HuggingFaceEmbeddings(
6670
model_name=self.model_name,
@@ -78,122 +82,144 @@ def create(self):
7882

7983
class SnowflakeEmbedding(BaseEmbeddingModel):
8084
def prepare_kwargs(self):
81-
# 1) inherit the padded / truncated tokenizer-kwargs from the base class
85+
# 1) inherit all kwargs from the base class
8286
snow_kwargs = super().prepare_kwargs()
8387

84-
# 2) If this is a “large” Snowflake model, no extra tweaks are required
88+
# 2) update tokenizer_kwargs for large model
8589
if "large" in self.model_name.lower():
86-
logging.debug("Model name contains 'large' – returning base kwargs unchanged")
90+
tok_kw = snow_kwargs.setdefault("tokenizer_kwargs", {})
91+
tok_kw.update({"max_length": 8192})
92+
8793
return snow_kwargs
8894

89-
# 3) Decide whether xFormers memory-efficient attention is available
90-
compute_device = snow_kwargs.get("device", "")
91-
is_cuda = compute_device.lower().startswith("cuda")
95+
# 1) determine if xformers can be used
96+
compute_device = snow_kwargs.get("device", "").lower()
97+
is_cuda = compute_device.startswith("cuda")
9298
use_xformers = is_cuda and supports_flash_attention()
9399

94-
# 4) Merge or create the config-level overrides
95-
extra_cfg = {
100+
# 2) update tokenizer_kwargs for medium model
101+
tok_kw = snow_kwargs.setdefault("tokenizer_kwargs", {})
102+
tok_kw.update({"max_length": 8192})
103+
104+
# 3) update config_kwargs for medium model
105+
snow_kwargs["config_kwargs"] = {
96106
"use_memory_efficient_attention": use_xformers,
97107
"unpad_inputs": use_xformers,
98108
"attn_implementation": "eager" if use_xformers else "sdpa",
99109
}
100-
snow_kwargs["config_kwargs"] = {
101-
**snow_kwargs.get("config_kwargs", {}), # keep any user-supplied keys
102-
**extra_cfg,
103-
}
104110

105-
logging.debug("Final Snowflake config_kwargs: %s", snow_kwargs["config_kwargs"])
106111
return snow_kwargs
107112

108113

109114
class StellaEmbedding(BaseEmbeddingModel):
110115
def prepare_kwargs(self):
111-
# Start with the padded/truncated tokenizer-kwargs that the base class
112-
# already adds (return_token_type_ids=False, etc.).
116+
# 1) inherit all kwargs from the base class
113117
stella_kwargs = super().prepare_kwargs()
114118

115-
# Ensure the HF loader is allowed to execute remote code for Stella.
116-
stella_kwargs.setdefault("model_kwargs", {})
117-
stella_kwargs["model_kwargs"]["trust_remote_code"] = True
119+
# 2) update tokenizer_kwargs
120+
tok_kw = stella_kwargs.setdefault("tokenizer_kwargs", {})
121+
tok_kw.update({
122+
"max_length": 512,
123+
})
118124

119125
return stella_kwargs
120126

121127
def prepare_encode_kwargs(self):
122128
encode_kwargs = super().prepare_encode_kwargs()
129+
# 1) add the appropriate prompt_name if a query is being embedded
123130
if self.is_query:
124131
encode_kwargs["prompt_name"] = "s2p_query"
132+
125133
return encode_kwargs
126134

127135

128136
class Stella400MEmbedding(BaseEmbeddingModel):
129137
def prepare_kwargs(self):
130-
# 1) start from the padded/truncated tokenizer defaults
138+
# 1) inherit all kwargs from the base class
131139
stella_kwargs = super().prepare_kwargs()
132140

133-
# 2) detect whether we can use xFormers kernels
141+
# 2) determine if xformers can be used
134142
compute_device = stella_kwargs.get("device", "")
135-
is_cuda = compute_device.lower().startswith("cuda")
136-
use_xformers = is_cuda and supports_flash_attention()
137-
138-
# 3) ensure the inner model_kwargs exists and allows remote code
139-
stella_kwargs.setdefault("model_kwargs", {})
140-
stella_kwargs["model_kwargs"]["trust_remote_code"] = True
143+
is_cuda = compute_device.lower().startswith("cuda")
144+
use_xformers = is_cuda and supports_flash_attention()
141145

142-
# 4) merge/update tokenizer settings *without* losing existing keys
146+
# 3) update tokenizer_kwargs
143147
tok_kw = stella_kwargs.setdefault("tokenizer_kwargs", {})
144148
tok_kw.update({
145-
"max_length": 8000,
146-
"padding": True,
147-
"truncation": True,
149+
"max_length": 512,
148150
})
149151

150-
# 5) add config-level overrides
152+
# 4) update config_kwargs
151153
stella_kwargs["config_kwargs"] = {
152154
"use_memory_efficient_attention": use_xformers,
153155
"unpad_inputs": use_xformers,
154-
"attn_implementation": "eager", # always eager for 400 M
156+
"attn_implementation": "eager", # always "eager" even when not using xformers
155157
}
156158

157-
logger.debug("Stella400M kwargs → %s", stella_kwargs)
158159
return stella_kwargs
159160

160161
def prepare_encode_kwargs(self):
161162
encode_kwargs = super().prepare_encode_kwargs()
163+
# 1) add the appropriate prompt_name if a query is being embedded
162164
if self.is_query:
163165
encode_kwargs["prompt_name"] = "s2p_query"
166+
164167
return encode_kwargs
165168

166169

167170
class AlibabaEmbedding(BaseEmbeddingModel):
168171
def prepare_kwargs(self):
169-
logging.debug("Starting AlibabaEmbedding prepare_kwargs.")
170-
ali_kwargs = deepcopy(self.model_kwargs)
171-
logging.debug(f"Original model_kwargs: {self.model_kwargs}")
172+
# 1) inherit all kwargs from the base class
173+
ali_kwargs = super().prepare_kwargs()
172174

173-
compute_device = self.model_kwargs.get("device", "").lower()
174-
is_cuda = compute_device == "cuda"
175+
# 2) determine if xformers can be used
176+
compute_device = ali_kwargs.get("device", "").lower()
177+
is_cuda = compute_device.startswith("cuda")
175178
use_xformers = is_cuda and supports_flash_attention()
176-
logging.debug(f"Device: {compute_device}")
177-
logging.debug(f"is_cuda: {is_cuda}")
178-
logging.debug(f"use_xformers: {use_xformers}")
179179

180-
ali_kwargs["tokenizer_kwargs"] = {
180+
# 3) update tokenizer_kwargs
181+
tok_kw = ali_kwargs.setdefault("tokenizer_kwargs", {})
182+
tok_kw.update({
181183
"max_length": 8192,
182-
"padding": True,
183-
"truncation": True
184-
}
184+
})
185185

186+
# 4) update config_kwargs
186187
ali_kwargs["config_kwargs"] = {
187188
"use_memory_efficient_attention": use_xformers,
188189
"unpad_inputs": use_xformers,
189-
"attn_implementation": "eager" if use_xformers else "sdpa"
190+
"attn_implementation": "eager" if use_xformers else "sdpa",
190191
}
191-
logging.debug(f"Set 'config_kwargs': {ali_kwargs['config_kwargs']}")
192192

193-
logging.debug(f"Final ali_kwargs: {ali_kwargs}")
194193
return ali_kwargs
195194

196195

196+
class BgeCodeEmbedding(BaseEmbeddingModel):
197+
DEFAULT_INSTRUCTION = ("Given a question in text, retrieve relevant code that is relevant.")
198+
199+
def prepare_kwargs(self):
200+
# 1) inherit all kwargs from the base class
201+
bge_kwargs = super().prepare_kwargs()
202+
203+
# 2) update tokenizer_kwargs
204+
tok_kw = bge_kwargs.setdefault("tokenizer_kwargs", {})
205+
tok_kw.update({
206+
"max_length": 4096,
207+
})
208+
209+
return bge_kwargs
210+
211+
def prepare_encode_kwargs(self):
212+
encode_kwargs = super().prepare_encode_kwargs()
213+
214+
# 1) add the custom prompt formatting if a query is being embedded
215+
if self.is_query:
216+
encode_kwargs["prompt"] = (
217+
f"<instruct>{self.DEFAULT_INSTRUCTION}\n<query>"
218+
)
219+
220+
return encode_kwargs
221+
222+
197223
def create_vector_db_in_process(database_name):
198224
create_vector_db = CreateVectorDB(database_name=database_name)
199225
create_vector_db.run()
@@ -249,7 +275,9 @@ def initialize_vector_model(self, embedding_model_name, config_data):
249275
}
250276
}
251277

252-
encode_kwargs = {'normalize_embeddings': True, 'batch_size': 8}
278+
# encode_kwargs = {'normalize_embeddings': True, 'batch_size': 8}
279+
# encode_kwargs = {'max_length': 512, 'batch_size': 8}
280+
encode_kwargs = {'batch_size': 8}
253281

254282
if compute_device.lower() == 'cpu':
255283
encode_kwargs['batch_size'] = 2
@@ -265,7 +293,7 @@ def initialize_vector_model(self, embedding_model_name, config_data):
265293
'bge-small': 12,
266294
'gte-base': 14,
267295
'arctic-embed-m': 14,
268-
'stella_en_400M_v5': 12,
296+
'stella_en_400M_v5': 20,
269297
}
270298

271299
for key, value in batch_size_mapping.items():
@@ -355,7 +383,7 @@ def create_database(self, texts, embeddings):
355383
with open(self.ROOT_DIRECTORY / "config.yaml", 'r', encoding='utf-8') as config_file:
356384
config_data = yaml.safe_load(config_file)
357385

358-
# ── NEW EMBEDDING FLOW: pre‑compute vectors, then write DB ─────────
386+
# pre‑compute vectors, then write DB
359387
vectors = embeddings.embed_documents(all_texts)
360388
text_embed_pairs = [
361389
(txt, np.asarray(vec, dtype=np.float32))
@@ -580,7 +608,7 @@ def initialize_vector_model(self):
580608
"trust_remote_code": True,
581609
"model_kwargs": {}
582610
}
583-
encode_kwargs = {'normalize_embeddings': True}
611+
# encode_kwargs = {'normalize_embeddings': True}
584612

585613
if "snowflake" in model_path.lower():
586614
logger.debug("Matched Snowflake condition")
@@ -676,3 +704,100 @@ def cleanup(self):
676704

677705
gc.collect()
678706
logging.debug(f"Cleanup completed for instance {self._debug_id}")
707+
708+
"""
709+
╔══════════════════════════════════════════════════════════════════════════╗
710+
║ DEVELOPMENT NOTES – xFormers flags, attention-impl, and tokenization ║
711+
╚══════════════════════════════════════════════════════════════════════════╝
712+
713+
────────────────────────────────────────────────────────────────────────────
714+
1. Which models can use xFormers memory-efficient attention?
715+
────────────────────────────────────────────────────────────────────────────
716+
• Snowflake-GTE family (all sizes except the “-large” variants)
717+
• Alibaba-GTE family
718+
• Stella-400 M (v5)
719+
720+
Stella-1.5 B **cannot** use xFormers kernels at the time of writing.
721+
722+
────────────────────────────────────────────────────────────────────────────
723+
2. Snowflake-GTE & Alibaba-GTE (shared behaviour)
724+
────────────────────────────────────────────────────────────────────────────
725+
✔ Flags belong in ✧config_kwargs✧ (which LangChain forwards to AutoConfig):
726+
727+
{
728+
"config_kwargs": {
729+
"use_memory_efficient_attention": <bool>, # enable xFormers
730+
"unpad_inputs": <bool>, # strip padding tokens
731+
"attn_implementation": "eager" # MUST be "eager"
732+
}
733+
}
734+
735+
Implementation rules inside the GTE source:
736+
737+
• If use_memory_efficient_attention is **True**
738+
– xFormers must be importable, otherwise an assertion fires.
739+
– attn_implementation is automatically coerced to "eager"
740+
(the code does this for you, but supplying "eager" is clearer).
741+
742+
• If use_memory_efficient_attention is **False**
743+
– You may still set unpad_inputs=True. The model will unpad/re-pad
744+
tensors using pure-PyTorch helpers (slower but functional).
745+
– attn_implementation can be "sdpa" or "eager". Either works.
746+
747+
────────────────────────────────────────────────────────────────────────────
748+
3. Stella-400 M (v5)
749+
────────────────────────────────────────────────────────────────────────────
750+
✔ Same flag block, but with stricter rules:
751+
752+
{
753+
"config_kwargs": {
754+
"use_memory_efficient_attention": <bool>, # optional
755+
"unpad_inputs": <bool>, # should match the flag above
756+
"attn_implementation": "eager" # ALWAYS "eager"
757+
}
758+
}
759+
760+
• The 400 M code path **does not implement an SDPA class** yet, so
761+
"eager" is mandatory even when xFormers is disabled.
762+
763+
• If you set use_memory_efficient_attention=True while xFormers is
764+
missing, an assertion will raise at runtime.
765+
766+
────────────────────────────────────────────────────────────────────────────
767+
4. Flag placement summary
768+
────────────────────────────────────────────────────────────────────────────
769+
outer model_kwargs (passed to SentenceTransformer)
770+
771+
├── tokenizer_kwargs → forwarded to AutoTokenizer ← configure padding
772+
│ & truncation here
773+
774+
├── model_kwargs → forwarded to AutoModel ← runtime knobs
775+
│ (dtype, quantisation, ...)
776+
777+
└── config_kwargs → forwarded to AutoConfig ← put the three
778+
(BEFORE weights load) xFormers flags here
779+
• use_memory_efficient_attention
780+
• unpad_inputs
781+
• attn_implementation
782+
783+
────────────────────────────────────────────────────────────────────────────
784+
5. Tokenization vs. encode_kwargs (common pit-fall)
785+
────────────────────────────────────────────────────────────────────────────
786+
• SentenceTransformer.encode() *never* forwards encode_kwargs into the
787+
tokenizer. It tokenizes first, then passes encode_kwargs into the model’s
788+
forward() call.
789+
790+
• Therefore:
791+
– padding / truncation / max_length → tokenizer_kwargs
792+
– batch_size, convert_to_numpy, etc. → encode_kwargs
793+
794+
Putting padding flags in encode_kwargs is silently ignored.
795+
796+
────────────────────────────────────────────────────────────────────────────
797+
6. Runtime checklist
798+
────────────────────────────────────────────────────────────────────────────
799+
□ GPU build? → supports_flash_attention() must confirm
800+
□ xFormers installed? → import xformers.ops succeeds
801+
□ Flags consistent? → unpad_inputs should mirror use_memory_efficient_attention
802+
□ attn_implementation → "eager" for 400 M; "eager"/"sdpa" for others
803+
"""

0 commit comments

Comments
 (0)