Skip to content

Commit 13cd0bd

Browse files
committed
Add async tp pass
Signed-off-by: cascade812 <[email protected]>
1 parent 246e3e0 commit 13cd0bd

File tree

7 files changed

+361
-4
lines changed

7 files changed

+361
-4
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ steps:
309309
- pytest -v -s compile/test_pass_manager.py
310310
- pytest -v -s compile/test_fusion.py
311311
- pytest -v -s compile/test_sequence_parallelism.py
312+
- pytest -v -s compile/test_async_tp.py
312313

313314
- label: PyTorch Fullgraph Smoke Test # 9min
314315
mirror_hardwares: [amdexperimental, amdproduction]

tests/compile/test_async_tp.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import torch
5+
6+
import vllm.envs as envs
7+
from vllm.compilation.collective_fusion import AsyncTPPass
8+
from vllm.compilation.fx_utils import (find_specified_fn,
9+
find_specified_fn_maybe)
10+
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
11+
PassConfig, VllmConfig)
12+
from vllm.distributed import (tensor_model_parallel_all_gather,
13+
tensor_model_parallel_reduce_scatter)
14+
from vllm.distributed.parallel_state import (init_distributed_environment,
15+
initialize_model_parallel)
16+
from vllm.platforms import current_platform
17+
from vllm.utils import update_environment_variables
18+
19+
from ..utils import multi_gpu_test
20+
from .backend import TestBackend
21+
22+
prompts = [
23+
"Hello, my name is",
24+
"The president of the United States is",
25+
"The capital of France is",
26+
"The future of AI is",
27+
]
28+
29+
30+
class TestMMRSModel(torch.nn.Module):
31+
32+
def __init__(self, hidden_size=16):
33+
super().__init__()
34+
self.hidden_size = hidden_size
35+
self.gate_proj = torch.nn.Parameter(torch.empty(
36+
(self.hidden_size * 2, hidden_size)),
37+
requires_grad=False)
38+
# Initialize weights
39+
torch.nn.init.normal_(self.gate_proj, std=0.02)
40+
41+
def forward(self, hidden_states):
42+
"""
43+
Forward pass implementing the mm + reduce scatter in the FX graph
44+
45+
"""
46+
# Reshape input
47+
view = hidden_states.reshape(-1, self.hidden_size)
48+
49+
# matrix multiplication
50+
permute = self.gate_proj.permute(1, 0)
51+
mm = torch.mm(view, permute)
52+
reduce_scatter = tensor_model_parallel_reduce_scatter(mm, dim=0)
53+
return reduce_scatter
54+
55+
def ops_in_model_before(self):
56+
return [torch.ops.vllm.reduce_scatter.default]
57+
58+
def ops_in_model_after(self):
59+
return [torch.ops.symm_mem.fused_matmul_reduce_scatter.default]
60+
61+
62+
class TestAGMMModel(torch.nn.Module):
63+
64+
def __init__(self, hidden_size=16):
65+
super().__init__()
66+
self.hidden_size = hidden_size
67+
self.weight = torch.nn.Parameter(torch.empty(
68+
(hidden_size, hidden_size)),
69+
requires_grad=False)
70+
# Initialize weights
71+
torch.nn.init.normal_(self.weight, std=0.02)
72+
73+
def forward(self, hidden_states):
74+
"""
75+
Forward pass implementing the mm + all gather in the FX graph
76+
"""
77+
# Reshape input
78+
view = hidden_states.reshape(-1, self.hidden_size)
79+
all_gather = tensor_model_parallel_all_gather(view, dim=0)
80+
permute = self.weight.permute(1, 0)
81+
mm = torch.mm(all_gather, permute)
82+
return mm
83+
84+
def ops_in_model_before(self):
85+
return [torch.ops.vllm.all_gather.default]
86+
87+
def ops_in_model_after(self):
88+
return [torch.ops.symm_mem.fused_all_gather_matmul.default]
89+
90+
91+
@multi_gpu_test(num_gpus=2)
92+
@pytest.mark.parametrize("test_model", ["TestMMRSModel", "TestAGMMModel"])
93+
@pytest.mark.parametrize("batch_size", [8])
94+
@pytest.mark.parametrize("seq_len", [16])
95+
@pytest.mark.parametrize("hidden_size", [16])
96+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
97+
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
98+
reason="Only test on CUDA")
99+
def test_sequence_parallelism_pass(test_model: str, batch_size: int,
100+
seq_len: int, hidden_size: int,
101+
dtype: torch.dtype):
102+
num_processes = 2
103+
104+
def run_torch_spawn(fn, nprocs):
105+
# need to use torch.mp.spawn otherwise will have problems with
106+
# torch.distributed and cuda
107+
torch.multiprocessing.spawn(fn,
108+
args=(num_processes, test_model,
109+
batch_size, seq_len, hidden_size,
110+
dtype),
111+
nprocs=nprocs)
112+
113+
run_torch_spawn(async_tp_pass_on_test_model, num_processes)
114+
115+
116+
def async_tp_pass_on_test_model(local_rank: int, world_size: int,
117+
test_model: str, batch_size: int, seq_len: int,
118+
hidden_size: int, dtype: torch.dtype):
119+
current_platform.seed_everything(0)
120+
121+
device = torch.device(f"cuda:{local_rank}")
122+
torch.cuda.set_device(device)
123+
torch.set_default_device(device)
124+
torch.set_default_dtype(dtype)
125+
126+
update_environment_variables({
127+
'RANK': str(local_rank),
128+
'LOCAL_RANK': str(local_rank),
129+
'WORLD_SIZE': str(world_size),
130+
'MASTER_ADDR': 'localhost',
131+
'MASTER_PORT': '12345',
132+
})
133+
134+
# initialize distributed
135+
init_distributed_environment()
136+
initialize_model_parallel(tensor_model_parallel_size=world_size)
137+
138+
# configure vllm config for SequenceParallelismPass
139+
vllm_config = VllmConfig()
140+
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
141+
enable_async_tp=True, ), )
142+
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
143+
144+
# this is a fake model name to construct the model config
145+
# in the vllm_config, it's not really used.
146+
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
147+
vllm_config.model_config = ModelConfig(model=model,
148+
task="auto",
149+
tokenizer=model,
150+
tokenizer_mode="auto",
151+
trust_remote_code=True,
152+
dtype=dtype,
153+
seed=42)
154+
155+
async_tp_pass = AsyncTPPass(vllm_config)
156+
backend = TestBackend(async_tp_pass)
157+
158+
if test_model == "TestMMRSModel":
159+
model = TestMMRSModel(hidden_size)
160+
elif test_model == "TestAGMMModel":
161+
model = TestAGMMModel(hidden_size)
162+
else:
163+
raise ValueError(f"Unknown model: {test_model}")
164+
165+
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
166+
dtype=dtype,
167+
requires_grad=False)
168+
169+
compiled_model = torch.compile(model, backend=backend)
170+
compiled_model(hidden_states)
171+
172+
# Check substitution worked
173+
pre_nodes = backend.graph_pre_pass.nodes
174+
post_nodes = backend.graph_post_pass.nodes
175+
176+
# In pre-nodes, all reduce should exist,
177+
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
178+
for op in model.ops_in_model_before():
179+
find_specified_fn(pre_nodes, op)
180+
for op in model.ops_in_model_after():
181+
assert find_specified_fn_maybe(pre_nodes, op) is None
182+
183+
# In post-nodes, fused_matmul_reduce_scatter or \
184+
# fused_all_gather_matmul should exist
185+
for op in model.ops_in_model_after():
186+
find_specified_fn(post_nodes, op)

