scuffle_http/backend/hyper/
mod.rs1use std::fmt::Debug;
3use std::net::SocketAddr;
4
5use scuffle_context::ContextFutExt;
6#[cfg(feature = "tracing")]
7use tracing::Instrument;
8
9use crate::error::HttpError;
10use crate::service::{HttpService, HttpServiceFactory};
11
12mod handler;
13mod stream;
14mod utils;
15
16#[derive(Debug, Clone, bon::Builder)]
22pub struct HyperBackend<F> {
23 #[builder(default = scuffle_context::Context::global())]
25 ctx: scuffle_context::Context,
26 #[builder(default = 1)]
28 worker_tasks: usize,
29 service_factory: F,
31 bind: SocketAddr,
36 #[cfg(feature = "tls-rustls")]
41 #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
42 rustls_config: Option<rustls::ServerConfig>,
43 #[cfg(feature = "http1")]
45 #[cfg_attr(docsrs, doc(cfg(feature = "http1")))]
46 #[builder(default = true)]
47 http1_enabled: bool,
48 #[cfg(feature = "http2")]
50 #[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
51 #[builder(default = true)]
52 http2_enabled: bool,
53}
54
55impl<F> HyperBackend<F>
56where
57 F: HttpServiceFactory + Clone + Send + 'static,
58 F::Error: std::error::Error + Send,
59 F::Service: Clone + Send + 'static,
60 <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
61 <F::Service as HttpService>::ResBody: Send,
62 <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
63 <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
64{
65 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
69 #[allow(unused_mut)] pub async fn run(mut self) -> Result<(), HttpError<F>> {
71 #[cfg(feature = "tracing")]
72 tracing::debug!("starting server");
73
74 #[cfg(feature = "tls-rustls")]
77 if let Some(rustls_config) = self.rustls_config.as_mut() {
78 rustls_config.max_early_data_size = 0;
79 }
80
81 let listener = tokio::net::TcpListener::bind(self.bind).await?.into_std()?;
83
84 #[cfg(feature = "tls-rustls")]
85 let tls_acceptor = self
86 .rustls_config
87 .map(|c| tokio_rustls::TlsAcceptor::from(std::sync::Arc::new(c)));
88
89 let (worker_ctx, worker_handler) = self.ctx.new_child();
91
92 let workers = (0..self.worker_tasks)
93 .map(|_n| {
94 let service_factory = self.service_factory.clone();
95 let ctx = worker_ctx.clone();
96 let std_listener = listener.try_clone()?;
97 let listener = tokio::net::TcpListener::from_std(std_listener)?;
98 #[cfg(feature = "tls-rustls")]
99 let tls_acceptor = tls_acceptor.clone();
100
101 let worker_fut = async move {
102 loop {
103 #[cfg(feature = "tracing")]
104 tracing::trace!("waiting for connections");
105
106 let (mut stream, addr) = match listener.accept().with_context(ctx.clone()).await {
107 Some(Ok((tcp_stream, addr))) => (stream::Stream::Tcp(tcp_stream), addr),
108 Some(Err(e)) if utils::is_fatal_tcp_error(&e) => {
109 #[cfg(feature = "tracing")]
110 tracing::error!(err = %e, "failed to accept tcp connection");
111 return Err(HttpError::<F>::from(e));
112 }
113 Some(Err(_)) => continue,
114 None => {
115 #[cfg(feature = "tracing")]
116 tracing::trace!("context done, stopping listener");
117 break;
118 }
119 };
120
121 #[cfg(feature = "tracing")]
122 tracing::trace!(addr = %addr, "accepted tcp connection");
123
124 let ctx = ctx.clone();
125 #[cfg(feature = "tls-rustls")]
126 let tls_acceptor = tls_acceptor.clone();
127 let mut service_factory = service_factory.clone();
128
129 let connection_fut = async move {
130 #[cfg(feature = "tls-rustls")]
132 if let Some(tls_acceptor) = tls_acceptor {
133 #[cfg(feature = "tracing")]
134 tracing::trace!("accepting tls connection");
135
136 stream = match stream.try_accept_tls(&tls_acceptor).with_context(&ctx).await {
137 Some(Ok(stream)) => stream,
138 Some(Err(_err)) => {
139 #[cfg(feature = "tracing")]
140 tracing::warn!(err = %_err, "failed to accept tls connection");
141 return;
142 }
143 None => {
144 #[cfg(feature = "tracing")]
145 tracing::trace!("context done, stopping tls acceptor");
146 return;
147 }
148 };
149
150 #[cfg(feature = "tracing")]
151 tracing::trace!("accepted tls connection");
152 }
153
154 let http_service = match service_factory.new_service(addr).await {
156 Ok(service) => service,
157 Err(_e) => {
158 #[cfg(feature = "tracing")]
159 tracing::warn!(err = %_e, "failed to create service");
160 return;
161 }
162 };
163
164 #[cfg(feature = "tracing")]
165 tracing::trace!("handling connection");
166
167 #[cfg(feature = "http1")]
168 let http1 = self.http1_enabled;
169 #[cfg(not(feature = "http1"))]
170 let http1 = false;
171
172 #[cfg(feature = "http2")]
173 let http2 = self.http2_enabled;
174 #[cfg(not(feature = "http2"))]
175 let http2 = false;
176
177 let _res = handler::handle_connection::<F, _, _>(ctx, http_service, stream, http1, http2).await;
178
179 #[cfg(feature = "tracing")]
180 if let Err(e) = _res {
181 tracing::warn!(err = %e, "error handling connection");
182 }
183
184 #[cfg(feature = "tracing")]
185 tracing::trace!("connection closed");
186 };
187
188 #[cfg(feature = "tracing")]
189 let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
190
191 tokio::spawn(connection_fut);
192 }
193
194 #[cfg(feature = "tracing")]
195 tracing::trace!("listener closed");
196
197 Ok(())
198 };
199
200 #[cfg(feature = "tracing")]
201 let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
202
203 Ok(tokio::spawn(worker_fut))
204 })
205 .collect::<std::io::Result<Vec<_>>>()?;
206
207 match futures::future::try_join_all(workers).await {
208 Ok(res) => {
209 for r in res {
210 if let Err(e) = r {
211 drop(worker_ctx);
212 worker_handler.shutdown().await;
213 return Err(e);
214 }
215 }
216 }
217 Err(_e) => {
218 #[cfg(feature = "tracing")]
219 tracing::error!(err = %_e, "error running workers");
220 }
221 }
222
223 drop(worker_ctx);
224 worker_handler.shutdown().await;
225
226 #[cfg(feature = "tracing")]
227 tracing::debug!("all workers finished");
228
229 Ok(())
230 }
231}