diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index f1a37b85..e3197fcf 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -838,6 +838,10 @@ fn deserialize_struct( quote! { _serde::Deserializer::deserialize_any(#deserializer, #visitor_expr) } + } else if is_enum && cattrs.has_flatten() { + quote! { + _serde::de::VariantAccess::newtype_variant_seed(__variant, #visitor_expr) + } } else if is_enum { quote! { _serde::de::VariantAccess::struct_variant(__variant, FIELDS, #visitor_expr) @@ -875,6 +879,23 @@ fn deserialize_struct( _ => None, }; + let visitor_seed = if is_enum && cattrs.has_flatten() { + Some(quote! { + impl #de_impl_generics _serde::de::DeserializeSeed<#delife> for __Visitor #de_ty_generics #where_clause { + type Value = #this #ty_generics; + + fn deserialize<__D>(self, __deserializer: __D) -> _serde::export::Result + where + __D: _serde::Deserializer<'de>, + { + _serde::Deserializer::deserialize_map(__deserializer, self) + } + } + }) + } else { + None + }; + quote_block! { #field_visitor @@ -901,6 +922,8 @@ fn deserialize_struct( } } + #visitor_seed + #fields_stmt #dispatch @@ -1738,7 +1761,7 @@ fn deserialize_generated_identifier( let this = quote!(__Field); let field_idents: &Vec<_> = &fields.iter().map(|&(_, ref ident)| ident).collect(); - let (ignore_variant, fallthrough) = if cattrs.has_flatten() { + let (ignore_variant, fallthrough) = if !is_variant && cattrs.has_flatten() { let ignore_variant = quote!(__other(_serde::private::de::Content<'de>),); let fallthrough = quote!(_serde::export::Ok(__Field::__other(__value))); (Some(ignore_variant), Some(fallthrough)) @@ -1755,10 +1778,10 @@ fn deserialize_generated_identifier( fields, is_variant, fallthrough, - cattrs.has_flatten(), + !is_variant && cattrs.has_flatten(), )); - let lifetime = if cattrs.has_flatten() { + let lifetime = if !is_variant && cattrs.has_flatten() { Some(quote!(<'de>)) } else { None diff --git a/serde_derive/src/internals/check.rs b/serde_derive/src/internals/check.rs index 639f95e9..00188ac6 100644 --- a/serde_derive/src/internals/check.rs +++ b/serde_derive/src/internals/check.rs @@ -6,7 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use internals::ast::{Container, Data, Style}; +use internals::ast::{Container, Data, Field, Style}; use internals::attr::{EnumTag, Identifier}; use internals::Ctxt; @@ -44,43 +44,49 @@ fn check_getter(cx: &Ctxt, cont: &Container) { /// Flattening has some restrictions we can test. fn check_flatten(cx: &Ctxt, cont: &Container) { match cont.data { - Data::Enum(_) => { - if cont.attrs.has_flatten() { - cx.error("#[serde(flatten)] cannot be used within enums"); - } - } - Data::Struct(style, _) => { - for field in cont.data.all_fields() { - if !field.attrs.flatten() { - continue; - } - match style { - Style::Tuple => { - cx.error("#[serde(flatten)] cannot be used on tuple structs"); - } - Style::Newtype => { - cx.error("#[serde(flatten)] cannot be used on newtype structs"); - } - _ => {} - } - if field.attrs.skip_serializing() { - cx.error( - "#[serde(flatten] can not be combined with \ - #[serde(skip_serializing)]", - ); - } else if field.attrs.skip_serializing_if().is_some() { - cx.error( - "#[serde(flatten] can not be combined with \ - #[serde(skip_serializing_if = \"...\")]", - ); - } else if field.attrs.skip_deserializing() { - cx.error( - "#[serde(flatten] can not be combined with \ - #[serde(skip_deserializing)]", - ); + Data::Enum(ref variants) => { + for variant in variants { + for field in &variant.fields { + check_flatten_field(cx, variant.style, field); } } } + Data::Struct(style, ref fields) => { + for field in fields { + check_flatten_field(cx, style, field); + } + } + } +} + +fn check_flatten_field(cx: &Ctxt, style: Style, field: &Field) { + if !field.attrs.flatten() { + return; + } + match style { + Style::Tuple => { + cx.error("#[serde(flatten)] cannot be used on tuple structs"); + } + Style::Newtype => { + cx.error("#[serde(flatten)] cannot be used on newtype structs"); + } + _ => {} + } + if field.attrs.skip_serializing() { + cx.error( + "#[serde(flatten] can not be combined with \ + #[serde(skip_serializing)]", + ); + } else if field.attrs.skip_serializing_if().is_some() { + cx.error( + "#[serde(flatten] can not be combined with \ + #[serde(skip_serializing_if = \"...\")]", + ); + } else if field.attrs.skip_deserializing() { + cx.error( + "#[serde(flatten] can not be combined with \ + #[serde(skip_deserializing)]", + ); } } diff --git a/serde_derive/src/ser.rs b/serde_derive/src/ser.rs index e40ed46d..987cc53a 100644 --- a/serde_derive/src/ser.rs +++ b/serde_derive/src/ser.rs @@ -789,6 +789,10 @@ fn serialize_struct_variant<'a>( fields: &[Field], name: &str, ) -> Fragment { + if fields.iter().any(|field| field.attrs.flatten()) { + return serialize_struct_variant_with_flatten(context, params, fields, name); + } + let struct_trait = match context { StructVariant::ExternallyTagged { .. } => (StructTrait::SerializeStructVariant), StructVariant::InternallyTagged { .. } | StructVariant::Untagged => { @@ -863,6 +867,92 @@ fn serialize_struct_variant<'a>( } } +fn serialize_struct_variant_with_flatten<'a>( + context: StructVariant<'a>, + params: &Parameters, + fields: &[Field], + name: &str, +) -> Fragment { + let struct_trait = StructTrait::SerializeMap; + let serialize_fields = serialize_struct_visitor(fields, params, true, &struct_trait); + + let mut serialized_fields = fields + .iter() + .filter(|&field| !field.attrs.skip_serializing()) + .peekable(); + + let let_mut = mut_if(serialized_fields.peek().is_some()); + + match context { + StructVariant::ExternallyTagged { + variant_index, + variant_name, + } => { + let this = ¶ms.this; + let fields_ty = fields.iter().map(|f| &f.ty); + let fields_ident = &fields.iter().map(|f| f.ident).collect::>(); + + let (_, ty_generics, where_clause) = params.generics.split_for_impl(); + let wrapper_generics = bound::with_lifetime_bound(¶ms.generics, "'__a"); + let (wrapper_impl_generics, wrapper_ty_generics, _) = wrapper_generics.split_for_impl(); + + quote_block! { + struct __EnumFlatten #wrapper_generics #where_clause { + data: (#(&'__a #fields_ty,)*), + phantom: _serde::export::PhantomData<#this #ty_generics>, + } + + impl #wrapper_impl_generics _serde::Serialize for __EnumFlatten #wrapper_ty_generics #where_clause { + fn serialize<__S>(&self, __serializer: __S) -> _serde::export::Result<__S::Ok, __S::Error> + where + __S: _serde::Serializer, + { + let (#(#fields_ident,)*) = self.data; + let #let_mut __serde_state = try!(_serde::Serializer::serialize_map( + __serializer, + _serde::export::None)); + #(#serialize_fields)* + _serde::ser::SerializeMap::end(__serde_state) + } + } + + _serde::Serializer::serialize_newtype_variant( + __serializer, + #name, + #variant_index, + #variant_name, + &__EnumFlatten { + data: (#(#fields_ident,)*), + phantom: _serde::export::PhantomData::<#this #ty_generics>, + }) + } + } + StructVariant::InternallyTagged { tag, variant_name } => { + quote_block! { + let #let_mut __serde_state = try!(_serde::Serializer::serialize_map( + __serializer, + _serde::export::None)); + try!(_serde::ser::SerializeMap::serialize_entry( + &mut __serde_state, + #tag, + #variant_name, + )); + #(#serialize_fields)* + _serde::ser::SerializeMap::end(__serde_state) + } + } + StructVariant::Untagged => { + quote_block! { + let #let_mut __serde_state = try!(_serde::Serializer::serialize_map( + __serializer, + _serde::export::None)); + #(#serialize_fields)* + _serde::ser::SerializeMap::end(__serde_state) + } + } + } +} + fn serialize_tuple_struct_visitor( fields: &[Field], params: &Parameters, diff --git a/test_suite/tests/compile-fail/conflict/flatten-within-enum.rs b/test_suite/tests/compile-fail/conflict/flatten-within-enum.rs deleted file mode 100644 index d2e007d2..00000000 --- a/test_suite/tests/compile-fail/conflict/flatten-within-enum.rs +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2018 Serde Developers -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#[macro_use] -extern crate serde_derive; - -#[derive(Serialize)] //~ ERROR: proc-macro derive panicked -//~^ HELP: #[serde(flatten)] cannot be used within enums -enum Foo { - A { - #[serde(flatten)] - fields: HashMap, - } -} - -fn main() {} diff --git a/test_suite/tests/test_annotations.rs b/test_suite/tests/test_annotations.rs index 815a47b6..d33ad93e 100644 --- a/test_suite/tests/test_annotations.rs +++ b/test_suite/tests/test_annotations.rs @@ -1839,3 +1839,158 @@ fn test_flatten_internally_tagged() { ], ); } + +#[test] +fn test_externally_tagged_enum_containing_flatten() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + enum Data { + A { + a: i32, + #[serde(flatten)] + flat: Flat, + } + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Flat { + b: i32, + } + + let data = Data::A { + a: 0, + flat: Flat { + b: 0, + } + }; + + assert_tokens( + &data, + &[ + Token::NewtypeVariant { name: "Data", variant: "A" }, + Token::Map { len: None }, + Token::Str("a"), + Token::I32(0), + Token::Str("b"), + Token::I32(0), + Token::MapEnd, + ], + ); +} + +#[test] +fn test_internally_tagged_enum_containing_flatten() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + #[serde(tag = "t")] + enum Data { + A { + a: i32, + #[serde(flatten)] + flat: Flat, + } + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Flat { + b: i32, + } + + let data = Data::A { + a: 0, + flat: Flat { + b: 0, + } + }; + + assert_tokens( + &data, + &[ + Token::Map { len: None }, + Token::Str("t"), + Token::Str("A"), + Token::Str("a"), + Token::I32(0), + Token::Str("b"), + Token::I32(0), + Token::MapEnd, + ], + ); +} + +#[test] +fn test_adjacently_tagged_enum_containing_flatten() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + #[serde(tag = "t", content = "c")] + enum Data { + A { + a: i32, + #[serde(flatten)] + flat: Flat, + } + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Flat { + b: i32, + } + + let data = Data::A { + a: 0, + flat: Flat { + b: 0, + } + }; + + assert_tokens( + &data, + &[ + Token::Struct { name: "Data", len: 2 }, + Token::Str("t"), + Token::Str("A"), + Token::Str("c"), + Token::Map { len: None }, + Token::Str("a"), + Token::I32(0), + Token::Str("b"), + Token::I32(0), + Token::MapEnd, + Token::StructEnd, + ], + ); +} + +#[test] +fn test_untagged_enum_containing_flatten() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + #[serde(untagged)] + enum Data { + A { + a: i32, + #[serde(flatten)] + flat: Flat, + } + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Flat { + b: i32, + } + + let data = Data::A { + a: 0, + flat: Flat { + b: 0, + } + }; + + assert_tokens( + &data, + &[ + Token::Map { len: None }, + Token::Str("a"), + Token::I32(0), + Token::Str("b"), + Token::I32(0), + Token::MapEnd, + ], + ); +}