hyper/
upgrade.rs

1//! HTTP Upgrades
2//!
3//! This module deals with managing [HTTP Upgrades][mdn] in hyper. Since
4//! several concepts in HTTP allow for first talking HTTP, and then converting
5//! to a different protocol, this module conflates them into a single API.
6//! Those include:
7//!
8//! - HTTP/1.1 Upgrades
9//! - HTTP `CONNECT`
10//!
11//! You are responsible for any other pre-requisites to establish an upgrade,
12//! such as sending the appropriate headers, methods, and status codes. You can
13//! then use [`on`][] to grab a `Future` which will resolve to the upgraded
14//! connection object, or an error if the upgrade fails.
15//!
16//! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism
17//!
18//! # Client
19//!
20//! Sending an HTTP upgrade from the [`client`](super::client) involves setting
21//! either the appropriate method, if wanting to `CONNECT`, or headers such as
22//! `Upgrade` and `Connection`, on the `http::Request`. Once receiving the
23//! `http::Response` back, you must check for the specific information that the
24//! upgrade is agreed upon by the server (such as a `101` status code), and then
25//! get the `Future` from the `Response`.
26//!
27//! # Server
28//!
29//! Receiving upgrade requests in a server requires you to check the relevant
30//! headers in a `Request`, and if an upgrade should be done, you then send the
31//! corresponding headers in a response. To then wait for hyper to finish the
32//! upgrade, you call `on()` with the `Request`, and then can spawn a task
33//! awaiting it.
34//!
35//! # Example
36//!
37//! See [this example][example] showing how upgrades work with both
38//! Clients and Servers.
39//!
40//! [example]: https://github.com/hyperium/hyper/blob/master/examples/upgrades.rs
41
42use std::any::TypeId;
43use std::error::Error as StdError;
44use std::fmt;
45use std::io;
46use std::marker::Unpin;
47
48use bytes::Bytes;
49use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
50use tokio::sync::oneshot;
51#[cfg(any(feature = "http1", feature = "http2"))]
52use tracing::trace;
53
54use crate::common::io::Rewind;
55use crate::common::{task, Future, Pin, Poll};
56
57/// An upgraded HTTP connection.
58///
59/// This type holds a trait object internally of the original IO that
60/// was used to speak HTTP before the upgrade. It can be used directly
61/// as a `Read` or `Write` for convenience.
62///
63/// Alternatively, if the exact type is known, this can be deconstructed
64/// into its parts.
65pub struct Upgraded {
66    io: Rewind<Box<dyn Io + Send>>,
67}
68
69/// A future for a possible HTTP upgrade.
70///
71/// If no upgrade was available, or it doesn't succeed, yields an `Error`.
72pub struct OnUpgrade {
73    rx: Option<oneshot::Receiver<crate::Result<Upgraded>>>,
74}
75
76/// The deconstructed parts of an [`Upgraded`](Upgraded) type.
77///
78/// Includes the original IO type, and a read buffer of bytes that the
79/// HTTP state machine may have already read before completing an upgrade.
80#[derive(Debug)]
81pub struct Parts<T> {
82    /// The original IO object used before the upgrade.
83    pub io: T,
84    /// A buffer of bytes that have been read but not processed as HTTP.
85    ///
86    /// For instance, if the `Connection` is used for an HTTP upgrade request,
87    /// it is possible the server sent back the first bytes of the new protocol
88    /// along with the response upgrade.
89    ///
90    /// You will want to check for any existing bytes if you plan to continue
91    /// communicating on the IO object.
92    pub read_buf: Bytes,
93    _inner: (),
94}
95
96/// Gets a pending HTTP upgrade from this message.
97///
98/// This can be called on the following types:
99///
100/// - `http::Request<B>`
101/// - `http::Response<B>`
102/// - `&mut http::Request<B>`
103/// - `&mut http::Response<B>`
104pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
105    msg.on_upgrade()
106}
107
108#[cfg(any(feature = "http1", feature = "http2"))]
109pub(super) struct Pending {
110    tx: oneshot::Sender<crate::Result<Upgraded>>,
111}
112
113#[cfg(any(feature = "http1", feature = "http2"))]
114pub(super) fn pending() -> (Pending, OnUpgrade) {
115    let (tx, rx) = oneshot::channel();
116    (Pending { tx }, OnUpgrade { rx: Some(rx) })
117}
118
119// ===== impl Upgraded =====
120
121impl Upgraded {
122    #[cfg(any(feature = "http1", feature = "http2", test))]
123    pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
124    where
125        T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
126    {
127        Upgraded {
128            io: Rewind::new_buffered(Box::new(io), read_buf),
129        }
130    }
131
132    /// Tries to downcast the internal trait object to the type passed.
133    ///
134    /// On success, returns the downcasted parts. On error, returns the
135    /// `Upgraded` back.
136    pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
137        let (io, buf) = self.io.into_inner();
138        match io.__hyper_downcast() {
139            Ok(t) => Ok(Parts {
140                io: *t,
141                read_buf: buf,
142                _inner: (),
143            }),
144            Err(io) => Err(Upgraded {
145                io: Rewind::new_buffered(io, buf),
146            }),
147        }
148    }
149}
150
151impl AsyncRead for Upgraded {
152    fn poll_read(
153        mut self: Pin<&mut Self>,
154        cx: &mut task::Context<'_>,
155        buf: &mut ReadBuf<'_>,
156    ) -> Poll<io::Result<()>> {
157        Pin::new(&mut self.io).poll_read(cx, buf)
158    }
159}
160
161impl AsyncWrite for Upgraded {
162    fn poll_write(
163        mut self: Pin<&mut Self>,
164        cx: &mut task::Context<'_>,
165        buf: &[u8],
166    ) -> Poll<io::Result<usize>> {
167        Pin::new(&mut self.io).poll_write(cx, buf)
168    }
169
170    fn poll_write_vectored(
171        mut self: Pin<&mut Self>,
172        cx: &mut task::Context<'_>,
173        bufs: &[io::IoSlice<'_>],
174    ) -> Poll<io::Result<usize>> {
175        Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
176    }
177
178    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
179        Pin::new(&mut self.io).poll_flush(cx)
180    }
181
182    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
183        Pin::new(&mut self.io).poll_shutdown(cx)
184    }
185
186    fn is_write_vectored(&self) -> bool {
187        self.io.is_write_vectored()
188    }
189}
190
191impl fmt::Debug for Upgraded {
192    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193        f.debug_struct("Upgraded").finish()
194    }
195}
196
197// ===== impl OnUpgrade =====
198
199impl OnUpgrade {
200    pub(super) fn none() -> Self {
201        OnUpgrade { rx: None }
202    }
203
204    #[cfg(feature = "http1")]
205    pub(super) fn is_none(&self) -> bool {
206        self.rx.is_none()
207    }
208}
209
210impl Future for OnUpgrade {
211    type Output = Result<Upgraded, crate::Error>;
212
213    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
214        match self.rx {
215            Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res {
216                Ok(Ok(upgraded)) => Ok(upgraded),
217                Ok(Err(err)) => Err(err),
218                Err(_oneshot_canceled) => Err(crate::Error::new_canceled().with(UpgradeExpected)),
219            }),
220            None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())),
221        }
222    }
223}
224
225impl fmt::Debug for OnUpgrade {
226    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227        f.debug_struct("OnUpgrade").finish()
228    }
229}
230
231// ===== impl Pending =====
232
233#[cfg(any(feature = "http1", feature = "http2"))]
234impl Pending {
235    pub(super) fn fulfill(self, upgraded: Upgraded) {
236        trace!("pending upgrade fulfill");
237        let _ = self.tx.send(Ok(upgraded));
238    }
239
240    #[cfg(feature = "http1")]
241    /// Don't fulfill the pending Upgrade, but instead signal that
242    /// upgrades are handled manually.
243    pub(super) fn manual(self) {
244        trace!("pending upgrade handled manually");
245        let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade()));
246    }
247}
248
249// ===== impl UpgradeExpected =====
250
251/// Error cause returned when an upgrade was expected but canceled
252/// for whatever reason.
253///
254/// This likely means the actual `Conn` future wasn't polled and upgraded.
255#[derive(Debug)]
256struct UpgradeExpected;
257
258impl fmt::Display for UpgradeExpected {
259    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260        f.write_str("upgrade expected but not completed")
261    }
262}
263
264impl StdError for UpgradeExpected {}
265
266// ===== impl Io =====
267
268pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static {
269    fn __hyper_type_id(&self) -> TypeId {
270        TypeId::of::<Self>()
271    }
272}
273
274impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {}
275
276impl dyn Io + Send {
277    fn __hyper_is<T: Io>(&self) -> bool {
278        let t = TypeId::of::<T>();
279        self.__hyper_type_id() == t
280    }
281
282    fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
283        if self.__hyper_is::<T>() {
284            // Taken from `std::error::Error::downcast()`.
285            unsafe {
286                let raw: *mut dyn Io = Box::into_raw(self);
287                Ok(Box::from_raw(raw as *mut T))
288            }
289        } else {
290            Err(self)
291        }
292    }
293}
294
295mod sealed {
296    use super::OnUpgrade;
297
298    pub trait CanUpgrade {
299        fn on_upgrade(self) -> OnUpgrade;
300    }
301
302    impl<B> CanUpgrade for http::Request<B> {
303        fn on_upgrade(mut self) -> OnUpgrade {
304            self.extensions_mut()
305                .remove::<OnUpgrade>()
306                .unwrap_or_else(OnUpgrade::none)
307        }
308    }
309
310    impl<B> CanUpgrade for &'_ mut http::Request<B> {
311        fn on_upgrade(self) -> OnUpgrade {
312            self.extensions_mut()
313                .remove::<OnUpgrade>()
314                .unwrap_or_else(OnUpgrade::none)
315        }
316    }
317
318    impl<B> CanUpgrade for http::Response<B> {
319        fn on_upgrade(mut self) -> OnUpgrade {
320            self.extensions_mut()
321                .remove::<OnUpgrade>()
322                .unwrap_or_else(OnUpgrade::none)
323        }
324    }
325
326    impl<B> CanUpgrade for &'_ mut http::Response<B> {
327        fn on_upgrade(self) -> OnUpgrade {
328            self.extensions_mut()
329                .remove::<OnUpgrade>()
330                .unwrap_or_else(OnUpgrade::none)
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn upgraded_downcast() {
341        let upgraded = Upgraded::new(Mock, Bytes::new());
342
343        let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err();
344
345        upgraded.downcast::<Mock>().unwrap();
346    }
347
348    // TODO: replace with tokio_test::io when it can test write_buf
349    struct Mock;
350
351    impl AsyncRead for Mock {
352        fn poll_read(
353            self: Pin<&mut Self>,
354            _cx: &mut task::Context<'_>,
355            _buf: &mut ReadBuf<'_>,
356        ) -> Poll<io::Result<()>> {
357            unreachable!("Mock::poll_read")
358        }
359    }
360
361    impl AsyncWrite for Mock {
362        fn poll_write(
363            self: Pin<&mut Self>,
364            _: &mut task::Context<'_>,
365            buf: &[u8],
366        ) -> Poll<io::Result<usize>> {
367            // panic!("poll_write shouldn't be called");
368            Poll::Ready(Ok(buf.len()))
369        }
370
371        fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
372            unreachable!("Mock::poll_flush")
373        }
374
375        fn poll_shutdown(
376            self: Pin<&mut Self>,
377            _cx: &mut task::Context<'_>,
378        ) -> Poll<io::Result<()>> {
379            unreachable!("Mock::poll_shutdown")
380        }
381    }
382}