1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
// SPDX-License-Identifier: AGPL-3.0-or-later
//! Types used for receiving messages over the network. Chiefly, the [Receiver] type.

use std::{
    any::Any,
    future::Future,
    io,
    net::IpAddr,
    sync::{Arc, Mutex as StdMutex},
};

use btlib::{bterr, crypto::Creds, error::DisplayErr, BlockPath, Writecap};
use futures::{FutureExt, SinkExt};
use log::{debug, error};
use quinn::{Connection, ConnectionError, Endpoint, RecvStream, SendStream};
use serde::{Deserialize, Serialize};
use tokio::{
    select,
    sync::{broadcast, Mutex},
    task::JoinHandle,
};
use tokio_util::codec::FramedWrite;

use crate::{
    serialization::{CallbackFramed, MsgEncoder},
    tls::{server_config, CertResolver},
    BlockAddr, CallMsg, DeserCallback, Result, Transmitter,
};

macro_rules! handle_err {
    ($result:expr, $on_err:expr, $control_flow:expr) => {
        match $result {
            Ok(inner) => inner,
            Err(err) => {
                $on_err(err);
                $control_flow;
            }
        }
    };
}

/// Unwraps the given result, or if the result is an error, returns from the enclosing function.
macro_rules! unwrap_or_return {
    ($result:expr, $on_err:expr) => {
        handle_err!($result, $on_err, return)
    };
    ($result:expr) => {
        unwrap_or_return!($result, |err| error!("{err}"))
    };
}

/// Unwraps the given result, or if the result is an error, continues the enclosing loop.
macro_rules! unwrap_or_continue {
    ($result:expr, $on_err:expr) => {
        handle_err!($result, $on_err, continue)
    };
    ($result:expr) => {
        unwrap_or_continue!($result, |err| error!("{err}"))
    };
}

/// Awaits its first argument, unless interrupted by its second argument, in which case the
/// enclosing function returns. The second argument needs to be cancel safe, but the first
/// need not be if it is discarded when the enclosing function returns (because losing messages
/// from the first argument doesn't matter in this case).
macro_rules! await_or_stop {
    ($future:expr, $stop_fut:expr) => {
        select! {
            Some(connecting) = $future => connecting,
            _ = $stop_fut => break,
        }
    };
}

/// Type which receives messages sent over the network sent by a [Transmitter].
pub struct Receiver {
    recv_addr: Arc<BlockAddr>,
    stop_tx: broadcast::Sender<()>,
    endpoint: Endpoint,
    resolver: Arc<CertResolver>,
    join_handle: StdMutex<Option<JoinHandle<()>>>,
}

impl Receiver {
    /// Returns a [Receiver] bound to the given [IpAddr] which receives messages at the bind path of
    /// the [Writecap] in the given credentials. The returned type can be used to make
    /// [Transmitter]s for any path.
    pub fn new<C: 'static + Creds + Send + Sync, F: 'static + MsgCallback>(
        ip_addr: IpAddr,
        creds: Arc<C>,
        callback: F,
    ) -> Result<Receiver> {
        let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?;
        let recv_addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
        log::info!("starting Receiver with address {}", recv_addr);
        let socket_addr = recv_addr.socket_addr()?;
        let resolver = Arc::new(CertResolver::new(creds)?);
        let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?;
        let (stop_tx, stop_rx) = broadcast::channel(1);
        let join_handle = tokio::spawn(Self::server_loop(endpoint.clone(), callback, stop_rx));
        Ok(Self {
            recv_addr,
            stop_tx,
            endpoint,
            resolver,
            join_handle: StdMutex::new(Some(join_handle)),
        })
    }

