diff --git a/serde_codegen/src/bound.rs b/serde_codegen/src/bound.rs index ea7da8e9..0f421959 100644 --- a/serde_codegen/src/bound.rs +++ b/serde_codegen/src/bound.rs @@ -1,7 +1,8 @@ +use std::collections::HashSet; + use aster::AstBuilder; use syntax::ast; -use syntax::ptr::P; use syntax::visit; use internals::ast::Item; @@ -47,6 +48,17 @@ pub fn with_where_predicates_from_fields( .build() } +// Puts the given bound on any generic type parameters that are used in fields +// for which filter returns true. +// +// For example, the following struct needs the bound `A: Serialize, B: Serialize`. +// +// struct S<'b, A, B: 'b, C> { +// a: A, +// b: Option<&'b B> +// #[serde(skip_serializing)] +// c: C, +// } pub fn with_bound( builder: &AstBuilder, item: &Item, @@ -56,95 +68,53 @@ pub fn with_bound( ) -> ast::Generics where F: Fn(&attr::Field) -> bool, { + struct FindTyParams { + // Set of all generic type parameters on the current struct (A, B, C in + // the example). Initialized up front. + all_ty_params: HashSet, + // Set of generic type parameters used in fields for which filter + // returns true (A and B in the example). Filled in as the visitor sees + // them. + relevant_ty_params: HashSet, + } + impl visit::Visitor for FindTyParams { + fn visit_path(&mut self, path: &ast::Path, _id: ast::NodeId) { + if !path.global && path.segments.len() == 1 { + let id = path.segments[0].identifier.name; + if self.all_ty_params.contains(&id) { + self.relevant_ty_params.insert(id); + } + } + visit::walk_path(self, path); + } + } + + let all_ty_params: HashSet<_> = generics.ty_params.iter() + .map(|ty_param| ty_param.ident.name) + .collect(); + + let relevant_tys = item.body.all_fields() + .filter(|&field| filter(&field.attrs)) + .map(|field| &field.ty); + + let mut visitor = FindTyParams { + all_ty_params: all_ty_params, + relevant_ty_params: HashSet::new(), + }; + for ty in relevant_tys { + visit::walk_ty(&mut visitor, ty); + } + builder.from_generics(generics.clone()) .with_predicates( - item.body.all_fields() - .filter(|&field| filter(&field.attrs)) - .map(|field| &field.ty) - .filter(|ty| !contains_recursion(ty, item.ident)) - .map(|ty| strip_reference(ty)) - .map(|ty| builder.where_predicate() - // the type that is being bounded e.g. T - .bound().build(ty.clone()) + generics.ty_params.iter() + .map(|ty_param| ty_param.ident.name) + .filter(|id| visitor.relevant_ty_params.contains(id)) + .map(|id| builder.where_predicate() + // the type parameter that is being bounded e.g. T + .bound().build(builder.ty().id(id)) // the bound e.g. Serialize .bound().trait_(bound.clone()).build() .build())) .build() } - -// We do not attempt to generate any bounds based on field types that are -// directly recursive, as in: -// -// struct Test { -// next: Box>, -// } -// -// This does not catch field types that are mutually recursive with some other -// type. For those, we require bounds to be specified by a `bound` attribute if -// the inferred ones are not correct. -// -// struct Test { -// #[serde(bound="D: Serialize + Deserialize")] -// next: Box>, -// } -// struct Other { -// #[serde(bound="D: Serialize + Deserialize")] -// next: Box>, -// } -fn contains_recursion(ty: &ast::Ty, ident: ast::Ident) -> bool { - struct FindRecursion { - ident: ast::Ident, - found_recursion: bool, - } - impl visit::Visitor for FindRecursion { - fn visit_path(&mut self, path: &ast::Path, _id: ast::NodeId) { - if !path.global - && path.segments.len() == 1 - && path.segments[0].identifier == self.ident { - self.found_recursion = true; - } else { - visit::walk_path(self, path); - } - } - } - - let mut visitor = FindRecursion { - ident: ident, - found_recursion: false, - }; - visit::walk_ty(&mut visitor, ty); - visitor.found_recursion -} - -// This is required to handle types that use both a reference and a value of -// the same type, as in: -// -// enum Test<'a, T> where T: 'a { -// Lifetime(&'a T), -// NoLifetime(T), -// } -// -// Preserving references, we would generate an impl like: -// -// impl<'a, T> Serialize for Test<'a, T> -// where &'a T: Serialize, -// T: Serialize { ... } -// -// And taking a reference to one of the elements would fail with: -// -// error: cannot infer an appropriate lifetime for pattern due -// to conflicting requirements [E0495] -// Test::NoLifetime(ref v) => { ... } -// ^~~~~ -// -// Instead, we strip references before adding `T: Serialize` bounds in order to -// generate: -// -// impl<'a, T> Serialize for Test<'a, T> -// where T: Serialize { ... } -fn strip_reference(mut ty: &P) -> &P { - while let ast::TyKind::Rptr(_, ref mut_ty) = ty.node { - ty = &mut_ty.ty; - } - ty -} diff --git a/testing/tests/test_gen.rs b/testing/tests/test_gen.rs index 80b85f56..32b7ad32 100644 --- a/testing/tests/test_gen.rs +++ b/testing/tests/test_gen.rs @@ -99,6 +99,35 @@ struct ListNode { next: Box>, } +#[derive(Serialize, Deserialize)] +struct RecursiveA { + b: Box, +} + +#[derive(Serialize, Deserialize)] +enum RecursiveB { + A(RecursiveA), +} + +#[derive(Serialize, Deserialize)] +struct RecursiveGenericA { + t: T, + b: Box>, +} + +#[derive(Serialize, Deserialize)] +enum RecursiveGenericB { + T(T), + A(RecursiveGenericA), +} + +#[derive(Serialize)] +#[allow(dead_code)] +struct OptionStatic<'a> { + a: Option<&'a str>, + b: Option<&'static str>, +} + #[derive(Serialize, Deserialize)] #[serde(bound="D: SerializeWith + DeserializeWith")] struct WithTraits1 { @@ -139,4 +168,3 @@ trait DeserializeWith: Sized { struct X; fn ser_x(_: &X, _: &mut S) -> Result<(), S::Error> { panic!() } fn de_x(_: &mut D) -> Result { panic!() } -