From 78137ee3a4332d3e657a9685debb4f07936c393f Mon Sep 17 00:00:00 2001 From: Erick Tryzelaar Date: Sun, 15 Mar 2015 22:04:17 -0700 Subject: [PATCH] Rewrite enum deserialization to not require allocations --- benches/bench_enum.rs | 90 ++++++++++++++++++++++++++++++++------- serde_macros/src/de.rs | 95 +++++++++++++++++++----------------------- src/de.rs | 72 +++++++++++++++++++------------- src/json/de.rs | 70 ++++++++++++++++++------------- src/json/value.rs | 44 ++++++++++++++++--- tests/test_de.rs | 47 +++++++++++++++------ tests/test_json.rs | 2 +- 7 files changed, 273 insertions(+), 147 deletions(-) diff --git a/benches/bench_enum.rs b/benches/bench_enum.rs index b1a97897..a3ec8f27 100644 --- a/benches/bench_enum.rs +++ b/benches/bench_enum.rs @@ -245,9 +245,11 @@ mod deserializer { use serde::de; + #[derive(Debug)] enum State { AnimalState(Animal), IsizeState(isize), + StrState(&'static str), StringState(String), UnitState, } @@ -273,29 +275,51 @@ mod deserializer { where V: de::Visitor, { match self.stack.pop() { - Some(State::AnimalState(Animal::Dog)) => { - self.stack.push(State::UnitState); - visitor.visit_enum("Animal", "Dog", DogVisitor { - de: self, - }) - } - Some(State::AnimalState(Animal::Frog(x0, x1))) => { - self.stack.push(State::IsizeState(x1)); - self.stack.push(State::StringState(x0)); - visitor.visit_enum("Animal", "Frog", FrogVisitor { - de: self, - state: 0, - }) - } Some(State::IsizeState(value)) => { visitor.visit_isize(value) } Some(State::StringState(value)) => { visitor.visit_string(value) } + Some(State::StrState(value)) => { + visitor.visit_str(value) + } Some(State::UnitState) => { visitor.visit_unit() } + Some(_) => { + Err(Error::SyntaxError) + } + None => { + Err(Error::EndOfStreamError) + } + } + } + + #[inline] + fn visit_enum(&mut self, _name: &str, mut visitor: V) -> Result + where V: de::EnumVisitor, + { + match self.stack.pop() { + Some(State::AnimalState(Animal::Dog)) => { + self.stack.push(State::UnitState); + self.stack.push(State::StrState("Dog")); + visitor.visit(DogVisitor { + de: self, + }) + } + Some(State::AnimalState(Animal::Frog(x0, x1))) => { + self.stack.push(State::IsizeState(x1)); + self.stack.push(State::StringState(x0)); + self.stack.push(State::StrState("Frog")); + visitor.visit(FrogVisitor { + de: self, + state: 0, + }) + } + Some(_) => { + Err(Error::SyntaxError) + } None => { Err(Error::EndOfStreamError) } @@ -307,12 +331,30 @@ mod deserializer { de: &'a mut AnimalDeserializer, } - impl<'a> de::EnumVisitor for DogVisitor<'a> { + impl<'a> de::VariantVisitor for DogVisitor<'a> { type Error = Error; + fn visit_variant(&mut self) -> Result + where V: de::Deserialize + { + de::Deserialize::deserialize(self.de) + } + fn visit_unit(&mut self) -> Result<(), Error> { de::Deserialize::deserialize(self.de) } + + fn visit_seq(&mut self, _visitor: V) -> Result + where V: de::EnumSeqVisitor + { + Err(de::Error::syntax_error()) + } + + fn visit_map(&mut self, _visitor: V) -> Result + where V: de::EnumMapVisitor + { + Err(de::Error::syntax_error()) + } } struct FrogVisitor<'a> { @@ -320,14 +362,30 @@ mod deserializer { state: usize, } - impl<'a> de::EnumVisitor for FrogVisitor<'a> { + impl<'a> de::VariantVisitor for FrogVisitor<'a> { type Error = Error; + fn visit_variant(&mut self) -> Result + where V: de::Deserialize + { + de::Deserialize::deserialize(self.de) + } + + fn visit_unit(&mut self) -> Result<(), Error> { + Err(de::Error::syntax_error()) + } + fn visit_seq(&mut self, mut visitor: V) -> Result where V: de::EnumSeqVisitor, { visitor.visit(self) } + + fn visit_map(&mut self, _visitor: V) -> Result + where V: de::EnumMapVisitor + { + Err(de::Error::syntax_error()) + } } impl<'a> de::SeqVisitor for FrogVisitor<'a> { diff --git a/serde_macros/src/de.rs b/serde_macros/src/de.rs index 756a22c7..f01d00f2 100644 --- a/serde_macros/src/de.rs +++ b/serde_macros/src/de.rs @@ -386,17 +386,32 @@ fn deserialize_item_enum( let type_name = builder.expr().str(type_ident); + let variant_visitor = deserialize_field_visitor( + cx, + builder, + enum_def.variants.iter() + .map(|variant| builder.expr().str(variant.node.name)) + .collect() + ); + // Match arms to extract a variant from a string let variant_arms: Vec<_> = enum_def.variants.iter() - .map(|variant| { - deserialize_variant( + .enumerate() + .map(|(i, variant)| { + let variant_name = builder.expr().path() + .id("__Field").id(format!("__field{}", i)) + .build(); + + let expr = deserialize_variant( cx, builder, type_ident, impl_generics, ty.clone(), variant, - ) + ); + + quote_arm!(cx, $variant_name => { $expr }) }) .collect(); @@ -406,37 +421,23 @@ fn deserialize_item_enum( ); quote_expr!(cx, { + $variant_visitor + $visitor_item - impl $impl_generics ::serde::de::Visitor for $visitor_ty $where_clause { + impl $impl_generics ::serde::de::EnumVisitor for $visitor_ty $where_clause { type Value = $ty; - fn visit_enum<__V>(&mut self, - name: &str, - variant: &str, - visitor: __V) -> Result<$ty, __V::Error> - where __V: ::serde::de::EnumVisitor, + fn visit<__V>(&mut self, mut visitor: __V) -> Result<$ty, __V::Error> + where __V: ::serde::de::VariantVisitor, { - if name == $type_name { - self.visit_variant(variant, visitor) - } else { - Err(::serde::de::Error::syntax_error()) - } - } - - fn visit_variant<__V>(&mut self, - name: &str, - mut visitor: __V) -> Result<$ty, __V::Error> - where __V: ::serde::de::EnumVisitor - { - match name { + match try!(visitor.visit_variant()) { $variant_arms - _ => Err(::serde::de::Error::syntax_error()), } } } - deserializer.visit_enum($visitor_expr) + deserializer.visit_enum($type_name, $visitor_expr) }) } @@ -447,21 +448,18 @@ fn deserialize_variant( generics: &ast::Generics, ty: P, variant: &ast::Variant, -) -> ast::Arm { +) -> P { let variant_ident = variant.node.name; - let variant_name = builder.expr().str(variant_ident); match variant.node.kind { ast::TupleVariantKind(ref args) if args.is_empty() => { - quote_arm!(cx, - $variant_name => { - try!(visitor.visit_unit()); - Ok($type_ident::$variant_ident) - } - ) + quote_expr!(cx, { + try!(visitor.visit_unit()); + Ok($type_ident::$variant_ident) + }) } ast::TupleVariantKind(ref args) => { - let expr = deserialize_tuple_variant( + deserialize_tuple_variant( cx, builder, type_ident, @@ -469,12 +467,10 @@ fn deserialize_variant( generics, ty, args.len(), - ); - - quote_arm!(cx, $variant_name => { $expr }) + ) } ast::StructVariantKind(ref struct_def) => { - let expr = deserialize_struct_variant( + deserialize_struct_variant( cx, builder, type_ident, @@ -482,9 +478,7 @@ fn deserialize_variant( generics, ty, struct_def, - ); - - quote_arm!(cx, $variant_name => { $expr }) + ) } } } @@ -574,10 +568,10 @@ fn deserialize_struct_variant( fn deserialize_field_visitor( cx: &ExtCtxt, builder: &aster::AstBuilder, - struct_def: &StructDef, + field_exprs: Vec>, ) -> Vec> { // Create the field names for the fields. - let field_names: Vec = (0 .. struct_def.fields.len()) + let field_idents: Vec = (0 .. field_exprs.len()) .map(|i| builder.id(format!("__field{}", i))) .collect(); @@ -585,20 +579,17 @@ fn deserialize_field_visitor( .attr().allow(&["non_camel_case_types"]) .enum_("__Field") .with_variants( - field_names.iter().map(|field| { - builder.variant(field).tuple().build() + field_idents.iter().map(|field_ident| { + builder.variant(field_ident).tuple().build() }) ) .build(); - // Get aliases - let aliases = field::struct_field_strs(cx, builder, struct_def); - // Match arms to extract a field from a string - let field_arms: Vec = aliases.iter() - .zip(field_names.iter()) - .map(|(alias, field_name)| { - quote_arm!(cx, $alias => { Ok(__Field::$field_name) }) + let field_arms: Vec<_> = field_idents.iter() + .zip(field_exprs.into_iter()) + .map(|(field_ident, field_expr)| { + quote_arm!(cx, $field_expr => { Ok(__Field::$field_ident) }) }) .collect(); @@ -642,7 +633,7 @@ fn deserialize_struct_visitor( let field_visitor = deserialize_field_visitor( cx, builder, - struct_def, + field::struct_field_strs(cx, builder, struct_def), ); let visit_map_expr = deserialize_map( diff --git a/src/de.rs b/src/de.rs index 500bab4b..700cb4f8 100644 --- a/src/de.rs +++ b/src/de.rs @@ -43,10 +43,10 @@ pub trait Deserializer { /// deserializers that provide a custom enumeration serialization to /// properly deserialize the type. #[inline] - fn visit_enum(&mut self, visitor: V) -> Result - where V: Visitor, + fn visit_enum(&mut self, _enum: &str, _visitor: V) -> Result + where V: EnumVisitor, { - self.visit(visitor) + Err(Error::syntax_error()) } } @@ -204,23 +204,6 @@ pub trait Visitor { { self.visit_map(visitor) } - - #[inline] - fn visit_enum(&mut self, - _name: &str, - _variant: &str, - _visitor: V) -> Result - where V: EnumVisitor, - { - Err(Error::syntax_error()) - } - - #[inline] - fn visit_variant(&mut self, _name: &str, _visitor: V) -> Result - where V: EnumVisitor, - { - Err(Error::syntax_error()) - } } /////////////////////////////////////////////////////////////////////////////// @@ -338,22 +321,53 @@ impl<'a, V_> MapVisitor for &'a mut V_ where V_: MapVisitor { /////////////////////////////////////////////////////////////////////////////// pub trait EnumVisitor { + type Value; + + fn visit(&mut self, visitor: V) -> Result + where V: VariantVisitor; +} + +/////////////////////////////////////////////////////////////////////////////// + +pub trait VariantVisitor { type Error: Error; - fn visit_unit(&mut self) -> Result<(), Self::Error> { - Err(Error::syntax_error()) - } + fn visit_variant(&mut self) -> Result + where V: Deserialize; + + fn visit_unit(&mut self) -> Result<(), Self::Error>; fn visit_seq(&mut self, _visitor: V) -> Result - where V: EnumSeqVisitor, - { - Err(Error::syntax_error()) - } + where V: EnumSeqVisitor; fn visit_map(&mut self, _visitor: V) -> Result - where V: EnumMapVisitor, + where V: EnumMapVisitor; +} + +impl<'a, T> VariantVisitor for &'a mut T where T: VariantVisitor { + type Error = T::Error; + + fn visit_variant(&mut self) -> Result + where V: Deserialize { - Err(Error::syntax_error()) + (**self).visit_variant() + } + + { + fn visit_unit(&mut self) -> Result<(), T::Error> { + (**self).visit_unit() + } + + fn visit_seq(&mut self, visitor: V) -> Result + where V: EnumSeqVisitor + { + (**self).visit_seq(visitor) + } + + fn visit_map(&mut self, visitor: V) -> Result + where V: EnumMapVisitor + { + (**self).visit_map(visitor) } } diff --git a/src/json/de.rs b/src/json/de.rs index 8c3d8983..a584598e 100644 --- a/src/json/de.rs +++ b/src/json/de.rs @@ -409,8 +409,8 @@ impl de::Deserializer for Deserializer } #[inline] - fn visit_enum(&mut self, mut visitor: V) -> Result - where V: de::Visitor, + fn visit_enum(&mut self, _name: &str, mut visitor: V) -> Result + where V: de::EnumVisitor, { self.parse_whitespace(); @@ -418,14 +418,9 @@ impl de::Deserializer for Deserializer self.bump(); self.parse_whitespace(); - try!(self.parse_string()); - try!(self.parse_object_colon()); - - let variant = str::from_utf8(&self.buf).unwrap().to_string(); - - let value = try!(visitor.visit_variant(&variant, EnumVisitor { - de: self, - })); + let value = { + try!(visitor.visit(&mut *self)) + }; self.parse_whitespace(); @@ -433,7 +428,7 @@ impl de::Deserializer for Deserializer self.bump(); Ok(value) } else { - return Err(self.error(ErrorCode::ExpectedSomeValue)); + Err(self.error(ErrorCode::ExpectedSomeValue)) } } else { Err(self.error(ErrorCode::ExpectedSomeValue)) @@ -541,7 +536,6 @@ impl<'a, Iter> de::MapVisitor for MapVisitor<'a, Iter> } if self.de.eof() { - println!("here3"); return Err(self.de.error(ErrorCode::EOFWhileParsingValue)); } @@ -599,42 +593,58 @@ impl<'a, Iter> de::MapVisitor for MapVisitor<'a, Iter> } } -struct EnumVisitor<'a, Iter: 'a> { - de: &'a mut Deserializer, -} - -impl<'a, Iter> de::EnumVisitor for EnumVisitor<'a, Iter> +impl de::VariantVisitor for Deserializer where Iter: Iterator, { type Error = Error; + fn visit_variant(&mut self) -> Result + where V: de::Deserialize + { + de::Deserialize::deserialize(self) + } + + /* + fn visit_value(&mut self) -> Result + where V: de::Deserialize + { + de::Deserialize::deserialize(self) + } + */ + fn visit_unit(&mut self) -> Result<(), Error> { - de::Deserialize::deserialize(self.de) + try!(self.parse_object_colon()); + + de::Deserialize::deserialize(self) } fn visit_seq(&mut self, mut visitor: V) -> Result - where V: de::EnumSeqVisitor, + where V: de::EnumSeqVisitor { - self.de.parse_whitespace(); + try!(self.parse_object_colon()); - if self.de.ch_is(b'[') { - self.de.bump(); - visitor.visit(SeqVisitor::new(self.de)) + self.parse_whitespace(); + + if self.ch_is(b'[') { + self.bump(); + visitor.visit(SeqVisitor::new(self)) } else { - Err(self.de.error(ErrorCode::ExpectedSomeValue)) + Err(self.error(ErrorCode::ExpectedSomeValue)) } } fn visit_map(&mut self, mut visitor: V) -> Result - where V: de::EnumMapVisitor, + where V: de::EnumMapVisitor { - self.de.parse_whitespace(); + try!(self.parse_object_colon()); - if self.de.ch_is(b'{') { - self.de.bump(); - visitor.visit(MapVisitor::new(self.de)) + self.parse_whitespace(); + + if self.ch_is(b'{') { + self.bump(); + visitor.visit(MapVisitor::new(self)) } else { - Err(self.de.error(ErrorCode::ExpectedSomeValue)) + Err(self.error(ErrorCode::ExpectedSomeValue)) } } } diff --git a/src/json/value.rs b/src/json/value.rs index 55caca98..66bd6be4 100644 --- a/src/json/value.rs +++ b/src/json/value.rs @@ -420,8 +420,8 @@ impl de::Deserializer for Deserializer { } #[inline] - fn visit_enum(&mut self, mut visitor: V) -> Result - where V: de::Visitor, + fn visit_enum(&mut self, _name: &str, mut visitor: V) -> Result + where V: de::EnumVisitor, { let value = match self.value.take() { Some(Value::Object(value)) => value, @@ -433,16 +433,20 @@ impl de::Deserializer for Deserializer { let value = match iter.next() { Some((variant, Value::Array(fields))) => { + self.value = Some(Value::String(variant)); + let len = fields.len(); - try!(visitor.visit_variant(&variant, SeqDeserializer { + try!(visitor.visit(SeqDeserializer { de: self, iter: fields.into_iter(), len: len, })) } Some((variant, Value::Object(fields))) => { + self.value = Some(Value::String(variant)); + let len = fields.len(); - try!(visitor.visit_variant(&variant, MapDeserializer { + try!(visitor.visit(MapDeserializer { de: self, iter: fields.into_iter(), value: None, @@ -495,9 +499,15 @@ impl<'a> de::SeqVisitor for SeqDeserializer<'a> { } } -impl<'a> de::EnumVisitor for SeqDeserializer<'a> { +impl<'a> de::VariantVisitor for SeqDeserializer<'a> { type Error = Error; + fn visit_variant(&mut self) -> Result + where V: de::Deserialize, + { + de::Deserialize::deserialize(self.de) + } + fn visit_unit(&mut self) -> Result<(), Error> { if self.len == 0 { Ok(()) @@ -511,6 +521,12 @@ impl<'a> de::EnumVisitor for SeqDeserializer<'a> { { visitor.visit(self) } + + fn visit_map(&mut self, _visitor: V) -> Result + where V: de::EnumMapVisitor + { + Err(de::Error::syntax_error()) + } } struct MapDeserializer<'a> { @@ -583,9 +599,25 @@ impl<'a> de::MapVisitor for MapDeserializer<'a> { } } -impl<'a> de::EnumVisitor for MapDeserializer<'a> { +impl<'a> de::VariantVisitor for MapDeserializer<'a> { type Error = Error; + fn visit_variant(&mut self) -> Result + where V: de::Deserialize, + { + de::Deserialize::deserialize(self.de) + } + + fn visit_unit(&mut self) -> Result<(), Error> { + Err(de::Error::syntax_error()) + } + + fn visit_seq(&mut self, _visitor: V) -> Result + where V: de::EnumSeqVisitor + { + Err(de::Error::syntax_error()) + } + fn visit_map(&mut self, mut visitor: V) -> Result where V: de::EnumMapVisitor, { diff --git a/tests/test_de.rs b/tests/test_de.rs index 0ae59097..8194a00c 100644 --- a/tests/test_de.rs +++ b/tests/test_de.rs @@ -44,7 +44,7 @@ enum Token<'a> { MapSep(bool), MapEnd, - EnumStart(&'a str, &'a str), + EnumStart(&'a str), EnumEnd, } @@ -132,11 +132,6 @@ impl<'a> Deserializer for TokenDeserializer<'a> { first: true, }) } - Some(Token::EnumStart(name, variant)) => { - visitor.visit_enum(name, variant, TokenDeserializerEnumVisitor { - de: self, - }) - } Some(_) => Err(Error::SyntaxError), None => Err(Error::EndOfStreamError), } @@ -144,7 +139,6 @@ impl<'a> Deserializer for TokenDeserializer<'a> { /// Hook into `Option` deserializing so we can treat `Unit` as a /// `None`, or a regular value as `Some(value)`. - #[inline] fn visit_option(&mut self, mut visitor: V) -> Result where V: Visitor, { @@ -165,6 +159,24 @@ impl<'a> Deserializer for TokenDeserializer<'a> { None => Err(Error::EndOfStreamError), } } + + fn visit_enum(&mut self, name: &str, mut visitor: V) -> Result + where V: de::EnumVisitor, + { + match self.tokens.next() { + Some(Token::EnumStart(n)) => { + if name == n { + visitor.visit(TokenDeserializerVariantVisitor { + de: self, + }) + } else { + Err(Error::SyntaxError) + } + } + Some(_) => Err(Error::SyntaxError), + None => Err(Error::EndOfStreamError), + } + } } ////////////////////////////////////////////////////////////////////////// @@ -263,13 +275,19 @@ impl<'a, 'b> de::MapVisitor for TokenDeserializerMapVisitor<'a, 'b> { ////////////////////////////////////////////////////////////////////////// -struct TokenDeserializerEnumVisitor<'a, 'b: 'a> { +struct TokenDeserializerVariantVisitor<'a, 'b: 'a> { de: &'a mut TokenDeserializer<'b>, } -impl<'a, 'b> de::EnumVisitor for TokenDeserializerEnumVisitor<'a, 'b> { +impl<'a, 'b> de::VariantVisitor for TokenDeserializerVariantVisitor<'a, 'b> { type Error = Error; + fn visit_kind(&mut self) -> Result + where V: de::Deserialize, + { + de::Deserialize::deserialize(self.de) + } + fn visit_unit(&mut self) -> Result<(), Error> { let value = try!(Deserialize::deserialize(self.de)); @@ -611,16 +629,18 @@ declare_tests! { Token::MapEnd, ], } - test_enum { + test_enum_unit { Enum::Unit => vec![ - Token::EnumStart("Enum", "Unit"), + Token::EnumStart("Enum"), + Token::Str("Unit"), Token::Unit, Token::EnumEnd, ], } test_enum_seq { Enum::Seq(1, 2, 3) => vec![ - Token::EnumStart("Enum", "Seq"), + Token::EnumStart("Enum"), + Token::Str("Seq"), Token::SeqStart(3), Token::SeqSep(true), Token::I32(1), @@ -636,7 +656,8 @@ declare_tests! { } test_enum_map { Enum::Map { a: 1, b: 2, c: 3 } => vec![ - Token::EnumStart("Enum", "Map"), + Token::EnumStart("Enum"), + Token::Str("Map"), Token::MapStart(3), Token::MapSep(true), Token::Str("a"), diff --git a/tests/test_json.rs b/tests/test_json.rs index 6fdc1a4c..28205ceb 100644 --- a/tests/test_json.rs +++ b/tests/test_json.rs @@ -907,7 +907,7 @@ fn test_parse_option() { #[test] fn test_parse_enum() { test_parse_err::(&[ - ("{}", Error::SyntaxError(ErrorCode::EOFWhileParsingString, 1, 3)), + ("{}", Error::SyntaxError(ErrorCode::ExpectedSomeValue, 1, 2)), ("{\"unknown\":[]}", Error::SyntaxError(ErrorCode::ExpectedSomeValue, 0, 0)), ("{\"Dog\":{}}", Error::SyntaxError(ErrorCode::ExpectedSomeValue, 0, 0)), ("{\"Frog\":{}}", Error::SyntaxError(ErrorCode::ExpectedSomeValue, 1, 9)),