Skip to content

Commit f9bc5a0

Browse files
authored
[Bugfix] Fix triton import with local TritonPlaceholder (#17446)
Signed-off-by: Mengqing Cao <[email protected]>
1 parent 05e1f96 commit f9bc5a0

30 files changed

+171
-81
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010

1111
import ray
1212
import torch
13-
import triton
1413
from ray.experimental.tqdm_ray import tqdm
1514
from transformers import AutoConfig
1615

1716
from vllm.model_executor.layers.fused_moe.fused_moe import *
1817
from vllm.platforms import current_platform
18+
from vllm.triton_utils import triton
1919
from vllm.utils import FlexibleArgumentParser
2020

2121
FP8_DTYPE = current_platform.fp8_dtype()

benchmarks/kernels/benchmark_rmsnorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from typing import Optional, Union
55

66
import torch
7-
import triton
87
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
98
from torch import nn
109

1110
from vllm import _custom_ops as vllm_ops
11+
from vllm.triton_utils import triton
1212

1313

1414
class HuggingFaceRMSNorm(nn.Module):

benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
# Import DeepGEMM functions
77
import deep_gemm
88
import torch
9-
import triton
109
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
1110

1211
# Import vLLM functions
1312
from vllm import _custom_ops as ops
1413
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1514
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
15+
from vllm.triton_utils import triton
1616

1717

1818
# Copied from

tests/kernels/attention/test_flashmla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
import pytest
77
import torch
8-
import triton
98

109
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
1110
get_mla_metadata,
1211
is_flashmla_supported)
12+
from vllm.triton_utils import triton
1313

1414

1515
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:

