Skip to content

Commit 39029a4

Browse files
committed
update embedding replace pattern
Signed-off-by: cascade812 <[email protected]>
1 parent 09caae6 commit 39029a4

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,20 @@ def replace_with_embedding_reduce_scatter_rmsnorm(
6565
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
6666
where, dim=0, world_size=tp_size, group_name=tp.unique_name)
6767

68-
# rmsnorm_result = torch.empty_like(reduce_scatter)
68+
rmsnorm_result = torch.empty_like(reduce_scatter)
6969
rmsnorm = torch.ops.higher_order.auto_functionalized(
7070
torch.ops._C.rms_norm.default,
71-
result=permute,
71+
result=rmsnorm_result,
7272
input=reduce_scatter,
7373
weight=arg3_1,
7474
epsilon=1e-5)
7575

76-
all_gather = torch.ops.vllm.all_gather.default(reduce_scatter,
76+
all_gather = torch.ops.vllm.all_gather.default(rmsnorm[1],
7777
dim=0,
7878
world_size=tp_size,
7979
group_name=tp.unique_name)
8080

81-
return rmsnorm[1], all_gather
81+
return all_gather, reduce_scatter
8282

8383

8484
def search_gemm_allreduce_rmsnorm(

0 commit comments

Comments
 (0)