Skip to content

Commit a6fed02

Browse files
authored
[V1][PP] Support PP for MultiprocExecutor (#14219)
Signed-off-by: jiang1.li <[email protected]> Signed-off-by: jiang.li <[email protected]>
1 parent d419aa5 commit a6fed02

File tree

5 files changed

+98
-28
lines changed

5 files changed

+98
-28
lines changed

tests/distributed/test_pipeline_parallel.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,8 @@ def detailed(
100100
eager_mode=True,
101101
chunked_prefill=False),
102102
],
103-
# only ray is supported for V1
104-
distributed_backends=["mp", "ray", "ray"],
105-
vllm_major_versions=["0", "0", "1"],
103+
distributed_backends=["mp", "mp", "ray", "ray"],
104+
vllm_major_versions=["0", "1", "0", "1"],
106105
task=task,
107106
test_options=PPTestOptions(multi_node_only=multi_node_only,
108107
load_format=load_format),
@@ -350,6 +349,11 @@ def _compare_tp(
350349
# Temporary. Currently when zeromq + SPMD is used, it does not properly
351350
# terminate because of a Ray Compiled Graph issue.
352351
common_args.append("--disable-frontend-multiprocessing")
352+
elif distributed_backend == "mp":
353+
# Both V0/V1 of multiprocessing executor support PP
354+
pp_env = {
355+
"VLLM_USE_V1": vllm_major_version,
356+
}
353357
else:
354358
pp_env = None
355359

vllm/engine/arg_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,11 +1338,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13381338
and _warn_or_fallback("Engine in background thread")):
13391339
return False
13401340

1341-
# PP is supported on V1 with Ray distributed executor,
1342-
# but off for MP distributed executor for now.
13431341
if (self.pipeline_parallel_size > 1
1344-
and self.distributed_executor_backend != "ray"):
1345-
name = "Pipeline Parallelism without Ray distributed executor"
1342+
and self.distributed_executor_backend not in ["ray", "mp"]):
1343+
name = "Pipeline Parallelism without Ray distributed executor " \
1344+
"or multiprocessing executor"
13461345
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
13471346
return False
13481347

vllm/v1/executor/multiproc_executor.py

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import time
99
import traceback
1010
import weakref
11-
from concurrent.futures import Future
11+
from concurrent.futures import Future, ThreadPoolExecutor
1212
from dataclasses import dataclass
1313
from enum import Enum, auto
1414
from functools import partial
@@ -53,10 +53,11 @@ def _init_executor(self) -> None:
5353

5454
self.world_size = self.parallel_config.world_size
5555
tensor_parallel_size = self.parallel_config.tensor_parallel_size
56-
assert self.world_size == tensor_parallel_size, (
56+
pp_parallel_size = self.parallel_config.pipeline_parallel_size
57+
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
5758
f"world_size ({self.world_size}) must be equal to the "
58-
f"tensor_parallel_size ({tensor_parallel_size}). "
59-
f"Pipeline parallelism is not yet implemented in v1")
59+
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
60+
f"_parallel_size ({pp_parallel_size}). ")
6061

6162
# Set multiprocessing envs that are common to V0 and V1
6263
set_multiprocessing_worker_envs(self.parallel_config)
@@ -104,6 +105,17 @@ def _init_executor(self) -> None:
104105
self._ensure_worker_termination(
105106
[w.proc for w in unready_workers])
106107

