Skip to content

Vulkan: Support fp32 accumulator in quantized matmul to fix GLM4-32B incoherence #13607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 19, 2025

Conversation

0cc4m
Copy link
Collaborator

@0cc4m 0cc4m commented May 17, 2025

Currently we only support fp16 accumulators in quantized mul mat and this leads to numerical issues with GLM4-32B, some of which have been addressed with the model precision parameter, which was getting ignored by the Vulkan backend.

However, this is draft because it solves only a part of the problem. It seems there are still tensors that fail with fp16 accumulators and have not been set to GGML_PREC_F32 in llama.cpp. I'm not sure why this only affects the Vulkan backend. Forcing all tensors to run with fp32 accumulation resolves the incoherence.

I don't have a good way of finding these all of these tensors. The problem seems to be infinity values in the result. Using the internal results checker of the Vulkan backend gives me the first of the problematic tensors: blk.1.attn_output.weight.

The trouble with GLM4 seems to be unusually large values. I'm opening this in the hopes that someone can help me figure out the rest of the problem, especially why this doesn't affect CUDA or Metal.

ERROR: Invalid value in MUL_MAT i3=0 i2=0 i1=295 i0=3321 result=-inf correct=-68706.3 avg_err=0.0281893
tensor=0x7c0f1a8ac4c0 tensor->name=node_66 tensor->type: f32 ne0=6144 nb0=4 ne1=512 nb1=24576 ne2=1 nb2=12582912 ne3=1 nb3=12582912 offset=0
src0=0x5bf32bb72160 src0->name=blk.1.attn_output.weight op=NONE type=q4_0 ne0=6144 nb0=18 ne1=6144 nb1=3456 ne2=1 nb2=21233664 ne3=1 nb3=21233664 offset=0
src1=0x7c0f1a8ac350 src1->name=kqv_out-1 op=CONT type=f32 ne0=6144 nb0=4 ne1=512 nb1=24576 ne2=1 nb2=12582912 ne3=1 nb3=12582912 offset=0
First error: result=-141.625 correct=-78.4177 i3=0 i2=0 i1=0 i0=114

Result:
             290     291     292     293     294     295     296     297     298     299
   3316: -564.50  3044.00  2774.00   37.53 -856.00  3146.00 -652.50  2886.00   25.59  292.25
   3317:  7000.00 -6492.00 -13120.00  6288.00  6472.00 -11536.00  7540.00 -9336.00  6292.00 -1033.00
   3318: -896.50 -4044.00  7040.00 -1930.00  1650.00  607.50 -1560.00 -857.00 -2019.00  6472.00
   3319:  1666.00 -7552.00  2744.00  133.62  4588.00 -4332.00  1145.00 -4972.00   35.28  7120.00
   3320:  1283.00  5496.00 -1505.00  2428.00 -968.50  2184.00  1628.00  3756.00  2494.00 -2484.00
   3321:  20160.00 -42592.00 -36096.00  16096.00  31488.00    -inf  22592.00 -47328.00  16320.00  11688.00
   3322:  6868.00  2464.00 -1387.00  7196.00  5392.00 -2434.00  7032.00  1039.00  7232.00  3626.00
   3323: -19840.00 -4280.00 -4082.00 -19856.00 -17744.00  5640.00 -19696.00 -3396.00 -19824.00 -17040.00
   3324:  5404.00 -16976.00 -6708.00  2666.00  9808.00 -15056.00  5020.00 -15328.00  2508.00  7612.00
   3325:  4688.00 -7400.00 -10216.00  3648.00  5172.00 -10080.00  5036.00 -9320.00  3646.00   22.81

Correct:
             290     291     292     293     294     295     296     297     298     299
   3316: -570.66  3041.39  2781.36   34.57 -853.04  3156.60 -657.01  2892.20   24.63  292.58
   3317:  6978.54 -6484.46 -13017.67  6266.93  6449.19 -11507.74  7519.72 -9342.18  6274.81 -1018.87
   3318: -889.91 -4043.65  7033.03 -1927.52  1673.27  614.18 -1548.47 -865.19 -2015.27  6494.03
   3319:  1666.02 -7551.79  2731.42  132.06  4592.75 -4339.87  1142.80 -4980.06   32.58  7129.88
   3320:  1278.52  5510.90 -1512.11  2424.03 -973.02  2192.46  1616.08  3763.04  2492.45 -2482.38
   3321:  20265.07 -42728.43 -36189.09  15949.43  31408.91 -68706.27  22610.84 -47450.00  16149.71  11702.63
   3322:  6874.31  2487.35 -1370.15  7227.34  5416.84 -2413.33  7039.71  1047.16  7246.70  3627.93
   3323: -20007.28 -4222.13 -4072.23 -20024.71 -17799.77  5710.86 -19857.95 -3403.51 -20001.98 -17118.29
   3324:  5380.83 -17033.93 -6779.77  2656.03  9755.27 -15092.67  4998.03 -15351.77  2491.30  7547.54
   3325:  4684.53 -7444.94 -10239.35  3656.67  5151.60 -10132.43  5022.94 -9403.43  3652.60   24.86

