Skip to content

Commit 18dd5e0

Browse files
authored
[Model] Mamba2 causal conv1d Refactor to Split Prefill and Decode Requests for Corresponding Kernels (#17146)
Signed-off-by: Chih-Chieh-Yang <[email protected]>
1 parent 6de3e13 commit 18dd5e0

File tree

8 files changed

+153
-125
lines changed

8 files changed

+153
-125
lines changed

tests/kernels/mamba/test_mamba_ssm_ssd.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from einops import rearrange, repeat
77

88
from vllm.model_executor.layers.mamba.mamba2_metadata import (
9-
_seq_idx_to_chunk_indices_offsets)
9+
_query_start_loc_to_chunk_indices_offsets)
1010
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
1111
mamba_chunk_scan_combined)
1212
from vllm.platforms import current_platform
@@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
274274
last_taken, exhausted, n_heads,
275275
d_head, itype):
276276

277-
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
278-
seq_idx, chunk_size)
277+
chunk_indices, chunk_offsets = \
278+
_query_start_loc_to_chunk_indices_offsets(
279+
cu_seqlens, chunk_size, cu_seqlens[-1])
279280

280281
Y, new_states = mamba_chunk_scan_combined(
281282
X,

vllm/model_executor/layers/mamba/mamba2_metadata.py

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
@dataclass
1515
class Mamba2Metadata:
16-
has_prefill: bool
1716

1817
has_initial_states: torch.Tensor
1918
prep_initial_states: bool
@@ -24,21 +23,23 @@ class Mamba2Metadata:
2423
chunk_offsets: torch.Tensor
2524

2625

27-
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
26+
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
27+
chunk_size: int,
28+
total_seqlens: int):
2829

29-
# convert seq_idx to chunk indices and offsets
30-
# - derive the cu_seqlens
31-
_, cu_seqlens = torch.where(seq_idx.diff())
32-
cu_seqlens += 1
30+
cu_seqlens = query_start_loc[1:] # remove prepended 0
3331

3432
# outputs will have length expansion of chunks that do not divide
3533
# chunk_size
36-
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
37-
> 0).sum()
38-
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
39-
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
34+
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
35+
> 0).sum()
36+
chunk_indices = torch.arange(N,
37+
dtype=torch.int,
38+
device=query_start_loc.device)
39+
chunk_offsets = torch.zeros((N, ),
40+
dtype=torch.int,
41+
device=query_start_loc.device)
4042

41-
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
4243
p = 0 # num of insertions
4344
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
4445

@@ -60,48 +61,49 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
6061

6162
def prepare_mamba2_metadata(
6263
chunk_size: int,
63-
input_ids: torch.Tensor,
6464
attn_metadata: AttentionMetadata,
6565
) -> Mamba2Metadata:
6666

67+
# compute number of prefill and decode requests
68+
# NOTE: in V0 we assume prefills are before decodes
69+
num_prefills = attn_metadata.num_prefills
70+
num_prefill_tokens = attn_metadata.num_prefill_tokens
71+
72+
seq_idx = None
73+
chunk_indices, chunk_offsets = None, None
6774
# Need flags to indicate if there are initial states
6875
# currently we really only support the FlashAttention backend
6976
has_initial_states = None
7077
prep_initial_states = False
71-
if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata,
72-
PlaceholderAttentionMetadata))
73-
and attn_metadata.context_lens_tensor is not None):
74-
has_initial_states = attn_metadata.context_lens_tensor > 0
75-
# precompute flag to avoid device syncs later in mamba2 forwards
76-
prep_initial_states = torch.any(has_initial_states).item()
77-
78-
has_prefill = attn_metadata.num_prefills > 0
7978

