Skip to content

Commit 810fd2b

Browse files
adobrzynwuisawesome
authored andcommitted
[Hardware][Intel-Gaudi] Update hpu-extension and update bucketing system for HPU device (vllm-project#17186)
Signed-off-by: Agata Dobrzyniewicz <[email protected]>
1 parent 26cf9d6 commit 810fd2b

File tree

6 files changed

+128
-335
lines changed

6 files changed

+128
-335
lines changed

requirements/hpu.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ numpy==1.26.4
99
tabulate
1010
setuptools>=61
1111
setuptools-scm>=8
12-
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768
12+
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@f1f6624

vllm/attention/backends/hpu_attn.py

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
55
###############################################################################
66

7-
import os
87
from dataclasses import dataclass
98
from typing import Any, Dict, List, Optional, Tuple, Type
109

1110
import torch
11+
import vllm_hpu_extension.kernels as kernels
1212
import vllm_hpu_extension.ops as ops
13-
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
14-
VLLMKVCache)
13+
from vllm_hpu_extension.flags import enabled_flags
14+
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
1515

1616
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1717
AttentionLayer,
@@ -126,7 +126,15 @@ def __init__(
126126
self.block2batch_matmul = Matmul()
127127
self.k_cache = VLLMKVCache()
128128
self.v_cache = VLLMKVCache()
129-
ops.pa_impl = ops.pa
129+
self.fused_scaled_dot_product_attention = kernels.fsdpa()
130+
131+
self.prefill_impl = 'naive'
132+
if "flex_attention" in enabled_flags():
133+
self.prefill_impl = 'flex'
134+
if "fsdpa" in enabled_flags():
135+
assert alibi_slopes is None, \
136+
'Prefill with FusedSDPA not supported with alibi slopes!'
137+
self.prefill_impl = 'fsdpa'
130138

131139
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
132140
self.sliding_window = sliding_window
@@ -138,27 +146,18 @@ def __init__(
138146
assert self.num_heads % self.num_kv_heads == 0
139147
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
140148

141-
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
142-
'0').lower() in ['1', 'true']
143-
self.fused_scaled_dot_product_attention = None
144-
if self.prefill_usefusedsdpa:
149+
if self.prefill_impl == 'fsdpa':
145150
assert alibi_slopes is None, \
146151
'Prefill with FusedSDPA not supported with alibi slopes!'
147-
try:
148-
from habana_frameworks.torch.hpex.kernels import FusedSDPA
149-
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
150-
FusedSDPA)
151-
except ImportError:
152-
logger.warning("Could not import HPU FusedSDPA kernel. "
153-
"vLLM will use native implementation.")
154152

155153
supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
156154
if head_size not in supported_head_sizes:
157155
raise ValueError(
158156
f"Head size {head_size} is not supported by PagedAttention. "
159157
f"Supported head sizes are: {supported_head_sizes}.")
160158

