Skip to content

Commit 40ab70d

Browse files
che-shfacebook-github-bot
authored andcommitted
Add context manager to use next batch context for postprocs (#2939)
Summary: Pull Request resolved: #2939 Small refactor to reduce code repetition of setting and reverting pipelined postprocs context to the next batch's context Reviewed By: TroyGarden Differential Revision: D73824600
1 parent 8cda1a4 commit 40ab70d

File tree

2 files changed

+38
-28
lines changed

2 files changed

+38
-28
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
StageOut,
5959
StageOutputWithEvent,
6060
TrainPipelineContext,
61+
use_context_for_postprocs,
6162
)
6263
from torchrec.distributed.types import Awaitable
6364
from torchrec.pt2.checks import is_torchdynamo_compiling
@@ -791,19 +792,9 @@ def start_sparse_data_dist(
791792
with self._stream_context(self._data_dist_stream):
792793
_wait_for_batch(batch, self._memcpy_stream)
793794

794-
original_contexts = [p.get_context() for p in self._pipelined_postprocs]
795-
796795
# Temporarily set context for next iter to populate cache
797-
for postproc_mod in self._pipelined_postprocs:
798-
postproc_mod.set_context(context)
799-
800-
_start_data_dist(self._pipelined_modules, batch, context)
801-
802-
# Restore context for model fwd
803-
for module, context in zip(
804-
self._pipelined_postprocs, original_contexts
805-
):
806-
module.set_context(context)
796+
with use_context_for_postprocs(self._pipelined_postprocs, context):
797+
_start_data_dist(self._pipelined_modules, batch, context)
807798

808799
def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None:
809800
"""
@@ -1324,22 +1315,15 @@ def start_sparse_data_dist(
13241315
return
13251316

13261317
# Temporarily set context for next iter to populate cache
1327-
original_contexts = [p.get_context() for p in self._pipelined_postprocs]
1328-
for postproc_mod in self._pipelined_postprocs:
1329-
postproc_mod.set_context(context)
1330-
1331-
with record_function(f"## start_sparse_data_dist {context.index} ##"):
1332-
with self._stream_context(self._data_dist_stream):
1333-
_wait_for_events(batch, context, self._data_dist_stream)
1334-
model_input = self.extract_model_input_from_batch(batch)
1335-
_start_data_dist(self._pipelined_modules, model_input, context)
1336-
event = torch.get_device_module(self._device).Event()
1337-
event.record()
1338-
context.events.append(event)
1339-
1340-
# Restore context for model forward
1341-
for module, context in zip(self._pipelined_postprocs, original_contexts):
1342-
module.set_context(context)
1318+
with use_context_for_postprocs(self._pipelined_postprocs, context):
1319+
with record_function(f"## start_sparse_data_dist {context.index} ##"):
1320+
with self._stream_context(self._data_dist_stream):
1321+
_wait_for_events(batch, context, self._data_dist_stream)
1322+
model_input = self.extract_model_input_from_batch(batch)
1323+
_start_data_dist(self._pipelined_modules, model_input, context)
1324+
event = torch.get_device_module(self._device).Event()
1325+
event.record()
1326+
context.events.append(event)
13431327

13441328
def start_embedding_lookup(
13451329
self,

torchrec/distributed/train_pipeline/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
# pyre-strict
99
import abc
10+
11+
import contextlib
1012
import copy
1113
import itertools
1214
import logging
@@ -21,6 +23,7 @@
2123
Callable,
2224
cast,
2325
Dict,
26+
Generator,
2427
Generic,
2528
Iterable,
2629
Iterator,
@@ -1834,6 +1837,28 @@ def _prefetch_embeddings(
18341837
return data_per_sharded_module
18351838

18361839

1840+
@contextlib.contextmanager
1841+
def use_context_for_postprocs(
1842+
pipelined_postprocs: List[PipelinedPostproc],
1843+
next_batch_context: TrainPipelineContext,
1844+
) -> Generator[None, None, None]:
1845+
"""
1846+
Temporarily set pipelined postproc context for next iter to populate cache.
1847+
"""
1848+
# Save original context for model fwd
1849+
original_contexts = [p.get_context() for p in pipelined_postprocs]
1850+
1851+
# Temporarily set context for next iter to populate cache
1852+
for postproc_mod in pipelined_postprocs:
1853+
postproc_mod.set_context(next_batch_context)
1854+
1855+
yield
1856+
1857+
# Restore context for model fwd
1858+
for module, context in zip(pipelined_postprocs, original_contexts):
1859+
module.set_context(context)
1860+
1861+
18371862
class SparseDataDistUtil(Generic[In]):
18381863
"""
18391864
Helper class exposing methods for sparse data dist and prefetch pipelining.
@@ -1845,6 +1870,7 @@ class SparseDataDistUtil(Generic[In]):
18451870
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
18461871
prefetch_stream (Optional[torch.cuda.Stream]): Stream on which model prefetch runs
18471872
Defaults to `None`. This needs to be passed in to enable prefetch pipelining.
1873+
pipeline_postproc (bool): whether to pipeline postproc modules. Defaults to `False`.
18481874
18491875
Example::
18501876
sdd = SparseDataDistUtil(

0 commit comments

Comments
 (0)