scuffle_amf0/
value.rs

1//! AMF0 value types.
2
3use std::borrow::Cow;
4use std::collections::HashMap;
5use std::io;
6
7use scuffle_bytes_util::StringCow;
8
9use crate::Amf0Error;
10use crate::encoder::Amf0Encoder;
11
12/// Represents any AMF0 object.
13pub type Amf0Object<'a> = HashMap<StringCow<'a>, Amf0Value<'a>>;
14/// Represents any AMF0 array.
15pub type Amf0Array<'a> = Cow<'a, [Amf0Value<'a>]>;
16
17/// Represents any AMF0 value.
18#[derive(Debug, PartialEq, Clone)]
19pub enum Amf0Value<'a> {
20    /// AMF0 Number.
21    Number(f64),
22    /// AMF0 Boolean.
23    Boolean(bool),
24    /// AMF0 String.
25    String(StringCow<'a>),
26    /// AMF0 Object.
27    Object(Amf0Object<'a>),
28    /// AMF0 Null.
29    Null,
30    /// AMF0 Array.
31    Array(Amf0Array<'a>),
32}
33
34impl Amf0Value<'_> {
35    /// Converts this AMF0 value into an owned version (static lifetime).
36    pub fn into_owned(self) -> Amf0Value<'static> {
37        match self {
38            Amf0Value::Number(v) => Amf0Value::Number(v),
39            Amf0Value::Boolean(v) => Amf0Value::Boolean(v),
40            Amf0Value::String(v) => Amf0Value::String(v.into_owned()),
41            Amf0Value::Object(v) => {
42                Amf0Value::Object(v.into_iter().map(|(k, v)| (k.into_owned(), v.into_owned())).collect())
43            }
44            Amf0Value::Null => Amf0Value::Null,
45            Amf0Value::Array(v) => Amf0Value::Array(v.into_owned().into_iter().map(|v| v.into_owned()).collect()),
46        }
47    }
48
49    /// Encode this AMF0 value with the given encoder.
50    pub fn encode<W: io::Write>(&self, encoder: &mut Amf0Encoder<W>) -> Result<(), Amf0Error> {
51        match self {
52            Amf0Value::Number(v) => encoder.encode_number(*v),
53            Amf0Value::Boolean(v) => encoder.encode_boolean(*v),
54            Amf0Value::String(v) => encoder.encode_string(v.as_str()),
55            Amf0Value::Object(v) => encoder.encode_object(v),
56            Amf0Value::Null => encoder.encode_null(),
57            Amf0Value::Array(v) => encoder.encode_array(v),
58        }
59    }
60}
61
62impl From<f64> for Amf0Value<'_> {
63    fn from(value: f64) -> Self {
64        Amf0Value::Number(value)
65    }
66}
67
68impl From<bool> for Amf0Value<'_> {
69    fn from(value: bool) -> Self {
70        Amf0Value::Boolean(value)
71    }
72}
73
74impl<'a> From<StringCow<'a>> for Amf0Value<'a> {
75    fn from(value: StringCow<'a>) -> Self {
76        Amf0Value::String(value)
77    }
78}
79
80// object
81impl<'a> From<Amf0Object<'a>> for Amf0Value<'a> {
82    fn from(value: Amf0Object<'a>) -> Self {
83        Amf0Value::Object(value)
84    }
85}
86
87// owned array
88impl<'a> From<Vec<Amf0Value<'a>>> for Amf0Value<'a> {
89    fn from(value: Vec<Amf0Value<'a>>) -> Self {
90        Amf0Value::Array(Cow::Owned(value))
91    }
92}
93
94// borrowed array
95impl<'a> From<&'a [Amf0Value<'a>]> for Amf0Value<'a> {
96    fn from(value: &'a [Amf0Value<'a>]) -> Self {
97        Amf0Value::Array(Cow::Borrowed(value))
98    }
99}
100
101// cow array
102impl<'a> From<Amf0Array<'a>> for Amf0Value<'a> {
103    fn from(value: Amf0Array<'a>) -> Self {
104        Amf0Value::Array(value)
105    }
106}
107
108impl<'a> FromIterator<Amf0Value<'a>> for Amf0Value<'a> {
109    fn from_iter<T: IntoIterator<Item = Amf0Value<'a>>>(iter: T) -> Self {
110        Amf0Value::Array(Cow::Owned(iter.into_iter().collect()))
111    }
112}
113
114#[cfg(feature = "serde")]
115#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
116impl<'de> serde::de::Deserialize<'de> for Amf0Value<'de> {
117    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
118    where
119        D: serde::Deserializer<'de>,
120    {
121        struct Amf0ValueVisitor;
122
123        impl<'de> serde::de::Visitor<'de> for Amf0ValueVisitor {
124            type Value = Amf0Value<'de>;
125
126            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
127                formatter.write_str("an AMF0 value")
128            }
129
130            #[inline]
131            fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
132            where
133                E: serde::de::Error,
134            {
135                Ok(Amf0Value::Boolean(v))
136            }
137
138            #[inline]
139            fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
140            where
141                E: serde::de::Error,
142            {
143                self.visit_f64(v as f64)
144            }
145
146            #[inline]
147            fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
148            where
149                E: serde::de::Error,
150            {
151                self.visit_f64(v as f64)
152            }
153
154            #[inline]
155            fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
156            where
157                E: serde::de::Error,
158            {
159                Ok(Amf0Value::Number(v))
160            }
161
162            #[inline]
163            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
164            where
165                E: serde::de::Error,
166            {
167                self.visit_string(v.to_owned())
168            }
169
170            #[inline]
171            fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
172            where
173                E: serde::de::Error,
174            {
175                Ok(StringCow::from(v).into())
176            }
177
178            #[inline]
179            fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
180            where
181                E: serde::de::Error,
182            {
183                Ok(StringCow::from(v).into())
184            }
185
186            #[inline]
187            fn visit_unit<E>(self) -> Result<Self::Value, E>
188            where
189                E: serde::de::Error,
190            {
191                Ok(Amf0Value::Null)
192            }
193
194            #[inline]
195            fn visit_none<E>(self) -> Result<Self::Value, E>
196            where
197                E: serde::de::Error,
198            {
199                Ok(Amf0Value::Null)
200            }
201
202            #[inline]
203            fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
204            where
205                D: serde::Deserializer<'de>,
206            {
207                serde::Deserialize::deserialize(deserializer)
208            }
209
210            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
211            where
212                A: serde::de::SeqAccess<'de>,
213            {
214                let mut vec = Vec::new();
215
216                while let Some(value) = seq.next_element()? {
217                    vec.push(value);
218                }
219
220                Ok(vec.into())
221            }
222
223            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
224            where
225                A: serde::de::MapAccess<'de>,
226            {
227                let mut object = HashMap::new();
228
229                while let Some((key, value)) = map.next_entry()? {
230                    object.insert(key, value);
231                }
232
233                Ok(object.into())
234            }
235
236            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
237            where
238                E: serde::de::Error,
239            {
240                let array = v.iter().map(|b| Amf0Value::Number(*b as f64)).collect();
241                Ok(Amf0Value::Array(array))
242            }
243        }
244
245        deserializer.deserialize_any(Amf0ValueVisitor)
246    }
247}
248
249#[cfg(feature = "serde")]
250#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
251impl serde::ser::Serialize for Amf0Value<'_> {
252    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
253    where
254        S: serde::Serializer,
255    {
256        match self {
257            Amf0Value::Number(v) => serializer.serialize_f64(*v),
258            Amf0Value::Boolean(v) => serializer.serialize_bool(*v),
259            Amf0Value::String(v) => v.serialize(serializer),
260            Amf0Value::Object(v) => {
261                let mut map = serializer.serialize_map(Some(v.len()))?;
262
263                for (key, value) in v.iter() {
264                    serde::ser::SerializeMap::serialize_entry(&mut map, key, value)?;
265                }
266
267                serde::ser::SerializeMap::end(map)
268            }
269            Amf0Value::Null => serializer.serialize_none(),
270            Amf0Value::Array(v) => {
271                let mut seq = serializer.serialize_seq(Some(v.len()))?;
272
273                for value in v.iter() {
274                    serde::ser::SerializeSeq::serialize_element(&mut seq, value)?;
275                }
276
277                serde::ser::SerializeSeq::end(seq)
278            }
279        }
280    }
281}
282
283#[cfg(test)]
284#[cfg_attr(all(test, coverage_nightly), coverage(off))]
285mod tests {
286    use std::borrow::Cow;
287
288    use scuffle_bytes_util::StringCow;
289
290    use super::Amf0Value;
291    use crate::{Amf0Array, Amf0Decoder, Amf0Encoder, Amf0Error, Amf0Marker, Amf0Object};
292
293    #[test]
294    fn from() {
295        let value: Amf0Value = 1.0.into();
296        assert_eq!(value, Amf0Value::Number(1.0));
297
298        let value: Amf0Value = true.into();
299        assert_eq!(value, Amf0Value::Boolean(true));
300
301        let value: Amf0Value = StringCow::from("abc").into();
302        assert_eq!(value, Amf0Value::String("abc".into()));
303
304        let object: Amf0Object = [("a".into(), Amf0Value::Boolean(true))].into_iter().collect();
305        let value: Amf0Value = object.clone().into();
306        assert_eq!(value, Amf0Value::Object(object));
307
308        let array: Vec<Amf0Value> = vec![Amf0Value::Boolean(true)];
309        let value: Amf0Value = array.clone().into();
310        assert_eq!(value, Amf0Value::Array(Cow::Owned(array)));
311
312        let array: &[Amf0Value] = &[Amf0Value::Boolean(true)];
313        let value: Amf0Value = array.into();
314        assert_eq!(value, Amf0Value::Array(Cow::Borrowed(array)));
315
316        let array: Amf0Array = Cow::Borrowed(&[Amf0Value::Boolean(true)]);
317        let value: Amf0Value = array.clone().into();
318        assert_eq!(value, Amf0Value::Array(array));
319
320        let iter = std::iter::once(Amf0Value::Boolean(true));
321        let value: Amf0Value = iter.collect();
322        assert_eq!(value, Amf0Value::Array(Cow::Owned(vec![Amf0Value::Boolean(true)])));
323    }
324
325    #[test]
326    fn unsupported_marker() {
327        let bytes = [Amf0Marker::MovieClipMarker as u8];
328
329        let err = Amf0Decoder::from_slice(&bytes).decode_value().unwrap_err();
330        assert!(matches!(err, Amf0Error::UnsupportedMarker(Amf0Marker::MovieClipMarker)));
331    }
332
333    #[test]
334    fn string() {
335        use crate::Amf0Decoder;
336
337        #[rustfmt::skip]
338        let bytes = [
339            Amf0Marker::String as u8,
340            0, 3, // length
341            b'a', b'b', b'c',
342        ];
343
344        let value = Amf0Decoder::from_slice(&bytes).decode_value().unwrap();
345        assert_eq!(value, Amf0Value::String("abc".into()));
346    }
347
348    #[test]
349    fn bool() {
350        let bytes = [Amf0Marker::Boolean as u8, 0];
351
352        let value = Amf0Decoder::from_slice(&bytes).decode_value().unwrap();
353        assert_eq!(value, Amf0Value::Boolean(false));
354    }
355
356    #[test]
357    fn object() {
358        #[rustfmt::skip]
359        let bytes = [
360            Amf0Marker::Object as u8,
361            0, 1,
362            b'a',
363            Amf0Marker::Boolean as u8,
364            1,
365            0, 0, Amf0Marker::ObjectEnd as u8,
366        ];
367
368        let value = Amf0Decoder::from_slice(&bytes).decode_value().unwrap();
369        assert_eq!(
370            value,
371            Amf0Value::Object([("a".into(), Amf0Value::Boolean(true))].into_iter().collect())
372        );
373    }
374
375    #[test]
376    fn array() {
377        #[rustfmt::skip]
378        let bytes = [
379            Amf0Marker::StrictArray as u8,
380            0, 0, 0, 1,
381            Amf0Marker::Boolean as u8,
382            1,
383        ];
384
385        let value = Amf0Decoder::from_slice(&bytes).decode_value().unwrap();
386        assert_eq!(value, Amf0Value::Array(Cow::Borrowed(&[Amf0Value::Boolean(true)])));
387
388        let mut serialized = vec![];
389        value.encode(&mut Amf0Encoder::new(&mut serialized)).unwrap();
390        assert_eq!(serialized, bytes);
391    }
392
393    #[test]
394    fn null() {
395        let bytes = [Amf0Marker::Null as u8];
396
397        let value = Amf0Decoder::from_slice(&bytes).decode_value().unwrap();
398        assert_eq!(value, Amf0Value::Null);
399
400        let mut serialized = vec![];
401        value.encode(&mut Amf0Encoder::new(&mut serialized)).unwrap();
402        assert_eq!(serialized, bytes);
403    }
404
405    #[test]
406    fn into_owned() {
407        let value = Amf0Value::Number(1.0);
408        let owned_value = value.clone().into_owned();
409        assert_eq!(owned_value, value);
410
411        let value = Amf0Value::Boolean(true);
412        let owned_value = value.clone().into_owned();
413        assert_eq!(owned_value, value);
414
415        let value = Amf0Value::String("abc".into());
416        let owned_value = value.clone().into_owned();
417        assert_eq!(owned_value, value);
418
419        let value = Amf0Value::Object([("a".into(), Amf0Value::Boolean(true))].into_iter().collect());
420        let owned_value = value.clone().into_owned();
421        assert_eq!(owned_value, value,);
422
423        let value = Amf0Value::Null;
424        let owned_value = value.clone().into_owned();
425        assert_eq!(owned_value, value);
426
427        let value = Amf0Value::Array(Cow::Borrowed(&[Amf0Value::Boolean(true)]));
428        let owned_value = value.clone().into_owned();
429        assert_eq!(owned_value, value);
430    }
431
432    #[cfg(feature = "serde")]
433    #[test]
434    fn deserialize() {
435        use std::fmt::Display;
436
437        use serde::Deserialize;
438        use serde::de::{IntoDeserializer, MapAccess, SeqAccess};
439
440        #[derive(Debug)]
441        struct TestError;
442
443        impl Display for TestError {
444            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445                write!(f, "Test error")
446            }
447        }
448
449        impl std::error::Error for TestError {}
450
451        impl serde::de::Error for TestError {
452            fn custom<T: std::fmt::Display>(msg: T) -> Self {
453                assert_eq!(msg.to_string(), "invalid type: Option value, expected a byte slice");
454                Self
455            }
456        }
457
458        enum Mode {
459            Bool,
460            I64,
461            U64,
462            F64,
463            Str,
464            String,
465            BorrowedStr,
466            Unit,
467            None,
468            Some,
469            Seq,
470            Map,
471            Bytes,
472            End,
473        }
474
475        struct TestDeserializer {
476            mode: Mode,
477        }
478
479        impl<'de> SeqAccess<'de> for TestDeserializer {
480            type Error = TestError;
481
482            fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
483            where
484                T: serde::de::DeserializeSeed<'de>,
485            {
486                match self.mode {
487                    Mode::Seq => {
488                        self.mode = Mode::End;
489                        Ok(Some(seed.deserialize(TestDeserializer { mode: Mode::I64 })?))
490                    }
491                    Mode::End => Ok(None),
492                    _ => Err(TestError),
493                }
494            }
495        }
496
497        impl<'de> MapAccess<'de> for TestDeserializer {
498            type Error = TestError;
499
500            fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
501            where
502                K: serde::de::DeserializeSeed<'de>,
503            {
504                match self.mode {
505                    Mode::Map => Ok(Some(seed.deserialize(TestDeserializer { mode: Mode::Str })?)),
506                    Mode::End => Ok(None),
507                    _ => Err(TestError),
508                }
509            }
510
511            fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
512            where
513                V: serde::de::DeserializeSeed<'de>,
514            {
515                match self.mode {
516                    Mode::Map => {
517                        self.mode = Mode::End;
518                        Ok(seed.deserialize(TestDeserializer { mode: Mode::I64 })?)
519                    }
520                    _ => Err(TestError),
521                }
522            }
523        }
524
525        impl<'de> serde::Deserializer<'de> for TestDeserializer {
526            type Error = TestError;
527
528            serde::forward_to_deserialize_any! {
529                bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes byte_buf
530                option unit unit_struct newtype_struct seq tuple tuple_struct
531                map struct enum identifier ignored_any
532            }
533
534            fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
535            where
536                V: serde::de::Visitor<'de>,
537            {
538                match self.mode {
539                    Mode::Bool => visitor.visit_bool(true),
540                    Mode::I64 => visitor.visit_i64(1),
541                    Mode::U64 => visitor.visit_u64(1),
542                    Mode::F64 => visitor.visit_f64(1.0),
543                    Mode::Str => visitor.visit_str("hello"),
544                    Mode::String => visitor.visit_string("hello".to_owned()),
545                    Mode::BorrowedStr => visitor.visit_borrowed_str("hello"),
546                    Mode::Unit => visitor.visit_unit(),
547                    Mode::None => visitor.visit_none(),
548                    Mode::Some => visitor.visit_some(1.into_deserializer()),
549                    Mode::Seq => visitor.visit_seq(self),
550                    Mode::Map => visitor.visit_map(self),
551                    Mode::Bytes => visitor.visit_bytes(b"hello"),
552                    Mode::End => unreachable!(),
553                }
554            }
555        }
556
557        fn test_de(mode: Mode, expected: Amf0Value) {
558            let deserializer = TestDeserializer { mode };
559            let deserialized_value: Amf0Value = Amf0Value::deserialize(deserializer).unwrap();
560            assert_eq!(deserialized_value, expected);
561        }
562
563        test_de(Mode::Bool, Amf0Value::Boolean(true));
564        test_de(Mode::I64, Amf0Value::Number(1.0));
565        test_de(Mode::U64, Amf0Value::Number(1.0));
566        test_de(Mode::F64, Amf0Value::Number(1.0));
567        test_de(Mode::Str, Amf0Value::String("hello".into()));
568        test_de(Mode::String, Amf0Value::String("hello".into()));
569        test_de(Mode::BorrowedStr, Amf0Value::String("hello".into()));
570        test_de(Mode::Unit, Amf0Value::Null);
571        test_de(Mode::None, Amf0Value::Null);
572        test_de(Mode::Some, Amf0Value::Number(1.0));
573        test_de(Mode::Seq, Amf0Value::Array(Cow::Owned(vec![Amf0Value::Number(1.0)])));
574        test_de(
575            Mode::Map,
576            Amf0Value::Object([("hello".into(), Amf0Value::Number(1.0))].into_iter().collect()),
577        );
578        test_de(
579            Mode::Bytes,
580            Amf0Value::Array(Cow::Owned(vec![
581                Amf0Value::Number(104.0),
582                Amf0Value::Number(101.0),
583                Amf0Value::Number(108.0),
584                Amf0Value::Number(108.0),
585                Amf0Value::Number(111.0),
586            ])),
587        );
588    }
589}