Skip to content

Commit 79064a2

Browse files
BoyuanFengfacebook-github-bot
authored andcommitted
enable cudagraph for AUCMetricComputation (#2951)
Summary: Pull Request resolved: #2951 Previously, enable cudagraph leads to [errors](https://fb.workplace.com/groups/1075192433118967/permalink/1661777797793758/) on AUCMetricComputation. The root cause is that model forward output tensors `predictions`, `labels`, `weights`, and `grouping_keys` got overwritten by the next CUDAGraph replay. This diff fixes the issue by clone these tensors before they are overwritten. The overhead should be small since these tensors only have shape `(n_task, n_examples)`. Reviewed By: iamzainhuda, TroyGarden Differential Revision: D74293458
1 parent 57deb6e commit 79064a2

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

torchrec/metrics/metric_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def _generate_rec_metrics(
370370
kwargs = metric_def.arguments
371371

372372
kwargs["enable_pt2_compile"] = metrics_config.enable_pt2_compile
373+
kwargs["should_clone_update_inputs"] = metrics_config.should_clone_update_inputs
373374

374375
rec_tasks: List[RecTaskInfo] = []
375376
if metric_def.rec_tasks and metric_def.rec_task_indices:

torchrec/metrics/metrics_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ class MetricsConfig:
170170
update if the inputs are invalid. Invalid inputs include the case where all
171171
examples have 0 weights for a batch.
172172
enable_pt2_compile (bool): whether to enable PT2 compilation for metrics.
173+
should_clone_update_inputs (bool): whether to clone the inputs of update(). This
174+
prevents CUDAGraph error on overwritting tensor outputs by subsequent runs.
173175
"""
174176

175177
rec_tasks: List[RecTaskInfo] = field(default_factory=list)
@@ -184,6 +186,7 @@ class MetricsConfig:
184186
compute_on_all_ranks: bool = False
185187
should_validate_update: bool = False
186188
enable_pt2_compile: bool = False
189+
should_clone_update_inputs: bool = False
187190

188191

189192
DefaultTaskInfo = RecTaskInfo(

torchrec/metrics/rec_metric.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,14 @@ def __init__(
384384
if "enable_pt2_compile" in kwargs:
385385
del kwargs["enable_pt2_compile"]
386386

387+
# pyre-fixme[8]: Attribute has type `bool`; used as `Union[bool,
388+
# Dict[str, Any]]`.
389+
self._should_clone_update_inputs: bool = kwargs.get(
390+
"should_clone_update_inputs", False
391+
)
392+
if "should_clone_update_inputs" in kwargs:
393+
del kwargs["should_clone_update_inputs"]
394+
387395
if self._window_size < self._batch_size:
388396
raise ValueError(
389397
f"Local window size must be larger than batch size. Got local window size {self._window_size} and batch size {self._batch_size}."
@@ -541,6 +549,35 @@ def _create_default_weights(self, predictions: torch.Tensor) -> torch.Tensor:
541549
def _check_nonempty_weights(self, weights: torch.Tensor) -> torch.Tensor:
542550
return torch.gt(torch.count_nonzero(weights, dim=-1), 0)
543551

552+
def clone_update_inputs(
553+
self,
554+
predictions: RecModelOutput,
555+
labels: RecModelOutput,
556+
weights: Optional[RecModelOutput],
557+
**kwargs: Dict[str, Any],
558+
) -> tuple[
559+
RecModelOutput, RecModelOutput, Optional[RecModelOutput], Dict[str, Any]
560+
]:
561+
def clone_rec_model_output(
562+
rec_model_output: RecModelOutput,
563+
) -> RecModelOutput:
564+
if isinstance(rec_model_output, torch.Tensor):
565+
return rec_model_output.clone()
566+
else:
567+
return {k: v.clone() for k, v in rec_model_output.items()}
568+
569+
predictions = clone_rec_model_output(predictions)
570+
labels = clone_rec_model_output(labels)
571+
if weights is not None:
572+
weights = clone_rec_model_output(weights)
573+
574+
if "required_inputs" in kwargs:
575+
kwargs["required_inputs"] = {
576+
k: v.clone() for k, v in kwargs["required_inputs"].items()
577+
}
578+
579+
return predictions, labels, weights, kwargs
580+
544581
def _update(
545582
self,
546583
*,
@@ -550,6 +587,11 @@ def _update(
550587
**kwargs: Dict[str, Any],
551588
) -> None:
552589
with torch.no_grad():
590+
if self._should_clone_update_inputs:
591+
predictions, labels, weights, kwargs = self.clone_update_inputs(
592+
predictions, labels, weights, **kwargs
593+
)
594+
553595
if self._compute_mode in [
554596
RecComputeMode.FUSED_TASKS_COMPUTATION,
555597
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,

0 commit comments

Comments
 (0)