1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
// SPDX-License-Identifier: AGPL-3.0-or-later
//! Contains the [Transmitter] type.

use std::{
    future::{ready, Future, Ready},
    io,
    marker::PhantomData,
    net::{IpAddr, Ipv6Addr, SocketAddr},
    sync::Arc,
};

use btlib::{bterr, crypto::Creds};
use bytes::BytesMut;
use futures::SinkExt;
use quinn::{Connection, Endpoint, RecvStream, SendStream};
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::Mutex;
use tokio_util::codec::{Framed, FramedParts};

use crate::{
    receiver::{Envelope, ReplyEnvelope},
    serialization::{CallbackFramed, MsgEncoder},
    tls::{client_config, CertResolver},
    BlockAddr, CallMsg, DeserCallback, Result, SendMsg,
};

/// A type which can be used to transmit messages over the network to a [crate::Receiver].
pub struct Transmitter {
    addr: Arc<BlockAddr>,
    connection: Connection,
    send_parts: Mutex<Option<FramedParts<SendStream, MsgEncoder>>>,
    recv_buf: Mutex<Option<BytesMut>>,
}

macro_rules! cleanup_on_err {
    ($result:expr, $guard:ident, $parts:ident) => {
        match $result {
            Ok(value) => value,
            Err(err) => {
                *$guard = Some($parts);
                return Err(err.into());
            }
        }
    };
}

impl Transmitter {
    pub async fn new<C: 'static + Creds + Send + Sync>(
        addr: Arc<BlockAddr>,
        creds: Arc<C>,
    ) -> Result<Transmitter> {
        let resolver = Arc::new(CertResolver::new(creds)?);
        let endpoint = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))?;
        Transmitter::from_endpoint(endpoint, addr, resolver).await
    }

    pub(crate) async fn from_endpoint(
        endpoint: Endpoint,
        addr: Arc<BlockAddr>,
        resolver: Arc<CertResolver>,
    ) -> Result<Self> {
        let socket_addr = addr.socket_addr()?;
        let connecting = endpoint.connect_with(
            client_config(addr.path().clone(), resolver)?,
            socket_addr,
            // The ServerCertVerifier ensures we connect to the correct path.
            "UNIMPORTANT",
        )?;
        let connection = connecting.await?;
        let send_parts = Mutex::new(None);
        let recv_buf = Mutex::new(Some(BytesMut::new()));
        Ok(Self {
            addr,
            connection,
            send_parts,
            recv_buf,
        })
    }

    async fn transmit<T: Serialize>(&self, envelope: Envelope<T>) -> Result<RecvStream> {
        let mut guard = self.send_parts.lock().await;
        let (send_stream, recv_stream) = self.connection.open_bi().await?;
        let parts = match guard.take() {
            Some(mut parts) => {
                parts.io = send_stream;
                parts
            }
            None => FramedParts::new::<Envelope<T>>(send_stream, MsgEncoder::new()),
        };
        let mut sink = Framed::from_parts(parts);
        let result = sink.send(envelope).await;
        let parts = sink.into_parts();
        cleanup_on_err!(result, guard, parts);
        *guard = Some(parts);
        Ok(recv_stream)
    }

    /// Returns the address that this instance is transmitting to.
    pub fn addr(&self) -> &Arc<BlockAddr> {
        &self.addr
    }

    /// Transmit a message to the connected [crate::Receiver] without waiting for a reply.
    pub async fn send<'ser, 'de, T>(&self, msg: T) -> Result<()>
    where
        T: 'ser + SendMsg<'de>,
    {
        self.transmit(Envelope::send(msg)).await?;
        Ok(())
    }

    /// Transmit a message to the connected [crate::Receiver], waits for a reply, then calls the given
    /// [DeserCallback] with the deserialized reply.
    ///
    /// ## WARNING
    /// The callback must be such that `F::Arg<'a> = T::Reply<'a>` for any `'a`. If this
    /// is violated, then a deserilization error will occur at runtime.
    ///
    /// ## TODO
    /// This issue needs to be fixed. Due to the fact that
    /// `F::Arg` is a Generic Associated Type (GAT) I have been unable to express this constraint in
    /// the where clause of this method. I'm not sure if the errors I've encountered are due to a
    /// lack of understanding on my part or due to the current limitations of the borrow checker in
    /// its handling of GATs.
    pub async fn call<'ser, 'de, T, F>(&self, msg: T, callback: F) -> Result<F::Return>
    where
        T: 'ser + CallMsg<'de>,
        F: 'static + Send + DeserCallback,
    {
        let recv_stream = self.transmit(Envelope::call(msg)).await?;
        let mut guard = self.recv_buf.lock().await;
        let buffer = guard.take().unwrap();
        let mut callback_framed = CallbackFramed::from_parts(recv_stream, buffer);
        let result = callback_framed
            .next(ReplyCallback::new(callback))
            .await
            .ok_or_else(|| bterr!("server hung up before sending reply"));
        let (_, buffer) = callback_framed.into_parts();
        let output = cleanup_on_err!(result, guard, buffer);
        *guard = Some(buffer);
        output?
    }

    /// Transmits a message to the connected [crate::Receiver], waits for a reply, then passes back the
    /// the reply to the caller. This only works for messages whose reply doesn't borrow any data,
    /// otherwise the `call` method must be used.
    pub async fn call_through<'ser, T>(&self, msg: T) -> Result<T::Reply<'static>>
    where
        // TODO: CallMsg must take a static lifetime until this issue is resolved:
        //     https://github.com/rust-lang/rust/issues/103532
        T: 'ser + CallMsg<'static>,
        T::Reply<'static>: 'static + Send + Sync + DeserializeOwned,
    {
        self.call(msg, Passthrough::new()).await
    }
}

struct ReplyCallback<F> {
    inner: F,
}

impl<F> ReplyCallback<F> {
    fn new(inner: F) -> Self {
        Self { inner }
    }
}

impl<F: 'static + Send + DeserCallback> DeserCallback for ReplyCallback<F> {
    type Arg<'de> = ReplyEnvelope<F::Arg<'de>>;
    type Return = Result<F::Return>;
    type CallFut<'de> = impl 'de + Future<Output = Self::Return> + Send;
    fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
        async move {
            match arg {
                ReplyEnvelope::Ok(msg) => Ok(self.inner.call(msg).await),
                ReplyEnvelope::Err { message, os_code } => {
                    if let Some(os_code) = os_code {
                        let err = bterr!(io::Error::from_raw_os_error(os_code)).context(message);
                        Err(err)
                    } else {
                        Err(bterr!(message))
                    }
                }
            }
        }
    }
}

pub struct Passthrough<T> {
    phantom: PhantomData<T>,
}

impl<T> Passthrough<T> {
    pub fn new() -> Self {
        Self {
            phantom: PhantomData,
        }
    }
}

impl<T> Default for Passthrough<T> {
    fn default() -> Self {
        Self::new()
    }
}

impl<T> Clone for Passthrough<T> {
    fn clone(&self) -> Self {
        Self::new()
    }
}

impl<T: 'static + Send + DeserializeOwned> DeserCallback for Passthrough<T> {
    type Arg<'de> = T;
    type Return = T;
    type CallFut<'de> = Ready<T>;
    fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
        ready(arg)
    }
}