|
1 | 1 | import gc |
2 | 2 | import os |
3 | 3 | import time |
4 | | -import json |
5 | 4 | from copy import deepcopy |
6 | 5 | from pathlib import Path |
7 | 6 | from typing import Optional |
@@ -75,8 +74,12 @@ def prepare_kwargs(self): |
75 | 74 | "model_max_length": 512, |
76 | 75 | }) |
77 | 76 |
|
78 | | - hf_embed_kw.pop("model_kwargs", None) |
79 | | - hf_embed_kw.pop("config_kwargs", None) |
| 77 | + inner = hf_embed_kw.get("model_kwargs", {}) |
| 78 | + inner = {k: v for k, v in inner.items() if v is not None} |
| 79 | + if inner: |
| 80 | + hf_embed_kw["model_kwargs"] = inner |
| 81 | + else: |
| 82 | + hf_embed_kw.pop("model_kwargs", None) |
80 | 83 |
|
81 | 84 | return hf_embed_kw |
82 | 85 |
|
@@ -179,17 +182,21 @@ def prepare_kwargs(self): |
179 | 182 | is_cuda = device.startswith("cuda") |
180 | 183 | use_flash = is_cuda and supports_flash_attention() |
181 | 184 |
|
| 185 | + inner = q_kwargs.setdefault("model_kwargs", {}) |
| 186 | + |
182 | 187 | if use_flash: |
183 | 188 | try: |
184 | 189 | dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16 |
185 | 190 | except Exception: |
186 | 191 | dtype = torch.float16 |
187 | | - q_kwargs.update({ |
| 192 | + inner.update({ |
188 | 193 | "torch_dtype": dtype, |
189 | 194 | "attn_implementation": "flash_attention_2", |
190 | 195 | }) |
191 | 196 | else: |
192 | | - q_kwargs["attn_implementation"] = "sdpa" |
| 197 | + inner.update({ |
| 198 | + "attn_implementation": "sdpa", |
| 199 | + }) |
193 | 200 |
|
194 | 201 | tok_kw = q_kwargs.setdefault("tokenizer_kwargs", {}) |
195 | 202 | tok_kw.update({ |
@@ -258,8 +265,10 @@ def initialize_vector_model(self, embedding_model_name, config_data): |
258 | 265 | outer_model_kwargs = { |
259 | 266 | "device": compute_device, |
260 | 267 | "trust_remote_code": True, |
261 | | - "torch_dtype": torch_dtype if torch_dtype is not None else None, |
262 | | - "attn_implementation": "sdpa", |
| 268 | + "model_kwargs": { |
| 269 | + "torch_dtype": torch_dtype if torch_dtype is not None else None, |
| 270 | + "attn_implementation": "sdpa", |
| 271 | + }, |
263 | 272 | } |
264 | 273 |
|
265 | 274 | encode_kwargs = { |
@@ -634,8 +643,10 @@ def initialize_vector_model(self): |
634 | 643 | outer_model_kwargs = { |
635 | 644 | "device": compute_device, |
636 | 645 | "trust_remote_code": True, |
637 | | - "torch_dtype": torch_dtype if torch_dtype is not None else None, |
638 | | - "attn_implementation": "sdpa", |
| 646 | + "model_kwargs": { |
| 647 | + "torch_dtype": torch_dtype if torch_dtype is not None else None, |
| 648 | + "attn_implementation": "sdpa", |
| 649 | + }, |
639 | 650 | } |
640 | 651 |
|
641 | 652 | encode_kwargs = { |
|
0 commit comments