Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Large refactor and complete testing #38

Merged
merged 76 commits into from
Mar 25, 2022
Merged

Large refactor and complete testing #38

merged 76 commits into from
Mar 25, 2022

Conversation

gpauloski
Copy link
Owner

@gpauloski gpauloski commented Mar 25, 2022

Mass K-FAC refactor and repository changes

DevOps changes

  • kfac requires torch>=1.8 and Python >=3.7
  • tox used for testing environments and automation
  • pre-commit updated. Major changes include prefer single-quotes, mypy, flake8 plugins
  • Switch to setup.cfg for package metadata and tox/flake8/mypy/coverage configuration
  • Add requirement-dev.txt that contains all dependencies needed to run the test suite

Code quality and testing

  • Complete type annotations for all code
    • Passes mypy
  • Separated testing utilities and unit tests into testing/ and tests/ respectively
  • Expansive unit testing suite that achieves 100% code coverage
  • New testing utilities include wrappers for simulating distributed environments and small test models
  • Added end-to-end training tests
    • small unit test (included in pytest) that checks loss decreases when training with K-FAC
    • MNIST integration test (not run with pytest) that verifies training with K-FAC achieves higher accuracy

kfac package improvements

  • KFAC layers separated from PyTorch module wrappers
    • KFACBaseLayer handles general K-FAC computations and communications for an arbitrary layer
    • ModuleHelper implementations provide a unified interface for interacting with supported PyTorch modules
      • Provides methods that return the size of the factors for the layer so the size of factors can be determined prior to training
      • Provides methods for getting the current gradients, updating the gradients, and computing the factors from the intermediate data
    • Each KFACBaseLayer instance is passed a ModuleHelper instance corresponding to the module in the model being preconditioned
  • Removed broken LSTM/RNN/Embedding layer support
  • Module registration utilities moved out of the preconditioner class and into the kfac.layers.register module
  • Replaced the comm module with the distributed module that provide a more exhaustive set of distributed communication utilties
    • All communication ops now return futures to the return value to allow more aggressive asynchronous communication
    • Added allreduce bucketing for factor allreduce (closes Add Tensor Communication Bucketing #32)
    • Added get_rank and get_world_size methods to enable K-FAC training when torch.distributed is not initialized
  • Enum types moved to enums module for convenience with type annotations
  • KFACBaseLayer is now agnostic of its placement
    • I.e., the KFACBaseLayer expects some other object to correctly execute its operations according to some placement strategy.
    • This change was made to allow other preconditioner implementations to use the math/communication operations provided by the KFACBaseLayer without being beholded to some placement strategy.
  • Created the BaseKFACPreconditioner which provides the minimal set of functionality for preconditioning with K-FAC
    • Provides state dict saving/loading, a step() method, hook registration to KFACBaseLayer, and some small bookkeeping functionality
    • The BaseKFACPreconditioner takes as input already registered KFACBaseLayers and an initialized WorkAssignment object.
    • This change was made to factor out the strategy specific details from the core preconditioning functions with the goal of having preconditioner implementations that interact more closely with other frameworks such as DeepSpeed
    • Added reset_batch() to clear the staged factors for the batch in the case of a bad batch of data (e.g., if the gradients overflowed)
    • memory_usage() includes the intermediate factors accumulated for the current batch
    • state_dict now includes K-FAC hyperparameters and steps in addition to factors
  • Added KFACPreconditioner, a subclass of BaseKFACPreconditioner, that implements the full functionality described in the KAISA paper.
  • New WorkAssignment interface that provides a schematic for the methods needed by BaseKFACPreconditioner to determine where to perform computations and communications
    • Added the KAISAAssignment implementation that provides the KAISA gradient worker fraction-based strategy
  • K-FAC hyperparameter schedule changes
    • Old inflexible KFACParamScheduler replace with a LambdaParamScheduler modeled on PyTorch's LambdaLRSchedule
    • BaseKFACPreconditioner can be passed functions the return the current K-FAC hyperparameters rather than static float values
  • All printing done via logging and KFACBasePreconditioner takes an optional loglevel parameter (closes KFAC verbose should use logging instead of print #33)

Example script changes

  • Added examples/requirements.txt
  • Usage instructions for examples moved to examples/README.md
  • Update examples to use new kfac API
  • Examples are now properly type annotated
  • Removed non-working language model example

Other changes + future goals

  • Removed a lot of content from the README that should eventually be moved to a wiki
    • Previously, the README was quite verbose and made it difficult to find the important content
  • Updated README examples, publications, and development instructions
  • Future changes include:
    • GitHub actions for running code formatting, unit tests, integration tests
    • Issue/PR templates
    • Added badges to README
    • wiki

This commit makes a number of large changes including updates to support
PyTorch 1.8-1.10 features, a large refactoring of the KFAC codebase,
removal of Horovod support, and more.

General Changes:
- Updated to KFAC 0.4.0
- Add `Makefile` for running code formatting (with black), flaking (with
  flake8), and unit tests (with pytest).
  - All code has been formatted to conform with PEP (and some small
    modifications to PEP).
- Removed Horovod examples
- Updated distributed launch scripts to work with
  `torch.distributed.launch` (removing the need for auxilliary
  launch-on-node-scripts).
- Update `scripts/*.sh` to infer distributed environment from
  environment variables.
- Cleaned up README

KFAC Changes:
- PyTorch 1.9/1.10 compatibility
  - Updated grad hooks to be compatible with new torch.nn.Module hooks
  - Update inverse and eigen decomp methods to no longer use deprecated
    functions. Additionally, inverse/eigen decomp results are written
    straight into buffer.
- Intermediate values in forward/backward passes are no longer accumulated as
  separate values but summed in a tensor.
  - The number of accumulation steps passed to `KFAC()` is used to
    average the summed intermediate values.
  - Intermediate values are only saved if the model is in training mode.
    E.g., with `model.train()`.
- Added enum types for `AssignmentStrategy`, `ComputeMethod`, and
  `DistributedStrategy`.
- Updated `KFAC()` parameters:
  - Added `accumulation_steps`
  - Added `symmetry_aware`
  - Renamed `*_update_freq` to `*_update_steps`
  - Renamed `compute_eigen_outer_product`
  - Renamed `colocate_factors`
- Added @Property methods for all params in the `KFAC()` state dict.
- `KFAC()` now launchers communication operations for factors, inverses,
  and gradients as soon as they are computed to overlap communication
  with computation.
- `KFAC()` computes inverse and gradients in reverse order of layers
  since G factors are only made availble in backward pass.
- New `TorchDistributedCommunicator()` class
  - Communication operations now return PyTorch futures,
    and the `KFACBaseLayer` classes have sync methods to wait on futures.
  - Removed Horovod backend support
- Rewrote `WorkerAllocator`
  - Merged load balancer with `WorkerAllocator`
  - Update variable and function names to be consistent with names used
    in KAISA paper.
- Removed `state` attribute of `KFACBaseLayer` classes so tensors are
  stored directly as class attributes.
- Separated the eigen decomp and inverse method layer classes into
  separate classes that inherit from `KFACBaseLayer`.
- KFAC `get_*_factors` for specific Torch modules are now separate
  `ModuleHelper` classes that are stored as am attribute of
  `KFACBaseLayer`.
- Removed support for Embedding and LSTM/RNN layers (not that they were
  really supported anyways).
Refactoring:
- Changed the `kfac` package to use absolute imports everywhere.
- Renamed `grad_comm_*` to `grad_receiver_*` in `KFACBaseLayer` to use
  align the terminology with that used in `WorkerAllocator`.
- Use `torch.outer()` for computing outer products instead of unsqueezes
  + matrix multiplication.
- Renamed `kfac/comm.py` to `kfac/preconditioner.py`.

Changes:
- Use multiplication by inverse of world size instead of matrix division
  after allreduces.
- Eigenvalues are clamped to be non-negative (as it was in KFAC 0.3).
- Removed uncessary `.view()` calls in `kfac/layers/modules.py`.
- Add early function exit from communication functions in
  `KFACBaseLayer`.
  Note that `KFAC.step()` would already correctly skip these based on the
  distribution strategy.

Bug Fixes:
- Fixed instance of using `Future.value()` instead of `Future.wait()` in
  a sync function`.
- Fixed issue on the first iteration where the A and G factors were correctly
  initialized to a zero matrix with ones on the diagonals but then the
  running average of the intialized matrix with the new values of A and
  G were skipped causing poor convergence.
When accumulating the inputs for the A factor, PyTorch was raising an error that variables needed for the gradient computation were modified by in-place operations.
This was because on the first accumulation step, we were saving a reference to the intermediate data.
In this fix, we simply save a reference to the data, rather than the tensor as only the tensor has a history.
- Refactored script argument names to be inline with the parameter names to KFAC.
- Added prefetching to dataloader
- Removed reference to language modeling scripts in README
The reshape in the computation of the factors for the linear module were removed previously but these reshapes are necessary for linear modules with inputs with more than 2 dimensions.
The mini-step tracking assumed that forward/backward steps were never overlapped so a single counter was used to determine when a gradient accumulation boundary was.
This assumption breaks for more advanced layer pipelining schemes such as those used by DeepSpeed.
The fix is counting the mini-steps on a per-module basis.
- Added parameters to KFAC and KFACLayer classes
- Changed self.comm to self.tdc for clarity
- Updated CNN example to explicitly define bucket cap
@gpauloski gpauloski self-assigned this Mar 25, 2022
@gpauloski gpauloski marked this pull request as ready for review March 25, 2022 15:47
@gpauloski gpauloski merged commit 491ffa8 into master Mar 25, 2022
@gpauloski gpauloski deleted the experimental branch March 25, 2022 15:48
gpauloski added a commit that referenced this pull request Mar 25, 2022
Large refactor and complete testing
gpauloski added a commit that referenced this pull request Mar 25, 2022
Large refactor and complete testing
gpauloski added a commit that referenced this pull request Mar 25, 2022
Large refactor and complete testing

Former-commit-id: b778010
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

KFAC verbose should use logging instead of print Add Tensor Communication Bucketing
1 participant