1use std::{
2 io::{Error, ErrorKind, Read, Result, Seek, SeekFrom, Write},
3 vec::Vec,
4};
5
6struct Op {
7 offset: u64,
8 data: Vec<u8>,
9}
10
11impl Op {
12 fn end(&self) -> u64 {
13 self.offset + self.data.len() as u64
14 }
15}
16
17pub struct TransactionManager<T> {
21 inner: T,
22 ops: Vec<Op>,
23 active: bool,
24}
25
26impl<T: Seek + Write> TransactionManager<T> {
27 pub fn new(inner: T) -> Self {
28 TransactionManager { inner, ops: Vec::new(), active: false }
29 }
30
31 #[allow(dead_code)]
32 pub fn into_inner(self) -> T {
33 self.inner
34 }
35
36 #[must_use]
37 pub fn begin_transaction(&mut self) -> bool {
38 if self.active {
39 false
40 } else {
41 self.active = true;
42 true
43 }
44 }
45
46 pub fn commit(&mut self) -> Result<()> {
47 assert!(self.active);
48 self.active = false;
49 for op in self.ops.drain(..) {
50 self.inner.seek(SeekFrom::Start(op.offset))?;
51 let mut buf = op.data.as_slice();
52 while !buf.is_empty() {
53 let done = self.inner.write(buf)?;
54 if done == 0 {
55 return Err(Error::new(ErrorKind::WriteZero, "Inner write failed"));
56 }
57 buf = &buf[done..];
58 }
59 }
60 Ok(())
61 }
62
63 pub fn revert(&mut self) {
64 assert!(self.active);
65 self.active = false;
66 self.ops.clear();
67 }
68
69 pub fn borrow_inner(&self) -> &T {
70 &self.inner
71 }
72}
73
74impl<T: Read + Seek> Read for TransactionManager<T> {
75 fn read(&mut self, mut buf: &mut [u8]) -> Result<usize> {
76 if !self.active {
77 return self.inner.read(buf);
78 }
79
80 let mut offset = self.inner.seek(SeekFrom::Current(0))?;
81 let mut i = 0;
82 while i < self.ops.len() && self.ops[i].end() <= offset {
83 i += 1;
84 }
85 let mut done = 0;
86 while !buf.is_empty() {
87 let mut to_do;
88 let next_offset = if i < self.ops.len() { self.ops[i].offset } else { u64::MAX };
89 if next_offset <= offset {
90 to_do = std::cmp::min(self.ops[i].end() - offset, buf.len() as u64) as usize;
92 let data_offset = (offset - next_offset) as usize;
93 buf[..to_do].copy_from_slice(&self.ops[i].data[data_offset..data_offset + to_do]);
94 i += 1;
95 self.inner.seek(SeekFrom::Current(to_do as i64))?;
96 } else {
97 to_do = std::cmp::min(next_offset - offset, buf.len() as u64) as usize;
99 to_do = self.inner.read(&mut buf[..to_do])?;
100 if to_do == 0 {
101 return Ok(done);
102 }
103 }
104 buf = &mut buf[to_do..];
105 offset += to_do as u64;
106 done += to_do;
107 }
108 return Ok(done);
109 }
110}
111
112impl<T: Seek + Write> Write for TransactionManager<T> {
113 fn write(&mut self, buf: &[u8]) -> Result<usize> {
114 if !self.active {
115 return self.inner.write(buf);
116 }
117
118 let offset = self.inner.seek(SeekFrom::Current(0))?;
119 let mut i = 0;
120 while i < self.ops.len() && self.ops[i].end() < offset {
121 i += 1;
122 }
123 if i >= self.ops.len() {
124 self.ops.push(Op { offset, data: buf.to_vec() });
126 } else if self.ops[i].end() == offset {
127 self.ops[i].data.extend_from_slice(buf);
129 } else if self.ops[i].offset < offset {
130 let data_offset = (offset - self.ops[i].offset) as usize;
132 if self.ops[i].end() >= offset + buf.len() as u64 {
133 self.ops[i].data[data_offset..data_offset + buf.len()].copy_from_slice(buf);
135 } else {
136 let to_do = self.ops[i].data.len() - data_offset;
138 self.ops[i].data[data_offset..].copy_from_slice(&buf[..to_do]);
139 self.ops[i].data.extend_from_slice(&buf[to_do..]);
140 }
141 } else {
142 self.ops.insert(i, Op { offset: offset, data: buf.to_vec() });
144 }
145 let end = self.ops[i].end();
147 i += 1;
148 while i < self.ops.len() && self.ops[i].offset < end {
149 if self.ops[i].end() <= end {
150 self.ops.remove(i);
151 } else {
152 let to_delete = end - self.ops[i].offset;
154 self.ops[i].offset += to_delete;
155 self.ops[i].data.drain(..to_delete as usize);
156 break;
157 }
158 }
159 self.inner.seek(SeekFrom::Current(buf.len() as i64))?;
160 Ok(buf.len())
161 }
162
163 fn flush(&mut self) -> Result<()> {
164 self.inner.flush()
165 }
166}
167
168impl<T: std::io::Seek> Seek for TransactionManager<T> {
169 fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
170 self.inner.seek(pos)
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use {
177 super::TransactionManager,
178 std::io::{Cursor, Read, Seek, SeekFrom, Write},
179 };
180
181 #[test]
182 fn test_read_fall_through_when_no_transaction() {
183 let mut manager = TransactionManager::new(Cursor::new(vec![55; 100]));
184 let mut read_buf = vec![0; 3];
185 assert_eq!(manager.read(&mut read_buf).expect("read failed"), 3);
186 assert_eq!(&read_buf, &[55, 55, 55]);
187 }
188
189 #[test]
190 fn test_write_fall_through_when_no_transaction() {
191 let mut manager = TransactionManager::new(Cursor::new(vec![0; 4]));
192 let write_buf = vec![1, 2, 3];
193 assert_eq!(manager.write(&write_buf).expect("write failed"), 3);
194 assert_eq!(&manager.into_inner().into_inner(), &[1, 2, 3, 0]);
195 }
196
197 #[test]
198 fn test_read_part_transaction_part_inner() {
199 let mut manager = TransactionManager::new(Cursor::new(vec![55; 100]));
200 assert!(manager.begin_transaction());
201 assert_eq!(manager.write(&[1, 2, 3]).expect("write failed"), 3);
202 assert_eq!(manager.seek(SeekFrom::Start(10)).expect("seek failed"), 10);
203 assert_eq!(manager.write(&[4, 5, 6]).expect("write failed"), 3);
204 assert_eq!(manager.seek(SeekFrom::Start(7)).expect("seek failed"), 7);
205 let mut read_buf = vec![0; 7];
206 assert_eq!(manager.read(&mut read_buf).expect("read failed"), 7);
207 assert_eq!(&read_buf, &[55, 55, 55, 4, 5, 6, 55]);
208 assert_eq!(manager.seek(SeekFrom::Current(0)).expect("seek failed"), 14);
209 }
210
211 #[test]
212 fn test_write_extend_entry() {
213 let mut manager = TransactionManager::new(Cursor::new(vec![55; 100]));
214 assert!(manager.begin_transaction());
215 assert_eq!(manager.write(&[1, 2, 3]).expect("write failed"), 3);
216 assert_eq!(manager.seek(SeekFrom::Start(10)).expect("seek failed"), 10);
217 assert_eq!(manager.write(&[4, 5, 6]).expect("write failed"), 3);
218 assert_eq!(manager.write(&[7, 8, 9]).expect("write failed"), 3);
219 assert_eq!(manager.seek(SeekFrom::Start(11)).expect("seek failed"), 11);
220 let mut read_buf = vec![0; 4];
221 assert_eq!(manager.read(&mut read_buf).expect("read failed"), 4);
222 assert_eq!(&read_buf, &[5, 6, 7, 8]);
223 }
224
225 #[test]
226 fn test_write_existing_entry_encompasses_write() {
227 let mut manager = TransactionManager::new(Cursor::new(vec![55; 100]));
228 assert!(manager.begin_transaction());
229 assert_eq!(manager.write(&[1, 2, 3]).expect("write failed"), 3);
230 assert_eq!(manager.seek(SeekFrom::Start(10)).expect("seek failed"), 10);
231 assert_eq!(manager.write(&[4, 5, 6, 7, 8, 9]).expect("write failed"), 6);
232 assert_eq!(manager.seek(SeekFrom::Start(12)).expect("seek failed"), 12);
233 assert_eq!(manager.write(&[99, 100]).expect("write failed"), 2);
234 assert_eq!(manager.seek(SeekFrom::Start(11)).expect("seek failed"), 11);
235 let mut read_buf = vec![0; 4];
236 assert_eq!(manager.read(&mut read_buf).expect("read failed"), 4);
237 assert_eq!(&read_buf, &[5, 99, 100, 8]);
238 }
239
240 #[test]
241 fn test_write_partial_overlap_and_extension() {
242 let mut manager = TransactionManager::new(Cursor::new(vec![55; 100]));
243 assert!(manager.begin_transaction());
244 assert_eq!(manager.write(&[1, 2, 3]).expect("write failed"), 3);
245 assert_eq!(manager.seek(SeekFrom::Start(10)).expect("seek failed"), 10);
246 assert_eq!(manager.write(&[4, 5, 6, 7, 8, 9]).expect("write failed"), 6);
247 assert_eq!(manager.seek(SeekFrom::Start(14)).expect("seek failed"), 14);
248 assert_eq!(manager.write(&[99, 100, 101, 102]).expect("write failed"), 4);
249 assert_eq!(manager.seek(SeekFrom::Start(13)).expect("seek failed"), 13);
250 let mut read_buf = vec![0; 6];
251 assert_eq!(manager.read(&mut read_buf).expect("read failed"), 6);
252 assert_eq!(&read_buf, &[7, 99, 100, 101, 102, 55]);
253 }
254
255 #[test]
256 fn test_write_no_overlap() {
257 let mut manager = TransactionManager::new(Cursor::new(vec![55; 100]));
258 assert!(manager.begin_transaction());
259 assert_eq!(manager.write(&[1, 2, 3]).expect("write failed"), 3);
260 assert_eq!(manager.seek(SeekFrom::Start(8)).expect("seek failed"), 8);
261 assert_eq!(manager.write(&[4, 5, 6]).expect("write failed"), 3);
262 assert_eq!(manager.seek(SeekFrom::Start(4)).expect("seek failed"), 4);
263 assert_eq!(manager.write(&[7, 8, 9]).expect("write failed"), 3);
264 assert_eq!(manager.seek(SeekFrom::Start(2)).expect("seek failed"), 2);
265 let mut read_buf = vec![0; 10];
266 assert_eq!(manager.read(&mut read_buf).expect("read failed"), 10);
267 assert_eq!(&read_buf, &[3, 55, 7, 8, 9, 55, 4, 5, 6, 55]);
268 }
269
270 #[test]
271 fn test_trim_following_entries() {
272 let mut manager = TransactionManager::new(Cursor::new(vec![55; 100]));
273 assert!(manager.begin_transaction());
274 assert_eq!(manager.seek(SeekFrom::Start(2)).expect("seek failed"), 2);
275 assert_eq!(manager.write(&[1, 2, 3]).expect("write failed"), 3);
276 assert_eq!(manager.seek(SeekFrom::Current(1)).expect("seek failed"), 6);
277 assert_eq!(manager.write(&[4, 5, 6]).expect("write failed"), 3);
278 assert_eq!(manager.seek(SeekFrom::Start(1)).expect("seek failed"), 1);
279 assert_eq!(manager.write(&[100, 101, 102, 103, 104, 105]).expect("write failed"), 6);
280 assert_eq!(manager.seek(SeekFrom::Start(0)).expect("seek failed"), 0);
281 let mut read_buf = vec![0; 10];
282 assert_eq!(manager.read(&mut read_buf).expect("read failed"), 10);
283 assert_eq!(&read_buf, &[55, 100, 101, 102, 103, 104, 105, 5, 6, 55]);
284 }
285
286 #[test]
287 fn test_revert_transaction() {
288 let mut manager = TransactionManager::new(Cursor::new(vec![1, 2, 3, 4, 5, 6]));
289 assert!(manager.begin_transaction());
290 assert_eq!(manager.write(&[7, 8, 9]).expect("write failed"), 3);
291 manager.revert();
292 assert_eq!(manager.seek(SeekFrom::Start(1)).expect("seek failed"), 1);
293 let mut read_buf = vec![0; 4];
294 assert_eq!(manager.read(&mut read_buf).expect("read failed"), 4);
295 assert_eq!(&read_buf, &[2, 3, 4, 5]);
296 }
297
298 #[test]
299 fn test_commit_transaction() {
300 let mut manager = TransactionManager::new(Cursor::new(vec![55; 10]));
301 assert!(manager.begin_transaction());
302 assert_eq!(manager.write(&[1, 2]).expect("write failed"), 2);
303 assert_eq!(manager.seek(SeekFrom::Current(1)).expect("seek failed"), 3);
304 assert_eq!(manager.write(&[3, 4]).expect("write failed"), 2);
305 assert_eq!(manager.seek(SeekFrom::Current(1)).expect("seek failed"), 6);
306 assert_eq!(manager.write(&[5, 6]).expect("write failed"), 2);
307 manager.commit().expect("commit failed");
308 assert_eq!(&manager.into_inner().into_inner(), &[1, 2, 55, 3, 4, 55, 5, 6, 55, 55]);
309 }
310}