Skip to content

Commit d783fdf

Browse files
committed
improve traverse
Signed-off-by: Ashwin Naren <arihant2math@gmail.com>
1 parent 6589c9b commit d783fdf

5 files changed

Lines changed: 106 additions & 20 deletions

File tree

vm/src/object/core.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,14 +464,17 @@ impl<T: PyObjectPayload> PyInner<T> {
464464
}
465465
}
466466

467+
static CURRENT_TAG: PyAtomic<usize> = Radium::new(0);
468+
467469
/// The `PyObjectRef` is one of the most used types. It is a reference to a
468470
/// python object. A single python object can have multiple references, and
469471
/// this reference counting is accounted for by this type. Use the `.clone()`
470-
/// method to create a new reference and increment the amount of references
472+
/// method to create a new reference and increment the number of references
471473
/// to the python object by 1.
472-
#[repr(transparent)]
473474
pub struct PyObjectRef {
474475
pub(crate) ptr: NonNull<PyObject>,
476+
#[cfg(feature = "gc")]
477+
pub(crate) tag: PyAtomic<usize>
475478
}
476479

477480
impl Clone for PyObjectRef {
@@ -507,6 +510,7 @@ impl ToOwned for PyObject {
507510
self.0.ref_count.inc();
508511
PyObjectRef {
509512
ptr: NonNull::from(self),
513+
tag: Radium::new(0),
510514
}
511515
}
512516
}
@@ -528,6 +532,7 @@ impl PyObjectRef {
528532
pub unsafe fn from_raw(ptr: *const PyObject) -> Self {
529533
Self {
530534
ptr: NonNull::new_unchecked(ptr as *mut PyObject),
535+
tag: Radium::new(0),
531536
}
532537
}
533538

@@ -1085,7 +1090,10 @@ where
10851090
#[inline]
10861091
fn from(value: PyRef<T>) -> Self {
10871092
let me = ManuallyDrop::new(value);
1088-
PyObjectRef { ptr: me.ptr.cast() }
1093+
PyObjectRef {
1094+
ptr: me.ptr.cast(),
1095+
tag: Radium::new(0),
1096+
}
10891097
}
10901098
}
10911099