    async fn server_loop<F: 'static + MsgCallback>(
        endpoint: Endpoint,
        callback: F,
        mut stop_rx: broadcast::Receiver<()>,
    ) {
        loop {
            let connecting = await_or_stop!(endpoint.accept(), stop_rx.recv());
            let connection = unwrap_or_continue!(connecting.await, |err| error!(
                "error accepting QUIC connection: {err}"
            ));
            tokio::spawn(Self::handle_connection(
                connection,
                callback.clone(),
                stop_rx.resubscribe(),
            ));
        }
    }

    async fn handle_connection<F: 'static + MsgCallback>(
        connection: Connection,
        callback: F,
        mut stop_rx: broadcast::Receiver<()>,
    ) {
        let client_path = unwrap_or_return!(
            Self::client_path(connection.peer_identity()),
            |err| error!("failed to get client path from peer identity: {err}")
        );
        loop {
            let result = await_or_stop!(connection.accept_bi().map(Some), stop_rx.recv());
            let (send_stream, recv_stream) = match result {
                Ok(pair) => pair,
                Err(err) => match err {
                    ConnectionError::ApplicationClosed(app) => {
                        debug!("connection closed: {app}");
                        return;
                    }
                    _ => {
                        error!("error accepting stream: {err}");
                        continue;
                    }
                },
            };
            let client_path = client_path.clone();
            let callback = callback.clone();
            tokio::task::spawn(Self::handle_message(
                client_path,
                send_stream,
                recv_stream,
                callback,
            ));
        }
    }

    async fn handle_message<F: 'static + MsgCallback>(
        client_path: Arc<BlockPath>,
        send_stream: SendStream,
        recv_stream: RecvStream,
        callback: F,
    ) {
        let framed_msg = Arc::new(Mutex::new(FramedWrite::new(send_stream, MsgEncoder::new())));
        let callback =
            MsgRecvdCallback::new(client_path.clone(), framed_msg.clone(), callback.clone());
        let mut msg_stream = CallbackFramed::new(recv_stream);
        let result = msg_stream
            .next(callback)
            .await
            .ok_or_else(|| bterr!("client closed stream before sending a message"));
        match unwrap_or_return!(result) {
            Err(err) => error!("msg_stream produced an error: {err}"),
            Ok(result) => {
                if let Err(err) = result {
                    error!("callback returned an error: {err}");
                }
            }
        }
    }

    /// Returns the path the client is bound to.
    fn client_path(peer_identity: Option<Box<dyn Any>>) -> Result<Arc<BlockPath>> {
        let peer_identity =
            peer_identity.ok_or_else(|| bterr!("connection did not contain a peer identity"))?;
        let client_certs = peer_identity
            .downcast::<Vec<rustls::Certificate>>()
            .map_err(|_| bterr!("failed to downcast peer_identity to certificate chain"))?;
        let first = client_certs
            .first()
            .ok_or_else(|| bterr!("no certificates were presented by the client"))?;
        let (writecap, ..) = Writecap::from_cert_chain(first, &client_certs[1..])?;
        Ok(Arc::new(writecap.bind_path()))
    }

    /// The address at which messages will be received.
    pub fn addr(&self) -> &Arc<BlockAddr> {
        &self.recv_addr
    }

    /// Creates a [Transmitter] which is connected to the given address.
    pub async fn transmitter(&self, addr: Arc<BlockAddr>) -> Result<Transmitter> {
        Transmitter::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await
    }

    /// Returns a future which completes when this [Receiver] has completed
    /// (which may be never).
    pub fn complete(&self) -> Result<JoinHandle<()>> {
        let mut guard = self.join_handle.lock().display_err()?;
        let handle = guard
            .take()
            .ok_or_else(|| bterr!("join handle has already been taken"))?;
        Ok(handle)
    }

    /// Sends a signal indicating that the task running the server loop should return.
    pub fn stop(&self) -> Result<()> {
        self.stop_tx.send(()).map(|_| ()).map_err(|err| err.into())
    }
}

impl Drop for Receiver {
    fn drop(&mut self) {
        // This result will be a failure if the tasks have already returned, which is not a
        // problem.
        let _ = self.stop_tx.send(());
    }
}

/// Trait for types which can be called to handle messages received over the network. The
/// server loop in [Receiver] uses a type that implements this trait to react to messages it
/// receives.
pub trait MsgCallback: Clone + Send + Sync + Unpin {
    type Arg<'de>: CallMsg<'de>
    where
        Self: 'de;
    type CallFut<'de>: Future<Output = Result<()>> + Send
    where
        Self: 'de;
    fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de>;
}

impl<T: MsgCallback> MsgCallback for &T {
    type Arg<'de> = T::Arg<'de> where Self: 'de;
    type CallFut<'de> = T::CallFut<'de> where Self: 'de;
    fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
        (*self).call(arg)
    }
}

struct MsgRecvdCallback<F> {
    path: Arc<BlockPath>,
    replier: Replier,
    inner: F,
}

