use btlib::error::BoxInIoErr;
use btserde::{from_slice, read_from, write_to};
use bytes::{BufMut, BytesMut};
use futures::Future;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio_util::codec::Encoder;
use crate::Result;
pub(crate) struct CallbackFramed<I> {
io: I,
buffer: BytesMut,
}
impl<I> CallbackFramed<I> {
const INIT_CAPACITY: usize = 4096;
const FRAME_LEN_SZ: usize = std::mem::size_of::<u64>();
pub fn new(inner: I) -> Self {
Self {
io: inner,
buffer: BytesMut::with_capacity(Self::INIT_CAPACITY),
}
}
pub fn into_parts(self) -> (I, BytesMut) {
(self.io, self.buffer)
}
pub fn from_parts(io: I, mut buffer: BytesMut) -> Self {
if buffer.capacity() < Self::INIT_CAPACITY {
buffer.reserve(Self::INIT_CAPACITY - buffer.capacity());
}
Self { io, buffer }
}
async fn decode(mut slice: &[u8]) -> Result<DecodeStatus> {
let payload_len: u64 = match read_from(&mut slice) {
Ok(payload_len) => payload_len,
Err(err) => {
return match err {
btserde::Error::Eof => Ok(DecodeStatus::None),
btserde::Error::Io(ref io_err) => match io_err.kind() {
std::io::ErrorKind::UnexpectedEof => Ok(DecodeStatus::None),
_ => Err(err.into()),
},
_ => Err(err.into()),
}
}
};
let payload_len: usize = payload_len.try_into().box_err()?;
if slice.len() < payload_len {
return Ok(DecodeStatus::Reserve(payload_len - slice.len()));
}
Ok(DecodeStatus::Consume(Self::FRAME_LEN_SZ + payload_len))
}
}
macro_rules! attempt {
($result:expr) => {
match $result {
Ok(value) => value,
Err(err) => return Some(Err(err.into())),
}
};
}
impl<S: AsyncRead + Unpin> CallbackFramed<S> {
pub(crate) async fn next<F: DeserCallback>(
&mut self,
mut callback: F,
) -> Option<Result<F::Return>> {
let mut total_read = 0;
loop {
if self.buffer.capacity() - self.buffer.len() == 0 {
self.buffer.reserve(1);
}
let read_ct = attempt!(self.io.read_buf(&mut self.buffer).await);
if 0 == read_ct {
return None;
}
total_read += read_ct;
match attempt!(Self::decode(&self.buffer[..total_read]).await) {
DecodeStatus::None => continue,
DecodeStatus::Reserve(count) => {
self.buffer.reserve(count);
continue;
}
DecodeStatus::Consume(consume) => {
let start = self.buffer.split_to(consume);
let arg: F::Arg<'_> = attempt!(from_slice(&start[Self::FRAME_LEN_SZ..]));
let returned = callback.call(arg).await;
return Some(Ok(returned));
}
}
}
}
}
enum DecodeStatus {
None,
Reserve(usize),
Consume(usize),
}
pub trait DeserCallback {
type Arg<'de>: 'de + Deserialize<'de> + Send
where
Self: 'de;
type Return;
type CallFut<'de>: 'de + Future<Output = Self::Return> + Send
where
Self: 'de;
fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de>;
}
impl<'a, T: DeserCallback> DeserCallback for &'a mut T {
type Arg<'de> = T::Arg<'de> where T: 'de, 'a: 'de;
type Return = T::Return;
type CallFut<'de> = T::CallFut<'de> where T: 'de, 'a: 'de;
fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
(*self).call(arg)
}
}
#[derive(Debug)]
pub(crate) struct MsgEncoder;
impl MsgEncoder {
pub(crate) fn new() -> Self {
Self
}
}
impl<T: Serialize> Encoder<T> for MsgEncoder {
type Error = btlib::Error;
fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<()> {
const U64_LEN: usize = std::mem::size_of::<u64>();
let payload = dst.split_off(U64_LEN);
let mut writer = payload.writer();
write_to(&item, &mut writer)?;
let payload = writer.into_inner();
let payload_len = payload.len() as u64;
let mut writer = dst.writer();
write_to(&payload_len, &mut writer)?;
let dst = writer.into_inner();
dst.unsplit(payload);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::{future::Ready, SinkExt};
use serde::Serialize;
use std::{
future::ready,
io::{Cursor, Seek},
task::Poll,
};
use tokio_util::codec::FramedWrite;
#[derive(Serialize, Deserialize)]
struct Msg<'a>(&'a [u8]);
#[tokio::test]
async fn read_single_message() {
macro_rules! test_data {
() => {
b"fulcrum"
};
}
#[derive(Clone)]
struct TestCb;
impl DeserCallback for TestCb {
type Arg<'de> = Msg<'de> where Self: 'de;
type Return = bool;
type CallFut<'de> = Ready<Self::Return>;
fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
futures::future::ready(arg.0 == test_data!())
}
}
let mut write = FramedWrite::new(Cursor::new(Vec::<u8>::new()), MsgEncoder);
write.send(Msg(test_data!())).await.unwrap();
let mut io = write.into_inner();
io.rewind().unwrap();
let mut read = CallbackFramed::new(io);
let matched = read.next(TestCb).await.unwrap().unwrap();
assert!(matched)
}
struct WindowedCursor {
window_sz: usize,
pos: usize,
buf: Vec<u8>,
}
impl WindowedCursor {
fn new(data: Vec<u8>, window_sz: usize) -> Self {
WindowedCursor {
window_sz,
pos: 0,
buf: data,
}
}
}
impl AsyncRead for WindowedCursor {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let end = self.buf.len().min(self.pos + self.window_sz);
let window = &self.buf[self.pos..end];
buf.put_slice(window);
self.as_mut().pos += window.len();
Poll::Ready(Ok(()))
}
}
struct CopyCallback;
impl DeserCallback for CopyCallback {
type Arg<'de> = Msg<'de>;
type Return = Vec<u8>;
type CallFut<'de> = std::future::Ready<Self::Return>;
fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
ready(arg.0.to_owned())
}
}
#[tokio::test]
async fn read_in_multiple_parts() {
const EXPECTED: &[u8] = b"We live in the most interesting of times.";
let mut write = FramedWrite::new(Cursor::new(Vec::<u8>::new()), MsgEncoder);
write.send(Msg(EXPECTED)).await.unwrap();
let data = write.into_inner().into_inner();
let io = WindowedCursor::new(data, EXPECTED.len() / 2);
let mut read = CallbackFramed::new(io);
let actual = read.next(CopyCallback).await.unwrap().unwrap();
assert_eq!(EXPECTED, &actual);
}
}