Skip to content

Commit 355fa4b

Browse files
committed
[Bugfix] Fix triton import with local TritonPlaceholder
Signed-off-by: Mengqing Cao <[email protected]>
1 parent c777df7 commit 355fa4b

29 files changed

+79
-81
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010

1111
import ray
1212
import torch
13-
import triton
1413
from ray.experimental.tqdm_ray import tqdm
1514
from transformers import AutoConfig
1615

1716
from vllm.model_executor.layers.fused_moe.fused_moe import *
1817
from vllm.platforms import current_platform
18+
from vllm.triton_utils import triton
1919
from vllm.utils import FlexibleArgumentParser
2020

2121
FP8_DTYPE = current_platform.fp8_dtype()

benchmarks/kernels/benchmark_rmsnorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from typing import Optional, Union
55

66
import torch
7-
import triton
87
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
98
from torch import nn
109

1110
from vllm import _custom_ops as vllm_ops
11+
from vllm.triton_utils import triton
1212

1313

1414
class HuggingFaceRMSNorm(nn.Module):

benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
# Import DeepGEMM functions
77
import deep_gemm
88
import torch
9-
import triton
109
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
1110

1211
# Import vLLM functions
1312
from vllm import _custom_ops as ops
1413
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1514
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
15+
from vllm.triton_utils import triton
1616

1717

1818
# Copied from

tests/kernels/attention/test_flashmla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
import pytest
77
import torch
8-
import triton
98

109
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
1110
get_mla_metadata,
1211
is_flashmla_supported)
12+
from vllm.triton_utils import triton
1313

1414

1515
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:

vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import torch
4-
import triton
5-
import triton.language as tl
4+
5+
from vllm.triton_utils import tl, triton
66

77

88
def blocksparse_flash_attn_varlen_fwd(

vllm/attention/ops/blocksparse_attention/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
import numpy as np
1010
import torch
11-
import triton
11+
12+
from vllm.triton_utils import triton
1213

1314

1415
class csr_matrix:

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
# - Thomas Parnell <[email protected]>
88

99
import torch
10-
import triton
11-
import triton.language as tl
1210

1311
from vllm import _custom_ops as ops
1412
from vllm.platforms.rocm import use_rocm_custom_paged_attention
13+
from vllm.triton_utils import tl, triton
1514

1615
from .prefix_prefill import context_attention_fwd
1716

vllm/attention/ops/prefix_prefill.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
55

66
import torch
7-
import triton
8-
import triton.language as tl
97

108
from vllm.platforms import current_platform
9+
from vllm.triton_utils import tl, triton
1110

1211
# Static kernels parameters
1312
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64

vllm/attention/ops/triton_decode_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,8 @@
3030

3131
import logging
3232

33-
import triton
34-
import triton.language as tl
35-
3633
from vllm.platforms import current_platform
34+
from vllm.triton_utils import tl, triton
3735

3836
is_hip_ = current_platform.is_rocm()
3937

vllm/attention/ops/triton_flash_attention.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,10 @@
2525
from typing import Optional
2626

2727
import torch
28-
import triton
29-
import triton.language as tl
3028

3129
from vllm import _custom_ops as ops
3230
from vllm.platforms import current_platform
31+
from vllm.triton_utils import tl, triton
3332

3433
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
3534

vllm/attention/ops/triton_merge_attn_states.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from typing import Optional
33

44
import torch
5-
import triton
6-
import triton.language as tl
5+
6+
from vllm.triton_utils import tl, triton
77

88

99
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005

vllm/lora/ops/triton_ops/kernel_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
"""
33
Utilities for Punica kernel construction.
44
"""
5-
import triton
6-
import triton.language as tl
5+
from vllm.triton_utils import tl, triton
76

87

98
@triton.jit

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from typing import Any, Callable, Dict, List, Optional, Tuple
77

88
import torch
9-
import triton
10-
import triton.language as tl
119

1210
import vllm.envs as envs
1311
from vllm import _custom_ops as ops
@@ -21,6 +19,7 @@
2119
from vllm.model_executor.layers.quantization.utils.int8_utils import (
2220
per_token_group_quant_int8, per_token_quant_int8)
2321
from vllm.platforms import current_platform
22+
from vllm.triton_utils import tl, triton
2423
from vllm.utils import direct_register_custom_op
2524

2625
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled

vllm/model_executor/layers/fused_moe/moe_align_block_size.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
from typing import Optional, Tuple
33

44
import torch
5-
import triton
6-
import triton.language as tl
75

86
import vllm.envs as envs
97
from vllm import _custom_ops as ops
8+
from vllm.triton_utils import tl, triton
109
from vllm.utils import round_up
1110

1211

vllm/model_executor/layers/lightning_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import torch
3-
import triton
4-
import triton.language as tl
53
from einops import rearrange
64

5+
from vllm.triton_utils import tl, triton
6+
77

88
@triton.jit
99
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,

vllm/model_executor/layers/mamba/ops/mamba_ssm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
55

66
import torch
7-
import triton
8-
import triton.language as tl
97
from packaging import version
108

119
from vllm import _custom_ops as ops
1210
from vllm.attention.backends.utils import PAD_SLOT_ID
13-
from vllm.triton_utils import HAS_TRITON
11+
from vllm.triton_utils import HAS_TRITON, tl, triton
1412

1513
TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
1614
>= version.parse("3.0.0"))