80-
seq_idx = None
81-
chunk_indices, chunk_offsets = None, None
82-
if has_prefill:
83-
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
84-
for i, (srt, end) in enumerate(
85-
zip(
86-
attn_metadata.query_start_loc,
87-
attn_metadata.query_start_loc[1:],
88-
)):
89-
seq_idx[srt:end] = i
79+
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
80+
if num_prefills > 0:
81+
if (isinstance(attn_metadata,
82+
(FlashAttentionMetadata, XFormersMetadata,
83+
PlaceholderAttentionMetadata))
84+
and attn_metadata.context_lens_tensor is not None):
85+
has_initial_states = \
86+
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
87+
# precompute flag to avoid device syncs in mamba2 layer forwards
88+
# prep is only needed for mamba2 ssd prefill processing
89+
prep_initial_states = torch.any(has_initial_states).item()
90+
91+
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
92+
seq_idx = torch.repeat_interleave(torch.arange(
93+
num_prefills, dtype=torch.int32, device=query_start_loc.device),
94+
query_start_loc.diff(),
95+
output_size=num_prefill_tokens)
9096
seq_idx.unsqueeze_(0)
9197

92-
# compute metadata for chunked prefill.
93-
# actually this is only needed if there are initial states,
94-
# but this is determinable only from attention metadata yet
95-
# unavailable from the top-level model forward. Rather than
96-
# complicating things to extract said metadata, we simply just
97-
# compute them once at the top level model forward and reuse
98-
# them in mamba layers. If not needed, they will be ignored
99-
# inside mamba kernels.
100-
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
101-
seq_idx, chunk_size)
102-
103-
return Mamba2Metadata(has_prefill=has_prefill,
104-
has_initial_states=has_initial_states,
98+
# We compute metadata for chunked prefill once at the top level model
99+
# forward and reuse them in mamba layers. If not needed, they will be
100+
# ignored inside mamba kernels.
101+
if prep_initial_states:
102+
chunk_indices, chunk_offsets = \
103+
_query_start_loc_to_chunk_indices_offsets(
104+
query_start_loc, chunk_size, num_prefill_tokens)
105+
106+
return Mamba2Metadata(has_initial_states=has_initial_states,
105107
prep_initial_states=prep_initial_states,
106108
chunk_size=chunk_size,
107109
seq_idx=seq_idx,

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 104 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,15 @@ def forward_cuda(
388388
# mamba2_metadata contains metadata necessary for the mamba2 triton
389389
# kernels to operate in continuous batching and in chunked prefill
390390
# modes; they are computed at top-level model forward since they
391-
# are the same and reused for all mamba layers in the same iteration
391+
# stay the same and reused for all mamba layers in the same iteration
392392
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
393393

394-
seq_len, _ = hidden_states.shape
394+
num_prefills = attn_metadata.num_prefills # request count
395+
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
396+
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
397+
has_prefill = num_prefills > 0
398+
has_decode = num_decodes > 0
399+
395400
groups_time_state_size = self.n_groups * self.ssm_state_size
396401

397402
# 1. Gated MLP's linear projection
@@ -406,44 +411,32 @@ def forward_cuda(
406411
dim=-1,
407412
)
408413

409-
# 2. Convolution sequence transformation
410414
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
411415
self.conv1d.weight.size(2))
412416

413-
if mamba2_metadata.has_prefill:
414-
# |---------- N-1 iteration --------|
415-
# |---------------- N iteration ---------------------|
416-
# |- tokenA -|......................|-- newTokens ---|
417-
# |---------- context_len ----------|
418-
# |-------------------- seq_len ---------------------|
419-
# |-- query_len ---|
420-
421-
# - "cache_indices" updates the conv_state cache in positions
422-
# pointed to by "mamba_cache_params.state_indices_tensor"
423-
hidden_states_B_C = causal_conv1d_fn(
424-
hidden_states_B_C.transpose(0, 1),
425-
conv_weights,
426-
self.conv1d.bias,
427-
activation=self.activation,
428-
conv_states=mamba_cache_params.conv_state,
429-
has_initial_state=mamba2_metadata.has_initial_states,
430-
cache_indices=mamba_cache_params.state_indices_tensor,
431-
query_start_loc=attn_metadata.query_start_loc).transpose(
432-
0, 1)[:seq_len]
433-
434-
# TODO: Why is this needed?
435-
hidden_states_B_C = hidden_states_B_C.contiguous()
436-
else:
437-
hidden_states_B_C = causal_conv1d_update(
438-
hidden_states_B_C,
439-
mamba_cache_params.conv_state,
440-
conv_weights,
441-
self.conv1d.bias,
442-
self.activation,
443-
conv_state_indices=mamba_cache_params.state_indices_tensor)
417+
# Separate prefill and decode by splitting varlen input
418+
# Split along token dimension
419+
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
420+
hidden_states_B_C,
421+
[num_prefill_tokens, num_decodes],
422+
dim=0,
423+
)
424+
dt_p, dt_d = torch.split(
425+
dt,
426+
[num_prefill_tokens, num_decodes],
427+
dim=0,
428+
)
429+
# Split along batch dimension
430+
state_indices_tensor_p, state_indices_tensor_d = torch.split(
431+
mamba_cache_params.state_indices_tensor,
432+
[num_prefills, num_decodes],
433+
dim=0,
434+
)
435+
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
436+
if has_prefill else None)
444437

