Skip to content

Commit 4dca833

Browse files
[Perf]Optimize rotary_emb implementation to use Triton operator for improved inference performance
Signed-off-by: cynthieye <[email protected]> Co-authored-by: MagnetoWang <[email protected]>
1 parent 99ef59c commit 4dca833

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
4646
return x.flatten(-2)
4747

4848

49-
def _apply_rotary_emb(
49+
def _apply_rotary_emb_torch(
5050
x: torch.Tensor,
5151
cos: torch.Tensor,
5252
sin: torch.Tensor,
5353
is_neox_style: bool,
5454
) -> torch.Tensor:
55-
"""
56-
Args:
57-
x: [num_tokens, num_heads, head_size]
58-
cos: [num_tokens, head_size // 2]
59-
sin: [num_tokens, head_size // 2]
60-
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
61-
positional embeddings.
62-
"""
6355
cos = cos.unsqueeze(-2).to(x.dtype)
6456
sin = sin.unsqueeze(-2).to(x.dtype)
6557
if is_neox_style:
@@ -75,6 +67,26 @@ def _apply_rotary_emb(
7567
return torch.stack((o1, o2), dim=-1).flatten(-2)
7668

7769

70+
def _apply_rotary_emb(x: torch.Tensor,
71+
cos: torch.Tensor,
72+
sin: torch.Tensor,
73+
is_neox_style: bool) -> torch.Tensor:
74+
"""
75+
Args:
76+
x: [num_tokens, num_heads, head_size]
77+
cos: [num_tokens, head_size // 2]
78+
sin: [num_tokens, head_size // 2]
79+
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
80+
positional embeddings.
81+
"""
82+
if current_platform.is_cuda_alike():
83+
from vllm_flash_attn.layers.rotary import apply_rotary_emb
84+
return apply_rotary_emb(x.unsqueeze(0), cos, sin,
85+
not is_neox_style).squeeze(0)
86+
else:
87+
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
88+
89+
7890
@CustomOp.register("rotary_embedding")
7991
class RotaryEmbedding(CustomOp):
8092
"""Original rotary positional embedding."""
@@ -141,14 +153,14 @@ def forward_native(
141153
query = query.view(num_tokens, -1, self.head_size)
142154
query_rot = query[..., :self.rotary_dim]
143155
query_pass = query[..., self.rotary_dim:]
144-
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
156+
query_rot = _apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style)
145157
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
146158

147159
key_shape = key.shape
148160
key = key.view(num_tokens, -1, self.head_size)
149161
key_rot = key[..., :self.rotary_dim]
150162
key_pass = key[..., self.rotary_dim:]
151-
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
163+
key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style)
152164
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
153165
return query, key
154166

@@ -309,9 +321,9 @@ def _apply_rotary_emb_neuron(
309321
key = key.view(num_tokens, -1, self.head_size)
310322

311323
if self.rotary_dim == self.head_size:
312-
query = _apply_rotary_emb(query, cos, sin, self.is_neox_style)
324+
query = _apply_rotary_emb_torch(query, cos, sin, self.is_neox_style)
313325
query = query.reshape(query_shape)
314-
key = _apply_rotary_emb(key, cos, sin, self.is_neox_style)
326+
key = _apply_rotary_emb_torch(key, cos, sin, self.is_neox_style)
315327
key = key.reshape(key_shape)
316328
else:
317329
head_size = query.shape[-1]

0 commit comments

Comments
 (0)