Skip to content

[ROCm] (Deprecated) Enable AITER Tkw1 kernel #16418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 23 commits into from

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Apr 10, 2025

NOTE: This PR is deprecated as it is going to be broken down into two PRs, the first PR has to be closed first:

  1. [ROCm] Add aiter tkw1 kernel for Llama4 fp8 #16727
  2. [FEAT] [ROCm]: AITER Fused MOE V1 Support #16752

Description

This is a PR to enable "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8".

Issue has been resolved

Progress

  • V0 eager mode
    • The output of the model is incoherent. Find the steps to reproduce below.
  • V0 hipgraph mode
    • The output of the model is incoherent. Find the steps to reproduce below.
  • V1 torch.compile eager mode
    • The output of the model is incoherent. Find the steps to reproduce below.
  • V1 torch.compile graph mode
    • The output of the model is incoherent. Find the steps to reproduce below.

Step to reproduce

  1. Install AITER of commit: fd04da
git clone --recursive https://github.com/ROCm/aiter.git
cd aiter
git checkout fd04da
python3 setup.py develop
# for quicker compilation of the AITER kernels
# you can run the following aiter unittests first
python3 op_tests/test_moe_sorting.py
python3 op_tests/test_moe_tkw1.py
python3 op_tests/test_moe.py

  1. Use this test script
#example.py

from vllm import LLM, SamplingParams

def test():

    prompts = [
        "The color of the sky is blue but sometimes it can also be",
        "The capital of France is",
        "What is batch inference?"
    ]
    sampling_params = SamplingParams(temperature=0.6,
                                     top_p=0.1,
                                     max_tokens=256)
    llm = LLM(
        model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
        tensor_parallel_size=8,
        max_model_len=1024,
        gpu_memory_utilization=0.7,
        enforce_eager=True,
        
    )

    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == "__main__":
    test()
  1. Command to run the script:
#!/bin/bash
HF_TOKEN=<your-hf-token> \
VLLM_USE_V1=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ROCM_FP8_PADDING=0 \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_MOE=1 \
VLLM_ROCM_USE_AITER_FP8_CHANNEL_SCALED_MOE=1  \
VLLM_ROCM_USE_AITER_RMSNORM=0 \
VLLM_ROCM_USE_AITER_LINEAR=0 \
SAFETENSORS_FAST_GPU=1 \
python example.py > log6.txt
  1. Example output from model:
Prompt: 'The color of the sky is blue but sometimes it can also be', 
Generated text: ' seen in different colors. What is the reason behind this phenomenon?\nThe color of the sky is primarily determined by the scattering of light by tiny particles in the atmosphere, known as aerosols. The most common color of the sky is blue, which is due to the scattering of Rayleigh blue light. However, the color of the sky can change depending on various factors.\n\n1. **Atmospheric Conditions**: The presence and concentration of aerosols, such as dust, water vapor, and pollutants, can affect the color of the sky. For example, a high concentration of aerosols can cause the sky to appear more gray or brown.\n2. **Sunlight and Angle of View**: The angle at which the sun is viewed and the amount of sunlight can also influence the color of the sky. During sunrise and sunset, the sky can appear orange or red due to the scattering of longer-wavelength light.\n3. **Weather Conditions**: Weather phenomena like clouds, fog, and haze can alter the color of the sky. For instance, a cloudy sky can appear white or gray, while a foggy sky can appear brown or yellow.\n4. **Altitude and Location**: The color of the sky can vary with altitude and location. At higher altitudes, the sky can appear bluer due to the'

Prompt: 'The capital of France is', 
Generated text: ' is located in the midst of the country, North Africa -IELA- is called in short. BisA. Paris B. Paris C. London D. Ottowa\nWASHINGTON, P. 5. Which of the following best parapines the best the correct option or combination of the following Paragraph, : the question and choose the correct read more and nean of the followingm the given options 1. Which of the following is synory of the following 1. The capital of France is located in is called - -- Paris 2. The capital of France is located in is called in short 3. The capital of France is Paris 4. The capital of France is is located in the midst of France 5. The capital of France is located in is called in short IELTS Prep. com_PADDING[5] 3. The capital of France is Paris 4. The capital of France is is located in the midst of France 5. The capital of jUMBES 1 and 3 2 and 3 3 and 4 1 and 4 2 and 4 1 and 3 2 and 4 1 and 2\nThe even paratheses of the following and group 4'

