Skip to content

Commit 4402b3c

Browse files
gh-76785: Minor Improvements to "interpreters" Module (gh-116328)
This includes adding pickle support to various classes, and small changes to improve the maintainability of the low-level _xxinterpqueues module.
1 parent bdba8ef commit 4402b3c

File tree

9 files changed

+337
-88
lines changed

9 files changed

+337
-88
lines changed

Lib/test/support/interpreters/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,14 @@ def __hash__(self):
129129
def __del__(self):
130130
self._decref()
131131

132+
# for pickling:
133+
def __getnewargs__(self):
134+
return (self._id,)
135+
136+
# for pickling:
137+
def __getstate__(self):
138+
return None
139+
132140
def _decref(self):
133141
if not self._ownsref:
134142
return

Lib/test/support/interpreters/channels.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,16 @@ class _ChannelEnd:
3838

3939
_end = None
4040

41-
def __init__(self, cid):
41+
def __new__(cls, cid):
42+
self = super().__new__(cls)
4243
if self._end == 'send':
4344
cid = _channels._channel_id(cid, send=True, force=True)
4445
elif self._end == 'recv':
4546
cid = _channels._channel_id(cid, recv=True, force=True)
4647
else:
4748
raise NotImplementedError(self._end)
4849
self._id = cid
50+
return self
4951

5052
def __repr__(self):
5153
return f'{type(self).__name__}(id={int(self._id)})'
@@ -61,6 +63,14 @@ def __eq__(self, other):
6163
return NotImplemented
6264
return other._id == self._id
6365

66+
# for pickling:
67+
def __getnewargs__(self):
68+
return (int(self._id),)
69+
70+
# for pickling:
71+
def __getstate__(self):
72+
return None
73+
6474
@property
6575
def id(self):
6676
return self._id

Lib/test/support/interpreters/queues.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
]
1919

2020

21-
class QueueEmpty(_queues.QueueEmpty, queue.Empty):
21+
class QueueEmpty(QueueError, queue.Empty):
2222
"""Raised from get_nowait() when the queue is empty.
2323
2424
It is also raised from get() if it times out.
2525
"""
2626

2727

28-
class QueueFull(_queues.QueueFull, queue.Full):
28+
class QueueFull(QueueError, queue.Full):
2929
"""Raised from put_nowait() when the queue is full.
3030
3131
It is also raised from put() if it times out.
@@ -66,7 +66,7 @@ def __new__(cls, id, /, *, _fmt=None):
6666
else:
6767
raise TypeError(f'id must be an int, got {id!r}')
6868
if _fmt is None:
69-
_fmt = _queues.get_default_fmt(id)
69+
_fmt, = _queues.get_queue_defaults(id)
7070
try:
7171
self = _known_queues[id]
7272
except KeyError:
@@ -93,6 +93,14 @@ def __repr__(self):
9393
def __hash__(self):
9494
return hash(self._id)
9595

96+
# for pickling:
97+
def __getnewargs__(self):
98+
return (self._id,)
99+
100+
# for pickling:
101+
def __getstate__(self):
102+
return None
103+
96104
@property
97105
def id(self):
98106
return self._id
@@ -159,9 +167,8 @@ def put(self, obj, timeout=None, *,
159167
while True:
160168
try:
161169
_queues.put(self._id, obj, fmt)
162-
except _queues.QueueFull as exc:
170+
except QueueFull as exc:
163171
if timeout is not None and time.time() >= end:
164-
exc.__class__ = QueueFull
165172
raise # re-raise
166173
time.sleep(_delay)
167174
else:
@@ -174,11 +181,7 @@ def put_nowait(self, obj, *, syncobj=None):
174181
fmt = _SHARED_ONLY if syncobj else _PICKLED
175182
if fmt is _PICKLED:
176183
obj = pickle.dumps(obj)
177-
try:
178-
_queues.put(self._id, obj, fmt)
179-
except _queues.QueueFull as exc:
180-
exc.__class__ = QueueFull
181-
raise # re-raise
184+
_queues.put(self._id, obj, fmt)
182185

183186
def get(self, timeout=None, *,
184187
_delay=10 / 1000, # 10 milliseconds
@@ -195,9 +198,8 @@ def get(self, timeout=None, *,
195198
while True:
196199
try:
197200
obj, fmt = _queues.get(self._id)
198-
except _queues.QueueEmpty as exc:
201+
except QueueEmpty as exc:
199202
if timeout is not None and time.time() >= end:
200-
exc.__class__ = QueueEmpty
201203
raise # re-raise
202204
time.sleep(_delay)
203205
else:
@@ -216,8 +218,7 @@ def get_nowait(self):
216218
"""
217219
try:
218220
obj, fmt = _queues.get(self._id)
219-
except _queues.QueueEmpty as exc:
220-
exc.__class__ = QueueEmpty
221+
except QueueEmpty as exc:
221222
raise # re-raise
222223
if fmt == _PICKLED:
223224
obj = pickle.loads(obj)
@@ -226,4 +227,4 @@ def get_nowait(self):
226227
return obj
227228

228229

229-
_queues._register_queue_type(Queue)
230+
_queues._register_heap_types(Queue, QueueEmpty, QueueFull)

Lib/test/test_interpreters/test_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import pickle
23
import threading
34
from textwrap import dedent
45
import unittest
@@ -261,6 +262,12 @@ def test_equality(self):
261262
self.assertEqual(interp1, interp1)
262263
self.assertNotEqual(interp1, interp2)
263264

265+
def test_pickle(self):
266+
interp = interpreters.create()
267+
data = pickle.dumps(interp)
268+
unpickled = pickle.loads(data)
269+
self.assertEqual(unpickled, interp)
270+
264271

265272
class TestInterpreterIsRunning(TestBase):
266273

Lib/test/test_interpreters/test_channels.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib
2+
import pickle
23
import threading
34
from textwrap import dedent
45
import unittest
@@ -100,6 +101,12 @@ def test_equality(self):
100101
self.assertEqual(ch1, ch1)
101102
self.assertNotEqual(ch1, ch2)
102103

104+
def test_pickle(self):
105+
ch, _ = channels.create()
106+
data = pickle.dumps(ch)
107+
unpickled = pickle.loads(data)
108+
self.assertEqual(unpickled, ch)
109+
103110

104111
class TestSendChannelAttrs(TestBase):
105112

@@ -125,6 +132,12 @@ def test_equality(self):
125132
self.assertEqual(ch1, ch1)
126133
self.assertNotEqual(ch1, ch2)
127134

135+
def test_pickle(self):
136+
_, ch = channels.create()
137+
data = pickle.dumps(ch)
138+
unpickled = pickle.loads(data)
139+
self.assertEqual(unpickled, ch)
140+
128141

129142
class TestSendRecv(TestBase):
130143

Lib/test/test_interpreters/test_queues.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
import importlib
2+
import pickle
23
import threading
34
from textwrap import dedent
45
import unittest
56
import time
67

7-
from test.support import import_helper
8+
from test.support import import_helper, Py_DEBUG
89
# Raise SkipTest if subinterpreters not supported.
910
_queues = import_helper.import_module('_xxinterpqueues')
1011
from test.support import interpreters
1112
from test.support.interpreters import queues
12-
from .utils import _run_output, TestBase
13+
from .utils import _run_output, TestBase as _TestBase
1314

1415

15-
class TestBase(TestBase):
16+
def get_num_queues():
17+
return len(_queues.list_all())
18+
19+
20+
class TestBase(_TestBase):
1621
def tearDown(self):
17-
for qid in _queues.list_all():
22+
for qid, _ in _queues.list_all():
1823
try:
1924
_queues.destroy(qid)
2025
except Exception:
@@ -34,6 +39,58 @@ def test_highlevel_reloaded(self):
3439
# See gh-115490 (https://github.com/python/cpython/issues/115490).
3540
importlib.reload(queues)
3641

42+
def test_create_destroy(self):
43+
qid = _queues.create(2, 0)
44+
_queues.destroy(qid)
45+
self.assertEqual(get_num_queues(), 0)
46+
with self.assertRaises(queues.QueueNotFoundError):
47+
_queues.get(qid)
48+
with self.assertRaises(queues.QueueNotFoundError):
49+
_queues.destroy(qid)
50+
51+
def test_not_destroyed(self):
52+
# It should have cleaned up any remaining queues.
53+
stdout, stderr = self.assert_python_ok(
54+
'-c',
55+
dedent(f"""
56+
import {_queues.__name__} as _queues
57+
_queues.create(2, 0)
58+
"""),
59+
)
60+
self.assertEqual(stdout, '')
61+
if Py_DEBUG:
62+
self.assertNotEqual(stderr, '')
63+
else:
64+
self.assertEqual(stderr, '')
65+
66+
def test_bind_release(self):
67+
with self.subTest('typical'):
68+
qid = _queues.create(2, 0)
69+
_queues.bind(qid)
70+
_queues.release(qid)
71+
self.assertEqual(get_num_queues(), 0)
72+
73+
with self.subTest('bind too much'):
74+
qid = _queues.create(2, 0)
75+
_queues.bind(qid)
76+
_queues.bind(qid)
77+
_queues.release(qid)
78+
_queues.destroy(qid)
79+
self.assertEqual(get_num_queues(), 0)
80+
81+
with self.subTest('nested'):
82+
qid = _queues.create(2, 0)
83+
_queues.bind(qid)
84+
_queues.bind(qid)
85+
_queues.release(qid)
86+
_queues.release(qid)
87+
self.assertEqual(get_num_queues(), 0)
88+
89+
with self.subTest('release without binding'):
90+
qid = _queues.create(2, 0)
91+
with self.assertRaises(queues.QueueError):
92+
_queues.release(qid)
93+
3794

