scuffle_amf0/
decoder.rs

1//! AMF0 decoder
2
3use std::io;
4
5use byteorder::{BigEndian, ReadBytesExt};
6use num_traits::FromPrimitive;
7use scuffle_bytes_util::StringCow;
8use scuffle_bytes_util::zero_copy::ZeroCopyReader;
9
10use crate::{Amf0Array, Amf0Error, Amf0Marker, Amf0Object, Amf0Value};
11
12/// AMF0 decoder.
13///
14/// Provides various functions to decode different types of AMF0 values.
15#[derive(Debug, Clone)]
16pub struct Amf0Decoder<R> {
17    pub(crate) reader: R,
18    pub(crate) next_marker: Option<Amf0Marker>,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub(crate) enum ObjectHeader<'a> {
23    Object,
24    TypedObject { name: StringCow<'a> },
25    EcmaArray { size: u32 },
26}
27
28impl<B> Amf0Decoder<scuffle_bytes_util::zero_copy::BytesBuf<B>>
29where
30    B: bytes::Buf,
31{
32    /// Create a new deserializer from a buffer implementing [`bytes::Buf`].
33    pub fn from_buf(buf: B) -> Self {
34        Self {
35            reader: buf.into(),
36            next_marker: None,
37        }
38    }
39}
40
41impl<R> Amf0Decoder<scuffle_bytes_util::zero_copy::IoRead<R>>
42where
43    R: std::io::Read,
44{
45    /// Create a new deserializer from a reader implementing [`std::io::Read`].
46    pub fn from_reader(reader: R) -> Self {
47        Self {
48            reader: reader.into(),
49            next_marker: None,
50        }
51    }
52}
53
54impl<'a> Amf0Decoder<scuffle_bytes_util::zero_copy::Slice<'a>> {
55    /// Create a new deserializer from a byte slice.
56    pub fn from_slice(slice: &'a [u8]) -> Amf0Decoder<scuffle_bytes_util::zero_copy::Slice<'a>> {
57        Self {
58            reader: slice.into(),
59            next_marker: None,
60        }
61    }
62}
63
64impl<'a, R> Amf0Decoder<R>
65where
66    R: ZeroCopyReader<'a>,
67{
68    /// Decode a [`Amf0Value`] from the buffer.
69    pub fn decode_value(&mut self) -> Result<Amf0Value<'a>, Amf0Error> {
70        let marker = self.peek_marker()?;
71
72        match marker {
73            Amf0Marker::Boolean => self.decode_boolean().map(Into::into),
74            Amf0Marker::Number | Amf0Marker::Date => self.decode_number().map(Into::into),
75            Amf0Marker::String | Amf0Marker::LongString | Amf0Marker::XmlDocument => self.decode_string().map(Into::into),
76            Amf0Marker::Null | Amf0Marker::Undefined => self.decode_null().map(|_| Amf0Value::Null),
77            Amf0Marker::Object | Amf0Marker::TypedObject | Amf0Marker::EcmaArray => self.decode_object().map(Into::into),
78            Amf0Marker::StrictArray => self.decode_strict_array().map(Into::into),
79            _ => Err(Amf0Error::UnsupportedMarker(marker)),
80        }
81    }
82
83    /// Decode all values from the buffer until the end.
84    pub fn decode_all(&mut self) -> Result<Vec<Amf0Value<'a>>, Amf0Error> {
85        let mut values = Vec::new();
86
87        while self.has_remaining()? {
88            values.push(self.decode_value()?);
89        }
90
91        Ok(values)
92    }
93
94    /// Convert the decoder into an iterator over the values in the buffer.
95    pub fn stream(&mut self) -> Amf0DecoderStream<'_, 'a, R> {
96        Amf0DecoderStream {
97            decoder: self,
98            _marker: std::marker::PhantomData,
99        }
100    }
101
102    /// Check if there are any values left in the buffer.
103    pub fn has_remaining(&mut self) -> Result<bool, Amf0Error> {
104        match self.peek_marker() {
105            Ok(_) => Ok(true),
106            Err(Amf0Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(false),
107            Err(err) => Err(err),
108        }
109    }
110
111    /// Peek the next marker in the buffer without consuming it.
112    pub fn peek_marker(&mut self) -> Result<Amf0Marker, Amf0Error> {
113        let marker = self.read_marker()?;
114        // Buffer the marker for the next read
115        self.next_marker = Some(marker);
116
117        Ok(marker)
118    }
119
120    fn read_marker(&mut self) -> Result<Amf0Marker, Amf0Error> {
121        if let Some(marker) = self.next_marker.take() {
122            return Ok(marker);
123        }
124
125        let marker = self.reader.as_std().read_u8()?;
126        let marker = Amf0Marker::from_u8(marker).ok_or(Amf0Error::UnknownMarker(marker))?;
127        Ok(marker)
128    }
129
130    fn expect_marker(&mut self, expect: &'static [Amf0Marker]) -> Result<Amf0Marker, Amf0Error> {
131        let marker = self.read_marker()?;
132
133        if !expect.contains(&marker) {
134            Err(Amf0Error::UnexpectedType {
135                expected: expect,
136                got: marker,
137            })
138        } else {
139            Ok(marker)
140        }
141    }
142
143    /// Decode a number from the buffer.
144    pub fn decode_number(&mut self) -> Result<f64, Amf0Error> {
145        let marker = self.expect_marker(&[Amf0Marker::Number, Amf0Marker::Date])?;
146
147        let number = self.reader.as_std().read_f64::<BigEndian>()?;
148
149        if marker == Amf0Marker::Date {
150            // Skip the timezone
151            self.reader.as_std().read_i16::<BigEndian>()?;
152        }
153
154        Ok(number)
155    }
156
157    /// Decode a boolean from the buffer.
158    pub fn decode_boolean(&mut self) -> Result<bool, Amf0Error> {
159        self.expect_marker(&[Amf0Marker::Boolean])?;
160        let value = self.reader.as_std().read_u8()?;
161        Ok(value != 0)
162    }
163
164    pub(crate) fn decode_normal_string(&mut self) -> Result<StringCow<'a>, Amf0Error> {
165        let len = self.reader.as_std().read_u16::<BigEndian>()? as usize;
166
167        let bytes = self.reader.try_read(len)?;
168        Ok(StringCow::from_bytes(bytes.into_bytes().try_into()?))
169    }
170
171    /// Decode a string from the buffer.
172    ///
173    /// This function can decode both normal strings and long strings.
174    pub fn decode_string(&mut self) -> Result<StringCow<'a>, Amf0Error> {
175        let marker = self.expect_marker(&[Amf0Marker::String, Amf0Marker::LongString, Amf0Marker::XmlDocument])?;
176
177        let len = if marker == Amf0Marker::String {
178            self.reader.as_std().read_u16::<BigEndian>()? as usize
179        } else {
180            // LongString or XmlDocument
181            self.reader.as_std().read_u32::<BigEndian>()? as usize
182        };
183
184        let bytes = self.reader.try_read(len)?;
185        Ok(StringCow::from_bytes(bytes.into_bytes().try_into()?))
186    }
187
188    /// Decode a null value from the buffer.
189    ///
190    /// This function can also decode undefined values.
191    pub fn decode_null(&mut self) -> Result<(), Amf0Error> {
192        self.expect_marker(&[Amf0Marker::Null, Amf0Marker::Undefined])?;
193        Ok(())
194    }
195
196    /// Deserialize a value from the buffer using [serde].
197    #[cfg(feature = "serde")]
198    #[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
199    pub fn deserialize<T>(&mut self) -> Result<T, Amf0Error>
200    where
201        T: serde::de::Deserialize<'a>,
202    {
203        T::deserialize(self)
204    }
205
206    /// Deserialize a stream of values from the buffer using [serde].
207    #[cfg(feature = "serde")]
208    #[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
209    pub fn deserialize_stream<T>(&mut self) -> crate::de::Amf0DeserializerStream<'_, R, T>
210    where
211        T: serde::de::Deserialize<'a>,
212    {
213        crate::de::Amf0DeserializerStream::new(self)
214    }
215
216    // --- Object and Ecma array ---
217
218    pub(crate) fn decode_object_header(&mut self) -> Result<ObjectHeader<'a>, Amf0Error> {
219        let marker = self.expect_marker(&[Amf0Marker::Object, Amf0Marker::TypedObject, Amf0Marker::EcmaArray])?;
220
221        if marker == Amf0Marker::Object {
222            Ok(ObjectHeader::Object)
223        } else if marker == Amf0Marker::TypedObject {
224            let name = self.decode_normal_string()?;
225            Ok(ObjectHeader::TypedObject { name })
226        } else {
227            // EcmaArray
228            let size = self.reader.as_std().read_u32::<BigEndian>()?;
229            Ok(ObjectHeader::EcmaArray { size })
230        }
231    }
232
233    pub(crate) fn decode_object_key(&mut self) -> Result<Option<StringCow<'a>>, Amf0Error> {
234        // Object keys are not preceeded with a marker and are always normal strings
235        let key = self.decode_normal_string()?;
236
237        // The object end marker is preceeded by an empty string
238        if key.as_str().is_empty() {
239            // Check if the next marker is an object end marker
240            if self.peek_marker()? == Amf0Marker::ObjectEnd {
241                // Clear the next marker buffer
242                self.next_marker = None;
243
244                return Ok(None);
245            }
246        }
247
248        Ok(Some(key))
249    }
250
251    /// Decode an object from the buffer.
252    ///
253    /// This function can decode normal objects, typed objects and ECMA arrays.
254    pub fn decode_object(&mut self) -> Result<Amf0Object<'a>, Amf0Error> {
255        let header = self.decode_object_header()?;
256
257        match header {
258            ObjectHeader::Object | ObjectHeader::TypedObject { .. } => {
259                let mut object = Amf0Object::new();
260
261                while let Some(key) = self.decode_object_key()? {
262                    let value = self.decode_value()?;
263                    object.insert(key, value);
264                }
265
266                Ok(object)
267            }
268            ObjectHeader::EcmaArray { size } => {
269                let mut object = Amf0Object::with_capacity(size as usize);
270
271                for _ in 0..size {
272                    // Object keys are not preceeded with a marker and are always normal strings
273                    let key = self.decode_normal_string()?;
274                    let value = self.decode_value()?;
275                    object.insert(key, value);
276                }
277
278                // There might be an object end marker after the last key
279                if self.has_remaining()? && self.peek_marker()? == Amf0Marker::ObjectEnd {
280                    // Clear the next marker buffer
281                    self.next_marker = None;
282                }
283
284                Ok(object)
285            }
286        }
287    }
288
289    // --- Strict array ---
290
291    pub(crate) fn decode_strict_array_header(&mut self) -> Result<u32, Amf0Error> {
292        self.expect_marker(&[Amf0Marker::StrictArray])?;
293        let size = self.reader.as_std().read_u32::<BigEndian>()?;
294
295        Ok(size)
296    }
297
298    /// Decode a strict array from the buffer.
299    pub fn decode_strict_array(&mut self) -> Result<Amf0Array<'a>, Amf0Error> {
300        let size = self.decode_strict_array_header()? as usize;
301
302        let mut array = Vec::with_capacity(size);
303
304        for _ in 0..size {
305            let value = self.decode_value()?;
306            array.push(value);
307        }
308
309        Ok(Amf0Array::from(array))
310    }
311}
312
313/// An iterator over the values in the buffer.
314///
315/// Yields values of type [`Amf0Value`] until the end of the buffer is reached.
316#[must_use = "Iterators are lazy and do nothing unless consumed"]
317pub struct Amf0DecoderStream<'a, 'de, R> {
318    decoder: &'a mut Amf0Decoder<R>,
319    _marker: std::marker::PhantomData<&'de ()>,
320}
321
322impl<'de, R: ZeroCopyReader<'de>> Iterator for Amf0DecoderStream<'_, 'de, R> {
323    type Item = Result<Amf0Value<'de>, Amf0Error>;
324
325    fn next(&mut self) -> Option<Self::Item> {
326        match self.decoder.has_remaining() {
327            Ok(true) => Some(self.decoder.decode_value()),
328            Ok(false) => None,
329            Err(err) => Some(Err(err)),
330        }
331    }
332}
333
334impl<'de, R> std::iter::FusedIterator for Amf0DecoderStream<'_, 'de, R> where R: ZeroCopyReader<'de> {}
335
336#[cfg(test)]
337#[cfg_attr(all(test, coverage_nightly), coverage(off))]
338mod tests {
339    use super::Amf0Decoder;
340    use crate::{Amf0Marker, Amf0Value};
341
342    #[test]
343    fn strict_array() {
344        #[rustfmt::skip]
345        let bytes = [
346            Amf0Marker::StrictArray as u8,
347            0, 0, 0, 2, // size
348            Amf0Marker::String as u8,
349            0, 3, b'v', b'a', b'l', // value
350            Amf0Marker::Boolean as u8,
351            1, // value
352        ];
353
354        let mut decoder = Amf0Decoder::from_slice(&bytes);
355        let array = decoder.decode_strict_array().unwrap();
356        assert_eq!(array.len(), 2);
357        assert_eq!(array[0], Amf0Value::String("val".into()));
358        assert_eq!(array[1], Amf0Value::Boolean(true));
359    }
360
361    #[test]
362    fn ecma_array() {
363        #[rustfmt::skip]
364        let bytes = [
365            Amf0Marker::EcmaArray as u8,
366            0, 0, 0, 2, // size
367            0, 3, b'a', b'b', b'c', // key
368            Amf0Marker::String as u8,
369            0, 3, b'v', b'a', b'l', // value
370            0, 4, b'd', b'e', b'f', b'g', // key
371            Amf0Marker::Boolean as u8,
372            1, // value
373        ];
374
375        let mut decoder = Amf0Decoder::from_slice(&bytes);
376        let object = decoder.decode_object().unwrap();
377        assert_eq!(object.len(), 2);
378        assert_eq!(*object.get(&"abc".into()).unwrap(), Amf0Value::String("val".into()));
379        assert_eq!(*object.get(&"defg".into()).unwrap(), Amf0Value::Boolean(true));
380    }
381
382    #[test]
383    fn decoder_stream() {
384        #[rustfmt::skip]
385        let bytes = [
386            Amf0Marker::Boolean as u8,
387            1, // value
388            Amf0Marker::String as u8,
389            0, 3, b'a', b'b', b'c', // value
390            Amf0Marker::Null as u8,
391        ];
392
393        let mut decoder = Amf0Decoder::from_slice(&bytes);
394        let mut stream = decoder.stream();
395        assert_eq!(stream.next().unwrap().unwrap(), Amf0Value::Boolean(true));
396        assert_eq!(stream.next().unwrap().unwrap(), Amf0Value::String("abc".into()));
397        assert_eq!(stream.next().unwrap().unwrap(), Amf0Value::Null);
398        assert!(stream.next().is_none());
399    }
400}