internet_checksum/lib.rs
1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! RFC 1071 "internet checksum" computation.
6//!
7//! This crate implements the "internet checksum" defined in [RFC 1071] and
8//! updated in [RFC 1141] and [RFC 1624], which is used by many different
9//! protocols' packet formats. The checksum operates by computing the 1s
10//! complement of the 1s complement sum of successive 16-bit words of the input.
11//!
12//! [RFC 1071]: https://tools.ietf.org/html/rfc1071
13//! [RFC 1141]: https://tools.ietf.org/html/rfc1141
14//! [RFC 1624]: https://tools.ietf.org/html/rfc1624
15
16// Optimizations applied:
17//
18// 0. Byteorder independence: as described in RFC 1071 section 2.(B)
19// The sum of 16-bit integers can be computed in either byte order,
20// so this actually saves us from the unnecessary byte swapping on
21// an LE machine. As perfed on a gLinux workstation, that swapping
22// can account for ~20% of the runtime.
23//
24// 1. Widen the accumulator: doing so enables us to process a bigger
25// chunk of data once at a time, achieving some kind of poor man's
26// SIMD. Currently a u128 counter is used on x86-64 and a u64 is
27// used conservatively on other architectures.
28//
29// 2. Process more at a time: the old implementation uses a u32 accumulator
30// but it only adds one u16 each time to implement deferred carry. In
31// the current implementation we are processing a u128 once at a time
32// on x86-64, which is 8 u16's. On other platforms, we are processing
33// a u64 at a time, which is 4 u16's.
34//
35// 3. Induce the compiler to produce `adc` instruction: this is a very
36// useful instruction to implement 1's complement addition and available
37// on both x86 and ARM. The functions `adc_uXX` are for this use.
38//
39// 4. Eliminate branching as much as possible: the old implementation has
40// if statements for detecting overflow of the u32 accumulator which
41// is not needed when we can access the carry flag with `adc`. The old
42// `normalize` function used to have a while loop to fold the u32,
43// however, we can unroll that loop because we know ahead of time how
44// much additions we need.
45//
46// 5. In the loop of `add_bytes`, the `adc_u64` is not used, instead,
47// the `overflowing_add` is directly used. `adc_u64`'s carry flag
48// comes from the current number being added while the slightly
49// convoluted version in `add_bytes`, adding each number depends on
50// the carry flag of the previous computation. I checked under release
51// mode this issues 3 instructions instead of 4 for x86 and it should
52// theoretically be beneficial, however, measurement showed me that it
53// helps only a little. So this trick is not used for `update`.
54//
55// Results:
56//
57// Micro-benchmarks are run on an x86-64 gLinux workstation. In summary,
58// compared the baseline 0 which is prior to the byteorder independence
59// patch, there is a ~4x speedup.
60//
61// TODO: run this optimization on other platforms. I would expect
62// the situation on ARM a bit different because I am not sure
63// how much penalty there will be for misaligned read on ARM, or
64// whether it is even supported (On x86 there is generally no
65// penalty for misaligned read). If there will be penalties, we
66// should consider alignment as an optimization opportunity on ARM.
67
68// TODO(joshlf): Right-justify the columns above
69
70// TODO(joshlf):
71// - Investigate optimizations proposed in RFC 1071 Section 2. The most
72// promising on modern hardware is probably (C) Parallel Summation, although
73// that needs to be balanced against (1) Deferred Carries. Benchmarks will
74// need to be performed to determine which is faster in practice, and under
75// what scenarios.
76
77/// Compute the checksum of "bytes".
78///
79/// `checksum(bytes)` is shorthand for:
80///
81/// ```rust
82/// # use internet_checksum::Checksum;
83/// # let bytes = &[];
84/// # let _ = {
85/// let mut c = Checksum::new();
86/// c.add_bytes(bytes);
87/// c.checksum()
88/// # };
89/// ```
90#[inline]
91pub fn checksum(bytes: &[u8]) -> [u8; 2] {
92 let mut c = Checksum::new();
93 c.add_bytes(bytes);
94 c.checksum()
95}
96
97#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
98type Accumulator = u128;
99#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
100type Accumulator = u64;
101
102/// Updates bytes in an existing checksum.
103///
104/// `update` updates a checksum to reflect that the already-checksummed bytes
105/// `old` have been updated to contain the values in `new`. It implements the
106/// algorithm described in Equation 3 in [RFC 1624]. The first byte must be at
107/// an even number offset in the original input. If an odd number offset byte
108/// needs to be updated, the caller should simply include the preceding byte as
109/// well. If an odd number of bytes is given, it is assumed that these are the
110/// last bytes of the input. If an odd number of bytes in the middle of the
111/// input needs to be updated, the preceding or following byte of the input
112/// should be added to make an even number of bytes.
113///
114/// # Panics
115///
116/// `update` panics if `old.len() != new.len()`.
117///
118/// [RFC 1624]: https://tools.ietf.org/html/rfc1624
119#[inline]
120pub fn update(checksum: [u8; 2], old: &[u8], new: &[u8]) -> [u8; 2] {
121 assert_eq!(old.len(), new.len());
122 // We compute on the sum, not the one's complement of the sum. checksum
123 // is the one's complement of the sum, so we need to get back to the
124 // sum. Thus, we negate checksum.
125 // HC' = ~HC
126 let mut sum = !u16::from_ne_bytes(checksum) as Accumulator;
127
128 // Let's reuse `Checksum::add_bytes` to update our checksum
129 // so that we can get the speedup for free. Using
130 // [RFC 1071 Eqn. 3], we can efficiently update our new checksum.
131 let mut c1 = Checksum::new();
132 let mut c2 = Checksum::new();
133 c1.add_bytes(old);
134 c2.add_bytes(new);
135
136 // Note, `c1.checksum_inner()` is actually ~m in [Eqn. 3]
137 // `c2.checksum_inner()` is actually ~m' in [Eqn. 3]
138 // so we have to negate `c2.checksum_inner()` first to get m'.
139 // HC' += ~m, c1.checksum_inner() == ~m.
140 sum = adc_accumulator(sum, c1.checksum_inner() as Accumulator);
141 // HC' += m', c2.checksum_inner() == ~m'.
142 sum = adc_accumulator(sum, !c2.checksum_inner() as Accumulator);
143 // HC' = ~HC.
144 (!normalize(sum)).to_ne_bytes()
145}
146
147/// RFC 1071 "internet checksum" computation.
148///
149/// `Checksum` implements the "internet checksum" defined in [RFC 1071] and
150/// updated in [RFC 1141] and [RFC 1624], which is used by many different
151/// protocols' packet formats. The checksum operates by computing the 1s
152/// complement of the 1s complement sum of successive 16-bit words of the input.
153///
154/// [RFC 1071]: https://tools.ietf.org/html/rfc1071
155/// [RFC 1141]: https://tools.ietf.org/html/rfc1141
156/// [RFC 1624]: https://tools.ietf.org/html/rfc1624
157#[derive(Default)]
158pub struct Checksum {
159 sum: Accumulator,
160 // Since odd-length inputs are treated specially, we store the trailing byte
161 // for use in future calls to add_bytes(), and only treat it as a true
162 // trailing byte in checksum().
163 trailing_byte: Option<u8>,
164}
165
166impl Checksum {
167 /// Initialize a new checksum.
168 #[inline]
169 pub const fn new() -> Self {
170 Checksum { sum: 0, trailing_byte: None }
171 }
172
173 /// Add bytes to the checksum.
174 ///
175 /// If `bytes` does not contain an even number of bytes, a single zero byte
176 /// will be added to the end before updating the checksum.
177 ///
178 /// Note that `add_bytes` has some fixed overhead regardless of the size of
179 /// `bytes`. Where performance is a concern, prefer fewer calls to
180 /// `add_bytes` with larger input over more calls with smaller input.
181 #[inline]
182 pub fn add_bytes(&mut self, mut bytes: &[u8]) {
183 if bytes.is_empty() {
184 return;
185 }
186
187 let mut sum = self.sum;
188 let mut carry = false;
189
190 // We are not using `adc_uXX` functions here, instead, we manually track
191 // the carry flag. This is because in `adc_uXX` functions, the carry
192 // flag depends on addition itself. So the assembly for that function
193 // reads as follows:
194 //
195 // mov %rdi, %rcx
196 // mov %rsi, %rax
197 // add %rcx, %rsi -- waste! only used to generate CF.
198 // adc %rdi, $rax -- the real useful instruction.
199 //
200 // So we had better to make us depend on the CF generated by the
201 // addition of the previous 16-bit word. The ideal assembly should look
202 // like:
203 //
204 // add 0(%rdi), %rax
205 // adc 8(%rdi), %rax
206 // adc 16(%rdi), %rax
207 // .... and so on ...
208 //
209 // Sadly, there are too many instructions that can affect the carry
210 // flag, and LLVM is not that optimized to find out the pattern and let
211 // all these adc instructions not interleaved. However, doing so results
212 // in 3 instructions instead of the original 4 instructions (the two
213 // mov's are still there) and it makes a difference on input size like
214 // 1023.
215 macro_rules! update_sum_carry {
216 ($ty: ident, $chunk: expr) => {
217 let (s, c) = sum.overflowing_add($ty::from_ne_bytes($chunk) as Accumulator);
218 sum = s.wrapping_add(carry as Accumulator);
219 carry = c;
220 };
221 }
222
223 const ACCUMULATOR_BYTES: usize = (Accumulator::BITS / 8) as usize;
224 while let Some(chunk) = bytes.first_chunk::<ACCUMULATOR_BYTES>() {
225 update_sum_carry!(Accumulator, *chunk);
226 bytes = &bytes[ACCUMULATOR_BYTES..];
227 }
228
229 // Handle the tail.
230 if let Some(chunk) = bytes.first_chunk::<8>() {
231 update_sum_carry!(u64, *chunk);
232 bytes = &bytes[8..];
233 }
234 if let Some(chunk) = bytes.first_chunk::<4>() {
235 update_sum_carry!(u32, *chunk);
236 bytes = &bytes[4..];
237 }
238 if let Some(chunk) = bytes.first_chunk::<2>() {
239 update_sum_carry!(u16, *chunk);
240 bytes = &bytes[2..];
241 }
242 if bytes.len() == 1 {
243 if let Some(existing) = self.trailing_byte.take() {
244 // We already had a trailing byte. Deal with them both.
245 update_sum_carry!(u16, [existing, bytes[0]]);
246 } else {
247 // Otherwise, stash the trailing byte.
248 self.trailing_byte = Some(bytes[0])
249 }
250 }
251
252 self.sum = sum + (carry as Accumulator);
253 }
254
255 /// Computes the checksum, but in big endian byte order.
256 fn checksum_inner(&self) -> u16 {
257 let mut sum = self.sum;
258 if let Some(byte) = self.trailing_byte {
259 sum = adc_accumulator(sum, u16::from_ne_bytes([byte, 0]) as Accumulator);
260 }
261 !normalize(sum)
262 }
263
264 /// Computes the checksum, and returns the array representation.
265 ///
266 /// `checksum` returns the checksum of all data added using `add_bytes` so
267 /// far. Calling `checksum` does *not* reset the checksum. More bytes may be
268 /// added after calling `checksum`, and they will be added to the checksum
269 /// as expected.
270 ///
271 /// If an odd number of bytes have been added so far, the checksum will be
272 /// computed as though a single 0 byte had been added at the end in order to
273 /// even out the length of the input.
274 #[inline]
275 pub fn checksum(&self) -> [u8; 2] {
276 self.checksum_inner().to_ne_bytes()
277 }
278}
279
280macro_rules! impl_adc {
281 ($name: ident, $t: ty) => {
282 /// implements 1's complement addition for $t,
283 /// exploiting the carry flag on a 2's complement machine.
284 /// In practice, the adc instruction will be generated.
285 fn $name(a: $t, b: $t) -> $t {
286 let (s, c) = a.overflowing_add(b);
287 s + (c as $t)
288 }
289 };
290}
291
292impl_adc!(adc_u16, u16);
293impl_adc!(adc_u32, u32);
294#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
295impl_adc!(adc_u64, u64);
296impl_adc!(adc_accumulator, Accumulator);
297
298/// Normalizes the accumulator by mopping up the
299/// overflow until it fits in a `u16`.
300fn normalize(a: Accumulator) -> u16 {
301 #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
302 return normalize_64(adc_u64(a as u64, (a >> 64) as u64));
303 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
304 return normalize_64(a);
305}
306
307fn normalize_64(a: u64) -> u16 {
308 let t = adc_u32(a as u32, (a >> 32) as u32);
309 adc_u16(t as u16, (t >> 16) as u16)
310}
311
312#[cfg(test)]
313mod tests {
314 use rand::{Rng, SeedableRng};
315
316 use rand_xorshift::XorShiftRng;
317
318 use super::*;
319
320 /// Create a new deterministic RNG from a seed.
321 fn new_rng(mut seed: u128) -> XorShiftRng {
322 if seed == 0 {
323 // XorShiftRng can't take 0 seeds
324 seed = 1;
325 }
326 XorShiftRng::from_seed(seed.to_ne_bytes())
327 }
328
329 #[test]
330 fn test_checksum() {
331 for buf in IPV4_HEADERS {
332 // compute the checksum as normal
333 let mut c = Checksum::new();
334 c.add_bytes(&buf);
335 assert_eq!(c.checksum(), [0u8; 2]);
336 // compute the checksum one byte at a time to make sure our
337 // trailing_byte logic works
338 let mut c = Checksum::new();
339 for byte in *buf {
340 c.add_bytes(&[*byte]);
341 }
342 assert_eq!(c.checksum(), [0u8; 2]);
343
344 // Make sure that it works even if we overflow u32. Performing this
345 // loop 2 * 2^16 times is guaranteed to cause such an overflow
346 // because 0xFFFF + 0xFFFF > 2^16, and we're effectively adding
347 // (0xFFFF + 0xFFFF) 2^16 times. We verify the overflow as well by
348 // making sure that, at least once, the sum gets smaller from one
349 // loop iteration to the next.
350 let mut c = Checksum::new();
351 c.add_bytes(&[0xFF, 0xFF]);
352 for _ in 0..((2 * (1 << 16)) - 1) {
353 c.add_bytes(&[0xFF, 0xFF]);
354 }
355 assert_eq!(c.checksum(), [0u8; 2]);
356 }
357 }
358
359 #[test]
360 fn test_update() {
361 for b in IPV4_HEADERS {
362 let mut buf = Vec::new();
363 buf.extend_from_slice(b);
364
365 let mut c = Checksum::new();
366 c.add_bytes(&buf);
367 assert_eq!(c.checksum(), [0u8; 2]);
368
369 // replace the destination IP with the loopback address
370 let old = [buf[16], buf[17], buf[18], buf[19]];
371 (&mut buf[16..20]).copy_from_slice(&[127, 0, 0, 1]);
372 let updated = update(c.checksum(), &old, &[127, 0, 0, 1]);
373 let from_scratch = {
374 let mut c = Checksum::new();
375 c.add_bytes(&buf);
376 c.checksum()
377 };
378 assert_eq!(updated, from_scratch);
379 }
380 }
381
382 #[test]
383 fn test_update_noop() {
384 for b in IPV4_HEADERS {
385 let mut buf = Vec::new();
386 buf.extend_from_slice(b);
387
388 let mut c = Checksum::new();
389 c.add_bytes(&buf);
390 assert_eq!(c.checksum(), [0u8; 2]);
391
392 // Replace the destination IP with the same address. I.e. this
393 // update should be a no-op.
394 let old = [buf[16], buf[17], buf[18], buf[19]];
395 let updated = update(c.checksum(), &old, &old);
396 let from_scratch = {
397 let mut c = Checksum::new();
398 c.add_bytes(&buf);
399 c.checksum()
400 };
401 assert_eq!(updated, from_scratch);
402 }
403 }
404
405 #[test]
406 fn test_smoke_update() {
407 let mut rng = new_rng(70_812_476_915_813);
408
409 for _ in 0..2048 {
410 // use an odd length so we test the odd length logic
411 const BUF_LEN: usize = 31;
412 let buf: [u8; BUF_LEN] = rng.random();
413 let mut c = Checksum::new();
414 c.add_bytes(&buf);
415
416 let (begin, end) = loop {
417 let begin = rng.random_range(0..BUF_LEN);
418 let end = begin + (rng.random_range(0..(BUF_LEN + 1 - begin)));
419 // update requires that begin is even and end is either even or
420 // the end of the input
421 if begin % 2 == 0 && (end % 2 == 0 || end == BUF_LEN) {
422 break (begin, end);
423 }
424 };
425
426 let mut new_buf = buf;
427 for i in begin..end {
428 new_buf[i] = rng.random();
429 }
430 let updated = update(c.checksum(), &buf[begin..end], &new_buf[begin..end]);
431 let from_scratch = {
432 let mut c = Checksum::new();
433 c.add_bytes(&new_buf);
434 c.checksum()
435 };
436 assert_eq!(updated, from_scratch);
437 }
438 }
439
440 /// IPv4 headers.
441 ///
442 /// This data was obtained by capturing live network traffic.
443 const IPV4_HEADERS: &[&[u8]] = &[
444 &[
445 0x45, 0x00, 0x00, 0x34, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0xae, 0xea, 0xc0, 0xa8,
446 0x01, 0x0f, 0xc0, 0xb8, 0x09, 0x6a,
447 ],
448 &[
449 0x45, 0x20, 0x00, 0x74, 0x5b, 0x6e, 0x40, 0x00, 0x37, 0x06, 0x5c, 0x1c, 0xc0, 0xb8,
450 0x09, 0x6a, 0xc0, 0xa8, 0x01, 0x0f,
451 ],
452 &[
453 0x45, 0x20, 0x02, 0x8f, 0x00, 0x00, 0x40, 0x00, 0x3b, 0x11, 0xc9, 0x3f, 0xac, 0xd9,
454 0x05, 0x6e, 0xc0, 0xa8, 0x01, 0x0f,
455 ],
456 ];
457
458 // This test checks that an input, found by a fuzzer, no longer causes a crash due to addition
459 // overflow.
460 #[test]
461 fn test_large_buffer_addition_overflow() {
462 let mut sum = Checksum { sum: 0, trailing_byte: None };
463 let bytes = [
464 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
465 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
466 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
467 ];
468 sum.add_bytes(&bytes[..]);
469 }
470}