use std::{
any::Any,
future::Future,
io,
net::IpAddr,
sync::{Arc, Mutex as StdMutex},
};
use btlib::{bterr, crypto::Creds, error::DisplayErr, BlockPath, Writecap};
use futures::{FutureExt, SinkExt};
use log::{debug, error};
use quinn::{Connection, ConnectionError, Endpoint, RecvStream, SendStream};
use serde::{Deserialize, Serialize};
use tokio::{
select,
sync::{broadcast, Mutex},
task::JoinHandle,
};
use tokio_util::codec::FramedWrite;
use crate::{
serialization::{CallbackFramed, MsgEncoder},
tls::{server_config, CertResolver},
BlockAddr, CallMsg, DeserCallback, Result, Transmitter,
};
macro_rules! handle_err {
($result:expr, $on_err:expr, $control_flow:expr) => {
match $result {
Ok(inner) => inner,
Err(err) => {
$on_err(err);
$control_flow;
}
}
};
}
macro_rules! unwrap_or_return {
($result:expr, $on_err:expr) => {
handle_err!($result, $on_err, return)
};
($result:expr) => {
unwrap_or_return!($result, |err| error!("{err}"))
};
}
macro_rules! unwrap_or_continue {
($result:expr, $on_err:expr) => {
handle_err!($result, $on_err, continue)
};
($result:expr) => {
unwrap_or_continue!($result, |err| error!("{err}"))
};
}
macro_rules! await_or_stop {
($future:expr, $stop_fut:expr) => {
select! {
Some(connecting) = $future => connecting,
_ = $stop_fut => break,
}
};
}
pub struct Receiver {
recv_addr: Arc<BlockAddr>,
stop_tx: broadcast::Sender<()>,
endpoint: Endpoint,
resolver: Arc<CertResolver>,
join_handle: StdMutex<Option<JoinHandle<()>>>,
}
impl Receiver {
pub fn new<C: 'static + Creds + Send + Sync, F: 'static + MsgCallback>(
ip_addr: IpAddr,
creds: Arc<C>,
callback: F,
) -> Result<Receiver> {
let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?;
let recv_addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
log::info!("starting Receiver with address {}", recv_addr);
let socket_addr = recv_addr.socket_addr()?;
let resolver = Arc::new(CertResolver::new(creds)?);
let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?;
let (stop_tx, stop_rx) = broadcast::channel(1);
let join_handle = tokio::spawn(Self::server_loop(endpoint.clone(), callback, stop_rx));
Ok(Self {
recv_addr,
stop_tx,
endpoint,
resolver,
join_handle: StdMutex::new(Some(join_handle)),
})
}
async fn server_loop<F: 'static + MsgCallback>(
endpoint: Endpoint,
callback: F,
mut stop_rx: broadcast::Receiver<()>,
) {
loop {
let connecting = await_or_stop!(endpoint.accept(), stop_rx.recv());
let connection = unwrap_or_continue!(connecting.await, |err| error!(
"error accepting QUIC connection: {err}"
));
tokio::spawn(Self::handle_connection(
connection,
callback.clone(),
stop_rx.resubscribe(),
));
}
}
async fn handle_connection<F: 'static + MsgCallback>(
connection: Connection,
callback: F,
mut stop_rx: broadcast::Receiver<()>,
) {
let client_path = unwrap_or_return!(
Self::client_path(connection.peer_identity()),
|err| error!("failed to get client path from peer identity: {err}")
);
loop {
let result = await_or_stop!(connection.accept_bi().map(Some), stop_rx.recv());
let (send_stream, recv_stream) = match result {
Ok(pair) => pair,
Err(err) => match err {
ConnectionError::ApplicationClosed(app) => {
debug!("connection closed: {app}");
return;
}
_ => {
error!("error accepting stream: {err}");
continue;
}
},
};
let client_path = client_path.clone();
let callback = callback.clone();
tokio::task::spawn(Self::handle_message(
client_path,
send_stream,
recv_stream,
callback,
));
}
}
async fn handle_message<F: 'static + MsgCallback>(
client_path: Arc<BlockPath>,
send_stream: SendStream,
recv_stream: RecvStream,
callback: F,
) {
let framed_msg = Arc::new(Mutex::new(FramedWrite::new(send_stream, MsgEncoder::new())));
let callback =
MsgRecvdCallback::new(client_path.clone(), framed_msg.clone(), callback.clone());
let mut msg_stream = CallbackFramed::new(recv_stream);
let result = msg_stream
.next(callback)
.await
.ok_or_else(|| bterr!("client closed stream before sending a message"));
match unwrap_or_return!(result) {
Err(err) => error!("msg_stream produced an error: {err}"),
Ok(result) => {
if let Err(err) = result {
error!("callback returned an error: {err}");
}
}
}
}
fn client_path(peer_identity: Option<Box<dyn Any>>) -> Result<Arc<BlockPath>> {
let peer_identity =
peer_identity.ok_or_else(|| bterr!("connection did not contain a peer identity"))?;
let client_certs = peer_identity
.downcast::<Vec<rustls::Certificate>>()
.map_err(|_| bterr!("failed to downcast peer_identity to certificate chain"))?;
let first = client_certs
.first()
.ok_or_else(|| bterr!("no certificates were presented by the client"))?;
let (writecap, ..) = Writecap::from_cert_chain(first, &client_certs[1..])?;
Ok(Arc::new(writecap.bind_path()))
}
pub fn addr(&self) -> &Arc<BlockAddr> {
&self.recv_addr
}
pub async fn transmitter(&self, addr: Arc<BlockAddr>) -> Result<Transmitter> {
Transmitter::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await
}
pub fn complete(&self) -> Result<JoinHandle<()>> {
let mut guard = self.join_handle.lock().display_err()?;
let handle = guard
.take()
.ok_or_else(|| bterr!("join handle has already been taken"))?;
Ok(handle)
}
pub fn stop(&self) -> Result<()> {
self.stop_tx.send(()).map(|_| ()).map_err(|err| err.into())
}
}
impl Drop for Receiver {
fn drop(&mut self) {
let _ = self.stop_tx.send(());
}
}
pub trait MsgCallback: Clone + Send + Sync + Unpin {
type Arg<'de>: CallMsg<'de>
where
Self: 'de;
type CallFut<'de>: Future<Output = Result<()>> + Send
where
Self: 'de;
fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de>;
}
impl<T: MsgCallback> MsgCallback for &T {
type Arg<'de> = T::Arg<'de> where Self: 'de;
type CallFut<'de> = T::CallFut<'de> where Self: 'de;
fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
(*self).call(arg)
}
}
struct MsgRecvdCallback<F> {
path: Arc<BlockPath>,
replier: Replier,
inner: F,
}
impl<F: MsgCallback> MsgRecvdCallback<F> {
fn new(path: Arc<BlockPath>, framed_msg: Arc<Mutex<FramedMsg>>, inner: F) -> Self {
Self {
path,
replier: Replier::new(framed_msg),
inner,
}
}
}
impl<F: 'static + MsgCallback> DeserCallback for MsgRecvdCallback<F> {
type Arg<'de> = Envelope<F::Arg<'de>> where Self: 'de;
type Return = Result<()>;
type CallFut<'de> = impl 'de + Future<Output = Self::Return> + Send where F: 'de, Self: 'de;
fn call<'de>(&'de mut self, arg: Envelope<F::Arg<'de>>) -> Self::CallFut<'de> {
let replier = match arg.kind {
MsgKind::Call => Some(self.replier.clone()),
MsgKind::Send => None,
};
async move {
let result = self
.inner
.call(MsgReceived::new(self.path.clone(), arg, replier))
.await;
match result {
Ok(value) => Ok(value),
Err(err) => match err.downcast::<io::Error>() {
Ok(err) => {
self.replier
.reply_err(err.to_string(), err.raw_os_error())
.await
}
Err(err) => self.replier.reply_err(err.to_string(), None).await,
},
}
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
enum MsgKind {
Call,
Send,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
pub(crate) struct Envelope<T> {
kind: MsgKind,
msg: T,
}
impl<T> Envelope<T> {
pub(crate) fn send(msg: T) -> Self {
Self {
msg,
kind: MsgKind::Send,
}
}
pub(crate) fn call(msg: T) -> Self {
Self {
msg,
kind: MsgKind::Call,
}
}
fn msg(&self) -> &T {
&self.msg
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
pub(crate) enum ReplyEnvelope<T> {
Ok(T),
Err {
message: String,
os_code: Option<i32>,
},
}
impl<T> ReplyEnvelope<T> {
fn err(message: String, os_code: Option<i32>) -> Self {
Self::Err { message, os_code }
}
}
pub struct MsgReceived<T> {
from: Arc<BlockPath>,
msg: Envelope<T>,
replier: Option<Replier>,
}
impl<T> MsgReceived<T> {
fn new(from: Arc<BlockPath>, msg: Envelope<T>, replier: Option<Replier>) -> Self {
Self { from, msg, replier }
}
pub fn into_parts(self) -> (Arc<BlockPath>, T, Option<Replier>) {
(self.from, self.msg.msg, self.replier)
}
pub fn from(&self) -> &Arc<BlockPath> {
&self.from
}
pub fn body(&self) -> &T {
self.msg.msg()
}
pub fn needs_reply(&self) -> bool {
self.replier.is_some()
}
pub fn take_replier(&mut self) -> Option<Replier> {
self.replier.take()
}
}
type FramedMsg = FramedWrite<SendStream, MsgEncoder>;
type ArcMutex<T> = Arc<Mutex<T>>;
#[derive(Clone)]
pub struct Replier {
stream: ArcMutex<FramedMsg>,
}
impl Replier {
fn new(stream: ArcMutex<FramedMsg>) -> Self {
Self { stream }
}
pub async fn reply<T: Serialize + Send>(&mut self, reply: T) -> Result<()> {
let mut guard = self.stream.lock().await;
guard.send(ReplyEnvelope::Ok(reply)).await?;
Ok(())
}
pub async fn reply_err(&mut self, err: String, os_code: Option<i32>) -> Result<()> {
let mut guard = self.stream.lock().await;
guard.send(ReplyEnvelope::<()>::err(err, os_code)).await?;
Ok(())
}
}