Skip to content

mtmd : add **vision** support for Mistral Small 3.1 #13231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 1, 2025

Conversation

ngxson
Copy link
Collaborator

@ngxson ngxson commented May 1, 2025

Tested with @unslothai dyn quant, works fine:

llama-mtmd-cli -m ../models/Mistral-Small-3.1-24B-Instruct-2503-UD-IQ2_M.gguf --mmproj ../models/Mistral-Small-3.1-24B-Instruct-2503/mmproj-model.gguf --image ../models/lenna.png -p "How many people do you see in the image? What is the composition of this image?" --chat-template mistral-v7

I see one person in the image. The composition of this image includes a woman who is the main subject. She is wearing a large, stylish hat adorned with feathers and decorative elements. The hat has a unique, rounded shape reminiscent of a bucket or a large bowler hat. The woman has long hair and is looking directly at the camera with a serious expression. The background is blurred, focusing attention on her, and there are hints of other people and elements that are not clearly distinguishable due to the focus on the woman in the foreground. The overall color tone of the image is warm, with a vintage feel, suggesting it might be an older photograph.


Prequant mmproj: https://huggingface.co/ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF

llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7

How it works (under the hood)?

Mistral Small use the same vision tower as Pixtral, with gelu activation instead of silu.

The main difference is in the projector. Mistral Small use a technique that they call "patch merger" to reduce number of output tokens. For example, for a 1024x1024 image, Pixtral need 64x64 = 4096 tokens, while Mistral Small reduce the number of tokens by a factor of 2 for each dimension, giving (64/2)x(64/2) = 1024 tokens.

This is done simply by using ggml_im2col to rearrange tokens in such a way that a row contains info from multiple patches, then rearranged tensor will be projected via a matrix to actually "merge" these info together. Or in simple term, this is a modified version of Conv2D


Test results:

OK:   llama-mtmd-cli ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0
OK:   llama-mtmd-cli ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M
OK:   llama-mtmd-cli ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0
OK:   llama-mtmd-cli ggml-org/gemma-3-4b-it-GGUF:Q4_K_M
OK:   llama-mtmd-cli guinmoon/MobileVLM-3B-GGUF:Q4_K_M
OK:   llama-mtmd-cli THUDM/glm-edge-v-5b-gguf:Q4_K_M
OK:   llama-mtmd-cli second-state/Llava-v1.5-7B-GGUF:Q2_K
OK:   llama-mtmd-cli cjpais/llava-1.6-mistral-7b-gguf:Q3_K
OK:   llama-mtmd-cli ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M
OK:   llama-mtmd-cli second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K
OK:   llama-mtmd-cli openbmb/MiniCPM-V-2_6-gguf:Q2_K
OK:   llama-mtmd-cli openbmb/MiniCPM-o-2_6-gguf:Q4_0
OK:   llama-mtmd-cli bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M
OK:   llama-mtmd-cli ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M
OK:   llama-mtmd-cli ggml-org/pixtral-12b-GGUF:Q4_K_M
OK:   llama-mtmd-cli ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF

@github-actions github-actions bot added examples python python script changes labels May 1, 2025
@ngxson
Copy link
Collaborator Author

ngxson commented May 1, 2025

@bartowski1182 r u ready to cook some quants :)

@ngxson ngxson requested a review from ggerganov May 1, 2025 11:22
Comment on lines +760 to +761
ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov @slaren just want to double check with you here, I suppose that ggml_im2col only care about the shape of kernel, not the actual data inside. Is this looks OK for you? (Or maybe there is another way?) Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be good. We can move this to ggml.h as a ggml_unfold() call for convenience, but first lets gather some feedback that this works correct.

@ngxson ngxson merged commit 8936784 into ggml-org:master May 1, 2025
51 checks passed
@bartowski1182
Copy link
Contributor

so with this @ngxson do I have to remake the whole thing from scratch or do i just add the new mmproj?

@ngxson
Copy link
Collaborator Author

ngxson commented May 1, 2025

You just need to add the mmproj, the text model is the same. I did use the prequant from unsloth without any modifications

gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 1, 2025
* origin/master:
sync : ggml
whisper : add check that target name exists (whisper/3103)
ggml : suppress Windows compiler warnings (whisper/3075)
mtmd : add **vision** support for Mistral Small 3.1 (ggml-org#13231)
arg : remove CURLINFO_EFFECTIVE_METHOD (ggml-org#13228)
llama-model : fix the reported size class for nomic-embed-text-v2-moe (ggml-org#13223)
sync : ggml
ggml : fix ggml_gallocr_ptr type (ggml/1205)
cuda : fix unused variable compile warning (whisper/0)
CUDA: batched+noncont MMQ, refactor bs>1 MoE code (ggml-org#13199)
arg : -hf do not fail if url mismatch (ggml-org#13219)
fix typo: `n_ctx_pre_seq` -> `n_ctx_per_seq` (ggml-org#13221)
convert : improve model arch handling (ggml-org#13122)
llava : remove duplicate include (ggml-org#13207)
common : add -jf / --json-schema-file flag (ggml-org#12011)
@Dampfinchen
Copy link

Dampfinchen commented May 1, 2025

Fantastic news, I was waiting for vision support of Small 3.1 in llama.cpp! IMO, Mistrals best model. Thanks a lot!

@stduhpf
Copy link
Contributor

stduhpf commented May 1, 2025

Seems to work fairly well, but for some reason the compute buffer for the mmproj is over 9GB. This makes it unusable on Vulkan without --no-mmproj-offload (9GB is way over the allocation limit for most drivers). Is it to be expected or is that a bug?

@BugReporterZ
Copy link

Seems to work fairly well, but for some reason the compute buffer for the mmproj is over 9GB.

Same issue on CUDA.

...
load_hparams: projector:          pixtral
load_hparams: has_llava_proj:     0
load_hparams: minicpmv_version:   0
load_hparams: proj_scale_factor:  0
load_hparams: n_wa_pattern:       0
load_hparams: use_silu:           0
load_hparams: use_gelu:           1
load_hparams: model size:         837.36 MiB
load_hparams: metadata size:      0.08 MiB
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 9248.06 MiB on device 0: cudaMalloc failed: out of memory
ggml_gallocr_reserve_n: failed to allocate CUDA0 buffer of size 9697292288
main: loading model: /home/user/llm/Mistral-Small-3.1-24B-Instruct-2503.Q6_K.gguf

@ngxson
Copy link
Collaborator Author

ngxson commented May 2, 2025

Did you also tried the Q8_0 adapter to see if it use less memory? IMO the vision encoder is quite big this time, I assume this is the biggest vision encoder that we ever supported

@BugReporterZ
Copy link

The Q8_0 adapter also wants to allocate 9248.06 MB of memory in addition to that of the text-only weights (18GB).

The expectation was that 24GB of VRAM would have been enough for 18GB + 0.8GB + some context memory.

@ngxson
Copy link
Collaborator Author

ngxson commented May 2, 2025

Hmm ok, the problem is that we always allocate enough memory for the worst case, which is an image of 1024x1024 (even though not everyone use an image that big ; input image with different ratios will still be resize to max 1024 for the longest dim)

Another idea is to allow reducing max image size, so it will allocate less memory while also improve speed on big image, since now they will be resized to a smaller size.

@ggerganov
Copy link
Member

The KQ tensor is very large: F32 [num_patches, num_patches, n_head]. Using flash attention should reduce the compute buffer dramatically:

diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index 7607d4e3a..61cb502f9 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -699,17 +699,15 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
             struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
 
             V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
-            V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
-
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
-            KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
+            V = ggml_permute(ctx0, V, 0, 2, 1, 3);
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
-            KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
-            KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            K = ggml_cast(ctx0, K, GGML_TYPE_F16);
+            V = ggml_cast(ctx0, V, GGML_TYPE_F16);
 
-            cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
+            cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, 1.0f / sqrtf((float)d_head), 0.0f, 0.0f);
+            ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
 
+            cur = ggml_reshape_2d(ctx0, cur, hidden_size, num_patches);
             cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur);
         }

However, the Metal backend currently requires the number of patches to be a multiple of 32 in order to use Flash Attention. It's something that we should fix now that we have #12850. As a temporary workaround, we can add:

diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index d92392edb..52b1511d5 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -1659,6 +1659,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 // TODO: not sure if it is worth adding kernels for this size
                 return false;
             }
+            if (op->src[1]->ne[1] % 32 != 0) {
+                // TODO: temporary requirement, can be avoided by dynamically padding the inputs
+                return false;
+            }
             if (op->src[0]->ne[0] == 576) {
                 // DeepSeek sizes
                 // TODO: disabled for now, until optmized

@ngxson ngxson mentioned this pull request May 6, 2025
4 tasks
@ddh0
Copy link
Contributor

ddh0 commented May 22, 2025

A Q6_K quantization of mistralai/Mistral-Small-3.1-24B-Instruct-2503, using the f16 mmproj, tries to allocate over 9GB to process a single image (as already mentioned). I am also noticing that the image is being converted to ~3000 tokens instead of the expected 1024 at most. Maybe this could be a hint as to what's going wrong?

Edit to add: I am the latest commit as of a few hours ago

@ddh0
Copy link
Contributor

ddh0 commented May 22, 2025

Here are the full verbose logs from a session with Mistral Small 2503, hopefully it can be helpful for debugging what's happening.
ms2503_cuda_log_output.8e186ef0.txt

The crash at the end is ultimately because of OOM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants