Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arbitrary X in data #100

Merged
merged 21 commits into from
Apr 24, 2024
Merged

Conversation

wiseodd
Copy link
Collaborator

@wiseodd wiseodd commented Apr 2, 2024

Closes #86

@f-dangel @runame please give feedback and answer the question below.

Assumption:

data: Union[Iterable[Tuple[Tensor, Tensor]], Iterable[Tuple[UserDict, Tensor]], Iterable[Tuple[dict, Tensor]]]

and there is an additional parameter in _base._LinearOperator:

batch_size_fn: Optional[Callable[[Any], int]] = None

where it must be non-None whenever X is not a torch.Tensor.

This also fits well with Huggingface, although one must replace HF's default dataloader to outputs (data, data['labels']) instead of just data. Let me know if this should be considered.

Code for testing the functionality:

https://gist.github.com/wiseodd/426061afae24199446e60bfabc00e26e
I use laplace-torch there (two birds one stone), so just remove it if you don't want to install it. If you want to test laplace-torch, install via

pip install git+https://github.com/aleximmer/Laplace.git@mc-subset2

@wiseodd
Copy link
Collaborator Author

wiseodd commented Apr 2, 2024

Just to link this to aleximmer/Laplace#144

wiseodd added a commit to aleximmer/Laplace that referenced this pull request Apr 2, 2024
@coveralls
Copy link

coveralls commented Apr 5, 2024

Pull Request Test Coverage Report for Build 8791845160

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 71 of 82 (86.59%) changed or added relevant lines in 8 files are covered.
  • 26 unchanged lines in 3 files lost coverage.
  • Overall coverage increased (+0.2%) to 88.59%

Changes Missing Coverage Covered Lines Changed/Added Lines %
curvlinops/fisher.py 1 2 50.0%
curvlinops/examples/functorch.py 17 21 80.95%
curvlinops/jacobian.py 36 42 85.71%
Files with Coverage Reduction New Missed Lines %
curvlinops/fisher.py 1 28.57%
curvlinops/examples/functorch.py 1 88.3%
curvlinops/kfac.py 24 92.08%
Totals Coverage Status
Change from base Build 8757107974: 0.2%
Covered Lines: 1219
Relevant Lines: 1376

💛 - Coveralls

@wiseodd wiseodd marked this pull request as ready for review April 6, 2024 17:26
@wiseodd
Copy link
Collaborator Author

wiseodd commented Apr 6, 2024

@f-dangel @runame I added (i) an example on the usage with HuggingFace transformers in the docs, and (ii) unit tests.

Please review.

@wiseodd
Copy link
Collaborator Author

wiseodd commented Apr 7, 2024

Here's an example use case.

https://curvlinops--100.org.readthedocs.build/en/100/basic_usage/example_huggingface.html

The load for the users is minimal and the choice of UserDict is compatible with HF. Note that the users can do whatever they want in terms of their X's since the only requirements are they are dict/UserDict and they handle their "preprocessing" like .to(device) inside the forward function of the model.

Copy link
Owner

@f-dangel f-dangel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a first pass and added some minor suggestions, mainly missing documentation.
One main suggestion is to avoid having to duplicate the test functions, simply by appending the test cases using a dict to CASES_NO_DEVICE and instead make the cases fixture return the batch_size_fn.

curvlinops/_base.py Outdated Show resolved Hide resolved
curvlinops/_base.py Outdated Show resolved Hide resolved
curvlinops/_base.py Outdated Show resolved Hide resolved
curvlinops/_base.py Outdated Show resolved Hide resolved
curvlinops/_base.py Outdated Show resolved Hide resolved
docs/examples/basic_usage/example_huggingface.py Outdated Show resolved Hide resolved
docs/examples/basic_usage/example_huggingface.py Outdated Show resolved Hide resolved
docs/examples/basic_usage/example_huggingface.py Outdated Show resolved Hide resolved
test/cases.py Outdated Show resolved Hide resolved
test/cases.py Outdated Show resolved Hide resolved
@runame runame added the enhancement New feature or request label Apr 10, 2024
@runame runame self-requested a review April 10, 2024 11:46
Copy link
Collaborator

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While for many settings the model/data loader still has to be adjusted, this seems like a nice usability improvement!

