diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index acd16a32f0..50c57329d5 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -37,6 +37,7 @@ pre_export_lowering, ) from torch_tensorrt.dynamo.utils import ( + CPU_DEVICE, get_flat_args_with_check, get_output_metadata, parse_graph_io, @@ -421,6 +422,7 @@ def compile( enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, + offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -498,6 +500,7 @@ def compile( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -550,15 +553,6 @@ def compile( "`immutable_weights` must be False when `refit_identical_engine_weights` is True." ) - if ( - not immutable_weights - and not refit_identical_engine_weights - and enable_weight_streaming - ): - raise ValueError( - "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" - ) - if ( "enable_cross_compile_for_windows" in kwargs.keys() and kwargs["enable_cross_compile_for_windows"] @@ -674,6 +668,7 @@ def compile( "enable_weight_streaming": enable_weight_streaming, "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, + "offload_module_to_cpu": offload_module_to_cpu, } settings = CompilationSettings(**compilation_options) @@ -690,6 +685,18 @@ def compile( gm = post_lowering(gm, settings) logger.debug("Lowered Input graph: " + str(gm.graph)) + # Move the weights in the state_dict to CPU + if offload_module_to_cpu: + exported_program.module().to(CPU_DEVICE) + logger.info( + "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" + ) + else: + remaining_memory, total_memory = torch.cuda.mem_get_info() + if remaining_memory < total_memory // 2: + logger.warning( + "Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True" + ) trt_gm = compile_module( gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache ) @@ -820,6 +827,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: trt_modules = {} # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those + for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -833,6 +841,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: str(name), str(submodule.graph), ) + submodule.to(to_torch_device(settings.device)) continue if name not in submodule_node_dict: diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 379a196e2e..aafd1072f4 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -49,6 +49,7 @@ TILING_OPTIMIZATION_LEVEL = "none" L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False +OFFLOAD_MODULE_TO_CPU = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d9b0e05e4d..97c02f34fb 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -25,6 +25,7 @@ MAX_AUX_STREAMS, MIN_BLOCK_SIZE, NUM_AVG_TIMING_ITERS, + OFFLOAD_MODULE_TO_CPU, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, REFIT_IDENTICAL_ENGINE_WEIGHTS, @@ -140,6 +141,7 @@ class CompilationSettings: tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE + offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU _SETTINGS_TO_BE_ENGINE_INVARIANT = ( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index fde07bf1f5..0f87d26250 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -45,7 +45,7 @@ get_trt_tensor, to_torch, ) -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, delete_module, to_torch_device from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER @@ -736,7 +736,8 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - + if self.compilation_settings.offload_module_to_cpu: + delete_module(self.module) serialized_engine = self.builder.build_serialized_network( self.ctx.net, builder_config ) diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index 4c0b9c6d06..52e5eefb63 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -6,7 +6,11 @@ import torch import torch_tensorrt as torchtrt import torchvision.models as models -from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +from torch_tensorrt.dynamo.utils import ( + COSINE_THRESHOLD, + cosine_similarity, + get_model_device, +) assertions = unittest.TestCase() @@ -283,6 +287,53 @@ def test_resnet18(ir): ) +@pytest.mark.unit +def test_resnet18_cpu_offload(ir): + """ + This tests export save and load functionality on Resnet18 model + """ + model = models.resnet18().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "ir": ir, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + "offload_module_to_cpu": True, + } + + exp_program = torchtrt.dynamo.trace(model, **compile_spec) + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) + assertions.assertTrue( + get_model_device(model).type == "cpu", + msg="Model should be offloaded to CPU", + ) + model.cuda() + torchtrt.save(trt_module, trt_ep_path) + + deser_trt_module = torchtrt.load(trt_ep_path).module() + outputs_pyt = model(input) + outputs_trt = trt_module(input) + cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + outputs_trt_deser = deser_trt_module(input) + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + @pytest.mark.unit def test_resnet18_dynamic(ir): """ @@ -381,6 +432,67 @@ def forward(self, x): ) +@pytest.mark.unit +def test_hybrid_conv_fallback_cpu_offload(ir): + """ + This tests export save and load functionality on a hybrid + model where a conv (a weighted layer) has been forced to fallback to Pytorch. + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv = self.conv(x) + relu = self.relu(conv) + mul = relu * 0.5 + return mul + + model = MyModule().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "ir": ir, + "min_block_size": 1, + "torch_executed_ops": {"torch.ops.aten.convolution.default"}, + "cache_built_engines": False, + "reuse_cached_engines": False, + "offload_module_to_cpu": True, + } + + exp_program = torchtrt.dynamo.trace(model, **compile_spec) + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) + model.cuda() + torchtrt.save(trt_module, trt_ep_path) + + deser_trt_module = torchtrt.load(trt_ep_path).module() + outputs_pyt = model(input) + outputs_trt = trt_module(input) + + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + outputs_trt_deser = deser_trt_module(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + @pytest.mark.unit def test_arange_export(ir): """ diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index d71091b04e..068ba81473 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -492,6 +492,74 @@ def forward(self, x): torch._dynamo.reset() +@pytest.mark.unit +def test_refit_multiple_engine_with_weightmap_cpu_offload(): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 12, 3, padding=1) + self.bn = nn.BatchNorm2d(12) + self.conv2 = nn.Conv2d(12, 12, 3, padding=1) + self.fc1 = nn.Linear(12 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.bn(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + model = net().eval().to("cuda") + model2 = net().eval().to("cuda") + + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + enabled_precisions = {torch.float} + debug = False + min_block_size = 1 + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + torch_executed_ops = {"torch.ops.aten.convolution.default"} + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + debug=debug, + min_block_size=min_block_size, + immutable_weights=False, + torch_executed_ops=torch_executed_ops, + reuse_cached_engines=False, + offload_module_to_cpu=True, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=True, + ) + model2.cuda() + # Check the output + expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( + *inputs + ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() + + @unittest.skipIf( not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available",