diff --git a/serde_codegen/src/attr.rs b/serde_codegen/src/attr.rs index 2c019382..720d16c1 100644 --- a/serde_codegen/src/attr.rs +++ b/serde_codegen/src/attr.rs @@ -377,7 +377,7 @@ fn get_renames(cx: &ExtCtxt, Ok((ser_name, de_name)) } -fn get_serde_meta_items(attr: &ast::Attribute) -> Option<&[P]> { +pub fn get_serde_meta_items(attr: &ast::Attribute) -> Option<&[P]> { match attr.node.value.node { ast::MetaItemKind::List(ref name, ref items) if name == &"serde" => { attr::mark_used(&attr); diff --git a/serde_codegen/src/ser.rs b/serde_codegen/src/ser.rs index 08288f35..955a7165 100644 --- a/serde_codegen/src/ser.rs +++ b/serde_codegen/src/ser.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use aster; use syntax::ast::{ @@ -10,6 +12,7 @@ use syntax::codemap::Span; use syntax::ext::base::{Annotatable, ExtCtxt}; use syntax::ext::build::AstBuilder; use syntax::ptr::P; +use syntax::visit; use attr; use error::Error; @@ -60,11 +63,7 @@ fn serialize_item( } }; - let impl_generics = builder.from_generics(generics.clone()) - .add_ty_param_bound( - builder.path().global().ids(&["serde", "ser", "Serialize"]).build() - ) - .build(); + let impl_generics = build_impl_generics(cx, builder, item, generics); let ty = builder.ty().path() .segment(item.ident).with_generics(impl_generics.clone()).build() @@ -89,6 +88,157 @@ fn serialize_item( ).unwrap()) } +// All the generics in the input, plus a bound `T: Serialize` for each generic +// field type that will be serialized by us. +fn build_impl_generics( + cx: &ExtCtxt, + builder: &aster::AstBuilder, + item: &Item, + generics: &ast::Generics, +) -> ast::Generics { + let serialize_path = builder.path() + .global() + .ids(&["serde", "ser", "Serialize"]) + .build(); + + builder.from_generics(generics.clone()) + .with_predicates( + all_variants(cx, item).iter() + .flat_map(|variant_data| all_struct_fields(variant_data)) + .filter(|field| serialized_by_us(field)) + .map(|field| &field.ty) + // TODO this filter can be removed later, see comment on function + .filter(|ty| contains_generic(ty, generics)) + .map(|ty| strip_reference(ty)) + .map(|ty| builder.where_predicate() + // the type that is being bounded i.e. T + .bound().build(ty.clone()) + // the bound i.e. Serialize + .bound().trait_(serialize_path.clone()).build() + .build())) + .build() +} + +fn all_variants<'a>(cx: &ExtCtxt, item: &'a Item) -> Vec<&'a ast::VariantData> { + match item.node { + ast::ItemKind::Struct(ref variant_data, _) => { + vec![variant_data] + } + ast::ItemKind::Enum(ref enum_def, _) => { + enum_def.variants.iter() + .map(|variant| &variant.node.data) + .collect() + } + _ => { + cx.span_bug(item.span, + "expected Item to be Struct or Enum in #[derive(Serialize)]"); + } + } +} + +fn all_struct_fields(variant_data: &ast::VariantData) -> &[ast::StructField] { + match *variant_data { + ast::VariantData::Struct(ref fields, _) | + ast::VariantData::Tuple(ref fields, _) => { + fields + } + ast::VariantData::Unit(_) => { + &[] + } + } +} + +// Fields with a `skip_serializing` or `serialize_with` attribute are not +// serialized by us. All other fields may need a `T: Serialize` bound where T is +// the type of the field. +fn serialized_by_us(field: &ast::StructField) -> bool { + for meta_items in field.attrs.iter().filter_map(attr::get_serde_meta_items) { + for meta_item in meta_items { + match meta_item.node { + ast::MetaItemKind::Word(ref name) if name == &"skip_serializing" => { + return false + } + ast::MetaItemKind::NameValue(ref name, _) if name == &"serialize_with" => { + return false + } + _ => {} + } + } + } + true +} + +// Rust <1.7 enforces that `where` clauses involve generic type parameters. The +// corresponding compiler error is E0193. It is no longer enforced in Rust >=1.7 +// so this filtering can be removed in the future when we stop supporting <1.7. +// +// E0193 means we must not generate a `where` clause like `i32: Serialize` +// because even though i32 implements Serialize, i32 is not a generic type +// parameter. Clauses like `T: Serialize` and `Option: Serialize` are okay. +// This function decides whether a given type references any of the generic type +// parameters in the input `Generics`. +fn contains_generic(ty: &ast::Ty, generics: &ast::Generics) -> bool { + struct FindGeneric<'a> { + generic_names: &'a HashSet, + found_generic: bool, + } + impl<'a, 'v> visit::Visitor<'v> for FindGeneric<'a> { + fn visit_path(&mut self, path: &'v ast::Path, _id: ast::NodeId) { + if !path.global + && path.segments.len() == 1 + && self.generic_names.contains(&path.segments[0].identifier.name) { + self.found_generic = true; + } else { + visit::walk_path(self, path); + } + } + } + + let generic_names: HashSet<_> = generics.ty_params.iter() + .map(|ty_param| ty_param.ident.name) + .collect(); + + let mut visitor = FindGeneric { + generic_names: &generic_names, + found_generic: false, + }; + visit::walk_ty(&mut visitor, ty); + visitor.found_generic +} + +// This is required to handle types that use both a reference and a value of +// the same type, as in: +// +// enum Test<'a, T> where T: 'a { +// Lifetime(&'a T), +// NoLifetime(T), +// } +// +// Preserving references, we would generate an impl like: +// +// impl<'a, T> Serialize for Test<'a, T> +// where &'a T: Serialize, +// T: Serialize { ... } +// +// And taking a reference to one of the elements would fail with: +// +// error: cannot infer an appropriate lifetime for pattern due +// to conflicting requirements [E0495] +// Test::NoLifetime(ref v) => { ... } +// ^~~~~ +// +// Instead, we strip references before adding `T: Serialize` bounds in order to +// generate: +// +// impl<'a, T> Serialize for Test<'a, T> +// where T: Serialize { ... } +fn strip_reference(ty: &P) -> &P { + match ty.node { + ast::TyKind::Rptr(_, ref mut_ty) => &mut_ty.ty, + _ => ty + } +} + fn serialize_body( cx: &ExtCtxt, builder: &aster::AstBuilder, diff --git a/serde_tests/tests/test_annotations.rs b/serde_tests/tests/test_annotations.rs index 1608e85e..a1f27bf8 100644 --- a/serde_tests/tests/test_annotations.rs +++ b/serde_tests/tests/test_annotations.rs @@ -10,23 +10,33 @@ use token::{ assert_de_tokens_error }; -trait Trait: Sized { +trait MyDefault: Sized { fn my_default() -> Self; +} +trait ShouldSkip: Sized { fn should_skip(&self) -> bool; +} +trait SerializeWith: Sized { fn serialize_with(&self, ser: &mut S) -> Result<(), S::Error> where S: Serializer; +} +trait DeserializeWith: Sized { fn deserialize_with(de: &mut D) -> Result where D: Deserializer; } -impl Trait for i32 { +impl MyDefault for i32 { fn my_default() -> Self { 123 } +} +impl ShouldSkip for i32 { fn should_skip(&self) -> bool { *self == 123 } +} +impl SerializeWith for i32 { fn serialize_with(&self, ser: &mut S) -> Result<(), S::Error> where S: Serializer { @@ -36,7 +46,9 @@ impl Trait for i32 { false.serialize(ser) } } +} +impl DeserializeWith for i32 { fn deserialize_with(de: &mut D) -> Result where D: Deserializer { @@ -49,11 +61,11 @@ impl Trait for i32 { } #[derive(Debug, PartialEq, Serialize, Deserialize)] -struct DefaultStruct where C: Trait { +struct DefaultStruct where C: MyDefault { a1: A, #[serde(default)] a2: B, - #[serde(default="Trait::my_default")] + #[serde(default="MyDefault::my_default")] a3: C, } @@ -95,12 +107,12 @@ fn test_default_struct() { } #[derive(Debug, PartialEq, Serialize, Deserialize)] -enum DefaultEnum where C: Trait { +enum DefaultEnum where C: MyDefault { Struct { a1: A, #[serde(default)] a2: B, - #[serde(default="Trait::my_default")] + #[serde(default="MyDefault::my_default")] a3: C, } } @@ -389,11 +401,11 @@ fn test_rename_enum() { } #[derive(Debug, PartialEq, Serialize)] -struct SkipSerializingStruct<'a, B, C> where C: Trait { +struct SkipSerializingStruct<'a, B, C> where C: ShouldSkip { a: &'a i8, #[serde(skip_serializing)] b: B, - #[serde(skip_serializing_if="Trait::should_skip")] + #[serde(skip_serializing_if="ShouldSkip::should_skip")] c: C, } @@ -440,12 +452,12 @@ fn test_skip_serializing_struct() { } #[derive(Debug, PartialEq, Serialize)] -enum SkipSerializingEnum<'a, B, C> where C: Trait { +enum SkipSerializingEnum<'a, B, C> where C: ShouldSkip { Struct { a: &'a i8, #[serde(skip_serializing)] _b: B, - #[serde(skip_serializing_if="Trait::should_skip")] + #[serde(skip_serializing_if="ShouldSkip::should_skip")] c: C, } } @@ -492,10 +504,62 @@ fn test_skip_serializing_enum() { ); } +#[derive(Debug, PartialEq)] +struct NotSerializeStruct(i8); + +#[derive(Debug, PartialEq)] +enum NotSerializeEnum { Trouble } + +impl SerializeWith for NotSerializeEnum { + fn serialize_with(&self, ser: &mut S) -> Result<(), S::Error> + where S: Serializer + { + "trouble".serialize(ser) + } +} + #[derive(Debug, PartialEq, Serialize)] -struct SerializeWithStruct<'a, B> where B: Trait { +struct ContainsNotSerialize<'a, B, C, D> where B: 'a, D: SerializeWith { + a: &'a Option, + #[serde(skip_serializing)] + b: &'a B, + #[serde(skip_serializing)] + c: Option, + #[serde(serialize_with="SerializeWith::serialize_with")] + d: D, +} + +#[test] +fn test_elt_not_serialize() { + let a = 1; + assert_ser_tokens( + &ContainsNotSerialize { + a: &Some(a), + b: &NotSerializeStruct(2), + c: Some(NotSerializeEnum::Trouble), + d: NotSerializeEnum::Trouble, + }, + &[ + Token::StructStart("ContainsNotSerialize", Some(2)), + + Token::StructSep, + Token::Str("a"), + Token::Option(true), + Token::I8(1), + + Token::StructSep, + Token::Str("d"), + Token::Str("trouble"), + + Token::StructEnd, + ] + ); +} + +#[derive(Debug, PartialEq, Serialize)] +struct SerializeWithStruct<'a, B> where B: SerializeWith { a: &'a i8, - #[serde(serialize_with="Trait::serialize_with")] + #[serde(serialize_with="SerializeWith::serialize_with")] b: B, } @@ -544,10 +608,10 @@ fn test_serialize_with_struct() { } #[derive(Debug, PartialEq, Serialize)] -enum SerializeWithEnum<'a, B> where B: Trait { +enum SerializeWithEnum<'a, B> where B: SerializeWith { Struct { a: &'a i8, - #[serde(serialize_with="Trait::serialize_with")] + #[serde(serialize_with="SerializeWith::serialize_with")] b: B, } } @@ -597,9 +661,9 @@ fn test_serialize_with_enum() { } #[derive(Debug, PartialEq, Deserialize)] -struct DeserializeWithStruct where B: Trait { +struct DeserializeWithStruct where B: DeserializeWith { a: i8, - #[serde(deserialize_with="Trait::deserialize_with")] + #[serde(deserialize_with="DeserializeWith::deserialize_with")] b: B, } @@ -647,10 +711,10 @@ fn test_deserialize_with_struct() { } #[derive(Debug, PartialEq, Deserialize)] -enum DeserializeWithEnum where B: Trait { +enum DeserializeWithEnum where B: DeserializeWith { Struct { a: i8, - #[serde(deserialize_with="Trait::deserialize_with")] + #[serde(deserialize_with="DeserializeWith::deserialize_with")] b: B, } }