Skip to content

Added CPU offloading #3452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
pre_export_lowering,
)
from torch_tensorrt.dynamo.utils import (
CPU_DEVICE,
get_flat_args_with_check,
get_output_metadata,
parse_graph_io,
Expand Down Expand Up @@ -421,6 +422,7 @@ def compile(
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -498,6 +500,7 @@ def compile(
enable_weight_streaming (bool): Enable weight streaming.
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -550,15 +553,6 @@ def compile(
"`immutable_weights` must be False when `refit_identical_engine_weights` is True."
)

if (
not immutable_weights
and not refit_identical_engine_weights
and enable_weight_streaming
):
raise ValueError(
"TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305"
)

if (
"enable_cross_compile_for_windows" in kwargs.keys()
and kwargs["enable_cross_compile_for_windows"]
Expand Down Expand Up @@ -674,6 +668,7 @@ def compile(
"enable_weight_streaming": enable_weight_streaming,
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
}

settings = CompilationSettings(**compilation_options)
Expand All @@ -690,6 +685,18 @@ def compile(
gm = post_lowering(gm, settings)
logger.debug("Lowered Input graph: " + str(gm.graph))

# Move the weights in the state_dict to CPU
if offload_module_to_cpu:
exported_program.module().to(CPU_DEVICE)
logger.info(
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
)
else:
remaining_memory, total_memory = torch.cuda.mem_get_info()
if remaining_memory < total_memory // 2:
logger.warning(
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
)
trt_gm = compile_module(
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
)
Expand Down Expand Up @@ -820,6 +827,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
trt_modules = {}
# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those

for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# filter on the GraphModule
Expand All @@ -833,6 +841,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
str(name),
str(submodule.graph),
)
submodule.to(to_torch_device(settings.device))
continue

if name not in submodule_node_dict:
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
TILING_OPTIMIZATION_LEVEL = "none"
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False
OFFLOAD_MODULE_TO_CPU = False


def default_device() -> Device:
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OFFLOAD_MODULE_TO_CPU,
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
REFIT_IDENTICAL_ENGINE_WEIGHTS,
Expand Down Expand Up @@ -140,6 +141,7 @@ class CompilationSettings:
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand Down
5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
get_trt_tensor,
to_torch,
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, delete_module, to_torch_device
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

Expand Down Expand Up @@ -736,7 +736,8 @@ def run(
self._create_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)

if self.compilation_settings.offload_module_to_cpu:
delete_module(self.module)
serialized_engine = self.builder.build_serialized_network(
self.ctx.net, builder_config
)
Expand Down
114 changes: 113 additions & 1 deletion tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import torch
import torch_tensorrt as torchtrt
import torchvision.models as models
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
from torch_tensorrt.dynamo.utils import (
COSINE_THRESHOLD,
cosine_similarity,
get_model_device,
)

assertions = unittest.TestCase()

Expand Down Expand Up @@ -283,6 +287,53 @@ def test_resnet18(ir):
)


@pytest.mark.unit
def test_resnet18_cpu_offload(ir):
"""
This tests export save and load functionality on Resnet18 model
"""
model = models.resnet18().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"ir": ir,
"min_block_size": 1,
"cache_built_engines": False,
"reuse_cached_engines": False,
"offload_module_to_cpu": True,
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
assertions.assertTrue(
get_model_device(model).type == "cpu",
msg="Model should be offloaded to CPU",
)
model.cuda()
torchtrt.save(trt_module, trt_ep_path)

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

outputs_trt_deser = deser_trt_module(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_resnet18_dynamic(ir):
"""
Expand Down Expand Up @@ -381,6 +432,67 @@ def forward(self, x):
)


@pytest.mark.unit
def test_hybrid_conv_fallback_cpu_offload(ir):
"""
This tests export save and load functionality on a hybrid
model where a conv (a weighted layer) has been forced to fallback to Pytorch.
"""

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
conv = self.conv(x)
relu = self.relu(conv)
mul = relu * 0.5
return mul

model = MyModule().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"ir": ir,
"min_block_size": 1,
"torch_executed_ops": {"torch.ops.aten.convolution.default"},
"cache_built_engines": False,
"reuse_cached_engines": False,
"offload_module_to_cpu": True,
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
model.cuda()
torchtrt.save(trt_module, trt_ep_path)

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)

for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

outputs_trt_deser = deser_trt_module(input)
for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_arange_export(ir):
"""
Expand Down
68 changes: 68 additions & 0 deletions tests/py/dynamo/models/test_model_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,74 @@ def forward(self, x):
torch._dynamo.reset()


@pytest.mark.unit
def test_refit_multiple_engine_with_weightmap_cpu_offload():
class net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 12, 3, padding=1)
self.bn = nn.BatchNorm2d(12)
self.conv2 = nn.Conv2d(12, 12, 3, padding=1)
self.fc1 = nn.Linear(12 * 56 * 56, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.bn(x)
x = F.max_pool2d(x, (2, 2))
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, (2, 2))
x = torch.flatten(x, 1)
return self.fc1(x)

model = net().eval().to("cuda")
model2 = net().eval().to("cuda")

inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
min_block_size = 1
use_python_runtime = False

exp_program = torch.export.export(model, tuple(inputs))
exp_program2 = torch.export.export(model2, tuple(inputs))

torch_executed_ops = {"torch.ops.aten.convolution.default"}
trt_gm = torchtrt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
immutable_weights=False,
torch_executed_ops=torch_executed_ops,
reuse_cached_engines=False,
offload_module_to_cpu=True,
)

new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
arg_inputs=inputs,
use_weight_map_cache=True,
)
model2.cuda()
# Check the output
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(
*inputs
)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assertions.assertTrue(
torch.allclose(expected_output, refitted_output, 1e-2, 1e-2),
"Refit Result is not correct. Refit failed",
)
# Clean up model env

torch._dynamo.reset()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
Expand Down
Loading