Skip to content

Commit ba8beaf

Browse files
committed
split inputs also for causal_conv1d
Signed-off-by: Chih-Chieh-Yang <[email protected]>
1 parent 20452d3 commit ba8beaf

File tree

2 files changed

+65
-77
lines changed

2 files changed

+65
-77
lines changed

vllm/model_executor/layers/mamba/mamba2_metadata.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,25 @@ def prepare_mamba2_metadata(
6969
num_prefills = attn_metadata.num_prefills
7070
num_prefill_tokens = attn_metadata.num_prefill_tokens
7171

72+
seq_idx = None
73+
chunk_indices, chunk_offsets = None, None
7274
# Need flags to indicate if there are initial states
7375
# currently we really only support the FlashAttention backend
7476
has_initial_states = None
7577
prep_initial_states = False
76-
if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata,
77-
PlaceholderAttentionMetadata))
78-
and attn_metadata.context_lens_tensor is not None):
79-
# keeping flags for both prefill and decode causal_conv1d varlen
80-
has_initial_states = attn_metadata.context_lens_tensor > 0 # [batch,]
81-
# precompute flag to avoid device syncs later in mamba2 layer forwards
82-
# prep is only needed for mamba2 ssd prefill processing
83-
prep_initial_states = torch.any(
84-
has_initial_states[:num_prefills]).item()
8578

8679
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
87-
seq_idx = None
88-
chunk_indices, chunk_offsets = None, None
8980
if num_prefills > 0:
81+
if (isinstance(attn_metadata,
82+
(FlashAttentionMetadata, XFormersMetadata,
83+
PlaceholderAttentionMetadata))
84+
and attn_metadata.context_lens_tensor is not None):
85+
has_initial_states = \
86+
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
87+
# precompute flag to avoid device syncs in mamba2 layer forwards
88+
# prep is only needed for mamba2 ssd prefill processing
89+
prep_initial_states = torch.any(has_initial_states).item()
90+
9091
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
9192
seq_idx = torch.repeat_interleave(torch.arange(
9293
num_prefills, dtype=torch.int32, device=query_start_loc.device),

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 53 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -412,71 +412,13 @@ def forward_cuda(
412412
dim=-1,
413413
)
414414

415-
# 2. Convolution sequence transformation
416415
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
417416
self.conv1d.weight.size(2))
418417

419-
# causal_conv1d_fn deals with both prefill and decode if input
420-
# has prefill requests.
421-
if has_prefill:
422-
# |---------- N-1 iteration --------|
423-
# |---------------- N iteration ---------------------|
424-
# |- tokenA -|......................|-- newTokens ---|
425-
# |---------- context_len ----------|
426-
# |-------------------- seq_len ---------------------|
427-
# |-- query_len ---|
428-
429-
# - "cache_indices" updates the conv_state cache in positions
430-
# pointed to by "mamba_cache_params.state_indices_tensor"
431-
hidden_states_B_C = causal_conv1d_fn(
432-
hidden_states_B_C.transpose(0, 1),
433-
conv_weights,
434-
self.conv1d.bias,
435-
activation=self.activation,
436-
conv_states=mamba_cache_params.conv_state,
437-
has_initial_state=mamba2_metadata.has_initial_states,
438-
cache_indices=mamba_cache_params.state_indices_tensor,
439-
query_start_loc=attn_metadata.query_start_loc).transpose(
440-
0, 1)[:seq_len]
441-
442-
# TODO: Why is this needed?
443-
hidden_states_B_C = hidden_states_B_C.contiguous()
444-
else:
445-
hidden_states_B_C = causal_conv1d_update(
446-
hidden_states_B_C,
447-
mamba_cache_params.conv_state,
448-
conv_weights,
449-
self.conv1d.bias,
450-
self.activation,
451-
conv_state_indices=mamba_cache_params.state_indices_tensor)
452-
453-
# - get hidden_states, B and C after depthwise convolution.
454-
hidden_states, B, C = torch.split(
455-
hidden_states_B_C,
456-
[
457-
self.intermediate_size // self.tp_size,
458-
groups_time_state_size // self.tp_size,
459-
groups_time_state_size // self.tp_size,
460-
],
461-
dim=-1,
462-
)
463-
464-
# 3. State Space Model sequence transformation
465-
466418
# Separate prefill and decode by splitting varlen input
467419
# Split along token dimension
468-
hidden_states_p, hidden_states_d = torch.split(
469-
hidden_states,
470-
[num_prefill_tokens, num_decodes],
471-
dim=0,
472-
)
473-
B_p, B_d = torch.split(
474-
B,
475-
[num_prefill_tokens, num_decodes],
476-
dim=0,
477-
)
478-
C_p, C_d = torch.split(
479-
C,
420+
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
421+
hidden_states_B_C,
480422
[num_prefill_tokens, num_decodes],
481423
dim=0,
482424
)
@@ -491,18 +433,50 @@ def forward_cuda(
491433
[num_prefills, num_decodes],
492434
dim=0,
493435
)
436+
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
437+
if has_prefill else None)
494438

