Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[b/307966230] Implement force_download option for model_download over HTTP #44

Merged
merged 15 commits into from
Dec 15, 2023

Conversation

lucyhe
Copy link
Contributor

@lucyhe lucyhe commented Dec 14, 2023

Previously, after a model or file was downloaded with model_download, there was no way to programmatically force a new download. This PR enables forced downloads.

More Details

  • This feature can only be used outside of Kaggle notebooks. Kaggle notebooks rely on a different caching mechanism where force_download is not yet possible
  • force_download defaults to False
  • If a file is force downloaded, the other files in the same folder are not force downloaded

New flag usage:

kagglehub.model_download(model_handle, force_download=True)
kagglehub.model_download(model_handle, path=path, force_download=True)

@@ -26,6 +41,7 @@

class TestCache(BaseTestCase):
def test_load_from_cache_miss(self):
# Why is this ModelHandle needed?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was curious about this! Will remove this comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like I left this by mistake. Probably when I moved it over to a "constant" at the top 🤦 . You can remove.

os.makedirs(cache_path)
Path(os.path.join(cache_path, TEST_FILEPATH)).touch() # Create file

self.assertEqual(None, load_from_cache(TEST_MODEL_HANDLE, path=TEST_FILEPATH))

def test_load_from_cache_with_complete_marker_no_files_miss(self):
with create_test_cache():
# Why is this line needed?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was curious about this too! Will remove this comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove it and same for the line #88 below. Another 🤦. Thanks for cleaning this up :)

@lucyhe
Copy link
Contributor Author

lucyhe commented Dec 14, 2023

I refactored some tests so that the new force_download tests wouldnt feel too verbose. Please let me know if you disagree with any refactoring, and I'm happy to revert!

@lucyhe lucyhe requested a review from rosbo December 14, 2023 01:33
Copy link
Contributor

@rosbo rosbo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work! Just a few minor suggestions/comments.

@@ -26,6 +41,7 @@

class TestCache(BaseTestCase):
def test_load_from_cache_miss(self):
# Why is this ModelHandle needed?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like I left this by mistake. Probably when I moved it over to a "constant" at the top 🤦 . You can remove.

os.makedirs(cache_path)
Path(os.path.join(cache_path, TEST_FILEPATH)).touch() # Create file

self.assertEqual(None, load_from_cache(TEST_MODEL_HANDLE, path=TEST_FILEPATH))

def test_load_from_cache_with_complete_marker_no_files_miss(self):
with create_test_cache():
# Why is this line needed?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove it and same for the line #88 below. Another 🤦. Thanks for cleaning this up :)

def test_delete_from_cache(self):
with create_test_cache() as d:
cache_path = get_cached_path(TEST_MODEL_HANDLE)
os.makedirs(cache_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a few files with Path(..).touch() and also files in subdirectories and then assert below that all the files have indeed been deleted.


deleted_path = delete_from_cache(TEST_MODEL_HANDLE, path=TEST_FILEPATH)

self.assertEqual(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you assert that the file has been deleted.

def test_versioned_model_download_force_download_raises(self):
with create_test_jwt_http_server(KaggleJwtHandler):
with self.assertRaises(ValueError):
kagglehub.model_download(VERSIONED_MODEL_HANDLE, force_download="hiksjdhf")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why using hiksjdhf instead of True?

Maybe add a test case that if force_download=False, it still works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, "hiksjdhf" was a debugging typo, thanks for catching! Added the test case!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding exhaustive tests & refactoring the assertions!

@rosbo
Copy link
Contributor

rosbo commented Dec 14, 2023

And also, looks like shutil.rmtree is failing for some versions of Python based on whether files are there or not. See the tests run from CI.

@lucyhe
Copy link
Contributor Author

lucyhe commented Dec 14, 2023

And also, looks like shutil.rmtree is failing for some versions of Python based on whether files are there or not. See the tests run from CI.

Thanks for the quick review! Are the CI checks identical except for python version? Want to make sure I've narrowed down what could be happening

@@ -132,3 +130,85 @@ def test_model_archive_path(self):
),
archive_path,
)

def _download_test_model_to_cache(self):
cache_path = get_cached_path(TEST_MODEL_HANDLE)
Copy link
Contributor

@rosbo rosbo Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are missing the with create_test_cache(): block (see the other tests above).

This test util function creates a temporary directory to read/write cache files and set the proper environment variable to ensure the cache uses that temporary directory:

kagglehub/tests/utils.py

Lines 25 to 28 in c1b6e00

def create_test_cache():
with TemporaryDirectory() as d:
with mock.patch.dict(os.environ, {CACHE_FOLDER_ENV_VAR_NAME: d}):
yield d

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer!! I think I misunderstood along the way. While testing I thought I might not need that because I call this helper from within a with create_tech_cache(): block

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. Disregard my comment. I didn't realize this was a helper function and not a test.

If this is always called within the create_test_cache block, then you are good.

def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
if force_download:
msg = "Invalid input: Cannot force download in a Kaggle notebook"
raise ValueError(msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we should simply log a warning here.

The reason is that we want the same code to run regardless of whether users are in a Kaggle notebook or outside.

@lucyhe lucyhe merged commit 8345bb3 into main Dec 15, 2023
6 checks passed
@lucyhe lucyhe deleted the lh/cache branch December 15, 2023 01:16
@lucyhe lucyhe changed the title Implement force_download option for model_download over HTTP [b/307966230] Implement force_download option for model_download over HTTP Dec 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants