netstack3_tcp/congestion.rs
1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implements loss-based congestion control algorithms.
6//!
7//! The currently implemented algorithms are CUBIC from [RFC 8312] and RENO
8//! style fast retransmit and fast recovery from [RFC 5681].
9//!
10//! [RFC 8312]: https://www.rfc-editor.org/rfc/rfc8312
11//! [RFC 5681]: https://www.rfc-editor.org/rfc/rfc5681
12
13mod cubic;
14
15use core::cmp::Ordering;
16use core::num::{NonZeroU32, NonZeroU8};
17use core::time::Duration;
18
19use netstack3_base::{Instant, Mss, SeqNum, WindowSize};
20
21// Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
22/// The fast retransmit algorithm uses the arrival of 3 duplicate ACKs (...)
23/// as an indication that a segment has been lost.
24const DUP_ACK_THRESHOLD: u8 = 3;
25
26/// Holds the parameters of congestion control that are common to algorithms.
27#[derive(Debug)]
28struct CongestionControlParams {
29 /// Slow start threshold.
30 ssthresh: u32,
31 /// Congestion control window size, in bytes.
32 cwnd: u32,
33 /// Sender MSS.
34 mss: Mss,
35}
36
37impl CongestionControlParams {
38 fn with_mss(mss: Mss) -> Self {
39 let mss_u32 = u32::from(mss);
40 // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-5):
41 // IW, the initial value of cwnd, MUST be set using the following
42 // guidelines as an upper bound.
43 // If SMSS > 2190 bytes:
44 // IW = 2 * SMSS bytes and MUST NOT be more than 2 segments
45 // If (SMSS > 1095 bytes) and (SMSS <= 2190 bytes):
46 // IW = 3 * SMSS bytes and MUST NOT be more than 3 segments
47 // if SMSS <= 1095 bytes:
48 // IW = 4 * SMSS bytes and MUST NOT be more than 4 segments
49 let cwnd = if mss_u32 > 2190 {
50 mss_u32 * 2
51 } else if mss_u32 > 1095 {
52 mss_u32 * 3
53 } else {
54 mss_u32 * 4
55 };
56 Self { cwnd, ssthresh: u32::MAX, mss }
57 }
58
59 fn rounded_cwnd(&self) -> WindowSize {
60 let mss_u32 = u32::from(self.mss);
61 WindowSize::from_u32(self.cwnd / mss_u32 * mss_u32).unwrap_or(WindowSize::MAX)
62 }
63}
64
65/// Congestion control with four intertwined algorithms.
66///
67/// - Slow start
68/// - Congestion avoidance from a loss-based algorithm
69/// - Fast retransmit
70/// - Fast recovery: https://datatracker.ietf.org/doc/html/rfc5681#section-3
71#[derive(Debug)]
72pub(super) struct CongestionControl<I> {
73 params: CongestionControlParams,
74 algorithm: LossBasedAlgorithm<I>,
75 /// The connection is in fast recovery when this field is a [`Some`].
76 fast_recovery: Option<FastRecovery>,
77}
78
79/// Available congestion control algorithms.
80#[derive(Debug)]
81enum LossBasedAlgorithm<I> {
82 Cubic(cubic::Cubic<I, true /* FAST_CONVERGENCE */>),
83}
84
85impl<I: Instant> LossBasedAlgorithm<I> {
86 /// Called when there is a loss detected.
87 ///
88 /// Specifically, packet loss means
89 /// - either when the retransmission timer fired;
90 /// - or when we have received a certain amount of duplicate acks.
91 fn on_loss_detected(&mut self, params: &mut CongestionControlParams) {
92 match self {
93 LossBasedAlgorithm::Cubic(cubic) => cubic.on_loss_detected(params),
94 }
95 }
96
97 /// Called when we recovered from packet loss when receiving an ACK that
98 /// acknowledges new data.
99 fn on_loss_recovered(&mut self, params: &mut CongestionControlParams) {
100 // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
101 // When the next ACK arrives that acknowledges previously
102 // unacknowledged data, a TCP MUST set cwnd to ssthresh (the value
103 // set in step 2). This is termed "deflating" the window.
104 params.cwnd = params.ssthresh;
105 }
106
107 fn on_ack(
108 &mut self,
109 params: &mut CongestionControlParams,
110 bytes_acked: NonZeroU32,
111 now: I,
112 rtt: Duration,
113 ) {
114 match self {
115 LossBasedAlgorithm::Cubic(cubic) => cubic.on_ack(params, bytes_acked, now, rtt),
116 }
117 }
118
119 fn on_retransmission_timeout(&mut self, params: &mut CongestionControlParams) {
120 match self {
121 LossBasedAlgorithm::Cubic(cubic) => cubic.on_retransmission_timeout(params),
122 }
123 }
124}
125
126impl<I: Instant> CongestionControl<I> {
127 /// Called when there are previously unacknowledged bytes being acked.
128 pub(super) fn on_ack(&mut self, bytes_acked: NonZeroU32, now: I, rtt: Duration) {
129 let Self { params, algorithm, fast_recovery } = self;
130 // Exit fast recovery since there is an ACK that acknowledges new data.
131 if let Some(fast_recovery) = fast_recovery.take() {
132 if fast_recovery.dup_acks.get() >= DUP_ACK_THRESHOLD {
133 algorithm.on_loss_recovered(params);
134 }
135 };
136 algorithm.on_ack(params, bytes_acked, now, rtt);
137 }
138
139 /// Called when a duplicate ack is arrived.
140 ///
141 /// Returns `true` if fast recovery was initiated as a result of this ACK.
142 pub(super) fn on_dup_ack(&mut self, seg_ack: SeqNum) -> bool {
143 let Self { params, algorithm, fast_recovery } = self;
144 match fast_recovery {
145 None => {
146 *fast_recovery = Some(FastRecovery::new());
147 true
148 }
149 Some(fast_recovery) => {
150 fast_recovery.on_dup_ack(params, algorithm, seg_ack);
151 false
152 }
153 }
154 }
155
156 /// Called upon a retransmission timeout.
157 pub(super) fn on_retransmission_timeout(&mut self) {
158 let Self { params, algorithm, fast_recovery } = self;
159 *fast_recovery = None;
160 algorithm.on_retransmission_timeout(params);
161 }
162
163 /// Gets the current congestion window size in bytes.
164 ///
165 /// This normally just returns whatever value the loss-based algorithm tells
166 /// us, with the exception that in limited transmit case, the cwnd is
167 /// inflated by dup_ack_cnt * mss, to allow unsent data packets to enter the
168 /// network and trigger more duplicate ACKs to enter fast retransmit. Note
169 /// that this still conforms to the RFC because we don't change the cwnd of
170 /// our algorithm, the algorithm is not aware of this "inflation".
171 pub(super) fn cwnd(&self) -> WindowSize {
172 let Self { params, algorithm: _, fast_recovery } = self;
173 let cwnd = params.rounded_cwnd();
174 if let Some(fast_recovery) = fast_recovery {
175 // Per RFC 3042 (https://www.rfc-editor.org/rfc/rfc3042#section-2):
176 // ... the Limited Transmit algorithm, which calls for a TCP
177 // sender to transmit new data upon the arrival of the first two
178 // consecutive duplicate ACKs ...
179 // The amount of outstanding data would remain less than or equal
180 // to the congestion window plus 2 segments. In other words, the
181 // sender can only send two segments beyond the congestion window
182 // (cwnd).
183 // Note: We don't directly change cwnd in the loss-based algorithm
184 // because the RFC says one MUST NOT do that. We follow the
185 // requirement here by not changing the cwnd of the algorithm - if
186 // a new ACK is received after the two dup acks, the loss-based
187 // algorithm will continue to operate the same way as if the 2 SMSS
188 // is never added to cwnd.
189 if fast_recovery.dup_acks.get() < DUP_ACK_THRESHOLD {
190 return cwnd.saturating_add(
191 u32::from(fast_recovery.dup_acks.get()) * u32::from(params.mss),
192 );
193 }
194 }
195 cwnd
196 }
197
198 pub(super) fn slow_start_threshold(&self) -> u32 {
199 self.params.ssthresh
200 }
201
202 /// Returns the starting sequence number of the segment that needs to be
203 /// retransmitted, if any.
204 pub(super) fn fast_retransmit(&mut self) -> Option<SeqNum> {
205 self.fast_recovery.as_mut().and_then(|r| r.fast_retransmit.take())
206 }
207
208 pub(super) fn cubic_with_mss(mss: Mss) -> Self {
209 Self {
210 params: CongestionControlParams::with_mss(mss),
211 algorithm: LossBasedAlgorithm::Cubic(Default::default()),
212 fast_recovery: None,
213 }
214 }
215
216 pub(super) fn mss(&self) -> Mss {
217 self.params.mss
218 }
219
220 /// Returns true if this [`CongestionControl`] is in fast recovery.
221 pub(super) fn in_fast_recovery(&self) -> bool {
222 self.fast_recovery.is_some()
223 }
224
225 /// Returns true if this [`CongestionControl`] is in slow start.
226 pub(super) fn in_slow_start(&self) -> bool {
227 self.params.cwnd < self.params.ssthresh
228 }
229}
230
231/// Reno style Fast Recovery algorithm as described in
232/// [RFC 5681](https://tools.ietf.org/html/rfc5681).
233#[derive(Debug)]
234pub struct FastRecovery {
235 /// Holds the sequence number of the segment to fast retransmit, if any.
236 fast_retransmit: Option<SeqNum>,
237 /// The running count of consecutive duplicate ACKs we have received so far.
238 ///
239 /// Here we limit the maximum number of duplicate ACKS we track to 255, as
240 /// per a note in the RFC:
241 ///
242 /// Note: [SCWA99] discusses a receiver-based attack whereby many
243 /// bogus duplicate ACKs are sent to the data sender in order to
244 /// artificially inflate cwnd and cause a higher than appropriate
245 /// sending rate to be used. A TCP MAY therefore limit the number of
246 /// times cwnd is artificially inflated during loss recovery to the
247 /// number of outstanding segments (or, an approximation thereof).
248 ///
249 /// [SCWA99]: https://homes.cs.washington.edu/~tom/pubs/CCR99.pdf
250 dup_acks: NonZeroU8,
251}
252
253impl FastRecovery {
254 fn new() -> Self {
255 Self { dup_acks: NonZeroU8::new(1).unwrap(), fast_retransmit: None }
256 }
257
258 fn on_dup_ack<I: Instant>(
259 &mut self,
260 params: &mut CongestionControlParams,
261 loss_based: &mut LossBasedAlgorithm<I>,
262 seg_ack: SeqNum,
263 ) {
264 self.dup_acks = self.dup_acks.saturating_add(1);
265
266 match self.dup_acks.get().cmp(&DUP_ACK_THRESHOLD) {
267 Ordering::Less => {}
268 Ordering::Equal => {
269 loss_based.on_loss_detected(params);
270 // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
271 // The lost segment starting at SND.UNA MUST be retransmitted
272 // and cwnd set to ssthresh plus 3*SMSS. This artificially
273 // "inflates" the congestion window by the number of segments
274 // (three) that have left the network and which the receiver
275 // has buffered.
276 self.fast_retransmit = Some(seg_ack);
277 params.cwnd =
278 params.ssthresh + u32::from(DUP_ACK_THRESHOLD) * u32::from(params.mss);
279 }
280 Ordering::Greater => {
281 // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
282 // For each additional duplicate ACK received (after the third),
283 // cwnd MUST be incremented by SMSS. This artificially inflates
284 // the congestion window in order to reflect the additional
285 // segment that has left the network.
286 params.cwnd += u32::from(params.mss);
287 }
288 }
289 }
290}
291
292#[cfg(test)]
293mod test {
294 use netstack3_base::testutil::FakeInstant;
295
296 use super::*;
297 use crate::internal::base::testutil::DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE;
298
299 #[test]
300 fn no_recovery_before_reaching_threshold() {
301 let mut congestion_control =
302 CongestionControl::cubic_with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
303 let old_cwnd = congestion_control.params.cwnd;
304 assert_eq!(congestion_control.params.ssthresh, u32::MAX);
305 assert!(congestion_control.on_dup_ack(SeqNum::new(0)));
306 congestion_control.on_ack(
307 NonZeroU32::new(1).unwrap(),
308 FakeInstant::from(Duration::from_secs(0)),
309 Duration::from_secs(1),
310 );
311 // We have only received one duplicate ack, receiving a new ACK should
312 // not mean "loss recovery" - we should not bump our cwnd to initial
313 // ssthresh (u32::MAX) and then overflow.
314 assert_eq!(old_cwnd + 1, congestion_control.params.cwnd);
315 }
316}