Skip to content

Commit f4663ce

Browse files
aobo-yfacebook-github-bot
authored andcommitted
fix LLM attribution related mypy issues (#1200)
Summary: as title Reviewed By: vivekmig Differential Revision: D50715178
1 parent 52ebc12 commit f4663ce

File tree

3 files changed

+41
-22
lines changed

3 files changed

+41
-22
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def _format_model_input(self, model_input: Union[str, Tensor]):
153153
raw text and text token tensors
154154
"""
155155
# return tensor(1, n_tokens)
156-
if type(model_input) is str:
156+
if isinstance(model_input, str):
157157
return self.tokenizer.encode(model_input, return_tensors="pt").to(
158158
self.device
159159
)

captum/attr/_utils/interpretable_input.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,24 @@ def __init__(
210210
if baselines is None:
211211
# default baseline is to remove the element
212212
baselines = [""] * len(values)
213-
elif dict_keys:
214-
assert isinstance(baselines, dict), (
215-
"if values is dict, the baselines must also be a dict, "
216-
f"received: {type(baselines)}"
217-
)
213+
elif not callable(baselines):
214+
if dict_keys:
215+
assert isinstance(baselines, dict), (
216+
"if values is a dict, the baselines must also be a dict "
217+
"or a callable which return a dict, "
218+
f"received: {type(baselines)}"
219+
)
218220

219-
# convert dict to list
220-
baselines = [baselines[k] for k in self.dict_keys]
221+
# convert dict to list
222+
baselines = [baselines[k] for k in dict_keys]
223+
else:
224+
assert isinstance(baselines, list), (
225+
"if values is a list, the baselines must also be a list "
226+
"or a callable which return a list, "
227+
f"received: {type(baselines)}"
228+
)
229+
230+
self.baselines = baselines
221231

222232
if mask is None:
223233
n_itp_features = n_features
@@ -247,14 +257,13 @@ def __init__(
247257
if isinstance(template, str):
248258
template = template.format
249259
else:
250-
assert isinstance(template, Callable), (
260+
assert callable(template), (
251261
"the template must be either a string or a callable, "
252262
f"received: {type(template)}"
253263
)
254264
template = template
255265
self.format_fn = template
256266

257-
self.baselines = baselines
258267
self.mask = mask
259268

260269
def to_tensor(self) -> torch.Tensor:
@@ -265,13 +274,23 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> str:
265274
values = list(self.values) # clone
266275

267276
if perturbed_tensor is not None:
268-
baselines = self.baselines
269-
if isinstance(baselines, Callable):
277+
if callable(self.baselines):
270278
# a placeholder for advanced baselines
271279
# TODO: support callable baselines
272280
baselines = self.baselines()
273281
if self.dict_keys:
282+
assert isinstance(baselines, dict), (
283+
"if values is a dict and the baselines is a callable"
284+
f"it must return a dict, received: {type(baselines)}"
285+
)
274286
baselines = [baselines[k] for k in self.dict_keys]
287+
else:
288+
assert isinstance(baselines, list), (
289+
"if values is a list and the baselines is a callable"
290+
f"it must return a list, received: {type(baselines)}"
291+
)
292+
else:
293+
baselines = self.baselines
275294

276295
for i in range(len(values)):
277296
itp_idx = i
@@ -284,8 +303,8 @@ def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> str:
284303
values[i] = baselines[i]
285304

286305
if self.dict_keys:
287-
values = dict(zip(self.dict_keys, values))
288-
input_str = self.format_fn(**values)
306+
dict_values = dict(zip(self.dict_keys, values))
307+
input_str = self.format_fn(**dict_values)
289308
else:
290309
input_str = self.format_fn(*values)
291310

tests/attr/test_interpretable_input.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ class TestTextTemplateInput(BaseTest):
1616
),
1717
]
1818
)
19-
def test_input(self, template, inputs) -> None:
20-
tt_input = TextTemplateInput(template, inputs)
19+
def test_input(self, template, values) -> None:
20+
tt_input = TextTemplateInput(template, values)
2121

2222
expected_tensor = torch.tensor([[1.0] * 4])
2323
assertTensorAlmostEqual(self, tt_input.to_tensor(), expected_tensor)
@@ -37,11 +37,11 @@ def test_input(self, template, inputs) -> None:
3737
),
3838
]
3939
)
40-
def test_input_with_baselines(self, template, inputs, baselines) -> None:
40+
def test_input_with_baselines(self, template, values, baselines) -> None:
4141
perturbed_tensor = torch.tensor([[1.0, 0.0, 1.0, 0.0]])
4242

4343
# single instance baselines
44-
tt_input = TextTemplateInput(template, inputs, baselines=baselines)
44+
tt_input = TextTemplateInput(template, values, baselines=baselines)
4545
self.assertEqual(tt_input.to_model_input(perturbed_tensor), "a b x d e z")
4646

4747
@parameterized.expand(
@@ -54,8 +54,8 @@ def test_input_with_baselines(self, template, inputs, baselines) -> None:
5454
),
5555
]
5656
)
57-
def test_input_with_mask(self, template, inputs, mask) -> None:
58-
tt_input = TextTemplateInput(template, inputs, mask=mask)
57+
def test_input_with_mask(self, template, values, mask) -> None:
58+
tt_input = TextTemplateInput(template, values, mask=mask)
5959

6060
expected_tensor = torch.tensor([[1.0] * 2])
6161
assertTensorAlmostEqual(self, tt_input.to_tensor(), expected_tensor)
@@ -75,8 +75,8 @@ def test_input_with_mask(self, template, inputs, mask) -> None:
7575
),
7676
]
7777
)
78-
def test_format_attr(self, template, inputs, mask) -> None:
79-
tt_input = TextTemplateInput(template, inputs, mask=mask)
78+
def test_format_attr(self, template, values, mask) -> None:
79+
tt_input = TextTemplateInput(template, values, mask=mask)
8080

8181
attr = torch.tensor([[0.1, 0.2]])
8282

0 commit comments

Comments
 (0)