1use std::sync::{Arc, Mutex};
2
3use bytes::{Buf, BytesMut};
4
5#[derive(Debug, Clone)]
9pub struct ChannelCompat<T: Send> {
10 inner: Arc<Mutex<T>>,
14 buffer: BytesMut,
15}
16
17impl<T: Send> ChannelCompat<T> {
18 pub fn new(inner: T) -> Self {
20 Self {
21 inner: Arc::new(Mutex::new(inner)),
22 buffer: BytesMut::new(),
23 }
24 }
25}
26
27pub trait ChannelCompatRecv: Send {
29 type Data: AsRef<[u8]>;
31
32 fn channel_recv(&mut self) -> Option<Self::Data>;
34
35 fn try_channel_recv(&mut self) -> Option<Self::Data>;
37
38 fn into_compat(self) -> ChannelCompat<Self>
40 where
41 Self: Sized,
42 {
43 ChannelCompat::new(self)
44 }
45}
46
47pub trait ChannelCompatSend: Send {
49 type Data: From<Vec<u8>>;
51
52 fn channel_send(&mut self, data: Self::Data) -> bool;
54
55 fn into_compat(self) -> ChannelCompat<Self>
57 where
58 Self: Sized,
59 {
60 ChannelCompat::new(self)
61 }
62}
63
64#[cfg(feature = "tokio-channel")]
65#[cfg_attr(docsrs, doc(cfg(feature = "tokio-channel")))]
66impl<D: AsRef<[u8]> + Send> ChannelCompatRecv for tokio::sync::mpsc::Receiver<D> {
67 type Data = D;
68
69 fn channel_recv(&mut self) -> Option<Self::Data> {
70 self.blocking_recv()
71 }
72
73 fn try_channel_recv(&mut self) -> Option<Self::Data> {
74 self.try_recv().ok()
75 }
76}
77
78#[cfg(feature = "tokio-channel")]
79#[cfg_attr(docsrs, doc(cfg(feature = "tokio-channel")))]
80impl<D: From<Vec<u8>> + Send> ChannelCompatSend for tokio::sync::mpsc::Sender<D> {
81 type Data = D;
82
83 fn channel_send(&mut self, data: Self::Data) -> bool {
84 self.blocking_send(data).is_ok()
85 }
86}
87
88#[cfg(feature = "tokio-channel")]
89#[cfg_attr(docsrs, doc(cfg(feature = "tokio-channel")))]
90impl<D: AsRef<[u8]> + Send> ChannelCompatRecv for tokio::sync::mpsc::UnboundedReceiver<D> {
91 type Data = D;
92
93 fn channel_recv(&mut self) -> Option<Self::Data> {
94 self.blocking_recv()
95 }
96
97 fn try_channel_recv(&mut self) -> Option<Self::Data> {
98 self.try_recv().ok()
99 }
100}
101
102#[cfg(feature = "tokio-channel")]
103#[cfg_attr(docsrs, doc(cfg(feature = "tokio-channel")))]
104impl<D: From<Vec<u8>> + Send> ChannelCompatSend for tokio::sync::mpsc::UnboundedSender<D> {
105 type Data = D;
106
107 fn channel_send(&mut self, data: Self::Data) -> bool {
108 self.send(data).is_ok()
109 }
110}
111
112#[cfg(feature = "tokio-channel")]
113#[cfg_attr(docsrs, doc(cfg(feature = "tokio-channel")))]
114impl<D: AsRef<[u8]> + Clone + Send> ChannelCompatRecv for tokio::sync::broadcast::Receiver<D> {
115 type Data = D;
116
117 fn channel_recv(&mut self) -> Option<Self::Data> {
118 self.blocking_recv().ok()
119 }
120
121 fn try_channel_recv(&mut self) -> Option<Self::Data> {
122 self.try_recv().ok()
123 }
124}
125
126#[cfg(feature = "tokio-channel")]
127#[cfg_attr(docsrs, doc(cfg(feature = "tokio-channel")))]
128impl<D: From<Vec<u8>> + Clone + Send> ChannelCompatSend for tokio::sync::broadcast::Sender<D> {
129 type Data = D;
130
131 fn channel_send(&mut self, data: Self::Data) -> bool {
132 self.send(data).is_ok()
133 }
134}
135
136#[cfg(feature = "crossbeam-channel")]
137#[cfg_attr(docsrs, doc(cfg(feature = "crossbeam-channel")))]
138impl<D: AsRef<[u8]> + Send> ChannelCompatRecv for crossbeam_channel::Receiver<D> {
139 type Data = D;
140
141 fn channel_recv(&mut self) -> Option<Self::Data> {
142 self.recv().ok()
143 }
144
145 fn try_channel_recv(&mut self) -> Option<Self::Data> {
146 self.try_recv().ok()
147 }
148}
149
150#[cfg(feature = "crossbeam-channel")]
151#[cfg_attr(docsrs, doc(cfg(feature = "crossbeam-channel")))]
152impl<D: From<Vec<u8>> + Send> ChannelCompatSend for crossbeam_channel::Sender<D> {
153 type Data = D;
154
155 fn channel_send(&mut self, data: Self::Data) -> bool {
156 self.send(data).is_ok()
157 }
158}
159
160impl<D: AsRef<[u8]> + Send> ChannelCompatRecv for std::sync::mpsc::Receiver<D> {
161 type Data = D;
162
163 fn channel_recv(&mut self) -> Option<Self::Data> {
164 self.recv().ok()
165 }
166
167 fn try_channel_recv(&mut self) -> Option<Self::Data> {
168 self.try_recv().ok()
169 }
170}
171
172impl<D: From<Vec<u8>> + Send> ChannelCompatSend for std::sync::mpsc::Sender<D> {
173 type Data = D;
174
175 fn channel_send(&mut self, data: Self::Data) -> bool {
176 self.send(data).is_ok()
177 }
178}
179
180impl<D: From<Vec<u8>> + Send> ChannelCompatSend for std::sync::mpsc::SyncSender<D> {
181 type Data = D;
182
183 fn channel_send(&mut self, data: Self::Data) -> bool {
184 self.send(data).is_ok()
185 }
186}
187
188impl<T: ChannelCompatRecv> std::io::Read for ChannelCompat<T> {
189 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
190 if self.buffer.len() >= buf.len() {
191 buf.copy_from_slice(&self.buffer[..buf.len()]);
192 self.buffer.advance(buf.len());
193 return Ok(buf.len());
194 }
195
196 let mut inner = self.inner.lock().unwrap();
197
198 let mut total_read = 0;
199 if self.buffer.is_empty() {
200 let Some(data) = inner.channel_recv() else {
201 return Ok(0);
202 };
203
204 let data = data.as_ref();
205 let min = data.len().min(buf.len());
206
207 buf.copy_from_slice(&data[..min]);
208 self.buffer.extend_from_slice(&data[min..]);
209 total_read += min;
210 } else {
211 buf[..self.buffer.len()].copy_from_slice(&self.buffer);
212 total_read += self.buffer.len();
213 self.buffer.clear();
214 }
215
216 while let Some(Some(data)) = (total_read < buf.len()).then(|| inner.try_channel_recv()) {
217 let data = data.as_ref();
218 let min = data.len().min(buf.len() - total_read);
219 buf[total_read..total_read + min].copy_from_slice(&data[..min]);
220 self.buffer.extend_from_slice(&data[min..]);
221 total_read += min;
222 }
223
224 Ok(total_read)
225 }
226}
227
228impl<T: ChannelCompatSend> std::io::Write for ChannelCompat<T> {
229 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
230 if !self.inner.lock().unwrap().channel_send(buf.to_vec().into()) {
231 return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "Unexpected EOF"));
232 }
233
234 Ok(buf.len())
235 }
236
237 fn flush(&mut self) -> std::io::Result<()> {
238 Ok(())
239 }
240}
241
242#[cfg(test)]
243#[cfg_attr(all(test, coverage_nightly), coverage(off))]
244mod tests {
245 use std::io::{Read, Write};
246
247 use rand::Rng;
248 use rand::distr::StandardUniform;
249
250 use crate::io::channel::{ChannelCompat, ChannelCompatRecv, ChannelCompatSend};
251
252 macro_rules! make_test {
253 (
254 $(
255 $(
256 #[variant($name:ident, $channel:expr$(, cfg($($cfg_meta:meta)*))?)]
257 )*
258 |$tx:ident, $rx:ident| $body:block
259 )*
260 ) => {
261 $(
262 $(
263 #[test]
264 $(#[cfg($($cfg_meta)*)])?
265 fn $name() {
266 let ($tx, $rx) = $channel;
267 $body
268 }
269 )*
270 )*
271 };
272 }
273
274 make_test! {
276 #[variant(
277 test_read_std_mpsc,
278 std::sync::mpsc::channel::<Vec<u8>>()
279 )]
280 #[variant(
281 test_read_std_sync_mpsc,
282 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
283 )]
284 #[variant(
285 test_read_tokio_mpsc,
286 tokio::sync::mpsc::channel::<Vec<u8>>(1),
287 cfg(feature = "tokio-channel")
288 )]
289 #[variant(
290 test_read_tokio_unbounded,
291 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
292 cfg(feature = "tokio-channel")
293 )]
294 #[variant(
295 test_read_tokio_broadcast,
296 tokio::sync::broadcast::channel::<Vec<u8>>(1),
297 cfg(feature = "tokio-channel")
298 )]
299 #[variant(
300 test_read_crossbeam_unbounded,
301 crossbeam_channel::unbounded::<Vec<u8>>(),
302 cfg(feature = "crossbeam-channel")
303 )]
304 |tx, rx| {
305 let mut reader = rx.into_compat();
306
307 let mut rng = rand::rng();
309 let data: Vec<u8> = (0..1000).map(|_| rng.sample::<u8, _>(StandardUniform)).collect();
310
311 let mut tx = tx;
312 let write_result = tx.channel_send(data.clone());
313 assert!(write_result);
314
315 let mut buffer = vec![0u8; 1000];
317 let read_result = reader.read(&mut buffer);
318 assert!(read_result.is_ok());
319 assert_eq!(read_result.unwrap(), data.len());
320
321 assert_eq!(buffer, data);
323 }
324 }
325
326 make_test! {
328 #[variant(
329 test_write_std_mpsc,
330 std::sync::mpsc::channel::<Vec<u8>>()
331 )]
332 #[variant(
333 test_write_std_sync_mpsc,
334 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
335 )]
336 #[variant(
337 test_write_tokio_mpsc,
338 tokio::sync::mpsc::channel::<Vec<u8>>(1),
339 cfg(feature = "tokio-channel")
340 )]
341 #[variant(
342 test_write_tokio_unbounded,
343 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
344 cfg(feature = "tokio-channel")
345 )]
346 #[variant(
347 test_write_tokio_broadcast,
348 tokio::sync::broadcast::channel::<Vec<u8>>(1),
349 cfg(feature = "tokio-channel")
350 )]
351 #[variant(
352 test_write_crossbeam_unbounded,
353 crossbeam_channel::unbounded::<Vec<u8>>(),
354 cfg(feature = "crossbeam-channel")
355 )]
356 |tx, rx| {
357 let mut writer = tx.into_compat();
358
359 let mut rng = rand::rng();
361 let data: Vec<u8> = (0..1000).map(|_| rng.sample::<u8, _>(StandardUniform)).collect();
362
363 let write_result = writer.write(&data);
364 assert!(write_result.is_ok(), "Failed to write data to the channel");
365 assert_eq!(write_result.unwrap(), data.len(), "Written byte count mismatch");
366
367 let mut rx = rx;
369 let read_result = rx.channel_recv();
370 assert!(read_result.is_some(), "No data received from the channel");
371
372 let received_data = read_result.unwrap();
373 assert_eq!(received_data.len(), data.len(), "Received byte count mismatch");
374
375 assert_eq!(
377 received_data, data,
378 "Mismatch between written data and received data"
379 );
380 }
381 }
382
383 make_test! {
385 #[variant(
386 test_read_smaller_buffer_than_data_std_mpsc,
387 std::sync::mpsc::channel::<Vec<u8>>()
388 )]
389 #[variant(
390 test_read_smaller_buffer_than_data_std_sync_mpsc,
391 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
392 )]
393 #[variant(
394 test_read_smaller_buffer_than_data_tokio_mpsc,
395 tokio::sync::mpsc::channel::<Vec<u8>>(1),
396 cfg(feature = "tokio-channel")
397 )]
398 #[variant(
399 test_read_smaller_buffer_than_data_tokio_unbounded,
400 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
401 cfg(feature = "tokio-channel")
402 )]
403 #[variant(
404 test_read_smaller_buffer_than_data_tokio_broadcast,
405 tokio::sync::broadcast::channel::<Vec<u8>>(1),
406 cfg(feature = "tokio-channel")
407 )]
408 #[variant(
409 test_read_smaller_buffer_than_data_crossbeam_unbounded,
410 crossbeam_channel::unbounded::<Vec<u8>>(),
411 cfg(feature = "crossbeam-channel")
412 )]
413 |tx, rx| {
414 let mut reader = ChannelCompat::new(rx);
415 let data = b"PartialReadTest".to_vec();
416 let mut tx = tx;
417 let send_result = tx.channel_send(data);
418 assert!(send_result);
419
420 let mut buffer = vec![0u8; 7]; let read_result = reader.read(&mut buffer);
422 assert!(read_result.is_ok());
423 assert_eq!(read_result.unwrap(), buffer.len());
424 assert_eq!(&buffer, b"Partial");
425
426 let mut buffer = vec![0u8; 8];
428 let read_result = reader.read(&mut buffer);
429 assert!(read_result.is_ok());
430 assert_eq!(read_result.unwrap(), buffer.len());
431 assert_eq!(&buffer, b"ReadTest");
432 }
433 }
434
435 make_test! {
437 #[variant(
438 test_read_no_data_std_mpsc,
439 std::sync::mpsc::channel::<Vec<u8>>()
440 )]
441 #[variant(
442 test_read_no_data_std_sync_mpsc,
443 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
444 )]
445 #[variant(
446 test_read_no_data_tokio_mpsc,
447 tokio::sync::mpsc::channel::<Vec<u8>>(1),
448 cfg(feature = "tokio-channel")
449 )]
450 #[variant(
451 test_read_no_data_tokio_unbounded,
452 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
453 cfg(feature = "tokio-channel")
454 )]
455 #[variant(
456 test_read_no_data_tokio_broadcast,
457 tokio::sync::broadcast::channel::<Vec<u8>>(1),
458 cfg(feature = "tokio-channel")
459 )]
460 #[variant(
461 test_read_no_data_crossbeam_unbounded,
462 crossbeam_channel::unbounded::<Vec<u8>>(),
463 cfg(feature = "crossbeam-channel")
464 )]
465 |tx, rx| {
466 let mut reader = ChannelCompat::new(rx);
467
468 drop(tx);
470 let mut buffer = vec![0u8; 10];
471 let read_result = reader.read(&mut buffer);
472
473 assert!(read_result.is_ok());
474 assert_eq!(read_result.unwrap(), 0);
475 }
476 }
477
478 make_test! {
480 #[variant(
481 test_read_else_case_std_mpsc,
482 std::sync::mpsc::channel::<Vec<u8>>()
483 )]
484 #[variant(
485 test_read_else_case_std_sync_mpsc,
486 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
487 )]
488 #[variant(
489 test_read_else_case_tokio_mpsc,
490 tokio::sync::mpsc::channel::<Vec<u8>>(1),
491 cfg(feature = "tokio-channel")
492 )]
493 #[variant(
494 test_read_else_case_tokio_unbounded,
495 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
496 cfg(feature = "tokio-channel")
497 )]
498 #[variant(
499 test_read_else_case_tokio_broadcast,
500 tokio::sync::broadcast::channel::<Vec<u8>>(1),
501 cfg(feature = "tokio-channel")
502 )]
503 #[variant(
504 test_read_else_case_crossbeam_unbounded,
505 crossbeam_channel::unbounded::<Vec<u8>>(),
506 cfg(feature = "crossbeam-channel")
507 )]
508 |tx, rx| {
509 let mut reader = ChannelCompat::new(rx);
510 let mut tx = tx;
511
512 let data1 = b"FirstChunk".to_vec();
513 let write_result1 = tx.channel_send(data1);
514 assert!(write_result1, "Failed to send data1");
515
516 let mut buffer = vec![0u8; 5];
518 let read_result = reader.read(&mut buffer);
519 assert!(read_result.is_ok(), "Failed to read the first chunk");
520 let bytes_read = read_result.unwrap();
521 assert_eq!(bytes_read, buffer.len(), "Mismatch in first chunk read size");
522 assert_eq!(&buffer, b"First", "Buffer content mismatch for first part of FirstChunk");
523
524 let mut buffer = vec![0u8; 10];
526 let read_result = reader.read(&mut buffer);
527 assert!(read_result.is_ok(), "Failed to read the next 10 bytes");
528 let bytes_read = read_result.unwrap();
529
530 assert_eq!(bytes_read, 5, "Unexpected read size for the next part");
532 assert_eq!(&buffer[..bytes_read], b"Chunk", "Buffer content mismatch for combined reads");
533
534 let data2 = b"SecondChunk".to_vec();
536 let write_result2 = tx.channel_send(data2);
537 assert!(write_result2, "Failed to send data2");
538
539 let mut buffer = vec![0u8; 5];
541 let read_result = reader.read(&mut buffer);
542 assert!(read_result.is_ok(), "Failed to read leftover data from data2");
543 let bytes_read = read_result.unwrap();
544 assert!(bytes_read > 0, "No leftover data from data2 was available");
545 }
546 }
547
548 make_test! {
550 #[variant(
551 test_read_while_case_std_mpsc,
552 std::sync::mpsc::channel::<Vec<u8>>()
553 )]
554 #[variant(
555 test_read_while_case_std_sync_mpsc,
556 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
557 )]
558 #[variant(
559 test_read_while_case_tokio_mpsc,
560 tokio::sync::mpsc::channel::<Vec<u8>>(1),
561 cfg(feature = "tokio-channel")
562 )]
563 #[variant(
564 test_read_while_case_tokio_unbounded,
565 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
566 cfg(feature = "tokio-channel")
567 )]
568 #[variant(
569 test_read_while_case_tokio_broadcast,
570 tokio::sync::broadcast::channel::<Vec<u8>>(1),
571 cfg(feature = "tokio-channel")
572 )]
573 #[variant(
574 test_read_while_case_crossbeam_unbounded,
575 crossbeam_channel::unbounded::<Vec<u8>>(),
576 cfg(feature = "crossbeam-channel")
577 )]
578 |tx, rx| {
579 let mut reader = ChannelCompat::new(rx);
580 let mut tx = tx;
581
582 let data1 = b"FirstChunk".to_vec();
583 let write_result1 = tx.channel_send(data1);
584 assert!(write_result1, "Failed to send data1");
585
586 let mut buffer = vec![0u8; 5];
588 let read_result = reader.read(&mut buffer);
589 assert!(read_result.is_ok(), "Failed to read the first chunk");
590 let bytes_read = read_result.unwrap();
591 assert_eq!(bytes_read, buffer.len(), "Mismatch in first chunk read size");
592 assert_eq!(&buffer, b"First", "Buffer content mismatch for first part of FirstChunk");
593
594 let data2 = b"SecondChunk".to_vec();
596 let write_result2 = tx.channel_send(data2);
597 assert!(write_result2, "Failed to send data2");
598
599 let mut buffer = vec![0u8; 10];
601 let read_result = reader.read(&mut buffer);
602 assert!(read_result.is_ok(), "Failed to read the next chunk of data");
603 let bytes_read = read_result.unwrap();
604 assert!(bytes_read > 0, "No data was read");
605 assert_eq!(&buffer[..bytes_read], b"ChunkSecon", "Buffer content mismatch");
606
607 let mut buffer = vec![0u8; 6];
609 let read_result = reader.read(&mut buffer);
610 assert!(read_result.is_ok(), "Failed to read remaining data");
611 let bytes_read = read_result.unwrap();
612 assert!(bytes_read > 0, "No additional data was read");
613 assert_eq!(&buffer[..bytes_read], b"dChunk", "Buffer content mismatch for remaining data");
614 }
615 }
616
617 make_test! {
619 #[variant(
620 test_write_eof_error_std_mpsc,
621 std::sync::mpsc::channel::<Vec<u8>>()
622 )]
623 #[variant(
624 test_write_eof_error_std_sync_mpsc,
625 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
626 )]
627 #[variant(
628 test_write_eof_error_tokio_mpsc,
629 tokio::sync::mpsc::channel::<Vec<u8>>(1),
630 cfg(feature = "tokio-channel")
631 )]
632 #[variant(
633 test_write_eof_error_tokio_unbounded,
634 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
635 cfg(feature = "tokio-channel")
636 )]
637 #[variant(
638 test_write_eof_error_tokio_broadcast,
639 tokio::sync::broadcast::channel::<Vec<u8>>(1),
640 cfg(feature = "tokio-channel")
641 )]
642 #[variant(
643 test_write_eof_error_crossbeam_unbounded,
644 crossbeam_channel::unbounded::<Vec<u8>>(),
645 cfg(feature = "crossbeam-channel")
646 )]
647 |tx, rx| {
648 let mut writer = ChannelCompat::new(tx);
649
650 drop(rx);
652
653 let data = vec![42u8; 100];
654 let write_result = writer.write(&data);
655 assert!(write_result.is_err());
656 assert_eq!(write_result.unwrap_err().kind(), std::io::ErrorKind::UnexpectedEof);
657 }
658 }
659
660 make_test! {
662 #[variant(
663 test_flush_std_mpsc,
664 std::sync::mpsc::channel::<Vec<u8>>()
665 )]
666 #[variant(
667 test_flush_std_sync_mpsc,
668 std::sync::mpsc::sync_channel::<Vec<u8>>(1)
669 )]
670 #[variant(
671 test_flush_tokio_mpsc,
672 tokio::sync::mpsc::channel::<Vec<u8>>(1),
673 cfg(feature = "tokio-channel")
674 )]
675 #[variant(
676 test_flush_tokio_unbounded,
677 tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(),
678 cfg(feature = "tokio-channel")
679 )]
680 #[variant(
681 test_flush_tokio_broadcast,
682 tokio::sync::broadcast::channel::<Vec<u8>>(1),
683 cfg(feature = "tokio-channel")
684 )]
685 #[variant(
686 test_flush_crossbeam_unbounded,
687 crossbeam_channel::unbounded::<Vec<u8>>(),
688 cfg(feature = "crossbeam-channel")
689 )]
690 |tx, _rx| {
691 let mut writer = ChannelCompat::new(tx);
692
693 let flush_result = writer.flush();
694 assert!(flush_result.is_ok());
695 }
696 }
697}