From be10a616fbd21565598ccba6e54bb40f19a249b3 Mon Sep 17 00:00:00 2001 From: Zach Carmichael Date: Thu, 30 Oct 2025 14:52:25 -0700 Subject: [PATCH 1/4] Bump minimum Python version to 3.10 (#1660) Summary: Python 3.9 is EOL. Bump minimum version of Python to 3.10. Reviewed By: sarahtranfb Differential Revision: D85569426 --- .conda/meta.yaml | 2 +- .github/workflows/test-conda-cpu.yml | 2 +- .github/workflows/test-pip-cpu.yml | 2 +- CONTRIBUTING.md | 4 ++-- README.md | 2 +- captum/_utils/typing.py | 18 ++---------------- captum/concept/_core/cav.py | 8 ++------ pyproject.toml | 2 +- setup.py | 2 +- 9 files changed, 12 insertions(+), 30 deletions(-) diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 1b6b1cf5c4..070bc4a31b 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -13,7 +13,7 @@ build: requirements: host: - - python>=3.9 + - python>=3.10 - setuptools run: - numpy diff --git a/.github/workflows/test-conda-cpu.yml b/.github/workflows/test-conda-cpu.yml index e0da5e42e3..1211029694 100644 --- a/.github/workflows/test-conda-cpu.yml +++ b/.github/workflows/test-conda-cpu.yml @@ -15,7 +15,7 @@ jobs: tests: strategy: matrix: - python_version: ["3.9", "3.10", "3.11", "3.12"] + python_version: ["3.10", "3.11", "3.12"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: diff --git a/.github/workflows/test-pip-cpu.yml b/.github/workflows/test-pip-cpu.yml index 42bb3a7080..51075cc463 100644 --- a/.github/workflows/test-pip-cpu.yml +++ b/.github/workflows/test-pip-cpu.yml @@ -14,7 +14,7 @@ jobs: matrix: pytorch_args: ["-v 2.3.0", "-v 2.4.0", "-v 2.5.0", "-v 2.6.0", "-v 2.7.0"] transformers_args: ["-t 4.38.0", "-t 4.39.0", "-t 4.41.0", "-t 4.43.0", "-t 4.45.2"] - docker_img: ["cimg/python:3.9", "cimg/python:3.10", "cimg/python:3.11", "cimg/python:3.12"] + docker_img: ["cimg/python:3.10", "cimg/python:3.11", "cimg/python:3.12"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 04642bf1c6..a1f2be883a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -50,7 +50,7 @@ Github Actions will fail on your PR if it does not adhere to the ufmt or flake8 #### Type Hints -Captum is fully typed using Python 3.9+ +Captum is fully typed using Python 3.10+ [type hints](https://www.python.org/dev/peps/pep-0484/). We expect any contributions to also use proper type annotations, and we enforce consistency of these in our continuous integration tests. @@ -63,7 +63,7 @@ Then run this script from the repository root: ``` Note that we expect mypy to have version 0.760 or higher, and when type checking, use PyTorch 1.10 or higher due to fixes to the PyTorch type hints available. We also use the Literal feature which is -available only in Python 3.9 or above. +available only in Python 3.10 or above. We also use [pyre](https://pyre-check.org/) for type checking. For contributors, the nightly version of pyre is used which can be installed with pip `pip install pyre-check-nightly`. To run pyre, you can diff --git a/README.md b/README.md index d7d87fd6dc..c1a49b8691 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Captum can also be used by application engineers who are using trained models in ## Installation **Installation Requirements** -- Python >= 3.9 +- Python >= 3.10 - PyTorch >= 2.3 diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index 512c910f08..f0f7427e96 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -3,17 +3,7 @@ # pyre-strict from collections import UserDict -from typing import ( - List, - Literal, - Optional, - overload, - Protocol, - Tuple, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union from torch import Tensor from torch.nn import Module @@ -51,11 +41,7 @@ # pyre-ignore[24]: Generic type `slice` expects 3 type parameters. SliceIntType = slice # type: ignore -# Necessary for Python >=3.7 and <3.9! -if TYPE_CHECKING: - BatchEncodingType = UserDict[Union[int, str], object] -else: - BatchEncodingType = UserDict +BatchEncodingType = UserDict[Union[int, str], object] class TokenizerLike(Protocol): diff --git a/captum/concept/_core/cav.py b/captum/concept/_core/cav.py index 9cd5cc3137..8b96056d35 100644 --- a/captum/concept/_core/cav.py +++ b/captum/concept/_core/cav.py @@ -4,7 +4,7 @@ import os from contextlib import AbstractContextManager, nullcontext -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Dict, List, Optional import numpy as np import torch @@ -168,11 +168,7 @@ def load( cavs_path = CAV.assemble_save_path(cavs_path, model_id, concepts, layer) if os.path.exists(cavs_path): - # Necessary for Python >=3.7 and <3.9! - if TYPE_CHECKING: - ctx: AbstractContextManager[None, None] - else: - ctx: AbstractContextManager + ctx: AbstractContextManager[None, None] if hasattr(torch.serialization, "safe_globals"): safe_globals = [ # pyre-ignore[16]: Module `numpy.core.multiarray` has no attribute diff --git a/pyproject.toml b/pyproject.toml index 9608c4f127..672071496a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,4 +2,4 @@ first_party_detection = false [tool.black] -target-version = ['py39'] +target-version = ['py310'] diff --git a/setup.py b/setup.py index 32126301b1..10dbfb90c7 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ from setuptools import find_packages, setup REQUIRED_MAJOR = 3 -REQUIRED_MINOR = 9 +REQUIRED_MINOR = 10 # Check for python version if sys.version_info < (REQUIRED_MAJOR, REQUIRED_MINOR): From 5a7c79254ac5f316422883b84d604b4489cd167a Mon Sep 17 00:00:00 2001 From: Zach Carmichael Date: Thu, 30 Oct 2025 14:52:25 -0700 Subject: [PATCH 2/4] Attribution API refactor: Base LLMAttributionResult class + refactor (#1657) Summary: Refactor LLMAttributionResult into an abstract base object that is generic. Create LLMAttributionResult as a concrete child with aliases for captum.attr API supporting legacy use. Changes support the refactor and enable more generalized use beyond logprob-based attribution. Differential Revision: D84721127 --- captum/attr/_core/llm_attr.py | 263 +++++++++++++++++++++++----------- 1 file changed, 179 insertions(+), 84 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index c7422d8d92..29dd38647e 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -5,18 +5,20 @@ from abc import ABC from copy import copy from dataclasses import dataclass -from textwrap import dedent, shorten +from textwrap import shorten from typing import ( Any, Callable, cast, Dict, + Generic, List, Optional, Tuple, Type, TYPE_CHECKING, + TypeVar, Union, ) @@ -56,130 +58,138 @@ "temperature": None, "top_p": None, } +TInputValue = TypeVar("TInputValue") +TTargetValue = TypeVar("TTargetValue") -@dataclass -class LLMAttributionResult: +@dataclass(kw_only=True) +class BaseLLMAttributionResult(ABC, Generic[TInputValue, TTargetValue]): """ Data class for the return result of LLMAttribution, which includes the necessary properties of the attribution. It also provides utilities to help present and plot the result in different forms. """ - input_tokens: List[str] - output_tokens: List[str] - # pyre-ignore[13]: initialized via a property setter - _seq_attr: Tensor - _token_attr: Optional[Tensor] = None - _output_probs: Optional[Tensor] = None + input_values: List[TInputValue] # ablated values + target_names: List[str] # names of each target, e.g. judge name or tokens + _target_values: Optional[ + List[TTargetValue] + ] # value for each target name e.g. token prob + _aggregate_attr: Tensor # 1D [# input_values] + _element_attr: Optional[Tensor] = None # 2D [# target_names, # input_values] + aggregate_descriptor: str = "Aggregate" + element_descriptor: str = "Element" def __init__( self, *, - input_tokens: List[str], - output_tokens: List[str], - seq_attr: npt.ArrayLike, - token_attr: Optional[npt.ArrayLike] = None, - output_probs: Optional[npt.ArrayLike] = None, + input_values: List[TInputValue], + target_names: List[str], + target_values: Optional[npt.ArrayLike] = None, + aggregate_attr: npt.ArrayLike, + element_attr: Optional[npt.ArrayLike] = None, + aggregate_descriptor: str = "Aggregate", + element_descriptor: str = "Element", ) -> None: - self.input_tokens = input_tokens - self.output_tokens = output_tokens - self.seq_attr = seq_attr - self.token_attr = token_attr - self.output_probs = output_probs + self.input_values = input_values + self.target_names = target_names + self.target_values = target_values + self.aggregate_attr = aggregate_attr + self.element_attr = element_attr + self.aggregate_descriptor = aggregate_descriptor + self.element_descriptor = element_descriptor @property - def seq_attr(self) -> Tensor: - return self._seq_attr + def aggregate_attr(self) -> Tensor: + return self._aggregate_attr - @seq_attr.setter - def seq_attr(self, seq_attr: npt.ArrayLike) -> None: + @aggregate_attr.setter + def aggregate_attr(self, seq_attr: npt.ArrayLike) -> None: if isinstance(seq_attr, Tensor): - self._seq_attr = seq_attr + self._aggregate_attr = seq_attr else: - self._seq_attr = torch.tensor(seq_attr) + self._aggregate_attr = torch.tensor(seq_attr) # IDEA: in the future we might want to support higher dim seq_attr # (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes) - assert len(self._seq_attr.shape) == 1, "seq_attr must be a 1D tensor" + assert len(self._aggregate_attr.shape) == 1, "seq_attr must be a 1D tensor" assert ( - len(self.input_tokens) == self._seq_attr.shape[0] + len(self.input_values) == self._aggregate_attr.shape[0] ), "seq_attr and input_tokens must have the same length" @property - def token_attr(self) -> Optional[Tensor]: - return self._token_attr + def element_attr(self) -> Optional[Tensor]: + return self._element_attr - @token_attr.setter - def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None: + @element_attr.setter + def element_attr(self, token_attr: Optional[npt.ArrayLike]) -> None: if token_attr is None: - self._token_attr = None + self._element_attr = None elif isinstance(token_attr, Tensor): - self._token_attr = token_attr + self._element_attr = token_attr else: - self._token_attr = torch.tensor(token_attr) + self._element_attr = torch.tensor(token_attr) - if self._token_attr is not None: + if self._element_attr is not None: # IDEA: in the future we might want to support higher dim seq_attr - assert len(self._token_attr.shape) == 2, "token_attr must be a 2D tensor" - assert self._token_attr.shape == ( - len(self.output_tokens), - len(self.input_tokens), - ), dedent( - f"""\ - Expect token_attr to have shape - {len(self.output_tokens), len(self.input_tokens)}, - got {self._token_attr.shape} - """ + assert len(self._element_attr.shape) == 2, "token_attr must be a 2D tensor" + assert self._element_attr.shape == ( + len(self.target_names), + len(self.input_values), + ), ( + "Expect token_attr to have shape " + f"({len(self.target_names), len(self.input_values)}), " + f"got {self._element_attr.shape}" ) @property - def output_probs(self) -> Optional[Tensor]: - return self._output_probs - - @output_probs.setter - def output_probs(self, output_probs: Optional[npt.ArrayLike]) -> None: - if output_probs is None: - self._output_probs = None - elif isinstance(output_probs, Tensor): - self._output_probs = output_probs + def target_values(self) -> Optional[List[TTargetValue]]: + return self._target_values + + @target_values.setter + def target_values(self, target_values: Optional[npt.ArrayLike]) -> None: + if target_values is None: + self._target_values = None + elif isinstance(target_values, (Tensor, np.ndarray)): + self._target_values = target_values.tolist() else: - self._output_probs = torch.tensor(output_probs) + # pyre-ignore[6]: should be iterable + self._target_values = list(target_values) - if self._output_probs is not None: - assert ( - len(self._output_probs.shape) == 1 - ), "output_probs must be a 1D tensor" - assert ( - len(self.output_tokens) == self._output_probs.shape[0] - ), "seq_attr and input_tokens must have the same length" + if self._target_values is not None: + assert len(self._target_values) == len( + self.target_names + ), f"{len(self._target_values)=} and {len(self.target_names)=} must have the same length" @property - def seq_attr_dict(self) -> Dict[str, float]: - return {k: v for v, k in zip(self.seq_attr.cpu().tolist(), self.input_tokens)} + def aggregate_attr_dict(self) -> Dict[TInputValue, float]: + return { + k: v for v, k in zip(self.aggregate_attr.cpu().tolist(), self.input_values) + } - def plot_token_attr( + def plot_element_attr( self, show: bool = False ) -> Union[None, Tuple["Figure", "Axes"]]: """ Generate a matplotlib plot for visualising the attribution - of the output tokens. + of the output elements. Args: show (bool): whether to show the plot directly or return the figure and axis Default: False """ - if self.token_attr is None: + if self.element_attr is None: raise ValueError( - "token_attr is None (no token-level attribution was performed), please " - "use plot_seq_attr instead for the sequence-level attribution plot" + f"element_attr is None (no {self.element_descriptor.lower()}-level attribution was " + "performed), please use plot_aggregate_attr instead for the " + f"{self.aggregate_descriptor}-level attribution plot" ) - token_attr = self.token_attr.cpu() + element_attr = self.element_attr.cpu() # maximum absolute attribution value # used as the boundary of normalization # always keep 0 as the mid point to differentiate pos/neg attr - max_abs_attr_val = token_attr.abs().max().item() + max_abs_attr_val = element_attr.abs().max().item() import matplotlib.pyplot as plt @@ -189,7 +199,7 @@ def plot_token_attr( ax.grid(False) # Plot the heatmap - data = token_attr.numpy() + data = element_attr.numpy() fig.set_size_inches( max(data.shape[1] * 1.3, 6.4), max(data.shape[0] / 2.5, 4.8) @@ -219,17 +229,19 @@ def plot_token_attr( # Create colorbar cbar = fig.colorbar(im, ax=ax) # type: ignore - cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom") + cbar.ax.set_ylabel( + f"{self.element_descriptor} Attribution", rotation=-90, va="bottom" + ) # Show all ticks and label them with the respective list entries. - shortened_tokens = [ + shortened_values = [ shorten(repr(t)[1:-1], width=50, placeholder="...") - for t in self.input_tokens + for t in self.input_values ] - ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens) + ax.set_xticks(np.arange(data.shape[1]), labels=shortened_values) ax.set_yticks( np.arange(data.shape[0]), - labels=[repr(token)[1:-1] for token in self.output_tokens], + labels=[repr(name)[1:-1] for name in self.target_names], ) # Let the horizontal axes labeling appear on top. @@ -259,10 +271,12 @@ def plot_token_attr( else: return fig, ax - def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes"]]: + def plot_aggregated_attr( + self, show: bool = False + ) -> Union[None, Tuple["Figure", "Axes"]]: """ Generate a matplotlib plot for visualising the attribution - of the output sequence. + of the aggregated output. Args: show (bool): whether to show the plot directly or return the figure and axis @@ -273,15 +287,15 @@ def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes fig, ax = plt.subplots() - data = self.seq_attr.cpu().numpy() + data = self.aggregate_attr.cpu().numpy() fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8)) - shortened_tokens = [ + shortened_values = [ shorten(repr(t)[1:-1], width=50, placeholder="...") - for t in self.input_tokens + for t in self.input_values ] - ax.set_xticks(range(data.shape[0]), labels=shortened_tokens) + ax.set_xticks(range(data.shape[0]), labels=shortened_values) ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) @@ -309,7 +323,9 @@ def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes color="#d0365b", ) - ax.set_ylabel("Sequence Attribution", rotation=90, va="bottom") + ax.set_ylabel( + f"{self.aggregate_descriptor} Attribution", rotation=90, va="bottom" + ) if show: plt.show() @@ -317,6 +333,85 @@ def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes else: return fig, ax + # Aliases + + @property + def input_tokens(self) -> List[TInputValue]: + return self.input_values + + @input_tokens.setter + def input_tokens(self, input_tokens: List[TInputValue]) -> None: + self.input_values = input_tokens + + @property + def output_tokens(self) -> List[str]: + return self.target_names + + @output_tokens.setter + def output_tokens(self, output_tokens: List[str]) -> None: + self.target_names = output_tokens + + @property + def output_probs(self) -> Optional[List[TTargetValue]]: + return self.target_values + + @output_probs.setter + def output_probs(self, output_probs: Optional[npt.ArrayLike]) -> None: + self.target_values = output_probs + + @property + def seq_attr(self) -> Tensor: + return self.aggregate_attr + + @seq_attr.setter + def seq_attr(self, seq_attr: npt.ArrayLike) -> None: + self.aggregate_attr = seq_attr + + @property + def token_attr(self) -> Optional[Tensor]: + return self.element_attr + + @token_attr.setter + def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None: + self.element_attr = token_attr + + @property + def seq_attr_dict(self) -> Dict[TInputValue, float]: + return self.aggregate_attr_dict + + def plot_token_attr( + self, show: bool = False + ) -> Union[None, Tuple["Figure", "Axes"]]: + return self.plot_element_attr(show=show) + + def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes"]]: + return self.plot_aggregated_attr(show=show) + + +@dataclass(kw_only=True) +# pyre-ignore[13]: _aggregate_attr and _target_values initialized via setters +class LLMAttributionResult(BaseLLMAttributionResult[str, float]): + """LLM Attribution Result for the captum.attr API""" + + def __init__( + self, + *, + input_tokens: List[str], + output_tokens: List[str], + seq_attr: npt.ArrayLike, + token_attr: Optional[npt.ArrayLike] = None, + output_probs: Optional[npt.ArrayLike] = None, + ) -> None: + super().__init__( + input_values=input_tokens, + target_names=output_tokens, + target_values=output_probs, + aggregate_attr=seq_attr, + element_attr=token_attr, + aggregate_descriptor="Sequence", + element_descriptor="Token", + ) + def _clean_up_pretty_token(token: str) -> str: """Remove newlines and leading/trailing whitespace from token.""" From f8418df654434c56a5482c2133227cf92a45de7d Mon Sep 17 00:00:00 2001 From: Zach Carmichael Date: Thu, 30 Oct 2025 14:52:25 -0700 Subject: [PATCH 3/4] Attribution API refactor: Update LLM attr typing + minor naming (#1658) Summary: Update LLM attr definition to accommodate other typing considerations. Clean up some variable names as well. Reviewed By: jimshao1999 Differential Revision: D84721071 --- captum/attr/_core/llm_attr.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 29dd38647e..185483fa71 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -85,7 +85,7 @@ def __init__( *, input_values: List[TInputValue], target_names: List[str], - target_values: Optional[npt.ArrayLike] = None, + target_values: Optional[Union[npt.ArrayLike, List[TTargetValue]]] = None, aggregate_attr: npt.ArrayLike, element_attr: Optional[npt.ArrayLike] = None, aggregate_descriptor: str = "Aggregate", @@ -104,11 +104,11 @@ def aggregate_attr(self) -> Tensor: return self._aggregate_attr @aggregate_attr.setter - def aggregate_attr(self, seq_attr: npt.ArrayLike) -> None: - if isinstance(seq_attr, Tensor): - self._aggregate_attr = seq_attr + def aggregate_attr(self, aggregate_attr: npt.ArrayLike) -> None: + if isinstance(aggregate_attr, Tensor): + self._aggregate_attr = aggregate_attr else: - self._aggregate_attr = torch.tensor(seq_attr) + self._aggregate_attr = torch.tensor(aggregate_attr) # IDEA: in the future we might want to support higher dim seq_attr # (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes) assert len(self._aggregate_attr.shape) == 1, "seq_attr must be a 1D tensor" @@ -121,13 +121,13 @@ def element_attr(self) -> Optional[Tensor]: return self._element_attr @element_attr.setter - def element_attr(self, token_attr: Optional[npt.ArrayLike]) -> None: - if token_attr is None: + def element_attr(self, element_attr: Optional[npt.ArrayLike]) -> None: + if element_attr is None: self._element_attr = None - elif isinstance(token_attr, Tensor): - self._element_attr = token_attr + elif isinstance(element_attr, Tensor): + self._element_attr = element_attr else: - self._element_attr = torch.tensor(token_attr) + self._element_attr = torch.tensor(element_attr) if self._element_attr is not None: # IDEA: in the future we might want to support higher dim seq_attr @@ -146,7 +146,9 @@ def target_values(self) -> Optional[List[TTargetValue]]: return self._target_values @target_values.setter - def target_values(self, target_values: Optional[npt.ArrayLike]) -> None: + def target_values( + self, target_values: Optional[Union[npt.ArrayLike, List[TTargetValue]]] + ) -> None: if target_values is None: self._target_values = None elif isinstance(target_values, (Tensor, np.ndarray)): From a8da853f175642706b32c7ef894a862d82070215 Mon Sep 17 00:00:00 2001 From: Zach Carmichael Date: Thu, 30 Oct 2025 14:52:25 -0700 Subject: [PATCH 4/4] Attribution API refactor: Introduce an optional agg/element-wise variance to LLM attribution results (#1659) Summary: As title. It is possible for attr to be computed as an estimated amount over multiple samples of the response, so the estimate has variance. This adds an attribute to store this variance in the results, if we have it. Differential Revision: D84970183 --- captum/attr/_core/llm_attr.py | 95 +++++++++++++++++++++++++++++++---- 1 file changed, 85 insertions(+), 10 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 185483fa71..b39c54fda7 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -14,7 +14,9 @@ Dict, Generic, List, + Literal, Optional, + overload, Tuple, Type, TYPE_CHECKING, @@ -62,6 +64,26 @@ TTargetValue = TypeVar("TTargetValue") +@overload +def _to_tensor( + name: str, arr: Optional[npt.ArrayLike], none_ok: Literal[True] = ... +) -> Optional[Tensor]: ... +@overload +def _to_tensor( + name: str, arr: Optional[npt.ArrayLike], none_ok: Literal[False] = ... +) -> Tensor: ... +def _to_tensor( + name: str, arr: Optional[npt.ArrayLike], none_ok: bool = False +) -> Optional[Tensor]: + if arr is None: + if none_ok: + return None + raise TypeError(f"Expected array-like for `{name}` but received None!") + if not isinstance(arr, Tensor): + arr = torch.tensor(arr) + return arr + + @dataclass(kw_only=True) class BaseLLMAttributionResult(ABC, Generic[TInputValue, TTargetValue]): """ @@ -77,6 +99,8 @@ class BaseLLMAttributionResult(ABC, Generic[TInputValue, TTargetValue]): ] # value for each target name e.g. token prob _aggregate_attr: Tensor # 1D [# input_values] _element_attr: Optional[Tensor] = None # 2D [# target_names, # input_values] + _aggregate_attr_var: Optional[Tensor] = None # 1D [# input_values] + _element_attr_var: Optional[Tensor] = None # 2D [# target_names, # input_values] aggregate_descriptor: str = "Aggregate" element_descriptor: str = "Element" @@ -88,6 +112,8 @@ def __init__( target_values: Optional[Union[npt.ArrayLike, List[TTargetValue]]] = None, aggregate_attr: npt.ArrayLike, element_attr: Optional[npt.ArrayLike] = None, + aggregate_attr_var: Optional[npt.ArrayLike] = None, + element_attr_var: Optional[npt.ArrayLike] = None, aggregate_descriptor: str = "Aggregate", element_descriptor: str = "Element", ) -> None: @@ -96,6 +122,8 @@ def __init__( self.target_values = target_values self.aggregate_attr = aggregate_attr self.element_attr = element_attr + self.aggregate_attr_var = aggregate_attr_var + self.element_attr_var = element_attr_var self.aggregate_descriptor = aggregate_descriptor self.element_descriptor = element_descriptor @@ -105,10 +133,9 @@ def aggregate_attr(self) -> Tensor: @aggregate_attr.setter def aggregate_attr(self, aggregate_attr: npt.ArrayLike) -> None: - if isinstance(aggregate_attr, Tensor): - self._aggregate_attr = aggregate_attr - else: - self._aggregate_attr = torch.tensor(aggregate_attr) + self._aggregate_attr = _to_tensor( + "aggregate_attr", aggregate_attr, none_ok=False + ) # IDEA: in the future we might want to support higher dim seq_attr # (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes) assert len(self._aggregate_attr.shape) == 1, "seq_attr must be a 1D tensor" @@ -122,12 +149,7 @@ def element_attr(self) -> Optional[Tensor]: @element_attr.setter def element_attr(self, element_attr: Optional[npt.ArrayLike]) -> None: - if element_attr is None: - self._element_attr = None - elif isinstance(element_attr, Tensor): - self._element_attr = element_attr - else: - self._element_attr = torch.tensor(element_attr) + self._element_attr = _to_tensor("element_attr", element_attr, none_ok=True) if self._element_attr is not None: # IDEA: in the future we might want to support higher dim seq_attr @@ -141,6 +163,39 @@ def element_attr(self, element_attr: Optional[npt.ArrayLike]) -> None: f"got {self._element_attr.shape}" ) + @property + def aggregate_attr_var(self) -> Optional[Tensor]: + return self._aggregate_attr_var + + @aggregate_attr_var.setter + def aggregate_attr_var(self, aggregate_attr_var: Optional[npt.ArrayLike]) -> None: + self._aggregate_attr_var = _to_tensor( + "aggregate_attr_var", aggregate_attr_var, none_ok=True + ) + if self._aggregate_attr_var is not None: + assert self._aggregate_attr_var.shape == self._aggregate_attr.shape, ( + f"aggregate_attr ({self._aggregate_attr.shape}) must have same shape " + f"as aggregate_attr_var ({self._aggregate_attr_var.shape})" + ) + + @property + def element_attr_var(self) -> Optional[Tensor]: + return self._element_attr_var + + @element_attr_var.setter + def element_attr_var(self, element_attr_var: Optional[npt.ArrayLike]) -> None: + self._element_attr_var = _to_tensor( + "element_attr_var", element_attr_var, none_ok=True + ) + if self._element_attr_var is not None: + assert ( + self._element_attr is not None + ), "element_attr must be set before setting element_attr_var" + assert self._element_attr_var.shape == self._element_attr.shape, ( + f"element_attr ({self._element_attr.shape}) must have same shape " + f"as element_attr_var ({self._element_attr_var.shape})" + ) + @property def target_values(self) -> Optional[List[TTargetValue]]: return self._target_values @@ -377,6 +432,22 @@ def token_attr(self) -> Optional[Tensor]: def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None: self.element_attr = token_attr + @property + def seq_attr_var(self) -> Optional[Tensor]: + return self.aggregate_attr_var + + @seq_attr_var.setter + def seq_attr_var(self, seq_attr_var: Optional[npt.ArrayLike]) -> None: + self.aggregate_attr_var = seq_attr_var + + @property + def token_attr_var(self) -> Optional[Tensor]: + return self.element_attr_var + + @token_attr_var.setter + def token_attr_var(self, token_attr_var: Optional[npt.ArrayLike]) -> None: + self.element_attr_var = token_attr_var + @property def seq_attr_dict(self) -> Dict[TInputValue, float]: return self.aggregate_attr_dict @@ -402,6 +473,8 @@ def __init__( output_tokens: List[str], seq_attr: npt.ArrayLike, token_attr: Optional[npt.ArrayLike] = None, + seq_attr_var: Optional[npt.ArrayLike] = None, + token_attr_var: Optional[npt.ArrayLike] = None, output_probs: Optional[npt.ArrayLike] = None, ) -> None: super().__init__( @@ -410,6 +483,8 @@ def __init__( target_values=output_probs, aggregate_attr=seq_attr, element_attr=token_attr, + aggregate_attr_var=seq_attr_var, + element_attr_var=token_attr_var, aggregate_descriptor="Sequence", element_descriptor="Token", )