diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index 7c81bf3e188..f58a6b1393f 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -7,6 +7,7 @@ Structured Outputs refer to predefined format constraints that force large language models to generate content strictly following specified structures. This feature significantly improves output controllability and is suitable for scenarios requiring precise format outputs (such as API calls, data parsing, code generation, etc.), while supporting dynamic grammar extensions to balance flexibility and standardization. FastDeploy supports using the [XGrammar](https://xgrammar.mlc.ai/docs/) backend to generate structured outputs. +FastDeploy supports using the [LLguidance](https://github.com/guidance-ai/llguidance) backend to generate structured outputs. Supported output formats: diff --git a/docs/parameters.md b/docs/parameters.md index c26e471625c..3a104f65482 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -44,7 +44,7 @@ When using FastDeploy to deploy models (including offline inference and service | ```disable_sequence_parallel_moe``` | `bool` | Disable sequence parallel moe, default: False | | ```splitwise_role``` | `str` | Whether to enable splitwise inference, default value: mixed, supported parameters: ["mixed", "decode", "prefill"] | | ```innode_prefill_ports``` | `str` | Internal engine startup ports for prefill instances (only required for single-machine PD separation), default: None | -| ```guided_decoding_backend``` | `str` | Specify the guided decoding backend to use, supports `auto`, `xgrammar`, `off`, default: `off` | +| ```guided_decoding_backend``` | `str` | Specify the guided decoding backend to use, supports `auto`, `xgrammar`, `guidance`, `off`, default: `off` | | ```guided_decoding_disable_any_whitespace``` | `bool` | Whether to disable whitespace generation during guided decoding, default: False | | ```speculative_config``` | `dict[str]` | Speculative decoding configuration, only supports standard format JSON string, default: None | | ```dynamic_load_weight``` | `int` | Whether to enable dynamic weight loading, default: 0 | diff --git a/docs/zh/features/structured_outputs.md b/docs/zh/features/structured_outputs.md index 50c010c1498..f7503c3a6e0 100644 --- a/docs/zh/features/structured_outputs.md +++ b/docs/zh/features/structured_outputs.md @@ -7,6 +7,7 @@ Structured Outputs 是指通过预定义格式约束,使大模型生成内容严格遵循指定结构。该功能可显著提升生成结果的可控性,适用于需要精确格式输出的场景(如API调用、数据解析、代码生成等),同时支持动态语法扩展,平衡灵活性与规范性。 FastDeploy 支持使用 [XGrammar](https://xgrammar.mlc.ai/docs/) 后端生成结构化输出。 +FastDeploy 支持使用 [LLguidance](https://github.com/guidance-ai/llguidance) 后端生成结构化输出。 支持输出格式 diff --git a/docs/zh/parameters.md b/docs/zh/parameters.md index 841441ab2d5..6d745d668c2 100644 --- a/docs/zh/parameters.md +++ b/docs/zh/parameters.md @@ -42,7 +42,7 @@ | ```disable_sequence_parallel_moe``` | `bool` | 禁止在TP+EP中使用序列并行优化, default: False | | ```splitwise_role``` | `str` | 是否开启splitwise推理,默认值mixed, 支持参数为["mixed", "decode", "prefill"] | | ```innode_prefill_ports``` | `str` | prefill 实例内部引擎启动端口 (仅单机PD分离需要),默认值None | -| ```guided_decoding_backend``` | `str` | 指定要使用的guided decoding后端,支持 `auto`、`xgrammar`、`off`, 默认为 `off` | +| ```guided_decoding_backend``` | `str` | 指定要使用的guided decoding后端,支持 `auto`、`xgrammar`、 `guidance`、`off`, 默认为 `off` | | ```guided_decoding_disable_any_whitespace``` | `bool` | guided decoding期间是否禁止生成空格,默认False | | ```speculative_config``` | `dict[str]` | 投机解码配置,仅支持标准格式json字符串,默认为None | | ```dynamic_load_weight``` | `int` | 是否动态加载权重,默认0 | diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 4cf05ea989f..c07ca0926e2 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1646,13 +1646,27 @@ def postprocess(self): if ( self.structured_outputs_config is not None - and self.structured_outputs_config.guided_decoding_backend == "auto" + and self.structured_outputs_config.guided_decoding_backend != "off" ): if current_platform.is_xpu() or self.speculative_config.method is not None: logger.warning("Speculative Decoding and XPU currently do not support Guided decoding, set off.") self.structured_outputs_config.guided_decoding_backend = "off" - else: + elif self.structured_outputs_config.guided_decoding_backend in ["auto", "xgrammar"]: self.structured_outputs_config.guided_decoding_backend = "xgrammar" + elif self.structured_outputs_config.guided_decoding_backend == "guidance": + try: + import llguidance.torch + + llguidance.torch + except ImportError: + raise ImportError( + "The 'llguidance' package is required for using guidance as the guided decoding backend. " + "Please install it via the appropriate method." + ) + else: + raise NotImplementedError( + f"Guided decoding backend '{self.structured_outputs_config.guided_decoding_backend}' is not implemented. [auto, xgrammar, guidance, off]" + ) if self.model_config.enable_mm: if self.cache_config.max_encoder_cache is None or self.cache_config.max_encoder_cache < 0: @@ -1772,7 +1786,8 @@ def check(self): "XGrammar", "auto", "off", - ], f"Only support xgrammar、auto guided decoding backend, but got {self.structured_outputs_config.guided_decoding_backend}." + "guidance", + ], f"Only support [auto, xgrammar, guidance, off] guided decoding backend, but got {self.structured_outputs_config.guided_decoding_backend}." if self.structured_outputs_config.guided_decoding_backend != "off": # TODO: speculative decoding support guided_decoding diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 6a161d14038..f0127785ad1 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -148,6 +148,8 @@ "FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")), "FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")), "FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")), + "FD_GUIDANCE_DISABLE_ADDITIONAL": lambda: bool(int(os.getenv("FD_GUIDANCE_DISABLE_ADDITIONAL", "1"))), + "FD_LLGUIDANCE_LOG_LEVEL": lambda: int(os.getenv("FD_LLGUIDANCE_LOG_LEVEL", "0")), } diff --git a/fastdeploy/lazy_loader.py b/fastdeploy/lazy_loader.py new file mode 100644 index 00000000000..c3bc3ec43aa --- /dev/null +++ b/fastdeploy/lazy_loader.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A LazyLoader class.""" + +import importlib +import sys +import types +from typing import Any + + +class LazyLoader(types.ModuleType): + """ + LazyLoader module borrowed from Tensorflow + https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py + with an addition of "module caching". + + Lazily import a module, mainly to avoid pulling in large dependencies. + Modules such as `xgrammar` might do additional side effects, so we + only want to use this when it is needed, delaying all eager effects + """ + + def __init__( + self, + local_name: str, + parent_module_globals: dict[str, Any], + name: str, + ): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + self._module: types.ModuleType | None = None + + super().__init__(str(name)) + + def _load(self) -> types.ModuleType: + # Import the target module and insert it into the parent's namespace + try: + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + # The additional add to sys.modules + # ensures library is actually loaded. + sys.modules[self._local_name] = module + except ModuleNotFoundError as err: + raise err from None + + # Update this object's dict so that if someone keeps a + # reference to the LazyLoader, lookups are efficient + # (__getattr__ is only called on lookups that fail). + self.__dict__.update(module.__dict__) + return module + + def __getattr__(self, item: Any) -> Any: + if self._module is None: + self._module = self._load() + return getattr(self._module, item) + + def __dir__(self) -> list[str]: + if self._module is None: + self._module = self._load() + return dir(self._module) diff --git a/fastdeploy/model_executor/guided_decoding/__init__.py b/fastdeploy/model_executor/guided_decoding/__init__.py index dbfc70215d1..989ce0fd63f 100644 --- a/fastdeploy/model_executor/guided_decoding/__init__.py +++ b/fastdeploy/model_executor/guided_decoding/__init__.py @@ -50,6 +50,15 @@ def get_guided_backend( fd_config=fd_config, **kwargs, ) + elif fd_config.structured_outputs_config.guided_decoding_backend.lower() == "guidance": + from fastdeploy.model_executor.guided_decoding.guidance_backend import ( + LLGuidanceBackend, + ) + + return LLGuidanceBackend( + fd_config=fd_config, + **kwargs, + ) else: raise ValueError( f"Get unsupported backend {fd_config.structured_outputs_config.guided_decoding_backend}," @@ -77,5 +86,11 @@ def schema_checker(backend_name: str, **kwargs): ) return XGrammarChecker(**kwargs) + elif backend_name.lower() == "guidance": + from fastdeploy.model_executor.guided_decoding.guidance_backend import ( + LLGuidanceChecker, + ) + + return LLGuidanceChecker(**kwargs) else: raise ValueError(f"Get unsupported backend {backend_name}, please check your configuration.") diff --git a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py index c4e235afc36..717cafdcdf8 100644 --- a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py +++ b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py @@ -294,7 +294,12 @@ def _get_tokenizer_hf(self): """ try: architectures = self.fd_config.model_config.architectures - if not ErnieArchitectures.contains_ernie_arch(architectures): + is_guidance_backend = ( + self.fd_config.structured_outputs_config is not None + and self.fd_config.structured_outputs_config.guided_decoding_backend is not None + and self.fd_config.structured_outputs_config.guided_decoding_backend == "guidance" + ) + if not ErnieArchitectures.contains_ernie_arch(architectures) or is_guidance_backend: from transformers import AutoTokenizer, PreTrainedTokenizerFast tokenizer = AutoTokenizer.from_pretrained( diff --git a/fastdeploy/model_executor/guided_decoding/guidance_backend.py b/fastdeploy/model_executor/guided_decoding/guidance_backend.py new file mode 100644 index 00000000000..613256e81cc --- /dev/null +++ b/fastdeploy/model_executor/guided_decoding/guidance_backend.py @@ -0,0 +1,315 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import copy +import json +import traceback +from typing import Any, Optional, Tuple, Union + +from fastdeploy.config import FDConfig +from fastdeploy.engine.request import Request +from fastdeploy.envs import FD_GUIDANCE_DISABLE_ADDITIONAL, FD_LLGUIDANCE_LOG_LEVEL +from fastdeploy.lazy_loader import LazyLoader +from fastdeploy.model_executor.guided_decoding import ( + BackendBase, + BaseChecker, + LogitsProcessorBase, +) +from fastdeploy.utils import llm_logger + +torch = LazyLoader("torch", globals(), "torch") +llguidance = LazyLoader("llguidance", globals(), "llguidance") +llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf") +llguidance_torch = LazyLoader("llguidance.torch", globals(), "llguidance.torch") + + +class LLGuidanceProcessor(LogitsProcessorBase): + """ + LLGuidance-specific implementation of LogitsProcessorBase. + + This processor enforces grammar constraints during token generation using llguidance. + It manages the grammar matching state and applies token masks to logits. + """ + + def __init__( + self, + ll_matcher: llguidance.LLMatcher, + ll_tokenizer: llguidance.LLTokenizer, + serialized_grammar: str, + vocab_size: int, + batch_size: int, + enable_thinking: bool = False, + ): + super().__init__(enable_reasoning=enable_thinking) + self.matcher = ll_matcher + self.ll_tokenizer = ll_tokenizer + self.serialized_grammar = serialized_grammar + self.vocab_size = vocab_size + self.batch_size = batch_size + self.is_terminated: bool = False + self._printed_error: bool = False + + def _check_error(self): + """Checks for and logs any errors from the LLMatcher.""" + if not self._printed_error: + err = self.matcher.get_error() + if err: + self._printed_error = True + llm_logger.warning(f"LLGuidance Matcher error: {err}") + + def allocate_token_bitmask(self) -> torch.Tensor: + """ + Allocate a token bitmask tensor for grammar constraints. + """ + return llguidance_torch.allocate_token_bitmask(self.batch_size, self.vocab_size) + + def fill_token_bitmask(self, token_bitmask: torch.Tensor, idx: int) -> None: + """ + Fill the token bitmask with allowed tokens for the given index. + This will automatically provide an EOS mask if the matcher is stopped. + """ + llguidance_torch.fill_next_token_bitmask(self.matcher, token_bitmask, idx) + self._check_error() + + def reset(self) -> None: + """ + Reset the grammar matcher state to initial conditions. + """ + self.matcher.reset() + self.is_terminated = False + self._printed_error = False + self._check_error() + + def accept_token(self, token: int) -> bool: + """ + Validate and accept a generated token against the grammar constraints. + Returns True if the token is accepted, False otherwise. + """ + if self.is_terminated: + return False + if self.ll_tokenizer.eos_token == token: + self.is_terminated = True + return True + + result = self.matcher.consume_tokens([token]) + self._check_error() + + return result + + +class LLGuidanceBackend(BackendBase): + """ + LLGuidance-specific implementation of BackendBase. + + This backend handles the compilation of various schema types (JSON, regex, etc.) + into LLGuidance processors. + """ + + def __init__(self, fd_config: FDConfig, **kwargs): + super().__init__(fd_config=fd_config) + self.vocab_size = fd_config.model_config.vocab_size + self.batch_size = fd_config.scheduler_config.max_num_seqs + self.any_whitespace = not fd_config.structured_outputs_config.disable_any_whitespace + + llm_logger.info(f"LLGuidanceBackend vocab_size={self.vocab_size} batch_size={self.batch_size}") + try: + self.ll_tokenizer = llguidance_hf.from_tokenizer(self.hf_tokenizer, self.vocab_size) + except Exception as e: + import traceback + + raise RuntimeError( + f"Failed to initialize llguidance tokenizer from HuggingFace tokenizer: {e} {traceback.format_exc()}" + ) + + def _create_processor( + self, + compiled_grammar: str, + enable_thinking: bool = False, + ) -> Optional[LLGuidanceProcessor]: + """ + Create a logits processor instance for the given grammar schemata. + """ + try: + + ll_matcher = llguidance.LLMatcher( + self.ll_tokenizer, + compiled_grammar, + log_level=FD_LLGUIDANCE_LOG_LEVEL, + ) + + return LLGuidanceProcessor( + ll_matcher=ll_matcher, + ll_tokenizer=self.ll_tokenizer, + serialized_grammar=compiled_grammar, + vocab_size=self.vocab_size, + batch_size=self.batch_size, + enable_thinking=enable_thinking, + ) + except Exception as e: + llm_logger.error(f"Failed to create llguidance processor: {e}, {str(traceback.format_exc())}") + return None + + def _json_processor(self, compiled_grammar: str, enable_thinking: bool = False) -> Optional[LLGuidanceProcessor]: + return self._create_processor(compiled_grammar, enable_thinking) + + def _regex_processor(self, compiled_grammar: str, enable_thinking: bool = False) -> Optional[LLGuidanceProcessor]: + return self._create_processor(compiled_grammar, enable_thinking) + + def _grammar_processor( + self, compiled_grammar: str, enable_thinking: bool = False + ) -> Optional[LLGuidanceProcessor]: + return self._create_processor(compiled_grammar, enable_thinking) + + def _structural_tag_processor( + self, compiled_grammar: str, enable_thinking: bool = False + ) -> Optional[LLGuidanceProcessor]: + return self._create_processor(compiled_grammar, enable_thinking) + + +def _walk_json_for_additional_properties(data: object): + if isinstance(data, dict): + for value in data.values(): + _walk_json_for_additional_properties(value) + if "additionalProperties" not in data and ("properties" in data or "patternProperties" in data): + data["additionalProperties"] = False + elif isinstance(data, list): + for item in data: + _walk_json_for_additional_properties(item) + + +def process_for_additional_properties(guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]: + if isinstance(guide_json, str): + guide_json_obj = json.loads(guide_json) + else: + # copy for modifications + guide_json_obj = copy.deepcopy(guide_json) + _walk_json_for_additional_properties(guide_json_obj) + return guide_json_obj + + +class LLGuidanceChecker(BaseChecker): + """ + LLGuidance-specific implementation of BaseChecker. + + This checker validates various schema types for compatibility with the + llguidance library before processing. + """ + + def __init__(self, **kwargs): + super().__init__() + # Although the backend handles serialization, we can perform a quick + # static check here without a full tokenizer. + self.any_whitespace = not kwargs.get("disable_any_whitespace", False) + self.disable_additional_properties = FD_GUIDANCE_DISABLE_ADDITIONAL + """If `True`, the `guidance` backend will not use `additionalProperties` + in the JSON schema. This is only supported for the `guidance` backend and + is used to better align its behaviour with `outlines` and `xgrammar`.""" + + def serialize_guidance_grammar(self, request: Request): + def _process_schema( + grammar_spec: Union[str, dict[str, Any]], + ) -> str: + if self.disable_additional_properties: + grammar_spec = process_for_additional_properties(grammar_spec) + return llguidance.LLMatcher.grammar_from_json_schema( + grammar_spec, + defaults={ + "whitespace_flexible": self.any_whitespace, + }, + ) + + if request.guided_json: + if isinstance(request.guided_json, dict): + guided_json = json.dumps(request.guided_json) + else: + guided_json = request.guided_json + return _process_schema(guided_json) + elif request.guided_json_object: + return llguidance.LLMatcher.grammar_from_json_schema( + '{"type": "object"}', + defaults={ + "whitespace_flexible": self.any_whitespace, + }, + ) + + if request.structural_tag: + if isinstance(request.structural_tag, str): + s_tag = json.loads(request.structural_tag) + else: + s_tag = request.structural_tag + triggers: list[str] = s_tag["triggers"] + tags: list[llguidance.StructTag] = [] + for s in s_tag["structures"]: + begin: str = s["begin"] + trig = next((t for t in triggers if begin.startswith(t)), None) + if trig is None: + raise ValueError(f"Trigger {begin} not found in triggers {triggers}") + tags.append( + llguidance.StructTag( + trigger=trig, + begin=s["begin"], + grammar=_process_schema(s["schema"]), + end=s["end"], + ) + ) + if not tags: + raise ValueError("No structural tags found in the grammar spec.") + return llguidance.StructTag.to_grammar(tags) + + if request.guided_regex: + tp = "regex" + grammar_spec = request.guided_regex + elif request.guided_choice: + tp = "choice" + grammar_spec = request.guided_choice + elif request.guided_grammar: + tp = "grammar" + grammar_spec = request.guided_grammar + else: + llm_logger.error("Validation should have already occurred. " "Please file an issue.") + raise ValueError("grammar is not of valid supported types. ") + return llguidance.grammar_from(tp, grammar_spec) + + def schema_format(self, request: Request) -> Tuple[Request, Optional[str]]: + """ + Validates and formats the schema for the LLGuidance backend. + """ + try: + guidance_grm = self.serialize_guidance_grammar(request) + err = llguidance.LLMatcher.validate_grammar(guidance_grm, None) + if err: + raise ValueError(f"Grammar error: {err}") + else: + llm_logger.info(f"valid schema_format {guidance_grm} {request}") + if request.guided_regex: + request.guided_regex = guidance_grm + elif request.guided_choice: + request.guided_grammar = guidance_grm + request.guided_choice = None + elif request.guided_grammar: + request.guided_grammar = guidance_grm + elif request.guided_json: + request.guided_json = guidance_grm + + except (ValueError, TypeError, json.JSONDecodeError) as e: + err_msg = f"Invalid format for guided decoding: {e!s} request={request}" + return request, err_msg + + except Exception as e: + err_msg = f"An unexpected error occurred during schema validation: {e!s}" + return request, err_msg + + return request, None diff --git a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py index e0f195e47df..8ef67ef6abb 100644 --- a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py +++ b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py @@ -73,7 +73,6 @@ def __init__( enable_thinking: bool = False, ): super().__init__(enable_reasoning=enable_thinking) - self.max_rollback_tokens = 200 self.vocab_size = vocab_size self.batch_size = batch_size self.compiled_grammar = compiled_grammar @@ -82,7 +81,6 @@ def __init__( self.matcher = GrammarMatcher( compiled_grammar=compiled_grammar, - max_rollback_tokens=self.max_rollback_tokens, terminate_without_stop_token=terminate_without_stop_token, override_stop_tokens=override_stop_tokens, ) diff --git a/requirements_guided_decoding.txt b/requirements_guided_decoding.txt new file mode 100644 index 00000000000..627ecc25e1a --- /dev/null +++ b/requirements_guided_decoding.txt @@ -0,0 +1,3 @@ +xgrammar==0.1.25 +llguidance==1.3.0 +torch==2.8.0 diff --git a/tests/model_executor/guided_decoding/test_guidance_backend.py b/tests/model_executor/guided_decoding/test_guidance_backend.py new file mode 100644 index 00000000000..a545aba800e --- /dev/null +++ b/tests/model_executor/guided_decoding/test_guidance_backend.py @@ -0,0 +1,173 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import sys +import unittest +from unittest.mock import MagicMock, patch + +from fastdeploy.model_executor.guided_decoding import BackendBase + +mock_llguidance = MagicMock() +mock_llguidancehf = MagicMock() +mock_torch = MagicMock() +sys.modules["llguidance"] = mock_llguidance +sys.modules["llguidance.hf"] = mock_llguidancehf +sys.modules["torch"] = mock_torch + +# Import the module to be tested +from fastdeploy.model_executor.guided_decoding.guidance_backend import ( + LLGuidanceBackend, + LLGuidanceProcessor, + process_for_additional_properties, +) + + +class TestProcessForAdditionalProperties(unittest.TestCase): + def test_process_json_string(self): + # Test string input + json_str = '{"type": "object", "properties": {"name": {"type": "string"}}}' + result = process_for_additional_properties(json_str) + self.assertFalse(result["additionalProperties"]) + + def test_process_json_dict(self): + # Test dictionary input + json_dict = {"type": "object", "properties": {"name": {"type": "string"}}} + result = process_for_additional_properties(json_dict) + self.assertFalse(result["additionalProperties"]) + # Ensure the original dictionary is not modified + self.assertNotIn("additionalProperties", json_dict) + + def test_nested_objects(self): + # Test nested objects + json_dict = { + "type": "object", + "properties": {"person": {"type": "object", "properties": {"name": {"type": "string"}}}}, + } + result = process_for_additional_properties(json_dict) + self.assertFalse(result["additionalProperties"]) + self.assertFalse(result["properties"]["person"]["additionalProperties"]) + + +@patch("llguidance.LLMatcher") +@patch("llguidance.LLTokenizer") +class TestLLGuidanceProcessor(unittest.TestCase): + def setUp(self): + self.vocab_size = 100 + self.batch_size = 2 + + def test_initialization(self, mock_tokenizer, mock_matcher): + # Test initialization + processor = LLGuidanceProcessor( + ll_matcher=mock_matcher, + ll_tokenizer=mock_tokenizer, + serialized_grammar="test_grammar", + vocab_size=self.vocab_size, + batch_size=self.batch_size, + ) + + self.assertEqual(processor.vocab_size, self.vocab_size) + self.assertEqual(processor.batch_size, self.batch_size) + self.assertFalse(processor.is_terminated) + + def test_reset(self, mock_tokenizer, mock_matcher): + # Test reset functionality + processor = LLGuidanceProcessor( + ll_matcher=mock_matcher, + ll_tokenizer=mock_tokenizer, + serialized_grammar="test_grammar", + vocab_size=self.vocab_size, + batch_size=self.batch_size, + ) + + processor.is_terminated = True + processor.reset() + + mock_matcher.reset.assert_called_once() + self.assertFalse(processor.is_terminated) + + def test_accept_token(self, mock_tokenizer, mock_matcher): + # Test accept_token functionality + mock_matcher.is_stopped.return_value = False + mock_matcher.consume_tokens.return_value = True + mock_tokenizer.eos_token = 1 + + processor = LLGuidanceProcessor( + ll_matcher=mock_matcher, + ll_tokenizer=mock_tokenizer, + serialized_grammar="test_grammar", + vocab_size=self.vocab_size, + batch_size=self.batch_size, + ) + + # Normal token + result = processor.accept_token(0) + self.assertTrue(result) + mock_matcher.consume_tokens.assert_called_with([0]) + + # EOS token + result = processor.accept_token(1) + self.assertTrue(result) + self.assertTrue(processor.is_terminated) + + +@patch("llguidance.LLMatcher") +@patch("llguidance.hf.from_tokenizer") +class TestLLGuidanceBackend(unittest.TestCase): + def setUp(self): + # Create a mock FDConfig + self.fd_config = MagicMock() + self.fd_config.model_config.vocab_size = 100 + self.fd_config.scheduler_config.max_num_seqs = 2 + self.fd_config.structured_outputs_config.disable_any_whitespace = False + self.fd_config.structured_outputs_config.disable_additional_properties = False + self.fd_config.structured_outputs_config.reasoning_parser = None + + def test_initialization(self, mock_from_tokenizer, mock_matcher): + # Test backend initialization + mock_tokenizer = MagicMock() + with patch.object(BackendBase, "_get_tokenizer_hf", return_value=mock_tokenizer): + backend = LLGuidanceBackend(fd_config=self.fd_config) + + self.assertEqual(backend.vocab_size, 100) + self.assertEqual(backend.batch_size, 2) + self.assertTrue(backend.any_whitespace) + + @patch("llguidance.LLMatcher") + def test_create_processor(self, mock_matcher_class, mock_from_tokenizer, mock_matcher): + # Test creating a processor + with patch.object(LLGuidanceBackend, "__init__", return_value=None): + backend = LLGuidanceBackend(fd_config=None) # Arguments are not important because __init__ is mocked + + # Manually set all required attributes + backend.hf_tokenizer = MagicMock() + backend.ll_tokenizer = MagicMock() + backend.vocab_size = 100 + backend.batch_size = 2 + backend.any_whitespace = True + backend.disable_additional_properties = False + + mock_matcher = MagicMock() + mock_matcher_class.return_value = mock_matcher + + processor = backend._create_processor("test_grammar") + + self.assertIsInstance(processor, LLGuidanceProcessor) + self.assertEqual(processor.vocab_size, 100) + self.assertEqual(processor.batch_size, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/model_executor/guided_decoding/test_guidance_checker.py b/tests/model_executor/guided_decoding/test_guidance_checker.py new file mode 100644 index 00000000000..d1dd09c61dc --- /dev/null +++ b/tests/model_executor/guided_decoding/test_guidance_checker.py @@ -0,0 +1,591 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import json +import sys +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +# Check if llguidance can be imported +HAS_LLGUIDANCE = False +try: + import llguidance + + llguidance + HAS_LLGUIDANCE = True +except ImportError: + mock_llguidance = MagicMock() + mock_torch = MagicMock() + sys.modules["llguidance"] = mock_llguidance + sys.modules["torch"] = mock_torch + + +@pytest.fixture +def llguidance_checker(): + """Return an LLGuidanceChecker instance for testing.""" + return LLGuidanceChecker() + + +@pytest.fixture +def llguidance_checker_with_options(): + """Return an LLGuidanceChecker instance configured with specific options.""" + return LLGuidanceChecker(disable_any_whitespace=True) + + +from fastdeploy.model_executor.guided_decoding.guidance_backend import LLGuidanceChecker + + +def MockRequest(): + request = MagicMock() + request.guided_json = None + request.guided_json_object = None + request.structural_tag = None + request.guided_regex = None + request.guided_choice = None + request.guided_grammar = None + return request + + +class TestLLGuidanceCheckerMocked: + """Test LLGuidanceChecker using Mock, suitable for environments without the llguidance library.""" + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + @patch("llguidance.LLMatcher.validate_grammar") + def test_serialize_guided_json_as_string(self, mock_validate, mock_from_schema, llguidance_checker): + """Test processing guided_json string type.""" + mock_from_schema.return_value = "serialized_grammar" + mock_validate.return_value = None + + request = MockRequest() + request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}' + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_from_schema.assert_called_once() + assert grammar == "serialized_grammar" + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + @patch("llguidance.LLMatcher.validate_grammar") + def test_serialize_guided_json_as_dict(self, mock_validate, mock_from_schema, llguidance_checker): + """Test processing guided_json dictionary type.""" + mock_from_schema.return_value = "serialized_grammar" + mock_validate.return_value = None + + request = MockRequest() + request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}} + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_from_schema.assert_called_once() + assert isinstance(request.guided_json, dict) # Verify that the dictionary has been converted to a string + assert grammar == "serialized_grammar" + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + @patch("llguidance.LLMatcher.validate_grammar") + def test_serialize_guided_json_object(self, mock_validate, mock_from_schema, llguidance_checker): + """Test processing guided_json_object.""" + mock_from_schema.return_value = "serialized_grammar" + mock_validate.return_value = None + + request = MockRequest() + request.guided_json_object = True + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_from_schema.assert_called_once() + assert request.guided_json_object + assert grammar == "serialized_grammar" + + @patch("llguidance.grammar_from") + @patch("llguidance.LLMatcher.validate_grammar") + def test_serialize_guided_regex(self, mock_validate, mock_grammar_from, llguidance_checker): + """Test processing guided_regex.""" + mock_grammar_from.return_value = "serialized_regex_grammar" + mock_validate.return_value = None + + request = MockRequest() + request.guided_regex = "[a-zA-Z]+" + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_grammar_from.assert_called_once_with("regex", "[a-zA-Z]+") + assert grammar == "serialized_regex_grammar" + + @patch("llguidance.grammar_from") + @patch("llguidance.LLMatcher.validate_grammar") + def test_serialize_guided_choice(self, mock_validate, mock_grammar_from, llguidance_checker): + """Test processing guided_choice.""" + mock_grammar_from.return_value = "serialized_choice_grammar" + mock_validate.return_value = None + + request = MockRequest() + request.guided_choice = ["option1", "option2"] + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_grammar_from.assert_called_once_with("choice", ["option1", "option2"]) + assert grammar == "serialized_choice_grammar" + + @patch("llguidance.grammar_from") + @patch("llguidance.LLMatcher.validate_grammar") + def test_serialize_guided_grammar(self, mock_validate, mock_grammar_from, llguidance_checker): + """Test processing guided_grammar.""" + mock_grammar_from.return_value = "serialized_grammar_spec" + mock_validate.return_value = None + + request = MockRequest() + request.guided_grammar = "grammar specification" + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_grammar_from.assert_called_once_with("grammar", "grammar specification") + assert grammar == "serialized_grammar_spec" + + @patch("llguidance.StructTag") + @patch("llguidance.LLMatcher.grammar_from_json_schema") + def test_serialize_structural_tag(self, mock_from_schema, mock_struct_tag, llguidance_checker): + """Test processing structural_tag.""" + # Configure mock objects + mock_from_schema.return_value = "serialized_schema" + mock_struct_tag.to_grammar.return_value = "serialized_structural_grammar" + struct_tag_instance = MagicMock() + mock_struct_tag.return_value = struct_tag_instance + + request = MockRequest() + request.structural_tag = { + "triggers": [""], + "structures": [{"begin": "", "schema": {"type": "object"}, "end": ""}], + } + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_from_schema.assert_called_once() + mock_struct_tag.assert_called_once() + mock_struct_tag.to_grammar.assert_called_once() + assert grammar == "serialized_structural_grammar" + + @patch("llguidance.StructTag") + def test_serialize_structural_tag_missing_trigger(self, mock_struct_tag, llguidance_checker): + """Test processing structural_tag when a trigger is missing.""" + request = MockRequest() + request.structural_tag = { + "triggers": [""], + "structures": [{"begin": "", "schema": {"type": "object"}, "end": ""}], + } + + with pytest.raises(ValueError, match="Trigger .* not found in triggers"): + llguidance_checker.serialize_guidance_grammar(request) + + @patch("llguidance.StructTag") + def test_serialize_structural_tag_empty_structures(self, mock_struct_tag, llguidance_checker): + """Test processing structural_tag when structures are empty.""" + request = MockRequest() + request.structural_tag = {"triggers": [""], "structures": []} + + with pytest.raises(ValueError, match="No structural tags found in the grammar spec"): + llguidance_checker.serialize_guidance_grammar(request) + + def test_serialize_invalid_grammar_type(self, llguidance_checker): + """Test processing invalid grammar types.""" + request = MockRequest() + # No grammar type set + + with pytest.raises(ValueError, match="grammar is not of valid supported types"): + llguidance_checker.serialize_guidance_grammar(request) + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + @patch("llguidance.LLMatcher.validate_grammar") + def test_schema_format_valid_json(self, mock_validate, mock_from_schema, llguidance_checker): + """Test schema_format method processing valid JSON.""" + mock_from_schema.return_value = "serialized_grammar" + mock_validate.return_value = None + + request = MockRequest() + request.guided_json = '{"type": "object"}' + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + @patch("llguidance.LLMatcher.validate_grammar") + def test_schema_format_invalid_grammar(self, mock_validate, mock_from_schema, llguidance_checker): + """Test schema_format method processing invalid grammar.""" + mock_from_schema.return_value = "serialized_grammar" + mock_validate.return_value = "Invalid grammar" + + request = MockRequest() + request.guided_json = '{"type": "object"}' + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "Grammar error: Invalid grammar" in error + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + def test_schema_format_json_decode_error(self, mock_from_schema, llguidance_checker): + """Test schema_format method processing JSON decode error.""" + mock_from_schema.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + + request = MockRequest() + request.guided_json = "{invalid json}" + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "Invalid format for guided decoding" in error + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + def test_schema_format_unexpected_error(self, mock_from_schema, llguidance_checker): + """Test schema_format method processing unexpected errors.""" + mock_from_schema.side_effect = Exception("Unexpected error") + + request = MockRequest() + request.guided_json = '{"type": "object"}' + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "An unexpected error occurred during schema validation" in error + + def test_init_with_disable_whitespace(self, llguidance_checker_with_options): + """Test setting the disable_any_whitespace option during initialization.""" + assert llguidance_checker_with_options.any_whitespace is False + assert llguidance_checker_with_options.disable_additional_properties is True + assert LLGuidanceChecker(disable_any_whitespace=True).any_whitespace is False + assert LLGuidanceChecker(disable_any_whitespace=False).any_whitespace is True + + # default check + from fastdeploy.envs import FD_GUIDANCE_DISABLE_ADDITIONAL + + assert FD_GUIDANCE_DISABLE_ADDITIONAL + + assert LLGuidanceChecker().disable_additional_properties is True + with patch("fastdeploy.model_executor.guided_decoding.guidance_backend.FD_GUIDANCE_DISABLE_ADDITIONAL", False): + assert LLGuidanceChecker().disable_additional_properties is False + + +@pytest.mark.skipif(not HAS_LLGUIDANCE, reason="llguidance library not installed, skipping actual dependency tests") +class TestLLGuidanceCheckerReal: + """Test using the actual llguidance library, suitable for development environments.""" + + def test_serialize_guided_json_string_real(self, llguidance_checker): + """Test processing guided_json string using the actual library.""" + request = MockRequest() + request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}' + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + # Verify if the returned grammar is a valid string + assert isinstance(grammar, str) + assert len(grammar) > 0 + print("grammar", grammar) + + def test_serialize_guided_json_dict_real(self, llguidance_checker): + """Test processing guided_json dictionary using the actual library.""" + request = MockRequest() + request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}} + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + assert isinstance(request.guided_json, dict) + assert isinstance(grammar, str) + assert len(grammar) > 0 + + def test_serialize_guided_json_object_real(self, llguidance_checker): + """Test processing guided_json_object using the actual library.""" + request = MockRequest() + request.guided_json_object = True + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + assert request.guided_json_object + assert isinstance(grammar, str) + assert len(grammar) > 0 + + def test_serialize_guided_regex_real(self, llguidance_checker): + """Test processing guided_regex using the actual library.""" + request = MockRequest() + request.guided_regex = "[a-zA-Z]+" + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + assert isinstance(grammar, str) + assert len(grammar) > 0 + + def test_serialize_guided_choice_real(self, llguidance_checker): + """Test processing guided_choice using the actual library.""" + request = MockRequest() + request.guided_choice = ["option1", "option2"] + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + assert isinstance(grammar, str) + assert len(grammar) > 0 + + def test_serialize_guided_grammar_real(self, llguidance_checker): + """Test processing guided_grammar using the actual library.""" + request = MockRequest() + # Use a simple CFG grammar example + request.guided_grammar = """ + root ::= greeting name + greeting ::= "Hello" | "Hi" + name ::= "world" | "everyone" + """ + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + assert isinstance(grammar, str) + assert len(grammar) > 0 + + def test_serialize_structural_tag_real(self, llguidance_checker): + """Test processing structural_tag using the actual library.""" + request = MockRequest() + request.structural_tag = { + "triggers": [""], + "structures": [{"begin": "", "schema": {"type": "object"}, "end": ""}], + } + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + assert isinstance(grammar, str) + assert len(grammar) > 0 + + def test_schema_format_valid_json_real(self, llguidance_checker): + """Test schema_format method processing valid JSON using the actual library.""" + request = MockRequest() + request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}' + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + assert result_request.guided_json != '{"type": "object", "properties": {"name": {"type": "string"}}}' + + def test_schema_format_invalid_json_real(self, llguidance_checker): + """Test schema_format method processing invalid JSON using the actual library.""" + request = MockRequest() + request.guided_json = "{invalid json}" + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "Invalid format for guided decoding" in error + + def test_whitespace_flexibility_option_real(self): + """Test the impact of the whitespace flexibility option using the actual library.""" + # Create two instances with different configurations + flexible = LLGuidanceChecker(disable_any_whitespace=False) + strict = LLGuidanceChecker(disable_any_whitespace=True) + + request_flexible = MockRequest() + request_flexible.guided_json = '{"type": "object"}' + + request_strict = MockRequest() + request_strict.guided_json = '{"type": "object"}' + + grammar_flexible = flexible.serialize_guidance_grammar(request_flexible) + grammar_strict = strict.serialize_guidance_grammar(request_strict) + print("grammar_flexible", grammar_flexible) + print("grammar_strict", grammar_strict) + + # Expect grammars generated by the two configurations to be different + assert grammar_flexible != grammar_strict + + def test_schema_format_guided_json_object_real(self, llguidance_checker): + """Test schema_format processing guided_json_object.""" + request = MockRequest() + request.guided_json_object = True + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + + def test_schema_format_guided_regex_real(self, llguidance_checker): + """Test schema_format processing valid regular expressions.""" + request = MockRequest() + request.guided_regex = r"[a-zA-Z0-9]+" + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + assert result_request.guided_regex != r"[a-zA-Z0-9]+" # Should be converted to grammar format + + def test_schema_format_invalid_guided_regex_real(self, llguidance_checker): + """Test schema_format processing invalid regular expressions.""" + request = MockRequest() + request.guided_regex = r"[" # Invalid regular expression + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "Invalid format for guided decoding" in error + + def test_schema_format_guided_choice_real(self, llguidance_checker): + """Test schema_format processing guided_choice.""" + request = MockRequest() + request.guided_choice = ["option1", "option2", "option3"] + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + assert result_request.guided_choice != [ + "option1", + "option2", + "option3", + ] # Should be converted to grammar format + + def test_schema_format_guided_grammar_real(self, llguidance_checker): + """Test schema_format processing guided_grammar.""" + request = MockRequest() + # Use the correct grammar format supported by LLGuidance + request.guided_grammar = """ + start: number + number: DIGIT+ + DIGIT: "0"|"1"|"2"|"3"|"4"|"5"|"6"|"7"|"8"|"9" + """ + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + assert isinstance(result_request.guided_grammar, str) + + def test_schema_format_structural_tag_real(self, llguidance_checker): + """Test schema_format processing structural_tag.""" + request = MockRequest() + request.structural_tag = { + "triggers": ["```json"], + "structures": [ + { + "begin": "```json", + "schema": {"type": "object", "properties": {"name": {"type": "string"}}}, + "end": "```", + } + ], + } + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + + def test_schema_format_structural_tag_string_real(self, llguidance_checker): + """Test schema_format processing structural_tag in string format.""" + request = MockRequest() + request.structural_tag = json.dumps( + { + "triggers": ["```json"], + "structures": [ + { + "begin": "```json", + "schema": {"type": "object", "properties": {"name": {"type": "string"}}}, + "end": "```", + } + ], + } + ) + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + + def test_schema_format_structural_tag_invalid_trigger_real(self, llguidance_checker): + """Test schema_format processing structural_tag with invalid triggers.""" + request = MockRequest() + request.structural_tag = { + "triggers": ["```xml"], # Trigger does not match begin + "structures": [ + { + "begin": "```json", + "schema": {"type": "object"}, + "end": "```", + } # Does not contain any prefix from triggers here + ], + } + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "Invalid format for guided decoding" in error + + def test_schema_format_structural_tag_empty_structures_real(self, llguidance_checker): + """Test schema_format processing structural_tag with empty structures.""" + request = MockRequest() + request.structural_tag = {"triggers": ["```json"], "structures": []} # Empty structure + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "Invalid format for guided decoding" in error + + def test_schema_format_json_dict_real(self, llguidance_checker): + """Test schema_format processing guided_json in dictionary format.""" + request = MockRequest() + request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}} + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + + def test_schema_format_disable_additional_properties_real(self): + """Test schema_format processing disable_additional_properties parameter.""" + checker = LLGuidanceChecker(disable_additional_properties=True) + request = MockRequest() + request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}} + + result_request, error = checker.schema_format(request) + + assert error is None + assert result_request is request + + def test_schema_format_unexpected_error_real(self, monkeypatch, llguidance_checker): + """Test schema_format processing unexpected errors.""" + request = MockRequest() + request.guided_json = '{"type": "object"}' + + # Mock unexpected exception + def mock_serialize_grammar(*args, **kwargs): + raise Exception("Unexpected error") + + monkeypatch.setattr(llguidance_checker, "serialize_guidance_grammar", mock_serialize_grammar) + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "An unexpected error occurred during schema validation" in error + + def test_schema_format_no_valid_grammar_real(self, llguidance_checker): + """Test schema_format processing requests without valid grammar.""" + request = MockRequest() + # No grammar-related attributes set + + with pytest.raises(ValueError, match="grammar is not of valid supported types"): + llguidance_checker.serialize_guidance_grammar(request) + result_request, error = llguidance_checker.schema_format(request) + assert error is not None + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/model_executor/guided_decoding/test_guidance_processor.py b/tests/model_executor/guided_decoding/test_guidance_processor.py new file mode 100644 index 00000000000..cdeb91e378d --- /dev/null +++ b/tests/model_executor/guided_decoding/test_guidance_processor.py @@ -0,0 +1,172 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import sys +import unittest +from unittest.mock import MagicMock, patch + +# --- Mocking Setup --- +# Prioritize mocking these lazy-loaded modules to facilitate testing in environments where these libraries are not installed. +mock_torch = MagicMock() +mock_llguidance = MagicMock() +mock_llguidance_hf = MagicMock() +mock_llguidance_torch = MagicMock() + +mock_torch.__spec__ = MagicMock() +mock_torch.distributed = MagicMock() + +sys.modules["torch"] = mock_torch +sys.modules["llguidance"] = mock_llguidance +sys.modules["llguidance.hf"] = mock_llguidance_hf +sys.modules["llguidance.torch"] = mock_llguidance_torch + +# Import the module to be tested after the mock setup is complete +from fastdeploy.model_executor.guided_decoding.guidance_backend import ( + LLGuidanceProcessor, +) + + +def MockFDConfig(): + """Create a mock FDConfig object for testing""" + config = MagicMock() + # --- Fix point 1: Explicitly set model as a string to pass HF validation --- + config.model_config.model = "test-model-path" + config.model_config.architectures = [] # Set to empty list to prevent errors when iterating over the Mock + + config.model_config.vocab_size = 1000 + config.scheduler_config.max_num_seqs = 4 + config.structured_outputs_config.disable_any_whitespace = False + # Ensure the backend check logic passes + config.structured_outputs_config.guided_decoding_backend = "guidance" + return config + + +def MockHFTokenizer(): + """Create a mock Hugging Face Tokenizer object for testing""" + return MagicMock() + + +class TestLLGuidanceProcessorMocked(unittest.TestCase): + """ + Unit tests for LLGuidanceProcessor using Mock. + This test class is suitable for environments where the llguidance library is not installed. + """ + + def setUp(self): + """Set up a new LLGuidanceProcessor instance for each test case""" + self.mock_matcher = MagicMock() + self.mock_tokenizer = MagicMock() + self.mock_tokenizer.eos_token = 2 # Example EOS token ID + self.processor = LLGuidanceProcessor( + ll_matcher=self.mock_matcher, + ll_tokenizer=self.mock_tokenizer, + serialized_grammar="test_grammar", + vocab_size=1000, + batch_size=4, + enable_thinking=False, + ) + + def test_init(self): + """Test the constructor of LLGuidanceProcessor""" + self.assertIs(self.processor.matcher, self.mock_matcher) + self.assertEqual(self.processor.vocab_size, 1000) + self.assertEqual(self.processor.batch_size, 4) + self.assertFalse(self.processor.is_terminated) + + @patch("fastdeploy.utils.llm_logger.warning") + def test_check_error_logs_warning_once(self, mock_log_warning): + """Test that the _check_error method logs a warning when the matcher errors, and only logs it once""" + self.mock_matcher.get_error.return_value = "A test error." + + # First call, should log the message + self.processor._check_error() + mock_log_warning.assert_called_once_with("LLGuidance Matcher error: A test error.") + + # Second call, should not log repeatedly + self.processor._check_error() + mock_log_warning.assert_called_once() + + @patch("fastdeploy.model_executor.guided_decoding.guidance_backend.llguidance_torch") + def test_allocate_token_bitmask(self, mock_backend_torch): + """ + Test the allocation of token bitmask. + Note: We patch the llguidance_torch variable imported in the guidance_backend module here, + instead of the global mock in sys.modules, to resolve inconsistent references caused by LazyLoader. + """ + mock_backend_torch.allocate_token_bitmask.return_value = "fake_bitmask_tensor" + + result = self.processor.allocate_token_bitmask() + + mock_backend_torch.allocate_token_bitmask.assert_called_once_with(4, 1000) + self.assertEqual(result, "fake_bitmask_tensor") + + @patch("fastdeploy.model_executor.guided_decoding.guidance_backend.llguidance_torch") + def test_fill_token_bitmask(self, mock_backend_torch): + """Test the filling of token bitmask""" + mock_bitmask = MagicMock() + + self.processor.fill_token_bitmask(mock_bitmask, idx=2) + + mock_backend_torch.fill_next_token_bitmask.assert_called_once_with(self.mock_matcher, mock_bitmask, 2) + self.mock_matcher.get_error.assert_called_once() + + def test_reset(self): + """Test the state reset of the processor""" + self.processor.is_terminated = True + self.processor._printed_error = True + self.mock_matcher.get_error.return_value = "" + + self.processor.reset() + + self.mock_matcher.reset.assert_called_once() + self.assertFalse(self.processor.is_terminated) + self.assertFalse(self.processor._printed_error) + + def test_accept_token_when_terminated(self): + """Test that accept_token returns False immediately when status is is_terminated""" + self.processor.is_terminated = True + self.assertFalse(self.processor.accept_token(123)) + + def test_accept_token_when_matcher_stopped(self): + """Test that accept_token returns False and updates status when the matcher is stopped""" + self.mock_matcher.is_stopped.return_value = True + self.assertTrue(self.processor.accept_token(123)) + self.assertFalse(self.processor.is_terminated) + + def test_accept_token_is_eos(self): + """Test the behavior when an EOS token is received""" + self.mock_matcher.is_stopped.return_value = False + self.assertTrue(self.processor.accept_token(self.mock_tokenizer.eos_token)) + self.assertTrue(self.processor.is_terminated) + + def test_accept_token_consumes_and_succeeds(self): + """Test successfully consuming a token""" + self.mock_matcher.is_stopped.return_value = False + self.mock_matcher.consume_tokens.return_value = True + self.assertTrue(self.processor.accept_token(123)) + self.mock_matcher.consume_tokens.assert_called_once_with([123]) + self.mock_matcher.get_error.assert_called_once() + + def test_accept_token_consumes_and_fails(self): + """Test failing to consume a token""" + self.mock_matcher.is_stopped.return_value = False + self.mock_matcher.consume_tokens.return_value = False + self.assertFalse(self.processor.accept_token(123)) + self.mock_matcher.consume_tokens.assert_called_once_with([123]) + + +if __name__ == "__main__": + unittest.main()