netstack3_tcp/congestion/
cubic.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! The CUBIC congestion control algorithm as described in
6//! [RFC 8312](https://tools.ietf.org/html/rfc8312).
7//!
8//! Note: This module uses floating point arithmetics, assuming the TCP stack is
9//! in user space, as it is on Fuchsia. By not restricting ourselves, it is more
10//! straightforward to implement and easier to understand. We don't need to care
11//! about overflows and we get better precision. However, if this algorithm ever
12//! needs to be run in kernel space, especially when fp arithmentics are not
13//! allowed when the kernel deems saving fp registers too expensive, we should
14//! use fixed point arithmetic. Casts from u32 to f32 are always fine as f32 can
15//! represent a much bigger value range than u32; On the other hand, f32 to u32
16//! casts are also fine because Rust guarantees rounding towards zero (+inf is
17//! converted to u32::MAX), which aligns with our intention well.
18//!
19//! Reference: https://doc.rust-lang.org/reference/expressions/operator-expr.html#type-cast-expressions
20
21use core::num::NonZeroU32;
22use core::time::Duration;
23
24use netstack3_base::{Instant, Mss};
25
26use crate::internal::congestion::CongestionControlParams;
27
28/// Per RFC 8312 (https://tools.ietf.org/html/rfc8312#section-4.5):
29///  Parameter beta_cubic SHOULD be set to 0.7.
30const CUBIC_BETA: f32 = 0.7;
31/// Per RFC 8312 (https://tools.ietf.org/html/rfc8312#section-5):
32///  Therefore, C SHOULD be set to 0.4.
33const CUBIC_C: f32 = 0.4;
34
35/// The CUBIC algorithm state variables.
36#[derive(Debug, Clone, Copy, PartialEq, derivative::Derivative)]
37#[derivative(Default(bound = ""))]
38pub(super) struct Cubic<I, const FAST_CONVERGENCE: bool> {
39    /// The start of the current congestion avoidance epoch.
40    epoch_start: Option<I>,
41    /// Coefficient for the cubic term of time into the current congestion
42    /// avoidance epoch.
43    k: f32,
44    /// The window size when the last congestion event occurred, in bytes.
45    w_max: u32,
46    /// The running count of acked bytes during congestion avoidance.
47    bytes_acked: u32,
48}
49
50impl<I: Instant, const FAST_CONVERGENCE: bool> Cubic<I, FAST_CONVERGENCE> {
51    /// Returns the window size governed by the cubic growth function, in bytes.
52    ///
53    /// This function is responsible for the concave/convex regions described
54    /// in the RFC.
55    fn cubic_window(&self, t: Duration, mss: Mss) -> u32 {
56        // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.1):
57        //   W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1)
58        let x = t.as_secs_f32() - self.k;
59        let w_cubic = (self.cubic_c(mss) * f32::powi(x, 3)) + self.w_max as f32;
60        w_cubic as u32
61    }
62
63    /// Returns the window size for standard TCP, in bytes.
64    fn standard_tcp_window(&self, t: Duration, rtt: Duration, mss: Mss) -> u32 {
65        // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.2):
66        //   W_est(t) = W_max*beta_cubic +
67        //         [3*(1-beta_cubic)/(1+beta_cubic)] * (t/RTT) (Eq. 4)
68        let round_trips = t.as_secs_f32() / rtt.as_secs_f32();
69        let w_tcp = self.w_max as f32 * CUBIC_BETA
70            + (3.0 * (1.0 - CUBIC_BETA) / (1.0 + CUBIC_BETA)) * round_trips * u32::from(mss) as f32;
71        w_tcp as u32
72    }
73
74    pub(super) fn on_ack(
75        &mut self,
76        CongestionControlParams { cwnd, ssthresh, mss }: &mut CongestionControlParams,
77        mut bytes_acked: NonZeroU32,
78        now: I,
79        rtt: Duration,
80    ) {
81        if *cwnd < *ssthresh {
82            // Slow start, Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-6):
83            // we RECOMMEND that TCP implementations increase cwnd, per:
84            //   cwnd += min (N, SMSS)                      (2)
85            *cwnd += u32::min(bytes_acked.get(), u32::from(*mss));
86            if *cwnd <= *ssthresh {
87                return;
88            }
89            // Now that we are moving out of slow start, we need to treat the
90            // extra bytes differently, set the cwnd back to ssthresh and then
91            // backtrack the portion of bytes that should be processed in
92            // congestion avoidance.
93            match cwnd.checked_sub(*ssthresh).and_then(NonZeroU32::new) {
94                None => return,
95                Some(diff) => bytes_acked = diff,
96            }
97            *cwnd = *ssthresh;
98        }
99
100        // Congestion avoidance.
101        let epoch_start = match self.epoch_start {
102            Some(epoch_start) => epoch_start,
103            None => {
104                // Setup the parameters for the current congestion avoidance epoch.
105                if let Some(w_max_diff_cwnd) = self.w_max.checked_sub(*cwnd) {
106                    // K is the time period that the above function takes to
107                    // increase the current window size to W_max if there are no
108                    // further congestion events and is calculated using the
109                    // following equation:
110                    //   K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2)
111                    self.k = (w_max_diff_cwnd as f32 / self.cubic_c(*mss)).cbrt();
112                } else {
113                    // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.8):
114                    //   In the case when CUBIC runs the hybrid slow start [HR08],
115                    //   it may exit the first slow start without incurring any
116                    //   packet loss and thus W_max is undefined. In this special
117                    //   case, CUBIC switches to congestion avoidance and increases
118                    //   its congestion window size using Eq. 1, where t is the
119                    //   elapsed time since the beginning of the current congestion
120                    //   avoidance, K is set to 0, and W_max is set to the
121                    //   congestion window size at the beginning of the current
122                    //   congestion avoidance.
123                    self.k = 0.0;
124                    self.w_max = *cwnd;
125                }
126                self.epoch_start = Some(now);
127                now
128            }
129        };
130
131        // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.7):
132        //   Upon receiving an ACK during congestion avoidance, CUBIC computes
133        //   the window increase rate during the next RTT period using Eq. 1.
134        //   It sets W_cubic(t+RTT) as the candidate target value of the
135        //   congestion window, where RTT is the weighted average RTT calculated
136        //   by Standard TCP.
137        let t = now.saturating_duration_since(epoch_start);
138        let target = self.cubic_window(t + rtt, *mss);
139
140        // In a *very* rare case, we might overflow the counter if the acks
141        // keep coming in and we can't increase our congestion window. Use
142        // saturating add here as a defense so that we don't lost ack counts
143        // by accident.
144        self.bytes_acked = self.bytes_acked.saturating_add(bytes_acked.get());
145
146        // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.3):
147        //   cwnd MUST be incremented by (W_cubic(t+RTT) - cwnd)/cwnd for each
148        //   received ACK.
149        // Note: Here we use a similar approach as in appropriate byte counting
150        // (RFC 3465) - We count how many bytes are now acked, then we use Eq. 1
151        // to calculate how many acked bytes are needed before we can increase
152        // our cwnd by 1 MSS. The increase rate is (target - cwnd)/cwnd segments
153        // per ACK which is the same as 1 segment per cwnd/(target - cwnd) ACKs.
154        // Because our cubic function is a monotonically increasing function,
155        // this method is slightly more aggressive - if we need N acks to
156        // increase our window by 1 MSS, then it would take the RFC method at
157        // least N acks to increase the same amount. This method is used in the
158        // original CUBIC paper[1], and it eliminates the need to use f32 for
159        // cwnd, which is a bit awkward especially because our unit is in bytes
160        // and it doesn't make much sense to have byte number not to be a whole
161        // number.
162        // [1]: (https://www.cs.princeton.edu/courses/archive/fall16/cos561/papers/Cubic08.pdf)
163
164        {
165            let mss = u32::from(*mss);
166            // `saturating_add` avoids overflow in `cwnd`. See https://fxbug.dev/327628809.
167            let increased_cwnd = cwnd.saturating_add(mss);
168            if target >= increased_cwnd {
169                // Ensure the divisor is at least `mss` in case `target` and `cwnd`
170                // are both u32::MAX to avoid divide-by-zero.
171                let divisor = (target - *cwnd).max(mss);
172                let to_subtract_from_bytes_acked = *cwnd / divisor * mss;
173                // And the # of acked bytes is at least the required amount of bytes for
174                // increasing 1 MSS.
175                if self.bytes_acked >= to_subtract_from_bytes_acked {
176                    self.bytes_acked -= to_subtract_from_bytes_acked;
177                    *cwnd = increased_cwnd;
178                }
179            }
180        }
181
182        // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.2):
183        //   CUBIC checks whether W_cubic(t) is less than W_est(t).  If so,
184        //   CUBIC is in the TCP-friendly region and cwnd SHOULD be set to
185        //   W_est(t) at each reception of an ACK.
186        let w_tcp = self.standard_tcp_window(t, rtt, *mss);
187        if *cwnd < w_tcp {
188            *cwnd = w_tcp;
189        }
190    }
191
192    pub(super) fn on_loss_detected(
193        &mut self,
194        CongestionControlParams { cwnd, ssthresh, mss }: &mut CongestionControlParams,
195    ) {
196        // End the current congestion avoidance epoch.
197        self.epoch_start = None;
198        // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.6):
199        //   With fast convergence, when a congestion event occurs, before the
200        //   window reduction of the congestion window, a flow remembers the last
201        //   value of W_max before it updates W_max for the current congestion
202        //   event.  Let us call the last value of W_max to be W_last_max.
203        //   if (W_max < W_last_max){ // should we make room for others
204        //     W_last_max = W_max;             // remember the last W_max
205        //     W_max = W_max*(1.0+beta_cubic)/2.0; // further reduce W_max
206        //   } else {
207        //     W_last_max = W_max              // remember the last W_max
208        //   }
209        // Note: Here the code is slightly different from the RFC because there
210        // is an order to update the variables so that we do not need to store
211        // an extra variable (W_last_max). i.e. instead of assigning cwnd to
212        // W_max first, we compare it to W_last_max, that is the W_max before
213        // updating.
214        if FAST_CONVERGENCE && *cwnd < self.w_max {
215            self.w_max = (*cwnd as f32 * (1.0 + CUBIC_BETA) / 2.0) as u32;
216        } else {
217            self.w_max = *cwnd;
218        }
219        // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.7):
220        //   In case of timeout, CUBIC follows Standard TCP to reduce cwnd
221        //   [RFC5681], but sets ssthresh using beta_cubic (same as in
222        //   Section 4.5) that is different from Standard TCP [RFC5681].
223        *ssthresh = u32::max((*cwnd as f32 * CUBIC_BETA) as u32, 2 * u32::from(*mss));
224        *cwnd = *ssthresh;
225        // Reset our running count of the acked bytes.
226        self.bytes_acked = 0;
227    }
228
229    pub(super) fn on_retransmission_timeout(&mut self, params: &mut CongestionControlParams) {
230        self.on_loss_detected(params);
231        // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-8):
232        //   Furthermore, upon a timeout (as specified in [RFC2988]) cwnd MUST be
233        //   set to no more than the loss window, LW, which equals 1 full-sized
234        //   segment (regardless of the value of IW).
235        params.cwnd = u32::from(params.mss);
236    }
237
238    fn cubic_c(&self, mss: Mss) -> f32 {
239        // Note: cwnd and w_max are in unit of bytes as opposed to segments in
240        // RFC, so C should be CUBIC_C * mss for our implementation.
241        CUBIC_C * u32::from(mss) as f32
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use netstack3_base::testutil::FakeInstantCtx;
248    use netstack3_base::InstantContext as _;
249    use test_case::test_case;
250
251    use super::*;
252    use crate::internal::base::testutil::DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE;
253
254    impl<I: Instant, const FAST_CONVERGENCE: bool> Cubic<I, FAST_CONVERGENCE> {
255        // Helper function in test that takes a u32 instead of a NonZeroU32
256        // as we know we never pass 0 in the test and it's a bit clumsy to
257        // convert a u32 into a NonZeroU32 every time.
258        fn on_ack_u32(
259            &mut self,
260            params: &mut CongestionControlParams,
261            bytes_acked: u32,
262            now: I,
263            rtt: Duration,
264        ) {
265            self.on_ack(params, NonZeroU32::new(bytes_acked).unwrap(), now, rtt)
266        }
267    }
268
269    // The following expectations are extracted from table. 1 and table. 2 in
270    // RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-5.1). Note that
271    // some numbers do not match as-is, but the error rate is acceptable (~2%),
272    // this can be attributed to a few things, e.g., the way we simulate is
273    // slightly different from the the ideal process, as we start the first
274    // congestion avoidance with the convex region which grows pretty fast, also
275    // the theoretical estimation is an approximation already. The theoretical
276    // value is included in the name for each case.
277    #[test_case(Duration::from_millis(100), 100 => 11; "rtt=0.1 p=0.01 Wavg=12")]
278    #[test_case(Duration::from_millis(100), 1_000 => 38; "rtt=0.1 p=0.001 Wavg=38")]
279    #[test_case(Duration::from_millis(100), 10_000 => 186; "rtt=0.1 p=0.0001 Wavg=187")]
280    #[test_case(Duration::from_millis(100), 100_000 => 1078; "rtt=0.1 p=0.00001 Wavg=1054")]
281    #[test_case(Duration::from_millis(10), 100 => 11; "rtt=0.01 p=0.01 Wavg=12")]
282    #[test_case(Duration::from_millis(10), 1_000 => 37; "rtt=0.01 p=0.001 Wavg=38")]
283    #[test_case(Duration::from_millis(10), 10_000 => 121; "rtt=0.01 p=0.0001 Wavg=120")]
284    #[test_case(Duration::from_millis(10), 100_000 => 384; "rtt=0.01 p=0.00001 Wavg=379")]
285    #[test_case(Duration::from_millis(10), 1_000_000 => 1276; "rtt=0.01 p=0.000001 Wavg=1200")]
286    fn average_window_size(rtt: Duration, loss_rate_reciprocal: u32) -> u32 {
287        const ROUND_TRIPS: u32 = 100_000;
288
289        // The theoretical predictions do not consider fast convergence,
290        // disable it.
291        let mut cubic = Cubic::<_, false /* FAST_CONVERGENCE */>::default();
292        let mut params = CongestionControlParams::with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
293        // The theoretical value is a prediction for the congestion avoidance
294        // only, set ssthresh to 1 so that we skip slow start. Slow start can
295        // grow the window size very quickly.
296        params.ssthresh = 1;
297
298        let mut clock = FakeInstantCtx::default();
299
300        let mut avg_pkts = 0.0f32;
301        let mut ack_cnt = 0;
302
303        // We simulate a deterministic loss model, i.e., for loss_rate p, we
304        // drop one packet for every 1/p packets.
305        for _ in 0..ROUND_TRIPS {
306            let cwnd = u32::from(params.rounded_cwnd());
307            if ack_cnt >= loss_rate_reciprocal {
308                ack_cnt -= loss_rate_reciprocal;
309                cubic.on_loss_detected(&mut params);
310            } else {
311                ack_cnt += cwnd / u32::from(params.mss);
312                // We will get at least one ack for every two segments we send.
313                for _ in 0..u32::max(cwnd / u32::from(params.mss) / 2, 1) {
314                    let bytes_acked = 2 * u32::from(params.mss);
315                    cubic.on_ack_u32(&mut params, bytes_acked, clock.now(), rtt);
316                }
317            }
318            clock.sleep(rtt);
319            avg_pkts += (cwnd / u32::from(params.mss)) as f32 / ROUND_TRIPS as f32;
320        }
321        avg_pkts as u32
322    }
323
324    #[test]
325    fn cubic_example() {
326        let mut clock = FakeInstantCtx::default();
327        let mut cubic = Cubic::<_, true /* FAST_CONVERGENCE */>::default();
328        let mut params = CongestionControlParams::with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
329        const RTT: Duration = Duration::from_millis(100);
330
331        // Assert we have the correct initial window.
332        assert_eq!(params.cwnd, 4 * u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE));
333
334        // Slow start.
335        clock.sleep(RTT);
336        for _seg in 0..params.cwnd / u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE) {
337            cubic.on_ack_u32(
338                &mut params,
339                u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE),
340                clock.now(),
341                RTT,
342            );
343        }
344        assert_eq!(params.cwnd, 8 * u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE));
345
346        clock.sleep(RTT);
347        cubic.on_retransmission_timeout(&mut params);
348        assert_eq!(params.cwnd, u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE));
349
350        // We are now back in slow start.
351        clock.sleep(RTT);
352        cubic.on_ack_u32(
353            &mut params,
354            u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE),
355            clock.now(),
356            RTT,
357        );
358        assert_eq!(params.cwnd, 2 * u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE));
359
360        clock.sleep(RTT);
361        for _ in 0..2 {
362            cubic.on_ack_u32(
363                &mut params,
364                u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE),
365                clock.now(),
366                RTT,
367            );
368        }
369        assert_eq!(params.cwnd, 4 * u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE));
370
371        // In this roundtrip, we enter a new congestion epoch from slow start,
372        // in this round trip, both cubic and tcp equations will have t=0, so
373        // the cwnd in this round trip will be ssthresh, which is 3001 bytes,
374        // or 5 full sized segments.
375        clock.sleep(RTT);
376        for _seg in 0..params.cwnd / u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE) {
377            cubic.on_ack_u32(
378                &mut params,
379                u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE),
380                clock.now(),
381                RTT,
382            );
383        }
384        assert_eq!(
385            u32::from(params.rounded_cwnd()),
386            5 * u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE)
387        );
388
389        // Now we are at `epoch_start+RTT`, the window size should grow by at
390        // lease 1 u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE) per RTT (standard TCP).
391        clock.sleep(RTT);
392        for _seg in 0..params.cwnd / u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE) {
393            cubic.on_ack_u32(
394                &mut params,
395                u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE),
396                clock.now(),
397                RTT,
398            );
399        }
400        assert_eq!(
401            u32::from(params.rounded_cwnd()),
402            6 * u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE)
403        );
404    }
405
406    // This is a regression test for https://fxbug.dev/327628809.
407    #[test_case(u32::MAX ; "cwnd is u32::MAX")]
408    #[test_case(u32::MAX - 1; "cwnd is u32::MAX - 1")]
409    fn repro_overflow_b327628809(cwnd: u32) {
410        let clock = FakeInstantCtx::default();
411        let mut cubic = Cubic::<_, true /* FAST_CONVERGENCE */>::default();
412        let mut params =
413            CongestionControlParams { ssthresh: 0, cwnd, mss: DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE };
414        const RTT: Duration = Duration::from_millis(100);
415
416        cubic.on_ack(&mut params, NonZeroU32::MIN, clock.now(), RTT);
417    }
418}