netstack3_tcp/congestion/cubic.rs
1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! The CUBIC congestion control algorithm as described in
6//! [RFC 8312](https://tools.ietf.org/html/rfc8312).
7//!
8//! Note: This module uses floating point arithmetics, assuming the TCP stack is
9//! in user space, as it is on Fuchsia. By not restricting ourselves, it is more
10//! straightforward to implement and easier to understand. We don't need to care
11//! about overflows and we get better precision. However, if this algorithm ever
12//! needs to be run in kernel space, especially when fp arithmentics are not
13//! allowed when the kernel deems saving fp registers too expensive, we should
14//! use fixed point arithmetic. Casts from u32 to f32 are always fine as f32 can
15//! represent a much bigger value range than u32; On the other hand, f32 to u32
16//! casts are also fine because Rust guarantees rounding towards zero (+inf is
17//! converted to u32::MAX), which aligns with our intention well.
18//!
19//! Reference: https://doc.rust-lang.org/reference/expressions/operator-expr.html#type-cast-expressions
20
21use core::num::NonZeroU32;
22use core::time::Duration;
23
24use netstack3_base::{Instant, Mss};
25
26use crate::internal::congestion::CongestionControlParams;
27
28/// Per RFC 8312 (https://tools.ietf.org/html/rfc8312#section-4.5):
29/// Parameter beta_cubic SHOULD be set to 0.7.
30const CUBIC_BETA: f32 = 0.7;
31/// Per RFC 8312 (https://tools.ietf.org/html/rfc8312#section-5):
32/// Therefore, C SHOULD be set to 0.4.
33const CUBIC_C: f32 = 0.4;
34
35/// The CUBIC algorithm state variables.
36#[derive(Debug, Clone, Copy, PartialEq, derivative::Derivative)]
37#[derivative(Default(bound = ""))]
38pub(super) struct Cubic<I, const FAST_CONVERGENCE: bool> {
39 /// The start of the current congestion avoidance epoch.
40 epoch_start: Option<I>,
41 /// Coefficient for the cubic term of time into the current congestion
42 /// avoidance epoch.
43 k: f32,
44 /// The window size when the last congestion event occurred, in bytes.
45 w_max: u32,
46 /// The running count of acked bytes during congestion avoidance.
47 bytes_acked: u32,
48}
49
50impl<I: Instant, const FAST_CONVERGENCE: bool> Cubic<I, FAST_CONVERGENCE> {
51 /// Returns the window size governed by the cubic growth function, in bytes.
52 ///
53 /// This function is responsible for the concave/convex regions described
54 /// in the RFC.
55 fn cubic_window(&self, t: Duration, mss: Mss) -> u32 {
56 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.1):
57 // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1)
58 let x = t.as_secs_f32() - self.k;
59 let w_cubic = (self.cubic_c(mss) * f32::powi(x, 3)) + self.w_max as f32;
60 w_cubic as u32
61 }
62
63 /// Returns the window size for standard TCP, in bytes.
64 fn standard_tcp_window(&self, t: Duration, rtt: Duration, mss: Mss) -> u32 {
65 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.2):
66 // W_est(t) = W_max*beta_cubic +
67 // [3*(1-beta_cubic)/(1+beta_cubic)] * (t/RTT) (Eq. 4)
68 let round_trips = t.as_secs_f32() / rtt.as_secs_f32();
69 let w_tcp = self.w_max as f32 * CUBIC_BETA
70 + (3.0 * (1.0 - CUBIC_BETA) / (1.0 + CUBIC_BETA)) * round_trips * u32::from(mss) as f32;
71 w_tcp as u32
72 }
73
74 pub(super) fn on_ack(
75 &mut self,
76 CongestionControlParams { cwnd, ssthresh, mss }: &mut CongestionControlParams,
77 mut bytes_acked: NonZeroU32,
78 now: I,
79 rtt: Duration,
80 ) {
81 if *cwnd < *ssthresh {
82 // Slow start, Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-6):
83 // we RECOMMEND that TCP implementations increase cwnd, per:
84 // cwnd += min (N, SMSS) (2)
85 *cwnd += u32::min(bytes_acked.get(), u32::from(*mss));
86 if *cwnd <= *ssthresh {
87 return;
88 }
89 // Now that we are moving out of slow start, we need to treat the
90 // extra bytes differently, set the cwnd back to ssthresh and then
91 // backtrack the portion of bytes that should be processed in
92 // congestion avoidance.
93 match cwnd.checked_sub(*ssthresh).and_then(NonZeroU32::new) {
94 None => return,
95 Some(diff) => bytes_acked = diff,
96 }
97 *cwnd = *ssthresh;
98 }
99
100 // Congestion avoidance.
101 let epoch_start = match self.epoch_start {
102 Some(epoch_start) => epoch_start,
103 None => {
104 // Setup the parameters for the current congestion avoidance epoch.
105 if let Some(w_max_diff_cwnd) = self.w_max.checked_sub(*cwnd) {
106 // K is the time period that the above function takes to
107 // increase the current window size to W_max if there are no
108 // further congestion events and is calculated using the
109 // following equation:
110 // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2)
111 self.k = (w_max_diff_cwnd as f32 / self.cubic_c(*mss)).cbrt();
112 } else {
113 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.8):
114 // In the case when CUBIC runs the hybrid slow start [HR08],
115 // it may exit the first slow start without incurring any
116 // packet loss and thus W_max is undefined. In this special
117 // case, CUBIC switches to congestion avoidance and increases
118 // its congestion window size using Eq. 1, where t is the
119 // elapsed time since the beginning of the current congestion
120 // avoidance, K is set to 0, and W_max is set to the
121 // congestion window size at the beginning of the current
122 // congestion avoidance.
123 self.k = 0.0;
124 self.w_max = *cwnd;
125 }
126 self.epoch_start = Some(now);
127 now
128 }
129 };
130
131 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.7):
132 // Upon receiving an ACK during congestion avoidance, CUBIC computes
133 // the window increase rate during the next RTT period using Eq. 1.
134 // It sets W_cubic(t+RTT) as the candidate target value of the
135 // congestion window, where RTT is the weighted average RTT calculated
136 // by Standard TCP.
137 let t = now.saturating_duration_since(epoch_start);
138 let target = self.cubic_window(t + rtt, *mss);
139
140 // In a *very* rare case, we might overflow the counter if the acks
141 // keep coming in and we can't increase our congestion window. Use
142 // saturating add here as a defense so that we don't lost ack counts
143 // by accident.
144 self.bytes_acked = self.bytes_acked.saturating_add(bytes_acked.get());
145
146 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.3):
147 // cwnd MUST be incremented by (W_cubic(t+RTT) - cwnd)/cwnd for each
148 // received ACK.
149 // Note: Here we use a similar approach as in appropriate byte counting
150 // (RFC 3465) - We count how many bytes are now acked, then we use Eq. 1
151 // to calculate how many acked bytes are needed before we can increase
152 // our cwnd by 1 MSS. The increase rate is (target - cwnd)/cwnd segments
153 // per ACK which is the same as 1 segment per cwnd/(target - cwnd) ACKs.
154 // Because our cubic function is a monotonically increasing function,
155 // this method is slightly more aggressive - if we need N acks to
156 // increase our window by 1 MSS, then it would take the RFC method at
157 // least N acks to increase the same amount. This method is used in the
158 // original CUBIC paper[1], and it eliminates the need to use f32 for
159 // cwnd, which is a bit awkward especially because our unit is in bytes
160 // and it doesn't make much sense to have byte number not to be a whole
161 // number.
162 // [1]: (https://www.cs.princeton.edu/courses/archive/fall16/cos561/papers/Cubic08.pdf)
163
164 {
165 let mss = u32::from(*mss);
166 // `saturating_add` avoids overflow in `cwnd`. See https://fxbug.dev/327628809.
167 let increased_cwnd = cwnd.saturating_add(mss);
168 if target >= increased_cwnd {
169 // Ensure the divisor is at least `mss` in case `target` and `cwnd`
170 // are both u32::MAX to avoid divide-by-zero.
171 let divisor = (target - *cwnd).max(mss);
172 let to_subtract_from_bytes_acked = *cwnd / divisor * mss;
173 // And the # of acked bytes is at least the required amount of bytes for
174 // increasing 1 MSS.
175 if self.bytes_acked >= to_subtract_from_bytes_acked {
176 self.bytes_acked -= to_subtract_from_bytes_acked;
177 *cwnd = increased_cwnd;
178 }
179 }
180 }
181
182 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.2):
183 // CUBIC checks whether W_cubic(t) is less than W_est(t). If so,
184 // CUBIC is in the TCP-friendly region and cwnd SHOULD be set to
185 // W_est(t) at each reception of an ACK.
186 let w_tcp = self.standard_tcp_window(t, rtt, *mss);
187 if *cwnd < w_tcp {
188 *cwnd = w_tcp;
189 }
190 }
191
192 pub(super) fn on_loss_detected(
193 &mut self,
194 CongestionControlParams { cwnd, ssthresh, mss }: &mut CongestionControlParams,
195 ) {
196 // End the current congestion avoidance epoch.
197 self.epoch_start = None;
198 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.6):
199 // With fast convergence, when a congestion event occurs, before the
200 // window reduction of the congestion window, a flow remembers the last
201 // value of W_max before it updates W_max for the current congestion
202 // event. Let us call the last value of W_max to be W_last_max.
203 // if (W_max < W_last_max){ // should we make room for others
204 // W_last_max = W_max; // remember the last W_max
205 // W_max = W_max*(1.0+beta_cubic)/2.0; // further reduce W_max
206 // } else {
207 // W_last_max = W_max // remember the last W_max
208 // }
209 // Note: Here the code is slightly different from the RFC because there
210 // is an order to update the variables so that we do not need to store
211 // an extra variable (W_last_max). i.e. instead of assigning cwnd to
212 // W_max first, we compare it to W_last_max, that is the W_max before
213 // updating.
214 if FAST_CONVERGENCE && *cwnd < self.w_max {
215 self.w_max = (*cwnd as f32 * (1.0 + CUBIC_BETA) / 2.0) as u32;
216 } else {
217 self.w_max = *cwnd;
218 }
219 // Per RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-4.7):
220 // In case of timeout, CUBIC follows Standard TCP to reduce cwnd
221 // [RFC5681], but sets ssthresh using beta_cubic (same as in
222 // Section 4.5) that is different from Standard TCP [RFC5681].
223 *ssthresh = u32::max((*cwnd as f32 * CUBIC_BETA) as u32, 2 * u32::from(*mss));
224 *cwnd = *ssthresh;
225 // Reset our running count of the acked bytes.
226 self.bytes_acked = 0;
227 }
228
229 pub(super) fn on_retransmission_timeout(&mut self, params: &mut CongestionControlParams) {
230 self.on_loss_detected(params);
231 // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-8):
232 // Furthermore, upon a timeout (as specified in [RFC2988]) cwnd MUST be
233 // set to no more than the loss window, LW, which equals 1 full-sized
234 // segment (regardless of the value of IW).
235 params.cwnd = u32::from(params.mss);
236 }
237
238 fn cubic_c(&self, mss: Mss) -> f32 {
239 // Note: cwnd and w_max are in unit of bytes as opposed to segments in
240 // RFC, so C should be CUBIC_C * mss for our implementation.
241 CUBIC_C * u32::from(mss) as f32
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use netstack3_base::testutil::FakeInstantCtx;
248 use netstack3_base::InstantContext as _;
249 use test_case::test_case;
250
251 use super::*;
252 use crate::internal::base::testutil::DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE;
253
254 impl<I: Instant, const FAST_CONVERGENCE: bool> Cubic<I, FAST_CONVERGENCE> {
255 // Helper function in test that takes a u32 instead of a NonZeroU32
256 // as we know we never pass 0 in the test and it's a bit clumsy to
257 // convert a u32 into a NonZeroU32 every time.
258 fn on_ack_u32(
259 &mut self,
260 params: &mut CongestionControlParams,
261 bytes_acked: u32,
262 now: I,
263 rtt: Duration,
264 ) {
265 self.on_ack(params, NonZeroU32::new(bytes_acked).unwrap(), now, rtt)
266 }
267 }
268
269 // The following expectations are extracted from table. 1 and table. 2 in
270 // RFC 8312 (https://www.rfc-editor.org/rfc/rfc8312#section-5.1). Note that
271 // some numbers do not match as-is, but the error rate is acceptable (~2%),
272 // this can be attributed to a few things, e.g., the way we simulate is
273 // slightly different from the the ideal process, as we start the first
274 // congestion avoidance with the convex region which grows pretty fast, also
275 // the theoretical estimation is an approximation already. The theoretical
276 // value is included in the name for each case.
277 #[test_case(Duration::from_millis(100), 100 => 11; "rtt=0.1 p=0.01 Wavg=12")]
278 #[test_case(Duration::from_millis(100), 1_000 => 38; "rtt=0.1 p=0.001 Wavg=38")]
279 #[test_case(Duration::from_millis(100), 10_000 => 186; "rtt=0.1 p=0.0001 Wavg=187")]
280 #[test_case(Duration::from_millis(100), 100_000 => 1078; "rtt=0.1 p=0.00001 Wavg=1054")]
281 #[test_case(Duration::from_millis(10), 100 => 11; "rtt=0.01 p=0.01 Wavg=12")]
282 #[test_case(Duration::from_millis(10), 1_000 => 37; "rtt=0.01 p=0.001 Wavg=38")]
283 #[test_case(Duration::from_millis(10), 10_000 => 121; "rtt=0.01 p=0.0001 Wavg=120")]
284 #[test_case(Duration::from_millis(10), 100_000 => 384; "rtt=0.01 p=0.00001 Wavg=379")]
285 #[test_case(Duration::from_millis(10), 1_000_000 => 1276; "rtt=0.01 p=0.000001 Wavg=1200")]
286 fn average_window_size(rtt: Duration, loss_rate_reciprocal: u32) -> u32 {
287 const ROUND_TRIPS: u32 = 100_000;
288
289 // The theoretical predictions do not consider fast convergence,
290 // disable it.
291 let mut cubic = Cubic::<_, false /* FAST_CONVERGENCE */>::default();
292 let mut params = CongestionControlParams::with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
293 // The theoretical value is a prediction for the congestion avoidance
294 // only, set ssthresh to 1 so that we skip slow start. Slow start can
295 // grow the window size very quickly.
296 params.ssthresh = 1;
297
298 let mut clock = FakeInstantCtx::default();
299
300 let mut avg_pkts = 0.0f32;
301 let mut ack_cnt = 0;
302
303 // We simulate a deterministic loss model, i.e., for loss_rate p, we
304 // drop one packet for every 1/p packets.
305 for _ in 0..ROUND_TRIPS {
306 let cwnd = u32::from(params.rounded_cwnd());
307 if ack_cnt >= loss_rate_reciprocal {
308 ack_cnt -= loss_rate_reciprocal;
309 cubic.on_loss_detected(&mut params);
310 } else {
311 ack_cnt += cwnd / u32::from(params.mss);
312 // We will get at least one ack for every two segments we send.
313 for _ in 0..u32::max(cwnd / u32::from(params.mss) / 2, 1) {
314 let bytes_acked = 2 * u32::from(params.mss);
315 cubic.on_ack_u32(&mut params, bytes_acked, clock.now(), rtt);
316 }
317 }
318 clock.sleep(rtt);
319 avg_pkts += (cwnd / u32::from(params.mss)) as f32 / ROUND_TRIPS as f32;
320 }
321 avg_pkts as u32
322 }
323
324 #[test]
325 fn cubic_example() {
326 let mut clock = FakeInstantCtx::default();
327 let mut cubic = Cubic::<_, true /* FAST_CONVERGENCE */>::default();
328 let mut params = CongestionControlParams::with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
329 const RTT: Duration = Duration::from_millis(100);
330
331 // Assert we have the correct initial window.
332 assert_eq!(params.cwnd, 4 * u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE));
333
334 // Slow start.
335 clock.sleep(RTT);
336 for _seg in 0..params.cwnd / u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE) {
337 cubic.on_ack_u32(
338 &mut params,
339 u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE),
340 clock.now(),
341 RTT,
342 );
343 }
344 assert_eq!(params.cwnd, 8 * u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE));
345
346 clock.sleep(RTT);
347 cubic.on_retransmission_timeout(&mut params);
348 assert_eq!(params.cwnd, u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE));
349
350 // We are now back in slow start.
351 clock.sleep(RTT);
352 cubic.on_ack_u32(
353 &mut params,
354 u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE),
355 clock.now(),
356 RTT,
357 );
358 assert_eq!(params.cwnd, 2 * u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE));
359
360 clock.sleep(RTT);
361 for _ in 0..2 {
362 cubic.on_ack_u32(
363 &mut params,
364 u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE),
365 clock.now(),
366 RTT,
367 );
368 }
369 assert_eq!(params.cwnd, 4 * u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE));
370
371 // In this roundtrip, we enter a new congestion epoch from slow start,
372 // in this round trip, both cubic and tcp equations will have t=0, so
373 // the cwnd in this round trip will be ssthresh, which is 3001 bytes,
374 // or 5 full sized segments.
375 clock.sleep(RTT);
376 for _seg in 0..params.cwnd / u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE) {
377 cubic.on_ack_u32(
378 &mut params,
379 u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE),
380 clock.now(),
381 RTT,
382 );
383 }
384 assert_eq!(
385 u32::from(params.rounded_cwnd()),
386 5 * u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE)
387 );
388
389 // Now we are at `epoch_start+RTT`, the window size should grow by at
390 // lease 1 u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE) per RTT (standard TCP).
391 clock.sleep(RTT);
392 for _seg in 0..params.cwnd / u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE) {
393 cubic.on_ack_u32(
394 &mut params,
395 u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE),
396 clock.now(),
397 RTT,
398 );
399 }
400 assert_eq!(
401 u32::from(params.rounded_cwnd()),
402 6 * u32::from(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE)
403 );
404 }
405
406 // This is a regression test for https://fxbug.dev/327628809.
407 #[test_case(u32::MAX ; "cwnd is u32::MAX")]
408 #[test_case(u32::MAX - 1; "cwnd is u32::MAX - 1")]
409 fn repro_overflow_b327628809(cwnd: u32) {
410 let clock = FakeInstantCtx::default();
411 let mut cubic = Cubic::<_, true /* FAST_CONVERGENCE */>::default();
412 let mut params =
413 CongestionControlParams { ssthresh: 0, cwnd, mss: DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE };
414 const RTT: Duration = Duration::from_millis(100);
415
416 cubic.on_ack(&mut params, NonZeroU32::MIN, clock.now(), RTT);
417 }
418}