|
58 | 58 | StageOut,
|
59 | 59 | StageOutputWithEvent,
|
60 | 60 | TrainPipelineContext,
|
| 61 | + use_context_for_postprocs, |
61 | 62 | )
|
62 | 63 | from torchrec.distributed.types import Awaitable
|
63 | 64 | from torchrec.pt2.checks import is_torchdynamo_compiling
|
@@ -792,19 +793,9 @@ def start_sparse_data_dist(
|
792 | 793 | with self._stream_context(self._data_dist_stream):
|
793 | 794 | _wait_for_batch(batch, self._memcpy_stream)
|
794 | 795 |
|
795 |
| - original_contexts = [p.get_context() for p in self._pipelined_postprocs] |
796 |
| - |
797 | 796 | # Temporarily set context for next iter to populate cache
|
798 |
| - for postproc_mod in self._pipelined_postprocs: |
799 |
| - postproc_mod.set_context(context) |
800 |
| - |
801 |
| - _start_data_dist(self._pipelined_modules, batch, context) |
802 |
| - |
803 |
| - # Restore context for model fwd |
804 |
| - for module, context in zip( |
805 |
| - self._pipelined_postprocs, original_contexts |
806 |
| - ): |
807 |
| - module.set_context(context) |
| 797 | + with use_context_for_postprocs(self._pipelined_postprocs, context): |
| 798 | + _start_data_dist(self._pipelined_modules, batch, context) |
808 | 799 |
|
809 | 800 | def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None:
|
810 | 801 | """
|
@@ -1325,22 +1316,15 @@ def start_sparse_data_dist(
|
1325 | 1316 | return
|
1326 | 1317 |
|
1327 | 1318 | # Temporarily set context for next iter to populate cache
|
1328 |
| - original_contexts = [p.get_context() for p in self._pipelined_postprocs] |
1329 |
| - for postproc_mod in self._pipelined_postprocs: |
1330 |
| - postproc_mod.set_context(context) |
1331 |
| - |
1332 |
| - with record_function(f"## start_sparse_data_dist {context.index} ##"): |
1333 |
| - with self._stream_context(self._data_dist_stream): |
1334 |
| - _wait_for_events(batch, context, self._data_dist_stream) |
1335 |
| - model_input = self.extract_model_input_from_batch(batch) |
1336 |
| - _start_data_dist(self._pipelined_modules, model_input, context) |
1337 |
| - event = torch.get_device_module(self._device).Event() |
1338 |
| - event.record() |
1339 |
| - context.events.append(event) |
1340 |
| - |
1341 |
| - # Restore context for model forward |
1342 |
| - for module, context in zip(self._pipelined_postprocs, original_contexts): |
1343 |
| - module.set_context(context) |
| 1319 | + with use_context_for_postprocs(self._pipelined_postprocs, context): |
| 1320 | + with record_function(f"## start_sparse_data_dist {context.index} ##"): |
| 1321 | + with self._stream_context(self._data_dist_stream): |
| 1322 | + _wait_for_events(batch, context, self._data_dist_stream) |
| 1323 | + model_input = self.extract_model_input_from_batch(batch) |
| 1324 | + _start_data_dist(self._pipelined_modules, model_input, context) |
| 1325 | + event = torch.get_device_module(self._device).Event() |
| 1326 | + event.record() |
| 1327 | + context.events.append(event) |
1344 | 1328 |
|
1345 | 1329 | def start_embedding_lookup(
|
1346 | 1330 | self,
|
|
0 commit comments