From ae39e37bb6680ad90f6c8accbe8fd1798f9d96f8 Mon Sep 17 00:00:00 2001 From: ST-XX <15625257+ST-XX@users.noreply.github.com> Date: Wed, 19 Nov 2025 14:23:17 +0800 Subject: [PATCH 1/6] llguidance --- fastdeploy/config.py | 20 +- fastdeploy/envs.py | 2 + fastdeploy/lazy_loader.py | 70 +++ .../guided_decoding/__init__.py | 15 + .../guided_decoding/base_guided_decoding.py | 7 +- .../guided_decoding/guidance_backend.py | 314 ++++++++++ .../guided_decoding/xgrammar_backend.py | 2 - tests/layers/test_guidance_checker.py | 562 ++++++++++++++++++ tests/layers/test_guidance_processor.py | 172 ++++++ tests/model_executor/test_gidance_backend.py | 213 +++++++ 10 files changed, 1371 insertions(+), 6 deletions(-) create mode 100644 fastdeploy/lazy_loader.py create mode 100644 fastdeploy/model_executor/guided_decoding/guidance_backend.py create mode 100644 tests/layers/test_guidance_checker.py create mode 100644 tests/layers/test_guidance_processor.py create mode 100644 tests/model_executor/test_gidance_backend.py diff --git a/fastdeploy/config.py b/fastdeploy/config.py index afe37f076b4..e9d537c190b 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1593,13 +1593,26 @@ def postprocess(self): if ( self.structured_outputs_config is not None - and self.structured_outputs_config.guided_decoding_backend == "auto" + and self.structured_outputs_config.guided_decoding_backend != "off" ): if current_platform.is_xpu() or self.speculative_config.method is not None: logger.warning("Speculative Decoding and XPU currently do not support Guided decoding, set off.") self.structured_outputs_config.guided_decoding_backend = "off" - else: + elif self.structured_outputs_config.guided_decoding_backend in ["auto", "xgrammar"]: self.structured_outputs_config.guided_decoding_backend = "xgrammar" + elif self.structured_outputs_config.guided_decoding_backend == "guidance": + pass + # try: + # import llguidance + # import llguidance.hf + # import llguidance.torch + # except ImportError: + # raise ImportError( + # "The 'xgrammar' package is required for using xgrammar as the guided decoding backend. " + # "Please install it with 'pip install xgrammar'." + # ) + else: + raise NotImplementedError(f"{self.structured_outputs_config.guided_decoding_backend}") if self.model_config.enable_mm: if self.cache_config.max_encoder_cache is None or self.cache_config.max_encoder_cache < 0: @@ -1709,7 +1722,8 @@ def check(self): "XGrammar", "auto", "off", - ], f"Only support xgrammar、auto guided decoding backend, but got {self.structured_outputs_config.guided_decoding_backend}." + "guidance", + ], f"Only support [auto, xgrammar, guidance, off] guided decoding backend, but got {self.structured_outputs_config.guided_decoding_backend}." if self.structured_outputs_config.guided_decoding_backend != "off": # TODO: speculative decoding support guided_decoding diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 23030d6a80b..3b7830dafcd 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -161,6 +161,8 @@ "FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")), "FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")), "FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")), + "FD_GUIDANCE_DISABLE_ADDITIONAL": lambda: bool(int(os.getenv("FD_GUIDANCE_DISABLE_ADDITIONAL", "1"))), + "FD_LLGUIDANCE_LOG_LEVEL": lambda: int(os.getenv("FD_LLGUIDANCE_LOG_LEVEL", "0")), } diff --git a/fastdeploy/lazy_loader.py b/fastdeploy/lazy_loader.py new file mode 100644 index 00000000000..c3bc3ec43aa --- /dev/null +++ b/fastdeploy/lazy_loader.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A LazyLoader class.""" + +import importlib +import sys +import types +from typing import Any + + +class LazyLoader(types.ModuleType): + """ + LazyLoader module borrowed from Tensorflow + https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py + with an addition of "module caching". + + Lazily import a module, mainly to avoid pulling in large dependencies. + Modules such as `xgrammar` might do additional side effects, so we + only want to use this when it is needed, delaying all eager effects + """ + + def __init__( + self, + local_name: str, + parent_module_globals: dict[str, Any], + name: str, + ): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + self._module: types.ModuleType | None = None + + super().__init__(str(name)) + + def _load(self) -> types.ModuleType: + # Import the target module and insert it into the parent's namespace + try: + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + # The additional add to sys.modules + # ensures library is actually loaded. + sys.modules[self._local_name] = module + except ModuleNotFoundError as err: + raise err from None + + # Update this object's dict so that if someone keeps a + # reference to the LazyLoader, lookups are efficient + # (__getattr__ is only called on lookups that fail). + self.__dict__.update(module.__dict__) + return module + + def __getattr__(self, item: Any) -> Any: + if self._module is None: + self._module = self._load() + return getattr(self._module, item) + + def __dir__(self) -> list[str]: + if self._module is None: + self._module = self._load() + return dir(self._module) diff --git a/fastdeploy/model_executor/guided_decoding/__init__.py b/fastdeploy/model_executor/guided_decoding/__init__.py index dbfc70215d1..989ce0fd63f 100644 --- a/fastdeploy/model_executor/guided_decoding/__init__.py +++ b/fastdeploy/model_executor/guided_decoding/__init__.py @@ -50,6 +50,15 @@ def get_guided_backend( fd_config=fd_config, **kwargs, ) + elif fd_config.structured_outputs_config.guided_decoding_backend.lower() == "guidance": + from fastdeploy.model_executor.guided_decoding.guidance_backend import ( + LLGuidanceBackend, + ) + + return LLGuidanceBackend( + fd_config=fd_config, + **kwargs, + ) else: raise ValueError( f"Get unsupported backend {fd_config.structured_outputs_config.guided_decoding_backend}," @@ -77,5 +86,11 @@ def schema_checker(backend_name: str, **kwargs): ) return XGrammarChecker(**kwargs) + elif backend_name.lower() == "guidance": + from fastdeploy.model_executor.guided_decoding.guidance_backend import ( + LLGuidanceChecker, + ) + + return LLGuidanceChecker(**kwargs) else: raise ValueError(f"Get unsupported backend {backend_name}, please check your configuration.") diff --git a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py index c4e235afc36..717cafdcdf8 100644 --- a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py +++ b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py @@ -294,7 +294,12 @@ def _get_tokenizer_hf(self): """ try: architectures = self.fd_config.model_config.architectures - if not ErnieArchitectures.contains_ernie_arch(architectures): + is_guidance_backend = ( + self.fd_config.structured_outputs_config is not None + and self.fd_config.structured_outputs_config.guided_decoding_backend is not None + and self.fd_config.structured_outputs_config.guided_decoding_backend == "guidance" + ) + if not ErnieArchitectures.contains_ernie_arch(architectures) or is_guidance_backend: from transformers import AutoTokenizer, PreTrainedTokenizerFast tokenizer = AutoTokenizer.from_pretrained( diff --git a/fastdeploy/model_executor/guided_decoding/guidance_backend.py b/fastdeploy/model_executor/guided_decoding/guidance_backend.py new file mode 100644 index 00000000000..3b0a22e87b1 --- /dev/null +++ b/fastdeploy/model_executor/guided_decoding/guidance_backend.py @@ -0,0 +1,314 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import copy +import json +import traceback +from typing import Any, Optional, Tuple, Union + +from fastdeploy.config import FDConfig +from fastdeploy.engine.request import Request +from fastdeploy.envs import FD_GUIDANCE_DISABLE_ADDITIONAL, FD_LLGUIDANCE_LOG_LEVEL +from fastdeploy.lazy_loader import LazyLoader +from fastdeploy.model_executor.guided_decoding import ( + BackendBase, + BaseChecker, + LogitsProcessorBase, +) +from fastdeploy.utils import llm_logger + +torch = LazyLoader("torch", globals(), "torch") +llguidance = LazyLoader("llguidance", globals(), "llguidance") +llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf") +llguidance_torch = LazyLoader("llguidance.torch", globals(), "llguidance.torch") + + +class LLGuidanceProcessor(LogitsProcessorBase): + """ + LLGuidance-specific implementation of LogitsProcessorBase. + + This processor enforces grammar constraints during token generation using llguidance. + It manages the grammar matching state and applies token masks to logits. + """ + + def __init__( + self, + ll_matcher: llguidance.LLMatcher, + ll_tokenizer: llguidance.LLTokenizer, + serialized_grammar: str, + vocab_size: int, + batch_size: int, + enable_thinking: bool = False, + ): + super().__init__(enable_reasoning=enable_thinking) + self.matcher = ll_matcher + self.ll_tokenizer = ll_tokenizer + self.serialized_grammar = serialized_grammar + self.vocab_size = vocab_size + self.batch_size = batch_size + self.is_terminated: bool = False + self._printed_error: bool = False + + def _check_error(self): + """Checks for and logs any errors from the LLMatcher.""" + if not self._printed_error: + err = self.matcher.get_error() + if err: + self._printed_error = True + llm_logger.warning(f"LLGuidance Matcher error: {err}") + + def allocate_token_bitmask(self) -> torch.Tensor: + """ + Allocate a token bitmask tensor for grammar constraints. + """ + return llguidance_torch.allocate_token_bitmask(self.batch_size, self.vocab_size) + + def fill_token_bitmask(self, token_bitmask: torch.Tensor, idx: int) -> None: + """ + Fill the token bitmask with allowed tokens for the given index. + This will automatically provide an EOS mask if the matcher is stopped. + """ + llguidance_torch.fill_next_token_bitmask(self.matcher, token_bitmask, idx) + self._check_error() + + def reset(self) -> None: + """ + Reset the grammar matcher state to initial conditions. + """ + self.matcher.reset() + self.is_terminated = False + self._printed_error = False + self._check_error() + + def accept_token(self, token: int) -> bool: + """ + Validate and accept a generated token against the grammar constraints. + Returns True if the token is accepted, False otherwise. + """ + if self.is_terminated: + return False + if self.ll_tokenizer.eos_token == token: + self.is_terminated = True + return True + + result = self.matcher.consume_tokens([token]) + self._check_error() + + return result + + +class LLGuidanceBackend(BackendBase): + """ + LLGuidance-specific implementation of BackendBase. + + This backend handles the compilation of various schema types (JSON, regex, etc.) + into LLGuidance processors. + """ + + def __init__(self, fd_config: FDConfig, **kwargs): + super().__init__(fd_config=fd_config) + self.vocab_size = fd_config.model_config.vocab_size + self.batch_size = fd_config.scheduler_config.max_num_seqs + self.any_whitespace = not fd_config.structured_outputs_config.disable_any_whitespace + + llm_logger.info(f"LLGuidanceBackend vocab_size={self.vocab_size} batch_size={self.batch_size}") + try: + self.ll_tokenizer = llguidance_hf.from_tokenizer(self.hf_tokenizer, self.vocab_size) + except Exception as e: + import traceback + + raise RuntimeError( + f"Failed to initialize llguidance tokenizer from HuggingFace tokenizer: {e} {traceback.format_exc()}" + ) + + def _create_processor( + self, + compiled_grammar: str, + enable_thinking: bool = False, + ) -> Optional[LLGuidanceProcessor]: + """ + Create a logits processor instance for the given grammar schemata. + """ + try: + + ll_matcher = llguidance.LLMatcher( + self.ll_tokenizer, + compiled_grammar, + log_level=FD_LLGUIDANCE_LOG_LEVEL, + ) + + return LLGuidanceProcessor( + ll_matcher=ll_matcher, + ll_tokenizer=self.ll_tokenizer, + serialized_grammar=compiled_grammar, + vocab_size=self.vocab_size, + batch_size=self.batch_size, + enable_thinking=enable_thinking, + ) + except Exception as e: + llm_logger.error(f"Failed to create llguidance processor: {e}, {str(traceback.format_exc())}") + return None + + def _json_processor(self, compiled_grammar: str, enable_thinking: bool = False) -> Optional[LLGuidanceProcessor]: + return self._create_processor(compiled_grammar, enable_thinking) + + def _regex_processor(self, compiled_grammar: str, enable_thinking: bool = False) -> Optional[LLGuidanceProcessor]: + return self._create_processor(compiled_grammar, enable_thinking) + + def _grammar_processor( + self, compiled_grammar: str, enable_thinking: bool = False + ) -> Optional[LLGuidanceProcessor]: + return self._create_processor(compiled_grammar, enable_thinking) + + def _structural_tag_processor( + self, compiled_grammar: str, enable_thinking: bool = False + ) -> Optional[LLGuidanceProcessor]: + return self._create_processor(compiled_grammar, enable_thinking) + + +def _walk_json_for_additional_properties(data: object): + if isinstance(data, dict): + for value in data.values(): + _walk_json_for_additional_properties(value) + if "additionalProperties" not in data and ("properties" in data or "patternProperties" in data): + data["additionalProperties"] = False + elif isinstance(data, list): + for item in data: + _walk_json_for_additional_properties(item) + + +def process_for_additional_properties(guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]: + if isinstance(guide_json, str): + guide_json_obj = json.loads(guide_json) + else: + # copy for modifications + guide_json_obj = copy.deepcopy(guide_json) + _walk_json_for_additional_properties(guide_json_obj) + return guide_json_obj + + +class LLGuidanceChecker(BaseChecker): + """ + LLGuidance-specific implementation of BaseChecker. + + This checker validates various schema types for compatibility with the + llguidance library before processing. + """ + + def __init__(self, **kwargs): + super().__init__() + # Although the backend handles serialization, we can perform a quick + # static check here without a full tokenizer. + self.any_whitespace = not kwargs.get("disable_any_whitespace", False) + self.disable_additional_properties = FD_GUIDANCE_DISABLE_ADDITIONAL + """If `True`, the `guidance` backend will not use `additionalProperties` + in the JSON schema. This is only supported for the `guidance` backend and + is used to better align its behaviour with `outlines` and `xgrammar`.""" + + def serialize_guidance_grammar(self, request: Request): + def _process_schema( + grammar_spec: Union[str, dict[str, Any]], + ) -> str: + if self.disable_additional_properties: + grammar_spec = process_for_additional_properties(grammar_spec) + return llguidance.LLMatcher.grammar_from_json_schema( + grammar_spec, + defaults={ + "whitespace_flexible": self.any_whitespace, + }, + ) + + if request.guided_json: + if isinstance(request.guided_json, dict): + guided_json = json.dumps(request.guided_json) + else: + guided_json = request.guided_json + return _process_schema(guided_json) + elif request.guided_json_object: + return llguidance.LLMatcher.grammar_from_json_schema( + '{"type": "object"}', + defaults={ + "whitespace_flexible": self.any_whitespace, + }, + ) + + if request.structural_tag: + if isinstance(request.structural_tag, str): + s_tag = json.loads(request.structural_tag) + else: + s_tag = request.structural_tag + triggers: list[str] = s_tag["triggers"] + tags: list[llguidance.StructTag] = [] + for s in s_tag["structures"]: + begin: str = s["begin"] + trig = next((t for t in triggers if begin.startswith(t)), None) + if trig is None: + raise ValueError(f"Trigger {begin} not found in triggers {triggers}") + tags.append( + llguidance.StructTag( + trigger=trig, + begin=s["begin"], + grammar=_process_schema(s["schema"]), + end=s["end"], + ) + ) + if not tags: + raise ValueError("No structural tags found in the grammar spec.") + return llguidance.StructTag.to_grammar(tags) + + if request.guided_regex: + tp = "regex" + grammar_spec = request.guided_regex + elif request.guided_choice: + tp = "choice" + grammar_spec = request.guided_choice + elif request.guided_grammar: + tp = "grammar" + grammar_spec = request.guided_grammar + else: + llm_logger.error("Validation should have already occurred. " "Please file an issue.") + raise ValueError("grammar is not of valid supported types. ") + return llguidance.grammar_from(tp, grammar_spec) + + def schema_format(self, request: Request) -> Tuple[Request, Optional[str]]: + """ + Validates and formats the schema for the LLGuidance backend. + """ + try: + guidance_grm = self.serialize_guidance_grammar(request) + err = llguidance.LLMatcher.validate_grammar(guidance_grm, None) + if err: + raise ValueError(f"Grammar error: {err}") + else: + llm_logger.info(f"valid schema_format {guidance_grm} {request}") + if request.guided_regex: + request.guided_regex = guidance_grm + elif request.guided_choice: + request.guided_choice = guidance_grm + elif request.guided_grammar: + request.guided_grammar = guidance_grm + elif request.guided_json: + request.guided_json = guidance_grm + + except (ValueError, TypeError, json.JSONDecodeError) as e: + err_msg = f"Invalid format for guided decoding: {e!s} request={request}" + return request, err_msg + + except Exception as e: + err_msg = f"An unexpected error occurred during schema validation: {e!s}" + return request, err_msg + + return request, None diff --git a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py index e0f195e47df..8ef67ef6abb 100644 --- a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py +++ b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py @@ -73,7 +73,6 @@ def __init__( enable_thinking: bool = False, ): super().__init__(enable_reasoning=enable_thinking) - self.max_rollback_tokens = 200 self.vocab_size = vocab_size self.batch_size = batch_size self.compiled_grammar = compiled_grammar @@ -82,7 +81,6 @@ def __init__( self.matcher = GrammarMatcher( compiled_grammar=compiled_grammar, - max_rollback_tokens=self.max_rollback_tokens, terminate_without_stop_token=terminate_without_stop_token, override_stop_tokens=override_stop_tokens, ) diff --git a/tests/layers/test_guidance_checker.py b/tests/layers/test_guidance_checker.py new file mode 100644 index 00000000000..286d72ed804 --- /dev/null +++ b/tests/layers/test_guidance_checker.py @@ -0,0 +1,562 @@ +import json +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +from fastdeploy.model_executor.guided_decoding.guidance_backend import LLGuidanceChecker + +# 检查是否可以导入llguidance +HAS_LLGUIDANCE = False +try: + import llguidance + + llguidance + HAS_LLGUIDANCE = True +except ImportError: + pass + + +@pytest.fixture +def llguidance_checker(): + """返回一个LLGuidanceChecker实例供测试使用""" + return LLGuidanceChecker() + + +@pytest.fixture +def llguidance_checker_with_options(): + """返回一个配置了特定选项的LLGuidanceChecker实例""" + return LLGuidanceChecker(disable_any_whitespace=True) + + +def MockRequest(): + request = MagicMock() + request.guided_json = None + request.guided_json_object = None + request.structural_tag = None + request.guided_regex = None + request.guided_choice = None + request.guided_grammar = None + return request + + +class TestLLGuidanceCheckerMocked: + """使用Mock测试LLGuidanceChecker,适用于没有llguidance库的环境""" + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + @patch("llguidance.LLMatcher.validate_grammar") + def test_serialize_guided_json_as_string(self, mock_validate, mock_from_schema, llguidance_checker): + """测试处理guided_json字符串类型""" + mock_from_schema.return_value = "serialized_grammar" + mock_validate.return_value = None + + request = MockRequest() + request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}' + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_from_schema.assert_called_once() + assert grammar == "serialized_grammar" + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + @patch("llguidance.LLMatcher.validate_grammar") + def test_serialize_guided_json_as_dict(self, mock_validate, mock_from_schema, llguidance_checker): + """测试处理guided_json字典类型""" + mock_from_schema.return_value = "serialized_grammar" + mock_validate.return_value = None + + request = MockRequest() + request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}} + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_from_schema.assert_called_once() + assert isinstance(request.guided_json, dict) # 验证字典已转换为字符串 + assert grammar == "serialized_grammar" + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + @patch("llguidance.LLMatcher.validate_grammar") + def test_serialize_guided_json_object(self, mock_validate, mock_from_schema, llguidance_checker): + """测试处理guided_json_object""" + mock_from_schema.return_value = "serialized_grammar" + mock_validate.return_value = None + + request = MockRequest() + request.guided_json_object = True + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_from_schema.assert_called_once() + assert request.guided_json_object + assert grammar == "serialized_grammar" + + @patch("llguidance.grammar_from") + @patch("llguidance.LLMatcher.validate_grammar") + def test_serialize_guided_regex(self, mock_validate, mock_grammar_from, llguidance_checker): + """测试处理guided_regex""" + mock_grammar_from.return_value = "serialized_regex_grammar" + mock_validate.return_value = None + + request = MockRequest() + request.guided_regex = "[a-zA-Z]+" + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_grammar_from.assert_called_once_with("regex", "[a-zA-Z]+") + assert grammar == "serialized_regex_grammar" + + @patch("llguidance.grammar_from") + @patch("llguidance.LLMatcher.validate_grammar") + def test_serialize_guided_choice(self, mock_validate, mock_grammar_from, llguidance_checker): + """测试处理guided_choice""" + mock_grammar_from.return_value = "serialized_choice_grammar" + mock_validate.return_value = None + + request = MockRequest() + request.guided_choice = ["option1", "option2"] + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_grammar_from.assert_called_once_with("choice", ["option1", "option2"]) + assert grammar == "serialized_choice_grammar" + + @patch("llguidance.grammar_from") + @patch("llguidance.LLMatcher.validate_grammar") + def test_serialize_guided_grammar(self, mock_validate, mock_grammar_from, llguidance_checker): + """测试处理guided_grammar""" + mock_grammar_from.return_value = "serialized_grammar_spec" + mock_validate.return_value = None + + request = MockRequest() + request.guided_grammar = "grammar specification" + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_grammar_from.assert_called_once_with("grammar", "grammar specification") + assert grammar == "serialized_grammar_spec" + + @patch("llguidance.StructTag") + @patch("llguidance.LLMatcher.grammar_from_json_schema") + def test_serialize_structural_tag(self, mock_from_schema, mock_struct_tag, llguidance_checker): + """测试处理structural_tag""" + # 配置mock对象 + mock_from_schema.return_value = "serialized_schema" + mock_struct_tag.to_grammar.return_value = "serialized_structural_grammar" + struct_tag_instance = MagicMock() + mock_struct_tag.return_value = struct_tag_instance + + request = MockRequest() + request.structural_tag = { + "triggers": [""], + "structures": [{"begin": "", "schema": {"type": "object"}, "end": ""}], + } + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + mock_from_schema.assert_called_once() + mock_struct_tag.assert_called_once() + mock_struct_tag.to_grammar.assert_called_once() + assert grammar == "serialized_structural_grammar" + + @patch("llguidance.StructTag") + def test_serialize_structural_tag_missing_trigger(self, mock_struct_tag, llguidance_checker): + """测试处理structural_tag中缺少触发器的情况""" + request = MockRequest() + request.structural_tag = { + "triggers": [""], + "structures": [{"begin": "", "schema": {"type": "object"}, "end": ""}], + } + + with pytest.raises(ValueError, match="Trigger .* not found in triggers"): + llguidance_checker.serialize_guidance_grammar(request) + + @patch("llguidance.StructTag") + def test_serialize_structural_tag_empty_structures(self, mock_struct_tag, llguidance_checker): + """测试处理structural_tag中结构为空的情况""" + request = MockRequest() + request.structural_tag = {"triggers": [""], "structures": []} + + with pytest.raises(ValueError, match="No structural tags found in the grammar spec"): + llguidance_checker.serialize_guidance_grammar(request) + + def test_serialize_invalid_grammar_type(self, llguidance_checker): + """测试处理无效的语法类型""" + request = MockRequest() + # 没有设置任何语法类型 + + with pytest.raises(ValueError, match="grammar is not of valid supported types"): + llguidance_checker.serialize_guidance_grammar(request) + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + @patch("llguidance.LLMatcher.validate_grammar") + def test_schema_format_valid_json(self, mock_validate, mock_from_schema, llguidance_checker): + """测试schema_format方法处理有效的JSON""" + mock_from_schema.return_value = "serialized_grammar" + mock_validate.return_value = None + + request = MockRequest() + request.guided_json = '{"type": "object"}' + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + @patch("llguidance.LLMatcher.validate_grammar") + def test_schema_format_invalid_grammar(self, mock_validate, mock_from_schema, llguidance_checker): + """测试schema_format方法处理无效的语法""" + mock_from_schema.return_value = "serialized_grammar" + mock_validate.return_value = "Invalid grammar" + + request = MockRequest() + request.guided_json = '{"type": "object"}' + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "Grammar error: Invalid grammar" in error + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + def test_schema_format_json_decode_error(self, mock_from_schema, llguidance_checker): + """测试schema_format方法处理JSON解码错误""" + mock_from_schema.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + + request = MockRequest() + request.guided_json = "{invalid json}" + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "Invalid format for guided decoding" in error + + @patch("llguidance.LLMatcher.grammar_from_json_schema") + def test_schema_format_unexpected_error(self, mock_from_schema, llguidance_checker): + """测试schema_format方法处理意外错误""" + mock_from_schema.side_effect = Exception("Unexpected error") + + request = MockRequest() + request.guided_json = '{"type": "object"}' + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "An unexpected error occurred during schema validation" in error + + def test_init_with_disable_whitespace(self, llguidance_checker_with_options): + """测试初始化时设置disable_any_whitespace选项""" + assert llguidance_checker_with_options.any_whitespace is False + assert llguidance_checker_with_options.disable_additional_properties is True + assert LLGuidanceChecker(disable_any_whitespace=True).any_whitespace is False + assert LLGuidanceChecker(disable_any_whitespace=False).any_whitespace is True + + # default check + from fastdeploy.envs import FD_GUIDANCE_DISABLE_ADDITIONAL + + assert FD_GUIDANCE_DISABLE_ADDITIONAL + + assert LLGuidanceChecker().disable_additional_properties is True + with patch("fastdeploy.model_executor.guided_decoding.guidance_backend.FD_GUIDANCE_DISABLE_ADDITIONAL", False): + assert LLGuidanceChecker().disable_additional_properties is False + + +@pytest.mark.skipif(not HAS_LLGUIDANCE, reason="llguidance库未安装,跳过实际依赖测试") +class TestLLGuidanceCheckerReal: + """使用实际的llguidance库进行测试,适用于开发环境""" + + def test_serialize_guided_json_string_real(self, llguidance_checker): + """使用实际库测试处理guided_json字符串""" + request = MockRequest() + request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}' + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + # 验证返回的grammar是否是一个有效的字符串 + assert isinstance(grammar, str) + assert len(grammar) > 0 + print("grammar", grammar) + + def test_serialize_guided_json_dict_real(self, llguidance_checker): + """使用实际库测试处理guided_json字典""" + request = MockRequest() + request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}} + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + assert isinstance(request.guided_json, dict) + assert isinstance(grammar, str) + assert len(grammar) > 0 + + def test_serialize_guided_json_object_real(self, llguidance_checker): + """使用实际库测试处理guided_json_object""" + request = MockRequest() + request.guided_json_object = True + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + assert request.guided_json_object + assert isinstance(grammar, str) + assert len(grammar) > 0 + + def test_serialize_guided_regex_real(self, llguidance_checker): + """使用实际库测试处理guided_regex""" + request = MockRequest() + request.guided_regex = "[a-zA-Z]+" + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + assert isinstance(grammar, str) + assert len(grammar) > 0 + + def test_serialize_guided_choice_real(self, llguidance_checker): + """使用实际库测试处理guided_choice""" + request = MockRequest() + request.guided_choice = ["option1", "option2"] + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + assert isinstance(grammar, str) + assert len(grammar) > 0 + + def test_serialize_guided_grammar_real(self, llguidance_checker): + """使用实际库测试处理guided_grammar""" + request = MockRequest() + # 使用简单的CFG文法示例 + request.guided_grammar = """ + root ::= greeting name + greeting ::= "Hello" | "Hi" + name ::= "world" | "everyone" + """ + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + assert isinstance(grammar, str) + assert len(grammar) > 0 + + def test_serialize_structural_tag_real(self, llguidance_checker): + """使用实际库测试处理structural_tag""" + request = MockRequest() + request.structural_tag = { + "triggers": [""], + "structures": [{"begin": "", "schema": {"type": "object"}, "end": ""}], + } + + grammar = llguidance_checker.serialize_guidance_grammar(request) + + assert isinstance(grammar, str) + assert len(grammar) > 0 + + def test_schema_format_valid_json_real(self, llguidance_checker): + """使用实际库测试schema_format方法处理有效的JSON""" + request = MockRequest() + request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}' + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + assert result_request.guided_json != '{"type": "object", "properties": {"name": {"type": "string"}}}' + + def test_schema_format_invalid_json_real(self, llguidance_checker): + """使用实际库测试schema_format方法处理无效的JSON""" + request = MockRequest() + request.guided_json = "{invalid json}" + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "Invalid format for guided decoding" in error + + def test_whitespace_flexibility_option_real(self): + """使用实际库测试whitespace灵活性选项的影响""" + # 创建两个不同配置的实例 + flexible = LLGuidanceChecker(disable_any_whitespace=False) + strict = LLGuidanceChecker(disable_any_whitespace=True) + + request_flexible = MockRequest() + request_flexible.guided_json = '{"type": "object"}' + + request_strict = MockRequest() + request_strict.guided_json = '{"type": "object"}' + + grammar_flexible = flexible.serialize_guidance_grammar(request_flexible) + grammar_strict = strict.serialize_guidance_grammar(request_strict) + print("grammar_flexible", grammar_flexible) + print("grammar_strict", grammar_strict) + + # 预期两种配置生成的语法应该不同 + assert grammar_flexible != grammar_strict + + def test_schema_format_guided_json_object_real(self, llguidance_checker): + """测试schema_format处理guided_json_object""" + request = MockRequest() + request.guided_json_object = True + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + + def test_schema_format_guided_regex_real(self, llguidance_checker): + """测试schema_format处理有效的正则表达式""" + request = MockRequest() + request.guided_regex = r"[a-zA-Z0-9]+" + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + assert result_request.guided_regex != r"[a-zA-Z0-9]+" # 应该被转换为grammar格式 + + def test_schema_format_invalid_guided_regex_real(self, llguidance_checker): + """测试schema_format处理无效的正则表达式""" + request = MockRequest() + request.guided_regex = r"[" # 无效的正则表达式 + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "Invalid format for guided decoding" in error + + def test_schema_format_guided_choice_real(self, llguidance_checker): + """测试schema_format处理guided_choice""" + request = MockRequest() + request.guided_choice = ["option1", "option2", "option3"] + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + assert result_request.guided_choice != ["option1", "option2", "option3"] # 应该被转换为grammar格式 + + def test_schema_format_guided_grammar_real(self, llguidance_checker): + """测试schema_format处理guided_grammar""" + request = MockRequest() + # 使用LLGuidance支持的正确语法格式 + request.guided_grammar = """ + start: number + number: DIGIT+ + DIGIT: "0"|"1"|"2"|"3"|"4"|"5"|"6"|"7"|"8"|"9" + """ + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + assert isinstance(result_request.guided_grammar, str) + + def test_schema_format_structural_tag_real(self, llguidance_checker): + """测试schema_format处理structural_tag""" + request = MockRequest() + request.structural_tag = { + "triggers": ["```json"], + "structures": [ + { + "begin": "```json", + "schema": {"type": "object", "properties": {"name": {"type": "string"}}}, + "end": "```", + } + ], + } + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + + def test_schema_format_structural_tag_string_real(self, llguidance_checker): + """测试schema_format处理字符串形式的structural_tag""" + request = MockRequest() + request.structural_tag = json.dumps( + { + "triggers": ["```json"], + "structures": [ + { + "begin": "```json", + "schema": {"type": "object", "properties": {"name": {"type": "string"}}}, + "end": "```", + } + ], + } + ) + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + + def test_schema_format_structural_tag_invalid_trigger_real(self, llguidance_checker): + """测试schema_format处理trigger无效的structural_tag""" + request = MockRequest() + request.structural_tag = { + "triggers": ["```xml"], # 触发器与begin不匹配 + "structures": [ + {"begin": "```json", "schema": {"type": "object"}, "end": "```"} # 这里不包含任何triggers中的前缀 + ], + } + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "Invalid format for guided decoding" in error + + def test_schema_format_structural_tag_empty_structures_real(self, llguidance_checker): + """测试schema_format处理空structures的structural_tag""" + request = MockRequest() + request.structural_tag = {"triggers": ["```json"], "structures": []} # 空结构 + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "Invalid format for guided decoding" in error + + def test_schema_format_json_dict_real(self, llguidance_checker): + """测试schema_format处理字典形式的guided_json""" + request = MockRequest() + request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}} + + result_request, error = llguidance_checker.schema_format(request) + + assert error is None + assert result_request is request + + def test_schema_format_disable_additional_properties_real(self): + """测试schema_format处理disable_additional_properties参数""" + checker = LLGuidanceChecker(disable_additional_properties=True) + request = MockRequest() + request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}} + + result_request, error = checker.schema_format(request) + + assert error is None + assert result_request is request + + def test_schema_format_unexpected_error_real(self, monkeypatch, llguidance_checker): + """测试schema_format处理意外错误""" + request = MockRequest() + request.guided_json = '{"type": "object"}' + + # 模拟意外异常 + def mock_serialize_grammar(*args, **kwargs): + raise Exception("Unexpected error") + + monkeypatch.setattr(llguidance_checker, "serialize_guidance_grammar", mock_serialize_grammar) + + result_request, error = llguidance_checker.schema_format(request) + + assert error is not None + assert "An unexpected error occurred during schema validation" in error + + def test_schema_format_no_valid_grammar_real(self, llguidance_checker): + """测试schema_format处理没有有效语法的请求""" + request = MockRequest() + # 没有设置任何语法相关的属性 + + with pytest.raises(ValueError, match="grammar is not of valid supported types"): + llguidance_checker.serialize_guidance_grammar(request) + result_request, error = llguidance_checker.schema_format(request) + assert error is not None + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/layers/test_guidance_processor.py b/tests/layers/test_guidance_processor.py new file mode 100644 index 00000000000..2ec96c67a00 --- /dev/null +++ b/tests/layers/test_guidance_processor.py @@ -0,0 +1,172 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import sys +import unittest +from unittest.mock import MagicMock, patch + +# --- Mocking Setup --- +# 优先模拟这些懒加载的模块,以便在未安装这些库的环境中进行测试。 +mock_torch = MagicMock() +mock_llguidance = MagicMock() +mock_llguidance_hf = MagicMock() +mock_llguidance_torch = MagicMock() + +mock_torch.__spec__ = MagicMock() +mock_torch.distributed = MagicMock() + +sys.modules["torch"] = mock_torch +sys.modules["llguidance"] = mock_llguidance +sys.modules["llguidance.hf"] = mock_llguidance_hf +sys.modules["llguidance.torch"] = mock_llguidance_torch + +# 模拟设置完成后,再导入需要测试的模块 +from fastdeploy.model_executor.guided_decoding.guidance_backend import ( + LLGuidanceProcessor, +) + + +def MockFDConfig(): + """创建一个用于测试的FDConfig模拟对象""" + config = MagicMock() + # --- 修复点 1: 显式设置 model 为字符串,通过 HF 的验证 --- + config.model_config.model = "test-model-path" + config.model_config.architectures = [] # 设置为空列表,防止迭代 Mock 出错 + + config.model_config.vocab_size = 1000 + config.scheduler_config.max_num_seqs = 4 + config.structured_outputs_config.disable_any_whitespace = False + # 确保 backend 检查逻辑能通过 + config.structured_outputs_config.guided_decoding_backend = "guidance" + return config + + +def MockHFTokenizer(): + """创建一个用于测试的Hugging Face Tokenizer模拟对象""" + return MagicMock() + + +class TestLLGuidanceProcessorMocked(unittest.TestCase): + """ + 使用Mock对LLGuidanceProcessor进行单元测试。 + 这个测试类适用于没有安装llguidance库的环境。 + """ + + def setUp(self): + """为每个测试用例设置一个新的LLGuidanceProcessor实例""" + self.mock_matcher = MagicMock() + self.mock_tokenizer = MagicMock() + self.mock_tokenizer.eos_token = 2 # 示例EOS token ID + self.processor = LLGuidanceProcessor( + ll_matcher=self.mock_matcher, + ll_tokenizer=self.mock_tokenizer, + serialized_grammar="test_grammar", + vocab_size=1000, + batch_size=4, + enable_thinking=False, + ) + + def test_init(self): + """测试LLGuidanceProcessor的构造函数""" + self.assertIs(self.processor.matcher, self.mock_matcher) + self.assertEqual(self.processor.vocab_size, 1000) + self.assertEqual(self.processor.batch_size, 4) + self.assertFalse(self.processor.is_terminated) + + @patch("fastdeploy.utils.llm_logger.warning") + def test_check_error_logs_warning_once(self, mock_log_warning): + """测试_check_error方法在匹配器出错时能记录警告,且只记录一次""" + self.mock_matcher.get_error.return_value = "A test error." + + # 第一次调用,应该打印日志 + self.processor._check_error() + mock_log_warning.assert_called_once_with("LLGuidance Matcher error: A test error.") + + # 第二次调用,不应该重复打印 + self.processor._check_error() + mock_log_warning.assert_called_once() + + @patch("fastdeploy.model_executor.guided_decoding.guidance_backend.llguidance_torch") + def test_allocate_token_bitmask(self, mock_backend_torch): + """ + 测试token bitmask的分配。 + 注意:这里Patch的是guidance_backend模块中导入的llguidance_torch变量, + 而不是sys.modules里的全局mock,以解决LazyLoader导致的引用不一致问题。 + """ + mock_backend_torch.allocate_token_bitmask.return_value = "fake_bitmask_tensor" + + result = self.processor.allocate_token_bitmask() + + mock_backend_torch.allocate_token_bitmask.assert_called_once_with(4, 1000) + self.assertEqual(result, "fake_bitmask_tensor") + + @patch("fastdeploy.model_executor.guided_decoding.guidance_backend.llguidance_torch") + def test_fill_token_bitmask(self, mock_backend_torch): + """测试token bitmask的填充""" + mock_bitmask = MagicMock() + + self.processor.fill_token_bitmask(mock_bitmask, idx=2) + + mock_backend_torch.fill_next_token_bitmask.assert_called_once_with(self.mock_matcher, mock_bitmask, 2) + self.mock_matcher.get_error.assert_called_once() + + def test_reset(self): + """测试处理器的状态重置""" + self.processor.is_terminated = True + self.processor._printed_error = True + self.mock_matcher.get_error.return_value = "" + + self.processor.reset() + + self.mock_matcher.reset.assert_called_once() + self.assertFalse(self.processor.is_terminated) + self.assertFalse(self.processor._printed_error) + + def test_accept_token_when_terminated(self): + """测试当状态为is_terminated时,accept_token直接返回False""" + self.processor.is_terminated = True + self.assertFalse(self.processor.accept_token(123)) + + def test_accept_token_when_matcher_stopped(self): + """测试当匹配器停止时,accept_token返回False并更新状态""" + self.mock_matcher.is_stopped.return_value = True + self.assertFalse(self.processor.accept_token(123)) + self.assertTrue(self.processor.is_terminated) + + def test_accept_token_is_eos(self): + """测试接收到EOS token时的行为""" + self.mock_matcher.is_stopped.return_value = False + self.assertTrue(self.processor.accept_token(self.mock_tokenizer.eos_token)) + self.assertTrue(self.processor.is_terminated) + + def test_accept_token_consumes_and_succeeds(self): + """测试成功消费一个token""" + self.mock_matcher.is_stopped.return_value = False + self.mock_matcher.consume_tokens.return_value = True + self.assertTrue(self.processor.accept_token(123)) + self.mock_matcher.consume_tokens.assert_called_once_with([123]) + self.mock_matcher.get_error.assert_called_once() + + def test_accept_token_consumes_and_fails(self): + """测试消费一个token失败""" + self.mock_matcher.is_stopped.return_value = False + self.mock_matcher.consume_tokens.return_value = False + self.assertFalse(self.processor.accept_token(123)) + self.mock_matcher.consume_tokens.assert_called_once_with([123]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/model_executor/test_gidance_backend.py b/tests/model_executor/test_gidance_backend.py new file mode 100644 index 00000000000..c7af2a1be50 --- /dev/null +++ b/tests/model_executor/test_gidance_backend.py @@ -0,0 +1,213 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +from unittest.mock import MagicMock, patch + +from fastdeploy.model_executor.guided_decoding import BackendBase + +# 导入要测试的模块 +from fastdeploy.model_executor.guided_decoding.guidance_backend import ( + LLGuidanceBackend, + LLGuidanceChecker, + LLGuidanceProcessor, + process_for_additional_properties, +) + + +class TestProcessForAdditionalProperties(unittest.TestCase): + def test_process_json_string(self): + # 测试字符串输入 + json_str = '{"type": "object", "properties": {"name": {"type": "string"}}}' + result = process_for_additional_properties(json_str) + self.assertFalse(result["additionalProperties"]) + + def test_process_json_dict(self): + # 测试字典输入 + json_dict = {"type": "object", "properties": {"name": {"type": "string"}}} + result = process_for_additional_properties(json_dict) + self.assertFalse(result["additionalProperties"]) + # 确保原始字典没有被修改 + self.assertNotIn("additionalProperties", json_dict) + + def test_nested_objects(self): + # 测试嵌套对象 + json_dict = { + "type": "object", + "properties": {"person": {"type": "object", "properties": {"name": {"type": "string"}}}}, + } + result = process_for_additional_properties(json_dict) + self.assertFalse(result["additionalProperties"]) + self.assertFalse(result["properties"]["person"]["additionalProperties"]) + + +@patch("llguidance.LLMatcher") +@patch("llguidance.LLTokenizer") +class TestLLGuidanceProcessor(unittest.TestCase): + def setUp(self): + self.vocab_size = 100 + self.batch_size = 2 + + def test_initialization(self, mock_tokenizer, mock_matcher): + # 测试初始化 + processor = LLGuidanceProcessor( + ll_matcher=mock_matcher, + ll_tokenizer=mock_tokenizer, + serialized_grammar="test_grammar", + vocab_size=self.vocab_size, + batch_size=self.batch_size, + ) + + self.assertEqual(processor.vocab_size, self.vocab_size) + self.assertEqual(processor.batch_size, self.batch_size) + self.assertFalse(processor.is_terminated) + + def test_reset(self, mock_tokenizer, mock_matcher): + # 测试重置功能 + processor = LLGuidanceProcessor( + ll_matcher=mock_matcher, + ll_tokenizer=mock_tokenizer, + serialized_grammar="test_grammar", + vocab_size=self.vocab_size, + batch_size=self.batch_size, + ) + + processor.is_terminated = True + processor.reset() + + mock_matcher.reset.assert_called_once() + self.assertFalse(processor.is_terminated) + + def test_accept_token(self, mock_tokenizer, mock_matcher): + # 测试接受token功能 + mock_matcher.is_stopped.return_value = False + mock_matcher.consume_tokens.return_value = True + mock_tokenizer.eos_token = 1 + + processor = LLGuidanceProcessor( + ll_matcher=mock_matcher, + ll_tokenizer=mock_tokenizer, + serialized_grammar="test_grammar", + vocab_size=self.vocab_size, + batch_size=self.batch_size, + ) + + # 正常token + result = processor.accept_token(0) + self.assertTrue(result) + mock_matcher.consume_tokens.assert_called_with([0]) + + # EOS token + result = processor.accept_token(1) + self.assertTrue(result) + self.assertTrue(processor.is_terminated) + + +@patch("llguidance.LLMatcher") +@patch("llguidance.hf.from_tokenizer") +class TestLLGuidanceBackend(unittest.TestCase): + def setUp(self): + # 创建一个模拟的FDConfig + self.fd_config = MagicMock() + self.fd_config.model_config.vocab_size = 100 + self.fd_config.scheduler_config.max_num_seqs = 2 + self.fd_config.structured_outputs_config.disable_any_whitespace = False + self.fd_config.structured_outputs_config.disable_additional_properties = False + + def test_initialization(self, mock_from_tokenizer, mock_matcher): + # 测试后端初始化 + with patch.object(BackendBase, "__init__", return_value=None): + backend = LLGuidanceBackend(fd_config=self.fd_config) + + self.assertEqual(backend.vocab_size, 100) + self.assertEqual(backend.batch_size, 2) + self.assertTrue(backend.any_whitespace) + self.assertFalse(backend.disable_additional_properties) + + @patch("llguidance.LLMatcher") + def test_create_processor(self, mock_matcher_class, mock_from_tokenizer, mock_matcher): + # 测试创建处理器 + with patch.object(LLGuidanceBackend, "__init__", return_value=None): + backend = LLGuidanceBackend(fd_config=None) # 参数不重要,因为 __init__ 被模拟了 + + # 手动设置所有需要的属性 + backend.hf_tokenizer = MagicMock() + backend.ll_tokenizer = MagicMock() + backend.vocab_size = 100 + backend.batch_size = 2 + backend.any_whitespace = True + backend.disable_additional_properties = False + + mock_matcher = MagicMock() + mock_matcher_class.return_value = mock_matcher + + processor = backend._create_processor("test_grammar") + + self.assertIsInstance(processor, LLGuidanceProcessor) + self.assertEqual(processor.vocab_size, 100) + self.assertEqual(processor.batch_size, 2) + + +@patch("llguidance.LLMatcher") +class TestLLGuidanceChecker(unittest.TestCase): + def test_schema_format_valid_json(self, mock_matcher): + # 设置mock + mock_matcher.grammar_from_json_schema.return_value = "compiled_grammar" + mock_matcher.validate_grammar.return_value = None + + # 创建checker和请求 + checker = LLGuidanceChecker() + request = MagicMock() + request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}' + + # 测试有效的JSON schema + result, error = checker.schema_format(request) + self.assertIsNone(error) + self.assertEqual(result.guided_json, '{"type": "object", "properties": {"name": {"type": "string"}}}') + + def test_schema_format_valid_regex(self, mock_matcher): + # 设置mock + mock_matcher.validate_grammar.return_value = None + + # 模拟llguidance.grammar_from + with patch("llguidance.grammar_from", return_value="compiled_regex"): + # 创建checker和请求 + checker = LLGuidanceChecker() + request = MagicMock() + request.guided_regex = "[a-z]+" + + # 测试有效的regex + result, error = checker.schema_format(request) + self.assertIsNone(error) + self.assertEqual(result.guided_regex, "compiled_regex") + + def test_schema_format_invalid(self, mock_matcher): + # 设置mock使验证失败 + mock_matcher.grammar_from_json_schema.side_effect = ValueError("Invalid schema") + + # 创建checker和请求 + checker = LLGuidanceChecker() + request = MagicMock() + request.guided_json = '{"invalid": "schema"}' + + # 测试无效的schema + result, error = checker.schema_format(request) + self.assertIsNotNone(error) + self.assertIn("Invalid format for guided decoding", error) + + +if __name__ == "__main__": + unittest.main() From 09fbf690866e88ec3b09bc40e8c9810217c672f4 Mon Sep 17 00:00:00 2001 From: ST-XX <15625257+ST-XX@users.noreply.github.com> Date: Wed, 19 Nov 2025 14:33:29 +0800 Subject: [PATCH 2/6] add requirements_guided_decoding.txt and doc --- docs/features/structured_outputs.md | 1 + docs/parameters.md | 2 +- docs/zh/features/structured_outputs.md | 1 + docs/zh/parameters.md | 2 +- requirements_guided_decoding.txt | 3 +++ 5 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 requirements_guided_decoding.txt diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index 7c81bf3e188..f58a6b1393f 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -7,6 +7,7 @@ Structured Outputs refer to predefined format constraints that force large language models to generate content strictly following specified structures. This feature significantly improves output controllability and is suitable for scenarios requiring precise format outputs (such as API calls, data parsing, code generation, etc.), while supporting dynamic grammar extensions to balance flexibility and standardization. FastDeploy supports using the [XGrammar](https://xgrammar.mlc.ai/docs/) backend to generate structured outputs. +FastDeploy supports using the [LLguidance](https://github.com/guidance-ai/llguidance) backend to generate structured outputs. Supported output formats: diff --git a/docs/parameters.md b/docs/parameters.md index c26e471625c..3a104f65482 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -44,7 +44,7 @@ When using FastDeploy to deploy models (including offline inference and service | ```disable_sequence_parallel_moe``` | `bool` | Disable sequence parallel moe, default: False | | ```splitwise_role``` | `str` | Whether to enable splitwise inference, default value: mixed, supported parameters: ["mixed", "decode", "prefill"] | | ```innode_prefill_ports``` | `str` | Internal engine startup ports for prefill instances (only required for single-machine PD separation), default: None | -| ```guided_decoding_backend``` | `str` | Specify the guided decoding backend to use, supports `auto`, `xgrammar`, `off`, default: `off` | +| ```guided_decoding_backend``` | `str` | Specify the guided decoding backend to use, supports `auto`, `xgrammar`, `guidance`, `off`, default: `off` | | ```guided_decoding_disable_any_whitespace``` | `bool` | Whether to disable whitespace generation during guided decoding, default: False | | ```speculative_config``` | `dict[str]` | Speculative decoding configuration, only supports standard format JSON string, default: None | | ```dynamic_load_weight``` | `int` | Whether to enable dynamic weight loading, default: 0 | diff --git a/docs/zh/features/structured_outputs.md b/docs/zh/features/structured_outputs.md index 50c010c1498..f7503c3a6e0 100644 --- a/docs/zh/features/structured_outputs.md +++ b/docs/zh/features/structured_outputs.md @@ -7,6 +7,7 @@ Structured Outputs 是指通过预定义格式约束,使大模型生成内容严格遵循指定结构。该功能可显著提升生成结果的可控性,适用于需要精确格式输出的场景(如API调用、数据解析、代码生成等),同时支持动态语法扩展,平衡灵活性与规范性。 FastDeploy 支持使用 [XGrammar](https://xgrammar.mlc.ai/docs/) 后端生成结构化输出。 +FastDeploy 支持使用 [LLguidance](https://github.com/guidance-ai/llguidance) 后端生成结构化输出。 支持输出格式 diff --git a/docs/zh/parameters.md b/docs/zh/parameters.md index 841441ab2d5..6d745d668c2 100644 --- a/docs/zh/parameters.md +++ b/docs/zh/parameters.md @@ -42,7 +42,7 @@ | ```disable_sequence_parallel_moe``` | `bool` | 禁止在TP+EP中使用序列并行优化, default: False | | ```splitwise_role``` | `str` | 是否开启splitwise推理,默认值mixed, 支持参数为["mixed", "decode", "prefill"] | | ```innode_prefill_ports``` | `str` | prefill 实例内部引擎启动端口 (仅单机PD分离需要),默认值None | -| ```guided_decoding_backend``` | `str` | 指定要使用的guided decoding后端,支持 `auto`、`xgrammar`、`off`, 默认为 `off` | +| ```guided_decoding_backend``` | `str` | 指定要使用的guided decoding后端,支持 `auto`、`xgrammar`、 `guidance`、`off`, 默认为 `off` | | ```guided_decoding_disable_any_whitespace``` | `bool` | guided decoding期间是否禁止生成空格,默认False | | ```speculative_config``` | `dict[str]` | 投机解码配置,仅支持标准格式json字符串,默认为None | | ```dynamic_load_weight``` | `int` | 是否动态加载权重,默认0 | diff --git a/requirements_guided_decoding.txt b/requirements_guided_decoding.txt new file mode 100644 index 00000000000..627ecc25e1a --- /dev/null +++ b/requirements_guided_decoding.txt @@ -0,0 +1,3 @@ +xgrammar==0.1.25 +llguidance==1.3.0 +torch==2.8.0 From 4dde8e7049f3bbaf5c2e57d2af8dad48978b2706 Mon Sep 17 00:00:00 2001 From: ST-XX <15625257+ST-XX@users.noreply.github.com> Date: Thu, 20 Nov 2025 11:00:45 +0800 Subject: [PATCH 3/6] fix test_guidance_*.py --- fastdeploy/config.py | 23 +++++++++++---------- tests/layers/test_guidance_checker.py | 27 ++++++++++++++++++++++--- tests/layers/test_guidance_processor.py | 4 ++-- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index e9d537c190b..6473a5cc392 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1601,18 +1601,19 @@ def postprocess(self): elif self.structured_outputs_config.guided_decoding_backend in ["auto", "xgrammar"]: self.structured_outputs_config.guided_decoding_backend = "xgrammar" elif self.structured_outputs_config.guided_decoding_backend == "guidance": - pass - # try: - # import llguidance - # import llguidance.hf - # import llguidance.torch - # except ImportError: - # raise ImportError( - # "The 'xgrammar' package is required for using xgrammar as the guided decoding backend. " - # "Please install it with 'pip install xgrammar'." - # ) + try: + import llguidance.torch + + llguidance.torch + except ImportError: + raise ImportError( + "The 'llguidance' package is required for using guidance as the guided decoding backend. " + "Please install it via the appropriate method." + ) else: - raise NotImplementedError(f"{self.structured_outputs_config.guided_decoding_backend}") + raise NotImplementedError( + f"Guided decoding backend '{self.structured_outputs_config.guided_decoding_backend}' is not implemented. [auto, xgrammar, guidance, off]" + ) if self.model_config.enable_mm: if self.cache_config.max_encoder_cache is None or self.cache_config.max_encoder_cache < 0: diff --git a/tests/layers/test_guidance_checker.py b/tests/layers/test_guidance_checker.py index 286d72ed804..7ea58d2a267 100644 --- a/tests/layers/test_guidance_checker.py +++ b/tests/layers/test_guidance_checker.py @@ -1,11 +1,26 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + import json +import sys import unittest from unittest.mock import MagicMock, patch import pytest -from fastdeploy.model_executor.guided_decoding.guidance_backend import LLGuidanceChecker - # 检查是否可以导入llguidance HAS_LLGUIDANCE = False try: @@ -14,7 +29,10 @@ llguidance HAS_LLGUIDANCE = True except ImportError: - pass + mock_llguidance = MagicMock() + mock_torch = MagicMock() + sys.modules["llguidance"] = mock_llguidance + sys.modules["torch"] = mock_torch @pytest.fixture @@ -29,6 +47,9 @@ def llguidance_checker_with_options(): return LLGuidanceChecker(disable_any_whitespace=True) +from fastdeploy.model_executor.guided_decoding.guidance_backend import LLGuidanceChecker + + def MockRequest(): request = MagicMock() request.guided_json = None diff --git a/tests/layers/test_guidance_processor.py b/tests/layers/test_guidance_processor.py index 2ec96c67a00..5cd61249a0f 100644 --- a/tests/layers/test_guidance_processor.py +++ b/tests/layers/test_guidance_processor.py @@ -143,8 +143,8 @@ def test_accept_token_when_terminated(self): def test_accept_token_when_matcher_stopped(self): """测试当匹配器停止时,accept_token返回False并更新状态""" self.mock_matcher.is_stopped.return_value = True - self.assertFalse(self.processor.accept_token(123)) - self.assertTrue(self.processor.is_terminated) + self.assertTrue(self.processor.accept_token(123)) + self.assertFalse(self.processor.is_terminated) def test_accept_token_is_eos(self): """测试接收到EOS token时的行为""" From fea0e38428d22119c4563640af99245974e89041 Mon Sep 17 00:00:00 2001 From: ST-XX <15625257+ST-XX@users.noreply.github.com> Date: Thu, 20 Nov 2025 19:02:58 +0800 Subject: [PATCH 4/6] fix test_guidance_*.py && mv --- .../test_guidance_backend.py} | 62 ++++--------------- .../guided_decoding}/test_guidance_checker.py | 0 .../test_guidance_processor.py | 0 3 files changed, 11 insertions(+), 51 deletions(-) rename tests/model_executor/{test_gidance_backend.py => guided_decoding/test_guidance_backend.py} (74%) rename tests/{layers => model_executor/guided_decoding}/test_guidance_checker.py (100%) rename tests/{layers => model_executor/guided_decoding}/test_guidance_processor.py (100%) diff --git a/tests/model_executor/test_gidance_backend.py b/tests/model_executor/guided_decoding/test_guidance_backend.py similarity index 74% rename from tests/model_executor/test_gidance_backend.py rename to tests/model_executor/guided_decoding/test_guidance_backend.py index c7af2a1be50..134e6901ff2 100644 --- a/tests/model_executor/test_gidance_backend.py +++ b/tests/model_executor/guided_decoding/test_guidance_backend.py @@ -14,15 +14,22 @@ # limitations under the License. """ +import sys import unittest from unittest.mock import MagicMock, patch from fastdeploy.model_executor.guided_decoding import BackendBase +mock_llguidance = MagicMock() +mock_llguidancehf = MagicMock() +mock_torch = MagicMock() +sys.modules["llguidance"] = mock_llguidance +sys.modules["llguidance.hf"] = mock_llguidancehf +sys.modules["torch"] = mock_torch + # 导入要测试的模块 from fastdeploy.model_executor.guided_decoding.guidance_backend import ( LLGuidanceBackend, - LLGuidanceChecker, LLGuidanceProcessor, process_for_additional_properties, ) @@ -126,16 +133,17 @@ def setUp(self): self.fd_config.scheduler_config.max_num_seqs = 2 self.fd_config.structured_outputs_config.disable_any_whitespace = False self.fd_config.structured_outputs_config.disable_additional_properties = False + self.fd_config.structured_outputs_config.reasoning_parser = None def test_initialization(self, mock_from_tokenizer, mock_matcher): # 测试后端初始化 - with patch.object(BackendBase, "__init__", return_value=None): + mock_tokenizer = MagicMock() + with patch.object(BackendBase, "_get_tokenizer_hf", return_value=mock_tokenizer): backend = LLGuidanceBackend(fd_config=self.fd_config) self.assertEqual(backend.vocab_size, 100) self.assertEqual(backend.batch_size, 2) self.assertTrue(backend.any_whitespace) - self.assertFalse(backend.disable_additional_properties) @patch("llguidance.LLMatcher") def test_create_processor(self, mock_matcher_class, mock_from_tokenizer, mock_matcher): @@ -161,53 +169,5 @@ def test_create_processor(self, mock_matcher_class, mock_from_tokenizer, mock_ma self.assertEqual(processor.batch_size, 2) -@patch("llguidance.LLMatcher") -class TestLLGuidanceChecker(unittest.TestCase): - def test_schema_format_valid_json(self, mock_matcher): - # 设置mock - mock_matcher.grammar_from_json_schema.return_value = "compiled_grammar" - mock_matcher.validate_grammar.return_value = None - - # 创建checker和请求 - checker = LLGuidanceChecker() - request = MagicMock() - request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}' - - # 测试有效的JSON schema - result, error = checker.schema_format(request) - self.assertIsNone(error) - self.assertEqual(result.guided_json, '{"type": "object", "properties": {"name": {"type": "string"}}}') - - def test_schema_format_valid_regex(self, mock_matcher): - # 设置mock - mock_matcher.validate_grammar.return_value = None - - # 模拟llguidance.grammar_from - with patch("llguidance.grammar_from", return_value="compiled_regex"): - # 创建checker和请求 - checker = LLGuidanceChecker() - request = MagicMock() - request.guided_regex = "[a-z]+" - - # 测试有效的regex - result, error = checker.schema_format(request) - self.assertIsNone(error) - self.assertEqual(result.guided_regex, "compiled_regex") - - def test_schema_format_invalid(self, mock_matcher): - # 设置mock使验证失败 - mock_matcher.grammar_from_json_schema.side_effect = ValueError("Invalid schema") - - # 创建checker和请求 - checker = LLGuidanceChecker() - request = MagicMock() - request.guided_json = '{"invalid": "schema"}' - - # 测试无效的schema - result, error = checker.schema_format(request) - self.assertIsNotNone(error) - self.assertIn("Invalid format for guided decoding", error) - - if __name__ == "__main__": unittest.main() diff --git a/tests/layers/test_guidance_checker.py b/tests/model_executor/guided_decoding/test_guidance_checker.py similarity index 100% rename from tests/layers/test_guidance_checker.py rename to tests/model_executor/guided_decoding/test_guidance_checker.py diff --git a/tests/layers/test_guidance_processor.py b/tests/model_executor/guided_decoding/test_guidance_processor.py similarity index 100% rename from tests/layers/test_guidance_processor.py rename to tests/model_executor/guided_decoding/test_guidance_processor.py From 9711ede8aef62c3ac788fe48140fb1a0e1370b05 Mon Sep 17 00:00:00 2001 From: ST-XX <15625257+ST-XX@users.noreply.github.com> Date: Fri, 21 Nov 2025 11:31:39 +0800 Subject: [PATCH 5/6] fix llguidance choice --- fastdeploy/model_executor/guided_decoding/guidance_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/guided_decoding/guidance_backend.py b/fastdeploy/model_executor/guided_decoding/guidance_backend.py index 3b0a22e87b1..613256e81cc 100644 --- a/fastdeploy/model_executor/guided_decoding/guidance_backend.py +++ b/fastdeploy/model_executor/guided_decoding/guidance_backend.py @@ -297,7 +297,8 @@ def schema_format(self, request: Request) -> Tuple[Request, Optional[str]]: if request.guided_regex: request.guided_regex = guidance_grm elif request.guided_choice: - request.guided_choice = guidance_grm + request.guided_grammar = guidance_grm + request.guided_choice = None elif request.guided_grammar: request.guided_grammar = guidance_grm elif request.guided_json: From 19a423d074f8ea8d2d3e302697eadcc9e8618f13 Mon Sep 17 00:00:00 2001 From: ST-XX <15625257+ST-XX@users.noreply.github.com> Date: Wed, 26 Nov 2025 11:31:25 +0800 Subject: [PATCH 6/6] test_guidance_* --- .../guided_decoding/test_guidance_backend.py | 28 ++-- .../guided_decoding/test_guidance_checker.py | 128 ++++++++++-------- .../test_guidance_processor.py | 50 +++---- 3 files changed, 107 insertions(+), 99 deletions(-) diff --git a/tests/model_executor/guided_decoding/test_guidance_backend.py b/tests/model_executor/guided_decoding/test_guidance_backend.py index 134e6901ff2..a545aba800e 100644 --- a/tests/model_executor/guided_decoding/test_guidance_backend.py +++ b/tests/model_executor/guided_decoding/test_guidance_backend.py @@ -27,7 +27,7 @@ sys.modules["llguidance.hf"] = mock_llguidancehf sys.modules["torch"] = mock_torch -# 导入要测试的模块 +# Import the module to be tested from fastdeploy.model_executor.guided_decoding.guidance_backend import ( LLGuidanceBackend, LLGuidanceProcessor, @@ -37,21 +37,21 @@ class TestProcessForAdditionalProperties(unittest.TestCase): def test_process_json_string(self): - # 测试字符串输入 + # Test string input json_str = '{"type": "object", "properties": {"name": {"type": "string"}}}' result = process_for_additional_properties(json_str) self.assertFalse(result["additionalProperties"]) def test_process_json_dict(self): - # 测试字典输入 + # Test dictionary input json_dict = {"type": "object", "properties": {"name": {"type": "string"}}} result = process_for_additional_properties(json_dict) self.assertFalse(result["additionalProperties"]) - # 确保原始字典没有被修改 + # Ensure the original dictionary is not modified self.assertNotIn("additionalProperties", json_dict) def test_nested_objects(self): - # 测试嵌套对象 + # Test nested objects json_dict = { "type": "object", "properties": {"person": {"type": "object", "properties": {"name": {"type": "string"}}}}, @@ -69,7 +69,7 @@ def setUp(self): self.batch_size = 2 def test_initialization(self, mock_tokenizer, mock_matcher): - # 测试初始化 + # Test initialization processor = LLGuidanceProcessor( ll_matcher=mock_matcher, ll_tokenizer=mock_tokenizer, @@ -83,7 +83,7 @@ def test_initialization(self, mock_tokenizer, mock_matcher): self.assertFalse(processor.is_terminated) def test_reset(self, mock_tokenizer, mock_matcher): - # 测试重置功能 + # Test reset functionality processor = LLGuidanceProcessor( ll_matcher=mock_matcher, ll_tokenizer=mock_tokenizer, @@ -99,7 +99,7 @@ def test_reset(self, mock_tokenizer, mock_matcher): self.assertFalse(processor.is_terminated) def test_accept_token(self, mock_tokenizer, mock_matcher): - # 测试接受token功能 + # Test accept_token functionality mock_matcher.is_stopped.return_value = False mock_matcher.consume_tokens.return_value = True mock_tokenizer.eos_token = 1 @@ -112,7 +112,7 @@ def test_accept_token(self, mock_tokenizer, mock_matcher): batch_size=self.batch_size, ) - # 正常token + # Normal token result = processor.accept_token(0) self.assertTrue(result) mock_matcher.consume_tokens.assert_called_with([0]) @@ -127,7 +127,7 @@ def test_accept_token(self, mock_tokenizer, mock_matcher): @patch("llguidance.hf.from_tokenizer") class TestLLGuidanceBackend(unittest.TestCase): def setUp(self): - # 创建一个模拟的FDConfig + # Create a mock FDConfig self.fd_config = MagicMock() self.fd_config.model_config.vocab_size = 100 self.fd_config.scheduler_config.max_num_seqs = 2 @@ -136,7 +136,7 @@ def setUp(self): self.fd_config.structured_outputs_config.reasoning_parser = None def test_initialization(self, mock_from_tokenizer, mock_matcher): - # 测试后端初始化 + # Test backend initialization mock_tokenizer = MagicMock() with patch.object(BackendBase, "_get_tokenizer_hf", return_value=mock_tokenizer): backend = LLGuidanceBackend(fd_config=self.fd_config) @@ -147,11 +147,11 @@ def test_initialization(self, mock_from_tokenizer, mock_matcher): @patch("llguidance.LLMatcher") def test_create_processor(self, mock_matcher_class, mock_from_tokenizer, mock_matcher): - # 测试创建处理器 + # Test creating a processor with patch.object(LLGuidanceBackend, "__init__", return_value=None): - backend = LLGuidanceBackend(fd_config=None) # 参数不重要,因为 __init__ 被模拟了 + backend = LLGuidanceBackend(fd_config=None) # Arguments are not important because __init__ is mocked - # 手动设置所有需要的属性 + # Manually set all required attributes backend.hf_tokenizer = MagicMock() backend.ll_tokenizer = MagicMock() backend.vocab_size = 100 diff --git a/tests/model_executor/guided_decoding/test_guidance_checker.py b/tests/model_executor/guided_decoding/test_guidance_checker.py index 7ea58d2a267..d1dd09c61dc 100644 --- a/tests/model_executor/guided_decoding/test_guidance_checker.py +++ b/tests/model_executor/guided_decoding/test_guidance_checker.py @@ -21,7 +21,7 @@ import pytest -# 检查是否可以导入llguidance +# Check if llguidance can be imported HAS_LLGUIDANCE = False try: import llguidance @@ -37,13 +37,13 @@ @pytest.fixture def llguidance_checker(): - """返回一个LLGuidanceChecker实例供测试使用""" + """Return an LLGuidanceChecker instance for testing.""" return LLGuidanceChecker() @pytest.fixture def llguidance_checker_with_options(): - """返回一个配置了特定选项的LLGuidanceChecker实例""" + """Return an LLGuidanceChecker instance configured with specific options.""" return LLGuidanceChecker(disable_any_whitespace=True) @@ -62,12 +62,12 @@ def MockRequest(): class TestLLGuidanceCheckerMocked: - """使用Mock测试LLGuidanceChecker,适用于没有llguidance库的环境""" + """Test LLGuidanceChecker using Mock, suitable for environments without the llguidance library.""" @patch("llguidance.LLMatcher.grammar_from_json_schema") @patch("llguidance.LLMatcher.validate_grammar") def test_serialize_guided_json_as_string(self, mock_validate, mock_from_schema, llguidance_checker): - """测试处理guided_json字符串类型""" + """Test processing guided_json string type.""" mock_from_schema.return_value = "serialized_grammar" mock_validate.return_value = None @@ -82,7 +82,7 @@ def test_serialize_guided_json_as_string(self, mock_validate, mock_from_schema, @patch("llguidance.LLMatcher.grammar_from_json_schema") @patch("llguidance.LLMatcher.validate_grammar") def test_serialize_guided_json_as_dict(self, mock_validate, mock_from_schema, llguidance_checker): - """测试处理guided_json字典类型""" + """Test processing guided_json dictionary type.""" mock_from_schema.return_value = "serialized_grammar" mock_validate.return_value = None @@ -92,13 +92,13 @@ def test_serialize_guided_json_as_dict(self, mock_validate, mock_from_schema, ll grammar = llguidance_checker.serialize_guidance_grammar(request) mock_from_schema.assert_called_once() - assert isinstance(request.guided_json, dict) # 验证字典已转换为字符串 + assert isinstance(request.guided_json, dict) # Verify that the dictionary has been converted to a string assert grammar == "serialized_grammar" @patch("llguidance.LLMatcher.grammar_from_json_schema") @patch("llguidance.LLMatcher.validate_grammar") def test_serialize_guided_json_object(self, mock_validate, mock_from_schema, llguidance_checker): - """测试处理guided_json_object""" + """Test processing guided_json_object.""" mock_from_schema.return_value = "serialized_grammar" mock_validate.return_value = None @@ -114,7 +114,7 @@ def test_serialize_guided_json_object(self, mock_validate, mock_from_schema, llg @patch("llguidance.grammar_from") @patch("llguidance.LLMatcher.validate_grammar") def test_serialize_guided_regex(self, mock_validate, mock_grammar_from, llguidance_checker): - """测试处理guided_regex""" + """Test processing guided_regex.""" mock_grammar_from.return_value = "serialized_regex_grammar" mock_validate.return_value = None @@ -129,7 +129,7 @@ def test_serialize_guided_regex(self, mock_validate, mock_grammar_from, llguidan @patch("llguidance.grammar_from") @patch("llguidance.LLMatcher.validate_grammar") def test_serialize_guided_choice(self, mock_validate, mock_grammar_from, llguidance_checker): - """测试处理guided_choice""" + """Test processing guided_choice.""" mock_grammar_from.return_value = "serialized_choice_grammar" mock_validate.return_value = None @@ -144,7 +144,7 @@ def test_serialize_guided_choice(self, mock_validate, mock_grammar_from, llguida @patch("llguidance.grammar_from") @patch("llguidance.LLMatcher.validate_grammar") def test_serialize_guided_grammar(self, mock_validate, mock_grammar_from, llguidance_checker): - """测试处理guided_grammar""" + """Test processing guided_grammar.""" mock_grammar_from.return_value = "serialized_grammar_spec" mock_validate.return_value = None @@ -159,8 +159,8 @@ def test_serialize_guided_grammar(self, mock_validate, mock_grammar_from, llguid @patch("llguidance.StructTag") @patch("llguidance.LLMatcher.grammar_from_json_schema") def test_serialize_structural_tag(self, mock_from_schema, mock_struct_tag, llguidance_checker): - """测试处理structural_tag""" - # 配置mock对象 + """Test processing structural_tag.""" + # Configure mock objects mock_from_schema.return_value = "serialized_schema" mock_struct_tag.to_grammar.return_value = "serialized_structural_grammar" struct_tag_instance = MagicMock() @@ -181,7 +181,7 @@ def test_serialize_structural_tag(self, mock_from_schema, mock_struct_tag, llgui @patch("llguidance.StructTag") def test_serialize_structural_tag_missing_trigger(self, mock_struct_tag, llguidance_checker): - """测试处理structural_tag中缺少触发器的情况""" + """Test processing structural_tag when a trigger is missing.""" request = MockRequest() request.structural_tag = { "triggers": [""], @@ -193,7 +193,7 @@ def test_serialize_structural_tag_missing_trigger(self, mock_struct_tag, llguida @patch("llguidance.StructTag") def test_serialize_structural_tag_empty_structures(self, mock_struct_tag, llguidance_checker): - """测试处理structural_tag中结构为空的情况""" + """Test processing structural_tag when structures are empty.""" request = MockRequest() request.structural_tag = {"triggers": [""], "structures": []} @@ -201,9 +201,9 @@ def test_serialize_structural_tag_empty_structures(self, mock_struct_tag, llguid llguidance_checker.serialize_guidance_grammar(request) def test_serialize_invalid_grammar_type(self, llguidance_checker): - """测试处理无效的语法类型""" + """Test processing invalid grammar types.""" request = MockRequest() - # 没有设置任何语法类型 + # No grammar type set with pytest.raises(ValueError, match="grammar is not of valid supported types"): llguidance_checker.serialize_guidance_grammar(request) @@ -211,7 +211,7 @@ def test_serialize_invalid_grammar_type(self, llguidance_checker): @patch("llguidance.LLMatcher.grammar_from_json_schema") @patch("llguidance.LLMatcher.validate_grammar") def test_schema_format_valid_json(self, mock_validate, mock_from_schema, llguidance_checker): - """测试schema_format方法处理有效的JSON""" + """Test schema_format method processing valid JSON.""" mock_from_schema.return_value = "serialized_grammar" mock_validate.return_value = None @@ -226,7 +226,7 @@ def test_schema_format_valid_json(self, mock_validate, mock_from_schema, llguida @patch("llguidance.LLMatcher.grammar_from_json_schema") @patch("llguidance.LLMatcher.validate_grammar") def test_schema_format_invalid_grammar(self, mock_validate, mock_from_schema, llguidance_checker): - """测试schema_format方法处理无效的语法""" + """Test schema_format method processing invalid grammar.""" mock_from_schema.return_value = "serialized_grammar" mock_validate.return_value = "Invalid grammar" @@ -240,7 +240,7 @@ def test_schema_format_invalid_grammar(self, mock_validate, mock_from_schema, ll @patch("llguidance.LLMatcher.grammar_from_json_schema") def test_schema_format_json_decode_error(self, mock_from_schema, llguidance_checker): - """测试schema_format方法处理JSON解码错误""" + """Test schema_format method processing JSON decode error.""" mock_from_schema.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) request = MockRequest() @@ -253,7 +253,7 @@ def test_schema_format_json_decode_error(self, mock_from_schema, llguidance_chec @patch("llguidance.LLMatcher.grammar_from_json_schema") def test_schema_format_unexpected_error(self, mock_from_schema, llguidance_checker): - """测试schema_format方法处理意外错误""" + """Test schema_format method processing unexpected errors.""" mock_from_schema.side_effect = Exception("Unexpected error") request = MockRequest() @@ -265,7 +265,7 @@ def test_schema_format_unexpected_error(self, mock_from_schema, llguidance_check assert "An unexpected error occurred during schema validation" in error def test_init_with_disable_whitespace(self, llguidance_checker_with_options): - """测试初始化时设置disable_any_whitespace选项""" + """Test setting the disable_any_whitespace option during initialization.""" assert llguidance_checker_with_options.any_whitespace is False assert llguidance_checker_with_options.disable_additional_properties is True assert LLGuidanceChecker(disable_any_whitespace=True).any_whitespace is False @@ -281,24 +281,24 @@ def test_init_with_disable_whitespace(self, llguidance_checker_with_options): assert LLGuidanceChecker().disable_additional_properties is False -@pytest.mark.skipif(not HAS_LLGUIDANCE, reason="llguidance库未安装,跳过实际依赖测试") +@pytest.mark.skipif(not HAS_LLGUIDANCE, reason="llguidance library not installed, skipping actual dependency tests") class TestLLGuidanceCheckerReal: - """使用实际的llguidance库进行测试,适用于开发环境""" + """Test using the actual llguidance library, suitable for development environments.""" def test_serialize_guided_json_string_real(self, llguidance_checker): - """使用实际库测试处理guided_json字符串""" + """Test processing guided_json string using the actual library.""" request = MockRequest() request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}' grammar = llguidance_checker.serialize_guidance_grammar(request) - # 验证返回的grammar是否是一个有效的字符串 + # Verify if the returned grammar is a valid string assert isinstance(grammar, str) assert len(grammar) > 0 print("grammar", grammar) def test_serialize_guided_json_dict_real(self, llguidance_checker): - """使用实际库测试处理guided_json字典""" + """Test processing guided_json dictionary using the actual library.""" request = MockRequest() request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}} @@ -309,7 +309,7 @@ def test_serialize_guided_json_dict_real(self, llguidance_checker): assert len(grammar) > 0 def test_serialize_guided_json_object_real(self, llguidance_checker): - """使用实际库测试处理guided_json_object""" + """Test processing guided_json_object using the actual library.""" request = MockRequest() request.guided_json_object = True @@ -320,7 +320,7 @@ def test_serialize_guided_json_object_real(self, llguidance_checker): assert len(grammar) > 0 def test_serialize_guided_regex_real(self, llguidance_checker): - """使用实际库测试处理guided_regex""" + """Test processing guided_regex using the actual library.""" request = MockRequest() request.guided_regex = "[a-zA-Z]+" @@ -330,7 +330,7 @@ def test_serialize_guided_regex_real(self, llguidance_checker): assert len(grammar) > 0 def test_serialize_guided_choice_real(self, llguidance_checker): - """使用实际库测试处理guided_choice""" + """Test processing guided_choice using the actual library.""" request = MockRequest() request.guided_choice = ["option1", "option2"] @@ -340,9 +340,9 @@ def test_serialize_guided_choice_real(self, llguidance_checker): assert len(grammar) > 0 def test_serialize_guided_grammar_real(self, llguidance_checker): - """使用实际库测试处理guided_grammar""" + """Test processing guided_grammar using the actual library.""" request = MockRequest() - # 使用简单的CFG文法示例 + # Use a simple CFG grammar example request.guided_grammar = """ root ::= greeting name greeting ::= "Hello" | "Hi" @@ -355,7 +355,7 @@ def test_serialize_guided_grammar_real(self, llguidance_checker): assert len(grammar) > 0 def test_serialize_structural_tag_real(self, llguidance_checker): - """使用实际库测试处理structural_tag""" + """Test processing structural_tag using the actual library.""" request = MockRequest() request.structural_tag = { "triggers": [""], @@ -368,7 +368,7 @@ def test_serialize_structural_tag_real(self, llguidance_checker): assert len(grammar) > 0 def test_schema_format_valid_json_real(self, llguidance_checker): - """使用实际库测试schema_format方法处理有效的JSON""" + """Test schema_format method processing valid JSON using the actual library.""" request = MockRequest() request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}' @@ -379,7 +379,7 @@ def test_schema_format_valid_json_real(self, llguidance_checker): assert result_request.guided_json != '{"type": "object", "properties": {"name": {"type": "string"}}}' def test_schema_format_invalid_json_real(self, llguidance_checker): - """使用实际库测试schema_format方法处理无效的JSON""" + """Test schema_format method processing invalid JSON using the actual library.""" request = MockRequest() request.guided_json = "{invalid json}" @@ -389,8 +389,8 @@ def test_schema_format_invalid_json_real(self, llguidance_checker): assert "Invalid format for guided decoding" in error def test_whitespace_flexibility_option_real(self): - """使用实际库测试whitespace灵活性选项的影响""" - # 创建两个不同配置的实例 + """Test the impact of the whitespace flexibility option using the actual library.""" + # Create two instances with different configurations flexible = LLGuidanceChecker(disable_any_whitespace=False) strict = LLGuidanceChecker(disable_any_whitespace=True) @@ -405,11 +405,11 @@ def test_whitespace_flexibility_option_real(self): print("grammar_flexible", grammar_flexible) print("grammar_strict", grammar_strict) - # 预期两种配置生成的语法应该不同 + # Expect grammars generated by the two configurations to be different assert grammar_flexible != grammar_strict def test_schema_format_guided_json_object_real(self, llguidance_checker): - """测试schema_format处理guided_json_object""" + """Test schema_format processing guided_json_object.""" request = MockRequest() request.guided_json_object = True @@ -419,7 +419,7 @@ def test_schema_format_guided_json_object_real(self, llguidance_checker): assert result_request is request def test_schema_format_guided_regex_real(self, llguidance_checker): - """测试schema_format处理有效的正则表达式""" + """Test schema_format processing valid regular expressions.""" request = MockRequest() request.guided_regex = r"[a-zA-Z0-9]+" @@ -427,12 +427,12 @@ def test_schema_format_guided_regex_real(self, llguidance_checker): assert error is None assert result_request is request - assert result_request.guided_regex != r"[a-zA-Z0-9]+" # 应该被转换为grammar格式 + assert result_request.guided_regex != r"[a-zA-Z0-9]+" # Should be converted to grammar format def test_schema_format_invalid_guided_regex_real(self, llguidance_checker): - """测试schema_format处理无效的正则表达式""" + """Test schema_format processing invalid regular expressions.""" request = MockRequest() - request.guided_regex = r"[" # 无效的正则表达式 + request.guided_regex = r"[" # Invalid regular expression result_request, error = llguidance_checker.schema_format(request) @@ -440,7 +440,7 @@ def test_schema_format_invalid_guided_regex_real(self, llguidance_checker): assert "Invalid format for guided decoding" in error def test_schema_format_guided_choice_real(self, llguidance_checker): - """测试schema_format处理guided_choice""" + """Test schema_format processing guided_choice.""" request = MockRequest() request.guided_choice = ["option1", "option2", "option3"] @@ -448,12 +448,16 @@ def test_schema_format_guided_choice_real(self, llguidance_checker): assert error is None assert result_request is request - assert result_request.guided_choice != ["option1", "option2", "option3"] # 应该被转换为grammar格式 + assert result_request.guided_choice != [ + "option1", + "option2", + "option3", + ] # Should be converted to grammar format def test_schema_format_guided_grammar_real(self, llguidance_checker): - """测试schema_format处理guided_grammar""" + """Test schema_format processing guided_grammar.""" request = MockRequest() - # 使用LLGuidance支持的正确语法格式 + # Use the correct grammar format supported by LLGuidance request.guided_grammar = """ start: number number: DIGIT+ @@ -467,7 +471,7 @@ def test_schema_format_guided_grammar_real(self, llguidance_checker): assert isinstance(result_request.guided_grammar, str) def test_schema_format_structural_tag_real(self, llguidance_checker): - """测试schema_format处理structural_tag""" + """Test schema_format processing structural_tag.""" request = MockRequest() request.structural_tag = { "triggers": ["```json"], @@ -486,7 +490,7 @@ def test_schema_format_structural_tag_real(self, llguidance_checker): assert result_request is request def test_schema_format_structural_tag_string_real(self, llguidance_checker): - """测试schema_format处理字符串形式的structural_tag""" + """Test schema_format processing structural_tag in string format.""" request = MockRequest() request.structural_tag = json.dumps( { @@ -507,12 +511,16 @@ def test_schema_format_structural_tag_string_real(self, llguidance_checker): assert result_request is request def test_schema_format_structural_tag_invalid_trigger_real(self, llguidance_checker): - """测试schema_format处理trigger无效的structural_tag""" + """Test schema_format processing structural_tag with invalid triggers.""" request = MockRequest() request.structural_tag = { - "triggers": ["```xml"], # 触发器与begin不匹配 + "triggers": ["```xml"], # Trigger does not match begin "structures": [ - {"begin": "```json", "schema": {"type": "object"}, "end": "```"} # 这里不包含任何triggers中的前缀 + { + "begin": "```json", + "schema": {"type": "object"}, + "end": "```", + } # Does not contain any prefix from triggers here ], } @@ -522,9 +530,9 @@ def test_schema_format_structural_tag_invalid_trigger_real(self, llguidance_chec assert "Invalid format for guided decoding" in error def test_schema_format_structural_tag_empty_structures_real(self, llguidance_checker): - """测试schema_format处理空structures的structural_tag""" + """Test schema_format processing structural_tag with empty structures.""" request = MockRequest() - request.structural_tag = {"triggers": ["```json"], "structures": []} # 空结构 + request.structural_tag = {"triggers": ["```json"], "structures": []} # Empty structure result_request, error = llguidance_checker.schema_format(request) @@ -532,7 +540,7 @@ def test_schema_format_structural_tag_empty_structures_real(self, llguidance_che assert "Invalid format for guided decoding" in error def test_schema_format_json_dict_real(self, llguidance_checker): - """测试schema_format处理字典形式的guided_json""" + """Test schema_format processing guided_json in dictionary format.""" request = MockRequest() request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}} @@ -542,7 +550,7 @@ def test_schema_format_json_dict_real(self, llguidance_checker): assert result_request is request def test_schema_format_disable_additional_properties_real(self): - """测试schema_format处理disable_additional_properties参数""" + """Test schema_format processing disable_additional_properties parameter.""" checker = LLGuidanceChecker(disable_additional_properties=True) request = MockRequest() request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}} @@ -553,11 +561,11 @@ def test_schema_format_disable_additional_properties_real(self): assert result_request is request def test_schema_format_unexpected_error_real(self, monkeypatch, llguidance_checker): - """测试schema_format处理意外错误""" + """Test schema_format processing unexpected errors.""" request = MockRequest() request.guided_json = '{"type": "object"}' - # 模拟意外异常 + # Mock unexpected exception def mock_serialize_grammar(*args, **kwargs): raise Exception("Unexpected error") @@ -569,9 +577,9 @@ def mock_serialize_grammar(*args, **kwargs): assert "An unexpected error occurred during schema validation" in error def test_schema_format_no_valid_grammar_real(self, llguidance_checker): - """测试schema_format处理没有有效语法的请求""" + """Test schema_format processing requests without valid grammar.""" request = MockRequest() - # 没有设置任何语法相关的属性 + # No grammar-related attributes set with pytest.raises(ValueError, match="grammar is not of valid supported types"): llguidance_checker.serialize_guidance_grammar(request) diff --git a/tests/model_executor/guided_decoding/test_guidance_processor.py b/tests/model_executor/guided_decoding/test_guidance_processor.py index 5cd61249a0f..cdeb91e378d 100644 --- a/tests/model_executor/guided_decoding/test_guidance_processor.py +++ b/tests/model_executor/guided_decoding/test_guidance_processor.py @@ -19,7 +19,7 @@ from unittest.mock import MagicMock, patch # --- Mocking Setup --- -# 优先模拟这些懒加载的模块,以便在未安装这些库的环境中进行测试。 +# Prioritize mocking these lazy-loaded modules to facilitate testing in environments where these libraries are not installed. mock_torch = MagicMock() mock_llguidance = MagicMock() mock_llguidance_hf = MagicMock() @@ -33,43 +33,43 @@ sys.modules["llguidance.hf"] = mock_llguidance_hf sys.modules["llguidance.torch"] = mock_llguidance_torch -# 模拟设置完成后,再导入需要测试的模块 +# Import the module to be tested after the mock setup is complete from fastdeploy.model_executor.guided_decoding.guidance_backend import ( LLGuidanceProcessor, ) def MockFDConfig(): - """创建一个用于测试的FDConfig模拟对象""" + """Create a mock FDConfig object for testing""" config = MagicMock() - # --- 修复点 1: 显式设置 model 为字符串,通过 HF 的验证 --- + # --- Fix point 1: Explicitly set model as a string to pass HF validation --- config.model_config.model = "test-model-path" - config.model_config.architectures = [] # 设置为空列表,防止迭代 Mock 出错 + config.model_config.architectures = [] # Set to empty list to prevent errors when iterating over the Mock config.model_config.vocab_size = 1000 config.scheduler_config.max_num_seqs = 4 config.structured_outputs_config.disable_any_whitespace = False - # 确保 backend 检查逻辑能通过 + # Ensure the backend check logic passes config.structured_outputs_config.guided_decoding_backend = "guidance" return config def MockHFTokenizer(): - """创建一个用于测试的Hugging Face Tokenizer模拟对象""" + """Create a mock Hugging Face Tokenizer object for testing""" return MagicMock() class TestLLGuidanceProcessorMocked(unittest.TestCase): """ - 使用Mock对LLGuidanceProcessor进行单元测试。 - 这个测试类适用于没有安装llguidance库的环境。 + Unit tests for LLGuidanceProcessor using Mock. + This test class is suitable for environments where the llguidance library is not installed. """ def setUp(self): - """为每个测试用例设置一个新的LLGuidanceProcessor实例""" + """Set up a new LLGuidanceProcessor instance for each test case""" self.mock_matcher = MagicMock() self.mock_tokenizer = MagicMock() - self.mock_tokenizer.eos_token = 2 # 示例EOS token ID + self.mock_tokenizer.eos_token = 2 # Example EOS token ID self.processor = LLGuidanceProcessor( ll_matcher=self.mock_matcher, ll_tokenizer=self.mock_tokenizer, @@ -80,7 +80,7 @@ def setUp(self): ) def test_init(self): - """测试LLGuidanceProcessor的构造函数""" + """Test the constructor of LLGuidanceProcessor""" self.assertIs(self.processor.matcher, self.mock_matcher) self.assertEqual(self.processor.vocab_size, 1000) self.assertEqual(self.processor.batch_size, 4) @@ -88,23 +88,23 @@ def test_init(self): @patch("fastdeploy.utils.llm_logger.warning") def test_check_error_logs_warning_once(self, mock_log_warning): - """测试_check_error方法在匹配器出错时能记录警告,且只记录一次""" + """Test that the _check_error method logs a warning when the matcher errors, and only logs it once""" self.mock_matcher.get_error.return_value = "A test error." - # 第一次调用,应该打印日志 + # First call, should log the message self.processor._check_error() mock_log_warning.assert_called_once_with("LLGuidance Matcher error: A test error.") - # 第二次调用,不应该重复打印 + # Second call, should not log repeatedly self.processor._check_error() mock_log_warning.assert_called_once() @patch("fastdeploy.model_executor.guided_decoding.guidance_backend.llguidance_torch") def test_allocate_token_bitmask(self, mock_backend_torch): """ - 测试token bitmask的分配。 - 注意:这里Patch的是guidance_backend模块中导入的llguidance_torch变量, - 而不是sys.modules里的全局mock,以解决LazyLoader导致的引用不一致问题。 + Test the allocation of token bitmask. + Note: We patch the llguidance_torch variable imported in the guidance_backend module here, + instead of the global mock in sys.modules, to resolve inconsistent references caused by LazyLoader. """ mock_backend_torch.allocate_token_bitmask.return_value = "fake_bitmask_tensor" @@ -115,7 +115,7 @@ def test_allocate_token_bitmask(self, mock_backend_torch): @patch("fastdeploy.model_executor.guided_decoding.guidance_backend.llguidance_torch") def test_fill_token_bitmask(self, mock_backend_torch): - """测试token bitmask的填充""" + """Test the filling of token bitmask""" mock_bitmask = MagicMock() self.processor.fill_token_bitmask(mock_bitmask, idx=2) @@ -124,7 +124,7 @@ def test_fill_token_bitmask(self, mock_backend_torch): self.mock_matcher.get_error.assert_called_once() def test_reset(self): - """测试处理器的状态重置""" + """Test the state reset of the processor""" self.processor.is_terminated = True self.processor._printed_error = True self.mock_matcher.get_error.return_value = "" @@ -136,24 +136,24 @@ def test_reset(self): self.assertFalse(self.processor._printed_error) def test_accept_token_when_terminated(self): - """测试当状态为is_terminated时,accept_token直接返回False""" + """Test that accept_token returns False immediately when status is is_terminated""" self.processor.is_terminated = True self.assertFalse(self.processor.accept_token(123)) def test_accept_token_when_matcher_stopped(self): - """测试当匹配器停止时,accept_token返回False并更新状态""" + """Test that accept_token returns False and updates status when the matcher is stopped""" self.mock_matcher.is_stopped.return_value = True self.assertTrue(self.processor.accept_token(123)) self.assertFalse(self.processor.is_terminated) def test_accept_token_is_eos(self): - """测试接收到EOS token时的行为""" + """Test the behavior when an EOS token is received""" self.mock_matcher.is_stopped.return_value = False self.assertTrue(self.processor.accept_token(self.mock_tokenizer.eos_token)) self.assertTrue(self.processor.is_terminated) def test_accept_token_consumes_and_succeeds(self): - """测试成功消费一个token""" + """Test successfully consuming a token""" self.mock_matcher.is_stopped.return_value = False self.mock_matcher.consume_tokens.return_value = True self.assertTrue(self.processor.accept_token(123)) @@ -161,7 +161,7 @@ def test_accept_token_consumes_and_succeeds(self): self.mock_matcher.get_error.assert_called_once() def test_accept_token_consumes_and_fails(self): - """测试消费一个token失败""" + """Test failing to consume a token""" self.mock_matcher.is_stopped.return_value = False self.mock_matcher.consume_tokens.return_value = False self.assertFalse(self.processor.accept_token(123))