Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 65 additions & 65 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,73 +555,9 @@ def attribute_future(
]
)

def eval_fut_to_ablated_out_fut(
# pyre-ignore Invalid type parameters [24]
eval_futs: Future[List[Future[List[object]]]],
current_inputs: Tuple[Tensor, ...],
current_mask: Tensor,
i: int,
perturbations_per_eval: int,
num_examples: int,
formatted_inputs: Tuple[Tensor, ...],
) -> Tuple[List[Tensor], List[Tensor]]:
try:
modified_eval = cast(Tensor, eval_futs.value()[1].value())
initial_eval_tuple = cast(
Tuple[
List[Tensor],
List[Tensor],
Tensor,
Tensor,
int,
dtype,
],
eval_futs.value()[0].value(),
)
if len(initial_eval_tuple) != 6:
raise AssertionError(
"eval_fut_to_ablated_out_fut: "
"initial_eval_tuple should have 6 elements: "
"total_attrib, weights, initial_eval, "
"flattened_initial_eval, n_outputs, attrib_type "
)
if not isinstance(modified_eval, Tensor):
raise AssertionError(
"eval_fut_to_ablated_out_fut: "
"modified eval should be a Tensor"
)
(
total_attrib,
weights,
initial_eval,
flattened_initial_eval,
n_outputs,
attrib_type,
) = initial_eval_tuple
result = self._process_ablated_out( # type: ignore # noqa: E501 line too long
modified_eval=modified_eval,
current_inputs=current_inputs,
current_mask=current_mask,
perturbations_per_eval=perturbations_per_eval,
num_examples=num_examples,
initial_eval=initial_eval,
flattened_initial_eval=flattened_initial_eval,
inputs=formatted_inputs,
n_outputs=n_outputs,
total_attrib=total_attrib,
weights=weights,
i=i,
attrib_type=attrib_type,
)
except FeatureAblationFutureError as e:
raise FeatureAblationFutureError(
"eval_fut_to_ablated_out_fut func failed)"
) from e
return result

ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = (
eval_futs.then(
lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long
lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long
eval_futs=eval_futs,
current_inputs=current_inputs,
current_mask=current_mask,
Expand Down Expand Up @@ -660,6 +596,70 @@ def _attribute_progress_setup(
)
return attr_progress

def _eval_fut_to_ablated_out_fut(
self,
# pyre-ignore Invalid type parameters [24]
eval_futs: Future[List[Future[List[object]]]],
current_inputs: Tuple[Tensor, ...],
current_mask: Tensor,
i: int,
perturbations_per_eval: int,
num_examples: int,
formatted_inputs: Tuple[Tensor, ...],
) -> Tuple[List[Tensor], List[Tensor]]:
try:
modified_eval = cast(Tensor, eval_futs.value()[1].value())
initial_eval_tuple = cast(
Tuple[
List[Tensor],
List[Tensor],
Tensor,
Tensor,
int,
dtype,
],
eval_futs.value()[0].value(),
)
if len(initial_eval_tuple) != 6:
raise AssertionError(
"eval_fut_to_ablated_out_fut: "
"initial_eval_tuple should have 6 elements: "
"total_attrib, weights, initial_eval, "
"flattened_initial_eval, n_outputs, attrib_type "
)
if not isinstance(modified_eval, Tensor):
raise AssertionError(
"eval_fut_to_ablated_out_fut: " "modified eval should be a Tensor"
)
(
total_attrib,
weights,
initial_eval,
flattened_initial_eval,
n_outputs,
attrib_type,
) = initial_eval_tuple
result = self._process_ablated_out( # type: ignore # noqa: E501 line too long
modified_eval=modified_eval,
current_inputs=current_inputs,
current_mask=current_mask,
perturbations_per_eval=perturbations_per_eval,
num_examples=num_examples,
initial_eval=initial_eval,
flattened_initial_eval=flattened_initial_eval,
inputs=formatted_inputs,
n_outputs=n_outputs,
total_attrib=total_attrib,
weights=weights,
i=i,
attrib_type=attrib_type,
)
except FeatureAblationFutureError as e:
raise FeatureAblationFutureError(
"eval_fut_to_ablated_out_fut func failed)"
) from e
return result

# pyre-fixme[3]: Return type must be specified as type that does not contain `Any`
def _ith_input_ablation_generator(
self,
Expand Down
Loading