use anyhow::anyhow;
use log::error;
use nix::{
sys::signal::{self, Signal},
unistd::Pid,
};
use std::{
error::Error,
io,
path::PathBuf,
process::{Child, Command, ExitStatus, Stdio},
str::FromStr,
sync::{
atomic::{AtomicU16, Ordering},
mpsc::{channel, Receiver, TryRecvError},
},
time::{Duration, SystemTime},
};
use tempdir::TempDir;
use tss_esapi::{
tcti_ldr::{TabrmdConfig, TctiNameConf},
Context,
};
pub struct SwtpmHarness {
dir: TempDir,
state_path: PathBuf,
pid_path: PathBuf,
tabrmd: Child,
tabrmd_config: String,
}
impl SwtpmHarness {
const HOST: &'static str = "127.0.0.1";
fn dbus_name(port: u16) -> String {
let port_str: String = port
.to_string()
.chars()
.map(|e| ((e as u8) + 17) as char)
.collect();
format!("com.intel.tss2.Tabrmd.{port_str}")
}
pub fn new() -> anyhow::Result<SwtpmHarness> {
static PORT: AtomicU16 = AtomicU16::new(21901);
let port = PORT.fetch_add(2, Ordering::SeqCst);
let ctrl_port = port + 1;
let dir = TempDir::new(format!("swtpm_harness.{port}").as_str())?;
let dir_path = dir.path();
let dir_path_display = dir_path.display();
let conf_path = dir_path.join("swtpm_setup.conf");
let state_path = dir_path.join("tpm_cred_store.state");
let pid_path = dir_path.join("swtpm.pid");
let dbus_name = Self::dbus_name(port);
let addr = Self::HOST;
std::fs::write(
&conf_path,
r#"# Program invoked for creating certificates
#create_certs_tool= /usr/bin/swtpm_localca
# Comma-separated list (no spaces) of PCR banks to activate by default
active_pcr_banks = sha256
"#,
)?;
Command::new("swtpm_setup")
.stdout(Stdio::null())
.args([
"--tpm2",
"--config",
conf_path.to_str().unwrap(),
"--tpm-state",
format!("{dir_path_display}").as_str(),
])
.status()?
.success_or_err()?;
Command::new("swtpm")
.args([
"socket",
"--daemon",
"--tpm2",
"--server",
format!("type=tcp,port={port},bindaddr={addr}").as_str(),
"--ctrl",
format!("type=tcp,port={ctrl_port},bindaddr={addr}").as_str(),
"--log",
format!("file={dir_path_display}/log.txt,level=5").as_str(),
"--flags",
"not-need-init,startup-clear",
"--tpmstate",
format!("dir={dir_path_display}").as_str(),
"--pid",
format!("file={}", pid_path.display()).as_str(),
])
.status()?
.success_or_err()
.map_err(|err| {
anyhow!("swtpm {err}. This usually indicates an instance of swtpm is still running. You can rectify this with `killall swtpm`.")
})?;
let mut blocker = DbusBlocker::new_session(dbus_name.clone())?;
let tabrmd = Command::new("tpm2-abrmd")
.args([
format!("--tcti=swtpm:host=127.0.0.1,port={port}").as_str(),
"--dbus-name",
dbus_name.as_str(),
"--session",
])
.spawn()?;
blocker.block(Duration::from_secs(5))?;
Ok(SwtpmHarness {
dir,
state_path,
pid_path,
tabrmd,
tabrmd_config: format!("bus_name={},bus_type=session", Self::dbus_name(port)),
})
}
pub fn tabrmd_config(&self) -> &str {
&self.tabrmd_config
}
pub fn context(&self) -> io::Result<Context> {
let config = TabrmdConfig::from_str(self.tabrmd_config()).box_err()?;
Context::new(TctiNameConf::Tabrmd(config)).box_err()
}
pub fn dir_path(&self) -> &std::path::Path {
self.dir.path()
}
pub fn state_path(&self) -> &std::path::Path {
&self.state_path
}
}
impl Drop for SwtpmHarness {
fn drop(&mut self) {
if let Err(err) = self.tabrmd.kill() {
error!("failed to kill tpm2-abrmd: {err}");
}
let pid_str = std::fs::read_to_string(&self.pid_path).unwrap();
let pid_int = pid_str.parse::<i32>().unwrap();
let pid = Pid::from_raw(pid_int);
signal::kill(pid, Signal::SIGKILL).unwrap();
}
}
trait ExitStatusExt {
fn success_or_err(&self) -> anyhow::Result<()>;
}
impl ExitStatusExt for ExitStatus {
fn success_or_err(&self) -> anyhow::Result<()> {
match self.code() {
Some(0) => Ok(()),
Some(code) => Err(anyhow!("ExitCode was non-zero: {code}")),
None => Err(anyhow!("ExitCode was None")),
}
}
}
struct NameOwnerChanged {
name: String,
old_owner: String,
new_owner: String,
}
impl dbus::arg::AppendAll for NameOwnerChanged {
fn append(&self, iter: &mut dbus::arg::IterAppend) {
dbus::arg::RefArg::append(&self.name, iter);
dbus::arg::RefArg::append(&self.old_owner, iter);
dbus::arg::RefArg::append(&self.new_owner, iter);
}
}
impl dbus::arg::ReadAll for NameOwnerChanged {
fn read(iter: &mut dbus::arg::Iter) -> std::result::Result<Self, dbus::arg::TypeMismatchError> {
Ok(NameOwnerChanged {
name: iter.read()?,
old_owner: iter.read()?,
new_owner: iter.read()?,
})
}
}
impl dbus::message::SignalArgs for NameOwnerChanged {
const NAME: &'static str = "NameOwnerChanged";
const INTERFACE: &'static str = "org.freedesktop.DBus";
}
struct DbusBlocker {
receiver: Receiver<()>,
conn: dbus::blocking::Connection,
}
impl DbusBlocker {
fn new_session(name: String) -> io::Result<DbusBlocker> {
use dbus::{blocking::Connection, Message};
const DEST: &str = "org.freedesktop.DBus";
let (sender, receiver) = channel();
let conn = Connection::new_session().box_err()?;
let proxy = conn.with_proxy(DEST, "/org/freedesktop/DBus", Duration::from_secs(1));
let _ = proxy.match_signal(move |h: NameOwnerChanged, _: &Connection, _: &Message| {
let name_appeared = h.name == name;
if name_appeared {
if let Err(err) = sender.send(()) {
error!("failed to send unblocking signal: {err}");
}
}
#[allow(clippy::let_and_return)]
let remove_match = !name_appeared;
remove_match
});
Ok(DbusBlocker { receiver, conn })
}
fn block(&mut self, timeout: Duration) -> io::Result<()> {
let time_limit = SystemTime::now() + timeout;
loop {
self.conn.process(Duration::from_millis(100)).box_err()?;
match self.receiver.try_recv() {
Ok(_) => break,
Err(err) => match err {
TryRecvError::Empty => (),
_ => return Err(io::Error::custom(err)),
},
}
if SystemTime::now() > time_limit {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"timed out waiting for DBUS message",
));
}
}
Ok(())
}
}
trait IoErrorExt {
fn custom<E: Into<Box<dyn Error + Send + Sync>>>(err: E) -> io::Error {
io::Error::new(io::ErrorKind::Other, err)
}
}
impl IoErrorExt for io::Error {}
trait ResultExt<T, E> {
fn box_err(self) -> Result<T, io::Error>;
}
impl<T, E: Into<Box<dyn Error + Send + Sync>>> ResultExt<T, E> for Result<T, E> {
fn box_err(self) -> Result<T, io::Error> {
self.map_err(io::Error::custom)
}
}