3
3
from __future__ import annotations
4
4
5
5
import time
6
- from collections import deque
6
+ from collections import defaultdict , deque
7
7
from collections .abc import Iterable
8
8
from typing import Optional , Union
9
9
@@ -88,9 +88,6 @@ def __init__(
88
88
# Priority queues for requests.
89
89
self .waiting : deque [Request ] = deque ()
90
90
self .running : list [Request ] = []
91
- # The requests that have been scheduled and are being executed
92
- # by the executor.
93
- self .scheduled_req_ids : set [str ] = set ()
94
91
95
92
# The request IDs that are finished in between the previous and the
96
93
# current steps. This is used to notify the workers about the finished
@@ -100,8 +97,9 @@ def __init__(
100
97
101
98
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
102
99
# them at each scheduling step.
103
- # Request id -> CachedRequestData
104
- self ._cached_reqs_data : dict [str , CachedRequestData ] = {}
100
+ # Request id -> deque of CachedRequestData
101
+ self ._cached_reqs_data : dict [
102
+ str , deque [CachedRequestData ]] = defaultdict (deque )
105
103
106
104
# Encoder-related.
107
105
# Calculate encoder cache size if applicable
@@ -171,10 +169,6 @@ def schedule(self) -> SchedulerOutput:
171
169
req_index = 0
172
170
while req_index < len (self .running ) and token_budget > 0 :
173
171
request = self .running [req_index ]
174
- if request .request_id in self .scheduled_req_ids :
175
- # This request has already been scheduled.
176
- req_index += 1
177
- continue
178
172
179
173
num_new_tokens = (request .num_tokens_with_spec -
180
174
request .num_computed_tokens )
@@ -183,33 +177,35 @@ def schedule(self) -> SchedulerOutput:
183
177
num_new_tokens = (
184
178
self .scheduler_config .long_prefill_token_threshold )
185
179
num_new_tokens = min (num_new_tokens , token_budget )
186
- assert num_new_tokens > 0
187
180
188
181
# Make sure the input position does not exceed the max model len.
189
182
# This is necessary when using spec decoding.
190
183
num_new_tokens = min (
191
184
num_new_tokens ,
192
185
self .max_model_len - request .num_computed_tokens )
193
- assert num_new_tokens > 0
194
186
195
187
# Schedule encoder inputs.
188
+ encoder_inputs_to_schedule = None
189
+ new_encoder_budget = encoder_budget
196
190
if request .has_encoder_inputs :
197
191
(encoder_inputs_to_schedule , num_new_tokens ,
198
192
new_encoder_budget ) = self ._try_schedule_encoder_inputs (
199
193
request , request .num_computed_tokens , num_new_tokens ,
200
194
encoder_budget )
201
- if num_new_tokens == 0 :
202
- # The request cannot be scheduled because the encoder budget
203
- # or the encoder cache is exhausted.
204
- # NOTE(woosuk): By using `continue` instead of `break` here,
205
- # we intentionally relax the strict FCFS scheduling policy
206
- # to allow lower-priority requests to be scheduled when a
207
- # higher-priority request is blocked by encoder constraints.
208
- req_index += 1
209
- continue
210
- else :
211
- encoder_inputs_to_schedule = None
212
- new_encoder_budget = encoder_budget
195
+
196
+ if num_new_tokens == 0 :
197
+ # The request cannot be scheduled because one of the following
198
+ # reasons:
199
+ # 1. No new tokens to schedule. This may happen when PP>1 and
200
+ # we have already scheduled all prompt tokens but they are
201
+ # not finished yet.
202
+ # 2. The encoder budget is exhausted.
203
+ # 3. The encoder cache is exhausted.
204
+ # NOTE(woosuk): Here, by doing `continue` instead of `break`,
205
+ # we do not strictly follow the FCFS scheduling policy and
206
+ # allow the lower-priority requests to be scheduled.
207
+ req_index += 1
208
+ continue
213
209
214
210
while True :
215
211
new_blocks = self .kv_cache_manager .allocate_slots (
@@ -243,7 +239,6 @@ def schedule(self) -> SchedulerOutput:
243
239
244
240
# Schedule the request.
245
241
scheduled_running_reqs .append (request )
246
- self .scheduled_req_ids .add (request .request_id )
247
242
if request .use_structured_output :
248
243
# PERF: in case of chunked prefill,
249
244
# request might not include any new tokens.
@@ -382,7 +377,6 @@ def schedule(self) -> SchedulerOutput:
382
377
request .request_id ] = req_index
383
378
req_index += 1
384
379
self .running .append (request )
385
- self .scheduled_req_ids .add (request .request_id )
386
380
if self .log_stats :
387
381
request .record_event (EngineCoreEventType .SCHEDULED ,
388
382
scheduled_timestamp )
@@ -521,18 +515,21 @@ def _make_cached_request_data(
521
515
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
522
516
new_token_ids = request .all_token_ids [
523
517
num_computed_tokens :num_computed_tokens + num_regular_tokens ]
524
- req_data = self ._cached_reqs_data .get (request .request_id )
525
- if req_data is not None :
518
+
519
+ req_data_queue = self ._cached_reqs_data .get (request .request_id )
520
+ if req_data_queue :
521
+ req_data = req_data_queue .popleft ()
526
522
req_data .resumed_from_preemption = resumed_from_preemption
527
523
req_data .new_token_ids = new_token_ids
528
524
req_data .new_block_ids = new_block_ids
529
525
req_data .num_computed_tokens = num_computed_tokens
530
526
else :
527
+ # No cached request data, or all cached request data has been
528
+ # used by the scheduled requests.
531
529
req_data = CachedRequestData .from_request (request ,
532
530
resumed_from_preemption ,
533
531
new_token_ids ,
534
532
new_block_ids )
535
- self ._cached_reqs_data [request .request_id ] = req_data
536
533
return req_data
537
534
538
535
def _try_schedule_encoder_inputs (
@@ -561,6 +558,8 @@ def _try_schedule_encoder_inputs(
561
558
Note that num_computed_tokens includes both locally cached
562
559
blocks and externally cached blocks (via KVConnector).
563
560
"""
561
+ if num_new_tokens == 0 or not request .has_encoder_inputs :
562
+ return [], num_new_tokens , encoder_budget
564
563
encoder_inputs_to_schedule : list [int ] = []
565
564
mm_positions = request .mm_positions
566
565
assert mm_positions is not None
@@ -728,10 +727,13 @@ def update_from_output(
728
727
# Invariant: EngineCore returns no partial prefill outputs.
729
728
assert not prompt_logprobs_tensors
730
729
731
- self .scheduled_req_ids .remove (req_id )
732
730
if not stopped :
733
731
new_running .append (request )
734
732
733
+ # Return the cached request data to the queue so they can be reused.
734
+ for req_data in scheduler_output .scheduled_cached_reqs :
735
+ self ._cached_reqs_data [req_data .req_id ].append (req_data )
736
+
735
737
self .running = new_running
736
738
engine_core_outputs = EngineCoreOutputs (
737
739
outputs = outputs ,
@@ -774,7 +776,6 @@ def finish_requests(
774
776
775
777
if request .status == RequestStatus .RUNNING :
776
778
self .running .remove (request )
777
- self .scheduled_req_ids .discard (request .request_id )
778
779
else :
779
780
self .waiting .remove (request )
780
781
request .status = finished_status
@@ -795,10 +796,6 @@ def get_num_unfinished_requests(self) -> int:
795
796
def has_finished_requests (self ) -> bool :
796
797
return len (self .finished_req_ids ) > 0
797
798
798
- def get_num_unscheduled_requests (self ) -> int :
799
- """Number of requests that are not being processed by the executor."""
800
- return self .get_num_unfinished_requests () - len (self .scheduled_req_ids )
801
-
802
799
def reset_prefix_cache (self ) -> bool :
803
800
return self .kv_cache_manager .reset_prefix_cache ()
804
801
0 commit comments