Prompt: 'What is batch inference?', 
Generated text: " | Definition and new  - a5Marketingiaja\n داخ 1 - copy\nWhat is batch processing? Batch processing is aavosep? \nBatch processing is a exchange of information about a group of data sets. In this article we'll signing, we'll explore the concept of batch processing, its benefits, and how it can be applied in various industries.\nWhat is batch processing?\nBatch processing refers to the processing of a group of data, such as a batch of transactions or a set of records, in a single operation. This approach is often used in data processing, where a large dataset is divided into smaller batches, and each batch is processed separately.\nBenefits of batch processing\nBatch processing is an efficient way to process large data sets, as it reduces the need for manual processing and minimizes the risk of errors. Here are some benefits of batch processing:\n1. **Improved efficiency**: Batch processing allows for the processing of large data sets in a single operation, reducing the time and effort required for data processing.\n he or she is\n**Batch processing vs. real-time processing**\nBatch processing is often compared to real-time processing, where data is processed as servers. While batch processing is suitable for processing large data sets, real-time/batch processing is ideal for applications that require immediate"

Updates 12 Apr 2025

Running V1 Engine, HipGraph, torch.compile, full 1 million context length. The output of the model is incoherent. Find the steps to reproduce below.

#example.py

from vllm import LLM, SamplingParams

def test():

    prompts = [
        "The color of the sky is blue but sometimes it can also be",
        "The capital of France is",
        "What is batch inference?"
    ]
    sampling_params = SamplingParams(temperature=0.6,
                                     top_p=0.1,
                                     max_tokens=256)
    llm = LLM(
        model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
        tensor_parallel_size=8,
        gpu_memory_utilization=0.9,
        enforce_eager=False,
    )

    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == "__main__":
    test()

Command:

#!/bin/bash
HF_TOKEN=<Your-HF-Token> \
VLLM_USE_V1=1 \
VLLM_USE_TRITON_FLASH_ATTN=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ROCM_FP8_PADDING=1 \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_MOE=1 \
VLLM_ROCM_USE_AITER_FP8_CHANNEL_SCALED_MOE=1  \
VLLM_ROCM_USE_AITER_RMSNORM=0 \
VLLM_ROCM_USE_AITER_LINEAR=0 \
SAFETENSORS_FAST_GPU=1 \
python example.py > aiterfp8_1_v1.txt

Output

Prompt: 'The color of the sky is blue but sometimes it can also be', 
Generated text: ', for example, red during a sunset. What can we conclude about the color of the sky? (a) It is sometimes blue. (b) It is always blue. (c) It is always either blue or red. (d) It is either blue or red. (e) It is sometimes either blue or red.\nThe best answer is (d) It is either blue or red.\nThe color of the sky is blue but sometimes it can also be, for example, red during a sunset. What can we conclude about the color of the sky? (a) It is sometimes blue. (b) It is always blue. (c) It is always either blue or red. (d) It is either blue or red. (e) It is sometimes either blue or red. The best answer is (d) It is either blue or red. \nThe best answer is (d) It is either blue or red. The statement "The color of the sky is blue but sometimes it can also be, for example, red during a sunset" implies that the sky can be blue and it can also be red at certain times (like during a sunset). This indicates that the color of the sky is not limited to just being blue; it'

Prompt: 'The capital of France is', 
Generated text: ' Paris. The capital of Germany is Berlin. The capital of Italy is Rome. The capital of Spain is Madrid. The capital of Portugal is Lisbon. The capital of the United Kingdom is London. The capital of the United States is Washington. The capital of Canada is Ottawa. The capital of Australia is Sydney. The capital of New Zealand is Auckland. The capital of South Africa is Pretoria. The capital of China is Beijing. The capital of Japan is Tokyo. The capital of India is New Delhi. The capital of Pakistan is Islamabad. The capital of Bangladesh is Dhaka. The capital of Sri Lanka is Colombo. The capital of Thailand is Bangkok. The capital of Vietnam is Hanoi. The capital of Cambodia is Phnompheny. The capital of Laos is Luangpou. The capital of Myanmar is Yangpou. The capital of Malaysia is Kuala Lumpur. The capital of Singapore is Singapore. The capital of Brunei is Brunei. The capital of Indonesia is Jakarta. The capital of Philippines is Manila. The capital of Hong Kong is Hong Kong. The capital of Macau is Macau. The capital of Taiwan is Taipei. The capital of North Korea is Pyongpou. The capital of South Korea is Seoul. The capital of Russia is Moscow. The capital of Ukraine,'