MUL_MAT gpu=0
 NONE gpu=0
 CONT gpu=0
  PERMUTE gpu=0
   MUL_MAT gpu=0
    VIEW gpu=0
     NONE gpu=0
    SOFT_MAX gpu=0
     MUL_MAT gpu=0
      VIEW gpu=0
       NONE gpu=0
      PERMUTE gpu=0
       ROPE gpu=0
        RESHAPE gpu=0
         MUL_MAT gpu=0
          NONE gpu=0
          MUL gpu=0
        NONE gpu=0
     NONE gpu=0

@0cc4m 0cc4m requested a review from jeffbolznv May 17, 2025 15:25
@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels May 17, 2025
@jeffbolznv
Copy link
Collaborator

I can try to repro this on Monday. I wonder if differences in split_k/stream_k might explain the difference vs other backends? If the partial results are all in range then the resolve can accumulate them at f32.

@0cc4m
Copy link
Collaborator Author

0cc4m commented May 18, 2025

@JohannesGaessler @ggerganov Did you see any similar problems with CUDA/Metal? GLM4 seems to produce values outside of the range of 16-bit floats, so 32-bit accumulators are needed, even in tensors that have not yet been set to that precision mode.

That they have not yet been set to F32 precision suggests that either the issues weren't noticed (and are covered up by some feature in your backends, like clamping?), or that they don't occur in CUDA/Metal.

@JohannesGaessler
Copy link
Collaborator

In CUDA the only place where FP16 accumulators are potentially used is cuBLAS GEMM, but only if the precision is GGML_PREC_DEFAULT. The custom matrix multiplication kernels that I wrote for ggml all use FP32 accumulators. For FlashAttention KQ is either unconditionally using FP32 accumulators or only using FP16 accumulators after checking the precision. There are kernels which unconditionally use FP16 accumulators for VKQ but this seems to be much less of an issue compared to KQ.

@jeffbolznv
Copy link
Collaborator

jeffbolznv commented May 18, 2025

Isn't CUDA using int8 MMA for most of these, then converting to fp32 for the accumulators?

fp16*fp16->f32 MMA is slower than fp16*fp16->fp16 on Geforce (in addition to increasing register usage, which may require some retuning), so this change as-is is likely to cause perf regressions.

@0cc4m
Copy link
Collaborator Author

0cc4m commented May 18, 2025

fp16_fp16->f32 MMA is slower than fp16_fp16->fp16 on Geforce (in addition to increasing register usage, which may require some retuning), so this change as-is is likely to cause perf regressions.

This change just allows f32 accumulation if requested, it doesn't enforce it.

@jeffbolznv
Copy link
Collaborator

I was concerned that we'd be using the f32 path much more often. But I looked at a few models and it seems like a small number of layers use f32, the ones that do are relatively small, and they go away with FA enabled. So I'm less concerned about performance now.

So far I haven't been able to repro the remaining issues with fp16 accumulators. Which exact model are you using and what command line?

@JohannesGaessler
Copy link
Collaborator

Isn't CUDA using int8 MMA for most of these, then converting to fp32 for the accumulators?

Integer accumulators are used within a block of quantized data, then the floating point scales are applied to convert the integers to FP32, then those FP32 partial results are added to the FP32 accumulators of the output tiles.

@0cc4m
Copy link
Collaborator Author

0cc4m commented May 18, 2025

I've been using a q4_0 quant of GLM4-32B. I didn't set any special parameters apart from offloading fully. I tested with and without flash attention and the output was always GGGGGGGGGGG when the prompt is large enough to trigger matmul shaders.

@jeffbolznv
Copy link
Collaborator

With your change applied, I haven't been able to repro any corruption using the Q4_0 or Q4_K model. I've tried scalar/coopmat1/coopmat2 paths. My command line is: llama-cli -no-cnv -p "The Peninsular War (1807–1814) was fought in the Iberian Peninsula by Portugal, Spain and the United Kingdom against the invading and occupying forces of the First French Empire during the Napoleonic Wars." -c 2048 -n 150 --ignore-eos -ngl 99 --seed 0 -m C:\models\GLM-4-32B-0414-Q4_0.gguf

@0cc4m
Copy link
Collaborator Author

0cc4m commented May 18, 2025

The prompt is too short. Not sure why exactly, but it only triggers with long prompts.

@0cc4m
Copy link
Collaborator Author

0cc4m commented May 18, 2025

For me it starts somewhere between 300 and 380 tokens.

@jeffbolznv
Copy link
Collaborator

I was able to repro with a longer prompt. IMO the right thing to do about it is to mark the mul_mat as f32 for this model. It's the mul_mat at the top of llm_graph_context::build_lora_mm.

I think Metal always uses fp32 accumulators, and cuda uses the int8 path with fp32 accumulators, so I think it's expected that this is only affecting Vulkan?

@0cc4m
Copy link
Collaborator Author

0cc4m commented May 18, 2025

Generally setting the first mul_mat in build_lora_mm to F32 works. This is similar to #13101 , but much more broad. Is there a way to tell whether all of them are required or not? I'm not really familiar with how models are set up and what specific functions in llama-graph do.

Copy link
Collaborator

@jeffbolznv jeffbolznv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change LGTM. We can try to change the precision for GLM4 in a separate change.

@ggerganov
Copy link
Member

ggerganov commented May 19, 2025

I think the correct way to think about is that the CPU implementation is the reference and it uses F32 accumulators for all operations. So the default for the backends would be to also use F32 accumulators. We should probably add an option in ggml to set F16 precision/range optionally by extending the enum ggml_prec. But this is not very urgent because we can see that using F16 accumulators breaks a lot of models.

@0cc4m
Copy link
Collaborator Author

0cc4m commented May 19, 2025

That is not a good idea for performance reasons. CPUs prefer FP32 calculations, GPUs do not. As Jeff Bolz said:

fp16*fp16->f32 MMA is slower than fp16*fp16->fp16 on Geforce (in addition to increasing register usage, which may require some retuning), so this change as-is is likely to cause perf regressions.

Isn't that how the current precision selection was created? GGML_PREC_DEFAULT for "lower precision is fine" and GGML_PREC_F32 for 32-bit float precision needed. That seems sensible, backends that always use FP32 accumulators can just ignore the selector and others have to take it into consideration.

@0cc4m 0cc4m marked this pull request as ready for review May 19, 2025 06:53
@ggerganov
Copy link
Member

ggerganov commented May 19, 2025

Isn't that how the current precision selection was created?

If I remember correctly, yes it was. But I guess the primary concern should be to have correct results. Performance considerations come after that.

That is not a good idea for performance reasons. CPUs prefer FP32 calculations, GPUs do not. As Jeff Bolz said:

fp16fp16->f32 MMA is slower than fp16fp16->fp16 on Geforce (in addition to increasing register usage, which may require some retuning), so this change as-is is likely to cause perf regressions.

Indeed, even the Metal backend can be noticeably faster with F16 accumulators (see #10220). But this performance comes with some quality loss as can be seen by the PPL figures there.

That seems sensible, backends that always use FP32 accumulators can just ignore the selector and others have to take it into consideration.

Thinking some more about this, if an LLM was trained with BF16 range then the best thing to do technically is to run all computations with BF16. So maybe the correct logic is:

  • All ggml backends work with FP32 accumulators by default
  • Backends can support different accumulators (e.g. FP16, BF16, FP32)
  • User code (for example llama.cpp) can recommend an FP type based on the model / user settings (e.g. ggml_..._set_prec(FP16/BF16/...))
  • If the backend supports the recommended FP type - it uses it. Otherwise it fallbacks to another compatible type or to FP32

Btw, I'm just thinking out loud here - not sure if this makes complete sense, so feel free to object.

@LostRuins
Copy link
Collaborator

I wonder if it's possible to know the dynamic range of the values for each model beforehand.

Model might be trained with bf16, but if it turns out everything fits within -100.0 to 100.0 then it'll look kinda silly to enforce bf16.

Practical example: TAESD weight values all fit within the range of -5.0 to 5.0.
I can use fp8 e3m4 for them to great effect and with minimal degradation.

@jeffbolznv
Copy link
Collaborator

fp16 seems to be working pretty well - other than attention, GLM4 is the only model that needs to force fp32 accumulators? IMO it's unjustified to alter the defaults due to one model.

The slowdown for using fp32 accumulators on Geforce is 2x (not even counting increase in register usage), much larger than you saw in Metal. If we had to default to fp32, we'd probably need to switch to int8 MMA like CUDA and throw away all existing investment/tuning in the current shaders. And other backends would all be faced with similar questions.

@0cc4m
Copy link
Collaborator Author

0cc4m commented May 19, 2025

I haven't seen effects of small drops in perplexity, so I think performance is more important. As long as we properly mark tensors that require F32 accumulation, I think the current system is fine.

@0cc4m 0cc4m merged commit 8960efd into master May 19, 2025
45 of 46 checks passed
@0cc4m 0cc4m deleted the 0cc4m/vulkan-mmq-fp32-acc-glm4 branch May 19, 2025 15:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants