File tree 5 files changed +35
-11
lines changed
examples/offline_inference
5 files changed +35
-11
lines changed Original file line number Diff line number Diff line change @@ -22,7 +22,8 @@ def main():
22
22
# In real workloads, `enforace_eager` should be `False`.
23
23
llm = LLM (model = "Qwen/Qwen2-1.5B-Instruct" ,
24
24
max_num_batched_tokens = 64 ,
25
- max_num_seqs = 4 )
25
+ max_num_seqs = 4 ,
26
+ max_model_len = 128 )
26
27
outputs = llm .generate (prompts , sampling_params )
27
28
print ("-" * 50 )
28
29
for output , answer in zip (outputs , answers ):
Original file line number Diff line number Diff line change @@ -18,9 +18,9 @@ setuptools==78.1.0
18
18
--find-links https://storage.googleapis.com/libtpu-releases/index.html
19
19
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
20
20
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
21
- torch==2.8.0.dev20250408
22
- torchvision==0.22.0.dev20250408
23
- torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408 -cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24
- torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408 -cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25
- torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408 -cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
21
+ torch==2.8.0.dev20250430
22
+ torchvision==0.22.0.dev20250430
23
+ torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430 -cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24
+ torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430 -cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25
+ torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430 -cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
26
26
Original file line number Diff line number Diff line change @@ -73,9 +73,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
73
73
from vllm .config import CompilationLevel
74
74
75
75
cache_config = vllm_config .cache_config
76
+ # For v0, the default block size is 16.
76
77
if cache_config and cache_config .block_size is None :
77
78
cache_config .block_size = 16
78
-
79
79
compilation_config = vllm_config .compilation_config
80
80
81
81
# TPU only supports DYNAMO_ONCE compilation level
@@ -98,16 +98,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
98
98
if envs .VLLM_USE_V1 :
99
99
from vllm .v1 .attention .backends .pallas import (
100
100
PallasAttentionBackend )
101
+ cache_config .block_size = PallasAttentionBackend .get_page_size (
102
+ vllm_config )
101
103
min_page_size = PallasAttentionBackend .get_min_page_size (
102
104
vllm_config )
103
- if min_page_size > vllm_config . cache_config .block_size :
105
+ if min_page_size > cache_config .block_size :
104
106
logger .warning (
105
107
"Increase the page size from %s to %s to make sure there's"
106
108
"no SMEM OOM" ,
107
- vllm_config . cache_config .block_size ,
109
+ cache_config .block_size ,
108
110
min_page_size ,
109
111
)
110
- vllm_config . cache_config .block_size = min_page_size
112
+ cache_config .block_size = min_page_size
111
113
112
114
parallel_config = vllm_config .parallel_config
113
115
scheduler_config = vllm_config .scheduler_config
Original file line number Diff line number Diff line change @@ -704,6 +704,13 @@ def cdiv(a: int, b: int) -> int:
704
704
return - (a // - b )
705
705
706
706
707
+ def next_power_of_2 (n ) -> int :
708
+ """The next power of 2 (inclusive)"""
709
+ if n < 1 :
710
+ return 1
711
+ return 1 << (n - 1 ).bit_length ()
712
+
713
+
707
714
def round_up (x : int , y : int ) -> int :
708
715
return ((x + y - 1 ) // y ) * y
709
716
Original file line number Diff line number Diff line change 12
12
from vllm .attention .backends .utils import CommonAttentionState
13
13
from vllm .config import VllmConfig
14
14
from vllm .logger import init_logger
15
- from vllm .utils import cdiv
15
+ from vllm .utils import cdiv , next_power_of_2
16
16
17
17
logger = init_logger (__name__ )
18
18
@@ -65,6 +65,20 @@ def get_min_page_size(vllm_config: VllmConfig) -> int:
65
65
min_page_size = 1 << (min_page_size - 1 ).bit_length ()
66
66
return min_page_size
67
67
68
+ # TPU has limited SREGs (scalar registers), if page_size is too small, we
69
+ # can spill SREGs easily which leads to bad performance. The strategy we
70
+ # apply here is trying to split max-model-len to 16 pages which make the
71
+ # spill less likely. Meanwhile we make sure the page size is in [16, 256].
72
+ @staticmethod
73
+ def get_page_size (vllm_config : VllmConfig ) -> int :
74
+ page_size = next_power_of_2 (
75
+ vllm_config .model_config .max_model_len ) // 16
76
+ if page_size <= 16 :
77
+ return 16
78
+ if page_size >= 256 :
79
+ return 256
80
+ return page_size
81
+
68
82
69
83
@dataclass
70
84
class PallasMetadata :
You can’t perform that action at this time.
0 commit comments