Skip to content

Commit c937e34

Browse files
authored
fix try 2
1 parent 7787d0a commit c937e34

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

src/database_interactions.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import gc
22
import os
33
import time
4-
import json
54
from copy import deepcopy
65
from pathlib import Path
76
from typing import Optional
@@ -75,8 +74,12 @@ def prepare_kwargs(self):
7574
"model_max_length": 512,
7675
})
7776

78-
hf_embed_kw.pop("model_kwargs", None)
79-
hf_embed_kw.pop("config_kwargs", None)
77+
inner = hf_embed_kw.get("model_kwargs", {})
78+
inner = {k: v for k, v in inner.items() if v is not None}
79+
if inner:
80+
hf_embed_kw["model_kwargs"] = inner
81+
else:
82+
hf_embed_kw.pop("model_kwargs", None)
8083

8184
return hf_embed_kw
8285

@@ -179,17 +182,21 @@ def prepare_kwargs(self):
179182
is_cuda = device.startswith("cuda")
180183
use_flash = is_cuda and supports_flash_attention()
181184

185+
inner = q_kwargs.setdefault("model_kwargs", {})
186+
182187
if use_flash:
183188
try:
184189
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16
185190
except Exception:
186191
dtype = torch.float16
187-
q_kwargs.update({
192+
inner.update({
188193
"torch_dtype": dtype,
189194
"attn_implementation": "flash_attention_2",
190195
})
191196
else:
192-
q_kwargs["attn_implementation"] = "sdpa"
197+
inner.update({
198+
"attn_implementation": "sdpa",
199+
})
193200

194201
tok_kw = q_kwargs.setdefault("tokenizer_kwargs", {})
195202
tok_kw.update({
@@ -258,8 +265,10 @@ def initialize_vector_model(self, embedding_model_name, config_data):
258265
outer_model_kwargs = {
259266
"device": compute_device,
260267
"trust_remote_code": True,
261-
"torch_dtype": torch_dtype if torch_dtype is not None else None,
262-
"attn_implementation": "sdpa",
268+
"model_kwargs": {
269+
"torch_dtype": torch_dtype if torch_dtype is not None else None,
270+
"attn_implementation": "sdpa",
271+
},
263272
}
264273

265274
encode_kwargs = {
@@ -634,8 +643,10 @@ def initialize_vector_model(self):
634643
outer_model_kwargs = {
635644
"device": compute_device,
636645
"trust_remote_code": True,
637-
"torch_dtype": torch_dtype if torch_dtype is not None else None,
638-
"attn_implementation": "sdpa",
646+
"model_kwargs": {
647+
"torch_dtype": torch_dtype if torch_dtype is not None else None,
648+
"attn_implementation": "sdpa",
649+
},
639650
}
640651

641652
encode_kwargs = {

0 commit comments

Comments
 (0)