Skip to content

Commit a7a4561

Browse files
committed
minor improvement
Signed-off-by: Chih-Chieh-Yang <[email protected]>
1 parent 41e4ddc commit a7a4561

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,8 @@ def forward_cuda(
453453
dim=-1,
454454
)
455455

456-
# Separate prefill and decode by slicing hidden_states
456+
# 3. State Space Model sequence transformation
457+
# Separate prefill and decode by slicing varlen input
457458
num_prefills = mamba2_metadata.num_prefills # requests
458459
num_decodes = mamba2_metadata.num_decodes # requests (also tokens)
459460
num_prefill_tokens = attn_metadata.num_prefill_tokens # tokens
@@ -477,10 +478,15 @@ def forward_cuda(
477478
[num_prefill_tokens, num_decodes],
478479
dim=0,
479480
)
481+
state_indices_tensor_p, state_indices_tensor_d = torch.split(
482+
mamba_cache_params.state_indices_tensor,
483+
[num_prefills, num_decodes],
484+
dim=0,
485+
)
480486

481487
hidden_states_list = []
482488

483-
# Process Prefills
489+
# Process prefill requests
484490
if num_prefills > 0:
485491
initial_states = None
486492
if (mamba2_metadata.has_initial_states is not None
@@ -489,9 +495,7 @@ def forward_cuda(
489495
initial_states = torch.where(
490496
mamba2_metadata.has_initial_states[:num_prefills, None,
491497
None, None],
492-
mamba_cache_params.ssm_state[
493-
mamba_cache_params.
494-
state_indices_tensor[:num_prefills]], 0)
498+
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
495499

496500
scan_output, varlen_state = mamba_chunk_scan_combined(
497501
hidden_states_p.view(1, num_prefill_tokens,
@@ -520,13 +524,12 @@ def forward_cuda(
520524

521525
# update ssm states
522526
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
523-
mamba_cache_params.ssm_state[
524-
mamba_cache_params.
525-
state_indices_tensor[:num_prefills]] = varlen_state
527+
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
526528

527529
# - reshape
528530
hidden_states_list.append(scan_output.view(num_prefill_tokens, -1))
529531

532+
# Process decode requests
530533
if num_decodes > 0:
531534
n_groups = self.n_groups // self.tp_size
532535
A_d = self.A[:, None, ...][:, :, None].expand(
@@ -558,8 +561,7 @@ def forward_cuda(
558561
z=None,
559562
dt_bias=dt_bias,
560563
dt_softplus=True,
561-
state_batch_indices=mamba_cache_params.
562-
state_indices_tensor[num_prefills:], # take decodes only
564+
state_batch_indices=state_indices_tensor_d,
563565
)
564566
hidden_states_list.append(
565567
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *

0 commit comments

Comments
 (0)