495-
hidden_states_list = []
439+
# - get hidden_states, B and C after depthwise convolution.
440+
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
441+
hidden_states_B_C,
442+
[
443+
self.intermediate_size // self.tp_size,
444+
groups_time_state_size // self.tp_size,
445+
groups_time_state_size // self.tp_size,
446+
],
447+
dim=-1,
448+
)
449+
450+
ssd_output_list = []
496451

497452
# Process prefill requests
498453
if has_prefill:
454+
# 2. Convolution sequence transformation
455+
# - "cache_indices" updates the conv_state cache in positions
456+
# pointed to by "mamba_cache_params.state_indices_tensor"
457+
hidden_states_B_C_p = causal_conv1d_fn(
458+
hidden_states_B_C_p.transpose(0, 1),
459+
conv_weights,
460+
self.conv1d.bias,
461+
activation=self.activation,
462+
conv_states=mamba_cache_params.conv_state,
463+
has_initial_state=mamba2_metadata.has_initial_states,
464+
cache_indices=state_indices_tensor_p,
465+
query_start_loc=query_start_loc_p).transpose(
466+
0, 1)[:num_prefill_tokens]
467+
468+
# TODO: Why is this needed?
469+
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
470+
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
471+
hidden_states_B_C_p)
472+
473+
# 3. State Space Model sequence transformation
499474
initial_states = None
500475
if (mamba2_metadata.has_initial_states is not None
501476
and mamba2_metadata.prep_initial_states):
502477
# making a copy of the states
503478
initial_states = torch.where(
504-
mamba2_metadata.has_initial_states[:num_prefills, None,
505-
None, None],
479+
mamba2_metadata.has_initial_states[:, None, None, None],
506480
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
507481

508482
scan_output, varlen_state = mamba_chunk_scan_combined(
@@ -535,10 +509,23 @@ def forward_cuda(
535509
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
536510

537511
# - reshape
538-
hidden_states_list.append(scan_output.view(num_prefill_tokens, -1))
512+
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
539513

540514
# Process decode requests
541515
if has_decode:
516+
# 2. Convolution sequence transformation
517+
hidden_states_B_C_d = causal_conv1d_update(
518+
hidden_states_B_C_d,
519+
mamba_cache_params.conv_state,
520+
conv_weights,
521+
self.conv1d.bias,
522+
self.activation,
523+
conv_state_indices=state_indices_tensor_d)
524+
525+
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
526+
hidden_states_B_C_d)
527+
528+
# 3. State Space Model sequence transformation
542529
n_groups = self.n_groups // self.tp_size
543530
A_d = self.A[:, None, ...][:, :, None].expand(
544531
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
@@ -567,12 +554,12 @@ def forward_cuda(
567554
dt_softplus=True,
568555
state_batch_indices=state_indices_tensor_d,
569556
)
570-
hidden_states_list.append(
557+
ssd_output_list.append(
571558
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
572559
self.head_dim))
573560

574561
# Merge prefill and decode outputs before passing to gated MLP
575-
hidden_states = torch.vstack(hidden_states_list)
562+
hidden_states = torch.vstack(ssd_output_list)
576563

577564
# 4. gated MLP
578565
hidden_states = self.norm(hidden_states, gate)

0 commit comments

Comments
 (0)