-
-
Notifications
You must be signed in to change notification settings - Fork 7.5k
[Kernel] Unified Triton kernel that doesn't distinguish between prefill + decode #16828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Re: the weirdness with torch compile on MI300x, I followed the suggestion of @robertgshaw2-redhat and re-ran everything inside the latest ![]() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! Overall looks pretty good! Left a few comments
S = apply_softcap(S, softcap) | ||
|
||
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, | ||
S, float("-inf")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add support for non-casual attention too? could be a future PR, but its useful for cascade attention and MLA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I don't think that would be too hard
Signed-off-by: Lucas Wilkinson <[email protected]> review comments Signed-off-by: Lucas Wilkinson <[email protected]> review comments + make unit tests pass Signed-off-by: Lucas Wilkinson <[email protected]> fix assert Signed-off-by: Lucas Wilkinson <[email protected]>
19c5657
to
13c1c87
Compare
…d we have questions around cache keys Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Head branch was pushed to by a user without write access
Signed-off-by: Thomas Parnell <[email protected]>
* [Model] Add GraniteMoeHybrid 4.0 model (vllm-project#17497) Signed-off-by: Thomas Ortner <[email protected]> Signed-off-by: Stanislaw Wozniak <[email protected]> Co-authored-by: Thomas Ortner <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> * [easy] Fix logspam on PiecewiseBackend errors (vllm-project#17138) Signed-off-by: rzou <[email protected]> * [Bugfix] Fixed prompt length for random dataset (vllm-project#17408) Signed-off-by: Mikhail Podvitskii <[email protected]> * [Doc] Update notes for H2O-VL and Gemma3 (vllm-project#17219) Signed-off-by: DarkLight1337 <[email protected]> * [Misc] Fix ScalarType float4 naming (vllm-project#17690) Signed-off-by: Lucas Wilkinson <[email protected]> * Fix `dockerfilegraph` pre-commit hook (vllm-project#17698) Signed-off-by: Harry Mellor <[email protected]> * [Bugfix] Fix triton import with local TritonPlaceholder (vllm-project#17446) Signed-off-by: Mengqing Cao <[email protected]> * [V1] Enable TPU V1 backend by default (vllm-project#17673) Signed-off-by: mgoin <[email protected]> * [V1][PP] Support PP for MultiprocExecutor (vllm-project#14219) Signed-off-by: jiang1.li <[email protected]> Signed-off-by: jiang.li <[email protected]> * [v1] AttentionMetadata for each layer (vllm-project#17394) Signed-off-by: Chen Zhang <[email protected]> * [Feat] Add deprecated=True to CLI args (vllm-project#17426) Signed-off-by: Aaron Pham <[email protected]> * [Docs] Use gh-file to add links to tool_calling.md (vllm-project#17709) Signed-off-by: windsonsea <[email protected]> * [v1] Introduce KVCacheBlocks as interface between Scheduler and KVCacheManager (vllm-project#17479) Signed-off-by: Chen Zhang <[email protected]> * [doc] Add RAG Integration example (vllm-project#17692) Signed-off-by: reidliu41 <[email protected]> Co-authored-by: reidliu41 <[email protected]> * [Bugfix] Fix modality limits in vision language example (vllm-project#17721) Signed-off-by: DarkLight1337 <[email protected]> * Make right sidebar more readable in "Supported Models" (vllm-project#17723) Signed-off-by: Harry Mellor <[email protected]> * [TPU] Increase block size and reset block shapes (vllm-project#16458) * [Misc] Add Next Edit Prediction (NEP) datasets support in `benchmark_serving.py` (vllm-project#16839) Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal> Signed-off-by: dtransposed <> Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal> * [Bugfix] Fix for the condition to accept empty encoder inputs for mllama (vllm-project#17732) Signed-off-by: Gregory Shtrasberg <[email protected]> * [Kernel] Unified Triton kernel that doesn't distinguish between prefill + decode (vllm-project#16828) Signed-off-by: Thomas Parnell <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> --------- Signed-off-by: Thomas Ortner <[email protected]> Signed-off-by: Stanislaw Wozniak <[email protected]> Signed-off-by: rzou <[email protected]> Signed-off-by: Mikhail Podvitskii <[email protected]> Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Mengqing Cao <[email protected]> Signed-off-by: mgoin <[email protected]> Signed-off-by: jiang1.li <[email protected]> Signed-off-by: jiang.li <[email protected]> Signed-off-by: Chen Zhang <[email protected]> Signed-off-by: Aaron Pham <[email protected]> Signed-off-by: windsonsea <[email protected]> Signed-off-by: reidliu41 <[email protected]> Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal> Signed-off-by: dtransposed <> Signed-off-by: Gregory Shtrasberg <[email protected]> Signed-off-by: Thomas Parnell <[email protected]> Signed-off-by: [email protected] <[email protected]> Co-authored-by: Stan Wozniak <[email protected]> Co-authored-by: Thomas Ortner <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Richard Zou <[email protected]> Co-authored-by: Mikhail Podvitskii <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Co-authored-by: Harry Mellor <[email protected]> Co-authored-by: Mengqing Cao <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Li, Jiang <[email protected]> Co-authored-by: Chen Zhang <[email protected]> Co-authored-by: Aaron Pham <[email protected]> Co-authored-by: Michael Yao <[email protected]> Co-authored-by: Reid <[email protected]> Co-authored-by: reidliu41 <[email protected]> Co-authored-by: Jevin Jiang <[email protected]> Co-authored-by: d.transposed <[email protected]> Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal> Co-authored-by: Gregory Shtrasberg <[email protected]> Co-authored-by: Thomas Parnell <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]>
…ll + decode (vllm-project#16828) Signed-off-by: Thomas Parnell <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Signed-off-by: Mu Huai <[email protected]>
…ll + decode (vllm-project#16828) Signed-off-by: Thomas Parnell <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]>
In this PR we add:
triton_unified_attention
) that works likeflash_attn_varlen_func
and can handle arbitrary query length. The kernel does GQA "packing" along the query dimension to ensure the Tensor cores are well used.flash_attn_varlen_func
Best performance is obtained when using the jit cache decorator from
triton_dejavu
. In this code I'm using the jit cache decorator fromtriton_dejavu
package but if #16606 is merged we could use it directly from vLLM.Note that the unit tests don't currently pass when I enable the jit cache, but they all pass if it disabled. This is because we are testing different combinations of numbers of heads etc, which we assume to be constant in the decorator. We probably need to think of a good testing strategy for kernels with this decorator (cc @bringlein).
Initial benchmarking
Here are some initial benchmarking results on H100 for
llama3.1-8b
using:Note that with these changes, the Triton backend significantly outperforms FlashAttention backend on an H100 GPU for this workloads.
Correctness
@bringlein the correctness check only looks good if I use branch
tpa-grid-copy
from triton-dejavu, if I use main it fails. you should be able to reproduce on H100.Further benchmarking
I've run the following scenario across both H100 and MI300x using different backends from main, as well as using the changes from this PR:
Here are the results:

Main takeaways:
cc @robertgshaw2-redhat @tlrmchlsmth @SageMoore @bringlein @jvlunteren @LucasWilkinson