Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,6 @@ def __post_init__(self):

if not current_platform.is_cuda() and not current_platform.is_xpu():
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if self.guided_decoding_backend != "off":
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0

if "PaddleOCR" in get_model_architecture(self.model, self.model_config_name):
envs.FD_ENABLE_MAX_PREFILL = 1
Expand Down
38 changes: 38 additions & 0 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,28 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
if hasattr(request, "pooling_params") and request.pooling_params is not None:
batch_pooling_params.append(request.pooling_params)

logits_info = None
prefill_tokens = []
if request.task_type.value == RequestType.PREFILL.value: # prefill task
# guided decoding
if (
request.guided_json is not None
or request.guided_regex is not None
or request.structural_tag is not None
or request.guided_grammar is not None
):
logits_info, schemata_key = self._init_logits_processor(request)
request.schemata_key = schemata_key

if self.scheduler_config.splitwise_role == "decode":
if (
hasattr(request, "prefill_end_index")
and hasattr(request, "prompt_token_ids")
and request.prefill_end_index > len(request.prompt_token_ids)
):
if hasattr(request, "output_token_ids"):
prefill_tokens.extend(request.output_token_ids)

prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
Expand Down Expand Up @@ -657,6 +678,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
# For logits processors
self.share_inputs["logits_processors_args"][idx] = request.get("logits_processors_args") or {}

self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens)

if len(multi_vision_inputs["images_lst"]) > 0:
self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs)

Expand Down Expand Up @@ -2041,6 +2064,21 @@ def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_
if self.share_inputs["step_idx"][idx] == 0:
prefill_done_idxs.append(idx)

if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if model_forward_batch is None:
return prefill_done_idxs

for task in model_forward_batch:
if task.task_type.value != RequestType.PREFILL.value:
continue
# in chunk prefill
if self.cache_config.enable_chunked_prefill:
if hasattr(task, "prefill_end_index") and hasattr(task, "prompt_token_ids"):
if len(task.prompt_token_ids) > task.prefill_end_index and idx in prefill_done_idxs:
prefill_done_idxs.remove(idx)

return prefill_done_idxs

if self.cache_config.enable_chunked_prefill:
if model_forward_batch is not None:
for task in model_forward_batch:
Expand Down
3 changes: 0 additions & 3 deletions fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,9 +932,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
if not current_platform.is_cuda() and not current_platform.is_xpu():
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if structured_outputs_config.guided_decoding_backend != "off":
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported guided_decoding.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0

if envs.ENABLE_V1_KVCACHE_SCHEDULER and args.splitwise_role == "prefill":
os.environ["PREFILL_NODE_ONE_STEP_STOP_V1"] = "1"
Expand Down
Loading