Skip to content

Commit

Permalink
Fix ruff and mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
gpauloski committed Jan 28, 2023
1 parent 7f2e4e1 commit cd95269
Show file tree
Hide file tree
Showing 28 changed files with 88 additions and 65 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

K-FAC, Kronecker-factored Approximate Curvature, is a second-order optimization method based on an efficient approximation of the Fisher information matrix (see the [original paper](https://arxiv.org/abs/1503.05671)).
This repository provides a PyTorch implementation of K-FAC as a preconditioner to standard PyTorch optimizers with support for single-device or distributed training.
The distributed strategy is implemented using KAISA, a **K-FAC**-enabled, **A**daptable, **I**mproved, and **S**c**A**lable second-order optimizer framework, where the placement of the second-order computations and gradient preconditioning is controlled by the *gradient worker fraction* parameter (see the [paper](https://arxiv.org/abs/2107.01739) for more details).
The distributed strategy is implemented using KAISA, a K-FAC-enabled, Adaptable, Improved, and Scalable second-order optimizer framework, where the placement of the second-order computations and gradient preconditioning is controlled by the *gradient worker fraction* parameter (see the [paper](https://arxiv.org/abs/2107.01739) for more details).
KAISA has been shown to reduce time-to-convergence in [PyTorch distributed training](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) applications such as ResNet-50, Mask R-CNN, and BERT.

## Publications
Expand Down Expand Up @@ -41,7 +41,7 @@ $ pip install . # Use -e to install in development mode

## Usage

K-FAC requires minimial code to incorporate with existing training scripts.
K-FAC requires minimal code to incorporate with existing training scripts.
See the [K-FAC docstring](kfac/preconditioner.py) for a detailed list of K-FAC parameters.

```Python
Expand Down
4 changes: 2 additions & 2 deletions examples/language/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from torch.utils.data.distributed import DistributedSampler
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import PennTreebank
from torchtext.datasets import WikiText103
from torchtext.datasets import WikiText2
from torchtext.datasets import WikiText103
from torchtext.vocab import build_vocab_from_iterator
from torchtext.vocab import Vocab

Expand Down Expand Up @@ -112,7 +112,7 @@ def get_dataset(
seq_len (int): number of tokens in a training sequence.
batch_size (int): batch size.
cuda (bool): set as True if training with CUDA.
rank (int): optional rank of this worker for initalizing the
rank (int): optional rank of this worker for initializing the
distributed sampler.
world_size (int): optional world size if using distributed training.
Expand Down
3 changes: 2 additions & 1 deletion examples/language/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model),
)
self.pe: torch.Tensor
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
Expand All @@ -121,7 +122,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
about the relative or absolute position of the tokens in the
sequence.
"""
x = x + self.pe[: x.size(0)] # type: ignore
x = x + self.pe[: x.size(0)]
return self.dropout(x)


Expand Down
2 changes: 1 addition & 1 deletion examples/torch_cifar10_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.distributed as dist
from torch.utils import collect_env
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary # type: ignore
from torchinfo import summary

import examples.vision.cifar_resnet as models
import examples.vision.datasets as datasets
Expand Down
2 changes: 1 addition & 1 deletion examples/torch_imagenet_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
import torch.distributed as dist
import torchvision.models as models # type: ignore
import torchvision.models as models
from torch.utils import collect_env
from torch.utils.tensorboard import SummaryWriter

Expand Down
8 changes: 4 additions & 4 deletions examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ def __init__(self, smoothing: float = 0.0):

def forward(
self,
input: torch.Tensor,
input_: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
"""Forward pass."""
log_prob = log_softmax(input, dim=-1)
log_prob = log_softmax(input_, dim=-1)
weight = (
input.new_ones(input.size())
input_.new_ones(input_.size())
* self.smoothing
/ (input.size(-1) - 1.0)
/ (input_.size(-1) - 1.0)
)
weight.scatter_(-1, target.unsqueeze(-1), (1.0 - self.smoothing))
loss = (-weight * log_prob).sum(dim=-1).mean()
Expand Down
2 changes: 1 addition & 1 deletion examples/vision/cifar_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Moreover, most of the implementations on the web is copy-paste from
torchvision's resnet and has wrong number of params.
Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
Proper ResNet-s for CIFAR10 (for fair comparison and etc.) has following
number of layers and parameters:
name | layers | params
Expand Down
12 changes: 7 additions & 5 deletions examples/vision/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets # type: ignore
from torchvision import transforms # type: ignore
from torchvision import datasets
from torchvision import transforms

T = tuple[torch.Tensor, torch.Tensor]

Expand Down Expand Up @@ -121,7 +121,9 @@ def make_sampler_and_loader(
]:
"""Create sampler and dataloader for train and val datasets."""
torch.set_num_threads(4)
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
kwargs: dict[str, Any] = (
{'num_workers': 4, 'pin_memory': True} if args.cuda else {}
)
kwargs['prefetch_factor'] = 8
kwargs['persistent_workers'] = True

Expand All @@ -134,7 +136,7 @@ def make_sampler_and_loader(
train_dataset,
batch_size=args.batch_size,
sampler=train_sampler,
**kwargs, # type: ignore
**kwargs,
)
val_sampler: DistributedSampler[T] = DistributedSampler(
val_dataset,
Expand All @@ -145,7 +147,7 @@ def make_sampler_and_loader(
val_dataset,
batch_size=args.val_batch_size,
sampler=val_sampler,
**kwargs, # type: ignore
**kwargs,
)

return train_sampler, train_loader, val_sampler, val_loader
2 changes: 1 addition & 1 deletion examples/vision/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math

import torch
from tqdm import tqdm # type: ignore
from tqdm import tqdm

import kfac
from examples.utils import accuracy
Expand Down
8 changes: 4 additions & 4 deletions kfac/base_preconditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
accumulation_steps (int): number of forward/backward passes
between optimization steps (default: 1).
update_factors_in_hook (bool): If True, running average of factors
is updated in the module hook and the async commmunication is
is updated in the module hook and the async communication is
started. Otherwise, this will be performed at the start of
step() (default: True).
defaults (dict): dictionary of default values to include in the
Expand Down Expand Up @@ -315,7 +315,7 @@ def step(self) -> None:
Note:
Gradients must be averaged across ranks before calling `step()`.
This condition is guarenteed to be true if using the
This condition is guaranteed to be true if using the
`DistributedDataParallel` model wrapper as gradients are
communicated during `loss.backward()`.
"""
Expand Down Expand Up @@ -436,14 +436,14 @@ def _compute_grad_scale(self) -> float:
def _save_input(
self,
module: torch.nn.Module,
input: list[torch.Tensor],
input_: list[torch.Tensor],
) -> None:
"""Hook for saving the input during the forward pass of a module."""
if not module.training:
return
if self.steps % self.factor_update_steps == 0:
name, layer = self._layers[module]
layer.save_layer_input(input)
layer.save_layer_input(input_)
# Update mini_step here because forward pass should always
# happen before backward pass
self._mini_steps[name] += 1
Expand Down
6 changes: 3 additions & 3 deletions kfac/gpt_neox/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,21 +132,21 @@ def reduce_g_factor(

def save_layer_input(
self,
input: list[torch.Tensor],
input_: list[torch.Tensor],
) -> None: # pragma: no cover
"""Override to gather input to primary rank."""
if self.primary_rank is None:
raise RuntimeError('primary rank has not been set yet.')
if self.parallelism == 'input':
a = gather_from_model_parallel_region(
input[0],
input_[0],
dst=self.primary_rank,
model_parallel_group=self.model_parallel_group,
)
if a is not None:
super().save_layer_input([a])
else:
super().save_layer_input(input)
super().save_layer_input(input_)

def save_layer_grad_output(
self,
Expand Down
10 changes: 6 additions & 4 deletions kfac/gpt_neox/mpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@ def gather_from_model_parallel_region(
is an `all gather`.
Note:
The concatentation is done along the last axis. I.e., this is the
The concatenation is done along the last axis. I.e., this is the
inverse operation of mpu.scatter_to_model_parallel_region().
Args:
tensor (torch.Tensor): tensor parition to gather.
tensor (torch.Tensor): tensor partition to gather.
dst (rank): destination rank to gather full tensor on.
model_parallel_group (ProcessGroup): model parallel process group.
If None, model parallel region will be assumed to have size 1.
fp32_allreduce (bool): if True and tensor is bf16, the tensor will
be cast to float before communication. Note: this is to match
the functionality of megatron's
gather_from_model_parallel_region().
dim (int): dimension along which to concatenate tensors.
Returns:
Gathered tensor on rank `dst` else None.
Expand Down Expand Up @@ -101,17 +102,18 @@ def split_tensor_along_dim(
) -> tuple[torch.Tensor, ...]:
"""Split a tensor along its last dimension.
Source: https://github.com/EleutherAI/gpt-neox/blob/d7af1e7a8e3a816610b7d169456f81ca62d34ff7/megatron/mpu/utils.py # noqa: 501
Source: https://github.com/EleutherAI/gpt-neox/blob/d7af1e7a8e3a816610b7d169456f81ca62d34ff7/megatron/mpu/utils.py
Args:
tensor (torch.Tensor): input tensor
num_partitions (int): number of partitions to split the tensor
dim (int): dimension along which to split the tensor.
contiguous_split_chunks (bool): If True, make each chunk contiguous
in memory.
Returns:
tuple of tensors
"""
""" # noqa: E501
dim_size = tensor.size()[dim]

if dim_size % num_partitions != 0:
Expand Down
2 changes: 1 addition & 1 deletion kfac/gpt_neox/preconditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(
recursively registering child modules of the parent.
Case-insensitive (default: []).
update_factors_in_hook (bool): If True, running average of factors
is updated in the module hook and the async commmunication is
is updated in the module hook and the async communication is
started. Otherwise, this will be performed at the start of
step() (default: True).
loglevel (int): logging level (default: logging.DEBUG).
Expand Down
4 changes: 2 additions & 2 deletions kfac/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,12 @@ def reset_batch(self) -> None:
self._g_batch = None
self._g_count = 0

def save_layer_input(self, input: list[torch.Tensor]) -> None:
def save_layer_input(self, input_: list[torch.Tensor]) -> None:
"""Save input for layer."""
# Note: the clone here is a fix for "RuntimeError: one of the variables
# needed for gradient computation has been modified by an inplace
# operation" in the ResNet50 + ImageNet example.
a = input[0].to(self.factor_dtype).clone()
a = input_[0].to(self.factor_dtype).clone()
a = self.module.get_a_factor(a)
if self._a_batch is None:
self._a_batch = a
Expand Down
6 changes: 3 additions & 3 deletions kfac/layers/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_g_factor(self, g: torch.Tensor) -> torch.Tensor:
raise NotImplementedError

def get_grad(self) -> torch.Tensor:
"""Get formated gradients (weight and bias) of module.
"""Get formatted gradients (weight and bias) of module.
Returns:
gradient of shape If bias != None,
Expand Down Expand Up @@ -192,10 +192,10 @@ def get_g_factor(self, g: torch.Tensor) -> torch.Tensor:
return get_cov(g)

def get_grad(self) -> torch.Tensor:
"""Get formated gradients (weight and bias) of module."""
"""Get formmated gradients (weight and bias) of module."""
grad = cast(
torch.Tensor,
self.module.weight.grad.view(
self.module.weight.grad.view( # type: ignore
self.module.weight.grad.size(0), # type: ignore
-1,
),
Expand Down
8 changes: 4 additions & 4 deletions kfac/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_cov(
"""Computes the empirical second moment of a 2D tensor.
Reference:
- https://github.com/tensorflow/kfac/blob/master/kfac/python/ops/fisher_factors.py#L220 # noqa: E501
- https://github.com/tensorflow/kfac/blob/master/kfac/python/ops/fisher_factors.py#L220
- https://arxiv.org/pdf/1602.01407.pdf#subsection.2.2
Args:
Expand All @@ -35,7 +35,7 @@ def get_cov(
Returns:
square tensor representing the second moment of a.
"""
""" # noqa: E501
if len(a.shape) != 2:
raise ValueError(
'Input tensor must have 2 dimensions. Got tensor with shape '
Expand Down Expand Up @@ -69,12 +69,12 @@ def reshape_data(
data_list (list): list of tensors of equal, arbitrary shape where the
batch_dim is either 0 or 1 depending on self.batch_first.
batch_first (bool, optional): is batch dim first. (default: True)
collapse_dim (bool, optional): if True, collapse all but the last dim
collapse_dims (bool, optional): if True, collapse all but the last dim
together forming a 2D output tensor.
Returns:
single tensor with all tensors from data_list concatenated across
batch_dim. Guarenteed to be 2D if collapse_dims=True.
batch_dim. Guaranteed to be 2D if collapse_dims=True.
"""
d = torch.cat(data_list, dim=int(not batch_first))
if collapse_dims and len(d.shape) > 2:
Expand Down
12 changes: 8 additions & 4 deletions kfac/preconditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
logger = logging.getLogger(__name__)


def _mock_new_group(x: list[int]) -> None:
return None


class KFACPreconditioner(BaseKFACPreconditioner):
"""KFAC Distributed Gradient Preconditioner.
Expand Down Expand Up @@ -111,7 +115,7 @@ def __init__(
`AssignmentStrategy` for more details
(default: AssignmentStrategy.COMPUTE).
colocate_factors (bool): assign both factors for a single layer to
the same worker. Reccomended when num_layers < world_size
the same worker. Recommended when num_layers < world_size
(default: True).
compute_method (ComputeMethod, str): See `ComputeMethod` for more
details (default: ComputeMethod.EIGEN).
Expand Down Expand Up @@ -143,7 +147,7 @@ def __init__(
the layer to not be registered. The patterns will be applied
against the layer's name and class name.
update_factors_in_hook (bool): If True, running average of factors
is updated in the module hook and the async commmunication is
is updated in the module hook and the async communication is
started. Otherwise, this will be performed at the start of
step() (default: True).
loglevel (int): logging level (default: logging.DEBUG).
Expand Down Expand Up @@ -284,13 +288,13 @@ def __init__(
Callable[[List[int]], dist.ProcessGroup],
dist.new_group,
)
mock_new_group: Callable[[list[int]], None] = lambda x: None

assignment = KAISAAssignment(
work,
local_rank=get_rank(),
world_size=get_world_size(),
grad_worker_fraction=self.grad_worker_fraction,
group_func=new_group if dist.is_initialized() else mock_new_group,
group_func=new_group if dist.is_initialized() else _mock_new_group,
colocate_factors=self.colocate_factors,
)
logger.log(loglevel, f'KFAC layer assignments: {assignment}')
Expand Down
2 changes: 1 addition & 1 deletion kfac/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
"""Decorator for function execution time tracing."""

def func_timer(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
"""Time and executre function."""
"""Time and execute function."""
if sync:
torch.distributed.barrier()
t = time.time()
Expand Down
Loading

0 comments on commit cd95269

Please sign in to comment.