@@ -174,6 +174,7 @@ struct cmd_params {
174
174
std::vector<llama_split_mode> split_mode;
175
175
std::vector<int > main_gpu;
176
176
std::vector<bool > no_kv_offload;
177
+ std::vector<bool > flash_attn;
177
178
std::vector<std::vector<float >> tensor_split;
178
179
std::vector<bool > use_mmap;
179
180
std::vector<bool > embeddings;
@@ -195,6 +196,7 @@ static const cmd_params cmd_params_defaults = {
195
196
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
196
197
/* main_gpu */ {0 },
197
198
/* no_kv_offload */ {false },
199
+ /* flash_attn */ {false },
198
200
/* tensor_split */ {std::vector<float >(llama_max_devices (), 0 .0f )},
199
201
/* use_mmap */ {true },
200
202
/* embeddings */ {false },
@@ -220,6 +222,7 @@ static void print_usage(int /* argc */, char ** argv) {
220
222
printf (" -sm, --split-mode <none|layer|row> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.split_mode , split_mode_str), " ," ).c_str ());
221
223
printf (" -mg, --main-gpu <i> (default: %s)\n " , join (cmd_params_defaults.main_gpu , " ," ).c_str ());
222
224
printf (" -nkvo, --no-kv-offload <0|1> (default: %s)\n " , join (cmd_params_defaults.no_kv_offload , " ," ).c_str ());
225
+ printf (" -fa, --flash-attn <0|1> (default: %s)\n " , join (cmd_params_defaults.flash_attn , " ," ).c_str ());
223
226
printf (" -mmp, --mmap <0|1> (default: %s)\n " , join (cmd_params_defaults.use_mmap , " ," ).c_str ());
224
227
printf (" -embd, --embeddings <0|1> (default: %s)\n " , join (cmd_params_defaults.embeddings , " ," ).c_str ());
225
228
printf (" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n " );
@@ -393,6 +396,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
393
396
}
394
397
auto p = split<bool >(argv[i], split_delim);
395
398
params.no_kv_offload .insert (params.no_kv_offload .end (), p.begin (), p.end ());
399
+ } else if (arg == " -fa" || arg == " --flash-attn" ) {
400
+ if (++i >= argc) {
401
+ invalid_param = true ;
402
+ break ;
403
+ }
404
+ auto p = split<bool >(argv[i], split_delim);
405
+ params.flash_attn .insert (params.flash_attn .end (), p.begin (), p.end ());
396
406
} else if (arg == " -mmp" || arg == " --mmap" ) {
397
407
if (++i >= argc) {
398
408
invalid_param = true ;
@@ -477,6 +487,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
477
487
if (params.split_mode .empty ()) { params.split_mode = cmd_params_defaults.split_mode ; }
478
488
if (params.main_gpu .empty ()) { params.main_gpu = cmd_params_defaults.main_gpu ; }
479
489
if (params.no_kv_offload .empty ()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload ; }
490
+ if (params.flash_attn .empty ()) { params.flash_attn = cmd_params_defaults.flash_attn ; }
480
491
if (params.tensor_split .empty ()) { params.tensor_split = cmd_params_defaults.tensor_split ; }
481
492
if (params.use_mmap .empty ()) { params.use_mmap = cmd_params_defaults.use_mmap ; }
482
493
if (params.embeddings .empty ()) { params.embeddings = cmd_params_defaults.embeddings ; }
@@ -498,6 +509,7 @@ struct cmd_params_instance {
498
509
llama_split_mode split_mode;
499
510
int main_gpu;
500
511
bool no_kv_offload;
512
+ bool flash_attn;
501
513
std::vector<float > tensor_split;
502
514
bool use_mmap;
503
515
bool embeddings;
@@ -532,6 +544,7 @@ struct cmd_params_instance {
532
544
cparams.type_k = type_k;
533
545
cparams.type_v = type_v;
534
546
cparams.offload_kqv = !no_kv_offload;
547
+ cparams.flash_attn = flash_attn;
535
548
cparams.embeddings = embeddings;
536
549
537
550
return cparams;
@@ -554,6 +567,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
554
567
for (const auto & tk : params.type_k )
555
568
for (const auto & tv : params.type_v )
556
569
for (const auto & nkvo : params.no_kv_offload )
570
+ for (const auto & fa : params.flash_attn )
557
571
for (const auto & nt : params.n_threads ) {
558
572
for (const auto & n_prompt : params.n_prompt ) {
559
573
if (n_prompt == 0 ) {
@@ -572,6 +586,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
572
586
/* .split_mode = */ sm,
573
587
/* .main_gpu = */ mg,
574
588
/* .no_kv_offload= */ nkvo,
589
+ /* .flash_attn = */ fa,
575
590
/* .tensor_split = */ ts,
576
591
/* .use_mmap = */ mmp,
577
592
/* .embeddings = */ embd,
@@ -596,6 +611,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
596
611
/* .split_mode = */ sm,
597
612
/* .main_gpu = */ mg,
598
613
/* .no_kv_offload= */ nkvo,
614
+ /* .flash_attn = */ fa,
599
615
/* .tensor_split = */ ts,
600
616
/* .use_mmap = */ mmp,
601
617
/* .embeddings = */ embd,
@@ -633,6 +649,7 @@ struct test {
633
649
llama_split_mode split_mode;
634
650
int main_gpu;
635
651
bool no_kv_offload;
652
+ bool flash_attn;
636
653
std::vector<float > tensor_split;
637
654
bool use_mmap;
638
655
bool embeddings;
@@ -657,6 +674,7 @@ struct test {
657
674
split_mode = inst.split_mode ;
658
675
main_gpu = inst.main_gpu ;
659
676
no_kv_offload = inst.no_kv_offload ;
677
+ flash_attn = inst.flash_attn ;
660
678
tensor_split = inst.tensor_split ;
661
679
use_mmap = inst.use_mmap ;
662
680
embeddings = inst.embeddings ;
@@ -731,7 +749,7 @@ struct test {
731
749
" n_batch" , " n_ubatch" ,
732
750
" n_threads" , " type_k" , " type_v" ,
733
751
" n_gpu_layers" , " split_mode" ,
734
- " main_gpu" , " no_kv_offload" ,
752
+ " main_gpu" , " no_kv_offload" , " flash_attn " ,
735
753
" tensor_split" , " use_mmap" , " embeddings" ,
736
754
" n_prompt" , " n_gen" , " test_time" ,
737
755
" avg_ns" , " stddev_ns" ,
@@ -753,7 +771,7 @@ struct test {
753
771
}
754
772
if (field == " cuda" || field == " opencl" || field == " vulkan" || field == " kompute" || field == " metal" ||
755
773
field == " gpu_blas" || field == " blas" || field == " sycl" ||field == " f16_kv" || field == " no_kv_offload" ||
756
- field == " use_mmap" || field == " embeddings" ) {
774
+ field == " flash_attn " || field == " use_mmap" || field == " embeddings" ) {
757
775
return BOOL;
758
776
}
759
777
if (field == " avg_ts" || field == " stddev_ts" ) {
@@ -787,7 +805,7 @@ struct test {
787
805
std::to_string (n_batch), std::to_string (n_ubatch),
788
806
std::to_string (n_threads), ggml_type_name (type_k), ggml_type_name (type_v),
789
807
std::to_string (n_gpu_layers), split_mode_str (split_mode),
790
- std::to_string (main_gpu), std::to_string (no_kv_offload),
808
+ std::to_string (main_gpu), std::to_string (no_kv_offload), std::to_string (flash_attn),
791
809
tensor_split_str, std::to_string (use_mmap), std::to_string (embeddings),
792
810
std::to_string (n_prompt), std::to_string (n_gen), test_time,
793
811
std::to_string (avg_ns ()), std::to_string (stdev_ns ()),
@@ -955,6 +973,9 @@ struct markdown_printer : public printer {
955
973
if (field == " no_kv_offload" ) {
956
974
return " nkvo" ;
957
975
}
976
+ if (field == " flash_attn" ) {
977
+ return " fa" ;
978
+ }
958
979
if (field == " use_mmap" ) {
959
980
return " mmap" ;
960
981
}
@@ -1001,6 +1022,9 @@ struct markdown_printer : public printer {
1001
1022
if (params.no_kv_offload .size () > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload ) {
1002
1023
fields.emplace_back (" no_kv_offload" );
1003
1024
}
1025
+ if (params.flash_attn .size () > 1 || params.flash_attn != cmd_params_defaults.flash_attn ) {
1026
+ fields.emplace_back (" flash_attn" );
1027
+ }
1004
1028
if (params.tensor_split .size () > 1 || params.tensor_split != cmd_params_defaults.tensor_split ) {
1005
1029
fields.emplace_back (" tensor_split" );
1006
1030
}
0 commit comments