108+
# For pipeline parallel, we use a thread pool for asynchronous
109+
# execute_model.
110+
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
111+
if self.max_concurrent_batches > 1:
112+
# Note: must use only 1 IO thread to keep dequeue sequence
113+
# from the response queue
114+
self.io_thread_pool = ThreadPoolExecutor(
115+
max_workers=1, thread_name_prefix="mp_exec_io")
116+
117+
self.output_rank = self._get_output_rank()
118+
107119
def start_worker_monitor(self):
108120
workers = self.workers
109121
self_ref = weakref.ref(self)
@@ -145,7 +157,9 @@ def execute_model(
145157
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
146158
(output, ) = self.collective_rpc("execute_model",
147159
args=(scheduler_output, ),
148-
rank0_reply_only=True,
160+
unique_reply_rank=self.output_rank,
161+
non_block=self.max_concurrent_batches
162+
> 1,
149163
timeout=EXECUTE_MODEL_TIMEOUT_S)
150164
return output
151165

@@ -154,7 +168,8 @@ def collective_rpc(self,
154168
timeout: Optional[float] = None,
155169
args: tuple = (),
156170
kwargs: Optional[dict] = None,
157-
rank0_reply_only: bool = False) -> list[Any]:
171+
non_block: bool = False,
172+
unique_reply_rank: Optional[int] = None) -> list[Any]:
158173
if self.is_failed:
159174
raise RuntimeError("Executor failed.")
160175

@@ -171,22 +186,35 @@ def collective_rpc(self,
171186
send_method = cloudpickle.dumps(
172187
method, protocol=pickle.HIGHEST_PROTOCOL)
173188
self.rpc_broadcast_mq.enqueue(
174-
(send_method, args, kwargs, rank0_reply_only))
189+
(send_method, args, kwargs, unique_reply_rank))
175190

176-
workers = (self.workers[0], ) if rank0_reply_only else self.workers
177-
responses = [None] * len(workers)
178-
for w in workers:
179-
dequeue_timeout = None if deadline is None else (
180-
deadline - time.monotonic())
191+
workers = (self.workers[unique_reply_rank],
192+
) if unique_reply_rank is not None else self.workers
193+
responses = []
194+
195+
def get_response(w: WorkerProcHandle,
196+
dequeue_timeout: Optional[float] = None,
197+
cancel_event: Optional[threading.Event] = None):
181198
status, result = w.worker_response_mq.dequeue(
182-
timeout=dequeue_timeout, cancel=self.shutdown_event)
199+
timeout=dequeue_timeout, cancel=cancel_event)
183200

184201
if status != WorkerProc.ResponseStatus.SUCCESS:
185202
raise RuntimeError(
186203
f"Worker failed with error '{result}', please check the"
187204
" stack trace above for the root cause")
205+
return result
188206

189-
responses[w.rank] = result
207+
for w in workers:
208+
dequeue_timeout = None if deadline is None else (
209+
deadline - time.monotonic())
210+
211+
if non_block:
212+
result = self.io_thread_pool.submit( # type: ignore
213+
get_response, w, dequeue_timeout, self.shutdown_event)
214+
else:
215+
result = get_response(w, dequeue_timeout)
216+
217+
responses.append(result)
190218

191219
return responses
192220
except TimeoutError as e:
@@ -225,6 +253,11 @@ def shutdown(self):
225253
if not getattr(self, 'shutting_down', False):
226254
self.shutting_down = True
227255
self.shutdown_event.set()
256+
257+
if self.io_thread_pool is not None:
258+
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
259+
self.io_thread_pool = None
260+
228261
for w in self.workers:
229262
w.worker_response_mq = None
230263
self._ensure_worker_termination([w.proc for w in self.workers])
@@ -235,6 +268,22 @@ def check_health(self) -> None:
235268
self.collective_rpc("check_health", timeout=10)
236269
return
237270

271+
@property
272+
def max_concurrent_batches(self) -> int:
273+
return self.parallel_config.pipeline_parallel_size
274+
275+
def _get_output_rank(self) -> int:
276+
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
277+
# (the first TP worker of the last PP stage).
278+
# Example:
279+
# Assuming TP=8, PP=4, then the world_size=32
280+
# 0-7, PP rank 0
281+
# 8-15, PP rank 1
282+
# 16-23, PP rank 2
283+
# 24-31, PP rank 3
284+
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
285+
return self.world_size - self.parallel_config.tensor_parallel_size
286+
238287

