Skip to content

[Kernel] Unified Triton kernel that doesn't distinguish between prefill + decode #16828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 6, 2025

Conversation

tdoublep
Copy link
Member

@tdoublep tdoublep commented Apr 18, 2025

In this PR we add:

  • A new Triton kernel (triton_unified_attention) that works like flash_attn_varlen_func and can handle arbitrary query length. The kernel does GQA "packing" along the query dimension to ensure the Tensor cores are well used.
  • Added a new unit test that is based on the unit tests for flash_attn_varlen_func
  • Updated the V1 Triton attention backend to use this kernel. Note that the memory layout for the key cache also changes to match what the Flash attention backend is doing.

Best performance is obtained when using the jit cache decorator from triton_dejavu. In this code I'm using the jit cache decorator from triton_dejavu package but if #16606 is merged we could use it directly from vLLM.

Note that the unit tests don't currently pass when I enable the jit cache, but they all pass if it disabled. This is because we are testing different combinations of numbers of heads etc, which we assume to be constant in the decorator. We probably need to think of a good testing strategy for kernels with this decorator (cc @bringlein).

Initial benchmarking

Here are some initial benchmarking results on H100 for llama3.1-8b using:

python benchmark_serving.py \
    --model meta-llama/Llama-3.1-8B-Instruct  \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json

image

Note that with these changes, the Triton backend significantly outperforms FlashAttention backend on an H100 GPU for this workloads.

Correctness

$ VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.800|±  |0.0179|
|     |       |strict-match    |     5|exact_match|↑  |0.784|±  |0.0184|

@bringlein the correctness check only looks good if I use branch tpa-grid-copy from triton-dejavu, if I use main it fails. you should be able to reproduce on H100.

Further benchmarking

I've run the following scenario across both H100 and MI300x using different backends from main, as well as using the changes from this PR:

MODEL=mistralai/Mistral-Small-24B-Instruct-2501
REQUEST_RATES=(1 5 7 9)
TOTAL_SECONDS=120

for REQUEST_RATE in "${REQUEST_RATES[@]}";
do
    NUM_PROMPTS=$(($TOTAL_SECONDS * $REQUEST_RATE))

    echo ""
    echo "===== RUNNING $MODEL FOR $NUM_PROMPTS PROMPTS WITH $REQUEST_RATE QPS ====="
    echo ""

    python3 benchmark_serving.py \
        --model $MODEL \
        --dataset-name random \
        --request-rate $REQUEST_RATE \
        --random-input-len 1000 \
        --random-output-len 100 \
        --num-prompts $NUM_PROMPTS \
        --ignore-eos --seed $REQUEST_RATE
done

Here are the results:
image

Main takeaways:

  • TritonBackend from this PR is out-performing FlashAttentionBackend on H100
  • The changes from this PR significantly improve the performance on MI300x at high QPS
  • For some reason, using compile seems to make the results worse on MI300x (something that @SageMoore had mentioned to me previously). Do we have an issue to track that somewhere?

cc @robertgshaw2-redhat @tlrmchlsmth @SageMoore @bringlein @jvlunteren @LucasWilkinson



tdoublep added 10 commits April 14, 2025 09:42
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Apr 18, 2025
@tdoublep
Copy link
Member Author

Re: the weirdness with torch compile on MI300x, I followed the suggestion of @robertgshaw2-redhat and re-ran everything inside the latest rocm/vllm-dev:nightly image that uses newer version of torch (2.7.0a0+git295f2ed). I now see better results without needing to use --enforce-eager on MI300x. I also ran the baseline V0 performance on MI300x using the ROCmFlashAttentionBackend.

image

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Overall looks pretty good! Left a few comments

S = apply_softcap(S, softcap)