vllm/model_executor/layers/mamba/ops/ssd_bmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import math
99

1010
import torch
11-
import triton
12-
import triton.language as tl
11+
12+
from vllm.triton_utils import tl, triton
1313

1414

1515
@triton.autotune(

vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
# ruff: noqa: E501,SIM102
77

88
import torch
9-
import triton
10-
import triton.language as tl
119
from packaging import version
1210

11+
from vllm.triton_utils import tl, triton
12+
1313
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
1414

1515

vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import math
99

1010
import torch
11-
import triton
12-
import triton.language as tl
11+
12+
from vllm.triton_utils import tl, triton
1313

1414
from .mamba_ssm import softplus
1515

vllm/model_executor/layers/mamba/ops/ssd_combined.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
# ruff: noqa: E501
77

88
import torch
9-
import triton
109
from einops import rearrange
1110
from packaging import version
1211

12+
from vllm.triton_utils import triton
13+
1314
from .ssd_bmm import _bmm_chunk_fwd
1415
from .ssd_chunk_scan import _chunk_scan_fwd
1516
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,

vllm/model_executor/layers/mamba/ops/ssd_state_passing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
# ruff: noqa: E501
77

88
import torch
9-
import triton
10-
import triton.language as tl
9+
10+
from vllm.triton_utils import tl, triton
1111

1212

1313
@triton.autotune(

vllm/model_executor/layers/quantization/awq_triton.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import torch
4-
import triton
5-
import triton.language as tl
4+
5+
from vllm.triton_utils import tl, triton
66

77
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
88

vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Optional, Type
44

55
import torch
6-
import triton
7-
import triton.language as tl
6+
7+
from vllm.triton_utils import tl, triton
88

99

1010
def is_weak_contiguous(x: torch.Tensor):

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from typing import Any, Dict, List, Optional, Tuple, Union
88

99
import torch
10-
import triton
11-
import triton.language as tl
1210

1311
from vllm import _custom_ops as ops
1412
from vllm.logger import init_logger
@@ -17,6 +15,7 @@
1715
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1816
CUTLASS_BLOCK_FP8_SUPPORTED)
1917
from vllm.platforms import current_platform
18+
from vllm.triton_utils import tl, triton
2019
from vllm.utils import direct_register_custom_op
2120

2221
logger = init_logger(__name__)

vllm/model_executor/layers/quantization/utils/int8_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
from typing import Any, Dict, List, Optional, Tuple
99

1010
import torch
11-
import triton
12-
import triton.language as tl
1311

1412
from vllm.platforms import current_platform
13+
from vllm.triton_utils import tl, triton
1514

1615
logger = logging.getLogger(__name__)
1716

vllm/triton_utils/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from vllm.triton_utils.importing import HAS_TRITON
3+
from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder,
4+
TritonPlaceholder)
45

5-
__all__ = ["HAS_TRITON"]
6+
if HAS_TRITON:
7+
import triton
8+
import triton.language as tl
9+
else:
10+
triton = TritonPlaceholder()
11+
tl = TritonLanguagePlaceholder()
12+
13+
__all__ = ["HAS_TRITON", "triton", "tl"]

vllm/triton_utils/importing.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,34 @@
1616
logger.info("Triton not installed or not compatible; certain GPU-related"
1717
" functions will not be available.")
1818

19-
class TritonPlaceholder(types.ModuleType):
20-
21-
def __init__(self):
22-
super().__init__("triton")
23-
self.jit = self._dummy_decorator("jit")
24-
self.autotune = self._dummy_decorator("autotune")
25-
self.heuristics = self._dummy_decorator("heuristics")
26-
self.language = TritonLanguagePlaceholder()
27-
logger.warning_once(
28-
"Triton is not installed. Using dummy decorators. "
29-
"Install it via `pip install triton` to enable kernel"
30-
"compilation.")
31-
32-
def _dummy_decorator(self, name):
33-
34-
def decorator(func=None, **kwargs):
35-
if func is None:
36-
return lambda f: f
37-
return func
38-
39-
return decorator
40-
41-
class TritonLanguagePlaceholder(types.ModuleType):
42-
43-
def __init__(self):
44-
super().__init__("triton.language")
45-
self.constexpr = None
46-
self.dtype = None
47-
self.int64 = None
19+
20+
class TritonPlaceholder(types.ModuleType):
21+
22+
def __init__(self):
23+
super().__init__("triton")
24+
self.jit = self._dummy_decorator("jit")
25+
self.autotune = self._dummy_decorator("autotune")
26+
self.heuristics = self._dummy_decorator("heuristics")
27+
self.language = TritonLanguagePlaceholder()
28+
logger.warning_once(
29+
"Triton is not installed. Using dummy decorators. "
30+
"Install it via `pip install triton` to enable kernel"
31+
" compilation.")
32+
33+
def _dummy_decorator(self, name):
34+
35+
def decorator(func=None, **kwargs):
36+
if func is None:
37+
return lambda f: f
38+
return func
39+
40+
return decorator
41+
42+
43+
class TritonLanguagePlaceholder(types.ModuleType):
44+
45+
def __init__(self):
46+
super().__init__("triton.language")
47+
self.constexpr = None
48+
self.dtype = None
49+
self.int64 = None

0 commit comments

Comments
 (0)