445438
# - get hidden_states, B and C after depthwise convolution.
446-
hidden_states, B, C = torch.split(
439+
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
447440
hidden_states_B_C,
448441
[
449442
self.intermediate_size // self.tp_size,
@@ -453,32 +446,56 @@ def forward_cuda(
453446
dim=-1,
454447
)
455448

456-
# 3. State Space Model sequence transformation
457-
if mamba2_metadata.has_prefill:
449+
ssd_output_list = []
450+
451+
# Process prefill requests
452+
if has_prefill:
453+
# 2. Convolution sequence transformation
454+
# - "cache_indices" updates the conv_state cache in positions
455+
# pointed to by "mamba_cache_params.state_indices_tensor"
456+
hidden_states_B_C_p = causal_conv1d_fn(
457+
hidden_states_B_C_p.transpose(0, 1),
458+
conv_weights,
459+
self.conv1d.bias,
460+
activation=self.activation,
461+
conv_states=mamba_cache_params.conv_state,
462+
has_initial_state=mamba2_metadata.has_initial_states,
463+
cache_indices=state_indices_tensor_p,
464+
query_start_loc=query_start_loc_p).transpose(
465+
0, 1)[:num_prefill_tokens]
466+
467+
# TODO: Why is this needed?
468+
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
469+
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
470+
hidden_states_B_C_p)
471+
472+
# 3. State Space Model sequence transformation
458473
initial_states = None
459474
if (mamba2_metadata.has_initial_states is not None
460475
and mamba2_metadata.prep_initial_states):
461476
# making a copy of the states
462477
initial_states = torch.where(
463478
mamba2_metadata.has_initial_states[:, None, None, None],
464-
mamba_cache_params.ssm_state[
465-
mamba_cache_params.state_indices_tensor], 0)
479+
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
466480

467481
scan_output, varlen_state = mamba_chunk_scan_combined(
468-
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
469-
self.head_dim),
470-
dt.unsqueeze(0),
482+
hidden_states_p.view(1, num_prefill_tokens,
483+
self.num_heads // self.tp_size,
484+
self.head_dim),
485+
dt_p.unsqueeze(0),
471486
self.A,
472-
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
473-
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
487+
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
488+
-1),
489+
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
490+
-1),
474491
chunk_size=mamba2_metadata.chunk_size,
475492
D=self.D,
476493
z=None,
477494
dt_bias=self.dt_bias,
478495
seq_idx=mamba2_metadata.seq_idx,
479496
chunk_indices=mamba2_metadata.chunk_indices,
480497
chunk_offsets=mamba2_metadata.chunk_offsets,
481-
cu_seqlens=attn_metadata.query_start_loc,
498+
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
482499
initial_states=initial_states,
483500
return_varlen_states=True,
484501
return_final_states=False,
@@ -487,52 +504,65 @@ def forward_cuda(
487504
)
488505

