diff --git a/tests/v1/test_outputs.py b/tests/v1/test_outputs.py index af9df844249e..89d551e344cf 100644 --- a/tests/v1/test_outputs.py +++ b/tests/v1/test_outputs.py @@ -43,7 +43,7 @@ def test_slice_without_cu_num_generated_tokens(self): cu_num_generated_tokens=None, ) - sliced = logprobsLists.slice(1, 3) + sliced = logprobsLists.slice_request(1, num_positions=2) assert sliced.logprob_token_ids == [[2], [3]] assert sliced.logprobs == [[0.2], [0.3]] assert sliced.sampled_token_ranks == [2, 3] @@ -51,7 +51,7 @@ def test_slice_without_cu_num_generated_tokens(self): def test_slice_from_start(self): """Test slicing from the start position""" - sliced = self.logprobsLists.slice(0, 2) + sliced = self.logprobsLists.slice_request(0, num_positions=5) assert len(sliced.logprob_token_ids) == 5 assert sliced.logprob_token_ids == [ [1, 2], @@ -60,11 +60,11 @@ def test_slice_from_start(self): [7, 8], [9, 10], ] - assert sliced.cu_num_generated_tokens == [0, 2, 5] + assert sliced.cu_num_generated_tokens is None def test_slice_from_middle(self): """Test slicing from the middle position""" - sliced = self.logprobsLists.slice(1, 3) + sliced = self.logprobsLists.slice_request(1, num_positions=7) assert len(sliced.logprob_token_ids) == 7 assert sliced.logprob_token_ids == [ [5, 6], @@ -75,27 +75,25 @@ def test_slice_from_middle(self): [15, 16], [17, 18], ] - assert sliced.cu_num_generated_tokens == [0, 3, 7] + assert sliced.cu_num_generated_tokens is None def test_slice_single_request(self): """Test slicing a single request""" - sliced = self.logprobsLists.slice(1, 2) + sliced = self.logprobsLists.slice_request(1, num_positions=3) assert len(sliced.logprob_token_ids) == 3 assert sliced.logprob_token_ids == [[5, 6], [7, 8], [9, 10]] - assert sliced.cu_num_generated_tokens == [0, 3] + assert sliced.cu_num_generated_tokens is None def test_slice_last_request(self): """Test slicing the last request""" - sliced = self.logprobsLists.slice(2, 3) + sliced = self.logprobsLists.slice_request(2, num_positions=4) assert len(sliced.logprob_token_ids) == 4 assert sliced.logprob_token_ids == [[11, 12], [13, 14], [15, 16], [17, 18]] - assert sliced.cu_num_generated_tokens == [0, 4] + assert sliced.cu_num_generated_tokens is None def test_slice_all_requests(self): """Test slicing all requests (full slice)""" - sliced = self.logprobsLists.slice(0, 3) + sliced = self.logprobsLists.slice_request(0, num_positions=9) assert len(sliced.logprob_token_ids) == 9 # All tokens assert sliced.logprob_token_ids == self.logprobsLists.logprob_token_ids - assert ( - sliced.cu_num_generated_tokens == self.logprobsLists.cu_num_generated_tokens - ) + assert sliced.cu_num_generated_tokens is None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 0304a8ec48bf..e3ec8440a932 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -234,11 +234,15 @@ def schedule(self) -> SchedulerOutput: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) - # Make sure the input position does not exceed the max model len or - # request's max_tokens. - # This is necessary when using spec decoding and/or async scheduling. + num_spec_placeholders = max(0, request.num_output_placeholders - 1) max_total_tokens = min( - request.num_prompt_tokens + request.max_tokens, self.max_model_len + # Avoid scheduling tokens that we're sure won't will be needed based on + # request.max_tokens. For this calculation we assume placeholder + # speculated output tokens are rejected. + request.num_prompt_tokens + request.max_tokens + num_spec_placeholders, + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + self.max_model_len, ) num_new_tokens = min( num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens @@ -1089,7 +1093,7 @@ def update_from_output( and request.sampling_params.logprobs is not None and logprobs ): - new_logprobs = logprobs.slice(req_index, req_index + 1) + new_logprobs = logprobs.slice_request(req_index, len(new_token_ids)) if new_token_ids and self.structured_output_manager.should_advance(request): struct_output_request = request.structured_output_request diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index e32d5bb608b1..8110deb5a610 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -29,27 +29,15 @@ class LogprobsLists(NamedTuple): # different for each request. cu_num_generated_tokens: list[int] | None = None - def slice(self, start_req_idx: int, end_req_idx: int): - if self.cu_num_generated_tokens: - start = self.cu_num_generated_tokens[start_req_idx] - end = self.cu_num_generated_tokens[end_req_idx] - # Recompute cumulative array starting from 0 - cu_num_offset = self.cu_num_generated_tokens[start_req_idx] - sliced_cu_num_generated_tokens = [ - cu_num - cu_num_offset - for cu_num in self.cu_num_generated_tokens[ - start_req_idx : end_req_idx + 1 - ] - ] - else: - start = start_req_idx - end = end_req_idx - sliced_cu_num_generated_tokens = None + def slice_request(self, req_idx: int, num_positions: int): + if self.cu_num_generated_tokens is not None: + req_idx = self.cu_num_generated_tokens[req_idx] + end_idx = req_idx + num_positions return LogprobsLists( - self.logprob_token_ids[start:end], - self.logprobs[start:end], - self.sampled_token_ranks[start:end], - sliced_cu_num_generated_tokens, + self.logprob_token_ids[req_idx:end_idx], + self.logprobs[req_idx:end_idx], + self.sampled_token_ranks[req_idx:end_idx], + None, )