Skip to content

Commit bc4bba3

Browse files
agray3slaren
andauthored
Introduction of CUDA Graphs to LLama.cpp (#6766)
* DRAFT: Introduction of CUDA Graphs to LLama.cpp * FIx issues raised in comments * Tidied to now only use CUDA runtime (not mixed with driver calls) * disable for multi-gpu and batch size > 1 * Disable CUDA graphs for old GPU arch and with env var * added missing CUDA_CHECKs * Addressed comments * further addressed comments * limit to GGML_ALLOW_CUDA_GRAPHS defined in llama.cpp cmake * Added more comprehensive graph node checking * With mechanism to fall back if graph capture fails * Revert "With mechanism to fall back if graph capture fails" This reverts commit eb9f15f. * Fall back if graph capture fails and address other comments * - renamed GGML_ALLOW_CUDA_GRAPHS to GGML_CUDA_USE_GRAPHS - rename env variable to disable CUDA graphs to GGML_CUDA_DISABLE_GRAPHS - updated Makefile build to enable CUDA graphs - removed graph capture failure checking in ggml_cuda_error using a global variable to track this is not thread safe, but I am also not safistied with checking an error by string if this is necessary to workaround some issues with graph capture with eg. cuBLAS, we can pass the ggml_backend_cuda_context to the error checking macro and store the result in the context - fixed several resource leaks - fixed issue with zero node graphs - changed fixed size arrays to vectors - removed the count of number of evaluations before start capturing, and instead changed the capture mode to relaxed - removed the check for multiple devices so that it is still possible to use a single device, instead checks for split buffers to disable cuda graphs with -sm row - changed the op for checking batch size to GGML_OP_ADD, should be more reliable than GGML_OP_SOFT_MAX - code style fixes - things to look into - VRAM usage of the cudaGraphExec_t, if it is significant we may need to make it optional - possibility of using cudaStreamBeginCaptureToGraph to keep track of which ggml graph nodes correspond to which cuda graph nodes * fix build without cuda graphs * remove outdated comment * replace minimum cc value with a constant --------- Co-authored-by: slaren <[email protected]>
1 parent c12452c commit bc4bba3

File tree

11 files changed

+372
-44
lines changed

11 files changed

+372
-44
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ if (LLAMA_CUDA)
405405
list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu")
406406

407407
add_compile_definitions(GGML_USE_CUDA)
408+
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
408409
if (LLAMA_CUDA_FORCE_DMMV)
409410
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
410411
endif()

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ ifdef LLAMA_CUDA
433433
else
434434
CUDA_PATH ?= /usr/local/cuda
435435
endif
436-
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
436+
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
437437
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
438438
OBJS += ggml-cuda.o
439439
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))

ggml-cuda.cu

Lines changed: 286 additions & 14 deletions
Large diffs are not rendered by default.

ggml-cuda/clamp.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,4 @@ void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3131
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
3232

3333
clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
34-
CUDA_CHECK(cudaGetLastError());
3534
}

ggml-cuda/common.cuh

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <cassert>
2020
#include <cfloat>
2121
#include <string>
22+
#include <vector>
2223

2324
#if defined(GGML_USE_HIPBLAS)
2425
#include <hip/hip_runtime.h>
@@ -526,6 +527,43 @@ struct ggml_tensor_extra_gpu {
526527
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
527528
};
528529

530+
531+
#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
532+
#define USE_CUDA_GRAPH
533+
#endif
534+
535+
struct ggml_graph_node_properties {
536+
void * node_address;
537+
ggml_op node_op;
538+
int64_t ne[GGML_MAX_DIMS];
539+
size_t nb[GGML_MAX_DIMS];
540+
void * src_address[GGML_MAX_SRC];
541+
};
542+
543+
struct ggml_cuda_graph {
544+
#ifdef USE_CUDA_GRAPH
545+
~ggml_cuda_graph() {
546+
if (instance != nullptr) {
547+
CUDA_CHECK(cudaGraphExecDestroy(instance));
548+
}
549+
if (graph != nullptr) {
550+
CUDA_CHECK(cudaGraphDestroy(graph));
551+
}
552+
}
553+
cudaGraph_t graph = nullptr;
554+
cudaGraphExec_t instance = nullptr;
555+
size_t num_nodes = 0;
556+
std::vector<cudaGraphNode_t> nodes;
557+
std::vector<cudaKernelNodeParams> params;
558+
bool disable_due_to_gpu_arch = false;
559+
bool disable_due_to_too_many_updates = false;
560+
bool disable_due_to_failed_graph_capture = false;
561+
int number_consecutive_updates = 0;
562+
std::vector<ggml_graph_node_properties> ggml_graph_properties;
563+
std::vector<char **> updated_kernel_arg;
564+
#endif
565+
};
566+
529567
struct ggml_backend_cuda_context {
530568
int device;
531569
std::string name;
@@ -534,6 +572,8 @@ struct ggml_backend_cuda_context {
534572
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
535573
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
536574

575+
std::unique_ptr<ggml_cuda_graph> cuda_graph;
576+
537577
explicit ggml_backend_cuda_context(int device) :
538578
device(device),
539579
name(GGML_CUDA_NAME + std::to_string(device)) {

ggml-cuda/convert.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,6 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_
727727
}
728728

729729
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
730-
int id;
731730
switch (type) {
732731
case GGML_TYPE_Q4_0:
733732
return dequantize_row_q4_0_cuda;
@@ -738,8 +737,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
738737
case GGML_TYPE_Q5_1:
739738
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
740739
case GGML_TYPE_Q8_0:
741-
CUDA_CHECK(cudaGetDevice(&id));
742-
if (ggml_cuda_info().devices[id].cc >= CC_PASCAL) {
740+
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
743741
return dequantize_block_q8_0_f16_cuda;
744742
}
745743
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;

ggml-cuda/cpy.cu

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,32 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
459459
const ggml_tensor * src0 = dst->src[0];
460460
ggml_cuda_cpy(ctx, src0, dst);
461461
}
462+
463+
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
464+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
465+
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
466+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
467+
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
468+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
469+
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
470+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
471+
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
472+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
473+
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
474+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
475+
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
476+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
477+
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
478+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
479+
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
480+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
481+
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
482+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
483+
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
484+
} else {
485+
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
486+
ggml_type_name(src0->type), ggml_type_name(src1->type));
487+
GGML_ASSERT(false);
488+
}
489+
}
490+

