diff --git a/serde_derive/src/bound.rs b/serde_derive/src/bound.rs index 424c604a..714a9963 100644 --- a/serde_derive/src/bound.rs +++ b/serde_derive/src/bound.rs @@ -8,8 +8,9 @@ use std::collections::HashSet; -use syn::{self, visit}; +use syn; use syn::punctuated::{Punctuated, Pair}; +use syn::visit::{self, Visit}; use internals::ast::{Data, Container}; use internals::attr; @@ -50,46 +51,23 @@ pub fn with_where_predicates( generics } -pub fn with_where_predicates_from_fields( +pub fn with_where_predicates_from_fields( cont: &Container, generics: &syn::Generics, - trait_bound: &syn::Path, from_field: F, - gen_bound_where: W, ) -> syn::Generics where F: Fn(&attr::Field) -> Option<&[syn::WherePredicate]>, - W: Fn(&attr::Field) -> bool, { - let type_params = generics.type_params() - .map(|param| param.ident) - .collect::>(); - - let predicates_from_associated_types = cont.data - .all_fields() - .filter(|field| gen_bound_where(&field.attrs)) - .filter_map(|field| { - if let syn::Type::Path(ref ty) = *field.ty { - if let Some(Pair::Punctuated(ref t, _)) = ty.path.segments.first() { - if type_params.contains(&t.ident) { - return Some(parse_quote!(#ty: #trait_bound)); - } - } - } - None:: - }); - - let predicates_from_field_attrs = cont.data + let predicates = cont.data .all_fields() .flat_map(|field| from_field(&field.attrs)) .flat_map(|predicates| predicates.to_vec()); let mut generics = generics.clone(); - { - let predicates = &mut generics.make_where_clause().predicates; - predicates.extend(predicates_from_associated_types); - predicates.extend(predicates_from_field_attrs); - } + generics.make_where_clause() + .predicates + .extend(predicates); generics } @@ -113,17 +91,33 @@ pub fn with_bound( where F: Fn(&attr::Field, Option<&attr::Variant>) -> bool, { - struct FindTyParams { + struct FindTyParams<'ast> { // Set of all generic type parameters on the current struct (A, B, C in // the example). Initialized up front. all_type_params: HashSet, + // Set of generic type parameters used in fields for which filter // returns true (A and B in the example). Filled in as the visitor sees // them. relevant_type_params: HashSet, + + // Fields whose type is an associated type of one of the generic type + // parameters. + associated_type_usage: Vec<&'ast syn::TypePath>, } - impl<'ast> visit::Visit<'ast> for FindTyParams { - fn visit_path(&mut self, path: &syn::Path) { + impl<'ast> Visit<'ast> for FindTyParams<'ast> { + fn visit_field(&mut self, field: &'ast syn::Field) { + if let syn::Type::Path(ref ty) = field.ty { + if let Some(Pair::Punctuated(ref t, _)) = ty.path.segments.first() { + if self.all_type_params.contains(&t.ident) { + self.associated_type_usage.push(ty); + } + } + } + self.visit_type(&field.ty); + } + + fn visit_path(&mut self, path: &'ast syn::Path) { if let Some(seg) = path.segments.last() { if seg.into_value().ident == "PhantomData" { // Hardcoded exception, because PhantomData implements @@ -146,7 +140,7 @@ where // mac: T!(), // marker: PhantomData, // } - fn visit_macro(&mut self, _mac: &syn::Macro) {} + fn visit_macro(&mut self, _mac: &'ast syn::Macro) {} } let all_type_params = generics.type_params() @@ -156,6 +150,7 @@ where let mut visitor = FindTyParams { all_type_params: all_type_params, relevant_type_params: HashSet::new(), + associated_type_usage: Vec::new(), }; match cont.data { Data::Enum(ref variants) => for variant in variants.iter() { @@ -164,27 +159,28 @@ where .iter() .filter(|field| filter(&field.attrs, Some(&variant.attrs))); for field in relevant_fields { - visit::visit_type(&mut visitor, field.ty); + visitor.visit_field(field.original); } }, Data::Struct(_, ref fields) => { for field in fields.iter().filter(|field| filter(&field.attrs, None)) { - visit::visit_type(&mut visitor, field.ty); + visitor.visit_field(field.original); } } } + let relevant_type_params = visitor.relevant_type_params; + let associated_type_usage = visitor.associated_type_usage; let new_predicates = generics.type_params() .map(|param| param.ident) - .filter(|id| visitor.relevant_type_params.contains(id)) - .map(|id| { + .filter(|id| relevant_type_params.contains(id)) + .map(|id| syn::TypePath { qself: None, path: id.into() }) + .chain(associated_type_usage.into_iter().cloned()) + .map(|bounded_ty| { syn::WherePredicate::Type(syn::PredicateType { lifetimes: None, // the type parameter that is being bounded e.g. T - bounded_ty: syn::Type::Path(syn::TypePath { - qself: None, - path: id.into(), - }), + bounded_ty: syn::Type::Path(bounded_ty), colon_token: Default::default(), // the bound e.g. Serialize bounds: vec![ diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index f929397c..941c624a 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -124,15 +124,7 @@ impl Parameters { fn build_generics(cont: &Container, borrowed: &BorrowedLifetimes) -> syn::Generics { let generics = bound::without_defaults(cont.generics); - let delife = borrowed.de_lifetime(); - let de_bound = parse_quote!(_serde::Deserialize<#delife>); - let generics = bound::with_where_predicates_from_fields( - cont, - &generics, - &de_bound, - attr::Field::de_bound, - |field| field.deserialize_with().is_none() && !field.skip_deserializing() - ); + let generics = bound::with_where_predicates_from_fields(cont, &generics, attr::Field::de_bound); match cont.attrs.de_bound() { Some(predicates) => bound::with_where_predicates(&generics, predicates), @@ -144,11 +136,12 @@ fn build_generics(cont: &Container, borrowed: &BorrowedLifetimes) -> syn::Generi attr::Default::None | attr::Default::Path(_) => generics, }; + let delife = borrowed.de_lifetime(); let generics = bound::with_bound( cont, &generics, needs_deserialize_bound, - &de_bound, + &parse_quote!(_serde::Deserialize<#delife>), ); bound::with_bound( diff --git a/serde_derive/src/ser.rs b/serde_derive/src/ser.rs index 7e37ca6b..8bcacd6a 100644 --- a/serde_derive/src/ser.rs +++ b/serde_derive/src/ser.rs @@ -130,14 +130,8 @@ impl Parameters { fn build_generics(cont: &Container) -> syn::Generics { let generics = bound::without_defaults(cont.generics); - let trait_bound = parse_quote!(_serde::Serialize); - let generics = bound::with_where_predicates_from_fields( - cont, - &generics, - &trait_bound, - attr::Field::ser_bound, - |field| field.serialize_with().is_none() && !field.skip_serializing() - ); + let generics = + bound::with_where_predicates_from_fields(cont, &generics, attr::Field::ser_bound); match cont.attrs.ser_bound() { Some(predicates) => bound::with_where_predicates(&generics, predicates), @@ -145,7 +139,7 @@ fn build_generics(cont: &Container) -> syn::Generics { cont, &generics, needs_serialize_bound, - &trait_bound + &parse_quote!(_serde::Serialize), ), } }