diff --git a/serde_codegen/src/attr.rs b/serde_codegen/src/attr.rs index 1ef63530..975ca303 100644 --- a/serde_codegen/src/attr.rs +++ b/serde_codegen/src/attr.rs @@ -171,6 +171,7 @@ pub struct FieldAttrs { skip_serializing_field_if_none: bool, default_expr_if_missing: Option>, serialize_with: Option>, + deserialize_with: Option>, } impl FieldAttrs { @@ -196,6 +197,7 @@ impl FieldAttrs { skip_serializing_field_if_none: false, default_expr_if_missing: None, serialize_with: None, + deserialize_with: None, }; for meta_items in field.node.attrs.iter().filter_map(get_serde_meta_items) { @@ -271,6 +273,18 @@ impl FieldAttrs { field_attrs.serialize_with = Some(expr); } + // Parse `#[serde(deserialize_with="...")]` + ast::MetaItemKind::NameValue(ref name, ref lit) if name == &"deserialize_with" => { + let expr = wrap_deserialize_with( + cx, + &field.node.ty, + generics, + try!(parse_lit_into_expr(cx, name, lit)), + ); + + field_attrs.deserialize_with = Some(expr); + } + _ => { cx.span_err( meta_item.span, @@ -342,6 +356,10 @@ impl FieldAttrs { pub fn serialize_with(&self) -> Option<&P> { self.serialize_with.as_ref() } + + pub fn deserialize_with(&self) -> Option<&P> { + self.deserialize_with.as_ref() + } } @@ -515,3 +533,58 @@ fn wrap_serialize_with(cx: &ExtCtxt, } }) } + +/// This function wraps the expression in `#[serde(deserialize_with="...")]` in a trait to prevent +/// it from accessing the internal `Deserialize` state. +fn wrap_deserialize_with(cx: &ExtCtxt, + field_ty: &P, + generics: &ast::Generics, + expr: P) -> P { + let builder = AstBuilder::new(); + + let fn_generics = builder.from_generics(generics.clone()) + .ty_param("__D") + .bound() + .trait_( + builder.path() + .global() + .ids(&["serde", "de", "Deserializer"]) + .build() + ) + .build() + .build() + .build(); + + // Quasi-quoting doesn't do a great job of expanding generics into paths, so manually build it. + let ty_path = AstBuilder::new().path() + .segment("__SerdeDeserializeWithStruct") + .with_generics(generics.clone()) + .build() + .build(); + + let fn_where_clause = &fn_generics.where_clause; + let where_clause = &generics.where_clause; + + quote_expr!(cx, { + fn __serde_deserialize_with $fn_generics(deserializer: &mut __D) + -> Result<$field_ty, __D::Error> $fn_where_clause { + $expr + } + + struct __SerdeDeserializeWithStruct $generics $where_clause { + value: $field_ty, + } + + impl $generics ::serde::de::Deserialize for $ty_path $where_clause { + fn deserialize(deserializer: &mut D) -> Result + where D: ::serde::de::Deserializer + { + let value = try!(__serde_deserialize_with(deserializer)); + Ok(__SerdeDeserializeWithStruct { value: value }) + } + } + + let value: $ty_path = try!(visitor.visit_value()); + Ok(value.value) + }) +} diff --git a/serde_codegen/src/de.rs b/serde_codegen/src/de.rs index e95fe2dc..14174542 100644 --- a/serde_codegen/src/de.rs +++ b/serde_codegen/src/de.rs @@ -991,6 +991,12 @@ fn deserialize_map( .map(|i| builder.id(format!("__field{}", i))) .collect(); + let field_attrs: Vec<_> = try!( + fields.iter() + .map(|field| attr::FieldAttrs::from_field(cx, container_ty, generics, field)) + .collect() + ); + // Declare each field. let let_values: Vec = field_names.iter() .map(|field_name| quote_stmt!(cx, let mut $field_name = None;).unwrap()) @@ -1007,26 +1013,24 @@ fn deserialize_map( }; // Match arms to extract a value for a field. - let value_arms: Vec = field_names.iter() - .map(|field_name| { + let value_arms = field_attrs.iter().zip(field_names.iter()) + .map(|(field_attr, field_name)| { + let expr = match field_attr.deserialize_with() { + Some(expr) => expr.clone(), + None => quote_expr!(cx, visitor.visit_value()), + }; + quote_arm!(cx, __Field::$field_name => { - $field_name = Some(try!(visitor.visit_value())); + $field_name = Some(try!($expr)); } ) }) .chain(ignored_arm.into_iter()) - .collect(); + .collect::>(); - let extract_values = fields.iter() - .zip(field_names.iter()) - .map(|(field, field_name)| { - let field_attr = try!( - attr::FieldAttrs::from_field(cx, - container_ty, - generics, - field) - ); + let extract_values = field_attrs.iter().zip(field_names.iter()) + .map(|(field_attr, field_name)| { let missing_expr = field_attr.expr_is_missing(); Ok(quote_stmt!(cx, diff --git a/serde_tests/tests/test_annotations.rs b/serde_tests/tests/test_annotations.rs index bd640a62..5521c1d0 100644 --- a/serde_tests/tests/test_annotations.rs +++ b/serde_tests/tests/test_annotations.rs @@ -1,5 +1,5 @@ use std::default::Default; -use serde::{Serialize, Serializer}; +use serde::{Serialize, Serializer, Deserialize, Deserializer}; use token::{ Error, @@ -10,13 +10,16 @@ use token::{ assert_de_tokens_error }; -trait Trait { +trait Trait: Sized { fn my_default() -> Self; fn should_skip(&self) -> bool; fn serialize_with(&self, ser: &mut S) -> Result<(), S::Error> where S: Serializer; + + fn deserialize_with(de: &mut D) -> Result + where D: Deserializer; } impl Trait for i32 { @@ -33,6 +36,16 @@ impl Trait for i32 { false.serialize(ser) } } + + fn deserialize_with(de: &mut D) -> Result + where D: Deserializer + { + if try!(Deserialize::deserialize(de)) { + Ok(123) + } else { + Ok(2) + } + } } #[derive(Debug, PartialEq, Serialize, Deserialize)] @@ -622,3 +635,105 @@ fn test_serialize_with_enum() { ] ); } + +#[derive(Debug, PartialEq, Deserialize)] +struct DeserializeWithStruct where B: Trait { + a: i8, + #[serde(deserialize_with="Trait::deserialize_with(deserializer)")] + b: B, +} + +#[test] +fn test_deserialize_with_struct() { + assert_de_tokens( + &DeserializeWithStruct { + a: 1, + b: 2, + }, + vec![ + Token::StructStart("DeserializeWithStruct", Some(2)), + + Token::StructSep, + Token::Str("a"), + Token::I8(1), + + Token::StructSep, + Token::Str("b"), + Token::Bool(false), + + Token::StructEnd, + ] + ); + + assert_de_tokens( + &DeserializeWithStruct { + a: 1, + b: 123, + }, + vec![ + Token::StructStart("DeserializeWithStruct", Some(2)), + + Token::StructSep, + Token::Str("a"), + Token::I8(1), + + Token::StructSep, + Token::Str("b"), + Token::Bool(true), + + Token::StructEnd, + ] + ); +} + +#[derive(Debug, PartialEq, Deserialize)] +enum DeserializeWithEnum where B: Trait { + Struct { + a: i8, + #[serde(deserialize_with="Trait::deserialize_with(deserializer)")] + b: B, + } +} + +#[test] +fn test_deserialize_with_enum() { + assert_de_tokens( + &DeserializeWithEnum::Struct { + a: 1, + b: 2, + }, + vec![ + Token::EnumMapStart("DeserializeWithEnum", "Struct", Some(2)), + + Token::EnumMapSep, + Token::Str("a"), + Token::I8(1), + + Token::EnumMapSep, + Token::Str("b"), + Token::Bool(false), + + Token::EnumMapEnd, + ] + ); + + assert_de_tokens( + &DeserializeWithEnum::Struct { + a: 1, + b: 123, + }, + vec![ + Token::EnumMapStart("DeserializeWithEnum", "Struct", Some(2)), + + Token::EnumMapSep, + Token::Str("a"), + Token::I8(1), + + Token::EnumMapSep, + Token::Str("b"), + Token::Bool(true), + + Token::EnumMapEnd, + ] + ); +}