scuffle_context/
lib.rs

1//! A crate designed to provide the ability to cancel futures using a context
2//! go-like approach, allowing for graceful shutdowns and cancellations.
3//!
4//! ## Why do we need this?
5//!
6//! Its often useful to wait for all the futures to shutdown or to cancel them
7//! when we no longer care about the results. This crate provides an interface
8//! to cancel all futures associated with a context or wait for them to finish
9//! before shutting down. Allowing for graceful shutdowns and cancellations.
10//!
11//! ## Usage
12//!
13//! Here is an example of how to use the `Context` to cancel a spawned task.
14//!
15//! ```rust
16//! # use scuffle_context::{Context, ContextFutExt};
17//! # tokio_test::block_on(async {
18//! let (ctx, handler) = Context::new();
19//!
20//! tokio::spawn(async {
21//!     // Do some work
22//!     tokio::time::sleep(std::time::Duration::from_secs(10)).await;
23//! }.with_context(ctx));
24//!
25//! // Will stop the spawned task and cancel all associated futures.
26//! handler.cancel();
27//! # });
28//! ```
29//!
30//! ## License
31//!
32//! This project is licensed under the [MIT](./LICENSE.MIT) or
33//! [Apache-2.0](./LICENSE.Apache-2.0) license. You can choose between one of
34//! them if you use this work.
35//!
36//! `SPDX-License-Identifier: MIT OR Apache-2.0`
37#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
38#![deny(missing_docs)]
39#![deny(unsafe_code)]
40#![deny(unreachable_pub)]
41
42use std::sync::Arc;
43use std::sync::atomic::{AtomicBool, AtomicUsize};
44
45use tokio_util::sync::CancellationToken;
46
47/// For extending types.
48mod ext;
49
50pub use ext::*;
51
52/// Create by calling [`ContextTrackerInner::child`].
53#[derive(Debug)]
54struct ContextTracker(Arc<ContextTrackerInner>);
55
56impl Drop for ContextTracker {
57    fn drop(&mut self) {
58        let prev_active_count = self.0.active_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
59        // If this was the last active `ContextTracker` and the context has been
60        // stopped, then notify the waiters
61        if prev_active_count == 1 && self.0.stopped.load(std::sync::atomic::Ordering::Relaxed) {
62            self.0.notify.notify_waiters();
63        }
64    }
65}
66
67#[derive(Debug)]
68struct ContextTrackerInner {
69    stopped: AtomicBool,
70    /// This count keeps track of the number of `ContextTrackers` that exist for
71    /// this `ContextTrackerInner`.
72    active_count: AtomicUsize,
73    notify: tokio::sync::Notify,
74}
75
76impl ContextTrackerInner {
77    fn new() -> Arc<Self> {
78        Arc::new(Self {
79            stopped: AtomicBool::new(false),
80            active_count: AtomicUsize::new(0),
81            notify: tokio::sync::Notify::new(),
82        })
83    }
84
85    /// Create a new `ContextTracker` from an `Arc<ContextTrackerInner>`.
86    fn child(self: &Arc<Self>) -> ContextTracker {
87        self.active_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
88        ContextTracker(Arc::clone(self))
89    }
90
91    /// Mark this `ContextTrackerInner` as stopped.
92    fn stop(&self) {
93        self.stopped.store(true, std::sync::atomic::Ordering::Relaxed);
94    }
95
96    /// Wait for this `ContextTrackerInner` to be stopped and all associated
97    /// `ContextTracker`s to be dropped.
98    async fn wait(&self) {
99        let notify = self.notify.notified();
100
101        // If there are no active children, then the notify will never be called
102        if self.active_count.load(std::sync::atomic::Ordering::Relaxed) == 0 {
103            return;
104        }
105
106        notify.await;
107    }
108}
109
110/// A context for cancelling futures and waiting for shutdown.
111///
112/// A context can be created from a handler by calling [`Handler::context`] or
113/// from another context by calling [`Context::new_child`] so to have a
114/// hierarchy of contexts.
115///
116/// Contexts can then be attached to futures or streams in order to
117/// automatically cancel them when the context is done, when invoking
118/// [`Handler::cancel`].
119/// The [`Handler::shutdown`] method will block until all contexts have been
120/// dropped allowing for a graceful shutdown.
121#[derive(Debug)]
122pub struct Context {
123    token: CancellationToken,
124    tracker: ContextTracker,
125}
126
127impl Clone for Context {
128    fn clone(&self) -> Self {
129        Self {
130            token: self.token.clone(),
131            tracker: self.tracker.0.child(),
132        }
133    }
134}
135
136impl Context {
137    #[must_use]
138    /// Create a new context using the global handler.
139    /// Returns a child context and child handler of the global handler.
140    pub fn new() -> (Self, Handler) {
141        Handler::global().new_child()
142    }
143
144    #[must_use]
145    /// Create a new child context from this context.
146    /// Returns a new child context and child handler of this context.
147    ///
148    /// # Example
149    ///
150    /// ```rust
151    /// use scuffle_context::Context;
152    ///
153    /// let (parent, parent_handler) = Context::new();
154    /// let (child, child_handler) = parent.new_child();
155    /// ```
156    pub fn new_child(&self) -> (Self, Handler) {
157        let token = self.token.child_token();
158        let tracker = ContextTrackerInner::new();
159
160        (
161            Self {
162                tracker: tracker.child(),
163                token: token.clone(),
164            },
165            Handler {
166                token: Arc::new(TokenDropGuard(token)),
167                tracker,
168            },
169        )
170    }
171
172    #[must_use]
173    /// Returns the global context
174    pub fn global() -> Self {
175        Handler::global().context()
176    }
177
178    /// Wait for the context to be done (the handler to be shutdown).
179    pub async fn done(&self) {
180        self.token.cancelled().await;
181    }
182
183    /// The same as [`Context::done`] but takes ownership of the context.
184    pub async fn into_done(self) {
185        self.done().await;
186    }
187
188    /// Returns true if the context is done.
189    #[must_use]
190    pub fn is_done(&self) -> bool {
191        self.token.is_cancelled()
192    }
193}
194
195/// A wrapper type around [`CancellationToken`] that will cancel the token as
196/// soon as it is dropped.
197#[derive(Debug)]
198struct TokenDropGuard(CancellationToken);
199
200impl TokenDropGuard {
201    #[must_use]
202    fn child(&self) -> CancellationToken {
203        self.0.child_token()
204    }
205
206    fn cancel(&self) {
207        self.0.cancel();
208    }
209}
210
211impl Drop for TokenDropGuard {
212    fn drop(&mut self) {
213        self.cancel();
214    }
215}
216
217/// A handler is used to manage contexts and to cancel them.
218#[derive(Debug, Clone)]
219pub struct Handler {
220    token: Arc<TokenDropGuard>,
221    tracker: Arc<ContextTrackerInner>,
222}
223
224impl Default for Handler {
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230impl Handler {
231    #[must_use]
232    /// Create a new handler.
233    pub fn new() -> Handler {
234        let token = CancellationToken::new();
235        let tracker = ContextTrackerInner::new();
236
237        Handler {
238            token: Arc::new(TokenDropGuard(token)),
239            tracker,
240        }
241    }
242
243    #[must_use]
244    /// Returns the global handler.
245    pub fn global() -> &'static Self {
246        static GLOBAL: std::sync::OnceLock<Handler> = std::sync::OnceLock::new();
247
248        GLOBAL.get_or_init(Handler::new)
249    }
250
251    /// Shutdown the handler and wait for all contexts to be done.
252    pub async fn shutdown(&self) {
253        self.cancel();
254        self.done().await;
255    }
256
257    /// Waits for the handler to be done (waiting for all contexts to be done).
258    pub async fn done(&self) {
259        self.token.0.cancelled().await;
260        self.wait().await;
261    }
262
263    /// Waits for the handler to be done (waiting for all contexts to be done).
264    /// Returns once all contexts are done, even if the handler is not done and
265    /// contexts can be created after this call.
266    pub async fn wait(&self) {
267        self.tracker.wait().await;
268    }
269
270    #[must_use]
271    /// Create a new context from this handler.
272    pub fn context(&self) -> Context {
273        Context {
274            token: self.token.child(),
275            tracker: self.tracker.child(),
276        }
277    }
278
279    #[must_use]
280    /// Create a new child context from this handler
281    pub fn new_child(&self) -> (Context, Handler) {
282        self.context().new_child()
283    }
284
285    /// Cancel the handler.
286    pub fn cancel(&self) {
287        self.tracker.stop();
288        self.token.cancel();
289    }
290
291    /// Returns true if the handler is done.
292    pub fn is_done(&self) -> bool {
293        self.token.0.is_cancelled()
294    }
295}
296
297#[cfg_attr(all(coverage_nightly, test), coverage(off))]
298#[cfg(test)]
299mod tests {
300    use scuffle_future_ext::FutureExt;
301
302    use crate::{Context, Handler};
303
304    #[tokio::test]
305    async fn new() {
306        let (ctx, handler) = Context::new();
307        assert!(!handler.is_done());
308        assert!(!ctx.is_done());
309
310        let handler = Handler::default();
311        assert!(!handler.is_done());
312    }
313
314    #[tokio::test]
315    async fn cancel() {
316        let (ctx, handler) = Context::new();
317        let (child_ctx, child_handler) = ctx.new_child();
318        let child_ctx2 = ctx.clone();
319
320        assert!(!handler.is_done());
321        assert!(!ctx.is_done());
322        assert!(!child_handler.is_done());
323        assert!(!child_ctx.is_done());
324        assert!(!child_ctx2.is_done());
325
326        handler.cancel();
327
328        assert!(handler.is_done());
329        assert!(ctx.is_done());
330        assert!(child_handler.is_done());
331        assert!(child_ctx.is_done());
332        assert!(child_ctx2.is_done());
333    }
334
335    #[tokio::test]
336    async fn cancel_child() {
337        let (ctx, handler) = Context::new();
338        let (child_ctx, child_handler) = ctx.new_child();
339
340        assert!(!handler.is_done());
341        assert!(!ctx.is_done());
342        assert!(!child_handler.is_done());
343        assert!(!child_ctx.is_done());
344
345        child_handler.cancel();
346
347        assert!(!handler.is_done());
348        assert!(!ctx.is_done());
349        assert!(child_handler.is_done());
350        assert!(child_ctx.is_done());
351    }
352
353    #[tokio::test]
354    async fn shutdown() {
355        let (ctx, handler) = Context::new();
356
357        assert!(!handler.is_done());
358        assert!(!ctx.is_done());
359
360        // This is expected to timeout
361        assert!(
362            handler
363                .shutdown()
364                .with_timeout(std::time::Duration::from_millis(200))
365                .await
366                .is_err()
367        );
368        assert!(handler.is_done());
369        assert!(ctx.is_done());
370        assert!(
371            ctx.into_done()
372                .with_timeout(std::time::Duration::from_millis(200))
373                .await
374                .is_ok()
375        );
376
377        assert!(
378            handler
379                .shutdown()
380                .with_timeout(std::time::Duration::from_millis(200))
381                .await
382                .is_ok()
383        );
384        assert!(
385            handler
386                .wait()
387                .with_timeout(std::time::Duration::from_millis(200))
388                .await
389                .is_ok()
390        );
391        assert!(
392            handler
393                .done()
394                .with_timeout(std::time::Duration::from_millis(200))
395                .await
396                .is_ok()
397        );
398        assert!(handler.is_done());
399    }
400
401    #[tokio::test]
402    async fn global_handler() {
403        let handler = Handler::global();
404
405        assert!(!handler.is_done());
406
407        handler.cancel();
408
409        assert!(handler.is_done());
410        assert!(Handler::global().is_done());
411        assert!(Context::global().is_done());
412
413        let (child_ctx, child_handler) = Handler::global().new_child();
414        assert!(child_handler.is_done());
415        assert!(child_ctx.is_done());
416    }
417}