@@ -412,71 +412,13 @@ def forward_cuda(
412
412
dim = - 1 ,
413
413
)
414
414
415
- # 2. Convolution sequence transformation
416
415
conv_weights = self .conv1d .weight .view (self .conv1d .weight .size (0 ),
417
416
self .conv1d .weight .size (2 ))
418
417
419
- # causal_conv1d_fn deals with both prefill and decode if input
420
- # has prefill requests.
421
- if has_prefill :
422
- # |---------- N-1 iteration --------|
423
- # |---------------- N iteration ---------------------|
424
- # |- tokenA -|......................|-- newTokens ---|
425
- # |---------- context_len ----------|
426
- # |-------------------- seq_len ---------------------|
427
- # |-- query_len ---|
428
-
429
- # - "cache_indices" updates the conv_state cache in positions
430
- # pointed to by "mamba_cache_params.state_indices_tensor"
431
- hidden_states_B_C = causal_conv1d_fn (
432
- hidden_states_B_C .transpose (0 , 1 ),
433
- conv_weights ,
434
- self .conv1d .bias ,
435
- activation = self .activation ,
436
- conv_states = mamba_cache_params .conv_state ,
437
- has_initial_state = mamba2_metadata .has_initial_states ,
438
- cache_indices = mamba_cache_params .state_indices_tensor ,
439
- query_start_loc = attn_metadata .query_start_loc ).transpose (
440
- 0 , 1 )[:seq_len ]
441
-
442
- # TODO: Why is this needed?
443
- hidden_states_B_C = hidden_states_B_C .contiguous ()
444
- else :
445
- hidden_states_B_C = causal_conv1d_update (
446
- hidden_states_B_C ,
447
- mamba_cache_params .conv_state ,
448
- conv_weights ,
449
- self .conv1d .bias ,
450
- self .activation ,
451
- conv_state_indices = mamba_cache_params .state_indices_tensor )
452
-
453
- # - get hidden_states, B and C after depthwise convolution.
454
- hidden_states , B , C = torch .split (
455
- hidden_states_B_C ,
456
- [
457
- self .intermediate_size // self .tp_size ,
458
- groups_time_state_size // self .tp_size ,
459
- groups_time_state_size // self .tp_size ,
460
- ],
461
- dim = - 1 ,
462
- )
463
-
464
- # 3. State Space Model sequence transformation
465
-
466
418
# Separate prefill and decode by splitting varlen input
467
419
# Split along token dimension
468
- hidden_states_p , hidden_states_d = torch .split (
469
- hidden_states ,
470
- [num_prefill_tokens , num_decodes ],
471
- dim = 0 ,
472
- )
473
- B_p , B_d = torch .split (
474
- B ,
475
- [num_prefill_tokens , num_decodes ],
476
- dim = 0 ,
477
- )
478
- C_p , C_d = torch .split (
479
- C ,
420
+ hidden_states_B_C_p , hidden_states_B_C_d = torch .split (
421
+ hidden_states_B_C ,
480
422
[num_prefill_tokens , num_decodes ],
481
423
dim = 0 ,
482
424
)
@@ -491,18 +433,50 @@ def forward_cuda(
491
433
[num_prefills , num_decodes ],
492
434
dim = 0 ,
493
435
)
436
+ query_start_loc_p = (attn_metadata .query_start_loc [:num_prefills + 1 ]
437
+ if has_prefill else None )
494
438
495
- hidden_states_list = []
439
+ # - get hidden_states, B and C after depthwise convolution.
440
+ split_hidden_states_B_C_fn = lambda hidden_states_B_C : torch .split (
441
+ hidden_states_B_C ,
442
+ [
443
+ self .intermediate_size // self .tp_size ,
444
+ groups_time_state_size // self .tp_size ,
445
+ groups_time_state_size // self .tp_size ,
446
+ ],
447
+ dim = - 1 ,
448
+ )
449
+
450
+ ssd_output_list = []
496
451
497
452
# Process prefill requests
498
453
if has_prefill :
454
+ # 2. Convolution sequence transformation
455
+ # - "cache_indices" updates the conv_state cache in positions
456
+ # pointed to by "mamba_cache_params.state_indices_tensor"
457
+ hidden_states_B_C_p = causal_conv1d_fn (
458
+ hidden_states_B_C_p .transpose (0 , 1 ),
459
+ conv_weights ,
460
+ self .conv1d .bias ,
461
+ activation = self .activation ,
462
+ conv_states = mamba_cache_params .conv_state ,
463
+ has_initial_state = mamba2_metadata .has_initial_states ,
464
+ cache_indices = state_indices_tensor_p ,
465
+ query_start_loc = query_start_loc_p ).transpose (
466
+ 0 , 1 )[:num_prefill_tokens ]
467
+
468
+ # TODO: Why is this needed?
469
+ hidden_states_B_C_p = hidden_states_B_C_p .contiguous ()
470
+ hidden_states_p , B_p , C_p = split_hidden_states_B_C_fn (
471
+ hidden_states_B_C_p )
472
+
473
+ # 3. State Space Model sequence transformation
499
474
initial_states = None
500
475
if (mamba2_metadata .has_initial_states is not None
501
476
and mamba2_metadata .prep_initial_states ):
502
477
# making a copy of the states
503
478
initial_states = torch .where (
504
- mamba2_metadata .has_initial_states [:num_prefills , None ,
505
- None , None ],
479
+ mamba2_metadata .has_initial_states [:, None , None , None ],
506
480
mamba_cache_params .ssm_state [state_indices_tensor_p ], 0 )
507
481
508
482
scan_output , varlen_state = mamba_chunk_scan_combined (
@@ -535,10 +509,23 @@ def forward_cuda(
535
509
mamba_cache_params .ssm_state [state_indices_tensor_p ] = varlen_state
536
510
537
511
# - reshape
538
- hidden_states_list .append (scan_output .view (num_prefill_tokens , - 1 ))
512
+ ssd_output_list .append (scan_output .view (num_prefill_tokens , - 1 ))
539
513
540
514
# Process decode requests
541
515
if has_decode :
516
+ # 2. Convolution sequence transformation
517
+ hidden_states_B_C_d = causal_conv1d_update (
518
+ hidden_states_B_C_d ,
519
+ mamba_cache_params .conv_state ,
520
+ conv_weights ,
521
+ self .conv1d .bias ,
522
+ self .activation ,
523
+ conv_state_indices = state_indices_tensor_d )
524
+
525
+ hidden_states_d , B_d , C_d = split_hidden_states_B_C_fn (
526
+ hidden_states_B_C_d )
527
+
528
+ # 3. State Space Model sequence transformation
542
529
n_groups = self .n_groups // self .tp_size
543
530
A_d = self .A [:, None , ...][:, :, None ].expand (
544
531
- 1 , self .head_dim , self .ssm_state_size ).to (dtype = torch .float32 )
@@ -567,12 +554,12 @@ def forward_cuda(
567
554
dt_softplus = True ,
568
555
state_batch_indices = state_indices_tensor_d ,
569
556
)
570
- hidden_states_list .append (
557
+ ssd_output_list .append (
571
558
hidden_states_d .view (- 1 , (self .num_heads // self .tp_size ) *
572
559
self .head_dim ))
573
560
574
561
# Merge prefill and decode outputs before passing to gated MLP
575
- hidden_states = torch .vstack (hidden_states_list )
562
+ hidden_states = torch .vstack (ssd_output_list )
576
563
577
564
# 4. gated MLP
578
565
hidden_states = self .norm (hidden_states , gate )
0 commit comments