scuffle_signal/
lib.rs

1//! A crate designed to provide a more user friendly interface to
2//! `tokio::signal`.
3//!
4//! ## Why do we need this?
5//!
6//! The `tokio::signal` module provides a way for us to wait for a signal to be
7//! received in a non-blocking way. This crate extends that with a more helpful
8//! interface allowing the ability to listen to multiple signals concurrently.
9//!
10//! ## Example
11//!
12//! ```rust
13//! # #[cfg(unix)]
14//! # {
15//! use scuffle_signal::SignalHandler;
16//! use tokio::signal::unix::SignalKind;
17//!
18//! # tokio_test::block_on(async {
19//! let mut handler = SignalHandler::new()
20//!     .with_signal(SignalKind::interrupt())
21//!     .with_signal(SignalKind::terminate());
22//!
23//! # // Safety: This is a test, and we control the process.
24//! # unsafe {
25//! #    libc::raise(SignalKind::interrupt().as_raw_value());
26//! # }
27//! // Wait for a signal to be received
28//! let signal = handler.await;
29//!
30//! // Handle the signal
31//! let interrupt = SignalKind::interrupt();
32//! let terminate = SignalKind::terminate();
33//! match signal {
34//!     interrupt => {
35//!         // Handle SIGINT
36//!         println!("received SIGINT");
37//!     },
38//!     terminate => {
39//!         // Handle SIGTERM
40//!         println!("received SIGTERM");
41//!     },
42//! }
43//! # });
44//! # }
45//! ```
46//!
47//! ## Status
48//!
49//! This crate is currently under development and is not yet stable.
50//!
51//! Unit tests are not yet fully implemented. Use at your own risk.
52//!
53//! ## License
54//!
55//! This project is licensed under the [MIT](./LICENSE.MIT) or
56//! [Apache-2.0](./LICENSE.Apache-2.0) license. You can choose between one of
57//! them if you use this work.
58//!
59//! `SPDX-License-Identifier: MIT OR Apache-2.0`
60#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
61#![deny(missing_docs)]
62#![deny(unreachable_pub)]
63#![deny(clippy::undocumented_unsafe_blocks)]
64#![deny(clippy::multiple_unsafe_ops_per_block)]
65
66use std::pin::Pin;
67use std::task::{Context, Poll};
68
69#[cfg(unix)]
70use tokio::signal::unix;
71#[cfg(unix)]
72pub use tokio::signal::unix::SignalKind as UnixSignalKind;
73
74#[cfg(feature = "bootstrap")]
75mod bootstrap;
76
77#[cfg(feature = "bootstrap")]
78pub use bootstrap::{SignalConfig, SignalSvc};
79
80/// The type of signal to listen for.
81#[derive(Debug, Clone, Copy, Eq)]
82pub enum SignalKind {
83    /// Represents the interrupt signal, which is `SIGINT` on Unix and `Ctrl-C` on Windows.
84    Interrupt,
85    /// Represents the terminate signal, which is `SIGTERM` on Unix and `Ctrl-Close` on Windows.
86    Terminate,
87    /// Represents a Windows-specific signal kind, as defined in `WindowsSignalKind`.
88    #[cfg(windows)]
89    Windows(WindowsSignalKind),
90    /// Represents a Unix-specific signal kind, wrapping `tokio::signal::unix::SignalKind`.
91    #[cfg(unix)]
92    Unix(UnixSignalKind),
93}
94
95impl PartialEq for SignalKind {
96    fn eq(&self, other: &Self) -> bool {
97        #[cfg(unix)]
98        const INTERRUPT: UnixSignalKind = UnixSignalKind::interrupt();
99        #[cfg(unix)]
100        const TERMINATE: UnixSignalKind = UnixSignalKind::terminate();
101
102        match (self, other) {
103            #[cfg(windows)]
104            (
105                Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC),
106                Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC),
107            ) => true,
108            #[cfg(windows)]
109            (
110                Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose),
111                Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose),
112            ) => true,
113            #[cfg(windows)]
114            (Self::Windows(a), Self::Windows(b)) => a == b,
115            #[cfg(unix)]
116            (Self::Interrupt | Self::Unix(INTERRUPT), Self::Interrupt | Self::Unix(INTERRUPT)) => true,
117            #[cfg(unix)]
118            (Self::Terminate | Self::Unix(TERMINATE), Self::Terminate | Self::Unix(TERMINATE)) => true,
119            #[cfg(unix)]
120            (Self::Unix(a), Self::Unix(b)) => a == b,
121            _ => false,
122        }
123    }
124}
125
126#[cfg(unix)]
127impl From<UnixSignalKind> for SignalKind {
128    fn from(value: UnixSignalKind) -> Self {
129        match value {
130            kind if kind == UnixSignalKind::interrupt() => Self::Interrupt,
131            kind if kind == UnixSignalKind::terminate() => Self::Terminate,
132            kind => Self::Unix(kind),
133        }
134    }
135}
136
137#[cfg(unix)]
138impl PartialEq<UnixSignalKind> for SignalKind {
139    fn eq(&self, other: &UnixSignalKind) -> bool {
140        match self {
141            Self::Interrupt => other == &UnixSignalKind::interrupt(),
142            Self::Terminate => other == &UnixSignalKind::terminate(),
143            Self::Unix(kind) => kind == other,
144        }
145    }
146}
147
148/// Represents Windows-specific signal kinds.
149#[cfg(windows)]
150#[derive(Debug, Clone, Copy, PartialEq, Eq)]
151pub enum WindowsSignalKind {
152    /// Represents the `Ctrl-Break` signal.
153    CtrlBreak,
154    /// Represents the `Ctrl-C` signal.
155    CtrlC,
156    /// Represents the `Ctrl-Close` signal.
157    CtrlClose,
158    /// Represents the `Ctrl-Logoff` signal.
159    CtrlLogoff,
160    /// Represents the `Ctrl-Shutdown` signal.
161    CtrlShutdown,
162}
163
164#[cfg(windows)]
165impl From<WindowsSignalKind> for SignalKind {
166    fn from(value: WindowsSignalKind) -> Self {
167        match value {
168            WindowsSignalKind::CtrlC => Self::Interrupt,
169            WindowsSignalKind::CtrlClose => Self::Terminate,
170            WindowsSignalKind::CtrlBreak => Self::Windows(value),
171            WindowsSignalKind::CtrlLogoff => Self::Windows(value),
172            WindowsSignalKind::CtrlShutdown => Self::Windows(value),
173        }
174    }
175}
176
177#[cfg(windows)]
178impl PartialEq<WindowsSignalKind> for SignalKind {
179    fn eq(&self, other: &WindowsSignalKind) -> bool {
180        match self {
181            Self::Interrupt => other == &WindowsSignalKind::CtrlC,
182            Self::Terminate => other == &WindowsSignalKind::CtrlClose,
183            Self::Windows(kind) => kind == other,
184        }
185    }
186}
187
188#[cfg(windows)]
189#[derive(Debug)]
190enum WindowsSignalValue {
191    CtrlBreak(tokio::signal::windows::CtrlBreak),
192    CtrlC(tokio::signal::windows::CtrlC),
193    CtrlClose(tokio::signal::windows::CtrlClose),
194    CtrlLogoff(tokio::signal::windows::CtrlLogoff),
195    CtrlShutdown(tokio::signal::windows::CtrlShutdown),
196    #[cfg(test)]
197    Mock(SignalKind, Pin<Box<tokio_stream::wrappers::BroadcastStream<SignalKind>>>),
198}
199
200#[cfg(windows)]
201impl WindowsSignalValue {
202    fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
203        #[cfg(test)]
204        use futures::Stream;
205
206        match self {
207            Self::CtrlBreak(signal) => signal.poll_recv(cx),
208            Self::CtrlC(signal) => signal.poll_recv(cx),
209            Self::CtrlClose(signal) => signal.poll_recv(cx),
210            Self::CtrlLogoff(signal) => signal.poll_recv(cx),
211            Self::CtrlShutdown(signal) => signal.poll_recv(cx),
212            #[cfg(test)]
213            Self::Mock(kind, receiver) => match receiver.as_mut().poll_next(cx) {
214                Poll::Ready(Some(Ok(recv))) if recv == *kind => Poll::Ready(Some(())),
215                Poll::Ready(Some(Ok(_))) => {
216                    cx.waker().wake_by_ref();
217                    Poll::Pending
218                }
219                Poll::Ready(v) => unreachable!("receiver should always have a value: {:?}", v),
220                Poll::Pending => {
221                    cx.waker().wake_by_ref();
222                    Poll::Pending
223                }
224            },
225        }
226    }
227}
228
229#[cfg(unix)]
230type Signal = unix::Signal;
231
232#[cfg(windows)]
233type Signal = WindowsSignalValue;
234
235impl SignalKind {
236    #[cfg(unix)]
237    fn listen(&self) -> Result<Signal, std::io::Error> {
238        match self {
239            Self::Interrupt => tokio::signal::unix::signal(UnixSignalKind::interrupt()),
240            Self::Terminate => tokio::signal::unix::signal(UnixSignalKind::terminate()),
241            Self::Unix(kind) => tokio::signal::unix::signal(*kind),
242        }
243    }
244
245    #[cfg(windows)]
246    fn listen(&self) -> Result<Signal, std::io::Error> {
247        #[cfg(test)]
248        if cfg!(test) {
249            return Ok(WindowsSignalValue::Mock(
250                *self,
251                Box::pin(tokio_stream::wrappers::BroadcastStream::new(test::SignalMocker::subscribe())),
252            ));
253        }
254
255        match self {
256            // https://learn.microsoft.com/en-us/windows/console/ctrl-c-and-ctrl-break-signals
257            Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC) => {
258                Ok(WindowsSignalValue::CtrlC(tokio::signal::windows::ctrl_c()?))
259            }
260            // https://learn.microsoft.com/en-us/windows/console/ctrl-close-signal
261            Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose) => {
262                Ok(WindowsSignalValue::CtrlClose(tokio::signal::windows::ctrl_close()?))
263            }
264            Self::Windows(WindowsSignalKind::CtrlBreak) => {
265                Ok(WindowsSignalValue::CtrlBreak(tokio::signal::windows::ctrl_break()?))
266            }
267            Self::Windows(WindowsSignalKind::CtrlLogoff) => {
268                Ok(WindowsSignalValue::CtrlLogoff(tokio::signal::windows::ctrl_logoff()?))
269            }
270            Self::Windows(WindowsSignalKind::CtrlShutdown) => {
271                Ok(WindowsSignalValue::CtrlShutdown(tokio::signal::windows::ctrl_shutdown()?))
272            }
273        }
274    }
275}
276
277/// A handler for listening to multiple signals, and providing a future for
278/// receiving them.
279///
280/// This is useful for applications that need to listen for multiple signals,
281/// and want to react to them in a non-blocking way. Typically you would need to
282/// use a tokio::select{} to listen for multiple signals, but this provides a
283/// more ergonomic interface for doing so.
284///
285/// After a signal is received you can poll the handler again to wait for
286/// another signal. Dropping the handle will cancel the signal subscription
287///
288/// # Example
289///
290/// ```rust
291/// # #[cfg(unix)]
292/// # {
293/// use scuffle_signal::SignalHandler;
294/// use tokio::signal::unix::SignalKind;
295///
296/// # tokio_test::block_on(async {
297/// let mut handler = SignalHandler::new()
298///     .with_signal(SignalKind::interrupt())
299///     .with_signal(SignalKind::terminate());
300///
301/// # // Safety: This is a test, and we control the process.
302/// # unsafe {
303/// #    libc::raise(SignalKind::interrupt().as_raw_value());
304/// # }
305/// // Wait for a signal to be received
306/// let signal = handler.await;
307///
308/// // Handle the signal
309/// let interrupt = SignalKind::interrupt();
310/// let terminate = SignalKind::terminate();
311/// match signal {
312///     interrupt => {
313///         // Handle SIGINT
314///         println!("received SIGINT");
315///     },
316///     terminate => {
317///         // Handle SIGTERM
318///         println!("received SIGTERM");
319///     },
320/// }
321/// # });
322/// # }
323/// ```
324#[derive(Debug)]
325#[must_use = "signal handlers must be used to wait for signals"]
326pub struct SignalHandler {
327    signals: Vec<(SignalKind, Signal)>,
328}
329
330impl Default for SignalHandler {
331    fn default() -> Self {
332        Self::new()
333    }
334}
335
336impl SignalHandler {
337    /// Create a new `SignalHandler` with no signals.
338    pub const fn new() -> Self {
339        Self { signals: Vec::new() }
340    }
341
342    /// Create a new `SignalHandler` with the given signals.
343    pub fn with_signals<T: Into<SignalKind>>(signals: impl IntoIterator<Item = T>) -> Self {
344        let mut handler = Self::new();
345
346        for signal in signals {
347            handler = handler.with_signal(signal.into());
348        }
349
350        handler
351    }
352
353    /// Add a signal to the handler.
354    ///
355    /// If the signal is already in the handler, it will not be added again.
356    pub fn with_signal(mut self, kind: impl Into<SignalKind>) -> Self {
357        self.add_signal(kind);
358        self
359    }
360
361    /// Add a signal to the handler.
362    ///
363    /// If the signal is already in the handler, it will not be added again.
364    pub fn add_signal(&mut self, kind: impl Into<SignalKind>) -> &mut Self {
365        let kind = kind.into();
366        if self.signals.iter().any(|(k, _)| k == &kind) {
367            return self;
368        }
369
370        let signal = kind.listen().expect("failed to create signal");
371
372        self.signals.push((kind, signal));
373
374        self
375    }
376
377    /// Wait for a signal to be received.
378    /// This is equivilant to calling (&mut handler).await, but is more
379    /// ergonomic if you want to not take ownership of the handler.
380    pub async fn recv(&mut self) -> SignalKind {
381        self.await
382    }
383
384    /// Poll for a signal to be received.
385    /// Does not require pinning the handler.
386    pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<SignalKind> {
387        for (kind, signal) in self.signals.iter_mut() {
388            if signal.poll_recv(cx).is_ready() {
389                return Poll::Ready(*kind);
390            }
391        }
392
393        Poll::Pending
394    }
395}
396
397impl std::future::Future for SignalHandler {
398    type Output = SignalKind;
399
400    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
401        self.poll_recv(cx)
402    }
403}
404
405#[cfg(test)]
406#[cfg_attr(coverage_nightly, coverage(off))]
407mod test {
408    use std::time::Duration;
409
410    use scuffle_future_ext::FutureExt;
411
412    use crate::{SignalHandler, SignalKind};
413
414    #[cfg(windows)]
415    pub(crate) struct SignalMocker(tokio::sync::broadcast::Sender<SignalKind>);
416
417    #[cfg(windows)]
418    impl SignalMocker {
419        fn new() -> Self {
420            println!("new");
421            let (sender, _) = tokio::sync::broadcast::channel(100);
422            Self(sender)
423        }
424
425        fn raise(kind: SignalKind) {
426            println!("raising");
427            SIGNAL_MOCKER.with(|local| local.0.send(kind).unwrap());
428        }
429
430        pub(crate) fn subscribe() -> tokio::sync::broadcast::Receiver<SignalKind> {
431            println!("subscribing");
432            SIGNAL_MOCKER.with(|local| local.0.subscribe())
433        }
434    }
435
436    #[cfg(windows)]
437    thread_local! {
438        static SIGNAL_MOCKER: SignalMocker = SignalMocker::new();
439    }
440
441    #[cfg(windows)]
442    pub(crate) async fn raise_signal(kind: SignalKind) {
443        SignalMocker::raise(kind);
444    }
445
446    #[cfg(unix)]
447    pub(crate) async fn raise_signal(kind: SignalKind) {
448        // Safety: This is a test, and we control the process.
449        unsafe {
450            libc::raise(match kind {
451                SignalKind::Interrupt => libc::SIGINT,
452                SignalKind::Terminate => libc::SIGTERM,
453                SignalKind::Unix(kind) => kind.as_raw_value(),
454            });
455        }
456    }
457
458    #[cfg(windows)]
459    #[tokio::test]
460    async fn signal_handler() {
461        use crate::WindowsSignalKind;
462
463        let mut handler = SignalHandler::with_signals([WindowsSignalKind::CtrlC, WindowsSignalKind::CtrlBreak]);
464
465        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
466
467        raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlC)).await;
468
469        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
470
471        assert_eq!(recv, WindowsSignalKind::CtrlC, "expected CtrlC");
472
473        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await;
474        assert!(recv.is_err(), "expected timeout");
475
476        raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlBreak)).await;
477
478        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
479
480        assert_eq!(recv, WindowsSignalKind::CtrlBreak, "expected CtrlBreak");
481    }
482
483    #[cfg(windows)]
484    #[tokio::test]
485    async fn add_signal() {
486        use crate::WindowsSignalKind;
487
488        let mut handler = SignalHandler::new();
489
490        handler
491            .add_signal(WindowsSignalKind::CtrlC)
492            .add_signal(WindowsSignalKind::CtrlBreak)
493            .add_signal(WindowsSignalKind::CtrlC);
494
495        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
496
497        raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlC)).await;
498
499        let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
500
501        assert_eq!(recv, WindowsSignalKind::CtrlC, "expected CtrlC");
502
503        raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlBreak)).await;
504
505        let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
506
507        assert_eq!(recv, WindowsSignalKind::CtrlBreak, "expected CtrlBreak");
508    }
509
510    #[cfg(all(not(valgrind), unix))] // test is time-sensitive
511    #[tokio::test]
512    async fn signal_handler() {
513        use crate::UnixSignalKind;
514
515        let mut handler = SignalHandler::with_signals([UnixSignalKind::user_defined1()])
516            .with_signal(UnixSignalKind::user_defined2())
517            .with_signal(UnixSignalKind::user_defined1());
518
519        raise_signal(SignalKind::Unix(UnixSignalKind::user_defined1())).await;
520
521        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
522
523        assert_eq!(recv, SignalKind::Unix(UnixSignalKind::user_defined1()), "expected SIGUSR1");
524
525        // We already received the signal, so polling again should return Poll::Pending
526        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await;
527
528        assert!(recv.is_err(), "expected timeout");
529
530        raise_signal(SignalKind::Unix(UnixSignalKind::user_defined2())).await;
531
532        // We should be able to receive the signal again
533        let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
534
535        assert_eq!(recv, UnixSignalKind::user_defined2(), "expected SIGUSR2");
536    }
537
538    #[cfg(all(not(valgrind), unix))] // test is time-sensitive
539    #[tokio::test]
540    async fn add_signal() {
541        use crate::UnixSignalKind;
542
543        let mut handler = SignalHandler::new();
544
545        handler
546            .add_signal(UnixSignalKind::user_defined1())
547            .add_signal(UnixSignalKind::user_defined2())
548            .add_signal(UnixSignalKind::user_defined2());
549
550        raise_signal(SignalKind::Unix(UnixSignalKind::user_defined1())).await;
551
552        let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
553
554        assert_eq!(recv, UnixSignalKind::user_defined1(), "expected SIGUSR1");
555
556        raise_signal(SignalKind::Unix(UnixSignalKind::user_defined2())).await;
557
558        let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
559
560        assert_eq!(recv, UnixSignalKind::user_defined2(), "expected SIGUSR2");
561    }
562
563    #[cfg(not(valgrind))] // test is time-sensitive
564    #[tokio::test]
565    async fn no_signals() {
566        let mut handler = SignalHandler::default();
567
568        // Expected to timeout
569        assert!(handler.recv().with_timeout(Duration::from_millis(500)).await.is_err());
570    }
571
572    #[cfg(windows)]
573    #[test]
574    fn signal_kind_eq() {
575        use crate::WindowsSignalKind;
576
577        assert_eq!(SignalKind::Interrupt, SignalKind::Windows(WindowsSignalKind::CtrlC));
578        assert_eq!(SignalKind::Terminate, SignalKind::Windows(WindowsSignalKind::CtrlClose));
579        assert_eq!(SignalKind::Windows(WindowsSignalKind::CtrlC), SignalKind::Interrupt);
580        assert_eq!(SignalKind::Windows(WindowsSignalKind::CtrlClose), SignalKind::Terminate);
581        assert_ne!(SignalKind::Interrupt, SignalKind::Terminate);
582        assert_eq!(
583            SignalKind::Windows(WindowsSignalKind::CtrlBreak),
584            SignalKind::Windows(WindowsSignalKind::CtrlBreak)
585        );
586    }
587
588    #[cfg(unix)]
589    #[test]
590    fn signal_kind_eq() {
591        use crate::UnixSignalKind;
592
593        assert_eq!(SignalKind::Interrupt, SignalKind::Unix(UnixSignalKind::interrupt()));
594        assert_eq!(SignalKind::Terminate, SignalKind::Unix(UnixSignalKind::terminate()));
595        assert_eq!(SignalKind::Unix(UnixSignalKind::interrupt()), SignalKind::Interrupt);
596        assert_eq!(SignalKind::Unix(UnixSignalKind::terminate()), SignalKind::Terminate);
597        assert_ne!(SignalKind::Interrupt, SignalKind::Terminate);
598        assert_eq!(
599            SignalKind::Unix(UnixSignalKind::user_defined1()),
600            SignalKind::Unix(UnixSignalKind::user_defined1())
601        );
602    }
603}