Add #[derive_deserialize] support for enums

This commit is contained in:
Erick Tryzelaar
2015-02-11 08:56:27 -08:00
parent 3fd42e616c
commit 1552eb72dc
2 changed files with 344 additions and 315 deletions
+328 -312
View File
@@ -22,7 +22,7 @@ use syntax::ext::deriving::generic::{
Named, Named,
StaticFields, StaticFields,
StaticStruct, StaticStruct,
//StaticEnum, StaticEnum,
Struct, Struct,
Substructure, Substructure,
TraitDef, TraitDef,
@@ -302,20 +302,19 @@ fn deserialize_substructure(cx: &ExtCtxt, span: Span, substr: &Substructure) ->
cx, cx,
span, span,
substr.type_ident, substr.type_ident,
substr.type_ident,
cx.path(span, vec![substr.type_ident]),
fields, fields,
state) state)
} }
/*
StaticEnum(_, ref fields) => { StaticEnum(_, ref fields) => {
deserialize_enum( deserialize_enum(
cx, cx,
span, span,
substr.type_ident, substr.type_ident,
&fields, &fields,
deserializer, state)
token)
} }
*/
_ => cx.bug("expected StaticEnum or StaticStruct in derive(Deserialize)") _ => cx.bug("expected StaticEnum or StaticStruct in derive(Deserialize)")
} }
} }
@@ -324,42 +323,141 @@ fn deserialize_struct(
cx: &ExtCtxt, cx: &ExtCtxt,
span: Span, span: Span,
type_ident: Ident, type_ident: Ident,
struct_ident: Ident,
struct_path: ast::Path,
fields: &StaticFields, fields: &StaticFields,
state: P<ast::Expr>, state: P<ast::Expr>,
) -> P<ast::Expr> { ) -> P<ast::Expr> {
match *fields { match *fields {
Unnamed(ref fields) => { Unnamed(ref fields) => {
deserialize_struct_unnamed_fields( if fields.is_empty() {
cx, deserialize_struct_empty_fields(
span, cx,
type_ident, span,
&fields[], type_ident,
state) struct_ident,
struct_path,
state)
} else {
deserialize_struct_unnamed_fields(
cx,
span,
type_ident,
struct_ident,
struct_path,
&fields[],
state)
}
} }
Named(ref fields) => { Named(ref fields) => {
deserialize_struct_named_fields( deserialize_struct_named_fields(
cx, cx,
span, span,
type_ident, type_ident,
struct_ident,
struct_path,
&fields[], &fields[],
state) state)
} }
} }
} }
fn deserialize_struct_empty_fields(
cx: &ExtCtxt,
span: Span,
type_ident: Ident,
struct_ident: Ident,
struct_path: ast::Path,
state: P<ast::Expr>,
) -> P<ast::Expr> {
let struct_name = cx.expr_str(span, token::get_ident(struct_ident));
let result = cx.expr_path(struct_path);
quote_expr!(cx, {
struct __Visitor;
impl ::serde2::de::Visitor for __Visitor {
type Value = $type_ident;
#[inline]
fn visit_unit<
E: ::serde2::de::Error,
>(&mut self) -> Result<$type_ident, E> {
Ok($result)
}
#[inline]
fn visit_named_unit<
E: ::serde2::de::Error,
>(&mut self, name: &str) -> Result<$type_ident, E> {
if name == $struct_name {
self.visit_unit()
} else {
Err(::serde2::de::Error::syntax_error())
}
}
}
$state.visit(&mut __Visitor)
})
}
fn deserialize_struct_unnamed_fields( fn deserialize_struct_unnamed_fields(
cx: &ExtCtxt, cx: &ExtCtxt,
span: Span, span: Span,
type_ident: Ident, type_ident: Ident,
struct_ident: Ident,
struct_path: ast::Path,
fields: &[Span], fields: &[Span],
state: P<ast::Expr>, state: P<ast::Expr>,
) -> P<ast::Expr> { ) -> P<ast::Expr> {
let type_name = cx.expr_str(span, token::get_ident(type_ident)); let struct_name = cx.expr_str(span, token::get_ident(struct_ident));
let field_names: Vec<ast::Ident> = (0 .. fields.len()) let field_names: Vec<ast::Ident> = (0 .. fields.len())
.map(|i| token::str_to_ident(&format!("__field{}", i))) .map(|i| token::str_to_ident(&format!("__field{}", i)))
.collect(); .collect();
let visit_seq_expr = declare_visit_seq(
cx,
span,
struct_path,
&field_names[],
);
quote_expr!(cx, {
struct __Visitor;
impl ::serde2::de::Visitor for __Visitor {
type Value = $type_ident;
fn visit_seq<
__V: ::serde2::de::SeqVisitor,
>(&mut self, mut visitor: __V) -> Result<$type_ident, __V::Error> {
$visit_seq_expr
}
fn visit_named_seq<
__V: ::serde2::de::SeqVisitor,
>(&mut self, name: &str, visitor: __V) -> Result<$type_ident, __V::Error> {
if name == $struct_name {
self.visit_seq(visitor)
} else {
Err(::serde2::de::Error::syntax_error())
}
}
}
$state.visit(&mut __Visitor)
})
}
fn declare_visit_seq(
cx: &ExtCtxt,
span: Span,
struct_path: ast::Path,
field_names: &[Ident],
) -> P<ast::Expr> {
let let_values: Vec<P<ast::Stmt>> = field_names.iter() let let_values: Vec<P<ast::Stmt>> = field_names.iter()
.map(|name| { .map(|name| {
quote_stmt!(cx, quote_stmt!(cx,
@@ -373,32 +471,72 @@ fn deserialize_struct_unnamed_fields(
}) })
.collect(); .collect();
let result = cx.expr_call_ident( let result = cx.expr_call(
span, span,
type_ident, cx.expr_path(struct_path),
field_names.iter().map(|name| cx.expr_ident(span, *name)).collect()); field_names.iter().map(|name| cx.expr_ident(span, *name)).collect());
quote_expr!(cx, { quote_expr!(cx, {
$let_values
try!(visitor.end());
Ok($result)
})
}
fn deserialize_struct_named_fields(
cx: &ExtCtxt,
span: Span,
type_ident: Ident,
struct_ident: Ident,
struct_path: ast::Path,
fields: &[(Ident, Span)],
state: P<ast::Expr>,
) -> P<ast::Expr> {
let struct_name = cx.expr_str(span, token::get_ident(struct_ident));
// Create the field names for the fields.
let field_names: Vec<ast::Ident> = (0 .. fields.len())
.map(|i| token::str_to_ident(&format!("__field{}", i)))
.collect();
let field_deserializer = declare_map_field_deserializer(
cx,
span,
&field_names[],
fields,
);
let visit_map_expr = declare_visit_map(
cx,
span,
struct_path,
&field_names[],
fields,
);
quote_expr!(cx, {
$field_deserializer
struct __Visitor; struct __Visitor;
impl ::serde2::de::Visitor for __Visitor { impl ::serde2::de::Visitor for __Visitor {
type Value = $type_ident; type Value = $type_ident;
fn visit_seq< #[inline]
__V: ::serde2::de::SeqVisitor, fn visit_map<
__V: ::serde2::de::MapVisitor,
>(&mut self, mut visitor: __V) -> Result<$type_ident, __V::Error> { >(&mut self, mut visitor: __V) -> Result<$type_ident, __V::Error> {
$let_values $visit_map_expr
try!(visitor.end());
Ok($result)
} }
fn visit_named_seq< #[inline]
__V: ::serde2::de::SeqVisitor, fn visit_named_map<
__V: ::serde2::de::MapVisitor,
>(&mut self, name: &str, visitor: __V) -> Result<$type_ident, __V::Error> { >(&mut self, name: &str, visitor: __V) -> Result<$type_ident, __V::Error> {
if name == $type_name { if name == $struct_name {
self.visit_seq(visitor) self.visit_map(visitor)
} else { } else {
Err(::serde2::de::Error::syntax_error()) Err(::serde2::de::Error::syntax_error())
} }
@@ -409,20 +547,12 @@ fn deserialize_struct_unnamed_fields(
}) })
} }
fn deserialize_struct_named_fields( fn declare_map_field_deserializer(
cx: &ExtCtxt, cx: &ExtCtxt,
span: Span, span: Span,
type_ident: Ident, field_names: &[ast::Ident],
fields: &[(Ident, Span)], fields: &[(Ident, Span)],
state: P<ast::Expr>, ) -> Vec<P<ast::Item>> {
) -> P<ast::Expr> {
let type_name = cx.expr_str(span, token::get_ident(type_ident));
// Create the field names for the fields.
let field_names: Vec<ast::Ident> = (0 .. fields.len())
.map(|i| token::str_to_ident(&format!("__field{}", i)))
.collect();
// Create the field names for the fields. // Create the field names for the fields.
let field_variants: Vec<P<ast::Variant>> = field_names.iter() let field_variants: Vec<P<ast::Variant>> = field_names.iter()
.map(|field| { .map(|field| {
@@ -453,6 +583,52 @@ fn deserialize_struct_named_fields(
}) })
.collect(); .collect();
vec![
quote_item!(cx,
#[allow(non_camel_case_types)]
$field_enum
).unwrap(),
quote_item!(cx,
struct __FieldVisitor;
).unwrap(),
quote_item!(cx,
impl ::serde2::de::Visitor for __FieldVisitor {
type Value = __Field;
fn visit_str<
E: ::serde2::de::Error,
>(&mut self, value: &str) -> Result<__Field, E> {
match value {
$field_arms
_ => Err(::serde2::de::Error::syntax_error()),
}
}
}
).unwrap(),
quote_item!(cx,
impl ::serde2::de::Deserialize for __Field {
#[inline]
fn deserialize<
__S: ::serde2::de::Deserializer,
>(state: &mut __S) -> Result<__Field, __S::Error> {
state.visit(&mut __FieldVisitor)
}
}
).unwrap(),
]
}
fn declare_visit_map(
cx: &ExtCtxt,
span: Span,
struct_path: ast::Path,
field_names: &[Ident],
fields: &[(Ident, Span)],
) -> P<ast::Expr> {
// Declare each field. // Declare each field.
let let_values: Vec<P<ast::Stmt>> = field_names.iter() let let_values: Vec<P<ast::Stmt>> = field_names.iter()
.map(|field| { .map(|field| {
@@ -484,9 +660,9 @@ fn deserialize_struct_named_fields(
}) })
.collect(); .collect();
let result = cx.expr_struct_ident( let result = cx.expr_struct(
span, span,
type_ident, struct_path,
fields.iter() fields.iter()
.zip(field_names.iter()) .zip(field_names.iter())
.map(|(&(name, span), field)| { .map(|(&(name, span), field)| {
@@ -496,230 +672,15 @@ fn deserialize_struct_named_fields(
); );
quote_expr!(cx, { quote_expr!(cx, {
#[allow(non_camel_case_types)] $let_values
$field_enum
struct __FieldVisitor; while let Some(key) = try!(visitor.visit_key()) {
match key {
impl ::serde2::de::Visitor for __FieldVisitor { $value_arms
type Value = __Field;
fn visit_str<
E: ::serde2::de::Error,
>(&mut self, value: &str) -> Result<__Field, E> {
match value {
$field_arms
_ => Err(::serde2::de::Error::syntax_error()),
}
} }
} }
impl ::serde2::de::Deserialize for __Field { $extract_values
#[inline]
fn deserialize<
__S: ::serde2::de::Deserializer,
>(state: &mut __S) -> Result<__Field, __S::Error> {
state.visit(&mut __FieldVisitor)
}
}
struct __Visitor;
impl ::serde2::de::Visitor for __Visitor {
type Value = $type_ident;
fn visit_map<
__V: ::serde2::de::MapVisitor,
>(&mut self, mut visitor: __V) -> Result<$type_ident, __V::Error> {
$let_values
while let Some(key) = try!(visitor.visit_key()) {
match key {
$value_arms
}
}
$extract_values
Ok($result)
}
fn visit_named_map<
__V: ::serde2::de::MapVisitor,
>(&mut self, name: &str, visitor: __V) -> Result<$type_ident, __V::Error> {
if name == $type_name {
self.visit_map(visitor)
} else {
Err(::serde2::de::Error::syntax_error())
}
}
}
$state.visit(&mut __Visitor)
})
}
/*
fn deserialize_struct(
cx: &ExtCtxt,
span: Span,
type_ident: Ident,
fields: &StaticFields,
deserializer: P<ast::Expr>,
token: P<ast::Expr>
) -> P<ast::Expr> {
/*
let struct_block = deserialize_struct_from_struct(
cx,
span,
type_ident,
fields,
deserializer
);
*/
let map_block = deserialize_struct_from_map(
cx,
span,
type_ident,
fields,
deserializer
);
quote_expr!(
cx,
match $token {
::serde2::de::StructStart(_, _) => $struct_block,
::serde2::de::MapStart(_) => $map_block,
token => {
let expected_tokens = [
::serde2::de::StructStartKind,
::serde2::de::MapStartKind,
];
Err($deserializer.syntax_error(token, expected_tokens))
}
}
)
}
/*
fn deserialize_struct_from_struct(
cx: &ExtCtxt,
span: Span,
type_ident: Ident,
fields: &StaticFields,
deserializer: P<ast::Expr>
) -> P<ast::Expr> {
let expect_struct_field = cx.ident_of("expect_struct_field");
let call = deserialize_static_fields(
cx,
span,
type_ident,
fields,
|cx, span, name| {
let name = cx.expr_str(span, name);
quote_expr!(
cx,
try!($deserializer.expect_struct_field($name))
)
}
);
quote_expr!(cx, {
let result = $call;
try!($deserializer.expect_struct_end());
Ok(result)
})
}
*/
fn deserialize_struct_from_map(
cx: &ExtCtxt,
span: Span,
type_ident: Ident,
fields: &StaticFields,
deserializer: P<ast::Expr>
) -> P<ast::Expr> {
let fields = match *fields {
Unnamed(_) => panic!(),
Named(ref fields) => &fields[],
};
// Declare each field.
let let_fields: Vec<P<ast::Stmt>> = fields.iter()
.map(|&(name, span)| {
quote_stmt!(cx, let mut $name = None)
})
.collect();
// Declare key arms.
let key_arms: Vec<ast::Arm> = fields.iter()
.map(|&(name, span)| {
let s = cx.expr_str(span, token::get_ident(name));
quote_arm!(cx,
$s => {
$name = Some(
try!(::serde2::de::Deserialize::deserialize($deserializer))
);
continue;
})
})
.collect();
let extract_fields: Vec<P<ast::Stmt>> = fields.iter()
.map(|&(name, span)| {
let name_str = cx.expr_str(span, token::get_ident(name));
quote_stmt!(cx,
let $name = match $name {
Some($name) => $name,
None => try!($deserializer.missing_field($name_str)),
};
)
})
.collect();
let result = cx.expr_struct_ident(
span,
type_ident,
fields.iter()
.map(|&(name, span)| {
cx.field_imm(span, name, cx.expr_ident(span, name))
})
.collect()
);
quote_expr!(cx, {
$let_fields
loop {
let token = match try!($deserializer.expect_token()) {
::serde2::de::End => { break; }
token => token,
};
{
let key = match token {
::serde2::de::Str(s) => s,
::serde2::de::String(ref s) => &s,
token => {
let expected_tokens = [
::serde2::de::StrKind,
::serde2::de::StringKind,
];
return Err($deserializer.syntax_error(token, expected_tokens));
}
};
match key {
$key_arms
_ => { }
}
}
try!($deserializer.ignore_field(token))
}
$extract_fields
Ok($result) Ok($result)
}) })
} }
@@ -729,89 +690,144 @@ fn deserialize_enum(
span: Span, span: Span,
type_ident: Ident, type_ident: Ident,
fields: &[(Ident, Span, StaticFields)], fields: &[(Ident, Span, StaticFields)],
deserializer: P<ast::Expr>, state: P<ast::Expr>,
token: P<ast::Expr>
) -> P<ast::Expr> { ) -> P<ast::Expr> {
let type_name = cx.expr_str(span, token::get_ident(type_ident)); let type_name = cx.expr_str(span, token::get_ident(type_ident));
let variants = fields.iter() // Match arms to extract a variant from a string
.map(|&(name, span, _)| { let variant_arms: Vec<ast::Arm> = fields.iter()
cx.expr_str(span, token::get_ident(name)) .map(|&(name, span, ref fields)| {
}) let value = deserialize_enum_variant(
.collect();
let variants = cx.expr_vec(span, variants);
let arms: Vec<ast::Arm> = fields.iter()
.enumerate()
.map(|(i, &(name, span, ref parts))| {
let call = deserialize_static_fields(
cx, cx,
span, span,
type_ident,
name, name,
parts, fields,
|cx, span, _| { cx.expr_ident(span, cx.ident_of("visitor")),
quote_expr!(cx, try!($deserializer.expect_enum_elt()))
}
); );
quote_arm!(cx, $i => $call,) let s = cx.expr_str(span, token::get_ident(name));
quote_arm!(cx, $s => $value,)
}) })
.collect(); .collect();
quote_expr!(cx, { quote_expr!(cx, {
let i = try!($deserializer.expect_enum_start($token, $type_name, $variants)); struct __Visitor;
let result = match i { impl ::serde2::de::Visitor for __Visitor {
$arms type Value = $type_ident;
_ => { unreachable!() }
};
try!($deserializer.expect_enum_end()); fn visit_enum<
__V: ::serde2::de::EnumVisitor,
>(&mut self, name: &str, variant: &str, mut visitor: __V) -> Result<$type_ident, __V::Error> {
if name == $type_name {
self.visit_variant(variant, visitor)
} else {
Err(::serde2::de::Error::syntax_error())
}
}
Ok(result) fn visit_variant<
__V: ::serde2::de::EnumVisitor,
>(&mut self, name: &str, mut visitor: __V) -> Result<$type_ident, __V::Error> {
match name {
$variant_arms
_ => Err(::serde2::de::Error::syntax_error()),
}
}
}
$state.visit(&mut __Visitor)
}) })
} }
/// Create a deserializer for a single enum variant/struct: fn deserialize_enum_variant(
/// - `outer_pat_ident` is the name of this enum variant/struct
/// - `getarg` should retrieve the `u32`-th field with name `&str`.
fn deserialize_static_fields(
cx: &ExtCtxt, cx: &ExtCtxt,
span: Span, span: Span,
outer_pat_ident: Ident, type_ident: Ident,
variant_ident: Ident,
fields: &StaticFields, fields: &StaticFields,
getarg: |&ExtCtxt, Span, token::InternedString| -> P<Expr> state: P<ast::Expr>,
) -> P<Expr> { ) -> P<ast::Expr> {
let variant_path = cx.path(span, vec![type_ident, variant_ident]);
match *fields { match *fields {
Unnamed(ref fields) => { Unnamed(ref fields) => {
if fields.is_empty() { if fields.is_empty() {
cx.expr_ident(span, outer_pat_ident) let result = cx.expr_path(variant_path);
} else {
let fields = fields.iter().enumerate().map(|(i, &span)| {
getarg(
cx,
span,
token::intern_and_get_ident(&format!("_field{}", i))
)
}).collect();
cx.expr_call_ident(span, outer_pat_ident, fields) quote_expr!(cx, {
try!($state.visit_unit());
Ok($result)
})
} else {
// Create the field names for the fields.
let field_names: Vec<ast::Ident> = (0 .. fields.len())
.map(|i| token::str_to_ident(&format!("__field{}", i)))
.collect();
let visit_seq_expr = declare_visit_seq(
cx,
span,
variant_path,
&field_names[],
);
quote_expr!(cx, {
struct __Visitor;
impl ::serde2::de::EnumSeqVisitor for __Visitor {
type Value = $type_ident;
fn visit<
V: ::serde2::de::SeqVisitor,
>(&mut self, mut visitor: V) -> Result<$type_ident, V::Error> {
$visit_seq_expr
}
}
$state.visit_seq(&mut __Visitor)
})
} }
} }
Named(ref fields) => { Named(ref fields) => {
// use the field's span to get nicer error messages. // Create the field names for the fields.
let fields = fields.iter().map(|&(name, span)| { let field_names: Vec<ast::Ident> = (0 .. fields.len())
let arg = getarg( .map(|i| token::str_to_ident(&format!("__field{}", i)))
cx, .collect();
span,
token::get_ident(name)
);
cx.field_imm(span, name, arg)
}).collect();
cx.expr_struct_ident(span, outer_pat_ident, fields) let field_deserializer = declare_map_field_deserializer(
cx,
span,
&field_names[],
fields,
);
let visit_map_expr = declare_visit_map(
cx,
span,
variant_path,
&field_names[],
fields,
);
quote_expr!(cx, {
$field_deserializer
struct __Visitor;
impl ::serde2::de::EnumMapVisitor for __Visitor {
type Value = $type_ident;
fn visit<
V: ::serde2::de::MapVisitor,
>(&mut self, mut visitor: V) -> Result<$type_ident, V::Error> {
$visit_map_expr
}
}
$state.visit_map(&mut __Visitor)
})
} }
} }
} }
*/
+16 -3
View File
@@ -201,6 +201,13 @@ pub trait Visitor {
>(&mut self, _name: &str, _variant: &str, _visitor: V) -> Result<Self::Value, V::Error> { >(&mut self, _name: &str, _variant: &str, _visitor: V) -> Result<Self::Value, V::Error> {
Err(Error::syntax_error()) Err(Error::syntax_error())
} }
#[inline]
fn visit_variant<
V: EnumVisitor,
>(&mut self, _name: &str, _visitor: V) -> Result<Self::Value, V::Error> {
Err(Error::syntax_error())
}
} }
pub trait SeqVisitor { pub trait SeqVisitor {
@@ -1177,11 +1184,17 @@ mod tests {
fn visit_enum< fn visit_enum<
V: super::EnumVisitor, V: super::EnumVisitor,
>(&mut self, name: &str, variant: &str, mut visitor: V) -> Result<Enum, V::Error> { >(&mut self, name: &str, variant: &str, mut visitor: V) -> Result<Enum, V::Error> {
if name != "Enum" { if name == "Enum" {
return Err(super::Error::syntax_error()); self.visit_variant(variant, visitor)
} else {
Err(super::Error::syntax_error());
} }
}
match variant { fn visit_variant<
V: super::EnumVisitor,
>(&mut self, name: &str, mut visitor: V) -> Result<Enum, V::Error> {
match name {
"Unit" => { "Unit" => {
try!(visitor.visit_unit()); try!(visitor.visit_unit());
Ok(Enum::Unit) Ok(Enum::Unit)