Skip to content

[v1] Introduce KVCacheBlocks as interface between Scheduler and KVCacheManager #17479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 6, 2025

Conversation

heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented Apr 30, 2025

In the future hybrid allocator, the KVCacheManager output would be list[list[list[KVCacheBlocks]], which is much more complex than the current list[KVCacheBlocks].

To hide the complexity, this pr introduces KVCacheBlocks to save the KVCacheManager output so that scheduler do not need to parse the internal structure of KVCacheManager output.

Splitted from #16101

Signed-off-by: Chen Zhang <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Apr 30, 2025
Copy link

mergify bot commented Apr 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 30, 2025
@mergify mergify bot removed the needs-rebase label May 1, 2025
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Left minor comments.


def allocate_slots(
self,
request: Request,
num_tokens: int,
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
new_computed_blocks_obj: Optional[KVCacheBlocks] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Why rename this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Becase it is with type KVCacheBlocks and we need to unwrap it to list[KVCacheBlock] before using. Prefer to call the object with type list[KVCacheBlock] "new_computed_blocks" so I rename this object to "new_computed_blocks_obj"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But when would we ever use list[KVCacheBlock] directly? The _obj suffix feels a bit awkward to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list[KVCacheBlock] is used inside allocate_slots

Copy link
Collaborator

@WoosukKwon WoosukKwon May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we kinda decided to blur the distinction when introducing the KVCacheBlocks class. I don't understand why this particular variable name is a concern for you.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise I need to change new_computed_blocks in this function to another name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally prefer that, because I really feel new_computed_blocks_obj awkward. But I'm fine if you want to stick to this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I can change it later.

Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
@heheda12345
Copy link
Collaborator Author

@WoosukKwon I've updated this PR. Can you take another look?

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label May 6, 2025
Signed-off-by: Chen Zhang <[email protected]>
@@ -109,7 +128,7 @@ def get_computed_blocks(
"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the doc string as well about the return type?

@WoosukKwon WoosukKwon merged commit aabcd2c into vllm-project:main May 6, 2025
49 checks passed
robertgshaw2-redhat added a commit to neuralmagic/vllm that referenced this pull request May 6, 2025
* [Model] Add GraniteMoeHybrid 4.0 model (vllm-project#17497)

Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]>
Co-authored-by: Thomas Ortner <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>

* [easy] Fix logspam on PiecewiseBackend errors (vllm-project#17138)

Signed-off-by: rzou <[email protected]>

* [Bugfix] Fixed prompt length for random dataset (vllm-project#17408)

Signed-off-by: Mikhail Podvitskii <[email protected]>

* [Doc] Update notes for H2O-VL and Gemma3 (vllm-project#17219)

Signed-off-by: DarkLight1337 <[email protected]>

* [Misc] Fix ScalarType float4 naming  (vllm-project#17690)

Signed-off-by: Lucas Wilkinson <[email protected]>

* Fix `dockerfilegraph` pre-commit hook (vllm-project#17698)

Signed-off-by: Harry Mellor <[email protected]>

* [Bugfix] Fix triton import with local TritonPlaceholder (vllm-project#17446)

Signed-off-by: Mengqing Cao <[email protected]>

* [V1] Enable TPU V1 backend by default (vllm-project#17673)

Signed-off-by: mgoin <[email protected]>

* [V1][PP] Support PP for MultiprocExecutor (vllm-project#14219)

Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang.li <[email protected]>

* [v1] AttentionMetadata for each layer (vllm-project#17394)

Signed-off-by: Chen Zhang <[email protected]>

* [Feat] Add deprecated=True to CLI args (vllm-project#17426)

Signed-off-by: Aaron Pham <[email protected]>

* [Docs] Use gh-file to add links to tool_calling.md (vllm-project#17709)

Signed-off-by: windsonsea <[email protected]>

* [v1] Introduce KVCacheBlocks as interface between Scheduler and KVCacheManager (vllm-project#17479)

Signed-off-by: Chen Zhang <[email protected]>

* [doc] Add RAG Integration example (vllm-project#17692)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Bugfix] Fix modality limits in vision language example (vllm-project#17721)

Signed-off-by: DarkLight1337 <[email protected]>

* Make right sidebar more readable in "Supported Models" (vllm-project#17723)

Signed-off-by: Harry Mellor <[email protected]>

* [TPU] Increase block size and reset block shapes (vllm-project#16458)

* [Misc] Add Next Edit Prediction (NEP) datasets support in `benchmark_serving.py` (vllm-project#16839)

Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Signed-off-by: dtransposed <>
Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>

* [Bugfix] Fix for the condition to accept empty encoder inputs for mllama (vllm-project#17732)

Signed-off-by: Gregory Shtrasberg <[email protected]>

* [Kernel] Unified Triton kernel that doesn't distinguish between prefill + decode (vllm-project#16828)

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>

---------

Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: rzou <[email protected]>
Signed-off-by: Mikhail Podvitskii <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Mengqing Cao <[email protected]>
Signed-off-by: mgoin <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Signed-off-by: jiang.li <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: windsonsea <[email protected]>
Signed-off-by: reidliu41 <[email protected]>
Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Signed-off-by: dtransposed <>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: [email protected] <[email protected]>
Co-authored-by: Stan Wozniak <[email protected]>
Co-authored-by: Thomas Ortner <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Richard Zou <[email protected]>
Co-authored-by: Mikhail Podvitskii <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Mengqing Cao <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Li, Jiang <[email protected]>
Co-authored-by: Chen Zhang <[email protected]>
Co-authored-by: Aaron Pham <[email protected]>
Co-authored-by: Michael Yao <[email protected]>
Co-authored-by: Reid <[email protected]>
Co-authored-by: reidliu41 <[email protected]>
Co-authored-by: Jevin Jiang <[email protected]>
Co-authored-by: d.transposed <[email protected]>
Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants