-
Notifications
You must be signed in to change notification settings - Fork 957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[b/331681978] Add Short URL matching to KaggleFileResolver #1375
Changes from all commits
8700ad0
4f92929
5707223
2ad5299
2663eec
cf78793
fce7103
360163c
1caa6b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,18 +4,19 @@ | |
|
||
from tensorflow_hub import resolver | ||
|
||
url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)/frameworks/(?P<framework>[^\\/]+)/variations/(?P<variation>[^\\/]+)/versions/(?P<version>[0-9]+)$") | ||
short_url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)/(?P<framework>[^\\/]+)/(?P<variation>[^\\/]+)/(?P<version>[0-9]+)$") | ||
long_url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)/frameworks/(?P<framework>[^\\/]+)/variations/(?P<variation>[^\\/]+)/versions/(?P<version>[0-9]+)$") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's a unified version of the two regex's if we wanted to still only need one:
https://regex101.com/r/JbWJNI/1 All I did was making the "extra" parts of the path optional, though I suppose this does make each part independently optional. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you mind if we keep |
||
|
||
def _is_on_kaggle_notebook(): | ||
return os.getenv("KAGGLE_KERNEL_RUN_TYPE") != None and os.getenv("KAGGLE_USER_SECRETS_TOKEN") != None | ||
|
||
def _is_kaggle_handle(handle): | ||
return url_pattern.match(handle) != None | ||
return long_url_pattern.match(handle) != None or short_url_pattern.match(handle) != None | ||
|
||
class KaggleFileResolver(resolver.HttpResolverBase): | ||
def is_supported(self, handle): | ||
return _is_on_kaggle_notebook() and _is_kaggle_handle(handle) | ||
|
||
def __call__(self, handle): | ||
m = url_pattern.match(handle) | ||
return kagglehub.model_download(f"{m.group('owner')}/{m.group('model')}/{m.group('framework').lower()}/{m.group('variation')}/{m.group('version')}") | ||
m = long_url_pattern.match(handle) or short_url_pattern.match(handle) | ||
return kagglehub.model_download(f"{m.group('owner')}/{m.group('model')}/{m.group('framework').lower()}/{m.group('variation')}/{m.group('version')}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,11 +79,25 @@ def do_POST(self): | |
self.wfile.write(bytes(f"Unhandled path: {self.path}", "utf-8")) | ||
|
||
class TestKaggleModuleResolver(unittest.TestCase): | ||
def test_kaggle_resolver_succeeds(self): | ||
def test_kaggle_resolver_long_url_succeeds(self): | ||
model_url = "https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2" | ||
with create_test_server(KaggleJwtHandler) as addr: | ||
test_inputs = tf.ones([1,4]) | ||
layer = hub.KerasLayer("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2") | ||
layer = hub.KerasLayer(model_url) | ||
self.assertEqual([1, 1], layer(test_inputs).shape) | ||
# Delete the files that were created in KaggleJwtHandler's do_POST method | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was needed or else |
||
os.unlink(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2")) | ||
os.rmdir(os.path.dirname(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2"))) | ||
|
||
def test_kaggle_resolver_short_url_succeeds(self): | ||
model_url = "https://kaggle.com/models/foo/foomodule/TensorFlow2/barvar/2" | ||
with create_test_server(KaggleJwtHandler) as addr: | ||
test_inputs = tf.ones([1,4]) | ||
layer = hub.KerasLayer(model_url) | ||
self.assertEqual([1, 1], layer(test_inputs).shape) | ||
# Delete the files that were created in KaggleJwtHandler's do_POST method | ||
os.unlink(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2")) | ||
os.rmdir(os.path.dirname(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2"))) | ||
|
||
def test_kaggle_resolver_not_attached_throws(self): | ||
with create_test_server(KaggleJwtHandler) as addr: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was considering if we should load these via an ENV var that we can mutate, but I suppose it is really rare that we mutate this, and it might be fine if only new images can support changes like this going forward.