use crate::{
bterr,
crypto::{AsymKeyPub, BitLen, HashKind, RsaSsaPss, Sha2_256, Sha2_512, Sign, Signature},
BlockPath, Epoch, Principal, Principaled, Result, Writecap, WritecapBody,
};
use bcder::{
decode::{BytesSource, Constructed, DecodeError, SliceSource},
encode::{PrimitiveContent, Values},
BitString, Captured, Integer, Mode, OctetString, Oid, Tag, Utf8String,
};
use bytes::{BufMut, Bytes, BytesMut};
use chrono::{offset::Utc, TimeZone};
use std::ops::Deref;
use x509_certificate::{
asn1time::{Time, UtcTime},
certificate::X509Certificate,
rfc3280::{AttributeValue, Name},
rfc5280::{
AlgorithmIdentifier, AlgorithmParameter, Certificate, CertificateSerialNumber, Extension,
Extensions, SubjectPublicKeyInfo, TbsCertificate, Validity, Version,
},
};
mod private {
use super::*;
fn oid(slice: &'static [u8]) -> Oid {
Oid(Bytes::from(slice))
}
macro_rules! bit_string {
($bytes:expr) => {
BitString::new(0, Bytes::from($bytes))
};
}
impl Sha2_256 {
const OID: &[u8] = &[0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01];
}
impl Sha2_512 {
const OID: &[u8] = &[0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03];
}
impl HashKind {
const fn oid(&self) -> &'static [u8] {
match self {
HashKind::Sha2_256 => Sha2_256::OID,
HashKind::Sha2_512 => Sha2_512::OID,
}
}
fn from_oid(slice: &[u8]) -> Result<Self> {
if slice == Sha2_256::OID {
Ok(Self::Sha2_256)
} else if slice == Sha2_512::OID {
Ok(Self::Sha2_512)
} else {
Err(bterr!("unrecognized OID"))
}
}
}
const MGF1_OID: &[u8] = &[0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x08];
struct SingleParamAlgoId {
algorithm: Oid,
parameters: Oid,
}
impl SingleParamAlgoId {
fn new(hash_kind: HashKind) -> Self {
let algorithm = Oid(Bytes::from(MGF1_OID));
let parameters = Oid(Bytes::from(hash_kind.oid()));
Self {
algorithm,
parameters,
}
}
fn encode_ref(&self) -> impl Values + '_ {
self.encode_ref_as(Tag::SEQUENCE)
}
fn encode_ref_as(&self, tag: Tag) -> impl Values + '_ {
bcder::encode::sequence_as(
tag,
(self.algorithm.encode_ref(), self.parameters.encode_ref()),
)
}
fn take_from<S: bcder::decode::Source>(
cons: &mut Constructed<'_, S>,
) -> std::result::Result<Self, DecodeError<S::Error>> {
cons.take_sequence(|cons| {
Ok(Self {
algorithm: Oid::take_from(cons)?,
parameters: Oid::take_from(cons)?,
})
})
}
}
struct RsaSsaPssParams {
hash_algorithm: Oid,
mask_gen_algorithm: SingleParamAlgoId,
salt_length: Integer,
trailer_field: Option<Integer>,
}
impl RsaSsaPssParams {
fn new(hash_kind: HashKind) -> Self {
let hash_algorithm = Oid(Bytes::from(hash_kind.oid()));
let mask_gen_algorithm = SingleParamAlgoId::new(hash_kind);
let salt_length = Integer::from(hash_kind.len() as u64);
Self {
hash_algorithm,
mask_gen_algorithm,
salt_length,
trailer_field: None,
}
}
fn encode_ref(&self) -> impl bcder::encode::Values + '_ {
self.encode_ref_as(bcder::Tag::SEQUENCE)
}
fn encode_ref_as(&self, tag: Tag) -> impl bcder::encode::Values + '_ {
use bcder::encode::Constructed;
bcder::encode::sequence_as(
tag,
(
Constructed::new(Tag::CTX_0, self.hash_algorithm.encode_ref()),
Constructed::new(Tag::CTX_1, self.mask_gen_algorithm.encode_ref()),
Constructed::new(Tag::CTX_2, self.salt_length.encode()),
self.trailer_field
.as_ref()
.map(|e| Constructed::new(Tag::CTX_3, e.encode())),
),
)
}
fn take_from<S: bcder::decode::Source>(
cons: &mut Constructed<'_, S>,
) -> std::result::Result<Option<Self>, DecodeError<S::Error>> {
let option = cons.take_opt_value_if(Tag::SEQUENCE, |content| {
let cons = content.as_constructed()?;
Ok(RsaSsaPssParams {
hash_algorithm: cons.take_constructed_if(Tag::CTX_0, Oid::take_from)?,
mask_gen_algorithm: cons
.take_constructed_if(Tag::CTX_1, SingleParamAlgoId::take_from)?,
salt_length: cons.take_constructed_if(Tag::CTX_2, Integer::take_from)?,
trailer_field: cons.take_opt_constructed_if(Tag::CTX_3, Integer::take_from)?,
})
})?;
if option.is_none() {
cons.take_null()?;
}
Ok(option)
}
}
impl RsaSsaPss {
const USE_PSS_OID: bool = false;
const RSA_PSS_OID: &[u8] = &[0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0A];
const RSA_ES_OID: &[u8] = &[0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01];
const fn oid() -> &'static [u8] {
if Self::USE_PSS_OID {
Self::RSA_PSS_OID
} else {
Self::RSA_ES_OID
}
}
fn params(&self) -> Result<Option<AlgorithmParameter>> {
if Self::USE_PSS_OID {
let params = RsaSsaPssParams::new(self.hash_kind);
let values = params.encode_ref();
let captured = Captured::from_values(bcder::Mode::Der, values);
Ok(Some(AlgorithmParameter::from_captured(captured)))
} else {
Ok(None)
}
}
fn from_params(sig_octet_len: u32, params: Option<&[u8]>) -> Result<Self> {
let key_bits = BitLen::try_from(sig_octet_len)?;
let hash_kind = match params {
Some(params) => {
let source = SliceSource::new(params);
let params =
Constructed::decode(source, Mode::Der, RsaSsaPssParams::take_from)?;
if let Some(params) = params {
HashKind::from_oid(params.hash_algorithm.as_ref())?
} else {
HashKind::default()
}
}
None => HashKind::default(),
};
Ok(Self {
key_bits,
hash_kind,
})
}
}
impl Sign {
const fn oid(&self) -> &'static [u8] {
match self {
Sign::RsaSsaPss(..) => RsaSsaPss::oid(),
}
}
fn params(&self) -> Result<Option<AlgorithmParameter>> {
match self {
Sign::RsaSsaPss(inner) => inner.params(),
}
}
fn from_der(sig_octet_len: u32, oid: &[u8], params: Option<&[u8]>) -> Result<Sign> {
if oid == RsaSsaPss::oid() {
Ok(Sign::RsaSsaPss(RsaSsaPss::from_params(
sig_octet_len,
params,
)?))
} else {
Err(bterr!("OID does not match a Sign variant"))
}
}
fn to_algo_id(self) -> Result<AlgorithmIdentifier> {
Ok(AlgorithmIdentifier {
algorithm: oid(self.oid()),
parameters: self.params()?,
})
}
fn from_algo_id(sig_octet_len: u32, algo_id: &AlgorithmIdentifier) -> Result<Sign> {
let params = algo_id.parameters.as_ref().map(|e| e.as_ref());
Sign::from_der(sig_octet_len, algo_id.algorithm.as_ref(), params)
}
}
struct SubjectAltName {
block_path: Utf8String,
}
impl SubjectAltName {
const OID: &[u8] = &[0x55, 0x1D, 0x11];
fn new(block_path: &BlockPath) -> Result<Self> {
let block_path = Utf8String::from_string(block_path.to_string())
.map_err(|err| bterr!("{:?}", err))?;
Ok(Self { block_path })
}
fn encode(&self) -> impl Values + '_ {
self.encode_as(Tag::SEQUENCE)
}
fn encode_as(&self, tag: Tag) -> impl Values + '_ {
bcder::encode::sequence_as(tag, self.block_path.encode_ref())
}
fn take_from<S: bcder::decode::Source>(
cons: &mut Constructed<'_, S>,
) -> std::result::Result<Self, DecodeError<S::Error>> {
cons.take_sequence(|cons| {
Ok(Self {
block_path: Utf8String::take_from(cons)?,
})
})
}
fn encode_der(&self) -> Result<Bytes> {
let mut writer = BytesMut::new().writer();
self.encode().write_encoded(Mode::Der, &mut writer)?;
Ok(writer.into_inner().into())
}
fn decode_der<B: AsRef<[u8]>>(bytes: B) -> Result<Self> {
let source = SliceSource::new(bytes.as_ref());
Constructed::decode(source, Mode::Der, Self::take_from).map_err(|err| err.into())
}
}
impl AsymKeyPub<Sign> {
fn subject_public_key_info(&self) -> Result<SubjectPublicKeyInfo> {
Ok(SubjectPublicKeyInfo {
algorithm: self.scheme.to_algo_id()?,
subject_public_key: self.to_bit_string()?,
})
}
fn to_bit_string(&self) -> Result<BitString> {
let der = self.pkey.public_key_to_der()?;
let source = BytesSource::new(Bytes::from(der));
let spki = Constructed::decode(source, Mode::Der, SubjectPublicKeyInfo::take_from)?;
Ok(spki.subject_public_key)
}
fn from_subject_public_key_info(
sig_octet_len: u32,
spki: &SubjectPublicKeyInfo,
) -> Result<Self> {
let scheme = Sign::from_algo_id(sig_octet_len, &spki.algorithm)?;
let mut der = Vec::new();
spki.encode_ref().write_encoded(Mode::Der, &mut der)?;
AsymKeyPub::new(scheme, der.as_slice())
}
pub fn to_der(&self) -> Result<Vec<u8>> {
let spki = self.subject_public_key_info()?;
let mut vec = Vec::new();
spki.encode_ref().write_encoded(Mode::Der, &mut vec)?;
Ok(vec)
}
}
trait NameExt {
fn try_get_common_name(&self) -> Result<&AttributeValue>;
}
impl NameExt for Name {
fn try_get_common_name(&self) -> Result<&AttributeValue> {
Ok(&self
.iter_common_name()
.next()
.ok_or_else(|| bterr!("no CommonName component in Name"))?
.value)
}
}
trait TryAsStr {
fn try_as_str(&self) -> Result<&str>;
}
impl<T: ?Sized + AsRef<[u8]>> TryAsStr for T {
fn try_as_str(&self) -> Result<&str> {
std::str::from_utf8(self.as_ref()).map_err(|err| err.into())
}
}
trait TimeExt {
fn try_to_epoch(&self) -> Result<Epoch>;
}
impl TimeExt for Time {
fn try_to_epoch(&self) -> Result<Epoch> {
match self {
Self::UtcTime(time) => Ok(Epoch::from_value(time.timestamp() as u64)),
Self::GeneralTime(..) => Err(bterr!("unsupported Time variant encountered")),
}
}
}
trait ExtensionsExt {
fn find_subject_alt_name(&self) -> Result<SubjectAltName>;
}
impl ExtensionsExt for Extensions {
fn find_subject_alt_name(&self) -> Result<SubjectAltName> {
let extensions: &[Extension] = self.deref();
for extension in extensions {
if extension.id.as_ref() == SubjectAltName::OID {
return SubjectAltName::decode_der(extension.value.to_bytes());
}
}
Err(bterr!("SubjectAltName not found"))
}
}
impl Principal {
fn to_name(&self) -> Result<Name> {
let mut name = Name::default();
let string = self.to_string();
name.append_common_name_utf8_string(&string)
.map_err(|_| bterr!("failed to create Name for Principal"))?;
Ok(name)
}
fn from_name(name: &Name) -> Result<Self> {
let principal = name
.try_get_common_name()?
.to_string()?
.as_str()
.try_into()?;
Ok(principal)
}
pub fn to_name_der(&self) -> Result<Vec<u8>> {
let name = self.to_name()?;
let mut vec = Vec::new();
name.encode_ref().write_encoded(Mode::Der, &mut vec)?;
Ok(vec)
}
}
impl Writecap {
fn to_cert(&self, subject_key: &AsymKeyPub<Sign>) -> Result<Vec<u8>> {
let version = Some(Version::V3);
let serial_number = CertificateSerialNumber::from(1);
let signature_algorithm = self.body.signing_key.scheme.to_algo_id()?;
let issuer = self.body.signing_key.principal().to_name()?;
let expires = Utc
.timestamp_millis_opt(1000 * self.body.expires.to_unix())
.single()
.ok_or_else(|| {
bterr!("failed to convert writecap expiration to chrono DataTime")
})?;
let validity = Validity {
not_before: Time::UtcTime(UtcTime::now()),
not_after: Time::from(expires),
};
let subject = self.body.issued_to.to_name()?;
let subject_public_key_info = subject_key.subject_public_key_info()?;
let mut extensions = Extensions::default();
let san = SubjectAltName::new(&self.body.path)?;
extensions.push(Extension {
id: oid(SubjectAltName::OID),
critical: Some(false),
value: OctetString::new(san.encode_der()?),
});
let tbs_certificate = TbsCertificate {
version,
serial_number,
signature: signature_algorithm.clone(),
issuer,
validity,
subject,
subject_public_key_info,
issuer_unique_id: None,
subject_unique_id: None,
extensions: Some(extensions),
raw_data: None,
};
let cert = Certificate {
tbs_certificate,
signature_algorithm,
signature: bit_string!(self.signature.data.clone()),
};
let cert: X509Certificate = cert.into();
cert.encode_der().map_err(|err| err.into())
}
fn to_cert_chain_impl(&self, subject_key: &AsymKeyPub<Sign>) -> Result<Vec<Vec<u8>>> {
let mut chain = match self.next.as_ref() {
Some(next) => next.as_ref().to_cert_chain_impl(&self.body.signing_key)?,
None => {
let mut vec = Vec::with_capacity(2);
let root_principal = self.body.signing_key.principal();
let path =
BlockPath::from_components(root_principal.clone(), std::iter::empty());
let writecap = Writecap {
body: WritecapBody {
issued_to: root_principal,
expires: Epoch::now(),
path,
signing_key: self.body.signing_key.clone(),
},
signature: self.signature.clone(),
next: None,
};
vec.push(writecap.to_cert(&self.body.signing_key)?);
vec
}
};
chain.push(self.to_cert(subject_key)?);
Ok(chain)
}
pub fn to_cert_chain(&self, subject_key: &AsymKeyPub<Sign>) -> Result<Vec<Vec<u8>>> {
let mut chain = self.to_cert_chain_impl(subject_key)?;
chain.reverse();
Ok(chain)
}
pub fn from_cert_chain<B: AsRef<[u8]>>(
first: &B,
rest: &[B],
) -> Result<(Writecap, AsymKeyPub<Sign>)> {
let (next, signing_key) = if !rest.is_empty() {
let (writecap, signing_key) = Self::from_cert_chain(&rest[0], &rest[1..])?;
let writecap = if rest.len() == 1 {
None
} else {
Some(writecap)
};
(writecap, Some(signing_key))
} else {
(None, None)
};
let x509_cert = X509Certificate::from_der(first)?;
let cert: &Certificate = x509_cert.as_ref();
if cert.signature.unused() > 0 {
return Err(bterr!("signature length is not divisible by 8"));
}
let extensions = cert
.tbs_certificate
.extensions
.as_ref()
.ok_or_else(|| bterr!("no extensions present"))?;
let san = extensions.find_subject_alt_name()?;
let path = BlockPath::try_from(san.block_path.into_bytes().try_as_str()?)
.map_err(|err| bterr!(err))?;
let sig_octet_len: u32 = cert.signature.octet_len().try_into()?;
let scheme = Sign::from_algo_id(
sig_octet_len,
&cert.tbs_certificate.subject_public_key_info.algorithm,
)?;
let signature = Signature::new(scheme, cert.signature.octet_bytes().into());
let cert = &cert.tbs_certificate;
let subject_key = AsymKeyPub::from_subject_public_key_info(
sig_octet_len,
&cert.subject_public_key_info,
)?;
let issued_to = Principal::from_name(&cert.subject)?;
let expires = cert.validity.not_after.try_to_epoch()?;
let signing_key = signing_key.unwrap_or_else(|| subject_key.clone());
let writecap = Writecap {
body: WritecapBody {
issued_to,
signing_key,
path,
expires,
},
signature,
next: next.map(Box::new),
};
Ok((writecap, subject_key))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use webpki::EndEntityCert;
use crate::{
crypto::{ConcreteCreds, Creds, CredsPriv, CredsPub},
test_helpers::node_creds,
};
#[allow(dead_code)]
fn save_first_writecap_der_to_file() {
let node_creds = node_creds();
let chain = node_creds
.writecap
.as_ref()
.unwrap()
.to_cert_chain(&node_creds.sign.public)
.unwrap();
let first = chain.first().unwrap();
std::fs::write("/tmp/cert.der", first).unwrap();
}
#[test]
fn node_writecap_to_cert_chain() {
let node_creds = node_creds();
let result = node_creds
.writecap
.as_ref()
.unwrap()
.to_cert_chain(&node_creds.sign.public);
assert!(result.is_ok())
}
#[test]
fn node_writecap_to_cert_chain_end_cert_can_be_parsed() {
let node_creds = node_creds();
let chain = node_creds
.writecap
.as_ref()
.unwrap()
.to_cert_chain(&node_creds.sign.public)
.unwrap();
let der = chain.first().unwrap();
let result = EndEntityCert::try_from(der.as_slice());
result.unwrap();
}
#[test]
fn round_trip_writecap() {
let node_creds = node_creds();
let expected_key = node_creds.public_sign();
let expected_wc = node_creds.writecap().unwrap();
let certs = expected_wc.to_cert_chain(expected_key).unwrap();
let (actual_wc, actual_key) =
Writecap::from_cert_chain(certs.first().unwrap(), &certs[1..]).unwrap();
assert_eq!(expected_key, &actual_key);
assert_eq!(expected_wc, &actual_wc);
actual_wc.assert_valid_for(&expected_wc.body.path).unwrap();
}
#[test]
fn round_trip_chain_of_length_two() {
let node_creds = node_creds();
let mut process_creds = ConcreteCreds::generate().unwrap();
let writecap = node_creds
.issue_writecap(
process_creds.principal(),
&mut ["console"].into_iter(),
Epoch::now() + Duration::from_secs(3600),
)
.unwrap();
process_creds.set_writecap(writecap).unwrap();
let expected_key = process_creds.public_sign();
let expected_wc = process_creds.writecap().unwrap();
let certs = expected_wc.to_cert_chain(expected_key).unwrap();
let (actual_wc, actual_key) =
Writecap::from_cert_chain(certs.first().unwrap(), &certs[1..]).unwrap();
assert_eq!(expected_key, &actual_key);
assert_eq!(expected_wc, &actual_wc);
actual_wc.assert_valid_for(&expected_wc.body.path).unwrap();
}
}