diff --git a/Lib/test/test_marshal.py b/Lib/test/test_marshal.py index b8cdc4febcf..6234b99c242 100644 --- a/Lib/test/test_marshal.py +++ b/Lib/test/test_marshal.py @@ -49,7 +49,6 @@ def test_ints(self): self.helper(expected) n = n >> 1 - @unittest.expectedFailure # TODO: RUSTPYTHON def test_int64(self): # Simulate int marshaling with TYPE_INT64. maxint64 = (1 << 63) - 1 @@ -141,7 +140,6 @@ def test_different_filenames(self): self.assertEqual(co1.co_filename, "f1") self.assertEqual(co2.co_filename, "f2") - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument allow_code def test_no_allow_code(self): data = {'a': [({0},)]} dump = marshal.dumps(data, allow_code=False) @@ -234,14 +232,12 @@ def test_bytearray(self): new = marshal.loads(marshal.dumps(b)) self.assertEqual(type(new), bytes) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_memoryview(self): b = memoryview(b"abc") self.helper(b) new = marshal.loads(marshal.dumps(b)) self.assertEqual(type(new), bytes) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_array(self): a = array.array('B', b"abc") new = marshal.loads(marshal.dumps(a)) @@ -274,7 +270,6 @@ def test_fuzz(self): except Exception: pass - @unittest.expectedFailure # TODO: RUSTPYTHON def test_loads_recursion(self): def run_tests(N, check): # (((...None...),),) @@ -295,7 +290,7 @@ def check(s): run_tests(2**20, check) @unittest.skipIf(support.is_android, "TODO: RUSTPYTHON; segfault") - @unittest.expectedFailure # TODO: RUSTPYTHON; segfault + @unittest.skipIf(os.name == 'nt', "TODO: RUSTPYTHON; write depth limit is 2000 not 1000") def test_recursion_limit(self): # Create a deeply nested structure. head = last = [] @@ -324,7 +319,6 @@ def test_recursion_limit(self): last.append([0]) self.assertRaises(ValueError, marshal.dumps, head) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_exact_type_match(self): # Former bug: # >>> class Int(int): pass @@ -348,7 +342,6 @@ def test_invalid_longs(self): invalid_string = b'l\x02\x00\x00\x00\x00\x00\x00\x00' self.assertRaises(ValueError, marshal.loads, invalid_string) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_multiple_dumps_and_loads(self): # Issue 12291: marshal.load() should be callable multiple times # with interleaved data written by non-marshal code @@ -532,66 +525,56 @@ def helper3(self, rsample, recursive=False, simple=False): else: self.assertGreaterEqual(len(s2), len(s3)) - @unittest.expectedFailure # TODO: RUSTPYTHON def testInt(self): intobj = 123321 self.helper(intobj) self.helper3(intobj, simple=True) - @unittest.expectedFailure # TODO: RUSTPYTHON def testFloat(self): floatobj = 1.2345 self.helper(floatobj) self.helper3(floatobj) - @unittest.expectedFailure # TODO: RUSTPYTHON def testStr(self): strobj = "abcde"*3 self.helper(strobj) self.helper3(strobj) - @unittest.expectedFailure # TODO: RUSTPYTHON def testBytes(self): bytesobj = b"abcde"*3 self.helper(bytesobj) self.helper3(bytesobj) - @unittest.expectedFailure # TODO: RUSTPYTHON def testList(self): for obj in self.keys: listobj = [obj, obj] self.helper(listobj) self.helper3(listobj) - @unittest.expectedFailure # TODO: RUSTPYTHON def testTuple(self): for obj in self.keys: tupleobj = (obj, obj) self.helper(tupleobj) self.helper3(tupleobj) - @unittest.expectedFailure # TODO: RUSTPYTHON def testSet(self): for obj in self.keys: setobj = {(obj, 1), (obj, 2)} self.helper(setobj) self.helper3(setobj) - @unittest.expectedFailure # TODO: RUSTPYTHON def testFrozenSet(self): for obj in self.keys: frozensetobj = frozenset({(obj, 1), (obj, 2)}) self.helper(frozensetobj) self.helper3(frozensetobj) - @unittest.expectedFailure # TODO: RUSTPYTHON def testDict(self): for obj in self.keys: dictobj = {"hello": obj, "goodbye": obj, obj: "hello"} self.helper(dictobj) self.helper3(dictobj) - @unittest.expectedFailure # TODO: RUSTPYTHON def testModule(self): with open(__file__, "rb") as f: code = f.read() @@ -651,7 +634,6 @@ def testNoIntern(self): self.assertNotEqual(id(s2), id(s)) class SliceTestCase(unittest.TestCase, HelperMixin): - @unittest.expectedFailure # TODO: RUSTPYTHON; NotImplementedError: TODO: not implemented yet or marshal unsupported type def test_slice(self): for obj in ( slice(None), slice(1), slice(1, 2), slice(1, 2, 3), diff --git a/crates/compiler-core/src/bytecode.rs b/crates/compiler-core/src/bytecode.rs index 15940b68d1b..cd23458bf99 100644 --- a/crates/compiler-core/src/bytecode.rs +++ b/crates/compiler-core/src/bytecode.rs @@ -3,7 +3,7 @@ use crate::{ marshal::MarshalError, - varint::{read_varint, read_varint_with_start, write_varint, write_varint_with_start}, + varint::{read_varint, read_varint_with_start, write_varint_be, write_varint_with_start}, {OneIndexed, SourceLocation}, }; use alloc::{borrow::ToOwned, boxed::Box, collections::BTreeSet, fmt, string::String, vec::Vec}; @@ -71,9 +71,9 @@ pub fn encode_exception_table(entries: &[ExceptionTableEntry]) -> alloc::boxed:: let depth_lasti = ((entry.depth as u32) << 1) | (entry.push_lasti as u32); write_varint_with_start(&mut data, entry.start); - write_varint(&mut data, size); - write_varint(&mut data, entry.target); - write_varint(&mut data, depth_lasti); + write_varint_be(&mut data, size); + write_varint_be(&mut data, entry.target); + write_varint_be(&mut data, depth_lasti); } data.into_boxed_slice() } @@ -204,7 +204,7 @@ impl PyCodeLocationInfoKind { } } -pub trait Constant: Sized { +pub trait Constant: Sized + Clone { type Name: AsRef; /// Transforms the given Constant to a BorrowedConstant @@ -567,6 +567,14 @@ impl Deref for CodeUnits { } impl CodeUnits { + /// Disable adaptive specialization by setting all counters to unreachable. + /// Used for CPython-compiled bytecode where specialization may not be safe. + pub fn disable_specialization(&self) { + for counter in self.adaptive_counters.iter() { + counter.store(UNREACHABLE_BACKOFF, Ordering::Relaxed); + } + } + /// Replace the opcode at `index` in-place without changing the arg byte. /// Uses atomic Release store to ensure prior cache writes are visible /// to threads that subsequently read the new opcode with Acquire. diff --git a/crates/compiler-core/src/bytecode/oparg.rs b/crates/compiler-core/src/bytecode/oparg.rs index 11f1a59c38b..2dd18fba963 100644 --- a/crates/compiler-core/src/bytecode/oparg.rs +++ b/crates/compiler-core/src/bytecode/oparg.rs @@ -382,6 +382,10 @@ oparg_enum!( ); bitflagset::bitflag! { + /// `SET_FUNCTION_ATTRIBUTE` flags. + /// Bitmask: Defaults=0x01, KwOnly=0x02, Annotations=0x04, + /// Closure=0x08, TypeParams=0x10, Annotate=0x20. + /// Stored as bit position (0-5) by `bitflag!` macro. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[repr(u8)] pub enum MakeFunctionFlag { @@ -426,20 +430,63 @@ impl From for u32 { impl OpArgType for MakeFunctionFlag {} -oparg_enum!( - /// The possible comparison operators. - #[derive(Debug, Copy, Clone, PartialEq, Eq)] - pub enum ComparisonOperator { - // be intentional with bits so that we can do eval_ord with just a bitwise and - // bits: | Equal | Greater | Less | - Less = 0b001, - Greater = 0b010, - NotEqual = 0b011, - Equal = 0b100, - LessOrEqual = 0b101, - GreaterOrEqual = 0b110, +/// `COMPARE_OP` arg is `(cmp_index << 5) | mask`. Only the upper +/// 3 bits identify the comparison; the lower 5 bits are an inline +/// cache mask for adaptive specialization. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum ComparisonOperator { + Less, + LessOrEqual, + Equal, + NotEqual, + Greater, + GreaterOrEqual, +} + +impl TryFrom for ComparisonOperator { + type Error = MarshalError; + fn try_from(value: u8) -> Result { + Self::try_from(value as u32) } -); +} + +impl TryFrom for ComparisonOperator { + type Error = MarshalError; + /// Decode from `COMPARE_OP` arg: `(cmp_index << 5) | mask`. + fn try_from(value: u32) -> Result { + match value >> 5 { + 0 => Ok(Self::Less), + 1 => Ok(Self::LessOrEqual), + 2 => Ok(Self::Equal), + 3 => Ok(Self::NotEqual), + 4 => Ok(Self::Greater), + 5 => Ok(Self::GreaterOrEqual), + _ => Err(MarshalError::InvalidBytecode), + } + } +} + +impl From for u8 { + /// Encode as `cmp_index << 5` (mask bits zero). + fn from(value: ComparisonOperator) -> Self { + match value { + ComparisonOperator::Less => 0, + ComparisonOperator::LessOrEqual => 1 << 5, + ComparisonOperator::Equal => 2 << 5, + ComparisonOperator::NotEqual => 3 << 5, + ComparisonOperator::Greater => 4 << 5, + ComparisonOperator::GreaterOrEqual => 5 << 5, + } + } +} + +impl From for u32 { + fn from(value: ComparisonOperator) -> Self { + Self::from(u8::from(value)) + } +} + +impl OpArgType for ComparisonOperator {} oparg_enum!( /// The possible Binary operators diff --git a/crates/compiler-core/src/marshal.rs b/crates/compiler-core/src/marshal.rs index c47d41f1233..0031fba29e4 100644 --- a/crates/compiler-core/src/marshal.rs +++ b/crates/compiler-core/src/marshal.rs @@ -1,5 +1,5 @@ use crate::{OneIndexed, SourceLocation, bytecode::*}; -use alloc::{boxed::Box, vec, vec::Vec}; +use alloc::{boxed::Box, vec::Vec}; use core::convert::Infallible; use malachite_bigint::{BigInt, Sign}; use num_complex::Complex64; @@ -46,70 +46,72 @@ type Result = core::result::Result; #[derive(Clone, Copy)] #[repr(u8)] enum Type { - // Null = b'0', + Null = b'0', None = b'N', False = b'F', True = b'T', StopIter = b'S', Ellipsis = b'.', Int = b'i', + Int64 = b'I', + Long = b'l', Float = b'g', + FloatStr = b'f', + ComplexStr = b'x', Complex = b'y', - // Long = b'l', // i32 - Bytes = b's', // = TYPE_STRING - // Interned = b't', - // Ref = b'r', + Bytes = b's', + Interned = b't', + Ref = b'r', Tuple = b'(', + SmallTuple = b')', List = b'[', Dict = b'{', Code = b'c', Unicode = b'u', - // Unknown = b'?', Set = b'<', FrozenSet = b'>', - Slice = b':', // Added in version 5 + Slice = b':', Ascii = b'a', - // AsciiInterned = b'A', - // SmallTuple = b')', - // ShortAscii = b'z', - // ShortAsciiInterned = b'Z', + AsciiInterned = b'A', + ShortAscii = b'z', + ShortAsciiInterned = b'Z', } -// const FLAG_REF: u8 = b'\x80'; impl TryFrom for Type { type Error = MarshalError; fn try_from(value: u8) -> Result { use Type::*; - Ok(match value { - // b'0' => Null, + b'0' => Null, b'N' => None, b'F' => False, b'T' => True, b'S' => StopIter, b'.' => Ellipsis, b'i' => Int, + b'I' => Int64, + b'l' => Long, + b'f' => FloatStr, b'g' => Float, + b'x' => ComplexStr, b'y' => Complex, - // b'l' => Long, b's' => Bytes, - // b't' => Interned, - // b'r' => Ref, + b't' => Interned, + b'r' => Ref, b'(' => Tuple, + b')' => SmallTuple, b'[' => List, b'{' => Dict, b'c' => Code, b'u' => Unicode, - // b'?' => Unknown, b'<' => Set, b'>' => FrozenSet, b':' => Slice, b'a' => Ascii, - // b'A' => AsciiInterned, - // b')' => SmallTuple, - // b'z' => ShortAscii, - // b'Z' => ShortAsciiInterned, + b'A' => AsciiInterned, + b'z' => ShortAscii, + b'Z' => ShortAsciiInterned, _ => return Err(MarshalError::BadType), }) } @@ -187,119 +189,68 @@ impl> Read for Cursor { } } +/// Deserialize a code object (CPython field order). pub fn deserialize_code( rdr: &mut R, bag: Bag, ) -> Result> { - let len = rdr.read_u32()?; - let raw_instructions = rdr.read_slice(len * 2)?; - let instructions = CodeUnits::try_from(raw_instructions)?; - - let len = rdr.read_u32()?; - let locations = (0..len) - .map(|_| { - let start = SourceLocation { - line: OneIndexed::new(rdr.read_u32()? as _).ok_or(MarshalError::InvalidLocation)?, - character_offset: OneIndexed::from_zero_indexed(rdr.read_u32()? as _), - }; - let end = SourceLocation { - line: OneIndexed::new(rdr.read_u32()? as _).ok_or(MarshalError::InvalidLocation)?, - character_offset: OneIndexed::from_zero_indexed(rdr.read_u32()? as _), - }; - Ok((start, end)) - }) - .collect::>>()?; - - let flags = CodeFlags::from_bits_truncate(rdr.read_u32()?); - - let posonlyarg_count = rdr.read_u32()?; + // 1–5: scalar fields let arg_count = rdr.read_u32()?; + let posonlyarg_count = rdr.read_u32()?; let kwonlyarg_count = rdr.read_u32()?; + let max_stackdepth = rdr.read_u32()?; + let flags = CodeFlags::from_bits_truncate(rdr.read_u32()?); - let len = rdr.read_u32()?; - let source_path = bag.make_name(rdr.read_str(len)?); + // 6: co_code + let code_bytes = read_marshal_bytes(rdr)?; - let first_line_number = OneIndexed::new(rdr.read_u32()? as _); - let max_stackdepth = rdr.read_u32()?; + // 7: co_consts + let constants = read_marshal_const_tuple(rdr, bag)?; - let len = rdr.read_u32()?; - let obj_name = bag.make_name(rdr.read_str(len)?); + // 8: co_names + let names = read_marshal_name_tuple(rdr, &bag)?; - let len = rdr.read_u32()?; - let qualname = bag.make_name(rdr.read_str(len)?); + // 9: co_localsplusnames + let localsplusnames = read_marshal_str_vec(rdr)?; - // Read and discard legacy cell2arg data for backwards compatibility - let cell2arg_len = rdr.read_u32()?; - for _ in 0..cell2arg_len { - let _ = rdr.read_u32()?; - } + // 10: co_localspluskinds + let localspluskinds = read_marshal_bytes(rdr)?; - let len = rdr.read_u32()?; - let constants = (0..len) - .map(|_| deserialize_value(rdr, bag)) - .collect::>()?; - - let mut read_names = || { - let len = rdr.read_u32()?; - (0..len) - .map(|_| { - let len = rdr.read_u32()?; - Ok(bag.make_name(rdr.read_str(len)?)) - }) - .collect::>>() + // 11–13: filename, name, qualname + let source_path = bag.make_name(&read_marshal_str(rdr)?); + let obj_name = bag.make_name(&read_marshal_str(rdr)?); + let qualname = bag.make_name(&read_marshal_str(rdr)?); + + // 14: co_firstlineno + let first_line_raw = rdr.read_u32()? as i32; + let first_line_number = if first_line_raw > 0 { + OneIndexed::new(first_line_raw as usize) + } else { + None }; - let names = read_names()?; - let varnames = read_names()?; - let cellvars = read_names()?; - let freevars = read_names()?; - - // Read linetable and exceptiontable - let linetable_len = rdr.read_u32()?; - let linetable = rdr.read_slice(linetable_len)?.to_vec().into_boxed_slice(); - - let exceptiontable_len = rdr.read_u32()?; - let exceptiontable = rdr - .read_slice(exceptiontable_len)? - .to_vec() - .into_boxed_slice(); - - // Build localspluskinds with cell-local merging - let localspluskinds = { - use crate::bytecode::*; - let nlocals = varnames.len(); - let ncells = cellvars.len(); - let nfrees = freevars.len(); - // Count merged cells (cellvar also in varnames) - let numdropped = cellvars + // 15–16: linetable, exceptiontable + let linetable = read_marshal_bytes(rdr)?.to_vec().into_boxed_slice(); + let exceptiontable = read_marshal_bytes(rdr)?.to_vec().into_boxed_slice(); + + // Split localsplusnames/kinds → varnames/cellvars/freevars + let lp = split_localplus( + &localsplusnames .iter() - .filter(|cv| varnames.iter().any(|v| v.as_ref() == cv.as_ref())) - .count(); - let nlocalsplus = nlocals + ncells - numdropped + nfrees; - let mut kinds = vec![0u8; nlocalsplus]; - // Mark locals - for kind in kinds.iter_mut().take(nlocals) { - *kind = CO_FAST_LOCAL; - } - // Build cellfixedoffsets and mark cells - let mut cell_numdropped = 0usize; - for (i, cv) in cellvars.iter().enumerate() { - let merged_idx = varnames.iter().position(|v| v.as_ref() == cv.as_ref()); - if let Some(local_idx) = merged_idx { - kinds[local_idx] |= CO_FAST_CELL; // merged: LOCAL | CELL - cell_numdropped += 1; - } else { - let idx = nlocals + i - cell_numdropped; - kinds[idx] = CO_FAST_CELL; - } - } - // Mark frees - let free_start = nlocals + ncells - numdropped; - for i in 0..nfrees { - kinds[free_start + i] = CO_FAST_FREE; - } - kinds.into_boxed_slice() - }; + .map(|s| s.as_str()) + .collect::>(), + &localspluskinds, + arg_count, + kwonlyarg_count, + flags, + )?; + + // Bytecode already uses flat localsplus indices (no translation needed) + let instructions = CodeUnits::try_from(code_bytes.as_slice())?; + let locations = linetable_to_locations(&linetable, first_line_raw, instructions.len()); + + // Use original localspluskinds from marshal data (preserves CO_FAST_HIDDEN etc.) + let localspluskinds = localspluskinds.into_boxed_slice(); Ok(CodeObject { instructions, @@ -315,17 +266,85 @@ pub fn deserialize_code( qualname, constants, names, - varnames, - cellvars, - freevars, + varnames: lp.varnames.iter().map(|s| bag.make_name(s)).collect(), + cellvars: lp.cellvars.iter().map(|s| bag.make_name(s)).collect(), + freevars: lp.freevars.iter().map(|s| bag.make_name(s)).collect(), localspluskinds, linetable, exceptiontable, }) } +/// Read a marshal bytes object (TYPE_STRING = b's'). +fn read_marshal_bytes(rdr: &mut R) -> Result> { + let type_byte = rdr.read_u8()? & !FLAG_REF; + if type_byte != Type::Bytes as u8 { + return Err(MarshalError::BadType); + } + let len = rdr.read_u32()?; + Ok(rdr.read_slice(len)?.to_vec()) +} + +/// Read a marshal string object. +fn read_marshal_str(rdr: &mut R) -> Result { + let type_byte = rdr.read_u8()? & !FLAG_REF; + let s = match type_byte { + b'u' | b't' | b'a' | b'A' => { + let len = rdr.read_u32()?; + rdr.read_str(len)? + } + b'z' | b'Z' => { + let len = rdr.read_u8()? as u32; + rdr.read_str(len)? + } + _ => return Err(MarshalError::BadType), + }; + Ok(alloc::string::String::from(s)) +} + +/// Read a marshal tuple of strings, returning owned Strings. +fn read_marshal_str_vec(rdr: &mut R) -> Result> { + let type_byte = rdr.read_u8()? & !FLAG_REF; + let n = match type_byte { + b'(' => rdr.read_u32()? as usize, + b')' => rdr.read_u8()? as usize, + _ => return Err(MarshalError::BadType), + }; + (0..n).map(|_| read_marshal_str(rdr)).collect() +} + +fn read_marshal_name_tuple( + rdr: &mut R, + bag: &Bag, +) -> Result::Name]>> { + let type_byte = rdr.read_u8()? & !FLAG_REF; + let n = match type_byte { + b'(' => rdr.read_u32()? as usize, + b')' => rdr.read_u8()? as usize, + _ => return Err(MarshalError::BadType), + }; + (0..n) + .map(|_| Ok(bag.make_name(&read_marshal_str(rdr)?))) + .collect::>>() + .map(Vec::into_boxed_slice) +} + +/// Read a marshal tuple of constants. +fn read_marshal_const_tuple( + rdr: &mut R, + bag: Bag, +) -> Result> { + let type_byte = rdr.read_u8()? & !FLAG_REF; + let n = match type_byte { + b'(' => rdr.read_u32()? as usize, + b')' => rdr.read_u8()? as usize, + _ => return Err(MarshalError::BadType), + }; + (0..n).map(|_| deserialize_value(rdr, bag)).collect() +} + pub trait MarshalBag: Copy { - type Value; + type Value: Clone; type ConstantBag: ConstantBag; fn make_bool(&self, value: bool) -> Self::Value; @@ -364,6 +383,15 @@ pub trait MarshalBag: Copy { it: impl Iterator, ) -> Result; + fn make_slice( + &self, + _start: Self::Value, + _stop: Self::Value, + _step: Self::Value, + ) -> Result { + Err(MarshalError::BadType) + } + fn constant_bag(self) -> Self::ConstantBag; } @@ -442,8 +470,63 @@ impl MarshalBag for Bag { } } +pub const MAX_MARSHAL_STACK_DEPTH: usize = 2000; + pub fn deserialize_value(rdr: &mut R, bag: Bag) -> Result { - let typ = Type::try_from(rdr.read_u8()?)?; + let mut refs: Vec> = Vec::new(); + deserialize_value_depth(rdr, bag, MAX_MARSHAL_STACK_DEPTH, &mut refs) +} + +fn deserialize_value_depth( + rdr: &mut R, + bag: Bag, + depth: usize, + refs: &mut Vec>, +) -> Result { + if depth == 0 { + return Err(MarshalError::InvalidBytecode); + } + let raw = rdr.read_u8()?; + let flag = raw & FLAG_REF != 0; + let type_code = raw & !FLAG_REF; + + // TYPE_REF: return previously stored object + if type_code == Type::Ref as u8 { + let idx = rdr.read_u32()? as usize; + return refs + .get(idx) + .and_then(|v| v.clone()) + .ok_or(MarshalError::InvalidBytecode); + } + + // Reserve ref slot before reading (matches write order) + let slot = if flag { + let idx = refs.len(); + refs.push(None); + Some(idx) + } else { + None + }; + + let typ = Type::try_from(type_code)?; + let value = deserialize_value_typed(rdr, bag, depth, refs, typ)?; + + if let Some(idx) = slot { + refs[idx] = Some(value.clone()); + } + Ok(value) +} + +fn deserialize_value_typed( + rdr: &mut R, + bag: Bag, + depth: usize, + refs: &mut Vec>, + typ: Type, +) -> Result { + if depth == 0 { + return Err(MarshalError::InvalidBytecode); + } let value = match typ { Type::True => bag.make_bool(true), Type::False => bag.make_bool(false), @@ -451,68 +534,127 @@ pub fn deserialize_value(rdr: &mut R, bag: Bag) -> Res Type::StopIter => bag.make_stop_iter()?, Type::Ellipsis => bag.make_ellipsis(), Type::Int => { - let len = rdr.read_u32()? as i32; - let sign = if len < 0 { Sign::Minus } else { Sign::Plus }; - let bytes = rdr.read_slice(len.unsigned_abs())?; - let int = BigInt::from_bytes_le(sign, bytes); - bag.make_int(int) + let val = rdr.read_u32()? as i32; + bag.make_int(BigInt::from(val)) + } + Type::Int64 => { + let lo = rdr.read_u32()? as u64; + let hi = rdr.read_u32()? as u64; + bag.make_int(BigInt::from(((hi << 32) | lo) as i64)) } + Type::Long => bag.make_int(read_pylong(rdr)?), + Type::FloatStr => bag.make_float(read_float_str(rdr)?), Type::Float => { let value = f64::from_bits(rdr.read_u64()?); bag.make_float(value) } + Type::ComplexStr => { + let re = read_float_str(rdr)?; + let im = read_float_str(rdr)?; + bag.make_complex(Complex64 { re, im }) + } Type::Complex => { let re = f64::from_bits(rdr.read_u64()?); let im = f64::from_bits(rdr.read_u64()?); let value = Complex64 { re, im }; bag.make_complex(value) } - Type::Ascii | Type::Unicode => { + Type::Ascii | Type::AsciiInterned | Type::Unicode | Type::Interned => { let len = rdr.read_u32()?; let value = rdr.read_wtf8(len)?; bag.make_str(value) } + Type::ShortAscii | Type::ShortAsciiInterned => { + let len = rdr.read_u8()? as u32; + let value = rdr.read_wtf8(len)?; + bag.make_str(value) + } + Type::SmallTuple => { + let len = rdr.read_u8()? as usize; + let d = depth - 1; + let it = (0..len).map(|_| deserialize_value_depth(rdr, bag, d, refs)); + itertools::process_results(it, |it| bag.make_tuple(it))? + } + Type::Null => { + return Err(MarshalError::BadType); + } + Type::Ref => { + // Handled in deserialize_value_depth before calling this function + return Err(MarshalError::BadType); + } Type::Tuple => { let len = rdr.read_u32()?; - let it = (0..len).map(|_| deserialize_value(rdr, bag)); + let d = depth - 1; + let it = (0..len).map(|_| deserialize_value_depth(rdr, bag, d, refs)); itertools::process_results(it, |it| bag.make_tuple(it))? } Type::List => { let len = rdr.read_u32()?; - let it = (0..len).map(|_| deserialize_value(rdr, bag)); + let d = depth - 1; + let it = (0..len).map(|_| deserialize_value_depth(rdr, bag, d, refs)); itertools::process_results(it, |it| bag.make_list(it))?? } Type::Set => { let len = rdr.read_u32()?; - let it = (0..len).map(|_| deserialize_value(rdr, bag)); + let d = depth - 1; + let it = (0..len).map(|_| deserialize_value_depth(rdr, bag, d, refs)); itertools::process_results(it, |it| bag.make_set(it))?? } Type::FrozenSet => { let len = rdr.read_u32()?; - let it = (0..len).map(|_| deserialize_value(rdr, bag)); + let d = depth - 1; + let it = (0..len).map(|_| deserialize_value_depth(rdr, bag, d, refs)); itertools::process_results(it, |it| bag.make_frozenset(it))?? } Type::Dict => { - let len = rdr.read_u32()?; - let it = (0..len).map(|_| { - let k = deserialize_value(rdr, bag)?; - let v = deserialize_value(rdr, bag)?; - Ok::<_, MarshalError>((k, v)) - }); - itertools::process_results(it, |it| bag.make_dict(it))?? + let d = depth - 1; + let mut pairs = Vec::new(); + loop { + let raw = rdr.read_u8()?; + let type_code = raw & !FLAG_REF; + if type_code == b'0' { + break; + } + // TYPE_REF for key + let k = if type_code == Type::Ref as u8 { + let idx = rdr.read_u32()? as usize; + refs.get(idx) + .and_then(|v| v.clone()) + .ok_or(MarshalError::InvalidBytecode)? + } else { + let flag = raw & FLAG_REF != 0; + let key_slot = if flag { + let idx = refs.len(); + refs.push(None); + Some(idx) + } else { + None + }; + let key_type = Type::try_from(type_code)?; + let k = deserialize_value_typed(rdr, bag, d, refs, key_type)?; + if let Some(idx) = key_slot { + refs[idx] = Some(k.clone()); + } + k + }; + let v = deserialize_value_depth(rdr, bag, d, refs)?; + pairs.push((k, v)); + } + bag.make_dict(pairs.into_iter())? } Type::Bytes => { - // Following CPython, after marshaling, byte arrays are converted into bytes. + // After marshaling, byte arrays are converted into bytes. let len = rdr.read_u32()?; let value = rdr.read_slice(len)?; bag.make_bytes(value) } Type::Code => bag.make_code(deserialize_code(rdr, bag.constant_bag())?), Type::Slice => { - // Slice constants are not yet supported in RustPython - // This would require adding a Slice variant to ConstantData enum - // For now, return an error if we encounter a slice in marshal data - return Err(MarshalError::BadType); + let d = depth - 1; + let start = deserialize_value_depth(rdr, bag, d, refs)?; + let stop = deserialize_value_depth(rdr, bag, d, refs)?; + let step = deserialize_value_depth(rdr, bag, d, refs)?; + bag.make_slice(start, stop, step)? } }; Ok(value) @@ -541,6 +683,7 @@ pub enum DumpableValue<'a, D: Dumpable> { Set(&'a [D]), Frozenset(&'a [D]), Dict(&'a [(D, D)]), + Slice(&'a D, &'a D, &'a D), } impl<'a, C: Constant> From> for DumpableValue<'a, C> { @@ -614,12 +757,37 @@ pub fn serialize_value( ) -> Result<(), D::Error> { match constant { DumpableValue::Integer(int) => { - buf.write_u8(Type::Int as u8); - let (sign, bytes) = int.to_bytes_le(); - let len: i32 = bytes.len().try_into().expect("too long to serialize"); - let len = if sign == Sign::Minus { -len } else { len }; - buf.write_u32(len as u32); - buf.write_slice(&bytes); + if let Ok(val) = i32::try_from(int) { + buf.write_u8(Type::Int as u8); // TYPE_INT: 4-byte LE i32 + buf.write_u32(val as u32); + } else { + buf.write_u8(Type::Long as u8); + let (sign, raw) = int.to_bytes_le(); + let mut digits = alloc::vec::Vec::new(); + let mut accum: u32 = 0; + let mut bits = 0u32; + for &byte in &raw { + accum |= (byte as u32) << bits; + bits += 8; + while bits >= 15 { + digits.push((accum & 0x7fff) as u16); + accum >>= 15; + bits -= 15; + } + } + if accum > 0 || digits.is_empty() { + digits.push(accum as u16); + } + while digits.len() > 1 && *digits.last().unwrap() == 0 { + digits.pop(); + } + let n = digits.len() as i32; + let n = if sign == Sign::Minus { -n } else { n }; + buf.write_u32(n as u32); + for d in &digits { + buf.write_u16(*d); + } + } } DumpableValue::Float(f) => { buf.write_u8(Type::Float as u8); @@ -684,64 +852,411 @@ pub fn serialize_value( } DumpableValue::Dict(d) => { buf.write_u8(Type::Dict as u8); - write_len(buf, d.len()); for (k, v) in d { k.with_dump(|val| serialize_value(buf, val))??; v.with_dump(|val| serialize_value(buf, val))??; } + buf.write_u8(b'0'); // TYPE_NULL + } + DumpableValue::Slice(start, stop, step) => { + buf.write_u8(Type::Slice as u8); + start.with_dump(|val| serialize_value(buf, val))??; + stop.with_dump(|val| serialize_value(buf, val))??; + step.with_dump(|val| serialize_value(buf, val))??; } } Ok(()) } +/// Serialize a code object in CPython field order. +/// +/// Split varnames/cellvars/freevars are reassembled into +/// co_localsplusnames/co_localspluskinds. pub fn serialize_code(buf: &mut W, code: &CodeObject) { - write_len(buf, code.instructions.len()); - let original = code.instructions.original_bytes(); - buf.write_slice(&original); + // 1–5: scalar fields + buf.write_u32(code.arg_count); + buf.write_u32(code.posonlyarg_count); + buf.write_u32(code.kwonlyarg_count); + buf.write_u32(code.max_stackdepth); + buf.write_u32(code.flags.bits()); - write_len(buf, code.locations.len()); - for (start, end) in &*code.locations { - buf.write_u32(start.line.get() as _); - buf.write_u32(start.character_offset.to_zero_indexed() as _); - buf.write_u32(end.line.get() as _); - buf.write_u32(end.character_offset.to_zero_indexed() as _); + // 6: co_code (TYPE_STRING) — bytecode already uses flat localsplus indices + let bytecode = code.instructions.original_bytes(); + buf.write_u8(Type::Bytes as u8); + write_vec(buf, &bytecode); + + // 7: co_consts (TYPE_TUPLE) + buf.write_u8(Type::Tuple as u8); + write_len(buf, code.constants.len()); + for constant in &*code.constants { + serialize_value(buf, constant.borrow_constant().into()).unwrap_or_else(|x| match x {}) } - buf.write_u32(code.flags.bits()); + // 8: co_names (tuple of strings) + write_marshal_name_tuple(buf, &code.names); + + // 9: co_localsplusnames — varnames ++ cell_only ++ freevars + let cell_only_names: Vec<&str> = code + .cellvars + .iter() + .filter(|cv| !code.varnames.iter().any(|v| v.as_ref() == cv.as_ref())) + .map(|cv| cv.as_ref()) + .collect(); + let total_lp_count = code.varnames.len() + cell_only_names.len() + code.freevars.len(); + buf.write_u8(Type::Tuple as u8); + write_len(buf, total_lp_count); + for n in code.varnames.iter() { + write_marshal_str(buf, n.as_ref()); + } + for &n in &cell_only_names { + write_marshal_str(buf, n); + } + for n in code.freevars.iter() { + write_marshal_str(buf, n.as_ref()); + } + // 10: co_localspluskinds — use the stored kinds directly + buf.write_u8(Type::Bytes as u8); + write_vec(buf, &code.localspluskinds); + + // 11: co_filename + write_marshal_str(buf, code.source_path.as_ref()); + // 12: co_name + write_marshal_str(buf, code.obj_name.as_ref()); + // 13: co_qualname + write_marshal_str(buf, code.qualname.as_ref()); + // 14: co_firstlineno + buf.write_u32(code.first_line_number.map_or(0, |x| x.get() as _)); + // 15: co_linetable + buf.write_u8(Type::Bytes as u8); + write_vec(buf, &code.linetable); + // 16: co_exceptiontable + buf.write_u8(Type::Bytes as u8); + write_vec(buf, &code.exceptiontable); +} - buf.write_u32(code.posonlyarg_count); - buf.write_u32(code.arg_count); - buf.write_u32(code.kwonlyarg_count); +fn write_marshal_str(buf: &mut W, s: &str) { + let bytes = s.as_bytes(); + if bytes.len() < 256 && bytes.is_ascii() { + buf.write_u8(b'z'); // TYPE_SHORT_ASCII + buf.write_u8(bytes.len() as u8); + } else { + buf.write_u8(Type::Unicode as u8); + write_len(buf, bytes.len()); + } + buf.write_slice(bytes); +} + +fn write_marshal_name_tuple>(buf: &mut W, names: &[N]) { + buf.write_u8(Type::Tuple as u8); + write_len(buf, names.len()); + for name in names { + write_marshal_str(buf, name.as_ref()); + } +} - write_vec(buf, code.source_path.as_ref().as_bytes()); +pub const FLAG_REF: u8 = 0x80; - buf.write_u32(code.first_line_number.map_or(0, |x| x.get() as _)); - buf.write_u32(code.max_stackdepth); +/// Read a signed 32-bit LE integer. +pub fn read_i32(rdr: &mut R) -> Result { + let bytes = rdr.read_array::<4>()?; + Ok(i32::from_le_bytes(*bytes)) +} + +/// Read a TYPE_LONG arbitrary-precision integer (base-2^15 digits). +pub fn read_pylong(rdr: &mut R) -> Result { + const MARSHAL_SHIFT: u32 = 15; + const MARSHAL_BASE: u32 = 1 << MARSHAL_SHIFT; + let n = read_i32(rdr)?; + if n == 0 { + return Ok(BigInt::from(0)); + } + let negative = n < 0; + let num_digits = n.unsigned_abs() as usize; + let mut accum = BigInt::from(0); + let mut last_digit = 0u32; + for i in 0..num_digits { + let d = rdr.read_u16()? as u32; + if d >= MARSHAL_BASE { + return Err(MarshalError::InvalidBytecode); + } + last_digit = d; + accum += BigInt::from(d) << (i as u32 * MARSHAL_SHIFT); + } + if num_digits > 0 && last_digit == 0 { + return Err(MarshalError::InvalidBytecode); + } + if negative { + accum = -accum; + } + Ok(accum) +} - write_vec(buf, code.obj_name.as_ref().as_bytes()); - write_vec(buf, code.qualname.as_ref().as_bytes()); +/// Read a text-encoded float (1-byte length + ASCII). +pub fn read_float_str(rdr: &mut R) -> Result { + let n = rdr.read_u8()? as u32; + let s = rdr.read_str(n)?; + s.parse::().map_err(|_| MarshalError::InvalidBytecode) +} - // Write empty cell2arg for backwards compatibility - write_len(buf, 0); +/// Read a 4-byte-length-prefixed byte string. +pub fn read_pstring(rdr: &mut R) -> Result<&[u8]> { + let n = read_i32(rdr)?; + if n < 0 { + return Err(MarshalError::InvalidBytecode); + } + rdr.read_slice(n as u32) +} - write_len(buf, code.constants.len()); - for constant in &*code.constants { - serialize_value(buf, constant.borrow_constant().into()).unwrap_or_else(|x| match x {}) +const CO_FAST_LOCAL: u8 = 0x20; +const CO_FAST_CELL: u8 = 0x40; +const CO_FAST_FREE: u8 = 0x80; + +pub struct LocalsPlusResult { + pub varnames: Vec, + pub cellvars: Vec, + pub freevars: Vec, + pub cell2arg: Option>, + pub deref_map: Vec, +} + +pub fn split_localplus( + names: &[S], + kinds: &[u8], + arg_count: u32, + kwonlyarg_count: u32, + flags: CodeFlags, +) -> Result> { + if names.len() != kinds.len() { + return Err(MarshalError::InvalidBytecode); + } + + let mut varnames = Vec::new(); + let mut cellvars = Vec::new(); + let mut freevars = Vec::new(); + + // First pass: collect varnames (LOCAL entries) and freevars + for (name, &kind) in names.iter().zip(kinds.iter()) { + if kind & CO_FAST_LOCAL != 0 { + varnames.push(name.clone()); + } + if kind & CO_FAST_FREE != 0 { + freevars.push(name.clone()); + } + } + + // Second pass: collect cellvars in localsplusnames order. + // CELL-only vars come from non-LOCAL CELL entries. + // LOCAL|CELL vars are also added to cellvars. + // This preserves the original ordering from localsplusnames. + let mut arg_cell_positions = Vec::new(); // (cell_idx, localplus_idx) + for (i, (name, &kind)) in names.iter().zip(kinds.iter()).enumerate() { + let is_local = kind & CO_FAST_LOCAL != 0; + let is_cell = kind & CO_FAST_CELL != 0; + if is_cell { + let cell_idx = cellvars.len(); + cellvars.push(name.clone()); + if is_local { + arg_cell_positions.push((cell_idx, i)); + } + } } - let mut write_names = |names: &[C::Name]| { - write_len(buf, names.len()); - for name in names { - write_vec(buf, name.as_ref().as_bytes()); + let total_args = { + let mut t = arg_count + kwonlyarg_count; + if flags.contains(CodeFlags::VARARGS) { + t += 1; } + if flags.contains(CodeFlags::VARKEYWORDS) { + t += 1; + } + t }; - write_names(&code.names); - write_names(&code.varnames); - write_names(&code.cellvars); - write_names(&code.freevars); + let cell2arg = if !cellvars.is_empty() { + let mut mapping = alloc::vec![-1i32; cellvars.len()]; + for &(cell_idx, localplus_idx) in &arg_cell_positions { + if (localplus_idx as u32) < total_args { + mapping[cell_idx] = localplus_idx as i32; + } + } + if mapping.iter().any(|&x| x >= 0) { + Some(mapping.into_boxed_slice()) + } else { + None + } + } else { + None + }; - // Serialize linetable and exceptiontable - write_vec(buf, &code.linetable); - write_vec(buf, &code.exceptiontable); + // Build deref_map: localsplusnames index → cellvar/freevar index + let mut deref_map = alloc::vec![u32::MAX; names.len()]; + let mut cell_idx = 0u32; + for (i, &kind) in kinds.iter().enumerate() { + if kind & CO_FAST_CELL != 0 { + deref_map[i] = cell_idx; + cell_idx += 1; + } + } + let ncells = cellvars.len(); + let mut free_idx = 0u32; + for (i, &kind) in kinds.iter().enumerate() { + if kind & CO_FAST_FREE != 0 { + deref_map[i] = ncells as u32 + free_idx; + free_idx += 1; + } + } + + Ok(LocalsPlusResult { + varnames, + cellvars, + freevars, + cell2arg, + deref_map, + }) +} + +pub fn linetable_to_locations( + linetable: &[u8], + first_line: i32, + num_instructions: usize, +) -> Box<[(SourceLocation, SourceLocation)]> { + let default_loc = || { + let line = if first_line > 0 { + OneIndexed::new(first_line as usize).unwrap_or(OneIndexed::MIN) + } else { + OneIndexed::MIN + }; + let loc = SourceLocation { + line, + character_offset: OneIndexed::from_zero_indexed(0), + }; + (loc, loc) + }; + if linetable.is_empty() { + return alloc::vec![default_loc(); num_instructions].into_boxed_slice(); + } + + let mut locations = Vec::with_capacity(num_instructions); + let mut pos = 0; + let mut line = first_line; + + while pos < linetable.len() && locations.len() < num_instructions { + let first_byte = linetable[pos]; + pos += 1; + if first_byte & 0x80 == 0 { + break; + } + let code = (first_byte >> 3) & 0x0f; + let length = ((first_byte & 0x07) + 1) as usize; + let kind = match PyCodeLocationInfoKind::from_code(code) { + Some(k) => k, + None => break, + }; + + let (line_delta, end_line_delta, col, end_col): (i32, i32, Option, Option) = + match kind { + PyCodeLocationInfoKind::None => (0, 0, None, None), + PyCodeLocationInfoKind::Long => { + let d = lt_read_signed_varint(linetable, &mut pos); + let ed = lt_read_varint(linetable, &mut pos) as i32; + let c = lt_read_varint(linetable, &mut pos); + let ec = lt_read_varint(linetable, &mut pos); + ( + d, + ed, + if c == 0 { None } else { Some(c - 1) }, + if ec == 0 { None } else { Some(ec - 1) }, + ) + } + PyCodeLocationInfoKind::NoColumns => { + (lt_read_signed_varint(linetable, &mut pos), 0, None, None) + } + PyCodeLocationInfoKind::OneLine0 + | PyCodeLocationInfoKind::OneLine1 + | PyCodeLocationInfoKind::OneLine2 => { + let c = lt_byte(linetable, &mut pos) as u32; + let ec = lt_byte(linetable, &mut pos) as u32; + (kind.one_line_delta().unwrap_or(0), 0, Some(c), Some(ec)) + } + _ if kind.is_short() => { + let d = lt_byte(linetable, &mut pos); + let g = kind.short_column_group().unwrap_or(0); + let c = ((g as u32) << 3) | ((d >> 4) as u32); + (0, 0, Some(c), Some(c + (d & 0x0f) as u32)) + } + _ => (0, 0, None, None), + }; + + line += line_delta; + for _ in 0..length { + if locations.len() >= num_instructions { + break; + } + if kind == PyCodeLocationInfoKind::None { + locations.push(default_loc()); + } else { + let mk = |l: i32| { + if l > 0 { + OneIndexed::new(l as usize).unwrap_or(OneIndexed::MIN) + } else { + OneIndexed::MIN + } + }; + locations.push(( + SourceLocation { + line: mk(line), + character_offset: OneIndexed::from_zero_indexed(col.unwrap_or(0) as usize), + }, + SourceLocation { + line: mk(line + end_line_delta), + character_offset: OneIndexed::from_zero_indexed( + end_col.unwrap_or(0) as usize + ), + }, + )); + } + } + } + while locations.len() < num_instructions { + locations.push(default_loc()); + } + locations.into_boxed_slice() +} + +fn lt_byte(data: &[u8], pos: &mut usize) -> u8 { + if *pos < data.len() { + let b = data[*pos]; + *pos += 1; + b + } else { + 0 + } +} + +/// Linetable uses little-endian varint. +fn lt_read_varint(data: &[u8], pos: &mut usize) -> u32 { + let mut result: u32 = 0; + let mut shift = 0; + loop { + if *pos >= data.len() { + break; + } + let b = data[*pos]; + *pos += 1; + result |= ((b & 0x3f) as u32) << shift; + shift += 6; + if b & 0x40 == 0 { + break; + } + } + result +} + +fn lt_read_signed_varint(data: &[u8], pos: &mut usize) -> i32 { + let val = lt_read_varint(data, pos); + if val & 1 != 0 { + -((val >> 1) as i32) + } else { + (val >> 1) as i32 + } } diff --git a/crates/compiler-core/src/varint.rs b/crates/compiler-core/src/varint.rs index f1ea6b17ec0..c07b8b58e6a 100644 --- a/crates/compiler-core/src/varint.rs +++ b/crates/compiler-core/src/varint.rs @@ -1,12 +1,14 @@ //! Variable-length integer encoding utilities. //! -//! Uses 6-bit chunks with a continuation bit (0x40) to encode integers. -//! Used for exception tables and line number tables. +//! Two encodings are used: +//! - **Little-endian** (low bits first): linetable +//! - **Big-endian** (high bits first): exception tables +//! +//! Both use 6-bit chunks with 0x40 as the continuation bit. use alloc::vec::Vec; -/// Write a variable-length unsigned integer using 6-bit chunks. -/// Returns the number of bytes written. +/// Write a little-endian varint (used by linetable). #[inline] pub fn write_varint(buf: &mut Vec, mut val: u32) -> usize { let start_len = buf.len(); @@ -18,12 +20,10 @@ pub fn write_varint(buf: &mut Vec, mut val: u32) -> usize { buf.len() - start_len } -/// Write a variable-length signed integer. -/// Returns the number of bytes written. +/// Write a little-endian signed varint. #[inline] pub fn write_signed_varint(buf: &mut Vec, val: i32) -> usize { let uval = if val < 0 { - // (0 - val as u32) handles INT_MIN correctly ((0u32.wrapping_sub(val as u32)) << 1) | 1 } else { (val as u32) << 1 @@ -31,70 +31,72 @@ pub fn write_signed_varint(buf: &mut Vec, val: i32) -> usize { write_varint(buf, uval) } -/// Write a variable-length unsigned integer with a start marker (0x80 bit). -/// Used for exception table entries where each entry starts with the marker. +/// Write a big-endian varint (used by exception tables). +pub fn write_varint_be(buf: &mut Vec, val: u32) -> usize { + let start_len = buf.len(); + if val >= 1 << 30 { + buf.push(0x40 | ((val >> 30) & 0x3f) as u8); + } + if val >= 1 << 24 { + buf.push(0x40 | ((val >> 24) & 0x3f) as u8); + } + if val >= 1 << 18 { + buf.push(0x40 | ((val >> 18) & 0x3f) as u8); + } + if val >= 1 << 12 { + buf.push(0x40 | ((val >> 12) & 0x3f) as u8); + } + if val >= 1 << 6 { + buf.push(0x40 | ((val >> 6) & 0x3f) as u8); + } + buf.push((val & 0x3f) as u8); + buf.len() - start_len +} + +/// Write a big-endian varint with the start marker (0x80) on the first byte. pub fn write_varint_with_start(data: &mut Vec, val: u32) { let start_pos = data.len(); - write_varint(data, val); - // Set start bit on first byte + write_varint_be(data, val); if let Some(first) = data.get_mut(start_pos) { *first |= 0x80; } } -/// Read a variable-length unsigned integer that starts with a start marker (0x80 bit). -/// Returns None if not at a valid start byte or end of data. +/// Read a big-endian varint with start marker (0x80). pub fn read_varint_with_start(data: &[u8], pos: &mut usize) -> Option { if *pos >= data.len() { return None; } let first = data[*pos]; if first & 0x80 == 0 { - return None; // Not a start byte + return None; } *pos += 1; - let mut val = (first & 0x3f) as u32; - let mut shift = 6; - let mut has_continuation = first & 0x40 != 0; - - while has_continuation && *pos < data.len() { - let byte = data[*pos]; - if byte & 0x80 != 0 { - break; // Next entry start - } + let mut cont = first & 0x40 != 0; + while cont && *pos < data.len() { + let b = data[*pos]; *pos += 1; - val |= ((byte & 0x3f) as u32) << shift; - shift += 6; - has_continuation = byte & 0x40 != 0; + val = (val << 6) | (b & 0x3f) as u32; + cont = b & 0x40 != 0; } Some(val) } -/// Read a variable-length unsigned integer. -/// Returns None if end of data or malformed. +/// Read a big-endian varint (no start marker). pub fn read_varint(data: &[u8], pos: &mut usize) -> Option { if *pos >= data.len() { return None; } - - let mut val = 0u32; - let mut shift = 0; - - loop { - if *pos >= data.len() { - return None; - } - let byte = data[*pos]; - if byte & 0x80 != 0 && shift > 0 { - break; // Next entry start - } + let first = data[*pos]; + *pos += 1; + let mut val = (first & 0x3f) as u32; + let mut cont = first & 0x40 != 0; + while cont && *pos < data.len() { + let b = data[*pos]; *pos += 1; - val |= ((byte & 0x3f) as u32) << shift; - shift += 6; - if byte & 0x40 == 0 { - break; - } + val = (val << 6) | (b & 0x3f) as u32; + cont = b & 0x40 != 0; } Some(val) } @@ -104,37 +106,39 @@ mod tests { use super::*; #[test] - fn test_write_read_varint() { + fn test_le_varint_roundtrip() { + // Little-endian is only used internally in linetable, + // no read function needed outside of linetable parsing. let mut buf = Vec::new(); write_varint(&mut buf, 0); write_varint(&mut buf, 63); write_varint(&mut buf, 64); write_varint(&mut buf, 4095); - - // Values: 0, 63, 64, 4095 assert_eq!(buf.len(), 1 + 1 + 2 + 2); } #[test] - fn test_write_read_signed_varint() { - let mut buf = Vec::new(); - write_signed_varint(&mut buf, 0); - write_signed_varint(&mut buf, 1); - write_signed_varint(&mut buf, -1); - write_signed_varint(&mut buf, i32::MIN); - - assert!(!buf.is_empty()); + fn test_be_varint_roundtrip() { + for &val in &[0u32, 1, 63, 64, 127, 128, 4095, 4096, 1_000_000] { + let mut buf = Vec::new(); + write_varint_be(&mut buf, val); + let mut pos = 0; + assert_eq!(read_varint(&buf, &mut pos), Some(val), "val={val}"); + assert_eq!(pos, buf.len()); + } } #[test] - fn test_varint_with_start() { + fn test_be_varint_with_start() { let mut buf = Vec::new(); write_varint_with_start(&mut buf, 42); write_varint_with_start(&mut buf, 100); + write_varint_with_start(&mut buf, 71); let mut pos = 0; assert_eq!(read_varint_with_start(&buf, &mut pos), Some(42)); assert_eq!(read_varint_with_start(&buf, &mut pos), Some(100)); + assert_eq!(read_varint_with_start(&buf, &mut pos), Some(71)); assert_eq!(read_varint_with_start(&buf, &mut pos), None); } } diff --git a/crates/vm/src/builtins/code.rs b/crates/vm/src/builtins/code.rs index 0bf193914c2..43a88de273e 100644 --- a/crates/vm/src/builtins/code.rs +++ b/crates/vm/src/builtins/code.rs @@ -194,6 +194,12 @@ impl From for PyObjectRef { } } +impl From for Literal { + fn from(obj: PyObjectRef) -> Self { + Literal(obj) + } +} + fn borrow_obj_constant(obj: &PyObject) -> BorrowedConstant<'_, Literal> { match_class!(match obj { ref i @ super::int::PyInt => { diff --git a/crates/vm/src/stdlib/marshal.rs b/crates/vm/src/stdlib/marshal.rs index 412d71f49e2..dace6bbf3e3 100644 --- a/crates/vm/src/stdlib/marshal.rs +++ b/crates/vm/src/stdlib/marshal.rs @@ -5,20 +5,19 @@ pub(crate) use decl::module_def; mod decl { use crate::builtins::code::{CodeObject, Literal, PyObjBag}; use crate::class::StaticType; + use crate::common::wtf8::Wtf8; use crate::{ PyObjectRef, PyResult, TryFromObject, VirtualMachine, builtins::{ PyBool, PyByteArray, PyBytes, PyCode, PyComplex, PyDict, PyEllipsis, PyFloat, PyFrozenSet, PyInt, PyList, PyNone, PySet, PyStopIteration, PyStr, PyTuple, }, - common::wtf8::Wtf8, convert::ToPyObject, function::{ArgBytesLike, OptionalArg}, object::{AsObject, PyPayload}, protocol::PyBuffer, }; use malachite_bigint::BigInt; - use num_complex::Complex64; use num_traits::Zero; use rustpython_compiler_core::marshal; @@ -91,34 +90,290 @@ mod decl { } } - #[pyfunction] - fn dumps( + #[derive(FromArgs)] + struct DumpsArgs { value: PyObjectRef, + #[pyarg(any, optional)] _version: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - use marshal::Dumpable; + #[pyarg(named, default = true)] + allow_code: bool, + } + + #[pyfunction] + fn dumps(args: DumpsArgs, vm: &VirtualMachine) -> PyResult { + let DumpsArgs { + value, + allow_code, + _version, + } = args; + let version = _version.unwrap_or(marshal::FORMAT_VERSION as i32); + if !allow_code { + check_no_code(&value, vm)?; + } + check_exact_type(&value, vm)?; let mut buf = Vec::new(); - value - .with_dump(|val| marshal::serialize_value(&mut buf, val)) - .unwrap_or_else(Err) - .map_err(|DumpError| { - vm.new_not_implemented_error( - "TODO: not implemented yet or marshal unsupported type", - ) - })?; + let mut refs = if version >= 3 { + Some(WriterRefTable::new()) + } else { + None + }; + write_object(&mut buf, &value, &mut refs, version, vm)?; Ok(PyBytes::from(buf)) } - #[pyfunction] - fn dump( - value: PyObjectRef, - f: PyObjectRef, - version: OptionalArg, + struct WriterRefTable { + map: std::collections::HashMap, + next_idx: u32, + } + + impl WriterRefTable { + fn new() -> Self { + Self { + map: std::collections::HashMap::new(), + next_idx: 0, + } + } + fn try_ref(&mut self, buf: &mut Vec, obj: &PyObjectRef) -> bool { + use marshal::Write; + let id = obj.get_id(); + if let Some(&idx) = self.map.get(&id) { + buf.write_u8(b'r'); + buf.write_u32(idx); + true + } else { + false + } + } + fn reserve(&mut self, obj: &PyObjectRef) -> u32 { + let idx = self.next_idx; + self.map.insert(obj.get_id(), idx); + self.next_idx += 1; + idx + } + } + + fn write_object( + buf: &mut Vec, + obj: &PyObjectRef, + refs: &mut Option, + version: i32, vm: &VirtualMachine, ) -> PyResult<()> { - let dumped = dumps(value, version, vm)?; - vm.call_method(&f, "write", (dumped,))?; + write_object_depth( + buf, + obj, + refs, + version, + vm, + marshal::MAX_MARSHAL_STACK_DEPTH, + ) + } + + fn write_object_depth( + buf: &mut Vec, + obj: &PyObjectRef, + refs: &mut Option, + version: i32, + vm: &VirtualMachine, + depth: usize, + ) -> PyResult<()> { + use marshal::Write; + if depth == 0 { + return Err(vm.new_value_error("object too deeply nested to marshal".to_string())); + } + + // Singletons: no FLAG_REF needed + let is_singleton = vm.is_none(obj) + || obj.class().is(PyBool::static_type()) + || obj.is(PyStopIteration::static_type()) + || obj.downcast_ref::().is_some(); + + // FLAG_REF: check if already written, otherwise reserve slot + if !is_singleton + && let Some(rt) = refs.as_mut() + && rt.try_ref(buf, obj) + { + return Ok(()); + } + let type_pos = buf.len(); + let use_ref = refs.is_some() && !is_singleton; + if use_ref { + refs.as_mut().unwrap().reserve(obj); + } + + if vm.is_none(obj) { + buf.write_u8(b'N'); + } else if obj.is(PyStopIteration::static_type()) { + buf.write_u8(b'S'); + } else if obj.class().is(PyBool::static_type()) { + let val = obj + .downcast_ref::() + .is_some_and(|i| !i.as_bigint().is_zero()); + buf.write_u8(if val { b'T' } else { b'F' }); + } else if obj.downcast_ref::().is_some() { + buf.write_u8(b'.'); + } else if let Some(i) = obj.downcast_ref::() { + // TYPE_INT for i32 range, TYPE_LONG for larger + if let Ok(val) = i32::try_from(i.as_bigint()) { + buf.write_u8(b'i'); + buf.write_u32(val as u32); + } else { + buf.write_u8(b'l'); + let (sign, raw) = i.as_bigint().to_bytes_le(); + let mut digits = Vec::new(); + let mut accum: u32 = 0; + let mut bits = 0u32; + for &byte in &raw { + accum |= (byte as u32) << bits; + bits += 8; + while bits >= 15 { + digits.push((accum & 0x7fff) as u16); + accum >>= 15; + bits -= 15; + } + } + if accum > 0 || digits.is_empty() { + digits.push(accum as u16); + } + while digits.len() > 1 && *digits.last().unwrap() == 0 { + digits.pop(); + } + let n = digits.len() as i32; + let n = if sign == malachite_bigint::Sign::Minus { + -n + } else { + n + }; + buf.write_u32(n as u32); + for d in &digits { + buf.write_u16(*d); + } + } + } else if let Some(f) = obj.downcast_ref::() { + buf.write_u8(b'g'); + buf.write_u64(f.to_f64().to_bits()); + } else if let Some(c) = obj.downcast_ref::() { + buf.write_u8(b'y'); + let cv = c.to_complex64(); + buf.write_u64(cv.re.to_bits()); + buf.write_u64(cv.im.to_bits()); + } else if let Some(s) = obj.downcast_ref::() { + let bytes = s.as_wtf8().as_bytes(); + let interned = version >= 3; + if bytes.len() < 256 && bytes.is_ascii() { + buf.write_u8(if interned { b'Z' } else { b'z' }); + buf.write_u8(bytes.len() as u8); + } else { + buf.write_u8(if interned { b't' } else { b'u' }); + buf.write_u32(bytes.len() as u32); + } + buf.write_slice(bytes); + } else if let Some(b) = obj.downcast_ref::() { + buf.write_u8(b's'); + let data = b.as_bytes(); + buf.write_u32(data.len() as u32); + buf.write_slice(data); + } else if let Some(b) = obj.downcast_ref::() { + buf.write_u8(b's'); + let data = b.borrow_buf(); + buf.write_u32(data.len() as u32); + buf.write_slice(&data); + } else if let Some(t) = obj.downcast_ref::() { + buf.write_u8(b'('); + buf.write_u32(t.len() as u32); + for elem in t.as_slice() { + write_object_depth(buf, elem, refs, version, vm, depth - 1)?; + } + } else if let Some(l) = obj.downcast_ref::() { + buf.write_u8(b'['); + let items = l.borrow_vec(); + buf.write_u32(items.len() as u32); + for elem in items.iter() { + write_object_depth(buf, elem, refs, version, vm, depth - 1)?; + } + } else if let Some(d) = obj.downcast_ref::() { + buf.write_u8(b'{'); + for (k, v) in d.into_iter() { + write_object_depth(buf, &k, refs, version, vm, depth - 1)?; + write_object_depth(buf, &v, refs, version, vm, depth - 1)?; + } + buf.write_u8(b'0'); // TYPE_NULL terminator + } else if let Some(s) = obj.downcast_ref::() { + buf.write_u8(b'<'); + let elems = s.elements(); + buf.write_u32(elems.len() as u32); + for elem in &elems { + write_object_depth(buf, elem, refs, version, vm, depth - 1)?; + } + } else if let Some(s) = obj.downcast_ref::() { + buf.write_u8(b'>'); + let elems = s.elements(); + buf.write_u32(elems.len() as u32); + for elem in &elems { + write_object_depth(buf, elem, refs, version, vm, depth - 1)?; + } + } else if let Some(co) = obj.downcast_ref::() { + buf.write_u8(b'c'); + marshal::serialize_code(buf, &co.code); + } else if let Some(sl) = obj.downcast_ref::() { + if version < 5 { + return Err(vm.new_value_error("unmarshallable object".to_string())); + } + buf.write_u8(b':'); + let none: PyObjectRef = vm.ctx.none(); + write_object_depth( + buf, + sl.start.as_ref().unwrap_or(&none), + refs, + version, + vm, + depth - 1, + )?; + write_object_depth(buf, &sl.stop, refs, version, vm, depth - 1)?; + write_object_depth( + buf, + sl.step.as_ref().unwrap_or(&none), + refs, + version, + vm, + depth - 1, + )?; + } else if let Ok(bytes_like) = ArgBytesLike::try_from_object(vm, obj.clone()) { + buf.write_u8(b's'); + let data = bytes_like.borrow_buf(); + buf.write_u32(data.len() as u32); + buf.write_slice(&data); + } else { + return Err(vm.new_value_error("unmarshallable object".to_string())); + } + + if use_ref { + buf[type_pos] |= marshal::FLAG_REF; + } + Ok(()) + } + + #[derive(FromArgs)] + struct DumpArgs { + value: PyObjectRef, + f: PyObjectRef, + #[pyarg(any, optional)] + _version: OptionalArg, + #[pyarg(named, default = true)] + allow_code: bool, + } + + #[pyfunction] + fn dump(args: DumpArgs, vm: &VirtualMachine) -> PyResult<()> { + let dumped = dumps( + DumpsArgs { + value: args.value, + _version: args._version, + allow_code: args.allow_code, + }, + vm, + )?; + vm.call_method(&args.f, "write", (dumped,))?; Ok(()) } @@ -132,121 +387,219 @@ mod decl { fn make_bool(&self, value: bool) -> Self::Value { self.0.ctx.new_bool(value).into() } - fn make_none(&self) -> Self::Value { self.0.ctx.none() } - fn make_ellipsis(&self) -> Self::Value { self.0.ctx.ellipsis.clone().into() } - fn make_float(&self, value: f64) -> Self::Value { self.0.ctx.new_float(value).into() } - - fn make_complex(&self, value: Complex64) -> Self::Value { + fn make_complex(&self, value: num_complex::Complex64) -> Self::Value { self.0.ctx.new_complex(value).into() } - fn make_str(&self, value: &Wtf8) -> Self::Value { self.0.ctx.new_str(value).into() } - fn make_bytes(&self, value: &[u8]) -> Self::Value { self.0.ctx.new_bytes(value.to_vec()).into() } - fn make_int(&self, value: BigInt) -> Self::Value { self.0.ctx.new_int(value).into() } - fn make_tuple(&self, elements: impl Iterator) -> Self::Value { - let elements = elements.collect(); - self.0.ctx.new_tuple(elements).into() + self.0.ctx.new_tuple(elements.collect()).into() } - fn make_code(&self, code: CodeObject) -> Self::Value { self.0.ctx.new_code(code).into() } - fn make_stop_iter(&self) -> Result { Ok(self.0.ctx.exceptions.stop_iteration.to_owned().into()) } - fn make_list( &self, it: impl Iterator, ) -> Result { Ok(self.0.ctx.new_list(it.collect()).into()) } - fn make_set( &self, it: impl Iterator, ) -> Result { - let vm = self.0; - let set = PySet::default().into_ref(&vm.ctx); + let set = PySet::default().into_ref(&self.0.ctx); for elem in it { - set.add(elem, vm).unwrap() + set.add(elem, self.0).unwrap() } Ok(set.into()) } - fn make_frozenset( &self, it: impl Iterator, ) -> Result { - let vm = self.0; - Ok(PyFrozenSet::from_iter(vm, it).unwrap().to_pyobject(vm)) + Ok(PyFrozenSet::from_iter(self.0, it) + .unwrap() + .to_pyobject(self.0)) } - fn make_dict( &self, it: impl Iterator, ) -> Result { - let vm = self.0; - let dict = vm.ctx.new_dict(); + let dict = self.0.ctx.new_dict(); for (k, v) in it { - dict.set_item(&*k, v, vm).unwrap() + dict.set_item(&*k, v, self.0).unwrap() } Ok(dict.into()) } - + fn make_slice( + &self, + start: Self::Value, + stop: Self::Value, + step: Self::Value, + ) -> Result { + use crate::builtins::PySlice; + let vm = self.0; + Ok(PySlice { + start: if vm.is_none(&start) { + None + } else { + Some(start) + }, + stop, + step: if vm.is_none(&step) { None } else { Some(step) }, + } + .into_ref(&vm.ctx) + .into()) + } fn constant_bag(self) -> Self::ConstantBag { PyObjBag(&self.0.ctx) } } + #[derive(FromArgs)] + struct LoadsArgs { + #[pyarg(any)] + data: PyBuffer, + #[pyarg(named, default = true)] + allow_code: bool, + } + #[pyfunction] - fn loads(pybuffer: PyBuffer, vm: &VirtualMachine) -> PyResult { + fn loads(args: LoadsArgs, vm: &VirtualMachine) -> PyResult { + let LoadsArgs { + data: pybuffer, + allow_code, + } = args; let buf = pybuffer.as_contiguous().ok_or_else(|| { vm.new_buffer_error("Buffer provided to marshal.loads() is not contiguous") })?; - marshal::deserialize_value(&mut &buf[..], PyMarshalBag(vm)).map_err(|e| match e { - marshal::MarshalError::Eof => vm.new_exception_msg( - vm.ctx.exceptions.eof_error.to_owned(), - "marshal data too short".into(), - ), - marshal::MarshalError::InvalidBytecode => { - vm.new_value_error("Couldn't deserialize python bytecode") + + let result = + marshal::deserialize_value(&mut &buf[..], PyMarshalBag(vm)).map_err(|e| match e { + marshal::MarshalError::Eof => vm.new_exception_msg( + vm.ctx.exceptions.eof_error.to_owned(), + "marshal data too short".into(), + ), + _ => vm.new_value_error("bad marshal data"), + })?; + if !allow_code { + check_no_code(&result, vm)?; + } + Ok(result) + } + + #[derive(FromArgs)] + struct LoadArgs { + f: PyObjectRef, + #[pyarg(named, default = true)] + allow_code: bool, + } + + #[pyfunction] + fn load(args: LoadArgs, vm: &VirtualMachine) -> PyResult { + // Read from file object into a buffer, one object at a time. + // We read all available data, deserialize one object, then seek + // back to just after the consumed bytes. + let tell_before = vm + .call_method(&args.f, "tell", ())? + .try_into_value::(vm)?; + let read_res = vm.call_method(&args.f, "read", ())?; + let bytes = ArgBytesLike::try_from_object(vm, read_res)?; + let buf = bytes.borrow_buf(); + + let mut rdr: &[u8] = &buf; + let len_before = rdr.len(); + let result = + marshal::deserialize_value(&mut rdr, PyMarshalBag(vm)).map_err(|e| match e { + marshal::MarshalError::Eof => vm.new_exception_msg( + vm.ctx.exceptions.eof_error.to_owned(), + "marshal data too short".into(), + ), + _ => vm.new_value_error("bad marshal data"), + })?; + let consumed = len_before - rdr.len(); + + // Seek file to just after the consumed bytes + let new_pos = tell_before + consumed as i64; + vm.call_method(&args.f, "seek", (new_pos,))?; + + if !args.allow_code { + check_no_code(&result, vm)?; + } + Ok(result) + } + + /// Reject subclasses of marshallable types (int, float, complex, tuple, etc.). + /// Recursively check that no code objects are present. + fn check_no_code(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + if obj.downcast_ref::().is_some() { + return Err(vm.new_value_error("unmarshalling code objects is disallowed".to_string())); + } + if let Some(tup) = obj.downcast_ref::() { + for elem in tup.as_slice() { + check_no_code(elem, vm)?; } - marshal::MarshalError::InvalidUtf8 => { - vm.new_value_error("invalid utf8 in marshalled string") + } else if let Some(list) = obj.downcast_ref::() { + for elem in list.borrow_vec().iter() { + check_no_code(elem, vm)?; } - marshal::MarshalError::InvalidLocation => { - vm.new_value_error("invalid location in marshalled object") + } else if let Some(set) = obj.downcast_ref::() { + for elem in set.elements() { + check_no_code(&elem, vm)?; } - marshal::MarshalError::BadType => { - vm.new_value_error("bad marshal data (unknown type code)") + } else if let Some(fset) = obj.downcast_ref::() { + for elem in fset.elements() { + check_no_code(&elem, vm)?; } - }) + } else if let Some(dict) = obj.downcast_ref::() { + for (k, v) in dict.into_iter() { + check_no_code(&k, vm)?; + check_no_code(&v, vm)?; + } + } + Ok(()) } - #[pyfunction] - fn load(f: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let read_res = vm.call_method(&f, "read", ())?; - let bytes = ArgBytesLike::try_from_object(vm, read_res)?; - loads(PyBuffer::from(bytes), vm) + fn check_exact_type(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let cls = obj.class(); + // bool is a subclass of int but is marshallable + if cls.is(PyBool::static_type()) { + return Ok(()); + } + for base in [ + PyInt::static_type(), + PyFloat::static_type(), + PyComplex::static_type(), + PyTuple::static_type(), + PyList::static_type(), + PyDict::static_type(), + PySet::static_type(), + PyFrozenSet::static_type(), + ] { + if cls.fast_issubclass(base) && !cls.is(base) { + return Err(vm.new_value_error("unmarshallable object".to_string())); + } + } + Ok(()) } } diff --git a/crates/vm/src/types/slot.rs b/crates/vm/src/types/slot.rs index c560e2e1f9e..b92f2658f13 100644 --- a/crates/vm/src/types/slot.rs +++ b/crates/vm/src/types/slot.rs @@ -1865,12 +1865,14 @@ impl PyComparisonOp { } pub fn eval_ord(self, ord: Ordering) -> bool { - let bit = match ord { - Ordering::Less => Self::Lt, - Ordering::Equal => Self::Eq, - Ordering::Greater => Self::Gt, - }; - u8::from(self.0) & u8::from(bit.0) != 0 + match self { + Self::Lt => ord == Ordering::Less, + Self::Le => ord != Ordering::Greater, + Self::Eq => ord == Ordering::Equal, + Self::Ne => ord != Ordering::Equal, + Self::Gt => ord == Ordering::Greater, + Self::Ge => ord != Ordering::Less, + } } pub const fn swapped(self) -> Self {