1use crate::Event;
6use futures::channel::{mpsc, oneshot};
7use futures::future::Future;
8use futures::lock::Mutex;
9use futures::stream::Stream;
10use futures::task::{Context, Poll};
11use hyper::body::{Body, Bytes};
12use hyper::{Response, StatusCode};
13use std::mem::replace;
14use std::ops::DerefMut;
15use std::pin::Pin;
16use std::sync::Arc;
17
18pub struct SseResponseCreator {
19 buffer_size: usize,
20 clients: Arc<Mutex<Vec<Client>>>,
21}
22
23impl SseResponseCreator {
24 pub fn with_additional_buffer_size(buffer_size: usize) -> (Self, EventSender) {
27 let clients = Arc::new(Mutex::new(vec![]));
28 (Self { buffer_size, clients: Arc::clone(&clients) }, EventSender { clients })
29 }
30
31 pub async fn create(&self) -> Response<Body> {
34 let (abort_tx, abort_rx) = oneshot::channel();
35 let (chunk_tx, chunk_rx) = mpsc::channel(self.buffer_size);
36 self.clients.lock().await.push(Client { abort_tx, chunk_tx });
37 Response::builder()
38 .status(StatusCode::OK)
39 .header("content-type", "text/event-stream")
40 .body(Body::wrap_stream(BodyAbortStream { abort_rx, chunk_rx }))
41 .unwrap() }
43}
44
45pub struct EventSender {
46 clients: Arc<Mutex<Vec<Client>>>,
47}
48
49impl EventSender {
50 pub async fn send(&self, event: &Event) {
53 let mut clients_guard = self.clients.lock().await;
54 let clients = replace(DerefMut::deref_mut(&mut clients_guard), vec![]);
55 let clients = clients
56 .into_iter()
57 .filter_map(|mut c| {
58 if c.try_send(event).is_ok() {
59 Some(c)
60 } else {
61 let _ = c.abort();
62 None
63 }
64 })
65 .collect();
66 *clients_guard = clients;
67 }
68
69 pub async fn client_count(&self) -> usize {
71 self.clients.lock().await.len()
72 }
73
74 pub async fn drop_all_clients(&self) {
77 self.clients.lock().await.clear();
78 }
79}
80
81struct BodyAbortStream {
84 abort_rx: oneshot::Receiver<()>,
85 chunk_rx: mpsc::Receiver<Bytes>,
86}
87
88impl Stream for BodyAbortStream {
89 type Item = Result<Bytes, &'static str>;
90
91 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
92 if let Poll::Ready(_) = Pin::new(&mut self.abort_rx).poll(cx) {
93 return Poll::Ready(Some(Err("client dropped")));
94 }
95 match Pin::new(&mut self.chunk_rx).poll_next(cx) {
96 Poll::Ready(Some(chunk)) => Poll::Ready(Some(Ok(chunk))),
97 Poll::Ready(None) => Poll::Ready(None),
98 Poll::Pending => Poll::Pending,
99 }
100 }
101}
102
103struct Client {
106 abort_tx: oneshot::Sender<()>,
107 chunk_tx: mpsc::Sender<Bytes>,
108}
109
110impl Client {
111 fn try_send(&mut self, event: &Event) -> Result<(), ()> {
112 self.chunk_tx.try_send(event.to_vec().into()).map_err(|_| ())
113 }
114 fn abort(self) {
115 let _ = self.abort_tx.send(());
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use assert_matches::assert_matches;
123 use fuchsia_async::{self as fasync};
124 use futures::StreamExt;
125
126 #[fasync::run_singlethreaded(test)]
127 async fn response_headers() {
128 let (sse_response_creator, _) = SseResponseCreator::with_additional_buffer_size(0);
129 let resp = sse_response_creator.create().await;
130
131 assert_eq!(resp.status(), StatusCode::OK);
132 assert_eq!(
133 resp.headers().get("content-type").map(|h| h.as_bytes()),
134 Some(&b"text/event-stream"[..])
135 );
136 }
137
138 #[fasync::run_singlethreaded(test)]
139 async fn response_correct_body_single_event() {
140 let event = Event::from_type_and_data("event_type", "data_contents").unwrap();
141 let (sse_response_creator, event_sender) =
142 SseResponseCreator::with_additional_buffer_size(0);
143 let resp = sse_response_creator.create().await;
144
145 event_sender.send(&event).await;
146 let mut body_stream = resp.into_body();
147 let body_bytes = body_stream.next().await;
148
149 assert_eq!(body_bytes.unwrap().unwrap().to_vec(), event.to_vec());
150 }
151
152 #[fasync::run_singlethreaded(test)]
153 async fn full_client_dropped_other_clients_continue_to_receive_events() {
154 let event0 = Event::from_type_and_data("event_type0", "data_contents0").unwrap();
155 let (sse_response_creator, event_sender) =
156 SseResponseCreator::with_additional_buffer_size(0);
157 assert_eq!(event_sender.client_count().await, 0);
158
159 let mut body_stream0 = sse_response_creator.create().await.into_body();
160 let mut body_stream1 = sse_response_creator.create().await.into_body();
161 assert_eq!(event_sender.client_count().await, 2);
162
163 event_sender.send(&event0).await;
164
165 let body_bytes1 = body_stream1.next().await;
166
167 assert_matches!(body_bytes1, Some(Ok(chunk)) if chunk.to_vec() == event0.to_vec());
168
169 let event1 = Event::from_type_and_data("event_type1", "data_contents1").unwrap();
170 event_sender.send(&event1).await;
171 assert_eq!(event_sender.client_count().await, 1);
172
173 let body_bytes0 = body_stream0.next().await;
174 assert_matches!(body_bytes0, Some(Err(_)));
175
176 let body_bytes1 = body_stream1.next().await;
177 assert_eq!(body_bytes1.unwrap().unwrap().to_vec(), event1.to_vec());
178 }
179
180 #[fasync::run_singlethreaded(test)]
181 async fn drop_all_clients() {
182 let (sse_response_creator, event_sender) =
183 SseResponseCreator::with_additional_buffer_size(0);
184 let mut body_stream0 = sse_response_creator.create().await.into_body();
185 let mut body_stream1 = sse_response_creator.create().await.into_body();
186 assert_eq!(event_sender.client_count().await, 2);
187 event_sender.send(&Event::from_type_and_data("event_type", "data_contents").unwrap()).await;
188
189 event_sender.drop_all_clients().await;
190
191 assert_eq!(event_sender.client_count().await, 0);
192 assert_matches!(body_stream0.next().await, Some(Err(_)));
193 assert_matches!(body_stream1.next().await, Some(Err(_)));
194 }
195}