Skip to content

Commit

Permalink
Pyright type fixes (partial) (#4411)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssbarnea authored Nov 16, 2024
1 parent 2e5af42 commit e62ef52
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 48 deletions.
2 changes: 1 addition & 1 deletion plugins/modules/fake_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def main() -> None:
"""Return the module instance."""
return AnsibleModule(
AnsibleModule(
argument_spec={
"data": {"default": None},
"path": {"default": None},
Expand Down
4 changes: 2 additions & 2 deletions src/ansiblelint/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from typing import TYPE_CHECKING, Any, TextIO

from ansible_compat.prerun import get_cache_dir
from filelock import FileLock, Timeout
from filelock import BaseFileLock, FileLock, Timeout
from rich.markup import escape

from ansiblelint.constants import RC, SKIP_SCHEMA_UPDATE
Expand Down Expand Up @@ -119,7 +119,7 @@ def initialize_logger(level: int = 0) -> None:
_logger.debug("Logging initialized to level %s", logging_level)


def initialize_options(arguments: list[str] | None = None) -> None | FileLock:
def initialize_options(arguments: list[str] | None = None) -> None | BaseFileLock:
"""Load config options and store them inside options module."""
cache_dir_lock = None
new_options = cli.get_config(arguments or [])
Expand Down
4 changes: 2 additions & 2 deletions src/ansiblelint/formatters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, TypeVar

import rich
from rich.markup import escape

from ansiblelint.config import options
from ansiblelint.version import __version__
Expand Down Expand Up @@ -58,7 +58,7 @@ def apply(self, match: MatchError) -> str:
@staticmethod
def escape(text: str) -> str:
"""Escapes a string to avoid processing it as markup."""
return rich.markup.escape(text)
return escape(text)


class Formatter(BaseFormatter): # type: ignore[type-arg]
Expand Down
2 changes: 1 addition & 1 deletion src/ansiblelint/rules/no_tabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ def test_no_tabs_rule(default_rules_collection: RulesCollection) -> None:
assert len(results) >= i + 1
assert results[i].lineno == expected[0]
assert results[i].message == expected[1]
assert len(results) == len(expected), results
assert len(results) == len(expected_results), results
4 changes: 3 additions & 1 deletion src/ansiblelint/rules/var_naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,12 @@ def matchyaml(self, file: Lintable) -> list[MatchError]:
"""Return matches for variables defined in vars files."""
results: list[MatchError] = []
raw_results: list[MatchError] = []
meta_data: dict[AnsibleUnicode, Any] = {}

if str(file.kind) == "vars" and file.data:
meta_data = parse_yaml_from_file(str(file.path))
if not isinstance(meta_data, dict):
msg = f"Content if vars file {file} is not a dictionary."
raise TypeError(msg)
for key in meta_data:
prefix = Prefix(file.role) if file.role else Prefix()
match_error = self.get_var_naming_matcherror(
Expand Down
15 changes: 9 additions & 6 deletions src/ansiblelint/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
from ansiblelint.app import App, get_app
from ansiblelint.constants import States
from ansiblelint.errors import LintWarning, MatchError, WarnSource
from ansiblelint.file_utils import Lintable, expand_dirs_in_lintables
from ansiblelint.file_utils import (
Lintable,
expand_dirs_in_lintables,
expand_paths_vars,
normpath,
)
from ansiblelint.logger import timed_info
from ansiblelint.rules.syntax_check import OUTPUT_PATTERNS
from ansiblelint.text import strip_ansi_escape
Expand Down Expand Up @@ -111,7 +116,7 @@ def __init__(
def _update_exclude_paths(self, exclude_paths: list[str]) -> None:
if exclude_paths:
# These will be (potentially) relative paths
paths = ansiblelint.file_utils.expand_paths_vars(exclude_paths)
paths = expand_paths_vars(exclude_paths)
# Since ansiblelint.utils.find_children returns absolute paths,
# and the list of files we create in `Runner.run` can contain both
# relative and absolute paths, we need to cover both bases.
Expand Down Expand Up @@ -169,7 +174,6 @@ def run(self) -> list[MatchError]:
if warn.category is DeprecationWarning:
continue
if warn.category is LintWarning:
filename: None | Lintable = None
if isinstance(warn.source, WarnSource):
match = MatchError(
message=warn.source.message or warn.category.__name__,
Expand All @@ -179,13 +183,12 @@ def run(self) -> list[MatchError]:
lineno=warn.source.lineno,
)
else:
filename = warn.source
match = MatchError(
message=(
warn.message if isinstance(warn.message, str) else "?"
),
rule=self.rules["warning"],
lintable=Lintable(str(filename)),
lintable=Lintable(str(warn.source)),
)
matches.append(match)
continue
Expand Down Expand Up @@ -277,7 +280,7 @@ def worker(lintable: Lintable) -> list[MatchError]:
continue
_logger.debug(
"Examining %s of type %s",
ansiblelint.file_utils.normpath(file.path),
normpath(file.path),
file.kind,
)

Expand Down
6 changes: 2 additions & 4 deletions src/ansiblelint/schemas/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from http.client import HTTPException
from pathlib import Path
from typing import Any
from urllib.error import HTTPError
from urllib.request import Request

_logger = logging.getLogger(__package__)
Expand Down Expand Up @@ -94,10 +95,7 @@ def refresh_schemas(min_age_seconds: int = 3600 * 24) -> int:
if kind in _schema_cache: # pragma: no cover
del _schema_cache[kind]
except (ConnectionError, OSError, HTTPException) as exc:
if (
isinstance(exc, urllib.error.HTTPError)
and getattr(exc, "code", None) == 304
):
if isinstance(exc, HTTPError) and getattr(exc, "code", None) == 304:
_logger.debug("Schema %s is not modified", url)
continue
# In case of networking issues, we just stop and use last-known good
Expand Down
21 changes: 12 additions & 9 deletions src/ansiblelint/schemas/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import logging
import re
import typing
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import jsonschema
import yaml
from jsonschema.exceptions import ValidationError
from jsonschema.validators import validator_for

from ansiblelint.loaders import yaml_load_safe
from ansiblelint.schemas.__main__ import JSON_SCHEMAS, _schema_cache
Expand All @@ -22,13 +22,13 @@


def find_best_deep_match(
errors: jsonschema.ValidationError,
) -> jsonschema.ValidationError:
errors: ValidationError,
) -> ValidationError:
"""Return the deepest schema validation error."""

def iter_validation_error(
err: jsonschema.ValidationError,
) -> typing.Iterator[jsonschema.ValidationError]:
err: ValidationError,
) -> typing.Iterator[ValidationError]:
if err.context:
for e in err.context:
yield e
Expand All @@ -39,7 +39,7 @@ def iter_validation_error(

def validate_file_schema(file: Lintable) -> list[str]:
"""Return list of JSON validation errors found."""
schema = {}
schema: dict[Any, Any] = {}
if file.kind not in JSON_SCHEMAS:
return [f"Unable to find JSON Schema '{file.kind}' for '{file.path}' file."]
try:
Expand All @@ -48,7 +48,7 @@ def validate_file_schema(file: Lintable) -> list[str]:
json_data = json.loads(json.dumps(yaml_data))
schema = _schema_cache[file.kind]

validator = jsonschema.validators.validator_for(schema)
validator = validator_for(schema)
v = validator(schema)
try:
error = next(v.iter_errors(json_data))
Expand All @@ -57,6 +57,9 @@ def validate_file_schema(file: Lintable) -> list[str]:
if error.context:
error = find_best_deep_match(error)
# determine if we want to use our own messages embedded into schemas inside title/markdownDescription fields
if not hasattr(error, "schema") or not isinstance(error.schema, dict):
msg = "error object does not have schema attribute"
raise TypeError(msg)
if "not" in error.schema and len(error.schema["not"]) == 0:
message = error.schema["title"]
schema = error.schema
Expand Down Expand Up @@ -106,7 +109,7 @@ def validate_file_schema(file: Lintable) -> list[str]:
return [message]


def _deep_match_relevance(error: jsonschema.ValidationError) -> tuple[bool | int, ...]:
def _deep_match_relevance(error: ValidationError) -> tuple[bool | int, ...]:
validator = error.validator
return (
validator not in ("anyOf", "oneOf"), # type: ignore[comparison-overlap]
Expand Down
36 changes: 19 additions & 17 deletions src/ansiblelint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ def task_to_str(task: dict[str, Any]) -> str:
return f"{action['__ansible_module__']} {' '.join(args)}"


# pylint: disable=too-many-nested-blocks
def extract_from_list(
blocks: AnsibleBaseYAMLObject,
candidates: list[str],
Expand All @@ -737,23 +738,24 @@ def extract_from_list(
) -> list[Any]:
"""Get action tasks from block structures."""
results = []
for block in blocks:
for candidate in candidates:
if isinstance(block, dict) and candidate in block:
if isinstance(block[candidate], list):
subresults = add_action_type(block[candidate], candidate)
if recursive:
subresults.extend(
extract_from_list(
subresults,
candidates,
recursive=recursive,
),
)
results.extend(subresults)
elif block[candidate] is not None:
msg = f"Key '{candidate}' defined, but bad value: '{block[candidate]!s}'"
raise RuntimeError(msg)
if isinstance(blocks, Iterable):
for block in blocks:
for candidate in candidates:
if isinstance(block, dict) and candidate in block:
if isinstance(block[candidate], list):
subresults = add_action_type(block[candidate], candidate)
if recursive:
subresults.extend(
extract_from_list(
subresults,
candidates,
recursive=recursive,
),
)
results.extend(subresults)
elif block[candidate] is not None:
msg = f"Key '{candidate}' defined, but bad value: '{block[candidate]!s}'"
raise RuntimeError(msg)
return results


Expand Down
3 changes: 1 addition & 2 deletions src/ansiblelint/yaml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
if TYPE_CHECKING:
# noinspection PyProtectedMember
from ruamel.yaml.comments import LineCol
from ruamel.yaml.compat import StreamTextType
from ruamel.yaml.nodes import ScalarNode
from ruamel.yaml.representer import RoundTripRepresenter
from ruamel.yaml.tokens import CommentToken
Expand Down Expand Up @@ -1026,7 +1025,7 @@ def version(self, value: tuple[int, int] | None) -> None:
self._yaml_version = self._yaml_version_default
# We do nothing if the object did not have a previous default version defined

def load(self, stream: Path | StreamTextType) -> Any:
def load(self, stream: Path | Any) -> Any:
"""Load YAML content from a string while avoiding known ruamel.yaml issues."""
if not isinstance(stream, str):
msg = f"expected a str but got {type(stream)}"
Expand Down
6 changes: 4 additions & 2 deletions test/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)


def urlopen_side_effect(*_args: Any, **kwargs: Any) -> DEFAULT:
def urlopen_side_effect(*_args: Any, **kwargs: Any) -> Any:
"""Actual test that timeout parameter is defined."""
assert "timeout" in kwargs
assert kwargs["timeout"] > 0
Expand All @@ -47,7 +47,9 @@ def test_request_timeouterror_handling(
) -> None:
"""Test that schema refresh can handle time out errors."""
error_msg = "Simulating handshake operation time out."
mock_request.urlopen.side_effect = urllib.error.URLError(TimeoutError(error_msg))
mock_request.urlopen.side_effect = urllib.error.URLError(
TimeoutError(error_msg)
) # pyright: reportAttributeAccessIssue=false
with caplog.at_level(logging.DEBUG):
assert refresh_schemas(min_age_seconds=0) == 0
mock_request.urlopen.assert_called()
Expand Down
2 changes: 1 addition & 1 deletion test/test_yaml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ def test_deannotate(

def test_yamllint_incompatible_config() -> None:
"""Ensure we can detect incompatible yamllint settings."""
with (cwd(Path("examples/yamllint/incompatible-config")),):
with cwd(Path("examples/yamllint/incompatible-config")):
config = ansiblelint.yaml_utils.load_yamllint_config()
assert config.incompatible

Expand Down

0 comments on commit e62ef52

Please sign in to comment.