use pyo3::{types::*, Bound};
use serde::de::{self, IntoDeserializer};
use serde::Deserialize;
use crate::error::{ErrorImpl, PythonizeError, Result};
/// Attempt to convert a Python object to an instance of `T`
pub fn depythonize<'a, 'py, T>(obj: &'a Bound<'py, PyAny>) -> Result
where
T: Deserialize<'a>,
{
T::deserialize(&mut Depythonizer::from_object(obj))
}
/// A structure that deserializes Python objects into Rust values
pub struct Depythonizer<'a, 'py> {
input: &'a Bound<'py, PyAny>,
}
impl<'a, 'py> Depythonizer<'a, 'py> {
/// Create a deserializer from a Python object
pub fn from_object(input: &'a Bound<'py, PyAny>) -> Self {
Depythonizer { input }
}
fn sequence_access(&self, expected_len: Option) -> Result> {
let seq = self.input.downcast::()?;
let len = self.input.len()?;
match expected_len {
Some(expected) if expected != len => {
Err(PythonizeError::incorrect_sequence_length(expected, len))
}
_ => Ok(PySequenceAccess::new(seq, len)),
}
}
fn set_access(&self) -> Result> {
match self.input.downcast::() {
Ok(set) => Ok(PySetAsSequence::from_set(set)),
Err(e) => {
if let Ok(f) = self.input.downcast::() {
Ok(PySetAsSequence::from_frozenset(f))
} else {
Err(e.into())
}
}
}
}
fn dict_access(&self) -> Result> {
PyMappingAccess::new(self.input.downcast()?)
}
fn deserialize_any_int<'de, V>(&self, int: &Bound<'_, PyInt>, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
if let Ok(x) = int.extract::() {
if let Ok(x) = u8::try_from(x) {
visitor.visit_u8(x)
} else if let Ok(x) = u16::try_from(x) {
visitor.visit_u16(x)
} else if let Ok(x) = u32::try_from(x) {
visitor.visit_u32(x)
} else if let Ok(x) = u64::try_from(x) {
visitor.visit_u64(x)
} else {
visitor.visit_u128(x)
}
} else {
let x: i128 = int.extract()?;
if let Ok(x) = i8::try_from(x) {
visitor.visit_i8(x)
} else if let Ok(x) = i16::try_from(x) {
visitor.visit_i16(x)
} else if let Ok(x) = i32::try_from(x) {
visitor.visit_i32(x)
} else if let Ok(x) = i64::try_from(x) {
visitor.visit_i64(x)
} else {
visitor.visit_i128(x)
}
}
}
}
macro_rules! deserialize_type {
($method:ident => $visit:ident) => {
fn $method(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
visitor.$visit(self.input.extract()?)
}
};
}
impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> {
type Error = PythonizeError;
fn deserialize_any(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
let obj = self.input;
// First check for cases which are cheap to check due to pointer
// comparison or bitflag checks
if obj.is_none() {
self.deserialize_unit(visitor)
} else if obj.is_instance_of::() {
self.deserialize_bool(visitor)
} else if let Ok(x) = obj.downcast::() {
self.deserialize_any_int(x, visitor)
} else if obj.is_instance_of::() || obj.is_instance_of::() {
self.deserialize_tuple(obj.len()?, visitor)
} else if obj.is_instance_of::() {
self.deserialize_map(visitor)
} else if obj.is_instance_of::() {
self.deserialize_str(visitor)
}
// Continue with cases which are slower to check because they go
// throuh `isinstance` machinery
else if obj.is_instance_of::() || obj.is_instance_of::() {
self.deserialize_bytes(visitor)
} else if obj.is_instance_of::() {
self.deserialize_f64(visitor)
} else if obj.is_instance_of::() || obj.is_instance_of::() {
self.deserialize_seq(visitor)
} else if obj.downcast::().is_ok() {
self.deserialize_tuple(obj.len()?, visitor)
} else if obj.downcast::().is_ok() {
self.deserialize_map(visitor)
} else {
Err(obj.get_type().qualname().map_or_else(
|_| PythonizeError::unsupported_type("unknown"),
PythonizeError::unsupported_type,
))
}
}
fn deserialize_bool(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
visitor.visit_bool(self.input.is_truthy()?)
}
fn deserialize_char(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
let s = self.input.downcast::()?.to_cow()?;
if s.len() != 1 {
return Err(PythonizeError::invalid_length_char());
}
visitor.visit_char(s.chars().next().unwrap())
}
deserialize_type!(deserialize_i8 => visit_i8);
deserialize_type!(deserialize_i16 => visit_i16);
deserialize_type!(deserialize_i32 => visit_i32);
deserialize_type!(deserialize_i64 => visit_i64);
deserialize_type!(deserialize_i128 => visit_i128);
deserialize_type!(deserialize_u8 => visit_u8);
deserialize_type!(deserialize_u16 => visit_u16);
deserialize_type!(deserialize_u32 => visit_u32);
deserialize_type!(deserialize_u64 => visit_u64);
deserialize_type!(deserialize_u128 => visit_u128);
deserialize_type!(deserialize_f32 => visit_f32);
deserialize_type!(deserialize_f64 => visit_f64);
fn deserialize_str(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
let s = self.input.downcast::()?;
visitor.visit_str(&s.to_cow()?)
}
fn deserialize_string(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
self.deserialize_str(visitor)
}
fn deserialize_bytes(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
let b = self.input.downcast::()?;
visitor.visit_bytes(b.as_bytes())
}
fn deserialize_byte_buf(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
self.deserialize_bytes(visitor)
}
fn deserialize_option(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
if self.input.is_none() {
visitor.visit_none()
} else {
visitor.visit_some(self)
}
}
fn deserialize_unit(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
if self.input.is_none() {
visitor.visit_unit()
} else {
Err(PythonizeError::msg("expected None"))
}
}
fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
self.deserialize_unit(visitor)
}
fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_seq(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
match self.sequence_access(None) {
Ok(seq) => visitor.visit_seq(seq),
Err(e) => {
// we allow sets to be deserialized as sequences, so try that
if matches!(*e.inner, ErrorImpl::UnexpectedType(_)) {
if let Ok(set) = self.set_access() {
return visitor.visit_seq(set);
}
}
Err(e)
}
}
}
fn deserialize_tuple(self, len: usize, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
visitor.visit_seq(self.sequence_access(Some(len))?)
}
fn deserialize_tuple_struct(
self,
_name: &'static str,
len: usize,
visitor: V,
) -> Result<:value>
where
V: de::Visitor<'de>,
{
visitor.visit_seq(self.sequence_access(Some(len))?)
}
fn deserialize_map(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
visitor.visit_map(self.dict_access()?)
}
fn deserialize_struct(
self,
_name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<:value>
where
V: de::Visitor<'de>,
{
self.deserialize_map(visitor)
}
fn deserialize_enum(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<:value>
where
V: de::Visitor<'de>,
{
let item = &self.input;
if let Ok(s) = item.downcast::() {
visitor.visit_enum(s.to_cow()?.into_deserializer())
} else if let Ok(m) = item.downcast::() {
// Get the enum variant from the mapping key
if m.len()? != 1 {
return Err(PythonizeError::invalid_length_enum());
}
let variant: Bound = m
.keys()?
.get_item(0)?
.downcast_into::()
.map_err(|_| PythonizeError::dict_key_not_string())?;
let value = m.get_item(&variant)?;
visitor.visit_enum(PyEnumAccess::new(&value, variant))
} else {
Err(PythonizeError::invalid_enum_type())
}
}
fn deserialize_identifier(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
let s = self
.input
.downcast::()
.map_err(|_| PythonizeError::dict_key_not_string())?;
visitor.visit_str(&s.to_cow()?)
}
fn deserialize_ignored_any(self, visitor: V) -> Result<:value>
where
V: de::Visitor<'de>,
{
visitor.visit_unit()
}
}
struct PySequenceAccess<'a, 'py> {
seq: &'a Bound<'py, PySequence>,
index: usize,
len: usize,
}
impl<'a, 'py> PySequenceAccess<'a, 'py> {
fn new(seq: &'a Bound<'py, PySequence>, len: usize) -> Self {
Self { seq, index: 0, len }
}
}
impl<'de> de::SeqAccess<'de> for PySequenceAccess<'_, '_> {
type Error = PythonizeError;
fn next_element_seed(&mut self, seed: T) -> Result