vllm/compilation/collective_fusion.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from typing import Optional, Tuple
3+
4+
import torch
5+
import torch._inductor.pattern_matcher as pm
6+
import torch.fx as fx
7+
from torch._inductor.pattern_matcher import PatternMatcherPass
8+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
9+
10+
from vllm.config import VllmConfig
11+
from vllm.distributed import get_tp_group
12+
from vllm.distributed.parallel_state import (
13+
get_tensor_model_parallel_world_size)
14+
from vllm.logger import init_logger
15+
16+
from .vllm_inductor_pass import VllmInductorPass
17+
18+
logger = init_logger(__name__)
19+
20+
21+
class BasePattern:
22+
23+
def __init__(self, dtype: torch.dtype, device: str):
24+
self.dtype = dtype
25+
self.device = device
26+
self.tp = get_tp_group()
27+
self.tp_size = get_tensor_model_parallel_world_size()
28+
29+
30+
class GEMMReduceScatterPattern(BasePattern):
31+
32+
def get_inputs(self):
33+
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
34+
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
35+
return [mul, mm_weight]
36+
37+
def register(self, pm_pass: PatternMatcherPass):
38+
39+
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor):
40+
mm = torch.ops.aten.mm.default(mul, mm_weight)
41+
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
42+
mm,
43+
dim=0,
44+
world_size=self.tp_size,
45+
group_name=self.tp.unique_name)
46+
return reduce_scatter
47+
48+
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor):
49+
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
50+
mul,
51+
mm_weight,
52+
"avg",
53+
scatter_dim=0,
54+
group_name=self.tp.device_group.group_name,
55+
)
56+
57+
return gemm_rs
58+
59+
pm.register_replacement(pattern, replacement, self.get_inputs(),
60+
pm.fwd_only, pm_pass)
61+
62+
63+
class AllGatherGEMMPattern(BasePattern):
64+
65+
def get_inputs(self):
66+
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
67+
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
68+
69+
return [x, weight]
70+
71+
def register(self, pm_pass: PatternMatcherPass):
72+
73+
def pattern(
74+
x: torch.Tensor,
75+
weight: torch.Tensor,
76+
) -> Tuple[torch.Tensor, torch.Tensor]:
77+
all_gather = torch.ops.vllm.all_gather.default(
78+
x,
79+
dim=0,
80+
world_size=self.tp_size,
81+
group_name=self.tp.unique_name)
82+
83+
return torch.ops.aten.mm.default(all_gather, weight)
84+
85+
def replacement(
86+
x: torch.Tensor,
87+
weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
88+
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
89+
x,
90+
[weight],
91+
gather_dim=0,
92+
group_name=self.tp.device_group.group_name,
93+
)
94+
return mm_outputs
95+
96+
pm.register_replacement(pattern, replacement, self.get_inputs(),
97+
pm.fwd_only, pm_pass)
98+
99+
100+
class AsyncTPPass(VllmInductorPass):
101+
102+
def __init__(self, config: VllmConfig):
103+
super().__init__(config)
104+
105+
# Enable symmetric memory for the TP process group
106+
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
107+
self.patterns: PatternMatcherPass = PatternMatcherPass(
108+
pass_name="async_tp_pass")
109+
GEMMReduceScatterPattern(self.dtype,
110+
self.device).register(self.patterns)
111+
112+
AllGatherGEMMPattern(self.dtype, self.device).register(self.patterns)
113+
114+
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
115+
# only do replace for specific shapes
116+
tp_size = get_tensor_model_parallel_world_size()
117+
return shape is not None and shape % tp_size == 0
118+
119+
def __call__(self, graph: fx.Graph):
120+
self.begin()
121+
self.dump_graph(graph, "before_async_tp_pass")
122+
count = self.patterns.apply(graph)
123+
logger.debug("Replaced %s patterns", count)
124+
self.dump_graph(graph, "after_async_tp_pass")
125+
self.end_and_log()

