netstack3_tcp/
congestion.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

//! Implements loss-based congestion control algorithms.
//!
//! The currently implemented algorithms are CUBIC from [RFC 8312] and RENO
//! style fast retransmit and fast recovery from [RFC 5681].
//!
//! [RFC 8312]: https://www.rfc-editor.org/rfc/rfc8312
//! [RFC 5681]: https://www.rfc-editor.org/rfc/rfc5681

mod cubic;

use core::cmp::Ordering;
use core::num::{NonZeroU32, NonZeroU8};
use core::time::Duration;

use netstack3_base::{Instant, Mss, SeqNum, WindowSize};

// Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
///   The fast retransmit algorithm uses the arrival of 3 duplicate ACKs (...)
///   as an indication that a segment has been lost.
const DUP_ACK_THRESHOLD: u8 = 3;

/// Holds the parameters of congestion control that are common to algorithms.
#[derive(Debug)]
struct CongestionControlParams {
    /// Slow start threshold.
    ssthresh: u32,
    /// Congestion control window size, in bytes.
    cwnd: u32,
    /// Sender MSS.
    mss: Mss,
}

impl CongestionControlParams {
    fn with_mss(mss: Mss) -> Self {
        let mss_u32 = u32::from(mss);
        // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-5):
        //   IW, the initial value of cwnd, MUST be set using the following
        //   guidelines as an upper bound.
        //   If SMSS > 2190 bytes:
        //       IW = 2 * SMSS bytes and MUST NOT be more than 2 segments
        //   If (SMSS > 1095 bytes) and (SMSS <= 2190 bytes):
        //       IW = 3 * SMSS bytes and MUST NOT be more than 3 segments
        //   if SMSS <= 1095 bytes:
        //       IW = 4 * SMSS bytes and MUST NOT be more than 4 segments
        let cwnd = if mss_u32 > 2190 {
            mss_u32 * 2
        } else if mss_u32 > 1095 {
            mss_u32 * 3
        } else {
            mss_u32 * 4
        };
        Self { cwnd, ssthresh: u32::MAX, mss }
    }

    fn rounded_cwnd(&self) -> WindowSize {
        let mss_u32 = u32::from(self.mss);
        WindowSize::from_u32(self.cwnd / mss_u32 * mss_u32).unwrap_or(WindowSize::MAX)
    }
}

/// Congestion control with four intertwined algorithms.
///
/// - Slow start
/// - Congestion avoidance from a loss-based algorithm
/// - Fast retransmit
/// - Fast recovery: https://datatracker.ietf.org/doc/html/rfc5681#section-3
#[derive(Debug)]
pub(super) struct CongestionControl<I> {
    params: CongestionControlParams,
    algorithm: LossBasedAlgorithm<I>,
    /// The connection is in fast recovery when this field is a [`Some`].
    fast_recovery: Option<FastRecovery>,
}

/// Available congestion control algorithms.
#[derive(Debug)]
enum LossBasedAlgorithm<I> {
    Cubic(cubic::Cubic<I, true /* FAST_CONVERGENCE */>),
}

impl<I: Instant> LossBasedAlgorithm<I> {
    /// Called when there is a loss detected.
    ///
    /// Specifically, packet loss means
    /// - either when the retransmission timer fired;
    /// - or when we have received a certain amount of duplicate acks.
    fn on_loss_detected(&mut self, params: &mut CongestionControlParams) {
        match self {
            LossBasedAlgorithm::Cubic(cubic) => cubic.on_loss_detected(params),
        }
    }

    /// Called when we recovered from packet loss when receiving an ACK that
    /// acknowledges new data.
    fn on_loss_recovered(&mut self, params: &mut CongestionControlParams) {
        // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
        //   When the next ACK arrives that acknowledges previously
        //   unacknowledged data, a TCP MUST set cwnd to ssthresh (the value
        //   set in step 2).  This is termed "deflating" the window.
        params.cwnd = params.ssthresh;
    }

