diff --git a/runpod/serverless/utils/rp_model_cache.py b/runpod/serverless/utils/rp_model_cache.py new file mode 100644 index 00000000..b16694ca --- /dev/null +++ b/runpod/serverless/utils/rp_model_cache.py @@ -0,0 +1,47 @@ +"""Utility function for transforming HuggingFace repositories into model-cache paths""" + +import typing +from runpod.serverless.modules.rp_logger import RunPodLogger + +log = RunPodLogger() + + +def resolve_model_cache_path_from_hugginface_repository( + huggingface_repository: str, + /, + path_template: str = "/runpod/cache/{model}/{revision}", # TODO: Should we just hardcode this? +) -> typing.Union[str, None]: + """ + Resolves the model-cache path for a HuggingFace model based on its repository string. + + Args: + huggingface_repository (str): Repository string in format "model_name:revision" or + "org/model_name:revision". If no revision is specified, + "main" is used. For example: + - "runwayml/stable-diffusion-v1-5:experimental" + - "runwayml/stable-diffusion-v1-5" (uses "main" revision) + - "stable-diffusion-v1-5:main" + path_template (str, optional): Template string for the cache path. Must contain {model} + and {revision} placeholders. Defaults to "/runpod/cache/{model}/{revision}". + + Returns: + str | None: Absolute path where the model is cached, following the template provided in path_template. Returns None if no model name could be extracted. + + Examples: + >>> resolve_model_cache_path_from_hugginface_repository("runwayml/stable-diffusion-v1-5:experimental") + "/runpod/cache/runwayml/stable-diffusion-v1-5/experimental" + >>> resolve_model_cache_path_from_hugginface_repository("runwayml/stable-diffusion-v1-5") + "/runpod/cache/runwayml/stable-diffusion-v1-5/main" + >>> resolve_model_cache_path_from_hugginface_repository(":experimental") + None + """ + model, *revision = huggingface_repository.rsplit(":", 1) + if not model: + # We could throw an exception here but returning None allows us to filter a list of repositories without needing a try/except block + log.warn( # type: ignore in strict mode the typechecker complains about this method being partially unknown + f'Unable to resolve the model-cache path for "{huggingface_repository}"' + ) + return None + return path_template.format( + model=model, revision=revision[0] if revision else "main" + ) diff --git a/tests/test_serverless/test_utils/test_model_cache.py b/tests/test_serverless/test_utils/test_model_cache.py new file mode 100644 index 00000000..28cea0ab --- /dev/null +++ b/tests/test_serverless/test_utils/test_model_cache.py @@ -0,0 +1,53 @@ +import unittest + +from runpod.serverless.utils.rp_model_cache import ( + resolve_model_cache_path_from_hugginface_repository, +) + + +class TestModelCache(unittest.TestCase): + """Tests for rp_model_cache""" + + def test_with_revision(self): + """Test with a revision""" + path = resolve_model_cache_path_from_hugginface_repository( + "runwayml/stable-diffusion-v1-5:experimental" + ) + self.assertEqual( + path, "/runpod/cache/runwayml/stable-diffusion-v1-5/experimental" + ) + + def test_without_revision(self): + """Test without a revision""" + path = resolve_model_cache_path_from_hugginface_repository( + "runwayml/stable-diffusion-v1-5" + ) + self.assertEqual(path, "/runpod/cache/runwayml/stable-diffusion-v1-5/main") + + def test_with_multiple_colons(self): + """Test with multiple colons""" + path = resolve_model_cache_path_from_hugginface_repository( + "runwayml/stable-diffusion:v1-5:experimental" + ) + self.assertEqual( + path, "/runpod/cache/runwayml/stable-diffusion:v1-5/experimental" + ) + + def test_with_custom_path_template(self): + """Test with a custom path template""" + path = resolve_model_cache_path_from_hugginface_repository( + "runwayml/stable-diffusion-v1-5:experimental", + "/my-custom-model-cache/{model}/{revision}", + ) + self.assertEqual( + path, "/my-custom-model-cache/runwayml/stable-diffusion-v1-5/experimental" + ) + + def test_with_missing_model_name(self): + """Test with a missing model name""" + path = resolve_model_cache_path_from_hugginface_repository(":experimental") + self.assertIsNone(path) + + +if __name__ == "__main__": + unittest.main()