Skip to content

Commit

Permalink
feat: Implement image validation in BuiltinPackageLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
theduke committed Jul 30, 2024
1 parent 803c0be commit 7ddf96c
Showing 1 changed file with 223 additions and 1 deletion.
224 changes: 223 additions & 1 deletion lib/wasix/src/runtime/package_loader/builtin_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ pub struct BuiltinPackageLoader {
cache: Option<FileSystemCache>,
/// A mapping from hostnames to tokens
tokens: HashMap<String, String>,

hash_validation: HashIntegrityValidationMode,
}

/// Defines how to validate package hash integrity.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum HashIntegrityValidationMode {
/// Do not validate anything.
/// Best for performance.
NoValidate,
/// Compute the image hash and produce a trace warning on hash mismatches.
WarnOnHashMismatch,
/// Compute the image hash and fail on a mismatch.
FailOnHashMismatch,
}

impl BuiltinPackageLoader {
Expand All @@ -42,10 +56,19 @@ impl BuiltinPackageLoader {
in_memory: InMemoryCache::default(),
client: Arc::new(crate::http::default_http_client().unwrap()),
cache: None,
hash_validation: HashIntegrityValidationMode::NoValidate,
tokens: HashMap::new(),
}
}

/// Set the validation mode to apply after downloading an image.
///
/// See [`HashIntegrityValidationMode`] for details.
pub fn with_hash_validation_mode(mut self, mode: HashIntegrityValidationMode) -> Self {
self.hash_validation = mode;
self
}

pub fn with_cache_dir(self, cache_dir: impl Into<PathBuf>) -> Self {
BuiltinPackageLoader {
cache: Some(FileSystemCache {
Expand All @@ -55,6 +78,44 @@ impl BuiltinPackageLoader {
}
}

pub fn validate_cache(
&self,
mode: CacheValidationMode,
) -> Result<Vec<ImageHashMismatchError>, anyhow::Error> {
let cache = self
.cache
.as_ref()
.context("can not validate cache - no cache configured")?;

let items = cache.validate_hashes()?;
let mut errors = Vec::new();
for (path, error) in items {
match mode {
CacheValidationMode::WarnOnMismatch => {
tracing::warn!(?error, "hash mismatch in cached image file");
}
CacheValidationMode::PruneOnMismatch => {
tracing::warn!(?error, "deleting cached image file due to hash mismatch");
match std::fs::remove_file(&path) {
Ok(()) => {}
Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
Err(fs_err) => {
tracing::error!(
path=%error.source,
?fs_err,
"could not delete cached image file with hash mismatch"
);
}
}
}
}

errors.push(error);
}

Ok(errors)
}

pub fn with_http_client(self, client: impl HttpClient + Send + Sync + 'static) -> Self {
self.with_shared_http_client(Arc::new(client))
}
Expand Down Expand Up @@ -110,6 +171,40 @@ impl BuiltinPackageLoader {
Ok(None)
}

/// Validate image contents with the specified validation mode.
fn validate_hash(
image: &[u8],
mode: HashIntegrityValidationMode,
info: &DistributionInfo,
) -> Result<(), anyhow::Error> {
match mode {
HashIntegrityValidationMode::NoValidate => {
// Nothing to do.
Ok(())
}
HashIntegrityValidationMode::WarnOnHashMismatch => {
let actual_hash = WebcHash::sha256(image);
if actual_hash != info.webc_sha256 {
tracing::warn!(%info.webc_sha256, %actual_hash, "image hash mismatch - actual image hash does not match the expected hash!");
}
Ok(())
}
HashIntegrityValidationMode::FailOnHashMismatch => {
let actual_hash = WebcHash::sha256(image);
if actual_hash != info.webc_sha256 {
Err(ImageHashMismatchError {
source: info.webc.to_string(),
actual_hash,
expected_hash: info.webc_sha256,
}
.into())
} else {
Ok(())
}
}
}
}

#[tracing::instrument(level = "debug", skip_all, fields(%dist.webc, %dist.webc_sha256))]
async fn download(&self, dist: &DistributionInfo) -> Result<Bytes, Error> {
if dist.webc.scheme() == "file" {
Expand All @@ -121,6 +216,9 @@ impl BuiltinPackageLoader {
})
.await?
.with_context(|| format!("Unable to read \"{}\"", path.display()))?;

Self::validate_hash(&bytes, self.hash_validation, dist)?;

return Ok(bytes.into());
}
Err(e) => {
Expand Down Expand Up @@ -167,6 +265,8 @@ impl BuiltinPackageLoader {
let body = response.body.context("package download failed")?;
tracing::debug!(%url, "package_download_succeeded");

Self::validate_hash(&body, self.hash_validation, dist)?;

Ok(body.into())
}

Expand Down Expand Up @@ -267,6 +367,35 @@ impl PackageLoader for BuiltinPackageLoader {
}
}

#[derive(Clone, Debug)]
pub struct ImageHashMismatchError {
source: String,
expected_hash: WebcHash,
actual_hash: WebcHash,
}

impl std::fmt::Display for ImageHashMismatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"image hash mismatch! expected hash '{}', but the computed hash is '{}' (source '{}')",
self.expected_hash, self.actual_hash, self.source,
)
}
}