489506
# update ssm states
490-
# - varlen state is a (batch, nheads, headdim, dstate) tensor
491-
mamba_cache_params.ssm_state[
492-
mamba_cache_params.state_indices_tensor] = varlen_state
507+
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
508+
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
493509

494510
# - reshape
495-
hidden_states = scan_output.view(seq_len, -1)
496-
else:
511+
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
497512

513+
# Process decode requests
514+
if has_decode:
515+
# 2. Convolution sequence transformation
516+
hidden_states_B_C_d = causal_conv1d_update(
517+
hidden_states_B_C_d,
518+
mamba_cache_params.conv_state,
519+
conv_weights,
520+
self.conv1d.bias,
521+
self.activation,
522+
conv_state_indices=state_indices_tensor_d)
523+
524+
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
525+
hidden_states_B_C_d)
526+
527+
# 3. State Space Model sequence transformation
498528
n_groups = self.n_groups // self.tp_size
499-
A = self.A[:, None, ...][:, :, None].expand(
529+
A_d = self.A[:, None, ...][:, :, None].expand(
500530
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
501-
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
531+
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
502532
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
503-
D = self.D[:, None, ...].expand(-1, self.head_dim)
504-
B = B.view(-1, n_groups, B.shape[1] // n_groups)
505-
C = C.view(-1, n_groups, C.shape[1] // n_groups)
506-
hidden_states_reshaped = hidden_states.view(
533+
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
534+
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
535+
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
536+
hidden_states_d = hidden_states_d.view(
507537
-1, self.num_heads // self.tp_size, self.head_dim)
508538

509-
# - the hidden is reshaped into number of current batches
510-
# - in this case there is no more prefill, so the batches gen
511-
# 1 token at a time
512-
# - thus hidden will be (bs, num_heads, head_dim)
539+
# - the hidden is reshaped into (bs, num_heads, head_dim)
513540
# - mamba_cache_params.ssm_state's slots will be selected
514-
# using "mamba_cache_params.state_indices_tensor", just as
515-
# above in the prefill case
541+
# using state_indices_tensor_d
516542

517-
hidden_states = selective_state_update(
543+
hidden_states_d = selective_state_update(
518544
mamba_cache_params.ssm_state,
519-
hidden_states_reshaped,
520-
dt,
521-
A,
522-
B,
523-
C,
524-
D,
545+
hidden_states_d,
546+
dt_d,
547+
A_d,
548+
B_d,
549+
C_d,
550+
D_d,
525551
z=None,
526552
dt_bias=dt_bias,
527553
dt_softplus=True,
528-
state_batch_indices=mamba_cache_params.state_indices_tensor,
554+
state_batch_indices=state_indices_tensor_d,
529555
)
530-
hidden_states = hidden_states.view(
531-
-1, (self.num_heads // self.tp_size) * self.head_dim)
556+
ssd_output_list.append(
557+
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
558+
self.head_dim))
559+
560+
# Merge prefill and decode outputs before passing to gated MLP
561+
hidden_states = torch.vstack(ssd_output_list)
532562

533-
# # 4. gated MLP
563+
# 4. gated MLP
534564
hidden_states = self.norm(hidden_states, gate)
535565

536-
# # 5. Final linear projection
566+
# 5. Final linear projection
537567
out, _ = self.out_proj(hidden_states)
538568
return out

vllm/model_executor/layers/mamba/ops/ssd_combined.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def _mamba_chunk_scan_combined_fwd(x,
4040
_, _, ngroups, dstate = B.shape
4141
assert nheads % ngroups == 0
4242
assert B.shape == (batch, seqlen, ngroups, dstate)
43-
assert x.shape == (batch, seqlen, nheads, headdim)
4443
assert dt.shape == (batch, seqlen, nheads)
4544
assert A.shape == (nheads, )
4645
assert C.shape == B.shape

0 commit comments

Comments
 (0)