@@ -3369,25 +3369,18 @@ static struct ggml_cgraph * llm_build_baichaun(
3369
3369
3370
3370
static struct ggml_cgraph * llm_build_refact (
3371
3371
llama_context & lctx,
3372
- const llama_token * tokens,
3373
- const float * embd,
3374
- int n_tokens,
3375
- int n_past) {
3376
-
3377
- GGML_ASSERT ((!tokens && embd) || (tokens && !embd)); // NOLINT
3378
-
3379
- const int N = n_tokens;
3380
-
3372
+ const llama_batch & batch) {
3381
3373
const auto & model = lctx.model ;
3382
3374
const auto & hparams = model.hparams ;
3375
+ const auto & cparams = lctx.cparams ;
3383
3376
3384
3377
const auto & kv_self = lctx.kv_self ;
3385
3378
3386
3379
GGML_ASSERT (!!kv_self.ctx );
3387
3380
3388
3381
const int64_t n_embd = hparams.n_embd ;
3389
3382
const int64_t n_layer = hparams.n_layer ;
3390
- const int64_t n_ctx = hparams .n_ctx ;
3383
+ const int64_t n_ctx = cparams .n_ctx ;
3391
3384
const int64_t n_head = hparams.n_head ;
3392
3385
const int64_t n_head_kv = hparams.n_head_kv ;
3393
3386
const int64_t n_embd_head = hparams.n_embd_head ();
@@ -3397,6 +3390,12 @@ static struct ggml_cgraph * llm_build_refact(
3397
3390
3398
3391
const int n_gpu_layers = model.n_gpu_layers ;
3399
3392
3393
+ const int32_t n_tokens = batch.n_tokens ;
3394
+ const int32_t n_kv = ggml_allocr_is_measure (lctx.alloc ) ? n_ctx : kv_self.n ;
3395
+ const int32_t kv_head = ggml_allocr_is_measure (lctx.alloc ) ? n_ctx - n_tokens : kv_self.head ;
3396
+
3397
+ // printf("n_kv = %d\n", n_kv);
3398
+
3400
3399
auto & buf_compute = lctx.buf_compute ;
3401
3400
3402
3401
struct ggml_init_params params = {
@@ -3414,12 +3413,12 @@ static struct ggml_cgraph * llm_build_refact(
3414
3413
struct ggml_tensor * cur;
3415
3414
struct ggml_tensor * inpL;
3416
3415
3417
- if (tokens ) {
3418
- struct ggml_tensor * inp_tokens = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, N );
3416
+ if (batch. token ) {
3417
+ struct ggml_tensor * inp_tokens = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_tokens );
3419
3418
3420
3419
ggml_allocr_alloc (lctx.alloc , inp_tokens);
3421
3420
if (!ggml_allocr_is_measure (lctx.alloc )) {
3422
- memcpy (inp_tokens->data , tokens, N *ggml_element_size (inp_tokens));
3421
+ memcpy (inp_tokens->data , batch. token , n_tokens *ggml_element_size (inp_tokens));
3423
3422
}
3424
3423
ggml_set_name (inp_tokens, " inp_tokens" );
3425
3424
@@ -3429,11 +3428,11 @@ static struct ggml_cgraph * llm_build_refact(
3429
3428
GGML_ASSERT (false && " not implemented" );
3430
3429
#endif
3431
3430
3432
- inpL = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, N );
3431
+ inpL = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, n_tokens );
3433
3432
3434
3433
ggml_allocr_alloc (lctx.alloc , inpL);
3435
3434
if (!ggml_allocr_is_measure (lctx.alloc )) {
3436
- memcpy (inpL->data , embd, N * n_embd * ggml_element_size (inpL));
3435
+ memcpy (inpL->data , batch. embd , n_tokens * n_embd * ggml_element_size (inpL));
3437
3436
}
3438
3437
}
3439
3438
@@ -3442,9 +3441,6 @@ static struct ggml_cgraph * llm_build_refact(
3442
3441
3443
3442
// offload functions set the tensor output backend to GPU
3444
3443
// tensors are GPU-accelerated if any input or the output has been offloaded
3445
- //
3446
- // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
3447
- // in that case ggml_cuda_assign_buffers has no effect
3448
3444
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
3449
3445
offload_func_t offload_func_kq = llama_nop;
3450
3446
offload_func_t offload_func_v = llama_nop;
@@ -3461,12 +3457,36 @@ static struct ggml_cgraph * llm_build_refact(
3461
3457
}
3462
3458
#endif // GGML_USE_CUBLAS
3463
3459
3460
+ // KQ_scale
3464
3461
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d (ctx0, GGML_TYPE_F32, 1 );
3462
+ ggml_set_name (KQ_scale, " 1/sqrt(n_embd_head)" );
3465
3463
ggml_allocr_alloc (lctx.alloc , KQ_scale);
3466
3464
if (!ggml_allocr_is_measure (lctx.alloc )) {
3467
- ggml_set_f32 (KQ_scale, 1 .0f /sqrtf (float (n_embd)/n_head));
3465
+ ggml_set_f32 (KQ_scale, 1 .0f /sqrtf (float (n_embd_head)));
3466
+ }
3467
+
3468
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
3469
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1 );
3470
+ offload_func_kq (KQ_mask);
3471
+ ggml_set_name (KQ_mask, " KQ_mask" );
3472
+ ggml_allocr_alloc (lctx.alloc , KQ_mask);
3473
+ if (!ggml_allocr_is_measure (lctx.alloc )) {
3474
+ float * data = (float *) KQ_mask->data ;
3475
+ memset (data, 0 , ggml_nbytes (KQ_mask));
3476
+
3477
+ for (int h = 0 ; h < 1 ; ++h) {
3478
+ for (int j = 0 ; j < n_tokens; ++j) {
3479
+ const llama_pos pos = batch.pos [j];
3480
+ const llama_seq_id seq_id = batch.seq_id [j];
3481
+
3482
+ for (int i = 0 ; i < n_kv; ++i) {
3483
+ if (!kv_self.cells [i].has_seq_id (seq_id) || kv_self.cells [i].pos > pos) {
3484
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
3485
+ }
3486
+ }
3487
+ }
3488
+ }
3468
3489
}
3469
- ggml_set_name (KQ_scale, " 1/sqrt(n_embd_head)" );
3470
3490
3471
3491
for (int il = 0 ; il < n_layer; ++il) {
3472
3492
ggml_format_name (inpL, " layer_inp_%d" , il);
@@ -3504,36 +3524,33 @@ static struct ggml_cgraph * llm_build_refact(
3504
3524
offload_func_kq (tmpq);
3505
3525
ggml_set_name (tmpq, " tmpq" );
3506
3526
3507
- struct ggml_tensor * Kcur;
3508
- struct ggml_tensor * Qcur;
3509
- Kcur = ggml_reshape_3d (ctx0, tmpk, n_embd_head, n_head_kv, N);
3510
- Qcur = ggml_reshape_3d (ctx0, tmpq, n_embd_head, n_head, N);
3511
-
3527
+ struct ggml_tensor * Kcur = ggml_reshape_3d (ctx0, tmpk, n_embd_head, n_head_kv, n_tokens);
3512
3528
offload_func_kq (Kcur);
3513
3529
ggml_set_name (Kcur, " Kcur" );
3514
3530
3531
+ struct ggml_tensor * Qcur = ggml_reshape_3d (ctx0, tmpq, n_embd_head, n_head, n_tokens);
3515
3532
offload_func_kq (Qcur);
3516
3533
ggml_set_name (Qcur, " Qcur" );
3517
3534
3518
3535
// store key and value to memory
3519
3536
{
3520
- // compute the transposed [N , n_embd] V matrix
3537
+ // compute the transposed [n_tokens , n_embd] V matrix
3521
3538
3522
3539
struct ggml_tensor * tmpv = ggml_mul_mat (ctx0, model.layers [il].wv , cur);
3523
3540
offload_func_v (tmpv);
3524
3541
ggml_set_name (tmpv, " tmpv" );
3525
3542
3526
- struct ggml_tensor * Vcur = ggml_transpose (ctx0, ggml_reshape_2d (ctx0, tmpv, n_embd_gqa, N ));
3543
+ struct ggml_tensor * Vcur = ggml_transpose (ctx0, ggml_reshape_2d (ctx0, tmpv, n_embd_gqa, n_tokens ));
3527
3544
offload_func_v (Vcur);
3528
3545
ggml_set_name (Vcur, " Vcur" );
3529
3546
3530
- struct ggml_tensor * k = ggml_view_1d (ctx0, kv_self.k , N *n_embd_gqa, (ggml_element_size (kv_self.k )*n_embd_gqa)*(il*n_ctx + n_past ));
3547
+ struct ggml_tensor * k = ggml_view_1d (ctx0, kv_self.k , n_tokens *n_embd_gqa, (ggml_element_size (kv_self.k )*n_embd_gqa)*(il*n_ctx + kv_head ));
3531
3548
offload_func_kq (k);
3532
3549
ggml_set_name (k, " k" );
3533
3550
3534
- struct ggml_tensor * v = ggml_view_2d (ctx0, kv_self.v , N , n_embd_gqa,
3551
+ struct ggml_tensor * v = ggml_view_2d (ctx0, kv_self.v , n_tokens , n_embd_gqa,
3535
3552
( n_ctx)*ggml_element_size (kv_self.v ),
3536
- (il*n_ctx)*ggml_element_size (kv_self.v )*n_embd_gqa + n_past *ggml_element_size (kv_self.v ));
3553
+ (il*n_ctx)*ggml_element_size (kv_self.v )*n_embd_gqa + kv_head *ggml_element_size (kv_self.v ));
3537
3554
offload_func_v (v);
3538
3555
ggml_set_name (v, " v" );
3539
3556
@@ -3547,7 +3564,7 @@ static struct ggml_cgraph * llm_build_refact(
3547
3564
3548
3565
struct ggml_tensor * K =
3549
3566
ggml_view_3d (ctx0, kv_self.k ,
3550
- n_embd_head, n_past + N , n_head_kv,
3567
+ n_embd_head, n_kv , n_head_kv,
3551
3568
ggml_element_size (kv_self.k )*n_embd_gqa,
3552
3569
ggml_element_size (kv_self.k )*n_embd_head,
3553
3570
ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il);
@@ -3560,25 +3577,28 @@ static struct ggml_cgraph * llm_build_refact(
3560
3577
ggml_set_name (KQ, " KQ" );
3561
3578
3562
3579
// KQ_scaled = KQ / sqrt(n_embd_head)
3563
- // KQ_scaled shape [n_past + N, N , n_head, 1]
3564
- struct ggml_tensor * KQ_scaled = ggml_scale_inplace (ctx0, KQ, KQ_scale);
3580
+ // KQ_scaled shape [n_kv, n_tokens , n_head, 1]
3581
+ struct ggml_tensor * KQ_scaled = ggml_scale (ctx0, KQ, KQ_scale);
3565
3582
offload_func_kq (KQ_scaled);
3566
3583
ggml_set_name (KQ_scaled, " KQ_scaled" );
3567
3584
3568
- struct ggml_tensor * KQ_masked;
3569
- struct ggml_tensor * KQ_scaled_alibi;
3570
-
3571
- KQ_scaled_alibi =ggml_alibi (ctx0, KQ_scaled, n_past, n_head, 8 );
3585
+ // KQ_masked = mask_past(KQ_scaled)
3586
+ struct ggml_tensor * KQ_scaled_alibi = ggml_alibi (ctx0, KQ_scaled, /* n_past*/ 0 , n_head, 8 );
3572
3587
ggml_set_name (KQ_scaled_alibi, " KQ_scaled_alibi" );
3573
- KQ_masked = ggml_diag_mask_inf (ctx0, KQ_scaled_alibi, n_past);
3574
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace (ctx0, KQ_masked);
3588
+
3589
+ struct ggml_tensor * KQ_masked = ggml_add (ctx0, KQ_scaled_alibi, KQ_mask);
3590
+ offload_func_kq (KQ_masked);
3591
+ ggml_set_name (KQ_masked, " KQ_masked" );
3592
+
3593
+ // KQ = soft_max(KQ_masked)
3594
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max (ctx0, KQ_masked);
3575
3595
offload_func_v (KQ_soft_max);
3576
3596
ggml_set_name (KQ_soft_max, " KQ_soft_max" );
3577
3597
3578
3598
// split cached V into n_head heads
3579
3599
struct ggml_tensor * V =
3580
3600
ggml_view_3d (ctx0, kv_self.v ,
3581
- n_past + N , n_embd_head, n_head_kv,
3601
+ n_kv , n_embd_head, n_head_kv,
3582
3602
ggml_element_size (kv_self.v )*n_ctx,
3583
3603
ggml_element_size (kv_self.v )*n_ctx*n_embd_head,
3584
3604
ggml_element_size (kv_self.v )*n_ctx*n_embd_gqa*il);
@@ -3593,7 +3613,7 @@ static struct ggml_cgraph * llm_build_refact(
3593
3613
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
3594
3614
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
3595
3615
// is there a better way?
3596
- struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N , n_embd_head, n_head));
3616
+ struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx , n_embd_head, n_head));
3597
3617
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
3598
3618
#endif
3599
3619
@@ -3602,10 +3622,8 @@ static struct ggml_cgraph * llm_build_refact(
3602
3622
offload_func_v (KQV_merged);
3603
3623
ggml_set_name (KQV_merged, " KQV_merged" );
3604
3624
3605
- // cur = KQV_merged.contiguous().view(n_embd, N)
3606
- cur = ggml_cpy (ctx0,
3607
- KQV_merged,
3608
- ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, N));
3625
+ // cur = KQV_merged.contiguous().view(n_embd, n_tokens)
3626
+ cur = ggml_cont_2d (ctx0, KQV_merged, n_embd, n_tokens);
3609
3627
offload_func_v (cur);
3610
3628
ggml_set_name (cur, " KQV_merged_contiguous" );
3611
3629
@@ -4338,7 +4356,7 @@ static struct ggml_cgraph * llama_build_graph(
4338
4356
} break ;
4339
4357
case LLM_ARCH_REFACT:
4340
4358
{
4341
- result = llm_build_refact (lctx, tokens, embd, n_tokens, n_past );
4359
+ result = llm_build_refact (lctx, batch );
4342
4360
} break ;
4343
4361
default :
4344
4362
GGML_ASSERT (false );
0 commit comments