use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll};
use std::sync::{Arc, Mutex};
use crate::protocol::lockers::Lockers;
use crate::protocol::{decode_header, encode_header, ProtocolError, Transport};
use crate::{Encode, EncodeError, EncoderExt};
use super::lockers::LockerError;
struct Shared<T: Transport> {
responses: Mutex<Lockers<T::RecvBuffer>>,
}
impl<T: Transport> Shared<T> {
fn new() -> Self {
Self { responses: Mutex::new(Lockers::new()) }
}
}
pub struct ClientSender<T: Transport> {
shared: Arc<Shared<T>>,
sender: T::Sender,
}
impl<T: Transport> ClientSender<T> {
pub fn close(&self) {
T::close(&self.sender);
}
pub fn send_one_way<M>(
&self,
ordinal: u64,
request: &mut M,
) -> Result<T::SendFuture<'_>, EncodeError>
where
M: for<'a> Encode<T::Encoder<'a>>,
{
self.send_message(0, ordinal, request)
}
pub fn send_two_way<M>(
&self,
ordinal: u64,
request: &mut M,
) -> Result<ResponseFuture<'_, T>, EncodeError>
where
M: for<'a> Encode<T::Encoder<'a>>,
{
let index = self.shared.responses.lock().unwrap().alloc(ordinal);
match self.send_message(index + 1, ordinal, request) {
Ok(future) => Ok(ResponseFuture {
shared: &self.shared,
index,
state: ResponseFutureState::Sending(future),
}),
Err(e) => {
self.shared.responses.lock().unwrap().free(index);
Err(e)
}
}
}
fn send_message<M>(
&self,
txid: u32,
ordinal: u64,
message: &mut M,
) -> Result<T::SendFuture<'_>, EncodeError>
where
M: for<'a> Encode<T::Encoder<'a>>,
{
let mut buffer = T::acquire(&self.sender);
encode_header::<T>(&mut buffer, txid, ordinal)?;
T::encoder(&mut buffer).encode_next(message)?;
Ok(T::send(&self.sender, buffer))
}
}
impl<T: Transport> Clone for ClientSender<T> {
fn clone(&self) -> Self {
Self { shared: self.shared.clone(), sender: self.sender.clone() }
}
}
enum ResponseFutureState<'a, T: 'a + Transport> {
Sending(T::SendFuture<'a>),
Receiving,
Completed,
}
pub struct ResponseFuture<'a, T: Transport> {
shared: &'a Shared<T>,
index: u32,
state: ResponseFutureState<'a, T>,
}
impl<T: Transport> Drop for ResponseFuture<'_, T> {
fn drop(&mut self) {
let mut responses = self.shared.responses.lock().unwrap();
match self.state {
ResponseFutureState::Sending(_) => responses.free(self.index),
ResponseFutureState::Receiving => {
if responses.get(self.index).unwrap().cancel() {
responses.free(self.index);
}
}
ResponseFutureState::Completed => (),
}
}
}
impl<T: Transport> ResponseFuture<'_, T> {
fn poll_receiving(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Future>::Output> {
let mut responses = self.shared.responses.lock().unwrap();
if let Some(ready) = responses.get(self.index).unwrap().read(cx.waker()) {
responses.free(self.index);
self.state = ResponseFutureState::Completed;
Poll::Ready(Ok(ready))
} else {
Poll::Pending
}
}
}
impl<T: Transport> Future for ResponseFuture<'_, T> {
type Output = Result<T::RecvBuffer, T::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { Pin::into_inner_unchecked(self) };
match &mut this.state {
ResponseFutureState::Sending(future) => {
let pinned = unsafe { Pin::new_unchecked(future) };
match pinned.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => {
this.state = ResponseFutureState::Receiving;
this.poll_receiving(cx)
}
Poll::Ready(Err(e)) => {
this.shared.responses.lock().unwrap().free(this.index);
this.state = ResponseFutureState::Completed;
Poll::Ready(Err(e))
}
}
}
ResponseFutureState::Receiving => this.poll_receiving(cx),
ResponseFutureState::Completed => unreachable!(),
}
}
}
pub trait ClientHandler<T: Transport> {
fn on_event(&mut self, sender: &ClientSender<T>, ordinal: u64, buffer: T::RecvBuffer);
}
pub struct Client<T: Transport> {
sender: ClientSender<T>,
receiver: T::Receiver,
}
impl<T: Transport> Client<T> {
pub fn new(transport: T) -> Self {
let (sender, receiver) = transport.split();
let shared = Arc::new(Shared::new());
Self { sender: ClientSender { shared, sender }, receiver }
}
pub fn sender(&self) -> &ClientSender<T> {
&self.sender
}
pub async fn run<H>(&mut self, mut handler: H) -> Result<(), ProtocolError<T::Error>>
where
H: ClientHandler<T>,
{
let result = self.run_to_completion(&mut handler).await;
self.sender.shared.responses.lock().unwrap().wake_all();
result
}
async fn run_to_completion<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
where
H: ClientHandler<T>,
{
while let Some(mut buffer) =
T::recv(&mut self.receiver).await.map_err(ProtocolError::TransportError)?
{
let (txid, ordinal) =
decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
if txid == 0 {
handler.on_event(&self.sender, ordinal, buffer);
} else {
let mut responses = self.sender.shared.responses.lock().unwrap();
let locker = responses
.get(txid - 1)
.ok_or_else(|| ProtocolError::UnrequestedResponse(txid))?;
match locker.write(ordinal, buffer) {
Ok(false) => (),
Ok(true) => responses.free(txid - 1),
Err(LockerError::NotWriteable) => {
return Err(ProtocolError::UnrequestedResponse(txid));
}
Err(LockerError::MismatchedOrdinal { expected, actual }) => {
return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
}
}
}
}
Ok(())
}
}