Skip to content

Commit cde384c

Browse files
authored
[Model] support MiniMax-VL-01 model (#16328)
Signed-off-by: qingjun <[email protected]>
1 parent 96e06e3 commit cde384c

File tree

11 files changed

+954
-19
lines changed

11 files changed

+954
-19
lines changed

tests/models/decoder_only/vision_language/test_models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,19 @@
446446
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
447447
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
448448
),
449+
"minimax_vl_01": VLMTestInfo(
450+
models=["MiniMaxAI/MiniMax-VL-01"],
451+
prompt_formatter=lambda img_prompt: f"<beginning_of_sentence>user: {img_prompt} assistant:<end_of_sentence>", # noqa: E501
452+
img_idx_to_prompt=lambda _: "<image>",
453+
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
454+
max_model_len=8192,
455+
max_num_seqs=4,
456+
dtype="bfloat16",
457+
hf_output_post_proc=model_utils.minimax_vl_01_hf_output,
458+
patch_hf_runner=model_utils.minimax_vl_01_patch_hf_runner,
459+
auto_cls=AutoModelForImageTextToText,
460+
marks=[large_gpu_mark(min_gb=80)],
461+
),
449462
"molmo": VLMTestInfo(
450463
models=["allenai/Molmo-7B-D-0924"],
451464
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),

tests/models/decoder_only/vision_language/vlm_utils/model_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,14 @@ def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
229229
return output_ids, output_str, out_logprobs
230230

231231

232+
def minimax_vl_01_hf_output(hf_output: RunnerOutput,
233+
model: str) -> RunnerOutput:
234+
output_ids, output_str, out_logprobs = hf_output
235+
if output_str.endswith("<end_of_sentence>"):
236+
output_str = output_str.split("<end_of_sentence>")[0]
237+
return output_ids, output_str, out_logprobs
238+
239+
232240
####### Functions for converting image assets to embeddings
233241
def get_llava_embeddings(image_assets: _ImageAssets):
234242
return [asset.image_embeds for asset in image_assets]
@@ -627,6 +635,17 @@ def _generate(self, *args, image_sizes=None, **kwargs):
627635
return hf_model
628636

629637

638+
def minimax_vl_01_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
639+
orig_generate = hf_model.model.generate
640+
641+
def _generate(self, *args, image_sizes=None, **kwargs):
642+
return orig_generate(*args, decode_text=False, **kwargs)
643+
644+
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
645+
646+
return hf_model
647+
648+
630649
def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
631650
"""Patches and returns an instance of the HfRunner to use for Molmo."""
632651
hf_processor = hf_model.processor
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
from PIL import Image
5+
6+
from vllm.multimodal import MULTIMODAL_REGISTRY
7+
from vllm.multimodal.parse import ImageSize
8+
from vllm.multimodal.processing import BaseMultiModalProcessor
9+
10+
from ....conftest import _ImageAssets
11+
from ...utils import build_model_context
12+
13+
14+
@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"])
15+
# yapf: enable
16+
@pytest.mark.parametrize("num_imgs", [1, 2])
17+
def test_processor_override(
18+
image_assets: _ImageAssets,
19+
model_id: str,
20+
num_imgs: int,
21+
):
22+
ctx = build_model_context(
23+
model_id,
24+
mm_processor_kwargs=None,
25+
limit_mm_per_prompt={"image": num_imgs},
26+
)
27+
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
28+
prompt = "<image>" * num_imgs
29+
image = Image.new("RGB", size=(364, 364))
30+
mm_data = {"image": [image] * num_imgs}
31+
32+
processed_inputs = processor.apply(prompt, mm_data, {})
33+
image_placeholders = processed_inputs["mm_placeholders"]["image"]
34+
35+
assert len(image_placeholders) == num_imgs
36+
37+
38+
def _validate_image_prompt_replacements_one(
39+
processor: BaseMultiModalProcessor,
40+
num_imgs: int,
41+
failed_size_excs: list[tuple[ImageSize, Exception]],
42+
image_size: ImageSize,
43+
) -> None:
44+
prompt = "<image>" * num_imgs
45+
image = Image.new("RGB", size=image_size)
46+
mm_data = {"image": [image] * num_imgs}
47+
48+
try:
49+
processed_inputs = processor.apply(prompt, mm_data, {})
50+
51+
image_placeholders = processed_inputs["mm_placeholders"]["image"]
52+
assert len(image_placeholders) == num_imgs
53+
54+
except Exception as exc:
55+
failed_size_excs.append((image_size, exc))
56+
57+
58+
def _test_image_prompt_replacements(
59+
processor,
60+
*,
61+
num_imgs: int,
62+
image_sizes: list[ImageSize],
63+
) -> None:
64+
65+
failed_size_excs = list[tuple[ImageSize, Exception]]()
66+
67+
for size in image_sizes:
68+
_validate_image_prompt_replacements_one(processor, num_imgs,
69+
failed_size_excs, size)
70+
71+
if failed_size_excs:
72+
msg = "Found failing image sizes:" \
73+
+ "\n========\n".join(f"[{size}]\n{exc}"
74+
for size, exc in failed_size_excs)
75+
raise AssertionError(msg)
76+
77+
78+
@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"])
79+
@pytest.mark.parametrize("num_imgs", [1, 2])
80+
def test_processor_prompt_replacements_regression(model_id, num_imgs):
81+
ctx = build_model_context(
82+
model_id,
83+
mm_processor_kwargs=None,
84+
limit_mm_per_prompt={"image": num_imgs},
85+
)
86+
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
87+
88+
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
89+
(488, 183), (2560, 1669)]
90+
image_sizes = [
91+
size for w, h in image_ratios
92+
for size in [ImageSize(w, h), ImageSize(h, w)]
93+
]
94+
95+
_test_image_prompt_replacements(
96+
processor,
97+
num_imgs=num_imgs,
98+
image_sizes=image_sizes,
99+
)

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ def check_available_online(
337337
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
338338
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
339339
trust_remote_code=True),
340+
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
341+
trust_remote_code=True),
340342
"Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501
341343
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
342344
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",

