Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arbitrary X in data #100

Merged
merged 21 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions curvlinops/_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Contains functionality to analyze Hessian & GGN via matrix-free multiplication."""

from collections.abc import MutableMapping
from typing import Callable, Iterable, List, Optional, Tuple, Union
from warnings import warn

Expand Down Expand Up @@ -33,12 +34,13 @@ def __init__(
model_func: Callable[[Tensor], Tensor],
loss_func: Union[Callable[[Tensor, Tensor], Tensor], None],
params: List[Parameter],
data: Iterable[Tuple[Tensor, Tensor]],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
shape: Optional[Tuple[int, int]] = None,
num_data: Optional[int] = None,
block_sizes: Optional[List[int]] = None,
batch_size_fn: Optional[Callable[[MutableMapping], int]] = None,
):
"""Linear operator for DNN matrices.

Expand All @@ -55,7 +57,11 @@ def __init__(
represented matrix is independent of the loss function.
params: List of differentiable parameters used by the prediction function.
data: Source from which mini-batches can be drawn, for instance a list of
mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``.
mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``. Note that ``X``
could be a ``dict`` or ``UserDict``; this is useful for custom models.
In this case, you must (i) specify the ``batch_size_fn`` argument, and
(ii) take care of preprocessing like ``X.to(device)`` inside of your
``model.forward()`` function.
progressbar: Show a progressbar during matrix-multiplication.
Default: ``False``.
check_deterministic: Probe that model and data are deterministic, i.e.
Expand All @@ -73,14 +79,23 @@ def __init__(
For instance ``[len(params)]`` considers the full matrix, while
``[1, 1, ...]`` corresponds to a block diagonal approximation where
each parameter forms its own block.
batch_size_fn: If the ``X``'s in ``data`` are not ``torch.Tensor``, this
needs to be specified. The intended behavior is to consume the first
entry of the iterates from ``data`` and return their batch size.

Raises:
RuntimeError: If the check for deterministic behavior fails.
ValueError: If ``block_sizes`` is specified but the linear operator does not
support blocks.
ValueError: If the sum of blocks does not equal the number of parameters.
ValueError: If any block size is not positive.
ValueError: If ``X`` is not a tensor and ``batch_size_fn`` is not specified.
"""
if isinstance(next(iter(data))[0], MutableMapping) and batch_size_fn is None:
raise ValueError(
"When using dict-like custom data, `batch_size_fn` is required."
)

if shape is None:
dim = sum(p.numel() for p in params)
shape = (dim, dim)
Expand All @@ -103,9 +118,15 @@ def __init__(
self._data = data
self._device = self._infer_device(self._params)
self._progressbar = progressbar
self._batch_size_fn = (
(lambda X: X.shape[0]) if batch_size_fn is None else batch_size_fn
)

self._N_data = (
sum(X.shape[0] for (X, _) in self._loop_over_data(desc="_N_data"))
sum(
self._batch_size_fn(X)
for (X, _) in self._loop_over_data(desc="_N_data")
)
if num_data is None
else num_data
)
Expand Down Expand Up @@ -328,7 +349,11 @@ def _loop_over_data(
data_iter = tqdm(data_iter, desc=desc)

for X, y in data_iter:
X, y = X.to(self._device), y.to(self._device)
# Assume everything is handled by the model
# if `X` is a custom data format
if isinstance(X, Tensor):
X = X.to(self._device)
y = y.to(self._device)
yield (X, y)

def gradient_and_loss(self) -> Tuple[List[Tensor], Tensor]:
Expand Down Expand Up @@ -368,7 +393,7 @@ def _get_normalization_factor(self, X: Tensor, y: Tensor) -> float:
Returns:
Normalization factor
"""
return {"sum": 1.0, "mean": X.shape[0] / self._N_data}[
return {"sum": 1.0, "mean": self._batch_size_fn(X) / self._N_data}[
self._loss_func.reduction
]

Expand Down
84 changes: 61 additions & 23 deletions curvlinops/examples/functorch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Contains functorch functionality for the examples."""

from collections.abc import MutableMapping
from math import sqrt
from typing import Dict, Iterable, List, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import Tensor, cat, einsum
from torch.func import functional_call, grad, hessian, jacrev, jvp, vmap
from torch.nn import Module
Expand Down Expand Up @@ -35,7 +37,8 @@ def functorch_hessian(
model_func: Module,
loss_func: Module,
params: List[Tensor],
data: Iterable[Tuple[Tensor, Tensor]],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
input_key: Optional[str] = None,
) -> Tensor:
"""Compute the Hessian with functorch.

Expand All @@ -47,16 +50,18 @@ def functorch_hessian(
params: List of differentiable parameters used by the prediction function.
data: Source from which mini-batches can be drawn, for instance a list of
mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``.
input_key: Key to obtain the input tensor when ``X`` is a dict-like object.

Returns:
Square matrix containing the Hessian.
"""
(dev,) = {p.device for p in params}
X, y = _concatenate_batches(data)
X, y = X.to(dev), y.to(dev)
X, y = _concatenate_batches(data, input_key, device=dev)
params_dict = _make_params_dict(model_func, params)

def loss(X: Tensor, y: Tensor, params_dict: Dict[str, Tensor]) -> Tensor:
def loss(
X: Union[Tensor, MutableMapping], y: Tensor, params_dict: Dict[str, Tensor]
) -> Tensor:
"""Compute the loss given a mini-batch and the neural network parameters.

# noqa: DAR101
Expand All @@ -75,7 +80,8 @@ def functorch_ggn(
model_func: Module,
loss_func: Module,
params: List[Tensor],
data: Iterable[Tuple[Tensor, Tensor]],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
input_key: Optional[str] = None,
) -> Tensor:
"""Compute the GGN with functorch.

Expand All @@ -89,17 +95,19 @@ def functorch_ggn(
params: List of differentiable parameters used by the prediction function.
data: Source from which mini-batches can be drawn, for instance a list of
mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``.
input_key: Key to obtain the input tensor when ``X`` is a dict-like object.

Returns:
Square matrix containing the GGN.
"""
(dev,) = {p.device for p in params}
X, y = _concatenate_batches(data)
X, y = X.to(dev), y.to(dev)
X, y = _concatenate_batches(data, input_key, device=dev)
params_dict = _make_params_dict(model_func, params)

def linearized_model(
anchor_dict: Dict[str, Tensor], params_dict: Dict[str, Tensor], X: Tensor
anchor_dict: Dict[str, Tensor],
params_dict: Dict[str, Tensor],
X: Union[Tensor, MutableMapping],
) -> Tensor:
"""Evaluate the model at params, using its linearization around anchor.

Expand All @@ -118,7 +126,7 @@ def model_fn_params_only(params_dict: Dict[str, Tensor]) -> Tensor:
return model_at_anchor + jvp_diff

def linearized_loss(
X: Tensor,
X: Union[Tensor, MutableMapping],
y: Tensor,
anchor_dict: Dict[str, Tensor],
params_dict: Dict[str, Tensor],
Expand Down Expand Up @@ -146,7 +154,8 @@ def functorch_gradient(
model_func: Module,
loss_func: Module,
params: List[Tensor],
data: Iterable[Tuple[Tensor, Tensor]],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
input_key: Optional[str] = None,
) -> Tuple[Tensor]:
"""Compute the gradient with functorch.

Expand All @@ -158,14 +167,18 @@ def functorch_gradient(
params: List of differentiable parameters used by the prediction function.
data: Source from which mini-batches can be drawn, for instance a list of
mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``.
input_key: Key to obtain the input tensor when ``X`` is a dict-like object.

Returns:
Gradient in same format as the parameters.
"""
X, y = _concatenate_batches(data)
(dev,) = {p.device for p in params}
X, y = _concatenate_batches(data, input_key, device=dev)
params_dict = _make_params_dict(model_func, params)

def loss(X: Tensor, y: Tensor, params_dict: Dict[str, Tensor]) -> Tensor:
def loss(
X: Union[Tensor, MutableMapping], y: Tensor, params_dict: Dict[str, Tensor]
) -> Tensor:
"""Compute the loss given a mini-batch and the neural network parameters.

# noqa: DAR101
Expand All @@ -184,7 +197,9 @@ def functorch_empirical_fisher(
model_func: Module,
loss_func: Module,
params: List[Tensor],
data: Iterable[Tuple[Tensor, Tensor]],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
batch_size_fn: Optional[Callable[[MutableMapping], int]] = None,
input_key: Optional[str] = None,
) -> Tensor:
"""Compute the empirical Fisher with functorch.

Expand All @@ -196,6 +211,9 @@ def functorch_empirical_fisher(
params: List of differentiable parameters used by the prediction function.
data: Source from which mini-batches can be drawn, for instance a list of
mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``.
batch_size_fn: Given an input ``X``, tells the batch size. When ``None``,
defaults to ``lambda X: X.shape[0]``.
input_key: Key to obtain the input tensor when ``X`` is a dict-like object.

Returns:
Square matrix containing the empirical Fisher.
Expand All @@ -204,12 +222,13 @@ def functorch_empirical_fisher(
ValueError: If the loss function's reduction cannot be determined.
"""
(dev,) = {p.device for p in params}
X, y = _concatenate_batches(data)
X, y = X.to(dev), y.to(dev)
X, y = _concatenate_batches(data, input_key, device=dev)
params_dict = _make_params_dict(model_func, params)

# compute batched gradients
def loss_n(X_n: Tensor, y_n: Tensor, params_dict: Dict[str, Tensor]) -> Tensor:
def loss_n(
X_n: Union[Tensor, MutableMapping], y_n: Tensor, params_dict: Dict[str, Tensor]
) -> Tensor:
"""Compute the gradient for a single sample.

# noqa: DAR101
Expand All @@ -220,8 +239,8 @@ def loss_n(X_n: Tensor, y_n: Tensor, params_dict: Dict[str, Tensor]) -> Tensor:

params_argnum = 2
batch_grad_fn = vmap(grad(loss_n, argnums=params_argnum))
N = X.shape[0] if batch_size_fn is None else batch_size_fn(X)

N = X.shape[0]
params_replicated_dict = {
name: p.unsqueeze(0).expand(N, *(p.dim() * [-1]))
for name, p in params_dict.items()
Expand All @@ -243,7 +262,8 @@ def loss_n(X_n: Tensor, y_n: Tensor, params_dict: Dict[str, Tensor]) -> Tensor:
def functorch_jacobian(
model_func: Module,
params: List[Tensor],
data: Iterable[Tuple[Tensor, Tensor]],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
input_key: Optional[str] = None,
) -> Tensor:
"""Compute the Jacobian with functorch.

Expand All @@ -253,15 +273,15 @@ def functorch_jacobian(
params: List of differentiable parameters used by the prediction function.
data: Source from which mini-batches can be drawn, for instance a list of
mini-batches ``[(X, y), ...]`` or a torch ``DataLoader``.
input_key: Key to obtain the input tensor when ``X`` is a dict-like object.

Returns:
Matrix containing the Jacobian. Has shape ``[N * C, D]`` where ``D`` is the
total number of parameters, ``N`` the total number of data points, and ``C``
the model's output space dimension.
"""
(dev,) = {p.device for p in params}
X, _ = _concatenate_batches(data)
X = X.to(dev)
X, _ = _concatenate_batches(data, input_key, device=dev)
params_dict = _make_params_dict(model_func, params)

def model_fn_params_only(params_dict: Dict[str, Tensor]) -> Tensor:
Expand All @@ -279,19 +299,37 @@ def model_fn_params_only(params_dict: Dict[str, Tensor]) -> Tensor:


def _concatenate_batches(
data: Iterable[Tuple[Tensor, Tensor]]
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
input_key: Optional[str] = None,
device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
"""Concatenate all batches in the dataset along the batch dimension.

Args:
data: A dataloader or iterable of batches.
input_key: Key to obtain the input tensor when ``X`` is a dict-like object.
device: The device the data should live in.

Returns:
Concatenated model inputs.
Concatenated targets.

Raises:
ValueError: If ``X`` in ``data`` is a dict-like object and ``input_key`` is
not provided.
"""
X, y = list(zip(*list(data)))
return cat(X), cat(y)
device = y[0].device if device is None else device
y = cat(y).to(device)

if isinstance(X[0], MutableMapping) and input_key is None:
raise ValueError("input_key must be provided for dict-like X!")

if isinstance(X[0], Tensor):
return cat(X).to(device), y
else:
X = {input_key: cat([d[input_key] for d in X]).to(device)}
return X, y


def _make_params_dict(model_func: Module, params: List[Tensor]) -> Dict[str, Tensor]:
Expand Down
Loading
Loading