3895
class QueueTests(TestBase):
3996

@@ -127,6 +184,12 @@ def test_equality(self):
127184
self.assertEqual(queue1, queue1)
128185
self.assertNotEqual(queue1, queue2)
129186

187+
def test_pickle(self):
188+
queue = queues.create()
189+
data = pickle.dumps(queue)
190+
unpickled = pickle.loads(data)
191+
self.assertEqual(unpickled, queue)
192+
130193

131194
class TestQueueOps(TestBase):
132195

Modules/_interpreters_common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,11 @@ ensure_xid_class(PyTypeObject *cls, crossinterpdatafunc getdata)
1111
//assert(cls->tp_flags & Py_TPFLAGS_HEAPTYPE);
1212
return _PyCrossInterpreterData_RegisterClass(cls, getdata);
1313
}
14+
15+
#ifdef REGISTERS_HEAP_TYPES
16+
static int
17+
clear_xid_class(PyTypeObject *cls)
18+
{
19+
return _PyCrossInterpreterData_UnregisterClass(cls);
20+
}
21+
#endif

Modules/_xxinterpchannelsmodule.c

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
#include <sched.h> // sched_yield()
1818
#endif
1919

20+
#define REGISTERS_HEAP_TYPES
2021
#include "_interpreters_common.h"
22+
#undef REGISTERS_HEAP_TYPES
2123

2224

2325
/*
@@ -281,17 +283,17 @@ clear_xid_types(module_state *state)
281283
{
282284
/* external types */
283285
if (state->send_channel_type != NULL) {
284-
(void)_PyCrossInterpreterData_UnregisterClass(state->send_channel_type);
286+
(void)clear_xid_class(state->send_channel_type);
285287
Py_CLEAR(state->send_channel_type);
286288
}
287289
if (state->recv_channel_type != NULL) {
288-
(void)_PyCrossInterpreterData_UnregisterClass(state->recv_channel_type);
290+
(void)clear_xid_class(state->recv_channel_type);
289291
Py_CLEAR(state->recv_channel_type);
290292
}
291293

292294
/* heap types */
293295
if (state->ChannelIDType != NULL) {
294-
(void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
296+
(void)clear_xid_class(state->ChannelIDType);
295297
Py_CLEAR(state->ChannelIDType);
296298
}
297299
}
@@ -2677,11 +2679,11 @@ set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
26772679

26782680
// Clear the old values if the .py module was reloaded.
26792681
if (state->send_channel_type != NULL) {
2680-
(void)_PyCrossInterpreterData_UnregisterClass(state->send_channel_type);
2682+
(void)clear_xid_class(state->send_channel_type);
26812683
Py_CLEAR(state->send_channel_type);
26822684
}
26832685
if (state->recv_channel_type != NULL) {
2684-
(void)_PyCrossInterpreterData_UnregisterClass(state->recv_channel_type);
2686+
(void)clear_xid_class(state->recv_channel_type);
26852687
Py_CLEAR(state->recv_channel_type);
26862688
}
26872689

@@ -2694,7 +2696,7 @@ set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
26942696
return -1;
26952697
}
26962698
if (ensure_xid_class(recv, _channelend_shared) < 0) {
2697-
(void)_PyCrossInterpreterData_UnregisterClass(state->send_channel_type);
2699+
(void)clear_xid_class(state->send_channel_type);
26982700
Py_CLEAR(state->send_channel_type);
26992701
Py_CLEAR(state->recv_channel_type);
27002702
return -1;

0 commit comments

Comments
 (0)