Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use inspect.getattr_static instead of _safe_get_attributes
This uses builtins.dir, which will trigger custom __getattibute__ if any, and will trigger __get__ on __dict__ descriptor.
  • Loading branch information
Bobronium committed Oct 5, 2024
commit c68b4cc08f3048703c92422071afa8977eae0a47
39 changes: 17 additions & 22 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,25 +1237,6 @@ def _update_func_cell_for__class__(f, oldcls, newcls):
return False


def _safe_get_attributes(obj):
# we should avoid triggering any user-defined code
# when inspecting attributes if possible

# look for __slots__ descriptors
type_dict = object.__getattribute__(type(obj), "__dict__")
for value in type_dict.values():
if isinstance(value, types.MemberDescriptorType):
yield value.__get__(obj)

instance_dict_descriptor = type_dict.get("__dict__", None)
if not isinstance(instance_dict_descriptor, types.GetSetDescriptorType):
# __dict__ is either not present, or redefined by user
# as custom descriptor, either way, we're done here
return

yield from instance_dict_descriptor.__get__(obj).values()


def _find_inner_functions(obj, seen=None, depth=0):
if seen is None:
seen = set()
Expand All @@ -1271,11 +1252,25 @@ def _find_inner_functions(obj, seen=None, depth=0):
if depth > 2:
return None

for value in _safe_get_attributes(obj):
for attribute in dir(obj):
try:
value = inspect.getattr_static(obj, attribute)
except AttributeError:
continue
builtin_value = inspect.getattr_static(object, attribute, None)
if value is builtin_value:
# don't waste time on builtin objects
continue
if (
# isinstance() would trigger `value.__getattribute__("__class__")`
type(value) is types.MemberDescriptorType
and type not in inspect._static_getmro(type(obj))
):
value = value.__get__(obj)
if isinstance(value, types.FunctionType):
yield inspect.unwrap(value)
return
yield from _find_inner_functions(value, seen, depth)
else:
yield from _find_inner_functions(value, seen, depth)


def _create_slots(defined_fields, inherited_slots, field_names, weakref_slot):
Expand Down
39 changes: 39 additions & 0 deletions Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5181,6 +5181,45 @@ def foo(cls):
)
self.assertEqual(context.exception.args, (expected_error_message,))

def test_user_defined_code_execution(self):
class CustomDescriptor:
def __init__(self, f):
self._wrapper = partial(f, value="bar")

def __get__(self, instance, owner):
return object.__getattribute__(self, "_wrapper")(instance)

def __getattribute__(self, name):
if name in {
# these are the bare minimum for the feature to work
"__class__", # accessed on `isinstance(value, Field)`
"__wrapped__", # accessed by unwrap
"__get__", # is required for the descriptor protocol
"__dict__", # is accessed by dir() to work
}:
return object.__getattribute__(self, name)
raise RuntimeError(f"Never should be accessed: {name}")

class B:
def foo(self, value):
return value

@dataclass(slots=True)
class A(B):
@CustomDescriptor
def foo(self, value):
return super().foo(value)

self.assertEqual(A().foo, "bar")

@dataclass(slots=True)
class A(B):
@CustomDescriptor
def foo(self, value):
return super().foo(value)

self.assertEqual(A().foo, "bar")

def test_remembered_class(self):
# Apply the dataclass decorator manually (not when the class
# is created), so that we can keep a reference to the
Expand Down