From ca772a14f9a8f6e7fcbafe1509546e1c0ef4b4d9 Mon Sep 17 00:00:00 2001 From: Mingun Date: Sat, 27 Feb 2021 18:52:32 +0500 Subject: [PATCH 1/2] Get rid of useless DeserializeWith wrapper --- serde_derive/src/de.rs | 47 ++++++++++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index 5d0c00f3..2dc814a0 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -1728,6 +1728,15 @@ fn deserialize_externally_tagged_variant( } } +/// Generates significant part of the `Visitor::visit_seq` and `Visitor::visit_map` +/// bodies of visitors for deserializing variants of internally tagged enum: +/// +/// ```ignore +/// #[serde(tag = "...")] +/// enum InternallyTagged { +/// ... +/// } +/// ``` fn deserialize_internally_tagged_variant( params: &Parameters, variant: &Variant, @@ -1779,11 +1788,9 @@ fn deserialize_untagged_variant( deserializer: TokenStream, ) -> Fragment { if let Some(path) = variant.attrs.deserialize_with() { - let (wrapper, wrapper_ty, unwrap_fn) = wrap_deserialize_variant_with(params, variant, path); + let unwrap_fn = unwrap_to_variant_closure(params, variant, quote!(__wrap)); return quote_block! { - #wrapper - _serde::__private::Result::map( - <#wrapper_ty as _serde::Deserialize>::deserialize(#deserializer), #unwrap_fn) + _serde::__private::Result::map(#path(#deserializer), #unwrap_fn) }; } @@ -2883,44 +2890,54 @@ fn wrap_deserialize_variant_with( variant: &Variant, deserialize_with: &syn::ExprPath, ) -> (TokenStream, TokenStream, TokenStream) { - let this = ¶ms.this; - let variant_ident = &variant.ident; - let field_tys = variant.fields.iter().map(|field| field.ty); let (wrapper, wrapper_ty) = wrap_deserialize_with(params, "e!((#(#field_tys),*)), deserialize_with); + let unwrap_fn = unwrap_to_variant_closure(params, variant, quote!(__wrap.value)); + + (wrapper, wrapper_ty, unwrap_fn) +} + +/// Generates closure that converts single input parameter to the final value +fn unwrap_to_variant_closure( + params: &Parameters, + variant: &Variant, + wrapper: TokenStream, +) -> TokenStream { + let this = ¶ms.this; + let variant_ident = &variant.ident; + let field_access = (0..variant.fields.len()).map(|n| { Member::Unnamed(Index { index: n as u32, span: Span::call_site(), }) }); - let unwrap_fn = match variant.style { + + match variant.style { Style::Struct if variant.fields.len() == 1 => { let member = &variant.fields[0].member; quote! { - |__wrap| #this::#variant_ident { #member: __wrap.value } + |__wrap| #this::#variant_ident { #member: #wrapper } } } Style::Struct => { let members = variant.fields.iter().map(|field| &field.member); quote! { - |__wrap| #this::#variant_ident { #(#members: __wrap.value.#field_access),* } + |__wrap| #this::#variant_ident { #(#members: #wrapper.#field_access),* } } } Style::Tuple => quote! { - |__wrap| #this::#variant_ident(#(__wrap.value.#field_access),*) + |__wrap| #this::#variant_ident(#(#wrapper.#field_access),*) }, Style::Newtype => quote! { - |__wrap| #this::#variant_ident(__wrap.value) + |__wrap| #this::#variant_ident(#wrapper) }, Style::Unit => quote! { |__wrap| #this::#variant_ident }, - }; - - (wrapper, wrapper_ty, unwrap_fn) + } } fn expr_is_missing(field: &Field, cattrs: &attr::Container) -> Fragment { From e0fc46783d36286e7537ea5fb00eae4a01e85a35 Mon Sep 17 00:00:00 2001 From: Mingun Date: Sun, 28 Feb 2021 17:43:54 +0500 Subject: [PATCH 2/2] Add test with generic deserialize_with function --- serde_derive/src/de.rs | 31 +++++++++++++++++++++++-------- test_suite/tests/test_gen.rs | 23 +++++++++++++++++++++++ 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index 2dc814a0..8a413708 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -1788,7 +1788,7 @@ fn deserialize_untagged_variant( deserializer: TokenStream, ) -> Fragment { if let Some(path) = variant.attrs.deserialize_with() { - let unwrap_fn = unwrap_to_variant_closure(params, variant, quote!(__wrap)); + let unwrap_fn = unwrap_to_variant_closure(params, variant, false); return quote_block! { _serde::__private::Result::map(#path(#deserializer), #unwrap_fn) }; @@ -2894,7 +2894,7 @@ fn wrap_deserialize_variant_with( let (wrapper, wrapper_ty) = wrap_deserialize_with(params, "e!((#(#field_tys),*)), deserialize_with); - let unwrap_fn = unwrap_to_variant_closure(params, variant, quote!(__wrap.value)); + let unwrap_fn = unwrap_to_variant_closure(params, variant, true); (wrapper, wrapper_ty, unwrap_fn) } @@ -2903,11 +2903,26 @@ fn wrap_deserialize_variant_with( fn unwrap_to_variant_closure( params: &Parameters, variant: &Variant, - wrapper: TokenStream, + with_wrapper: bool, ) -> TokenStream { let this = ¶ms.this; let variant_ident = &variant.ident; + let (arg, wrapper) = if with_wrapper { + ( + quote!{ __wrap }, + quote!{ __wrap.value }, + ) + } else { + let field_tys = variant.fields.iter().map(|field| field.ty); + ( + // When additional DeserializeWith wrapper is not used, Rust + // is not able to infer types, so we help him + quote!{ __wrap: (#(#field_tys),*) }, + quote!{ __wrap }, + ) + }; + let field_access = (0..variant.fields.len()).map(|n| { Member::Unnamed(Index { index: n as u32, @@ -2919,23 +2934,23 @@ fn unwrap_to_variant_closure( Style::Struct if variant.fields.len() == 1 => { let member = &variant.fields[0].member; quote! { - |__wrap| #this::#variant_ident { #member: #wrapper } + |#arg| #this::#variant_ident { #member: #wrapper } } } Style::Struct => { let members = variant.fields.iter().map(|field| &field.member); quote! { - |__wrap| #this::#variant_ident { #(#members: #wrapper.#field_access),* } + |#arg| #this::#variant_ident { #(#members: #wrapper.#field_access),* } } } Style::Tuple => quote! { - |__wrap| #this::#variant_ident(#(#wrapper.#field_access),*) + |#arg| #this::#variant_ident(#(#wrapper.#field_access),*) }, Style::Newtype => quote! { - |__wrap| #this::#variant_ident(#wrapper) + |#arg| #this::#variant_ident(#wrapper) }, Style::Unit => quote! { - |__wrap| #this::#variant_ident + |#arg| #this::#variant_ident }, } } diff --git a/test_suite/tests/test_gen.rs b/test_suite/tests/test_gen.rs index 309b164b..d55a99b6 100644 --- a/test_suite/tests/test_gen.rs +++ b/test_suite/tests/test_gen.rs @@ -821,3 +821,26 @@ where { vec.first().serialize(serializer) } + +////////////////////////////////////////////////////////////////////////// + +#[derive(Debug, PartialEq, Deserialize)] +#[serde(tag = "tag")] +enum InternallyTagged { + #[serde(deserialize_with = "deserialize_generic")] + Unit, + + #[serde(deserialize_with = "deserialize_generic")] + Newtype(i32), + + #[serde(deserialize_with = "deserialize_generic")] + Struct { f1: String, f2: u8 }, +} + +fn deserialize_generic<'de, T, D>(deserializer: D) -> StdResult +where + T: Deserialize<'de>, + D: Deserializer<'de>, +{ + T::deserialize(deserializer) +}