-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
[Model] Refactor Mamba2 SSD to improve chunked prefill performance #16942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Refactor Mamba2 SSD to improve chunked prefill performance #16942
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Performance resultsTL;DR: this PR gives 1.53x throughput improvement when chunked prefill is ON.
Main (d9ac9e3) with chunked prefill ON
PR with chunked prefill ON
PR with chunked prefill OFF
|
a7a4561
to
94639f0
Compare
Output qualityBamba-9B
Main (d9ac9e3)
This PRIt makes the results slightly better..?
Zamba2-2.7B
Main
This PR
Mamba-Codestral-7B
Main
This PR
|
6e38447
to
e6f78da
Compare
Unit test resultsPassed after a fix 182d4ad
|
182d4ad
to
6035d6b
Compare
6035d6b
to
5b0acfe
Compare
Through some small experiments I am aware that when chunked prefill is ON with Mamba2 models, the same input repeated across a batch can lead to varying generations under greedy decoding. However, when chunked prefill is OFF, the generations are consistent. Does this PR (or some other effort) plan to address this? Thanks! |
@prannaykaul No, this PR does not attempt to fix what you described. However, the rerouting of the prefill and decode requests in the mamba2 layer may have an effect on that.
|
Using the above script which should be a self-contained qualitative eval of this behaviour. Installing each branch [a928424, pr_mamba2_chunk_prefill_refactor, pr_mamba2_conv1d_refactor] with the first one containing none of your edits. I find the generations to be inconsistent in all 3 branches when chunked_prefill is enabled:
e.g. on pr_mamba_conv1d_refactor:
whereas when chunked_prefill is disabled, the generations are consistent:
e.g. on pr_mamba_conv1d_refactor
Other models such as Bamba also demonstrate the same behavior but tend to require longer greedy generations to see the difference. In pure Mamba2 models (like the Codestral model used), the difference with chunked_prefill enabled tend to be immediate. |
Thanks for the details @prannaykaul |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a couple of small comments but LGTM
num_prefills = attn_metadata.num_prefills # #requests | ||
num_decodes = attn_metadata.num_decode_tokens # #tokens==#requests | ||
num_prefill_tokens = attn_metadata.num_prefill_tokens # #tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain what is meant by the comments at the ends of these lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments mean the corresponding variable values are counting "number of requests" or "number of tokens"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want me to change the comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it would help clarity. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved
n_groups = self.n_groups // self.tp_size | ||
A = self.A[:, None, ...][:, :, None].expand( | ||
A_d = self.A[:, None, ...][:, :, None].expand( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the suffix _d
mean in this code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh is it decode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, _p
means prefill and _d
means decode
0b19ed4
to
bdf4e64
Compare
Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Chih-Chieh-Yang <[email protected]>
bdf4e64
to
20452d3
Compare
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tlrmchlsmth I think you meant #17146 Closing |
We found that when chunked prefill is enabled, the performance is bad for benchmark_serving.py with ShareGPTv3. After some analysis, we identified that
chunk_scan_fwd_kernel
latency increases linearly with the number of "chunks", and while this can process prefill chunks efficiently, each decode request in the mixed batch will give the kernel one full chunk of work to process, despite that the decode request has only a single token.In this PR, we modify the mamba2 ssd control flow assuming vLLM v0, where the mixed input batch has prefill chunks that come before decode requests. When processing the input, we split the input tensors at the prefill-decode boundary, and invoke SSD processing functions to apply to them separately. In this way, the prefill kernels don't deal with decode requests and can run more efficiently.
For V1 Mamba2 SSD will likely require reordering of the batch for the logic to work and will need some rewriting.
Known issue: