1use crate::error::{Error, Result};
6use crate::protocol;
7
8use std::collections::VecDeque;
9use std::sync::{Arc, Mutex as SyncMutex};
10use std::task::{ready, Context, Poll, Waker};
11
12const BUFFER_TRIM_GRANULARITY: usize = 1048576;
18
19const _: () = assert!(BUFFER_TRIM_GRANULARITY.is_power_of_two());
21
22#[derive(Debug, Clone)]
24enum Status {
25 Open,
27 Closed(Option<String>),
29}
30
31impl Status {
32 fn is_closed(&self) -> bool {
33 match self {
34 Status::Open => false,
35 Status::Closed(_) => true,
36 }
37 }
38
39 fn reason(&self) -> Option<String> {
40 match self {
41 Status::Open => None,
42 Status::Closed(x) => x.clone(),
43 }
44 }
45
46 fn close(&mut self) {
47 if let Status::Open = self {
48 *self = Status::Closed(None);
49 }
50 }
51}
52
53#[derive(Debug)]
55struct State {
56 deque: VecDeque<u8>,
58 readable: usize,
65 notify_readable: Option<(Waker, usize)>,
69 closed: Status,
71}
72
73pub struct Reader(Arc<SyncMutex<State>>);
75
76impl Reader {
77 pub fn inspect_shutdown(&self) -> String {
79 let lock = self.0.lock().unwrap();
80 if lock.closed.is_closed() {
81 lock.closed.reason().unwrap_or_else(|| "No epitaph".to_owned())
82 } else {
83 "Not closed".to_owned()
84 }
85 }
86
87 pub async fn read<F, U>(&self, mut size: usize, mut f: F) -> Result<U>
118 where
119 F: FnMut(&[u8]) -> Result<(U, usize)>,
120 {
121 let mut f = move |_: &mut Context<'_>, b: &'_ [u8]| Poll::Ready(f(b));
122 futures::future::poll_fn(move |ctx| self.poll_read(ctx, &mut size, &mut f)).await
123 }
124
125 pub fn poll_read<F, U>(
133 &self,
134 ctx: &mut Context<'_>,
135 size: &mut usize,
136 mut f: F,
137 ) -> Poll<Result<U>>
138 where
139 F: FnMut(&mut Context<'_>, &[u8]) -> Poll<Result<(U, usize)>>,
140 {
141 let mut state = self.0.lock().unwrap();
142
143 if let Status::Closed(reason) = &state.closed && *size == 0 {
144 return Poll::Ready(Err(Error::ConnectionClosed(reason.clone())));
145 }
146
147 if state.readable >= *size {
148 let (first, _) = state.deque.as_slices();
149
150 let first = if first.len() >= *size {
151 first
152 } else {
153 state.deque.make_contiguous();
154 state.deque.as_slices().0
155 };
156
157 debug_assert!(first.len() >= *size);
158
159 let first = &first[..std::cmp::min(first.len(), state.readable)];
160 let (ret, consumed) = match ready!(f(ctx, first)) {
161 Err(Error::BufferTooShort(s)) => {
162 if s < first.len() {
163 return Poll::Ready(Err(Error::CallbackRejectedBuffer(s, first.len())));
164 }
165
166 *size = s;
167 ctx.waker().wake_by_ref();
168 return Poll::Pending;
169 }
170 other => other?,
171 };
172
173 if consumed > first.len() {
174 panic!("Read claimed to consume more bytes than it was given!");
175 }
176
177 state.readable -= consumed;
178 state.deque.drain(..consumed);
179 let target_capacity = std::cmp::max(
180 state.deque.len().next_multiple_of(BUFFER_TRIM_GRANULARITY),
181 BUFFER_TRIM_GRANULARITY,
182 );
183
184 if target_capacity <= state.deque.capacity() / 2 {
185 state.deque.shrink_to(target_capacity);
186 }
187 return Poll::Ready(Ok(ret));
188 }
189
190 if let Status::Closed(reason) = &state.closed {
191 if state.readable > 0 {
192 return Poll::Ready(Err(Error::BufferTooShort(*size)));
193 } else {
194 return Poll::Ready(Err(Error::ConnectionClosed(reason.clone())));
195 }
196 }
197
198 state.notify_readable = Some((ctx.waker().clone(), *size));
199 Poll::Pending
200 }
201
202 pub async fn read_protocol_message<P: protocol::ProtocolMessage>(&self) -> Result<P> {
205 self.read(P::MIN_SIZE, P::try_from_bytes).await
206 }
207
208 pub(crate) fn push_back_protocol_message<P: protocol::ProtocolMessage>(
211 &self,
212 message: &P,
213 ) -> Result<()> {
214 let size = message.byte_size();
215 let mut state = self.0.lock().unwrap();
216 let readable = state.readable;
217 state.deque.resize(readable + size, 0);
218 state.deque.rotate_right(size);
219 let (first, _) = state.deque.as_mut_slices();
220
221 let mut first = if first.len() >= size {
222 first
223 } else {
224 state.deque.make_contiguous();
225 state.deque.as_mut_slices().0
226 };
227
228 let got = message.write_bytes(&mut first)?;
229 debug_assert!(got == size);
230 state.readable += size;
231
232 if let Some((waker, size)) = state.notify_readable.take() {
233 if size <= state.readable {
234 waker.wake();
235 } else {
236 state.notify_readable = Some((waker, size));
237 }
238 }
239
240 Ok(())
241 }
242
243 pub fn is_closed(&self) -> bool {
246 let state = self.0.lock().unwrap();
247 state.closed.is_closed() && state.readable == 0
248 }
249
250 pub fn closed_reason(&self) -> Option<String> {
253 let state = self.0.lock().unwrap();
254 state.closed.reason()
255 }
256
257 pub fn close(self, reason: String) {
259 let mut state = self.0.lock().unwrap();
260 match &state.closed {
261 Status::Closed(Some(_)) => (),
262 _ => state.closed = Status::Closed(Some(reason)),
263 }
264 }
265}
266
267impl std::fmt::Debug for Reader {
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 write!(f, "Reader({:?})", Arc::as_ptr(&self.0))
270 }
271}
272
273impl Drop for Reader {
274 fn drop(&mut self) {
275 let mut state = self.0.lock().unwrap();
276 state.closed.close();
277 }
278}
279
280pub struct Writer(Arc<SyncMutex<State>>);
282
283impl Writer {
284 pub fn write<F>(&self, size: usize, f: F) -> Result<()>
296 where
297 F: FnOnce(&mut [u8]) -> Result<usize>,
298 {
299 let mut state = self.0.lock().unwrap();
300
301 if let Status::Closed(reason) = &state.closed {
302 return Err(Error::ConnectionClosed(reason.clone()));
303 }
304
305 let total_size = state.readable + size;
306
307 if state.deque.len() < total_size {
308 let total_size = std::cmp::max(total_size, state.deque.capacity());
309 state.deque.resize(total_size, 0);
310 }
311
312 let readable = state.readable;
313 let (first, second) = state.deque.as_mut_slices();
314
315 let slice = if first.len() > readable {
316 &mut first[readable..]
317 } else {
318 &mut second[(readable - first.len())..]
319 };
320
321 let slice = if slice.len() >= size {
322 slice
323 } else {
324 state.deque.make_contiguous();
325 &mut state.deque.as_mut_slices().0[readable..]
326 };
327
328 debug_assert!(slice.len() >= size);
329 let size = f(slice)?;
330
331 if size > slice.len() {
332 panic!("Write claimed to produce more bytes than buffer had space for!");
333 }
334
335 state.readable += size;
336
337 if let Some((waker, size)) = state.notify_readable.take() {
338 if size <= state.readable {
339 waker.wake();
340 } else {
341 state.notify_readable = Some((waker, size));
342 }
343 }
344
345 Ok(())
346 }
347
348 pub fn write_protocol_message<P: protocol::ProtocolMessage>(&self, message: &P) -> Result<()> {
351 self.write(message.byte_size(), |mut buf| message.write_bytes(&mut buf))
352 }
353
354 pub fn close(self, reason: String) {
356 self.0.lock().unwrap().closed = Status::Closed(Some(reason))
357 }
358
359 pub fn is_closed(&self) -> bool {
362 let state = self.0.lock().unwrap();
363 state.closed.is_closed() && state.readable == 0
364 }
365
366 pub fn closed_reason(&self) -> Option<String> {
369 let state = self.0.lock().unwrap();
370 state.closed.reason()
371 }
372}
373
374impl std::fmt::Debug for Writer {
375 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 write!(f, "Writer({:?})", Arc::as_ptr(&self.0))
377 }
378}
379
380impl Drop for Writer {
381 fn drop(&mut self) {
382 let mut state = self.0.lock().unwrap();
383 state.closed.close();
384
385 if let Some((waker, _)) = state.notify_readable.take() {
386 waker.wake();
387 }
388 }
389}
390
391pub fn stream() -> (Reader, Writer) {
396 let reader = Arc::new(SyncMutex::new(State {
397 deque: VecDeque::new(),
398 readable: 0,
399 notify_readable: None,
400 closed: Status::Open,
401 }));
402 let writer = Arc::clone(&reader);
403
404 (Reader(reader), Writer(writer))
405}
406
407#[cfg(test)]
408mod test {
409 use futures::channel::oneshot;
410 use futures::task::noop_waker;
411 use futures::FutureExt;
412 use std::future::Future;
413 use std::pin::pin;
414 use std::task::{Context, Poll};
415
416 use super::*;
417
418 impl protocol::ProtocolMessage for [u8; 4] {
419 const MIN_SIZE: usize = 4;
420 fn byte_size(&self) -> usize {
421 4
422 }
423
424 fn write_bytes<W: std::io::Write>(&self, out: &mut W) -> Result<usize> {
425 out.write_all(self)?;
426 Ok(4)
427 }
428
429 fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
430 if bytes.len() < 4 {
431 return Err(Error::BufferTooShort(4));
432 }
433
434 Ok((bytes[..4].try_into().unwrap(), 4))
435 }
436 }
437
438 #[fuchsia::test]
439 async fn stream_test() {
440 let (reader, writer) = stream();
441 writer
442 .write(8, |buf| {
443 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
444 Ok(8)
445 })
446 .unwrap();
447
448 let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
449
450 assert_eq!(vec![1, 2, 3, 4], got);
451
452 writer
453 .write(2, |buf| {
454 buf[..2].copy_from_slice(&[9, 10]);
455 Ok(2)
456 })
457 .unwrap();
458
459 let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();
460
461 assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
462 }
463
464 #[fuchsia::test]
465 async fn push_back_test() {
466 let (reader, writer) = stream();
467 writer
468 .write(8, |buf| {
469 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
470 Ok(8)
471 })
472 .unwrap();
473
474 let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
475
476 assert_eq!(vec![1, 2, 3, 4], got);
477
478 reader.push_back_protocol_message(&[4, 3, 2, 1]).unwrap();
479
480 writer
481 .write(2, |buf| {
482 buf[..2].copy_from_slice(&[9, 10]);
483 Ok(2)
484 })
485 .unwrap();
486
487 let got = reader.read(10, |buf| Ok((buf[..10].to_vec(), 6))).await.unwrap();
488
489 assert_eq!(vec![4, 3, 2, 1, 5, 6, 7, 8, 9, 10], got);
490 }
491
492 #[fuchsia::test]
493 async fn writer_sees_close() {
494 let (reader, writer) = stream();
495 writer
496 .write(8, |buf| {
497 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
498 Ok(8)
499 })
500 .unwrap();
501
502 let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
503
504 assert_eq!(vec![1, 2, 3, 4], got);
505
506 std::mem::drop(reader);
507
508 assert!(matches!(
509 writer.write(2, |buf| {
510 buf[..2].copy_from_slice(&[9, 10]);
511 Ok(2)
512 }),
513 Err(Error::ConnectionClosed(None))
514 ));
515 }
516
517 #[fuchsia::test]
518 async fn reader_sees_closed() {
519 let (reader, writer) = stream();
520 writer
521 .write(8, |buf| {
522 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
523 Ok(8)
524 })
525 .unwrap();
526
527 let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
528
529 assert_eq!(vec![1, 2, 3, 4], got);
530
531 writer
532 .write(2, |buf| {
533 buf[..2].copy_from_slice(&[9, 10]);
534 Ok(2)
535 })
536 .unwrap();
537
538 std::mem::drop(writer);
539
540 assert!(matches!(reader.read(7, |_| Ok(((), 1))).await, Err(Error::BufferTooShort(7))));
541
542 let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();
543
544 assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
545 assert!(matches!(
546 reader.read(1, |_| Ok(((), 1))).await,
547 Err(Error::ConnectionClosed(None))
548 ));
549 }
550
551 #[fuchsia::test]
552 async fn reader_sees_closed_when_polling() {
553 let (reader, writer) = stream();
554 writer
555 .write(8, |buf| {
556 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
557 Ok(8)
558 })
559 .unwrap();
560
561 let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();
562
563 assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);
564
565 let fut = reader
566 .read(1, |_| -> Result<((), usize)> { panic!("This read should never succeed!") });
567 let mut fut = std::pin::pin!(fut);
568
569 assert!(fut.poll_unpin(&mut Context::from_waker(&noop_waker())).is_pending());
570
571 std::mem::drop(writer);
572
573 assert!(matches!(
574 fut.poll_unpin(&mut Context::from_waker(&noop_waker())),
575 Poll::Ready(Err(Error::ConnectionClosed(None)))
576 ));
577 }
578
579 #[fuchsia::test]
580 async fn reader_sees_closed_separate_task() {
581 let (reader, writer) = stream();
582 writer
583 .write(8, |buf| {
584 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
585 Ok(8)
586 })
587 .unwrap();
588
589 let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();
590
591 assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);
592
593 let (sender, receiver) = oneshot::channel();
594 let task = fuchsia_async::Task::spawn(async move {
595 let fut = reader.read(1, |_| Ok(((), 1)));
596 let mut fut = std::pin::pin!(fut);
597 let mut writer = Some(writer);
598 let fut = futures::future::poll_fn(move |cx| {
599 let ret = fut.as_mut().poll(cx);
600
601 if writer.take().is_some() {
602 assert!(matches!(ret, Poll::Pending));
603 }
604
605 ret
606 });
607 assert!(matches!(fut.await, Err(Error::ConnectionClosed(None))));
608 sender.send(()).unwrap();
609 });
610
611 receiver.await.unwrap();
612 task.await;
613 }
614
615 #[fuchsia::test]
616 async fn reader_buffer_too_short() {
617 let (reader, writer) = stream();
618 let (sender, receiver) = oneshot::channel();
619 let mut sender = Some(sender);
620
621 let reader_task = async move {
622 let got = reader
623 .read(1, |buf| {
624 if buf.len() != 4 {
625 sender.take().unwrap().send(buf.len()).unwrap();
626 Err(Error::BufferTooShort(4))
627 } else {
628 Ok((buf[..4].to_vec(), 4))
629 }
630 })
631 .await
632 .unwrap();
633 assert_eq!(vec![1, 2, 3, 4], got);
634 };
635
636 let writer_task = async move {
637 writer
638 .write(2, |buf| {
639 buf[..2].copy_from_slice(&[1, 2]);
640 Ok(2)
641 })
642 .unwrap();
643
644 assert_eq!(2, receiver.await.unwrap());
645
646 writer
647 .write(2, |buf| {
648 buf[..2].copy_from_slice(&[3, 4]);
649 Ok(2)
650 })
651 .unwrap();
652 };
653
654 futures::future::join(pin!(reader_task), pin!(writer_task)).await;
655 }
656}