1#![warn(
6 missing_docs,
7 unreachable_patterns,
8 clippy::useless_conversion,
9 clippy::redundant_clone,
10 clippy::precedence
11)]
12
13use std::collections::VecDeque;
21
22use thiserror::Error;
23
24#[derive(Error, Debug, PartialEq)]
26#[error("write size is larger than buffer size")]
27pub struct WriteTooLarge;
28
29#[derive(Debug)]
43pub struct RingBuffer {
44 buf: Vec<u8>,
45 chunk_size: usize,
46 tail: usize,
47 boundary_indices: VecDeque<usize>,
51}
52
53impl RingBuffer {
54 pub fn new(size: usize, chunk_size: usize) -> Self {
59 let size = if chunk_size == 0 { size } else { size.next_multiple_of(chunk_size) };
60 Self { buf: vec![0u8; size], chunk_size, tail: 0, boundary_indices: VecDeque::new() }
61 }
62
63 fn head(&self) -> usize {
65 *self.boundary_indices.front().unwrap_or(&self.tail)
66 }
67
68 pub fn len(&self) -> usize {
70 let head = self.head();
71 if self.tail == head && !self.boundary_indices.is_empty() {
72 self.buf.len()
73 } else {
74 self.bytes_between(head, self.tail)
75 }
76 }
77
78 fn bytes_between(&self, i: usize, j: usize) -> usize {
79 if j >= i { j - i } else { self.buf.len() - i + j }
80 }
81
82 pub fn get_view(&self) -> (&[u8], &[u8]) {
84 let head = self.head();
85 if self.tail > head || self.tail == head && self.boundary_indices.is_empty() {
86 (&self.buf[head..self.tail], &[])
87 } else {
88 (&self.buf[head..], &self.buf[..self.tail])
89 }
90 }
91
92 pub fn start_transaction(&mut self) -> Transaction<'_> {
94 Transaction::new(self)
95 }
96
97 fn write_inner(&mut self, slice: &[u8]) {
101 while self.len() + slice.len() > self.buf.len() {
102 let _ = self.boundary_indices.pop_front();
103 if self.boundary_indices.is_empty() {
104 break;
105 }
106 }
107 if self.boundary_indices.is_empty() {
108 self.boundary_indices.push_back(self.tail);
109 }
110
111 if self.tail + slice.len() >= self.buf.len() {
112 let remaining = self.buf.len() - self.tail;
113 assert!(remaining > 0);
114 self.buf[self.tail..self.tail + remaining].copy_from_slice(&slice[..remaining]);
115 let data_remaining = slice.len() - remaining;
116 if data_remaining > 0 {
117 self.buf[..data_remaining].copy_from_slice(&slice[remaining..]);
118 }
119 self.tail = data_remaining;
120 } else {
121 self.buf[self.tail..self.tail + slice.len()].copy_from_slice(&slice);
122 self.tail += slice.len();
123 }
124 }
125
126 fn maybe_chunk(&mut self) {
127 let Some(penultimate) = self.boundary_indices.back() else {
128 return;
129 };
130 if *penultimate == self.tail {
134 return;
135 }
136 if self.bytes_between(*penultimate, self.tail) >= self.chunk_size {
137 self.boundary_indices.push_back(self.tail);
138 }
139 }
140
141 pub fn write(&mut self, slice: &[u8]) -> Result<(), WriteTooLarge> {
143 let mut transaction = self.start_transaction();
144 transaction.write(slice)?;
145 transaction.commit();
146 Ok(())
147 }
148
149 fn rollback(&mut self, start: usize) {
150 if self.head() == start {
151 self.boundary_indices.clear();
152 }
153 self.tail = start;
154 }
155}
156
157pub struct Transaction<'a> {
160 buffer: &'a mut RingBuffer,
161 start: usize,
162 written: usize,
163 completed: bool,
164}
165
166impl<'a> Transaction<'a> {
167 pub fn new(buffer: &'a mut RingBuffer) -> Self {
169 let start = buffer.tail;
170 Self { buffer, start, written: 0, completed: false }
171 }
172
173 pub fn write(&mut self, bytes: &[u8]) -> Result<(), WriteTooLarge> {
175 if bytes.len() == 0 {
176 return Ok(());
177 }
178
179 if self.written + bytes.len() > self.buffer.buf.len() {
180 return Err(WriteTooLarge);
181 }
182
183 self.buffer.write_inner(bytes);
184 self.written += bytes.len();
185 Ok(())
186 }
187
188 pub fn commit(mut self) {
190 self.buffer.maybe_chunk();
191 self.completed = true;
192 }
193}
194
195impl<'a> Drop for Transaction<'a> {
196 fn drop(&mut self) {
197 if !self.completed {
198 self.buffer.rollback(self.start);
199 }
200 }
201}
202
203impl<'a> std::io::Write for Transaction<'a> {
204 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
205 self.write(buf).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
206 Ok(buf.len())
207 }
208 fn flush(&mut self) -> std::io::Result<()> {
209 Ok(())
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 use test_case::test_case;
218
219 #[test]
220 fn test_write_no_wrap() {
221 let mut cb = RingBuffer::new(8, 4);
222 const DATA: &[u8] = b"hello";
223 cb.write(DATA).unwrap();
224
225 assert_eq!(cb.boundary_indices, VecDeque::from([0, 5]));
226 let (v1, v2) = cb.get_view();
227 assert_eq!(v1, DATA);
228 assert_eq!(v2, &[]);
229 }
230
231 #[test]
232 fn test_write_wrap() {
233 let mut cb = RingBuffer::new(8, 4);
234 let mut tx = cb.start_transaction();
235 tx.write(b"foo").unwrap();
236 tx.write(b"bar").unwrap();
237 tx.commit();
238 assert_eq!(cb.boundary_indices, VecDeque::from([0, 6]));
239
240 let mut tx = cb.start_transaction();
241 tx.write(b"baz").unwrap();
242 tx.write(b"qux").unwrap();
243 tx.commit();
244 assert_eq!(cb.boundary_indices, VecDeque::from([6, 4]));
245
246 let (v1, v2) = cb.get_view();
247 assert_eq!(v1, b"ba");
248 assert_eq!(v2, b"zqux");
249 }
250
251 #[test_case(8; "chunk_size_equals_buffer_size")]
252 #[test_case(4; "chunk_size_half_of_buffer_size")]
253 fn test_write_exact_fill(chunk_size: usize) {
254 let mut cb = RingBuffer::new(8, chunk_size);
255 cb.write(b"12345678").unwrap();
256 assert_eq!(cb.boundary_indices, VecDeque::from([0]));
257 assert_eq!(cb.tail, 0);
258
259 let (v1, v2) = cb.get_view();
260 assert_eq!(v1, b"12345678");
261 assert_eq!(v2, &[]);
262
263 cb.write(b"9").unwrap();
264 let (v1, v2) = cb.get_view();
265 assert_eq!(v1, b"9");
266 assert_eq!(v2, &[]);
267 }
268
269 #[test]
270 fn test_write_smaller_than_chunk_size() {
271 const CHUNK_SIZE: usize = 4;
272 let mut cb = RingBuffer::new(8, CHUNK_SIZE);
273 for i in 1u8..4 {
274 cb.write(&[i]).unwrap();
275 assert_eq!(cb.boundary_indices, VecDeque::from([0]));
276 }
277 cb.write(&[4]).unwrap();
278 assert_eq!(cb.boundary_indices, VecDeque::from([0, CHUNK_SIZE]));
279
280 let (v1, v2) = cb.get_view();
281 assert_eq!(v1, &[1, 2, 3, 4]);
282 assert_eq!(v2, &[]);
283 }
284
285 #[test]
286 fn test_zero_chunk_size_zero_writes() {
287 let mut cb = RingBuffer::new(8, 0);
288 cb.write(b"").unwrap();
289 assert_eq!(cb.boundary_indices, VecDeque::new());
290 let (v1, v2) = cb.get_view();
291 assert_eq!(v1, &[]);
292 assert_eq!(v2, &[]);
293 }
294
295 #[test]
296 fn test_zero_chunk_size_fill_and_overwrite() {
297 const N: usize = 4;
298 let mut cb = RingBuffer::new(N, 0);
299
300 for i in 1u8..=4 {
301 cb.write(&[i]).unwrap();
302 }
303
304 cb.write(&[5]).unwrap();
305 let (v1, v2) = cb.get_view();
306 assert_eq!([v1, v2].concat(), vec![2, 3, 4, 5]);
307
308 cb.write(&[6]).unwrap();
309 let (v1, v2) = cb.get_view();
310 assert_eq!([v1, v2].concat(), vec![3, 4, 5, 6]);
311 }
312
313 #[test]
314 fn test_transaction_rollback_on_drop() {
315 let mut cb = RingBuffer::new(8, 4);
316 cb.write(b"ab").unwrap();
317 {
318 let mut tx = cb.start_transaction();
319 tx.write(b"cd").unwrap();
320 tx.write(b"ef").unwrap();
321 }
323
324 let (v1, v2) = cb.get_view();
325 assert_eq!(v1, b"ab");
326 assert_eq!(v2, &[]);
327 }
328
329 #[test]
330 fn test_transaction_too_large() {
331 let mut cb = RingBuffer::new(8, 4);
332 let mut tx = cb.start_transaction();
333 tx.write(b"12345678").unwrap();
334 assert_eq!(tx.write(b"9"), Err(WriteTooLarge));
335 }
336}