use core::future::Future;
use core::pin::Pin;
use core::sync::atomic::{AtomicBool, Ordering};
use core::task::{Context, Poll};
use std::sync::{Arc, Mutex};
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::StreamExt as _;
use crate::protocol::lockers::Lockers;
use crate::protocol::{encode_buffer, MessageBuffer, ProtocolError, Transport};
use crate::{Decoder, Encode, EncodeError, Encoder};
pub fn make_client<T: Transport>(transport: T) -> (Dispatcher<T>, Client<T>, Events<T>) {
let (events_sender, events_receiver) = unbounded();
let (transport_sender, transport_receiver) = transport.split();
let shared = Arc::new(Shared::new());
(
Dispatcher { shared: shared.clone(), receiver: transport_receiver, sender: events_sender },
Client { shared, sender: transport_sender },
Events { receiver: events_receiver },
)
}
struct Shared<T: Transport> {
transactions: Mutex<Lockers<MessageBuffer<T>>>,
is_stopped: AtomicBool,
}
impl<T: Transport> Shared<T> {
fn new() -> Self {
Self { transactions: Mutex::new(Lockers::new()), is_stopped: AtomicBool::new(false) }
}
}
pub struct Dispatcher<T: Transport> {
shared: Arc<Shared<T>>,
receiver: T::Receiver,
sender: UnboundedSender<Result<MessageBuffer<T>, ProtocolError<T::Error>>>,
}
impl<T: Transport> Dispatcher<T> {
pub async fn run(&mut self)
where
for<'a> T::Decoder<'a>: Decoder<'a>,
{
if let Err(e) = self.try_run().await {
let _ = self.sender.unbounded_send(Err(e));
}
self.shared.is_stopped.store(true, Ordering::Relaxed);
self.shared.transactions.lock().unwrap().wake_all();
}
async fn try_run(&mut self) -> Result<(), ProtocolError<T::Error>>
where
for<'a> T::Decoder<'a>: Decoder<'a>,
{
while let Some(buffer) =
T::recv(&mut self.receiver).await.map_err(ProtocolError::TransportError)?
{
let (txid, buffer) = MessageBuffer::parse_header(buffer)?;
if txid == 0 {
let _ = self.sender.unbounded_send(Ok(buffer));
} else {
let mut transactions = self.shared.transactions.lock().unwrap();
let entry = transactions
.get(txid - 1)
.ok_or_else(|| ProtocolError::UnrequestedResponse(txid))?;
if entry.write(buffer).map_err(|_| ProtocolError::UnrequestedResponse(txid))? {
transactions.free(txid - 1);
}
}
}
Ok(())
}
}
#[derive(Clone)]
pub struct Client<T: Transport> {
shared: Arc<Shared<T>>,
sender: T::Sender,
}
impl<T: Transport> Client<T> {
pub fn send_request<'s, M>(
&'s self,
ordinal: u64,
request: &mut M,
) -> Result<T::SendFuture<'s>, EncodeError>
where
for<'a> T::Encoder<'a>: Encoder,
M: for<'a> Encode<T::Encoder<'a>>,
{
Self::send_message(&self.sender, 0, ordinal, request)
}
pub fn send_transaction<'s, M>(
&'s self,
ordinal: u64,
transaction: &mut M,
) -> Result<TransactionFuture<'s, T>, EncodeError>
where
for<'a> T::Encoder<'a>: Encoder,
M: for<'a> Encode<T::Encoder<'a>>,
{
let index = self.shared.transactions.lock().unwrap().alloc();
match Self::send_message(&self.sender, index + 1, ordinal, transaction) {
Ok(future) => Ok(TransactionFuture {
shared: &self.shared,
index,
ordinal,
state: TransactionFutureState::Sending(future),
}),
Err(e) => {
self.shared.transactions.lock().unwrap().free(index);
Err(e)
}
}
}
fn send_message<'s, M>(
sender: &'s T::Sender,
txid: u32,
ordinal: u64,
message: &mut M,
) -> Result<T::SendFuture<'s>, EncodeError>
where
for<'a> T::Encoder<'a>: Encoder,
M: for<'a> Encode<T::Encoder<'a>>,
{
let mut buffer = T::acquire(sender);
encode_buffer(&mut buffer, txid, ordinal, message)?;
Ok(T::send(sender, buffer))
}
}
enum TransactionFutureState<'a, T: 'a + Transport> {
Sending(T::SendFuture<'a>),
Receiving,
Completed,
}
pub struct TransactionFuture<'a, T: Transport> {
shared: &'a Shared<T>,
index: u32,
ordinal: u64,
state: TransactionFutureState<'a, T>,
}
impl<T: Transport> Drop for TransactionFuture<'_, T> {
fn drop(&mut self) {
let mut transactions = self.shared.transactions.lock().unwrap();
match self.state {
TransactionFutureState::Sending(_) => transactions.free(self.index),
TransactionFutureState::Receiving => {
if transactions.get(self.index).unwrap().cancel() {
transactions.free(self.index);
}
}
TransactionFutureState::Completed => (),
}
}
}
impl<T: Transport> TransactionFuture<'_, T> {
fn poll_receiving(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Future>::Output> {
let mut transactions = self.shared.transactions.lock().unwrap();
if let Some(ready) = transactions.get(self.index).unwrap().read(cx.waker()) {
transactions.free(self.index);
self.state = TransactionFutureState::Completed;
if ready.ordinal() != self.ordinal {
return Poll::Ready(Err(ProtocolError::InvalidResponseOrdinal {
expected: self.ordinal,
actual: ready.ordinal(),
}));
}
Poll::Ready(Ok(ready))
} else {
Poll::Pending
}
}
}
impl<T: Transport> Future for TransactionFuture<'_, T> {
type Output = Result<MessageBuffer<T>, ProtocolError<T::Error>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { Pin::into_inner_unchecked(self) };
match &mut this.state {
TransactionFutureState::Sending(future) => {
let pinned = unsafe { Pin::new_unchecked(future) };
match pinned.poll(cx) {
Poll::Pending => {
if this.shared.is_stopped.load(Ordering::Relaxed) {
return Poll::Ready(Err(ProtocolError::DispatcherStopped));
}
Poll::Pending
}
Poll::Ready(Ok(())) => {
this.state = TransactionFutureState::Receiving;
this.poll_receiving(cx)
}
Poll::Ready(Err(e)) => {
this.shared.transactions.lock().unwrap().free(this.index);
this.state = TransactionFutureState::Completed;
Poll::Ready(Err(ProtocolError::TransportError(e)))
}
}
}
TransactionFutureState::Receiving => this.poll_receiving(cx),
TransactionFutureState::Completed => unreachable!(),
}
}
}
pub struct Events<T: Transport> {
receiver: UnboundedReceiver<Result<MessageBuffer<T>, ProtocolError<T::Error>>>,
}
impl<T: Transport> Events<T> {
pub async fn next(&mut self) -> Result<Option<MessageBuffer<T>>, ProtocolError<T::Error>> {
self.receiver.next().await.transpose()
}
}