@@ -105,6 +105,7 @@ class LocalAttentionMetadata:
105
105
local_block_table : torch .Tensor
106
106
local_max_query_len : int
107
107
local_max_seq_len : int
108
+ local_scheduler_metadata : Optional [torch .Tensor ]
108
109
109
110
local_attn_metadata : Optional [LocalAttentionMetadata ] = None
110
111
@@ -282,7 +283,9 @@ def __init__(self, runner: "GPUModelRunner"):
282
283
283
284
self .runner = runner
284
285
self .aot_schedule = (get_flash_attn_version () == 3 )
285
- self .num_heads = model_config .get_num_attention_heads (
286
+ self .num_heads_q = model_config .get_num_attention_heads (
287
+ runner .parallel_config )
288
+ self .num_heads_kv = model_config .get_num_kv_heads (
286
289
runner .parallel_config )
287
290
self .headdim = model_config .get_head_size ()
288
291
self .page_size = self .runner .block_size
@@ -304,6 +307,23 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
304
307
slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
305
308
self .runner .device , non_blocking = True ).long ()
306
309
310
+ def schedule (batch_size , cu_query_lens , max_query_len , seqlens ,
311
+ max_seq_len , causal ):
312
+ if self .aot_schedule :
313
+ return get_scheduler_metadata (
314
+ batch_size = batch_size ,
315
+ max_seqlen_q = max_query_len ,
316
+ max_seqlen_k = max_seq_len ,
317
+ cache_seqlens = seqlens ,
318
+ num_heads_q = self .num_heads_q ,
319
+ num_heads_kv = self .num_heads_kv ,
320
+ headdim = self .headdim ,
321
+ page_size = self .page_size ,
322
+ cu_seqlens_q = cu_query_lens ,
323
+ causal = causal ,
324
+ )
325
+ return None
326
+
307
327
# for local attention
308
328
local_attn_metadata = None
309
329
if self .runner .attention_chunk_size is not None :
@@ -315,36 +335,31 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
315
335
block_table ,
316
336
self .runner .block_size ,
317
337
)
338
+ local_query_start_loc = torch .from_numpy (virt_q_cu_seqlens_np ).to (
339
+ self .runner .device , non_blocking = True )
340
+ local_seqused_k = torch .from_numpy (virt_k_seqlens_np ).to (
341
+ self .runner .device , non_blocking = True )
342
+ local_max_query_len = seqlens_q_local_np .max ()
343
+ local_max_seq_len = virt_k_seqlens_np .max ()
344
+ local_scheduler_metadata = schedule (
345
+ batch_size = local_query_start_loc .shape [0 ] - 1 ,
346
+ cu_query_lens = local_query_start_loc ,
347
+ max_query_len = local_max_query_len ,
348
+ seqlens = local_seqused_k ,
349
+ max_seq_len = local_max_seq_len ,
350
+ causal = True )
351
+
318
352
local_attn_metadata = FlashAttentionMetadata .LocalAttentionMetadata (
319
- local_query_start_loc = torch .from_numpy (
320
- virt_q_cu_seqlens_np ).to (self .runner .device ,
321
- non_blocking = True ),
322
- local_seqused_k = torch .from_numpy (virt_k_seqlens_np ).to (
323
- self .runner .device , non_blocking = True ),
353
+ local_query_start_loc = local_query_start_loc ,
354
+ local_seqused_k = local_seqused_k ,
324
355
local_block_table = virt_block_table ,
325
- local_max_query_len = seqlens_q_local_np .max (),
326
- local_max_seq_len = virt_k_seqlens_np .max (),
356
+ local_max_query_len = local_max_query_len ,
357
+ local_max_seq_len = local_max_seq_len ,
358
+ local_scheduler_metadata = local_scheduler_metadata ,
327
359
)
328
360
329
361
use_cascade = common_prefix_len > 0
330
362
331
- def schedule (cu_query_lens , max_query_len , seqlens , max_seq_len ,
332
- causal ):
333
- if self .aot_schedule :
334
- return get_scheduler_metadata (
335
- batch_size = num_reqs ,
336
- max_seqlen_q = max_query_len ,
337
- max_seqlen_k = max_seq_len ,
338
- cache_seqlens = seqlens ,
339
- num_heads_q = self .num_heads ,
340
- num_heads_kv = self .num_heads ,
341
- headdim = self .headdim ,
342
- page_size = self .page_size ,
343
- cu_seqlens_q = cu_query_lens ,
344
- causal = causal ,
345
- )
346
- return None
347
-
348
363
if use_cascade :
349
364
cu_prefix_query_lens = torch .tensor ([0 , num_actual_tokens ],
350
365
dtype = torch .int32 ,
@@ -357,12 +372,14 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
357
372
suffix_kv_lens = torch .from_numpy (suffix_kv_lens ).to (
358
373
self .runner .device )
359
374
prefix_scheduler_metadata = schedule (
375
+ batch_size = num_reqs ,
360
376
cu_query_lens = cu_prefix_query_lens ,
361
377
max_query_len = num_actual_tokens ,
362
378
seqlens = prefix_kv_lens ,
363
379
max_seq_len = common_prefix_len ,
364
380
causal = False )
365
- scheduler_metadata = schedule (cu_query_lens = query_start_loc ,
381
+ scheduler_metadata = schedule (batch_size = num_reqs ,
382
+ cu_query_lens = query_start_loc ,
366
383
max_query_len = max_query_len ,
367
384
seqlens = suffix_kv_lens ,
368
385
max_seq_len = max_seq_len -
@@ -373,7 +390,8 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
373
390
prefix_kv_lens = None
374
391
suffix_kv_lens = None
375
392
prefix_scheduler_metadata = None
376
- scheduler_metadata = schedule (cu_query_lens = query_start_loc ,
393
+ scheduler_metadata = schedule (batch_size = num_reqs ,
394
+ cu_query_lens = query_start_loc ,
377
395
max_query_len = max_query_len ,
378
396
seqlens = seq_lens ,
379
397
max_seq_len = max_seq_len ,
@@ -540,12 +558,14 @@ def forward(
540
558
max_seqlen_q = local_metadata .local_max_query_len
541
559
max_seqlen_k = local_metadata .local_max_seq_len
542
560
block_table = local_metadata .local_block_table
561
+ scheduler_metadata = local_metadata .local_scheduler_metadata
543
562
else :
544
563
cu_seqlens_q = attn_metadata .query_start_loc
545
564
seqused_k = attn_metadata .seq_lens
546
565
max_seqlen_q = attn_metadata .max_query_len
547
566
max_seqlen_k = attn_metadata .max_seq_len
548
567
block_table = attn_metadata .block_table
568
+ scheduler_metadata = attn_metadata .scheduler_metadata
549
569
550
570
descale_shape = (cu_seqlens_q .shape [0 ] - 1 , key .shape [1 ])
551
571
@@ -564,7 +584,7 @@ def forward(
564
584
window_size = self .sliding_window ,
565
585
block_table = block_table ,
566
586
softcap = self .logits_soft_cap ,
567
- scheduler_metadata = attn_metadata . scheduler_metadata ,
587
+ scheduler_metadata = scheduler_metadata ,
568
588
fa_version = self .vllm_flash_attn_version ,
569
589
q_descale = layer ._q_scale .expand (descale_shape ),
570
590
k_descale = layer ._k_scale .expand (descale_shape ),
0 commit comments