Skip to content

Commit 34c40e4

Browse files
charlifuwuisawesome
authored andcommitted
[ROCm][Misc] Follow-ups for Skinny Gemms on ROCm. (vllm-project#17011)
Signed-off-by: charlifu <[email protected]>
1 parent 4dd29a6 commit 34c40e4

File tree

4 files changed

+18
-15
lines changed

4 files changed

+18
-15
lines changed

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,9 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
155155
scale_b: torch.Tensor, bias: torch.Tensor,
156156
input_2d: torch.Tensor,
157157
output_shape: List) -> torch.Tensor:
158-
if envs.VLLM_ROCM_USE_SKINNY_GEMM and qinput.shape[
159-
0] == 1 and qinput.shape[1] % 16 == 0:
158+
from vllm.platforms.rocm import on_mi250_mi300
159+
if envs.VLLM_ROCM_USE_SKINNY_GEMM and not on_mi250_mi300(
160+
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
160161
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
161162
current_platform.get_cu_count())
162163
else:
@@ -371,7 +372,7 @@ def apply(
371372

372373
return w8a8_scaled_mm_func(qinput=qinput,
373374
weight=weight,
374-
out_dtype=input.dtype,
375+
out_dtype=out_dtype,
375376
scale_a=x_scale,
376377
scale_b=weight_scale,
377378
bias=bias,

vllm/model_executor/layers/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
7070
def rocm_unquantized_gemm(x: torch.Tensor,
7171
weight: torch.Tensor,
7272
bias: Optional[torch.Tensor] = None):
73+
from vllm.platforms.rocm import on_mi250_mi300
7374
k = weight.shape[1]
74-
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and \
75+
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300() and \
7576
x.dtype in [torch.float16, torch.bfloat16] \
7677
and k % 8 == 0 and bias is None)
7778

@@ -83,11 +84,11 @@ def rocm_unquantized_gemm(x: torch.Tensor,
8384
m = weight.shape[0]
8485
cu_count = current_platform.get_cu_count()
8586

86-
if m > 8 and n < 4:
87+
if m > 8 and 0 < n < 4:
8788
out = ops.wvSplitK(weight, x_view, cu_count)
8889
return out.view(*x.shape[:-1], weight.shape[0])
8990
elif m % 4 == 0 and n == 1 and k <= 8192:
90-
out = ops.LLMM1(weight, x_view, out, 4)
91+
out = ops.LLMM1(weight, x_view, 4)
9192
return out.view(*x.shape[:-1], weight.shape[0])
9293
return torch.nn.functional.linear(x, weight, bias)
9394

vllm/model_executor/layers/vocab_parallel_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
tensor_model_parallel_all_reduce)
1313
from vllm.model_executor.layers.quantization.base_config import (
1414
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
15+
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
1516
from vllm.model_executor.parameter import BasevLLMParameter
1617
from vllm.model_executor.utils import set_weight_attrs
1718
from vllm.platforms import current_platform
@@ -40,7 +41,7 @@ def apply(self,
4041
layer: torch.nn.Module,
4142
x: torch.Tensor,
4243
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
43-
return F.linear(x, layer.weight, bias)
44+
return dispatch_unquantized_gemm()(x, layer.weight, bias)
4445

4546
def embedding(self, layer: torch.nn.Module,
4647
input_: torch.Tensor) -> torch.Tensor:

vllm/platforms/rocm.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,22 @@ def device_id_to_physical_device_id(device_id: int) -> int:
9898
return device_id
9999

100100

101+
def on_mi250_mi300() -> bool:
102+
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
103+
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
104+
105+
101106
@cache
102107
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
103108
block_size: int, gqa_ratio: int,
104109
max_seq_len: int,
105110
sliding_window: int) -> bool:
106111

107-
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
108-
ON_NAVI = "gfx1" in GPU_ARCH
109-
ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
110-
111-
# rocm custom page attention not support on navi (gfx1*)
112+
# rocm custom page attention not support on gfx1*
112113
# custom paged attn always supported on V0. On V1, requires sliding window
113114
# disabled due to observed numerical discrepancy.
114-
return (ON_MI250_MI300 and not ON_NAVI
115-
and (not envs.VLLM_USE_V1 or sliding_window == 0
116-
or sliding_window == (-1, -1))
115+
return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0
116+
or sliding_window == (-1, -1))
117117
and (qtype == torch.half or qtype == torch.bfloat16)
118118
and (head_size == 64 or head_size == 128)
119119
and (block_size == 16 or block_size == 32)

0 commit comments

Comments
 (0)