impl std::error::Error for ImageHashMismatchError {}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CacheValidationMode {
/// Just emit a warning for all images where the filename doesn't match
/// the expected hash.
WarnOnMismatch,
/// Remove images from the cache if the filename doesn't match the actual
/// hash.
PruneOnMismatch,
}

// FIXME: This implementation will block the async runtime and should use
// some sort of spawn_blocking() call to run it in the background.
#[derive(Debug)]
Expand All @@ -275,6 +404,66 @@ struct FileSystemCache {
}

impl FileSystemCache {
const FILE_SUFFIX: &'static str = ".bin";

/// Validate that the cached image file names correspond to their actual
/// file content hashes.
fn validate_hashes(&self) -> Result<Vec<(PathBuf, ImageHashMismatchError)>, anyhow::Error> {
let mut items = Vec::<(PathBuf, ImageHashMismatchError)>::new();

let iter = match std::fs::read_dir(&self.cache_dir) {
Ok(v) => v,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
// Cache dir does not exist, so nothing to validate.
return Ok(Vec::new());
}
Err(err) => {
return Err(err).with_context(|| {
format!(
"Could not read image cache dir: '{}'",
self.cache_dir.display()
)
});
}
};

for res in iter {
let entry = res?;
if !entry.file_type()?.is_file() {
continue;
}

// Extract the hash from the filename.

let hash_opt = entry
.file_name()
.to_str()
.and_then(|x| {
let (raw_hash, _) = x.split_once(Self::FILE_SUFFIX)?;
Some(raw_hash)
})
.and_then(|x| WebcHash::parse_hex(x).ok());
let Some(expected_hash) = hash_opt else {
continue;
};

// Compute the actual hash.
let path = entry.path();
let actual_hash = WebcHash::for_file(&path)?;

if actual_hash != expected_hash {
let err = ImageHashMismatchError {
source: path.to_string_lossy().to_string(),
actual_hash,
expected_hash,
};
items.push((path, err));
}
}

Ok(items)
}

async fn lookup(&self, hash: &WebcHash) -> Result<Option<Container>, Error> {
let path = self.path(hash);

Expand Down Expand Up @@ -357,7 +546,7 @@ impl FileSystemCache {
for b in hash {
write!(filename, "{b:02x}").unwrap();
}
filename.push_str(".bin");
filename.push_str(Self::FILE_SUFFIX);

self.cache_dir.join(filename)
}
Expand Down Expand Up @@ -484,3 +673,36 @@ mod tests {
cache_misses_will_trigger_a_download_internal().await
}
}

#[cfg(test)]
mod test {
use super::*;

// NOTE: must be a tokio test because the BuiltinPackageLoader::new()
// constructor requires a runtime...
#[tokio::test]
async fn test_builtin_package_downloader_cache_validation() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path();

let contents = "fail";
let correct_hash = WebcHash::sha256(&contents);
let used_hash =
WebcHash::parse_hex("0000a28ea38a000f3a3328cb7fabe330638d3258affe1a869e3f92986222d997")
.unwrap();
let filename = format!("{}{}", used_hash, FileSystemCache::FILE_SUFFIX);
let file_path = path.join(filename);
std::fs::write(&file_path, contents).unwrap();

let dl = BuiltinPackageLoader::new().with_cache_dir(path);

let errors = dl
.validate_cache(CacheValidationMode::PruneOnMismatch)
.unwrap();
assert_eq!(errors.len(), 1);
assert_eq!(errors[0].actual_hash, correct_hash);
assert_eq!(errors[0].expected_hash, used_hash);

assert_eq!(file_path.exists(), false);
}
}

0 comments on commit 7ddf96c

Please sign in to comment.