Skip to content

[V1][PP] Optimization: continue scheduling prefill chunks #17080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,6 @@ def test_stop_via_update_from_output():
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)

scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
Expand Down Expand Up @@ -489,7 +488,6 @@ def test_stop_via_update_from_output():
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)

scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
Expand Down Expand Up @@ -539,7 +537,6 @@ def test_stop_via_update_from_output():
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)

scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
Expand Down Expand Up @@ -589,7 +586,6 @@ def test_stop_via_update_from_output():
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
scheduler.scheduled_req_ids.add(requests[0].request_id)

scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
Expand Down
106 changes: 84 additions & 22 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# SPDX-License-Identifier: Apache-2.0

import copy
import threading
import time
import uuid
from concurrent.futures import Future
from concurrent.futures import Future, ThreadPoolExecutor

import pytest
from transformers import AutoTokenizer
Expand Down Expand Up @@ -244,33 +243,33 @@ def initialize_from_config(
self, kv_cache_configs: list[KVCacheConfig]) -> None:
super().initialize_from_config(kv_cache_configs)

# This executor actually can only run 1 batch at a time
self.semaphore = threading.Semaphore(1)
# Create a thread pool with a single worker
self.thread_pool = ThreadPoolExecutor(max_workers=1)

def execute_model(
self,
scheduler_output,
) -> Future[ModelRunnerOutput]:
"""Make execute_model non-blocking."""
future: Future[ModelRunnerOutput] = Future()

def _thread_wrapper(scheduler_output, future):
with self.semaphore:
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
# Make a copy because output[0] may be reused
# by the next batch.
output = copy.deepcopy(output[0])
future.set_result(output)
def _execute():
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
# Make a copy because output[0] may be reused
# by the next batch.
return copy.deepcopy(output[0])

threading.Thread(target=_thread_wrapper,
args=(scheduler_output, future)).start()
return future
# Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute)

@property
def max_concurrent_batches(self) -> int:
return 2

def shutdown(self):
if hasattr(self, 'thread_pool'):
self.thread_pool.shutdown(wait=False)

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

Expand Down Expand Up @@ -299,14 +298,77 @@ def max_concurrent_batches(self) -> int:
# Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 1
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 10
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[
req0.request_id].num_computed_tokens == 10

# Schedule Batch 2: (2, req0), (8, req1)
assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 2
assert scheduler_output.num_scheduled_tokens[1] == 8
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[0].num_computed_tokens == 12
assert engine_core.scheduler.requests[1].num_computed_tokens == 8

assert engine_core.scheduler.get_num_unfinished_requests() == 2

# Loop through both requests.
while engine_core.scheduler.get_num_unfinished_requests() == 2:
engine_core.step_with_batch_queue()
# Batch queue is full. Finish Batch 1.
engine_core.step_with_batch_queue()

# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
# because it is in the decoding stage now.
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[1] == 4

# Reaching here when got the result of the first request.
while engine_core.scheduler.get_num_unfinished_requests() == 1:
engine_core.step_with_batch_queue()
# Batch queue is full. Finish Batch 2. Get first token of req0.
output = engine_core.step_with_batch_queue()
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13

# Schedule Batch 4: (1, req0).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 1

# Batch queue is full. Finish Batch 3. Get first token of req1.
output = engine_core.step_with_batch_queue()
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13

# Schedule Batch 5: (1, req1).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[1] == 1

# Loop until req0 is finished.
step = 0
req_id = 0
expected_num_tokens = [
engine_core.scheduler.requests[0].num_tokens + 1,
engine_core.scheduler.requests[1].num_tokens + 1,
]
while engine_core.scheduler.get_num_unfinished_requests() == 2:
output = engine_core.step_with_batch_queue()
if step % 2 == 0:
# Even steps consumes an output.
assert output is not None
assert len(output.outputs) == 1
if req_id in engine_core.scheduler.requests:
assert engine_core.scheduler.requests[
req_id].num_tokens == expected_num_tokens[req_id]
expected_num_tokens[req_id] += 1
req_id = (req_id + 1) % 2
else:
# Odd steps schedules a new batch.
assert output is None
step += 1
5 changes: 0 additions & 5 deletions vllm/v1/core/sched/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,6 @@ def has_requests(self) -> bool:
not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests()

@abstractmethod
def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""
raise NotImplementedError

@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset the prefix cache for KV cache.
Expand Down
67 changes: 32 additions & 35 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import time
from collections import deque
from collections import defaultdict, deque
from collections.abc import Iterable
from typing import Optional, Union

