scuffle_http/backend/hyper/
mod.rs

1//! Hyper backend.
2use std::fmt::Debug;
3use std::net::SocketAddr;
4
5use scuffle_context::ContextFutExt;
6#[cfg(feature = "tracing")]
7use tracing::Instrument;
8
9use crate::error::HttpError;
10use crate::service::{HttpService, HttpServiceFactory};
11
12mod handler;
13mod stream;
14mod utils;
15
16/// A backend that handles incoming HTTP connections using a hyper backend.
17///
18/// This is used internally by the [`HttpServer`](crate::server::HttpServer) but can be used directly if preferred.
19///
20/// Call [`run`](HyperBackend::run) to start the server.
21#[derive(Debug, Clone, bon::Builder)]
22pub struct HyperBackend<F> {
23    /// The [`scuffle_context::Context`] this server will live by.
24    #[builder(default = scuffle_context::Context::global())]
25    ctx: scuffle_context::Context,
26    /// The number of worker tasks to spawn for each server backend.
27    #[builder(default = 1)]
28    worker_tasks: usize,
29    /// The service factory that will be used to create new services.
30    service_factory: F,
31    /// The address to bind to.
32    ///
33    /// Use `[::]` for a dual-stack listener.
34    /// For example, use `[::]:80` to bind to port 80 on both IPv4 and IPv6.
35    bind: SocketAddr,
36    /// rustls config.
37    ///
38    /// Use this field to set the server into TLS mode.
39    /// It will only accept TLS connections when this is set.
40    #[cfg(feature = "tls-rustls")]
41    #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))]
42    rustls_config: Option<rustls::ServerConfig>,
43    /// Enable HTTP/1.1.
44    #[cfg(feature = "http1")]
45    #[cfg_attr(docsrs, doc(cfg(feature = "http1")))]
46    #[builder(default = true)]
47    http1_enabled: bool,
48    /// Enable HTTP/2.
49    #[cfg(feature = "http2")]
50    #[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
51    #[builder(default = true)]
52    http2_enabled: bool,
53}
54
55impl<F> HyperBackend<F>
56where
57    F: HttpServiceFactory + Clone + Send + 'static,
58    F::Error: std::error::Error + Send,
59    F::Service: Clone + Send + 'static,
60    <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
61    <F::Service as HttpService>::ResBody: Send,
62    <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
63    <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
64{
65    /// Run the HTTP server
66    ///
67    /// This function will bind to the address specified in `bind`, listen for incoming connections and handle requests.
68    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
69    #[allow(unused_mut)] // allow the unused `mut self`
70    pub async fn run(mut self) -> Result<(), HttpError<F>> {
71        #[cfg(feature = "tracing")]
72        tracing::debug!("starting server");
73
74        // reset to 0 because everything explodes if it's not
75        // https://github.com/hyperium/hyper/issues/3841
76        #[cfg(feature = "tls-rustls")]
77        if let Some(rustls_config) = self.rustls_config.as_mut() {
78            rustls_config.max_early_data_size = 0;
79        }
80
81        // We have to create an std listener first because the tokio listener isn't clonable
82        let listener = tokio::net::TcpListener::bind(self.bind).await?.into_std()?;
83
84        #[cfg(feature = "tls-rustls")]
85        let tls_acceptor = self
86            .rustls_config
87            .map(|c| tokio_rustls::TlsAcceptor::from(std::sync::Arc::new(c)));
88
89        // Create a child context for the workers so we can shut them down if one of them fails without shutting down the main context
90        let (worker_ctx, worker_handler) = self.ctx.new_child();
91
92        let workers = (0..self.worker_tasks)
93            .map(|_n| {
94                let service_factory = self.service_factory.clone();
95                let ctx = worker_ctx.clone();
96                let std_listener = listener.try_clone()?;
97                let listener = tokio::net::TcpListener::from_std(std_listener)?;
98                #[cfg(feature = "tls-rustls")]
99                let tls_acceptor = tls_acceptor.clone();
100
101                let worker_fut = async move {
102                    loop {
103                        #[cfg(feature = "tracing")]
104                        tracing::trace!("waiting for connections");
105
106                        let (mut stream, addr) = match listener.accept().with_context(ctx.clone()).await {
107                            Some(Ok((tcp_stream, addr))) => (stream::Stream::Tcp(tcp_stream), addr),
108                            Some(Err(e)) if utils::is_fatal_tcp_error(&e) => {
109                                #[cfg(feature = "tracing")]
110                                tracing::error!(err = %e, "failed to accept tcp connection");
111                                return Err(HttpError::<F>::from(e));
112                            }
113                            Some(Err(_)) => continue,
114                            None => {
115                                #[cfg(feature = "tracing")]
116                                tracing::trace!("context done, stopping listener");
117                                break;
118                            }
119                        };
120
121                        #[cfg(feature = "tracing")]
122                        tracing::trace!(addr = %addr, "accepted tcp connection");
123
124                        let ctx = ctx.clone();
125                        #[cfg(feature = "tls-rustls")]
126                        let tls_acceptor = tls_acceptor.clone();
127                        let mut service_factory = service_factory.clone();
128
129                        let connection_fut = async move {
130                            // Perform the TLS handshake if the acceptor is set
131                            #[cfg(feature = "tls-rustls")]
132                            if let Some(tls_acceptor) = tls_acceptor {
133                                #[cfg(feature = "tracing")]
134                                tracing::trace!("accepting tls connection");
135
136                                stream = match stream.try_accept_tls(&tls_acceptor).with_context(&ctx).await {
137                                    Some(Ok(stream)) => stream,
138                                    Some(Err(_err)) => {
139                                        #[cfg(feature = "tracing")]
140                                        tracing::warn!(err = %_err, "failed to accept tls connection");
141                                        return;
142                                    }
143                                    None => {
144                                        #[cfg(feature = "tracing")]
145                                        tracing::trace!("context done, stopping tls acceptor");
146                                        return;
147                                    }
148                                };
149
150                                #[cfg(feature = "tracing")]
151                                tracing::trace!("accepted tls connection");
152                            }
153
154                            // make a new service
155                            let http_service = match service_factory.new_service(addr).await {
156                                Ok(service) => service,
157                                Err(_e) => {
158                                    #[cfg(feature = "tracing")]
159                                    tracing::warn!(err = %_e, "failed to create service");
160                                    return;
161                                }
162                            };
163
164                            #[cfg(feature = "tracing")]
165                            tracing::trace!("handling connection");
166
167                            #[cfg(feature = "http1")]
168                            let http1 = self.http1_enabled;
169                            #[cfg(not(feature = "http1"))]
170                            let http1 = false;
171
172                            #[cfg(feature = "http2")]
173                            let http2 = self.http2_enabled;
174                            #[cfg(not(feature = "http2"))]
175                            let http2 = false;
176
177                            let _res = handler::handle_connection::<F, _, _>(ctx, http_service, stream, http1, http2).await;
178
179                            #[cfg(feature = "tracing")]
180                            if let Err(e) = _res {
181                                tracing::warn!(err = %e, "error handling connection");
182                            }
183
184                            #[cfg(feature = "tracing")]
185                            tracing::trace!("connection closed");
186                        };
187
188                        #[cfg(feature = "tracing")]
189                        let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
190
191                        tokio::spawn(connection_fut);
192                    }
193
194                    #[cfg(feature = "tracing")]
195                    tracing::trace!("listener closed");
196
197                    Ok(())
198                };
199
200                #[cfg(feature = "tracing")]
201                let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
202
203                Ok(tokio::spawn(worker_fut))
204            })
205            .collect::<std::io::Result<Vec<_>>>()?;
206
207        match futures::future::try_join_all(workers).await {
208            Ok(res) => {
209                for r in res {
210                    if let Err(e) = r {
211                        drop(worker_ctx);
212                        worker_handler.shutdown().await;
213                        return Err(e);
214                    }
215                }
216            }
217            Err(_e) => {
218                #[cfg(feature = "tracing")]
219                tracing::error!(err = %_e, "error running workers");
220            }
221        }
222
223        drop(worker_ctx);
224        worker_handler.shutdown().await;
225
226        #[cfg(feature = "tracing")]
227        tracing::debug!("all workers finished");
228
229        Ok(())
230    }
231}