diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index 3b1695e7..c9ff0636 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -46,7 +46,7 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): if hasattr(module, "q_proj"): Wq = module.q_proj.weight elif hasattr(module, "qkv_proj"): - Wq = module.qkv_proj.weight[: n * d] + Wq = module.qkv_proj.weight[: n * d] # type: ignore[index] else: raise NotImplementedError(f"ExpectedAttentionPress not yet implemented for {module.__class__}.") diff --git a/kvpress/presses/per_layer_compression_press.py b/kvpress/presses/per_layer_compression_press.py index 80c6db77..31e87d61 100644 --- a/kvpress/presses/per_layer_compression_press.py +++ b/kvpress/presses/per_layer_compression_press.py @@ -35,8 +35,8 @@ def __post_init__(self): assert isinstance(self.press, ScorerPress), "PerLayerCompressionPress requires a ScorerPress as input" def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): - original_compression_ratio = self.press.compression_ratio # type:ignore[attr-defined] - self.press.compression_ratio = self.compression_ratios[module.layer_idx] # type:ignore[attr-defined] + original_compression_ratio = self.press.compression_ratio # type:ignore[index] + self.press.compression_ratio = self.compression_ratios[module.layer_idx] # type:ignore[index] output = self.press.forward_hook(module, input, kwargs, output) self.press.compression_ratio = original_compression_ratio # type:ignore[attr-defined] return output diff --git a/pyproject.toml b/pyproject.toml index 922dd042..d2f56828 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ pytest = "^7.0.0" flake8 = "^7.0.0" isort = "^5.13.2" black = "^24.8.0" -mypy = "^1.11.2" +mypy = "^1.13.0" pytest-cov = "^5.0.0" pytest-dependency = "^0.6.0" pytest-html = ">=4.1.1, <5.0.0" @@ -44,7 +44,7 @@ build-backend = "poetry.core.masonry.api" [tool.black] line-length = 120 target_version = ["py310"] -exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|venv|doc-venv|.svn|_build|buck-out|build|dist|notebooks|tools|tmp|bundles)" +exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|venv|.venv|doc-venv|.svn|_build|buck-out|build|dist|notebooks|tools|tmp|bundles)" [tool.isort] multi_line_output = 3 @@ -53,9 +53,11 @@ force_grid_wrap = 0 use_parentheses = true ensure_newline_before_comments = true line_length = 120 +skip = ["venv", ".venv"] [tool.mypy] ignore_missing_imports = true allow_redefinition = true strict_optional = false -exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|venv|doc-venv|.svn|_build|buck-out|build|dist|notebooks|tools|tmp|tests|bundles)" +exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|venv|.venv|doc-venv|.svn|_build|buck-out|build|dist|notebooks|tools|tmp|tests|bundles)" +disable_error_code = ["union-attr", "operator", "call-overload", "arg-type"] \ No newline at end of file