vllm/model_executor/models/minimax_text_01.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import copy
44
import math
55
import re
6-
from typing import Dict, Iterable, List, Optional, Tuple, Union
6+
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
77

88
import torch
99
import torch.distributed
@@ -110,7 +110,17 @@ def _forward(
110110
variance = tensor_model_parallel_all_reduce(
111111
variance) / self.tp_world
112112
x = x * torch.rsqrt(variance + self.variance_epsilon)
113-
x = x.to(orig_dtype) * self.weight
113+
114+
weight = self.weight
115+
if x.size(-1) != self.weight.size(0):
116+
if self.weight.size(0) < x.size(-1):
117+
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
118+
full_weight = self.weight.repeat(repeat_count)
119+
weight = full_weight[:x.size(-1)]
120+
else:
121+
weight = self.weight[:x.size(-1)]
122+
123+
x = x.to(orig_dtype) * weight
114124
return x
115125

116126
def forward(
@@ -421,6 +431,10 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
421431
attn_metadata):
422432
hidden = []
423433
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
434+
if _prefill_idx >= len(attn_metadata.query_start_loc):
435+
break
436+
if _prefill_idx >= len(state_indices_tensor):
437+
break
424438
_start = attn_metadata.query_start_loc[_prefill_idx]
425439
_end = attn_metadata.query_start_loc[_prefill_idx + 1]
426440
slot_id = state_indices_tensor[_prefill_idx]
@@ -443,6 +457,10 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
443457
hidden.append(
444458
self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
445459
attn_metadata))
460+
461+
if not hidden:
462+
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
463+
446464
hidden = torch.concat(hidden, dim=0).contiguous()
447465
return hidden
448466

@@ -663,6 +681,9 @@ def __init__(
663681
self.shared_moe = False
664682

665683
shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
684+
if isinstance(shared_intermediate, list):
685+
shared_intermediate = shared_intermediate[
686+
layer_id] if layer_id < len(shared_intermediate) else 0
666687
if shared_intermediate > 0:
667688
self.shared_moe = True
668689
self.shared_mlp = MiniMaxText01MLP(
@@ -875,6 +896,8 @@ def _clear_prefill_cache(self, attn_metadata,
875896

876897
slots_to_clear = []
877898
for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)):
899+
if _prefill_id >= len(seq_id_map):
900+
break
878901
seq_id = seq_id_map[_prefill_id]
879902
if attn_metadata.context_lens_tensor[
880903
_prefill_id] == 0 and seq_id in seq_to_slot_maps:
@@ -886,13 +909,18 @@ def _clear_prefill_cache(self, attn_metadata,
886909
dtype=torch.long)
887910
minimax_cache_tensors[:, slots_tensor, ...] = 0
888911

912+
def get_input_embeddings(
913+
self,
914+
input_ids: torch.Tensor,
915+
) -> torch.Tensor:
916+
return self.embed_tokens(input_ids)
917+
889918
def forward(self,
890919
input_ids: Optional[torch.Tensor],
891920
positions: torch.Tensor,
892-
kv_caches: List[torch.Tensor],
893-
intermediate_tensors=None,
921+
intermediate_tensors: Optional[IntermediateTensors] = None,
894922
inputs_embeds: Optional[torch.Tensor] = None,
895-
**kwargs) -> torch.Tensor:
923+
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
896924
forward_context = get_forward_context()
897925
attn_metadata = forward_context.attn_metadata
898926
if attn_metadata is None:
@@ -901,6 +929,7 @@ def forward(self,
901929
kwargs["request_ids_to_seq_ids"] = {}
902930
if "finished_requests_ids" not in kwargs:
903931
kwargs["finished_requests_ids"] = []
932+
904933
(
905934
minimax_cache_tensors,
906935
state_indices_tensor,
@@ -922,15 +951,11 @@ def forward(self,
922951
hidden_states = intermediate_tensors["hidden_states"]
923952
residual = intermediate_tensors["residual"]
924953

925-
kv_cache_index = 0
926954
minimax_cache_index = 0
927955
attn_metadata.rotary_emb = self.rotary_emb
928956
for i in range(self.start_layer, self.end_layer):
929957
layer = self.layers[i]
930958
_caches = None
931-
if isinstance(layer.self_attn, MiniMaxText01Attention):
932-
_caches = kv_caches[kv_cache_index]
933-
kv_cache_index += 1
934959
if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
935960
current_state_layer = minimax_cache_index
936961
_caches = minimax_cache_params.at_layer_idx(
@@ -1009,15 +1034,20 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
10091034
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
10101035
batch_size)
10111036

1037+
def get_input_embeddings(
1038+
self,
1039+
input_ids: torch.Tensor,
1040+
) -> torch.Tensor:
1041+
return self.model.get_input_embeddings(input_ids)
1042+
10121043
def forward(self,
10131044
input_ids: torch.Tensor,
10141045
positions: torch.Tensor,
10151046
intermediate_tensors: Optional[IntermediateTensors] = None,
10161047
inputs_embeds: Optional[torch.Tensor] = None,
10171048
**kwargs) -> torch.Tensor:
1018-
hidden_states = self.model(input_ids, positions, self.kv_cache,
1019-
intermediate_tensors, inputs_embeds,
1020-
**kwargs)
1049+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
1050+
inputs_embeds, **kwargs)
10211051

10221052
return hidden_states
10231053

@@ -1043,8 +1073,9 @@ def make_empty_intermediate_tensors(
10431073
})
10441074

10451075
def load_weights(self, weights: Iterable[Tuple[str,
1046-
torch.Tensor]]) -> None:
1076+
torch.Tensor]]) -> Set[str]:
10471077
params_dict = dict(self.named_parameters())
1078+
loaded_params: Set[str] = set()
10481079

10491080
def which_layer(name: str) -> int:
10501081
if "layers" in name:
@@ -1108,6 +1139,7 @@ def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor,
11081139
weight_name,
11091140
expert_id=expert_id,
11101141
shard_id=shard_id)
1142+
loaded_params.add(name)
11111143
break
11121144
else:
11131145
if is_pp_missing_parameter(name, self):
@@ -1117,6 +1149,7 @@ def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor,
11171149
default_weight_loader)
11181150
weight_loader = weight_loader_with_alias(name)(weight_loader)
11191151
weight_loader(param, loaded_weight)
1152+
loaded_params.add(name)
11201153
return
11211154

11221155
def is_shared_mlp_weight(name: str) -> bool:
@@ -1154,6 +1187,7 @@ def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor,
11541187
else:
11551188
raise AssertionError(
11561189
"MLP weight not in [gate_up_proj, down_proj]")
1190+
loaded_params.add(name)
11571191
return
11581192

11591193
def is_mha_weight(name: str) -> bool:
@@ -1170,6 +1204,7 @@ def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor,
11701204
MiniMaxText01LinearAttention.weight_direct_load)
11711205
weight_loader = weight_loader_with_alias(name)(weight_loader)
11721206
weight_loader(param, loaded_weight)
1207+
loaded_params.add(name)
11731208
return
11741209

11751210
def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
@@ -1194,6 +1229,7 @@ def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
11941229
default_weight_loader)
11951230
weight_loader = weight_loader_with_alias(name)(weight_loader)
11961231
weight_loader(param, loaded_weight, shard_id)
1232+
loaded_params.add(name)
11971233
break
11981234
else:
11991235
if is_pp_missing_parameter(name, self):
@@ -1204,6 +1240,7 @@ def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
12041240
default_weight_loader)
12051241
weight_loader = weight_loader_with_alias(name)(weight_loader)
12061242
weight_loader(param, loaded_weight)
1243+
loaded_params.add(name)
12071244
return
12081245

12091246
def is_layer_norm_weight(name: str) -> bool:
@@ -1219,6 +1256,7 @@ def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor,
12191256
default_weight_loader)
12201257
weight_loader = weight_loader_with_alias(name)(weight_loader)
12211258
weight_loader(param, loaded_weight)
1259+
loaded_params.add(name)
12221260
return
12231261

12241262
def load_basic_weight(name: str, loaded_weight: torch.Tensor,
@@ -1230,6 +1268,7 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,
12301268
default_weight_loader)
12311269
weight_loader = weight_loader_with_alias(name)(weight_loader)
12321270
weight_loader(param, loaded_weight)
1271+
loaded_params.add(name)
12331272
return
12341273

12351274
for name, loaded_weight in weights:
@@ -1258,4 +1297,4 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,
12581297
continue
12591298

12601299
load_basic_weight(name, loaded_weight, self)
1261-
return
1300+
return loaded_params

0 commit comments

Comments
 (0)