Prompt: 'What is batch inference?', 
Generated text: ' | 000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000057747000000000576470000576470qmsw000000000000000000000000000000000000000000000000000576949000000000000000000000000000576470000000000000000576709000576709000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000576709000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'

llama4-fp8

2025-04-19:15:41:04 INFO [loggers.evaluation_tracker:272] Output path not provided, skipping saving results aggregated
vllm (pretrained=meta-llama/Llama-4-Maverick-17B-128E-Instruct,tensor_parallel_size=8,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9303 ± 0.0070
strict-match 5 exact_match 0.9318 ± 0.0069
VLLM_USE_V1=1 \
VLLM_USE_TRITON_FLASH_ATTN=1 \
VLLM_ROCM_USE_AITER_FP8_TKW1_MOE=1 \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_LINEAR=0 \
VLLM_ROCM_USE_AITER_RMSNORM=0 \
lm_eval --model vllm --model_args pretrained=meta-llama/Llama-4-Maverick-17B-128E-Instruct,tensor_parallel_size=8,max_model_len=10000 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto \
> gsm8k-meta-llama_Llama-4-Maverick-17B-128E-Instruct-v1-llama4-fp8.log 2>&1

llama4-fp8

vllm (pretrained=meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8,tensor_parallel_size=8,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9257 ± 0.0072
strict-match 5 exact_match 0.9295 ± 0.0071
VLLM_USE_V1=1 \
VLLM_USE_TRITON_FLASH_ATTN=1 \
VLLM_ROCM_USE_AITER_FP8_TKW1_MOE=1 \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_LINEAR=0 \
VLLM_ROCM_USE_AITER_RMSNORM=0 \
lm_eval --model vllm --model_args pretrained=meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8,tensor_parallel_size=8,max_model_len=10000 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto \
> gsm8k-meta-llama_Llama-4-Maverick-17B-128E-Instruct-FP8-v1-llama4-fp8.log 2>&1

SharedGPT Dataset

Metric V1 no aiter eager V1 no aiter torch.compile % Gain vs Eager V1 aiter torch.compile % Gain vs Eager
Benchmark duration (s) 85.89 54.61 36.42% 46.01 46.44%
Request throughput (req/s) 11.64 18.31 57.30% 21.74 86.76%
Request goodput (req/s) 11.1 18.11 63.24% 21.54 94.14%
Output token throughput (tok/s) 2255.57 3373.53 49.52% 4018.06 78.11%
Total token throughput (tok/s) 4764.04 7319.07 53.65% 8701.11 82.59%
Mean TTFT (ms) 118.07 117.3 0.65% 101.97 13.63%
Median TTFT (ms) 75.39 64.56 14.37% 59.12 21.59%
P99 TTFT (ms) 681.25 907.69 −33.23% 752.99 −10.55%
Mean TPOT (ms) 39.58 27.68 30.05% 22.81 42.37%
Median TPOT (ms) 36.32 27.1 25.39% 22.29 38.64%
P99 TPOT (ms) 128.1 54.38 57.56% 44.54 65.22%
Mean ITL (ms) 36.96 26.25 28.97% 21.7 41.27%
Median ITL (ms) 33.76 21.4 36.61% 18.33 45.70%
P99 ITL (ms) 48.12 59.73 −24.17% 53.88 −11.97%

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link

mergify bot commented Apr 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@hongxiayang
Copy link
Collaborator

cc @houseroad : can you help to verify correctness of the integration?

@zjing14
Copy link

zjing14 commented Apr 10, 2025

@tjtanaa Looks good to me.

Comment on lines 477 to 494
elif use_fp8_w8a8:
return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
fc1_smooth_scale=None,
fc2_smooth_scale=None,
a16=False,
activation=activation)

return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids)
Copy link
Contributor

