1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
// SPDX-License-Identifier: AGPL-3.0-or-later
//! Types for serializing and deserializing messages.

use btlib::error::BoxInIoErr;
use btserde::{from_slice, read_from, write_to};
use bytes::{BufMut, BytesMut};
use futures::Future;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio_util::codec::Encoder;

use crate::Result;

pub(crate) struct CallbackFramed<I> {
    io: I,
    buffer: BytesMut,
}

impl<I> CallbackFramed<I> {
    const INIT_CAPACITY: usize = 4096;
    /// The number of bytes used to encode the length of each frame.
    const FRAME_LEN_SZ: usize = std::mem::size_of::<u64>();

    pub fn new(inner: I) -> Self {
        Self {
            io: inner,
            buffer: BytesMut::with_capacity(Self::INIT_CAPACITY),
        }
    }

    pub fn into_parts(self) -> (I, BytesMut) {
        (self.io, self.buffer)
    }

    pub fn from_parts(io: I, mut buffer: BytesMut) -> Self {
        if buffer.capacity() < Self::INIT_CAPACITY {
            buffer.reserve(Self::INIT_CAPACITY - buffer.capacity());
        }
        Self { io, buffer }
    }

    async fn decode(mut slice: &[u8]) -> Result<DecodeStatus> {
        let payload_len: u64 = match read_from(&mut slice) {
            Ok(payload_len) => payload_len,
            Err(err) => {
                return match err {
                    btserde::Error::Eof => Ok(DecodeStatus::None),
                    btserde::Error::Io(ref io_err) => match io_err.kind() {
                        std::io::ErrorKind::UnexpectedEof => Ok(DecodeStatus::None),
                        _ => Err(err.into()),
                    },
                    _ => Err(err.into()),
                }
            }
        };
        let payload_len: usize = payload_len.try_into().box_err()?;
        if slice.len() < payload_len {
            return Ok(DecodeStatus::Reserve(payload_len - slice.len()));
        }
        Ok(DecodeStatus::Consume(Self::FRAME_LEN_SZ + payload_len))
    }
}

macro_rules! attempt {
    ($result:expr) => {
        match $result {
            Ok(value) => value,
            Err(err) => return Some(Err(err.into())),
        }
    };
}

impl<S: AsyncRead + Unpin> CallbackFramed<S> {
    pub(crate) async fn next<F: DeserCallback>(
        &mut self,
        mut callback: F,
    ) -> Option<Result<F::Return>> {
        let mut total_read = 0;
        loop {
            if self.buffer.capacity() - self.buffer.len() == 0 {
                // If there is no space left in the buffer we reserve additional bytes to ensure
                // read_buf doesn't return 0 unless we're at EOF.
                self.buffer.reserve(1);
            }
            let read_ct = attempt!(self.io.read_buf(&mut self.buffer).await);
            if 0 == read_ct {
                return None;
            }
            total_read += read_ct;
            match attempt!(Self::decode(&self.buffer[..total_read]).await) {
                DecodeStatus::None => continue,
                DecodeStatus::Reserve(count) => {
                    self.buffer.reserve(count);
                    continue;
                }
                DecodeStatus::Consume(consume) => {
                    let start = self.buffer.split_to(consume);
                    let arg: F::Arg<'_> = attempt!(from_slice(&start[Self::FRAME_LEN_SZ..]));
                    let returned = callback.call(arg).await;
                    return Some(Ok(returned));
                }
            }
        }
    }
}

enum DecodeStatus {
    None,
    Reserve(usize),
    Consume(usize),
}

/// A trait for types which can be called to asynchronously handle deserialization. This trait is
/// what enables zero-copy handling of messages which support borrowing data during deserialization.
pub trait DeserCallback {
    type Arg<'de>: 'de + Deserialize<'de> + Send
    where
        Self: 'de;
    type Return;
    type CallFut<'de>: 'de + Future<Output = Self::Return> + Send
    where
        Self: 'de;
    fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de>;
}

