@@ -453,7 +453,8 @@ def forward_cuda(
453
453
dim = - 1 ,
454
454
)
455
455
456
- # Separate prefill and decode by slicing hidden_states
456
+ # 3. State Space Model sequence transformation
457
+ # Separate prefill and decode by slicing varlen input
457
458
num_prefills = mamba2_metadata .num_prefills # requests
458
459
num_decodes = mamba2_metadata .num_decodes # requests (also tokens)
459
460
num_prefill_tokens = attn_metadata .num_prefill_tokens # tokens
@@ -477,10 +478,15 @@ def forward_cuda(
477
478
[num_prefill_tokens , num_decodes ],
478
479
dim = 0 ,
479
480
)
481
+ state_indices_tensor_p , state_indices_tensor_d = torch .split (
482
+ mamba_cache_params .state_indices_tensor ,
483
+ [num_prefills , num_decodes ],
484
+ dim = 0 ,
485
+ )
480
486
481
487
hidden_states_list = []
482
488
483
- # Process Prefills
489
+ # Process prefill requests
484
490
if num_prefills > 0 :
485
491
initial_states = None
486
492
if (mamba2_metadata .has_initial_states is not None
@@ -489,9 +495,7 @@ def forward_cuda(
489
495
initial_states = torch .where (
490
496
mamba2_metadata .has_initial_states [:num_prefills , None ,
491
497
None , None ],
492
- mamba_cache_params .ssm_state [
493
- mamba_cache_params .
494
- state_indices_tensor [:num_prefills ]], 0 )
498
+ mamba_cache_params .ssm_state [state_indices_tensor_p ], 0 )
495
499
496
500
scan_output , varlen_state = mamba_chunk_scan_combined (
497
501
hidden_states_p .view (1 , num_prefill_tokens ,
@@ -520,13 +524,12 @@ def forward_cuda(
520
524
521
525
# update ssm states
522
526
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
523
- mamba_cache_params .ssm_state [
524
- mamba_cache_params .
525
- state_indices_tensor [:num_prefills ]] = varlen_state
527
+ mamba_cache_params .ssm_state [state_indices_tensor_p ] = varlen_state
526
528
527
529
# - reshape
528
530
hidden_states_list .append (scan_output .view (num_prefill_tokens , - 1 ))
529
531
532
+ # Process decode requests
530
533
if num_decodes > 0 :
531
534
n_groups = self .n_groups // self .tp_size
532
535
A_d = self .A [:, None , ...][:, :, None ].expand (
@@ -558,8 +561,7 @@ def forward_cuda(
558
561
z = None ,
559
562
dt_bias = dt_bias ,
560
563
dt_softplus = True ,
561
- state_batch_indices = mamba_cache_params .
562
- state_indices_tensor [num_prefills :], # take decodes only
564
+ state_batch_indices = state_indices_tensor_d ,
563
565
)
564
566
hidden_states_list .append (
565
567
hidden_states_d .view (- 1 , (self .num_heads // self .tp_size ) *
0 commit comments