netstack3_tcp/
congestion.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implements loss-based congestion control algorithms.
6//!
7//! The currently implemented algorithms are CUBIC from [RFC 8312] and RENO
8//! style fast retransmit and fast recovery from [RFC 5681].
9//!
10//! [RFC 8312]: https://www.rfc-editor.org/rfc/rfc8312
11//! [RFC 5681]: https://www.rfc-editor.org/rfc/rfc5681
12
13mod cubic;
14
15use core::cmp::Ordering;
16use core::num::{NonZeroU8, NonZeroU32};
17use core::time::Duration;
18
19use netstack3_base::{EffectiveMss, Instant, Mss, SackBlocks, SeqNum, WindowSize};
20
21use crate::internal::sack_scoreboard::SackScoreboard;
22
23// Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
24///   The fast retransmit algorithm uses the arrival of 3 duplicate ACKs (...)
25///   as an indication that a segment has been lost.
26pub(crate) const DUP_ACK_THRESHOLD: u8 = 3;
27
28/// Holds the parameters of congestion control that are common to algorithms.
29#[derive(Debug)]
30struct CongestionControlParams {
31    /// Slow start threshold.
32    ssthresh: u32,
33    /// Congestion control window size, in bytes.
34    cwnd: u32,
35    /// Sender MSS.
36    mss: EffectiveMss,
37}
38
39impl CongestionControlParams {
40    fn with_mss(mss: EffectiveMss) -> Self {
41        let mss_u32 = u32::from(mss);
42        // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-5):
43        //   IW, the initial value of cwnd, MUST be set using the following
44        //   guidelines as an upper bound.
45        //   If SMSS > 2190 bytes:
46        //       IW = 2 * SMSS bytes and MUST NOT be more than 2 segments
47        //   If (SMSS > 1095 bytes) and (SMSS <= 2190 bytes):
48        //       IW = 3 * SMSS bytes and MUST NOT be more than 3 segments
49        //   if SMSS <= 1095 bytes:
50        //       IW = 4 * SMSS bytes and MUST NOT be more than 4 segments
51        let cwnd = if mss_u32 > 2190 {
52            mss_u32 * 2
53        } else if mss_u32 > 1095 {
54            mss_u32 * 3
55        } else {
56            mss_u32 * 4
57        };
58        Self { cwnd, ssthresh: u32::MAX, mss }
59    }
60
61    fn rounded_cwnd(&self) -> CongestionWindow {
62        CongestionWindow::new(self.cwnd, self.mss)
63    }
64}
65
66mod cwnd {
67    use super::*;
68    /// A witness type for a congestion window that is rounded to a multiple of
69    /// MSS.
70    ///
71    /// This type carries around the mss that was used to calculate it.
72    #[derive(Debug, Copy, Clone)]
73    #[cfg_attr(test, derive(Eq, PartialEq))]
74    pub(crate) struct CongestionWindow {
75        cwnd: u32,
76        mss: EffectiveMss,
77    }
78
79    impl CongestionWindow {
80        pub(super) fn new(cwnd: u32, mss: EffectiveMss) -> Self {
81            let mss_u32 = u32::from(mss);
82            Self { cwnd: cwnd / mss_u32 * mss_u32, mss }
83        }
84
85        pub(crate) fn cwnd(&self) -> u32 {
86            self.cwnd
87        }
88
89        pub(crate) fn mss(&self) -> EffectiveMss {
90            self.mss
91        }
92    }
93}
94pub(crate) use cwnd::CongestionWindow;
95
96/// Congestion control with five intertwined algorithms.
97///
98/// - Slow start
99/// - Congestion avoidance from a loss-based algorithm
100/// - Fast retransmit
101/// - Fast recovery: https://datatracker.ietf.org/doc/html/rfc5681#section-3
102/// - SACK recovery: https://datatracker.ietf.org/doc/html/rfc6675
103#[derive(Debug)]
104pub(crate) struct CongestionControl<I> {
105    params: CongestionControlParams,
106    sack_scoreboard: SackScoreboard,
107    algorithm: LossBasedAlgorithm<I>,
108    /// The connection is in loss recovery when this field is a [`Some`].
109    loss_recovery: Option<LossRecovery>,
110}
111
112/// Available congestion control algorithms.
113#[derive(Debug)]
114enum LossBasedAlgorithm<I> {
115    Cubic(cubic::Cubic<I, true /* FAST_CONVERGENCE */>),
116}
117
118impl<I: Instant> LossBasedAlgorithm<I> {
119    /// Called when there is a loss detected.
120    ///
121    /// Specifically, packet loss means
122    /// - either when the retransmission timer fired;
123    /// - or when we have received a certain amount of duplicate acks.
124    fn on_loss_detected(&mut self, params: &mut CongestionControlParams) {
125        match self {
126            LossBasedAlgorithm::Cubic(cubic) => cubic.on_loss_detected(params),
127        }
128    }
129
130    fn on_ack(
131        &mut self,
132        params: &mut CongestionControlParams,
133        bytes_acked: NonZeroU32,
134        now: I,
135        rtt: Duration,
136    ) {
137        match self {
138            LossBasedAlgorithm::Cubic(cubic) => cubic.on_ack(params, bytes_acked, now, rtt),
139        }
140    }
141
142    fn on_retransmission_timeout(&mut self, params: &mut CongestionControlParams) {
143        match self {
144            LossBasedAlgorithm::Cubic(cubic) => cubic.on_retransmission_timeout(params),
145        }
146    }
147}
148
149impl<I: Instant> CongestionControl<I> {
150    /// Preprocesses an ACK that may contain selective ack blocks.
151    ///
152    /// Returns `Some(true)` if this should be considered a duplicate ACK
153    /// according to the rules in [RFC 6675 section 2]. Returns `Some(false)`
154    /// otherwise.
155    ///
156    /// If the incoming ACK does not have SACK information, `None` is returned
157    /// and the caller should use the classic algorithm to determine if this is
158    /// a duplicate ACk.
159    ///
160    /// [RFC 6675 section 2]:
161    ///     https://datatracker.ietf.org/doc/html/rfc6675#section-2
162    pub(super) fn preprocess_ack(
163        &mut self,
164        seg_ack: SeqNum,
165        snd_nxt: SeqNum,
166        seg_sack_blocks: &SackBlocks,
167    ) -> Option<bool> {
168        let Self { params, algorithm: _, loss_recovery, sack_scoreboard } = self;
169        let high_rxt = loss_recovery.as_ref().and_then(|lr| match lr {
170            LossRecovery::FastRecovery(_) => None,
171            LossRecovery::SackRecovery(sack_recovery) => sack_recovery.high_rxt(),
172        });
173        let is_dup_ack =
174            sack_scoreboard.process_ack(seg_ack, snd_nxt, high_rxt, seg_sack_blocks, params.mss);
175        (!seg_sack_blocks.is_empty()).then_some(is_dup_ack)
176    }
177
178    /// Informs the congestion control algorithm that a segment of length
179    /// `seg_len` is being sent on the wire.
180    ///
181    /// This allows congestion control to keep the correct estimate of how many
182    /// bytes are in flight.
183    pub(super) fn on_will_send_segment(&mut self, seg_len: u32) {
184        let Self { params: _, sack_scoreboard, algorithm: _, loss_recovery: _ } = self;
185        // From RFC 6675:
186        //
187        //  (C.4) The estimate of the amount of data outstanding in the
188        //  network must be updated by incrementing pipe by the number of
189        //  octets transmitted in (C.1).
190        sack_scoreboard.increment_pipe(seg_len);
191    }
192
193    /// Called when there are previously unacknowledged bytes being acked.
194    ///
195    /// If a round-trip-time estimation is not available, `rtt` can be `None`,
196    /// but the loss-based algorithm is not updated in that case.
197    ///
198    /// Returns `true` if this ack signals a loss recovery.
199    pub(super) fn on_ack(
200        &mut self,
201        seg_ack: SeqNum,
202        bytes_acked: NonZeroU32,
203        now: I,
204        rtt: Option<Duration>,
205    ) -> bool {
206        let Self { params, algorithm, loss_recovery, sack_scoreboard: _ } = self;
207        // Exit fast recovery since there is an ACK that acknowledges new data.
208        let outcome = match loss_recovery {
209            None => LossRecoveryOnAckOutcome::None,
210            Some(LossRecovery::FastRecovery(fast_recovery)) => fast_recovery.on_ack(params),
211            Some(LossRecovery::SackRecovery(sack_recovery)) => sack_recovery.on_ack(seg_ack),
212        };
213
214        let recovered = match outcome {
215            LossRecoveryOnAckOutcome::None => false,
216            LossRecoveryOnAckOutcome::Discard { recovered } => {
217                *loss_recovery = None;
218                recovered
219            }
220        };
221
222        // It is possible, however unlikely, that we get here without an RTT
223        // estimation - in case the first data segment that we send out gets
224        // retransmitted. In that case, simply don't update the congestion
225        // parameters with the loss based algorithm which at worst causes slow
226        // start to take one extra step.
227        if let Some(rtt) = rtt {
228            algorithm.on_ack(params, bytes_acked, now, rtt);
229        }
230        recovered
231    }
232
233    /// Called when a duplicate ack is arrived.
234    ///
235    /// Returns `Some` if loss recovery was initiated as a result of this ACK,
236    /// informing which mode was triggered.
237    pub(super) fn on_dup_ack(
238        &mut self,
239        seg_ack: SeqNum,
240        snd_nxt: SeqNum,
241    ) -> Option<LossRecoveryMode> {
242        let Self { params, algorithm, loss_recovery, sack_scoreboard } = self;
243        match loss_recovery {
244            None => {
245                // If we have SACK information, prefer SACK recovery.
246                if sack_scoreboard.has_sack_info() {
247                    let mut sack_recovery = SackRecovery::new();
248                    let started_loss_recovery = sack_recovery
249                        .on_dup_ack(seg_ack, snd_nxt, sack_scoreboard)
250                        .apply(params, algorithm);
251                    *loss_recovery = Some(LossRecovery::SackRecovery(sack_recovery));
252                    started_loss_recovery.then_some(LossRecoveryMode::SackRecovery)
253                } else {
254                    *loss_recovery = Some(LossRecovery::FastRecovery(FastRecovery::new()));
255                    None
256                }
257            }
258            Some(LossRecovery::SackRecovery(sack_recovery)) => sack_recovery
259                .on_dup_ack(seg_ack, snd_nxt, sack_scoreboard)
260                .apply(params, algorithm)
261                .then_some(LossRecoveryMode::SackRecovery),
262            Some(LossRecovery::FastRecovery(fast_recovery)) => fast_recovery
263                .on_dup_ack(params, algorithm, seg_ack)
264                .then_some(LossRecoveryMode::FastRecovery),
265        }
266    }
267
268    /// Called upon a retransmission timeout.
269    ///
270    /// `snd_nxt` is the value of SND.NXT _before_ it is rewound to SND.UNA as
271    /// part of an RTO.
272    pub(super) fn on_retransmission_timeout(&mut self, snd_nxt: SeqNum) {
273        let Self { params, algorithm, loss_recovery, sack_scoreboard } = self;
274        sack_scoreboard.on_retransmission_timeout();
275        let discard_loss_recovery = match loss_recovery {
276            None | Some(LossRecovery::FastRecovery(_)) => true,
277            Some(LossRecovery::SackRecovery(sack_recovery)) => {
278                sack_recovery.on_retransmission_timeout(snd_nxt)
279            }
280        };
281        if discard_loss_recovery {
282            *loss_recovery = None;
283        }
284        algorithm.on_retransmission_timeout(params);
285    }
286
287    pub(super) fn slow_start_threshold(&self) -> u32 {
288        self.params.ssthresh
289    }
290
291    #[cfg(test)]
292    pub(super) fn pipe(&self) -> u32 {
293        self.sack_scoreboard.pipe()
294    }
295
296    /// Inflates the congestion window by `value` to facilitate testing.
297    #[cfg(test)]
298    pub(super) fn inflate_cwnd(&mut self, inflation: u32) {
299        self.params.cwnd += inflation;
300    }
301
302    pub(super) fn cubic_with_mss(mss: EffectiveMss) -> Self {
303        Self {
304            params: CongestionControlParams::with_mss(mss),
305            algorithm: LossBasedAlgorithm::Cubic(Default::default()),
306            loss_recovery: None,
307            sack_scoreboard: SackScoreboard::default(),
308        }
309    }
310
311    pub(super) fn mss(&self) -> EffectiveMss {
312        self.params.mss
313    }
314
315    pub(super) fn update_mss(&mut self, mss: Mss, snd_una: SeqNum, snd_nxt: SeqNum) {
316        let Self { params, sack_scoreboard, algorithm: _, loss_recovery } = self;
317        let orig = u32::from(params.mss);
318        params.mss.update_mss(mss);
319
320        // From [RFC 5681 section 3.1]:
321        //
322        //    When initial congestion windows of more than one segment are
323        //    implemented along with Path MTU Discovery [RFC1191], and the MSS
324        //    being used is found to be too large, the congestion window cwnd
325        //    SHOULD be reduced to prevent large bursts of smaller segments.
326        //    Specifically, cwnd SHOULD be reduced by the ratio of the old segment
327        //    size to the new segment size.
328        //
329        // [RFC 5681 section 3.1]: https://datatracker.ietf.org/doc/html/rfc5681#section-3.1
330        if params.ssthresh == u32::MAX {
331            params.cwnd = params.cwnd.saturating_div(orig).saturating_mul(u32::from(params.mss));
332        }
333
334        // Given we'll retransmit after receiving this, we need to update the
335        // SACK scoreboard so pipe is recalculated based on this value of
336        // snd_nxt and mss.
337        let high_rxt = loss_recovery.as_ref().and_then(|lr| match lr {
338            LossRecovery::FastRecovery(_) => None,
339            LossRecovery::SackRecovery(sack_recovery) => sack_recovery.high_rxt(),
340        });
341        sack_scoreboard.on_mss_update(snd_una, snd_nxt, high_rxt, params.mss);
342    }
343
344    /// Returns the rounded unmodified by loss recovery window size.
345    ///
346    /// This is meant to be used for inspection only. Congestion calculation for
347    /// sending should use [`CongestionControl::poll_send`].
348    pub(super) fn inspect_cwnd(&self) -> CongestionWindow {
349        self.params.rounded_cwnd()
350    }
351
352    /// Returns the current loss recovery mode, if any.
353    ///
354    /// This method returns the current loss recovery mode.
355    ///
356    /// *NOTE* It's possible for [`CongestionControl`] to return a
357    /// [`LossRecoveryMode`] here even if there was no congestion events. Rely
358    /// on the return values from [`CongestionControl::poll_send`],
359    /// [`CongestionControl::on_dup_ack`] to catch entering into loss recovery
360    /// mode or determining if segments originate from a specific algorithm.
361    pub(super) fn inspect_loss_recovery_mode(&self) -> Option<LossRecoveryMode> {
362        self.loss_recovery.as_ref().map(|lr| lr.mode())
363    }
364
365    /// Returns true if this [`CongestionControl`] is in slow start.
366    pub(super) fn in_slow_start(&self) -> bool {
367        self.params.cwnd < self.params.ssthresh
368    }
369
370    /// Polls congestion control for the next segment to be sent out.
371    ///
372    /// Receives pertinent parameters from the sender state machine to allow for
373    /// this decision:
374    ///
375    /// - `snd_una` is SND.UNA the highest unacknowledged sequence number.
376    /// - `snd_nxt` is SND.NXT, the next sequence number the sender would send
377    ///   without loss recovery.
378    /// - `snd_wnd` is the total send window, i.e. the allowable receiver window
379    ///   after `snd_una`.
380    /// - `available_bytes` is the total number of bytes in the send buffer,
381    ///   starting at `snd_una`.
382    ///
383    /// Returns `None` if no segment should be sent right now.
384    pub(super) fn poll_send(
385        &mut self,
386        snd_una: SeqNum,
387        snd_nxt: SeqNum,
388        snd_wnd: WindowSize,
389        available_bytes: usize,
390    ) -> Option<CongestionControlSendOutcome> {
391        let Self { params, algorithm: _, loss_recovery, sack_scoreboard } = self;
392        let cwnd = params.rounded_cwnd();
393
394        match loss_recovery {
395            None => {
396                let pipe = sack_scoreboard.pipe();
397                let congestion_window = cwnd.cwnd();
398                let available_window = congestion_window.saturating_sub(pipe);
399                let congestion_limit = available_window.min(cwnd.mss().into());
400                Some(CongestionControlSendOutcome {
401                    next_seg: snd_nxt,
402                    congestion_limit,
403                    congestion_window,
404                    loss_recovery: LossRecoverySegment::No,
405                })
406            }
407            Some(LossRecovery::FastRecovery(fast_recovery)) => {
408                Some(fast_recovery.poll_send(cwnd, sack_scoreboard.pipe(), snd_nxt))
409            }
410            Some(LossRecovery::SackRecovery(sack_recovery)) => sack_recovery.poll_send(
411                cwnd,
412                snd_una,
413                snd_nxt,
414                snd_wnd,
415                available_bytes,
416                sack_scoreboard,
417            ),
418        }
419    }
420}
421
422/// Indicates whether the segment yielded in [`CongestionControlSendOutcome`] is
423/// a loss recovery segment.
424#[derive(Debug)]
425#[cfg_attr(test, derive(Copy, Clone, Eq, PartialEq))]
426pub(super) enum LossRecoverySegment {
427    /// Indicates the segment is a loss recovery segment.
428    Yes {
429        /// If true, the retransmit timer should be rearmed due to this loss
430        /// recovery segment.
431        ///
432        /// This is used in SACK recovery to prevent RTOs during retransmission,
433        /// from [RFC 6675 section 6]:
434        ///
435        /// > Therefore, we give implementers the latitude to use the standard
436        /// > [RFC6298]-style RTO management or, optionally, a more careful
437        /// > variant that re-arms the RTO timer on each retransmission that is
438        /// > sent during recovery MAY be used.  This provides a more
439        /// > conservative timer than specified in [RFC6298], and so may not
440        /// > always be an attractive alternative.  However, in some cases it
441        /// > may prevent needless retransmissions, go-back-N transmission, and
442        /// > further reduction of the congestion window.
443        ///
444        /// [RFC 6675 section 6]: https://datatracker.ietf.org/doc/html/rfc6675#section-6
445        rearm_retransmit: bool,
446        /// The recovery mode that caused this loss recovery segment.
447        mode: LossRecoveryMode,
448    },
449    /// Indicates the segment is *not* a loss recovery segment.
450    No,
451}
452
453/// The outcome of [`CongestionControl::poll_send`].
454#[derive(Debug)]
455#[cfg_attr(test, derive(Eq, PartialEq))]
456pub(super) struct CongestionControlSendOutcome {
457    /// The next segment to be sent out.
458    pub next_seg: SeqNum,
459    /// The maximum number of bytes post next_seg that can be sent.
460    ///
461    /// This limit does not account for the unused/open window on the receiver.
462    ///
463    /// This is limited by the current congestion limit and the sender MSS.
464    pub congestion_limit: u32,
465    /// The congestion window used to calculate `congestion limit`.
466    ///
467    /// This is the estimated total congestion window, including loss
468    /// recovery-based inflation.
469    pub congestion_window: u32,
470    /// Whether this is a loss recovery segment.
471    pub loss_recovery: LossRecoverySegment,
472}
473
474/// The current loss recovery mode.
475#[derive(Debug)]
476pub enum LossRecovery {
477    FastRecovery(FastRecovery),
478    SackRecovery(SackRecovery),
479}
480
481impl LossRecovery {
482    fn mode(&self) -> LossRecoveryMode {
483        match self {
484            LossRecovery::FastRecovery(_) => LossRecoveryMode::FastRecovery,
485            LossRecovery::SackRecovery(_) => LossRecoveryMode::SackRecovery,
486        }
487    }
488}
489
490/// An equivalent to [`LossRecovery`] that simply informs the loss recovery
491/// mode, without carrying state.
492#[derive(Debug)]
493#[cfg_attr(test, derive(Copy, Clone, Eq, PartialEq))]
494pub enum LossRecoveryMode {
495    FastRecovery,
496    SackRecovery,
497}
498
499#[derive(Debug)]
500#[cfg_attr(test, derive(Eq, PartialEq))]
501enum LossRecoveryOnAckOutcome {
502    None,
503    Discard { recovered: bool },
504}
505
506/// Reno style Fast Recovery algorithm as described in
507/// [RFC 5681](https://tools.ietf.org/html/rfc5681).
508#[derive(Debug)]
509pub struct FastRecovery {
510    /// Holds the sequence number of the segment to fast retransmit, if any.
511    fast_retransmit: Option<SeqNum>,
512    /// The running count of consecutive duplicate ACKs we have received so far.
513    ///
514    /// Here we limit the maximum number of duplicate ACKS we track to 255, as
515    /// per a note in the RFC:
516    ///
517    /// Note: [SCWA99] discusses a receiver-based attack whereby many
518    /// bogus duplicate ACKs are sent to the data sender in order to
519    /// artificially inflate cwnd and cause a higher than appropriate
520    /// sending rate to be used.  A TCP MAY therefore limit the number of
521    /// times cwnd is artificially inflated during loss recovery to the
522    /// number of outstanding segments (or, an approximation thereof).
523    ///
524    /// [SCWA99]: https://homes.cs.washington.edu/~tom/pubs/CCR99.pdf
525    dup_acks: NonZeroU8,
526}
527
528impl FastRecovery {
529    fn new() -> Self {
530        Self { dup_acks: NonZeroU8::new(1).unwrap(), fast_retransmit: None }
531    }
532
533    fn poll_send(
534        &mut self,
535        cwnd: CongestionWindow,
536        used_congestion_window: u32,
537        snd_nxt: SeqNum,
538    ) -> CongestionControlSendOutcome {
539        let Self { fast_retransmit, dup_acks } = self;
540        // Per RFC 3042 (https://www.rfc-editor.org/rfc/rfc3042#section-2): ...
541        // the Limited Transmit algorithm, which calls for a TCP sender to
542        //   transmit new data upon the arrival of the first two consecutive
543        //   duplicate ACKs ... The amount of outstanding data would remain less
544        //   than or equal to the congestion window plus 2 segments.  In other
545        //   words, the sender can only send two segments beyond the congestion
546        //   window (cwnd).
547        //
548        // Note: We don't directly change cwnd in the loss-based algorithm
549        // because the RFC says one MUST NOT do that. We follow the requirement
550        // here by not changing the cwnd of the algorithm - if a new ACK is
551        // received after the two dup acks, the loss-based algorithm will
552        // continue to operate the same way as if the 2 SMSS is never added to
553        // cwnd.
554        let congestion_window = if dup_acks.get() < DUP_ACK_THRESHOLD {
555            cwnd.cwnd().saturating_add(u32::from(dup_acks.get()) * u32::from(cwnd.mss()))
556        } else {
557            cwnd.cwnd()
558        };
559
560        // Elect fast retransmit sequence number or snd_nxt if we don't have
561        // one.
562        let (next_seg, loss_recovery, congestion_limit) = match fast_retransmit.take() {
563            // From RFC 5681:
564            //
565            //  3. The lost segment starting at SND.UNA MUST be retransmitted
566            //     [...].
567            //
568            // So we always set the congestion limit to be just the mss.
569            Some(f) => (
570                f,
571                LossRecoverySegment::Yes {
572                    rearm_retransmit: false,
573                    mode: LossRecoveryMode::FastRecovery,
574                },
575                cwnd.mss().into(),
576            ),
577            // There's no fast retransmit pending, use snd_nxt applying the used
578            // congestion window.
579            None => (
580                snd_nxt,
581                LossRecoverySegment::No,
582                congestion_window.saturating_sub(used_congestion_window).min(cwnd.mss().into()),
583            ),
584        };
585        CongestionControlSendOutcome {
586            next_seg,
587            congestion_limit,
588            congestion_window,
589            loss_recovery,
590        }
591    }
592
593    fn on_ack(&mut self, params: &mut CongestionControlParams) -> LossRecoveryOnAckOutcome {
594        let recovered = self.dup_acks.get() >= DUP_ACK_THRESHOLD;
595        if recovered {
596            // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
597            //   When the next ACK arrives that acknowledges previously
598            //   unacknowledged data, a TCP MUST set cwnd to ssthresh (the value
599            //   set in step 2).  This is termed "deflating" the window.
600            params.cwnd = params.ssthresh;
601        }
602        LossRecoveryOnAckOutcome::Discard { recovered }
603    }
604
605    /// Processes a duplicate ack with sequence number `seg_ack`.
606    ///
607    /// Returns `true` if loss recovery is triggered.
608    fn on_dup_ack<I: Instant>(
609        &mut self,
610        params: &mut CongestionControlParams,
611        loss_based: &mut LossBasedAlgorithm<I>,
612        seg_ack: SeqNum,
613    ) -> bool {
614        self.dup_acks = self.dup_acks.saturating_add(1);
615
616        match self.dup_acks.get().cmp(&DUP_ACK_THRESHOLD) {
617            Ordering::Less => false,
618            Ordering::Equal => {
619                loss_based.on_loss_detected(params);
620                // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
621                //   The lost segment starting at SND.UNA MUST be retransmitted
622                //   and cwnd set to ssthresh plus 3*SMSS.  This artificially
623                //   "inflates" the congestion window by the number of segments
624                //   (three) that have left the network and which the receiver
625                //   has buffered.
626                self.fast_retransmit = Some(seg_ack);
627                params.cwnd =
628                    params.ssthresh + u32::from(DUP_ACK_THRESHOLD) * u32::from(params.mss);
629                true
630            }
631            Ordering::Greater => {
632                // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
633                //   For each additional duplicate ACK received (after the third),
634                //   cwnd MUST be incremented by SMSS. This artificially inflates
635                //   the congestion window in order to reflect the additional
636                //   segment that has left the network.
637                params.cwnd = params.cwnd.saturating_add(u32::from(params.mss));
638                false
639            }
640        }
641    }
642}
643
644/// The state kept by [`SackRecovery`] indicating the recovery state.
645#[derive(Debug)]
646#[cfg_attr(test, derive(Eq, PartialEq, Copy, Clone))]
647enum SackRecoveryState {
648    /// SACK is currently in active recovery.
649    InRecovery(SackInRecoveryState),
650    /// SACK is holding off starting new recovery after an RTO.
651    PostRto { recovery_point: SeqNum },
652    /// SACK is not in active recovery.
653    NotInRecovery,
654}
655
656/// The state kept by [`SackInRecoveryState::InRecovery`].
657#[derive(Debug)]
658#[cfg_attr(test, derive(Eq, PartialEq, Copy, Clone))]
659struct SackInRecoveryState {
660    /// The sequence number that marks the end of the current loss recovery
661    /// phase.
662    recovery_point: SeqNum,
663    /// The highest retransmitted sequence number during the current loss
664    /// recovery phase.
665    ///
666    /// Tracks the "HighRxt" variable defined in [RFC 6675 section 2].
667    ///
668    /// [RFC 6675 section 2]: https://datatracker.ietf.org/doc/html/rfc6675#section-2
669    high_rxt: SeqNum,
670    /// The highest sequence number that has been optimistically retransmitted.
671    ///
672    /// Tracks the "RescureRxt" variable defined in [RFC 6675 section 2].
673    ///
674    /// [RFC 6675 section 2]: https://datatracker.ietf.org/doc/html/rfc6675#section-2
675    rescue_rxt: Option<SeqNum>,
676}
677
678/// Implements the SACK based recovery from [RFC 6675].
679///
680/// [RFC 6675]: https://datatracker.ietf.org/doc/html/rfc6675
681#[derive(Debug)]
682pub(crate) struct SackRecovery {
683    /// Keeps track of the number of duplicate ACKs received during SACK
684    /// recovery.
685    ///
686    /// Tracks the "DupAcks" variable defined in [RFC 6675 section 2].
687    ///
688    /// [RFC 6675 section 2]: https://datatracker.ietf.org/doc/html/rfc6675#section-2
689    dup_acks: u8,
690    /// Statekeeping for loss recovery.
691    ///
692    /// Set to `Some` when we're in recovery state.
693    recovery: SackRecoveryState,
694}
695
696impl SackRecovery {
697    fn new() -> Self {
698        Self {
699            // Unlike FastRecovery, we start with zero duplicate ACKs,
700            // congestion control calls on_dup_ack after creation.
701            dup_acks: 0,
702            recovery: SackRecoveryState::NotInRecovery,
703        }
704    }
705
706    fn high_rxt(&self) -> Option<SeqNum> {
707        match &self.recovery {
708            SackRecoveryState::InRecovery(SackInRecoveryState {
709                recovery_point: _,
710                high_rxt,
711                rescue_rxt: _,
712            }) => Some(*high_rxt),
713            SackRecoveryState::PostRto { recovery_point: _ } | SackRecoveryState::NotInRecovery => {
714                None
715            }
716        }
717    }
718
719    fn on_ack(&mut self, seg_ack: SeqNum) -> LossRecoveryOnAckOutcome {
720        let Self { dup_acks, recovery } = self;
721        match recovery {
722            SackRecoveryState::InRecovery(SackInRecoveryState {
723                recovery_point,
724                high_rxt: _,
725                rescue_rxt: _,
726            })
727            | SackRecoveryState::PostRto { recovery_point } => {
728                // From RFC 6675:
729                //  An incoming cumulative ACK for a sequence number greater than
730                //  RecoveryPoint signals the end of loss recovery, and the loss
731                //  recovery phase MUST be terminated.
732                if seg_ack.after_or_eq(*recovery_point) {
733                    LossRecoveryOnAckOutcome::Discard {
734                        recovered: matches!(recovery, SackRecoveryState::InRecovery(_)),
735                    }
736                } else {
737                    // From RFC 6675:
738                    //  If the incoming ACK is a cumulative acknowledgment, the
739                    //  TCP MUST reset DupAcks to zero.
740                    *dup_acks = 0;
741                    LossRecoveryOnAckOutcome::None
742                }
743            }
744            SackRecoveryState::NotInRecovery => {
745                // We're not in loss recovery, we seem to have moved things
746                // forward. Discard loss recovery information.
747                LossRecoveryOnAckOutcome::Discard { recovered: false }
748            }
749        }
750    }
751
752    /// Processes a duplicate acknowledgement.
753    fn on_dup_ack(
754        &mut self,
755        seq_ack: SeqNum,
756        snd_nxt: SeqNum,
757        sack_scoreboard: &SackScoreboard,
758    ) -> SackDupAckOutcome {
759        let Self { dup_acks, recovery } = self;
760        match recovery {
761            SackRecoveryState::InRecovery(_) | SackRecoveryState::PostRto { .. } => {
762                // Already in recovery mode, nothing to do.
763                return SackDupAckOutcome(false);
764            }
765            SackRecoveryState::NotInRecovery => (),
766        }
767        *dup_acks += 1;
768        // From RFC 6675:
769        //  (1) If DupAcks >= DupThresh, [...].
770        //  (2) If DupAcks < DupThresh but IsLost (HighACK + 1) returns true
771        //  [...]
772        if *dup_acks >= DUP_ACK_THRESHOLD || sack_scoreboard.is_first_hole_lost() {
773            // Enter loss recovery:
774            //  (4.1) RecoveryPoint = HighData
775            //  When the TCP sender receives a cumulative ACK for this data
776            //  octet, the loss recovery phase is terminated.
777            *recovery = SackRecoveryState::InRecovery(SackInRecoveryState {
778                recovery_point: snd_nxt,
779                high_rxt: seq_ack,
780                rescue_rxt: None,
781            });
782            SackDupAckOutcome(true)
783        } else {
784            SackDupAckOutcome(false)
785        }
786    }
787
788    /// Updates SACK recovery to account for a retransmission timeout during
789    /// recovery.
790    ///
791    /// From [RFC 6675 section 5.1]:
792    ///
793    /// > If an RTO occurs during loss recovery as specified in this document,
794    /// > RecoveryPoint MUST be set to HighData.  Further, the new value of
795    /// > RecoveryPoint MUST be preserved and the loss recovery algorithm
796    /// > outlined in this document MUST be terminated.  In addition, a new
797    /// > recovery phase (as described in Section 5) MUST NOT be initiated until
798    /// > HighACK is greater than or equal to the new value of RecoveryPoint.
799    ///
800    /// [RFC 6675 section 5.1]: https://datatracker.ietf.org/doc/html/rfc6675#section-5.1
801    ///
802    /// Returns `true` iff we can clear all recovery state due to the timeout.
803    pub(crate) fn on_retransmission_timeout(&mut self, snd_nxt: SeqNum) -> bool {
804        let Self { dup_acks: _, recovery } = self;
805        match recovery {
806            SackRecoveryState::InRecovery(SackInRecoveryState { .. }) => {
807                *recovery = SackRecoveryState::PostRto { recovery_point: snd_nxt };
808                false
809            }
810            SackRecoveryState::PostRto { recovery_point: _ } => {
811                // NB: The RFC is not exactly clear on what to do here, but the
812                // best interpretation is that we should maintain the old
813                // recovery point until we've hit that point and don't update to
814                // the new (assumedly rewound) snd_nxt.
815                false
816            }
817            SackRecoveryState::NotInRecovery => {
818                // Not in recovery we can reset our state.
819                true
820            }
821        }
822    }
823
824    /// SACK recovery based congestion control next segment selection.
825    ///
826    /// Argument semantics are the same as [`CongestionControl::poll_send`].
827    fn poll_send(
828        &mut self,
829        cwnd: CongestionWindow,
830        snd_una: SeqNum,
831        snd_nxt: SeqNum,
832        snd_wnd: WindowSize,
833        available_bytes: usize,
834        sack_scoreboard: &SackScoreboard,
835    ) -> Option<CongestionControlSendOutcome> {
836        let Self { dup_acks: _, recovery } = self;
837
838        let pipe = sack_scoreboard.pipe();
839        let congestion_window = cwnd.cwnd();
840        let available_window = congestion_window.saturating_sub(pipe);
841        // Don't send anything if we can't send at least full MSS, following the
842        // RFC. All outcomes require at least one MSS of available window:
843        //
844        // (3.3) If (cwnd - pipe) >= 1 SMSS [...]
845        // (C) If cwnd - pipe >= 1 SMSS [...]
846        if available_window < cwnd.mss().into() {
847            return None;
848        }
849        let congestion_limit = available_window.min(cwnd.mss().into());
850
851        // If we're not in recovery, use the regular congestion calculation,
852        // adjusting the congestion window with the pipe value.
853        //
854        // From RFC 6675:
855        //
856        //  (3.3) If (cwnd - pipe) >= 1 SMSS, there exists previously unsent
857        //  data, and the receiver's advertised window allows, transmit up
858        //  to 1 SMSS of data starting with the octet HighData+1 and update
859        //  HighData to reflect this transmission, then return to (3.2).
860        let SackInRecoveryState { recovery_point, high_rxt, rescue_rxt } = match recovery {
861            SackRecoveryState::InRecovery(sack_in_recovery_state) => sack_in_recovery_state,
862            SackRecoveryState::PostRto { recovery_point: _ } | SackRecoveryState::NotInRecovery => {
863                return Some(CongestionControlSendOutcome {
864                    next_seg: snd_nxt,
865                    congestion_limit,
866                    congestion_window,
867                    loss_recovery: LossRecoverySegment::No,
868                });
869            }
870        };
871
872        // From RFC 6675 section 6:
873        //
874        //  we give implementers the latitude to use the standard
875        //  [RFC6298]-style RTO management or, optionally, a more careful
876        //  variant that re-arms the RTO timer on each retransmission that is
877        //  sent during recovery MAY be used.  This provides a more conservative
878        //  timer than specified in [RFC6298].
879        //
880        // As a local decision, we only rearm the retransmit timer for rules 1
881        // and 3 (regular retransmissions) when the next segment trying to be
882        // sent out is _before_ the recovery point that initiated this loss
883        // recovery. Given the recovery algorithm greedily keeps sending more
884        // data as long as it's available to keep the ACK clock running, there's
885        // a catastrophic scenario where the data sent past the recovery point
886        // creates new holes in the sack scoreboard that are filled by rules 1
887        // and 3 and rearm the RTO, even if the retransmissions from holes
888        // before RecoveryPoint might be lost themselves. Hence, once the
889        // algorithm has moved past trying to fix things past the RecoveryPoint
890        // we stop rearming the RTO in case the ACK for RecoveryPoint never
891        // arrives.
892        //
893        // Note that rule 4 always rearms the retransmission timer because it
894        // sents only a single segment per entry into recovery.¡
895        let rearm_retransmit = |next_seg: SeqNum| next_seg.before(*recovery_point);
896
897        // run NextSeg() as defined in RFC 6675.
898
899        // (1) If there exists a smallest unSACKed sequence number 'S2' that
900        //   meets the following three criteria for determining loss, the
901        //   sequence range of one segment of up to SMSS octets starting
902        //   with S2 MUST be returned.
903        //
904        //   (1.a) S2 is greater than HighRxt.
905        //   (1.b) S2 is less than the highest octet covered by any received
906        //         SACK.
907        //   (1.c) IsLost (S2) returns true.
908
909        let first_unsacked_range =
910            sack_scoreboard.first_unsacked_range_from(snd_una.latest(*high_rxt));
911
912        if let Some(first_hole) = &first_unsacked_range {
913            // Meta is the IsLost value.
914            if *first_hole.meta() {
915                let hole_size = first_hole.len();
916                let congestion_limit = congestion_limit.min(hole_size);
917                *high_rxt = first_hole.start() + congestion_limit;
918
919                // If we haven't set RescueRxt yet, set it to prevent eager
920                // rescue. From RFC 6675:
921                //
922                //  Retransmit the first data segment presumed dropped --
923                //  the segment starting with sequence number HighACK + 1.
924                //  To prevent repeated retransmission of the same data or a
925                //  premature rescue retransmission, set both HighRxt and
926                //  RescueRxt to the highest sequence number in the
927                //  retransmitted segment.
928                if rescue_rxt.is_none() {
929                    *rescue_rxt = Some(*high_rxt);
930                }
931
932                return Some(CongestionControlSendOutcome {
933                    next_seg: first_hole.start(),
934                    congestion_limit,
935                    congestion_window,
936                    loss_recovery: LossRecoverySegment::Yes {
937                        rearm_retransmit: rearm_retransmit(first_hole.start()),
938                        mode: LossRecoveryMode::SackRecovery,
939                    },
940                });
941            }
942        }
943
944        // Run next rule, from RFC 6675:
945        //
946        // (2) If no sequence number 'S2' per rule (1) exists but there
947        // exists available unsent data and the receiver's advertised window
948        // allows, the sequence range of one segment of up to SMSS octets of
949        // previously unsent data starting with sequence number HighData+1
950        // MUST be returned.
951        let total_sent = u32::try_from(snd_nxt - snd_una).unwrap();
952        if available_bytes > usize::try_from(total_sent).unwrap() && u32::from(snd_wnd) > total_sent
953        {
954            return Some(CongestionControlSendOutcome {
955                next_seg: snd_nxt,
956                // We only need to send out the congestion limit, the window
957                // limit is applied by the sender state machine.
958                congestion_limit,
959                congestion_window,
960                // NB: even though we're sending new bytes, we're still
961                // signaling that we're in loss recovery. Our goal here is
962                // to keep the ACK clock running and prevent an RTO, so we
963                // don't want this segment to be delayed by anything.
964                loss_recovery: LossRecoverySegment::Yes {
965                    rearm_retransmit: false,
966                    mode: LossRecoveryMode::SackRecovery,
967                },
968            });
969        }
970
971        // Run next rule, from RFC 6675:
972        //
973        //  (3) If the conditions for rules (1) and (2) fail, but there
974        //  exists an unSACKed sequence number 'S3' that meets the criteria
975        //  for detecting loss given in steps (1.a) and (1.b) above
976        //  (specifically excluding step (1.c)), then one segment of up to
977        //  SMSS octets starting with S3 SHOULD be returned.
978        if let Some(first_hole) = first_unsacked_range {
979            let hole_size = first_hole.len();
980            let congestion_limit = congestion_limit.min(hole_size);
981            *high_rxt = first_hole.start() + congestion_limit;
982
983            return Some(CongestionControlSendOutcome {
984                next_seg: first_hole.start(),
985                congestion_limit,
986                congestion_window,
987                loss_recovery: LossRecoverySegment::Yes {
988                    rearm_retransmit: rearm_retransmit(first_hole.start()),
989                    mode: LossRecoveryMode::SackRecovery,
990                },
991            });
992        }
993
994        // Run next rule, from RFC 6675:
995        //
996        //  (4) If the conditions for (1), (2), and (3) fail, but there
997        //  exists outstanding unSACKed data, we provide the opportunity for
998        //  a single "rescue" retransmission per entry into loss recovery.
999        //  If HighACK is greater than RescueRxt (or RescueRxt is
1000        //  undefined), then one segment of up to SMSS octets that MUST
1001        //  include the highest outstanding unSACKed sequence number SHOULD
1002        //  be returned, and RescueRxt set to RecoveryPoint. HighRxt MUST
1003        //  NOT be updated.
1004        if rescue_rxt.is_none_or(|rescue_rxt| snd_una.after_or_eq(rescue_rxt)) {
1005            if let Some(right_edge) = sack_scoreboard.right_edge() {
1006                let left = right_edge.latest(snd_nxt - congestion_limit);
1007                // This can't send any new data, so figure out how much space we
1008                // have left. If SND.NXT got rewound and is now before the right
1009                // edge, unwrap the calculation to zero to avoid sending the
1010                // rescue segment.
1011                let congestion_limit = u32::try_from(snd_nxt - left).unwrap_or(0);
1012                if congestion_limit > 0 {
1013                    *rescue_rxt = Some(*recovery_point);
1014                    return Some(CongestionControlSendOutcome {
1015                        next_seg: left,
1016                        congestion_limit,
1017                        congestion_window,
1018                        // NB: Rescue retransmissions can only happen once in
1019                        // every recovery enter, so always rearm the RTO.
1020                        loss_recovery: LossRecoverySegment::Yes {
1021                            rearm_retransmit: true,
1022                            mode: LossRecoveryMode::SackRecovery,
1023                        },
1024                    });
1025                }
1026            }
1027        }
1028
1029        None
1030    }
1031}
1032
1033/// The value returned by [`SackRecovery::on_dup_ack`].
1034///
1035/// It contains a boolean indicating whether loss recovery started due to a
1036/// duplicate ACK. [`SackDupAckOutcome::apply`] is used to retrieve the boolean
1037/// and notify loss recovery algorithm as needed and update the congestion
1038/// parameters.
1039///
1040/// This is its own type so [`SackRecovery::on_dup_ack`] can be tested in
1041/// isolation from [`LossBasedAlgorithm`].
1042#[derive(Debug)]
1043#[cfg_attr(test, derive(Eq, PartialEq))]
1044struct SackDupAckOutcome(bool);
1045
1046impl SackDupAckOutcome {
1047    /// Consumes this outcome, notifying `algorithm` that loss was detected if
1048    /// needed.
1049    ///
1050    /// Returns the inner boolean indicating whether loss recovery started.
1051    fn apply<I: Instant>(
1052        self,
1053        params: &mut CongestionControlParams,
1054        algorithm: &mut LossBasedAlgorithm<I>,
1055    ) -> bool {
1056        let Self(loss_recovery) = self;
1057        if loss_recovery {
1058            algorithm.on_loss_detected(params);
1059        }
1060        loss_recovery
1061    }
1062}
1063
1064#[cfg(test)]
1065mod test {
1066    use core::ops::Range;
1067
1068    use assert_matches::assert_matches;
1069    use netstack3_base::testutil::FakeInstant;
1070    use netstack3_base::{EffectiveMss, MssSizeLimiters, SackBlock};
1071    use test_case::{test_case, test_matrix};
1072
1073    use super::*;
1074    use crate::internal::testutil;
1075
1076    const MSS_1: EffectiveMss =
1077        EffectiveMss::from_mss(Mss::DEFAULT_IPV4, MssSizeLimiters { timestamp_enabled: false });
1078    const MSS_2: EffectiveMss =
1079        EffectiveMss::from_mss(Mss::DEFAULT_IPV6, MssSizeLimiters { timestamp_enabled: false });
1080
1081    enum StartingAck {
1082        One,
1083        Wraparound,
1084        WraparoundAfter(u32),
1085    }
1086
1087    impl StartingAck {
1088        fn into_seqnum(self, mss: EffectiveMss) -> SeqNum {
1089            let mss = u32::from(mss);
1090            match self {
1091                StartingAck::One => SeqNum::new(1),
1092                StartingAck::Wraparound => SeqNum::new((mss / 2).wrapping_sub(mss)),
1093                StartingAck::WraparoundAfter(n) => SeqNum::new((mss / 2).wrapping_sub(n * mss)),
1094            }
1095        }
1096    }
1097
1098    impl SackRecovery {
1099        #[track_caller]
1100        fn assert_in_recovery(&mut self) -> &mut SackInRecoveryState {
1101            assert_matches!(&mut self.recovery, SackRecoveryState::InRecovery(s) => s)
1102        }
1103    }
1104
1105    impl<I> CongestionControl<I> {
1106        #[track_caller]
1107        fn assert_sack_recovery(&mut self) -> &mut SackRecovery {
1108            assert_matches!(&mut self.loss_recovery, Some(LossRecovery::SackRecovery(s)) => s)
1109        }
1110    }
1111
1112    fn nth_segment_from(base: SeqNum, mss: EffectiveMss, n: u32) -> Range<SeqNum> {
1113        let mss = u32::from(mss);
1114        let start = base + n * mss;
1115        Range { start, end: start + mss }
1116    }
1117
1118    fn nth_range(base: SeqNum, mss: EffectiveMss, range: Range<u32>) -> Range<SeqNum> {
1119        let mss = u32::from(mss);
1120        let Range { start, end } = range;
1121        let start = base + start * mss;
1122        let end = base + end * mss;
1123        Range { start, end }
1124    }
1125
1126    #[test]
1127    fn no_recovery_before_reaching_threshold() {
1128        let mut congestion_control = CongestionControl::cubic_with_mss(MSS_1);
1129        let old_cwnd = congestion_control.params.cwnd;
1130        assert_eq!(congestion_control.params.ssthresh, u32::MAX);
1131        assert_eq!(congestion_control.on_dup_ack(SeqNum::new(0), SeqNum::new(1)), None);
1132        assert!(!congestion_control.on_ack(
1133            SeqNum::new(1),
1134            NonZeroU32::new(1).unwrap(),
1135            FakeInstant::from(Duration::from_secs(0)),
1136            Some(Duration::from_secs(1)),
1137        ));
1138        // We have only received one duplicate ack, receiving a new ACK should
1139        // not mean "loss recovery" - we should not bump our cwnd to initial
1140        // ssthresh (u32::MAX) and then overflow.
1141        assert_eq!(old_cwnd + 1, congestion_control.params.cwnd);
1142    }
1143
1144    #[test]
1145    fn preprocess_ack_result() {
1146        let ack = SeqNum::new(1);
1147        let snd_nxt = SeqNum::new(100);
1148        let mut congestion_control = CongestionControl::<FakeInstant>::cubic_with_mss(MSS_1);
1149        assert_eq!(congestion_control.preprocess_ack(ack, snd_nxt, &SackBlocks::EMPTY), None);
1150        assert_eq!(
1151            congestion_control.preprocess_ack(ack, snd_nxt, &testutil::sack_blocks([10..20])),
1152            Some(true)
1153        );
1154        assert_eq!(congestion_control.preprocess_ack(ack, snd_nxt, &SackBlocks::EMPTY), None);
1155        assert_eq!(
1156            congestion_control.preprocess_ack(ack, snd_nxt, &testutil::sack_blocks([10..20])),
1157            Some(false)
1158        );
1159        assert_eq!(
1160            congestion_control.preprocess_ack(
1161                ack,
1162                snd_nxt,
1163                &testutil::sack_blocks([10..20, 20..30])
1164            ),
1165            Some(true)
1166        );
1167    }
1168
1169    #[test_case(DUP_ACK_THRESHOLD-1; "no loss")]
1170    #[test_case(DUP_ACK_THRESHOLD; "exact threshold")]
1171    #[test_case(DUP_ACK_THRESHOLD+1; "over threshold")]
1172    fn sack_recovery_enter_exit_loss_dupacks(dup_acks: u8) {
1173        let mut congestion_control = CongestionControl::cubic_with_mss(MSS_1);
1174        let mss = congestion_control.mss();
1175
1176        let ack = SeqNum::new(1);
1177        let snd_nxt = nth_segment_from(ack, mss, 10).end;
1178
1179        let expect_recovery =
1180            SackInRecoveryState { recovery_point: snd_nxt, high_rxt: ack, rescue_rxt: None };
1181
1182        let mut sack = SackBlock::try_from(nth_segment_from(ack, mss, 1)).unwrap();
1183        for n in 1..=dup_acks {
1184            assert_eq!(
1185                congestion_control.preprocess_ack(ack, snd_nxt, &[sack].into_iter().collect()),
1186                Some(true)
1187            );
1188            assert_eq!(
1189                congestion_control.on_dup_ack(ack, snd_nxt),
1190                (n == DUP_ACK_THRESHOLD).then_some(LossRecoveryMode::SackRecovery)
1191            );
1192            let sack_recovery = congestion_control.assert_sack_recovery();
1193            // We stop counting duplicate acks after the threshold.
1194            assert_eq!(sack_recovery.dup_acks, n.min(DUP_ACK_THRESHOLD));
1195
1196            let expect_recovery = if n >= DUP_ACK_THRESHOLD {
1197                SackRecoveryState::InRecovery(expect_recovery.clone())
1198            } else {
1199                SackRecoveryState::NotInRecovery
1200            };
1201            assert_eq!(congestion_control.assert_sack_recovery().recovery, expect_recovery);
1202
1203            let (start, end) = sack.into_parts();
1204            // Don't increase by full MSS to prove that duplicate ACKs alone are
1205            // putting us in this state.
1206            sack = SackBlock::try_new(start, end + u32::from(mss) / 4).unwrap();
1207        }
1208
1209        let end = sack.right();
1210        let bytes_acked = NonZeroU32::new(u32::try_from(end - ack).unwrap()).unwrap();
1211        let ack = end;
1212        assert_eq!(congestion_control.preprocess_ack(ack, snd_nxt, &SackBlocks::EMPTY), None);
1213
1214        let now = FakeInstant::default();
1215        let rtt = Some(Duration::from_millis(1));
1216
1217        // A cumulative ACK not covering the recovery point arrives.
1218        assert_eq!(congestion_control.on_ack(ack, bytes_acked, now, rtt), false);
1219        if dup_acks >= DUP_ACK_THRESHOLD {
1220            assert_eq!(
1221                congestion_control.assert_sack_recovery().recovery,
1222                SackRecoveryState::InRecovery(expect_recovery)
1223            );
1224        } else {
1225            assert_matches!(congestion_control.loss_recovery, None);
1226        }
1227
1228        // A cumulative ACK covering the recovery point arrives.
1229        let bytes_acked = NonZeroU32::new(u32::try_from(snd_nxt - ack).unwrap()).unwrap();
1230        let ack = snd_nxt;
1231        assert_eq!(
1232            congestion_control.on_ack(ack, bytes_acked, now, rtt),
1233            dup_acks >= DUP_ACK_THRESHOLD
1234        );
1235        assert_matches!(congestion_control.loss_recovery, None);
1236
1237        // A later cumulative ACK arrives.
1238        let snd_nxt = snd_nxt + 20;
1239        let ack = ack + 10;
1240        assert_eq!(congestion_control.preprocess_ack(ack, snd_nxt, &SackBlocks::EMPTY), None);
1241        assert_eq!(congestion_control.on_ack(ack, bytes_acked, now, rtt), false);
1242        assert_matches!(congestion_control.loss_recovery, None);
1243    }
1244
1245    #[test]
1246    fn sack_recovery_enter_loss_single_dupack() {
1247        let mut congestion_control = CongestionControl::<FakeInstant>::cubic_with_mss(MSS_1);
1248
1249        // SACK can enter recovery after a *single* duplicate ACK provided
1250        // enough information is in the scoreboard:
1251        let snd_nxt = SeqNum::new(100);
1252        let ack = SeqNum::new(0);
1253        assert_eq!(
1254            congestion_control.preprocess_ack(
1255                ack,
1256                snd_nxt,
1257                &testutil::sack_blocks([5..15, 25..35, 45..55])
1258            ),
1259            Some(true)
1260        );
1261        assert_eq!(
1262            congestion_control.on_dup_ack(ack, snd_nxt),
1263            Some(LossRecoveryMode::SackRecovery)
1264        );
1265        assert_eq!(
1266            congestion_control.assert_sack_recovery().recovery,
1267            SackRecoveryState::InRecovery(SackInRecoveryState {
1268                recovery_point: snd_nxt,
1269                high_rxt: ack,
1270                rescue_rxt: None
1271            })
1272        );
1273    }
1274
1275    #[test]
1276    fn sack_recovery_poll_send_not_recovery() {
1277        let mut scoreboard = SackScoreboard::default();
1278        let mut recovery = SackRecovery::new();
1279        let mss = MSS_1;
1280        let cwnd_mss = 10u32;
1281        let cwnd = CongestionWindow::new(cwnd_mss * u32::from(mss), mss);
1282        let snd_una = SeqNum::new(1);
1283
1284        let in_flight = 5u32;
1285        let snd_nxt = nth_segment_from(snd_una, mss, in_flight).start;
1286
1287        // When not in recovery, we delegate all the receiver window calculation
1288        // out. Prove that that's the case by telling SACK there's nothing to
1289        // send.
1290        let snd_wnd = WindowSize::ZERO;
1291        let available_bytes = 0;
1292
1293        let sack_block = SackBlock::try_from(nth_segment_from(snd_una, mss, 1)).unwrap();
1294        assert!(scoreboard.process_ack(
1295            snd_una,
1296            snd_nxt,
1297            None,
1298            &[sack_block].into_iter().collect(),
1299            mss
1300        ));
1301
1302        // With 1 SACK block this is how much window we expect to have available
1303        // in multiples of mss.
1304        let wnd_used = in_flight - 1;
1305        assert_eq!(scoreboard.pipe(), wnd_used * u32::from(mss));
1306        let wnd_available = cwnd_mss - wnd_used;
1307
1308        for i in 0..wnd_available {
1309            let snd_nxt = snd_nxt + i * u32::from(mss);
1310            assert_eq!(
1311                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1312                Some(CongestionControlSendOutcome {
1313                    next_seg: snd_nxt,
1314                    congestion_limit: mss.into(),
1315                    congestion_window: cwnd.cwnd(),
1316                    loss_recovery: LossRecoverySegment::No,
1317                })
1318            );
1319            scoreboard.increment_pipe(mss.into());
1320        }
1321
1322        // Used all of the window.
1323        assert_eq!(scoreboard.pipe(), cwnd.cwnd());
1324        // Poll send stops this round.
1325        assert_eq!(
1326            recovery.poll_send(
1327                cwnd,
1328                snd_una,
1329                snd_nxt + (wnd_used + 1) * u32::from(mss),
1330                snd_wnd,
1331                available_bytes,
1332                &scoreboard
1333            ),
1334            None
1335        );
1336    }
1337
1338    #[test_matrix(
1339        [MSS_1, MSS_2],
1340        [1, 3, 5],
1341        [StartingAck::One, StartingAck::Wraparound]
1342    )]
1343    fn sack_recovery_next_seg_rule_1(mss: EffectiveMss, lost_segments: u32, snd_una: StartingAck) {
1344        let mut scoreboard = SackScoreboard::default();
1345        let mut recovery = SackRecovery::new();
1346
1347        let snd_una = snd_una.into_seqnum(mss);
1348
1349        let sacked_segments = u32::from(DUP_ACK_THRESHOLD);
1350        let sacked_range = lost_segments..(lost_segments + sacked_segments);
1351        let in_flight = sacked_range.end + 5;
1352        let snd_nxt = nth_segment_from(snd_una, mss, in_flight).start;
1353
1354        // Define a congestion window that will only let us fill part of the
1355        // lost segments with rule 1.
1356        let cwnd_mss = in_flight - sacked_segments - 1;
1357        let cwnd = CongestionWindow::new(cwnd_mss * u32::from(mss), mss);
1358
1359        // Rule 1 should not care about available window size, since it's
1360        // retransmitting a lost segment.
1361        let snd_wnd = WindowSize::ZERO;
1362        let available_bytes = 0;
1363
1364        let sack_block = SackBlock::try_from(nth_range(snd_una, mss, sacked_range)).unwrap();
1365        assert!(scoreboard.process_ack(
1366            snd_una,
1367            snd_nxt,
1368            None,
1369            &[sack_block].into_iter().collect(),
1370            mss
1371        ));
1372
1373        // Verify that our set up math here is correct, we want recovery to be
1374        // able to fill only a part of the hole.
1375        assert_eq!(cwnd.cwnd() - scoreboard.pipe(), (lost_segments - 1) * u32::from(mss));
1376        // Enter recovery.
1377        assert_eq!(recovery.on_dup_ack(snd_una, snd_nxt, &scoreboard), SackDupAckOutcome(true));
1378
1379        for i in 0..(lost_segments - 1) {
1380            let next_seg = snd_una + i * u32::from(mss);
1381            assert_eq!(
1382                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1383                Some(CongestionControlSendOutcome {
1384                    next_seg,
1385                    congestion_limit: mss.into(),
1386                    congestion_window: cwnd.cwnd(),
1387                    loss_recovery: LossRecoverySegment::Yes {
1388                        rearm_retransmit: true,
1389                        mode: LossRecoveryMode::SackRecovery,
1390                    },
1391                })
1392            );
1393            scoreboard.increment_pipe(mss.into());
1394            assert_eq!(
1395                recovery.recovery,
1396                SackRecoveryState::InRecovery(SackInRecoveryState {
1397                    recovery_point: snd_nxt,
1398                    high_rxt: nth_segment_from(snd_una, mss, i).end,
1399                    // RescueRxt is always set to the first retransmitted
1400                    // segment.
1401                    rescue_rxt: Some(snd_una + u32::from(mss)),
1402                })
1403            );
1404        }
1405        // Ran out of CWND.
1406        assert_eq!(
1407            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1408            None
1409        );
1410    }
1411
1412    #[test_matrix(
1413        [MSS_1, MSS_2],
1414        [1, 3, 5],
1415        [StartingAck::One, StartingAck::Wraparound]
1416    )]
1417    fn sack_recovery_next_seg_rule_2(mss: EffectiveMss, expect_send: u32, snd_una: StartingAck) {
1418        let mut scoreboard = SackScoreboard::default();
1419        let mut recovery = SackRecovery::new();
1420
1421        let snd_una = snd_una.into_seqnum(mss);
1422
1423        let lost_segments = 1;
1424        let sacked_segments = u32::from(DUP_ACK_THRESHOLD);
1425        let sacked_range = lost_segments..(lost_segments + sacked_segments);
1426        let in_flight = sacked_range.end + 5;
1427        let mut snd_nxt = nth_segment_from(snd_una, mss, in_flight).start;
1428
1429        let sack_block = SackBlock::try_from(nth_range(snd_una, mss, sacked_range)).unwrap();
1430        assert!(scoreboard.process_ack(
1431            snd_una,
1432            snd_nxt,
1433            None,
1434            &[sack_block].into_iter().collect(),
1435            mss
1436        ));
1437
1438        // Define a congestion window that will allow us to send only the
1439        // desired segments.
1440        let cwnd = CongestionWindow::new(scoreboard.pipe() + expect_send * u32::from(mss), mss);
1441        // Enter recovery.
1442        assert_eq!(recovery.on_dup_ack(snd_una, snd_nxt, &scoreboard), SackDupAckOutcome(true));
1443        // Force HighRxt to the end of the lost block to skip rules 1 and 3.
1444        let recovery_state = recovery.assert_in_recovery();
1445        recovery_state.high_rxt = nth_segment_from(snd_una, mss, lost_segments - 1).end;
1446        // Force RecoveryRxt to skip rule 4.
1447        recovery_state.rescue_rxt = Some(snd_nxt);
1448        let state_snapshot = recovery_state.clone();
1449
1450        // Available bytes is always counted from SND.UNA.
1451        let baseline = u32::try_from(snd_nxt - snd_una).unwrap();
1452        // If there is no window or nothing to send, return.
1453        for (snd_wnd, available_bytes) in [(0, 0), (1, 0), (0, 1)] {
1454            let snd_wnd = WindowSize::from_u32(baseline + snd_wnd).unwrap();
1455            let available_bytes = usize::try_from(baseline + available_bytes).unwrap();
1456            assert_eq!(
1457                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1458                None
1459            );
1460            assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(state_snapshot));
1461        }
1462
1463        let baseline = baseline + (expect_send - 1) * u32::from(mss) + 1;
1464        let snd_wnd = WindowSize::from_u32(baseline).unwrap();
1465        let available_bytes = usize::try_from(baseline).unwrap();
1466        for _ in 0..expect_send {
1467            assert_eq!(
1468                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1469                Some(CongestionControlSendOutcome {
1470                    next_seg: snd_nxt,
1471                    congestion_limit: mss.into(),
1472                    congestion_window: cwnd.cwnd(),
1473                    loss_recovery: LossRecoverySegment::Yes {
1474                        rearm_retransmit: false,
1475                        mode: LossRecoveryMode::SackRecovery,
1476                    },
1477                })
1478            );
1479            assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(state_snapshot));
1480            scoreboard.increment_pipe(mss.into());
1481            snd_nxt = snd_nxt + u32::from(mss);
1482        }
1483        // Ran out of CWND.
1484        let snd_wnd = WindowSize::MAX;
1485        let available_bytes = usize::MAX;
1486        assert_eq!(
1487            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1488            None
1489        );
1490        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(state_snapshot));
1491    }
1492
1493    #[test_matrix(
1494        [MSS_1, MSS_2],
1495        [1, 3, 5],
1496        [StartingAck::One, StartingAck::Wraparound]
1497    )]
1498    fn sack_recovery_next_seg_rule_3(
1499        mss: EffectiveMss,
1500        not_lost_segments: u32,
1501        snd_una: StartingAck,
1502    ) {
1503        let mut scoreboard = SackScoreboard::default();
1504        let mut recovery = SackRecovery::new();
1505
1506        let snd_una = snd_una.into_seqnum(mss);
1507
1508        let first_lost_block = 1;
1509        let first_sacked_segments = u32::from(DUP_ACK_THRESHOLD);
1510        let first_sacked_range = first_lost_block..(first_lost_block + first_sacked_segments);
1511
1512        // "not_lost_segments" segments will not be considered lost by the
1513        // scoreboard, but they will not be sacked.
1514        let sacked_segments = 1;
1515        let sacked_range_start = first_sacked_range.end + not_lost_segments;
1516        let sacked_range = sacked_range_start..(sacked_range_start + sacked_segments);
1517
1518        let in_flight = sacked_range.end + 5;
1519        let snd_nxt = nth_segment_from(snd_una, mss, in_flight).start;
1520
1521        let sack_block1 =
1522            SackBlock::try_from(nth_range(snd_una, mss, first_sacked_range.clone())).unwrap();
1523        let sack_block2 = SackBlock::try_from(nth_range(snd_una, mss, sacked_range)).unwrap();
1524        assert!(scoreboard.process_ack(
1525            snd_una,
1526            snd_nxt,
1527            None,
1528            &[sack_block1, sack_block2].into_iter().collect(),
1529            mss
1530        ));
1531
1532        // Define a congestion window that will only let us fill part of the
1533        // lost segments with rule 3.
1534        let expect_send = (not_lost_segments - 1).max(1);
1535        let cwnd_mss =
1536            in_flight - first_sacked_segments - sacked_segments - first_lost_block + expect_send;
1537        let cwnd = CongestionWindow::new(cwnd_mss * u32::from(mss), mss);
1538
1539        // Rule 3 is only hit if we don't have enough available data to send.
1540        let snd_wnd = WindowSize::ZERO;
1541        let available_bytes = 0;
1542
1543        // Verify that our set up math here is correct, we want recovery to be
1544        // able to fill only a part of the hole.
1545        assert_eq!(cwnd.cwnd() - scoreboard.pipe(), expect_send * u32::from(mss));
1546        // Enter recovery.
1547        assert_eq!(recovery.on_dup_ack(snd_una, snd_nxt, &scoreboard), SackDupAckOutcome(true));
1548        // Poll while we expect to hit rule 1. Don't increment pipe here because
1549        // we set up our congestion window to stop only rule 3.
1550        for i in 0..first_lost_block {
1551            assert_eq!(
1552                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1553                Some(CongestionControlSendOutcome {
1554                    next_seg: nth_segment_from(snd_una, mss, i).start,
1555                    congestion_limit: mss.into(),
1556                    congestion_window: cwnd.cwnd(),
1557                    loss_recovery: LossRecoverySegment::Yes {
1558                        rearm_retransmit: true,
1559                        mode: LossRecoveryMode::SackRecovery,
1560                    },
1561                })
1562            );
1563        }
1564        let expect_recovery = SackInRecoveryState {
1565            recovery_point: snd_nxt,
1566            high_rxt: nth_segment_from(snd_una, mss, first_sacked_range.start).start,
1567            rescue_rxt: Some(nth_segment_from(snd_una, mss, 0).end),
1568        };
1569        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(expect_recovery));
1570
1571        for i in 0..expect_send {
1572            let next_seg = snd_una + (first_sacked_range.end + i) * u32::from(mss);
1573            assert_eq!(
1574                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1575                Some(CongestionControlSendOutcome {
1576                    next_seg,
1577                    congestion_limit: mss.into(),
1578                    congestion_window: cwnd.cwnd(),
1579                    loss_recovery: LossRecoverySegment::Yes {
1580                        rearm_retransmit: true,
1581                        mode: LossRecoveryMode::SackRecovery,
1582                    },
1583                })
1584            );
1585            scoreboard.increment_pipe(mss.into());
1586            assert_eq!(
1587                recovery.recovery,
1588                SackRecoveryState::InRecovery(SackInRecoveryState {
1589                    high_rxt: next_seg + u32::from(mss),
1590                    ..expect_recovery
1591                })
1592            );
1593        }
1594        // Ran out of CWND.
1595        assert_eq!(
1596            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1597            None
1598        );
1599    }
1600
1601    #[test_matrix(
1602        [MSS_1, MSS_2],
1603        [0, 1, 3],
1604        [StartingAck::One, StartingAck::Wraparound]
1605    )]
1606    fn sack_recovery_next_seg_rule_4(
1607        mss: EffectiveMss,
1608        right_edge_segments: u32,
1609        snd_una: StartingAck,
1610    ) {
1611        let mut scoreboard = SackScoreboard::default();
1612        let mut recovery = SackRecovery::new();
1613
1614        let snd_una = snd_una.into_seqnum(mss);
1615
1616        let lost_segments = 1;
1617        let sacked_segments = u32::from(DUP_ACK_THRESHOLD);
1618        let sacked_range = lost_segments..(lost_segments + sacked_segments);
1619        let in_flight = sacked_range.end + right_edge_segments + 2;
1620        let snd_nxt = nth_segment_from(snd_una, mss, in_flight).start;
1621
1622        // Rule 4 should only be hit if we don't have available data to send.
1623        let snd_wnd = WindowSize::ZERO;
1624        let available_bytes = 0;
1625
1626        let sack_block =
1627            SackBlock::try_from(nth_range(snd_una, mss, sacked_range.clone())).unwrap();
1628        assert!(scoreboard.process_ack(
1629            snd_una,
1630            snd_nxt,
1631            None,
1632            &[sack_block].into_iter().collect(),
1633            mss
1634        ));
1635
1636        // Define a very large congestion window, given rule 4 should only
1637        // retransmit a single segment.
1638        let cwnd = CongestionWindow::new((in_flight + 500) * u32::from(mss), mss);
1639
1640        // Enter recovery.
1641        assert_eq!(recovery.on_dup_ack(snd_una, snd_nxt, &scoreboard), SackDupAckOutcome(true));
1642        // Send the segments that match rule 1. Don't increment pipe here, we
1643        // want to show that rule 4 stops even when cwnd is entirely open.
1644        for i in 0..lost_segments {
1645            let next_seg = snd_una + i * u32::from(mss);
1646            assert_eq!(
1647                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1648                Some(CongestionControlSendOutcome {
1649                    next_seg,
1650                    congestion_limit: mss.into(),
1651                    congestion_window: cwnd.cwnd(),
1652                    loss_recovery: LossRecoverySegment::Yes {
1653                        rearm_retransmit: true,
1654                        mode: LossRecoveryMode::SackRecovery,
1655                    },
1656                })
1657            );
1658        }
1659        let expect_recovery = SackInRecoveryState {
1660            recovery_point: snd_nxt,
1661            high_rxt: nth_segment_from(snd_una, mss, lost_segments).start,
1662            // RescueRxt is always set to the first retransmitted
1663            // segment.
1664            rescue_rxt: Some(nth_segment_from(snd_una, mss, 0).end),
1665        };
1666        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(expect_recovery));
1667
1668        // Rule 4 should only hit after we receive an ACK past the first
1669        // RescueRxt value that was set.
1670        assert_eq!(
1671            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1672            None
1673        );
1674        // Acknowledge up to the sacked range, with one new sack block.
1675        let snd_una = nth_segment_from(snd_una, mss, sacked_range.end).start;
1676        let sack_block = SackBlock::try_from(nth_range(snd_una, mss, 1..2)).unwrap();
1677        assert!(scoreboard.process_ack(
1678            snd_una,
1679            snd_nxt,
1680            Some(expect_recovery.high_rxt),
1681            &[sack_block].into_iter().collect(),
1682            mss
1683        ));
1684        assert_eq!(recovery.on_ack(snd_una), LossRecoveryOnAckOutcome::None);
1685        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(expect_recovery));
1686        // Rule 3 will hit once here because we have a single not lost segment.
1687        assert_eq!(
1688            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1689            Some(CongestionControlSendOutcome {
1690                next_seg: snd_una,
1691                congestion_limit: mss.into(),
1692                congestion_window: cwnd.cwnd(),
1693                loss_recovery: LossRecoverySegment::Yes {
1694                    rearm_retransmit: true,
1695                    mode: LossRecoveryMode::SackRecovery,
1696                },
1697            })
1698        );
1699        let expect_recovery =
1700            SackInRecoveryState { high_rxt: snd_una + u32::from(mss), ..expect_recovery };
1701        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(expect_recovery));
1702
1703        // Now we should hit Rule 4, as long as we have unacknowledged data.
1704        if right_edge_segments > 0 {
1705            assert_eq!(
1706                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1707                Some(CongestionControlSendOutcome {
1708                    next_seg: snd_nxt - u32::from(mss),
1709                    congestion_limit: mss.into(),
1710                    congestion_window: cwnd.cwnd(),
1711                    loss_recovery: LossRecoverySegment::Yes {
1712                        rearm_retransmit: true,
1713                        mode: LossRecoveryMode::SackRecovery,
1714                    },
1715                })
1716            );
1717            assert_eq!(
1718                recovery.recovery,
1719                SackRecoveryState::InRecovery(SackInRecoveryState {
1720                    rescue_rxt: Some(expect_recovery.recovery_point),
1721                    ..expect_recovery
1722                })
1723            );
1724        }
1725
1726        // Once we've done the rescue it can't happen again.
1727        assert_eq!(
1728            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1729            None
1730        );
1731    }
1732
1733    #[test_matrix(
1734        [MSS_1, MSS_2],
1735        [
1736            StartingAck::One,
1737            StartingAck::WraparoundAfter(1),
1738            StartingAck::WraparoundAfter(2),
1739            StartingAck::WraparoundAfter(3),
1740            StartingAck::WraparoundAfter(4)
1741        ]
1742    )]
1743    fn sack_recovery_all_rules(mss: EffectiveMss, snd_una: StartingAck) {
1744        let snd_una = snd_una.into_seqnum(mss);
1745
1746        // Set up the scoreboard so we have 1 hole considered lost, that is hit
1747        // by Rule 1, and another that is not lost, hit by Rule 3.
1748        let mut scoreboard = SackScoreboard::default();
1749        let first_sacked_range = 1..(u32::from(DUP_ACK_THRESHOLD) + 1);
1750        let first_sack_block =
1751            SackBlock::try_from(nth_range(snd_una, mss, first_sacked_range.clone())).unwrap();
1752
1753        let second_sacked_range = (first_sacked_range.end + 1)..(first_sacked_range.end + 2);
1754        let second_sack_block =
1755            SackBlock::try_from(nth_range(snd_una, mss, second_sacked_range.clone())).unwrap();
1756
1757        let snd_nxt = nth_segment_from(snd_una, mss, second_sacked_range.end + 1).start;
1758
1759        // To hit Rule 4 in one run, set up a recovery state that looks
1760        // like we've already tried to fill one hole with Rule 1 and
1761        // received an ack for it.
1762        let high_rxt = snd_una;
1763        let rescue_rxt = Some(snd_una);
1764
1765        assert!(scoreboard.process_ack(
1766            snd_una,
1767            snd_nxt,
1768            Some(high_rxt),
1769            &[first_sack_block, second_sack_block].into_iter().collect(),
1770            mss
1771        ));
1772
1773        // Create a situation where a single sequential round of calls to
1774        // poll_send will hit each rule.
1775        let recovery_state = SackInRecoveryState { recovery_point: snd_nxt, high_rxt, rescue_rxt };
1776        let mut recovery = SackRecovery {
1777            dup_acks: DUP_ACK_THRESHOLD,
1778            recovery: SackRecoveryState::InRecovery(recovery_state),
1779        };
1780
1781        // Define a congestion window that allows sending a single segment,
1782        // we'll not update the pipe variable at each call so we should never
1783        // hit the congestion limit.
1784        let cwnd = CongestionWindow::new(scoreboard.pipe() + u32::from(mss), mss);
1785
1786        // Make exactly one segment available in the receiver window and send
1787        // buffer so we hit Rule 2 exactly once.
1788        let available = u32::try_from(snd_nxt - snd_una).unwrap() + 1;
1789        let snd_wnd = WindowSize::from_u32(available).unwrap();
1790        let available_bytes = usize::try_from(available).unwrap();
1791
1792        // Hit Rule 1.
1793        assert_eq!(
1794            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1795            Some(CongestionControlSendOutcome {
1796                next_seg: snd_una,
1797                congestion_limit: u32::from(mss),
1798                congestion_window: cwnd.cwnd(),
1799                loss_recovery: LossRecoverySegment::Yes {
1800                    rearm_retransmit: true,
1801                    mode: LossRecoveryMode::SackRecovery,
1802                },
1803            })
1804        );
1805        let recovery_state =
1806            SackInRecoveryState { high_rxt: snd_una + u32::from(mss), ..recovery_state };
1807        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(recovery_state));
1808
1809        // Hit Rule 2.
1810        assert_eq!(
1811            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1812            Some(CongestionControlSendOutcome {
1813                next_seg: snd_nxt,
1814                congestion_limit: u32::from(mss),
1815                congestion_window: cwnd.cwnd(),
1816                loss_recovery: LossRecoverySegment::Yes {
1817                    rearm_retransmit: false,
1818                    mode: LossRecoveryMode::SackRecovery,
1819                },
1820            })
1821        );
1822        // snd_nxt should advance.
1823        let snd_nxt = snd_nxt + u32::from(mss);
1824        // No change to recovery state.
1825        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(recovery_state));
1826
1827        // Hit Rule 3.
1828        assert_eq!(
1829            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1830            Some(CongestionControlSendOutcome {
1831                next_seg: nth_segment_from(snd_una, mss, first_sacked_range.end).start,
1832                congestion_limit: u32::from(mss),
1833                congestion_window: cwnd.cwnd(),
1834                loss_recovery: LossRecoverySegment::Yes {
1835                    rearm_retransmit: true,
1836                    mode: LossRecoveryMode::SackRecovery,
1837                },
1838            })
1839        );
1840        let recovery_state = SackInRecoveryState {
1841            high_rxt: nth_segment_from(snd_una, mss, second_sacked_range.start).start,
1842            ..recovery_state
1843        };
1844        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(recovery_state));
1845
1846        // Hit Rule 4.
1847        assert_eq!(
1848            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1849            Some(CongestionControlSendOutcome {
1850                next_seg: snd_nxt - u32::from(mss),
1851                congestion_limit: u32::from(mss),
1852                congestion_window: cwnd.cwnd(),
1853                loss_recovery: LossRecoverySegment::Yes {
1854                    rearm_retransmit: true,
1855                    mode: LossRecoveryMode::SackRecovery,
1856                },
1857            })
1858        );
1859        let recovery_state = SackInRecoveryState {
1860            rescue_rxt: Some(recovery_state.recovery_point),
1861            ..recovery_state
1862        };
1863        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(recovery_state));
1864
1865        // Hit all the rules. Nothing to send even if we still have cwnd.
1866        assert_eq!(
1867            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1868            None
1869        );
1870        assert!(cwnd.cwnd() - scoreboard.pipe() >= u32::from(mss));
1871    }
1872
1873    #[test]
1874    fn sack_rto() {
1875        let mss = MSS_1;
1876        let mut congestion_control = CongestionControl::<FakeInstant>::cubic_with_mss(mss);
1877
1878        let rto_snd_nxt = SeqNum::new(50);
1879        // Set ourselves up not in recovery.
1880        congestion_control.loss_recovery = Some(LossRecovery::SackRecovery(SackRecovery {
1881            dup_acks: DUP_ACK_THRESHOLD - 1,
1882            recovery: SackRecoveryState::NotInRecovery,
1883        }));
1884        congestion_control.on_retransmission_timeout(rto_snd_nxt);
1885        assert_matches!(congestion_control.loss_recovery, None);
1886
1887        // Set ourselves up in loss recovery.
1888        congestion_control.loss_recovery = Some(LossRecovery::SackRecovery(SackRecovery {
1889            dup_acks: DUP_ACK_THRESHOLD,
1890            recovery: SackRecoveryState::InRecovery(SackInRecoveryState {
1891                recovery_point: SeqNum::new(10),
1892                high_rxt: SeqNum::new(0),
1893                rescue_rxt: None,
1894            }),
1895        }));
1896        congestion_control.on_retransmission_timeout(rto_snd_nxt);
1897        assert_eq!(
1898            congestion_control.assert_sack_recovery().recovery,
1899            SackRecoveryState::PostRto { recovery_point: rto_snd_nxt }
1900        );
1901
1902        let snd_una = SeqNum::new(0);
1903        let snd_nxt = SeqNum::new(10);
1904        // While in RTO held off state, we always send next data as if we were
1905        // not in recovery.
1906        assert_eq!(
1907            congestion_control.poll_send(snd_una, snd_nxt, WindowSize::ZERO, 0),
1908            Some(CongestionControlSendOutcome {
1909                next_seg: snd_nxt,
1910                congestion_limit: u32::from(mss),
1911                congestion_window: congestion_control.inspect_cwnd().cwnd(),
1912                loss_recovery: LossRecoverySegment::No,
1913            })
1914        );
1915        // Receiving duplicate acks does not enter recovery.
1916        for _ in 0..DUP_ACK_THRESHOLD {
1917            assert_eq!(congestion_control.on_dup_ack(snd_una, snd_nxt), None);
1918        }
1919
1920        let now = FakeInstant::default();
1921        let rtt = Some(Duration::from_millis(1));
1922
1923        // Receiving an ack before the RTO recovery point does not stop
1924        // recovery.
1925        let bytes_acked = NonZeroU32::new(u32::try_from(snd_nxt - snd_una).unwrap()).unwrap();
1926        let snd_una = snd_nxt;
1927        assert!(!congestion_control.on_ack(snd_una, bytes_acked, now, rtt));
1928        assert_eq!(
1929            congestion_control.assert_sack_recovery().recovery,
1930            SackRecoveryState::PostRto { recovery_point: rto_snd_nxt }
1931        );
1932
1933        // Covering the recovery point allows us to discard recovery state.
1934        let bytes_acked = NonZeroU32::new(u32::try_from(rto_snd_nxt - snd_una).unwrap()).unwrap();
1935        let snd_una = rto_snd_nxt;
1936
1937        // Not considered a recovery event since RTO is the thing that recovered
1938        // us.
1939        assert_eq!(congestion_control.on_ack(snd_una, bytes_acked, now, rtt), false);
1940        assert_matches!(congestion_control.loss_recovery, None);
1941    }
1942
1943    #[test]
1944    fn dont_rearm_rto_past_recovery_point() {
1945        let mut scoreboard = SackScoreboard::default();
1946        let mss = MSS_1;
1947        let snd_una = SeqNum::new(1);
1948
1949        let recovery_point = nth_segment_from(snd_una, mss, 100).start;
1950        let snd_nxt = recovery_point + 100 * u32::from(mss);
1951
1952        let mut recovery = SackRecovery {
1953            dup_acks: DUP_ACK_THRESHOLD,
1954            recovery: SackRecoveryState::InRecovery(SackInRecoveryState {
1955                recovery_point,
1956                high_rxt: recovery_point,
1957                rescue_rxt: Some(recovery_point),
1958            }),
1959        };
1960
1961        let block1 = nth_range(snd_una, mss, 101..110);
1962        let block2 = nth_range(snd_una, mss, 111..112);
1963        assert!(
1964            scoreboard.process_ack(
1965                snd_una,
1966                snd_nxt,
1967                recovery.high_rxt(),
1968                &[SackBlock::try_from(block1).unwrap(), SackBlock::try_from(block2).unwrap()]
1969                    .into_iter()
1970                    .collect(),
1971                mss,
1972            )
1973        );
1974
1975        let cwnd = CongestionWindow::new(u32::MAX, mss);
1976
1977        let snd_wnd = WindowSize::ZERO;
1978        let available_bytes = 0;
1979
1980        assert_eq!(
1981            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1982            Some(CongestionControlSendOutcome {
1983                next_seg: nth_segment_from(snd_una, mss, 100).start,
1984                congestion_limit: mss.into(),
1985                congestion_window: cwnd.cwnd(),
1986                loss_recovery: LossRecoverySegment::Yes {
1987                    rearm_retransmit: false,
1988                    mode: LossRecoveryMode::SackRecovery,
1989                }
1990            })
1991        );
1992        assert_eq!(
1993            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1994            Some(CongestionControlSendOutcome {
1995                next_seg: nth_segment_from(snd_una, mss, 110).start,
1996                congestion_limit: mss.into(),
1997                congestion_window: cwnd.cwnd(),
1998                loss_recovery: LossRecoverySegment::Yes {
1999                    rearm_retransmit: false,
2000                    mode: LossRecoveryMode::SackRecovery,
2001                }
2002            })
2003        );
2004    }
2005
2006    // Parts of the state machine may end up rewinding SND.NXT to SND.UNA.
2007    // Ensure that NextSeg rule 4 implementation in SackRecovery (which is
2008    // sensitive to SND.NXT) gracefully handles that.
2009    #[test]
2010    fn sack_snd_nxt_rewind() {
2011        let mut scoreboard = SackScoreboard::default();
2012        let mss = MSS_1;
2013        let snd_una = SeqNum::new(1);
2014
2015        let recovery_point = nth_segment_from(snd_una, mss, 100).start;
2016        let snd_nxt = nth_segment_from(recovery_point, mss, 100).start;
2017
2018        let mut recovery = SackRecovery {
2019            dup_acks: DUP_ACK_THRESHOLD,
2020            recovery: SackRecoveryState::InRecovery(SackInRecoveryState {
2021                recovery_point,
2022                high_rxt: recovery_point,
2023                rescue_rxt: None,
2024            }),
2025        };
2026        let sack_block = SackBlock::try_from(nth_range(snd_una, mss, 1..5)).unwrap();
2027        assert!(scoreboard.process_ack(
2028            snd_una,
2029            snd_nxt,
2030            Some(recovery_point),
2031            &[sack_block].into_iter().collect(),
2032            mss,
2033        ));
2034        // Rewind.
2035        let snd_nxt = snd_una;
2036
2037        let cwnd = CongestionWindow::new(u32::MAX, mss);
2038        let snd_wnd = WindowSize::ZERO;
2039        let available_bytes = 0;
2040
2041        assert_eq!(
2042            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
2043            None
2044        );
2045    }
2046}