tests/test_triton_utils.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import sys
4+
import types
5+
from unittest import mock
6+
7+
from vllm.triton_utils.importing import (TritonLanguagePlaceholder,
8+
TritonPlaceholder)
9+
10+
11+
def test_triton_placeholder_is_module():
12+
triton = TritonPlaceholder()
13+
assert isinstance(triton, types.ModuleType)
14+
assert triton.__name__ == "triton"
15+
16+
17+
def test_triton_language_placeholder_is_module():
18+
triton_language = TritonLanguagePlaceholder()
19+
assert isinstance(triton_language, types.ModuleType)
20+
assert triton_language.__name__ == "triton.language"
21+
22+
23+
def test_triton_placeholder_decorators():
24+
triton = TritonPlaceholder()
25+
26+
@triton.jit
27+
def foo(x):
28+
return x
29+
30+
@triton.autotune
31+
def bar(x):
32+
return x
33+
34+
@triton.heuristics
35+
def baz(x):
36+
return x
37+
38+
assert foo(1) == 1
39+
assert bar(2) == 2
40+
assert baz(3) == 3
41+
42+
43+
def test_triton_placeholder_decorators_with_args():
44+
triton = TritonPlaceholder()
45+
46+
@triton.jit(debug=True)
47+
def foo(x):
48+
return x
49+
50+
@triton.autotune(configs=[], key="x")
51+
def bar(x):
52+
return x
53+
54+
@triton.heuristics(
55+
{"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64})
56+
def baz(x):
57+
return x
58+
59+
assert foo(1) == 1
60+
assert bar(2) == 2
61+
assert baz(3) == 3
62+
63+
64+
def test_triton_placeholder_language():
65+
lang = TritonLanguagePlaceholder()
66+
assert isinstance(lang, types.ModuleType)
67+
assert lang.__name__ == "triton.language"
68+
assert lang.constexpr is None
69+
assert lang.dtype is None
70+
assert lang.int64 is None
71+
72+
73+
def test_triton_placeholder_language_from_parent():
74+
triton = TritonPlaceholder()
75+
lang = triton.language
76+
assert isinstance(lang, TritonLanguagePlaceholder)
77+
78+
79+
def test_no_triton_fallback():
80+
# clear existing triton modules
81+
sys.modules.pop("triton", None)
82+
sys.modules.pop("triton.language", None)
83+
sys.modules.pop("vllm.triton_utils", None)
84+
sys.modules.pop("vllm.triton_utils.importing", None)
85+
86+
# mock triton not being installed
87+
with mock.patch.dict(sys.modules, {"triton": None}):
88+
from vllm.triton_utils import HAS_TRITON, tl, triton
89+
assert HAS_TRITON is False
90+
assert triton.__class__.__name__ == "TritonPlaceholder"
91+
assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder"
92+
assert tl.__class__.__name__ == "TritonLanguagePlaceholder"

vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import torch
4-
import triton
5-
import triton.language as tl
4+
5+
from vllm.triton_utils import tl, triton
66

77

88
def blocksparse_flash_attn_varlen_fwd(

vllm/attention/ops/blocksparse_attention/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
import numpy as np
1010
import torch
11-
import triton
11+
12+
from vllm.triton_utils import triton
1213

1314

1415
class csr_matrix:

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
# - Thomas Parnell <[email protected]>
88

99
import torch
10-
import triton
11-
import triton.language as tl
1210

1311
from vllm import _custom_ops as ops
1412
from vllm.platforms.rocm import use_rocm_custom_paged_attention
13+
from vllm.triton_utils import tl, triton
1514

1615
from .prefix_prefill import context_attention_fwd
1716

vllm/attention/ops/prefix_prefill.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
55

66
import torch
7-
import triton
8-
import triton.language as tl
97

108
from vllm.platforms import current_platform
9+
from vllm.triton_utils import tl, triton
1110

1211
# Static kernels parameters
1312
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64

vllm/attention/ops/triton_decode_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,8 @@
3030

3131
import logging
3232

33-
import triton
34-
import triton.language as tl
35-
3633
from vllm.platforms import current_platform
34+
from vllm.triton_utils import tl, triton
3735

3836
is_hip_ = current_platform.is_rocm()
3937

vllm/attention/ops/triton_flash_attention.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,10 @@
2525
from typing import Optional
2626

2727
import torch
28-
import triton
29-
import triton.language as tl
3028

3129
from vllm import _custom_ops as ops
3230
from vllm.platforms import current_platform
31+
from vllm.triton_utils import tl, triton
3332

3433
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
3534

vllm/attention/ops/triton_merge_attn_states.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from typing import Optional
33

44
import torch
5-
import triton
6-
import triton.language as tl
5+
6+
from vllm.triton_utils import tl, triton
77

88

99
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005

vllm/lora/ops/triton_ops/kernel_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
"""
33
Utilities for Punica kernel construction.
44
"""
5-
import triton
6-
import triton.language as tl
5+
from vllm.triton_utils import tl, triton
76

87

98
@triton.jit

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from typing import Any, Callable, Dict, List, Optional, Tuple
77

88
import torch
9-
import triton
10-
import triton.language as tl
119

1210
import vllm.envs as envs
1311
from vllm import _custom_ops as ops
@@ -21,6 +19,7 @@
2119
from vllm.model_executor.layers.quantization.utils.int8_utils import (
2220
per_token_group_quant_int8, per_token_quant_int8)
2321
from vllm.platforms import current_platform
22+
from vllm.triton_utils import tl, triton
2423
from vllm.utils import direct_register_custom_op
2524

2625
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled

vllm/model_executor/layers/fused_moe/moe_align_block_size.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
from typing import Optional, Tuple
33

44
import torch
5-
import triton
6-
import triton.language as tl
75

86
import vllm.envs as envs
97
from vllm import _custom_ops as ops
8+
from vllm.triton_utils import tl, triton
109
from vllm.utils import round_up
1110

1211

vllm/model_executor/layers/lightning_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import torch
3-
import triton
4-
import triton.language as tl
53
from einops import rearrange
64

5+
from vllm.triton_utils import tl, triton
6+
77

88
@triton.jit
99
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,

vllm/model_executor/layers/mamba/ops/mamba_ssm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
55

66
import torch
7-
import triton
8-
import triton.language as tl
97
from packaging import version
108

119
from vllm import _custom_ops as ops
1210
from vllm.attention.backends.utils import PAD_SLOT_ID
13-
from vllm.triton_utils import HAS_TRITON
11+
from vllm.triton_utils import HAS_TRITON, tl, triton
1412

1513
TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
1614
>= version.parse("3.0.0"))

vllm/model_executor/layers/mamba/ops/ssd_bmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import math
99

1010
import torch
11-
import triton
12-
import triton.language as tl
11+
12+
from vllm.triton_utils import tl, triton
1313

1414

1515
@triton.autotune(

vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
# ruff: noqa: E501,SIM102
77

88
import torch
9-
import triton
10-
import triton.language as tl
119
from packaging import version
1210

11+
from vllm.triton_utils import tl, triton
12+
1313
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
1414

1515

vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import math
99

1010
import torch
11-
import triton
12-
import triton.language as tl
11+
12+
from vllm.triton_utils import tl, triton
1313

1414
from .mamba_ssm import softplus
1515

vllm/model_executor/layers/mamba/ops/ssd_combined.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
# ruff: noqa: E501
77

88
import torch
9-
import triton
109
from einops import rearrange
1110
from packaging import version
1211

12+
from vllm.triton_utils import triton
13+
1314
from .ssd_bmm import _bmm_chunk_fwd
1415
from .ssd_chunk_scan import _chunk_scan_fwd
1516
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,

vllm/model_executor/layers/mamba/ops/ssd_state_passing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
# ruff: noqa: E501
77

88
import torch
9-
import triton
10-
import triton.language as tl
9+
10+
from vllm.triton_utils import tl, triton
1111

1212

1313
@triton.autotune(

vllm/model_executor/layers/quantization/awq_triton.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import torch
4-
import triton
5-
import triton.language as tl
4+
5+
from vllm.triton_utils import tl, triton
66

77
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
88

vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Optional, Type
44

55
import torch
6-
import triton
7-
import triton.language as tl
6+
7+
from vllm.triton_utils import tl, triton
88

99

1010
def is_weak_contiguous(x: torch.Tensor):

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from typing import Any, Dict, List, Optional, Tuple, Union
88

99
import torch
10-
import triton
11-
import triton.language as tl
1210

1311
from vllm import _custom_ops as ops
1412
from vllm.logger import init_logger
@@ -17,6 +15,7 @@
1715
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1816
CUTLASS_BLOCK_FP8_SUPPORTED)
1917
from vllm.platforms import current_platform
18+
from vllm.triton_utils import tl, triton
2019
from vllm.utils import direct_register_custom_op
2120

2221
logger = init_logger(__name__)

0 commit comments

Comments
 (0)