diff --git a/equimo/io.py b/equimo/io.py index 08e3ae6..8d29bee 100644 --- a/equimo/io.py +++ b/equimo/io.py @@ -102,6 +102,7 @@ def download(identifier: str, repository: str) -> Path: url = f"{repository}/{identifier}.tar.lz4" path = Path(f"~/.cache/equimo/{identifier}.tar.lz4").expanduser() + path.parent.mkdir(parents=True, exist_ok=True) if path.exists(): logger.info("Archive already downloaded, using cached file.") diff --git a/tests/test_models.py b/tests/test_models.py index 16d662f..c8171a5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -120,3 +120,16 @@ def test_save_load_model_uncompressed(): loaded_output = loaded_model.features(x, key=key) assert jnp.allclose(original_output, loaded_output, atol=1e-5) + + +def test_load_pretrained_model(): + """Test loading a pretrained model from the repository.""" + key = jr.PRNGKey(42) + model = load_model(cls="vit", identifier="dinov2_vits14_reg") + + # Test inference + x = jr.normal(key, (3, 224, 224)) + features = model.features(x, key=key) + + assert features.shape[-1] == 384 # DINOv2-S has embedding dimension of 384 + assert jnp.all(jnp.isfinite(features)) # Check for NaN/Inf values diff --git a/uv.lock b/uv.lock index 151efe8..c3f1b19 100644 --- a/uv.lock +++ b/uv.lock @@ -84,7 +84,7 @@ wheels = [ [[package]] name = "equimo" -version = "0.2.0" +version = "0.2.1" source = { virtual = "." } dependencies = [ { name = "einops" },