@sijiac sijiac Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all branches except rocm_aiter_asm_moe_tkw1, topk_weights should be applied on each token, and pass a dummy topk_weights input. For example:

hidden_states = hidden_states * topk_weights

aiter_xxx_moe(...,
    topk_weights=torch.ones_like(topk_weights),
)

Since _tkw1 is a customized kernel, you can directly pass the actual top-k weights.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can reference this line from the nv branch:

if apply_router_weight_on_input:
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
# TODO: this only works for topK=1, will need to update for topK>1
a = a * topk_weights.to(out_dtype)

Copy link
Contributor Author

@tjtanaa tjtanaa Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sijiac (I will apply your approach for BF16)

In our current setup, the other branches are never used. Only rocm_aiter_asm_moe_tkw1 is being invoked. The output is still incoherent and garbled.

Our environment variable setup:

#!/bin/bash
HF_TOKEN=<your-hf-token> \
VLLM_USE_V1=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ROCM_FP8_PADDING=0 \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_MOE=1 \
+ VLLM_ROCM_USE_AITER_FP8_CHANNEL_SCALED_MOE=1  \
VLLM_ROCM_USE_AITER_RMSNORM=0 \
VLLM_ROCM_USE_AITER_LINEAR=0 \
SAFETENSORS_FAST_GPU=1 \
python example.py > log6.txt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for bf16 moe layer, which branch / kernel do you use? Do you go with the Triton Fmoe path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sijiac For BF16 MoE Layer, we are using Triton Fmoe path.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tjtanaa for the update! The use of topk_weights in bf16 can cause silent numeric collapse, I noticed this issue previously in the BF16 PR as well. To avoid this, we should add a dtype assert check in the sorting or fmoe kernel on the AITER side.

Are we now good to proceed with FP8 checkpoint support? After resolving this numeric issue, do we anticipate any other blockers for FP8 routed experts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sijiac The datatype is cast to (bfloat16) the same datatype as the hidden_states here as shown in https://github.com/EmbeddedLLM/vllm/blob/449bdaf5a2ad4fbe0087fc69a939250478bf79b4/vllm/model_executor/models/llama4.py#L54

So, what we will do is that since all the AITER MoE kernels are expecting the topk_weights to be float32, we will cast it explicitly in the rocm_aiter_fused_moe.py. This logic should be compatible with all other models.

We think we can proceed with FP8 checkpoint support. If the FP8 routed experts going to use TKW1 kernel, then we don't think there are any blockers after adding this FP8 checkpoint support.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, what we will do is that since all the AITER MoE kernels are expecting the topk_weights to be float32, we will cast it explicitly in the rocm_aiter_fused_moe.py. This logic should be compatible with all other models.

The cast in rocm_aiter_fused_moe is OK. I mean the kernel should not take bf16 input and give wrong result. It should trigger the assertion failure if bf16 is not supported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sijiac Does FP8 routed experts refer to Llama4MoE.custom_routing_function or rocm_aiter_fused_moe feature? Is it about whether the AITER can support topk > 1 (num_experts_per_tok > 1) in rocm_aiter_fused_moe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be rocm_aiter_fused_moe.

Is it about whether the AITER can support topk > 1 (num_experts_per_tok > 1) in rocm_aiter_fused_moe?

It doesn't matter. We don't have the use-case of topk > 1 for llama models at this moment.

tjtanaa added 3 commits April 11, 2025 05:13
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
@hongxiayang
Copy link
Collaborator

#16727

@tjtanaa tjtanaa changed the title [ROCm] (WIP) Enable AITER Tkw1 kernel [ROCm] (Deprecated) Enable AITER Tkw1 kernel Apr 17, 2025
Copy link

mergify bot commented Apr 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 19, 2025
@tjtanaa
Copy link
Contributor Author

tjtanaa commented Apr 28, 2025

Closed as completed by
#16727
#16752

@tjtanaa tjtanaa closed this Apr 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants