diff --git a/examples/api/model_wrapper.py b/examples/api/model_wrapper.py index 73b31283..ffc37199 100644 --- a/examples/api/model_wrapper.py +++ b/examples/api/model_wrapper.py @@ -114,7 +114,11 @@ def query_model(self, model: str, prompt: str, **kwargs) -> Dict[str, Any]: elif model == "ollama": response = self.query_ollama(prompt, kwargs.get('ollama_model', 'mistral')) elif model == "huggingface": - response = self.query_huggingface(prompt, kwargs.get('hf_model_id')) + hf_model_id = kwargs.get('hf_model_id') + if hf_model_id: + response = self.query_huggingface(prompt, hf_model_id) + else: + response = self.query_huggingface(prompt) else: raise ValueError(f"Unsupported model: {model}")