8
8
import time
9
9
import traceback
10
10
import weakref
11
- from concurrent .futures import Future
11
+ from concurrent .futures import Future , ThreadPoolExecutor
12
12
from dataclasses import dataclass
13
13
from enum import Enum , auto
14
14
from functools import partial
@@ -53,10 +53,11 @@ def _init_executor(self) -> None:
53
53
54
54
self .world_size = self .parallel_config .world_size
55
55
tensor_parallel_size = self .parallel_config .tensor_parallel_size
56
- assert self .world_size == tensor_parallel_size , (
56
+ pp_parallel_size = self .parallel_config .pipeline_parallel_size
57
+ assert self .world_size == tensor_parallel_size * pp_parallel_size , (
57
58
f"world_size ({ self .world_size } ) must be equal to the "
58
- f"tensor_parallel_size ({ tensor_parallel_size } ). "
59
- f"Pipeline parallelism is not yet implemented in v1 " )
59
+ f"tensor_parallel_size ({ tensor_parallel_size } ) x pipeline "
60
+ f"_parallel_size ( { pp_parallel_size } ). " )
60
61
61
62
# Set multiprocessing envs that are common to V0 and V1
62
63
set_multiprocessing_worker_envs (self .parallel_config )
@@ -104,6 +105,17 @@ def _init_executor(self) -> None:
104
105
self ._ensure_worker_termination (
105
106
[w .proc for w in unready_workers ])
106
107
108
+ # For pipeline parallel, we use a thread pool for asynchronous
109
+ # execute_model.
110
+ self .io_thread_pool : Optional [ThreadPoolExecutor ] = None
111
+ if self .max_concurrent_batches > 1 :
112
+ # Note: must use only 1 IO thread to keep dequeue sequence
113
+ # from the response queue
114
+ self .io_thread_pool = ThreadPoolExecutor (
115
+ max_workers = 1 , thread_name_prefix = "mp_exec_io" )
116
+
117
+ self .output_rank = self ._get_output_rank ()
118
+
107
119
def start_worker_monitor (self ):
108
120
workers = self .workers
109
121
self_ref = weakref .ref (self )
@@ -145,7 +157,9 @@ def execute_model(
145
157
) -> Union [ModelRunnerOutput , Future [ModelRunnerOutput ]]:
146
158
(output , ) = self .collective_rpc ("execute_model" ,
147
159
args = (scheduler_output , ),
148
- rank0_reply_only = True ,
160
+ unique_reply_rank = self .output_rank ,
161
+ non_block = self .max_concurrent_batches
162
+ > 1 ,
149
163
timeout = EXECUTE_MODEL_TIMEOUT_S )
150
164
return output
151
165
@@ -154,7 +168,8 @@ def collective_rpc(self,
154
168
timeout : Optional [float ] = None ,
155
169
args : tuple = (),
156
170
kwargs : Optional [dict ] = None ,
157
- rank0_reply_only : bool = False ) -> list [Any ]:
171
+ non_block : bool = False ,
172
+ unique_reply_rank : Optional [int ] = None ) -> list [Any ]:
158
173
if self .is_failed :
159
174
raise RuntimeError ("Executor failed." )
160
175
@@ -171,22 +186,35 @@ def collective_rpc(self,
171
186
send_method = cloudpickle .dumps (
172
187
method , protocol = pickle .HIGHEST_PROTOCOL )
173
188
self .rpc_broadcast_mq .enqueue (
174
- (send_method , args , kwargs , rank0_reply_only ))
189
+ (send_method , args , kwargs , unique_reply_rank ))
175
190
176
- workers = (self .workers [0 ], ) if rank0_reply_only else self .workers
177
- responses = [None ] * len (workers )
178
- for w in workers :
179
- dequeue_timeout = None if deadline is None else (
180
- deadline - time .monotonic ())
191
+ workers = (self .workers [unique_reply_rank ],
192
+ ) if unique_reply_rank is not None else self .workers
193
+ responses = []
194
+
195
+ def get_response (w : WorkerProcHandle ,
196
+ dequeue_timeout : Optional [float ] = None ,
197
+ cancel_event : Optional [threading .Event ] = None ):
181
198
status , result = w .worker_response_mq .dequeue (
182
- timeout = dequeue_timeout , cancel = self . shutdown_event )
199
+ timeout = dequeue_timeout , cancel = cancel_event )
183
200
184
201
if status != WorkerProc .ResponseStatus .SUCCESS :
185
202
raise RuntimeError (
186
203
f"Worker failed with error '{ result } ', please check the"
187
204
" stack trace above for the root cause" )
205
+ return result
188
206
189
- responses [w .rank ] = result
207
+ for w in workers :
208
+ dequeue_timeout = None if deadline is None else (
209
+ deadline - time .monotonic ())
210
+
211
+ if non_block :
212
+ result = self .io_thread_pool .submit ( # type: ignore
213
+ get_response , w , dequeue_timeout , self .shutdown_event )
214
+ else :
215
+ result = get_response (w , dequeue_timeout )
216
+
217
+ responses .append (result )
190
218
191
219
return responses
192
220
except TimeoutError as e :
@@ -225,6 +253,11 @@ def shutdown(self):
225
253
if not getattr (self , 'shutting_down' , False ):
226
254
self .shutting_down = True
227
255
self .shutdown_event .set ()
256
+
257
+ if self .io_thread_pool is not None :
258
+ self .io_thread_pool .shutdown (wait = False , cancel_futures = True )
259
+ self .io_thread_pool = None
260
+
228
261
for w in self .workers :
229
262
w .worker_response_mq = None
230
263
self ._ensure_worker_termination ([w .proc for w in self .workers ])
@@ -235,6 +268,22 @@ def check_health(self) -> None:
235
268
self .collective_rpc ("check_health" , timeout = 10 )
236
269
return
237
270
271
+ @property
272
+ def max_concurrent_batches (self ) -> int :
273
+ return self .parallel_config .pipeline_parallel_size
274
+
275
+ def _get_output_rank (self ) -> int :
276
+ # Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
277
+ # (the first TP worker of the last PP stage).
278
+ # Example:
279
+ # Assuming TP=8, PP=4, then the world_size=32
280
+ # 0-7, PP rank 0
281
+ # 8-15, PP rank 1
282
+ # 16-23, PP rank 2
283
+ # 24-31, PP rank 3
284
+ # so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
285
+ return self .world_size - self .parallel_config .tensor_parallel_size
286
+
238
287
239
288
@dataclass
240
289
class UnreadyWorkerProcHandle :
@@ -280,12 +329,14 @@ def __init__(
280
329
all_kwargs : list [dict ] = [
281
330
{} for _ in range (vllm_config .parallel_config .world_size )
282
331
]
332
+ is_driver_worker = (
333
+ rank % vllm_config .parallel_config .tensor_parallel_size == 0 )
283
334
all_kwargs [rank ] = {
284
335
"vllm_config" : vllm_config ,
285
336
"local_rank" : local_rank ,
286
337
"rank" : rank ,
287
338
"distributed_init_method" : distributed_init_method ,
288
- "is_driver_worker" : rank == 0 ,
339
+ "is_driver_worker" : is_driver_worker ,
289
340
}
290
341
wrapper .init_worker (all_kwargs )
291
342
self .worker = wrapper
@@ -455,7 +506,7 @@ class ResponseStatus(Enum):
455
506
def worker_busy_loop (self ):
456
507
"""Main busy loop for Multiprocessing Workers"""
457
508
while True :
458
- method , args , kwargs , rank0_only = self .rpc_broadcast_mq .dequeue ()
509
+ method , args , kwargs , output_rank = self .rpc_broadcast_mq .dequeue ()
459
510
460
511
try :
461
512
if isinstance (method , str ):
@@ -470,11 +521,11 @@ def worker_busy_loop(self):
470
521
logger .exception ("WorkerProc hit an exception." )
471
522
# exception might not be serializable, so we convert it to
472
523
# string, only for logging purpose.
473
- if not rank0_only or self .rank == 0 :
524
+ if output_rank is None or self .rank == output_rank :
474
525
self .worker_response_mq .enqueue (
475
526
(WorkerProc .ResponseStatus .FAILURE , str (e )))
476
527
continue
477
528
478
- if not rank0_only or self .rank == 0 :
529
+ if output_rank is None or self .rank == output_rank :
479
530
self .worker_response_mq .enqueue (
480
531
(WorkerProc .ResponseStatus .SUCCESS , output ))
0 commit comments