    fn on_ack(
        &mut self,
        params: &mut CongestionControlParams,
        bytes_acked: NonZeroU32,
        now: I,
        rtt: Duration,
    ) {
        match self {
            LossBasedAlgorithm::Cubic(cubic) => cubic.on_ack(params, bytes_acked, now, rtt),
        }
    }

    fn on_retransmission_timeout(&mut self, params: &mut CongestionControlParams) {
        match self {
            LossBasedAlgorithm::Cubic(cubic) => cubic.on_retransmission_timeout(params),
        }
    }
}

impl<I: Instant> CongestionControl<I> {
    /// Called when there are previously unacknowledged bytes being acked.
    pub(super) fn on_ack(&mut self, bytes_acked: NonZeroU32, now: I, rtt: Duration) {
        let Self { params, algorithm, fast_recovery } = self;
        // Exit fast recovery since there is an ACK that acknowledges new data.
        if let Some(fast_recovery) = fast_recovery.take() {
            if fast_recovery.dup_acks.get() >= DUP_ACK_THRESHOLD {
                algorithm.on_loss_recovered(params);
            }
        };
        algorithm.on_ack(params, bytes_acked, now, rtt);
    }

    /// Called when a duplicate ack is arrived.
    ///
    /// Returns `true` if fast recovery was initiated as a result of this ACK.
    pub(super) fn on_dup_ack(&mut self, seg_ack: SeqNum) -> bool {
        let Self { params, algorithm, fast_recovery } = self;
        match fast_recovery {
            None => {
                *fast_recovery = Some(FastRecovery::new());
                true
            }
            Some(fast_recovery) => {
                fast_recovery.on_dup_ack(params, algorithm, seg_ack);
                false
            }
        }
    }

    /// Called upon a retransmission timeout.
    pub(super) fn on_retransmission_timeout(&mut self) {
        let Self { params, algorithm, fast_recovery } = self;
        *fast_recovery = None;
        algorithm.on_retransmission_timeout(params);
    }

    /// Gets the current congestion window size in bytes.
    ///
    /// This normally just returns whatever value the loss-based algorithm tells
    /// us, with the exception that in limited transmit case, the cwnd is
    /// inflated by dup_ack_cnt * mss, to allow unsent data packets to enter the
    /// network and trigger more duplicate ACKs to enter fast retransmit. Note
    /// that this still conforms to the RFC because we don't change the cwnd of
    /// our algorithm, the algorithm is not aware of this "inflation".
    pub(super) fn cwnd(&self) -> WindowSize {
        let Self { params, algorithm: _, fast_recovery } = self;
        let cwnd = params.rounded_cwnd();
        if let Some(fast_recovery) = fast_recovery {
            // Per RFC 3042 (https://www.rfc-editor.org/rfc/rfc3042#section-2):
            //   ... the Limited Transmit algorithm, which calls for a TCP
            //   sender to transmit new data upon the arrival of the first two
            //   consecutive duplicate ACKs ...
            //   The amount of outstanding data would remain less than or equal
            //   to the congestion window plus 2 segments.  In other words, the
            //   sender can only send two segments beyond the congestion window
            //   (cwnd).
            // Note: We don't directly change cwnd in the loss-based algorithm
            // because the RFC says one MUST NOT do that. We follow the
            // requirement here by not changing the cwnd of the algorithm - if
            // a new ACK is received after the two dup acks, the loss-based
            // algorithm will continue to operate the same way as if the 2 SMSS
            // is never added to cwnd.
            if fast_recovery.dup_acks.get() < DUP_ACK_THRESHOLD {
                return cwnd.saturating_add(
                    u32::from(fast_recovery.dup_acks.get()) * u32::from(params.mss),
                );
            }
        }
        cwnd
    }

    /// Returns the starting sequence number of the segment that needs to be
    /// retransmitted, if any.
    pub(super) fn fast_retransmit(&mut self) -> Option<SeqNum> {
        self.fast_recovery.as_mut().and_then(|r| r.fast_retransmit.take())
    }

