diff --git a/serde_derive/src/bound.rs b/serde_derive/src/bound.rs index ecdca7cf..e9efc22a 100644 --- a/serde_derive/src/bound.rs +++ b/serde_derive/src/bound.rs @@ -65,6 +65,28 @@ pub fn with_where_predicates_from_fields( generics } +pub fn with_where_predicates_from_variants( + cont: &Container, + generics: &syn::Generics, + from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>, +) -> syn::Generics { + let variants = match cont.data { + Data::Enum(ref variants) => variants, + Data::Struct(_, _) => { + return generics.clone(); + } + }; + + let predicates = variants + .iter() + .flat_map(|variant| from_variant(&variant.attrs)) + .flat_map(|predicates| predicates.to_vec()); + + let mut generics = generics.clone(); + generics.make_where_clause().predicates.extend(predicates); + generics +} + // Puts the given bound on any generic type parameters that are used in fields // for which filter returns true. // diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index a2dbebf3..ed8b9438 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -163,6 +163,9 @@ fn build_generics(cont: &Container, borrowed: &BorrowedLifetimes) -> syn::Generi let generics = bound::with_where_predicates_from_fields(cont, &generics, attr::Field::de_bound); + let generics = + bound::with_where_predicates_from_variants(cont, &generics, attr::Variant::de_bound); + match cont.attrs.de_bound() { Some(predicates) => bound::with_where_predicates(&generics, predicates), None => { @@ -201,6 +204,7 @@ fn needs_deserialize_bound(field: &attr::Field, variant: Option<&attr::Variant>) !field.skip_deserializing() && field.deserialize_with().is_none() && field.de_bound().is_none() && variant.map_or(true, |variant| { !variant.skip_deserializing() && variant.deserialize_with().is_none() + && variant.de_bound().is_none() }) } diff --git a/serde_derive/src/internals/attr.rs b/serde_derive/src/internals/attr.rs index e8c34d3d..b00eb99d 100644 --- a/serde_derive/src/internals/attr.rs +++ b/serde_derive/src/internals/attr.rs @@ -527,6 +527,8 @@ pub struct Variant { ser_renamed: bool, de_renamed: bool, rename_all: RenameRule, + ser_bound: Option>, + de_bound: Option>, skip_deserializing: bool, skip_serializing: bool, other: bool, @@ -542,6 +544,8 @@ impl Variant { let mut skip_deserializing = BoolAttr::none(cx, "skip_deserializing"); let mut skip_serializing = BoolAttr::none(cx, "skip_serializing"); let mut rename_all = Attr::none(cx, "rename_all"); + let mut ser_bound = Attr::none(cx, "bound"); + let mut de_bound = Attr::none(cx, "bound"); let mut other = BoolAttr::none(cx, "other"); let mut serialize_with = Attr::none(cx, "serialize_with"); let mut deserialize_with = Attr::none(cx, "deserialize_with"); @@ -601,6 +605,24 @@ impl Variant { other.set_true(); } + // Parse `#[serde(bound = "D: Serialize")]` + Meta(NameValue(ref m)) if m.ident == "bound" => { + if let Ok(where_predicates) = + parse_lit_into_where(cx, m.ident.as_ref(), m.ident.as_ref(), &m.lit) + { + ser_bound.set(where_predicates.clone()); + de_bound.set(where_predicates); + } + } + + // Parse `#[serde(bound(serialize = "D: Serialize", deserialize = "D: Deserialize"))]` + Meta(List(ref m)) if m.ident == "bound" => { + if let Ok((ser, de)) = get_where_predicates(cx, &m.nested) { + ser_bound.set_opt(ser); + de_bound.set_opt(de); + } + } + // Parse `#[serde(with = "...")]` Meta(NameValue(ref m)) if m.ident == "with" => { if let Ok(path) = parse_lit_into_expr_path(cx, m.ident.as_ref(), &m.lit) { @@ -669,6 +691,8 @@ impl Variant { ser_renamed: ser_renamed, de_renamed: de_renamed, rename_all: rename_all.get().unwrap_or(RenameRule::None), + ser_bound: ser_bound.get(), + de_bound: de_bound.get(), skip_deserializing: skip_deserializing.get(), skip_serializing: skip_serializing.get(), other: other.get(), @@ -695,6 +719,14 @@ impl Variant { &self.rename_all } + pub fn ser_bound(&self) -> Option<&[syn::WherePredicate]> { + self.ser_bound.as_ref().map(|vec| &vec[..]) + } + + pub fn de_bound(&self) -> Option<&[syn::WherePredicate]> { + self.de_bound.as_ref().map(|vec| &vec[..]) + } + pub fn skip_deserializing(&self) -> bool { self.skip_deserializing } diff --git a/serde_derive/src/ser.rs b/serde_derive/src/ser.rs index 465491b3..2fb35a38 100644 --- a/serde_derive/src/ser.rs +++ b/serde_derive/src/ser.rs @@ -138,6 +138,9 @@ fn build_generics(cont: &Container) -> syn::Generics { let generics = bound::with_where_predicates_from_fields(cont, &generics, attr::Field::ser_bound); + let generics = + bound::with_where_predicates_from_variants(cont, &generics, attr::Variant::ser_bound); + match cont.attrs.ser_bound() { Some(predicates) => bound::with_where_predicates(&generics, predicates), None => bound::with_bound( @@ -158,6 +161,7 @@ fn needs_serialize_bound(field: &attr::Field, variant: Option<&attr::Variant>) - !field.skip_serializing() && field.serialize_with().is_none() && field.ser_bound().is_none() && variant.map_or(true, |variant| { !variant.skip_serializing() && variant.serialize_with().is_none() + && variant.ser_bound().is_none() }) } diff --git a/test_suite/tests/test_gen.rs b/test_suite/tests/test_gen.rs index 3b9d2a5e..c69d43fb 100644 --- a/test_suite/tests/test_gen.rs +++ b/test_suite/tests/test_gen.rs @@ -216,6 +216,42 @@ fn test_gen() { } assert::>(); + #[derive(Serialize, Deserialize)] + #[serde(bound = "D: SerializeWith + DeserializeWith")] + enum VariantWithTraits1 { + #[serde( + serialize_with = "SerializeWith::serialize_with", + deserialize_with = "DeserializeWith::deserialize_with" + )] + D(D), + #[serde( + serialize_with = "SerializeWith::serialize_with", + deserialize_with = "DeserializeWith::deserialize_with", + bound = "E: SerializeWith + DeserializeWith" + )] + E(E), + } + assert::>(); + + #[derive(Serialize, Deserialize)] + #[serde(bound(serialize = "D: SerializeWith", deserialize = "D: DeserializeWith"))] + enum VariantWithTraits2 { + #[serde( + serialize_with = "SerializeWith::serialize_with", + deserialize_with = "DeserializeWith::deserialize_with" + )] + D(D), + #[serde( + serialize_with = "SerializeWith::serialize_with", bound(serialize = "E: SerializeWith") + )] + #[serde( + deserialize_with = "DeserializeWith::deserialize_with", + bound(deserialize = "E: DeserializeWith") + )] + E(E), + } + assert::>(); + #[derive(Serialize, Deserialize)] struct CowStr<'a>(Cow<'a, str>); assert::();