Skip to content

Commit 229f545

Browse files
committed
[Bugfix] Fix TritonPlaceholder conflicts with torch.compile
Signed-off-by: Mengqing Cao <[email protected]>
1 parent a44c4f1 commit 229f545

32 files changed

+302
-158
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010

1111
import ray
1212
import torch
13-
import triton
13+
14+
from vllm.triton_utils import HAS_TRITON
15+
16+
if HAS_TRITON:
17+
import triton
18+
1419
from ray.experimental.tqdm_ray import tqdm
1520
from transformers import AutoConfig
1621

benchmarks/kernels/benchmark_rmsnorm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from typing import Optional, Union
55

66
import torch
7-
import triton
7+
8+
from vllm.triton_utils import HAS_TRITON
9+
10+
if HAS_TRITON:
11+
import triton
12+
813
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
914
from torch import nn
1015

benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
# Import DeepGEMM functions
77
import deep_gemm
88
import torch
9-
import triton
9+
10+
from vllm.triton_utils import HAS_TRITON
11+
12+
if HAS_TRITON:
13+
import triton
14+
1015
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
1116

1217
# Import vLLM functions

tests/kernels/attention/test_flashmla.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
import pytest
77
import torch
8-
import triton
8+
9+
from vllm.triton_utils.importing import HAS_TRITON
10+
11+
if HAS_TRITON:
12+
import triton
913

1014
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
1115
get_mla_metadata,

vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import torch
4-
import triton
5-
import triton.language as tl
4+
5+
from vllm.triton_utils import HAS_TRITON
6+
7+
if HAS_TRITON:
8+
import triton
9+
import triton.language as tl
10+
11+
from vllm.triton_utils import triton_heuristics_decorator, triton_jit_decorator
612

713

