Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Lib/test/test_json/test_recursion.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def default(self, o):
self.fail("didn't raise ValueError on default recursion")


@unittest.skip("TODO: RUSTPYTHON; crashes")
@support.skip_if_unlimited_stack_size
@support.skip_emscripten_stack_overflow()
@support.skip_wasi_stack_overflow()
Expand Down
195 changes: 102 additions & 93 deletions crates/stdlib/src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,107 +513,116 @@ mod _json {
memo: &mut HashMap<String, PyStrRef>,
vm: &VirtualMachine,
) -> PyResult<(PyObjectRef, usize, usize)> {
let bytes = pystr.as_bytes();
let wtf8 = pystr.as_wtf8();
let s = pystr.as_str();

let first_byte = match bytes.get(byte_idx) {
Some(&b) => b,
None => return Err(self.make_decode_error("Expecting value", pystr, char_idx, vm)),
};
// Recursion guard: parse_object/parse_array recurse into call_scan_once
// for each child value. Without this, a deeply-nested input like
// `'[' * 50000 + ']' * 50000` overflows the native Rust stack and
// crashes the process with SIGSEGV. Matches CPython's
// _Py_EnterRecursiveCall in Modules/_json.c.
vm.with_recursion("while decoding a JSON object from a string", || {
let bytes = pystr.as_bytes();
let wtf8 = pystr.as_wtf8();
let s = pystr.as_str();

let first_byte = match bytes.get(byte_idx) {
Some(&b) => b,
None => {
return Err(self.make_decode_error("Expecting value", pystr, char_idx, vm));
}
};

match first_byte {
b'"' => {
// String - pass slice starting after the quote
let (wtf8_result, chars_consumed, bytes_consumed) =
machinery::scanstring(&wtf8[byte_idx + 1..], char_idx + 1, self.strict)
.map_err(|e| py_decode_error(e, pystr.clone().into_wtf8(), vm))?;
let py_str = vm.ctx.new_str(wtf8_result.to_string());
Ok((
py_str.into(),
char_idx + 1 + chars_consumed,
byte_idx + 1 + bytes_consumed,
))
}
b'{' => {
// Object
self.parse_object(pystr, char_idx + 1, byte_idx + 1, scan_once, memo, vm)
}
b'[' => {
// Array
self.parse_array(pystr, char_idx + 1, byte_idx + 1, scan_once, memo, vm)
}
b'n' if starts_with_bytes(&bytes[byte_idx..], b"null") => {
// null
Ok((vm.ctx.none(), char_idx + 4, byte_idx + 4))
}
b't' if starts_with_bytes(&bytes[byte_idx..], b"true") => {
// true
Ok((vm.ctx.new_bool(true).into(), char_idx + 4, byte_idx + 4))
}
b'f' if starts_with_bytes(&bytes[byte_idx..], b"false") => {
// false
Ok((vm.ctx.new_bool(false).into(), char_idx + 5, byte_idx + 5))
}
b'N' if starts_with_bytes(&bytes[byte_idx..], b"NaN") => {
// NaN
let result = self.parse_constant.call(("NaN",), vm)?;
Ok((result, char_idx + 3, byte_idx + 3))
}
b'I' if starts_with_bytes(&bytes[byte_idx..], b"Infinity") => {
// Infinity
let result = self.parse_constant.call(("Infinity",), vm)?;
Ok((result, char_idx + 8, byte_idx + 8))
}
b'-' => {
// -Infinity or negative number
if starts_with_bytes(&bytes[byte_idx..], b"-Infinity") {
let result = self.parse_constant.call(("-Infinity",), vm)?;
return Ok((result, char_idx + 9, byte_idx + 9));
match first_byte {
b'"' => {
// String - pass slice starting after the quote
let (wtf8_result, chars_consumed, bytes_consumed) =
machinery::scanstring(&wtf8[byte_idx + 1..], char_idx + 1, self.strict)
.map_err(|e| py_decode_error(e, pystr.clone().into_wtf8(), vm))?;
let py_str = vm.ctx.new_str(wtf8_result.to_string());
Ok((
py_str.into(),
char_idx + 1 + chars_consumed,
byte_idx + 1 + bytes_consumed,
))
}
// Negative number - numbers are ASCII so len == bytes
if let Some((result, len)) = self.parse_number(&s[byte_idx..], vm) {
return Ok((result?, char_idx + len, byte_idx + len));
b'{' => {
// Object
self.parse_object(pystr, char_idx + 1, byte_idx + 1, scan_once, memo, vm)
}
Err(self.make_decode_error("Expecting value", pystr, char_idx, vm))
}
b'0'..=b'9' => {
// Positive number - numbers are ASCII so len == bytes
if let Some((result, len)) = self.parse_number(&s[byte_idx..], vm) {
return Ok((result?, char_idx + len, byte_idx + len));
b'[' => {
// Array
self.parse_array(pystr, char_idx + 1, byte_idx + 1, scan_once, memo, vm)
}
Err(self.make_decode_error("Expecting value", pystr, char_idx, vm))
}
_ => {
// Fall back to scan_once for unrecognized input
// Note: This path requires char_idx for Python compatibility
let result = scan_once.call((pystr.clone(), char_idx as isize), vm);

match result {
Ok(tuple) => {
use crate::vm::builtins::PyTupleRef;
let tuple: PyTupleRef = tuple.try_into_value(vm)?;
if tuple.len() != 2 {
return Err(vm.new_value_error("scan_once must return 2-tuple"));
}
let value = tuple.as_slice()[0].clone();
let end_char_idx: isize = tuple.as_slice()[1].try_to_value(vm)?;
// For fallback, we need to calculate byte_idx from char_idx
// This is expensive but fallback should be rare
let end_byte_idx = s
.char_indices()
.nth(end_char_idx as usize)
.map(|(i, _)| i)
.unwrap_or(s.len());
Ok((value, end_char_idx as usize, end_byte_idx))
b'n' if starts_with_bytes(&bytes[byte_idx..], b"null") => {
// null
Ok((vm.ctx.none(), char_idx + 4, byte_idx + 4))
}
b't' if starts_with_bytes(&bytes[byte_idx..], b"true") => {
// true
Ok((vm.ctx.new_bool(true).into(), char_idx + 4, byte_idx + 4))
}
b'f' if starts_with_bytes(&bytes[byte_idx..], b"false") => {
// false
Ok((vm.ctx.new_bool(false).into(), char_idx + 5, byte_idx + 5))
}
b'N' if starts_with_bytes(&bytes[byte_idx..], b"NaN") => {
// NaN
let result = self.parse_constant.call(("NaN",), vm)?;
Ok((result, char_idx + 3, byte_idx + 3))
}
b'I' if starts_with_bytes(&bytes[byte_idx..], b"Infinity") => {
// Infinity
let result = self.parse_constant.call(("Infinity",), vm)?;
Ok((result, char_idx + 8, byte_idx + 8))
}
b'-' => {
// -Infinity or negative number
if starts_with_bytes(&bytes[byte_idx..], b"-Infinity") {
let result = self.parse_constant.call(("-Infinity",), vm)?;
return Ok((result, char_idx + 9, byte_idx + 9));
}
// Negative number - numbers are ASCII so len == bytes
if let Some((result, len)) = self.parse_number(&s[byte_idx..], vm) {
return Ok((result?, char_idx + len, byte_idx + len));
}
Err(err) if err.fast_isinstance(vm.ctx.exceptions.stop_iteration) => {
Err(self.make_decode_error("Expecting value", pystr, char_idx, vm))
Err(self.make_decode_error("Expecting value", pystr, char_idx, vm))
}
b'0'..=b'9' => {
// Positive number - numbers are ASCII so len == bytes
if let Some((result, len)) = self.parse_number(&s[byte_idx..], vm) {
return Ok((result?, char_idx + len, byte_idx + len));
}
Err(self.make_decode_error("Expecting value", pystr, char_idx, vm))
}
_ => {
// Fall back to scan_once for unrecognized input
// Note: This path requires char_idx for Python compatibility
let result = scan_once.call((pystr.clone(), char_idx as isize), vm);

match result {
Ok(tuple) => {
use crate::vm::builtins::PyTupleRef;
let tuple: PyTupleRef = tuple.try_into_value(vm)?;
if tuple.len() != 2 {
return Err(vm.new_value_error("scan_once must return 2-tuple"));
}
let value = tuple.as_slice()[0].clone();
let end_char_idx: isize = tuple.as_slice()[1].try_to_value(vm)?;
// For fallback, we need to calculate byte_idx from char_idx
// This is expensive but fallback should be rare
let end_byte_idx = s
.char_indices()
.nth(end_char_idx as usize)
.map(|(i, _)| i)
.unwrap_or(s.len());
Ok((value, end_char_idx as usize, end_byte_idx))
}
Err(err) if err.fast_isinstance(vm.ctx.exceptions.stop_iteration) => {
Err(self.make_decode_error("Expecting value", pystr, char_idx, vm))
}
Err(err) => Err(err),
}
Err(err) => Err(err),
}
}
}
})
}

/// Create a decode error.
Expand Down
22 changes: 22 additions & 0 deletions extra_tests/snippets/stdlib_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,25 @@ class Dict(dict):
assert json.dumps(i) == str(i)

assert json.decoder.scanstring('✨x"', 1) == ("x", 3)


# Recursion guard: deeply-nested input must raise RecursionError instead of
# overflowing the native stack (SIGSEGV). Matches CPython's
# _Py_EnterRecursiveCall in Modules/_json.c.

_deep = 100_000 # well above the ~45k native-stack crash threshold

# Array nesting
assert_raises(RecursionError, lambda: json.loads("[" * _deep + "]" * _deep))

# Object nesting
assert_raises(
RecursionError,
lambda: json.loads('{"a":' * _deep + "1" + "}" * _deep),
)

# Alternating array/object nesting
assert_raises(
RecursionError,
lambda: json.loads(('[{"x":' * _deep) + "1" + ("}]" * _deep)),
)
Loading