Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in influence #358

Merged
merged 4 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

- Fix adding valuation results with overlapping indices and different lengths
[PR #370](https://github.com/appliedAI-Initiative/pyDVL/pull/370)
- Fixed bugs in conjugate gradient and `linear_solve`
[PR #358](https://github.com/appliedAI-Initiative/pyDVL/pull/358)
- Major changes to IF interface and functionality
[PR #278](https://github.com/appliedAI-Initiative/pyDVL/pull/278)

Expand Down
17 changes: 9 additions & 8 deletions src/pydvl/influence/frameworks/torch_differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def solve_linear(
all_y = cat(all_y)
matrix = model.hessian(
all_x, all_y, progress=progress
) + hessian_perturbation * identity_tensor(model.num_params)
) + hessian_perturbation * identity_tensor(model.num_params, device=model.device)
return torch.linalg.solve(matrix, b.T).T


Expand Down Expand Up @@ -149,6 +149,9 @@ def solve_cg(
optimal = False

for k in range(maxiter):
if gamma < stopping_val:
optimal = True
break
Ap = hvp(p).squeeze()
alpha = gamma / torch.sum(torch.matmul(p, Ap)).item()
x += alpha * p
Expand All @@ -158,10 +161,6 @@ def solve_cg(
gamma = gamma_
p = r + beta * p

if gamma < stopping_val:
optimal = True
break

info = {"niter": k, "optimal": optimal}
return x, info

Expand Down Expand Up @@ -269,8 +268,8 @@ def einsum(equation, *operands) -> torch.Tensor:
return torch.einsum(equation, *operands)


def identity_tensor(dim: int) -> torch.Tensor:
return torch.eye(dim, dim)
def identity_tensor(dim: int, **kwargs) -> torch.Tensor:
return torch.eye(dim, dim, **kwargs)


def mvp(
Expand Down Expand Up @@ -312,7 +311,9 @@ def mvp(
return mvp.detach() # type: ignore


class TorchTwiceDifferentiable(TwiceDifferentiable[torch.Tensor, nn.Module]):
class TorchTwiceDifferentiable(
TwiceDifferentiable[torch.Tensor, nn.Module, torch.device]
):
def __init__(
self,
model: nn.Module,
Expand Down
9 changes: 7 additions & 2 deletions src/pydvl/influence/frameworks/twice_differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@

TensorType = TypeVar("TensorType", bound=Sequence)
ModelType = TypeVar("ModelType")
DeviceType = TypeVar("DeviceType")


class TwiceDifferentiable(ABC, Generic[TensorType, ModelType]):
class TwiceDifferentiable(ABC, Generic[TensorType, ModelType, DeviceType]):
"""
Wraps a differentiable model and loss and provides methods to compute the
second derivative of the loss wrt. the model parameters.
"""

def __init__(
self, model: ModelType, loss: Callable[[TensorType, TensorType], TensorType]
self,
model: ModelType,
loss: Callable[[TensorType, TensorType], TensorType],
device: DeviceType,
):
self.device = device
pass

@property
Expand Down
8 changes: 4 additions & 4 deletions src/pydvl/influence/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class InfluenceType(str, Enum):


def compute_influence_factors(
model: TwiceDifferentiable[TensorType, ModelType],
model: TwiceDifferentiable,
training_data: DataLoaderType,
test_data: DataLoaderType,
inversion_method: InversionMethod,
Expand Down Expand Up @@ -82,7 +82,7 @@ def compute_influence_factors(


def compute_influences_up(
model: TwiceDifferentiable[TensorType, ModelType],
model: TwiceDifferentiable,
input_data: DataLoaderType,
influence_factors: TensorType,
*,
Expand Down Expand Up @@ -115,7 +115,7 @@ def compute_influences_up(


def compute_influences_pert(
model: TwiceDifferentiable[TensorType, ModelType],
model: TwiceDifferentiable,
input_data: DataLoaderType,
influence_factors: TensorType,
*,
Expand Down Expand Up @@ -165,7 +165,7 @@ def compute_influences_pert(


def compute_influences(
differentiable_model: TwiceDifferentiable[TensorType, ModelType],
differentiable_model: TwiceDifferentiable,
training_data: DataLoaderType,
*,
test_data: Optional[DataLoaderType] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/pydvl/influence/inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class InversionMethod(str, Enum):

def solve_hvp(
inversion_method: InversionMethod,
model: TwiceDifferentiable[TensorType, ModelType],
model: TwiceDifferentiable,
training_data: DataLoaderType,
b: TensorType,
*,
Expand Down