vllm/compilation/pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm.logger import init_logger
99

1010
from .activation_quant_fusion import ActivationQuantFusionPass
11+
from .collective_fusion import AsyncTPPass
1112
from .fix_functionalization import FixFunctionalizationPass
1213
from .fusion import FusionPass
1314
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
@@ -56,6 +57,8 @@ def configure(self, config: VllmConfig):
5657

5758
if self.pass_config.enable_sequence_parallelism:
5859
self.passes += [SequenceParallelismPass(config)]
60+
if self.pass_config.enable_async_tp:
61+
self.passes += [AsyncTPPass(config)]
5962

6063
self.fix_functionalization = FixFunctionalizationPass(config)
6164

vllm/compilation/sequence_parallelism.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,12 +255,13 @@ def __init__(self, config: VllmConfig):
255255
torch._inductor.pattern_matcher._seen_patterns.clear()
256256

257257
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
258-
# only do replace for specific shapes
259258
tp_size = get_tensor_model_parallel_world_size()
260259
return shape is not None and shape % tp_size == 0
261260

262261
def __call__(self, graph: fx.Graph):
262+
self.begin()
263263
self.dump_graph(graph, "before_sequence_parallelism_pass")
264264
count = self.patterns.apply(graph)
265265
logger.debug("Replaced %s patterns", count)
266266
self.dump_graph(graph, "after_sequence_parallelism_pass")
267+
self.end_and_log()

0 commit comments

Comments
 (0)