Skip to content

[TPU] Increase block size and reset block shapes #16458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 6, 2025

Conversation

bythew3i
Copy link
Contributor

@bythew3i bythew3i commented Apr 11, 2025

Increase kv cache block size and reset kernel block shapes based on autotuned results from kernel.
But still need to retune the kernel block shapes in kernel.

Note: we should wait for pytorch/xla#9041 to be checkin and update new torch_xla version in requirements.txt

Benchmarked without cache:

v6e-1 (single chip): 7.87 -> 8.37 req / sec

Benchmarking script:

VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct  --disable-log-requests --port 8003 --gpu-memory-utilization 0.95 --max-num-batched-tokens 512 --tensor-parallel-size 1 --max-model-len 2048 --max_num_seqs 512 &> /tmp/serve.log &

python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Llama-3.1-8B-Instruct --dataset-name random --port=8003 --random-input-len 1800 --random-output-len 128

Before:

============ Serving Benchmark Result ============
Successful requests:                     987       
Benchmark duration (s):                  125.47    
Total input tokens:                      1776600   
Total generated tokens:                  118669    
Request throughput (req/s):              7.87      
Output token throughput (tok/s):         945.80    
Total Token throughput (tok/s):          15105.47  
---------------Time to First Token----------------
Mean TTFT (ms):                          61168.93  
Median TTFT (ms):                        60913.25  
P99 TTFT (ms):                           121404.06 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          31.70     
Median TPOT (ms):                        31.98     
P99 TPOT (ms):                           32.56     
---------------Inter-token Latency----------------
Mean ITL (ms):                           31.69     
Median ITL (ms):                         31.93     
P99 ITL (ms):                            33.08     
==================================================

After:

============ Serving Benchmark Result ============
Successful requests:                     987       
Benchmark duration (s):                  117.96    
Total input tokens:                      1776600   
Total generated tokens:                  118669    
Request throughput (req/s):              8.37      
Output token throughput (tok/s):         1006.00   
Total Token throughput (tok/s):          16066.94  
---------------Time to First Token----------------
Mean TTFT (ms):                          57649.43  
Median TTFT (ms):                        57545.37  
P99 TTFT (ms):                           113943.35 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          29.78     
Median TPOT (ms):                        29.96     
P99 TPOT (ms):                           30.49     
---------------Inter-token Latency----------------
Mean ITL (ms):                           29.76     
Median ITL (ms):                         29.97     
P99 ITL (ms):                            30.93     
==================================================

v6e-8 (multi chip): 4.92 -> 5.42 req / sec

VLLM_USE_V1=1 vllm serve "meta-llama/Llama-3.1-70B" --download_dir "/root/.cache" --disable-log-requests --tensor_parallel_size=8 --max-model-len=2048 --gpu-memory-utilization 0.95 --max-num-batched-tokens 512 --max_num_seqs 512  &> /tmp/serve.log &


python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Llama-3.1-70B --dataset-name random --port=8003 --random-input-len 1800 --random-output-len 128

Before:

============ Serving Benchmark Result ============
Successful requests:                     987       
Benchmark duration (s):                  200.42    
Total input tokens:                      1776600   
Total generated tokens:                  111817    
Request throughput (req/s):              4.92      
Output token throughput (tok/s):         557.91    
Total Token throughput (tok/s):          9422.22   
---------------Time to First Token----------------
Mean TTFT (ms):                          98990.46  
Median TTFT (ms):                        99069.74  
P99 TTFT (ms):                           195396.73 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          51.12     
Median TPOT (ms):                        51.44     
P99 TPOT (ms):                           52.66     
---------------Inter-token Latency----------------
Mean ITL (ms):                           51.27     
Median ITL (ms):                         51.38     
P99 ITL (ms):                            53.50     
==================================================

After

============ Serving Benchmark Result ============
Successful requests:                     987       
Benchmark duration (s):                  182.04    
Total input tokens:                      1776600   
Total generated tokens:                  111445    
Request throughput (req/s):              5.42      
Output token throughput (tok/s):         612.19    
Total Token throughput (tok/s):          10371.49  
---------------Time to First Token----------------
Mean TTFT (ms):                          89566.91  
Median TTFT (ms):                        89369.13  
P99 TTFT (ms):                           177012.56 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          46.24     
Median TPOT (ms):                        46.64     
P99 TPOT (ms):                           47.39     
---------------Inter-token Latency----------------
Mean ITL (ms):                           46.45     
Median ITL (ms):                         46.59     
P99 ITL (ms):                            48.17     
==================================================

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added v1 tpu Related to Google TPUs labels Apr 11, 2025
@alexm-redhat
Copy link
Collaborator

@bythew3i which model did you test on single and multi-chip setups?

if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
if cache_config:
cache_config.block_size = 256
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

256 block size seems a bit aggressive. Maybe you can try the sharegpt (it has average short prompts) benchmark and see if you don't see a regression.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And usually the page_size should not be larger than max_model_len.

Copy link
Contributor Author

@bythew3i bythew3i Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re: @alexm-redhat I modified the code to calculate the block size based on the max-model-len. PTAL.

BTW, can you please share the cmds used for sharegpt benchmarking?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re: @yaochengji now the way how we choose page size should handle this. PTAL get_page_size in pallas.py

@bythew3i
Copy link
Contributor Author

@bythew3i which model did you test on single and multi-chip setups?

Hi @alexm-redhat, sorry for the late reply. I benchmarked meta-llama/Llama-3.1-8B-Instruct and meta-llama/Llama-3.1-70B. I also updated benchmarking cmds in the PR description.

@bythew3i
Copy link
Contributor Author

CC: @yarongmu-google @bvrockwell

Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, when CI is green. Thanks for the contribution.

@bythew3i
Copy link
Contributor Author

QQ: what is the cmd to format the code?

@yaochengji
Copy link
Collaborator

yaochengji commented Apr 26, 2025

QQ: what is the cmd to format the code?

You can use pre-commit run --all-files, from https://docs.vllm.ai/en/stable/contributing/overview.html#testing

Sometimes there's still some lines not formatted, you can install a ruff plungin in vscode and format selected

@bythew3i bythew3i force-pushed the ragged-jevinjiang branch 2 times, most recently from 8d890ed to 677fc5f Compare April 30, 2025 22:51
@mergify mergify bot added the ci/build label Apr 30, 2025
@bythew3i
Copy link
Contributor Author

bythew3i commented May 1, 2025

@WoosukKwon @alexm-redhat PTAL! Thanks!

@lsy323
Copy link
Collaborator

lsy323 commented May 1, 2025

cc @mgoin

@@ -65,6 +65,22 @@ def get_min_page_size(vllm_config: VllmConfig) -> int:
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size

