Skip to content

Commit

Permalink
Merge pull request #1375 from Kaggle/lh/short
Browse files Browse the repository at this point in the history
[b/331681978] Add Short URL matching to KaggleFileResolver
  • Loading branch information
lucyhe authored Apr 2, 2024
2 parents 052d182 + 1caa6b0 commit 5afea0a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
9 changes: 5 additions & 4 deletions patches/kaggle_module_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@

from tensorflow_hub import resolver

url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)/frameworks/(?P<framework>[^\\/]+)/variations/(?P<variation>[^\\/]+)/versions/(?P<version>[0-9]+)$")
short_url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)/(?P<framework>[^\\/]+)/(?P<variation>[^\\/]+)/(?P<version>[0-9]+)$")
long_url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)/frameworks/(?P<framework>[^\\/]+)/variations/(?P<variation>[^\\/]+)/versions/(?P<version>[0-9]+)$")

def _is_on_kaggle_notebook():
return os.getenv("KAGGLE_KERNEL_RUN_TYPE") != None and os.getenv("KAGGLE_USER_SECRETS_TOKEN") != None

def _is_kaggle_handle(handle):
return url_pattern.match(handle) != None
return long_url_pattern.match(handle) != None or short_url_pattern.match(handle) != None

class KaggleFileResolver(resolver.HttpResolverBase):
def is_supported(self, handle):
return _is_on_kaggle_notebook() and _is_kaggle_handle(handle)

def __call__(self, handle):
m = url_pattern.match(handle)
return kagglehub.model_download(f"{m.group('owner')}/{m.group('model')}/{m.group('framework').lower()}/{m.group('variation')}/{m.group('version')}")
m = long_url_pattern.match(handle) or short_url_pattern.match(handle)
return kagglehub.model_download(f"{m.group('owner')}/{m.group('model')}/{m.group('framework').lower()}/{m.group('variation')}/{m.group('version')}")
18 changes: 16 additions & 2 deletions tests/test_kaggle_module_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,25 @@ def do_POST(self):
self.wfile.write(bytes(f"Unhandled path: {self.path}", "utf-8"))

class TestKaggleModuleResolver(unittest.TestCase):
def test_kaggle_resolver_succeeds(self):
def test_kaggle_resolver_long_url_succeeds(self):
model_url = "https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2"
with create_test_server(KaggleJwtHandler) as addr:
test_inputs = tf.ones([1,4])
layer = hub.KerasLayer("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2")
layer = hub.KerasLayer(model_url)
self.assertEqual([1, 1], layer(test_inputs).shape)
# Delete the files that were created in KaggleJwtHandler's do_POST method
os.unlink(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2"))
os.rmdir(os.path.dirname(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2")))

def test_kaggle_resolver_short_url_succeeds(self):
model_url = "https://kaggle.com/models/foo/foomodule/TensorFlow2/barvar/2"
with create_test_server(KaggleJwtHandler) as addr:
test_inputs = tf.ones([1,4])
layer = hub.KerasLayer(model_url)
self.assertEqual([1, 1], layer(test_inputs).shape)
# Delete the files that were created in KaggleJwtHandler's do_POST method
os.unlink(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2"))
os.rmdir(os.path.dirname(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2")))

def test_kaggle_resolver_not_attached_throws(self):
with create_test_server(KaggleJwtHandler) as addr:
Expand Down

0 comments on commit 5afea0a

Please sign in to comment.