Expand Down Expand Up @@ -88,9 +88,6 @@ def __init__(
# Priority queues for requests.
self.waiting: deque[Request] = deque()
self.running: list[Request] = []
# The requests that have been scheduled and are being executed
# by the executor.
self.scheduled_req_ids: set[str] = set()

# The request IDs that are finished in between the previous and the
# current steps. This is used to notify the workers about the finished
Expand All @@ -100,8 +97,9 @@ def __init__(

# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> CachedRequestData
self._cached_reqs_data: dict[str, CachedRequestData] = {}
# Request id -> deque of CachedRequestData
self._cached_reqs_data: dict[
str, deque[CachedRequestData]] = defaultdict(deque)

# Encoder-related.
# Calculate encoder cache size if applicable
Expand Down Expand Up @@ -171,10 +169,6 @@ def schedule(self) -> SchedulerOutput:
req_index = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
if request.request_id in self.scheduled_req_ids:
# This request has already been scheduled.
req_index += 1
continue

num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens)
Expand All @@ -183,33 +177,35 @@ def schedule(self) -> SchedulerOutput:
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0

# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - request.num_computed_tokens)
assert num_new_tokens > 0

# Schedule encoder inputs.
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled because the encoder budget
# or the encoder cache is exhausted.
# NOTE(woosuk): By using `continue` instead of `break` here,
# we intentionally relax the strict FCFS scheduling policy
# to allow lower-priority requests to be scheduled when a
# higher-priority request is blocked by encoder constraints.
req_index += 1
continue
else:
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget

if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
# 1. No new tokens to schedule. This may happen when PP>1 and
# we have already scheduled all prompt tokens but they are
# not finished yet.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
req_index += 1
continue

while True:
new_blocks = self.kv_cache_manager.allocate_slots(
Expand Down Expand Up @@ -243,7 +239,6 @@ def schedule(self) -> SchedulerOutput:

# Schedule the request.
scheduled_running_reqs.append(request)
self.scheduled_req_ids.add(request.request_id)
if request.use_structured_output:
# PERF: in case of chunked prefill,
# request might not include any new tokens.
Expand Down Expand Up @@ -382,7 +377,6 @@ def schedule(self) -> SchedulerOutput:
request.request_id] = req_index
req_index += 1
self.running.append(request)
self.scheduled_req_ids.add(request.request_id)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED,
scheduled_timestamp)
Expand Down Expand Up @@ -521,18 +515,21 @@ def _make_cached_request_data(
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_regular_tokens]
req_data = self._cached_reqs_data.get(request.request_id)
if req_data is not None:

req_data_queue = self._cached_reqs_data.get(request.request_id)
if req_data_queue:
req_data = req_data_queue.popleft()
req_data.resumed_from_preemption = resumed_from_preemption
req_data.new_token_ids = new_token_ids
req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens
else:
# No cached request data, or all cached request data has been
# used by the scheduled requests.
req_data = CachedRequestData.from_request(request,
resumed_from_preemption,
new_token_ids,
new_block_ids)
self._cached_reqs_data[request.request_id] = req_data
return req_data

def _try_schedule_encoder_inputs(
Expand Down Expand Up @@ -561,6 +558,8 @@ def _try_schedule_encoder_inputs(
Note that num_computed_tokens includes both locally cached
blocks and externally cached blocks (via KVConnector).
"""
if num_new_tokens == 0 or not request.has_encoder_inputs:
return [], num_new_tokens, encoder_budget
encoder_inputs_to_schedule: list[int] = []
mm_positions = request.mm_positions
assert mm_positions is not None
Expand Down Expand Up @@ -728,10 +727,13 @@ def update_from_output(
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors

self.scheduled_req_ids.remove(req_id)
if not stopped:
new_running.append(request)

# Return the cached request data to the queue so they can be reused.
for req_data in scheduler_output.scheduled_cached_reqs:
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line introduces a memory leak. We clear _cached_reqs_data in free_request (which is called above). We add this data after the request has been freed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have a fix

self._cached_reqs_data[req_data.req_id].append(req_data)

self.running = new_running
engine_core_outputs = EngineCoreOutputs(
outputs=outputs,
Expand Down Expand Up @@ -774,7 +776,6 @@ def finish_requests(

if request.status == RequestStatus.RUNNING:
self.running.remove(request)
self.scheduled_req_ids.discard(request.request_id)
else:
self.waiting.remove(request)
request.status = finished_status
Expand All @@ -795,10 +796,6 @@ def get_num_unfinished_requests(self) -> int:
def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0

def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)

def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()

Expand Down
Loading