netstack3_tcp/congestion/cubic.rs
1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! The CUBIC congestion control algorithm as described in
6//! [RFC 8312](https://tools.ietf.org/html/rfc8312).
7//!
8//! Note: This module uses floating point arithmetics, assuming the TCP stack is
9//! in user space, as it is on Fuchsia. By not restricting ourselves, it is more
10//! straightforward to implement and easier to understand. We don't need to care
11//! about overflows and we get better precision. However, if this algorithm ever
12//! needs to be run in kernel space, especially when fp arithmentics are not
13//! allowed when the kernel deems saving fp registers too expensive, we should
14//! use fixed point arithmetic. Casts from u32 to f32 are always fine as f32 can
15//! represent a much bigger value range than u32; On the other hand, f32 to u32
16//! casts are also fine because Rust guarantees rounding towards zero (+inf is
17//! converted to u32::MAX), which aligns with our intention well.
18//!
19//! Reference: https://doc.rust-lang.org/reference/expressions/operator-expr.html#type-cast-expressions
20
21use core::num::NonZeroU32;
22use core::time::Duration;
23
24use netstack3_base::{EffectiveMss, Instant};
25
26use crate::internal::congestion::CongestionControlParams;
27
28/// Per RFC 8312 (https://tools.ietf.org/html/rfc8312#section-4.5):
29/// Parameter beta_cubic SHOULD be set to 0.7.
30const CUBIC_BETA: f32 = 0.7;
31/// Per RFC 8312 (https://tools.ietf.org/html/rfc8312#section-5):
32/// Therefore, C SHOULD be set to 0.4.
33const CUBIC_C: f32 = 0.4;
34
35/// The CUBIC algorithm state variables.
36#[derive(Debug, Clone, Copy, PartialEq, derivative::Derivative)]
37#[derivative(Default(bound = ""))]
38pub(super) struct Cubic<I, const FAST_CONVERGENCE: bool> {
39 /// The start of the current congestion avoidance epoch.
40 epoch_start: Option<I>,
41 /// Coefficient for the cubic term of time into the current congestion
42 /// avoidance epoch.
43 k: f32,
44 /// The window size when the last congestion event occurred, in bytes.
45 w_max: u32,
46 /// The running count of acked bytes during congestion avoidance.
47 bytes_acked: u32,
48}
49
50impl<I: Instant, const FAST_CONVERGENCE: bool> Cubic<I, FAST_CONVERGENCE> {
51 /// Returns the window size governed by the cubic growth function, in bytes.
52 ///
53 /// This function is responsible for the concave/convex regions described
54 /// in the RFC.
55 fn cubic_window(&self, t: Duration, mss: EffectiveMss) -> u32 {
56 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.1):
57 // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1)
58 let x = t.as_secs_f32() - self.k;
59 let w_cubic = (self.cubic_c(mss) * f32::powi(x, 3)) + self.w_max as f32;
60 w_cubic as u32
61 }
62
63 /// Returns the window size for standard TCP, in bytes.
64 fn standard_tcp_window(&self, t: Duration, rtt: Duration, mss: EffectiveMss) -> u32 {
65 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.2):
66 // W_est(t) = W_max*beta_cubic +
67 // [3*(1-beta_cubic)/(1+beta_cubic)] * (t/RTT) (Eq. 4)
68 let round_trips = t.as_secs_f32() / rtt.as_secs_f32();
69 let w_tcp = self.w_max as f32 * CUBIC_BETA
70 + (3.0 * (1.0 - CUBIC_BETA) / (1.0 + CUBIC_BETA)) * round_trips * u32::from(mss) as f32;
71 w_tcp as u32
72 }
73
74 pub(super) fn on_ack(
75 &mut self,
76 CongestionControlParams { cwnd, ssthresh, mss }: &mut CongestionControlParams,
77 mut bytes_acked: NonZeroU32,
78 now: I,
79 rtt: Duration,
80 ) {
81 if *cwnd < *ssthresh {
82 // Slow start, Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-6):
83 // we RECOMMEND that TCP implementations increase cwnd, per:
84 // cwnd += min (N, SMSS) (2)
85 *cwnd = cwnd.saturating_add(u32::min(bytes_acked.get(), u32::from(*mss)));
86 if *cwnd <= *ssthresh {
87 return;
88 }
89 // Now that we are moving out of slow start, we need to treat the
90 // extra bytes differently, set the cwnd back to ssthresh and then
91 // backtrack the portion of bytes that should be processed in
92 // congestion avoidance.
93 match cwnd.checked_sub(*ssthresh).and_then(NonZeroU32::new) {
94 None => return,
95 Some(diff) => bytes_acked = diff,
96 }
97 *cwnd = *ssthresh;
98 }
99
100 // Congestion avoidance.
101 let epoch_start = match self.epoch_start {
102 Some(epoch_start) => epoch_start,
103 None => {
104 // Setup the parameters for the current congestion avoidance epoch.
105 if let Some(w_max_diff_cwnd) = self.w_max.checked_sub(*cwnd) {
106 // K is the time period that the above function takes to
107 // increase the current window size to W_max if there are no
108 // further congestion events and is calculated using the
109 // following equation:
110 // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2)
111 self.k = (w_max_diff_cwnd as f32 / self.cubic_c(*mss)).cbrt();
112 } else {
113 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.8):
114 // In the case when CUBIC runs the hybrid slow start [HR08],
115 // it may exit the first slow start without incurring any
116 // packet loss and thus W_max is undefined. In this special
117 // case, CUBIC switches to congestion avoidance and increases
118 // its congestion window size using Eq. 1, where t is the
119 // elapsed time since the beginning of the current congestion
120 // avoidance, K is set to 0, and W_max is set to the
121 // congestion window size at the beginning of the current
122 // congestion avoidance.
123 self.k = 0.0;
124 self.w_max = *cwnd;
125 }
126 self.epoch_start = Some(now);
127 now
128 }
129 };
130
131 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.7):
132 // Upon receiving an ACK during congestion avoidance, CUBIC computes
133 // the window increase rate during the next RTT period using Eq. 1.
134 // It sets W_cubic(t+RTT) as the candidate target value of the
135 // congestion window, where RTT is the weighted average RTT calculated
136 // by Standard TCP.
137 let t = now.saturating_duration_since(epoch_start);
138 let target = self.cubic_window(t + rtt, *mss);
139
140 // In a *very* rare case, we might overflow the counter if the acks
141 // keep coming in and we can't increase our congestion window. Use
142 // saturating add here as a defense so that we don't lost ack counts
143 // by accident.
144 self.bytes_acked = self.bytes_acked.saturating_add(bytes_acked.get());
145
146 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.3):
147 // cwnd MUST be incremented by (W_cubic(t+RTT) - cwnd)/cwnd for each
148 // received ACK.
149 // Note: Here we use a similar approach as in appropriate byte counting
150 // (RFC 3465) - We count how many bytes are now acked, then we use Eq. 1
151 // to calculate how many acked bytes are needed before we can increase
152 // our cwnd by 1 MSS. The increase rate is (target - cwnd)/cwnd segments
153 // per ACK which is the same as 1 segment per cwnd/(target - cwnd) ACKs.
154 // Because our cubic function is a monotonically increasing function,
155 // this method is slightly more aggressive - if we need N acks to
156 // increase our window by 1 MSS, then it would take the RFC method at
157 // least N acks to increase the same amount. This method is used in the
158 // original CUBIC paper[1], and it eliminates the need to use f32 for
159 // cwnd, which is a bit awkward especially because our unit is in bytes
160 // and it doesn't make much sense to have byte number not to be a whole
161 // number.
162 // [1]: (https://www.cs.princeton.edu/courses/archive/fall16/cos561/papers/Cubic08.pdf)
163
164 {
165 let mss = u32::from(*mss);
166 // `saturating_add` avoids overflow in `cwnd`. See https://fxbug.dev/327628809.
167 let increased_cwnd = cwnd.saturating_add(mss);
168 if target >= increased_cwnd {
169 // Ensure the divisor is at least `mss` in case `target` and `cwnd`
170 // are both u32::MAX to avoid divide-by-zero.
171 let divisor = (target - *cwnd).max(mss);
172 let to_subtract_from_bytes_acked = *cwnd / divisor * mss;
173 // And the # of acked bytes is at least the required amount of bytes for
174 // increasing 1 MSS.
175 if self.bytes_acked >= to_subtract_from_bytes_acked {
176 self.bytes_acked -= to_subtract_from_bytes_acked;
177 *cwnd = increased_cwnd;
178 }
179 }
180 }
181
182 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.2):
183 // CUBIC checks whether W_cubic(t) is less than W_est(t). If so,
184 // CUBIC is in the TCP-friendly region and cwnd SHOULD be set to
185 // W_est(t) at each reception of an ACK.
186 let w_tcp = self.standard_tcp_window(t, rtt, *mss);
187 if *cwnd < w_tcp {
188 *cwnd = w_tcp;
189 }
190 }
191
192 pub(super) fn on_loss_detected(
193 &mut self,
194 CongestionControlParams { cwnd, ssthresh, mss }: &mut CongestionControlParams,
195 ) {
196 // End the current congestion avoidance epoch.
197 self.epoch_start = None;
198 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.6):
199 // With fast convergence, when a congestion event occurs, before the
200 // window reduction of the congestion window, a flow remembers the last
201 // value of W_max before it updates W_max for the current congestion
202 // event. Let us call the last value of W_max to be W_last_max.
203 // if (W_max < W_last_max){ // should we make room for others
204 // W_last_max = W_max; // remember the last W_max
205 // W_max = W_max*(1.0+beta_cubic)/2.0; // further reduce W_max
206 // } else {
207 // W_last_max = W_max // remember the last W_max
208 // }
209 // Note: Here the code is slightly different from the RFC because there
210 // is an order to update the variables so that we do not need to store
211 // an extra variable (W_last_max). i.e. instead of assigning cwnd to
212 // W_max first, we compare it to W_last_max, that is the W_max before
213 // updating.
214 if FAST_CONVERGENCE && *cwnd < self.w_max {
215 self.w_max = (*cwnd as f32 * (1.0 + CUBIC_BETA) / 2.0) as u32;
216 } else {
217 self.w_max = *cwnd;
218 }
219 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.7):
220 // In case of timeout, CUBIC follows Standard TCP to reduce cwnd
221 // [RFC5681], but sets ssthresh using beta_cubic (same as in
222 // Section 4.5) that is different from Standard TCP [RFC5681].
223 *ssthresh = u32::max((*cwnd as f32 * CUBIC_BETA) as u32, 2 * u32::from(*mss));
224 *cwnd = *ssthresh;
225 // Reset our running count of the acked bytes.
226 self.bytes_acked = 0;
227 }
228
229 pub(super) fn on_retransmission_timeout(&mut self, params: &mut CongestionControlParams) {
230 self.on_loss_detected(params);
231 // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-8):
232 // Furthermore, upon a timeout (as specified in [RFC2988]) cwnd MUST be
233 // set to no more than the loss window, LW, which equals 1 full-sized
234 // segment (regardless of the value of IW).
235 params.cwnd = u32::from(params.mss);
236 }
237
238 fn cubic_c(&self, mss: EffectiveMss) -> f32 {
239 // Note: cwnd and w_max are in unit of bytes as opposed to segments in
240 // RFC, so C should be CUBIC_C * mss for our implementation.
241 CUBIC_C * u32::from(mss) as f32
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use netstack3_base::testutil::FakeInstantCtx;
248 use netstack3_base::{EffectiveMss, InstantContext as _, Mss, MssSizeLimiters};
249 use test_case::test_case;
250
251 use super::*;
252
253 const DEFAULT_MSS: EffectiveMss =
254 EffectiveMss::from_mss(Mss::DEFAULT_IPV4, MssSizeLimiters { timestamp_enabled: false });
255 impl<I: Instant, const FAST_CONVERGENCE: bool> Cubic<I, FAST_CONVERGENCE> {
256 // Helper function in test that takes a u32 instead of a NonZeroU32
257 // as we know we never pass 0 in the test and it's a bit clumsy to
258 // convert a u32 into a NonZeroU32 every time.
259 fn on_ack_u32(
260 &mut self,
261 params: &mut CongestionControlParams,
262 bytes_acked: u32,
263 now: I,
264 rtt: Duration,
265 ) {
266 self.on_ack(params, NonZeroU32::new(bytes_acked).unwrap(), now, rtt)
267 }
268 }
269
270 // The following expectations are extracted from table. 1 and table. 2 in
271 // RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-5.1). Note that
272 // some numbers do not match as-is, but the error rate is acceptable (~2%),
273 // this can be attributed to a few things, e.g., the way we simulate is
274 // slightly different from the the ideal process, as we start the first
275 // congestion avoidance with the convex region which grows pretty fast, also
276 // the theoretical estimation is an approximation already. The theoretical
277 // value is included in the name for each case.
278 #[test_case(Duration::from_millis(100), 100 => 11; "rtt=0.1 p=0.01 Wavg=12")]
279 #[test_case(Duration::from_millis(100), 1_000 => 38; "rtt=0.1 p=0.001 Wavg=38")]
280 #[test_case(Duration::from_millis(100), 10_000 => 186; "rtt=0.1 p=0.0001 Wavg=187")]
281 #[test_case(Duration::from_millis(100), 100_000 => 1078; "rtt=0.1 p=0.00001 Wavg=1054")]
282 #[test_case(Duration::from_millis(10), 100 => 11; "rtt=0.01 p=0.01 Wavg=12")]
283 #[test_case(Duration::from_millis(10), 1_000 => 37; "rtt=0.01 p=0.001 Wavg=38")]
284 #[test_case(Duration::from_millis(10), 10_000 => 121; "rtt=0.01 p=0.0001 Wavg=120")]
285 #[test_case(Duration::from_millis(10), 100_000 => 384; "rtt=0.01 p=0.00001 Wavg=379")]
286 #[test_case(Duration::from_millis(10), 1_000_000 => 1276; "rtt=0.01 p=0.000001 Wavg=1200")]
287 fn average_window_size(rtt: Duration, loss_rate_reciprocal: u32) -> u32 {
288 const ROUND_TRIPS: u32 = 100_000;
289
290 // The theoretical predictions do not consider fast convergence,
291 // disable it.
292 let mut cubic = Cubic::<_, false /* FAST_CONVERGENCE */>::default();
293 let mut params = CongestionControlParams::with_mss(DEFAULT_MSS);
294 // The theoretical value is a prediction for the congestion avoidance
295 // only, set ssthresh to 1 so that we skip slow start. Slow start can
296 // grow the window size very quickly.
297 params.ssthresh = 1;
298
299 let mut clock = FakeInstantCtx::default();
300
301 let mut avg_pkts = 0.0f32;
302 let mut ack_cnt = 0;
303
304 // We simulate a deterministic loss model, i.e., for loss_rate p, we
305 // drop one packet for every 1/p packets.
306 for _ in 0..ROUND_TRIPS {
307 let cwnd = params.rounded_cwnd().cwnd();
308 if ack_cnt >= loss_rate_reciprocal {
309 ack_cnt -= loss_rate_reciprocal;
310 cubic.on_loss_detected(&mut params);
311 } else {
312 ack_cnt += cwnd / u32::from(params.mss);
313 // We will get at least one ack for every two segments we send.
314 for _ in 0..u32::max(cwnd / u32::from(params.mss) / 2, 1) {
315 let bytes_acked = 2 * u32::from(params.mss);
316 cubic.on_ack_u32(&mut params, bytes_acked, clock.now(), rtt);
317 }
318 }
319 clock.sleep(rtt);
320 avg_pkts += (cwnd / u32::from(params.mss)) as f32 / ROUND_TRIPS as f32;
321 }
322 avg_pkts as u32
323 }
324
325 #[test]
326 fn cubic_example() {
327 let mut clock = FakeInstantCtx::default();
328 let mut cubic = Cubic::<_, true /* FAST_CONVERGENCE */>::default();
329 let mut params = CongestionControlParams::with_mss(DEFAULT_MSS);
330 const RTT: Duration = Duration::from_millis(100);
331
332 // Assert we have the correct initial window.
333 assert_eq!(params.cwnd, 4 * u32::from(DEFAULT_MSS));
334
335 // Slow start.
336 clock.sleep(RTT);
337 for _seg in 0..params.cwnd / u32::from(DEFAULT_MSS) {
338 cubic.on_ack_u32(&mut params, u32::from(DEFAULT_MSS), clock.now(), RTT);
339 }
340 assert_eq!(params.cwnd, 8 * u32::from(DEFAULT_MSS));
341
342 clock.sleep(RTT);
343 cubic.on_retransmission_timeout(&mut params);
344 assert_eq!(params.cwnd, u32::from(DEFAULT_MSS));
345
346 // We are now back in slow start.
347 clock.sleep(RTT);
348 cubic.on_ack_u32(&mut params, u32::from(DEFAULT_MSS), clock.now(), RTT);
349 assert_eq!(params.cwnd, 2 * u32::from(DEFAULT_MSS));
350
351 clock.sleep(RTT);
352 for _ in 0..2 {
353 cubic.on_ack_u32(&mut params, u32::from(DEFAULT_MSS), clock.now(), RTT);
354 }
355 assert_eq!(params.cwnd, 4 * u32::from(DEFAULT_MSS));
356
357 // In this roundtrip, we enter a new congestion epoch from slow start,
358 // in this round trip, both cubic and tcp equations will have t=0, so
359 // the cwnd in this round trip will be ssthresh, which is 3001 bytes,
360 // or 5 full sized segments.
361 clock.sleep(RTT);
362 for _seg in 0..params.cwnd / u32::from(DEFAULT_MSS) {
363 cubic.on_ack_u32(&mut params, u32::from(DEFAULT_MSS), clock.now(), RTT);
364 }
365 assert_eq!(params.rounded_cwnd().cwnd(), 5 * u32::from(DEFAULT_MSS));
366
367 // Now we are at `epoch_start+RTT`, the window size should grow by at
368 // least 1 u32::from(DEFAULT_MSS) per RTT (standard TCP).
369 clock.sleep(RTT);
370 for _seg in 0..params.cwnd / u32::from(DEFAULT_MSS) {
371 cubic.on_ack_u32(&mut params, u32::from(DEFAULT_MSS), clock.now(), RTT);
372 }
373 assert_eq!(params.rounded_cwnd().cwnd(), 6 * u32::from(DEFAULT_MSS));
374 }
375
376 // This is a regression test for https://fxbug.dev/327628809.
377 #[test_case(u32::MAX ; "cwnd is u32::MAX")]
378 #[test_case(u32::MAX - 1; "cwnd is u32::MAX - 1")]
379 fn repro_overflow_b327628809(cwnd: u32) {
380 let clock = FakeInstantCtx::default();
381 let mut cubic = Cubic::<_, true /* FAST_CONVERGENCE */>::default();
382 let mut params = CongestionControlParams { ssthresh: 0, cwnd, mss: DEFAULT_MSS };
383 const RTT: Duration = Duration::from_millis(100);
384
385 cubic.on_ack(&mut params, NonZeroU32::MIN, clock.now(), RTT);
386 }
387
388 // This is a regression test for https://fxbug.dev/412748465.
389 #[test]
390 fn repro_overflow_b412748465() {
391 let clock = FakeInstantCtx::default();
392 let mut cubic = Cubic::<_, true /* FAST_CONVERGENCE */>::default();
393 // Setup the params in slow start with `cwnd` close to overflow.
394 let mut params =
395 CongestionControlParams { ssthresh: u32::MAX, cwnd: u32::MAX - 1, mss: DEFAULT_MSS };
396 const RTT: Duration = Duration::from_millis(100);
397 // Ack enough bytes to push cwnd over u32::MAX.
398 cubic.on_ack(
399 &mut params,
400 NonZeroU32::new(2).unwrap(), /*bytes_acked*/
401 clock.now(),
402 RTT,
403 );
404 }
405}