diff --git a/crates/derive-impl/src/pyclass.rs b/crates/derive-impl/src/pyclass.rs index 57cbf67de5a..e54c1a867fd 100644 --- a/crates/derive-impl/src/pyclass.rs +++ b/crates/derive-impl/src/pyclass.rs @@ -574,51 +574,80 @@ pub(crate) fn impl_pyclass(attr: PunctuatedNestedMeta, item: Item) -> Result) { + #try_clear_body } - assert_eq!(s, "manual"); - quote! {} - } else { - quote! {#[derive(Traverse)]} - }; - (maybe_trace_code, derive_trace) - } else { - ( - // a dummy impl, which do nothing - // #attrs - quote! { - impl ::rustpython_vm::object::MaybeTraverse for #ident { - fn try_traverse(&self, tracer_fn: &mut ::rustpython_vm::object::TraverseFn) { - // do nothing - } - } - }, - quote! {}, - ) + } } }; @@ -675,7 +704,7 @@ pub(crate) fn impl_pyclass(attr: PunctuatedNestedMeta, item: Item) -> Result) { self.0.try_traverse(traverse_fn) } + + fn try_clear(&mut self, out: &mut ::std::vec::Vec<::rustpython_vm::PyObjectRef>) { + self.0.try_clear(out) + } } // PySubclass for proper inheritance diff --git a/crates/derive-impl/src/util.rs b/crates/derive-impl/src/util.rs index 6be1fcdf7ad..b09ad9c93fe 100644 --- a/crates/derive-impl/src/util.rs +++ b/crates/derive-impl/src/util.rs @@ -372,6 +372,7 @@ impl ItemMeta for ClassItemMeta { "ctx", "impl", "traverse", + "clear", // tp_clear ]; fn from_inner(inner: ItemMetaInner) -> Self { diff --git a/crates/vm/src/builtins/dict.rs b/crates/vm/src/builtins/dict.rs index d1adb8a066d..fcb51c2ca0e 100644 --- a/crates/vm/src/builtins/dict.rs +++ b/crates/vm/src/builtins/dict.rs @@ -2,6 +2,7 @@ use super::{ IterStatus, PositionIterInternal, PyBaseExceptionRef, PyGenericAlias, PyMappingProxy, PySet, PyStr, PyStrRef, PyTupleRef, PyType, PyTypeRef, set::PySetInner, }; +use crate::object::{Traverse, TraverseFn}; use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult, TryFromObject, atomic_func, @@ -29,13 +30,28 @@ use std::sync::LazyLock; pub type DictContentType = dict_inner::Dict; -#[pyclass(module = false, name = "dict", unhashable = true, traverse)] +#[pyclass(module = false, name = "dict", unhashable = true, traverse = "manual")] #[derive(Default)] pub struct PyDict { entries: DictContentType, } pub type PyDictRef = PyRef; +// SAFETY: Traverse properly visits all owned PyObjectRefs +unsafe impl Traverse for PyDict { + fn traverse(&self, traverse_fn: &mut TraverseFn<'_>) { + self.entries.traverse(traverse_fn); + } + + fn clear(&mut self, out: &mut Vec) { + // Pop all entries and collect both keys and values + for (key, value) in self.entries.drain_entries() { + out.push(key); + out.push(value); + } + } +} + impl fmt::Debug for PyDict { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // TODO: implement more detailed, non-recursive Debug formatter diff --git a/crates/vm/src/builtins/function.rs b/crates/vm/src/builtins/function.rs index 9297cf07201..632fd867d2e 100644 --- a/crates/vm/src/builtins/function.rs +++ b/crates/vm/src/builtins/function.rs @@ -51,6 +51,70 @@ unsafe impl Traverse for PyFunction { closure.as_untyped().traverse(tracer_fn); } self.defaults_and_kwdefaults.traverse(tracer_fn); + // Traverse additional fields that may contain references + self.type_params.lock().traverse(tracer_fn); + self.annotations.lock().traverse(tracer_fn); + self.module.lock().traverse(tracer_fn); + self.doc.lock().traverse(tracer_fn); + } + + fn clear(&mut self, out: &mut Vec) { + // Pop closure if present (equivalent to Py_CLEAR(func_closure)) + if let Some(closure) = self.closure.take() { + out.push(closure.into()); + } + + // Pop defaults and kwdefaults + if let Some(mut guard) = self.defaults_and_kwdefaults.try_lock() { + if let Some(defaults) = guard.0.take() { + out.push(defaults.into()); + } + if let Some(kwdefaults) = guard.1.take() { + out.push(kwdefaults.into()); + } + } + + // Clear annotations and annotate (Py_CLEAR) + if let Some(mut guard) = self.annotations.try_lock() + && let Some(annotations) = guard.take() + { + out.push(annotations.into()); + } + if let Some(mut guard) = self.annotate.try_lock() + && let Some(annotate) = guard.take() + { + out.push(annotate); + } + + // Clear module, doc, and type_params (Py_CLEAR) + if let Some(mut guard) = self.module.try_lock() { + let old_module = + std::mem::replace(&mut *guard, Context::genesis().none.to_owned().into()); + out.push(old_module); + } + if let Some(mut guard) = self.doc.try_lock() { + let old_doc = std::mem::replace(&mut *guard, Context::genesis().none.to_owned().into()); + out.push(old_doc); + } + if let Some(mut guard) = self.type_params.try_lock() { + let old_type_params = + std::mem::replace(&mut *guard, Context::genesis().empty_tuple.to_owned()); + out.push(old_type_params.into()); + } + + // Replace name and qualname with empty string to break potential str subclass cycles + // name and qualname could be str subclasses, so they could have reference cycles + if let Some(mut guard) = self.name.try_lock() { + let old_name = std::mem::replace(&mut *guard, Context::genesis().empty_str.to_owned()); + out.push(old_name.into()); + } + if let Some(mut guard) = self.qualname.try_lock() { + let old_qualname = + std::mem::replace(&mut *guard, Context::genesis().empty_str.to_owned()); + out.push(old_qualname.into()); + } + + // Note: globals, builtins, code are NOT cleared (required to be non-NULL) } } diff --git a/crates/vm/src/builtins/list.rs b/crates/vm/src/builtins/list.rs index 02475ee12b6..84825de7d3d 100644 --- a/crates/vm/src/builtins/list.rs +++ b/crates/vm/src/builtins/list.rs @@ -3,6 +3,7 @@ use crate::atomic_func; use crate::common::lock::{ PyMappedRwLockReadGuard, PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, }; +use crate::object::{Traverse, TraverseFn}; use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, class::PyClassImpl, @@ -23,7 +24,7 @@ use crate::{ use alloc::fmt; use core::ops::DerefMut; -#[pyclass(module = false, name = "list", unhashable = true, traverse)] +#[pyclass(module = false, name = "list", unhashable = true, traverse = "manual")] #[derive(Default)] pub struct PyList { elements: PyRwLock>, @@ -50,6 +51,22 @@ impl FromIterator for PyList { } } +// SAFETY: Traverse properly visits all owned PyObjectRefs +unsafe impl Traverse for PyList { + fn traverse(&self, traverse_fn: &mut TraverseFn<'_>) { + self.elements.traverse(traverse_fn); + } + + fn clear(&mut self, out: &mut Vec) { + // During GC, we use interior mutability to access elements. + // This is safe because during GC collection, the object is unreachable + // and no other code should be accessing it. + if let Some(mut guard) = self.elements.try_write() { + out.extend(guard.drain(..)); + } + } +} + impl PyPayload for PyList { #[inline] fn class(ctx: &Context) -> &'static Py { diff --git a/crates/vm/src/builtins/str.rs b/crates/vm/src/builtins/str.rs index 640778c8cb9..d765847c1ab 100644 --- a/crates/vm/src/builtins/str.rs +++ b/crates/vm/src/builtins/str.rs @@ -1924,9 +1924,16 @@ impl fmt::Display for PyUtf8Str { } impl MaybeTraverse for PyUtf8Str { + const HAS_TRAVERSE: bool = true; + const HAS_CLEAR: bool = false; + fn try_traverse(&self, traverse_fn: &mut TraverseFn<'_>) { self.0.try_traverse(traverse_fn); } + + fn try_clear(&mut self, _out: &mut Vec) { + // No clear needed for PyUtf8Str + } } impl PyPayload for PyUtf8Str { diff --git a/crates/vm/src/builtins/tuple.rs b/crates/vm/src/builtins/tuple.rs index f6eff5b91e5..ba296686c73 100644 --- a/crates/vm/src/builtins/tuple.rs +++ b/crates/vm/src/builtins/tuple.rs @@ -3,6 +3,7 @@ use crate::common::{ hash::{PyHash, PyUHash}, lock::PyMutex, }; +use crate::object::{Traverse, TraverseFn}; use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, atomic_func, @@ -24,7 +25,7 @@ use crate::{ use alloc::fmt; use std::sync::LazyLock; -#[pyclass(module = false, name = "tuple", traverse)] +#[pyclass(module = false, name = "tuple", traverse = "manual")] pub struct PyTuple { elements: Box<[R]>, } @@ -36,6 +37,19 @@ impl fmt::Debug for PyTuple { } } +// SAFETY: Traverse properly visits all owned PyObjectRefs +// Note: Only impl for PyTuple (the default) +unsafe impl Traverse for PyTuple { + fn traverse(&self, traverse_fn: &mut TraverseFn<'_>) { + self.elements.traverse(traverse_fn); + } + + fn clear(&mut self, out: &mut Vec) { + let elements = std::mem::take(&mut self.elements); + out.extend(elements.into_vec()); + } +} + impl PyPayload for PyTuple { #[inline] fn class(ctx: &Context) -> &'static Py { diff --git a/crates/vm/src/dict_inner.rs b/crates/vm/src/dict_inner.rs index 1d9fe8403ab..f2a379d99a5 100644 --- a/crates/vm/src/dict_inner.rs +++ b/crates/vm/src/dict_inner.rs @@ -724,6 +724,17 @@ impl Dict { + inner.indices.len() * size_of::() + inner.entries.len() * size_of::>() } + + /// Pop all entries from the dict, returning (key, value) pairs. + /// This is used for circular reference resolution in GC. + /// Requires &mut self to avoid lock contention. + pub fn drain_entries(&mut self) -> impl Iterator + '_ { + let inner = self.inner.get_mut(); + inner.used = 0; + inner.filled = 0; + inner.indices.iter_mut().for_each(|i| *i = IndexEntry::FREE); + inner.entries.drain(..).flatten().map(|e| (e.key, e.value)) + } } type LookupResult = (IndexEntry, IndexIndex); diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index d83be841275..09da89089b9 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -13,6 +13,7 @@ use crate::{ coroutine::Coro, exceptions::ExceptionCtor, function::{ArgMapping, Either, FuncArgs}, + object::{Traverse, TraverseFn}, protocol::{PyIter, PyIterReturn}, scope::Scope, stdlib::{builtins, typing}, @@ -66,7 +67,7 @@ type Lasti = atomic::AtomicU32; #[cfg(not(feature = "threading"))] type Lasti = core::cell::Cell; -#[pyclass(module = false, name = "frame")] +#[pyclass(module = false, name = "frame", traverse = "manual")] pub struct Frame { pub code: PyRef, pub func_obj: Option, @@ -97,6 +98,27 @@ impl PyPayload for Frame { } } +unsafe impl Traverse for FrameState { + fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) { + self.stack.traverse(tracer_fn); + } +} + +unsafe impl Traverse for Frame { + fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) { + self.code.traverse(tracer_fn); + self.func_obj.traverse(tracer_fn); + self.fastlocals.traverse(tracer_fn); + self.cells_frees.traverse(tracer_fn); + self.locals.traverse(tracer_fn); + self.globals.traverse(tracer_fn); + self.builtins.traverse(tracer_fn); + self.trace.traverse(tracer_fn); + self.state.traverse(tracer_fn); + self.temporary_refs.traverse(tracer_fn); + } +} + // Running a frame can result in one of the below: pub enum ExecutionResult { Return(PyObjectRef), diff --git a/crates/vm/src/object/core.rs b/crates/vm/src/object/core.rs index 4e51e296462..c949cae9053 100644 --- a/crates/vm/src/object/core.rs +++ b/crates/vm/src/object/core.rs @@ -93,7 +93,7 @@ pub(super) unsafe fn debug_obj( } /// Call `try_trace` on payload -pub(super) unsafe fn try_trace_obj(x: &PyObject, tracer_fn: &mut TraverseFn<'_>) { +pub(super) unsafe fn try_traverse_obj(x: &PyObject, tracer_fn: &mut TraverseFn<'_>) { let x = unsafe { &*(x as *const PyObject as *const PyInner) }; let payload = &x.payload; payload.try_traverse(tracer_fn) diff --git a/crates/vm/src/object/traverse.rs b/crates/vm/src/object/traverse.rs index 2ce0db41a5e..367076b78e3 100644 --- a/crates/vm/src/object/traverse.rs +++ b/crates/vm/src/object/traverse.rs @@ -12,9 +12,13 @@ pub type TraverseFn<'a> = dyn FnMut(&PyObject) + 'a; /// Every PyObjectPayload impl `MaybeTrace`, which may or may not be traceable pub trait MaybeTraverse { /// if is traceable, will be used by vtable to determine - const IS_TRACE: bool = false; + const HAS_TRAVERSE: bool = false; + /// if has clear implementation for circular reference resolution (tp_clear) + const HAS_CLEAR: bool = false; // if this type is traceable, then call with tracer_fn, default to do nothing fn try_traverse(&self, traverse_fn: &mut TraverseFn<'_>); + // if this type has clear, extract child refs for circular reference resolution (tp_clear) + fn try_clear(&mut self, _out: &mut Vec) {} } /// Type that need traverse it's children should impl [`Traverse`] (not [`MaybeTraverse`]) @@ -28,6 +32,11 @@ pub unsafe trait Traverse { /// /// - _**DO NOT**_ clone a [`PyObjectRef`] or [`PyRef`] in [`Traverse::traverse()`] fn traverse(&self, traverse_fn: &mut TraverseFn<'_>); + + /// Extract all owned child PyObjectRefs for circular reference resolution (tp_clear). + /// Called just before object deallocation to break circular references. + /// Default implementation does nothing. + fn clear(&mut self, _out: &mut Vec) {} } unsafe impl Traverse for PyObjectRef { diff --git a/crates/vm/src/object/traverse_object.rs b/crates/vm/src/object/traverse_object.rs index 7a66f0b35f0..840bbd42b39 100644 --- a/crates/vm/src/object/traverse_object.rs +++ b/crates/vm/src/object/traverse_object.rs @@ -5,7 +5,7 @@ use crate::{ PyObject, object::{ Erased, InstanceDict, MaybeTraverse, PyInner, PyObjectPayload, debug_obj, drop_dealloc_obj, - try_trace_obj, + try_traverse_obj, }, }; @@ -25,8 +25,8 @@ impl PyObjVTable { drop_dealloc: drop_dealloc_obj::, debug: debug_obj::, trace: const { - if T::IS_TRACE { - Some(try_trace_obj::) + if T::HAS_TRAVERSE { + Some(try_traverse_obj::) } else { None }