-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[b/307966230] Implement force_download option for model_download over HTTP #44
Conversation
tests/test_cache.py
Outdated
@@ -26,6 +41,7 @@ | |||
|
|||
class TestCache(BaseTestCase): | |||
def test_load_from_cache_miss(self): | |||
# Why is this ModelHandle needed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was curious about this! Will remove this comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like I left this by mistake. Probably when I moved it over to a "constant" at the top 🤦 . You can remove.
tests/test_cache.py
Outdated
os.makedirs(cache_path) | ||
Path(os.path.join(cache_path, TEST_FILEPATH)).touch() # Create file | ||
|
||
self.assertEqual(None, load_from_cache(TEST_MODEL_HANDLE, path=TEST_FILEPATH)) | ||
|
||
def test_load_from_cache_with_complete_marker_no_files_miss(self): | ||
with create_test_cache(): | ||
# Why is this line needed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was curious about this too! Will remove this comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove it and same for the line #88 below. Another 🤦. Thanks for cleaning this up :)
I refactored some tests so that the new force_download tests wouldnt feel too verbose. Please let me know if you disagree with any refactoring, and I'm happy to revert! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work! Just a few minor suggestions/comments.
tests/test_cache.py
Outdated
@@ -26,6 +41,7 @@ | |||
|
|||
class TestCache(BaseTestCase): | |||
def test_load_from_cache_miss(self): | |||
# Why is this ModelHandle needed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like I left this by mistake. Probably when I moved it over to a "constant" at the top 🤦 . You can remove.
tests/test_cache.py
Outdated
os.makedirs(cache_path) | ||
Path(os.path.join(cache_path, TEST_FILEPATH)).touch() # Create file | ||
|
||
self.assertEqual(None, load_from_cache(TEST_MODEL_HANDLE, path=TEST_FILEPATH)) | ||
|
||
def test_load_from_cache_with_complete_marker_no_files_miss(self): | ||
with create_test_cache(): | ||
# Why is this line needed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove it and same for the line #88 below. Another 🤦. Thanks for cleaning this up :)
tests/test_cache.py
Outdated
def test_delete_from_cache(self): | ||
with create_test_cache() as d: | ||
cache_path = get_cached_path(TEST_MODEL_HANDLE) | ||
os.makedirs(cache_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a few files with Path(..).touch()
and also files in subdirectories and then assert below that all the files have indeed been deleted.
tests/test_cache.py
Outdated
|
||
deleted_path = delete_from_cache(TEST_MODEL_HANDLE, path=TEST_FILEPATH) | ||
|
||
self.assertEqual( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you assert that the file has been deleted.
def test_versioned_model_download_force_download_raises(self): | ||
with create_test_jwt_http_server(KaggleJwtHandler): | ||
with self.assertRaises(ValueError): | ||
kagglehub.model_download(VERSIONED_MODEL_HANDLE, force_download="hiksjdhf") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why using hiksjdhf
instead of True
?
Maybe add a test case that if force_download=False
, it still works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, "hiksjdhf" was a debugging typo, thanks for catching! Added the test case!
tests/test_http_model_download.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding exhaustive tests & refactoring the assertions!
And also, looks like |
Thanks for the quick review! Are the CI checks identical except for python version? Want to make sure I've narrowed down what could be happening |
@@ -132,3 +130,85 @@ def test_model_archive_path(self): | |||
), | |||
archive_path, | |||
) | |||
|
|||
def _download_test_model_to_cache(self): | |||
cache_path = get_cached_path(TEST_MODEL_HANDLE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are missing the with create_test_cache():
block (see the other tests above).
This test util function creates a temporary directory to read/write cache files and set the proper environment variable to ensure the cache uses that temporary directory:
Lines 25 to 28 in c1b6e00
def create_test_cache(): | |
with TemporaryDirectory() as d: | |
with mock.patch.dict(os.environ, {CACHE_FOLDER_ENV_VAR_NAME: d}): | |
yield d |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer!! I think I misunderstood along the way. While testing I thought I might not need that because I call this helper from within a with create_tech_cache():
block
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad. Disregard my comment. I didn't realize this was a helper function and not a test.
If this is always called within the create_test_cache
block, then you are good.
def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: | ||
if force_download: | ||
msg = "Invalid input: Cannot force download in a Kaggle notebook" | ||
raise ValueError(msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, we should simply log a warning here.
The reason is that we want the same code to run regardless of whether users are in a Kaggle notebook or outside.
Previously, after a model or file was downloaded with
model_download
, there was no way to programmatically force a new download. This PR enables forced downloads.More Details
force_download
defaults to FalseNew flag usage: