Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[b/331681978] Add Short URL matching to KaggleFileResolver #1375

Merged
merged 9 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions patches/kaggle_module_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@

from tensorflow_hub import resolver

url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)/frameworks/(?P<framework>[^\\/]+)/variations/(?P<variation>[^\\/]+)/versions/(?P<version>[0-9]+)$")
short_url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)/(?P<framework>[^\\/]+)/(?P<variation>[^\\/]+)/(?P<version>[0-9]+)$")
long_url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)/frameworks/(?P<framework>[^\\/]+)/variations/(?P<variation>[^\\/]+)/versions/(?P<version>[0-9]+)$")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was considering if we should load these via an ENV var that we can mutate, but I suppose it is really rare that we mutate this, and it might be fine if only new images can support changes like this going forward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a unified version of the two regex's if we wanted to still only need one:

https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)(/frameworks)?/(?P<framework>[^\\/]+)(/variations)?/(?P<variation>[^\\/]+)(/versions)?/(?P<version>[0-9]+)$

https://regex101.com/r/JbWJNI/1

All I did was making the "extra" parts of the path optional, though I suppose this does make each part independently optional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind if we keep short_url_pattern and long_url_pattern? I lean towards that because it feels more readable and avoids the "independently optional parts" issue you mention


def _is_on_kaggle_notebook():
return os.getenv("KAGGLE_KERNEL_RUN_TYPE") != None and os.getenv("KAGGLE_USER_SECRETS_TOKEN") != None

def _is_kaggle_handle(handle):
return url_pattern.match(handle) != None
return long_url_pattern.match(handle) != None or short_url_pattern.match(handle) != None

class KaggleFileResolver(resolver.HttpResolverBase):
def is_supported(self, handle):
return _is_on_kaggle_notebook() and _is_kaggle_handle(handle)

def __call__(self, handle):
m = url_pattern.match(handle)
return kagglehub.model_download(f"{m.group('owner')}/{m.group('model')}/{m.group('framework').lower()}/{m.group('variation')}/{m.group('version')}")
m = long_url_pattern.match(handle) or short_url_pattern.match(handle)
return kagglehub.model_download(f"{m.group('owner')}/{m.group('model')}/{m.group('framework').lower()}/{m.group('variation')}/{m.group('version')}")
18 changes: 16 additions & 2 deletions tests/test_kaggle_module_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,25 @@ def do_POST(self):
self.wfile.write(bytes(f"Unhandled path: {self.path}", "utf-8"))

class TestKaggleModuleResolver(unittest.TestCase):
def test_kaggle_resolver_succeeds(self):
def test_kaggle_resolver_long_url_succeeds(self):
model_url = "https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2"
with create_test_server(KaggleJwtHandler) as addr:
test_inputs = tf.ones([1,4])
layer = hub.KerasLayer("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2")
layer = hub.KerasLayer(model_url)
self.assertEqual([1, 1], layer(test_inputs).shape)
# Delete the files that were created in KaggleJwtHandler's do_POST method
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed or else os.makedirs and os.symlink would fail when the new test tried to recreate the same files. (Alternative solution was to use a different model URL for the new test, but I thought it was cleaner to do the proper cleanup after each test.)

os.unlink(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2"))
os.rmdir(os.path.dirname(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2")))

def test_kaggle_resolver_short_url_succeeds(self):
model_url = "https://kaggle.com/models/foo/foomodule/TensorFlow2/barvar/2"
with create_test_server(KaggleJwtHandler) as addr:
test_inputs = tf.ones([1,4])
layer = hub.KerasLayer(model_url)
self.assertEqual([1, 1], layer(test_inputs).shape)
# Delete the files that were created in KaggleJwtHandler's do_POST method
os.unlink(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2"))
os.rmdir(os.path.dirname(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2")))

def test_kaggle_resolver_not_attached_throws(self):
with create_test_server(KaggleJwtHandler) as addr:
Expand Down