From c1259fbc87379ae65605a5c2a09efb355e2349e4 Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Wed, 10 May 2017 19:56:05 -0700 Subject: [PATCH] Support borrowing within internally tagged enum --- serde/src/private/de.rs | 284 +++++++++++++++++++------------- test_suite/tests/test_macros.rs | 24 +++ 2 files changed, 195 insertions(+), 113 deletions(-) diff --git a/serde/src/private/de.rs b/serde/src/private/de.rs index 45b383fb..2c5648cd 100644 --- a/serde/src/private/de.rs +++ b/serde/src/private/de.rs @@ -233,7 +233,7 @@ mod content { /// /// Not public API. Use serde-value instead. #[derive(Debug)] - pub enum Content { + pub enum Content<'de> { Bool(bool), U8(u8), @@ -251,18 +251,20 @@ mod content { Char(char), String(String), - Bytes(Vec), + Str(&'de str), + ByteBuf(Vec), + Bytes(&'de [u8]), None, - Some(Box), + Some(Box>), Unit, - Newtype(Box), - Seq(Vec), - Map(Vec<(Content, Content)>), + Newtype(Box>), + Seq(Vec>), + Map(Vec<(Content<'de>, Content<'de>)>), } - impl Content { + impl<'de> Content<'de> { fn unexpected(&self) -> Unexpected { match *self { Content::Bool(b) => Unexpected::Bool(b), @@ -278,7 +280,9 @@ mod content { Content::F64(f) => Unexpected::Float(f), Content::Char(c) => Unexpected::Char(c), Content::String(ref s) => Unexpected::Str(s), - Content::Bytes(ref b) => Unexpected::Bytes(b), + Content::Str(s) => Unexpected::Str(s), + Content::ByteBuf(ref b) => Unexpected::Bytes(b), + Content::Bytes(b) => Unexpected::Bytes(b), Content::None | Content::Some(_) => Unexpected::Option, Content::Unit => Unexpected::Unit, Content::Newtype(_) => Unexpected::NewtypeStruct, @@ -288,21 +292,30 @@ mod content { } } - impl<'de> Deserialize<'de> for Content { + impl<'de> Deserialize<'de> for Content<'de> { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { // Untagged and internally tagged enums are only supported in // self-describing formats. - deserializer.deserialize_any(ContentVisitor) + let visitor = ContentVisitor { value: PhantomData }; + deserializer.deserialize_any(visitor) } } - struct ContentVisitor; + struct ContentVisitor<'de> { + value: PhantomData>, + } - impl<'de> Visitor<'de> for ContentVisitor { - type Value = Content; + impl<'de> ContentVisitor<'de> { + fn new() -> Self { + ContentVisitor { value: PhantomData } + } + } + + impl<'de> Visitor<'de> for ContentVisitor<'de> { + type Value = Content<'de>; fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fmt.write_str("any value") @@ -399,6 +412,13 @@ mod content { Ok(Content::String(value.into())) } + fn visit_borrowed_str(self, value: &'de str) -> Result + where + F: de::Error, + { + Ok(Content::Str(value)) + } + fn visit_string(self, value: String) -> Result where F: de::Error, @@ -410,14 +430,21 @@ mod content { where F: de::Error, { - Ok(Content::Bytes(value.into())) + Ok(Content::ByteBuf(value.into())) + } + + fn visit_borrowed_bytes(self, value: &'de [u8]) -> Result + where + F: de::Error, + { + Ok(Content::Bytes(value)) } fn visit_byte_buf(self, value: Vec) -> Result where F: de::Error, { - Ok(Content::Bytes(value)) + Ok(Content::ByteBuf(value)) } fn visit_unit(self) -> Result @@ -481,23 +508,24 @@ mod content { /// This is the type of the map keys in an internally tagged enum. /// /// Not public API. - pub enum TagOrContent { + pub enum TagOrContent<'de> { Tag, - Content(Content), + Content(Content<'de>), } - struct TagOrContentVisitor { + struct TagOrContentVisitor<'de> { name: &'static str, + value: PhantomData>, } - impl TagOrContentVisitor { + impl<'de> TagOrContentVisitor<'de> { fn new(name: &'static str) -> Self { - TagOrContentVisitor { name: name } + TagOrContentVisitor { name: name, value: PhantomData } } } - impl<'de> DeserializeSeed<'de> for TagOrContentVisitor { - type Value = TagOrContent; + impl<'de> DeserializeSeed<'de> for TagOrContentVisitor<'de> { + type Value = TagOrContent<'de>; fn deserialize(self, deserializer: D) -> Result where @@ -509,8 +537,8 @@ mod content { } } - impl<'de> Visitor<'de> for TagOrContentVisitor { - type Value = TagOrContent; + impl<'de> Visitor<'de> for TagOrContentVisitor<'de> { + type Value = TagOrContent<'de>; fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { write!(fmt, "a type tag `{}` or any other value", self.name) @@ -520,7 +548,7 @@ mod content { where F: de::Error, { - ContentVisitor + ContentVisitor::new() .visit_bool(value) .map(TagOrContent::Content) } @@ -529,14 +557,14 @@ mod content { where F: de::Error, { - ContentVisitor.visit_i8(value).map(TagOrContent::Content) + ContentVisitor::new().visit_i8(value).map(TagOrContent::Content) } fn visit_i16(self, value: i16) -> Result where F: de::Error, { - ContentVisitor + ContentVisitor::new() .visit_i16(value) .map(TagOrContent::Content) } @@ -545,7 +573,7 @@ mod content { where F: de::Error, { - ContentVisitor + ContentVisitor::new() .visit_i32(value) .map(TagOrContent::Content) } @@ -554,7 +582,7 @@ mod content { where F: de::Error, { - ContentVisitor + ContentVisitor::new() .visit_i64(value) .map(TagOrContent::Content) } @@ -563,14 +591,14 @@ mod content { where F: de::Error, { - ContentVisitor.visit_u8(value).map(TagOrContent::Content) + ContentVisitor::new().visit_u8(value).map(TagOrContent::Content) } fn visit_u16(self, value: u16) -> Result where F: de::Error, { - ContentVisitor + ContentVisitor::new() .visit_u16(value) .map(TagOrContent::Content) } @@ -579,7 +607,7 @@ mod content { where F: de::Error, { - ContentVisitor + ContentVisitor::new() .visit_u32(value) .map(TagOrContent::Content) } @@ -588,7 +616,7 @@ mod content { where F: de::Error, { - ContentVisitor + ContentVisitor::new() .visit_u64(value) .map(TagOrContent::Content) } @@ -597,7 +625,7 @@ mod content { where F: de::Error, { - ContentVisitor + ContentVisitor::new() .visit_f32(value) .map(TagOrContent::Content) } @@ -606,7 +634,7 @@ mod content { where F: de::Error, { - ContentVisitor + ContentVisitor::new() .visit_f64(value) .map(TagOrContent::Content) } @@ -615,7 +643,7 @@ mod content { where F: de::Error, { - ContentVisitor + ContentVisitor::new() .visit_char(value) .map(TagOrContent::Content) } @@ -627,12 +655,25 @@ mod content { if value == self.name { Ok(TagOrContent::Tag) } else { - ContentVisitor + ContentVisitor::new() .visit_str(value) .map(TagOrContent::Content) } } + fn visit_borrowed_str(self, value: &'de str) -> Result + where + F: de::Error, + { + if value == self.name { + Ok(TagOrContent::Tag) + } else { + ContentVisitor::new() + .visit_borrowed_str(value) + .map(TagOrContent::Content) + } + } + fn visit_string(self, value: String) -> Result where F: de::Error, @@ -640,7 +681,7 @@ mod content { if value == self.name { Ok(TagOrContent::Tag) } else { - ContentVisitor + ContentVisitor::new() .visit_string(value) .map(TagOrContent::Content) } @@ -653,12 +694,25 @@ mod content { if value == self.name.as_bytes() { Ok(TagOrContent::Tag) } else { - ContentVisitor + ContentVisitor::new() .visit_bytes(value) .map(TagOrContent::Content) } } + fn visit_borrowed_bytes(self, value: &'de [u8]) -> Result + where + F: de::Error, + { + if value == self.name.as_bytes() { + Ok(TagOrContent::Tag) + } else { + ContentVisitor::new() + .visit_borrowed_bytes(value) + .map(TagOrContent::Content) + } + } + fn visit_byte_buf(self, value: Vec) -> Result where F: de::Error, @@ -666,7 +720,7 @@ mod content { if value == self.name.as_bytes() { Ok(TagOrContent::Tag) } else { - ContentVisitor + ContentVisitor::new() .visit_byte_buf(value) .map(TagOrContent::Content) } @@ -676,21 +730,21 @@ mod content { where F: de::Error, { - ContentVisitor.visit_unit().map(TagOrContent::Content) + ContentVisitor::new().visit_unit().map(TagOrContent::Content) } fn visit_none(self) -> Result where F: de::Error, { - ContentVisitor.visit_none().map(TagOrContent::Content) + ContentVisitor::new().visit_none().map(TagOrContent::Content) } fn visit_some(self, deserializer: D) -> Result where D: Deserializer<'de>, { - ContentVisitor + ContentVisitor::new() .visit_some(deserializer) .map(TagOrContent::Content) } @@ -699,7 +753,7 @@ mod content { where D: Deserializer<'de>, { - ContentVisitor + ContentVisitor::new() .visit_newtype_struct(deserializer) .map(TagOrContent::Content) } @@ -708,7 +762,7 @@ mod content { where V: SeqAccess<'de>, { - ContentVisitor + ContentVisitor::new() .visit_seq(visitor) .map(TagOrContent::Content) } @@ -717,7 +771,7 @@ mod content { where V: MapAccess<'de>, { - ContentVisitor + ContentVisitor::new() .visit_map(visitor) .map(TagOrContent::Content) } @@ -726,7 +780,7 @@ mod content { where V: EnumAccess<'de>, { - ContentVisitor + ContentVisitor::new() .visit_enum(visitor) .map(TagOrContent::Content) } @@ -735,33 +789,33 @@ mod content { /// Used by generated code to deserialize an internally tagged enum. /// /// Not public API. - pub struct TaggedContent { + pub struct TaggedContent<'de, T> { pub tag: T, - pub content: Content, + pub content: Content<'de>, } /// Not public API. - pub struct TaggedContentVisitor { + pub struct TaggedContentVisitor<'de, T> { tag_name: &'static str, - tag: PhantomData, + value: PhantomData>, } - impl TaggedContentVisitor { + impl<'de, T> TaggedContentVisitor<'de, T> { /// Visitor for the content of an internally tagged enum with the given tag /// name. pub fn new(name: &'static str) -> Self { TaggedContentVisitor { tag_name: name, - tag: PhantomData, + value: PhantomData, } } } - impl<'de, T> DeserializeSeed<'de> for TaggedContentVisitor + impl<'de, T> DeserializeSeed<'de> for TaggedContentVisitor<'de, T> where T: Deserialize<'de>, { - type Value = TaggedContent; + type Value = TaggedContent<'de, T>; fn deserialize(self, deserializer: D) -> Result where @@ -773,11 +827,11 @@ mod content { } } - impl<'de, T> Visitor<'de> for TaggedContentVisitor + impl<'de, T> Visitor<'de> for TaggedContentVisitor<'de, T> where T: Deserialize<'de>, { - type Value = TaggedContent; + type Value = TaggedContent<'de, T>; fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fmt.write_str("any value") @@ -913,14 +967,14 @@ mod content { } /// Not public API - pub struct ContentDeserializer { - content: Content, + pub struct ContentDeserializer<'de, E> { + content: Content<'de>, err: PhantomData, } /// Used when deserializing an internally tagged enum because the content will /// be used exactly once. - impl<'de, E> Deserializer<'de> for ContentDeserializer + impl<'de, E> Deserializer<'de> for ContentDeserializer<'de, E> where E: de::Error, { @@ -944,6 +998,9 @@ mod content { Content::F64(v) => visitor.visit_f64(v), Content::Char(v) => visitor.visit_char(v), Content::String(v) => visitor.visit_string(v), + Content::Str(v) => visitor.visit_borrowed_str(v), + Content::ByteBuf(v) => visitor.visit_byte_buf(v), + Content::Bytes(v) => visitor.visit_borrowed_bytes(v), Content::Unit => visitor.visit_unit(), Content::None => visitor.visit_none(), Content::Some(v) => visitor.visit_some(ContentDeserializer::new(*v)), @@ -965,7 +1022,6 @@ mod content { try!(map_visitor.end()); Ok(value) } - Content::Bytes(v) => visitor.visit_byte_buf(v), } } @@ -1048,9 +1104,9 @@ mod content { } } - impl ContentDeserializer { + impl<'de, E> ContentDeserializer<'de, E> { /// private API, don't use - pub fn new(content: Content) -> Self { + pub fn new(content: Content<'de>) -> Self { ContentDeserializer { content: content, err: PhantomData, @@ -1058,21 +1114,21 @@ mod content { } } - struct EnumDeserializer + struct EnumDeserializer<'de, E> where E: de::Error, { - variant: Content, - value: Option, + variant: Content<'de>, + value: Option>, err: PhantomData, } - impl<'de, E> de::EnumAccess<'de> for EnumDeserializer + impl<'de, E> de::EnumAccess<'de> for EnumDeserializer<'de, E> where E: de::Error, { type Error = E; - type Variant = VariantDeserializer; + type Variant = VariantDeserializer<'de, Self::Error>; fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), E> where @@ -1087,15 +1143,15 @@ mod content { } } - struct VariantDeserializer + struct VariantDeserializer<'de, E> where E: de::Error, { - value: Option, + value: Option>, err: PhantomData, } - impl<'de, E> de::VariantAccess<'de> for VariantDeserializer + impl<'de, E> de::VariantAccess<'de> for VariantDeserializer<'de, E> where E: de::Error, { @@ -1151,19 +1207,19 @@ mod content { } } - struct SeqDeserializer + struct SeqDeserializer<'de, E> where E: de::Error, { - iter: as IntoIterator>::IntoIter, + iter: > as IntoIterator>::IntoIter, err: PhantomData, } - impl SeqDeserializer + impl<'de, E> SeqDeserializer<'de, E> where E: de::Error, { - fn new(vec: Vec) -> Self { + fn new(vec: Vec>) -> Self { SeqDeserializer { iter: vec.into_iter(), err: PhantomData, @@ -1171,7 +1227,7 @@ mod content { } } - impl<'de, E> de::Deserializer<'de> for SeqDeserializer + impl<'de, E> de::Deserializer<'de> for SeqDeserializer<'de, E> where E: de::Error, { @@ -1203,7 +1259,7 @@ mod content { } } - impl<'de, E> de::SeqAccess<'de> for SeqDeserializer + impl<'de, E> de::SeqAccess<'de> for SeqDeserializer<'de, E> where E: de::Error, { @@ -1227,20 +1283,20 @@ mod content { } } - struct MapDeserializer + struct MapDeserializer<'de, E> where E: de::Error, { - iter: as IntoIterator>::IntoIter, - value: Option, + iter: , Content<'de>)> as IntoIterator>::IntoIter, + value: Option>, err: PhantomData, } - impl MapDeserializer + impl<'de, E> MapDeserializer<'de, E> where E: de::Error, { - fn new(map: Vec<(Content, Content)>) -> Self { + fn new(map: Vec<(Content<'de>, Content<'de>)>) -> Self { MapDeserializer { iter: map.into_iter(), value: None, @@ -1249,7 +1305,7 @@ mod content { } } - impl<'de, E> de::MapAccess<'de> for MapDeserializer + impl<'de, E> de::MapAccess<'de> for MapDeserializer<'de, E> where E: de::Error, { @@ -1283,7 +1339,7 @@ mod content { } } - impl<'de, E> de::Deserializer<'de> for MapDeserializer + impl<'de, E> de::Deserializer<'de> for MapDeserializer<'de, E> where E: de::Error, { @@ -1305,14 +1361,14 @@ mod content { } /// Not public API. - pub struct ContentRefDeserializer<'a, E> { - content: &'a Content, + pub struct ContentRefDeserializer<'a, 'de: 'a, E> { + content: &'a Content<'de>, err: PhantomData, } /// Used when deserializing an untagged enum because the content may need to be /// used more than once. - impl<'de, 'a, E> Deserializer<'de> for ContentRefDeserializer<'a, E> + impl<'de, 'a, E> Deserializer<'de> for ContentRefDeserializer<'a, 'de, E> where E: de::Error, { @@ -1336,6 +1392,9 @@ mod content { Content::F64(v) => visitor.visit_f64(v), Content::Char(v) => visitor.visit_char(v), Content::String(ref v) => visitor.visit_str(v), + Content::Str(v) => visitor.visit_borrowed_str(v), + Content::ByteBuf(ref v) => visitor.visit_bytes(v), + Content::Bytes(v) => visitor.visit_borrowed_bytes(v), Content::Unit => visitor.visit_unit(), Content::None => visitor.visit_none(), Content::Some(ref v) => visitor.visit_some(ContentRefDeserializer::new(v)), @@ -1361,7 +1420,6 @@ mod content { try!(map_visitor.end()); Ok(value) } - Content::Bytes(ref v) => visitor.visit_bytes(v), } } @@ -1440,9 +1498,9 @@ mod content { } } - impl<'a, E> ContentRefDeserializer<'a, E> { + impl<'a, 'de, E> ContentRefDeserializer<'a, 'de, E> { /// private API, don't use - pub fn new(content: &'a Content) -> Self { + pub fn new(content: &'a Content<'de>) -> Self { ContentRefDeserializer { content: content, err: PhantomData, @@ -1450,21 +1508,21 @@ mod content { } } - struct EnumRefDeserializer<'a, E> + struct EnumRefDeserializer<'a, 'de: 'a, E> where E: de::Error, { - variant: &'a Content, - value: Option<&'a Content>, + variant: &'a Content<'de>, + value: Option<&'a Content<'de>>, err: PhantomData, } - impl<'de, 'a, E> de::EnumAccess<'de> for EnumRefDeserializer<'a, E> + impl<'de, 'a, E> de::EnumAccess<'de> for EnumRefDeserializer<'a, 'de, E> where E: de::Error, { type Error = E; - type Variant = VariantRefDeserializer<'a, Self::Error>; + type Variant = VariantRefDeserializer<'a, 'de, Self::Error>; fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> where @@ -1479,15 +1537,15 @@ mod content { } } - struct VariantRefDeserializer<'a, E> + struct VariantRefDeserializer<'a, 'de: 'a, E> where E: de::Error, { - value: Option<&'a Content>, + value: Option<&'a Content<'de>>, err: PhantomData, } - impl<'de, 'a, E> de::VariantAccess<'de> for VariantRefDeserializer<'a, E> + impl<'de, 'a, E> de::VariantAccess<'de> for VariantRefDeserializer<'a, 'de, E> where E: de::Error, { @@ -1543,19 +1601,19 @@ mod content { } } - struct SeqRefDeserializer<'a, E> + struct SeqRefDeserializer<'a, 'de: 'a, E> where E: de::Error, { - iter: <&'a [Content] as IntoIterator>::IntoIter, + iter: <&'a [Content<'de>] as IntoIterator>::IntoIter, err: PhantomData, } - impl<'a, E> SeqRefDeserializer<'a, E> + impl<'a, 'de, E> SeqRefDeserializer<'a, 'de, E> where E: de::Error, { - fn new(vec: &'a [Content]) -> Self { + fn new(vec: &'a [Content<'de>]) -> Self { SeqRefDeserializer { iter: vec.into_iter(), err: PhantomData, @@ -1563,7 +1621,7 @@ mod content { } } - impl<'de, 'a, E> de::Deserializer<'de> for SeqRefDeserializer<'a, E> + impl<'de, 'a, E> de::Deserializer<'de> for SeqRefDeserializer<'a, 'de, E> where E: de::Error, { @@ -1595,7 +1653,7 @@ mod content { } } - impl<'de, 'a, E> de::SeqAccess<'de> for SeqRefDeserializer<'a, E> + impl<'de, 'a, E> de::SeqAccess<'de> for SeqRefDeserializer<'a, 'de, E> where E: de::Error, { @@ -1619,20 +1677,20 @@ mod content { } } - struct MapRefDeserializer<'a, E> + struct MapRefDeserializer<'a, 'de: 'a, E> where E: de::Error, { - iter: <&'a [(Content, Content)] as IntoIterator>::IntoIter, - value: Option<&'a Content>, + iter: <&'a [(Content<'de>, Content<'de>)] as IntoIterator>::IntoIter, + value: Option<&'a Content<'de>>, err: PhantomData, } - impl<'a, E> MapRefDeserializer<'a, E> + impl<'a, 'de, E> MapRefDeserializer<'a, 'de, E> where E: de::Error, { - fn new(map: &'a [(Content, Content)]) -> Self { + fn new(map: &'a [(Content<'de>, Content<'de>)]) -> Self { MapRefDeserializer { iter: map.into_iter(), value: None, @@ -1641,7 +1699,7 @@ mod content { } } - impl<'de, 'a, E> de::MapAccess<'de> for MapRefDeserializer<'a, E> + impl<'de, 'a, E> de::MapAccess<'de> for MapRefDeserializer<'a, 'de, E> where E: de::Error, { @@ -1676,7 +1734,7 @@ mod content { } } - impl<'de, 'a, E> de::Deserializer<'de> for MapRefDeserializer<'a, E> + impl<'de, 'a, E> de::Deserializer<'de> for MapRefDeserializer<'a, 'de, E> where E: de::Error, { @@ -1697,7 +1755,7 @@ mod content { } } - impl<'de, E> de::IntoDeserializer<'de, E> for ContentDeserializer + impl<'de, E> de::IntoDeserializer<'de, E> for ContentDeserializer<'de, E> where E: de::Error, { @@ -1708,7 +1766,7 @@ mod content { } } - impl<'de, 'a, E> de::IntoDeserializer<'de, E> for ContentRefDeserializer<'a, E> + impl<'de, 'a, E> de::IntoDeserializer<'de, E> for ContentRefDeserializer<'a, 'de, E> where E: de::Error, { diff --git a/test_suite/tests/test_macros.rs b/test_suite/tests/test_macros.rs index 6c7e4b2a..651bcfea 100644 --- a/test_suite/tests/test_macros.rs +++ b/test_suite/tests/test_macros.rs @@ -695,6 +695,30 @@ fn test_internally_tagged_enum() { ); } +#[test] +fn test_internally_tagged_borrow() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + #[serde(tag = "type")] + pub enum Input<'a> { + Package { name: &'a str }, + } + + assert_tokens( + &Input::Package { name: "borrowed" }, + &[ + Token::Struct { name: "Input", len: 2 }, + + Token::BorrowedStr("type"), + Token::BorrowedStr("Package"), + + Token::BorrowedStr("name"), + Token::BorrowedStr("borrowed"), + + Token::StructEnd, + ], + ); +} + #[test] fn test_adjacently_tagged_enum() { #[derive(Debug, PartialEq, Serialize, Deserialize)]