Skip to content

Commit 5088b43

Browse files
o7siCISC
andauthored
convert : fix TypeError when loading base model remotely in convert_lora_to_gguf (ggml-org#17385)
* fix: TypeError when loading base model remotely in convert_lora_to_gguf * refactor: simplify base model loading using cache_dir from HuggingFace * Update convert_lora_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * feat: add remote_hf_model_id to trigger lazy mode in LoRA converter --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent 845f200 commit 5088b43

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

convert_lora_to_gguf.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,15 @@ def parse_args() -> argparse.Namespace:
277277
return parser.parse_args()
278278

279279

280-
def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
280+
def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]:
281+
from huggingface_hub import try_to_load_from_cache
282+
281283
# normally, adapter does not come with base model config, we need to load it from AutoConfig
282284
config = AutoConfig.from_pretrained(hf_model_id)
283-
return config.to_dict()
285+
cache_dir = try_to_load_from_cache(hf_model_id, "config.json")
286+
cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None
287+
288+
return config.to_dict(), cache_dir
284289

285290

286291
if __name__ == '__main__':
@@ -325,13 +330,13 @@ def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
325330
# load base model
326331
if base_model_id is not None:
327332
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
328-
hparams = load_hparams_from_hf(base_model_id)
333+
hparams, dir_base_model = load_hparams_from_hf(base_model_id)
329334
elif dir_base_model is None:
330335
if "base_model_name_or_path" in lparams:
331336
model_id = lparams["base_model_name_or_path"]
332337
logger.info(f"Loading base model from Hugging Face: {model_id}")
333338
try:
334-
hparams = load_hparams_from_hf(model_id)
339+
hparams, dir_base_model = load_hparams_from_hf(model_id)
335340
except OSError as e:
336341
logger.error(f"Failed to load base model config: {e}")
337342
logger.error("Please try downloading the base model and add its path to --base")
@@ -480,6 +485,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
480485
dir_lora_model=dir_lora,
481486
lora_alpha=alpha,
482487
hparams=hparams,
488+
remote_hf_model_id=base_model_id,
483489
)
484490

485491
logger.info("Exporting model...")

0 commit comments

Comments
 (0)