Skip to content

[Optim] Compute multimodal hash only once per item #17314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions vllm/model_executor/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
BaseProcessingInfo, MultiModalHashes,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
Expand Down Expand Up @@ -279,24 +279,26 @@ def _cached_apply_hf_processor(
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 2:
# This code path corresponds to the cache being disabled
return self._apply_hf_processor_main(
return self._apply_hf_processor(
prompt=prompt,
mm_items=mm_data_items,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
return_mm_hashes=return_mm_hashes,
)

return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
return_mm_hashes=return_mm_hashes,
)


Expand Down
16 changes: 9 additions & 7 deletions vllm/model_executor/models/h2ovl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.transformers_utils.tokenizer import AnyTokenizer

from .intern_vit import InternVisionModel
Expand Down Expand Up @@ -488,24 +488,26 @@ def _cached_apply_hf_processor(
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
# The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 1:
# This code path corresponds to the cache being disabled
return self._apply_hf_processor_main(
return self._apply_hf_processor(
prompt=prompt,
mm_items=mm_data_items,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
return_mm_hashes=return_mm_hashes,
)

return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
return_mm_hashes=return_mm_hashes,
)


Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,22 +396,19 @@ def _build_llava_or_pixtral_hf_processor(
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True,
) -> BaseMultiModalProcessor:
if isinstance(info, PixtralHFProcessingInfo):
return PixtralHFMultiModalProcessor(
info,
dummy_inputs, # type: ignore
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)

if isinstance(info, LlavaProcessingInfo):
return LlavaMultiModalProcessor(
info,
dummy_inputs, # type: ignore
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)

raise NotImplementedError(type(info))
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/mistral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,12 @@ def _build_mistral3_processor(
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True,
) -> BaseMultiModalProcessor:
assert isinstance(info, Mistral3ProcessingInfo)
return Mistral3MultiModalProcessor(
info,
dummy_inputs, # type: ignore
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)


Expand Down
15 changes: 10 additions & 5 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
BaseProcessingInfo, MultiModalHashes,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
Expand Down Expand Up @@ -271,15 +272,19 @@ def _cached_apply_hf_processor(
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
prompt_ids, mm_kwargs, _ = super()._cached_apply_hf_processor(
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
prompt_ids, mm_kwargs, mm_hashes, _ = super(
)._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
return_mm_hashes=return_mm_hashes,
)

# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_kwargs, True
return prompt_ids, mm_kwargs, mm_hashes, True


@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,
Expand Down
Loading