1use std::any::TypeId;
43use std::error::Error as StdError;
44use std::fmt;
45use std::io;
46use std::marker::Unpin;
47
48use bytes::Bytes;
49use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
50use tokio::sync::oneshot;
51#[cfg(any(feature = "http1", feature = "http2"))]
52use tracing::trace;
53
54use crate::common::io::Rewind;
55use crate::common::{task, Future, Pin, Poll};
56
57pub struct Upgraded {
66 io: Rewind<Box<dyn Io + Send>>,
67}
68
69pub struct OnUpgrade {
73 rx: Option<oneshot::Receiver<crate::Result<Upgraded>>>,
74}
75
76#[derive(Debug)]
81pub struct Parts<T> {
82 pub io: T,
84 pub read_buf: Bytes,
93 _inner: (),
94}
95
96pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
105 msg.on_upgrade()
106}
107
108#[cfg(any(feature = "http1", feature = "http2"))]
109pub(super) struct Pending {
110 tx: oneshot::Sender<crate::Result<Upgraded>>,
111}
112
113#[cfg(any(feature = "http1", feature = "http2"))]
114pub(super) fn pending() -> (Pending, OnUpgrade) {
115 let (tx, rx) = oneshot::channel();
116 (Pending { tx }, OnUpgrade { rx: Some(rx) })
117}
118
119impl Upgraded {
122 #[cfg(any(feature = "http1", feature = "http2", test))]
123 pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
124 where
125 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
126 {
127 Upgraded {
128 io: Rewind::new_buffered(Box::new(io), read_buf),
129 }
130 }
131
132 pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
137 let (io, buf) = self.io.into_inner();
138 match io.__hyper_downcast() {
139 Ok(t) => Ok(Parts {
140 io: *t,
141 read_buf: buf,
142 _inner: (),
143 }),
144 Err(io) => Err(Upgraded {
145 io: Rewind::new_buffered(io, buf),
146 }),
147 }
148 }
149}
150
151impl AsyncRead for Upgraded {
152 fn poll_read(
153 mut self: Pin<&mut Self>,
154 cx: &mut task::Context<'_>,
155 buf: &mut ReadBuf<'_>,
156 ) -> Poll<io::Result<()>> {
157 Pin::new(&mut self.io).poll_read(cx, buf)
158 }
159}
160
161impl AsyncWrite for Upgraded {
162 fn poll_write(
163 mut self: Pin<&mut Self>,
164 cx: &mut task::Context<'_>,
165 buf: &[u8],
166 ) -> Poll<io::Result<usize>> {
167 Pin::new(&mut self.io).poll_write(cx, buf)
168 }
169
170 fn poll_write_vectored(
171 mut self: Pin<&mut Self>,
172 cx: &mut task::Context<'_>,
173 bufs: &[io::IoSlice<'_>],
174 ) -> Poll<io::Result<usize>> {
175 Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
176 }
177
178 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
179 Pin::new(&mut self.io).poll_flush(cx)
180 }
181
182 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
183 Pin::new(&mut self.io).poll_shutdown(cx)
184 }
185
186 fn is_write_vectored(&self) -> bool {
187 self.io.is_write_vectored()
188 }
189}
190
191impl fmt::Debug for Upgraded {
192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193 f.debug_struct("Upgraded").finish()
194 }
195}
196
197impl OnUpgrade {
200 pub(super) fn none() -> Self {
201 OnUpgrade { rx: None }
202 }
203
204 #[cfg(feature = "http1")]
205 pub(super) fn is_none(&self) -> bool {
206 self.rx.is_none()
207 }
208}
209
210impl Future for OnUpgrade {
211 type Output = Result<Upgraded, crate::Error>;
212
213 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
214 match self.rx {
215 Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res {
216 Ok(Ok(upgraded)) => Ok(upgraded),
217 Ok(Err(err)) => Err(err),
218 Err(_oneshot_canceled) => Err(crate::Error::new_canceled().with(UpgradeExpected)),
219 }),
220 None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())),
221 }
222 }
223}
224
225impl fmt::Debug for OnUpgrade {
226 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227 f.debug_struct("OnUpgrade").finish()
228 }
229}
230
231#[cfg(any(feature = "http1", feature = "http2"))]
234impl Pending {
235 pub(super) fn fulfill(self, upgraded: Upgraded) {
236 trace!("pending upgrade fulfill");
237 let _ = self.tx.send(Ok(upgraded));
238 }
239
240 #[cfg(feature = "http1")]
241 pub(super) fn manual(self) {
244 trace!("pending upgrade handled manually");
245 let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade()));
246 }
247}
248
249#[derive(Debug)]
256struct UpgradeExpected;
257
258impl fmt::Display for UpgradeExpected {
259 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260 f.write_str("upgrade expected but not completed")
261 }
262}
263
264impl StdError for UpgradeExpected {}
265
266pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static {
269 fn __hyper_type_id(&self) -> TypeId {
270 TypeId::of::<Self>()
271 }
272}
273
274impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {}
275
276impl dyn Io + Send {
277 fn __hyper_is<T: Io>(&self) -> bool {
278 let t = TypeId::of::<T>();
279 self.__hyper_type_id() == t
280 }
281
282 fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
283 if self.__hyper_is::<T>() {
284 unsafe {
286 let raw: *mut dyn Io = Box::into_raw(self);
287 Ok(Box::from_raw(raw as *mut T))
288 }
289 } else {
290 Err(self)
291 }
292 }
293}
294
295mod sealed {
296 use super::OnUpgrade;
297
298 pub trait CanUpgrade {
299 fn on_upgrade(self) -> OnUpgrade;
300 }
301
302 impl<B> CanUpgrade for http::Request<B> {
303 fn on_upgrade(mut self) -> OnUpgrade {
304 self.extensions_mut()
305 .remove::<OnUpgrade>()
306 .unwrap_or_else(OnUpgrade::none)
307 }
308 }
309
310 impl<B> CanUpgrade for &'_ mut http::Request<B> {
311 fn on_upgrade(self) -> OnUpgrade {
312 self.extensions_mut()
313 .remove::<OnUpgrade>()
314 .unwrap_or_else(OnUpgrade::none)
315 }
316 }
317
318 impl<B> CanUpgrade for http::Response<B> {
319 fn on_upgrade(mut self) -> OnUpgrade {
320 self.extensions_mut()
321 .remove::<OnUpgrade>()
322 .unwrap_or_else(OnUpgrade::none)
323 }
324 }
325
326 impl<B> CanUpgrade for &'_ mut http::Response<B> {
327 fn on_upgrade(self) -> OnUpgrade {
328 self.extensions_mut()
329 .remove::<OnUpgrade>()
330 .unwrap_or_else(OnUpgrade::none)
331 }
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn upgraded_downcast() {
341 let upgraded = Upgraded::new(Mock, Bytes::new());
342
343 let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err();
344
345 upgraded.downcast::<Mock>().unwrap();
346 }
347
348 struct Mock;
350
351 impl AsyncRead for Mock {
352 fn poll_read(
353 self: Pin<&mut Self>,
354 _cx: &mut task::Context<'_>,
355 _buf: &mut ReadBuf<'_>,
356 ) -> Poll<io::Result<()>> {
357 unreachable!("Mock::poll_read")
358 }
359 }
360
361 impl AsyncWrite for Mock {
362 fn poll_write(
363 self: Pin<&mut Self>,
364 _: &mut task::Context<'_>,
365 buf: &[u8],
366 ) -> Poll<io::Result<usize>> {
367 Poll::Ready(Ok(buf.len()))
369 }
370
371 fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
372 unreachable!("Mock::poll_flush")
373 }
374
375 fn poll_shutdown(
376 self: Pin<&mut Self>,
377 _cx: &mut task::Context<'_>,
378 ) -> Poll<io::Result<()>> {
379 unreachable!("Mock::poll_shutdown")
380 }
381 }
382}