See my comments and formatting and linting have to be fixed as well.

curvlinops/_base.py Outdated Show resolved Hide resolved
curvlinops/fisher.py Outdated Show resolved Hide resolved
curvlinops/fisher.py Outdated Show resolved Hide resolved
curvlinops/kfac.py Outdated Show resolved Hide resolved
curvlinops/kfac.py Outdated Show resolved Hide resolved
test/test_fisher.py Outdated Show resolved Hide resolved
test/test_ggn.py Outdated Show resolved Hide resolved
test/test_gradient_moments.py Outdated Show resolved Hide resolved
test/test_hessian.py Outdated Show resolved Hide resolved
test/test_kfac.py Outdated Show resolved Hide resolved
@runame
Copy link
Collaborator

runame commented Apr 10, 2024

One more thing, the type hints for the data also have to be modified in all files that don't require other changes, e.g. ggn.py.

@wiseodd
Copy link
Collaborator Author

wiseodd commented Apr 18, 2024

I'm finally done.

  1. examples/functorch.py is now aware of dict-like inputs
  2. Update all the test cases to accommodate batch_size_fn

@f-dangel @runame ready for another check.

wiseodd added 3 commits April 18, 2024 18:05
                 _                 _
                (_) ___  ___  _ __| |_
                | |/ _/ / _ \/ '__  _/
                | |\__ \/\_\/| |  | |_
                |_|\___/\___/\_/   \_/

      isort your imports, so you don't have to.

                    VERSION 5.13.2

Nothing to do: no files or paths have have been passed in!

Try one of the following:

    `isort .` - sort all Python files, starting from the current directory, recursively.
    `isort . --interactive` - Do the same, but ask before making any changes.
    `isort . --check --diff` - Check to see if imports are correctly sorted within this project.
    `isort --help` - In-depth information about isort's available command-line options.

Visit https://pycqa.github.io/isort/ for complete information about how to use isort.
Copy link
Owner

@f-dangel f-dangel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly nits. Please make sure you try running make test on a compute infrastructure with access to a GPU to make sure there are no device-related issues as GH actions can only check on CPU (I'll do it anyways before merging).

curvlinops/_base.py Outdated Show resolved Hide resolved
curvlinops/examples/functorch.py Outdated Show resolved Hide resolved
curvlinops/fisher.py Outdated Show resolved Hide resolved
curvlinops/fisher.py Outdated Show resolved Hide resolved
curvlinops/jacobian.py Outdated Show resolved Hide resolved
test/test_jacobian.py Outdated Show resolved Hide resolved
test/test_jacobian.py Outdated Show resolved Hide resolved
test/test_jacobian.py Outdated Show resolved Hide resolved
test/test_submatrix_on_curvatures.py Outdated Show resolved Hide resolved
test/test_submatrix_on_curvatures.py Show resolved Hide resolved
@wiseodd
Copy link
Collaborator Author

wiseodd commented Apr 19, 2024

@f-dangel I'm done. Currently running the test on a GPU-enabled env. It takes so long but you can continue reviewing. PEP8 and other linters' issues should also be resolved.

@wiseodd
Copy link
Collaborator Author

wiseodd commented Apr 19, 2024

Alright, confirmed that all tests passed on GPU!

@f-dangel
Copy link
Owner

@runame I'm done with my second pass, could you take a quick second look and merge if everything looks good?

Copy link
Collaborator

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@f-dangel f-dangel merged commit 30c77b4 into f-dangel:main Apr 24, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make less/no assumption about the data
4 participants