use alloc::{vec, vec::Vec};
use core::{
cmp,
fmt::Debug,
num::{NonZeroUsize, TryFromIntError},
ops::Range,
};
use either::Either;
use packet::InnerPacketBuilder;
use crate::transport::tcp::{
segment::Payload,
seqnum::{SeqNum, WindowSize},
state::Takeable,
BufferSizes,
};
pub trait Buffer: Takeable + Debug + Sized {
fn limits(&self) -> BufferLimits;
fn target_capacity(&self) -> usize;
fn request_capacity(&mut self, size: usize);
}
pub trait ReceiveBuffer: Buffer {
fn write_at<P: Payload>(&mut self, offset: usize, data: &P) -> usize;
fn make_readable(&mut self, count: usize);
}
pub trait SendBuffer: Buffer {
fn mark_read(&mut self, count: usize);
fn peek_with<'a, F, R>(&'a mut self, offset: usize, f: F) -> R
where
F: FnOnce(SendPayload<'a>) -> R;
}
pub struct BufferLimits {
pub capacity: usize,
pub len: usize,
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum SendPayload<'a> {
Contiguous(&'a [u8]),
Straddle(&'a [u8], &'a [u8]),
}
impl Payload for SendPayload<'_> {
fn len(&self) -> usize {
match self {
SendPayload::Contiguous(p) => p.len(),
SendPayload::Straddle(p1, p2) => p1.len() + p2.len(),
}
}
fn slice(self, range: Range<u32>) -> Self {
match self {
SendPayload::Contiguous(p) => SendPayload::Contiguous(p.slice(range)),
SendPayload::Straddle(p1, p2) => {
let Range { start, end } = range;
let start = usize::try_from(start).unwrap_or_else(|TryFromIntError { .. }| {
panic!(
"range start index {} out of range for slice of length {}",
start,
self.len()
)
});
let end = usize::try_from(end).unwrap_or_else(|TryFromIntError { .. }| {
panic!(
"range end index {} out of range for slice of length {}",
end,
self.len()
)
});
assert!(start <= end);
let first_len = p1.len();
if start < first_len && end > first_len {
SendPayload::Straddle(&p1[start..first_len], &p2[0..end - first_len])
} else if start >= first_len {
SendPayload::Contiguous(&p2[start - first_len..end - first_len])
} else {
SendPayload::Contiguous(&p1[start..end])
}
}
}
}
fn partial_copy(&self, offset: usize, dst: &mut [u8]) {
match self {
SendPayload::Contiguous(p) => p.partial_copy(offset, dst),
SendPayload::Straddle(p1, p2) => {
if offset < p1.len() {
let first_len = dst.len().min(p1.len() - offset);
p1.partial_copy(offset, &mut dst[..first_len]);
if dst.len() > first_len {
p2.partial_copy(0, &mut dst[first_len..]);
}
} else {
p2.partial_copy(offset - p1.len(), dst);
}
}
}
}
}
impl InnerPacketBuilder for SendPayload<'_> {
fn bytes_len(&self) -> usize {
match self {
SendPayload::Contiguous(p) => p.len(),
SendPayload::Straddle(p1, p2) => p1.len() + p2.len(),
}
}
fn serialize(&self, buffer: &mut [u8]) {
self.partial_copy(0, buffer);
}
}
#[cfg_attr(any(test, feature = "testutils"), derive(Clone, PartialEq, Eq))]
pub struct RingBuffer {
storage: Vec<u8>,
head: usize,
len: usize,
shrink: Option<PendingShrink>,
}
impl Debug for RingBuffer {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let Self { storage, head, len, shrink } = self;
f.debug_struct("RingBuffer")
.field("storage (len, cap)", &(storage.len(), storage.capacity()))
.field("head", head)
.field("len", len)
.field("shrink", shrink)
.finish()
}
}
#[derive(Debug)]
#[cfg_attr(any(test, feature = "testutils"), derive(Copy, Clone, Eq, PartialEq))]
struct PendingShrink {
target: NonZeroUsize,
current: usize,
}
impl Default for RingBuffer {
fn default() -> Self {
Self::new(WindowSize::DEFAULT.into())
}
}
impl RingBuffer {
pub fn new(capacity: usize) -> Self {
Self { storage: vec![0; capacity], head: 0, len: 0, shrink: None }
}
fn with_readable<'a, F, R>(storage: &'a Vec<u8>, start: usize, len: usize, f: F) -> R
where
F: for<'b> FnOnce(&'b [&'a [u8]]) -> R,
{
let end = start + len;
if end > storage.len() {
let first_part = &storage[start..storage.len()];
let second_part = &storage[0..len - first_part.len()];
f(&[first_part, second_part][..])
} else {
let all_bytes = &storage[start..end];
f(&[all_bytes][..])
}
}
fn maybe_shrink(&mut self, max_shrink_by: usize) {
let Self { storage, head, len: _, shrink } = self;
let PendingShrink { target, current } = match shrink {
Some(x) => x,
None => return,
};
let target = target.get();
*current = core::cmp::min(target, *current + max_shrink_by);
if target == *current {
let mut new_storage = Vec::new();
new_storage.reserve_exact(storage.len() - target);
if let Some(writable_end) = (*head).checked_sub(target) {
new_storage.extend_from_slice(&storage[..writable_end]);
new_storage.extend_from_slice(&storage[(*head)..]);
*head = writable_end;
} else {
let unreserved_len = storage.len() - target;
new_storage.extend_from_slice(&storage[*head..(*head + unreserved_len)]);
*head = 0;
}
*storage = new_storage;
*shrink = None;
return;
}
}
pub fn read_with<F>(&mut self, f: F) -> usize
where
F: for<'a, 'b> FnOnce(&'b [&'a [u8]]) -> usize,
{
let Self { storage, head, len, shrink: _ } = self;
if storage.len() == 0 {
return f(&[&[]]);
}
let nread = RingBuffer::with_readable(storage, *head, *len, f);
assert!(nread <= *len);
*len -= nread;
*head = (*head + nread) % storage.len();
self.maybe_shrink(nread);
nread
}
pub fn writable_regions(&mut self) -> impl IntoIterator<Item = &mut [u8]> {
let BufferLimits { capacity, len } = self.limits();
let available = capacity - len;
let Self { storage, head, len, shrink: _ } = self;
let mut write_start = *head + *len;
if write_start >= storage.len() {
write_start -= storage.len()
}
let write_end = write_start + available;
if write_end <= storage.len() {
Either::Left([&mut self.storage[write_start..write_end]].into_iter())
} else {
let (b1, b2) = self.storage[..].split_at_mut(write_start);
let b2_len = b2.len();
Either::Right([b2, &mut b1[..(available - b2_len)]].into_iter())
}
}
pub fn set_target_size(&mut self, new_capacity: usize) {
let Self { ref mut shrink, head, len: _, storage } = self;
if let Some(extend_by) = new_capacity.checked_sub(storage.len()) {
let old_shrink = shrink.take();
if extend_by != 0 {
let reserved_len = old_shrink.map_or(0, |r| r.current);
let mut new_storage = Vec::new();
new_storage.reserve_exact(new_capacity);
if *head <= reserved_len {
new_storage
.extend_from_slice(&storage[*head..][..(storage.len() - reserved_len)])
} else {
new_storage.extend_from_slice(&storage[*head..]);
new_storage.extend_from_slice(&storage[..(*head - reserved_len)]);
}
new_storage.resize(new_capacity, 0);
*storage = new_storage;
*head = 0;
}
} else {
let target = NonZeroUsize::new(storage.len() - new_capacity).unwrap();
match shrink.take() {
None => *shrink = Some(PendingShrink { target, current: 0 }),
Some(PendingShrink { target: _, current }) => {
let current = core::cmp::min(current, target.get());
*shrink = Some(PendingShrink { target, current });
self.maybe_shrink(0)
}
}
}
}
}
impl Buffer for RingBuffer {
fn limits(&self) -> BufferLimits {
let Self { storage, shrink, len, head: _ } = self;
let capacity = storage.len() - shrink.as_ref().map_or(0, |r| r.current);
BufferLimits { len: *len, capacity }
}
fn target_capacity(&self) -> usize {
let Self { storage, shrink, len: _, head: _ } = self;
storage.len() - shrink.as_ref().map_or(0, |r| r.target.get())
}
fn request_capacity(&mut self, size: usize) {
self.set_target_size(size)
}
}
impl ReceiveBuffer for RingBuffer {
fn write_at<P: Payload>(&mut self, offset: usize, data: &P) -> usize {
let BufferLimits { capacity, len } = self.limits();
let available = capacity - len;
let Self { storage, head, len, shrink: _ } = self;
if storage.len() == 0 {
return 0;
}
if offset > available {
return 0;
}
let start_at = (*head + *len + offset) % storage.len();
let to_write = cmp::min(data.len(), available);
let first_len = cmp::min(to_write, storage.len() - start_at);
data.partial_copy(0, &mut storage[start_at..start_at + first_len]);
if to_write > first_len {
data.partial_copy(first_len, &mut storage[0..to_write - first_len]);
}
to_write
}
fn make_readable(&mut self, count: usize) {
let BufferLimits { capacity, len } = self.limits();
debug_assert!(count <= capacity - len);
self.len += count;
}
}
impl SendBuffer for RingBuffer {
fn mark_read(&mut self, count: usize) {
let Self { storage, head, len, shrink: _ } = self;
assert!(count <= *len);
*len -= count;
*head = (*head + count) % storage.len();
self.maybe_shrink(count);
}
fn peek_with<'a, F, R>(&'a mut self, offset: usize, f: F) -> R
where
F: FnOnce(SendPayload<'a>) -> R,
{
let Self { storage, head, len, shrink: _ } = self;
if storage.len() == 0 {
return f(SendPayload::Contiguous(&[]));
}
assert!(offset <= *len);
RingBuffer::with_readable(
storage,
(*head + offset) % storage.len(),
*len - offset,
|readable| match readable.len() {
1 => f(SendPayload::Contiguous(readable[0])),
2 => f(SendPayload::Straddle(readable[0], readable[1])),
x => unreachable!(
"the ring buffer cannot have more than 2 fragments, got {} fragments ({:?})",
x, readable
),
},
)
}
}
#[derive(Debug)]
#[cfg_attr(test, derive(PartialEq, Eq))]
pub(super) struct Assembler {
nxt: SeqNum,
outstanding: Vec<Range<SeqNum>>,
}
impl Assembler {
pub(super) fn new(nxt: SeqNum) -> Self {
Self { outstanding: Vec::new(), nxt }
}
pub(super) fn nxt(&self) -> SeqNum {
self.nxt
}
pub(super) fn has_out_of_order(&self) -> bool {
!self.outstanding.is_empty()
}
pub(super) fn insert(&mut self, Range { start, end }: Range<SeqNum>) -> usize {
assert!(!start.after(end));
assert!(!start.before(self.nxt));
if start == end {
return 0;
}
self.insert_inner(start..end);
let Self { outstanding, nxt } = self;
if outstanding[0].start == *nxt {
let advanced = outstanding.remove(0);
*nxt = advanced.end;
usize::try_from(advanced.end - advanced.start).unwrap()
} else {
0
}
}
fn insert_inner(&mut self, Range { mut start, mut end }: Range<SeqNum>) {
let Self { outstanding, nxt: _ } = self;
if start == end {
return;
}
if outstanding.is_empty() {
outstanding.push(Range { start, end });
return;
}
let first_after = {
let mut cur = 0;
while cur < outstanding.len() {
if start.before(outstanding[cur].start) {
break;
}
cur += 1;
}
cur
};
let mut merge_right = 0;
for range in &outstanding[first_after..outstanding.len()] {
if end.before(range.start) {
break;
}
merge_right += 1;
if end.before(range.end) {
end = range.end;
break;
}
}
let mut merge_left = 0;
for range in (&outstanding[0..first_after]).iter().rev() {
if start.after(range.end) {
break;
}
if end.before(range.end) {
end = range.end;
}
merge_left += 1;
if start.after(range.start) {
start = range.start;
break;
}
}
if merge_left == 0 && merge_right == 0 {
outstanding.insert(first_after, Range { start, end });
} else {
let left_edge = first_after - merge_left;
let right_edge = first_after + merge_right;
outstanding[left_edge] = Range { start, end };
for i in right_edge..outstanding.len() {
outstanding[i - merge_left - merge_right + 1] = outstanding[i].clone();
}
outstanding.truncate(outstanding.len() - merge_left - merge_right + 1);
}
}
}
pub trait IntoBuffers<R: ReceiveBuffer, S: SendBuffer> {
fn into_buffers(self, buffer_sizes: BufferSizes) -> (R, S);
}
#[cfg(any(test, feature = "testutils"))]
impl<R: Default + ReceiveBuffer, S: Default + SendBuffer> IntoBuffers<R, S> for () {
fn into_buffers(self, buffer_sizes: BufferSizes) -> (R, S) {
let BufferSizes { send: _, receive: _ } = buffer_sizes;
Default::default()
}
}
#[cfg(any(test, feature = "testutils"))]
pub(crate) mod testutil {
use super::*;
use alloc::sync::Arc;
use crate::sync::Mutex;
use crate::transport::tcp::socket::ListenerNotifier;
impl RingBuffer {
pub(crate) fn enqueue_data(&mut self, data: &[u8]) -> usize {
let nwritten = self.write_at(0, &data);
self.make_readable(nwritten);
nwritten
}
}
impl Buffer for Arc<Mutex<RingBuffer>> {
fn limits(&self) -> BufferLimits {
self.lock().limits()
}
fn target_capacity(&self) -> usize {
self.lock().target_capacity()
}
fn request_capacity(&mut self, size: usize) {
self.lock().set_target_size(size)
}
}
impl ReceiveBuffer for Arc<Mutex<RingBuffer>> {
fn write_at<P: Payload>(&mut self, offset: usize, data: &P) -> usize {
self.lock().write_at(offset, data)
}
fn make_readable(&mut self, count: usize) {
self.lock().make_readable(count)
}
}
#[derive(Debug, Default)]
pub struct TestSendBuffer {
fake_stream: Arc<Mutex<Vec<u8>>>,
ring: RingBuffer,
}
impl TestSendBuffer {
pub fn new(fake_stream: Arc<Mutex<Vec<u8>>>, ring: RingBuffer) -> TestSendBuffer {
Self { fake_stream, ring }
}
}
impl Buffer for TestSendBuffer {
fn limits(&self) -> BufferLimits {
let Self { fake_stream, ring } = self;
let BufferLimits { capacity: ring_capacity, len: ring_len } = ring.limits();
let guard = fake_stream.lock();
let len = ring_len + guard.len();
let capacity = ring_capacity + guard.capacity();
BufferLimits { len, capacity }
}
fn target_capacity(&self) -> usize {
let Self { fake_stream: _, ring } = self;
ring.target_capacity()
}
fn request_capacity(&mut self, size: usize) {
let Self { fake_stream: _, ring } = self;
ring.set_target_size(size)
}
}
impl SendBuffer for TestSendBuffer {
fn mark_read(&mut self, count: usize) {
let Self { fake_stream: _, ring } = self;
ring.mark_read(count)
}
fn peek_with<'a, F, R>(&'a mut self, offset: usize, f: F) -> R
where
F: FnOnce(SendPayload<'a>) -> R,
{
let Self { fake_stream, ring } = self;
let mut guard = fake_stream.lock();
if !guard.is_empty() {
let BufferLimits { capacity, len } = ring.limits();
let len = (capacity - len).min(guard.len());
let rest = guard.split_off(len);
let first = core::mem::replace(&mut *guard, rest);
assert_eq!(ring.enqueue_data(&first[..]), len);
}
ring.peek_with(offset, f)
}
}
fn arc_mutex_eq<T: PartialEq>(a: &Arc<Mutex<T>>, b: &Arc<Mutex<T>>) -> bool {
if Arc::ptr_eq(a, b) {
return true;
}
(&*a.lock()) == (&*b.lock())
}
#[derive(Clone, Debug, Default)]
pub struct ClientBuffers {
pub receive: Arc<Mutex<RingBuffer>>,
pub send: Arc<Mutex<Vec<u8>>>,
}
impl PartialEq for ClientBuffers {
fn eq(&self, ClientBuffers { receive: other_receive, send: other_send }: &Self) -> bool {
let Self { receive, send } = self;
arc_mutex_eq(receive, other_receive) && arc_mutex_eq(send, other_send)
}
}
impl Eq for ClientBuffers {}
impl ClientBuffers {
pub fn new(buffer_sizes: BufferSizes) -> Self {
let BufferSizes { send, receive } = buffer_sizes;
Self {
receive: Arc::new(Mutex::new(RingBuffer::new(receive))),
send: Arc::new(Mutex::new(Vec::with_capacity(send))),
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum ProvidedBuffers {
Buffers(WriteBackClientBuffers),
NoBuffers,
}
impl Default for ProvidedBuffers {
fn default() -> Self {
Self::NoBuffers
}
}
impl From<WriteBackClientBuffers> for ProvidedBuffers {
fn from(buffers: WriteBackClientBuffers) -> Self {
ProvidedBuffers::Buffers(buffers)
}
}
impl From<ProvidedBuffers> for WriteBackClientBuffers {
fn from(extra: ProvidedBuffers) -> Self {
match extra {
ProvidedBuffers::Buffers(buffers) => buffers,
ProvidedBuffers::NoBuffers => Default::default(),
}
}
}
impl From<ProvidedBuffers> for () {
fn from(_: ProvidedBuffers) -> Self {
()
}
}
impl From<()> for ProvidedBuffers {
fn from(_: ()) -> Self {
Default::default()
}
}
#[derive(Debug, Default, Clone)]
pub struct WriteBackClientBuffers(pub Arc<Mutex<Option<ClientBuffers>>>);
impl PartialEq for WriteBackClientBuffers {
fn eq(&self, Self(other): &Self) -> bool {
let Self(this) = self;
arc_mutex_eq(this, other)
}
}
impl Eq for WriteBackClientBuffers {}
impl IntoBuffers<Arc<Mutex<RingBuffer>>, TestSendBuffer> for ProvidedBuffers {
fn into_buffers(
self,
buffer_sizes: BufferSizes,
) -> (Arc<Mutex<RingBuffer>>, TestSendBuffer) {
let buffers = ClientBuffers::new(buffer_sizes);
if let ProvidedBuffers::Buffers(b) = self {
*b.0.as_ref().lock() = Some(buffers.clone());
}
let ClientBuffers { receive, send } = buffers;
(receive, TestSendBuffer::new(send, Default::default()))
}
}
impl ListenerNotifier for ProvidedBuffers {
fn new_incoming_connections(&mut self, _: usize) {}
}
}
#[cfg(test)]
mod test {
use assert_matches::assert_matches;
use packet::{
Buf, FragmentedBytesMut, PacketBuilder, PacketConstraints, SerializeError, SerializeTarget,
Serializer,
};
use proptest::{
proptest,
strategy::{Just, Strategy},
test_runner::Config,
};
use proptest_support::failed_seeds;
use test_case::test_case;
use super::*;
use crate::transport::tcp::seqnum::WindowSize;
const TEST_BYTES: &'static [u8] = "Hello World!".as_bytes();
proptest! {
#![proptest_config(Config {
failure_persistence: failed_seeds!(
"cc f621ca7d3a2b108e0dc41f7169ad028f4329b79e90e73d5f68042519a9f63999",
"cc c449aebed201b4ec4f137f3c224f20325f4cfee0b7fd596d9285176b6d811aa9"
),
..Config::default()
})]
#[test]
fn assembler_insertion(insertions in proptest::collection::vec(assembler::insertions(), 200)) {
let mut assembler = Assembler::new(SeqNum::new(0));
let mut num_insertions_performed = 0;
let mut min_seq = SeqNum::new(WindowSize::MAX.into());
let mut max_seq = SeqNum::new(0);
for Range { start, end } in insertions {
if min_seq.after(start) {
min_seq = start;
}
if max_seq.before(end) {
max_seq = end;
}
assert!(assembler.outstanding.len() <= num_insertions_performed);
assembler.insert_inner(start..end);
num_insertions_performed += 1;
for i in 1..assembler.outstanding.len() {
assert!(assembler.outstanding[i-1].end.before(assembler.outstanding[i].start));
}
}
assert_eq!(assembler.outstanding.first().unwrap().start, min_seq);
assert_eq!(assembler.outstanding.last().unwrap().end, max_seq);
}
#[test]
fn ring_buffer_make_readable((mut rb, avail) in ring_buffer::with_written()) {
let old_storage = rb.storage.clone();
let old_head = rb.head;
let old_len = rb.limits().len;
let old_shrink = rb.shrink;
rb.make_readable(avail);
let RingBuffer { storage, head, len, shrink } = rb;
assert_eq!(len, old_len + avail);
assert_eq!(head, old_head);
assert_eq!(storage, old_storage);
assert_eq!(shrink, old_shrink);
}
#[test]
fn ring_buffer_write_at((mut rb, offset, data) in ring_buffer::with_offset_data()) {
let old_head = rb.head;
let old_len = rb.limits().len;
assert_eq!(rb.write_at(offset, &&data[..]), data.len());
assert_eq!(rb.head, old_head);
assert_eq!(rb.limits().len, old_len);
for i in 0..data.len() {
let masked = (rb.head + rb.len + offset + i) % rb.storage.len();
assert_eq!(rb.storage[masked], data[i]);
rb.storage[masked] = 0;
}
assert_eq!(rb.storage, vec![0; rb.storage.len()])
}
#[test]
fn ring_buffer_read_with((mut rb, expected, consume) in ring_buffer::with_read_data()) {
assert_eq!(rb.limits().len, expected.len());
let nread = rb.read_with(|readable| {
assert!(readable.len() == 1 || readable.len() == 2);
let got = readable.concat();
assert_eq!(got, expected);
consume
});
assert_eq!(nread, consume);
assert_eq!(rb.limits().len, expected.len() - consume);
}
#[test]
fn ring_buffer_mark_read((mut rb, readable) in ring_buffer::with_readable()) {
const BYTE_TO_WRITE: u8 = 0x42;
let written = rb.writable_regions().into_iter().fold(0, |acc, slice| {
slice.fill(BYTE_TO_WRITE);
acc + slice.len()
});
let old_storage = rb.storage.clone();
let old_head = rb.head;
let old_len = rb.limits().len;
let old_shrink = rb.shrink;
rb.mark_read(readable);
let new_writable = rb.writable_regions().into_iter().fold(Vec::new(), |mut acc, slice| {
acc.extend_from_slice(slice);
acc
});
for (i, x) in new_writable.iter().enumerate().take(written) {
assert_eq!(*x, BYTE_TO_WRITE, "i={}, rb={:?}", i, rb);
}
assert!(new_writable.len() >= written);
let RingBuffer { storage, head, len, shrink } = rb;
assert_eq!(len, old_len - readable);
let shrank = old_shrink.is_some() && shrink.is_none();
if !shrank {
assert_eq!(head, (old_head + readable) % old_storage.len());
assert_eq!(storage, old_storage);
}
}
#[test]
fn ring_buffer_peek_with((mut rb, expected, offset) in ring_buffer::with_read_data()) {
assert_eq!(rb.limits().len, expected.len());
let () = rb.peek_with(offset, |readable| {
assert_eq!(readable.to_vec(), &expected[offset..]);
});
assert_eq!(rb.limits().len, expected.len());
}
#[test]
fn ring_buffer_writable_regions(mut rb in ring_buffer::arb_ring_buffer()) {
const BYTE_TO_WRITE: u8 = 0x42;
let writable_len = rb.writable_regions().into_iter().fold(0, |acc, slice| {
slice.fill(BYTE_TO_WRITE);
acc + slice.len()
});
let BufferLimits {len, capacity} = rb.limits();
assert_eq!(writable_len + len, capacity);
for i in 0..capacity {
let expected = if i < len {
0
} else {
BYTE_TO_WRITE
};
let idx = (rb.head + i) % rb.storage.len();
assert_eq!(rb.storage[idx], expected);
}
}
#[test]
fn send_payload_len((payload, _idx) in send_payload::with_index()) {
assert_eq!(payload.len(), TEST_BYTES.len())
}
#[test]
fn send_payload_slice((payload, idx) in send_payload::with_index()) {
let idx_u32 = u32::try_from(idx).unwrap();
let end = u32::try_from(TEST_BYTES.len()).unwrap();
assert_eq!(payload.clone().slice(0..idx_u32).to_vec(), &TEST_BYTES[..idx]);
assert_eq!(payload.clone().slice(idx_u32..end).to_vec(), &TEST_BYTES[idx..]);
}
#[test]
fn send_payload_partial_copy((payload, offset, len) in send_payload::with_offset_and_length()) {
let mut buffer = [0; TEST_BYTES.len()];
payload.partial_copy(offset, &mut buffer[0..len]);
assert_eq!(&buffer[0..len], &TEST_BYTES[offset..offset + len]);
}
#[test]
fn set_target_size((mut rb, new_cap) in ring_buffer::with_new_target_size()) {
const BYTE_TO_WRITE: u8 = 0x42;
let written = rb.writable_regions().into_iter().fold(0, |acc, slice| {
slice.fill(BYTE_TO_WRITE);
acc + slice.len()
});
let old_len = rb.limits().len;
rb.set_target_size(new_cap);
assert_eq!(rb.limits().len, old_len);
let new_writable = rb.writable_regions().into_iter().fold(Vec::new(), |mut acc, slice| {
acc.extend_from_slice(slice);
acc
});
let BufferLimits {len, capacity} = rb.limits();
assert_eq!(new_writable.len() + len, capacity);
assert!(new_writable.len() >= written);
for (i, x) in new_writable.iter().enumerate() {
let expected = (i < written).then_some(BYTE_TO_WRITE).unwrap_or(0);
assert_eq!(*x, expected, "i={}, rb={:?}", i, rb);
}
}
}
#[derive(Debug)]
struct OuterBuilder(PacketConstraints);
impl OuterBuilder {
const HEADER_BYTE: u8 = b'H';
const FOOTER_BYTE: u8 = b'H';
}
impl PacketBuilder for OuterBuilder {
fn constraints(&self) -> PacketConstraints {
let Self(constraints) = self;
constraints.clone()
}
fn serialize(&self, target: &mut SerializeTarget<'_>, _body: FragmentedBytesMut<'_, '_>) {
target.header.fill(Self::HEADER_BYTE);
target.footer.fill(Self::FOOTER_BYTE);
}
}
const EXAMPLE_DATA: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
#[test_case(SendPayload::Contiguous(&EXAMPLE_DATA); "contiguous")]
#[test_case(SendPayload::Straddle(&EXAMPLE_DATA[0..5], &EXAMPLE_DATA[5..]); "split")]
#[test_case(SendPayload::Straddle(&[], &EXAMPLE_DATA); "split empty front")]
#[test_case(SendPayload::Straddle(&EXAMPLE_DATA, &[]); "split empty back")]
fn send_payload_serializer_data(payload: SendPayload<'static>) {
const HEADER_LEN: usize = 5;
const FOOTER_LEN: usize = 6;
let outer = OuterBuilder(PacketConstraints::new(HEADER_LEN, FOOTER_LEN, 0, usize::MAX));
assert_eq!(
payload
.into_serializer()
.encapsulate(outer)
.serialize_vec_outer()
.expect("should serialize")
.unwrap_b(),
Buf::new(
[OuterBuilder::HEADER_BYTE; HEADER_LEN]
.into_iter()
.chain(EXAMPLE_DATA)
.chain([OuterBuilder::FOOTER_BYTE; FOOTER_LEN])
.collect(),
..
)
);
}
#[test]
fn send_payload_serializer_body_too_large() {
let outer = OuterBuilder(PacketConstraints::new(0, 0, 0, EXAMPLE_DATA.len() - 1));
let payload = SendPayload::Contiguous(&EXAMPLE_DATA);
assert_matches!(
payload.into_serializer().encapsulate(outer).serialize_vec_outer(),
Err((SerializeError::SizeLimitExceeded, _))
);
}
#[test]
fn send_payload_serializer_body_needs_padding() {
const PADDING: usize = 3;
let outer =
OuterBuilder(PacketConstraints::new(0, 0, EXAMPLE_DATA.len() + PADDING, usize::MAX));
let payload = SendPayload::Contiguous(&EXAMPLE_DATA);
assert_eq!(
payload
.into_serializer()
.encapsulate(outer)
.serialize_vec_outer()
.expect("can serialize")
.unwrap_b(),
Buf::new(EXAMPLE_DATA.into_iter().chain([0; PADDING]).collect(), ..)
);
}
#[test_case([Range { start: 0, end: 0 }]
=> Assembler { outstanding: vec![], nxt: SeqNum::new(0) })]
#[test_case([Range { start: 0, end: 10 }]
=> Assembler { outstanding: vec![], nxt: SeqNum::new(10) })]
#[test_case([Range{ start: 10, end: 15 }, Range { start: 5, end: 10 }]
=> Assembler { outstanding: vec![Range { start: SeqNum::new(5), end: SeqNum::new(15) }], nxt: SeqNum::new(0)})]
#[test_case([Range{ start: 10, end: 15 }, Range { start: 0, end: 5 }, Range { start: 5, end: 10 }]
=> Assembler { outstanding: vec![], nxt: SeqNum::new(15) })]
#[test_case([Range{ start: 10, end: 15 }, Range { start: 5, end: 10 }, Range { start: 0, end: 5 }]
=> Assembler { outstanding: vec![], nxt: SeqNum::new(15) })]
fn assembler_examples(ops: impl IntoIterator<Item = Range<u32>>) -> Assembler {
let mut assembler = Assembler::new(SeqNum::new(0));
for Range { start, end } in ops.into_iter() {
let _advanced = assembler.insert(SeqNum::new(start)..SeqNum::new(end));
}
assembler
}
#[test]
fn ring_buffer_wrap_around() {
const CAPACITY: usize = 16;
let mut rb = RingBuffer::new(CAPACITY);
const BUF_SIZE: usize = 10;
assert_eq!(rb.enqueue_data(&[0xAA; BUF_SIZE]), BUF_SIZE);
rb.peek_with(0, |payload| assert_eq!(payload, SendPayload::Contiguous(&[0xAA; BUF_SIZE])));
rb.mark_read(BUF_SIZE);
assert_eq!(rb.enqueue_data(&[0xBB; BUF_SIZE]), BUF_SIZE);
rb.peek_with(0, |payload| {
assert_eq!(
payload,
SendPayload::Straddle(
&[0xBB; (CAPACITY - BUF_SIZE)],
&[0xBB; (BUF_SIZE * 2 - CAPACITY)]
)
)
});
rb.mark_read(BUF_SIZE);
assert_eq!(rb.enqueue_data(&[0xCC; BUF_SIZE]), BUF_SIZE);
rb.peek_with(0, |payload| assert_eq!(payload, SendPayload::Contiguous(&[0xCC; BUF_SIZE])));
let read = rb.read_with(|segments| {
assert_eq!(segments, [[0xCC; BUF_SIZE]]);
BUF_SIZE
});
assert_eq!(read, BUF_SIZE);
}
#[test]
fn ring_buffer_example() {
let mut rb = RingBuffer::new(16);
assert_eq!(rb.write_at(5, &"World".as_bytes()), 5);
assert_eq!(rb.write_at(0, &"Hello".as_bytes()), 5);
rb.make_readable(10);
assert_eq!(
rb.read_with(|readable| {
assert_eq!(readable, &["HelloWorld".as_bytes()]);
5
}),
5
);
assert_eq!(
rb.read_with(|readable| {
assert_eq!(readable, &["World".as_bytes()]);
readable[0].len()
}),
5
);
assert_eq!(rb.write_at(0, &"HelloWorld".as_bytes()), 10);
rb.make_readable(10);
assert_eq!(
rb.read_with(|readable| {
assert_eq!(readable, &["HelloW".as_bytes(), "orld".as_bytes()]);
6
}),
6
);
assert_eq!(rb.limits().len, 4);
assert_eq!(
rb.read_with(|readable| {
assert_eq!(readable, &["orld".as_bytes()]);
4
}),
4
);
assert_eq!(rb.limits().len, 0);
assert_eq!(rb.enqueue_data("Hello".as_bytes()), 5);
assert_eq!(rb.limits().len, 5);
let () = rb.peek_with(3, |readable| {
assert_eq!(readable.to_vec(), "lo".as_bytes());
});
rb.mark_read(2);
let () = rb.peek_with(0, |readable| {
assert_eq!(readable.to_vec(), "llo".as_bytes());
});
}
mod assembler {
use super::*;
pub(super) fn insertions() -> impl Strategy<Value = Range<SeqNum>> {
(0..u32::from(WindowSize::MAX)).prop_flat_map(|start| {
(start + 1..=u32::from(WindowSize::MAX)).prop_flat_map(move |end| {
Just(Range { start: SeqNum::new(start), end: SeqNum::new(end) })
})
})
}
}
mod ring_buffer {
use super::*;
const MAX_CAP: usize = 32;
fn arb_ring_buffer_args(
) -> impl Strategy<Value = (usize, usize, usize, Option<PendingShrink>)> {
fn arb_shrink_args(cap: usize) -> impl Strategy<Value = Option<PendingShrink>> {
(0..=cap).prop_flat_map(|target| match NonZeroUsize::new(target) {
Some(target) => (Just(target), (0..=target.get()))
.prop_map(|(target, current)| Some(PendingShrink { target, current }))
.boxed(),
None => Just(None).boxed(),
})
}
(1..=MAX_CAP).prop_flat_map(|cap| {
arb_shrink_args(cap).prop_flat_map(move |shrink| {
let max_len = cap - shrink.as_ref().map_or(0, |r| r.current);
(Just(cap), 0..cap, 0..=max_len, Just(shrink))
})
})
}
pub(super) fn arb_ring_buffer() -> impl Strategy<Value = RingBuffer> {
arb_ring_buffer_args().prop_map(|(cap, head, len, shrink)| RingBuffer {
storage: vec![0; cap],
head,
len,
shrink,
})
}
pub(super) fn with_readable() -> impl Strategy<Value = (RingBuffer, usize)> {
arb_ring_buffer_args().prop_flat_map(|(cap, head, len, shrink)| {
(Just(RingBuffer { storage: vec![0; cap], head, len, shrink }), 0..=len)
})
}
pub(super) fn with_written() -> impl Strategy<Value = (RingBuffer, usize)> {
arb_ring_buffer_args().prop_flat_map(|(cap, head, len, shrink)| {
let rb = RingBuffer { storage: vec![0; cap], head, len, shrink };
let max_written = cap - len - shrink.map_or(0, |r| r.current);
(Just(rb), 0..=max_written)
})
}
pub(super) fn with_offset_data() -> impl Strategy<Value = (RingBuffer, usize, Vec<u8>)> {
arb_ring_buffer_args().prop_flat_map(|(cap, head, len, shrink)| {
let writable_len = cap - len - shrink.map_or(0, |r| r.current);
(0..=writable_len).prop_flat_map(move |offset| {
(0..=writable_len - offset).prop_flat_map(move |data_len| {
(
Just(RingBuffer { storage: vec![0; cap], head, len, shrink }),
Just(offset),
proptest::collection::vec(1..=u8::MAX, data_len),
)
})
})
})
}
pub(super) fn with_read_data() -> impl Strategy<Value = (RingBuffer, Vec<u8>, usize)> {
arb_ring_buffer_args().prop_flat_map(|(cap, head, len, shrink)| {
proptest::collection::vec(1..=u8::MAX, len).prop_flat_map(move |data| {
let mut rb = RingBuffer { storage: vec![0; cap], head, len: 0, shrink };
assert_eq!(rb.write_at(0, &&data[..]), len);
rb.make_readable(len);
(Just(rb), Just(data), 0..=len)
})
})
}
pub(super) fn with_new_target_size() -> impl Strategy<Value = (RingBuffer, usize)> {
arb_ring_buffer_args().prop_flat_map(|(cap, head, len, shrink)| {
(0..MAX_CAP * 2).prop_map(move |target_size| {
let rb = RingBuffer { storage: vec![0; cap], head, len, shrink };
(rb, target_size)
})
})
}
}
mod send_payload {
use super::*;
use alloc::borrow::ToOwned as _;
pub(super) fn with_index() -> impl Strategy<Value = (SendPayload<'static>, usize)> {
proptest::prop_oneof![
(Just(SendPayload::Contiguous(TEST_BYTES)), 0..TEST_BYTES.len()),
(0..TEST_BYTES.len()).prop_flat_map(|split_at| {
(
Just(SendPayload::Straddle(
&TEST_BYTES[..split_at],
&TEST_BYTES[split_at..],
)),
0..TEST_BYTES.len(),
)
})
]
}
pub(super) fn with_offset_and_length(
) -> impl Strategy<Value = (SendPayload<'static>, usize, usize)> {
with_index().prop_flat_map(|(payload, index)| {
(Just(payload), Just(index), 0..=TEST_BYTES.len() - index)
})
}
impl SendPayload<'_> {
pub(super) fn to_vec(self) -> Vec<u8> {
match self {
SendPayload::Contiguous(p) => p.to_owned(),
SendPayload::Straddle(p1, p2) => [p1, p2].concat(),
}
}
}
}
}