use std::{
future::{ready, Future, Ready},
io,
marker::PhantomData,
net::{IpAddr, Ipv6Addr, SocketAddr},
sync::Arc,
};
use btlib::{bterr, crypto::Creds};
use bytes::BytesMut;
use futures::SinkExt;
use quinn::{Connection, Endpoint, RecvStream, SendStream};
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::Mutex;
use tokio_util::codec::{Framed, FramedParts};
use crate::{
receiver::{Envelope, ReplyEnvelope},
serialization::{CallbackFramed, MsgEncoder},
tls::{client_config, CertResolver},
BlockAddr, CallMsg, DeserCallback, Result, SendMsg,
};
pub struct Transmitter {
addr: Arc<BlockAddr>,
connection: Connection,
send_parts: Mutex<Option<FramedParts<SendStream, MsgEncoder>>>,
recv_buf: Mutex<Option<BytesMut>>,
}
macro_rules! cleanup_on_err {
($result:expr, $guard:ident, $parts:ident) => {
match $result {
Ok(value) => value,
Err(err) => {
*$guard = Some($parts);
return Err(err.into());
}
}
};
}
impl Transmitter {
pub async fn new<C: 'static + Creds + Send + Sync>(
addr: Arc<BlockAddr>,
creds: Arc<C>,
) -> Result<Transmitter> {
let resolver = Arc::new(CertResolver::new(creds)?);
let endpoint = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))?;
Transmitter::from_endpoint(endpoint, addr, resolver).await
}
pub(crate) async fn from_endpoint(
endpoint: Endpoint,
addr: Arc<BlockAddr>,
resolver: Arc<CertResolver>,
) -> Result<Self> {
let socket_addr = addr.socket_addr()?;
let connecting = endpoint.connect_with(
client_config(addr.path().clone(), resolver)?,
socket_addr,
"UNIMPORTANT",
)?;
let connection = connecting.await?;
let send_parts = Mutex::new(None);
let recv_buf = Mutex::new(Some(BytesMut::new()));
Ok(Self {
addr,
connection,
send_parts,
recv_buf,
})
}
async fn transmit<T: Serialize>(&self, envelope: Envelope<T>) -> Result<RecvStream> {
let mut guard = self.send_parts.lock().await;
let (send_stream, recv_stream) = self.connection.open_bi().await?;
let parts = match guard.take() {
Some(mut parts) => {
parts.io = send_stream;
parts
}
None => FramedParts::new::<Envelope<T>>(send_stream, MsgEncoder::new()),
};
let mut sink = Framed::from_parts(parts);
let result = sink.send(envelope).await;
let parts = sink.into_parts();
cleanup_on_err!(result, guard, parts);
*guard = Some(parts);
Ok(recv_stream)
}
pub fn addr(&self) -> &Arc<BlockAddr> {
&self.addr
}
pub async fn send<'ser, 'de, T>(&self, msg: T) -> Result<()>
where
T: 'ser + SendMsg<'de>,
{
self.transmit(Envelope::send(msg)).await?;
Ok(())
}
pub async fn call<'ser, 'de, T, F>(&self, msg: T, callback: F) -> Result<F::Return>
where
T: 'ser + CallMsg<'de>,
F: 'static + Send + DeserCallback,
{
let recv_stream = self.transmit(Envelope::call(msg)).await?;
let mut guard = self.recv_buf.lock().await;
let buffer = guard.take().unwrap();
let mut callback_framed = CallbackFramed::from_parts(recv_stream, buffer);
let result = callback_framed
.next(ReplyCallback::new(callback))
.await
.ok_or_else(|| bterr!("server hung up before sending reply"));
let (_, buffer) = callback_framed.into_parts();
let output = cleanup_on_err!(result, guard, buffer);
*guard = Some(buffer);
output?
}
pub async fn call_through<'ser, T>(&self, msg: T) -> Result<T::Reply<'static>>
where
T: 'ser + CallMsg<'static>,
T::Reply<'static>: 'static + Send + Sync + DeserializeOwned,
{
self.call(msg, Passthrough::new()).await
}
}
struct ReplyCallback<F> {
inner: F,
}
impl<F> ReplyCallback<F> {
fn new(inner: F) -> Self {
Self { inner }
}
}
impl<F: 'static + Send + DeserCallback> DeserCallback for ReplyCallback<F> {
type Arg<'de> = ReplyEnvelope<F::Arg<'de>>;
type Return = Result<F::Return>;
type CallFut<'de> = impl 'de + Future<Output = Self::Return> + Send;
fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
async move {
match arg {
ReplyEnvelope::Ok(msg) => Ok(self.inner.call(msg).await),
ReplyEnvelope::Err { message, os_code } => {
if let Some(os_code) = os_code {
let err = bterr!(io::Error::from_raw_os_error(os_code)).context(message);
Err(err)
} else {
Err(bterr!(message))
}
}
}
}
}
}
pub struct Passthrough<T> {
phantom: PhantomData<T>,
}
impl<T> Passthrough<T> {
pub fn new() -> Self {
Self {
phantom: PhantomData,
}
}
}
impl<T> Default for Passthrough<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Clone for Passthrough<T> {
fn clone(&self) -> Self {
Self::new()
}
}
impl<T: 'static + Send + DeserializeOwned> DeserCallback for Passthrough<T> {
type Arg<'de> = T;
type Return = T;
type CallFut<'de> = Ready<T>;
fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
ready(arg)
}
}