See More

use pyo3::{types::*, Bound}; use serde::de::{self, IntoDeserializer}; use serde::Deserialize; use crate::error::{ErrorImpl, PythonizeError, Result}; /// Attempt to convert a Python object to an instance of `T` pub fn depythonize<'a, 'py, T>(obj: &'a Bound<'py, PyAny>) -> Result where T: Deserialize<'a>, { T::deserialize(&mut Depythonizer::from_object(obj)) } /// A structure that deserializes Python objects into Rust values pub struct Depythonizer<'a, 'py> { input: &'a Bound<'py, PyAny>, } impl<'a, 'py> Depythonizer<'a, 'py> { /// Create a deserializer from a Python object pub fn from_object(input: &'a Bound<'py, PyAny>) -> Self { Depythonizer { input } } fn sequence_access(&self, expected_len: Option) -> Result> { let seq = self.input.downcast::()?; let len = self.input.len()?; match expected_len { Some(expected) if expected != len => { Err(PythonizeError::incorrect_sequence_length(expected, len)) } _ => Ok(PySequenceAccess::new(seq, len)), } } fn set_access(&self) -> Result> { match self.input.downcast::() { Ok(set) => Ok(PySetAsSequence::from_set(set)), Err(e) => { if let Ok(f) = self.input.downcast::() { Ok(PySetAsSequence::from_frozenset(f)) } else { Err(e.into()) } } } } fn dict_access(&self) -> Result> { PyMappingAccess::new(self.input.downcast()?) } fn deserialize_any_int<'de, V>(&self, int: &Bound<'_, PyInt>, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { if let Ok(x) = int.extract::() { if let Ok(x) = u8::try_from(x) { visitor.visit_u8(x) } else if let Ok(x) = u16::try_from(x) { visitor.visit_u16(x) } else if let Ok(x) = u32::try_from(x) { visitor.visit_u32(x) } else if let Ok(x) = u64::try_from(x) { visitor.visit_u64(x) } else { visitor.visit_u128(x) } } else { let x: i128 = int.extract()?; if let Ok(x) = i8::try_from(x) { visitor.visit_i8(x) } else if let Ok(x) = i16::try_from(x) { visitor.visit_i16(x) } else if let Ok(x) = i32::try_from(x) { visitor.visit_i32(x) } else if let Ok(x) = i64::try_from(x) { visitor.visit_i64(x) } else { visitor.visit_i128(x) } } } } macro_rules! deserialize_type { ($method:ident => $visit:ident) => { fn $method(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { visitor.$visit(self.input.extract()?) } }; } impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> { type Error = PythonizeError; fn deserialize_any(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { let obj = self.input; // First check for cases which are cheap to check due to pointer // comparison or bitflag checks if obj.is_none() { self.deserialize_unit(visitor) } else if obj.is_instance_of::() { self.deserialize_bool(visitor) } else if let Ok(x) = obj.downcast::() { self.deserialize_any_int(x, visitor) } else if obj.is_instance_of::() || obj.is_instance_of::() { self.deserialize_tuple(obj.len()?, visitor) } else if obj.is_instance_of::() { self.deserialize_map(visitor) } else if obj.is_instance_of::() { self.deserialize_str(visitor) } // Continue with cases which are slower to check because they go // throuh `isinstance` machinery else if obj.is_instance_of::() || obj.is_instance_of::() { self.deserialize_bytes(visitor) } else if obj.is_instance_of::() { self.deserialize_f64(visitor) } else if obj.is_instance_of::() || obj.is_instance_of::() { self.deserialize_seq(visitor) } else if obj.downcast::().is_ok() { self.deserialize_tuple(obj.len()?, visitor) } else if obj.downcast::().is_ok() { self.deserialize_map(visitor) } else { Err(obj.get_type().qualname().map_or_else( |_| PythonizeError::unsupported_type("unknown"), PythonizeError::unsupported_type, )) } } fn deserialize_bool(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { visitor.visit_bool(self.input.is_truthy()?) } fn deserialize_char(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { let s = self.input.downcast::()?.to_cow()?; if s.len() != 1 { return Err(PythonizeError::invalid_length_char()); } visitor.visit_char(s.chars().next().unwrap()) } deserialize_type!(deserialize_i8 => visit_i8); deserialize_type!(deserialize_i16 => visit_i16); deserialize_type!(deserialize_i32 => visit_i32); deserialize_type!(deserialize_i64 => visit_i64); deserialize_type!(deserialize_i128 => visit_i128); deserialize_type!(deserialize_u8 => visit_u8); deserialize_type!(deserialize_u16 => visit_u16); deserialize_type!(deserialize_u32 => visit_u32); deserialize_type!(deserialize_u64 => visit_u64); deserialize_type!(deserialize_u128 => visit_u128); deserialize_type!(deserialize_f32 => visit_f32); deserialize_type!(deserialize_f64 => visit_f64); fn deserialize_str(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { let s = self.input.downcast::()?; visitor.visit_str(&s.to_cow()?) } fn deserialize_string(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { self.deserialize_str(visitor) } fn deserialize_bytes(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { let b = self.input.downcast::()?; visitor.visit_bytes(b.as_bytes()) } fn deserialize_byte_buf(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { self.deserialize_bytes(visitor) } fn deserialize_option(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { if self.input.is_none() { visitor.visit_none() } else { visitor.visit_some(self) } } fn deserialize_unit(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { if self.input.is_none() { visitor.visit_unit() } else { Err(PythonizeError::msg("expected None")) } } fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { self.deserialize_unit(visitor) } fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { visitor.visit_newtype_struct(self) } fn deserialize_seq(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { match self.sequence_access(None) { Ok(seq) => visitor.visit_seq(seq), Err(e) => { // we allow sets to be deserialized as sequences, so try that if matches!(*e.inner, ErrorImpl::UnexpectedType(_)) { if let Ok(set) = self.set_access() { return visitor.visit_seq(set); } } Err(e) } } } fn deserialize_tuple(self, len: usize, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { visitor.visit_seq(self.sequence_access(Some(len))?) } fn deserialize_tuple_struct( self, _name: &'static str, len: usize, visitor: V, ) -> Result<:value> where V: de::Visitor<'de>, { visitor.visit_seq(self.sequence_access(Some(len))?) } fn deserialize_map(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { visitor.visit_map(self.dict_access()?) } fn deserialize_struct( self, _name: &'static str, _fields: &'static [&'static str], visitor: V, ) -> Result<:value> where V: de::Visitor<'de>, { self.deserialize_map(visitor) } fn deserialize_enum( self, _name: &'static str, _variants: &'static [&'static str], visitor: V, ) -> Result<:value> where V: de::Visitor<'de>, { let item = &self.input; if let Ok(s) = item.downcast::() { visitor.visit_enum(s.to_cow()?.into_deserializer()) } else if let Ok(m) = item.downcast::() { // Get the enum variant from the mapping key if m.len()? != 1 { return Err(PythonizeError::invalid_length_enum()); } let variant: Bound = m .keys()? .get_item(0)? .downcast_into::() .map_err(|_| PythonizeError::dict_key_not_string())?; let value = m.get_item(&variant)?; visitor.visit_enum(PyEnumAccess::new(&value, variant)) } else { Err(PythonizeError::invalid_enum_type()) } } fn deserialize_identifier(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { let s = self .input .downcast::() .map_err(|_| PythonizeError::dict_key_not_string())?; visitor.visit_str(&s.to_cow()?) } fn deserialize_ignored_any(self, visitor: V) -> Result<:value> where V: de::Visitor<'de>, { visitor.visit_unit() } } struct PySequenceAccess<'a, 'py> { seq: &'a Bound<'py, PySequence>, index: usize, len: usize, } impl<'a, 'py> PySequenceAccess<'a, 'py> { fn new(seq: &'a Bound<'py, PySequence>, len: usize) -> Self { Self { seq, index: 0, len } } } impl<'de> de::SeqAccess<'de> for PySequenceAccess<'_, '_> { type Error = PythonizeError; fn next_element_seed(&mut self, seed: T) -> Result