internet_checksum/
lib.rs

1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! RFC 1071 "internet checksum" computation.
6//!
7//! This crate implements the "internet checksum" defined in [RFC 1071] and
8//! updated in [RFC 1141] and [RFC 1624], which is used by many different
9//! protocols' packet formats. The checksum operates by computing the 1s
10//! complement of the 1s complement sum of successive 16-bit words of the input.
11//!
12//! [RFC 1071]: https://tools.ietf.org/html/rfc1071
13//! [RFC 1141]: https://tools.ietf.org/html/rfc1141
14//! [RFC 1624]: https://tools.ietf.org/html/rfc1624
15
16// Optimizations applied:
17//
18// 0. Byteorder independence: as described in RFC 1071 section 2.(B)
19//    The sum of 16-bit integers can be computed in either byte order,
20//    so this actually saves us from the unnecessary byte swapping on
21//    an LE machine. As perfed on a gLinux workstation, that swapping
22//    can account for ~20% of the runtime.
23//
24// 1. Widen the accumulator: doing so enables us to process a bigger
25//    chunk of data once at a time, achieving some kind of poor man's
26//    SIMD. Currently a u128 counter is used on x86-64 and a u64 is
27//    used conservatively on other architectures.
28//
29// 2. Process more at a time: the old implementation uses a u32 accumulator
30//    but it only adds one u16 each time to implement deferred carry. In
31//    the current implementation we are processing a u128 once at a time
32//    on x86-64, which is 8 u16's. On other platforms, we are processing
33//    a u64 at a time, which is 4 u16's.
34//
35// 3. Induce the compiler to produce `adc` instruction: this is a very
36//    useful instruction to implement 1's complement addition and available
37//    on both x86 and ARM. The functions `adc_uXX` are for this use.
38//
39// 4. Eliminate branching as much as possible: the old implementation has
40//    if statements for detecting overflow of the u32 accumulator which
41//    is not needed when we can access the carry flag with `adc`. The old
42//    `normalize` function used to have a while loop to fold the u32,
43//    however, we can unroll that loop because we know ahead of time how
44//    much additions we need.
45//
46// 5. In the loop of `add_bytes`, the `adc_u64` is not used, instead,
47//    the `overflowing_add` is directly used. `adc_u64`'s carry flag
48//    comes from the current number being added while the slightly
49//    convoluted version in `add_bytes`, adding each number depends on
50//    the carry flag of the previous computation. I checked under release
51//    mode this issues 3 instructions instead of 4 for x86 and it should
52//    theoretically be beneficial, however, measurement showed me that it
53//    helps only a little. So this trick is not used for `update`.
54//
55// Results:
56//
57// Micro-benchmarks are run on an x86-64 gLinux workstation. In summary,
58// compared the baseline 0 which is prior to the byteorder independence
59// patch, there is a ~4x speedup.
60//
61// TODO: run this optimization on other platforms. I would expect
62// the situation on ARM a bit different because I am not sure
63// how much penalty there will be for misaligned read on ARM, or
64// whether it is even supported (On x86 there is generally no
65// penalty for misaligned read). If there will be penalties, we
66// should consider alignment as an optimization opportunity on ARM.
67
68// TODO(joshlf): Right-justify the columns above
69
70// TODO(joshlf):
71// - Investigate optimizations proposed in RFC 1071 Section 2. The most
72//   promising on modern hardware is probably (C) Parallel Summation, although
73//   that needs to be balanced against (1) Deferred Carries. Benchmarks will
74//   need to be performed to determine which is faster in practice, and under
75//   what scenarios.
76
77/// Compute the checksum of "bytes".
78///
79/// `checksum(bytes)` is shorthand for:
80///
81/// ```rust
82/// # use internet_checksum::Checksum;
83/// # let bytes = &[];
84/// # let _ = {
85/// let mut c = Checksum::new();
86/// c.add_bytes(bytes);
87/// c.checksum()
88/// # };
89/// ```
90#[inline]
91pub fn checksum(bytes: &[u8]) -> [u8; 2] {
92    let mut c = Checksum::new();
93    c.add_bytes(bytes);
94    c.checksum()
95}
96
97#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
98type Accumulator = u128;
99#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
100type Accumulator = u64;
101
102/// Updates bytes in an existing checksum.
103///
104/// `update` updates a checksum to reflect that the already-checksummed bytes
105/// `old` have been updated to contain the values in `new`. It implements the
106/// algorithm described in Equation 3 in [RFC 1624]. The first byte must be at
107/// an even number offset in the original input. If an odd number offset byte
108/// needs to be updated, the caller should simply include the preceding byte as
109/// well. If an odd number of bytes is given, it is assumed that these are the
110/// last bytes of the input. If an odd number of bytes in the middle of the
111/// input needs to be updated, the preceding or following byte of the input
112/// should be added to make an even number of bytes.
113///
114/// # Panics
115///
116/// `update` panics if `old.len() != new.len()`.
117///
118/// [RFC 1624]: https://tools.ietf.org/html/rfc1624
119#[inline]
120pub fn update(checksum: [u8; 2], old: &[u8], new: &[u8]) -> [u8; 2] {
121    assert_eq!(old.len(), new.len());
122    // We compute on the sum, not the one's complement of the sum. checksum
123    // is the one's complement of the sum, so we need to get back to the
124    // sum. Thus, we negate checksum.
125    // HC' = ~HC
126    let mut sum = !u16::from_ne_bytes(checksum) as Accumulator;
127
128    // Let's reuse `Checksum::add_bytes` to update our checksum
129    // so that we can get the speedup for free. Using
130    // [RFC 1071 Eqn. 3], we can efficiently update our new checksum.
131    let mut c1 = Checksum::new();
132    let mut c2 = Checksum::new();
133    c1.add_bytes(old);
134    c2.add_bytes(new);
135
136    // Note, `c1.checksum_inner()` is actually ~m in [Eqn. 3]
137    // `c2.checksum_inner()` is actually ~m' in [Eqn. 3]
138    // so we have to negate `c2.checksum_inner()` first to get m'.
139    // HC' += ~m, c1.checksum_inner() == ~m.
140    sum = adc_accumulator(sum, c1.checksum_inner() as Accumulator);
141    // HC' += m', c2.checksum_inner() == ~m'.
142    sum = adc_accumulator(sum, !c2.checksum_inner() as Accumulator);
143    // HC' = ~HC.
144    (!normalize(sum)).to_ne_bytes()
145}
146
147/// RFC 1071 "internet checksum" computation.
148///
149/// `Checksum` implements the "internet checksum" defined in [RFC 1071] and
150/// updated in [RFC 1141] and [RFC 1624], which is used by many different
151/// protocols' packet formats. The checksum operates by computing the 1s
152/// complement of the 1s complement sum of successive 16-bit words of the input.
153///
154/// [RFC 1071]: https://tools.ietf.org/html/rfc1071
155/// [RFC 1141]: https://tools.ietf.org/html/rfc1141
156/// [RFC 1624]: https://tools.ietf.org/html/rfc1624
157#[derive(Default)]
158pub struct Checksum {
159    sum: Accumulator,
160    // Since odd-length inputs are treated specially, we store the trailing byte
161    // for use in future calls to add_bytes(), and only treat it as a true
162    // trailing byte in checksum().
163    trailing_byte: Option<u8>,
164}
165
166impl Checksum {
167    /// Initialize a new checksum.
168    #[inline]
169    pub const fn new() -> Self {
170        Checksum { sum: 0, trailing_byte: None }
171    }
172
173    /// Add bytes to the checksum.
174    ///
175    /// If `bytes` does not contain an even number of bytes, a single zero byte
176    /// will be added to the end before updating the checksum.
177    ///
178    /// Note that `add_bytes` has some fixed overhead regardless of the size of
179    /// `bytes`. Where performance is a concern, prefer fewer calls to
180    /// `add_bytes` with larger input over more calls with smaller input.
181    #[inline]
182    pub fn add_bytes(&mut self, mut bytes: &[u8]) {
183        if bytes.is_empty() {
184            return;
185        }
186
187        let mut sum = self.sum;
188        let mut carry = false;
189
190        // We are not using `adc_uXX` functions here, instead, we manually track
191        // the carry flag. This is because in `adc_uXX` functions, the carry
192        // flag depends on addition itself. So the assembly for that function
193        // reads as follows:
194        //
195        // mov %rdi, %rcx
196        // mov %rsi, %rax
197        // add %rcx, %rsi -- waste! only used to generate CF.
198        // adc %rdi, $rax -- the real useful instruction.
199        //
200        // So we had better to make us depend on the CF generated by the
201        // addition of the previous 16-bit word. The ideal assembly should look
202        // like:
203        //
204        // add 0(%rdi), %rax
205        // adc 8(%rdi), %rax
206        // adc 16(%rdi), %rax
207        // .... and so on ...
208        //
209        // Sadly, there are too many instructions that can affect the carry
210        // flag, and LLVM is not that optimized to find out the pattern and let
211        // all these adc instructions not interleaved. However, doing so results
212        // in 3 instructions instead of the original 4 instructions (the two
213        // mov's are still there) and it makes a difference on input size like
214        // 1023.
215        macro_rules! update_sum_carry {
216            ($ty: ident, $chunk: expr) => {
217                let (s, c) = sum.overflowing_add($ty::from_ne_bytes($chunk) as Accumulator);
218                sum = s.wrapping_add(carry as Accumulator);
219                carry = c;
220            };
221        }
222
223        const ACCUMULATOR_BYTES: usize = (Accumulator::BITS / 8) as usize;
224        while let Some(chunk) = bytes.first_chunk::<ACCUMULATOR_BYTES>() {
225            update_sum_carry!(Accumulator, *chunk);
226            bytes = &bytes[ACCUMULATOR_BYTES..];
227        }
228
229        // Handle the tail.
230        if let Some(chunk) = bytes.first_chunk::<8>() {
231            update_sum_carry!(u64, *chunk);
232            bytes = &bytes[8..];
233        }
234        if let Some(chunk) = bytes.first_chunk::<4>() {
235            update_sum_carry!(u32, *chunk);
236            bytes = &bytes[4..];
237        }
238        if let Some(chunk) = bytes.first_chunk::<2>() {
239            update_sum_carry!(u16, *chunk);
240            bytes = &bytes[2..];
241        }
242        if bytes.len() == 1 {
243            if let Some(existing) = self.trailing_byte.take() {
244                // We already had a trailing byte. Deal with them both.
245                update_sum_carry!(u16, [existing, bytes[0]]);
246            } else {
247                // Otherwise, stash the trailing byte.
248                self.trailing_byte = Some(bytes[0])
249            }
250        }
251
252        self.sum = sum + (carry as Accumulator);
253    }
254
255    /// Computes the checksum, but in big endian byte order.
256    fn checksum_inner(&self) -> u16 {
257        let mut sum = self.sum;
258        if let Some(byte) = self.trailing_byte {
259            sum = adc_accumulator(sum, u16::from_ne_bytes([byte, 0]) as Accumulator);
260        }
261        !normalize(sum)
262    }
263
264    /// Computes the checksum, and returns the array representation.
265    ///
266    /// `checksum` returns the checksum of all data added using `add_bytes` so
267    /// far. Calling `checksum` does *not* reset the checksum. More bytes may be
268    /// added after calling `checksum`, and they will be added to the checksum
269    /// as expected.
270    ///
271    /// If an odd number of bytes have been added so far, the checksum will be
272    /// computed as though a single 0 byte had been added at the end in order to
273    /// even out the length of the input.
274    #[inline]
275    pub fn checksum(&self) -> [u8; 2] {
276        self.checksum_inner().to_ne_bytes()
277    }
278}
279
280macro_rules! impl_adc {
281    ($name: ident, $t: ty) => {
282        /// implements 1's complement addition for $t,
283        /// exploiting the carry flag on a 2's complement machine.
284        /// In practice, the adc instruction will be generated.
285        fn $name(a: $t, b: $t) -> $t {
286            let (s, c) = a.overflowing_add(b);
287            s + (c as $t)
288        }
289    };
290}
291
292impl_adc!(adc_u16, u16);
293impl_adc!(adc_u32, u32);
294#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
295impl_adc!(adc_u64, u64);
296impl_adc!(adc_accumulator, Accumulator);
297
298/// Normalizes the accumulator by mopping up the
299/// overflow until it fits in a `u16`.
300fn normalize(a: Accumulator) -> u16 {
301    #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
302    return normalize_64(adc_u64(a as u64, (a >> 64) as u64));
303    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
304    return normalize_64(a);
305}
306
307fn normalize_64(a: u64) -> u16 {
308    let t = adc_u32(a as u32, (a >> 32) as u32);
309    adc_u16(t as u16, (t >> 16) as u16)
310}
311
312#[cfg(test)]
313mod tests {
314    use rand::{Rng, SeedableRng};
315
316    use rand_xorshift::XorShiftRng;
317
318    use super::*;
319
320    /// Create a new deterministic RNG from a seed.
321    fn new_rng(mut seed: u128) -> XorShiftRng {
322        if seed == 0 {
323            // XorShiftRng can't take 0 seeds
324            seed = 1;
325        }
326        XorShiftRng::from_seed(seed.to_ne_bytes())
327    }
328
329    #[test]
330    fn test_checksum() {
331        for buf in IPV4_HEADERS {
332            // compute the checksum as normal
333            let mut c = Checksum::new();
334            c.add_bytes(&buf);
335            assert_eq!(c.checksum(), [0u8; 2]);
336            // compute the checksum one byte at a time to make sure our
337            // trailing_byte logic works
338            let mut c = Checksum::new();
339            for byte in *buf {
340                c.add_bytes(&[*byte]);
341            }
342            assert_eq!(c.checksum(), [0u8; 2]);
343
344            // Make sure that it works even if we overflow u32. Performing this
345            // loop 2 * 2^16 times is guaranteed to cause such an overflow
346            // because 0xFFFF + 0xFFFF > 2^16, and we're effectively adding
347            // (0xFFFF + 0xFFFF) 2^16 times. We verify the overflow as well by
348            // making sure that, at least once, the sum gets smaller from one
349            // loop iteration to the next.
350            let mut c = Checksum::new();
351            c.add_bytes(&[0xFF, 0xFF]);
352            for _ in 0..((2 * (1 << 16)) - 1) {
353                c.add_bytes(&[0xFF, 0xFF]);
354            }
355            assert_eq!(c.checksum(), [0u8; 2]);
356        }
357    }
358
359    #[test]
360    fn test_update() {
361        for b in IPV4_HEADERS {
362            let mut buf = Vec::new();
363            buf.extend_from_slice(b);
364
365            let mut c = Checksum::new();
366            c.add_bytes(&buf);
367            assert_eq!(c.checksum(), [0u8; 2]);
368
369            // replace the destination IP with the loopback address
370            let old = [buf[16], buf[17], buf[18], buf[19]];
371            (&mut buf[16..20]).copy_from_slice(&[127, 0, 0, 1]);
372            let updated = update(c.checksum(), &old, &[127, 0, 0, 1]);
373            let from_scratch = {
374                let mut c = Checksum::new();
375                c.add_bytes(&buf);
376                c.checksum()
377            };
378            assert_eq!(updated, from_scratch);
379        }
380    }
381
382    #[test]
383    fn test_update_noop() {
384        for b in IPV4_HEADERS {
385            let mut buf = Vec::new();
386            buf.extend_from_slice(b);
387
388            let mut c = Checksum::new();
389            c.add_bytes(&buf);
390            assert_eq!(c.checksum(), [0u8; 2]);
391
392            // Replace the destination IP with the same address. I.e. this
393            // update should be a no-op.
394            let old = [buf[16], buf[17], buf[18], buf[19]];
395            let updated = update(c.checksum(), &old, &old);
396            let from_scratch = {
397                let mut c = Checksum::new();
398                c.add_bytes(&buf);
399                c.checksum()
400            };
401            assert_eq!(updated, from_scratch);
402        }
403    }
404
405    #[test]
406    fn test_smoke_update() {
407        let mut rng = new_rng(70_812_476_915_813);
408
409        for _ in 0..2048 {
410            // use an odd length so we test the odd length logic
411            const BUF_LEN: usize = 31;
412            let buf: [u8; BUF_LEN] = rng.random();
413            let mut c = Checksum::new();
414            c.add_bytes(&buf);
415
416            let (begin, end) = loop {
417                let begin = rng.random_range(0..BUF_LEN);
418                let end = begin + (rng.random_range(0..(BUF_LEN + 1 - begin)));
419                // update requires that begin is even and end is either even or
420                // the end of the input
421                if begin % 2 == 0 && (end % 2 == 0 || end == BUF_LEN) {
422                    break (begin, end);
423                }
424            };
425
426            let mut new_buf = buf;
427            for i in begin..end {
428                new_buf[i] = rng.random();
429            }
430            let updated = update(c.checksum(), &buf[begin..end], &new_buf[begin..end]);
431            let from_scratch = {
432                let mut c = Checksum::new();
433                c.add_bytes(&new_buf);
434                c.checksum()
435            };
436            assert_eq!(updated, from_scratch);
437        }
438    }
439
440    /// IPv4 headers.
441    ///
442    /// This data was obtained by capturing live network traffic.
443    const IPV4_HEADERS: &[&[u8]] = &[
444        &[
445            0x45, 0x00, 0x00, 0x34, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0xae, 0xea, 0xc0, 0xa8,
446            0x01, 0x0f, 0xc0, 0xb8, 0x09, 0x6a,
447        ],
448        &[
449            0x45, 0x20, 0x00, 0x74, 0x5b, 0x6e, 0x40, 0x00, 0x37, 0x06, 0x5c, 0x1c, 0xc0, 0xb8,
450            0x09, 0x6a, 0xc0, 0xa8, 0x01, 0x0f,
451        ],
452        &[
453            0x45, 0x20, 0x02, 0x8f, 0x00, 0x00, 0x40, 0x00, 0x3b, 0x11, 0xc9, 0x3f, 0xac, 0xd9,
454            0x05, 0x6e, 0xc0, 0xa8, 0x01, 0x0f,
455        ],
456    ];
457
458    // This test checks that an input, found by a fuzzer, no longer causes a crash due to addition
459    // overflow.
460    #[test]
461    fn test_large_buffer_addition_overflow() {
462        let mut sum = Checksum { sum: 0, trailing_byte: None };
463        let bytes = [
464            0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
465            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
466            255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
467        ];
468        sum.add_bytes(&bytes[..]);
469    }
470}