Skip to content

Commit 5cda748

Browse files
ruisearch42liuzijing2014
authored andcommitted
[V1][PP] Optimization: continue scheduling prefill chunks (vllm-project#17080)
Signed-off-by: Rui Qiao <[email protected]>
1 parent 2b7679c commit 5cda748

File tree

5 files changed

+128
-74
lines changed

5 files changed

+128
-74
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,6 @@ def test_stop_via_update_from_output():
437437
req.num_computed_tokens = req.num_tokens
438438
scheduler.requests[req.request_id] = req
439439
scheduler.running.append(req)
440-
scheduler.scheduled_req_ids.add(req.request_id)
441440

442441
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
443442
scheduled_cached_reqs=[],
@@ -489,7 +488,6 @@ def test_stop_via_update_from_output():
489488
req.num_computed_tokens = req.num_tokens
490489
scheduler.requests[req.request_id] = req
491490
scheduler.running.append(req)
492-
scheduler.scheduled_req_ids.add(req.request_id)
493491

494492
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
495493
scheduled_cached_reqs=[],
@@ -539,7 +537,6 @@ def test_stop_via_update_from_output():
539537
req.num_computed_tokens = req.num_tokens
540538
scheduler.requests[req.request_id] = req
541539
scheduler.running.append(req)
542-
scheduler.scheduled_req_ids.add(req.request_id)
543540

544541
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
545542
scheduled_cached_reqs=[],
@@ -589,7 +586,6 @@ def test_stop_via_update_from_output():
589586
requests[0].num_computed_tokens = requests[0].num_tokens
590587
scheduler.requests[requests[0].request_id] = requests[0]
591588
scheduler.running.append(requests[0])
592-
scheduler.scheduled_req_ids.add(requests[0].request_id)
593589

594590
scheduler_output = SchedulerOutput(
595591
scheduled_new_reqs=[],

tests/v1/engine/test_engine_core.py

Lines changed: 84 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import copy
4-
import threading
54
import time
65
import uuid
7-
from concurrent.futures import Future
6+
from concurrent.futures import Future, ThreadPoolExecutor
87

98
import pytest
109
from transformers import AutoTokenizer
@@ -244,33 +243,33 @@ def initialize_from_config(
244243
self, kv_cache_configs: list[KVCacheConfig]) -> None:
245244
super().initialize_from_config(kv_cache_configs)
246245

247-
# This executor actually can only run 1 batch at a time
248-
self.semaphore = threading.Semaphore(1)
246+
# Create a thread pool with a single worker
247+
self.thread_pool = ThreadPoolExecutor(max_workers=1)
249248

250249
def execute_model(
251250
self,
252251
scheduler_output,
253252
) -> Future[ModelRunnerOutput]:
254253
"""Make execute_model non-blocking."""
255-
future: Future[ModelRunnerOutput] = Future()
256254

257-
def _thread_wrapper(scheduler_output, future):
258-
with self.semaphore:
259-
output = self.collective_rpc("execute_model",
260-
args=(scheduler_output, ))
261-
# Make a copy because output[0] may be reused
262-
# by the next batch.
263-
output = copy.deepcopy(output[0])
264-
future.set_result(output)
255+
def _execute():
256+
output = self.collective_rpc("execute_model",
257+
args=(scheduler_output, ))
258+
# Make a copy because output[0] may be reused
259+
# by the next batch.
260+
return copy.deepcopy(output[0])
265261

266-
threading.Thread(target=_thread_wrapper,
267-
args=(scheduler_output, future)).start()
268-
return future
262+
# Use the thread pool instead of creating a new thread
263+
return self.thread_pool.submit(_execute)
269264

270265
@property
271266
def max_concurrent_batches(self) -> int:
272267
return 2
273268

269+
def shutdown(self):
270+
if hasattr(self, 'thread_pool'):
271+
self.thread_pool.shutdown(wait=False)
272+
274273
with monkeypatch.context() as m:
275274
m.setenv("VLLM_USE_V1", "1")
276275

@@ -299,14 +298,77 @@ def max_concurrent_batches(self) -> int:
299298
# Schedule Batch 1: (10, req0)
300299
assert engine_core.step_with_batch_queue() is None
301300
assert engine_core.batch_queue.qsize() == 1
301+
scheduler_output = engine_core.batch_queue.queue[-1][1]
302+
assert scheduler_output.num_scheduled_tokens[0] == 10
303+
# num_computed_tokens should have been updated immediately.
304+
assert engine_core.scheduler.requests[
305+
req0.request_id].num_computed_tokens == 10
306+
307+
# Schedule Batch 2: (2, req0), (8, req1)
302308
assert engine_core.step_with_batch_queue() is None
303309
assert engine_core.batch_queue.qsize() == 2
310+
scheduler_output = engine_core.batch_queue.queue[-1][1]
311+
assert scheduler_output.num_scheduled_tokens[0] == 2
312+
assert scheduler_output.num_scheduled_tokens[1] == 8
313+
# num_computed_tokens should have been updated immediately.
314+
assert engine_core.scheduler.requests[0].num_computed_tokens == 12
315+
assert engine_core.scheduler.requests[1].num_computed_tokens == 8
316+
304317
assert engine_core.scheduler.get_num_unfinished_requests() == 2
305318

306-
# Loop through both requests.
307-
while engine_core.scheduler.get_num_unfinished_requests() == 2:
308-
engine_core.step_with_batch_queue()
319+
# Batch queue is full. Finish Batch 1.
320+
engine_core.step_with_batch_queue()
321+
322+
# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
323+
# because it is in the decoding stage now.
324+
engine_core.step_with_batch_queue()
325+
assert engine_core.batch_queue.qsize() == 2
326+
scheduler_output = engine_core.batch_queue.queue[-1][1]
327+
assert scheduler_output.num_scheduled_tokens[1] == 4
309328

310-
# Reaching here when got the result of the first request.
311-
while engine_core.scheduler.get_num_unfinished_requests() == 1:
312-
engine_core.step_with_batch_queue()
329+
# Batch queue is full. Finish Batch 2. Get first token of req0.
330+
output = engine_core.step_with_batch_queue()
331+
assert output is not None
332+
assert len(output.outputs) == 1
333+
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
334+
335+
# Schedule Batch 4: (1, req0).
336+
engine_core.step_with_batch_queue()
337+
assert engine_core.batch_queue.qsize() == 2
338+
scheduler_output = engine_core.batch_queue.queue[-1][1]
339+
assert scheduler_output.num_scheduled_tokens[0] == 1
340+
341+
# Batch queue is full. Finish Batch 3. Get first token of req1.
342+
output = engine_core.step_with_batch_queue()
343+
assert output is not None
344+
assert len(output.outputs) == 1
345+
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
346+
347+
# Schedule Batch 5: (1, req1).
348+
engine_core.step_with_batch_queue()
349+
assert engine_core.batch_queue.qsize() == 2
350+
scheduler_output = engine_core.batch_queue.queue[-1][1]
351+
assert scheduler_output.num_scheduled_tokens[1] == 1
352+
353+
# Loop until req0 is finished.
354+
step = 0
355+
req_id = 0
356+
expected_num_tokens = [
357+
engine_core.scheduler.requests[0].num_tokens + 1,
358+
engine_core.scheduler.requests[1].num_tokens + 1,
359+
]
360+
while engine_core.scheduler.get_num_unfinished_requests() == 2:
361+
output = engine_core.step_with_batch_queue()
362+
if step % 2 == 0:
363+
# Even steps consumes an output.
364+
assert output is not None
365+
assert len(output.outputs) == 1
366+
if req_id in engine_core.scheduler.requests:
367+
assert engine_core.scheduler.requests[
368+
req_id].num_tokens == expected_num_tokens[req_id]
369+
expected_num_tokens[req_id] += 1
370+
req_id = (req_id + 1) % 2
371+
else:
372+
# Odd steps schedules a new batch.
373+
assert output is None
374+
step += 1

vllm/v1/core/sched/interface.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,6 @@ def has_requests(self) -> bool:
117117
not yet returned in SchedulerOutputs."""
118118
return self.has_unfinished_requests() or self.has_finished_requests()
119119

120-
@abstractmethod
121-
def get_num_unscheduled_requests(self) -> int:
122-
"""Number of requests that are not being processed by the executor."""
123-
raise NotImplementedError
124-
125120
@abstractmethod
126121
def reset_prefix_cache(self) -> bool:
127122
"""Reset the prefix cache for KV cache.

vllm/v1/core/sched/scheduler.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import time
6-
from collections import deque
6+
from collections import defaultdict, deque
77
from collections.abc import Iterable
88
from typing import Optional, Union
99

@@ -88,9 +88,6 @@ def __init__(
8888
# Priority queues for requests.
8989
self.waiting: deque[Request] = deque()
9090
self.running: list[Request] = []
91-
# The requests that have been scheduled and are being executed
92-
# by the executor.
93-
self.scheduled_req_ids: set[str] = set()
9491

9592
# The request IDs that are finished in between the previous and the
9693
# current steps. This is used to notify the workers about the finished
@@ -100,8 +97,9 @@ def __init__(
10097

10198
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
10299
# them at each scheduling step.
103-
# Request id -> CachedRequestData
104-
self._cached_reqs_data: dict[str, CachedRequestData] = {}
100+
# Request id -> deque of CachedRequestData
101+
self._cached_reqs_data: dict[
102+
str, deque[CachedRequestData]] = defaultdict(deque)
105103

106104
# Encoder-related.
107105
# Calculate encoder cache size if applicable
@@ -171,10 +169,6 @@ def schedule(self) -> SchedulerOutput:
171169
req_index = 0
172170
while req_index < len(self.running) and token_budget > 0:
173171
request = self.running[req_index]
174-
if request.request_id in self.scheduled_req_ids:
175-
# This request has already been scheduled.
176-
req_index += 1
177-
continue
178172

179173
num_new_tokens = (request.num_tokens_with_spec -
180174
request.num_computed_tokens)
@@ -183,33 +177,35 @@ def schedule(self) -> SchedulerOutput:
183177
num_new_tokens = (
184178
self.scheduler_config.long_prefill_token_threshold)
185179
num_new_tokens = min(num_new_tokens, token_budget)
186-
assert num_new_tokens > 0
187180

188181
# Make sure the input position does not exceed the max model len.
189182
# This is necessary when using spec decoding.
190183
num_new_tokens = min(
191184
num_new_tokens,
192185
self.max_model_len - request.num_computed_tokens)
193-
assert num_new_tokens > 0
194186

195187
# Schedule encoder inputs.
188+
encoder_inputs_to_schedule = None
189+
new_encoder_budget = encoder_budget
196190
if request.has_encoder_inputs:
197191
(encoder_inputs_to_schedule, num_new_tokens,
198192
new_encoder_budget) = self._try_schedule_encoder_inputs(
199193
request, request.num_computed_tokens, num_new_tokens,
200194
encoder_budget)
201-
if num_new_tokens == 0:
202-
# The request cannot be scheduled because the encoder budget
203-
# or the encoder cache is exhausted.
204-
# NOTE(woosuk): By using `continue` instead of `break` here,
205-
# we intentionally relax the strict FCFS scheduling policy
206-
# to allow lower-priority requests to be scheduled when a
207-
# higher-priority request is blocked by encoder constraints.
208-
req_index += 1
209-
continue
210-
else:
211-
encoder_inputs_to_schedule = None
212-
new_encoder_budget = encoder_budget
195+
196+
if num_new_tokens == 0:
197+
# The request cannot be scheduled because one of the following
198+
# reasons:
199+
# 1. No new tokens to schedule. This may happen when PP>1 and
200+
# we have already scheduled all prompt tokens but they are
201+
# not finished yet.
202+
# 2. The encoder budget is exhausted.
203+
# 3. The encoder cache is exhausted.
204+
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
205+
# we do not strictly follow the FCFS scheduling policy and
206+
# allow the lower-priority requests to be scheduled.
207+
req_index += 1
208+
continue
213209

214210
while True:
215211
new_blocks = self.kv_cache_manager.allocate_slots(
@@ -243,7 +239,6 @@ def schedule(self) -> SchedulerOutput:
243239

244240
# Schedule the request.
245241
scheduled_running_reqs.append(request)
246-
self.scheduled_req_ids.add(request.request_id)
247242
if request.use_structured_output:
248243
# PERF: in case of chunked prefill,
249244
# request might not include any new tokens.
@@ -382,7 +377,6 @@ def schedule(self) -> SchedulerOutput:
382377
request.request_id] = req_index
383378
req_index += 1
384379
self.running.append(request)
385-
self.scheduled_req_ids.add(request.request_id)
386380
if self.log_stats:
387381
request.record_event(EngineCoreEventType.SCHEDULED,
388382
scheduled_timestamp)
@@ -521,18 +515,21 @@ def _make_cached_request_data(
521515
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
522516
new_token_ids = request.all_token_ids[
523517
num_computed_tokens:num_computed_tokens + num_regular_tokens]
524-
req_data = self._cached_reqs_data.get(request.request_id)
525-
if req_data is not None:
518+
519+
req_data_queue = self._cached_reqs_data.get(request.request_id)
520+
if req_data_queue:
521+
req_data = req_data_queue.popleft()
526522
req_data.resumed_from_preemption = resumed_from_preemption
527523
req_data.new_token_ids = new_token_ids
528524
req_data.new_block_ids = new_block_ids
529525
req_data.num_computed_tokens = num_computed_tokens
530526
else:
527+
# No cached request data, or all cached request data has been
528+
# used by the scheduled requests.
531529
req_data = CachedRequestData.from_request(request,
532530
resumed_from_preemption,
533531
new_token_ids,
534532
new_block_ids)
535-
self._cached_reqs_data[request.request_id] = req_data
536533
return req_data
537534

538535
def _try_schedule_encoder_inputs(
@@ -561,6 +558,8 @@ def _try_schedule_encoder_inputs(
561558
Note that num_computed_tokens includes both locally cached
562559
blocks and externally cached blocks (via KVConnector).
563560
"""
561+
if num_new_tokens == 0 or not request.has_encoder_inputs:
562+
return [], num_new_tokens, encoder_budget
564563
encoder_inputs_to_schedule: list[int] = []
565564
mm_positions = request.mm_positions
566565
assert mm_positions is not None
@@ -728,10 +727,13 @@ def update_from_output(
728727
# Invariant: EngineCore returns no partial prefill outputs.
729728
assert not prompt_logprobs_tensors
730729

731-
self.scheduled_req_ids.remove(req_id)
732730
if not stopped:
733731
new_running.append(request)
734732

733+
# Return the cached request data to the queue so they can be reused.
734+
for req_data in scheduler_output.scheduled_cached_reqs:
735+
self._cached_reqs_data[req_data.req_id].append(req_data)
736+
735737
self.running = new_running
736738
engine_core_outputs = EngineCoreOutputs(
737739
outputs=outputs,
@@ -774,7 +776,6 @@ def finish_requests(
774776

775777
if request.status == RequestStatus.RUNNING:
776778
self.running.remove(request)
777-
self.scheduled_req_ids.discard(request.request_id)
778779
else:
779780
self.waiting.remove(request)
780781
request.status = finished_status
@@ -795,10 +796,6 @@ def get_num_unfinished_requests(self) -> int:
795796
def has_finished_requests(self) -> bool:
796797
return len(self.finished_req_ids) > 0
797798

798-
def get_num_unscheduled_requests(self) -> int:
799-
"""Number of requests that are not being processed by the executor."""
800-
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)
801-
802799
def reset_prefix_cache(self) -> bool:
803800
return self.kv_cache_manager.reset_prefix_cache()
804801

0 commit comments

Comments
 (0)