Skip to content

[v1] Implement HybridKVCacheManager to support hybrid models with different KV cache type #16101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 39 commits into
base: main
Choose a base branch
from

Conversation

heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented Apr 5, 2025

The reference implementation of hybrid allocator. I’m splitting it into smaller PRs and do further clean up in the smaller PRs.
Key differences with #13296 and #16178
1. Only create one specialize manager for each type of attention. For instance, Gemma 3 uses 2 managers (a full attention manager and a SWA manager) instead of 6 (one full attention manager and five SWA managers). (The same as #16178 but with different implementation
2. Hash: compute the hash for each block_size, instead of each kv cache group in #13296 or only for full attention layer in #16178
3. A general hybrid allocator instead of a specialized one in #16178 or two allocators, one for hybrid model and another for non-hybrid model in #13296. Add fast path to non-hybrid model when necessary.
4. Introduce GroupedKVCacheBlock

# KVCacheBlocks for the same block of all kv cache groups with the same kv cache
# spec (and belongs to the same manager)
@dataclass
class GroupedKVCacheBlock:
blocks: tuple[KVCacheBlock, ...]
to save the same block of all kv cache groups with the same kv cache spec (and belongs to the same manager). E.g., a GroupedKVCacheBlock for Gemma3 may contain 5 blocks for the [0-16] tokens of the 5 SWA kv cache groups.
5. In block_pool, perform cache and eviction at the granularity of GroupedKVCacheBlock
self.cached_block_hash_to_block: list[dict[BlockHashType, dict[
int, GroupedKVCacheBlock]]] = [
defaultdict(dict) for _ in range(num_specialized_managers)
]

so that we do not need to iterate over all groups to check whether all group has a cached block for a specific hash (
if (cached_blocks and all(group_id in cached_blocks
) like #16178
6. Change the allocation result of KVCacheManager to
class KVCacheBlocks:
blocks: list[list[GroupedKVCacheBlock]]
where blocks[i][j] is for the GroupedKVCacheBlock of manager i and tokens [j * block_size, (j+1) * block_size]. With this data structure, each manager can work almost independently and do not need to iterate over all groups managed by it to update the allocation result.
7. (Not included in this PR, but will do) introduce a memory coordinator that serves as a middle layer between KVCacheManager and SpecializedManagers, to simplify the logic of KVCacheManager.

Hybrid allocator RFC #11382

Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Copy link

github-actions bot commented Apr 5, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added v1 tpu Related to Google TPUs labels Apr 5, 2025
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Great work!

A few high-level suggestions:

  1. Can we first focus on the cases where every layer has the same embedding size? I think we can support Mamba or other cases in a future PR.
  2. Can we have an architecture like this?
Screenshot 2025-04-05 at 10 16 46 PM

@@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for?

@@ -22,7 +22,7 @@
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
from vllm.v1.worker.tpu_model_runner import TPUModelRunner # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this for?

Signed-off-by: Chen Zhang <[email protected]>
Copy link

mergify bot commented Apr 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 9, 2025
@mergify mergify bot removed the needs-rebase label Apr 23, 2025
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
@mergify mergify bot added the documentation Improvements or additions to documentation label Apr 24, 2025
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Copy link

mergify bot commented Apr 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 26, 2025
@mergify mergify bot removed the needs-rebase label Apr 26, 2025
Copy link

mergify bot commented Apr 29, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 29, 2025
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus

logger = init_logger(__name__)


@dataclass
class KVCacheBlocks:
blocks: list[list[GroupedKVCacheBlock]]
Copy link
Collaborator Author

@heheda12345 heheda12345 Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 dimensions:
blocks[manager_id][ith_block][group_id_in_manager]

@heheda12345
Copy link
Collaborator Author

Also CC @comaniac

@renjie0
Copy link

renjie0 commented May 6, 2025

Could you please provide a design doc for the community to review? How does it work with prefix caching? How will it maximize the KV cache hit rate for global attention layer? How will it work with speculative decoding? What is the potential timeline for this support? It has been 6 months since the RFC

@heheda12345
Copy link
Collaborator Author

You can find the design in the PR description of #13296 and this PR.

@shan18
Copy link

shan18 commented May 6, 2025

Hi, I tried running inference on the gemma-3-12b-it using your branch but I keep on getting invalid responses from the model.

For example, when I host the model with the vllm server like this:

python3 -m vllm.entrypoints.openai.api_server \
    --model google/gemma-3-12b-it \
    --trust-remote-code \
    --seed 1 \
    --host "0.0.0.0" \
    --port 5000 \
    --served-model-name "test-model" \
    --tensor-parallel-size 8 \
    --max-model-len 65536 \
    --enforce-eager

And then give it a prompt from the AIME24 dataset, I get response like this:

\u9154 cudd\u0cbf\u0c82\u0ca6\u179a\u17bc\u1794breviinction Coy blossom\u5414\u0b9f\u0b95 \u0939\u093e\u092eClo obstructions\u054f anf Zin \u0aa5\u05de\u05d9\u05ea Meat\u8c5agente\u0644\u0627\u0646Twelves .....

For comparison, when I do the same with the main vllm branch (v0.8.3), the response to the same prompt is this:

The uncertainty principle states that the uncertainty in energy (\u0394E) and the lifetime (\u03c4) of a quantum state are related by \u0394E \
u2248 \u0127/\u03c4, where \u0127 is the reduced Planc ...

Is there anything that I need to setup separately in order for it to work?

@heheda12345
Copy link
Collaborator Author

@shan18 #17574 You need to cherry-pick this PR.

# Use copy to avoid modifying the original block_hashes
block_hashes = [
block_hashes_dict[g.kv_cache_spec.block_size].copy()
for g in self.kv_cache_config.kv_cache_groups
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use number of specialized manager instead of kv_cache_groups?



@dataclass
class KVCacheNewTensor(KVCacheTensorBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the tests might need to update to this naming, KVCacheNewTensor, e.g. test_kv_cache_utils.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. This PR is just a POC, so I didn't fix the tests.

@shan18
Copy link

shan18 commented May 12, 2025

@shan18 #17574 You need to cherry-pick this PR.

@heheda12345 , I tried with the fixes in the PR you shared but still I don't get any valid responses from the model. Do you have any test scripts that you used with your PR that I can try out?

@heheda12345
Copy link
Collaborator Author

@shan18 If I remembered correctly, this PR should pass tests/v1/e2e/test_correctness_sliding_window.py after cherry-pick #17574
But this PR is just a prototype and I don't plan to maintain it. I think the final implementation will be finished very soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation needs-rebase tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants