diff --git a/Include/internal/pycore_global_strings.h b/Include/internal/pycore_global_strings.h index aa597bc8281a5a..ddabd541e83c9e 100644 --- a/Include/internal/pycore_global_strings.h +++ b/Include/internal/pycore_global_strings.h @@ -204,6 +204,7 @@ struct _Py_global_strings { STRUCT_FOR_ID(_initializing) STRUCT_FOR_ID(_is_text_encoding) STRUCT_FOR_ID(_lock_unlock_module) + STRUCT_FOR_ID(_registrycleared) STRUCT_FOR_ID(_showwarnmsg) STRUCT_FOR_ID(_shutdown) STRUCT_FOR_ID(_slotnames) diff --git a/Include/internal/pycore_runtime_init.h b/Include/internal/pycore_runtime_init.h index 04c1e671235eae..3023b635156034 100644 --- a/Include/internal/pycore_runtime_init.h +++ b/Include/internal/pycore_runtime_init.h @@ -819,6 +819,7 @@ extern "C" { INIT_ID(_initializing), \ INIT_ID(_is_text_encoding), \ INIT_ID(_lock_unlock_module), \ + INIT_ID(_registrycleared), \ INIT_ID(_showwarnmsg), \ INIT_ID(_shutdown), \ INIT_ID(_slotnames), \ diff --git a/Lib/test/test_warnings/__init__.py b/Lib/test/test_warnings/__init__.py index 4b1b4e193cb165..e91339f79c5001 100644 --- a/Lib/test/test_warnings/__init__.py +++ b/Lib/test/test_warnings/__init__.py @@ -1125,6 +1125,15 @@ def test_check_warnings(self): with warnings_helper.check_warnings(('foo', RuntimeWarning)): wmod.warn("foo") + def test_check_warnings_restore_registries(self): + global __warningregistry__ + wmod = self.module + orig_registry = __warningregistry__ = {} + with wmod.catch_warnings(module=wmod): + wmod.warn("foo") + assert len(__warningregistry__) != 0 + assert len(__warningregistry__) == 0 + class CCatchWarningTests(CatchWarningTests, unittest.TestCase): module = c_warnings diff --git a/Lib/warnings.py b/Lib/warnings.py index 691ccddfa450ad..908e1b2bddc0c0 100644 --- a/Lib/warnings.py +++ b/Lib/warnings.py @@ -127,6 +127,9 @@ def _formatwarnmsg(msg): msg.filename, msg.lineno, msg.line) return _formatwarnmsg_impl(msg) +def _registrycleared(registry): + """Hook that notifies when a module warning registry is cleared.""" + def filterwarnings(action, message="", category=Warning, module="", lineno=0, append=False): """Insert an entry into the list of warnings filters (at the front). @@ -335,6 +338,7 @@ def warn_explicit(message, category, filename, lineno, if registry is None: registry = {} if registry.get('version', 0) != _filters_version: + _registrycleared(registry) registry.clear() registry['version'] = _filters_version if isinstance(message, Warning): @@ -445,6 +449,7 @@ def __init__(self, *, record=False, module=None): self._record = record self._module = sys.modules['warnings'] if module is None else module self._entered = False + self._old_registries = [] def __repr__(self): args = [] @@ -461,7 +466,9 @@ def __enter__(self): self._entered = True self._filters = self._module.filters self._module.filters = self._filters[:] - self._module._filters_mutated() + self._orig_registrycleared = self._module._registrycleared + self._module._registrycleared = self._registrycleared + self._filters_version = self._module._filters_mutated() self._showwarning = self._module.showwarning self._showwarnmsg_impl = self._module._showwarnmsg_impl if self._record: @@ -478,10 +485,17 @@ def __exit__(self, *exc_info): if not self._entered: raise RuntimeError("Cannot exit %r without entering first" % self) self._module.filters = self._filters - self._module._filters_mutated() + self._module._registrycleared = self._orig_registrycleared + self._module._set_filters_version(self._filters_version) + for registry, registry_copy in self._old_registries: + registry.clear() + registry.update(registry_copy) self._module.showwarning = self._showwarning self._module._showwarnmsg_impl = self._showwarnmsg_impl + def _registrycleared(self, registry): + self._old_registries.append((registry, registry.copy())) + # Private utility function called by _PyErr_WarnUnawaitedCoroutine def _warn_unawaited_coroutine(coro): @@ -516,7 +530,8 @@ def extract(): # If either if the compiled regexs are None, match anything. try: from _warnings import (filters, _defaultaction, _onceregistry, - warn, warn_explicit, _filters_mutated) + warn, warn_explicit, _filters_mutated, + _set_filters_version) defaultaction = _defaultaction onceregistry = _onceregistry _warnings_defaults = True @@ -529,7 +544,12 @@ def extract(): def _filters_mutated(): global _filters_version + old_filters_version = _filters_version _filters_version += 1 + return old_filters_version + + def _set_filters_version(filters_version): + _filters_version = filters_version _warnings_defaults = False diff --git a/Python/_warnings.c b/Python/_warnings.c index 03e6ffcee0ac24..6bc409b758b119 100644 --- a/Python/_warnings.c +++ b/Python/_warnings.c @@ -392,6 +392,39 @@ get_filter(PyInterpreterState *interp, PyObject *category, } +static int +call_registrycleared(PyInterpreterState *interp, PyObject *registry) +{ + PyObject *_registrycleared, *res; + + _registrycleared = GET_WARNINGS_ATTR(interp, _registrycleared, 0); + if (_registrycleared == NULL) { + if (PyErr_Occurred()) + return -1; + return 0; + } + + if (!PyCallable_Check(_registrycleared)) { + PyErr_SetString(PyExc_TypeError, + "warnings._registrycleared() must be set to a callable"); + goto error; + } + + res = PyObject_CallFunctionObjArgs(_registrycleared, registry, NULL); + Py_DECREF(_registrycleared); + + if (res == NULL); + return -1; + + Py_DECREF(res); + return 0; + +error: + Py_XDECREF(_registrycleared); + return -1; +} + + static int already_warned(PyInterpreterState *interp, PyObject *registry, PyObject *key, int should_set) @@ -413,6 +446,7 @@ already_warned(PyInterpreterState *interp, PyObject *registry, PyObject *key, if (PyErr_Occurred()) { return -1; } + call_registrycleared(interp, registry); PyDict_Clear(registry); version_obj = PyLong_FromLong(st->filters_version); if (version_obj == NULL) @@ -1107,7 +1141,27 @@ warnings_filters_mutated(PyObject *self, PyObject *args) if (st == NULL) { return NULL; } - st->filters_version++; + return PyLong_FromLong(st->filters_version++); +} + +static PyObject * +warnings_set_filters_version(PyObject *self, PyObject *args) +{ + long filters_version; + + if (!PyArg_ParseTuple(args, "l:_set_filters_version", &filters_version)) { + return NULL; + } + + PyInterpreterState *interp = get_current_interp(); + if (interp == NULL) { + return NULL; + } + WarningsState *st = warnings_get_state(interp); + if (st == NULL) { + return NULL; + } + st->filters_version = filters_version; Py_RETURN_NONE; } @@ -1376,6 +1430,8 @@ static PyMethodDef warnings_functions[] = { METH_VARARGS | METH_KEYWORDS, warn_explicit_doc}, {"_filters_mutated", (PyCFunction)warnings_filters_mutated, METH_NOARGS, NULL}, + {"_set_filters_version", (PyCFunction)warnings_set_filters_version, + METH_VARARGS, NULL}, /* XXX(brett.cannon): add showwarning? */ /* XXX(brett.cannon): Reasonable to add formatwarning? */ {NULL, NULL} /* sentinel */ diff --git a/Tools/scripts/generate_global_objects.py b/Tools/scripts/generate_global_objects.py index bad7865f1ff83b..35ace34cb54c62 100644 --- a/Tools/scripts/generate_global_objects.py +++ b/Tools/scripts/generate_global_objects.py @@ -28,6 +28,7 @@ # from GET_WARNINGS_ATTR() in Python/_warnings.c 'WarningMessage', + '_registrycleared', '_showwarnmsg', '_warn_unawaited_coroutine', 'defaultaction',