S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
S, float("-inf"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add support for non-casual attention too? could be a future PR, but its useful for cascade attention and MLA

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I don't think that would be too hard

Signed-off-by: Lucas Wilkinson <[email protected]>

review comments

Signed-off-by: Lucas Wilkinson <[email protected]>

review comments + make unit tests pass

Signed-off-by: Lucas Wilkinson <[email protected]>

fix assert

Signed-off-by: Lucas Wilkinson <[email protected]>
…d we have questions around cache keys

Signed-off-by: Lucas Wilkinson <[email protected]>
@LucasWilkinson LucasWilkinson enabled auto-merge (squash) May 1, 2025 18:34
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label May 1, 2025
auto-merge was automatically disabled May 6, 2025 15:52

Head branch was pushed to by a user without write access

Signed-off-by: Thomas Parnell <[email protected]>
@LucasWilkinson LucasWilkinson merged commit 2f925e5 into vllm-project:main May 6, 2025
51 checks passed
@tdoublep tdoublep deleted the tpa-unified2 branch May 6, 2025 22:25
robertgshaw2-redhat added a commit to neuralmagic/vllm that referenced this pull request May 6, 2025
* [Model] Add GraniteMoeHybrid 4.0 model (vllm-project#17497)

Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]>
Co-authored-by: Thomas Ortner <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>

* [easy] Fix logspam on PiecewiseBackend errors (vllm-project#17138)

Signed-off-by: rzou <[email protected]>

* [Bugfix] Fixed prompt length for random dataset (vllm-project#17408)

Signed-off-by: Mikhail Podvitskii <[email protected]>

* [Doc] Update notes for H2O-VL and Gemma3 (vllm-project#17219)

Signed-off-by: DarkLight1337 <[email protected]>

* [Misc] Fix ScalarType float4 naming  (vllm-project#17690)

Signed-off-by: Lucas Wilkinson <[email protected]>

* Fix `dockerfilegraph` pre-commit hook (vllm-project#17698)

Signed-off-by: Harry Mellor <[email protected]>

* [Bugfix] Fix triton import with local TritonPlaceholder (vllm-project#17446)

Signed-off-by: Mengqing Cao <[email protected]>

* [V1] Enable TPU V1 backend by default (vllm-project#17673)

Signed-off-by: mgoin <[email protected]>

* [V1][PP] Support PP for MultiprocExecutor (vllm-project#14219)

Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang.li <[email protected]>

* [v1] AttentionMetadata for each layer (vllm-project#17394)

Signed-off-by: Chen Zhang <[email protected]>

* [Feat] Add deprecated=True to CLI args (vllm-project#17426)

Signed-off-by: Aaron Pham <[email protected]>

* [Docs] Use gh-file to add links to tool_calling.md (vllm-project#17709)

Signed-off-by: windsonsea <[email protected]>

* [v1] Introduce KVCacheBlocks as interface between Scheduler and KVCacheManager (vllm-project#17479)

Signed-off-by: Chen Zhang <[email protected]>

* [doc] Add RAG Integration example (vllm-project#17692)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Bugfix] Fix modality limits in vision language example (vllm-project#17721)

Signed-off-by: DarkLight1337 <[email protected]>

* Make right sidebar more readable in "Supported Models" (vllm-project#17723)

Signed-off-by: Harry Mellor <[email protected]>

* [TPU] Increase block size and reset block shapes (vllm-project#16458)

* [Misc] Add Next Edit Prediction (NEP) datasets support in `benchmark_serving.py` (vllm-project#16839)

Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Signed-off-by: dtransposed <>
Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>

* [Bugfix] Fix for the condition to accept empty encoder inputs for mllama (vllm-project#17732)

Signed-off-by: Gregory Shtrasberg <[email protected]>

* [Kernel] Unified Triton kernel that doesn't distinguish between prefill + decode (vllm-project#16828)

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>

---------

Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: rzou <[email protected]>
Signed-off-by: Mikhail Podvitskii <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Mengqing Cao <[email protected]>
Signed-off-by: mgoin <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang.li <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: windsonsea <[email protected]>
Signed-off-by: reidliu41 <[email protected]>
Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Signed-off-by: dtransposed <>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: [email protected] <[email protected]>
Co-authored-by: Stan Wozniak <[email protected]>
Co-authored-by: Thomas Ortner <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Richard Zou <[email protected]>
Co-authored-by: Mikhail Podvitskii <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Mengqing Cao <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Li, Jiang <[email protected]>
Co-authored-by: Chen Zhang <[email protected]>
Co-authored-by: Aaron Pham <[email protected]>
Co-authored-by: Michael Yao <[email protected]>
Co-authored-by: Reid <[email protected]>
Co-authored-by: reidliu41 <[email protected]>
Co-authored-by: Jevin Jiang <[email protected]>
Co-authored-by: d.transposed <[email protected]>
Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…ll + decode (vllm-project#16828)

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
@tdoublep tdoublep restored the tpa-unified2 branch May 13, 2025 18:20
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
…ll + decode (vllm-project#16828)

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants