@@ -46,20 +46,12 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
46
46
return x .flatten (- 2 )
47
47
48
48
49
- def _apply_rotary_emb (
49
+ def _apply_rotary_emb_torch (
50
50
x : torch .Tensor ,
51
51
cos : torch .Tensor ,
52
52
sin : torch .Tensor ,
53
53
is_neox_style : bool ,
54
54
) -> torch .Tensor :
55
- """
56
- Args:
57
- x: [num_tokens, num_heads, head_size]
58
- cos: [num_tokens, head_size // 2]
59
- sin: [num_tokens, head_size // 2]
60
- is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
61
- positional embeddings.
62
- """
63
55
cos = cos .unsqueeze (- 2 ).to (x .dtype )
64
56
sin = sin .unsqueeze (- 2 ).to (x .dtype )
65
57
if is_neox_style :
@@ -75,6 +67,26 @@ def _apply_rotary_emb(
75
67
return torch .stack ((o1 , o2 ), dim = - 1 ).flatten (- 2 )
76
68
77
69
70
+ def _apply_rotary_emb (x : torch .Tensor ,
71
+ cos : torch .Tensor ,
72
+ sin : torch .Tensor ,
73
+ is_neox_style : bool ) -> torch .Tensor :
74
+ """
75
+ Args:
76
+ x: [num_tokens, num_heads, head_size]
77
+ cos: [num_tokens, head_size // 2]
78
+ sin: [num_tokens, head_size // 2]
79
+ is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
80
+ positional embeddings.
81
+ """
82
+ if current_platform .is_cuda_alike ():
83
+ from vllm_flash_attn .layers .rotary import apply_rotary_emb
84
+ return apply_rotary_emb (x .unsqueeze (0 ), cos , sin ,
85
+ not is_neox_style ).squeeze (0 )
86
+ else :
87
+ return _apply_rotary_emb_torch (x , cos , sin , is_neox_style )
88
+
89
+
78
90
@CustomOp .register ("rotary_embedding" )
79
91
class RotaryEmbedding (CustomOp ):
80
92
"""Original rotary positional embedding."""
@@ -141,14 +153,14 @@ def forward_native(
141
153
query = query .view (num_tokens , - 1 , self .head_size )
142
154
query_rot = query [..., :self .rotary_dim ]
143
155
query_pass = query [..., self .rotary_dim :]
144
- query_rot = _apply_rotary_emb (query_rot , cos , sin , self .is_neox_style )
156
+ query_rot = _apply_rotary_emb_torch (query_rot , cos , sin , self .is_neox_style )
145
157
query = torch .cat ((query_rot , query_pass ), dim = - 1 ).reshape (query_shape )
146
158
147
159
key_shape = key .shape
148
160
key = key .view (num_tokens , - 1 , self .head_size )
149
161
key_rot = key [..., :self .rotary_dim ]
150
162
key_pass = key [..., self .rotary_dim :]
151
- key_rot = _apply_rotary_emb (key_rot , cos , sin , self .is_neox_style )
163
+ key_rot = _apply_rotary_emb_torch (key_rot , cos , sin , self .is_neox_style )
152
164
key = torch .cat ((key_rot , key_pass ), dim = - 1 ).reshape (key_shape )
153
165
return query , key
154
166
@@ -309,9 +321,9 @@ def _apply_rotary_emb_neuron(
309
321
key = key .view (num_tokens , - 1 , self .head_size )
310
322
311
323
if self .rotary_dim == self .head_size :
312
- query = _apply_rotary_emb (query , cos , sin , self .is_neox_style )
324
+ query = _apply_rotary_emb_torch (query , cos , sin , self .is_neox_style )
313
325
query = query .reshape (query_shape )
314
- key = _apply_rotary_emb (key , cos , sin , self .is_neox_style )
326
+ key = _apply_rotary_emb_torch (key , cos , sin , self .is_neox_style )
315
327
key = key .reshape (key_shape )
316
328
else :
317
329
head_size = query .shape [- 1 ]
0 commit comments