239288
@dataclass
240289
class UnreadyWorkerProcHandle:
@@ -280,12 +329,14 @@ def __init__(
280329
all_kwargs: list[dict] = [
281330
{} for _ in range(vllm_config.parallel_config.world_size)
282331
]
332+
is_driver_worker = (
333+
rank % vllm_config.parallel_config.tensor_parallel_size == 0)
283334
all_kwargs[rank] = {
284335
"vllm_config": vllm_config,
285336
"local_rank": local_rank,
286337
"rank": rank,
287338
"distributed_init_method": distributed_init_method,
288-
"is_driver_worker": rank == 0,
339+
"is_driver_worker": is_driver_worker,
289340
}
290341
wrapper.init_worker(all_kwargs)
291342
self.worker = wrapper
@@ -455,7 +506,7 @@ class ResponseStatus(Enum):
455506
def worker_busy_loop(self):
456507
"""Main busy loop for Multiprocessing Workers"""
457508
while True:
458-
method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue()
509+
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
459510

460511
try:
461512
if isinstance(method, str):
@@ -470,11 +521,11 @@ def worker_busy_loop(self):
470521
logger.exception("WorkerProc hit an exception.")
471522
# exception might not be serializable, so we convert it to
472523
# string, only for logging purpose.
473-
if not rank0_only or self.rank == 0:
524+
if output_rank is None or self.rank == output_rank:
474525
self.worker_response_mq.enqueue(
475526
(WorkerProc.ResponseStatus.FAILURE, str(e)))
476527
continue
477528

478-
if not rank0_only or self.rank == 0:
529+
if output_rank is None or self.rank == output_rank:
479530
self.worker_response_mq.enqueue(
480531
(WorkerProc.ResponseStatus.SUCCESS, output))

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,7 @@ def execute_model(
10161016
self,
10171017
scheduler_output: "SchedulerOutput",
10181018
intermediate_tensors: Optional[IntermediateTensors] = None,
1019-
) -> Union[ModelRunnerOutput, torch.Tensor]:
1019+
) -> Union[ModelRunnerOutput, IntermediateTensors]:
10201020
# Update KVConnector with the KVConnector metadata forward().
10211021
if has_kv_transfer_group():
10221022
get_kv_transfer_group().bind_connector_metadata(

vllm/v1/worker/gpu_worker.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
init_distributed_environment,
1616
set_custom_all_reduce)
1717
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
18-
from vllm.distributed.parallel_state import get_pp_group
18+
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
1919
from vllm.logger import init_logger
2020
from vllm.lora.request import LoRARequest
2121
from vllm.model_executor import set_random_seed
2222
from vllm.platforms import current_platform
23+
from vllm.sequence import IntermediateTensors
2324
from vllm.utils import GiB_bytes
2425
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2526
from vllm.v1.outputs import ModelRunnerOutput
@@ -266,7 +267,22 @@ def execute_model(
266267
self,
267268
scheduler_output: "SchedulerOutput",
268269
) -> Optional[ModelRunnerOutput]:
269-
output = self.model_runner.execute_model(scheduler_output)
270+
intermediate_tensors = None
271+
if not get_pp_group().is_first_rank:
272+
intermediate_tensors = IntermediateTensors(
273+
get_pp_group().recv_tensor_dict(
274+
all_gather_group=get_tp_group()))
275+
276+
output = self.model_runner.execute_model(scheduler_output,
277+
intermediate_tensors)
278+
279+
if not get_pp_group().is_last_rank:
280+
assert isinstance(output, IntermediateTensors)
281+
get_pp_group().send_tensor_dict(output.tensors,
282+
all_gather_group=get_tp_group())
283+
return None
284+
285+
assert isinstance(output, ModelRunnerOutput)
270286
return output if self.is_driver_worker else None
271287

272288
def profile(self, is_start: bool = True):

0 commit comments

Comments
 (0)