use crate::{
bterr,
crypto::{Error, HashKind, Result, SymKey, SymParams},
Decompose, Sectored, Split, TryCompose, EMPTY_SLICE,
};
use openssl::symm::{Crypter, Mode};
use positioned_io::Size;
use std::io::{self, Read, Seek, SeekFrom, Write};
pub use private::SecretStream;
mod private {
use super::*;
const IV_BUF_LEN: usize = HashKind::Sha2_512.len();
pub struct SecretStream<T> {
inner: T,
inner_sect_sz: usize,
sect_sz: usize,
key: SymKey,
ct_buf: Vec<u8>,
pt_buf: Vec<u8>,
iv_buf: [u8; IV_BUF_LEN],
}
impl<T> SecretStream<T> {
pub fn get_ref(&self) -> &T {
&self.inner
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner
}
fn inner_offset(&self, outer_offset: u64) -> u64 {
let sect_sz = self.sect_sz as u64;
let inner_sect_sz = self.inner_sect_sz as u64;
outer_offset % sect_sz + outer_offset / sect_sz * inner_sect_sz
}
fn outer_offset(&self, inner_offset: u64) -> u64 {
let sect_sz = self.sect_sz as u64;
let inner_sect_sz = self.inner_sect_sz as u64;
inner_offset % inner_sect_sz + inner_offset / inner_sect_sz * sect_sz
}
}
macro_rules! sym_params {
($self:expr) => {{
let inner_offset = $self.inner.stream_position()?;
let SymParams { cipher, key, iv } = $self.key.params();
let iv = iv.ok_or_else(|| bterr!("no IV was present in block key"))?;
let kind = if iv.len() <= HashKind::Sha2_256.len() {
HashKind::Sha2_256
} else {
HashKind::Sha2_512
};
debug_assert!(iv.len() <= kind.len());
kind.digest(
&mut $self.iv_buf,
[inner_offset.to_le_bytes().as_slice(), iv].into_iter(),
)?;
let iv = &$self.iv_buf[..iv.len()];
Ok::<_, io::Error>(SymParams {
cipher,
key,
iv: Some(iv),
})
}};
}
impl SecretStream<()> {
pub fn new(key: SymKey) -> SecretStream<()> {
SecretStream {
inner: (),
inner_sect_sz: 0,
sect_sz: 0,
key,
ct_buf: Vec::new(),
pt_buf: Vec::new(),
iv_buf: [0u8; IV_BUF_LEN],
}
}
}
impl<T> Split<SecretStream<&'static [u8]>, T> for SecretStream<T> {
fn split(self) -> (SecretStream<&'static [u8]>, T) {
let new_self = SecretStream {
inner: EMPTY_SLICE,
inner_sect_sz: self.inner_sect_sz,
sect_sz: self.sect_sz,
key: self.key,
ct_buf: self.ct_buf,
pt_buf: self.pt_buf,
iv_buf: [0u8; IV_BUF_LEN],
};
(new_self, self.inner)
}
fn combine(left: SecretStream<&'static [u8]>, right: T) -> Self {
SecretStream {
inner: right,
inner_sect_sz: left.inner_sect_sz,
sect_sz: left.sect_sz,
key: left.key,
ct_buf: left.ct_buf,
pt_buf: left.pt_buf,
iv_buf: [0u8; IV_BUF_LEN],
}
}
}
impl<T> Decompose<T> for SecretStream<T> {
fn into_inner(self) -> T {
self.inner
}
}
impl<T, U: Sectored> TryCompose<U, SecretStream<U>> for SecretStream<T> {
type Error = crate::Error;
fn try_compose(mut self, inner: U) -> Result<SecretStream<U>> {
let inner_sect_sz = inner.sector_sz();
let expansion_sz = self.key.expansion_sz();
let sect_sz = inner_sect_sz - expansion_sz;
let block_sz = self.key.block_size();
if 0 != sect_sz % block_sz {
return Err(bterr!(Error::IndivisibleSize {
divisor: block_sz,
actual: sect_sz,
}));
}
self.ct_buf.resize(inner_sect_sz, 0);
self.pt_buf.resize(inner_sect_sz + block_sz, 0);
Ok(SecretStream {
inner,
inner_sect_sz,
sect_sz: inner_sect_sz - expansion_sz,
key: self.key,
ct_buf: self.ct_buf,
pt_buf: self.pt_buf,
iv_buf: [0u8; IV_BUF_LEN],
})
}
}
impl<T> Sectored for SecretStream<T> {
fn sector_sz(&self) -> usize {
self.sect_sz
}
}
impl<T: Write + Seek> Write for SecretStream<T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.assert_sector_sz(buf.len())?;
let SymParams { cipher, key, iv } = sym_params!(self)?;
self.ct_buf.resize(self.inner_sect_sz, 0);
let mut encrypter = Crypter::new(cipher, Mode::Encrypt, key, iv)?;
let mut count = encrypter.update(buf, &mut self.ct_buf)?;
count += encrypter.finalize(&mut self.ct_buf[count..])?;
self.ct_buf.truncate(count);
self.inner.write_all(&self.ct_buf).map(|_| buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<T: Read + Seek> Read for SecretStream<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.assert_sector_sz(buf.len())?;
let SymParams { cipher, key, iv } = sym_params!(self)?;
match self.inner.read_exact(&mut self.ct_buf) {
Ok(_) => (),
Err(err) => {
if err.kind() == io::ErrorKind::UnexpectedEof {
return Ok(0);
} else {
return Err(err);
}
}
}
self.pt_buf
.resize(self.inner_sect_sz + self.key.block_size(), 0);
let mut decrypter = Crypter::new(cipher, Mode::Decrypt, key, iv)?;
let mut count = decrypter.update(&self.ct_buf, &mut self.pt_buf)?;
count += decrypter.finalize(&mut self.pt_buf[count..])?;
self.pt_buf.truncate(count);
buf.copy_from_slice(&self.pt_buf);
Ok(buf.len())
}
}
impl<T: Seek> Seek for SecretStream<T> {
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
let outer_offset = match pos {
SeekFrom::Start(offset) => offset,
SeekFrom::Current(offset) => {
let inner_offset = self.inner.stream_position()?;
let outer_offset = self.outer_offset(inner_offset);
if offset >= 0 {
outer_offset + offset as u64
} else {
outer_offset - (-offset as u64)
}
}
SeekFrom::End(_) => {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"seeking from the end of the stream is not supported",
));
}
};
let inner_offset = self.inner_offset(outer_offset);
self.inner.seek(SeekFrom::Start(inner_offset))?;
Ok(outer_offset)
}
}
impl<U, T: AsRef<U>> AsRef<U> for SecretStream<T> {
fn as_ref(&self) -> &U {
self.inner.as_ref()
}
}
impl<U, T: AsMut<U>> AsMut<U> for SecretStream<T> {
fn as_mut(&mut self) -> &mut U {
self.inner.as_mut()
}
}
impl<T: Size> Size for SecretStream<T> {
fn size(&self) -> io::Result<Option<u64>> {
self.inner.size()
}
}
}
#[cfg(test)]
mod tests {
use crate::{
crypto::SymKeyKind,
test_helpers::{Randomizer, SectoredCursor},
SECTOR_SZ_DEFAULT,
};
use super::*;
fn secret_stream_sequential_test_case(key: SymKey, inner_sect_sz: usize, sect_ct: usize) {
let mut stream = SecretStream::new(key)
.try_compose(SectoredCursor::new(
vec![0u8; inner_sect_sz * sect_ct],
inner_sect_sz,
))
.expect("compose failed");
let sector_sz = stream.sector_sz();
for k in 0..sect_ct {
let sector = vec![k as u8; sector_sz];
stream.write(§or).expect("write failed");
}
stream.seek(SeekFrom::Start(0)).expect("seek failed");
for k in 0..sect_ct {
let expected = vec![k as u8; sector_sz];
let mut actual = vec![0u8; sector_sz];
stream.read(&mut actual).expect("read failed");
assert!(expected == actual);
}
}
fn secret_stream_sequential_test_suite(kind: SymKeyKind) {
let key = SymKey::generate(kind).expect("key generation failed");
secret_stream_sequential_test_case(key.clone(), SECTOR_SZ_DEFAULT, 16);
}
#[test]
fn secret_stream_encrypt_decrypt_are_inverse_aes256cbc() {
secret_stream_sequential_test_suite(SymKeyKind::Aes256Cbc)
}
#[test]
fn secret_stream_encrypt_decrypt_are_inverse_aes256ctr() {
secret_stream_sequential_test_suite(SymKeyKind::Aes256Ctr)
}
fn secret_stream_random_access_test_case(
rando: Randomizer,
key: SymKey,
inner_sect_sz: usize,
sect_ct: usize,
) {
let mut stream = SecretStream::new(key)
.try_compose(SectoredCursor::new(
vec![0u8; inner_sect_sz * sect_ct],
inner_sect_sz,
))
.expect("compose failed");
let sect_sz = stream.sector_sz();
let indices: Vec<usize> = rando.take(sect_ct).map(|e| e % sect_ct).collect();
for index in indices.iter().map(|e| *e) {
let offset = index * sect_sz;
stream
.seek(SeekFrom::Start(offset as u64))
.expect("seek to write failed");
let sector = vec![index as u8; sect_sz];
stream.write(§or).expect("write failed");
}
for index in indices.iter().map(|e| *e) {
let offset = index * sect_sz;
stream
.seek(SeekFrom::Start(offset as u64))
.expect("seek to read failed");
let expected = vec![index as u8; sect_sz];
let mut actual = vec![0u8; sect_sz];
stream.read(&mut actual).expect("read failed");
assert_eq!(expected, actual);
}
}
fn secret_stream_random_access_test_suite(kind: SymKeyKind) {
const SEED: [u8; Randomizer::HASH.len()] = [3u8; Randomizer::HASH.len()];
let key = SymKey::generate(kind).expect("key generation failed");
secret_stream_random_access_test_case(
Randomizer::new(SEED),
key.clone(),
SECTOR_SZ_DEFAULT,
20,
);
secret_stream_random_access_test_case(
Randomizer::new(SEED),
key.clone(),
SECTOR_SZ_DEFAULT,
800,
);
secret_stream_random_access_test_case(Randomizer::new(SEED), key.clone(), 512, 200);
secret_stream_random_access_test_case(Randomizer::new(SEED), key.clone(), 512, 20);
secret_stream_random_access_test_case(Randomizer::new(SEED), key.clone(), 512, 200);
}
#[test]
fn secret_stream_random_access() {
secret_stream_random_access_test_suite(SymKeyKind::Aes256Cbc);
secret_stream_random_access_test_suite(SymKeyKind::Aes256Ctr);
}
fn make_secret_stream(
key_kind: SymKeyKind,
num_sectors: usize,
) -> SecretStream<SectoredCursor<Vec<u8>>> {
let key = SymKey::generate(key_kind).expect("key generation failed");
let inner = SectoredCursor::new(
vec![0u8; num_sectors * SECTOR_SZ_DEFAULT],
SECTOR_SZ_DEFAULT,
);
SecretStream::new(key)
.try_compose(inner)
.expect("compose failed")
}
#[test]
fn secret_stream_seek_from_start() {
let mut stream = make_secret_stream(SymKeyKind::Aes256Cbc, 3);
let sector_sz = stream.sector_sz();
let expected = vec![2u8; sector_sz];
for k in 1..4 {
let sector: Vec<u8> = std::iter::repeat(k as u8).take(sector_sz).collect();
stream.write(§or).expect("writing to stream failed");
}
stream
.seek(SeekFrom::Start(sector_sz as u64))
.expect("seek failed");
let mut actual = vec![0u8; sector_sz];
stream
.read(&mut actual)
.expect("reading from stream failed");
assert_eq!(expected, actual);
}
#[test]
fn secret_stream_seek_from_current() {
let mut stream = make_secret_stream(SymKeyKind::Aes256Cbc, 3);
let sector_sz = stream.sector_sz();
let expected = vec![3u8; sector_sz];
for k in 1..4 {
let sector: Vec<u8> = std::iter::repeat(k as u8).take(sector_sz).collect();
stream.write(§or).expect("writing to stream failed");
}
stream
.seek(SeekFrom::Current(-1 * (sector_sz as i64)))
.expect("seek failed");
let mut actual = vec![0u8; sector_sz];
stream
.read(&mut actual)
.expect("reading from stream failed");
assert_eq!(expected, actual);
}
}