vm/src/object/gc/collect.rs

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ where
198198
/// need to be cleaned again.
199199
pub(super) fn mark_clean<T>(allocation: &PyInner<T>)
200200
where
201-
T: Traverse + Send + Sync + ?Sized,
201+
T: Traverse + Send + Sync + Sized,
202202
{
203203
DUMPSTER.with(|dumpster| {
204204
if dumpster
@@ -357,7 +357,7 @@ impl GarbageTruck {
357357
/// # Safety
358358
///
359359
/// `ptr` must have been created as a pointer to a `PyInner<T>`.
360-
unsafe fn dfs<T: Traverse + Send + Sync + ?Sized>(
360+
unsafe fn dfs<T: Traverse + Send + Sync>(
361361
ptr: Erased,
362362
ref_graph: &mut HashMap<AllocationId, AllocationInfo>,
363363
) {
@@ -384,8 +384,7 @@ unsafe fn dfs<T: Traverse + Send + Sync + ?Sized>(
384384
});
385385

386386
if box_ref
387-
.value
388-
.accept(&mut Dfs {
387+
.traverse(&mut Dfs {
389388
ref_graph,
390389
current_id: starting_id,
391390
})
@@ -412,13 +411,13 @@ struct Dfs<'a> {
412411
impl<'a> Visitor for Dfs<'a> {
413412
fn visit_sync<T>(&mut self, gc: &PyObjectRef)
414413
where
415-
T: Traverse + Send + Sync + ?Sized,
414+
T: Traverse + Send + Sync + Sized,
416415
{
417416
// must not use deref operators since we don't want to update the generation
418417
let ptr = unsafe {
419418
// SAFETY: This is the same as the deref implementation, but avoids
420419
// incrementing the generation count.
421-
(*gc.ptr.get()).unwrap()
420+
(*gc.ptr.read()).unwrap()
422421
};
423422
let box_ref = unsafe {
424423
// SAFETY: same as above.
@@ -509,7 +508,7 @@ fn mark(root: AllocationId, graph: &mut HashMap<AllocationId, AllocationInfo>) {
509508
/// # Safety
510509
///
511510
/// `ptr` must have been created from a pointer to a `PyInner<T>`.
512-
unsafe fn destroy_erased<T: Traverse + Send + Sync + ?Sized>(
511+
unsafe fn destroy_erased<T: Traverse + Send + Sync>(
513512
ptr: Erased,
514513
graph: &HashMap<AllocationId, AllocationInfo>,
515514
) {
@@ -523,7 +522,7 @@ unsafe fn destroy_erased<T: Traverse + Send + Sync + ?Sized>(
523522
impl Visitor for PrepareForDestruction<'_> {
524523
fn visit_sync<T>(&mut self, gc: &PyObjectRef)
525524
where
526-
T: Traverse + Send + Sync + ?Sized,
525+
T: Traverse + Send + Sync + Sized,
527526
{
528527
let id = AllocationId::from(unsafe {
529528
// SAFETY: This is the same as dereferencing the GC.
@@ -538,16 +537,15 @@ unsafe fn destroy_erased<T: Traverse + Send + Sync + ?Sized>(
538537
unsafe {
539538
// SAFETY: The GC is unreachable,
540539
// so the GC will never be dereferenced again.
541-
gc.ptr.get().write((*gc.ptr.get()).as_null());
540+
gc.ptr.write((*gc.ptr.read()).as_null());
542541
}
543542
}
544543
}
545544
}
546545

547546
let specified = ptr.specify::<PyInner<T>>().as_mut();
548547
specified
549-
.value
550-
.accept(&mut PrepareForDestruction { graph })
548+
.traverse(&mut PrepareForDestruction { graph })
551549
.expect("allocation assumed to be unreachable but somehow was accessed");
552550
let layout = Layout::for_value(specified);
553551
drop_in_place(specified);
@@ -560,7 +558,7 @@ unsafe fn destroy_erased<T: Traverse + Send + Sync + ?Sized>(
560558
/// # Safety
561559
///
562560
/// `ptr` must have been created as a pointer to a `PyInner<T>`.
563-
unsafe fn drop_weak_zero<T: Traverse + Send + Sync + ?Sized>(ptr: Erased) {
561+
unsafe fn drop_weak_zero<T: Traverse + Send + Sync>(ptr: Erased) {
564562
let mut specified = ptr.specify::<PyInner<T>>();
565563
assert_eq!(specified.as_ref().ref_count.weak.load(Ordering::Relaxed), 0);
566564
assert_eq!(specified.as_ref().ref_count.strong.load(Ordering::Relaxed), 0);
@@ -575,7 +573,7 @@ unsafe impl Sync for AllocationId {}
575573

576574
impl<T> From<&PyInner<T>> for AllocationId
577575
where
578-
T: Traverse + Send + Sync + ?Sized,
576+
T: Traverse + Send + Sync,
579577
{
580578
fn from(value: &PyInner<T>) -> Self {
581579
AllocationId(NonNull::from(value).cast())
@@ -584,7 +582,7 @@ where
584582

585583
impl<T> From<NonNull<PyInner<T>>> for AllocationId
586584
where
587-
T: Traverse + Send + Sync + ?Sized,
585+
T: Traverse + Send + Sync,
588586
{
589587
fn from(value: NonNull<PyInner<T>>) -> Self {
590588
AllocationId(value.cast())

vm/src/object/gc/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ pub(crate) use dumpster::{default_collect_condition, CollectCondition, CollectIn
88
pub(crate) use visitor::Visitor;
99

1010
pub fn try_gc() {
11-
// TODO: Finish
11+
// TODO: conditionally collect
1212
dumpster::collect();
1313
}

vm/src/object/gc/visitor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
use crate::object::Traverse;
1+
use crate::object::{Traverse, TraverseFn};
22
use crate::PyObjectRef;
33

44
pub trait Visitor {
55
/// Visit a synchronized garbage-collected pointer.
66
fn visit_sync<T>(&mut self, gc: &PyObjectRef)
77
where
8-
T: Traverse + Send + Sync + ?Sized;
8+
T: Traverse + Send + Sync + Sized;
99
}

vm/src/object/traverse.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::ptr::NonNull;
33
use rustpython_common::lock::{PyMutex, PyRwLock};
44

55
use crate::{function::Either, object::PyObjectPayload, AsObject, PyObject, PyObjectRef, PyRef};
6+
use crate::object::gc::Visitor;
67

78
pub type TraverseFn<'a> = dyn FnMut(&PyObject) + 'a;
89

@@ -28,6 +29,7 @@ pub unsafe trait Traverse {
2829
///
2930
/// - _**DO NOT**_ clone a `PyObjectRef` or `Pyef<T>` in `traverse()`
3031
fn traverse(&self, traverse_fn: &mut TraverseFn);
32+
fn accept<V: Visitor>(&self, visitor: &mut V) -> Result<(), ()>;
3133
}
3234

3335
unsafe impl Traverse for PyObjectRef {
@@ -53,6 +55,15 @@ unsafe impl<T: Traverse> Traverse for Option<T> {
5355
v.traverse(traverse_fn);
5456
}
5557
}
58+
59+
#[inline]
60+
fn accept<V: Visitor>(&self, visitor: &mut V) -> Result<(), ()> {
61+
if let Some(v) = self {
62+
v.accept(visitor)
63+
} else {
64+
Ok(())
65+
}
66+
}
5667
}
5768

5869
unsafe impl<T> Traverse for [T]
@@ -65,6 +76,14 @@ where
6576
elem.traverse(traverse_fn);
6677
}
6778
}
79+
80+
#[inline]
81+
fn accept<V: Visitor>(&self, visitor: &mut V) -> Result<(), ()> {
82+
for elem in self {
83+
elem.accept(visitor)?;
84+
}
85+
Ok(())
86+
}
6887
}
6988

7089
unsafe impl<T> Traverse for Box<[T]>
@@ -77,6 +96,14 @@ where
7796
elem.traverse(traverse_fn);
7897
}
7998
}
99+
100+
#[inline]
101+
fn accept<V: Visitor>(&self, visitor: &mut V) -> Result<(), ()> {
102+
for elem in &**self {
103+
elem.accept(visitor)?;
104+
}
105+
Ok(())
106+
}
80107
}
81108

82109
unsafe impl<T> Traverse for Vec<T>
@@ -89,6 +116,14 @@ where
89116
elem.traverse(traverse_fn);
90117
}
91118
}
119+
120+
#[inline]
121+
fn accept<V: Visitor>(&self, visitor: &mut V) -> Result<(), ()> {
122+
for elem in self {
123+
elem.accept(visitor)?;
124+
}
125+
Ok(())
126+
}
92127
}
93128

94129
unsafe impl<T: Traverse> Traverse for PyRwLock<T> {
@@ -101,6 +136,15 @@ unsafe impl<T: Traverse> Traverse for PyRwLock<T> {
101136
inner.traverse(traverse_fn)
102137
}
103138
}
139+
140+
#[inline]
141+
fn accept<V: Visitor>(&self, visitor: &mut V) -> Result<(), ()> {
142+
if let Some(inner) = self.try_read_recursive() {
143+
inner.accept(visitor)
144+
} else {
145+
Ok(())
146+
}
147+
}
104148
}
105149

106150
/// Safety: We can't hold lock during traverse it's child because it may cause deadlock.
@@ -124,6 +168,24 @@ unsafe impl<T: Traverse> Traverse for PyMutex<T> {
124168
})
125169
.count();
126170
}
171+
172+
#[inline]
173+
fn accept<V: Visitor>(&self, visitor: &mut V) -> Result<(), ()> {
174+
let mut chs: Vec<NonNull<PyObject>> = Vec::new();
175+
if let Some(obj) = self.try_lock() {
176+
obj.accept(&mut |ch| {
177+
chs.push(NonNull::from(ch));
178+
})?;
179+
}
180+
chs.iter()
181+
.map(|ch| {
182+
// Safety: during gc, this should be fine, because nothing should write during gc's tracing?
183+
let ch = unsafe { ch.as_ref() };
184+
visitor.visit_sync::<PyObject>(&ch.to_owned());
185+
})
186+
.count();
187+
Ok(())
188+
}
127189
}
128190

129191
macro_rules! trace_tuple {
@@ -135,6 +197,12 @@ macro_rules! trace_tuple {
135197
self.$NUM.traverse(traverse_fn);
136198
)*
137199
}
200+
fn accept<V: Visitor>(&self, visitor: &mut V) -> Result<(), ()> {
201+
$(
202+
self.$NUM.accept(visitor)?;
203+
)*
204+
Ok(())
205+
}
138206
}
139207

140208
};
@@ -148,6 +216,14 @@ unsafe impl<A: Traverse, B: Traverse> Traverse for Either<A, B> {
148216
Either::B(b) => b.traverse(tracer_fn),
149217
}
150218
}
219+
220+
#[inline]
221+
fn accept<V: Visitor>(&self, visitor: &mut V) -> Result<(), ()> {
222+
match self {
223+
Either::A(a) => a.accept(visitor),
224+
Either::B(b) => b.accept(visitor),
225+
}
226+
}
151227
}
152228

153229
// only tuple with 12 elements or less is supported,
@@ -157,6 +233,10 @@ unsafe impl<A: Traverse> Traverse for (A,) {
157233
fn traverse(&self, tracer_fn: &mut TraverseFn) {
158234
self.0.traverse(tracer_fn);
159235
}
236+
237+
fn accept<V: Visitor>(&self, visitor: &mut V) -> Result<(), ()> {
238+
self.0.accept(visitor)
239+
}
160240
}
161241
trace_tuple!((A, 0), (B, 1));
162242
trace_tuple!((A, 0), (B, 1), (C, 2));

0 commit comments

Comments
 (0)