Untagged and internally tagged enums

This commit is contained in:
David Tolnay
2017-02-02 12:10:07 -08:00
parent faaa494579
commit ff21d557c7
16 changed files with 2020 additions and 531 deletions
+308 -21
View File
@@ -110,7 +110,8 @@ fn deserialize_body(
impl_generics,
ty,
fields,
&item.attrs)
&item.attrs,
None)
}
Body::Struct(Style::Tuple, ref fields) |
Body::Struct(Style::Newtype, ref fields) => {
@@ -124,7 +125,8 @@ fn deserialize_body(
impl_generics,
ty,
fields,
&item.attrs)
&item.attrs,
None)
}
Body::Struct(Style::Unit, _) => {
deserialize_unit_struct(
@@ -238,6 +240,7 @@ fn deserialize_tuple(
ty: syn::Ty,
fields: &[Field],
item_attrs: &attr::Item,
deserializer: Option<Tokens>,
) -> Tokens {
let where_clause = &impl_generics.where_clause;
@@ -274,7 +277,9 @@ fn deserialize_tuple(
false,
);
let dispatch = if is_enum {
let dispatch = if let Some(deserializer) = deserializer {
quote!(_serde::Deserializer::deserialize(#deserializer, #visitor_expr))
} else if is_enum {
quote!(_serde::de::VariantVisitor::visit_tuple(visitor, #nfields, #visitor_expr))
} else if nfields == 1 {
let type_name = item_attrs.name().deserialize_name();
@@ -424,7 +429,11 @@ fn deserialize_struct(
ty: syn::Ty,
fields: &[Field],
item_attrs: &attr::Item,
deserializer: Option<Tokens>,
) -> Tokens {
let is_enum = variant_ident.is_some();
let is_untagged = deserializer.is_some();
let where_clause = &impl_generics.where_clause;
let (visitor_item, visitor_ty, visitor_expr) = deserialize_visitor(impl_generics);
@@ -454,8 +463,11 @@ fn deserialize_struct(
item_attrs,
);
let is_enum = variant_ident.is_some();
let dispatch = if is_enum {
let dispatch = if let Some(deserializer) = deserializer {
quote! {
_serde::Deserializer::deserialize(#deserializer, #visitor_expr)
}
} else if is_enum {
quote! {
_serde::de::VariantVisitor::visit_struct(visitor, FIELDS, #visitor_expr)
}
@@ -473,6 +485,20 @@ fn deserialize_struct(
quote!(mut visitor)
};
let visit_seq = if is_untagged {
// untagged struct variants do not get a visit_seq method
None
} else {
Some(quote! {
#[inline]
fn visit_seq<__V>(self, #visitor_var: __V) -> _serde::export::Result<#ty, __V::Error>
where __V: _serde::de::SeqVisitor
{
#visit_seq
}
})
};
quote!({
#field_visitor
@@ -485,12 +511,7 @@ fn deserialize_struct(
_serde::export::fmt::Formatter::write_str(formatter, #expecting)
}
#[inline]
fn visit_seq<__V>(self, #visitor_var: __V) -> _serde::export::Result<#ty, __V::Error>
where __V: _serde::de::SeqVisitor
{
#visit_seq
}
#visit_seq
#[inline]
fn visit_map<__V>(self, mut visitor: __V) -> _serde::export::Result<#ty, __V::Error>
@@ -512,6 +533,45 @@ fn deserialize_item_enum(
ty: syn::Ty,
variants: &[Variant],
item_attrs: &attr::Item
) -> Tokens {
match *item_attrs.tag() {
attr::EnumTag::External => {
deserialize_externally_tagged_enum(
type_ident,
impl_generics,
ty,
variants,
item_attrs,
)
}
attr::EnumTag::Internal(ref tag) => {
deserialize_internally_tagged_enum(
type_ident,
impl_generics,
ty,
variants,
item_attrs,
tag,
)
}
attr::EnumTag::None => {
deserialize_untagged_enum(
type_ident,
impl_generics,
ty,
variants,
item_attrs,
)
}
}
}
fn deserialize_externally_tagged_enum(
type_ident: &syn::Ident,
impl_generics: &syn::Generics,
ty: syn::Ty,
variants: &[Variant],
item_attrs: &attr::Item,
) -> Tokens {
let where_clause = &impl_generics.where_clause;
@@ -545,7 +605,7 @@ fn deserialize_item_enum(
.map(|(i, variant)| {
let variant_name = field_i(i);
let block = deserialize_variant(
let block = deserialize_externally_tagged_variant(
type_ident,
impl_generics,
ty.clone(),
@@ -604,7 +664,111 @@ fn deserialize_item_enum(
})
}
fn deserialize_variant(
fn deserialize_internally_tagged_enum(
type_ident: &syn::Ident,
impl_generics: &syn::Generics,
ty: syn::Ty,
variants: &[Variant],
item_attrs: &attr::Item,
tag: &str,
) -> Tokens {
let variant_names_idents: Vec<_> = variants.iter()
.enumerate()
.filter(|&(_, variant)| !variant.attrs.skip_deserializing())
.map(|(i, variant)| (variant.attrs.name().deserialize_name(), field_i(i)))
.collect();
let variants_stmt = {
let variant_names = variant_names_idents.iter().map(|&(ref name, _)| name);
quote! {
const VARIANTS: &'static [&'static str] = &[ #(#variant_names),* ];
}
};
let variant_visitor = deserialize_field_visitor(
variant_names_idents,
item_attrs,
true,
);
// Match arms to extract a variant from a string
let variant_arms = variants.iter()
.enumerate()
.filter(|&(_, variant)| !variant.attrs.skip_deserializing())
.map(|(i, variant)| {
let variant_name = field_i(i);
let block = deserialize_internally_tagged_variant(
type_ident,
impl_generics,
ty.clone(),
variant,
item_attrs,
quote!(_tagged.content),
);
quote! {
__Field::#variant_name => #block
}
});
quote!({
#variant_visitor
#variants_stmt
let _tagged = try!(_serde::Deserializer::deserialize(
deserializer,
_serde::de::private::TaggedContentVisitor::<__Field, __D::Error>::new(#tag)));
match _tagged.tag {
#(#variant_arms)*
}
})
}
fn deserialize_untagged_enum(
type_ident: &syn::Ident,
impl_generics: &syn::Generics,
ty: syn::Ty,
variants: &[Variant],
item_attrs: &attr::Item,
) -> Tokens {
let attempts = variants.iter()
.filter(|variant| !variant.attrs.skip_deserializing())
.map(|variant| {
deserialize_untagged_variant(
type_ident,
impl_generics,
ty.clone(),
variant,
item_attrs,
quote!(&_content),
)
});
// TODO this message could be better by saving the errors from the failed
// attempts. The heuristic used by TOML was to count the number of fields
// processed before an error, and use the error that happened after the
// largest number of fields. I'm not sure I like that. Maybe it would be
// better to save all the errors and combine them into one message that
// explains why none of the variants matched.
let fallthrough_msg = format!("data did not match any variant of untagged enum {}", type_ident);
quote!({
let _content = try!(<_serde::de::private::Content<__D::Error> as _serde::Deserialize>::deserialize(deserializer));
#(
if let _serde::export::Ok(ok) = #attempts {
return _serde::export::Ok(ok);
}
)*
_serde::export::Err(_serde::de::Error::custom(#fallthrough_msg))
})
}
fn deserialize_externally_tagged_variant(
type_ident: &syn::Ident,
generics: &syn::Generics,
ty: syn::Ty,
@@ -621,7 +785,7 @@ fn deserialize_variant(
})
}
Style::Newtype => {
deserialize_newtype_variant(
deserialize_externally_tagged_newtype_variant(
type_ident,
variant_ident,
generics,
@@ -636,6 +800,7 @@ fn deserialize_variant(
ty,
&variant.fields,
item_attrs,
None,
)
}
Style::Struct => {
@@ -646,22 +811,115 @@ fn deserialize_variant(
ty,
&variant.fields,
item_attrs,
None,
)
}
}
}
fn deserialize_newtype_variant(
fn deserialize_internally_tagged_variant(
type_ident: &syn::Ident,
generics: &syn::Generics,
ty: syn::Ty,
variant: &Variant,
item_attrs: &attr::Item,
deserializer: Tokens,
) -> Tokens {
let variant_ident = &variant.ident;
match variant.style {
Style::Unit => {
let type_name = type_ident.as_ref();
let variant_name = variant.ident.as_ref();
quote!({
try!(_serde::Deserializer::deserialize(#deserializer, _serde::de::private::InternallyTaggedUnitVisitor::new(#type_name, #variant_name)));
_serde::export::Ok(#type_ident::#variant_ident)
})
}
Style::Newtype | Style::Struct => {
deserialize_untagged_variant(
type_ident,
generics,
ty,
variant,
item_attrs,
deserializer,
)
}
Style::Tuple => unreachable!("checked in serde_codegen_internals"),
}
}
fn deserialize_untagged_variant(
type_ident: &syn::Ident,
generics: &syn::Generics,
ty: syn::Ty,
variant: &Variant,
item_attrs: &attr::Item,
deserializer: Tokens,
) -> Tokens {
let variant_ident = &variant.ident;
match variant.style {
Style::Unit => {
let type_name = type_ident.as_ref();
let variant_name = variant.ident.as_ref();
quote! {
_serde::export::Result::map(
_serde::Deserializer::deserialize(
#deserializer,
_serde::de::private::UntaggedUnitVisitor::new(#type_name, #variant_name)
),
|()| #type_ident::#variant_ident)
}
}
Style::Newtype => {
deserialize_untagged_newtype_variant(
type_ident,
variant_ident,
generics,
&variant.fields[0],
deserializer,
)
}
Style::Tuple => {
deserialize_tuple(
type_ident,
Some(variant_ident),
generics,
ty,
&variant.fields,
item_attrs,
Some(deserializer),
)
}
Style::Struct => {
deserialize_struct(
type_ident,
Some(variant_ident),
generics,
ty,
&variant.fields,
item_attrs,
Some(deserializer),
)
}
}
}
fn deserialize_externally_tagged_newtype_variant(
type_ident: &syn::Ident,
variant_ident: &syn::Ident,
impl_generics: &syn::Generics,
field: &Field,
) -> Tokens {
let visit = match field.attrs.deserialize_with() {
match field.attrs.deserialize_with() {
None => {
let field_ty = &field.ty;
quote! {
try!(_serde::de::VariantVisitor::visit_newtype::<#field_ty>(visitor))
_serde::export::Result::map(
_serde::de::VariantVisitor::visit_newtype::<#field_ty>(visitor),
#type_ident::#variant_ident),
}
}
Some(path) => {
@@ -670,12 +928,41 @@ fn deserialize_newtype_variant(
quote!({
#wrapper
#wrapper_impl
try!(_serde::de::VariantVisitor::visit_newtype::<#wrapper_ty>(visitor)).value
_serde::export::Result::map(
_serde::de::VariantVisitor::visit_newtype::<#wrapper_ty>(visitor),
|_wrapper| #type_ident::#variant_ident(_wrapper.value))
})
}
}
}
fn deserialize_untagged_newtype_variant(
type_ident: &syn::Ident,
variant_ident: &syn::Ident,
impl_generics: &syn::Generics,
field: &Field,
deserializer: Tokens,
) -> Tokens {
match field.attrs.deserialize_with() {
None => {
let field_ty = &field.ty;
quote!({
_serde::export::Result::map(
<#field_ty as _serde::Deserialize>::deserialize(#deserializer),
#type_ident::#variant_ident)
})
}
Some(path) => {
let (wrapper, wrapper_impl, wrapper_ty) = wrap_deserialize_with(
type_ident, impl_generics, field.ty, path);
quote!({
#wrapper
#wrapper_impl
_serde::export::Result::map(
<#wrapper_ty as _serde::Deserialize>::deserialize(#deserializer),
|_wrapper| #type_ident::#variant_ident(_wrapper.value))
})
}
};
quote! {
_serde::export::Ok(#type_ident::#variant_ident(#visit)),
}
}