1use crate::error::{Error, Result};
6use crate::protocol;
7
8use fuchsia_sync::Mutex as SyncMutex;
9use std::collections::VecDeque;
10use std::sync::Arc;
11use std::task::{Context, Poll, Waker, ready};
12
13const BUFFER_TRIM_GRANULARITY: usize = 1048576;
19
20const _: () = assert!(BUFFER_TRIM_GRANULARITY.is_power_of_two());
22
23#[derive(Debug, Clone)]
25enum Status {
26 Open,
28 Closed(Option<String>),
30}
31
32impl Status {
33 fn is_closed(&self) -> bool {
34 match self {
35 Status::Open => false,
36 Status::Closed(_) => true,
37 }
38 }
39
40 fn reason(&self) -> Option<String> {
41 match self {
42 Status::Open => None,
43 Status::Closed(x) => x.clone(),
44 }
45 }
46
47 fn close(&mut self) {
48 if let Status::Open = self {
49 *self = Status::Closed(None);
50 }
51 }
52}
53
54#[derive(Debug)]
56struct State {
57 deque: VecDeque<u8>,
59 readable: usize,
66 notify_readable: Option<(Waker, usize)>,
70 closed: Status,
72}
73
74pub struct Reader(Arc<SyncMutex<State>>);
76
77impl Reader {
78 pub fn inspect_shutdown(&self) -> String {
80 let lock = self.0.lock();
81 if lock.closed.is_closed() {
82 lock.closed.reason().unwrap_or_else(|| "No epitaph".to_owned())
83 } else {
84 "Not closed".to_owned()
85 }
86 }
87
88 pub async fn read<F, U>(&self, mut size: usize, mut f: F) -> Result<U>
119 where
120 F: FnMut(&[u8]) -> Result<(U, usize)>,
121 {
122 let mut f = move |_: &mut Context<'_>, b: &'_ [u8]| Poll::Ready(f(b));
123 futures::future::poll_fn(move |ctx| self.poll_read(ctx, &mut size, &mut f)).await
124 }
125
126 pub fn poll_read<F, U>(
134 &self,
135 ctx: &mut Context<'_>,
136 size: &mut usize,
137 mut f: F,
138 ) -> Poll<Result<U>>
139 where
140 F: FnMut(&mut Context<'_>, &[u8]) -> Poll<Result<(U, usize)>>,
141 {
142 let mut state = self.0.lock();
143
144 if let Status::Closed(reason) = &state.closed
145 && *size == 0
146 {
147 return Poll::Ready(Err(Error::ConnectionClosed(reason.clone())));
148 }
149
150 if state.readable >= *size {
151 let (first, _) = state.deque.as_slices();
152
153 let first = if first.len() >= *size {
154 first
155 } else {
156 state.deque.make_contiguous();
157 state.deque.as_slices().0
158 };
159
160 debug_assert!(first.len() >= *size);
161
162 let first = &first[..std::cmp::min(first.len(), state.readable)];
163 let (ret, consumed) = match ready!(f(ctx, first)) {
164 Err(Error::BufferTooShort(s)) => {
165 if s < first.len() {
166 return Poll::Ready(Err(Error::CallbackRejectedBuffer(s, first.len())));
167 }
168
169 *size = s;
170 ctx.waker().wake_by_ref();
171 return Poll::Pending;
172 }
173 other => other?,
174 };
175
176 if consumed > first.len() {
177 panic!("Read claimed to consume more bytes than it was given!");
178 }
179
180 state.readable -= consumed;
181 state.deque.drain(..consumed);
182 let target_capacity = std::cmp::max(
183 state.deque.len().next_multiple_of(BUFFER_TRIM_GRANULARITY),
184 BUFFER_TRIM_GRANULARITY,
185 );
186
187 if target_capacity <= state.deque.capacity() / 2 {
188 state.deque.shrink_to(target_capacity);
189 }
190 return Poll::Ready(Ok(ret));
191 }
192
193 if let Status::Closed(reason) = &state.closed {
194 if state.readable > 0 {
195 return Poll::Ready(Err(Error::BufferTooShort(*size)));
196 } else {
197 return Poll::Ready(Err(Error::ConnectionClosed(reason.clone())));
198 }
199 }
200
201 state.notify_readable = Some((ctx.waker().clone(), *size));
202 Poll::Pending
203 }
204
205 pub async fn read_protocol_message<P: protocol::ProtocolMessage>(&self) -> Result<P> {
208 self.read(P::MIN_SIZE, P::try_from_bytes).await
209 }
210
211 pub(crate) fn push_back_protocol_message<P: protocol::ProtocolMessage>(
214 &self,
215 message: &P,
216 ) -> Result<()> {
217 let size = message.byte_size();
218 let mut state = self.0.lock();
219 let readable = state.readable;
220 state.deque.resize(readable + size, 0);
221 state.deque.rotate_right(size);
222 let (first, _) = state.deque.as_mut_slices();
223
224 let mut first = if first.len() >= size {
225 first
226 } else {
227 state.deque.make_contiguous();
228 state.deque.as_mut_slices().0
229 };
230
231 let got = message.write_bytes(&mut first)?;
232 debug_assert!(got == size);
233 state.readable += size;
234
235 if let Some((waker, size)) = state.notify_readable.take() {
236 if size <= state.readable {
237 waker.wake();
238 } else {
239 state.notify_readable = Some((waker, size));
240 }
241 }
242
243 Ok(())
244 }
245
246 pub fn is_closed(&self) -> bool {
249 let state = self.0.lock();
250 state.closed.is_closed() && state.readable == 0
251 }
252
253 pub fn closed_reason(&self) -> Option<String> {
256 let state = self.0.lock();
257 state.closed.reason()
258 }
259
260 pub fn close(self, reason: String) {
262 let mut state = self.0.lock();
263 match &state.closed {
264 Status::Closed(Some(_)) => (),
265 _ => state.closed = Status::Closed(Some(reason)),
266 }
267 }
268}
269
270impl std::fmt::Debug for Reader {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 write!(f, "Reader({:?})", Arc::as_ptr(&self.0))
273 }
274}
275
276impl Drop for Reader {
277 fn drop(&mut self) {
278 let mut state = self.0.lock();
279 state.closed.close();
280 }
281}
282
283pub struct Writer(Arc<SyncMutex<State>>);
285
286impl Writer {
287 pub fn write<F>(&self, size: usize, f: F) -> Result<()>
299 where
300 F: FnOnce(&mut [u8]) -> Result<usize>,
301 {
302 let mut state = self.0.lock();
303
304 if let Status::Closed(reason) = &state.closed {
305 return Err(Error::ConnectionClosed(reason.clone()));
306 }
307
308 let total_size = state.readable + size;
309
310 if state.deque.len() < total_size {
311 let total_size = std::cmp::max(total_size, state.deque.capacity());
312 state.deque.resize(total_size, 0);
313 }
314
315 let readable = state.readable;
316 let (first, second) = state.deque.as_mut_slices();
317
318 let slice = if first.len() > readable {
319 &mut first[readable..]
320 } else {
321 &mut second[(readable - first.len())..]
322 };
323
324 let slice = if slice.len() >= size {
325 slice
326 } else {
327 state.deque.make_contiguous();
328 &mut state.deque.as_mut_slices().0[readable..]
329 };
330
331 debug_assert!(slice.len() >= size);
332 let size = f(slice)?;
333
334 if size > slice.len() {
335 panic!("Write claimed to produce more bytes than buffer had space for!");
336 }
337
338 state.readable += size;
339
340 if let Some((waker, size)) = state.notify_readable.take() {
341 if size <= state.readable {
342 waker.wake();
343 } else {
344 state.notify_readable = Some((waker, size));
345 }
346 }
347
348 Ok(())
349 }
350
351 pub fn write_protocol_message<P: protocol::ProtocolMessage>(&self, message: &P) -> Result<()> {
354 self.write(message.byte_size(), |mut buf| message.write_bytes(&mut buf))
355 }
356
357 pub fn close(self, reason: String) {
359 self.0.lock().closed = Status::Closed(Some(reason))
360 }
361
362 pub fn is_closed(&self) -> bool {
365 let state = self.0.lock();
366 state.closed.is_closed() && state.readable == 0
367 }
368
369 pub fn closed_reason(&self) -> Option<String> {
372 let state = self.0.lock();
373 state.closed.reason()
374 }
375}
376
377impl std::fmt::Debug for Writer {
378 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379 write!(f, "Writer({:?})", Arc::as_ptr(&self.0))
380 }
381}
382
383impl Drop for Writer {
384 fn drop(&mut self) {
385 let mut state = self.0.lock();
386 state.closed.close();
387
388 if let Some((waker, _)) = state.notify_readable.take() {
389 waker.wake();
390 }
391 }
392}
393
394pub fn stream() -> (Reader, Writer) {
399 let reader = Arc::new(SyncMutex::new(State {
400 deque: VecDeque::new(),
401 readable: 0,
402 notify_readable: None,
403 closed: Status::Open,
404 }));
405 let writer = Arc::clone(&reader);
406
407 (Reader(reader), Writer(writer))
408}
409
410#[cfg(test)]
411mod test {
412 use futures::FutureExt;
413 use futures::channel::oneshot;
414 use futures::task::noop_waker;
415 use std::future::Future;
416 use std::pin::pin;
417 use std::task::{Context, Poll};
418
419 use super::*;
420
421 impl protocol::ProtocolMessage for [u8; 4] {
422 const MIN_SIZE: usize = 4;
423 fn byte_size(&self) -> usize {
424 4
425 }
426
427 fn write_bytes<W: std::io::Write>(&self, out: &mut W) -> Result<usize> {
428 out.write_all(self)?;
429 Ok(4)
430 }
431
432 fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
433 if bytes.len() < 4 {
434 return Err(Error::BufferTooShort(4));
435 }
436
437 Ok((bytes[..4].try_into().unwrap(), 4))
438 }
439 }
440
441 #[fuchsia::test]
442 async fn stream_test() {
443 let (reader, writer) = stream();
444 writer
445 .write(8, |buf| {
446 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
447 Ok(8)
448 })
449 .unwrap();
450
451 let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
452
453 assert_eq!(vec![1, 2, 3, 4], got);
454
455 writer
456 .write(2, |buf| {
457 buf[..2].copy_from_slice(&[9, 10]);
458 Ok(2)
459 })
460 .unwrap();
461
462 let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();
463
464 assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
465 }
466
467 #[fuchsia::test]
468 async fn push_back_test() {
469 let (reader, writer) = stream();
470 writer
471 .write(8, |buf| {
472 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
473 Ok(8)
474 })
475 .unwrap();
476
477 let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
478
479 assert_eq!(vec![1, 2, 3, 4], got);
480
481 reader.push_back_protocol_message(&[4, 3, 2, 1]).unwrap();
482
483 writer
484 .write(2, |buf| {
485 buf[..2].copy_from_slice(&[9, 10]);
486 Ok(2)
487 })
488 .unwrap();
489
490 let got = reader.read(10, |buf| Ok((buf[..10].to_vec(), 6))).await.unwrap();
491
492 assert_eq!(vec![4, 3, 2, 1, 5, 6, 7, 8, 9, 10], got);
493 }
494
495 #[fuchsia::test]
496 async fn writer_sees_close() {
497 let (reader, writer) = stream();
498 writer
499 .write(8, |buf| {
500 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
501 Ok(8)
502 })
503 .unwrap();
504
505 let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
506
507 assert_eq!(vec![1, 2, 3, 4], got);
508
509 std::mem::drop(reader);
510
511 assert!(matches!(
512 writer.write(2, |buf| {
513 buf[..2].copy_from_slice(&[9, 10]);
514 Ok(2)
515 }),
516 Err(Error::ConnectionClosed(None))
517 ));
518 }
519
520 #[fuchsia::test]
521 async fn reader_sees_closed() {
522 let (reader, writer) = stream();
523 writer
524 .write(8, |buf| {
525 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
526 Ok(8)
527 })
528 .unwrap();
529
530 let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
531
532 assert_eq!(vec![1, 2, 3, 4], got);
533
534 writer
535 .write(2, |buf| {
536 buf[..2].copy_from_slice(&[9, 10]);
537 Ok(2)
538 })
539 .unwrap();
540
541 std::mem::drop(writer);
542
543 assert!(matches!(reader.read(7, |_| Ok(((), 1))).await, Err(Error::BufferTooShort(7))));
544
545 let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();
546
547 assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
548 assert!(matches!(
549 reader.read(1, |_| Ok(((), 1))).await,
550 Err(Error::ConnectionClosed(None))
551 ));
552 }
553
554 #[fuchsia::test]
555 async fn reader_sees_closed_when_polling() {
556 let (reader, writer) = stream();
557 writer
558 .write(8, |buf| {
559 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
560 Ok(8)
561 })
562 .unwrap();
563
564 let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();
565
566 assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);
567
568 let fut = reader
569 .read(1, |_| -> Result<((), usize)> { panic!("This read should never succeed!") });
570 let mut fut = std::pin::pin!(fut);
571
572 assert!(fut.poll_unpin(&mut Context::from_waker(&noop_waker())).is_pending());
573
574 std::mem::drop(writer);
575
576 assert!(matches!(
577 fut.poll_unpin(&mut Context::from_waker(&noop_waker())),
578 Poll::Ready(Err(Error::ConnectionClosed(None)))
579 ));
580 }
581
582 #[fuchsia::test]
583 async fn reader_sees_closed_separate_task() {
584 let (reader, writer) = stream();
585 writer
586 .write(8, |buf| {
587 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
588 Ok(8)
589 })
590 .unwrap();
591
592 let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();
593
594 assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);
595
596 let (sender, receiver) = oneshot::channel();
597 let task = fuchsia_async::Task::spawn(async move {
598 let fut = reader.read(1, |_| Ok(((), 1)));
599 let mut fut = std::pin::pin!(fut);
600 let mut writer = Some(writer);
601 let fut = futures::future::poll_fn(move |cx| {
602 let ret = fut.as_mut().poll(cx);
603
604 if writer.take().is_some() {
605 assert!(matches!(ret, Poll::Pending));
606 }
607
608 ret
609 });
610 assert!(matches!(fut.await, Err(Error::ConnectionClosed(None))));
611 sender.send(()).unwrap();
612 });
613
614 receiver.await.unwrap();
615 task.await;
616 }
617
618 #[fuchsia::test]
619 async fn reader_buffer_too_short() {
620 let (reader, writer) = stream();
621 let (sender, receiver) = oneshot::channel();
622 let mut sender = Some(sender);
623
624 let reader_task = async move {
625 let got = reader
626 .read(1, |buf| {
627 if buf.len() != 4 {
628 sender.take().unwrap().send(buf.len()).unwrap();
629 Err(Error::BufferTooShort(4))
630 } else {
631 Ok((buf[..4].to_vec(), 4))
632 }
633 })
634 .await
635 .unwrap();
636 assert_eq!(vec![1, 2, 3, 4], got);
637 };
638
639 let writer_task = async move {
640 writer
641 .write(2, |buf| {
642 buf[..2].copy_from_slice(&[1, 2]);
643 Ok(2)
644 })
645 .unwrap();
646
647 assert_eq!(2, receiver.await.unwrap());
648
649 writer
650 .write(2, |buf| {
651 buf[..2].copy_from_slice(&[3, 4]);
652 Ok(2)
653 })
654 .unwrap();
655 };
656
657 futures::future::join(pin!(reader_task), pin!(writer_task)).await;
658 }
659}