diff --git a/serde/src/de/mod.rs b/serde/src/de/mod.rs index 5484a7b7..2f3c6a30 100644 --- a/serde/src/de/mod.rs +++ b/serde/src/de/mod.rs @@ -571,6 +571,11 @@ pub trait VariantVisitor { Err(Error::syntax_error()) } + /// `visit_simple` is called when deserializing a variant with a single value. + fn visit_simple(&mut self) -> Result { + Err(Error::syntax_error()) + } + /// `visit_seq` is called when deserializing a tuple-like variant. fn visit_seq(&mut self, _visitor: V) -> Result where V: Visitor @@ -601,6 +606,10 @@ impl<'a, T> VariantVisitor for &'a mut T where T: VariantVisitor { (**self).visit_unit() } + fn visit_simple(&mut self) -> Result { + (**self).visit_simple() + } + fn visit_seq(&mut self, visitor: V) -> Result where V: Visitor, { diff --git a/serde_codegen/src/de.rs b/serde_codegen/src/de.rs index 3f949df0..90f7d7ff 100644 --- a/serde_codegen/src/de.rs +++ b/serde_codegen/src/de.rs @@ -555,6 +555,12 @@ fn deserialize_variant( Ok($type_ident::$variant_ident) }) } + ast::TupleVariantKind(ref args) if args.len() == 1 => { + quote_expr!(cx, { + let val = try!(visitor.visit_simple()); + Ok($type_ident::$variant_ident(val)) + }) + } ast::TupleVariantKind(ref args) => { deserialize_tuple_variant( cx, diff --git a/serde_json/src/de.rs b/serde_json/src/de.rs index 8bb33931..efdd53df 100644 --- a/serde_json/src/de.rs +++ b/serde_json/src/de.rs @@ -632,20 +632,22 @@ impl de::VariantVisitor for Deserializer fn visit_variant(&mut self) -> Result where V: de::Deserialize { - de::Deserialize::deserialize(self) + let val = try!(de::Deserialize::deserialize(self)); + try!(self.parse_object_colon()); + Ok(val) } fn visit_unit(&mut self) -> Result<(), Error> { - try!(self.parse_object_colon()); + de::Deserialize::deserialize(self) + } + fn visit_simple(&mut self) -> Result { de::Deserialize::deserialize(self) } fn visit_seq(&mut self, visitor: V) -> Result where V: de::Visitor, { - try!(self.parse_object_colon()); - de::Deserializer::visit(self, visitor) } @@ -654,8 +656,6 @@ impl de::VariantVisitor for Deserializer visitor: V) -> Result where V: de::Visitor, { - try!(self.parse_object_colon()); - de::Deserializer::visit(self, visitor) } } diff --git a/serde_json/src/value.rs b/serde_json/src/value.rs index 45cb7785..2cee4427 100644 --- a/serde_json/src/value.rs +++ b/serde_json/src/value.rs @@ -687,33 +687,19 @@ impl de::Deserializer for Deserializer { let mut iter = value.into_iter(); - let value = match iter.next() { - Some((variant, Value::Array(fields))) => { - self.value = Some(Value::String(variant)); - - let len = fields.len(); - try!(visitor.visit(SeqDeserializer { - de: self, - iter: fields.into_iter(), - len: len, - })) - } - Some((variant, Value::Object(fields))) => { - let len = fields.len(); - try!(visitor.visit(MapDeserializer { - de: self, - iter: fields.into_iter(), - value: Some(Value::String(variant)), - len: len, - })) - } - Some(_) => { return Err(de::Error::syntax_error()); } - None => { return Err(de::Error::syntax_error()); } + let (variant, value) = match iter.next() { + Some(v) => v, + None => return Err(de::Error::syntax_error()), }; + // enums are encoded in json as maps with a single key:value pair match iter.next() { Some(_) => Err(de::Error::syntax_error()), - None => Ok(value) + None => visitor.visit(VariantDeserializer { + de: self, + val: Some(value), + variant: Some(Value::String(variant)), + }), } } @@ -723,6 +709,67 @@ impl de::Deserializer for Deserializer { } } +struct VariantDeserializer<'a> { + de: &'a mut Deserializer, + val: Option, + variant: Option, +} + +impl<'a> de::VariantVisitor for VariantDeserializer<'a> { + type Error = Error; + + fn visit_variant(&mut self) -> Result + where V: de::Deserialize, + { + de::Deserialize::deserialize(&mut Deserializer::new(self.variant.take().unwrap())) + } + + fn visit_unit(&mut self) -> Result<(), Error> + { + de::Deserialize::deserialize(&mut Deserializer::new(self.val.take().unwrap())) + } + + fn visit_simple(&mut self) -> Result + { + de::Deserialize::deserialize(&mut Deserializer::new(self.val.take().unwrap())) + } + + fn visit_seq(&mut self, visitor: V) -> Result + where V: de::Visitor, + { + if let Value::Array(fields) = self.val.take().unwrap() { + de::Deserializer::visit( + &mut SeqDeserializer { + de: self.de, + len: fields.len(), + iter: fields.into_iter(), + }, + visitor, + ) + } else { + Err(de::Error::syntax_error()) + } + } + + fn visit_map(&mut self, visitor: V) -> Result + where V: de::Visitor, + { + if let Value::Object(fields) = self.val.take().unwrap() { + de::Deserializer::visit( + &mut MapDeserializer { + de: self.de, + len: fields.len(), + iter: fields.into_iter(), + value: None, + }, + visitor, + ) + } else { + Err(de::Error::syntax_error()) + } + } +} + struct SeqDeserializer<'a> { de: &'a mut Deserializer, iter: vec::IntoIter, @@ -773,35 +820,6 @@ impl<'a> de::SeqVisitor for SeqDeserializer<'a> { } } -impl<'a> de::VariantVisitor for SeqDeserializer<'a> { - type Error = Error; - - fn visit_variant(&mut self) -> Result - where V: de::Deserialize, - { - de::Deserialize::deserialize(self.de) - } - - fn visit_unit(&mut self) -> Result<(), Error> - { - de::Deserialize::deserialize(self) - } - - fn visit_seq(&mut self, visitor: V) -> Result - where V: de::Visitor, - { - de::Deserializer::visit(self, visitor) - } - - fn visit_map(&mut self, - _fields: &'static [&'static str], - visitor: V) -> Result - where V: de::Visitor, - { - de::Deserializer::visit(self, visitor) - } -} - struct MapDeserializer<'a> { de: &'a mut Deserializer, iter: btree_map::IntoIter, @@ -884,35 +902,6 @@ impl<'a> de::Deserializer for MapDeserializer<'a> { } } -impl<'a> de::VariantVisitor for MapDeserializer<'a> { - type Error = Error; - - fn visit_variant(&mut self) -> Result - where V: de::Deserialize, - { - self.de.value = self.value.take(); - de::Deserialize::deserialize(self.de) - } - - fn visit_unit(&mut self) -> Result<(), Error> { - de::Deserialize::deserialize(self) - } - - fn visit_seq(&mut self, visitor: V) -> Result - where V: de::Visitor, - { - de::Deserializer::visit(self, visitor) - } - - fn visit_map(&mut self, - _fields: &'static [&'static str], - visitor: V) -> Result - where V: de::Visitor, - { - de::Deserializer::visit(self, visitor) - } -} - /// Shortcut function to encode a `T` into a JSON `Value` pub fn to_value(value: &T) -> Value where T: ser::Serialize