impl<'a, T: DeserCallback> DeserCallback for &'a mut T {
    type Arg<'de> = T::Arg<'de> where T: 'de, 'a: 'de;
    type Return = T::Return;
    type CallFut<'de> = T::CallFut<'de> where T: 'de, 'a: 'de;
    fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
        (*self).call(arg)
    }
}

/// Encodes messages using [btserde].
#[derive(Debug)]
pub(crate) struct MsgEncoder;

impl MsgEncoder {
    pub(crate) fn new() -> Self {
        Self
    }
}

impl<T: Serialize> Encoder<T> for MsgEncoder {
    type Error = btlib::Error;

    fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<()> {
        const U64_LEN: usize = std::mem::size_of::<u64>();
        let payload = dst.split_off(U64_LEN);
        let mut writer = payload.writer();
        write_to(&item, &mut writer)?;
        let payload = writer.into_inner();
        let payload_len = payload.len() as u64;
        let mut writer = dst.writer();
        write_to(&payload_len, &mut writer)?;
        let dst = writer.into_inner();
        dst.unsplit(payload);
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use futures::{future::Ready, SinkExt};
    use serde::Serialize;
    use std::{
        future::ready,
        io::{Cursor, Seek},
        task::Poll,
    };
    use tokio_util::codec::FramedWrite;

    #[derive(Serialize, Deserialize)]
    struct Msg<'a>(&'a [u8]);

    #[tokio::test]
    async fn read_single_message() {
        macro_rules! test_data {
            () => {
                b"fulcrum"
            };
        }

        #[derive(Clone)]
        struct TestCb;

        impl DeserCallback for TestCb {
            type Arg<'de> = Msg<'de> where Self: 'de;
            type Return = bool;
            type CallFut<'de> = Ready<Self::Return>;

            fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
                futures::future::ready(arg.0 == test_data!())
            }
        }

        let mut write = FramedWrite::new(Cursor::new(Vec::<u8>::new()), MsgEncoder);
        write.send(Msg(test_data!())).await.unwrap();
        let mut io = write.into_inner();
        io.rewind().unwrap();
        let mut read = CallbackFramed::new(io);

        let matched = read.next(TestCb).await.unwrap().unwrap();

        assert!(matched)
    }

    struct WindowedCursor {
        window_sz: usize,
        pos: usize,
        buf: Vec<u8>,
    }

    impl WindowedCursor {
        fn new(data: Vec<u8>, window_sz: usize) -> Self {
            WindowedCursor {
                window_sz,
                pos: 0,
                buf: data,
            }
        }
    }

    impl AsyncRead for WindowedCursor {
        fn poll_read(
            mut self: std::pin::Pin<&mut Self>,
            _cx: &mut std::task::Context<'_>,
            buf: &mut tokio::io::ReadBuf<'_>,
        ) -> std::task::Poll<std::io::Result<()>> {
            let end = self.buf.len().min(self.pos + self.window_sz);
            let window = &self.buf[self.pos..end];
            buf.put_slice(window);
            self.as_mut().pos += window.len();
            Poll::Ready(Ok(()))
        }
    }

    struct CopyCallback;

    impl DeserCallback for CopyCallback {
        type Arg<'de> = Msg<'de>;
        type Return = Vec<u8>;
        type CallFut<'de> = std::future::Ready<Self::Return>;
        fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
            ready(arg.0.to_owned())
        }
    }

    #[tokio::test]
    async fn read_in_multiple_parts() {
        const EXPECTED: &[u8] = b"We live in the most interesting of times.";
        let mut write = FramedWrite::new(Cursor::new(Vec::<u8>::new()), MsgEncoder);
        write.send(Msg(EXPECTED)).await.unwrap();
        let data = write.into_inner().into_inner();
        // This will force the CallbackFramed to read the message in multiple iterations.
        let io = WindowedCursor::new(data, EXPECTED.len() / 2);
        let mut read = CallbackFramed::new(io);

        let actual = read.next(CopyCallback).await.unwrap().unwrap();

        assert_eq!(EXPECTED, &actual);
    }
}