diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py index f59090dec7c..6c592b6d706 100644 --- a/Lib/test/test_ast/test_ast.py +++ b/Lib/test/test_ast/test_ast.py @@ -1480,7 +1480,6 @@ def test_parse_in_error(self): ast.literal_eval(r"'\U'") self.assertIsNotNone(e.exception.__context__) - @unittest.expectedFailure # TODO: RUSTPYTHON; + Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), args=[Name(id='eggs', ctx=Load()), Constant(value='and cheese')]))]) def test_dump(self): node = ast.parse('spam(eggs, "and cheese")') self.assertEqual(ast.dump(node), @@ -1501,7 +1500,6 @@ def test_dump(self): "lineno=1, col_offset=0, end_lineno=1, end_col_offset=24)])" ) - @unittest.expectedFailure # TODO: RUSTPYTHON; - type_ignores=[]) def test_dump_indent(self): node = ast.parse('spam(eggs, "and cheese")') self.assertEqual(ast.dump(node, indent=3), """\ @@ -1557,7 +1555,6 @@ def test_dump_indent(self): end_lineno=1, end_col_offset=24)])""") - @unittest.expectedFailure # TODO: RUSTPYTHON; + Raise() def test_dump_incomplete(self): node = ast.Raise(lineno=3, col_offset=4) self.assertEqual(ast.dump(node), @@ -1622,7 +1619,6 @@ def test_dump_incomplete(self): "ClassDef('T', [], [keyword('a', Constant(None))], [], [Name('dataclass', Load())])", ) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_dump_show_empty(self): def check_node(node, empty, full, **kwargs): with self.subTest(show_empty=False): @@ -1743,7 +1739,6 @@ def test_copy_location(self): self.assertEqual(new.lineno, 1) self.assertEqual(new.col_offset, 1) - @unittest.expectedFailure # TODO: RUSTPYTHON; + Module(body=[Expr(value=Call(func=Name(id='write', ctx=Load(), lineno=1, col_offset=0, end_lineno=1, end_col_offset=5), args=[Constant(value='spam', lineno=1, col_offset=6, end_lineno=1, end_col_offset=12)], lineno=1, col_offset=0, end_lineno=1, end_col_offset=13), lineno=1, col_offset=0, end_lineno=1, end_col_offset=13), Expr(value=Call(func=Name(id='spam', ctx=Load(), lineno=1, col_offset=0, end_lineno=1, end_col_offset=0), args=[Constant(value='eggs', lineno=1, col_offset=0, end_lineno=1, end_col_offset=0)], lineno=1, col_offset=0, end_lineno=1, end_col_offset=0), lineno=1, col_offset=0, end_lineno=1, end_col_offset=0)]) def test_fix_missing_locations(self): src = ast.parse('write("spam")') src.body.append(ast.Expr(ast.Call(ast.Name('spam', ast.Load()), @@ -1807,7 +1802,6 @@ def test_iter_fields(self): self.assertEqual(d.pop('func').id, 'foo') self.assertEqual(d, {'keywords': [], 'args': []}) - @unittest.expectedFailure # TODO: RUSTPYTHON; + keyword(arg='eggs', value=Constant(value='leek')) def test_iter_child_nodes(self): node = ast.parse("spam(23, 42, eggs='leek')", mode='eval') self.assertEqual(len(list(ast.iter_child_nodes(node.body))), 4) @@ -3149,7 +3143,6 @@ def assertASTTransformation(self, transformer_class, self.assertASTEqual(result_ast, expected_ast) - @unittest.expectedFailure # TODO: RUSTPYTHON; is not def test_node_remove_single(self): code = 'def func(arg) -> SomeType: ...' expected = 'def func(arg): ...' @@ -3376,7 +3369,6 @@ class BadFields(ast.AST): with self.assertWarnsRegex(DeprecationWarning, r"Field b'\\xff\\xff.*' .*"): obj = BadFields() - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: None != [] def test_complete_field_types(self): class _AllFieldTypes(ast.AST): _fields = ('a', 'b') @@ -3569,7 +3561,6 @@ def test_single_mode_flag(self): with self.subTest(flag=flag): self.check_output(source, expect, flag) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_eval_mode_flag(self): # test 'python -m ast -m/--mode eval' source = 'print(1, 2, 3)' @@ -3604,7 +3595,6 @@ def test_func_type_mode_flag(self): with self.subTest(flag=flag): self.check_output(source, expect, flag) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: type object '_ast.Module' has no attribute '_field_types' def test_no_type_comments_flag(self): # test 'python -m ast --no-type-comments' source = 'x: bool = 1 # type: ignore[assignment]' @@ -3619,7 +3609,6 @@ def test_no_type_comments_flag(self): ''' self.check_output(source, expect, '--no-type-comments') - @unittest.expectedFailure # TODO: RUSTPYTHON def test_include_attributes_flag(self): # test 'python -m ast -a/--include-attributes' source = 'pass' @@ -3636,7 +3625,6 @@ def test_include_attributes_flag(self): with self.subTest(flag=flag): self.check_output(source, expect, flag) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_indent_flag(self): # test 'python -m ast -i/--indent 0' source = 'pass' @@ -3673,7 +3661,6 @@ def test_feature_version_flag(self): with self.assertRaises(SyntaxError): self.invoke_ast('--feature-version=3.9') - @unittest.expectedFailure # TODO: RUSTPYTHON def test_no_optimize_flag(self): # test 'python -m ast -O/--optimize -1/0' source = ''' @@ -3724,7 +3711,6 @@ def test_optimize_flag(self): with self.subTest(flag=flag): self.check_output(source, expect, flag) - @unittest.expectedFailure # TODO: RUSTPYTHON; type_ignores=[]) def test_show_empty_flag(self): # test 'python -m ast --show-empty' source = 'print(1, 2, 3)' diff --git a/crates/vm/src/stdlib/ast/pyast.rs b/crates/vm/src/stdlib/ast/pyast.rs index d6f995f6f72..a32385a3e87 100644 --- a/crates/vm/src/stdlib/ast/pyast.rs +++ b/crates/vm/src/stdlib/ast/pyast.rs @@ -1,7 +1,9 @@ #![allow(clippy::all)] use super::*; +use crate::builtins::{PyGenericAlias, PyTuple, PyTypeRef, make_union}; use crate::common::ascii; +use crate::convert::ToPyObject; use crate::function::FuncArgs; use crate::types::Initializer; @@ -926,6 +928,529 @@ impl_node!( attributes: ["lineno", "col_offset", "end_lineno", "end_col_offset"], ); +/// Marker for how to resolve an ASDL field type into a Python type object. +#[derive(Clone, Copy)] +enum FieldType { + /// AST node type reference (e.g. "expr", "stmt") + Node(&'static str), + /// Built-in type reference (e.g. "str", "int", "object") + Builtin(&'static str), + /// list[NodeType] — Py_GenericAlias(list, node_type) + ListOf(&'static str), + /// list[BuiltinType] — Py_GenericAlias(list, builtin_type) + ListOfBuiltin(&'static str), + /// NodeType | None — Union[node_type, None] + Optional(&'static str), + /// BuiltinType | None — Union[builtin_type, None] + OptionalBuiltin(&'static str), +} + +/// Field type annotations for all concrete AST node classes. +/// Derived from add_ast_annotations() in Python-ast.c. +const FIELD_TYPES: &[(&str, &[(&str, FieldType)])] = &[ + // -- mod -- + ( + "Module", + &[ + ("body", FieldType::ListOf("stmt")), + ("type_ignores", FieldType::ListOf("type_ignore")), + ], + ), + ("Interactive", &[("body", FieldType::ListOf("stmt"))]), + ("Expression", &[("body", FieldType::Node("expr"))]), + ( + "FunctionType", + &[ + ("argtypes", FieldType::ListOf("expr")), + ("returns", FieldType::Node("expr")), + ], + ), + // -- stmt -- + ( + "FunctionDef", + &[ + ("name", FieldType::Builtin("str")), + ("args", FieldType::Node("arguments")), + ("body", FieldType::ListOf("stmt")), + ("decorator_list", FieldType::ListOf("expr")), + ("returns", FieldType::Optional("expr")), + ("type_comment", FieldType::OptionalBuiltin("str")), + ("type_params", FieldType::ListOf("type_param")), + ], + ), + ( + "AsyncFunctionDef", + &[ + ("name", FieldType::Builtin("str")), + ("args", FieldType::Node("arguments")), + ("body", FieldType::ListOf("stmt")), + ("decorator_list", FieldType::ListOf("expr")), + ("returns", FieldType::Optional("expr")), + ("type_comment", FieldType::OptionalBuiltin("str")), + ("type_params", FieldType::ListOf("type_param")), + ], + ), + ( + "ClassDef", + &[ + ("name", FieldType::Builtin("str")), + ("bases", FieldType::ListOf("expr")), + ("keywords", FieldType::ListOf("keyword")), + ("body", FieldType::ListOf("stmt")), + ("decorator_list", FieldType::ListOf("expr")), + ("type_params", FieldType::ListOf("type_param")), + ], + ), + ("Return", &[("value", FieldType::Optional("expr"))]), + ("Delete", &[("targets", FieldType::ListOf("expr"))]), + ( + "Assign", + &[ + ("targets", FieldType::ListOf("expr")), + ("value", FieldType::Node("expr")), + ("type_comment", FieldType::OptionalBuiltin("str")), + ], + ), + ( + "TypeAlias", + &[ + ("name", FieldType::Node("expr")), + ("type_params", FieldType::ListOf("type_param")), + ("value", FieldType::Node("expr")), + ], + ), + ( + "AugAssign", + &[ + ("target", FieldType::Node("expr")), + ("op", FieldType::Node("operator")), + ("value", FieldType::Node("expr")), + ], + ), + ( + "AnnAssign", + &[ + ("target", FieldType::Node("expr")), + ("annotation", FieldType::Node("expr")), + ("value", FieldType::Optional("expr")), + ("simple", FieldType::Builtin("int")), + ], + ), + ( + "For", + &[ + ("target", FieldType::Node("expr")), + ("iter", FieldType::Node("expr")), + ("body", FieldType::ListOf("stmt")), + ("orelse", FieldType::ListOf("stmt")), + ("type_comment", FieldType::OptionalBuiltin("str")), + ], + ), + ( + "AsyncFor", + &[ + ("target", FieldType::Node("expr")), + ("iter", FieldType::Node("expr")), + ("body", FieldType::ListOf("stmt")), + ("orelse", FieldType::ListOf("stmt")), + ("type_comment", FieldType::OptionalBuiltin("str")), + ], + ), + ( + "While", + &[ + ("test", FieldType::Node("expr")), + ("body", FieldType::ListOf("stmt")), + ("orelse", FieldType::ListOf("stmt")), + ], + ), + ( + "If", + &[ + ("test", FieldType::Node("expr")), + ("body", FieldType::ListOf("stmt")), + ("orelse", FieldType::ListOf("stmt")), + ], + ), + ( + "With", + &[ + ("items", FieldType::ListOf("withitem")), + ("body", FieldType::ListOf("stmt")), + ("type_comment", FieldType::OptionalBuiltin("str")), + ], + ), + ( + "AsyncWith", + &[ + ("items", FieldType::ListOf("withitem")), + ("body", FieldType::ListOf("stmt")), + ("type_comment", FieldType::OptionalBuiltin("str")), + ], + ), + ( + "Match", + &[ + ("subject", FieldType::Node("expr")), + ("cases", FieldType::ListOf("match_case")), + ], + ), + ( + "Raise", + &[ + ("exc", FieldType::Optional("expr")), + ("cause", FieldType::Optional("expr")), + ], + ), + ( + "Try", + &[ + ("body", FieldType::ListOf("stmt")), + ("handlers", FieldType::ListOf("excepthandler")), + ("orelse", FieldType::ListOf("stmt")), + ("finalbody", FieldType::ListOf("stmt")), + ], + ), + ( + "TryStar", + &[ + ("body", FieldType::ListOf("stmt")), + ("handlers", FieldType::ListOf("excepthandler")), + ("orelse", FieldType::ListOf("stmt")), + ("finalbody", FieldType::ListOf("stmt")), + ], + ), + ( + "Assert", + &[ + ("test", FieldType::Node("expr")), + ("msg", FieldType::Optional("expr")), + ], + ), + ("Import", &[("names", FieldType::ListOf("alias"))]), + ( + "ImportFrom", + &[ + ("module", FieldType::OptionalBuiltin("str")), + ("names", FieldType::ListOf("alias")), + ("level", FieldType::OptionalBuiltin("int")), + ], + ), + ("Global", &[("names", FieldType::ListOfBuiltin("str"))]), + ("Nonlocal", &[("names", FieldType::ListOfBuiltin("str"))]), + ("Expr", &[("value", FieldType::Node("expr"))]), + // -- expr -- + ( + "BoolOp", + &[ + ("op", FieldType::Node("boolop")), + ("values", FieldType::ListOf("expr")), + ], + ), + ( + "NamedExpr", + &[ + ("target", FieldType::Node("expr")), + ("value", FieldType::Node("expr")), + ], + ), + ( + "BinOp", + &[ + ("left", FieldType::Node("expr")), + ("op", FieldType::Node("operator")), + ("right", FieldType::Node("expr")), + ], + ), + ( + "UnaryOp", + &[ + ("op", FieldType::Node("unaryop")), + ("operand", FieldType::Node("expr")), + ], + ), + ( + "Lambda", + &[ + ("args", FieldType::Node("arguments")), + ("body", FieldType::Node("expr")), + ], + ), + ( + "IfExp", + &[ + ("test", FieldType::Node("expr")), + ("body", FieldType::Node("expr")), + ("orelse", FieldType::Node("expr")), + ], + ), + ( + "Dict", + &[ + ("keys", FieldType::ListOf("expr")), + ("values", FieldType::ListOf("expr")), + ], + ), + ("Set", &[("elts", FieldType::ListOf("expr"))]), + ( + "ListComp", + &[ + ("elt", FieldType::Node("expr")), + ("generators", FieldType::ListOf("comprehension")), + ], + ), + ( + "SetComp", + &[ + ("elt", FieldType::Node("expr")), + ("generators", FieldType::ListOf("comprehension")), + ], + ), + ( + "DictComp", + &[ + ("key", FieldType::Node("expr")), + ("value", FieldType::Node("expr")), + ("generators", FieldType::ListOf("comprehension")), + ], + ), + ( + "GeneratorExp", + &[ + ("elt", FieldType::Node("expr")), + ("generators", FieldType::ListOf("comprehension")), + ], + ), + ("Await", &[("value", FieldType::Node("expr"))]), + ("Yield", &[("value", FieldType::Optional("expr"))]), + ("YieldFrom", &[("value", FieldType::Node("expr"))]), + ( + "Compare", + &[ + ("left", FieldType::Node("expr")), + ("ops", FieldType::ListOf("cmpop")), + ("comparators", FieldType::ListOf("expr")), + ], + ), + ( + "Call", + &[ + ("func", FieldType::Node("expr")), + ("args", FieldType::ListOf("expr")), + ("keywords", FieldType::ListOf("keyword")), + ], + ), + ( + "FormattedValue", + &[ + ("value", FieldType::Node("expr")), + ("conversion", FieldType::Builtin("int")), + ("format_spec", FieldType::Optional("expr")), + ], + ), + ("JoinedStr", &[("values", FieldType::ListOf("expr"))]), + ("TemplateStr", &[("values", FieldType::ListOf("expr"))]), + ( + "Interpolation", + &[ + ("value", FieldType::Node("expr")), + ("str", FieldType::Builtin("object")), + ("conversion", FieldType::Builtin("int")), + ("format_spec", FieldType::Optional("expr")), + ], + ), + ( + "Constant", + &[ + ("value", FieldType::Builtin("object")), + ("kind", FieldType::OptionalBuiltin("str")), + ], + ), + ( + "Attribute", + &[ + ("value", FieldType::Node("expr")), + ("attr", FieldType::Builtin("str")), + ("ctx", FieldType::Node("expr_context")), + ], + ), + ( + "Subscript", + &[ + ("value", FieldType::Node("expr")), + ("slice", FieldType::Node("expr")), + ("ctx", FieldType::Node("expr_context")), + ], + ), + ( + "Starred", + &[ + ("value", FieldType::Node("expr")), + ("ctx", FieldType::Node("expr_context")), + ], + ), + ( + "Name", + &[ + ("id", FieldType::Builtin("str")), + ("ctx", FieldType::Node("expr_context")), + ], + ), + ( + "List", + &[ + ("elts", FieldType::ListOf("expr")), + ("ctx", FieldType::Node("expr_context")), + ], + ), + ( + "Tuple", + &[ + ("elts", FieldType::ListOf("expr")), + ("ctx", FieldType::Node("expr_context")), + ], + ), + ( + "Slice", + &[ + ("lower", FieldType::Optional("expr")), + ("upper", FieldType::Optional("expr")), + ("step", FieldType::Optional("expr")), + ], + ), + // -- misc -- + ( + "comprehension", + &[ + ("target", FieldType::Node("expr")), + ("iter", FieldType::Node("expr")), + ("ifs", FieldType::ListOf("expr")), + ("is_async", FieldType::Builtin("int")), + ], + ), + ( + "ExceptHandler", + &[ + ("type", FieldType::Optional("expr")), + ("name", FieldType::OptionalBuiltin("str")), + ("body", FieldType::ListOf("stmt")), + ], + ), + ( + "arguments", + &[ + ("posonlyargs", FieldType::ListOf("arg")), + ("args", FieldType::ListOf("arg")), + ("vararg", FieldType::Optional("arg")), + ("kwonlyargs", FieldType::ListOf("arg")), + ("kw_defaults", FieldType::ListOf("expr")), + ("kwarg", FieldType::Optional("arg")), + ("defaults", FieldType::ListOf("expr")), + ], + ), + ( + "arg", + &[ + ("arg", FieldType::Builtin("str")), + ("annotation", FieldType::Optional("expr")), + ("type_comment", FieldType::OptionalBuiltin("str")), + ], + ), + ( + "keyword", + &[ + ("arg", FieldType::OptionalBuiltin("str")), + ("value", FieldType::Node("expr")), + ], + ), + ( + "alias", + &[ + ("name", FieldType::Builtin("str")), + ("asname", FieldType::OptionalBuiltin("str")), + ], + ), + ( + "withitem", + &[ + ("context_expr", FieldType::Node("expr")), + ("optional_vars", FieldType::Optional("expr")), + ], + ), + ( + "match_case", + &[ + ("pattern", FieldType::Node("pattern")), + ("guard", FieldType::Optional("expr")), + ("body", FieldType::ListOf("stmt")), + ], + ), + // -- pattern -- + ("MatchValue", &[("value", FieldType::Node("expr"))]), + ("MatchSingleton", &[("value", FieldType::Builtin("object"))]), + ( + "MatchSequence", + &[("patterns", FieldType::ListOf("pattern"))], + ), + ( + "MatchMapping", + &[ + ("keys", FieldType::ListOf("expr")), + ("patterns", FieldType::ListOf("pattern")), + ("rest", FieldType::OptionalBuiltin("str")), + ], + ), + ( + "MatchClass", + &[ + ("cls", FieldType::Node("expr")), + ("patterns", FieldType::ListOf("pattern")), + ("kwd_attrs", FieldType::ListOfBuiltin("str")), + ("kwd_patterns", FieldType::ListOf("pattern")), + ], + ), + ("MatchStar", &[("name", FieldType::OptionalBuiltin("str"))]), + ( + "MatchAs", + &[ + ("pattern", FieldType::Optional("pattern")), + ("name", FieldType::OptionalBuiltin("str")), + ], + ), + ("MatchOr", &[("patterns", FieldType::ListOf("pattern"))]), + // -- type_ignore -- + ( + "TypeIgnore", + &[ + ("lineno", FieldType::Builtin("int")), + ("tag", FieldType::Builtin("str")), + ], + ), + // -- type_param -- + ( + "TypeVar", + &[ + ("name", FieldType::Builtin("str")), + ("bound", FieldType::Optional("expr")), + ("default_value", FieldType::Optional("expr")), + ], + ), + ( + "ParamSpec", + &[ + ("name", FieldType::Builtin("str")), + ("default_value", FieldType::Optional("expr")), + ], + ), + ( + "TypeVarTuple", + &[ + ("name", FieldType::Builtin("str")), + ("default_value", FieldType::Optional("expr")), + ], + ), +]; + pub fn extend_module_nodes(vm: &VirtualMachine, module: &Py) { extend_module!(vm, module, { "mod" => NodeMod::make_class(&vm.ctx), @@ -1053,5 +1578,93 @@ pub fn extend_module_nodes(vm: &VirtualMachine, module: &Py) { "TypeVar" => NodeTypeParamTypeVar::make_class(&vm.ctx), "ParamSpec" => NodeTypeParamParamSpec::make_class(&vm.ctx), "TypeVarTuple" => NodeTypeParamTypeVarTuple::make_class(&vm.ctx), - }) + }); + + // Populate _field_types with real Python type objects + populate_field_types(vm, module); +} + +fn populate_field_types(vm: &VirtualMachine, module: &Py) { + let list_type: PyTypeRef = vm.ctx.types.list_type.to_owned(); + let none_type: PyObjectRef = vm.ctx.types.none_type.to_owned().into(); + + // Resolve a builtin type name to a Python type object + let resolve_builtin = |name: &str| -> PyObjectRef { + let ty: &Py = match name { + "str" => vm.ctx.types.str_type, + "int" => vm.ctx.types.int_type, + "object" => vm.ctx.types.object_type, + "bool" => vm.ctx.types.bool_type, + _ => unreachable!("unknown builtin type: {name}"), + }; + ty.to_owned().into() + }; + + // Resolve an AST node type name by looking it up from the module + let resolve_node = |name: &str| -> PyObjectRef { + module + .get_attr(vm.ctx.intern_str(name), vm) + .unwrap_or_else(|_| panic!("AST node type '{name}' not found in module")) + }; + + for &(class_name, fields) in FIELD_TYPES { + if fields.is_empty() { + continue; + } + + let class = module + .get_attr(class_name, vm) + .unwrap_or_else(|_| panic!("AST class '{class_name}' not found in module")); + let dict = vm.ctx.new_dict(); + + for &(field_name, ref field_type) in fields { + let type_obj = match field_type { + FieldType::Node(name) => resolve_node(name), + FieldType::Builtin(name) => resolve_builtin(name), + FieldType::ListOf(name) => { + let elem = resolve_node(name); + let args = PyTuple::new_ref(vec![elem], &vm.ctx); + PyGenericAlias::new(list_type.clone(), args, false, vm).to_pyobject(vm) + } + FieldType::ListOfBuiltin(name) => { + let elem = resolve_builtin(name); + let args = PyTuple::new_ref(vec![elem], &vm.ctx); + PyGenericAlias::new(list_type.clone(), args, false, vm).to_pyobject(vm) + } + FieldType::Optional(name) => { + let base = resolve_node(name); + let union_args = PyTuple::new_ref(vec![base, none_type.clone()], &vm.ctx); + make_union(&union_args, vm).expect("failed to create union type") + } + FieldType::OptionalBuiltin(name) => { + let base = resolve_builtin(name); + let union_args = PyTuple::new_ref(vec![base, none_type.clone()], &vm.ctx); + make_union(&union_args, vm).expect("failed to create union type") + } + }; + dict.set_item(vm.ctx.intern_str(field_name), type_obj, vm) + .expect("failed to set field type"); + } + + let dict_obj: PyObjectRef = dict.into(); + if let Some(type_obj) = class.downcast_ref::() { + type_obj.set_attr(vm.ctx.intern_str("_field_types"), dict_obj); + // NOTE: CPython also sets __annotations__ = _field_types, but + // RustPython AST types are not heap types so __annotations__ + // is not accessible via the type descriptor. + + // Set None as class-level default for optional fields. + // When ast_type_init skips optional fields, the instance + // inherits None from the class (init_types in Python-ast.c). + let none = vm.ctx.none(); + for &(field_name, ref field_type) in fields { + if matches!( + field_type, + FieldType::Optional(_) | FieldType::OptionalBuiltin(_) + ) { + type_obj.set_attr(vm.ctx.intern_str(field_name), none.clone()); + } + } + } + } } diff --git a/crates/vm/src/stdlib/ast/python.rs b/crates/vm/src/stdlib/ast/python.rs index 17062c99a0d..a2993ef1c10 100644 --- a/crates/vm/src/stdlib/ast/python.rs +++ b/crates/vm/src/stdlib/ast/python.rs @@ -5,6 +5,7 @@ pub(crate) mod _ast { use crate::{ AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{PyStrRef, PyTupleRef, PyType, PyTypeRef}, + class::PyClassImpl, function::FuncArgs, types::{Constructor, Initializer}, }; @@ -87,73 +88,36 @@ pub(crate) mod _ast { zelf.set_attr(vm.ctx.intern_str(key), value, vm)?; } - // Set default values only for built-in AST nodes (_field_types present). - // Custom AST subclasses without _field_types do NOT get automatic defaults. - let has_field_types = zelf - .class() - .get_attr(vm.ctx.intern_str("_field_types")) - .is_some(); - if has_field_types { - // ASDL list fields (type*) default to empty list, - // optional/required fields default to None. - // Fields that are always list-typed regardless of node class. - const LIST_FIELDS: &[&str] = &[ - "argtypes", - "bases", - "cases", - "comparators", - "decorator_list", - "defaults", - "elts", - "finalbody", - "generators", - "handlers", - "ifs", - "items", - "keys", - "kw_defaults", - "kwd_attrs", - "kwd_patterns", - "keywords", - "kwonlyargs", - "names", - "ops", - "patterns", - "posonlyargs", - "targets", - "type_ignores", - "type_params", - "values", - ]; - - let class_name = zelf.class().name().to_string(); + // Use _field_types to determine defaults for unset fields. + // Only built-in AST node classes have _field_types populated. + let field_types = zelf.class().get_attr(vm.ctx.intern_str("_field_types")); + if let Some(Ok(ft_dict)) = + field_types.map(|ft| ft.downcast::()) + { + let expr_ctx_type: PyObjectRef = + super::super::pyast::NodeExprContext::make_class(&vm.ctx).into(); for field in &fields { - if !set_fields.contains(field.as_str()) { - let field_name = field.as_str(); - // Some field names have different ASDL types depending on the node. - // For example, "args" is `expr*` in Call but `arguments` in Lambda. - // "body" and "orelse" are `stmt*` in most nodes but `expr` in IfExp. - let is_list_field = if field_name == "args" { - class_name == "Call" || class_name == "arguments" - } else if field_name == "body" || field_name == "orelse" { - !matches!(class_name.as_str(), "Lambda" | "Expression" | "IfExp") - } else { - LIST_FIELDS.contains(&field_name) - }; - - let default: PyObjectRef = if is_list_field { - vm.ctx.new_list(vec![]).into() - } else { - vm.ctx.none() - }; - zelf.set_attr(vm.ctx.intern_str(field_name), default, vm)?; + if set_fields.contains(field.as_str()) { + continue; + } + if let Some(ftype) = ft_dict.get_item_opt::(field.as_str(), vm)? { + if ftype.fast_isinstance(vm.ctx.types.union_type) { + // Optional field (T | None) — no default + } else if ftype.fast_isinstance(vm.ctx.types.generic_alias_type) { + // List field (list[T]) — default to [] + let empty_list: PyObjectRef = vm.ctx.new_list(vec![]).into(); + zelf.set_attr(vm.ctx.intern_str(field.as_str()), empty_list, vm)?; + } else if ftype.is(&expr_ctx_type) { + // expr_context — default to Load() + let load_type = + super::super::pyast::NodeExprContextLoad::make_class(&vm.ctx); + let load_instance = + vm.ctx.new_base_object(load_type, Some(vm.ctx.new_dict())); + zelf.set_attr(vm.ctx.intern_str(field.as_str()), load_instance, vm)?; + } + // else: required field, no default set } - } - - // Special defaults that are not None or empty list - if class_name == "ImportFrom" && !set_fields.contains("level") { - zelf.set_attr("level", vm.ctx.new_int(0), vm)?; } }