Skip to content

Commit e3cab4e

Browse files
committed
Side effect free & lazy multiprocessing context
1 parent 11c41b5 commit e3cab4e

File tree

7 files changed

+22
-8
lines changed

7 files changed

+22
-8
lines changed

distributed/nanny.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
from .utils import (
3737
TimeoutError,
3838
get_ip,
39+
get_mp_context,
3940
json_load_robust,
4041
log_errors,
41-
mp_context,
4242
parse_ports,
4343
silence_logging,
4444
)
@@ -662,6 +662,7 @@ async def start(self) -> Status:
662662
await self.running.wait()
663663
return self.status
664664

665+
mp_context = get_mp_context()
665666
self.init_result_q = init_q = mp_context.Queue()
666667
self.child_stop_q = mp_context.Queue()
667668
uid = uuid.uuid4().hex

distributed/process.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import dask
1313

14-
from .utils import TimeoutError, mp_context
14+
from .utils import TimeoutError, get_mp_context
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -65,6 +65,7 @@ def __init__(self, loop=None, target=None, name=None, args=(), kwargs={}):
6565
# monitor from the child and exit when the parent goes away unexpectedly
6666
# (for example due to SIGKILL). This variable is otherwise unused except
6767
# for the assignment here.
68+
mp_context = get_mp_context()
6869
parent_alive_pipe, self._keep_child_alive = mp_context.Pipe(duplex=False)
6970

7071
self._process = mp_context.Process(

distributed/tests/test_asyncprocess.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from distributed.compatibility import WINDOWS
1717
from distributed.metrics import time
1818
from distributed.process import AsyncProcess
19-
from distributed.utils import mp_context
19+
from distributed.utils import get_mp_context
2020
from distributed.utils_test import gen_test, nodebug, pristine_loop
2121

2222

@@ -53,6 +53,7 @@ def threads_info(q):
5353
@nodebug
5454
@gen_test()
5555
async def test_simple():
56+
mp_context = get_mp_context()
5657
to_child = mp_context.Queue()
5758
from_child = mp_context.Queue()
5859

@@ -143,6 +144,7 @@ async def test_simple():
143144

144145
@gen_test()
145146
async def test_exitcode():
147+
mp_context = get_mp_context()
146148
q = mp_context.Queue()
147149

148150
proc = AsyncProcess(target=exit, kwargs={"q": q})
@@ -220,6 +222,7 @@ async def test_close():
220222

221223
@gen_test()
222224
async def test_exit_callback():
225+
mp_context = get_mp_context()
223226
to_child = mp_context.Queue()
224227
from_child = mp_context.Queue()
225228
evt = Event()
@@ -268,6 +271,7 @@ async def test_child_main_thread():
268271
"""
269272
The main thread in the child should be called "MainThread".
270273
"""
274+
mp_context = get_mp_context()
271275
q = mp_context.Queue()
272276
proc = AsyncProcess(target=threads_info, args=(q,))
273277
await proc.start()
@@ -339,6 +343,7 @@ def _parent_process(child_pipe):
339343
be used to determine if it exited correctly."""
340344

341345
async def parent_process_coroutine():
346+
mp_context = get_mp_context()
342347
worker_ready = mp_context.Event()
343348

344349
worker = AsyncProcess(target=_worker_process, args=(worker_ready, child_pipe))
@@ -377,6 +382,7 @@ def test_asyncprocess_child_teardown_on_parent_exit():
377382
\________ <-- child_pipe <-- ________/
378383
"""
379384
# When child_pipe is closed, the children_alive pipe unblocks.
385+
mp_context = get_mp_context()
380386
children_alive, child_pipe = mp_context.Pipe(duplex=False)
381387

382388
try:

distributed/tests/test_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
Scheduler,
6969
)
7070
from distributed.sizeof import sizeof
71-
from distributed.utils import is_valid_xml, mp_context, sync, tmp_text
71+
from distributed.utils import get_mp_context, is_valid_xml, sync, tmp_text
7272
from distributed.utils_test import (
7373
TaskStateMetadataPlugin,
7474
_UnhashableCallable,
@@ -2201,7 +2201,9 @@ def long_running_client_connection(address):
22012201

22022202
@gen_cluster()
22032203
async def test_cleanup_after_broken_client_connection(s, a, b):
2204-
proc = mp_context.Process(target=long_running_client_connection, args=(s.address,))
2204+
proc = get_mp_context().Process(
2205+
target=long_running_client_connection, args=(s.address,)
2206+
)
22052207
proc.daemon = True
22062208
proc.start()
22072209

distributed/tests/test_diskutils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from distributed.diskutils import WorkSpace
1616
from distributed.metrics import time
17-
from distributed.utils import mp_context
17+
from distributed.utils import get_mp_context
1818
from distributed.utils_test import captured_logger
1919

2020

@@ -220,6 +220,7 @@ def test_workspace_concurrency(tmpdir, timeout, max_procs):
220220
deadlock happens.
221221
"""
222222
base_dir = str(tmpdir)
223+
mp_context = get_mp_context()
223224

224225
err_q = mp_context.Queue()
225226
purged_q = mp_context.Queue()

distributed/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _initialize_mp_context():
9393
return ctx
9494

9595

96-
mp_context = _initialize_mp_context()
96+
get_mp_context = toolz.memoize(_initialize_mp_context)
9797

9898

9999
def has_arg(func, argname):

distributed/utils_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@
6464
_offload_executor,
6565
get_ip,
6666
get_ipv6,
67+
get_mp_context,
6768
iscoroutinefunction,
6869
log_errors,
69-
mp_context,
7070
reset_logger_locks,
7171
sync,
7272
thread_state,
@@ -626,6 +626,8 @@ def cluster(
626626
enable_proctitle_on_children()
627627

628628
with clean(timeout=active_rpc_timeout, threads=False) as loop:
629+
mp_context = get_mp_context()
630+
629631
if nanny:
630632
_run_worker = run_nanny
631633
else:
@@ -1528,6 +1530,7 @@ def check_thread_leak():
15281530

15291531
@contextmanager
15301532
def check_process_leak(check=True):
1533+
mp_context = get_mp_context()
15311534
for proc in mp_context.active_children():
15321535
proc.terminate()
15331536

0 commit comments

Comments
 (0)