1use std::fmt::Debug;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use bytes::{Buf, Bytes};
8use http_body::Frame;
9
10#[derive(thiserror::Error, Debug)]
12pub enum IncomingBodyError {
13 #[error("hyper error: {0}")]
15 #[cfg(any(feature = "http1", feature = "http2"))]
16 #[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
17 Hyper(#[from] hyper::Error),
18 #[error("quic error: {0}")]
20 #[cfg(feature = "http3")]
21 #[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
22 Quic(#[from] h3::Error),
23}
24
25pub enum IncomingBody {
30 #[cfg(any(feature = "http1", feature = "http2"))]
32 #[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
33 Hyper(hyper::body::Incoming),
34 #[cfg(feature = "http3")]
36 #[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
37 Quic(crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>),
38}
39
40#[cfg(any(feature = "http1", feature = "http2"))]
41#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
42impl From<hyper::body::Incoming> for IncomingBody {
43 fn from(body: hyper::body::Incoming) -> Self {
44 IncomingBody::Hyper(body)
45 }
46}
47
48#[cfg(feature = "http3")]
49#[cfg_attr(docsrs, doc(cfg(feature = "http3")))]
50impl From<crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>> for IncomingBody {
51 fn from(body: crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>) -> Self {
52 IncomingBody::Quic(body)
53 }
54}
55
56impl http_body::Body for IncomingBody {
57 type Data = Bytes;
58 type Error = IncomingBodyError;
59
60 fn is_end_stream(&self) -> bool {
61 match self {
62 #[cfg(any(feature = "http1", feature = "http2"))]
63 IncomingBody::Hyper(body) => body.is_end_stream(),
64 #[cfg(feature = "http3")]
65 IncomingBody::Quic(body) => body.is_end_stream(),
66 #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
67 _ => false,
68 }
69 }
70
71 fn poll_frame(
72 self: std::pin::Pin<&mut Self>,
73 _cx: &mut std::task::Context<'_>,
74 ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
75 match self.get_mut() {
76 #[cfg(any(feature = "http1", feature = "http2"))]
77 IncomingBody::Hyper(body) => std::pin::Pin::new(body).poll_frame(_cx).map_err(Into::into),
78 #[cfg(feature = "http3")]
79 IncomingBody::Quic(body) => std::pin::Pin::new(body).poll_frame(_cx).map_err(Into::into),
80 #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
81 _ => std::task::Poll::Ready(None),
82 }
83 }
84
85 fn size_hint(&self) -> http_body::SizeHint {
86 match self {
87 #[cfg(any(feature = "http1", feature = "http2"))]
88 IncomingBody::Hyper(body) => body.size_hint(),
89 #[cfg(feature = "http3")]
90 IncomingBody::Quic(body) => body.size_hint(),
91 #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
92 _ => http_body::SizeHint::default(),
93 }
94 }
95}
96
97pin_project_lite::pin_project! {
98 pub struct TrackedBody<B, T> {
100 #[pin]
101 body: B,
102 tracker: T,
103 }
104}
105
106impl<B, T> TrackedBody<B, T> {
107 pub fn new(body: B, tracker: T) -> Self {
109 Self { body, tracker }
110 }
111}
112
113#[derive(thiserror::Error)]
115pub enum TrackedBodyError<B, T>
116where
117 B: http_body::Body,
118 T: Tracker,
119{
120 #[error("body error: {0}")]
122 Body(B::Error),
123 #[error("tracker error: {0}")]
125 Tracker(T::Error),
126}
127
128impl<B, T> Debug for TrackedBodyError<B, T>
129where
130 B: http_body::Body,
131 B::Error: Debug,
132 T: Tracker,
133 T::Error: Debug,
134{
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 match self {
137 TrackedBodyError::Body(err) => f.debug_tuple("TrackedBodyError::Body").field(err).finish(),
138 TrackedBodyError::Tracker(err) => f.debug_tuple("TrackedBodyError::Tracker").field(err).finish(),
139 }
140 }
141}
142
143pub trait Tracker: Send + Sync + 'static {
145 type Error;
147
148 fn on_data(&self, size: usize) -> Result<(), Self::Error> {
152 let _ = size;
153 Ok(())
154 }
155}
156
157impl<B, T> http_body::Body for TrackedBody<B, T>
158where
159 B: http_body::Body,
160 T: Tracker,
161{
162 type Data = B::Data;
163 type Error = TrackedBodyError<B, T>;
164
165 fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
166 let this = self.project();
167
168 match this.body.poll_frame(cx) {
169 Poll::Pending => Poll::Pending,
170 Poll::Ready(frame) => {
171 if let Some(Ok(frame)) = &frame {
172 if let Some(data) = frame.data_ref() {
173 if let Err(err) = this.tracker.on_data(data.remaining()) {
174 return Poll::Ready(Some(Err(TrackedBodyError::Tracker(err))));
175 }
176 }
177 }
178
179 Poll::Ready(frame.transpose().map_err(TrackedBodyError::Body).transpose())
180 }
181 }
182 }
183
184 fn is_end_stream(&self) -> bool {
185 self.body.is_end_stream()
186 }
187
188 fn size_hint(&self) -> http_body::SizeHint {
189 self.body.size_hint()
190 }
191}
192
193#[cfg(test)]
194#[cfg_attr(all(test, coverage_nightly), coverage(off))]
195mod tests {
196 use std::convert::Infallible;
197
198 use crate::body::TrackedBodyError;
199
200 #[test]
201 fn tracked_body_error_debug() {
202 struct TestTracker;
203
204 impl super::Tracker for TestTracker {
205 type Error = Infallible;
206 }
207
208 struct TestBody;
209
210 impl http_body::Body for TestBody {
211 type Data = bytes::Bytes;
212 type Error = ();
213
214 fn poll_frame(
215 self: std::pin::Pin<&mut Self>,
216 _cx: &mut std::task::Context<'_>,
217 ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
218 std::task::Poll::Ready(None)
219 }
220 }
221
222 let err = TrackedBodyError::<TestBody, TestTracker>::Body(());
223 assert_eq!(format!("{err:?}"), "TrackedBodyError::Body(())",);
224 }
225}