From 419f34c6657ca6badc456451d20317aa7c738414 Mon Sep 17 00:00:00 2001 From: Erick Tryzelaar Date: Sun, 8 Mar 2015 18:46:41 -0700 Subject: [PATCH] Fix #[derive_deserialize] for generic structs --- serde2/serde2_macros/src/lib.rs | 59 ++++++++++++++++++++++++++++----- serde2/tests/test_macros.rs | 49 ++++++++++++++++++++++----- 2 files changed, 91 insertions(+), 17 deletions(-) diff --git a/serde2/serde2_macros/src/lib.rs b/serde2/serde2_macros/src/lib.rs index 8dfa8209..c891e75c 100644 --- a/serde2/serde2_macros/src/lib.rs +++ b/serde2/serde2_macros/src/lib.rs @@ -729,7 +729,8 @@ fn deserialize_struct( struct_path, &fields, state, - struct_def) + struct_def, + generics) } } } @@ -919,14 +920,56 @@ fn deserialize_struct_named_fields( fields: &[(Ident, Span)], state: P, struct_def: &StructDef, + generics: &ast::Generics, ) -> P { - let struct_name = builder.expr().str(struct_ident); + let visitor_impl_generics = builder.from_generics(generics.clone()) + .add_ty_param_bound( + builder.path().global().ids(&["serde2", "de", "Deserialize"]).build() + ) + .build(); // Create the field names for the fields. let field_names: Vec = (0 .. fields.len()) .map(|i| token::str_to_ident(&format!("__field{}", i))) .collect(); + // Build `__Visitor(PhantomData, PhantomData, ...)` + let (visitor_struct, visitor_expr) = if generics.ty_params.is_empty() { + ( + builder.item().tuple_struct("__Visitor") + .build(), + builder.expr().id("__Visitor"), + ) + } else { + ( + builder.item().tuple_struct("__Visitor") + .generics().with(generics.clone()).build() + .with_tys( + generics.ty_params.iter().map(|ty_param| { + builder.ty().phantom_data().id(ty_param.ident) + }) + ) + .build(), + builder.expr().call().id("__Visitor") + .with_args( + generics.ty_params.iter().map(|_| { + builder.expr().phantom_data() + }) + ) + .build(), + ) + }; + + let struct_name = builder.expr().str(struct_ident); + + let visitor_ty = builder.ty().path() + .segment("__Visitor").with_generics(generics.clone()).build() + .build(); + + let value_ty = builder.ty().path() + .segment(type_ident).with_generics(generics.clone()).build() + .build(); + let field_deserializer = declare_map_field_deserializer( cx, span, @@ -948,13 +991,13 @@ fn deserialize_struct_named_fields( quote_expr!(cx, { $field_deserializer - struct __Visitor; + $visitor_struct; - impl ::serde2::de::Visitor for __Visitor { - type Value = $type_ident; + impl $visitor_impl_generics ::serde2::de::Visitor for $visitor_ty { + type Value = $value_ty; #[inline] - fn visit_map<__V>(&mut self, mut visitor: __V) -> Result<$type_ident, __V::Error> + fn visit_map<__V>(&mut self, mut visitor: __V) -> Result<$value_ty, __V::Error> where __V: ::serde2::de::MapVisitor, { $visit_map_expr @@ -963,7 +1006,7 @@ fn deserialize_struct_named_fields( #[inline] fn visit_named_map<__V>(&mut self, name: &str, - visitor: __V) -> Result<$type_ident, __V::Error> + visitor: __V) -> Result<$value_ty, __V::Error> where __V: ::serde2::de::MapVisitor, { if name == $struct_name { @@ -974,7 +1017,7 @@ fn deserialize_struct_named_fields( } } - $state.visit(__Visitor) + $state.visit($visitor_expr) }) } diff --git a/serde2/tests/test_macros.rs b/serde2/tests/test_macros.rs index 30a73e66..6c8770c1 100644 --- a/serde2/tests/test_macros.rs +++ b/serde2/tests/test_macros.rs @@ -30,14 +30,6 @@ trait Trait { #[derive_deserialize] struct NamedUnit; -#[derive(Debug, PartialEq)] -#[derive_serialize] -struct NamedMap<'a, 'b, A: 'a, B: 'b, C> { - a: &'a A, - b: &'b mut B, - c: C, -} - #[derive(Debug, PartialEq)] #[derive_serialize] //#[derive_deserialize] @@ -126,7 +118,15 @@ fn test_de_named_tuple() { } #[test] -fn test_named_map() { +fn test_ser_named_map() { + #[derive(Debug, PartialEq)] + #[derive_serialize] + struct NamedMap<'a, 'b, A: 'a, B: 'b, C> { + a: &'a A, + b: &'b mut B, + c: C, + } + let a = 5; let mut b = 6; let c = 7; @@ -151,6 +151,37 @@ fn test_named_map() { ); } +#[test] +fn test_de_named_map() { + #[derive(Debug, PartialEq)] + #[derive_deserialize] + struct NamedMap { + a: A, + b: B, + c: C, + } + + let v = NamedMap { + a: 5, + b: 6, + c: 7, + }; + + assert_eq!( + json::from_str("{\"a\":5,\"b\":6,\"c\":7}").unwrap(), + v + ); + + assert_eq!( + json::from_value(Value::Object(btreemap![ + "a".to_string() => Value::I64(5), + "b".to_string() => Value::I64(6), + "c".to_string() => Value::I64(7) + ])).unwrap(), + v + ); +} + #[test] fn test_enum_unit() { assert_eq!(