netstack3_tcp/
congestion.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implements loss-based congestion control algorithms.
6//!
7//! The currently implemented algorithms are CUBIC from [RFC 8312] and RENO
8//! style fast retransmit and fast recovery from [RFC 5681].
9//!
10//! [RFC 8312]: https://www.rfc-editor.org/rfc/rfc8312
11//! [RFC 5681]: https://www.rfc-editor.org/rfc/rfc5681
12
13mod cubic;
14
15use core::cmp::Ordering;
16use core::num::{NonZeroU32, NonZeroU8};
17use core::time::Duration;
18
19use netstack3_base::{Instant, Mss, SeqNum, WindowSize};
20
21// Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
22///   The fast retransmit algorithm uses the arrival of 3 duplicate ACKs (...)
23///   as an indication that a segment has been lost.
24const DUP_ACK_THRESHOLD: u8 = 3;
25
26/// Holds the parameters of congestion control that are common to algorithms.
27#[derive(Debug)]
28struct CongestionControlParams {
29    /// Slow start threshold.
30    ssthresh: u32,
31    /// Congestion control window size, in bytes.
32    cwnd: u32,
33    /// Sender MSS.
34    mss: Mss,
35}
36
37impl CongestionControlParams {
38    fn with_mss(mss: Mss) -> Self {
39        let mss_u32 = u32::from(mss);
40        // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-5):
41        //   IW, the initial value of cwnd, MUST be set using the following
42        //   guidelines as an upper bound.
43        //   If SMSS > 2190 bytes:
44        //       IW = 2 * SMSS bytes and MUST NOT be more than 2 segments
45        //   If (SMSS > 1095 bytes) and (SMSS <= 2190 bytes):
46        //       IW = 3 * SMSS bytes and MUST NOT be more than 3 segments
47        //   if SMSS <= 1095 bytes:
48        //       IW = 4 * SMSS bytes and MUST NOT be more than 4 segments
49        let cwnd = if mss_u32 > 2190 {
50            mss_u32 * 2
51        } else if mss_u32 > 1095 {
52            mss_u32 * 3
53        } else {
54            mss_u32 * 4
55        };
56        Self { cwnd, ssthresh: u32::MAX, mss }
57    }
58
59    fn rounded_cwnd(&self) -> WindowSize {
60        let mss_u32 = u32::from(self.mss);
61        WindowSize::from_u32(self.cwnd / mss_u32 * mss_u32).unwrap_or(WindowSize::MAX)
62    }
63}
64
65/// Congestion control with four intertwined algorithms.
66///
67/// - Slow start
68/// - Congestion avoidance from a loss-based algorithm
69/// - Fast retransmit
70/// - Fast recovery: https://datatracker.ietf.org/doc/html/rfc5681#section-3
71#[derive(Debug)]
72pub(super) struct CongestionControl<I> {
73    params: CongestionControlParams,
74    algorithm: LossBasedAlgorithm<I>,
75    /// The connection is in fast recovery when this field is a [`Some`].
76    fast_recovery: Option<FastRecovery>,
77}
78
79/// Available congestion control algorithms.
80#[derive(Debug)]
81enum LossBasedAlgorithm<I> {
82    Cubic(cubic::Cubic<I, true /* FAST_CONVERGENCE */>),
83}
84
85impl<I: Instant> LossBasedAlgorithm<I> {
86    /// Called when there is a loss detected.
87    ///
88    /// Specifically, packet loss means
89    /// - either when the retransmission timer fired;
90    /// - or when we have received a certain amount of duplicate acks.
91    fn on_loss_detected(&mut self, params: &mut CongestionControlParams) {
92        match self {
93            LossBasedAlgorithm::Cubic(cubic) => cubic.on_loss_detected(params),
94        }
95    }
96
97    /// Called when we recovered from packet loss when receiving an ACK that
98    /// acknowledges new data.
99    fn on_loss_recovered(&mut self, params: &mut CongestionControlParams) {
100        // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
101        //   When the next ACK arrives that acknowledges previously
102        //   unacknowledged data, a TCP MUST set cwnd to ssthresh (the value
103        //   set in step 2).  This is termed "deflating" the window.
104        params.cwnd = params.ssthresh;
105    }
106
107    fn on_ack(
108        &mut self,
109        params: &mut CongestionControlParams,
110        bytes_acked: NonZeroU32,
111        now: I,
112        rtt: Duration,
113    ) {
114        match self {
115            LossBasedAlgorithm::Cubic(cubic) => cubic.on_ack(params, bytes_acked, now, rtt),
116        }
117    }
118
119    fn on_retransmission_timeout(&mut self, params: &mut CongestionControlParams) {
120        match self {
121            LossBasedAlgorithm::Cubic(cubic) => cubic.on_retransmission_timeout(params),
122        }
123    }
124}
125
126impl<I: Instant> CongestionControl<I> {
127    /// Called when there are previously unacknowledged bytes being acked.
128    pub(super) fn on_ack(&mut self, bytes_acked: NonZeroU32, now: I, rtt: Duration) {
129        let Self { params, algorithm, fast_recovery } = self;
130        // Exit fast recovery since there is an ACK that acknowledges new data.
131        if let Some(fast_recovery) = fast_recovery.take() {
132            if fast_recovery.dup_acks.get() >= DUP_ACK_THRESHOLD {
133                algorithm.on_loss_recovered(params);
134            }
135        };
136        algorithm.on_ack(params, bytes_acked, now, rtt);
137    }
138
139    /// Called when a duplicate ack is arrived.
140    ///
141    /// Returns `true` if fast recovery was initiated as a result of this ACK.
142    pub(super) fn on_dup_ack(&mut self, seg_ack: SeqNum) -> bool {
143        let Self { params, algorithm, fast_recovery } = self;
144        match fast_recovery {
145            None => {
146                *fast_recovery = Some(FastRecovery::new());
147                true
148            }
149            Some(fast_recovery) => {
150                fast_recovery.on_dup_ack(params, algorithm, seg_ack);
151                false
152            }
153        }
154    }
155
156    /// Called upon a retransmission timeout.
157    pub(super) fn on_retransmission_timeout(&mut self) {
158        let Self { params, algorithm, fast_recovery } = self;
159        *fast_recovery = None;
160        algorithm.on_retransmission_timeout(params);
161    }
162
163    /// Gets the current congestion window size in bytes.
164    ///
165    /// This normally just returns whatever value the loss-based algorithm tells
166    /// us, with the exception that in limited transmit case, the cwnd is
167    /// inflated by dup_ack_cnt * mss, to allow unsent data packets to enter the
168    /// network and trigger more duplicate ACKs to enter fast retransmit. Note
169    /// that this still conforms to the RFC because we don't change the cwnd of
170    /// our algorithm, the algorithm is not aware of this "inflation".
171    pub(super) fn cwnd(&self) -> WindowSize {
172        let Self { params, algorithm: _, fast_recovery } = self;
173        let cwnd = params.rounded_cwnd();
174        if let Some(fast_recovery) = fast_recovery {
175            // Per RFC 3042 (https://www.rfc-editor.org/rfc/rfc3042#section-2):
176            //   ... the Limited Transmit algorithm, which calls for a TCP
177            //   sender to transmit new data upon the arrival of the first two
178            //   consecutive duplicate ACKs ...
179            //   The amount of outstanding data would remain less than or equal
180            //   to the congestion window plus 2 segments.  In other words, the
181            //   sender can only send two segments beyond the congestion window
182            //   (cwnd).
183            // Note: We don't directly change cwnd in the loss-based algorithm
184            // because the RFC says one MUST NOT do that. We follow the
185            // requirement here by not changing the cwnd of the algorithm - if
186            // a new ACK is received after the two dup acks, the loss-based
187            // algorithm will continue to operate the same way as if the 2 SMSS
188            // is never added to cwnd.
189            if fast_recovery.dup_acks.get() < DUP_ACK_THRESHOLD {
190                return cwnd.saturating_add(
191                    u32::from(fast_recovery.dup_acks.get()) * u32::from(params.mss),
192                );
193            }
194        }
195        cwnd
196    }
197
198    pub(super) fn slow_start_threshold(&self) -> u32 {
199        self.params.ssthresh
200    }
201
202    /// Returns the starting sequence number of the segment that needs to be
203    /// retransmitted, if any.
204    pub(super) fn fast_retransmit(&mut self) -> Option<SeqNum> {
205        self.fast_recovery.as_mut().and_then(|r| r.fast_retransmit.take())
206    }
207
208    pub(super) fn cubic_with_mss(mss: Mss) -> Self {
209        Self {
210            params: CongestionControlParams::with_mss(mss),
211            algorithm: LossBasedAlgorithm::Cubic(Default::default()),
212            fast_recovery: None,
213        }
214    }
215
216    pub(super) fn mss(&self) -> Mss {
217        self.params.mss
218    }
219
220    /// Returns true if this [`CongestionControl`] is in fast recovery.
221    pub(super) fn in_fast_recovery(&self) -> bool {
222        self.fast_recovery.is_some()
223    }
224
225    /// Returns true if this [`CongestionControl`] is in slow start.
226    pub(super) fn in_slow_start(&self) -> bool {
227        self.params.cwnd < self.params.ssthresh
228    }
229}
230
231/// Reno style Fast Recovery algorithm as described in
232/// [RFC 5681](https://tools.ietf.org/html/rfc5681).
233#[derive(Debug)]
234pub struct FastRecovery {
235    /// Holds the sequence number of the segment to fast retransmit, if any.
236    fast_retransmit: Option<SeqNum>,
237    /// The running count of consecutive duplicate ACKs we have received so far.
238    ///
239    /// Here we limit the maximum number of duplicate ACKS we track to 255, as
240    /// per a note in the RFC:
241    ///
242    /// Note: [SCWA99] discusses a receiver-based attack whereby many
243    /// bogus duplicate ACKs are sent to the data sender in order to
244    /// artificially inflate cwnd and cause a higher than appropriate
245    /// sending rate to be used.  A TCP MAY therefore limit the number of
246    /// times cwnd is artificially inflated during loss recovery to the
247    /// number of outstanding segments (or, an approximation thereof).
248    ///
249    /// [SCWA99]: https://homes.cs.washington.edu/~tom/pubs/CCR99.pdf
250    dup_acks: NonZeroU8,
251}
252
253impl FastRecovery {
254    fn new() -> Self {
255        Self { dup_acks: NonZeroU8::new(1).unwrap(), fast_retransmit: None }
256    }
257
258    fn on_dup_ack<I: Instant>(
259        &mut self,
260        params: &mut CongestionControlParams,
261        loss_based: &mut LossBasedAlgorithm<I>,
262        seg_ack: SeqNum,
263    ) {
264        self.dup_acks = self.dup_acks.saturating_add(1);
265
266        match self.dup_acks.get().cmp(&DUP_ACK_THRESHOLD) {
267            Ordering::Less => {}
268            Ordering::Equal => {
269                loss_based.on_loss_detected(params);
270                // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
271                //   The lost segment starting at SND.UNA MUST be retransmitted
272                //   and cwnd set to ssthresh plus 3*SMSS.  This artificially
273                //   "inflates" the congestion window by the number of segments
274                //   (three) that have left the network and which the receiver
275                //   has buffered.
276                self.fast_retransmit = Some(seg_ack);
277                params.cwnd =
278                    params.ssthresh + u32::from(DUP_ACK_THRESHOLD) * u32::from(params.mss);
279            }
280            Ordering::Greater => {
281                // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
282                //   For each additional duplicate ACK received (after the third),
283                //   cwnd MUST be incremented by SMSS. This artificially inflates
284                //   the congestion window in order to reflect the additional
285                //   segment that has left the network.
286                params.cwnd += u32::from(params.mss);
287            }
288        }
289    }
290}
291
292#[cfg(test)]
293mod test {
294    use netstack3_base::testutil::FakeInstant;
295
296    use super::*;
297    use crate::internal::base::testutil::DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE;
298
299    #[test]
300    fn no_recovery_before_reaching_threshold() {
301        let mut congestion_control =
302            CongestionControl::cubic_with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
303        let old_cwnd = congestion_control.params.cwnd;
304        assert_eq!(congestion_control.params.ssthresh, u32::MAX);
305        assert!(congestion_control.on_dup_ack(SeqNum::new(0)));
306        congestion_control.on_ack(
307            NonZeroU32::new(1).unwrap(),
308            FakeInstant::from(Duration::from_secs(0)),
309            Duration::from_secs(1),
310        );
311        // We have only received one duplicate ack, receiving a new ACK should
312        // not mean "loss recovery" - we should not bump our cwnd to initial
313        // ssthresh (u32::MAX) and then overflow.
314        assert_eq!(old_cwnd + 1, congestion_control.params.cwnd);
315    }
316}