161-
if attn_type != AttentionType.DECODER:
159+
self.attn_type = attn_type
160+
if self.attn_type != AttentionType.DECODER:
162161
raise NotImplementedError("Encoder self-attention and "
163162
"encoder/decoder cross-attention "
164163
"are not implemented for "
@@ -192,15 +191,18 @@ def forward(
192191
batch_size, seq_len, hidden_size = query.shape
193192
_, seq_len_kv, _ = key.shape
194193

195-
query = query.view(-1, self.num_heads, self.head_size)
196194
key = key.view(-1, self.num_kv_heads, self.head_size)
197195
value = value.view(-1, self.num_kv_heads, self.head_size)
198196
block_indices = attn_metadata.block_indices
199197
block_offsets = attn_metadata.block_offsets
200-
if attn_metadata.is_prompt:
198+
key_cache = None
199+
value_cache = None
200+
if attn_metadata.is_prompt and self.attn_type \
201+
is not AttentionType.ENCODER_ONLY \
202+
and attn_metadata.block_list is None:
201203
key = key.unflatten(0, (block_indices.size(0), -1))
202204
value = value.unflatten(0, (block_indices.size(0), -1))
203-
if kv_cache is not None:
205+
if kv_cache is not None and isinstance(kv_cache, tuple):
204206
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
205207
kv_cache, self.num_kv_heads, self.head_size)
206208

@@ -214,36 +216,28 @@ def forward(
214216

215217
if attn_metadata.is_prompt:
216218
# Prompt run.
217-
if not self.prefill_usefusedsdpa:
218-
# TODO: move this outside of model
219-
assert attn_metadata.attn_bias is not None, \
220-
'attn_bias must be set before calling model.forward!'
221-
attn_bias = attn_metadata.attn_bias
222-
if self.alibi_slopes is not None:
223-
position_bias = _make_alibi_bias(self.alibi_slopes,
224-
self.num_kv_heads,
225-
attn_bias.dtype,
226-
attn_bias.shape[-1])
227-
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
228-
attn_bias.add_(position_bias)
229-
else:
230-
attn_bias = None
231-
232219
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
233220
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
234221
self.head_size)
222+
223+
attn_bias = attn_metadata.attn_bias
224+
if attn_bias is not None and self.alibi_slopes is not None:
225+
position_bias = _make_alibi_bias(self.alibi_slopes,
226+
self.num_kv_heads,
227+
attn_bias.dtype,
228+
attn_bias.shape[-1])
229+
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
230+
attn_bias.add_(position_bias)
231+
235232
out = ops.prompt_attention(
236-
query.view(query_shape),
237-
key.view(kv_shape),
238-
value.view(kv_shape),
233+
impl=self.prefill_impl,
234+
query=query.view(query_shape),
235+
key=key.view(kv_shape),
236+
value=value.view(kv_shape),
237+
is_causal=True,
239238
attn_bias=attn_bias,
240-
p=0.0,
241-
scale=self.scale,
242-
matmul_qk_op=self.matmul_qk,
243-
softmax_op=self.softmax,
244-
matmul_av_op=self.matmul_av,
245-
fsdpa_op=self.fused_scaled_dot_product_attention,
246-
)
239+
valid_seq_lengths=attn_metadata.seq_lens_tensor,
240+
**self.common_attention_args())
247241
output = out.reshape(batch_size, seq_len, hidden_size)
248242
else:
249243
# Decoding run.
@@ -254,18 +248,26 @@ def forward(
254248
block_list=attn_metadata.block_list,
255249
block_mapping=attn_metadata.block_mapping,
256250
block_bias=attn_metadata.attn_bias,
257-
block_scales=attn_metadata.block_scales,
258251
block_groups=attn_metadata.block_groups,
259-
scale=self.scale,
260-
matmul_qk_op=self.matmul_qk,
261-
matmul_av_op=self.matmul_av,
262-
batch2block_matmul_op=self.batch2block_matmul,
263-
block2batch_matmul_op=self.block2batch_matmul,
264-
keys_fetch_func=self.k_cache.fetch_from_cache,
265-
values_fetch_func=self.v_cache.fetch_from_cache)
252+
**self.common_attention_args())
266253
# Reshape the output tensor.
267254
return output.view(batch_size, seq_len, hidden_size)
268255

256+
def common_attention_args(self):
257+
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
258+
if self.fused_scaled_dot_product_attention is not None else None
259+
return {
260+
'scale': self.scale,
261+
'matmul_qk_op': self.matmul_qk,
262+
'matmul_av_op': self.matmul_av,
263+
'batch2block_matmul_op': self.batch2block_matmul,
264+
'block2batch_matmul_op': self.block2batch_matmul,
265+
'fsdpa_op': fsdpa_op,
266+
'keys_fetch_func': self.k_cache.fetch_from_cache,
267+
'values_fetch_func': self.v_cache.fetch_from_cache,
268+
'softmax_op': self.softmax,
269+
}
270+
269271

270272
def _make_alibi_bias(
271273
alibi_slopes: torch.Tensor,

vllm/attention/ops/hpu_paged_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ class HPUPagedAttentionMetadata:
2222
block_usage: Optional[torch.Tensor]
2323
block_indices: Optional[torch.Tensor]
2424
block_offsets: Optional[torch.Tensor]
25-
block_scales: Optional[torch.Tensor]
2625
block_groups: Optional[torch.Tensor]
2726

2827

vllm/model_executor/layers/layernorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ def forward_hpu(
168168
x: torch.Tensor,
169169
residual: Optional[torch.Tensor] = None,
170170
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
171-
from vllm_hpu_extension.ops import HPUFusedRMSNorm
171+
from vllm_hpu_extension.kernels import rms_norm
172+
HPUFusedRMSNorm = rms_norm()
172173
if HPUFusedRMSNorm is None:
173174
return self.forward_native(x, residual)
174175
if residual is not None:

0 commit comments

Comments
 (0)