    pub(super) fn cubic_with_mss(mss: Mss) -> Self {
        Self {
            params: CongestionControlParams::with_mss(mss),
            algorithm: LossBasedAlgorithm::Cubic(Default::default()),
            fast_recovery: None,
        }
    }

    pub(super) fn mss(&self) -> Mss {
        self.params.mss
    }

    /// Returns true if this [`CongestionControl`] is in fast recovery.
    pub(super) fn in_fast_recovery(&self) -> bool {
        self.fast_recovery.is_some()
    }

    /// Returns true if this [`CongestionControl`] is in slow start.
    pub(super) fn in_slow_start(&self) -> bool {
        self.params.cwnd < self.params.ssthresh
    }
}

/// Reno style Fast Recovery algorithm as described in
/// [RFC 5681](https://tools.ietf.org/html/rfc5681).
#[derive(Debug)]
pub struct FastRecovery {
    /// Holds the sequence number of the segment to fast retransmit, if any.
    fast_retransmit: Option<SeqNum>,
    /// The running count of consecutive duplicate ACKs we have received so far.
    ///
    /// Here we limit the maximum number of duplicate ACKS we track to 255, as
    /// per a note in the RFC:
    ///
    /// Note: [SCWA99] discusses a receiver-based attack whereby many
    /// bogus duplicate ACKs are sent to the data sender in order to
    /// artificially inflate cwnd and cause a higher than appropriate
    /// sending rate to be used.  A TCP MAY therefore limit the number of
    /// times cwnd is artificially inflated during loss recovery to the
    /// number of outstanding segments (or, an approximation thereof).
    ///
    /// [SCWA99]: https://homes.cs.washington.edu/~tom/pubs/CCR99.pdf
    dup_acks: NonZeroU8,
}

impl FastRecovery {
    fn new() -> Self {
        Self { dup_acks: NonZeroU8::new(1).unwrap(), fast_retransmit: None }
    }

    fn on_dup_ack<I: Instant>(
        &mut self,
        params: &mut CongestionControlParams,
        loss_based: &mut LossBasedAlgorithm<I>,
        seg_ack: SeqNum,
    ) {
        self.dup_acks = self.dup_acks.saturating_add(1);

        match self.dup_acks.get().cmp(&DUP_ACK_THRESHOLD) {
            Ordering::Less => {}
            Ordering::Equal => {
                loss_based.on_loss_detected(params);
                // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
                //   The lost segment starting at SND.UNA MUST be retransmitted
                //   and cwnd set to ssthresh plus 3*SMSS.  This artificially
                //   "inflates" the congestion window by the number of segments
                //   (three) that have left the network and which the receiver
                //   has buffered.
                self.fast_retransmit = Some(seg_ack);
                params.cwnd =
                    params.ssthresh + u32::from(DUP_ACK_THRESHOLD) * u32::from(params.mss);
            }
            Ordering::Greater => {
                // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
                //   For each additional duplicate ACK received (after the third),
                //   cwnd MUST be incremented by SMSS. This artificially inflates
                //   the congestion window in order to reflect the additional
                //   segment that has left the network.
                params.cwnd += u32::from(params.mss);
            }
        }
    }
}

#[cfg(test)]
mod test {
    use netstack3_base::testutil::FakeInstant;

    use super::*;
    use crate::internal::base::testutil::DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE;

    #[test]
    fn no_recovery_before_reaching_threshold() {
        let mut congestion_control =
            CongestionControl::cubic_with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
        let old_cwnd = congestion_control.params.cwnd;
        assert_eq!(congestion_control.params.ssthresh, u32::MAX);
        assert!(congestion_control.on_dup_ack(SeqNum::new(0)));
        congestion_control.on_ack(
            NonZeroU32::new(1).unwrap(),
            FakeInstant::from(Duration::from_secs(0)),
            Duration::from_secs(1),
        );
        // We have only received one duplicate ack, receiving a new ACK should
        // not mean "loss recovery" - we should not bump our cwnd to initial
        // ssthresh (u32::MAX) and then overflow.
        assert_eq!(old_cwnd + 1, congestion_control.params.cwnd);
    }
}