Skip to content

Commit d0da99f

Browse files
[BugFix] llama4 fa3 fix - RuntimeError: scheduler_metadata must have shape (metadata_size) (#16998)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent b2f195c commit d0da99f

File tree

1 file changed

+48
-28
lines changed

1 file changed

+48
-28
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class LocalAttentionMetadata:
105105
local_block_table: torch.Tensor
106106
local_max_query_len: int
107107
local_max_seq_len: int
108+
local_scheduler_metadata: Optional[torch.Tensor]
108109

109110
local_attn_metadata: Optional[LocalAttentionMetadata] = None
110111

@@ -282,7 +283,9 @@ def __init__(self, runner: "GPUModelRunner"):
282283

283284
self.runner = runner
284285
self.aot_schedule = (get_flash_attn_version() == 3)
285-
self.num_heads = model_config.get_num_attention_heads(
286+
self.num_heads_q = model_config.get_num_attention_heads(
287+
runner.parallel_config)
288+
self.num_heads_kv = model_config.get_num_kv_heads(
286289
runner.parallel_config)
287290
self.headdim = model_config.get_head_size()
288291
self.page_size = self.runner.block_size
@@ -304,6 +307,23 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
304307
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
305308
self.runner.device, non_blocking=True).long()
306309

310+
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
311+
max_seq_len, causal):
312+
if self.aot_schedule:
313+
return get_scheduler_metadata(
314+
batch_size=batch_size,
315+
max_seqlen_q=max_query_len,
316+
max_seqlen_k=max_seq_len,
317+
cache_seqlens=seqlens,
318+
num_heads_q=self.num_heads_q,
319+
num_heads_kv=self.num_heads_kv,
320+
headdim=self.headdim,
321+
page_size=self.page_size,
322+
cu_seqlens_q=cu_query_lens,
323+
causal=causal,
324+
)
325+
return None
326+
307327
# for local attention
308328
local_attn_metadata = None
309329
if self.runner.attention_chunk_size is not None:
@@ -315,36 +335,31 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
315335
block_table,
316336
self.runner.block_size,
317337
)
338+
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
339+
self.runner.device, non_blocking=True)
340+
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
341+
self.runner.device, non_blocking=True)
342+
local_max_query_len = seqlens_q_local_np.max()
343+
local_max_seq_len = virt_k_seqlens_np.max()
344+
local_scheduler_metadata = schedule(
345+
batch_size=local_query_start_loc.shape[0] - 1,
346+
cu_query_lens=local_query_start_loc,
347+
max_query_len=local_max_query_len,
348+
seqlens=local_seqused_k,
349+
max_seq_len=local_max_seq_len,
350+
causal=True)
351+
318352
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
319-
local_query_start_loc=torch.from_numpy(
320-
virt_q_cu_seqlens_np).to(self.runner.device,
321-
non_blocking=True),
322-
local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to(
323-
self.runner.device, non_blocking=True),
353+
local_query_start_loc=local_query_start_loc,
354+
local_seqused_k=local_seqused_k,
324355
local_block_table=virt_block_table,
325-
local_max_query_len=seqlens_q_local_np.max(),
326-
local_max_seq_len=virt_k_seqlens_np.max(),
356+
local_max_query_len=local_max_query_len,
357+
local_max_seq_len=local_max_seq_len,
358+
local_scheduler_metadata=local_scheduler_metadata,
327359
)
328360

329361
use_cascade = common_prefix_len > 0
330362

331-
def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
332-
causal):
333-
if self.aot_schedule:
334-
return get_scheduler_metadata(
335-
batch_size=num_reqs,
336-
max_seqlen_q=max_query_len,
337-
max_seqlen_k=max_seq_len,
338-
cache_seqlens=seqlens,
339-
num_heads_q=self.num_heads,
340-
num_heads_kv=self.num_heads,
341-
headdim=self.headdim,
342-
page_size=self.page_size,
343-
cu_seqlens_q=cu_query_lens,
344-
causal=causal,
345-
)
346-
return None
347-
348363
if use_cascade:
349364
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
350365
dtype=torch.int32,
@@ -357,12 +372,14 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
357372
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
358373
self.runner.device)
359374
prefix_scheduler_metadata = schedule(
375+
batch_size=num_reqs,
360376
cu_query_lens=cu_prefix_query_lens,
361377
max_query_len=num_actual_tokens,
362378
seqlens=prefix_kv_lens,
363379
max_seq_len=common_prefix_len,
364380
causal=False)
365-
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
381+
scheduler_metadata = schedule(batch_size=num_reqs,
382+
cu_query_lens=query_start_loc,
366383
max_query_len=max_query_len,
367384
seqlens=suffix_kv_lens,
368385
max_seq_len=max_seq_len -
@@ -373,7 +390,8 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
373390
prefix_kv_lens = None
374391
suffix_kv_lens = None
375392
prefix_scheduler_metadata = None
376-
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
393+
scheduler_metadata = schedule(batch_size=num_reqs,
394+
cu_query_lens=query_start_loc,
377395
max_query_len=max_query_len,
378396
seqlens=seq_lens,
379397
max_seq_len=max_seq_len,
@@ -540,12 +558,14 @@ def forward(
540558
max_seqlen_q = local_metadata.local_max_query_len
541559
max_seqlen_k = local_metadata.local_max_seq_len
542560
block_table = local_metadata.local_block_table
561+
scheduler_metadata = local_metadata.local_scheduler_metadata
543562
else:
544563
cu_seqlens_q = attn_metadata.query_start_loc
545564
seqused_k = attn_metadata.seq_lens
546565
max_seqlen_q = attn_metadata.max_query_len
547566
max_seqlen_k = attn_metadata.max_seq_len
548567
block_table = attn_metadata.block_table
568+
scheduler_metadata = attn_metadata.scheduler_metadata
549569

550570
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
551571

@@ -564,7 +584,7 @@ def forward(
564584
window_size=self.sliding_window,
565585
block_table=block_table,
566586
softcap=self.logits_soft_cap,
567-
scheduler_metadata=attn_metadata.scheduler_metadata,
587+
scheduler_metadata=scheduler_metadata,
568588
fa_version=self.vllm_flash_attn_version,
569589
q_descale=layer._q_scale.expand(descale_shape),
570590
k_descale=layer._k_scale.expand(descale_shape),

0 commit comments

Comments
 (0)