1use fuchsia_async::{self as fasync, ReadableHandle, ReadableState};
6
7use futures::Stream;
8use std::pin::Pin;
9use std::task::{ready, Context, Poll};
10use thiserror::Error;
11
12const NEWLINE: u8 = b'\n';
13
14pub struct NewlineChunker {
22 socket: fasync::Socket,
23 buffer: Vec<u8>,
24 is_terminated: bool,
25 max_message_size: usize,
26 trim_newlines: bool,
27}
28
29impl NewlineChunker {
30 pub fn new(socket: fasync::Socket, max_message_size: usize) -> Self {
32 Self { socket, buffer: vec![], is_terminated: false, max_message_size, trim_newlines: true }
33 }
34
35 pub fn new_with_newlines(socket: fasync::Socket, max_message_size: usize) -> Self {
37 Self {
38 socket,
39 buffer: vec![],
40 is_terminated: false,
41 max_message_size,
42 trim_newlines: false,
43 }
44 }
45
46 fn next_chunk_from_buffer(&mut self) -> Option<Vec<u8>> {
49 let new_tail_start =
50 if let Some(mut newline_pos) = self.buffer.iter().position(|&b| b == NEWLINE) {
51 while let Some(&NEWLINE) = self.buffer.get(newline_pos + 1) {
53 newline_pos += 1;
54 }
55 newline_pos + 1
56 } else if self.buffer.len() >= self.max_message_size {
57 self.max_message_size
61 } else {
62 return None;
64 };
65
66 let new_tail = self.buffer.split_off(new_tail_start);
68 let mut next_chunk = std::mem::replace(&mut self.buffer, new_tail);
69
70 if self.trim_newlines {
71 while let Some(&NEWLINE) = next_chunk.last() {
73 next_chunk.pop();
74 }
75 }
76
77 Some(next_chunk)
78 }
79
80 fn end_of_stream(&mut self) -> Poll<Option<Vec<u8>>> {
81 if !self.buffer.is_empty() {
82 Poll::Ready(Some(std::mem::replace(&mut self.buffer, vec![])))
84 } else {
85 self.is_terminated = true;
87 Poll::Ready(None)
88 }
89 }
90}
91
92impl Stream for NewlineChunker {
93 type Item = Result<Vec<u8>, NewlineChunkerError>;
94
95 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
96 let this = self.get_mut();
97
98 if this.is_terminated {
99 return Poll::Ready(None);
100 }
101
102 if let Some(chunk) = this.next_chunk_from_buffer() {
104 return Poll::Ready(Some(Ok(chunk)));
105 }
106
107 loop {
108 let readable_state = futures::ready!(this.socket.poll_readable(cx))
110 .map_err(NewlineChunkerError::PollReadable)?;
111
112 let bytes_in_socket = this
114 .socket
115 .as_ref()
116 .outstanding_read_bytes()
117 .map_err(NewlineChunkerError::OutstandingReadBytes)?;
118 if bytes_in_socket == 0 {
119 if readable_state == ReadableState::MaybeReadableAndClosed {
120 return this.end_of_stream().map(|buf| buf.map(Ok));
121 }
122 ready!(this.socket.need_readable(cx).map_err(NewlineChunkerError::NeedReadable)?);
124 continue;
125 }
126
127 let bytes_to_read = std::cmp::min(bytes_in_socket, this.max_message_size);
129 let prev_len = this.buffer.len();
130
131 this.buffer.resize(prev_len + bytes_to_read, 0);
134
135 let bytes_read = match this.socket.as_ref().read(&mut this.buffer[prev_len..]) {
136 Ok(b) => b,
137 Err(zx::Status::PEER_CLOSED) => return this.end_of_stream().map(|buf| buf.map(Ok)),
138 Err(zx::Status::SHOULD_WAIT) => {
139 this.buffer.truncate(prev_len);
141 return Poll::Ready(Some(Err(NewlineChunkerError::ShouldWait)));
142 }
143 Err(status) => {
144 this.buffer.truncate(prev_len);
146 return Poll::Ready(Some(Err(NewlineChunkerError::ReadSocket(status))));
147 }
148 };
149
150 this.buffer.truncate(prev_len + bytes_read);
152
153 if let Some(chunk) = this.next_chunk_from_buffer() {
155 return Poll::Ready(Some(Ok(chunk)));
157 } else {
158 ready!(this.socket.need_readable(cx).map_err(NewlineChunkerError::NeedReadable)?);
160 }
161 }
162 }
163}
164
165#[derive(Debug, Error)]
166pub enum NewlineChunkerError {
167 #[error("got SHOULD_WAIT from socket read after confirming outstanding_read_bytes > 0")]
168 ShouldWait,
169
170 #[error("failed to read from socket")]
171 ReadSocket(#[source] zx::Status),
172
173 #[error("failed to get readable state for socket")]
174 PollReadable(#[source] zx::Status),
175
176 #[error("failed to register readable signal for socket")]
177 NeedReadable(#[source] zx::Status),
178
179 #[error("failed to get number of outstanding readable bytes in socket")]
180 OutstandingReadBytes(#[source] zx::Status),
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use futures::StreamExt;
187
188 #[fuchsia::test]
189 async fn parse_bytes_with_newline() {
190 let (s1, s2) = zx::Socket::create_stream();
191 let s1 = fasync::Socket::from_socket(s1);
192 let mut chunker = NewlineChunker::new(s1, 100);
193 s2.write(b"test\n").expect("Failed to write");
194 assert_eq!(chunker.next().await.unwrap().unwrap(), b"test".to_vec());
195 }
196
197 #[fuchsia::test]
198 async fn parse_bytes_with_many_newlines() {
199 let (s1, s2) = zx::Socket::create_stream();
200 let s1 = fasync::Socket::from_socket(s1);
201 let mut chunker = NewlineChunker::new(s1, 100);
202 s2.write(b"test1\ntest2\ntest3\n").expect("Failed to write");
203 assert_eq!(chunker.next().await.unwrap().unwrap(), b"test1".to_vec());
204 assert_eq!(chunker.next().await.unwrap().unwrap(), b"test2".to_vec());
205 assert_eq!(chunker.next().await.unwrap().unwrap(), b"test3".to_vec());
206 std::mem::drop(s2);
207 assert!(chunker.next().await.is_none());
208 }
209
210 #[fuchsia::test]
211 async fn parse_bytes_with_newlines_included() {
212 let (s1, s2) = zx::Socket::create_stream();
213 let s1 = fasync::Socket::from_socket(s1);
214 let mut chunker = NewlineChunker::new_with_newlines(s1, 100);
215 s2.write(b"test1\ntest2\ntest3\n").expect("Failed to write");
216 assert_eq!(chunker.next().await.unwrap().unwrap(), b"test1\n".to_vec());
217 assert_eq!(chunker.next().await.unwrap().unwrap(), b"test2\n".to_vec());
218 assert_eq!(chunker.next().await.unwrap().unwrap(), b"test3\n".to_vec());
219 }
220
221 #[fuchsia::test]
222 async fn max_message_size() {
223 let (s1, s2) = zx::Socket::create_stream();
224 let s1 = fasync::Socket::from_socket(s1);
225 let mut chunker = NewlineChunker::new(s1, 2);
226 s2.write(b"test\n").expect("Failed to write");
227 assert_eq!(chunker.next().await.unwrap().unwrap(), b"te".to_vec());
228 assert_eq!(chunker.next().await.unwrap().unwrap(), b"st".to_vec());
229 }
230}