1#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
38#![deny(missing_docs)]
39#![deny(unsafe_code)]
40#![deny(unreachable_pub)]
41
42use std::sync::Arc;
43use std::sync::atomic::{AtomicBool, AtomicUsize};
44
45use tokio_util::sync::CancellationToken;
46
47mod ext;
49
50pub use ext::*;
51
52#[derive(Debug)]
54struct ContextTracker(Arc<ContextTrackerInner>);
55
56impl Drop for ContextTracker {
57 fn drop(&mut self) {
58 let prev_active_count = self.0.active_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
59 if prev_active_count == 1 && self.0.stopped.load(std::sync::atomic::Ordering::Relaxed) {
62 self.0.notify.notify_waiters();
63 }
64 }
65}
66
67#[derive(Debug)]
68struct ContextTrackerInner {
69 stopped: AtomicBool,
70 active_count: AtomicUsize,
73 notify: tokio::sync::Notify,
74}
75
76impl ContextTrackerInner {
77 fn new() -> Arc<Self> {
78 Arc::new(Self {
79 stopped: AtomicBool::new(false),
80 active_count: AtomicUsize::new(0),
81 notify: tokio::sync::Notify::new(),
82 })
83 }
84
85 fn child(self: &Arc<Self>) -> ContextTracker {
87 self.active_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
88 ContextTracker(Arc::clone(self))
89 }
90
91 fn stop(&self) {
93 self.stopped.store(true, std::sync::atomic::Ordering::Relaxed);
94 }
95
96 async fn wait(&self) {
99 let notify = self.notify.notified();
100
101 if self.active_count.load(std::sync::atomic::Ordering::Relaxed) == 0 {
103 return;
104 }
105
106 notify.await;
107 }
108}
109
110#[derive(Debug)]
122pub struct Context {
123 token: CancellationToken,
124 tracker: ContextTracker,
125}
126
127impl Clone for Context {
128 fn clone(&self) -> Self {
129 Self {
130 token: self.token.clone(),
131 tracker: self.tracker.0.child(),
132 }
133 }
134}
135
136impl Context {
137 #[must_use]
138 pub fn new() -> (Self, Handler) {
141 Handler::global().new_child()
142 }
143
144 #[must_use]
145 pub fn new_child(&self) -> (Self, Handler) {
157 let token = self.token.child_token();
158 let tracker = ContextTrackerInner::new();
159
160 (
161 Self {
162 tracker: tracker.child(),
163 token: token.clone(),
164 },
165 Handler {
166 token: Arc::new(TokenDropGuard(token)),
167 tracker,
168 },
169 )
170 }
171
172 #[must_use]
173 pub fn global() -> Self {
175 Handler::global().context()
176 }
177
178 pub async fn done(&self) {
180 self.token.cancelled().await;
181 }
182
183 pub async fn into_done(self) {
185 self.done().await;
186 }
187
188 #[must_use]
190 pub fn is_done(&self) -> bool {
191 self.token.is_cancelled()
192 }
193}
194
195#[derive(Debug)]
198struct TokenDropGuard(CancellationToken);
199
200impl TokenDropGuard {
201 #[must_use]
202 fn child(&self) -> CancellationToken {
203 self.0.child_token()
204 }
205
206 fn cancel(&self) {
207 self.0.cancel();
208 }
209}
210
211impl Drop for TokenDropGuard {
212 fn drop(&mut self) {
213 self.cancel();
214 }
215}
216
217#[derive(Debug, Clone)]
219pub struct Handler {
220 token: Arc<TokenDropGuard>,
221 tracker: Arc<ContextTrackerInner>,
222}
223
224impl Default for Handler {
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230impl Handler {
231 #[must_use]
232 pub fn new() -> Handler {
234 let token = CancellationToken::new();
235 let tracker = ContextTrackerInner::new();
236
237 Handler {
238 token: Arc::new(TokenDropGuard(token)),
239 tracker,
240 }
241 }
242
243 #[must_use]
244 pub fn global() -> &'static Self {
246 static GLOBAL: std::sync::OnceLock<Handler> = std::sync::OnceLock::new();
247
248 GLOBAL.get_or_init(Handler::new)
249 }
250
251 pub async fn shutdown(&self) {
253 self.cancel();
254 self.done().await;
255 }
256
257 pub async fn done(&self) {
259 self.token.0.cancelled().await;
260 self.wait().await;
261 }
262
263 pub async fn wait(&self) {
267 self.tracker.wait().await;
268 }
269
270 #[must_use]
271 pub fn context(&self) -> Context {
273 Context {
274 token: self.token.child(),
275 tracker: self.tracker.child(),
276 }
277 }
278
279 #[must_use]
280 pub fn new_child(&self) -> (Context, Handler) {
282 self.context().new_child()
283 }
284
285 pub fn cancel(&self) {
287 self.tracker.stop();
288 self.token.cancel();
289 }
290
291 pub fn is_done(&self) -> bool {
293 self.token.0.is_cancelled()
294 }
295}
296
297#[cfg_attr(all(coverage_nightly, test), coverage(off))]
298#[cfg(test)]
299mod tests {
300 use scuffle_future_ext::FutureExt;
301
302 use crate::{Context, Handler};
303
304 #[tokio::test]
305 async fn new() {
306 let (ctx, handler) = Context::new();
307 assert!(!handler.is_done());
308 assert!(!ctx.is_done());
309
310 let handler = Handler::default();
311 assert!(!handler.is_done());
312 }
313
314 #[tokio::test]
315 async fn cancel() {
316 let (ctx, handler) = Context::new();
317 let (child_ctx, child_handler) = ctx.new_child();
318 let child_ctx2 = ctx.clone();
319
320 assert!(!handler.is_done());
321 assert!(!ctx.is_done());
322 assert!(!child_handler.is_done());
323 assert!(!child_ctx.is_done());
324 assert!(!child_ctx2.is_done());
325
326 handler.cancel();
327
328 assert!(handler.is_done());
329 assert!(ctx.is_done());
330 assert!(child_handler.is_done());
331 assert!(child_ctx.is_done());
332 assert!(child_ctx2.is_done());
333 }
334
335 #[tokio::test]
336 async fn cancel_child() {
337 let (ctx, handler) = Context::new();
338 let (child_ctx, child_handler) = ctx.new_child();
339
340 assert!(!handler.is_done());
341 assert!(!ctx.is_done());
342 assert!(!child_handler.is_done());
343 assert!(!child_ctx.is_done());
344
345 child_handler.cancel();
346
347 assert!(!handler.is_done());
348 assert!(!ctx.is_done());
349 assert!(child_handler.is_done());
350 assert!(child_ctx.is_done());
351 }
352
353 #[tokio::test]
354 async fn shutdown() {
355 let (ctx, handler) = Context::new();
356
357 assert!(!handler.is_done());
358 assert!(!ctx.is_done());
359
360 assert!(
362 handler
363 .shutdown()
364 .with_timeout(std::time::Duration::from_millis(200))
365 .await
366 .is_err()
367 );
368 assert!(handler.is_done());
369 assert!(ctx.is_done());
370 assert!(
371 ctx.into_done()
372 .with_timeout(std::time::Duration::from_millis(200))
373 .await
374 .is_ok()
375 );
376
377 assert!(
378 handler
379 .shutdown()
380 .with_timeout(std::time::Duration::from_millis(200))
381 .await
382 .is_ok()
383 );
384 assert!(
385 handler
386 .wait()
387 .with_timeout(std::time::Duration::from_millis(200))
388 .await
389 .is_ok()
390 );
391 assert!(
392 handler
393 .done()
394 .with_timeout(std::time::Duration::from_millis(200))
395 .await
396 .is_ok()
397 );
398 assert!(handler.is_done());
399 }
400
401 #[tokio::test]
402 async fn global_handler() {
403 let handler = Handler::global();
404
405 assert!(!handler.is_done());
406
407 handler.cancel();
408
409 assert!(handler.is_done());
410 assert!(Handler::global().is_done());
411 assert!(Context::global().is_done());
412
413 let (child_ctx, child_handler) = Handler::global().new_child();
414 assert!(child_handler.is_done());
415 assert!(child_ctx.is_done());
416 }
417}