814
def blocksparse_flash_attn_varlen_fwd(
@@ -122,7 +128,7 @@ def blocksparse_flash_attn_varlen_fwd(
122128
return out
123129

124130

125-
@triton.jit
131+
@triton_jit_decorator
126132
def _fwd_kernel_inner(
127133
acc,
128134
l_i,
@@ -227,11 +233,11 @@ def _fwd_kernel_inner(
227233
return acc, l_i, m_i
228234

229235

230-
@triton.heuristics({
236+
@triton_heuristics_decorator({
231237
"M_LT_N":
232238
lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
233239
})
234-
@triton.jit
240+
@triton_jit_decorator
235241
def _fwd_kernel_batch_inference(
236242
Q,
237243
K,

vllm/attention/ops/blocksparse_attention/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
import numpy as np
1010
import torch
11-
import triton
11+
12+
from vllm.triton_utils import HAS_TRITON
13+
14+
if HAS_TRITON:
15+
import triton
1216

1317

1418
class csr_matrix:

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,26 @@
77
# - Thomas Parnell <[email protected]>
88

99
import torch
10-
import triton
11-
import triton.language as tl
10+
11+
from vllm.triton_utils import HAS_TRITON
12+
13+
if HAS_TRITON:
14+
import triton
15+
import triton.language as tl
1216

1317
from vllm import _custom_ops as ops
1418
from vllm.platforms.rocm import use_rocm_custom_paged_attention
19+
from vllm.triton_utils import triton_jit_decorator
1520

1621
from .prefix_prefill import context_attention_fwd
1722

1823

19-
@triton.jit
24+
@triton_jit_decorator
2025
def cdiv_fn(x, y):
2126
return (x + y - 1) // y
2227

2328

24-
@triton.jit
29+
@triton_jit_decorator
2530
def kernel_paged_attention_2d(
2631
output_ptr, # [num_tokens, num_query_heads, head_size]
2732
query_ptr, # [num_tokens, num_query_heads, head_size]

vllm/attention/ops/prefix_prefill.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,15 @@
44
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
55

66
import torch
7-
import triton
8-
import triton.language as tl
7+
8+
from vllm.triton_utils import HAS_TRITON
9+
10+
if HAS_TRITON:
11+
import triton
12+
import triton.language as tl
913

1014
from vllm.platforms import current_platform
15+
from vllm.triton_utils import triton_jit_decorator
1116

1217
# Static kernels parameters
1318
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
@@ -32,7 +37,7 @@
3237
# ],
3338
# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"]
3439
# )
35-
@triton.jit
40+
@triton_jit_decorator
3641
def _fwd_kernel(Q,
3742
K,
3843
V,
@@ -280,7 +285,7 @@ def _fwd_kernel(Q,
280285
return
281286

282287

283-
@triton.jit
288+
@triton_jit_decorator
284289
def _fwd_kernel_flash_attn_v2(
285290
Q,
286291
K,
@@ -466,7 +471,7 @@ def _fwd_kernel_flash_attn_v2(
466471
return
467472

468473

469-
@triton.jit
474+
@triton_jit_decorator
470475
def _fwd_kernel_alibi(
471476
Q,
472477
K,

vllm/attention/ops/triton_decode_attention.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,14 @@
3030

3131
import logging
3232

33-
import triton
34-
import triton.language as tl
33+
from vllm.triton_utils import HAS_TRITON
34+
35+
if HAS_TRITON:
36+
import triton
37+
import triton.language as tl
3538

3639
from vllm.platforms import current_platform
40+
from vllm.triton_utils import triton_jit_decorator
3741

3842
is_hip_ = current_platform.is_rocm()
3943

@@ -47,13 +51,13 @@
4751
"can be ignored.")
4852

4953

50-
@triton.jit
54+
@triton_jit_decorator
5155
def tanh(x):
5256
# Tanh is just a scaled sigmoid
5357
return 2 * tl.sigmoid(2 * x) - 1
5458

5559

56-
@triton.jit
60+
@triton_jit_decorator
5761
def _fwd_kernel_stage1(
5862
Q,
5963
K_Buffer,
@@ -229,7 +233,7 @@ def _decode_att_m_fwd(
229233
)
230234

231235

232-
@triton.jit
236+
@triton_jit_decorator
233237
def _fwd_grouped_kernel_stage1(
234238
Q,
235239
K_Buffer,
@@ -469,7 +473,7 @@ def _decode_grouped_att_m_fwd(
469473
)
470474

471475

472-
@triton.jit
476+
@triton_jit_decorator
473477
def _fwd_kernel_stage2(
474478
Mid_O,
475479
o,

vllm/attention/ops/triton_flash_attention.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,16 @@
2525
from typing import Optional
2626

2727
import torch
28-
import triton
29-
import triton.language as tl
28+
29+
from vllm.triton_utils import HAS_TRITON
30+
31+
if HAS_TRITON:
32+
import triton
33+
import triton.language as tl
3034

3135
from vllm import _custom_ops as ops
3236
from vllm.platforms import current_platform
37+
from vllm.triton_utils import triton_jit_decorator
3338

3439
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
3540

@@ -234,19 +239,19 @@ def check_args(self, q, k, v, o):
234239
assert self.layout == 'thd' or not self.varlen
235240

236241

237-
@triton.jit
242+
@triton_jit_decorator
238243
def cdiv_fn(x, y):
239244
return (x + y - 1) // y
240245

241246

242-
@triton.jit
247+
@triton_jit_decorator
243248
def max_fn(x, y):
244249
return tl.math.max(x, y)
245250

246251

247252
# Convenience function to load with optional boundary checks.
248253
# "First" is the major dim, "second" is the minor dim.
249-
@triton.jit
254+
@triton_jit_decorator
250255
def masked_load(ptrs, offset_first, offset_second, boundary_first,
251256
boundary_second):
252257
if offset_first is not None and offset_second is not None:
@@ -264,7 +269,7 @@ def masked_load(ptrs, offset_first, offset_second, boundary_first,
264269
return tensor
265270

266271

267-
@triton.jit
272+
@triton_jit_decorator
268273
def compute_alibi_block(alibi_slope,
269274
seqlen_q,
270275
seqlen_k,
@@ -318,14 +323,14 @@ def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k):
318323
-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)
319324

320325

321-
@triton.jit
326+
@triton_jit_decorator
322327
def quant_fp8(x, scale):
323328
x *= scale
324329
x = tl.clamp(x, FP8_MIN, FP8_MAX)
325330
return x
326331

327332

328-
@triton.jit
333+
@triton_jit_decorator
329334
def _attn_fwd_inner(
330335
acc,
331336
l_i,
@@ -676,7 +681,7 @@ def get_autotune_configs():
676681
key=autotune_keys,
677682
use_cuda_graph=True,
678683
)
679-
@triton.jit
684+
@triton_jit_decorator
680685
def attn_fwd(
681686
Q,
682687
K,

vllm/attention/ops/triton_merge_attn_states.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@
22
from typing import Optional
33

44
import torch
5-
import triton
6-
import triton.language as tl
5+
6+
from vllm.triton_utils import HAS_TRITON
7+
8+
if HAS_TRITON:
9+
import triton
10+
import triton.language as tl
11+
12+
from vllm.triton_utils import triton_jit_decorator
713

814

915
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
@@ -35,7 +41,7 @@ def merge_attn_states(
3541
)
3642

3743

38-
@triton.jit
44+
@triton_jit_decorator
3945
def merge_attn_states_kernel(
4046
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
4147
output_lse, # [NUM_HEADS, NUM_TOKENS]

vllm/lora/ops/triton_ops/kernel_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
"""
33
Utilities for Punica kernel construction.
44
"""
5-
import triton
6-
import triton.language as tl
5+
from vllm.triton_utils import HAS_TRITON
76

7+
if HAS_TRITON:
8+
import triton.language as tl
89

9-
@triton.jit
10+
from vllm.triton_utils import triton_jit_decorator
11+
12+
13+
@triton_jit_decorator
1014
def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr,
1115
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
1216
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr,
@@ -59,7 +63,7 @@ def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr,
5963
return accumulator
6064

6165

62-
@triton.jit
66+
@triton_jit_decorator
6367
def do_expand_kernel(
6468
pid_n,
6569
lora_index,
@@ -161,7 +165,7 @@ def do_expand_kernel(
161165
tl.store(c_ptr, tiled_c, mask=c_mask)
162166

163167

164-
@triton.jit
168+
@triton_jit_decorator
165169
def do_shrink_kernel(
166170
pid_n,
167171
pid_sk,

vllm/lora/ops/triton_ops/lora_expand.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414

1515
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
1616
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
17+
from vllm.triton_utils import triton_jit_decorator
1718
from vllm.utils import direct_register_custom_op
1819

1920

20-
@triton.jit
21+
@triton_jit_decorator
2122
def _lora_expand_kernel(
2223
input_ptr,
2324
lora_ptr,

0 commit comments

Comments
 (0)