http_sse/
server.rs

1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::Event;
6use futures::channel::{mpsc, oneshot};
7use futures::future::Future;
8use futures::lock::Mutex;
9use futures::stream::Stream;
10use futures::task::{Context, Poll};
11use hyper::body::{Body, Bytes};
12use hyper::{Response, StatusCode};
13use std::mem::replace;
14use std::ops::DerefMut;
15use std::pin::Pin;
16use std::sync::Arc;
17
18pub struct SseResponseCreator {
19    buffer_size: usize,
20    clients: Arc<Mutex<Vec<Client>>>,
21}
22
23impl SseResponseCreator {
24    /// hyper `Response` `Body`s created by this `SseResponseCreator` will buffer
25    /// `buffer_size + 1` `Events` before the `Body` stream is closed for falling too far behind.
26    pub fn with_additional_buffer_size(buffer_size: usize) -> (Self, EventSender) {
27        let clients = Arc::new(Mutex::new(vec![]));
28        (Self { buffer_size, clients: Arc::clone(&clients) }, EventSender { clients })
29    }
30
31    /// Creates hyper `Response`s whose `Body`s receive `Event`s from the `EventSender` associated
32    /// with this `SseResponseCreator`.
33    pub async fn create(&self) -> Response<Body> {
34        let (abort_tx, abort_rx) = oneshot::channel();
35        let (chunk_tx, chunk_rx) = mpsc::channel(self.buffer_size);
36        self.clients.lock().await.push(Client { abort_tx, chunk_tx });
37        Response::builder()
38            .status(StatusCode::OK)
39            .header("content-type", "text/event-stream")
40            .body(Body::wrap_stream(BodyAbortStream { abort_rx, chunk_rx }))
41            .unwrap() // builder arguments are all statically determined, build will not fail
42    }
43}
44
45pub struct EventSender {
46    clients: Arc<Mutex<Vec<Client>>>,
47}
48
49impl EventSender {
50    /// Send an `Event` to each connected client. Clients that have fallen too far behind have
51    /// their connections closed.
52    pub async fn send(&self, event: &Event) {
53        let mut clients_guard = self.clients.lock().await;
54        let clients = replace(DerefMut::deref_mut(&mut clients_guard), vec![]);
55        let clients = clients
56            .into_iter()
57            .filter_map(|mut c| {
58                if c.try_send(event).is_ok() {
59                    Some(c)
60                } else {
61                    let _ = c.abort();
62                    None
63                }
64            })
65            .collect();
66        *clients_guard = clients;
67    }
68
69    /// Number of clients that `send` will attempt to communicate with.
70    pub async fn client_count(&self) -> usize {
71        self.clients.lock().await.len()
72    }
73
74    /// Drops all connected clients. Already existing `Response<Body>`s created by the
75    /// `SseResponseCreator` should return error on subsequent `poll_next`.
76    pub async fn drop_all_clients(&self) {
77        self.clients.lock().await.clear();
78    }
79}
80
81// reimplementation of the body created by hyper::body::body::channel() b/c hyper doesn't allow
82// specifying the buffer size and doesn't provide an abort channel.
83struct BodyAbortStream {
84    abort_rx: oneshot::Receiver<()>,
85    chunk_rx: mpsc::Receiver<Bytes>,
86}
87
88impl Stream for BodyAbortStream {
89    type Item = Result<Bytes, &'static str>;
90
91    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
92        if let Poll::Ready(_) = Pin::new(&mut self.abort_rx).poll(cx) {
93            return Poll::Ready(Some(Err("client dropped")));
94        }
95        match Pin::new(&mut self.chunk_rx).poll_next(cx) {
96            Poll::Ready(Some(chunk)) => Poll::Ready(Some(Ok(chunk))),
97            Poll::Ready(None) => Poll::Ready(None),
98            Poll::Pending => Poll::Pending,
99        }
100    }
101}
102
103// Reimplementation of hyper::body::Sender b/c in-tree hyper doesn't allow specifying the buffer
104// size and doesn't provide an abort channel.
105struct Client {
106    abort_tx: oneshot::Sender<()>,
107    chunk_tx: mpsc::Sender<Bytes>,
108}
109
110impl Client {
111    fn try_send(&mut self, event: &Event) -> Result<(), ()> {
112        self.chunk_tx.try_send(event.to_vec().into()).map_err(|_| ())
113    }
114    fn abort(self) {
115        let _ = self.abort_tx.send(());
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use assert_matches::assert_matches;
123    use fuchsia_async::{self as fasync};
124    use futures::StreamExt;
125
126    #[fasync::run_singlethreaded(test)]
127    async fn response_headers() {
128        let (sse_response_creator, _) = SseResponseCreator::with_additional_buffer_size(0);
129        let resp = sse_response_creator.create().await;
130
131        assert_eq!(resp.status(), StatusCode::OK);
132        assert_eq!(
133            resp.headers().get("content-type").map(|h| h.as_bytes()),
134            Some(&b"text/event-stream"[..])
135        );
136    }
137
138    #[fasync::run_singlethreaded(test)]
139    async fn response_correct_body_single_event() {
140        let event = Event::from_type_and_data("event_type", "data_contents").unwrap();
141        let (sse_response_creator, event_sender) =
142            SseResponseCreator::with_additional_buffer_size(0);
143        let resp = sse_response_creator.create().await;
144
145        event_sender.send(&event).await;
146        let mut body_stream = resp.into_body();
147        let body_bytes = body_stream.next().await;
148
149        assert_eq!(body_bytes.unwrap().unwrap().to_vec(), event.to_vec());
150    }
151
152    #[fasync::run_singlethreaded(test)]
153    async fn full_client_dropped_other_clients_continue_to_receive_events() {
154        let event0 = Event::from_type_and_data("event_type0", "data_contents0").unwrap();
155        let (sse_response_creator, event_sender) =
156            SseResponseCreator::with_additional_buffer_size(0);
157        assert_eq!(event_sender.client_count().await, 0);
158
159        let mut body_stream0 = sse_response_creator.create().await.into_body();
160        let mut body_stream1 = sse_response_creator.create().await.into_body();
161        assert_eq!(event_sender.client_count().await, 2);
162
163        event_sender.send(&event0).await;
164
165        let body_bytes1 = body_stream1.next().await;
166
167        assert_matches!(body_bytes1, Some(Ok(chunk)) if chunk.to_vec() == event0.to_vec());
168
169        let event1 = Event::from_type_and_data("event_type1", "data_contents1").unwrap();
170        event_sender.send(&event1).await;
171        assert_eq!(event_sender.client_count().await, 1);
172
173        let body_bytes0 = body_stream0.next().await;
174        assert_matches!(body_bytes0, Some(Err(_)));
175
176        let body_bytes1 = body_stream1.next().await;
177        assert_eq!(body_bytes1.unwrap().unwrap().to_vec(), event1.to_vec());
178    }
179
180    #[fasync::run_singlethreaded(test)]
181    async fn drop_all_clients() {
182        let (sse_response_creator, event_sender) =
183            SseResponseCreator::with_additional_buffer_size(0);
184        let mut body_stream0 = sse_response_creator.create().await.into_body();
185        let mut body_stream1 = sse_response_creator.create().await.into_body();
186        assert_eq!(event_sender.client_count().await, 2);
187        event_sender.send(&Event::from_type_and_data("event_type", "data_contents").unwrap()).await;
188
189        event_sender.drop_all_clients().await;
190
191        assert_eq!(event_sender.client_count().await, 0);
192        assert_matches!(body_stream0.next().await, Some(Err(_)));
193        assert_matches!(body_stream1.next().await, Some(Err(_)));
194    }
195}