diff --git a/patches/kaggle_module_resolver.py b/patches/kaggle_module_resolver.py index 5e0c5d8a..430cb980 100644 --- a/patches/kaggle_module_resolver.py +++ b/patches/kaggle_module_resolver.py @@ -4,18 +4,19 @@ from tensorflow_hub import resolver -url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P[^\\/]+)/(?P[^\\/]+)/frameworks/(?P[^\\/]+)/variations/(?P[^\\/]+)/versions/(?P[0-9]+)$") +short_url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P[^\\/]+)/(?P[^\\/]+)/(?P[^\\/]+)/(?P[^\\/]+)/(?P[0-9]+)$") +long_url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P[^\\/]+)/(?P[^\\/]+)/frameworks/(?P[^\\/]+)/variations/(?P[^\\/]+)/versions/(?P[0-9]+)$") def _is_on_kaggle_notebook(): return os.getenv("KAGGLE_KERNEL_RUN_TYPE") != None and os.getenv("KAGGLE_USER_SECRETS_TOKEN") != None def _is_kaggle_handle(handle): - return url_pattern.match(handle) != None + return long_url_pattern.match(handle) != None or short_url_pattern.match(handle) != None class KaggleFileResolver(resolver.HttpResolverBase): def is_supported(self, handle): return _is_on_kaggle_notebook() and _is_kaggle_handle(handle) def __call__(self, handle): - m = url_pattern.match(handle) - return kagglehub.model_download(f"{m.group('owner')}/{m.group('model')}/{m.group('framework').lower()}/{m.group('variation')}/{m.group('version')}") \ No newline at end of file + m = long_url_pattern.match(handle) or short_url_pattern.match(handle) + return kagglehub.model_download(f"{m.group('owner')}/{m.group('model')}/{m.group('framework').lower()}/{m.group('variation')}/{m.group('version')}") diff --git a/tests/test_kaggle_module_resolver.py b/tests/test_kaggle_module_resolver.py index 9c88c563..bb1f243e 100644 --- a/tests/test_kaggle_module_resolver.py +++ b/tests/test_kaggle_module_resolver.py @@ -79,11 +79,25 @@ def do_POST(self): self.wfile.write(bytes(f"Unhandled path: {self.path}", "utf-8")) class TestKaggleModuleResolver(unittest.TestCase): - def test_kaggle_resolver_succeeds(self): + def test_kaggle_resolver_long_url_succeeds(self): + model_url = "https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2" with create_test_server(KaggleJwtHandler) as addr: test_inputs = tf.ones([1,4]) - layer = hub.KerasLayer("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2") + layer = hub.KerasLayer(model_url) self.assertEqual([1, 1], layer(test_inputs).shape) + # Delete the files that were created in KaggleJwtHandler's do_POST method + os.unlink(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2")) + os.rmdir(os.path.dirname(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2"))) + + def test_kaggle_resolver_short_url_succeeds(self): + model_url = "https://kaggle.com/models/foo/foomodule/TensorFlow2/barvar/2" + with create_test_server(KaggleJwtHandler) as addr: + test_inputs = tf.ones([1,4]) + layer = hub.KerasLayer(model_url) + self.assertEqual([1, 1], layer(test_inputs).shape) + # Delete the files that were created in KaggleJwtHandler's do_POST method + os.unlink(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2")) + os.rmdir(os.path.dirname(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2"))) def test_kaggle_resolver_not_attached_throws(self): with create_test_server(KaggleJwtHandler) as addr: