Skip to content

Commit 73d3ec3

Browse files
committed
[CI][UT]Compat with cuda and npu
Signed-off-by: jiangpeng <[email protected]>
1 parent db2f8d9 commit 73d3ec3

File tree

5 files changed

+36
-13
lines changed

5 files changed

+36
-13
lines changed

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,8 @@ class HfRunner:
272272
def get_default_device(self):
273273
from vllm.platforms import current_platform
274274

275-
return ("cpu" if current_platform.is_cpu() else "cuda")
275+
return ("cpu"
276+
if current_platform.is_cpu() else current_platform.device_type)
276277

277278
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
278279
if x is None or isinstance(x, (bool, )):

tests/v1/sample/test_sampler.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
import pytest
77
import torch
88

9+
from vllm.platforms import current_platform
910
from vllm.utils import make_tensor_with_pad
1011
from vllm.v1.sample.metadata import SamplingMetadata
1112
from vllm.v1.sample.sampler import Sampler
1213

1314
VOCAB_SIZE = 1024
1415
NUM_OUTPUT_TOKENS = 20
15-
CUDA_DEVICES = [
16-
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
16+
TORCH_DEVICES = [
17+
f"{current_platform.device_type}:{i}"
18+
for i in range(1 if current_platform.get_device_count() == 1 else 2)
1719
]
1820
MAX_NUM_PROMPT_TOKENS = 64
1921

@@ -224,7 +226,7 @@ def _create_weighted_output_token_list(
224226
return output_token_ids, sorted_token_ids_in_output
225227

226228

227-
@pytest.mark.parametrize("device", CUDA_DEVICES)
229+
@pytest.mark.parametrize("device", TORCH_DEVICES)
228230
@pytest.mark.parametrize("batch_size", [1, 2, 32])
229231
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
230232
"""
@@ -254,7 +256,7 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
254256
assert logits[batch_idx][token_id] != -float("inf")
255257

256258

257-
@pytest.mark.parametrize("device", CUDA_DEVICES)
259+
@pytest.mark.parametrize("device", TORCH_DEVICES)
258260
@pytest.mark.parametrize("batch_size", [1, 2, 32])
259261
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
260262
def test_sampler_presence_penalty(device: str, batch_size: int,
@@ -299,7 +301,7 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
299301
assert penalized_token_id not in output_token_ids[batch_idx]
300302

301303

302-
@pytest.mark.parametrize("device", CUDA_DEVICES)
304+
@pytest.mark.parametrize("device", TORCH_DEVICES)
303305
@pytest.mark.parametrize("batch_size", [1, 2, 32])
304306
@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0])
305307
def test_sampler_frequency_penalty(device: str, batch_size: int,
@@ -352,7 +354,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
352354
assert penalized_token_id not in distinct_sorted_token_ids_in_output
353355

354356

355-
@pytest.mark.parametrize("device", CUDA_DEVICES)
357+
@pytest.mark.parametrize("device", TORCH_DEVICES)
356358
@pytest.mark.parametrize("batch_size", [1, 2, 32])
357359
@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9])
358360
def test_sampler_repetition_penalty(device: str, batch_size: int,
@@ -398,7 +400,7 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
398400
or non_penalized_token_id in output_tokens)
399401

400402

401-
@pytest.mark.parametrize("device", CUDA_DEVICES)
403+
@pytest.mark.parametrize("device", TORCH_DEVICES)
402404
@pytest.mark.parametrize("batch_size", [1, 2, 32])
403405
@pytest.mark.parametrize("min_p", [0.0, 0.1])
404406
def test_sampler_min_p(device: str, batch_size: int, min_p: float):
@@ -438,7 +440,7 @@ def test_sampler_min_p(device: str, batch_size: int, min_p: float):
438440
assert logits[batch_idx][token_id] != -float("inf")
439441

440442

441-
@pytest.mark.parametrize("device", CUDA_DEVICES)
443+
@pytest.mark.parametrize("device", TORCH_DEVICES)
442444
@pytest.mark.parametrize("batch_size", [1, 2, 32])
443445
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
444446
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
@@ -472,7 +474,7 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
472474
assert logits_for_req[token_id] == pytest.approx(1e-2)
473475

474476

475-
@pytest.mark.parametrize("device", CUDA_DEVICES)
477+
@pytest.mark.parametrize("device", TORCH_DEVICES)
476478
@pytest.mark.parametrize("batch_size", [1, 2, 32])
477479
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
478480
def test_sampler_allowed_token_ids(device: str, batch_size: int,
@@ -513,7 +515,7 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
513515
assert logits_for_req[token_id] != -float("inf")
514516

515517

516-
@pytest.mark.parametrize("device", CUDA_DEVICES)
518+
@pytest.mark.parametrize("device", TORCH_DEVICES)
517519
@pytest.mark.parametrize("batch_size", [1, 2, 32])
518520
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
519521
def test_sampler_bad_words(device: str, batch_size: int,

vllm/platforms/cuda.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ def get_device_capability(cls,
8686
def get_device_name(cls, device_id: int = 0) -> str:
8787
raise NotImplementedError
8888

89+
@classmethod
90+
def get_device_count(cls) -> int:
91+
return torch.cuda.device_count()
92+
93+
@classmethod
94+
def get_device_event(cls, blocking) -> torch.cuda.Event:
95+
return torch.cuda.Event(blocking=blocking)
96+
8997
@classmethod
9098
def get_device_total_memory(cls, device_id: int = 0) -> int:
9199
raise NotImplementedError

vllm/platforms/interface.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,16 @@ def get_device_name(cls, device_id: int = 0) -> str:
196196
"""Get the name of a device."""
197197
raise NotImplementedError
198198

199+
@classmethod
200+
def get_device_count(cls) -> int:
201+
"""Get the tensor core number of a device."""
202+
raise NotImplementedError
203+
204+
@classmethod
205+
def get_device_event(cls, blocking):
206+
"""Get the tensor core event of a device."""
207+
raise NotImplementedError
208+
199209
@classmethod
200210
def get_device_uuid(cls, device_id: int = 0) -> str:
201211
"""Get the uuid of a device, e.g. the PCI bus ID."""

vllm/worker/multi_step_model_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
SamplerOutput,
1515
SamplingMetadata, get_logprobs,
1616
get_pythonized_sample_results)
17+
from vllm.platforms import current_platform
1718
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
1819
Logprob, SequenceGroupMetadata, SequenceOutput)
1920
from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream
@@ -158,8 +159,9 @@ class StatefulModelInput(BroadcastableModelInput):
158159
is_first_multi_step: bool = False
159160
base_output_proc_callback: Optional[Callable] = None
160161
# ping-pong data structures for multi-step to wait on the previous step
161-
step_cuda_events: List[torch.cuda.Event] = field(
162-
default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2)
162+
step_cuda_events: List = field(
163+
default_factory=lambda:
164+
[current_platform.get_device_event(blocking=True)] * 2)
163165
num_seqs: int = -1
164166
num_queries: int = -1
165167
num_single_step_prefills: int = 0

0 commit comments

Comments
 (0)