netstack3_tcp/congestion.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! Implements loss-based congestion control algorithms.
//!
//! The currently implemented algorithms are CUBIC from [RFC 8312] and RENO
//! style fast retransmit and fast recovery from [RFC 5681].
//!
//! [RFC 8312]: https://www.rfc-editor.org/rfc/rfc8312
//! [RFC 5681]: https://www.rfc-editor.org/rfc/rfc5681
mod cubic;
use core::cmp::Ordering;
use core::num::{NonZeroU32, NonZeroU8};
use core::time::Duration;
use netstack3_base::{Instant, Mss, SeqNum, WindowSize};
// Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
/// The fast retransmit algorithm uses the arrival of 3 duplicate ACKs (...)
/// as an indication that a segment has been lost.
const DUP_ACK_THRESHOLD: u8 = 3;
/// Holds the parameters of congestion control that are common to algorithms.
#[derive(Debug)]
struct CongestionControlParams {
/// Slow start threshold.
ssthresh: u32,
/// Congestion control window size, in bytes.
cwnd: u32,
/// Sender MSS.
mss: Mss,
}
impl CongestionControlParams {
fn with_mss(mss: Mss) -> Self {
let mss_u32 = u32::from(mss);
// Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-5):
// IW, the initial value of cwnd, MUST be set using the following
// guidelines as an upper bound.
// If SMSS > 2190 bytes:
// IW = 2 * SMSS bytes and MUST NOT be more than 2 segments
// If (SMSS > 1095 bytes) and (SMSS <= 2190 bytes):
// IW = 3 * SMSS bytes and MUST NOT be more than 3 segments
// if SMSS <= 1095 bytes:
// IW = 4 * SMSS bytes and MUST NOT be more than 4 segments
let cwnd = if mss_u32 > 2190 {
mss_u32 * 2
} else if mss_u32 > 1095 {
mss_u32 * 3
} else {
mss_u32 * 4
};
Self { cwnd, ssthresh: u32::MAX, mss }
}
fn rounded_cwnd(&self) -> WindowSize {
let mss_u32 = u32::from(self.mss);
WindowSize::from_u32(self.cwnd / mss_u32 * mss_u32).unwrap_or(WindowSize::MAX)
}
}
/// Congestion control with four intertwined algorithms.
///
/// - Slow start
/// - Congestion avoidance from a loss-based algorithm
/// - Fast retransmit
/// - Fast recovery: https://datatracker.ietf.org/doc/html/rfc5681#section-3
#[derive(Debug)]
pub(super) struct CongestionControl<I> {
params: CongestionControlParams,
algorithm: LossBasedAlgorithm<I>,
/// The connection is in fast recovery when this field is a [`Some`].
fast_recovery: Option<FastRecovery>,
}
/// Available congestion control algorithms.
#[derive(Debug)]
enum LossBasedAlgorithm<I> {
Cubic(cubic::Cubic<I, true /* FAST_CONVERGENCE */>),
}
impl<I: Instant> LossBasedAlgorithm<I> {
/// Called when there is a loss detected.
///
/// Specifically, packet loss means
/// - either when the retransmission timer fired;
/// - or when we have received a certain amount of duplicate acks.
fn on_loss_detected(&mut self, params: &mut CongestionControlParams) {
match self {
LossBasedAlgorithm::Cubic(cubic) => cubic.on_loss_detected(params),
}
}
/// Called when we recovered from packet loss when receiving an ACK that
/// acknowledges new data.
fn on_loss_recovered(&mut self, params: &mut CongestionControlParams) {
// Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
// When the next ACK arrives that acknowledges previously
// unacknowledged data, a TCP MUST set cwnd to ssthresh (the value
// set in step 2). This is termed "deflating" the window.
params.cwnd = params.ssthresh;
}
fn on_ack(
&mut self,
params: &mut CongestionControlParams,
bytes_acked: NonZeroU32,
now: I,
rtt: Duration,
) {
match self {
LossBasedAlgorithm::Cubic(cubic) => cubic.on_ack(params, bytes_acked, now, rtt),
}
}
fn on_retransmission_timeout(&mut self, params: &mut CongestionControlParams) {
match self {
LossBasedAlgorithm::Cubic(cubic) => cubic.on_retransmission_timeout(params),
}
}
}
impl<I: Instant> CongestionControl<I> {
/// Called when there are previously unacknowledged bytes being acked.
pub(super) fn on_ack(&mut self, bytes_acked: NonZeroU32, now: I, rtt: Duration) {
let Self { params, algorithm, fast_recovery } = self;
// Exit fast recovery since there is an ACK that acknowledges new data.
if let Some(fast_recovery) = fast_recovery.take() {
if fast_recovery.dup_acks.get() >= DUP_ACK_THRESHOLD {
algorithm.on_loss_recovered(params);
}
};
algorithm.on_ack(params, bytes_acked, now, rtt);
}
/// Called when a duplicate ack is arrived.
///
/// Returns `true` if fast recovery was initiated as a result of this ACK.
pub(super) fn on_dup_ack(&mut self, seg_ack: SeqNum) -> bool {
let Self { params, algorithm, fast_recovery } = self;
match fast_recovery {
None => {
*fast_recovery = Some(FastRecovery::new());
true
}
Some(fast_recovery) => {
fast_recovery.on_dup_ack(params, algorithm, seg_ack);
false
}
}
}
/// Called upon a retransmission timeout.
pub(super) fn on_retransmission_timeout(&mut self) {
let Self { params, algorithm, fast_recovery } = self;
*fast_recovery = None;
algorithm.on_retransmission_timeout(params);
}
/// Gets the current congestion window size in bytes.
///
/// This normally just returns whatever value the loss-based algorithm tells
/// us, with the exception that in limited transmit case, the cwnd is
/// inflated by dup_ack_cnt * mss, to allow unsent data packets to enter the
/// network and trigger more duplicate ACKs to enter fast retransmit. Note
/// that this still conforms to the RFC because we don't change the cwnd of
/// our algorithm, the algorithm is not aware of this "inflation".
pub(super) fn cwnd(&self) -> WindowSize {
let Self { params, algorithm: _, fast_recovery } = self;
let cwnd = params.rounded_cwnd();
if let Some(fast_recovery) = fast_recovery {
// Per RFC 3042 (https://www.rfc-editor.org/rfc/rfc3042#section-2):
// ... the Limited Transmit algorithm, which calls for a TCP
// sender to transmit new data upon the arrival of the first two
// consecutive duplicate ACKs ...
// The amount of outstanding data would remain less than or equal
// to the congestion window plus 2 segments. In other words, the
// sender can only send two segments beyond the congestion window
// (cwnd).
// Note: We don't directly change cwnd in the loss-based algorithm
// because the RFC says one MUST NOT do that. We follow the
// requirement here by not changing the cwnd of the algorithm - if
// a new ACK is received after the two dup acks, the loss-based
// algorithm will continue to operate the same way as if the 2 SMSS
// is never added to cwnd.
if fast_recovery.dup_acks.get() < DUP_ACK_THRESHOLD {
return cwnd.saturating_add(
u32::from(fast_recovery.dup_acks.get()) * u32::from(params.mss),
);
}
}
cwnd
}
/// Returns the starting sequence number of the segment that needs to be
/// retransmitted, if any.
pub(super) fn fast_retransmit(&mut self) -> Option<SeqNum> {
self.fast_recovery.as_mut().and_then(|r| r.fast_retransmit.take())
}
pub(super) fn cubic_with_mss(mss: Mss) -> Self {
Self {
params: CongestionControlParams::with_mss(mss),
algorithm: LossBasedAlgorithm::Cubic(Default::default()),
fast_recovery: None,
}
}
pub(super) fn mss(&self) -> Mss {
self.params.mss
}
/// Returns true if this [`CongestionControl`] is in fast recovery.
pub(super) fn in_fast_recovery(&self) -> bool {
self.fast_recovery.is_some()
}
/// Returns true if this [`CongestionControl`] is in slow start.
pub(super) fn in_slow_start(&self) -> bool {
self.params.cwnd < self.params.ssthresh
}
}
/// Reno style Fast Recovery algorithm as described in
/// [RFC 5681](https://tools.ietf.org/html/rfc5681).
#[derive(Debug)]
pub struct FastRecovery {
/// Holds the sequence number of the segment to fast retransmit, if any.
fast_retransmit: Option<SeqNum>,
/// The running count of consecutive duplicate ACKs we have received so far.
///
/// Here we limit the maximum number of duplicate ACKS we track to 255, as
/// per a note in the RFC:
///
/// Note: [SCWA99] discusses a receiver-based attack whereby many
/// bogus duplicate ACKs are sent to the data sender in order to
/// artificially inflate cwnd and cause a higher than appropriate
/// sending rate to be used. A TCP MAY therefore limit the number of
/// times cwnd is artificially inflated during loss recovery to the
/// number of outstanding segments (or, an approximation thereof).
///
/// [SCWA99]: https://homes.cs.washington.edu/~tom/pubs/CCR99.pdf
dup_acks: NonZeroU8,
}
impl FastRecovery {
fn new() -> Self {
Self { dup_acks: NonZeroU8::new(1).unwrap(), fast_retransmit: None }
}
fn on_dup_ack<I: Instant>(
&mut self,
params: &mut CongestionControlParams,
loss_based: &mut LossBasedAlgorithm<I>,
seg_ack: SeqNum,
) {
self.dup_acks = self.dup_acks.saturating_add(1);
match self.dup_acks.get().cmp(&DUP_ACK_THRESHOLD) {
Ordering::Less => {}
Ordering::Equal => {
loss_based.on_loss_detected(params);
// Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
// The lost segment starting at SND.UNA MUST be retransmitted
// and cwnd set to ssthresh plus 3*SMSS. This artificially
// "inflates" the congestion window by the number of segments
// (three) that have left the network and which the receiver
// has buffered.
self.fast_retransmit = Some(seg_ack);
params.cwnd =
params.ssthresh + u32::from(DUP_ACK_THRESHOLD) * u32::from(params.mss);
}
Ordering::Greater => {
// Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
// For each additional duplicate ACK received (after the third),
// cwnd MUST be incremented by SMSS. This artificially inflates
// the congestion window in order to reflect the additional
// segment that has left the network.
params.cwnd += u32::from(params.mss);
}
}
}
}
#[cfg(test)]
mod test {
use netstack3_base::testutil::FakeInstant;
use super::*;
use crate::internal::base::testutil::DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE;
#[test]
fn no_recovery_before_reaching_threshold() {
let mut congestion_control =
CongestionControl::cubic_with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
let old_cwnd = congestion_control.params.cwnd;
assert_eq!(congestion_control.params.ssthresh, u32::MAX);
assert!(congestion_control.on_dup_ack(SeqNum::new(0)));
congestion_control.on_ack(
NonZeroU32::new(1).unwrap(),
FakeInstant::from(Duration::from_secs(0)),
Duration::from_secs(1),
);
// We have only received one duplicate ack, receiving a new ACK should
// not mean "loss recovery" - we should not bump our cwnd to initial
// ssthresh (u32::MAX) and then overflow.
assert_eq!(old_cwnd + 1, congestion_control.params.cwnd);
}
}