4
4
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
5
5
###############################################################################
6
6
7
- import os
8
7
from dataclasses import dataclass
9
8
from typing import Any , Dict , List , Optional , Tuple , Type
10
9
11
10
import torch
11
+ import vllm_hpu_extension .kernels as kernels
12
12
import vllm_hpu_extension .ops as ops
13
- from vllm_hpu_extension .utils import ( Matmul , ModuleFusedSDPA , Softmax ,
14
- VLLMKVCache )
13
+ from vllm_hpu_extension .flags import enabled_flags
14
+ from vllm_hpu_extension . utils import Matmul , Softmax , VLLMKVCache
15
15
16
16
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
17
17
AttentionLayer ,
@@ -126,7 +126,15 @@ def __init__(
126
126
self .block2batch_matmul = Matmul ()
127
127
self .k_cache = VLLMKVCache ()
128
128
self .v_cache = VLLMKVCache ()
129
- ops .pa_impl = ops .pa
129
+ self .fused_scaled_dot_product_attention = kernels .fsdpa ()
130
+
131
+ self .prefill_impl = 'naive'
132
+ if "flex_attention" in enabled_flags ():
133
+ self .prefill_impl = 'flex'
134
+ if "fsdpa" in enabled_flags ():
135
+ assert alibi_slopes is None , \
136
+ 'Prefill with FusedSDPA not supported with alibi slopes!'
137
+ self .prefill_impl = 'fsdpa'
130
138
131
139
self .num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
132
140
self .sliding_window = sliding_window
@@ -138,27 +146,18 @@ def __init__(
138
146
assert self .num_heads % self .num_kv_heads == 0
139
147
self .num_queries_per_kv = self .num_heads // self .num_kv_heads
140
148
141
- self .prefill_usefusedsdpa = os .getenv ('VLLM_PROMPT_USE_FUSEDSDPA' ,
142
- '0' ).lower () in ['1' , 'true' ]
143
- self .fused_scaled_dot_product_attention = None
144
- if self .prefill_usefusedsdpa :
149
+ if self .prefill_impl == 'fsdpa' :
145
150
assert alibi_slopes is None , \
146
151
'Prefill with FusedSDPA not supported with alibi slopes!'
147
- try :
148
- from habana_frameworks .torch .hpex .kernels import FusedSDPA
149
- self .fused_scaled_dot_product_attention = ModuleFusedSDPA (
150
- FusedSDPA )
151
- except ImportError :
152
- logger .warning ("Could not import HPU FusedSDPA kernel. "
153
- "vLLM will use native implementation." )
154
152
155
153
supported_head_sizes = HPUPagedAttention .get_supported_head_sizes ()
156
154
if head_size not in supported_head_sizes :
157
155
raise ValueError (
158
156
f"Head size { head_size } is not supported by PagedAttention. "
159
157
f"Supported head sizes are: { supported_head_sizes } ." )
160
158
161
- if attn_type != AttentionType .DECODER :
159
+ self .attn_type = attn_type
160
+ if self .attn_type != AttentionType .DECODER :
162
161
raise NotImplementedError ("Encoder self-attention and "
163
162
"encoder/decoder cross-attention "
164
163
"are not implemented for "
@@ -192,15 +191,18 @@ def forward(
192
191
batch_size , seq_len , hidden_size = query .shape
193
192
_ , seq_len_kv , _ = key .shape
194
193
195
- query = query .view (- 1 , self .num_heads , self .head_size )
196
194
key = key .view (- 1 , self .num_kv_heads , self .head_size )
197
195
value = value .view (- 1 , self .num_kv_heads , self .head_size )
198
196
block_indices = attn_metadata .block_indices
199
197
block_offsets = attn_metadata .block_offsets
200
- if attn_metadata .is_prompt :
198
+ key_cache = None
199
+ value_cache = None
200
+ if attn_metadata .is_prompt and self .attn_type \
201
+ is not AttentionType .ENCODER_ONLY \
202
+ and attn_metadata .block_list is None :
201
203
key = key .unflatten (0 , (block_indices .size (0 ), - 1 ))
202
204
value = value .unflatten (0 , (block_indices .size (0 ), - 1 ))
203
- if kv_cache is not None :
205
+ if kv_cache is not None and isinstance ( kv_cache , tuple ) :
204
206
key_cache , value_cache = HPUPagedAttention .split_kv_cache (
205
207
kv_cache , self .num_kv_heads , self .head_size )
206
208
@@ -214,36 +216,28 @@ def forward(
214
216
215
217
if attn_metadata .is_prompt :
216
218
# Prompt run.
217
- if not self .prefill_usefusedsdpa :
218
- # TODO: move this outside of model
219
- assert attn_metadata .attn_bias is not None , \
220
- 'attn_bias must be set before calling model.forward!'
221
- attn_bias = attn_metadata .attn_bias
222
- if self .alibi_slopes is not None :
223
- position_bias = _make_alibi_bias (self .alibi_slopes ,
224
- self .num_kv_heads ,
225
- attn_bias .dtype ,
226
- attn_bias .shape [- 1 ])
227
- attn_bias = attn_bias .tile ((1 , self .num_kv_heads , 1 , 1 ))
228
- attn_bias .add_ (position_bias )
229
- else :
230
- attn_bias = None
231
-
232
219
query_shape = (batch_size , seq_len , self .num_heads , self .head_size )
233
220
kv_shape = (batch_size , seq_len_kv , self .num_kv_heads ,
234
221
self .head_size )
222
+
223
+ attn_bias = attn_metadata .attn_bias
224
+ if attn_bias is not None and self .alibi_slopes is not None :
225
+ position_bias = _make_alibi_bias (self .alibi_slopes ,
226
+ self .num_kv_heads ,
227
+ attn_bias .dtype ,
228
+ attn_bias .shape [- 1 ])
229
+ attn_bias = attn_bias .tile ((1 , self .num_kv_heads , 1 , 1 ))
230
+ attn_bias .add_ (position_bias )
231
+
235
232
out = ops .prompt_attention (
236
- query .view (query_shape ),
237
- key .view (kv_shape ),
238
- value .view (kv_shape ),
233
+ impl = self .prefill_impl ,
234
+ query = query .view (query_shape ),
235
+ key = key .view (kv_shape ),
236
+ value = value .view (kv_shape ),
237
+ is_causal = True ,
239
238
attn_bias = attn_bias ,
240
- p = 0.0 ,
241
- scale = self .scale ,
242
- matmul_qk_op = self .matmul_qk ,
243
- softmax_op = self .softmax ,
244
- matmul_av_op = self .matmul_av ,
245
- fsdpa_op = self .fused_scaled_dot_product_attention ,
246
- )
239
+ valid_seq_lengths = attn_metadata .seq_lens_tensor ,
240
+ ** self .common_attention_args ())
247
241
output = out .reshape (batch_size , seq_len , hidden_size )
248
242
else :
249
243
# Decoding run.
@@ -254,18 +248,26 @@ def forward(
254
248
block_list = attn_metadata .block_list ,
255
249
block_mapping = attn_metadata .block_mapping ,
256
250
block_bias = attn_metadata .attn_bias ,
257
- block_scales = attn_metadata .block_scales ,
258
251
block_groups = attn_metadata .block_groups ,
259
- scale = self .scale ,
260
- matmul_qk_op = self .matmul_qk ,
261
- matmul_av_op = self .matmul_av ,
262
- batch2block_matmul_op = self .batch2block_matmul ,
263
- block2batch_matmul_op = self .block2batch_matmul ,
264
- keys_fetch_func = self .k_cache .fetch_from_cache ,
265
- values_fetch_func = self .v_cache .fetch_from_cache )
252
+ ** self .common_attention_args ())
266
253
# Reshape the output tensor.
267
254
return output .view (batch_size , seq_len , hidden_size )
268
255
256
+ def common_attention_args (self ):
257
+ fsdpa_op = self .fused_scaled_dot_product_attention .apply \
258
+ if self .fused_scaled_dot_product_attention is not None else None
259
+ return {
260
+ 'scale' : self .scale ,
261
+ 'matmul_qk_op' : self .matmul_qk ,
262
+ 'matmul_av_op' : self .matmul_av ,
263
+ 'batch2block_matmul_op' : self .batch2block_matmul ,
264
+ 'block2batch_matmul_op' : self .block2batch_matmul ,
265
+ 'fsdpa_op' : fsdpa_op ,
266
+ 'keys_fetch_func' : self .k_cache .fetch_from_cache ,
267
+ 'values_fetch_func' : self .v_cache .fetch_from_cache ,
268
+ 'softmax_op' : self .softmax ,
269
+ }
270
+
269
271
270
272
def _make_alibi_bias (
271
273
alibi_slopes : torch .Tensor ,
0 commit comments