Skip to content

[Bugfix] Enable torch.comple for 2 parts of model #14913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

vadiklyutiy
Copy link
Contributor

@vadiklyutiy vadiklyutiy commented Mar 17, 2025

Before this PR

Prior to this PR, we were unable to cover two parts of the model using @support_torch_compile.

For instance, the example below fails:

@support_torch_compile
class FirstLinear(nn.Module):
    def __init__(self, input_size=10, output_size=20, *, vllm_config=None, prefix='', **kwargs):
        super().__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)

class ActivationLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.activation = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.activation(x)

@support_torch_compile
class SecondLinear(nn.Module):
    def __init__(self, input_size=20, output_size=5, *, vllm_config=None, prefix='', **kwargs):
        super().__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)

class SimpleModel(nn.Module):
    def __init__(self, input_size=10, hidden_size=20, output_size=5, *, vllm_config=None, prefix=''):
        super().__init__()
        self.first_linear = FirstLinear(
            input_size=input_size,
            output_size=hidden_size,
            vllm_config=vllm_config,
            prefix=f"{prefix}first_linear."
        )
        self.activation = ActivationLayer()
        self.second_linear = SecondLinear(
            input_size=hidden_size,
            output_size=output_size,
            vllm_config=vllm_config,
            prefix=f"{prefix}second_linear."
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.first_linear(x)
        x = self.activation(x)
        x = self.second_linear(x)
        return x

What This PR Fixes

This PR resolves several bugs and this enables covering with @support_torch_compile multiple parts of the model.


Why It's Needed

Consider a model with three parts:

part1()
part2()
part3()

part2() might not be traceable by Dynamo. If we attempt to run it with Dynamo, it will fail. In such cases, we can apply @support_torch_compile to part1() or part3(), but not both.

…ph and parameter shapes into the cache directory naming

Signed-off-by: Vadim Gimpelson <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Vadim Gimpelson <[email protected]>
@youkaichao
Copy link
Member

what is the use case here?

@vadiklyutiy
Copy link
Contributor Author

what is the use case here?

The main motivation is models that contain part(s) that can not be traced by dynamo.

If you asked about specific example then motivation example is Qwen2.5-vl. It contains two parts: vision and language. We spend more or less equal time in both parts. Right now we use @support_torch_compile only for language part, so, compile only half of execution time. If try to wrap while model with torch compile, we get dynamo tracing fail due to data dependence(shapes of tensor depends on data in another tensor). It happens in rotary embedding part. Rotary embedding is not performance critical(spend very few time there). Rotary embedding is in very beginning of vision part and we can add @support_torch_compile a bit later/deeper in call stack with still covering 99% of execution time of vision part.

But right now we have fails if use @support_torch_compile twice(not nested). This PR fixes this bug.

I added a simple test that shows a error caused by lack of support of several @support_torch_compile.

@youkaichao
Copy link
Member

for Qwen2.5-vl, we do plan to compile two parts separately. we need to spend some time to design the config though. right now, the compilation config is only for text model, and you need to have another compilation config for the vision part.

@vadiklyutiy
Copy link
Contributor Author

for Qwen2.5-vl, we do plan to compile two parts separately. we need to spend some time to design the config though. right now, the compilation config is only for text model, and you need to have another compilation config for the vision part.

Ok, lets assume we need standalone compilation config for Qwen2.5-vl. But "to compile two parts separately" I think we need 2 changes made in this PR, no?

In next comment I will provide some comment to code changes

# Add fxgraph to cache path to avoid conflict with other
# @support_torch_compile caches
import hashlib
graph_code = graph.graph.python_code(root_module="self").src
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the text model can have different models, too, e.g. in the pipeline parallel case, but we want them to share the same cache directory.

i think we need to be explicit here, say let the upper level caller indicate a tag for the compilation, like compilation_config.tag = "text_tower"/"vision_tower"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the text model can have different models, too, e.g. in the pipeline parallel case, but we want them to share the same cache directory.

I think if fxgraph IRs (graph.python_code().src) are different, then inductor will produce different code to execute that fxgraph.

i think we need to be explicit here, say let the upper level caller indicate a tag for the compilation, like compilation_config.tag = "text_tower"/"vision_tower"

Did I understand correctly that you propose to add in @support_torch_compile explicit arg that specify suffix for cache (more generally some identifier of compile piece, potentially might be used for another purposes).

Copy link
Contributor Author

@vadiklyutiy vadiklyutiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below is my comment to made changes

if isinstance(input_arg, torch.nn.parameter.Parameter):
graph_code += f"\n{str(input_arg.shape)}"
graph_hash = hashlib.md5(graph_code.encode()).hexdigest()
cache_dir = os.path.join(cache_dir, f"fxgraph_{graph_hash}")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we "plan to compile two parts separately" we need to add fxgraph to hash. Otherwise 2 both fxgraph will access same computation_graph.py and the second will fail.

@@ -345,7 +345,6 @@ def configure_post_pass(self):
# Config should automatically wrap all inductor passes
assert isinstance(inductor_config[PASS_KEY], InductorPass)
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
inductor_config[PASS_KEY] = self.post_grad_pass_manager
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is code is not incorrect even formally. self.post_grad_pass_manager is not instance of InductorPass that checked 2 lines above.

I checked source and for my best understanding we need only initially passed value inductor_config[PASS_KEY]. After we make here self.post_grad_pass_manager.add( ) we don't need to update inductor_config[PASS_KEY].

@vadiklyutiy
Copy link
Contributor Author

@youkaichao kindly remind about this PR

@vadiklyutiy
Copy link
Contributor Author

Support of torch.compile of several models in the same run was implemented in #17211

@vadiklyutiy
Copy link
Contributor Author

Support of torch.compile of several models in the same run was implemented in #17211

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants