Skip to content

Commit 9e96f56

Browse files
authored
Allocate kv_cache with stride order (#16605)
Signed-off-by: shuw <[email protected]>
1 parent b278911 commit 9e96f56

File tree

6 files changed

+119
-50
lines changed

6 files changed

+119
-50
lines changed

csrc/cache_kernels.cu

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel(
270270
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
271271
// head_size]
272272
const int64_t* __restrict__ slot_mapping, // [num_tokens]
273-
const int block_stride, const int key_stride, const int value_stride,
274-
const int num_heads, const int head_size, const int block_size,
275-
const float* k_scale, const float* v_scale) {
273+
const int64_t block_stride, const int64_t page_stride,
274+
const int64_t head_stride, const int64_t key_stride,
275+
const int64_t value_stride, const int num_heads, const int head_size,
276+
const int block_size, const float* k_scale, const float* v_scale) {
276277
const int64_t token_idx = blockIdx.x;
277278
const int64_t slot_idx = slot_mapping[token_idx];
278279
// NOTE: slot_idx can be -1 if the token is padded
@@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel(
288289
const int head_idx = i / head_size;
289290
const int head_offset = i % head_size;
290291
const int64_t tgt_key_value_idx = block_idx * block_stride +
291-
block_offset * num_heads * head_size +
292-
head_idx * head_size + head_offset;
292+
block_offset * page_stride +
293+
head_idx * head_stride + head_offset;
293294
scalar_t tgt_key = key[src_key_idx];
294295
scalar_t tgt_value = value[src_value_idx];
295296
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
@@ -396,16 +397,16 @@ void reshape_and_cache(
396397
// KV_T is the data type of key and value tensors.
397398
// CACHE_T is the stored data type of kv-cache.
398399
// KV_DTYPE is the real data type of kv-cache.
399-
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
400-
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
401-
<<<grid, block, 0, stream>>>( \
402-
reinterpret_cast<KV_T*>(key.data_ptr()), \
403-
reinterpret_cast<KV_T*>(value.data_ptr()), \
404-
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
405-
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
406-
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
407-
value_stride, num_heads, head_size, block_size, \
408-
reinterpret_cast<const float*>(k_scale.data_ptr()), \
400+
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
401+
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
402+
<<<grid, block, 0, stream>>>( \
403+
reinterpret_cast<KV_T*>(key.data_ptr()), \
404+
reinterpret_cast<KV_T*>(value.data_ptr()), \
405+
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
406+
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
407+
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
408+
head_stride, key_stride, value_stride, num_heads, head_size, \
409+
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
409410
reinterpret_cast<const float*>(v_scale.data_ptr()));
410411

411412
void reshape_and_cache_flash(
@@ -432,9 +433,11 @@ void reshape_and_cache_flash(
432433
int head_size = key.size(2);
433434
int block_size = key_cache.size(1);
434435

435-
int key_stride = key.stride(0);
436-
int value_stride = value.stride(0);
437-
int block_stride = key_cache.stride(0);
436+
int64_t key_stride = key.stride(0);
437+
int64_t value_stride = value.stride(0);
438+
int64_t block_stride = key_cache.stride(0);
439+
int64_t page_stride = key_cache.stride(1);
440+
int64_t head_stride = key_cache.stride(2);
438441
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
439442

440443
dim3 grid(num_tokens);

tests/kernels/attention/test_cache.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
NUM_HEADS = [8] # Arbitrary values for testing
1717
HEAD_SIZES = [64, 80, 120, 256]
1818
BLOCK_SIZES = [8, 16, 32]
19+
CACHE_LAYOUTS = ["NHD", "HND"]
1920

2021
# Parameters for MLA tests.
2122
KV_LORA_RANKS = [512]
@@ -220,6 +221,7 @@ def test_reshape_and_cache(
220221
@pytest.mark.parametrize("seed", SEEDS)
221222
@pytest.mark.parametrize("device", CUDA_DEVICES)
222223
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
224+
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
223225
@torch.inference_mode()
224226
def test_reshape_and_cache_flash(
225227
kv_cache_factory_flashinfer,
@@ -232,17 +234,21 @@ def test_reshape_and_cache_flash(
232234
seed: int,
233235
device: str,
234236
kv_cache_dtype: str,
237+
kv_cache_layout: str,
235238
) -> None:
236239
current_platform.seed_everything(seed)
237240
torch.set_default_device(device)
238241

242+
# fp8 conversion requires continugous memory buffer. Reduce the number of
243+
# blocks and tokens to consume less memory.
244+
num_tokens = num_tokens // 2
245+
num_blocks = num_blocks // 2
239246
# Create a random slot mapping.
240247
num_slots = block_size * num_blocks
241248
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
242249
slot_mapping = torch.tensor(slot_mapping_lst,
243250
dtype=torch.long,
244251
device=device)
245-
246252
qkv = torch.randn(num_tokens,
247253
3,
248254
num_heads,
@@ -261,44 +267,56 @@ def test_reshape_and_cache_flash(
261267
kv_cache_dtype,
262268
dtype,
263269
device=device,
270+
cache_layout=kv_cache_layout,
264271
)
265-
key_cache, value_cache = key_caches[0].contiguous(
266-
), value_caches[0].contiguous()
272+
key_cache, value_cache = key_caches[0], value_caches[0]
267273
del key_caches
268274
del value_caches
269275

270276
k_scale = (key.amax() / 64.0).to(torch.float32)
271277
v_scale = (value.amax() / 64.0).to(torch.float32)
272278

279+
def permute_and_compact(x):
280+
y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
281+
return y.contiguous()
282+
283+
key_cache_compact = permute_and_compact(key_cache)
284+
value_cache_compact = permute_and_compact(value_cache)
285+
273286
# Clone the KV caches.
274287
if kv_cache_dtype == "fp8":
275-
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
276-
ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(),
277-
kv_cache_dtype)
278-
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
279-
ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(),
288+
cloned_key_cache = torch.empty_like(key_cache_compact,
289+
dtype=torch.float16)
290+
ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(),
280291
kv_cache_dtype)
292+
cloned_value_cache = torch.empty_like(value_cache_compact,
293+
dtype=torch.float16)
294+
ops.convert_fp8(cloned_value_cache, value_cache_compact,
295+
v_scale.item(), kv_cache_dtype)
281296
else:
282-
cloned_key_cache = key_cache.clone()
283-
cloned_value_cache = value_cache.clone()
284-
297+
cloned_key_cache = key_cache_compact.clone()
298+
cloned_value_cache = value_cache_compact.clone()
285299
# Call the reshape_and_cache kernel.
286300
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
287301
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
288302
k_scale, v_scale),
289303
cond=(head_size == HEAD_SIZES[0]))
290304
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
291305
slot_mapping, kv_cache_dtype, k_scale, v_scale)
306+
key_cache_compact = permute_and_compact(key_cache)
307+
value_cache_compact = permute_and_compact(value_cache)
292308

293309
if kv_cache_dtype == "fp8":
294-
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
310+
result_key_cache = torch.empty_like(key_cache_compact,
311+
dtype=torch.float16)
295312
ops.convert_fp8(result_key_cache,
296-
key_cache,
313+
key_cache_compact,
297314
k_scale.item(),
298315
kv_dtype=kv_cache_dtype)
299-
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
316+
result_value_cache = torch.empty_like(value_cache_compact,
317+
dtype=torch.float16)
300318
ops.convert_fp8(result_value_cache,
301-
value_cache,
319+
value_cache_compact,
302320
v_scale.item(),
303321
kv_dtype=kv_cache_dtype)
304322

@@ -310,8 +328,12 @@ def test_reshape_and_cache_flash(
310328
for i in range(num_tokens):
311329
block_idx = block_indicies_lst[i]
312330
block_offset = block_offsets_lst[i]
313-
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
314-
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
331+
if kv_cache_layout == "NHD":
332+
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
333+
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
334+
else:
335+
cloned_key_cache[block_idx, :, block_offset, :] = key[i]
336+
cloned_value_cache[block_idx, :, block_offset, :] = value[i]
315337

316338
if kv_cache_dtype == "fp8":
317339
torch.testing.assert_close(result_key_cache,
@@ -323,8 +345,8 @@ def test_reshape_and_cache_flash(
323345
atol=0.001,
324346
rtol=0.1)
325347
else:
326-
torch.testing.assert_close(key_cache, cloned_key_cache)
327-
torch.testing.assert_close(value_cache, cloned_value_cache)
348+
torch.testing.assert_close(key_cache_compact, cloned_key_cache)
349+
torch.testing.assert_close(value_cache_compact, cloned_value_cache)
328350

329351

330352
@pytest.mark.parametrize("direction", COPYING_DIRECTION)

vllm/attention/backends/abstract.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def get_kv_cache_shape(
7777
) -> Tuple[int, ...]:
7878
raise NotImplementedError
7979

80+
@staticmethod
81+
def get_kv_cache_stride_order() -> Tuple[int, ...]:
82+
raise NotImplementedError
83+
8084
@staticmethod
8185
@abstractmethod
8286
def swap_blocks(

vllm/attention/backends/flashinfer.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import dataclasses
4+
import os
45
from collections import defaultdict
56
from contextlib import contextmanager
67
from dataclasses import dataclass
@@ -48,6 +49,9 @@
4849
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
4950
ModelInputForGPUWithSamplingMetadata)
5051

52+
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
53+
"NHD").upper()
54+
5155

5256
class FlashInferBackend(AttentionBackend):
5357

@@ -80,6 +84,14 @@ def get_kv_cache_shape(
8084
) -> Tuple[int, ...]:
8185
return (num_blocks, 2, block_size, num_kv_heads, head_size)
8286

87+
@staticmethod
88+
def get_kv_cache_stride_order() -> Tuple[int, ...]:
89+
cache_layout = FLASHINFER_KV_CACHE_LAYOUT
90+
assert (cache_layout in ("NHD", "HND"))
91+
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3,
92+
2, 4)
93+
return stride_order
94+
8395
@staticmethod
8496
def swap_blocks(
8597
src_kv_cache: torch.Tensor,
@@ -188,6 +200,7 @@ def __init__(self, runner):
188200
self.global_hyperparameters: Optional[PerLayerParameters] = None
189201

190202
self.vllm_config = self.runner.vllm_config
203+
self._kv_cache_layout = None
191204

192205
def _get_workspace_buffer(self):
193206
if self._workspace_buffer is None:
@@ -197,10 +210,15 @@ def _get_workspace_buffer(self):
197210
device=self.runner.device)
198211
return self._workspace_buffer
199212

213+
def get_kv_cache_layout(self):
214+
if self._kv_cache_layout is None:
215+
self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT
216+
return self._kv_cache_layout
217+
200218
def _get_prefill_wrapper(self):
201219
if self._prefill_wrapper is None:
202220
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
203-
self._get_workspace_buffer(), "NHD")
221+
self._get_workspace_buffer(), self.get_kv_cache_layout())
204222
return self._prefill_wrapper
205223

206224
def _get_decode_wrapper(self):
@@ -213,7 +231,7 @@ def _get_decode_wrapper(self):
213231
num_qo_heads // num_kv_heads > 4)
214232
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
215233
self._get_workspace_buffer(),
216-
"NHD",
234+
self.get_kv_cache_layout(),
217235
use_tensor_cores=use_tensor_cores)
218236
return self._decode_wrapper
219237

@@ -274,7 +292,8 @@ def graph_capture_get_metadata_for_batch(
274292
self._graph_decode_wrapper = \
275293
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
276294
self._graph_decode_workspace_buffer, _indptr_buffer,
277-
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
295+
self._graph_indices_buffer, _last_page_len_buffer,
296+
self.get_kv_cache_layout(),
278297
use_tensor_cores)
279298
if self.runner.kv_cache_dtype.startswith("fp8"):
280299
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
@@ -1005,6 +1024,7 @@ def forward(
10051024

10061025
prefill_output: Optional[torch.Tensor] = None
10071026
decode_output: Optional[torch.Tensor] = None
1027+
stride_order = FlashInferBackend.get_kv_cache_stride_order()
10081028
if prefill_meta := attn_metadata.prefill_metadata:
10091029
# We will use flash attention for prefill
10101030
# when kv_cache is not provided.
@@ -1036,7 +1056,7 @@ def forward(
10361056

10371057
prefill_output = prefill_meta.prefill_wrapper.run(
10381058
query,
1039-
kv_cache,
1059+
kv_cache.permute(*stride_order),
10401060
k_scale=layer._k_scale_float,
10411061
v_scale=layer._v_scale_float,
10421062
)
@@ -1051,7 +1071,7 @@ def forward(
10511071

10521072
decode_output = decode_meta.decode_wrapper.run(
10531073
decode_query,
1054-
kv_cache,
1074+
kv_cache.permute(*stride_order),
10551075
k_scale=layer._k_scale_float,
10561076
v_scale=layer._v_scale_float,
10571077
)

vllm/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -765,21 +765,28 @@ def create_kv_caches_with_random_flash(
765765
model_dtype: Optional[Union[str, torch.dtype]] = None,
766766
seed: Optional[int] = None,
767767
device: Optional[str] = "cuda",
768+
cache_layout: Optional[str] = "NHD",
768769
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
769770
from vllm.platforms import current_platform
770771
current_platform.seed_everything(seed)
771772

772773
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
773-
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
774+
generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
775+
assert cache_layout in ("NHD", "HND")
776+
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2,
777+
4)
778+
779+
kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i]
780+
for i in stride_order)
774781
scale = head_size**-0.5
775782

776783
key_caches: list[torch.Tensor] = []
777784
value_caches: list[torch.Tensor] = []
778785

779786
for _ in range(num_layers):
780-
key_value_cache = torch.empty(size=key_value_cache_shape,
787+
key_value_cache = torch.empty(size=kv_cache_allocation_shape,
781788
dtype=torch_dtype,
782-
device=device)
789+
device=device).permute(*stride_order)
783790
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
784791
key_value_cache.uniform_(-scale, scale)
785792
elif cache_dtype == 'fp8':

vllm/worker/cache_engine.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,32 @@ def _allocate_kv_cache(
7171
device: str,
7272
) -> List[torch.Tensor]:
7373
"""Allocates KV cache on the specified device."""
74-
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
74+
kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape(
7575
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
7676
pin_memory = is_pin_memory_available() if device == "cpu" else False
7777
kv_cache: List[torch.Tensor] = []
78+
try:
79+
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
80+
)
81+
except (AttributeError, NotImplementedError):
82+
kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape)))
83+
84+
# The allocation respects the backend-defined stride order to ensure
85+
# the semantic remains consistent for each backend. We first obtain the
86+
# generic kv cache shape and then permute it according to the stride
87+
# order which could result in a non-contiguous tensor.
88+
kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i]
89+
for i in kv_cache_stride_order)
7890

7991
for _ in range(self.num_attention_layers):
8092
# null block in CpuGpuBlockAllocator requires at least that
8193
# block to be zeroed-out.
8294
# We zero-out everything for simplicity.
83-
layer_kv_cache = torch.zeros(kv_cache_shape,
84-
dtype=self.dtype,
85-
pin_memory=pin_memory,
86-
device=device)
95+
layer_kv_cache = torch.zeros(
96+
kv_cache_allocation_shape,
97+
dtype=self.dtype,
98+
pin_memory=pin_memory,
99+
device=device).permute(*kv_cache_stride_order)
87100

88101
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
89102
# when entry_shape is higher than 1D

0 commit comments

Comments
 (0)