ggml-cuda/cpy.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
66

77
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8+
9+
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);

ggml-cuda/mmq.cu

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,8 +1735,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
17351735
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
17361736
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
17371737

1738-
int id;
1739-
CUDA_CHECK(cudaGetDevice(&id));
1738+
int id = ggml_cuda_get_device();
17401739
const int compute_capability = ggml_cuda_info().devices[id].cc;
17411740

17421741
int mmq_x, mmq_y, nwarps;
@@ -1780,8 +1779,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
17801779
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
17811780
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
17821781

1783-
int id;
1784-
CUDA_CHECK(cudaGetDevice(&id));
1782+
int id = ggml_cuda_get_device();
17851783
const int compute_capability = ggml_cuda_info().devices[id].cc;
17861784

17871785
int mmq_x, mmq_y, nwarps;
@@ -1825,8 +1823,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
18251823
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
18261824
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
18271825

1828-
int id;
1829-
CUDA_CHECK(cudaGetDevice(&id));
1826+
int id = ggml_cuda_get_device();
18301827
const int compute_capability = ggml_cuda_info().devices[id].cc;
18311828

18321829
int mmq_x, mmq_y, nwarps;
@@ -1870,8 +1867,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
18701867
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
18711868
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
18721869

1873-
int id;
1874-
CUDA_CHECK(cudaGetDevice(&id));
1870+
int id = ggml_cuda_get_device();
18751871
const int compute_capability = ggml_cuda_info().devices[id].cc;
18761872

18771873
int mmq_x, mmq_y, nwarps;
@@ -1915,8 +1911,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
19151911
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
19161912
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
19171913

1918-
int id;
1919-
CUDA_CHECK(cudaGetDevice(&id));
1914+
int id = ggml_cuda_get_device();
19201915
const int compute_capability = ggml_cuda_info().devices[id].cc;
19211916

19221917
int mmq_x, mmq_y, nwarps;
@@ -1960,8 +1955,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
19601955
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
19611956
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
19621957

1963-
int id;
1964-
CUDA_CHECK(cudaGetDevice(&id));
1958+
int id = ggml_cuda_get_device();
19651959
const int compute_capability = ggml_cuda_info().devices[id].cc;
19661960

19671961
int mmq_x, mmq_y, nwarps;
@@ -2007,8 +2001,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
20072001

20082002
#if QK_K == 256
20092003

2010-
int id;
2011-
CUDA_CHECK(cudaGetDevice(&id));
2004+
int id = ggml_cuda_get_device();
20122005
const int compute_capability = ggml_cuda_info().devices[id].cc;
20132006

20142007
int mmq_x, mmq_y, nwarps;
@@ -2053,8 +2046,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
20532046
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
20542047
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
20552048

2056-
int id;
2057-
CUDA_CHECK(cudaGetDevice(&id));
2049+
int id = ggml_cuda_get_device();
20582050
const int compute_capability = ggml_cuda_info().devices[id].cc;
20592051

20602052
int mmq_x, mmq_y, nwarps;
@@ -2098,8 +2090,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
20982090
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
20992091
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
21002092

2101-
int id;
2102-
CUDA_CHECK(cudaGetDevice(&id));
2093+
int id = ggml_cuda_get_device();
21032094
const int compute_capability = ggml_cuda_info().devices[id].cc;
21042095

21052096
int mmq_x, mmq_y, nwarps;
@@ -2143,8 +2134,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
21432134
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
21442135
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
21452136

2146-
int id;
2147-
CUDA_CHECK(cudaGetDevice(&id));
2137+
int id = ggml_cuda_get_device();
21482138
const int compute_capability = ggml_cuda_info().devices[id].cc;
21492139

21502140
int mmq_x, mmq_y, nwarps;

ggml-cuda/mmvq.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,7 @@ static void mul_mat_vec_q_cuda(
8989
GGML_ASSERT(ncols_x % qk == 0);
9090
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
9191

92-
int id;
93-
CUDA_CHECK(cudaGetDevice(&id));
92+
int id = ggml_cuda_get_device();
9493

9594
int64_t nwarps = 1;
9695
int64_t rows_per_cuda_block = 1;
@@ -328,8 +327,7 @@ void ggml_cuda_op_mul_mat_vec_q(
328327

329328
const int64_t ne0 = dst->ne[0];
330329

331-
int id;
332-
CUDA_CHECK(cudaGetDevice(&id));
330+
int id = ggml_cuda_get_device();
333331

334332
// the main device has a larger memory buffer to hold the results from all GPUs
335333
// nrows_dst == nrows of the matrix that the kernel writes into

ggml-cuda/scale.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,4 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
2828
memcpy(&scale, dst->op_params, sizeof(float));
2929

3030
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
31-
CUDA_CHECK(cudaGetLastError());
3231
}

0 commit comments

Comments
 (0)