Skip to content

Commit 5391139

Browse files
che-shfacebook-github-bot
authored andcommitted
Add context manager to use next batch context for postprocs (#2939)
Summary: Pull Request resolved: #2939 Small refactor to reduce code repetition of setting and reverting pipelined postprocs context to the next batch's context Differential Revision: D73824600
1 parent 54723ee commit 5391139

File tree

2 files changed

+37
-28
lines changed

2 files changed

+37
-28
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
StageOut,
5959
StageOutputWithEvent,
6060
TrainPipelineContext,
61+
use_context_for_postprocs,
6162
)
6263
from torchrec.distributed.types import Awaitable
6364
from torchrec.pt2.checks import is_torchdynamo_compiling
@@ -792,19 +793,9 @@ def start_sparse_data_dist(
792793
with self._stream_context(self._data_dist_stream):
793794
_wait_for_batch(batch, self._memcpy_stream)
794795

795-
original_contexts = [p.get_context() for p in self._pipelined_postprocs]
796-
797796
# Temporarily set context for next iter to populate cache
798-
for postproc_mod in self._pipelined_postprocs:
799-
postproc_mod.set_context(context)
800-
801-
_start_data_dist(self._pipelined_modules, batch, context)
802-
803-
# Restore context for model fwd
804-
for module, context in zip(
805-
self._pipelined_postprocs, original_contexts
806-
):
807-
module.set_context(context)
797+
with use_context_for_postprocs(self._pipelined_postprocs, context):
798+
_start_data_dist(self._pipelined_modules, batch, context)
808799

809800
def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None:
810801
"""
@@ -1325,22 +1316,15 @@ def start_sparse_data_dist(
13251316
return
13261317

13271318
# Temporarily set context for next iter to populate cache
1328-
original_contexts = [p.get_context() for p in self._pipelined_postprocs]
1329-
for postproc_mod in self._pipelined_postprocs:
1330-
postproc_mod.set_context(context)
1331-
1332-
with record_function(f"## start_sparse_data_dist {context.index} ##"):
1333-
with self._stream_context(self._data_dist_stream):
1334-
_wait_for_events(batch, context, self._data_dist_stream)
1335-
model_input = self.extract_model_input_from_batch(batch)
1336-
_start_data_dist(self._pipelined_modules, model_input, context)
1337-
event = torch.get_device_module(self._device).Event()
1338-
event.record()
1339-
context.events.append(event)
1340-
1341-
# Restore context for model forward
1342-
for module, context in zip(self._pipelined_postprocs, original_contexts):
1343-
module.set_context(context)
1319+
with use_context_for_postprocs(self._pipelined_postprocs, context):
1320+
with record_function(f"## start_sparse_data_dist {context.index} ##"):
1321+
with self._stream_context(self._data_dist_stream):
1322+
_wait_for_events(batch, context, self._data_dist_stream)
1323+
model_input = self.extract_model_input_from_batch(batch)
1324+
_start_data_dist(self._pipelined_modules, model_input, context)
1325+
event = torch.get_device_module(self._device).Event()
1326+
event.record()
1327+
context.events.append(event)
13441328

13451329
def start_embedding_lookup(
13461330
self,

torchrec/distributed/train_pipeline/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99

10+
import contextlib
1011
import copy
1112
import itertools
1213
import logging
@@ -21,6 +22,7 @@
2122
Callable,
2223
cast,
2324
Dict,
25+
Generator,
2426
Generic,
2527
Iterable,
2628
Iterator,
@@ -1797,6 +1799,28 @@ def _prefetch_embeddings(
17971799
return data_per_sharded_module
17981800

17991801

1802+
@contextlib.contextmanager
1803+
def use_context_for_postprocs(
1804+
pipelined_postprocs: List[PipelinedPostproc],
1805+
next_batch_context: TrainPipelineContext,
1806+
) -> Generator[None, None, None]:
1807+
"""
1808+
Temporarily set pipelined postproc context for next iter to populate cache.
1809+
"""
1810+
# Save original context for model fwd
1811+
original_contexts = [p.get_context() for p in pipelined_postprocs]
1812+
1813+
# Temporarily set context for next iter to populate cache
1814+
for postproc_mod in pipelined_postprocs:
1815+
postproc_mod.set_context(next_batch_context)
1816+
1817+
yield
1818+
1819+
# Restore context for model fwd
1820+
for module, context in zip(pipelined_postprocs, original_contexts):
1821+
module.set_context(context)
1822+
1823+
18001824
class SparseDataDistUtil(Generic[In]):
18011825
"""
18021826
Helper class exposing methods for sparse data dist and prefetch pipelining.
@@ -1808,6 +1832,7 @@ class SparseDataDistUtil(Generic[In]):
18081832
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
18091833
prefetch_stream (Optional[torch.cuda.Stream]): Stream on which model prefetch runs
18101834
Defaults to `None`. This needs to be passed in to enable prefetch pipelining.
1835+
pipeline_postproc (bool): whether to pipeline postproc modules. Defaults to `False`.
18111836
18121837
Example::
18131838
sdd = SparseDataDistUtil(

0 commit comments

Comments
 (0)