1use crate::error::{Error, Result};
6use crate::protocol;
7
8use std::collections::VecDeque;
9use std::sync::{Arc, Mutex as SyncMutex};
10use tokio::sync::oneshot;
11
12const BUFFER_TRIM_GRANULARITY: usize = 1048576;
18
19const _: () = assert!(BUFFER_TRIM_GRANULARITY.is_power_of_two());
21
22#[derive(Debug, Clone)]
24enum Status {
25 Open,
27 Closed(Option<String>),
29}
30
31impl Status {
32 fn is_closed(&self) -> bool {
33 match self {
34 Status::Open => false,
35 Status::Closed(_) => true,
36 }
37 }
38
39 fn reason(&self) -> Option<String> {
40 match self {
41 Status::Open => None,
42 Status::Closed(x) => x.clone(),
43 }
44 }
45
46 fn close(&mut self) {
47 if let Status::Open = self {
48 *self = Status::Closed(None);
49 }
50 }
51}
52
53#[derive(Debug)]
55struct State {
56 deque: VecDeque<u8>,
58 readable: usize,
65 notify_readable: Option<(oneshot::Sender<()>, usize)>,
68 closed: Status,
70}
71
72pub struct Reader(Arc<SyncMutex<State>>);
74
75impl Reader {
76 pub fn inspect_shutdown(&self) -> String {
78 let lock = self.0.lock().unwrap();
79 if lock.closed.is_closed() {
80 lock.closed.reason().unwrap_or_else(|| "No epitaph".to_owned())
81 } else {
82 "Not closed".to_owned()
83 }
84 }
85
86 pub async fn read<F, U>(&self, mut size: usize, mut f: F) -> Result<U>
112 where
113 F: FnMut(&[u8]) -> Result<(U, usize)>,
114 {
115 loop {
116 let receiver = {
117 let mut state = self.0.lock().unwrap();
118
119 if let Status::Closed(reason) = &state.closed {
120 if size == 0 {
121 return Err(Error::ConnectionClosed(reason.clone()));
122 }
123 }
124
125 if state.readable >= size {
126 let (first, _) = state.deque.as_slices();
127
128 let first = if first.len() >= size {
129 first
130 } else {
131 state.deque.make_contiguous();
132 state.deque.as_slices().0
133 };
134
135 debug_assert!(first.len() >= size);
136
137 let first = &first[..std::cmp::min(first.len(), state.readable)];
138 let (ret, consumed) = match f(first) {
139 Err(Error::BufferTooShort(s)) => {
140 if s < first.len() {
141 return Err(Error::CallbackRejectedBuffer(s, first.len()));
142 }
143
144 size = s;
145 continue;
146 }
147 other => other?,
148 };
149
150 if consumed > first.len() {
151 panic!("Read claimed to consume more bytes than it was given!");
152 }
153
154 state.readable -= consumed;
155 state.deque.drain(..consumed);
156 let target_capacity = std::cmp::max(
157 state.deque.len().next_multiple_of(BUFFER_TRIM_GRANULARITY),
158 BUFFER_TRIM_GRANULARITY,
159 );
160
161 if target_capacity <= state.deque.capacity() / 2 {
162 state.deque.shrink_to(target_capacity);
163 }
164 return Ok(ret);
165 }
166
167 if let Status::Closed(reason) = &state.closed {
168 if state.readable > 0 {
169 return Err(Error::BufferTooShort(size));
170 } else {
171 return Err(Error::ConnectionClosed(reason.clone()));
172 }
173 }
174
175 let (sender, receiver) = oneshot::channel();
176 state.notify_readable = Some((sender, size));
177 receiver
178 };
179
180 let _ = receiver.await;
181 }
182 }
183
184 pub async fn read_protocol_message<P: protocol::ProtocolMessage>(&self) -> Result<P> {
187 self.read(P::MIN_SIZE, P::try_from_bytes).await
188 }
189
190 pub(crate) fn push_back_protocol_message<P: protocol::ProtocolMessage>(
193 &self,
194 message: &P,
195 ) -> Result<()> {
196 let size = message.byte_size();
197 let mut state = self.0.lock().unwrap();
198 let readable = state.readable;
199 state.deque.resize(readable + size, 0);
200 state.deque.rotate_right(size);
201 let (first, _) = state.deque.as_mut_slices();
202
203 let mut first = if first.len() >= size {
204 first
205 } else {
206 state.deque.make_contiguous();
207 state.deque.as_mut_slices().0
208 };
209
210 let got = message.write_bytes(&mut first)?;
211 debug_assert!(got == size);
212 state.readable += size;
213
214 if let Some((sender, size)) = state.notify_readable.take() {
215 if size <= state.readable {
216 let _ = sender.send(());
217 } else {
218 state.notify_readable = Some((sender, size));
219 }
220 }
221
222 Ok(())
223 }
224
225 pub fn is_closed(&self) -> bool {
228 let state = self.0.lock().unwrap();
229 state.closed.is_closed() && state.readable == 0
230 }
231
232 pub fn closed_reason(&self) -> Option<String> {
235 let state = self.0.lock().unwrap();
236 state.closed.reason()
237 }
238
239 pub fn close(self, reason: String) {
241 let mut state = self.0.lock().unwrap();
242 match &state.closed {
243 Status::Closed(Some(_)) => (),
244 _ => state.closed = Status::Closed(Some(reason)),
245 }
246 }
247}
248
249impl std::fmt::Debug for Reader {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 write!(f, "Reader({:?})", Arc::as_ptr(&self.0))
252 }
253}
254
255impl Drop for Reader {
256 fn drop(&mut self) {
257 let mut state = self.0.lock().unwrap();
258 state.closed.close();
259 }
260}
261
262pub struct Writer(Arc<SyncMutex<State>>);
264
265impl Writer {
266 pub fn write<F>(&self, size: usize, f: F) -> Result<()>
278 where
279 F: FnOnce(&mut [u8]) -> Result<usize>,
280 {
281 let mut state = self.0.lock().unwrap();
282
283 if let Status::Closed(reason) = &state.closed {
284 return Err(Error::ConnectionClosed(reason.clone()));
285 }
286
287 let total_size = state.readable + size;
288
289 if state.deque.len() < total_size {
290 let total_size = std::cmp::max(total_size, state.deque.capacity());
291 state.deque.resize(total_size, 0);
292 }
293
294 let readable = state.readable;
295 let (first, second) = state.deque.as_mut_slices();
296
297 let slice = if first.len() > readable {
298 &mut first[readable..]
299 } else {
300 &mut second[(readable - first.len())..]
301 };
302
303 let slice = if slice.len() >= size {
304 slice
305 } else {
306 state.deque.make_contiguous();
307 &mut state.deque.as_mut_slices().0[readable..]
308 };
309
310 debug_assert!(slice.len() >= size);
311 let size = f(slice)?;
312
313 if size > slice.len() {
314 panic!("Write claimed to produce more bytes than buffer had space for!");
315 }
316
317 state.readable += size;
318
319 if let Some((sender, size)) = state.notify_readable.take() {
320 if size <= state.readable {
321 let _ = sender.send(());
322 } else {
323 state.notify_readable = Some((sender, size));
324 }
325 }
326
327 Ok(())
328 }
329
330 pub fn write_protocol_message<P: protocol::ProtocolMessage>(&self, message: &P) -> Result<()> {
333 self.write(message.byte_size(), |mut buf| message.write_bytes(&mut buf))
334 }
335
336 pub fn close(self, reason: String) {
338 self.0.lock().unwrap().closed = Status::Closed(Some(reason))
339 }
340
341 pub fn is_closed(&self) -> bool {
344 let state = self.0.lock().unwrap();
345 state.closed.is_closed() && state.readable == 0
346 }
347
348 pub fn closed_reason(&self) -> Option<String> {
351 let state = self.0.lock().unwrap();
352 state.closed.reason()
353 }
354}
355
356impl std::fmt::Debug for Writer {
357 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358 write!(f, "Writer({:?})", Arc::as_ptr(&self.0))
359 }
360}
361
362impl Drop for Writer {
363 fn drop(&mut self) {
364 let Some(x) = ({
365 let mut state = self.0.lock().unwrap();
366 state.closed.close();
367
368 state.notify_readable.take()
369 }) else {
370 return;
371 };
372 let _ = x.0.send(());
373 }
374}
375
376pub fn stream() -> (Reader, Writer) {
381 let reader = Arc::new(SyncMutex::new(State {
382 deque: VecDeque::new(),
383 readable: 0,
384 notify_readable: None,
385 closed: Status::Open,
386 }));
387 let writer = Arc::clone(&reader);
388
389 (Reader(reader), Writer(writer))
390}
391
392#[cfg(test)]
393mod test {
394 use futures::task::noop_waker;
395 use futures::FutureExt;
396 use std::future::Future;
397 use std::pin::pin;
398 use std::task::{Context, Poll};
399
400 use super::*;
401
402 impl protocol::ProtocolMessage for [u8; 4] {
403 const MIN_SIZE: usize = 4;
404 fn byte_size(&self) -> usize {
405 4
406 }
407
408 fn write_bytes<W: std::io::Write>(&self, out: &mut W) -> Result<usize> {
409 out.write_all(self)?;
410 Ok(4)
411 }
412
413 fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
414 if bytes.len() < 4 {
415 return Err(Error::BufferTooShort(4));
416 }
417
418 Ok((bytes[..4].try_into().unwrap(), 4))
419 }
420 }
421
422 #[fuchsia::test]
423 async fn stream_test() {
424 let (reader, writer) = stream();
425 writer
426 .write(8, |buf| {
427 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
428 Ok(8)
429 })
430 .unwrap();
431
432 let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
433
434 assert_eq!(vec![1, 2, 3, 4], got);
435
436 writer
437 .write(2, |buf| {
438 buf[..2].copy_from_slice(&[9, 10]);
439 Ok(2)
440 })
441 .unwrap();
442
443 let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();
444
445 assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
446 }
447
448 #[fuchsia::test]
449 async fn push_back_test() {
450 let (reader, writer) = stream();
451 writer
452 .write(8, |buf| {
453 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
454 Ok(8)
455 })
456 .unwrap();
457
458 let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
459
460 assert_eq!(vec![1, 2, 3, 4], got);
461
462 reader.push_back_protocol_message(&[4, 3, 2, 1]).unwrap();
463
464 writer
465 .write(2, |buf| {
466 buf[..2].copy_from_slice(&[9, 10]);
467 Ok(2)
468 })
469 .unwrap();
470
471 let got = reader.read(10, |buf| Ok((buf[..10].to_vec(), 6))).await.unwrap();
472
473 assert_eq!(vec![4, 3, 2, 1, 5, 6, 7, 8, 9, 10], got);
474 }
475
476 #[fuchsia::test]
477 async fn writer_sees_close() {
478 let (reader, writer) = stream();
479 writer
480 .write(8, |buf| {
481 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
482 Ok(8)
483 })
484 .unwrap();
485
486 let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
487
488 assert_eq!(vec![1, 2, 3, 4], got);
489
490 std::mem::drop(reader);
491
492 assert!(matches!(
493 writer.write(2, |buf| {
494 buf[..2].copy_from_slice(&[9, 10]);
495 Ok(2)
496 }),
497 Err(Error::ConnectionClosed(None))
498 ));
499 }
500
501 #[fuchsia::test]
502 async fn reader_sees_closed() {
503 let (reader, writer) = stream();
504 writer
505 .write(8, |buf| {
506 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
507 Ok(8)
508 })
509 .unwrap();
510
511 let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
512
513 assert_eq!(vec![1, 2, 3, 4], got);
514
515 writer
516 .write(2, |buf| {
517 buf[..2].copy_from_slice(&[9, 10]);
518 Ok(2)
519 })
520 .unwrap();
521
522 std::mem::drop(writer);
523
524 assert!(matches!(reader.read(7, |_| Ok(((), 1))).await, Err(Error::BufferTooShort(7))));
525
526 let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();
527
528 assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
529 assert!(matches!(
530 reader.read(1, |_| Ok(((), 1))).await,
531 Err(Error::ConnectionClosed(None))
532 ));
533 }
534
535 #[fuchsia::test]
536 async fn reader_sees_closed_when_polling() {
537 let (reader, writer) = stream();
538 writer
539 .write(8, |buf| {
540 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
541 Ok(8)
542 })
543 .unwrap();
544
545 let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();
546
547 assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);
548
549 let fut = reader
550 .read(1, |_| -> Result<((), usize)> { panic!("This read should never succeed!") });
551 let mut fut = std::pin::pin!(fut);
552
553 assert!(fut.poll_unpin(&mut Context::from_waker(&noop_waker())).is_pending());
554
555 std::mem::drop(writer);
556
557 assert!(matches!(
558 fut.poll_unpin(&mut Context::from_waker(&noop_waker())),
559 Poll::Ready(Err(Error::ConnectionClosed(None)))
560 ));
561 }
562
563 #[fuchsia::test]
564 async fn reader_sees_closed_separate_task() {
565 let (reader, writer) = stream();
566 writer
567 .write(8, |buf| {
568 buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
569 Ok(8)
570 })
571 .unwrap();
572
573 let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();
574
575 assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);
576
577 let (sender, receiver) = oneshot::channel();
578 let task = fuchsia_async::Task::spawn(async move {
579 let fut = reader.read(1, |_| Ok(((), 1)));
580 let mut fut = std::pin::pin!(fut);
581 let mut writer = Some(writer);
582 let fut = futures::future::poll_fn(move |cx| {
583 let ret = fut.as_mut().poll(cx);
584
585 if writer.take().is_some() {
586 assert!(matches!(ret, Poll::Pending));
587 }
588
589 ret
590 });
591 assert!(matches!(fut.await, Err(Error::ConnectionClosed(None))));
592 sender.send(()).unwrap();
593 });
594
595 receiver.await.unwrap();
596 task.await;
597 }
598
599 #[fuchsia::test]
600 async fn reader_buffer_too_short() {
601 let (reader, writer) = stream();
602 let (sender, receiver) = oneshot::channel();
603 let mut sender = Some(sender);
604
605 let reader_task = async move {
606 let got = reader
607 .read(1, |buf| {
608 if buf.len() != 4 {
609 sender.take().unwrap().send(buf.len()).unwrap();
610 Err(Error::BufferTooShort(4))
611 } else {
612 Ok((buf[..4].to_vec(), 4))
613 }
614 })
615 .await
616 .unwrap();
617 assert_eq!(vec![1, 2, 3, 4], got);
618 };
619
620 let writer_task = async move {
621 writer
622 .write(2, |buf| {
623 buf[..2].copy_from_slice(&[1, 2]);
624 Ok(2)
625 })
626 .unwrap();
627
628 assert_eq!(2, receiver.await.unwrap());
629
630 writer
631 .write(2, |buf| {
632 buf[..2].copy_from_slice(&[3, 4]);
633 Ok(2)
634 })
635 .unwrap();
636 };
637
638 futures::future::join(pin!(reader_task), pin!(writer_task)).await;
639 }
640}