Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions runpod/serverless/utils/rp_model_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Utility function for transforming HuggingFace repositories into model-cache paths"""

import typing
from runpod.serverless.modules.rp_logger import RunPodLogger

log = RunPodLogger()


def resolve_model_cache_path_from_hugginface_repository(
huggingface_repository: str,
/,
path_template: str = "/runpod/cache/{model}/{revision}", # TODO: Should we just hardcode this?
) -> typing.Union[str, None]:
"""
Resolves the model-cache path for a HuggingFace model based on its repository string.

Args:
huggingface_repository (str): Repository string in format "model_name:revision" or
"org/model_name:revision". If no revision is specified,
"main" is used. For example:
- "runwayml/stable-diffusion-v1-5:experimental"
- "runwayml/stable-diffusion-v1-5" (uses "main" revision)
- "stable-diffusion-v1-5:main"
path_template (str, optional): Template string for the cache path. Must contain {model}
and {revision} placeholders. Defaults to "/runpod/cache/{model}/{revision}".

Returns:
str | None: Absolute path where the model is cached, following the template provided in path_template. Returns None if no model name could be extracted.

Examples:
>>> resolve_model_cache_path_from_hugginface_repository("runwayml/stable-diffusion-v1-5:experimental")
"/runpod/cache/runwayml/stable-diffusion-v1-5/experimental"
>>> resolve_model_cache_path_from_hugginface_repository("runwayml/stable-diffusion-v1-5")
"/runpod/cache/runwayml/stable-diffusion-v1-5/main"
>>> resolve_model_cache_path_from_hugginface_repository(":experimental")
None
"""
model, *revision = huggingface_repository.rsplit(":", 1)
if not model:
# We could throw an exception here but returning None allows us to filter a list of repositories without needing a try/except block
log.warn( # type: ignore in strict mode the typechecker complains about this method being partially unknown
f'Unable to resolve the model-cache path for "{huggingface_repository}"'
)
return None
return path_template.format(
model=model, revision=revision[0] if revision else "main"
)
53 changes: 53 additions & 0 deletions tests/test_serverless/test_utils/test_model_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import unittest

from runpod.serverless.utils.rp_model_cache import (
resolve_model_cache_path_from_hugginface_repository,
)


class TestModelCache(unittest.TestCase):
"""Tests for rp_model_cache"""

def test_with_revision(self):
"""Test with a revision"""
path = resolve_model_cache_path_from_hugginface_repository(
"runwayml/stable-diffusion-v1-5:experimental"
)
self.assertEqual(
path, "/runpod/cache/runwayml/stable-diffusion-v1-5/experimental"
)

def test_without_revision(self):
"""Test without a revision"""
path = resolve_model_cache_path_from_hugginface_repository(
"runwayml/stable-diffusion-v1-5"
)
self.assertEqual(path, "/runpod/cache/runwayml/stable-diffusion-v1-5/main")

def test_with_multiple_colons(self):
"""Test with multiple colons"""
path = resolve_model_cache_path_from_hugginface_repository(
"runwayml/stable-diffusion:v1-5:experimental"
)
self.assertEqual(
path, "/runpod/cache/runwayml/stable-diffusion:v1-5/experimental"
)

def test_with_custom_path_template(self):
"""Test with a custom path template"""
path = resolve_model_cache_path_from_hugginface_repository(
"runwayml/stable-diffusion-v1-5:experimental",
"/my-custom-model-cache/{model}/{revision}",
)
self.assertEqual(
path, "/my-custom-model-cache/runwayml/stable-diffusion-v1-5/experimental"
)

def test_with_missing_model_name(self):
"""Test with a missing model name"""
path = resolve_model_cache_path_from_hugginface_repository(":experimental")
self.assertIsNone(path)


if __name__ == "__main__":
unittest.main()
Loading