scuffle_http/
body.rs

1//! Types for working with HTTP bodies.
2
3use std::fmt::Debug;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use bytes::{Buf, Bytes};
8use http_body::Frame;
9
10/// An error that can occur when reading the body of an incoming request.
11#[derive(thiserror::Error, Debug)]
12pub enum IncomingBodyError {
13    /// An error that occurred while reading a hyper body.
14    #[error("hyper error: {0}")]
15    #[cfg(any(feature = "http1", feature = "http2"))]
16    #[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
17    Hyper(#[from] hyper::Error),
18    /// An error that occurred while reading a quic body.
19    #[error("quic error: {0}")]
20    #[cfg(feature = "http3")]
21    #[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
22    Quic(#[from] h3::Error),
23}
24
25/// The body of an incoming request.
26///
27/// This enum is used to abstract away the differences between the body types of HTTP/1, HTTP/2 and HTTP/3.
28/// It implements the [`http_body::Body`] trait.
29pub enum IncomingBody {
30    /// The body of an incoming hyper request.
31    #[cfg(any(feature = "http1", feature = "http2"))]
32    #[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
33    Hyper(hyper::body::Incoming),
34    /// The body of an incoming h3 request.
35    #[cfg(feature = "http3")]
36    #[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
37    Quic(crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>),
38}
39
40#[cfg(any(feature = "http1", feature = "http2"))]
41#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
42impl From<hyper::body::Incoming> for IncomingBody {
43    fn from(body: hyper::body::Incoming) -> Self {
44        IncomingBody::Hyper(body)
45    }
46}
47
48#[cfg(feature = "http3")]
49#[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
50impl From<crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>> for IncomingBody {
51    fn from(body: crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>) -> Self {
52        IncomingBody::Quic(body)
53    }
54}
55
56impl http_body::Body for IncomingBody {
57    type Data = Bytes;
58    type Error = IncomingBodyError;
59
60    fn is_end_stream(&self) -> bool {
61        match self {
62            #[cfg(any(feature = "http1", feature = "http2"))]
63            IncomingBody::Hyper(body) => body.is_end_stream(),
64            #[cfg(feature = "http3")]
65            IncomingBody::Quic(body) => body.is_end_stream(),
66            #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
67            _ => false,
68        }
69    }
70
71    fn poll_frame(
72        self: std::pin::Pin<&mut Self>,
73        _cx: &mut std::task::Context<'_>,
74    ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
75        match self.get_mut() {
76            #[cfg(any(feature = "http1", feature = "http2"))]
77            IncomingBody::Hyper(body) => std::pin::Pin::new(body).poll_frame(_cx).map_err(Into::into),
78            #[cfg(feature = "http3")]
79            IncomingBody::Quic(body) => std::pin::Pin::new(body).poll_frame(_cx).map_err(Into::into),
80            #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
81            _ => std::task::Poll::Ready(None),
82        }
83    }
84
85    fn size_hint(&self) -> http_body::SizeHint {
86        match self {
87            #[cfg(any(feature = "http1", feature = "http2"))]
88            IncomingBody::Hyper(body) => body.size_hint(),
89            #[cfg(feature = "http3")]
90            IncomingBody::Quic(body) => body.size_hint(),
91            #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
92            _ => http_body::SizeHint::default(),
93        }
94    }
95}
96
97pin_project_lite::pin_project! {
98    /// A wrapper around an HTTP body that tracks the size of the data that is read from it.
99    pub struct TrackedBody<B, T> {
100        #[pin]
101        body: B,
102        tracker: T,
103    }
104}
105
106impl<B, T> TrackedBody<B, T> {
107    /// Create a new [`TrackedBody`] with the given body and tracker.
108    pub fn new(body: B, tracker: T) -> Self {
109        Self { body, tracker }
110    }
111}
112
113/// An error that can occur when tracking the body of an incoming request.
114#[derive(thiserror::Error)]
115pub enum TrackedBodyError<B, T>
116where
117    B: http_body::Body,
118    T: Tracker,
119{
120    /// An error that occurred while reading the body.
121    #[error("body error: {0}")]
122    Body(B::Error),
123    /// An error that occurred while calling [`Tracker::on_data`].
124    #[error("tracker error: {0}")]
125    Tracker(T::Error),
126}
127
128impl<B, T> Debug for TrackedBodyError<B, T>
129where
130    B: http_body::Body,
131    B::Error: Debug,
132    T: Tracker,
133    T::Error: Debug,
134{
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        match self {
137            TrackedBodyError::Body(err) => f.debug_tuple("TrackedBodyError::Body").field(err).finish(),
138            TrackedBodyError::Tracker(err) => f.debug_tuple("TrackedBodyError::Tracker").field(err).finish(),
139        }
140    }
141}
142
143/// A trait for tracking the size of the data that is read from an HTTP body.
144pub trait Tracker: Send + Sync + 'static {
145    /// The error type that can occur when [`Tracker::on_data`] is called.
146    type Error;
147
148    /// Called when data is read from the body.
149    ///
150    /// The `size` parameter is the size of the data that is remaining to be read from the body.
151    fn on_data(&self, size: usize) -> Result<(), Self::Error> {
152        let _ = size;
153        Ok(())
154    }
155}
156
157impl<B, T> http_body::Body for TrackedBody<B, T>
158where
159    B: http_body::Body,
160    T: Tracker,
161{
162    type Data = B::Data;
163    type Error = TrackedBodyError<B, T>;
164
165    fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
166        let this = self.project();
167
168        match this.body.poll_frame(cx) {
169            Poll::Pending => Poll::Pending,
170            Poll::Ready(frame) => {
171                if let Some(Ok(frame)) = &frame {
172                    if let Some(data) = frame.data_ref() {
173                        if let Err(err) = this.tracker.on_data(data.remaining()) {
174                            return Poll::Ready(Some(Err(TrackedBodyError::Tracker(err))));
175                        }
176                    }
177                }
178
179                Poll::Ready(frame.transpose().map_err(TrackedBodyError::Body).transpose())
180            }
181        }
182    }
183
184    fn is_end_stream(&self) -> bool {
185        self.body.is_end_stream()
186    }
187
188    fn size_hint(&self) -> http_body::SizeHint {
189        self.body.size_hint()
190    }
191}
192
193#[cfg(test)]
194#[cfg_attr(all(test, coverage_nightly), coverage(off))]
195mod tests {
196    use std::convert::Infallible;
197
198    use crate::body::TrackedBodyError;
199
200    #[test]
201    fn tracked_body_error_debug() {
202        struct TestTracker;
203
204        impl super::Tracker for TestTracker {
205            type Error = Infallible;
206        }
207
208        struct TestBody;
209
210        impl http_body::Body for TestBody {
211            type Data = bytes::Bytes;
212            type Error = ();
213
214            fn poll_frame(
215                self: std::pin::Pin<&mut Self>,
216                _cx: &mut std::task::Context<'_>,
217            ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
218                std::task::Poll::Ready(None)
219            }
220        }
221
222        let err = TrackedBodyError::<TestBody, TestTracker>::Body(());
223        assert_eq!(format!("{err:?}"), "TrackedBodyError::Body(())",);
224    }
225}