Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Factor out fix_up_extension_for_interpreter().
  • Loading branch information
ericsnowcurrently committed Apr 23, 2024
commit d87d2128422ab3a37563d6d6501a395724fb53f5
1 change: 1 addition & 0 deletions Include/internal/pycore_import.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ extern int _PyImport_FixupBuiltin(
const char *name, /* UTF-8 encoded string */
PyObject *modules
);
// We could probably drop this:
extern int _PyImport_FixupExtensionObject(PyObject*, PyObject *,
PyObject *, PyObject *);

Expand Down
100 changes: 71 additions & 29 deletions Python/import.c
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,22 @@ _PyImport_ClearModules(PyInterpreterState *interp)
Py_SETREF(MODULES(interp), NULL);
}

PyObject *
PyImport_GetModuleDict(void)
static inline PyObject *
get_modules_dict(PyInterpreterState *interp)
{
PyInterpreterState *interp = _PyInterpreterState_GET();
if (MODULES(interp) == NULL) {
Py_FatalError("interpreter has no modules dictionary");
}
return MODULES(interp);
}

PyObject *
PyImport_GetModuleDict(void)
{
PyInterpreterState *interp = _PyInterpreterState_GET();
return get_modules_dict(interp);
}

int
_PyImport_SetModule(PyObject *name, PyObject *m)
{
Expand Down Expand Up @@ -894,7 +900,7 @@ extensions_lock_release(void)
(module name, module name) (for built-in modules) or by
(filename, module name) (for dynamically loaded modules), containing these
modules. A copy of the module's dictionary is stored by calling
_PyImport_FixupExtensionObject() immediately after the module initialization
fix_up_extension() immediately after the module initialization
function succeeds. A copy can be retrieved from there by calling
import_find_extension().

Expand Down Expand Up @@ -1159,28 +1165,57 @@ is_core_module(PyInterpreterState *interp, PyObject *name, PyObject *path)
}

static int
fix_up_extension(PyObject *mod, PyObject *name, PyObject *path)
fix_up_extension_for_interpreter(PyInterpreterState *interp,
PyObject *mod, PyModuleDef *def,
PyObject *name, PyObject *path,
PyObject *modules)
{
if (mod == NULL || !PyModule_Check(mod)) {
PyErr_BadInternalCall();
assert(mod != NULL && PyModule_Check(mod));
assert(def == PyModule_GetDef(mod));

if (modules == NULL) {
modules = get_modules_dict(interp);
}
if (PyObject_SetItem(modules, name, mod) < 0) {
return -1;
}

struct PyModuleDef *def = PyModule_GetDef(mod);
if (!def) {
PyErr_BadInternalCall();
return -1;
if (_modules_by_index_set(interp, def, mod) < 0) {
goto error;
}

PyThreadState *tstate = _PyThreadState_GET();
if (_modules_by_index_set(tstate->interp, def, mod) < 0) {
return 0;

error:
PyMapping_DelItem(modules, name);
return -1;
}


static int
fix_up_extension(PyObject *mod, PyModuleDef *def,
PyObject *name, PyObject *path,
PyObject *modules)
{
PyInterpreterState *interp = _PyInterpreterState_GET();
if (def == NULL) {
def = PyModule_GetDef(mod);
if (def == NULL) {
PyErr_BadInternalCall();
return -1;
}
}

if (fix_up_extension_for_interpreter(
interp, mod, def, name, path, modules) < 0)
{
return -1;
}

// bpo-44050: Extensions and def->m_base.m_copy can be updated
// when the extension module doesn't support sub-interpreters.
if (def->m_size == -1) {
if (!is_core_module(tstate->interp, name, path)) {
if (!is_core_module(interp, name, path)) {
assert(PyUnicode_CompareWithASCIIString(name, "sys") != 0);
assert(PyUnicode_CompareWithASCIIString(name, "builtins") != 0);
if (def->m_base.m_copy) {
Expand All @@ -1191,34 +1226,43 @@ fix_up_extension(PyObject *mod, PyObject *name, PyObject *path)
}
PyObject *dict = PyModule_GetDict(mod);
if (dict == NULL) {
return -1;
goto error;
}
def->m_base.m_copy = PyDict_Copy(dict);
if (def->m_base.m_copy == NULL) {
return -1;
goto error;
}
}
}

// XXX Why special-case the main interpreter?
if (_Py_IsMainInterpreter(tstate->interp) || def->m_size == -1) {
if (_Py_IsMainInterpreter(interp) || def->m_size == -1) {
#ifndef NDEBUG
PyModuleDef *cached = _extensions_cache_get(path, name);
assert(cached == NULL || cached == def);
#endif
if (_extensions_cache_set(path, name, def) < 0) {
return -1;
goto error;
}
}

return 0;


error:
PyMapping_DelItem(modules, name);
return -1;
}

int
_PyImport_FixupExtensionObject(PyObject *mod, PyObject *name,
PyObject *filename, PyObject *modules)
{
if (PyObject_SetItem(modules, name, mod) < 0) {
if (mod == NULL || !PyModule_Check(mod)) {
PyErr_BadInternalCall();
return -1;
}
if (fix_up_extension(mod, name, filename) < 0) {
PyMapping_DelItem(modules, name);
if (fix_up_extension(mod, NULL, name, filename, modules) < 0) {
return -1;
}
return 0;
Expand Down Expand Up @@ -1350,11 +1394,7 @@ _PyImport_FixupBuiltin(PyObject *mod, const char *name, PyObject *modules)
goto finally;
}

if (PyObject_SetItem(modules, nameobj, mod) < 0) {
goto finally;
}
if (fix_up_extension(mod, nameobj, nameobj) < 0) {
PyMapping_DelItem(modules, nameobj);
if (fix_up_extension(mod, def, nameobj, nameobj, modules) < 0) {
goto finally;
}

Expand Down Expand Up @@ -1391,7 +1431,6 @@ create_builtin(PyThreadState *tstate, PyObject *name, PyObject *spec)
return mod;
}

PyObject *modules = MODULES(tstate->interp);
struct _inittab *found = NULL;
for (struct _inittab *p = INITTAB; p->name != NULL; p++) {
if (_PyUnicode_EqualToASCIIString(name, p->name)) {
Expand Down Expand Up @@ -1419,14 +1458,17 @@ create_builtin(PyThreadState *tstate, PyObject *name, PyObject *spec)
return PyModule_FromDefAndSpec((PyModuleDef*)mod, spec);
}
else {
/* Remember pointer to module init function. */
assert(PyModule_Check(mod));
PyModuleDef *def = PyModule_GetDef(mod);
if (def == NULL) {
return NULL;
}

/* Remember pointer to module init function. */
def->m_base.m_init = p0;
if (_PyImport_FixupExtensionObject(mod, name, name, modules) < 0) {

PyObject *modules = MODULES(tstate->interp);
if (fix_up_extension(mod, def, name, name, modules) < 0) {
return NULL;
}
return mod;
Expand Down