@@ -388,10 +388,15 @@ def forward_cuda(
388
388
# mamba2_metadata contains metadata necessary for the mamba2 triton
389
389
# kernels to operate in continuous batching and in chunked prefill
390
390
# modes; they are computed at top-level model forward since they
391
- # are the same and reused for all mamba layers in the same iteration
391
+ # stay the same and reused for all mamba layers in the same iteration
392
392
attn_metadata : AttentionMetadata = get_forward_context ().attn_metadata
393
393
394
- seq_len , _ = hidden_states .shape
394
+ num_prefills = attn_metadata .num_prefills # request count
395
+ num_decodes = attn_metadata .num_decode_tokens # token count (=request)
396
+ num_prefill_tokens = attn_metadata .num_prefill_tokens # token count
397
+ has_prefill = num_prefills > 0
398
+ has_decode = num_decodes > 0
399
+
395
400
groups_time_state_size = self .n_groups * self .ssm_state_size
396
401
397
402
# 1. Gated MLP's linear projection
@@ -406,44 +411,32 @@ def forward_cuda(
406
411
dim = - 1 ,
407
412
)
408
413
409
- # 2. Convolution sequence transformation
410
414
conv_weights = self .conv1d .weight .view (self .conv1d .weight .size (0 ),
411
415
self .conv1d .weight .size (2 ))
412
416
413
- if mamba2_metadata .has_prefill :
414
- # |---------- N-1 iteration --------|
415
- # |---------------- N iteration ---------------------|
416
- # |- tokenA -|......................|-- newTokens ---|
417
- # |---------- context_len ----------|
418
- # |-------------------- seq_len ---------------------|
419
- # |-- query_len ---|
420
-
421
- # - "cache_indices" updates the conv_state cache in positions
422
- # pointed to by "mamba_cache_params.state_indices_tensor"
423
- hidden_states_B_C = causal_conv1d_fn (
424
- hidden_states_B_C .transpose (0 , 1 ),
425
- conv_weights ,
426
- self .conv1d .bias ,
427
- activation = self .activation ,
428
- conv_states = mamba_cache_params .conv_state ,
429
- has_initial_state = mamba2_metadata .has_initial_states ,
430
- cache_indices = mamba_cache_params .state_indices_tensor ,
431
- query_start_loc = attn_metadata .query_start_loc ).transpose (
432
- 0 , 1 )[:seq_len ]
433
-
434
- # TODO: Why is this needed?
435
- hidden_states_B_C = hidden_states_B_C .contiguous ()
436
- else :
437
- hidden_states_B_C = causal_conv1d_update (
438
- hidden_states_B_C ,
439
- mamba_cache_params .conv_state ,
440
- conv_weights ,
441
- self .conv1d .bias ,
442
- self .activation ,
443
- conv_state_indices = mamba_cache_params .state_indices_tensor )
417
+ # Separate prefill and decode by splitting varlen input
418
+ # Split along token dimension
419
+ hidden_states_B_C_p , hidden_states_B_C_d = torch .split (
420
+ hidden_states_B_C ,
421
+ [num_prefill_tokens , num_decodes ],
422
+ dim = 0 ,
423
+ )
424
+ dt_p , dt_d = torch .split (
425
+ dt ,
426
+ [num_prefill_tokens , num_decodes ],
427
+ dim = 0 ,
428
+ )
429
+ # Split along batch dimension
430
+ state_indices_tensor_p , state_indices_tensor_d = torch .split (
431
+ mamba_cache_params .state_indices_tensor ,
432
+ [num_prefills , num_decodes ],
433
+ dim = 0 ,
434
+ )
435
+ query_start_loc_p = (attn_metadata .query_start_loc [:num_prefills + 1 ]
436
+ if has_prefill else None )
444
437
445
438
# - get hidden_states, B and C after depthwise convolution.
446
- hidden_states , B , C = torch .split (
439
+ split_hidden_states_B_C_fn = lambda hidden_states_B_C : torch .split (
447
440
hidden_states_B_C ,
448
441
[
449
442
self .intermediate_size // self .tp_size ,
@@ -453,32 +446,56 @@ def forward_cuda(
453
446
dim = - 1 ,
454
447
)
455
448
456
- # 3. State Space Model sequence transformation
457
- if mamba2_metadata .has_prefill :
449
+ ssd_output_list = []
450
+
451
+ # Process prefill requests
452
+ if has_prefill :
453
+ # 2. Convolution sequence transformation
454
+ # - "cache_indices" updates the conv_state cache in positions
455
+ # pointed to by "mamba_cache_params.state_indices_tensor"
456
+ hidden_states_B_C_p = causal_conv1d_fn (
457
+ hidden_states_B_C_p .transpose (0 , 1 ),
458
+ conv_weights ,
459
+ self .conv1d .bias ,
460
+ activation = self .activation ,
461
+ conv_states = mamba_cache_params .conv_state ,
462
+ has_initial_state = mamba2_metadata .has_initial_states ,
463
+ cache_indices = state_indices_tensor_p ,
464
+ query_start_loc = query_start_loc_p ).transpose (
465
+ 0 , 1 )[:num_prefill_tokens ]
466
+
467
+ # TODO: Why is this needed?
468
+ hidden_states_B_C_p = hidden_states_B_C_p .contiguous ()
469
+ hidden_states_p , B_p , C_p = split_hidden_states_B_C_fn (
470
+ hidden_states_B_C_p )
471
+
472
+ # 3. State Space Model sequence transformation
458
473
initial_states = None
459
474
if (mamba2_metadata .has_initial_states is not None
460
475
and mamba2_metadata .prep_initial_states ):
461
476
# making a copy of the states
462
477
initial_states = torch .where (
463
478
mamba2_metadata .has_initial_states [:, None , None , None ],
464
- mamba_cache_params .ssm_state [
465
- mamba_cache_params .state_indices_tensor ], 0 )
479
+ mamba_cache_params .ssm_state [state_indices_tensor_p ], 0 )
466
480
467
481
scan_output , varlen_state = mamba_chunk_scan_combined (
468
- hidden_states .view (1 , seq_len , self .num_heads // self .tp_size ,
469
- self .head_dim ),
470
- dt .unsqueeze (0 ),
482
+ hidden_states_p .view (1 , num_prefill_tokens ,
483
+ self .num_heads // self .tp_size ,
484
+ self .head_dim ),
485
+ dt_p .unsqueeze (0 ),
471
486
self .A ,
472
- B .view (1 , seq_len , self .n_groups // self .tp_size , - 1 ),
473
- C .view (1 , seq_len , self .n_groups // self .tp_size , - 1 ),
487
+ B_p .view (1 , num_prefill_tokens , self .n_groups // self .tp_size ,
488
+ - 1 ),
489
+ C_p .view (1 , num_prefill_tokens , self .n_groups // self .tp_size ,
490
+ - 1 ),
474
491
chunk_size = mamba2_metadata .chunk_size ,
475
492
D = self .D ,
476
493
z = None ,
477
494
dt_bias = self .dt_bias ,
478
495
seq_idx = mamba2_metadata .seq_idx ,
479
496
chunk_indices = mamba2_metadata .chunk_indices ,
480
497
chunk_offsets = mamba2_metadata .chunk_offsets ,
481
- cu_seqlens = attn_metadata .query_start_loc ,
498
+ cu_seqlens = attn_metadata .query_start_loc [: num_prefills + 1 ] ,
482
499
initial_states = initial_states ,
483
500
return_varlen_states = True ,
484
501
return_final_states = False ,
@@ -487,52 +504,65 @@ def forward_cuda(
487
504
)
488
505
489
506
# update ssm states
490
- # - varlen state is a (batch, nheads, headdim, dstate) tensor
491
- mamba_cache_params .ssm_state [
492
- mamba_cache_params .state_indices_tensor ] = varlen_state
507
+ # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
508
+ mamba_cache_params .ssm_state [state_indices_tensor_p ] = varlen_state
493
509
494
510
# - reshape
495
- hidden_states = scan_output .view (seq_len , - 1 )
496
- else :
511
+ ssd_output_list .append (scan_output .view (num_prefill_tokens , - 1 ))
497
512
513
+ # Process decode requests
514
+ if has_decode :
515
+ # 2. Convolution sequence transformation
516
+ hidden_states_B_C_d = causal_conv1d_update (
517
+ hidden_states_B_C_d ,
518
+ mamba_cache_params .conv_state ,
519
+ conv_weights ,
520
+ self .conv1d .bias ,
521
+ self .activation ,
522
+ conv_state_indices = state_indices_tensor_d )
523
+
524
+ hidden_states_d , B_d , C_d = split_hidden_states_B_C_fn (
525
+ hidden_states_B_C_d )
526
+
527
+ # 3. State Space Model sequence transformation
498
528
n_groups = self .n_groups // self .tp_size
499
- A = self .A [:, None , ...][:, :, None ].expand (
529
+ A_d = self .A [:, None , ...][:, :, None ].expand (
500
530
- 1 , self .head_dim , self .ssm_state_size ).to (dtype = torch .float32 )
501
- dt = dt [:, :, None ].expand (- 1 , - 1 , self .head_dim )
531
+ dt_d = dt_d [:, :, None ].expand (- 1 , - 1 , self .head_dim )
502
532
dt_bias = self .dt_bias [:, None , ...].expand (- 1 , self .head_dim )
503
- D = self .D [:, None , ...].expand (- 1 , self .head_dim )
504
- B = B .view (- 1 , n_groups , B .shape [1 ] // n_groups )
505
- C = C .view (- 1 , n_groups , C .shape [1 ] // n_groups )
506
- hidden_states_reshaped = hidden_states .view (
533
+ D_d = self .D [:, None , ...].expand (- 1 , self .head_dim )
534
+ B_d = B_d .view (- 1 , n_groups , B_d .shape [1 ] // n_groups )
535
+ C_d = C_d .view (- 1 , n_groups , C_d .shape [1 ] // n_groups )
536
+ hidden_states_d = hidden_states_d .view (
507
537
- 1 , self .num_heads // self .tp_size , self .head_dim )
508
538
509
- # - the hidden is reshaped into number of current batches
510
- # - in this case there is no more prefill, so the batches gen
511
- # 1 token at a time
512
- # - thus hidden will be (bs, num_heads, head_dim)
539
+ # - the hidden is reshaped into (bs, num_heads, head_dim)
513
540
# - mamba_cache_params.ssm_state's slots will be selected
514
- # using "mamba_cache_params.state_indices_tensor", just as
515
- # above in the prefill case
541
+ # using state_indices_tensor_d
516
542
517
- hidden_states = selective_state_update (
543
+ hidden_states_d = selective_state_update (
518
544
mamba_cache_params .ssm_state ,
519
- hidden_states_reshaped ,
520
- dt ,
521
- A ,
522
- B ,
523
- C ,
524
- D ,
545
+ hidden_states_d ,
546
+ dt_d ,
547
+ A_d ,
548
+ B_d ,
549
+ C_d ,
550
+ D_d ,
525
551
z = None ,
526
552
dt_bias = dt_bias ,
527
553
dt_softplus = True ,
528
- state_batch_indices = mamba_cache_params . state_indices_tensor ,
554
+ state_batch_indices = state_indices_tensor_d ,
529
555
)
530
- hidden_states = hidden_states .view (
531
- - 1 , (self .num_heads // self .tp_size ) * self .head_dim )
556
+ ssd_output_list .append (
557
+ hidden_states_d .view (- 1 , (self .num_heads // self .tp_size ) *
558
+ self .head_dim ))
559
+
560
+ # Merge prefill and decode outputs before passing to gated MLP
561
+ hidden_states = torch .vstack (ssd_output_list )
532
562
533
- # # 4. gated MLP
563
+ # 4. gated MLP
534
564
hidden_states = self .norm (hidden_states , gate )
535
565
536
- # # 5. Final linear projection
566
+ # 5. Final linear projection
537
567
out , _ = self .out_proj (hidden_states )
538
568
return out
0 commit comments