1#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
61#![deny(missing_docs)]
62#![deny(unreachable_pub)]
63#![deny(clippy::undocumented_unsafe_blocks)]
64#![deny(clippy::multiple_unsafe_ops_per_block)]
65
66use std::pin::Pin;
67use std::task::{Context, Poll};
68
69#[cfg(unix)]
70use tokio::signal::unix;
71#[cfg(unix)]
72pub use tokio::signal::unix::SignalKind as UnixSignalKind;
73
74#[cfg(feature = "bootstrap")]
75mod bootstrap;
76
77#[cfg(feature = "bootstrap")]
78pub use bootstrap::{SignalConfig, SignalSvc};
79
80#[derive(Debug, Clone, Copy, Eq)]
82pub enum SignalKind {
83 Interrupt,
85 Terminate,
87 #[cfg(windows)]
89 Windows(WindowsSignalKind),
90 #[cfg(unix)]
92 Unix(UnixSignalKind),
93}
94
95impl PartialEq for SignalKind {
96 fn eq(&self, other: &Self) -> bool {
97 #[cfg(unix)]
98 const INTERRUPT: UnixSignalKind = UnixSignalKind::interrupt();
99 #[cfg(unix)]
100 const TERMINATE: UnixSignalKind = UnixSignalKind::terminate();
101
102 match (self, other) {
103 #[cfg(windows)]
104 (
105 Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC),
106 Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC),
107 ) => true,
108 #[cfg(windows)]
109 (
110 Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose),
111 Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose),
112 ) => true,
113 #[cfg(windows)]
114 (Self::Windows(a), Self::Windows(b)) => a == b,
115 #[cfg(unix)]
116 (Self::Interrupt | Self::Unix(INTERRUPT), Self::Interrupt | Self::Unix(INTERRUPT)) => true,
117 #[cfg(unix)]
118 (Self::Terminate | Self::Unix(TERMINATE), Self::Terminate | Self::Unix(TERMINATE)) => true,
119 #[cfg(unix)]
120 (Self::Unix(a), Self::Unix(b)) => a == b,
121 _ => false,
122 }
123 }
124}
125
126#[cfg(unix)]
127impl From<UnixSignalKind> for SignalKind {
128 fn from(value: UnixSignalKind) -> Self {
129 match value {
130 kind if kind == UnixSignalKind::interrupt() => Self::Interrupt,
131 kind if kind == UnixSignalKind::terminate() => Self::Terminate,
132 kind => Self::Unix(kind),
133 }
134 }
135}
136
137#[cfg(unix)]
138impl PartialEq<UnixSignalKind> for SignalKind {
139 fn eq(&self, other: &UnixSignalKind) -> bool {
140 match self {
141 Self::Interrupt => other == &UnixSignalKind::interrupt(),
142 Self::Terminate => other == &UnixSignalKind::terminate(),
143 Self::Unix(kind) => kind == other,
144 }
145 }
146}
147
148#[cfg(windows)]
150#[derive(Debug, Clone, Copy, PartialEq, Eq)]
151pub enum WindowsSignalKind {
152 CtrlBreak,
154 CtrlC,
156 CtrlClose,
158 CtrlLogoff,
160 CtrlShutdown,
162}
163
164#[cfg(windows)]
165impl From<WindowsSignalKind> for SignalKind {
166 fn from(value: WindowsSignalKind) -> Self {
167 match value {
168 WindowsSignalKind::CtrlC => Self::Interrupt,
169 WindowsSignalKind::CtrlClose => Self::Terminate,
170 WindowsSignalKind::CtrlBreak => Self::Windows(value),
171 WindowsSignalKind::CtrlLogoff => Self::Windows(value),
172 WindowsSignalKind::CtrlShutdown => Self::Windows(value),
173 }
174 }
175}
176
177#[cfg(windows)]
178impl PartialEq<WindowsSignalKind> for SignalKind {
179 fn eq(&self, other: &WindowsSignalKind) -> bool {
180 match self {
181 Self::Interrupt => other == &WindowsSignalKind::CtrlC,
182 Self::Terminate => other == &WindowsSignalKind::CtrlClose,
183 Self::Windows(kind) => kind == other,
184 }
185 }
186}
187
188#[cfg(windows)]
189#[derive(Debug)]
190enum WindowsSignalValue {
191 CtrlBreak(tokio::signal::windows::CtrlBreak),
192 CtrlC(tokio::signal::windows::CtrlC),
193 CtrlClose(tokio::signal::windows::CtrlClose),
194 CtrlLogoff(tokio::signal::windows::CtrlLogoff),
195 CtrlShutdown(tokio::signal::windows::CtrlShutdown),
196 #[cfg(test)]
197 Mock(SignalKind, Pin<Box<tokio_stream::wrappers::BroadcastStream<SignalKind>>>),
198}
199
200#[cfg(windows)]
201impl WindowsSignalValue {
202 fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
203 #[cfg(test)]
204 use futures::Stream;
205
206 match self {
207 Self::CtrlBreak(signal) => signal.poll_recv(cx),
208 Self::CtrlC(signal) => signal.poll_recv(cx),
209 Self::CtrlClose(signal) => signal.poll_recv(cx),
210 Self::CtrlLogoff(signal) => signal.poll_recv(cx),
211 Self::CtrlShutdown(signal) => signal.poll_recv(cx),
212 #[cfg(test)]
213 Self::Mock(kind, receiver) => match receiver.as_mut().poll_next(cx) {
214 Poll::Ready(Some(Ok(recv))) if recv == *kind => Poll::Ready(Some(())),
215 Poll::Ready(Some(Ok(_))) => {
216 cx.waker().wake_by_ref();
217 Poll::Pending
218 }
219 Poll::Ready(v) => unreachable!("receiver should always have a value: {:?}", v),
220 Poll::Pending => {
221 cx.waker().wake_by_ref();
222 Poll::Pending
223 }
224 },
225 }
226 }
227}
228
229#[cfg(unix)]
230type Signal = unix::Signal;
231
232#[cfg(windows)]
233type Signal = WindowsSignalValue;
234
235impl SignalKind {
236 #[cfg(unix)]
237 fn listen(&self) -> Result<Signal, std::io::Error> {
238 match self {
239 Self::Interrupt => tokio::signal::unix::signal(UnixSignalKind::interrupt()),
240 Self::Terminate => tokio::signal::unix::signal(UnixSignalKind::terminate()),
241 Self::Unix(kind) => tokio::signal::unix::signal(*kind),
242 }
243 }
244
245 #[cfg(windows)]
246 fn listen(&self) -> Result<Signal, std::io::Error> {
247 #[cfg(test)]
248 if cfg!(test) {
249 return Ok(WindowsSignalValue::Mock(
250 *self,
251 Box::pin(tokio_stream::wrappers::BroadcastStream::new(test::SignalMocker::subscribe())),
252 ));
253 }
254
255 match self {
256 Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC) => {
258 Ok(WindowsSignalValue::CtrlC(tokio::signal::windows::ctrl_c()?))
259 }
260 Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose) => {
262 Ok(WindowsSignalValue::CtrlClose(tokio::signal::windows::ctrl_close()?))
263 }
264 Self::Windows(WindowsSignalKind::CtrlBreak) => {
265 Ok(WindowsSignalValue::CtrlBreak(tokio::signal::windows::ctrl_break()?))
266 }
267 Self::Windows(WindowsSignalKind::CtrlLogoff) => {
268 Ok(WindowsSignalValue::CtrlLogoff(tokio::signal::windows::ctrl_logoff()?))
269 }
270 Self::Windows(WindowsSignalKind::CtrlShutdown) => {
271 Ok(WindowsSignalValue::CtrlShutdown(tokio::signal::windows::ctrl_shutdown()?))
272 }
273 }
274 }
275}
276
277#[derive(Debug)]
325#[must_use = "signal handlers must be used to wait for signals"]
326pub struct SignalHandler {
327 signals: Vec<(SignalKind, Signal)>,
328}
329
330impl Default for SignalHandler {
331 fn default() -> Self {
332 Self::new()
333 }
334}
335
336impl SignalHandler {
337 pub const fn new() -> Self {
339 Self { signals: Vec::new() }
340 }
341
342 pub fn with_signals<T: Into<SignalKind>>(signals: impl IntoIterator<Item = T>) -> Self {
344 let mut handler = Self::new();
345
346 for signal in signals {
347 handler = handler.with_signal(signal.into());
348 }
349
350 handler
351 }
352
353 pub fn with_signal(mut self, kind: impl Into<SignalKind>) -> Self {
357 self.add_signal(kind);
358 self
359 }
360
361 pub fn add_signal(&mut self, kind: impl Into<SignalKind>) -> &mut Self {
365 let kind = kind.into();
366 if self.signals.iter().any(|(k, _)| k == &kind) {
367 return self;
368 }
369
370 let signal = kind.listen().expect("failed to create signal");
371
372 self.signals.push((kind, signal));
373
374 self
375 }
376
377 pub async fn recv(&mut self) -> SignalKind {
381 self.await
382 }
383
384 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<SignalKind> {
387 for (kind, signal) in self.signals.iter_mut() {
388 if signal.poll_recv(cx).is_ready() {
389 return Poll::Ready(*kind);
390 }
391 }
392
393 Poll::Pending
394 }
395}
396
397impl std::future::Future for SignalHandler {
398 type Output = SignalKind;
399
400 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
401 self.poll_recv(cx)
402 }
403}
404
405#[cfg(test)]
406#[cfg_attr(coverage_nightly, coverage(off))]
407mod test {
408 use std::time::Duration;
409
410 use scuffle_future_ext::FutureExt;
411
412 use crate::{SignalHandler, SignalKind};
413
414 #[cfg(windows)]
415 pub(crate) struct SignalMocker(tokio::sync::broadcast::Sender<SignalKind>);
416
417 #[cfg(windows)]
418 impl SignalMocker {
419 fn new() -> Self {
420 println!("new");
421 let (sender, _) = tokio::sync::broadcast::channel(100);
422 Self(sender)
423 }
424
425 fn raise(kind: SignalKind) {
426 println!("raising");
427 SIGNAL_MOCKER.with(|local| local.0.send(kind).unwrap());
428 }
429
430 pub(crate) fn subscribe() -> tokio::sync::broadcast::Receiver<SignalKind> {
431 println!("subscribing");
432 SIGNAL_MOCKER.with(|local| local.0.subscribe())
433 }
434 }
435
436 #[cfg(windows)]
437 thread_local! {
438 static SIGNAL_MOCKER: SignalMocker = SignalMocker::new();
439 }
440
441 #[cfg(windows)]
442 pub(crate) async fn raise_signal(kind: SignalKind) {
443 SignalMocker::raise(kind);
444 }
445
446 #[cfg(unix)]
447 pub(crate) async fn raise_signal(kind: SignalKind) {
448 unsafe {
450 libc::raise(match kind {
451 SignalKind::Interrupt => libc::SIGINT,
452 SignalKind::Terminate => libc::SIGTERM,
453 SignalKind::Unix(kind) => kind.as_raw_value(),
454 });
455 }
456 }
457
458 #[cfg(windows)]
459 #[tokio::test]
460 async fn signal_handler() {
461 use crate::WindowsSignalKind;
462
463 let mut handler = SignalHandler::with_signals([WindowsSignalKind::CtrlC, WindowsSignalKind::CtrlBreak]);
464
465 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
466
467 raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlC)).await;
468
469 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
470
471 assert_eq!(recv, WindowsSignalKind::CtrlC, "expected CtrlC");
472
473 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await;
474 assert!(recv.is_err(), "expected timeout");
475
476 raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlBreak)).await;
477
478 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
479
480 assert_eq!(recv, WindowsSignalKind::CtrlBreak, "expected CtrlBreak");
481 }
482
483 #[cfg(windows)]
484 #[tokio::test]
485 async fn add_signal() {
486 use crate::WindowsSignalKind;
487
488 let mut handler = SignalHandler::new();
489
490 handler
491 .add_signal(WindowsSignalKind::CtrlC)
492 .add_signal(WindowsSignalKind::CtrlBreak)
493 .add_signal(WindowsSignalKind::CtrlC);
494
495 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
496
497 raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlC)).await;
498
499 let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
500
501 assert_eq!(recv, WindowsSignalKind::CtrlC, "expected CtrlC");
502
503 raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlBreak)).await;
504
505 let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
506
507 assert_eq!(recv, WindowsSignalKind::CtrlBreak, "expected CtrlBreak");
508 }
509
510 #[cfg(all(not(valgrind), unix))] #[tokio::test]
512 async fn signal_handler() {
513 use crate::UnixSignalKind;
514
515 let mut handler = SignalHandler::with_signals([UnixSignalKind::user_defined1()])
516 .with_signal(UnixSignalKind::user_defined2())
517 .with_signal(UnixSignalKind::user_defined1());
518
519 raise_signal(SignalKind::Unix(UnixSignalKind::user_defined1())).await;
520
521 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
522
523 assert_eq!(recv, SignalKind::Unix(UnixSignalKind::user_defined1()), "expected SIGUSR1");
524
525 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await;
527
528 assert!(recv.is_err(), "expected timeout");
529
530 raise_signal(SignalKind::Unix(UnixSignalKind::user_defined2())).await;
531
532 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
534
535 assert_eq!(recv, UnixSignalKind::user_defined2(), "expected SIGUSR2");
536 }
537
538 #[cfg(all(not(valgrind), unix))] #[tokio::test]
540 async fn add_signal() {
541 use crate::UnixSignalKind;
542
543 let mut handler = SignalHandler::new();
544
545 handler
546 .add_signal(UnixSignalKind::user_defined1())
547 .add_signal(UnixSignalKind::user_defined2())
548 .add_signal(UnixSignalKind::user_defined2());
549
550 raise_signal(SignalKind::Unix(UnixSignalKind::user_defined1())).await;
551
552 let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
553
554 assert_eq!(recv, UnixSignalKind::user_defined1(), "expected SIGUSR1");
555
556 raise_signal(SignalKind::Unix(UnixSignalKind::user_defined2())).await;
557
558 let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
559
560 assert_eq!(recv, UnixSignalKind::user_defined2(), "expected SIGUSR2");
561 }
562
563 #[cfg(not(valgrind))] #[tokio::test]
565 async fn no_signals() {
566 let mut handler = SignalHandler::default();
567
568 assert!(handler.recv().with_timeout(Duration::from_millis(500)).await.is_err());
570 }
571
572 #[cfg(windows)]
573 #[test]
574 fn signal_kind_eq() {
575 use crate::WindowsSignalKind;
576
577 assert_eq!(SignalKind::Interrupt, SignalKind::Windows(WindowsSignalKind::CtrlC));
578 assert_eq!(SignalKind::Terminate, SignalKind::Windows(WindowsSignalKind::CtrlClose));
579 assert_eq!(SignalKind::Windows(WindowsSignalKind::CtrlC), SignalKind::Interrupt);
580 assert_eq!(SignalKind::Windows(WindowsSignalKind::CtrlClose), SignalKind::Terminate);
581 assert_ne!(SignalKind::Interrupt, SignalKind::Terminate);
582 assert_eq!(
583 SignalKind::Windows(WindowsSignalKind::CtrlBreak),
584 SignalKind::Windows(WindowsSignalKind::CtrlBreak)
585 );
586 }
587
588 #[cfg(unix)]
589 #[test]
590 fn signal_kind_eq() {
591 use crate::UnixSignalKind;
592
593 assert_eq!(SignalKind::Interrupt, SignalKind::Unix(UnixSignalKind::interrupt()));
594 assert_eq!(SignalKind::Terminate, SignalKind::Unix(UnixSignalKind::terminate()));
595 assert_eq!(SignalKind::Unix(UnixSignalKind::interrupt()), SignalKind::Interrupt);
596 assert_eq!(SignalKind::Unix(UnixSignalKind::terminate()), SignalKind::Terminate);
597 assert_ne!(SignalKind::Interrupt, SignalKind::Terminate);
598 assert_eq!(
599 SignalKind::Unix(UnixSignalKind::user_defined1()),
600 SignalKind::Unix(UnixSignalKind::user_defined1())
601 );
602 }
603}