Skip to content

Commit 3e83c21

Browse files
committed
Implement Windows SemLock in _multiprocessing module
Add SemLock class using Windows semaphore APIs (CreateSemaphoreW, WaitForSingleObjectEx, ReleaseSemaphore) so test_multiprocessing suites are no longer skipped with "lacks a functioning sem_open". Also add sem_unlink as no-op and flags dict for Windows.
1 parent 66f97c9 commit 3e83c21

1 file changed

Lines changed: 365 additions & 1 deletion

File tree

crates/stdlib/src/multiprocessing.rs

Lines changed: 365 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,372 @@ pub(crate) use _multiprocessing::module_def;
33
#[cfg(windows)]
44
#[pymodule]
55
mod _multiprocessing {
6-
use crate::vm::{PyResult, VirtualMachine, function::ArgBytesLike};
6+
use crate::vm::{
7+
Context, FromArgs, Py, PyPayload, PyRef, PyResult, VirtualMachine,
8+
builtins::{PyDict, PyType, PyTypeRef},
9+
function::{ArgBytesLike, FuncArgs, KwArgs},
10+
types::Constructor,
11+
};
12+
use core::sync::atomic::{AtomicI32, AtomicU32, Ordering};
13+
use windows_sys::Win32::Foundation::{
14+
CloseHandle, HANDLE, INVALID_HANDLE_VALUE, WAIT_EVENT, WAIT_OBJECT_0,
15+
};
716
use windows_sys::Win32::Networking::WinSock::{self, SOCKET};
17+
use windows_sys::Win32::System::Threading::{
18+
CreateSemaphoreW, GetCurrentThreadId, ReleaseSemaphore, WaitForSingleObjectEx,
19+
};
20+
21+
const INFINITE: u32 = 0xFFFFFFFF;
22+
const WAIT_TIMEOUT: WAIT_EVENT = 258; // 0x102
23+
const WAIT_FAILED: WAIT_EVENT = 0xFFFFFFFF;
24+
const ERROR_TOO_MANY_POSTS: u32 = 298;
25+
26+
// These match the values in Lib/multiprocessing/synchronize.py
27+
const RECURSIVE_MUTEX: i32 = 0;
28+
const SEMAPHORE: i32 = 1;
29+
30+
macro_rules! ismine {
31+
($self:expr) => {
32+
$self.count.load(Ordering::Acquire) > 0
33+
&& $self.last_tid.load(Ordering::Acquire) == unsafe { GetCurrentThreadId() }
34+
};
35+
}
36+
37+
#[derive(FromArgs)]
38+
struct SemLockNewArgs {
39+
#[pyarg(positional)]
40+
kind: i32,
41+
#[pyarg(positional)]
42+
value: i32,
43+
#[pyarg(positional)]
44+
maxvalue: i32,
45+
#[pyarg(positional)]
46+
name: String,
47+
#[pyarg(positional)]
48+
unlink: bool,
49+
}
50+
51+
#[pyattr]
52+
#[pyclass(name = "SemLock", module = "_multiprocessing")]
53+
#[derive(Debug, PyPayload)]
54+
struct SemLock {
55+
handle: SemHandle,
56+
kind: i32,
57+
maxvalue: i32,
58+
name: Option<String>,
59+
last_tid: AtomicU32,
60+
count: AtomicI32,
61+
}
62+
63+
#[derive(Debug)]
64+
struct SemHandle {
65+
raw: HANDLE,
66+
}
67+
68+
unsafe impl Send for SemHandle {}
69+
unsafe impl Sync for SemHandle {}
70+
71+
impl SemHandle {
72+
fn create(value: i32, maxvalue: i32, vm: &VirtualMachine) -> PyResult<Self> {
73+
let handle = unsafe {
74+
CreateSemaphoreW(
75+
core::ptr::null(),
76+
value,
77+
maxvalue,
78+
core::ptr::null(),
79+
)
80+
};
81+
if handle == 0 as HANDLE {
82+
return Err(vm.new_last_os_error());
83+
}
84+
// Check ERROR_ALREADY_EXISTS
85+
let last_err = unsafe { windows_sys::Win32::Foundation::GetLastError() };
86+
if last_err != 0 {
87+
unsafe { CloseHandle(handle) };
88+
return Err(vm.new_last_os_error());
89+
}
90+
Ok(SemHandle { raw: handle })
91+
}
92+
93+
#[inline]
94+
fn as_raw(&self) -> HANDLE {
95+
self.raw
96+
}
97+
}
98+
99+
impl Drop for SemHandle {
100+
fn drop(&mut self) {
101+
if self.raw != 0 as HANDLE && self.raw != INVALID_HANDLE_VALUE {
102+
unsafe {
103+
CloseHandle(self.raw);
104+
}
105+
}
106+
}
107+
}
108+
109+
/// _GetSemaphoreValue - get value of semaphore by briefly acquiring and releasing
110+
fn get_semaphore_value(handle: HANDLE) -> Result<i32, ()> {
111+
match unsafe { WaitForSingleObjectEx(handle, 0, 0) } {
112+
WAIT_OBJECT_0 => {
113+
let mut previous: i32 = 0;
114+
if unsafe { ReleaseSemaphore(handle, 1, &mut previous) } == 0 {
115+
return Err(());
116+
}
117+
Ok(previous + 1)
118+
}
119+
WAIT_TIMEOUT => Ok(0),
120+
_ => Err(()),
121+
}
122+
}
123+
124+
#[pyclass(with(Constructor), flags(BASETYPE))]
125+
impl SemLock {
126+
#[pygetset]
127+
fn handle(&self) -> isize {
128+
self.handle.as_raw() as isize
129+
}
130+
131+
#[pygetset]
132+
fn kind(&self) -> i32 {
133+
self.kind
134+
}
135+
136+
#[pygetset]
137+
fn maxvalue(&self) -> i32 {
138+
self.maxvalue
139+
}
140+
141+
#[pygetset]
142+
fn name(&self) -> Option<String> {
143+
self.name.clone()
144+
}
145+
146+
#[pymethod]
147+
fn acquire(&self, args: FuncArgs, vm: &VirtualMachine) -> PyResult<bool> {
148+
let blocking: bool = args
149+
.kwargs
150+
.get("block")
151+
.or_else(|| args.args.first())
152+
.map(|o| o.clone().try_to_bool(vm))
153+
.transpose()?
154+
.unwrap_or(true);
155+
156+
let timeout_obj = args
157+
.kwargs
158+
.get("timeout")
159+
.or_else(|| args.args.get(1))
160+
.cloned();
161+
162+
// Calculate timeout in milliseconds
163+
let full_msecs: u32 = if !blocking {
164+
0
165+
} else if timeout_obj.as_ref().is_none_or(|o| vm.is_none(o)) {
166+
INFINITE
167+
} else {
168+
let timeout: f64 = timeout_obj.unwrap().try_float(vm)?.to_f64();
169+
let timeout = timeout * 1000.0; // convert to ms
170+
if timeout < 0.0 {
171+
0
172+
} else if timeout >= 0.5 * INFINITE as f64 {
173+
return Err(
174+
vm.new_overflow_error("timeout is too large".to_owned())
175+
);
176+
} else {
177+
(timeout + 0.5) as u32
178+
}
179+
};
180+
181+
// Check whether we already own the lock
182+
if self.kind == RECURSIVE_MUTEX && ismine!(self) {
183+
self.count.fetch_add(1, Ordering::Release);
184+
return Ok(true);
185+
}
186+
187+
// Check whether we can acquire without blocking
188+
if unsafe { WaitForSingleObjectEx(self.handle.as_raw(), 0, 0) }
189+
== WAIT_OBJECT_0
190+
{
191+
self.last_tid
192+
.store(unsafe { GetCurrentThreadId() }, Ordering::Release);
193+
self.count.fetch_add(1, Ordering::Release);
194+
return Ok(true);
195+
}
196+
197+
// Do the wait
198+
let res =
199+
unsafe { WaitForSingleObjectEx(self.handle.as_raw(), full_msecs, 0) };
200+
201+
match res {
202+
WAIT_TIMEOUT => Ok(false),
203+
WAIT_OBJECT_0 => {
204+
self.last_tid
205+
.store(unsafe { GetCurrentThreadId() }, Ordering::Release);
206+
self.count.fetch_add(1, Ordering::Release);
207+
Ok(true)
208+
}
209+
WAIT_FAILED => Err(vm.new_last_os_error()),
210+
_ => Err(vm.new_runtime_error(format!(
211+
"WaitForSingleObject() gave unrecognized value {res}"
212+
))),
213+
}
214+
}
215+
216+
#[pymethod]
217+
fn release(&self, vm: &VirtualMachine) -> PyResult<()> {
218+
if self.kind == RECURSIVE_MUTEX {
219+
if !ismine!(self) {
220+
return Err(vm.new_exception_msg(
221+
vm.ctx.exceptions.assertion_error.to_owned(),
222+
"attempt to release recursive lock not owned by thread".to_owned(),
223+
));
224+
}
225+
if self.count.load(Ordering::Acquire) > 1 {
226+
self.count.fetch_sub(1, Ordering::Release);
227+
return Ok(());
228+
}
229+
}
230+
231+
if unsafe { ReleaseSemaphore(self.handle.as_raw(), 1, core::ptr::null_mut()) }
232+
== 0
233+
{
234+
let err = unsafe { windows_sys::Win32::Foundation::GetLastError() };
235+
if err == ERROR_TOO_MANY_POSTS {
236+
return Err(vm.new_value_error(
237+
"semaphore or lock released too many times".to_owned(),
238+
));
239+
}
240+
return Err(vm.new_last_os_error());
241+
}
242+
243+
self.count.fetch_sub(1, Ordering::Release);
244+
Ok(())
245+
}
246+
247+
#[pymethod(name = "__enter__")]
248+
fn enter(&self, vm: &VirtualMachine) -> PyResult<bool> {
249+
self.acquire(
250+
FuncArgs::new::<Vec<_>, KwArgs>(
251+
vec![vm.ctx.new_bool(true).into()],
252+
KwArgs::default(),
253+
),
254+
vm,
255+
)
256+
}
257+
258+
#[pymethod]
259+
fn __exit__(&self, _args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> {
260+
self.release(vm)
261+
}
262+
263+
#[pyclassmethod(name = "_rebuild")]
264+
fn rebuild(
265+
cls: PyTypeRef,
266+
handle: isize,
267+
kind: i32,
268+
maxvalue: i32,
269+
name: Option<String>,
270+
vm: &VirtualMachine,
271+
) -> PyResult {
272+
// On Windows, _rebuild receives the handle directly (no sem_open)
273+
let zelf = SemLock {
274+
handle: SemHandle {
275+
raw: handle as HANDLE,
276+
},
277+
kind,
278+
maxvalue,
279+
name,
280+
last_tid: AtomicU32::new(0),
281+
count: AtomicI32::new(0),
282+
};
283+
zelf.into_ref_with_type(vm, cls).map(Into::into)
284+
}
285+
286+
#[pymethod]
287+
fn _after_fork(&self) {
288+
self.count.store(0, Ordering::Release);
289+
self.last_tid.store(0, Ordering::Release);
290+
}
291+
292+
#[pymethod]
293+
fn __reduce__(&self, vm: &VirtualMachine) -> PyResult {
294+
Err(vm.new_type_error("cannot pickle 'SemLock' object".to_owned()))
295+
}
296+
297+
#[pymethod]
298+
fn _count(&self) -> i32 {
299+
self.count.load(Ordering::Acquire)
300+
}
301+
302+
#[pymethod]
303+
fn _is_mine(&self) -> bool {
304+
ismine!(self)
305+
}
306+
307+
#[pymethod]
308+
fn _get_value(&self, vm: &VirtualMachine) -> PyResult<i32> {
309+
get_semaphore_value(self.handle.as_raw())
310+
.map_err(|_| vm.new_last_os_error())
311+
}
312+
313+
#[pymethod]
314+
fn _is_zero(&self, vm: &VirtualMachine) -> PyResult<bool> {
315+
let val = get_semaphore_value(self.handle.as_raw())
316+
.map_err(|_| vm.new_last_os_error())?;
317+
Ok(val == 0)
318+
}
319+
320+
#[extend_class]
321+
fn extend_class(ctx: &Context, class: &Py<PyType>) {
322+
class.set_attr(
323+
ctx.intern_str("RECURSIVE_MUTEX"),
324+
ctx.new_int(RECURSIVE_MUTEX).into(),
325+
);
326+
class.set_attr(ctx.intern_str("SEMAPHORE"), ctx.new_int(SEMAPHORE).into());
327+
class.set_attr(
328+
ctx.intern_str("SEM_VALUE_MAX"),
329+
ctx.new_int(i32::MAX).into(),
330+
);
331+
}
332+
}
333+
334+
impl Constructor for SemLock {
335+
type Args = SemLockNewArgs;
336+
337+
fn py_new(_cls: &Py<PyType>, args: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
338+
if args.kind != RECURSIVE_MUTEX && args.kind != SEMAPHORE {
339+
return Err(vm.new_value_error("unrecognized kind".to_owned()));
340+
}
341+
if args.value < 0 || args.value > args.maxvalue {
342+
return Err(vm.new_value_error("invalid value".to_owned()));
343+
}
344+
345+
let handle = SemHandle::create(args.value, args.maxvalue, vm)?;
346+
let name = if args.unlink {
347+
None
348+
} else {
349+
Some(args.name)
350+
};
351+
352+
Ok(SemLock {
353+
handle,
354+
kind: args.kind,
355+
maxvalue: args.maxvalue,
356+
name,
357+
last_tid: AtomicU32::new(0),
358+
count: AtomicI32::new(0),
359+
})
360+
}
361+
}
362+
363+
// On Windows, sem_unlink is a no-op
364+
#[pyfunction]
365+
fn sem_unlink(_name: String) {}
366+
367+
#[pyattr]
368+
fn flags(vm: &VirtualMachine) -> PyRef<PyDict> {
369+
// On Windows, no HAVE_SEM_OPEN / HAVE_SEM_TIMEDWAIT / HAVE_BROKEN_SEM_GETVALUE
370+
vm.ctx.new_dict()
371+
}
8372

9373
#[pyfunction]
10374
fn closesocket(socket: usize, vm: &VirtualMachine) -> PyResult<()> {

0 commit comments

Comments
 (0)