# TPU only has 32 SREGs (scalar registers), if page_size is too small, we
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, not a review, but this is interesting information, is there anywhere I can find it online?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Akshat, thanks for asking! I can not find any TPU's SREGs number documented anywhere publicly. So I think it is better to not mention this in the comments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok thanks!

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In trying to run a quick test I accidentally ran V0. It seems this PR breaks V0 by not specifying the block_size in that flow

  File "/home/mgoin/code/vllm/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgoin/code/vllm/vllm/utils.py", line 2463, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgoin/code/vllm/vllm/worker/worker_base.py", line 594, in init_worker
    self.worker = worker_class(**kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mgoin/code/vllm/vllm/worker/tpu_worker.py", line 51, in __init__
    self.model_runner: TPUModelRunner = TPUModelRunner(
                                        ^^^^^^^^^^^^^^^
  File "/home/mgoin/code/vllm/vllm/worker/tpu_model_runner.py", line 111, in __init__
    self.max_num_blocks_per_seq = (self.model_config.max_model_len //
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: unsupported operand type(s) for //: 'int' and 'NoneType'
        self.max_num_blocks_per_seq = (self.model_config.max_model_len //
                                       self.block_size)

@mergify mergify bot added the documentation Improvements or additions to documentation label May 2, 2025
@bythew3i bythew3i force-pushed the ragged-jevinjiang branch from 66696a6 to 284c605 Compare May 2, 2025 05:36
@bythew3i
Copy link
Contributor Author

bythew3i commented May 2, 2025

@mgoin PTAL! Thanks!

@bythew3i bythew3i force-pushed the ragged-jevinjiang branch from 284c605 to b6e7d0b Compare May 2, 2025 05:46
@bythew3i bythew3i requested a review from mgoin May 2, 2025 05:50
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label May 2, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thank you!

@bythew3i
Copy link
Contributor Author

bythew3i commented May 5, 2025

@mgoin Can you please help merge this PR?

@mgoin
Copy link
Member

mgoin commented May 5, 2025

@bythew3i the TPU V1 sampler test is failing https://buildkite.com/vllm/ci/builds/19212#01969083-3aed-45d6-9738-f4d601113fd5/6-1707
Can you see if you can reproduce locally? It is concerning that the outputs are all "!"

@bythew3i
Copy link
Contributor Author

bythew3i commented May 6, 2025

@bythew3i the TPU V1 sampler test is failing https://buildkite.com/vllm/ci/builds/19212#01969083-3aed-45d6-9738-f4d601113fd5/6-1707 Can you see if you can reproduce locally? It is concerning that the outputs are all "!"

@mgoin Is this the right cmd to test?

VLLM_USE_V1=1 pytest tests/v1/tpu/test_sampler.py

I tested on local... it also failed at main branch... let me pull the latest change to see if the failure still exist

@bythew3i
Copy link
Contributor Author

bythew3i commented May 6, 2025

@bythew3i the TPU V1 sampler test is failing https://buildkite.com/vllm/ci/builds/19212#01969083-3aed-45d6-9738-f4d601113fd5/6-1707 Can you see if you can reproduce locally? It is concerning that the outputs are all "!"

The error seems not related to this PR... It fails at HEAD on main branch @mgoin

@mgoin mgoin merged commit 621ca2c into vllm-project:main May 6, 2025
71 checks passed
robertgshaw2-redhat added a commit to neuralmagic/vllm that referenced this pull request May 6, 2025
* [Model] Add GraniteMoeHybrid 4.0 model (vllm-project#17497)

Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]>
Co-authored-by: Thomas Ortner <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>

* [easy] Fix logspam on PiecewiseBackend errors (vllm-project#17138)

Signed-off-by: rzou <[email protected]>

* [Bugfix] Fixed prompt length for random dataset (vllm-project#17408)

Signed-off-by: Mikhail Podvitskii <[email protected]>

* [Doc] Update notes for H2O-VL and Gemma3 (vllm-project#17219)

Signed-off-by: DarkLight1337 <[email protected]>

* [Misc] Fix ScalarType float4 naming  (vllm-project#17690)

Signed-off-by: Lucas Wilkinson <[email protected]>

* Fix `dockerfilegraph` pre-commit hook (vllm-project#17698)

Signed-off-by: Harry Mellor <[email protected]>

* [Bugfix] Fix triton import with local TritonPlaceholder (vllm-project#17446)

Signed-off-by: Mengqing Cao <[email protected]>

* [V1] Enable TPU V1 backend by default (vllm-project#17673)

Signed-off-by: mgoin <[email protected]>

* [V1][PP] Support PP for MultiprocExecutor (vllm-project#14219)

Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang.li <[email protected]>

* [v1] AttentionMetadata for each layer (vllm-project#17394)

Signed-off-by: Chen Zhang <[email protected]>

* [Feat] Add deprecated=True to CLI args (vllm-project#17426)

Signed-off-by: Aaron Pham <[email protected]>

* [Docs] Use gh-file to add links to tool_calling.md (vllm-project#17709)

Signed-off-by: windsonsea <[email protected]>

* [v1] Introduce KVCacheBlocks as interface between Scheduler and KVCacheManager (vllm-project#17479)

Signed-off-by: Chen Zhang <[email protected]>

* [doc] Add RAG Integration example (vllm-project#17692)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Bugfix] Fix modality limits in vision language example (vllm-project#17721)

Signed-off-by: DarkLight1337 <[email protected]>

* Make right sidebar more readable in "Supported Models" (vllm-project#17723)

Signed-off-by: Harry Mellor <[email protected]>

* [TPU] Increase block size and reset block shapes (vllm-project#16458)

* [Misc] Add Next Edit Prediction (NEP) datasets support in `benchmark_serving.py` (vllm-project#16839)

Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Signed-off-by: dtransposed <>
Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>

* [Bugfix] Fix for the condition to accept empty encoder inputs for mllama (vllm-project#17732)

Signed-off-by: Gregory Shtrasberg <[email protected]>

* [Kernel] Unified Triton kernel that doesn't distinguish between prefill + decode (vllm-project#16828)

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>

---------

Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: rzou <[email protected]>
Signed-off-by: Mikhail Podvitskii <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Mengqing Cao <[email protected]>
Signed-off-by: mgoin <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang.li <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: windsonsea <[email protected]>
Signed-off-by: reidliu41 <[email protected]>
Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Signed-off-by: dtransposed <>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: [email protected] <[email protected]>
Co-authored-by: Stan Wozniak <[email protected]>
Co-authored-by: Thomas Ortner <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Richard Zou <[email protected]>
Co-authored-by: Mikhail Podvitskii <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Mengqing Cao <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Li, Jiang <[email protected]>
Co-authored-by: Chen Zhang <[email protected]>
Co-authored-by: Aaron Pham <[email protected]>
Co-authored-by: Michael Yao <[email protected]>
Co-authored-by: Reid <[email protected]>
Co-authored-by: reidliu41 <[email protected]>
Co-authored-by: Jevin Jiang <[email protected]>
Co-authored-by: d.transposed <[email protected]>
Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants