-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
[ROCm] (Deprecated) Enable AITER Tkw1 kernel #16418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Hongxia Yang <[email protected]> Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: kliuae <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Co-authored-by: tjtanaa [email protected] Signed-off-by: vllmellm <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request has merge conflicts that must be resolved before it can be |
cc @houseroad : can you help to verify correctness of the integration? |
Signed-off-by: tjtanaa <[email protected]>
@tjtanaa Looks good to me. |
elif use_fp8_w8a8: | ||
return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states, | ||
w1=w1, | ||
w2=w2, | ||
topk_weight=topk_weights, | ||
topk_ids=topk_ids, | ||
fc1_scale=w1_scale, | ||
fc2_scale=w2_scale, | ||
fc1_smooth_scale=None, | ||
fc2_smooth_scale=None, | ||
a16=False, | ||
activation=activation) | ||
|
||
return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states, | ||
w1=w1, | ||
w2=w2, | ||
topk_weights=topk_weights, | ||
topk_ids=topk_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For all branches except rocm_aiter_asm_moe_tkw1, topk_weights should be applied on each token, and pass a dummy topk_weights input. For example:
hidden_states = hidden_states * topk_weights
aiter_xxx_moe(...,
topk_weights=torch.ones_like(topk_weights),
)
Since _tkw1
is a customized kernel, you can directly pass the actual top-k weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can reference this line from the nv branch:
vllm/vllm/model_executor/layers/fused_moe/cutlass_moe.py
Lines 103 to 107 in cb391d8
if apply_router_weight_on_input: | |
assert topk == 1, \ | |
"apply_router_weight_on_input is only implemented for topk=1" | |
# TODO: this only works for topK=1, will need to update for topK>1 | |
a = a * topk_weights.to(out_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sijiac (I will apply your approach for BF16)
In our current setup, the other branches are never used. Only rocm_aiter_asm_moe_tkw1
is being invoked. The output is still incoherent and garbled.
Our environment variable setup:
#!/bin/bash
HF_TOKEN=<your-hf-token> \
VLLM_USE_V1=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ROCM_FP8_PADDING=0 \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_MOE=1 \
+ VLLM_ROCM_USE_AITER_FP8_CHANNEL_SCALED_MOE=1 \
VLLM_ROCM_USE_AITER_RMSNORM=0 \
VLLM_ROCM_USE_AITER_LINEAR=0 \
SAFETENSORS_FAST_GPU=1 \
python example.py > log6.txt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for bf16 moe layer, which branch / kernel do you use? Do you go with the Triton Fmoe path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sijiac For BF16 MoE Layer, we are using Triton Fmoe path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tjtanaa for the update! The use of topk_weights in bf16 can cause silent numeric collapse, I noticed this issue previously in the BF16 PR as well. To avoid this, we should add a dtype assert check in the sorting or fmoe kernel on the AITER side.
Are we now good to proceed with FP8 checkpoint support? After resolving this numeric issue, do we anticipate any other blockers for FP8 routed experts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sijiac The datatype is cast to (bfloat16) the same datatype as the hidden_states
here as shown in https://github.com/EmbeddedLLM/vllm/blob/449bdaf5a2ad4fbe0087fc69a939250478bf79b4/vllm/model_executor/models/llama4.py#L54
So, what we will do is that since all the AITER MoE kernels are expecting the topk_weights to be float32, we will cast it explicitly in the rocm_aiter_fused_moe.py
. This logic should be compatible with all other models.
We think we can proceed with FP8 checkpoint support. If the FP8 routed experts going to use TKW1 kernel, then we don't think there are any blockers after adding this FP8 checkpoint support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, what we will do is that since all the AITER MoE kernels are expecting the topk_weights to be float32, we will cast it explicitly in the rocm_aiter_fused_moe.py. This logic should be compatible with all other models.
The cast in rocm_aiter_fused_moe
is OK. I mean the kernel should not take bf16 input and give wrong result. It should trigger the assertion failure if bf16 is not supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sijiac Does FP8 routed experts
refer to Llama4MoE.custom_routing_function
or rocm_aiter_fused_moe
feature? Is it about whether the AITER can support topk > 1
(num_experts_per_tok > 1
) in rocm_aiter_fused_moe
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be rocm_aiter_fused_moe
.
Is it about whether the AITER can support topk > 1 (num_experts_per_tok > 1) in rocm_aiter_fused_moe?
It doesn't matter. We don't have the use-case of topk > 1 for llama models at this moment.
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
NOTE: This PR is deprecated as it is going to be broken down into two PRs, the first PR has to be closed first:
Description
This is a PR to enable "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8".
Issue has been resolved
Progress
The output of the model is incoherent. Find the steps to reproduce below.The output of the model is incoherent. Find the steps to reproduce below.The output of the model is incoherent. Find the steps to reproduce below.The output of the model is incoherent. Find the steps to reproduce below.Step to reproduce
fd04da
Updates 12 Apr 2025
Running V1 Engine, HipGraph, torch.compile, full 1 million context length. The output of the model is incoherent. Find the steps to reproduce below.
Command:
Output
llama4-fp8
2025-04-19:15:41:04 INFO [loggers.evaluation_tracker:272] Output path not provided, skipping saving results aggregated
vllm (pretrained=meta-llama/Llama-4-Maverick-17B-128E-Instruct,tensor_parallel_size=8,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
llama4-fp8
vllm (pretrained=meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8,tensor_parallel_size=8,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
SharedGPT Dataset