Skip to content

Commit 5eb5498

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Fix Black version and OSS Failures (#1241)
Summary: Currently, OSS GitHub Actions tests are failing due to failing test, lint and typing issues. This updates the black version used externally (and corresponding python version to support latest black) to match the internal updates in D54447730 and also updates flake8 settings to avoid incompatibilities. Typing issues are also resolved and imports from torch._tensor are removed, since these are not supported for previous torch versions. Pull Request resolved: #1241 Reviewed By: cyrjano Differential Revision: D54901754 Pulled By: vivekmig fbshipit-source-id: 2b94bf36488b11b6c145175cfe10fc5433b014fe
1 parent 837168f commit 5eb5498

27 files changed

+78
-65
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
1414
with:
1515
runner: linux.12xlarge
16-
docker-image: cimg/python:3.6
16+
docker-image: cimg/python:3.9
1717
repository: pytorch/captum
1818
script: |
1919
sudo chmod -R 777 .

captum/_utils/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import Enum
44
from functools import reduce
55
from inspect import signature
6-
from typing import Any, Callable, cast, Dict, List, overload, Tuple, Union
6+
from typing import Any, Callable, cast, Dict, List, overload, Sequence, Tuple, Union
77

88
import numpy as np
99
import torch
@@ -683,7 +683,7 @@ def _extract_device(
683683

684684

685685
def _reduce_list(
686-
val_list: List[TupleOrTensorOrBoolGeneric],
686+
val_list: Sequence[TupleOrTensorOrBoolGeneric],
687687
red_func: Callable[[List], Any] = torch.cat,
688688
) -> TupleOrTensorOrBoolGeneric:
689689
"""

captum/influence/_core/tracincp_fast_rand_proj.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class TracInCPFast(TracInCPBase):
8282
def __init__(
8383
self,
8484
model: Module,
85-
final_fc_layer: Union[Module, str],
85+
final_fc_layer: Module,
8686
train_dataset: Union[Dataset, DataLoader],
8787
checkpoints: Union[str, List[str], Iterator],
8888
checkpoints_load_func: Callable = _load_flexible_state_dict,
@@ -96,11 +96,9 @@ def __init__(
9696
9797
model (torch.nn.Module): An instance of pytorch model. This model should
9898
define all of its layers as attributes of the model.
99-
final_fc_layer (torch.nn.Module or str): The last fully connected layer in
99+
final_fc_layer (torch.nn.Module): The last fully connected layer in
100100
the network for which gradients will be approximated via fast random
101-
projection method. Can be either the layer module itself, or the
102-
fully qualified name of the layer if it is a defined attribute of
103-
the passed `model`.
101+
projection method.
104102
train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader):
105103
In the `influence` method, we compute the influence score of
106104
training examples on examples in a test batch.
@@ -869,7 +867,7 @@ class TracInCPFastRandProj(TracInCPFast):
869867
def __init__(
870868
self,
871869
model: Module,
872-
final_fc_layer: Union[Module, str],
870+
final_fc_layer: Module,
873871
train_dataset: Union[Dataset, DataLoader],
874872
checkpoints: Union[str, List[str], Iterator],
875873
checkpoints_load_func: Callable = _load_flexible_state_dict,
@@ -886,11 +884,9 @@ def __init__(
886884
887885
model (torch.nn.Module): An instance of pytorch model. This model should
888886
define all of its layers as attributes of the model.
889-
final_fc_layer (torch.nn.Module or str): The last fully connected layer in
887+
final_fc_layer (torch.nn.Module): The last fully connected layer in
890888
the network for which gradients will be approximated via fast random
891-
projection method. Can be either the layer module itself, or the
892-
fully qualified name of the layer if it is a defined attribute of
893-
the passed `model`.
889+
projection method.
894890
train_dataset (torch.utils.data.Dataset or torch.utils.data.DataLoader):
895891
In the `influence` method, we compute the influence score of
896892
training examples on examples in a test batch.

captum/insights/attr_vis/attribution_calculation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def calculate_attribution(
131131
)
132132
if "baselines" in inspect.signature(attribution_method.attribute).parameters:
133133
attribution_arguments["baselines"] = baseline
134-
attr = attribution_method.attribute.__wrapped__(
134+
attr = attribution_method.attribute.__wrapped__( # type: ignore
135135
attribution_method, # self
136136
data,
137137
additional_forward_args=additional_forward_args,

captum/insights/attr_vis/features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from captum._utils.common import safe_div
99
from captum.attr._utils import visualization as viz
1010
from captum.insights.attr_vis._utils.transforms import format_transforms
11-
from torch._tensor import Tensor
11+
from torch import Tensor
1212

1313
FeatureOutput = namedtuple("FeatureOutput", "name base modified type contribution")
1414

captum/insights/attr_vis/server.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import socket
55
import threading
66
from time import sleep
7-
from typing import Optional
7+
from typing import cast, Dict, Optional
88

99
from captum.log import log_usage
1010
from flask import Flask, jsonify, render_template, request
@@ -41,10 +41,10 @@ def namedtuple_to_dict(obj):
4141
def attribute() -> Response:
4242
# force=True needed for Colab notebooks, which doesn't use the correct
4343
# Content-Type header when forwarding requests through the Colab proxy
44-
r = request.get_json(force=True)
44+
r = cast(Dict, request.get_json(force=True))
4545
return jsonify(
4646
namedtuple_to_dict(
47-
visualizer._calculate_attribution_from_cache(
47+
visualizer._calculate_attribution_from_cache( # type: ignore
4848
r["inputIndex"], r["modelIndex"], r["labelIndex"]
4949
)
5050
)
@@ -54,15 +54,15 @@ def attribute() -> Response:
5454
@app.route("/fetch", methods=["POST"])
5555
def fetch() -> Response:
5656
# force=True needed, see comment for "/attribute" route above
57-
visualizer._update_config(request.get_json(force=True))
58-
visualizer_output = visualizer.visualize()
57+
visualizer._update_config(request.get_json(force=True)) # type: ignore
58+
visualizer_output = visualizer.visualize() # type: ignore
5959
clean_output = namedtuple_to_dict(visualizer_output)
6060
return jsonify(clean_output)
6161

6262

6363
@app.route("/init")
6464
def init() -> Response:
65-
return jsonify(visualizer.get_insights_config())
65+
return jsonify(visualizer.get_insights_config()) # type: ignore
6666

6767

6868
@app.route("/")

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[flake8]
22
# E203: black and flake8 disagree on whitespace before ':'
33
# W503: black and flake8 disagree on how to place operators
4-
ignore = E203, W503
4+
ignore = E203, W503, E704
55
max-line-length = 88
66
exclude =
77
build, dist, tutorials, website

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def report(*args):
6767
INSIGHTS_REQUIRES
6868
+ TEST_REQUIRES
6969
+ [
70-
"black==22.3.0",
70+
"black",
7171
"flake8",
7272
"sphinx",
7373
"sphinx-autodoc-typehints",

tests/attr/helpers/conductance_reference.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
from typing import Optional, Tuple
2+
from typing import cast, Tuple, Union
33

44
import numpy as np
55
import torch
@@ -10,7 +10,8 @@
1010
from captum.attr._utils.approximation_methods import approximation_parameters
1111
from captum.attr._utils.attribution import LayerAttribution
1212
from captum.attr._utils.common import _reshape_and_sum
13-
from torch._tensor import Tensor
13+
from torch import Tensor
14+
from torch.utils.hooks import RemovableHandle
1415

1516
"""
1617
Note: This implementation of conductance follows the procedure described in the original
@@ -55,7 +56,7 @@ def forward_hook(module, inp, out):
5556
# The hidden layer tensor is assumed to have dimension (num_hidden, ...)
5657
# where the product of the dimensions >= 1 correspond to the total
5758
# number of hidden neurons in the layer.
58-
layer_size = tuple(saved_tensor.size())[1:]
59+
layer_size = tuple(cast(Tensor, saved_tensor).size())[1:]
5960
layer_units = int(np.prod(layer_size))
6061

6162
# Remove unnecessary forward hook.
@@ -101,12 +102,12 @@ def forward_hook_register_back(module, inp, out):
101102
input_grads = torch.autograd.grad(torch.unbind(output), expanded_input)
102103

103104
# Remove backwards hook
104-
back_hook.remove()
105+
cast(RemovableHandle, back_hook).remove()
105106

106107
# Remove duplicates in gradient with respect to hidden layer,
107108
# choose one for each layer_units indices.
108109
output_mid_grads = torch.index_select(
109-
saved_grads,
110+
cast(Tensor, saved_grads),
110111
0,
111112
torch.tensor(range(0, input_grads[0].shape[0], layer_units)),
112113
)
@@ -115,7 +116,7 @@ def forward_hook_register_back(module, inp, out):
115116
def attribute(
116117
self,
117118
inputs,
118-
baselines: Optional[int] = None,
119+
baselines: Union[None, int, Tensor] = None,
119120
target=None,
120121
n_steps: int = 500,
121122
method: str = "riemann_trapezoid",

tests/attr/layer/test_layer_lrp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
# mypy: ignore-errors
23

34
from typing import Any, Tuple
45

@@ -9,7 +10,7 @@
910

1011
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1112
from tests.helpers.basic_models import BasicModel_ConvNet_One_Conv, SimpleLRPModel
12-
from torch._tensor import Tensor
13+
from torch import Tensor
1314

1415

1516
def _get_basic_config() -> Tuple[BasicModel_ConvNet_One_Conv, Tensor]:

0 commit comments

Comments
 (0)