From 58550787039661f495ee9de78b5cca17616bdeaf Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Sat, 21 Jan 2017 14:38:33 -0800 Subject: [PATCH] Fix backward invalid_length messages in value deserializer --- serde/src/de/value.rs | 220 +++++++++++++++++++++++++----------------- 1 file changed, 132 insertions(+), 88 deletions(-) diff --git a/serde/src/de/value.rs b/serde/src/de/value.rs index dd3dfa7e..f34fb087 100644 --- a/serde/src/de/value.rs +++ b/serde/src/de/value.rs @@ -43,9 +43,10 @@ use std::error; use error; use core::fmt::{self, Display}; +use core::iter::{self, Iterator}; use core::marker::PhantomData; -use de::{self, SeqVisitor}; +use de::{self, Expected, SeqVisitor}; use bytes; /////////////////////////////////////////////////////////////////////////////// @@ -378,22 +379,37 @@ impl<'a, E> de::EnumVisitor for CowStrDeserializer<'a, E> /// A helper deserializer that deserializes a sequence. pub struct SeqDeserializer { - iter: I, - len: usize, + iter: iter::Fuse, + count: usize, marker: PhantomData, } impl SeqDeserializer - where E: de::Error, + where I: Iterator, + E: de::Error, { /// Construct a new `SeqDeserializer`. - pub fn new(iter: I, len: usize) -> Self { + pub fn new(iter: I) -> Self { SeqDeserializer { - iter: iter, - len: len, + iter: iter.fuse(), + count: 0, marker: PhantomData, } } + + fn end(&mut self) -> Result<(), E> { + let mut remaining = 0; + while self.iter.next().is_some() { + remaining += 1; + } + if remaining == 0 { + Ok(()) + } else { + // First argument is the number of elements in the data, second + // argument is the number of elements expected by the Deserialize. + Err(de::Error::invalid_length(self.count + remaining, &ExpectedInSeq(self.count))) + } + } } impl de::Deserializer for SeqDeserializer @@ -407,11 +423,8 @@ impl de::Deserializer for SeqDeserializer where V: de::Visitor, { let v = try!(visitor.visit_seq(&mut self)); - if self.len == 0 { - Ok(v) - } else { - Err(de::Error::invalid_length(self.len, &"TODO")) - } + try!(self.end()); + Ok(v) } forward_to_deserialize! { @@ -433,7 +446,7 @@ impl de::SeqVisitor for SeqDeserializer { match self.iter.next() { Some(value) => { - self.len -= 1; + self.count += 1; seed.deserialize(value.into_deserializer()).map(Some) } None => Ok(None), @@ -441,7 +454,19 @@ impl de::SeqVisitor for SeqDeserializer } fn size_hint(&self) -> (usize, Option) { - (self.len, Some(self.len)) + self.iter.size_hint() + } +} + +struct ExpectedInSeq(usize); + +impl Expected for ExpectedInSeq { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + if self.0 == 1 { + write!(formatter, "1 element in sequence") + } else { + write!(formatter, "{} elements in sequence", self.0) + } } } @@ -455,8 +480,7 @@ impl ValueDeserializer for Vec type Deserializer = SeqDeserializer, E>; fn into_deserializer(self) -> Self::Deserializer { - let len = self.len(); - SeqDeserializer::new(self.into_iter(), len) + SeqDeserializer::new(self.into_iter()) } } @@ -468,8 +492,7 @@ impl ValueDeserializer for BTreeSet type Deserializer = SeqDeserializer, E>; fn into_deserializer(self) -> Self::Deserializer { - let len = self.len(); - SeqDeserializer::new(self.into_iter(), len) + SeqDeserializer::new(self.into_iter()) } } @@ -481,8 +504,7 @@ impl ValueDeserializer for HashSet type Deserializer = SeqDeserializer, E>; fn into_deserializer(self) -> Self::Deserializer { - let len = self.len(); - SeqDeserializer::new(self.into_iter(), len) + SeqDeserializer::new(self.into_iter()) } } @@ -527,66 +549,66 @@ impl de::Deserializer for SeqVisitorDeserializer /////////////////////////////////////////////////////////////////////////////// /// A helper deserializer that deserializes a map. -pub struct MapDeserializer - where I: Iterator, - K: ValueDeserializer, - V: ValueDeserializer, +pub struct MapDeserializer + where I: Iterator, + I::Item: private::Pair, + ::First: ValueDeserializer, + ::Second: ValueDeserializer, E: de::Error, { - iter: I, - value: Option, - len: Option, + iter: iter::Fuse, + value: Option<::Second>, + count: usize, marker: PhantomData, } -impl MapDeserializer - where I: Iterator, - K: ValueDeserializer, - V: ValueDeserializer, +impl MapDeserializer + where I: Iterator, + I::Item: private::Pair, + ::First: ValueDeserializer, + ::Second: ValueDeserializer, E: de::Error, { - /// Construct a new `MapDeserializer` with a specific length. - pub fn new(iter: I, len: usize) -> Self { + /// Construct a new `MapDeserializer`. + pub fn new(iter: I) -> Self { MapDeserializer { - iter: iter, + iter: iter.fuse(), value: None, - len: Some(len), + count: 0, marker: PhantomData, } } - /// Construct a new `MapDeserializer` that is not bounded - /// by a specific length and that delegates to `iter` for its size hint. - pub fn unbounded(iter: I) -> Self { - MapDeserializer { - iter: iter, - value: None, - len: None, - marker: PhantomData, - } - } - - fn next(&mut self) -> Option<(K, V)> { - self.iter.next().map(|(k, v)| { - if let Some(len) = self.len.as_mut() { - *len -= 1; + fn next(&mut self) -> Option<(::First, ::Second)> { + match self.iter.next() { + Some(kv) => { + self.count += 1; + Some(private::Pair::split(kv)) } - (k, v) - }) + None => None, + } } fn end(&mut self) -> Result<(), E> { - match self.len { - Some(len) if len > 0 => Err(de::Error::invalid_length(len, &"TODO")), - _ => Ok(()) + let mut remaining = 0; + while self.iter.next().is_some() { + remaining += 1; + } + if remaining == 0 { + Ok(()) + } else { + // First argument is the number of elements in the data, second + // argument is the number of elements expected by the Deserialize. + Err(de::Error::invalid_length(self.count + remaining, &ExpectedInMap(self.count))) } } } -impl de::Deserializer for MapDeserializer - where I: Iterator, - K: ValueDeserializer, - V: ValueDeserializer, +impl de::Deserializer for MapDeserializer + where I: Iterator, + I::Item: private::Pair, + ::First: ValueDeserializer, + ::Second: ValueDeserializer, E: de::Error, { type Error = E; @@ -607,17 +629,10 @@ impl de::Deserializer for MapDeserializer Ok(value) } - fn deserialize_seq_fixed_size(mut self, len: usize, visitor: V_) -> Result + fn deserialize_seq_fixed_size(self, _len: usize, visitor: V_) -> Result where V_: de::Visitor, { - match self.len { - Some(map_len) if map_len != len => Err(de::Error::invalid_length(len, &"TODO")), - _ => { - let value = try!(visitor.visit_seq(&mut self)); - try!(self.end()); - Ok(value) - } - } + self.deserialize_seq(visitor) } forward_to_deserialize! { @@ -627,10 +642,11 @@ impl de::Deserializer for MapDeserializer } } -impl de::MapVisitor for MapDeserializer - where I: Iterator, - K: ValueDeserializer, - V: ValueDeserializer, +impl de::MapVisitor for MapDeserializer + where I: Iterator, + I::Item: private::Pair, + ::First: ValueDeserializer, + ::Second: ValueDeserializer, E: de::Error, { type Error = E; @@ -672,16 +688,15 @@ impl de::MapVisitor for MapDeserializer } fn size_hint(&self) -> (usize, Option) { - self.len.map_or_else( - || self.iter.size_hint(), - |len| (len, Some(len))) + self.iter.size_hint() } } -impl de::SeqVisitor for MapDeserializer - where I: Iterator, - K: ValueDeserializer, - V: ValueDeserializer, +impl de::SeqVisitor for MapDeserializer + where I: Iterator, + I::Item: private::Pair, + ::First: ValueDeserializer, + ::Second: ValueDeserializer, E: de::Error, { type Error = E; @@ -699,7 +714,7 @@ impl de::SeqVisitor for MapDeserializer } fn size_hint(&self) -> (usize, Option) { - de::MapVisitor::size_hint(self) + self.iter.size_hint() } } @@ -734,7 +749,10 @@ impl de::Deserializer for PairDeserializer if pair_visitor.1.is_none() { Ok(pair) } else { - Err(de::Error::invalid_length(pair_visitor.size_hint().0, &"TODO")) + let remaining = pair_visitor.size_hint().0; + // First argument is the number of elements in the data, second + // argument is the number of elements expected by the Deserialize. + Err(de::Error::invalid_length(2, &ExpectedInSeq(2 - remaining))) } } @@ -744,7 +762,9 @@ impl de::Deserializer for PairDeserializer if len == 2 { self.deserialize_seq(visitor) } else { - Err(de::Error::invalid_length(len, &"TODO")) + // First argument is the number of elements in the data, second + // argument is the number of elements expected by the Deserialize. + Err(de::Error::invalid_length(2, &ExpectedInSeq(len))) } } } @@ -782,6 +802,18 @@ impl de::SeqVisitor for PairVisitor } } +struct ExpectedInMap(usize); + +impl Expected for ExpectedInMap { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + if self.0 == 1 { + write!(formatter, "1 element in map") + } else { + write!(formatter, "{} elements in map", self.0) + } + } +} + /////////////////////////////////////////////////////////////////////////////// #[cfg(any(feature = "std", feature = "collections"))] @@ -790,11 +822,10 @@ impl ValueDeserializer for BTreeMap V: ValueDeserializer, E: de::Error, { - type Deserializer = MapDeserializer, K, V, E>; + type Deserializer = MapDeserializer, E>; fn into_deserializer(self) -> Self::Deserializer { - let len = self.len(); - MapDeserializer::new(self.into_iter(), len) + MapDeserializer::new(self.into_iter()) } } @@ -804,11 +835,10 @@ impl ValueDeserializer for HashMap V: ValueDeserializer, E: de::Error, { - type Deserializer = MapDeserializer, K, V, E>; + type Deserializer = MapDeserializer, E>; fn into_deserializer(self) -> Self::Deserializer { - let len = self.len(); - MapDeserializer::new(self.into_iter(), len) + MapDeserializer::new(self.into_iter()) } } @@ -962,4 +992,18 @@ mod private { Err(de::Error::invalid_type(Unexpected::UnitVariant, &"struct variant")) } } + + /// Avoid having to restate the generic types on MapDeserializer. The + /// Iterator::Item contains enough information to figure out K and V. + pub trait Pair { + type First; + type Second; + fn split(self) -> (Self::First, Self::Second); + } + + impl Pair for (A, B) { + type First = A; + type Second = B; + fn split(self) -> (A, B) { self } + } }