Skip to content

Commit b6e7d0b

Browse files
committed
Set pagesize based on max-model-len
Signed-off-by: Jevin Jiang <[email protected]>
1 parent f192ca9 commit b6e7d0b

File tree

5 files changed

+35
-11
lines changed

5 files changed

+35
-11
lines changed

examples/offline_inference/tpu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def main():
2222
# In real workloads, `enforace_eager` should be `False`.
2323
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
2424
max_num_batched_tokens=64,
25-
max_num_seqs=4)
25+
max_num_seqs=4,
26+
max_model_len=128)
2627
outputs = llm.generate(prompts, sampling_params)
2728
print("-" * 50)
2829
for output, answer in zip(outputs, answers):

requirements/tpu.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ setuptools==78.1.0
1818
--find-links https://storage.googleapis.com/libtpu-releases/index.html
1919
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
2020
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
21-
torch==2.8.0.dev20250408
22-
torchvision==0.22.0.dev20250408
23-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
21+
torch==2.8.0.dev20250430
22+
torchvision==0.22.0.dev20250430
23+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
2626

vllm/platforms/tpu.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
7373
from vllm.config import CompilationLevel
7474

7575
cache_config = vllm_config.cache_config
76+
# For v0, the default block size is 16.
7677
if cache_config and cache_config.block_size is None:
7778
cache_config.block_size = 16
78-
7979
compilation_config = vllm_config.compilation_config
8080

8181
# TPU only supports DYNAMO_ONCE compilation level
@@ -98,16 +98,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9898
if envs.VLLM_USE_V1:
9999
from vllm.v1.attention.backends.pallas import (
100100
PallasAttentionBackend)
101+
cache_config.block_size = PallasAttentionBackend.get_page_size(
102+
vllm_config)
101103
min_page_size = PallasAttentionBackend.get_min_page_size(
102104
vllm_config)
103-
if min_page_size > vllm_config.cache_config.block_size:
105+
if min_page_size > cache_config.block_size:
104106
logger.warning(
105107
"Increase the page size from %s to %s to make sure there's"
106108
"no SMEM OOM",
107-
vllm_config.cache_config.block_size,
109+
cache_config.block_size,
108110
min_page_size,
109111
)
110-
vllm_config.cache_config.block_size = min_page_size
112+
cache_config.block_size = min_page_size
111113

112114
parallel_config = vllm_config.parallel_config
113115
scheduler_config = vllm_config.scheduler_config

vllm/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,13 @@ def cdiv(a: int, b: int) -> int:
704704
return -(a // -b)
705705

706706

707+
def next_power_of_2(n) -> int:
708+
"""The next power of 2 (inclusive)"""
709+
if n < 1:
710+
return 1
711+
return 1 << (n - 1).bit_length()
712+
713+
707714
def round_up(x: int, y: int) -> int:
708715
return ((x + y - 1) // y) * y
709716

vllm/v1/attention/backends/pallas.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.attention.backends.utils import CommonAttentionState
1313
from vllm.config import VllmConfig
1414
from vllm.logger import init_logger
15-
from vllm.utils import cdiv
15+
from vllm.utils import cdiv, next_power_of_2
1616

1717
logger = init_logger(__name__)
1818

@@ -65,6 +65,20 @@ def get_min_page_size(vllm_config: VllmConfig) -> int:
6565
min_page_size = 1 << (min_page_size - 1).bit_length()
6666
return min_page_size
6767

68+
# TPU has limited SREGs (scalar registers), if page_size is too small, we
69+
# can spill SREGs easily which leads to bad performance. The strategy we
70+
# apply here is trying to split max-model-len to 16 pages which make the
71+
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
72+
@staticmethod
73+
def get_page_size(vllm_config: VllmConfig) -> int:
74+
page_size = next_power_of_2(
75+
vllm_config.model_config.max_model_len) // 16
76+
if page_size <= 16:
77+
return 16
78+
if page_size >= 256:
79+
return 256
80+
return page_size
81+
6882

6983
@dataclass
7084
class PallasMetadata:

0 commit comments

Comments
 (0)