Rewrite enum deserialization to not require allocations

This commit is contained in:
Erick Tryzelaar
2015-03-15 22:04:17 -07:00
parent b40d8f7bac
commit 78137ee3a4
7 changed files with 273 additions and 147 deletions
+43 -52
View File
@@ -386,17 +386,32 @@ fn deserialize_item_enum(
let type_name = builder.expr().str(type_ident);
let variant_visitor = deserialize_field_visitor(
cx,
builder,
enum_def.variants.iter()
.map(|variant| builder.expr().str(variant.node.name))
.collect()
);
// Match arms to extract a variant from a string
let variant_arms: Vec<_> = enum_def.variants.iter()
.map(|variant| {
deserialize_variant(
.enumerate()
.map(|(i, variant)| {
let variant_name = builder.expr().path()
.id("__Field").id(format!("__field{}", i))
.build();
let expr = deserialize_variant(
cx,
builder,
type_ident,
impl_generics,
ty.clone(),
variant,
)
);
quote_arm!(cx, $variant_name => { $expr })
})
.collect();
@@ -406,37 +421,23 @@ fn deserialize_item_enum(
);
quote_expr!(cx, {
$variant_visitor
$visitor_item
impl $impl_generics ::serde::de::Visitor for $visitor_ty $where_clause {
impl $impl_generics ::serde::de::EnumVisitor for $visitor_ty $where_clause {
type Value = $ty;
fn visit_enum<__V>(&mut self,
name: &str,
variant: &str,
visitor: __V) -> Result<$ty, __V::Error>
where __V: ::serde::de::EnumVisitor,
fn visit<__V>(&mut self, mut visitor: __V) -> Result<$ty, __V::Error>
where __V: ::serde::de::VariantVisitor,
{
if name == $type_name {
self.visit_variant(variant, visitor)
} else {
Err(::serde::de::Error::syntax_error())
}
}
fn visit_variant<__V>(&mut self,
name: &str,
mut visitor: __V) -> Result<$ty, __V::Error>
where __V: ::serde::de::EnumVisitor
{
match name {
match try!(visitor.visit_variant()) {
$variant_arms
_ => Err(::serde::de::Error::syntax_error()),
}
}
}
deserializer.visit_enum($visitor_expr)
deserializer.visit_enum($type_name, $visitor_expr)
})
}
@@ -447,21 +448,18 @@ fn deserialize_variant(
generics: &ast::Generics,
ty: P<ast::Ty>,
variant: &ast::Variant,
) -> ast::Arm {
) -> P<ast::Expr> {
let variant_ident = variant.node.name;
let variant_name = builder.expr().str(variant_ident);
match variant.node.kind {
ast::TupleVariantKind(ref args) if args.is_empty() => {
quote_arm!(cx,
$variant_name => {
try!(visitor.visit_unit());
Ok($type_ident::$variant_ident)
}
)
quote_expr!(cx, {
try!(visitor.visit_unit());
Ok($type_ident::$variant_ident)
})
}
ast::TupleVariantKind(ref args) => {
let expr = deserialize_tuple_variant(
deserialize_tuple_variant(
cx,
builder,
type_ident,
@@ -469,12 +467,10 @@ fn deserialize_variant(
generics,
ty,
args.len(),
);
quote_arm!(cx, $variant_name => { $expr })
)
}
ast::StructVariantKind(ref struct_def) => {
let expr = deserialize_struct_variant(
deserialize_struct_variant(
cx,
builder,
type_ident,
@@ -482,9 +478,7 @@ fn deserialize_variant(
generics,
ty,
struct_def,
);
quote_arm!(cx, $variant_name => { $expr })
)
}
}
}
@@ -574,10 +568,10 @@ fn deserialize_struct_variant(
fn deserialize_field_visitor(
cx: &ExtCtxt,
builder: &aster::AstBuilder,
struct_def: &StructDef,
field_exprs: Vec<P<ast::Expr>>,
) -> Vec<P<ast::Item>> {
// Create the field names for the fields.
let field_names: Vec<ast::Ident> = (0 .. struct_def.fields.len())
let field_idents: Vec<ast::Ident> = (0 .. field_exprs.len())
.map(|i| builder.id(format!("__field{}", i)))
.collect();
@@ -585,20 +579,17 @@ fn deserialize_field_visitor(
.attr().allow(&["non_camel_case_types"])
.enum_("__Field")
.with_variants(
field_names.iter().map(|field| {
builder.variant(field).tuple().build()
field_idents.iter().map(|field_ident| {
builder.variant(field_ident).tuple().build()
})
)
.build();
// Get aliases
let aliases = field::struct_field_strs(cx, builder, struct_def);
// Match arms to extract a field from a string
let field_arms: Vec<ast::Arm> = aliases.iter()
.zip(field_names.iter())
.map(|(alias, field_name)| {
quote_arm!(cx, $alias => { Ok(__Field::$field_name) })
let field_arms: Vec<_> = field_idents.iter()
.zip(field_exprs.into_iter())
.map(|(field_ident, field_expr)| {
quote_arm!(cx, $field_expr => { Ok(__Field::$field_ident) })
})
.collect();
@@ -642,7 +633,7 @@ fn deserialize_struct_visitor(
let field_visitor = deserialize_field_visitor(
cx,
builder,
struct_def,
field::struct_field_strs(cx, builder, struct_def),
);
let visit_map_expr = deserialize_map(