impl<F: MsgCallback> MsgRecvdCallback<F> {
    fn new(path: Arc<BlockPath>, framed_msg: Arc<Mutex<FramedMsg>>, inner: F) -> Self {
        Self {
            path,
            replier: Replier::new(framed_msg),
            inner,
        }
    }
}

impl<F: 'static + MsgCallback> DeserCallback for MsgRecvdCallback<F> {
    type Arg<'de> = Envelope<F::Arg<'de>> where Self: 'de;
    type Return = Result<()>;
    type CallFut<'de> = impl 'de + Future<Output = Self::Return> + Send where F: 'de, Self: 'de;
    fn call<'de>(&'de mut self, arg: Envelope<F::Arg<'de>>) -> Self::CallFut<'de> {
        let replier = match arg.kind {
            MsgKind::Call => Some(self.replier.clone()),
            MsgKind::Send => None,
        };
        async move {
            let result = self
                .inner
                .call(MsgReceived::new(self.path.clone(), arg, replier))
                .await;
            match result {
                Ok(value) => Ok(value),
                Err(err) => match err.downcast::<io::Error>() {
                    Ok(err) => {
                        self.replier
                            .reply_err(err.to_string(), err.raw_os_error())
                            .await
                    }
                    Err(err) => self.replier.reply_err(err.to_string(), None).await,
                },
            }
        }
    }
}

/// Indicates whether a message was sent using `call` or `send`.
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
enum MsgKind {
    /// This message expects exactly one reply.
    Call,
    /// This message expects exactly zero replies.
    Send,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
pub(crate) struct Envelope<T> {
    kind: MsgKind,
    msg: T,
}

impl<T> Envelope<T> {
    pub(crate) fn send(msg: T) -> Self {
        Self {
            msg,
            kind: MsgKind::Send,
        }
    }

    pub(crate) fn call(msg: T) -> Self {
        Self {
            msg,
            kind: MsgKind::Call,
        }
    }

    fn msg(&self) -> &T {
        &self.msg
    }
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
pub(crate) enum ReplyEnvelope<T> {
    Ok(T),
    Err {
        message: String,
        os_code: Option<i32>,
    },
}

impl<T> ReplyEnvelope<T> {
    fn err(message: String, os_code: Option<i32>) -> Self {
        Self::Err { message, os_code }
    }
}

/// A message tagged with the block path that it was sent from.
pub struct MsgReceived<T> {
    from: Arc<BlockPath>,
    msg: Envelope<T>,
    replier: Option<Replier>,
}

impl<T> MsgReceived<T> {
    fn new(from: Arc<BlockPath>, msg: Envelope<T>, replier: Option<Replier>) -> Self {
        Self { from, msg, replier }
    }

    pub fn into_parts(self) -> (Arc<BlockPath>, T, Option<Replier>) {
        (self.from, self.msg.msg, self.replier)
    }

    /// The path from which this message was received.
    pub fn from(&self) -> &Arc<BlockPath> {
        &self.from
    }

    /// Payload contained in this message.
    pub fn body(&self) -> &T {
        self.msg.msg()
    }

    /// Returns true if and only if this messages needs to be replied to.
    pub fn needs_reply(&self) -> bool {
        self.replier.is_some()
    }

    /// Takes the replier out of this struct and returns it, if it has not previously been returned.
    pub fn take_replier(&mut self) -> Option<Replier> {
        self.replier.take()
    }
}

type FramedMsg = FramedWrite<SendStream, MsgEncoder>;
type ArcMutex<T> = Arc<Mutex<T>>;

/// A type for sending a reply to a message. Replies are sent over their own streams, so no two
/// messages can interfere with one another.
#[derive(Clone)]
pub struct Replier {
    stream: ArcMutex<FramedMsg>,
}

impl Replier {
    fn new(stream: ArcMutex<FramedMsg>) -> Self {
        Self { stream }
    }

    pub async fn reply<T: Serialize + Send>(&mut self, reply: T) -> Result<()> {
        let mut guard = self.stream.lock().await;
        guard.send(ReplyEnvelope::Ok(reply)).await?;
        Ok(())
    }

    pub async fn reply_err(&mut self, err: String, os_code: Option<i32>) -> Result<()> {
        let mut guard = self.stream.lock().await;
        guard.send(ReplyEnvelope::<()>::err(err, os_code)).await?;
        Ok(())
    }
}