Skip to main content

internet_checksum/
lib.rs

1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! RFC 1071 "internet checksum" computation.
6//!
7//! This crate implements the "internet checksum" defined in [RFC 1071] and
8//! updated in [RFC 1141] and [RFC 1624], which is used by many different
9//! protocols' packet formats. The checksum operates by computing the 1s
10//! complement of the 1s complement sum of successive 16-bit words of the input.
11//!
12//! [RFC 1071]: https://tools.ietf.org/html/rfc1071
13//! [RFC 1141]: https://tools.ietf.org/html/rfc1141
14//! [RFC 1624]: https://tools.ietf.org/html/rfc1624
15
16// Optimizations applied:
17//
18// 0. Byteorder independence: as described in RFC 1071 section 2.(B)
19//    The sum of 16-bit integers can be computed in either byte order,
20//    so this actually saves us from the unnecessary byte swapping on
21//    an LE machine. As perfed on a gLinux workstation, that swapping
22//    can account for ~20% of the runtime.
23//
24// 1. Widen the accumulator: doing so enables us to process a bigger
25//    chunk of data once at a time, achieving some kind of poor man's
26//    SIMD. Currently a u128 accumulator is used.
27//
28// 2. Process more at a time: we add in increments of u64 rather than u16.
29//
30// 3. Induce the compiler to produce `adc` instruction: this is a very
31//    useful instruction to implement 1's complement addition and available
32//    on both x86 and ARM. The functions `adc_uXX` are for this use.
33
34/// Compute the checksum of "bytes".
35///
36/// `checksum(bytes)` is shorthand for:
37///
38/// ```rust
39/// # use internet_checksum::Checksum;
40/// # let bytes = &[];
41/// # let _ = {
42/// let mut c = Checksum::new();
43/// c.add_bytes(bytes);
44/// c.checksum()
45/// # };
46/// ```
47#[inline]
48pub fn checksum(bytes: &[u8]) -> [u8; 2] {
49    let mut c = Checksum::new();
50    c.add_bytes(bytes);
51    c.checksum()
52}
53
54/// Updates bytes in an existing checksum.
55///
56/// `update` updates a checksum to reflect that the already-checksummed bytes
57/// `old` have been updated to contain the values in `new`. It implements the
58/// algorithm described in Equation 3 in [RFC 1624]. The first byte must be at
59/// an even number offset in the original input. If an odd number offset byte
60/// needs to be updated, the caller should simply include the preceding byte as
61/// well. If an odd number of bytes is given, it is assumed that these are the
62/// last bytes of the input. If an odd number of bytes in the middle of the
63/// input needs to be updated, the preceding or following byte of the input
64/// should be added to make an even number of bytes.
65///
66/// # Panics
67///
68/// `update` panics if `old.len() != new.len()`.
69///
70/// [RFC 1624]: https://tools.ietf.org/html/rfc1624
71#[inline]
72pub fn update(checksum: [u8; 2], old: &[u8], new: &[u8]) -> [u8; 2] {
73    assert_eq!(old.len(), new.len());
74    // We compute on the sum, not the one's complement of the sum. checksum
75    // is the one's complement of the sum, so we need to get back to the
76    // sum. Thus, we negate checksum.
77    // HC' = ~HC
78    let mut sum = !u16::from_ne_bytes(checksum);
79
80    // Let's reuse `Checksum::add_bytes` to update our checksum
81    // so that we can get the speedup for free. Using
82    // [RFC 1071 Eqn. 3], we can efficiently update our new checksum.
83    let mut c1 = Checksum::new();
84    let mut c2 = Checksum::new();
85    c1.add_bytes(old);
86    c2.add_bytes(new);
87
88    // Note, `c1.checksum_inner()` is actually ~m in [Eqn. 3]
89    // `c2.checksum_inner()` is actually ~m' in [Eqn. 3]
90    // so we have to negate `c2.checksum_inner()` first to get m'.
91    // HC' += ~m, c1.checksum_inner() == ~m.
92    sum = adc_u16(sum, c1.checksum_inner());
93    // HC' += m', c2.checksum_inner() == ~m'.
94    sum = adc_u16(sum, !c2.checksum_inner());
95    // HC' = ~HC.
96    (!sum).to_ne_bytes()
97}
98
99/// RFC 1071 "internet checksum" computation.
100///
101/// `Checksum` implements the "internet checksum" defined in [RFC 1071] and
102/// updated in [RFC 1141] and [RFC 1624], which is used by many different
103/// protocols' packet formats. The checksum operates by computing the 1s
104/// complement of the 1s complement sum of successive 16-bit words of the input.
105///
106/// [RFC 1071]: https://tools.ietf.org/html/rfc1071
107/// [RFC 1141]: https://tools.ietf.org/html/rfc1141
108/// [RFC 1624]: https://tools.ietf.org/html/rfc1624
109#[derive(Default)]
110pub struct Checksum {
111    // Accumulate the sum into a u128, despite the fact that the `Checksum`
112    // implementation adds 8-byte or smaller chunks at a time. This effectively
113    // allows us to ignore overflow, which has been demonstrated to improve
114    // performance.
115    //
116    // Adding an 8-byte chunk to a u128 can be done safely without overflow up
117    // to u64::MAX times. Thus, we need not worry about overflow unless we were
118    // to checksum more than 8 * 2^64 bytes, or ~147 exabytes. We ignore this
119    // possibility.
120    sum: u128,
121    // Since odd-length inputs are treated specially, we store the trailing byte
122    // for use in future calls to add_bytes(), and only treat it as a true
123    // trailing byte in checksum().
124    trailing_byte: Option<u8>,
125}
126
127impl Checksum {
128    /// Initialize a new checksum.
129    #[inline]
130    pub const fn new() -> Self {
131        Checksum { sum: 0, trailing_byte: None }
132    }
133
134    /// Add bytes to the checksum.
135    ///
136    /// If `bytes` does not contain an even number of bytes, a single zero byte
137    /// will be added to the end before updating the checksum.
138    ///
139    /// Note that `add_bytes` has some fixed overhead regardless of the size of
140    /// `bytes`. Where performance is a concern, prefer fewer calls to
141    /// `add_bytes` with larger input over more calls with smaller input.
142    #[inline]
143    pub fn add_bytes(&mut self, mut bytes: &[u8]) {
144        if bytes.is_empty() {
145            return;
146        }
147
148        let mut sum = self.sum;
149
150        // Deal with previous trailing byte, if we have one.
151        // NB: Don't use `if let Some(t) = self.trailing_byte.take()`. It slows
152        // down the fast path (i.e. the `None` case).
153        if self.trailing_byte.is_some() {
154            let trailing = self.trailing_byte.take().unwrap();
155            sum += u16::from_ne_bytes([trailing, bytes[0]]) as u128;
156            bytes = &bytes[1..];
157        }
158
159        // NB: Even though our accumulator is 16 bytes, summing in 8 byte chunks
160        // (rather than 16 byte chunks) leads to better optimized machine code
161        // on 64 bit platforms.
162        while let Some(chunk) = bytes.first_chunk::<8>() {
163            sum += u64::from_ne_bytes(*chunk) as u128;
164            bytes = &bytes[8..];
165        }
166
167        // Handle the tail.
168        if let Some(chunk) = bytes.first_chunk::<4>() {
169            sum += u32::from_ne_bytes(*chunk) as u128;
170            bytes = &bytes[4..];
171        }
172        if let Some(chunk) = bytes.first_chunk::<2>() {
173            sum += u16::from_ne_bytes(*chunk) as u128;
174            bytes = &bytes[2..];
175        }
176        if bytes.len() == 1 {
177            // Stash the trailing byte.
178            self.trailing_byte = Some(bytes[0]);
179        }
180
181        self.sum = sum;
182    }
183
184    /// Computes the checksum, but in big endian byte order.
185    fn checksum_inner(&self) -> u16 {
186        let mut sum = self.sum;
187        if let Some(byte) = self.trailing_byte {
188            sum += u16::from_ne_bytes([byte, 0]) as u128;
189        }
190        !normalize(sum)
191    }
192
193    /// Computes the one's complement sum and returns the array representation.
194    ///
195    /// `partial_checksum` returns the one's complement sum of all data added
196    /// using `add_bytes` so far. Calling `partial_checksum` does *not* reset
197    /// the checksum. More bytes may be added after calling `partial_checksum`,
198    /// and they will be added to the checksum as expected.
199    ///
200    /// `partial_checksum` will return `None` if an odd number of bytes have
201    /// been added so far.
202    pub fn partial_checksum(&self) -> Option<[u8; 2]> {
203        if self.trailing_byte.is_some() {
204            return None;
205        }
206        Some(normalize(self.sum).to_ne_bytes())
207    }
208
209    /// Computes the checksum, and returns the array representation.
210    ///
211    /// `checksum` returns the checksum of all data added using `add_bytes` so
212    /// far. Calling `checksum` does *not* reset the checksum. More bytes may be
213    /// added after calling `checksum`, and they will be added to the checksum
214    /// as expected.
215    ///
216    /// If an odd number of bytes have been added so far, the checksum will be
217    /// computed as though a single 0 byte had been added at the end in order to
218    /// even out the length of the input.
219    #[inline]
220    pub fn checksum(&self) -> [u8; 2] {
221        self.checksum_inner().to_ne_bytes()
222    }
223}
224
225macro_rules! impl_adc {
226    ($name: ident, $t: ty) => {
227        /// implements 1's complement addition for $t,
228        /// exploiting the carry flag on a 2's complement machine.
229        /// In practice, the adc instruction will be generated.
230        fn $name(a: $t, b: $t) -> $t {
231            let (s, c) = a.overflowing_add(b);
232            s + (c as $t)
233        }
234    };
235}
236
237impl_adc!(adc_u16, u16);
238impl_adc!(adc_u32, u32);
239impl_adc!(adc_u64, u64);
240
241/// Normalizes the accumulator by mopping up the
242/// overflow until it fits in a `u16`.
243fn normalize(a: u128) -> u16 {
244    let t = adc_u64(a as u64, (a >> 64) as u64);
245    let t = adc_u32(t as u32, (t >> 32) as u32);
246    adc_u16(t as u16, (t >> 16) as u16)
247}
248
249#[cfg(test)]
250mod tests {
251    use rand::{Rng, SeedableRng};
252
253    use rand_xorshift::XorShiftRng;
254
255    use super::*;
256
257    /// Create a new deterministic RNG from a seed.
258    fn new_rng(mut seed: u128) -> XorShiftRng {
259        if seed == 0 {
260            // XorShiftRng can't take 0 seeds
261            seed = 1;
262        }
263        XorShiftRng::from_seed(seed.to_ne_bytes())
264    }
265
266    #[test]
267    fn test_checksum() {
268        for buf in IPV4_HEADERS {
269            // compute the checksum as normal
270            let mut c = Checksum::new();
271            c.add_bytes(&buf);
272            assert_eq!(c.checksum(), [0u8; 2]);
273            // compute the checksum one byte at a time to make sure our
274            // trailing_byte logic works
275            let mut c = Checksum::new();
276            for byte in *buf {
277                c.add_bytes(&[*byte]);
278            }
279            assert_eq!(c.checksum(), [0u8; 2]);
280
281            // Make sure that it works even if we overflow u32. Performing this
282            // loop 2 * 2^16 times is guaranteed to cause such an overflow
283            // because 0xFFFF + 0xFFFF > 2^16, and we're effectively adding
284            // (0xFFFF + 0xFFFF) 2^16 times. We verify the overflow as well by
285            // making sure that, at least once, the sum gets smaller from one
286            // loop iteration to the next.
287            let mut c = Checksum::new();
288            c.add_bytes(&[0xFF, 0xFF]);
289            for _ in 0..((2 * (1 << 16)) - 1) {
290                c.add_bytes(&[0xFF, 0xFF]);
291            }
292            assert_eq!(c.checksum(), [0u8; 2]);
293        }
294    }
295
296    #[test]
297    fn test_partial_checksum() {
298        for buf in IPV4_HEADERS {
299            // Partial checksum should compute for even length slices.
300            for i in (0..buf.len()).step_by(2) {
301                let mut part = Checksum::new();
302                part.add_bytes(&buf[..i]);
303
304                let mut c = Checksum::new();
305                c.add_bytes(
306                    &part
307                        .partial_checksum()
308                        .expect("partial checksum should compute for even length slices"),
309                );
310                c.add_bytes(&buf[i..]);
311                assert_eq!(c.checksum(), [0u8; 2]);
312            }
313            // Partial checksum should not compute for odd length slices.
314            for i in (1..buf.len()).step_by(2) {
315                let mut part = Checksum::new();
316                part.add_bytes(&buf[..i]);
317                assert_eq!(part.partial_checksum(), None);
318            }
319            // Partial checksum should be the complement of the checksum.
320            let mut c = Checksum::new();
321            c.add_bytes(buf);
322            assert_eq!(c.partial_checksum(), Some([0xFF; 2]));
323        }
324    }
325
326    #[test]
327    fn test_update() {
328        for b in IPV4_HEADERS {
329            let mut buf = Vec::new();
330            buf.extend_from_slice(b);
331
332            let mut c = Checksum::new();
333            c.add_bytes(&buf);
334            assert_eq!(c.checksum(), [0u8; 2]);
335
336            // replace the destination IP with the loopback address
337            let old = [buf[16], buf[17], buf[18], buf[19]];
338            (&mut buf[16..20]).copy_from_slice(&[127, 0, 0, 1]);
339            let updated = update(c.checksum(), &old, &[127, 0, 0, 1]);
340            let from_scratch = {
341                let mut c = Checksum::new();
342                c.add_bytes(&buf);
343                c.checksum()
344            };
345            assert_eq!(updated, from_scratch);
346        }
347    }
348
349    #[test]
350    fn test_update_noop() {
351        for b in IPV4_HEADERS {
352            let mut buf = Vec::new();
353            buf.extend_from_slice(b);
354
355            let mut c = Checksum::new();
356            c.add_bytes(&buf);
357            assert_eq!(c.checksum(), [0u8; 2]);
358
359            // Replace the destination IP with the same address. I.e. this
360            // update should be a no-op.
361            let old = [buf[16], buf[17], buf[18], buf[19]];
362            let updated = update(c.checksum(), &old, &old);
363            let from_scratch = {
364                let mut c = Checksum::new();
365                c.add_bytes(&buf);
366                c.checksum()
367            };
368            assert_eq!(updated, from_scratch);
369        }
370    }
371
372    #[test]
373    fn test_smoke_update() {
374        let mut rng = new_rng(70_812_476_915_813);
375
376        for _ in 0..2048 {
377            // use an odd length so we test the odd length logic
378            const BUF_LEN: usize = 31;
379            let buf: [u8; BUF_LEN] = rng.random();
380            let mut c = Checksum::new();
381            c.add_bytes(&buf);
382
383            let (begin, end) = loop {
384                let begin = rng.random_range(0..BUF_LEN);
385                let end = begin + (rng.random_range(0..(BUF_LEN + 1 - begin)));
386                // update requires that begin is even and end is either even or
387                // the end of the input
388                if begin % 2 == 0 && (end % 2 == 0 || end == BUF_LEN) {
389                    break (begin, end);
390                }
391            };
392
393            let mut new_buf = buf;
394            for i in begin..end {
395                new_buf[i] = rng.random();
396            }
397            let updated = update(c.checksum(), &buf[begin..end], &new_buf[begin..end]);
398            let from_scratch = {
399                let mut c = Checksum::new();
400                c.add_bytes(&new_buf);
401                c.checksum()
402            };
403            assert_eq!(updated, from_scratch);
404        }
405    }
406
407    /// IPv4 headers.
408    ///
409    /// This data was obtained by capturing live network traffic.
410    const IPV4_HEADERS: &[&[u8]] = &[
411        &[
412            0x45, 0x00, 0x00, 0x34, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0xae, 0xea, 0xc0, 0xa8,
413            0x01, 0x0f, 0xc0, 0xb8, 0x09, 0x6a,
414        ],
415        &[
416            0x45, 0x20, 0x00, 0x74, 0x5b, 0x6e, 0x40, 0x00, 0x37, 0x06, 0x5c, 0x1c, 0xc0, 0xb8,
417            0x09, 0x6a, 0xc0, 0xa8, 0x01, 0x0f,
418        ],
419        &[
420            0x45, 0x20, 0x02, 0x8f, 0x00, 0x00, 0x40, 0x00, 0x3b, 0x11, 0xc9, 0x3f, 0xac, 0xd9,
421            0x05, 0x6e, 0xc0, 0xa8, 0x01, 0x0f,
422        ],
423    ];
424
425    // This test checks that an input, found by a fuzzer, no longer causes a crash due to addition
426    // overflow.
427    #[test]
428    fn test_large_buffer_addition_overflow() {
429        let mut sum = Checksum { sum: 0, trailing_byte: None };
430        let bytes = [
431            0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
432            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
433            255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
434        ];
435        sum.add_bytes(&bytes[..]);
436    }
437
438    // Regression test for https://fxbug.dev/515774797.
439    //
440    // Verify that checksum calculations produce the same result, no matter if
441    // the bytes are added at once, or in odd-length chunks.
442    #[test]
443    fn test_odd_length_checksum() {
444        // Determine the expected value. Per RFC 1071, an odd length of bytes
445        // should be padded at the end with a 0.
446        let mut c = Checksum::new();
447        c.add_bytes(&[1, 2, 3, 0]);
448        let expected_checksum = c.checksum();
449
450        // Add the bytes all at once.
451        let mut c = Checksum::new();
452        c.add_bytes(&[1, 2, 3]);
453        assert_eq!(c.checksum(), expected_checksum);
454
455        // Add the bytes in two passes (first pass uses an odd number of bytes).
456        let mut c = Checksum::new();
457        c.add_bytes(&[1]);
458        c.add_bytes(&[2, 3]);
459        assert_eq!(c.checksum(), expected_checksum);
460    }
461
462    // Verify that we properly perform bounds checks against the byte buffer.
463    // Failure to do so would result in index-out-of-bounds panics.
464    #[test]
465    fn test_add_zero_bytes() {
466        let mut c = Checksum::new();
467        c.add_bytes(&[]);
468        assert_eq!(c.checksum(), [255, 255]);
469
470        // Try again, but this time set a trailing_byte.
471        let mut c = Checksum::new();
472        c.add_bytes(&[0]);
473        c.add_bytes(&[]);
474        assert_eq!(c.checksum(), [255, 255]);
475
476        // Try once more, but now complete the trailing byte exactly (no remainder).
477        let mut c = Checksum::new();
478        c.add_bytes(&[0]);
479        c.add_bytes(&[0]);
480        assert_eq!(c.checksum(), [255, 255]);
481    }
482
483    // Regression test for https://fxbug.dev/515753165.
484    //
485    // The checksum implementation manually tracks the carry bit during
486    // arithmetic overflows. Verify that we correctly handle the edge case where
487    // adding the carry bit from a previous overflow causes a second overflow to
488    // occur.
489    #[test]
490    fn test_carry_loss() {
491        const MAX: [u8; 16] = u128::MAX.to_ne_bytes();
492        const ONE: [u8; 16] = 1u128.to_ne_bytes();
493
494        let mut c1 = Checksum::new();
495        c1.add_bytes(&MAX);
496        c1.add_bytes(&ONE);
497        c1.add_bytes(&MAX);
498
499        let mut c2 = Checksum::new();
500        let bytes = [MAX, ONE, MAX].concat();
501        c2.add_bytes(&bytes);
502
503        assert_eq!(c1.checksum(), c2.checksum());
504    }
505}