diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 3d101c62e12..0a3e57d34c6 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -2281,7 +2281,6 @@ class Ints(enum.IntEnum): self.assertEqual(Union[Literal[1], Literal[Ints.B], Literal[True]].__args__, (Literal[1], Literal[Ints.B], Literal[True])) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: types.UnionType[int, str] | float != types.UnionType[int, str, float] def test_allow_non_types_in_or(self): # gh-140348: Test that using | with a Union object allows things that are # not allowed by is_unionable(). diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index 110e50c374e..fd46f9058c0 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -2048,12 +2048,7 @@ pub(crate) fn call_slot_new( } pub(crate) fn or_(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - if !union_::is_unionable(zelf.clone(), vm) || !union_::is_unionable(other.clone(), vm) { - return Ok(vm.ctx.not_implemented()); - } - - let tuple = PyTuple::new_ref(vec![zelf, other], &vm.ctx); - union_::make_union(&tuple, vm) + union_::or_op(zelf, other, vm) } fn take_next_base(bases: &mut [Vec]) -> Option { diff --git a/crates/vm/src/builtins/union.rs b/crates/vm/src/builtins/union.rs index 9856235ecf4..907383639bd 100644 --- a/crates/vm/src/builtins/union.rs +++ b/crates/vm/src/builtins/union.rs @@ -8,7 +8,7 @@ use crate::{ convert::ToPyObject, function::PyComparisonValue, protocol::{PyMappingMethods, PyNumberMethods}, - stdlib::typing::TypeAliasType, + stdlib::typing::{TypeAliasType, call_typing_func_object}, types::{AsMapping, AsNumber, Comparable, GetAttr, Hashable, PyComparisonOp, Representable}, }; use alloc::fmt; @@ -193,7 +193,7 @@ impl PyUnion { } } -pub fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool { +fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool { let cls = obj.class(); cls.is(vm.ctx.types.none_type) || obj.downcastable::() @@ -202,6 +202,36 @@ pub fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool { || obj.downcast_ref::().is_some() } +fn type_check(arg: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // Fast path to avoid calling into typing.py + if is_unionable(arg.clone(), vm) { + return Ok(arg); + } + let message_str: PyObjectRef = vm + .ctx + .new_str("Union[arg, ...]: each arg must be a type.") + .into(); + call_typing_func_object(vm, "_type_check", (arg, message_str)) +} + +fn has_union_operands(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> bool { + let union_type = vm.ctx.types.union_type; + a.class().is(union_type) || b.class().is(union_type) +} + +pub fn or_op(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if !has_union_operands(zelf.clone(), other.clone(), vm) + && (!is_unionable(zelf.clone(), vm) || !is_unionable(other.clone(), vm)) + { + return Ok(vm.ctx.not_implemented()); + } + + let left = type_check(zelf, vm)?; + let right = type_check(other, vm)?; + let tuple = PyTuple::new_ref(vec![left, right], &vm.ctx); + make_union(&tuple, vm) +} + fn make_parameters(args: &Py, vm: &VirtualMachine) -> PyResult { let parameters = genericalias::make_parameters(args, vm); let result = dedup_and_flatten_args(¶meters, vm)?; diff --git a/crates/vm/src/stdlib/typevar.rs b/crates/vm/src/stdlib/typevar.rs index 36f2b170023..d1be1118a2e 100644 --- a/crates/vm/src/stdlib/typevar.rs +++ b/crates/vm/src/stdlib/typevar.rs @@ -6,30 +6,21 @@ pub use typevar::*; pub(crate) mod typevar { use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, - builtins::{PyTuple, PyTupleRef, PyType, PyTypeRef, make_union, pystr::AsPyStr}, + builtins::{PyTuple, PyTupleRef, PyType, PyTypeRef, make_union}, common::lock::PyMutex, - function::{FuncArgs, IntoFuncArgs, PyComparisonValue}, + function::{FuncArgs, PyComparisonValue}, protocol::PyNumberMethods, + stdlib::typing::call_typing_func_object, types::{AsNumber, Comparable, Constructor, Iterable, PyComparisonOp, Representable}, }; - pub(crate) fn _call_typing_func_object<'a>( - vm: &VirtualMachine, - func_name: impl AsPyStr<'a>, - args: impl IntoFuncArgs, - ) -> PyResult { - let module = vm.import("typing", 0)?; - let func = module.get_attr(func_name.as_pystr(&vm.ctx), vm)?; - func.call(args, vm) - } - fn type_check(arg: PyObjectRef, msg: &str, vm: &VirtualMachine) -> PyResult { // Calling typing.py here leads to bootstrapping problems if vm.is_none(&arg) { return Ok(arg.class().to_owned().into()); } let message_str: PyObjectRef = vm.ctx.new_str(msg).into(); - _call_typing_func_object(vm, "_type_check", (arg, message_str)) + call_typing_func_object(vm, "_type_check", (arg, message_str)) } /// Get the module of the caller frame, similar to CPython's caller() function. @@ -169,7 +160,7 @@ pub(crate) mod typevar { vm: &VirtualMachine, ) -> PyResult { let self_obj: PyObjectRef = zelf.into(); - _call_typing_func_object(vm, "_typevar_subst", (self_obj, arg)) + call_typing_func_object(vm, "_typevar_subst", (self_obj, arg)) } #[pymethod] @@ -514,7 +505,7 @@ pub(crate) mod typevar { vm: &VirtualMachine, ) -> PyResult { let self_obj: PyObjectRef = zelf.into(); - _call_typing_func_object(vm, "_paramspec_subst", (self_obj, arg)) + call_typing_func_object(vm, "_paramspec_subst", (self_obj, arg)) } #[pymethod] @@ -525,7 +516,7 @@ pub(crate) mod typevar { vm: &VirtualMachine, ) -> PyResult { let self_obj: PyObjectRef = zelf.into(); - _call_typing_func_object(vm, "_paramspec_prepare_subst", (self_obj, alias, args)) + call_typing_func_object(vm, "_paramspec_prepare_subst", (self_obj, alias, args)) } } @@ -711,7 +702,7 @@ pub(crate) mod typevar { vm: &VirtualMachine, ) -> PyResult { let self_obj: PyObjectRef = zelf.into(); - _call_typing_func_object(vm, "_typevartuple_prepare_subst", (self_obj, alias, args)) + call_typing_func_object(vm, "_typevartuple_prepare_subst", (self_obj, alias, args)) } } diff --git a/crates/vm/src/stdlib/typing.rs b/crates/vm/src/stdlib/typing.rs index 6938bca8bbb..94b014c62fa 100644 --- a/crates/vm/src/stdlib/typing.rs +++ b/crates/vm/src/stdlib/typing.rs @@ -1,5 +1,8 @@ // spell-checker:ignore typevarobject funcobj -use crate::{Context, class::PyClassImpl}; +use crate::{ + Context, PyResult, VirtualMachine, builtins::pystr::AsPyStr, class::PyClassImpl, + function::IntoFuncArgs, +}; pub use crate::stdlib::typevar::{ Generic, ParamSpec, ParamSpecArgs, ParamSpecKwargs, TypeVar, TypeVarTuple, @@ -13,26 +16,26 @@ pub fn init(ctx: &Context) { NoDefault::extend_class(ctx, ctx.types.typing_no_default_type); } +pub fn call_typing_func_object<'a>( + vm: &VirtualMachine, + func_name: impl AsPyStr<'a>, + args: impl IntoFuncArgs, +) -> PyResult { + let module = vm.import("typing", 0)?; + let func = module.get_attr(func_name.as_pystr(&vm.ctx), vm)?; + func.call(args, vm) +} + #[pymodule(name = "_typing", with(super::typevar::typevar))] pub(crate) mod decl { use crate::{ Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, - builtins::{PyStrRef, PyTupleRef, PyType, PyTypeRef, pystr::AsPyStr, type_}, - function::{FuncArgs, IntoFuncArgs}, + builtins::{PyStrRef, PyTupleRef, PyType, PyTypeRef, type_}, + function::FuncArgs, protocol::PyNumberMethods, types::{AsNumber, Constructor, Representable}, }; - pub(crate) fn _call_typing_func_object<'a>( - vm: &VirtualMachine, - func_name: impl AsPyStr<'a>, - args: impl IntoFuncArgs, - ) -> PyResult { - let module = vm.import("typing", 0)?; - let func = module.get_attr(func_name.as_pystr(&vm.ctx), vm)?; - func.call(args, vm) - } - #[pyfunction] pub(crate) fn _idfunc(args: FuncArgs, _vm: &VirtualMachine) -> PyObjectRef { args.args[0].clone()