From f346f0a47482294cc55a682efd43bcc2ec48ffb9 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 6 Jun 2024 11:53:33 +0200 Subject: [PATCH] Migrate Python linters from `ast` (standard library) to `astroid` package (#1835) ## Changes Migrate Python linters from ast to astroid Implement minimal inference ### Linked issues Progresses #1205 ### Functionality - [ ] added relevant user documentation - [ ] added new CLI command - [ ] modified existing command: `databricks labs ucx ...` - [ ] added a new workflow - [ ] modified existing workflow: `...` - [ ] added a new table - [ ] modified existing table: `...` ### Tests - [ ] manually tested - [x] added unit tests - [ ] added integration tests - [ ] verified on staging environment (screenshot attached) --------- Co-authored-by: Eric Vergnaud --- pyproject.toml | 5 +- src/databricks/labs/ucx/source_code/graph.py | 6 +- .../ucx/source_code/linters/ast_helpers.py | 53 ++-- .../labs/ucx/source_code/linters/context.py | 5 +- .../labs/ucx/source_code/linters/dbfs.py | 33 +- .../labs/ucx/source_code/linters/imports.py | 288 +++++++++++------- .../labs/ucx/source_code/linters/pyspark.py | 91 +++--- .../ucx/source_code/linters/spark_connect.py | 71 ++--- .../ucx/source_code/linters/table_creation.py | 32 +- .../labs/ucx/source_code/python_libraries.py | 2 +- tests/integration/conftest.py | 1 + .../integration/source_code/test_libraries.py | 2 +- tests/unit/source_code/linters/test_dbfs.py | 13 +- .../unit/source_code/linters/test_pyspark.py | 97 +++--- .../source_code/linters/test_python_linter.py | 151 ++++----- .../source_code/linters/test_spark_connect.py | 37 ++- .../linters/test_table_creation.py | 3 +- tests/unit/source_code/samples/root10.py | 3 - tests/unit/source_code/test_dependencies.py | 44 +-- tests/unit/source_code/test_functional.py | 2 +- tests/unit/source_code/test_notebook.py | 2 +- .../unit/source_code/test_notebook_linter.py | 34 ++- 22 files changed, 526 insertions(+), 449 deletions(-) delete mode 100644 tests/unit/source_code/samples/root10.py diff --git a/pyproject.toml b/pyproject.toml index 38113c579b..ae7d04ec5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,8 @@ dependencies = ["databricks-sdk>=0.27,<0.29", "databricks-labs-lsql~=0.4.0", "databricks-labs-blueprint>=0.6.0", "PyYAML>=6.0.0,<7.0.0", - "sqlglot>=23.9,<24.2"] + "sqlglot>=23.9,<24.2", + "astroid>=3.2.2"] [project.entry-points.databricks] runtime = "databricks.labs.ucx.runtime:main" @@ -65,7 +66,7 @@ dependencies = [ "black~=24.3.0", "coverage[toml]~=7.4.4", "mypy~=1.9.0", - "pylint~=3.1.0", + "pylint~=3.2.2", "pylint-pytest==2.0.0a0", "databricks-labs-pylint~=0.4.0", "pytest~=8.1.0", diff --git a/src/databricks/labs/ucx/source_code/graph.py b/src/databricks/labs/ucx/source_code/graph.py index 944873da7b..cf51aea9fb 100644 --- a/src/databricks/labs/ucx/source_code/graph.py +++ b/src/databricks/labs/ucx/source_code/graph.py @@ -11,10 +11,10 @@ from databricks.labs.ucx.source_code.linters.imports import ( ASTLinter, DbutilsLinter, - SysPathChange, - NotebookRunCall, ImportSource, NodeBase, + NotebookRunCall, + SysPathChange, ) from databricks.labs.ucx.source_code.path_lookup import PathLookup @@ -186,7 +186,7 @@ def _process_node(self, base_node: NodeBase): if isinstance(base_node, SysPathChange): self._mutate_path_lookup(base_node) if isinstance(base_node, NotebookRunCall): - strpath = base_node.get_constant_path() + strpath = base_node.get_notebook_path() if strpath is None: yield DependencyProblem('dependency-not-constant', "Can't check dependency not provided as a constant") else: diff --git a/src/databricks/labs/ucx/source_code/linters/ast_helpers.py b/src/databricks/labs/ucx/source_code/linters/ast_helpers.py index 5993c1ea7a..4324860e89 100644 --- a/src/databricks/labs/ucx/source_code/linters/ast_helpers.py +++ b/src/databricks/labs/ucx/source_code/linters/ast_helpers.py @@ -1,30 +1,39 @@ -import ast +import logging +from astroid import Attribute, Call, Name # type: ignore -class AstHelper: - @staticmethod - def get_full_attribute_name(node: ast.Attribute) -> str: - return AstHelper._get_value(node) +logger = logging.getLogger(__file__) - @staticmethod - def get_full_function_name(node: ast.Call) -> str | None: - if isinstance(node.func, ast.Attribute): - return AstHelper._get_value(node.func) +missing_handlers: set[str] = set() - if isinstance(node.func, ast.Name): - return node.func.id +class AstHelper: + @classmethod + def get_full_attribute_name(cls, node: Attribute) -> str: + return cls._get_attribute_value(node) + + @classmethod + def get_full_function_name(cls, node: Call) -> str | None: + if not isinstance(node, Call): + return None + if isinstance(node.func, Attribute): + return cls._get_attribute_value(node.func) + if isinstance(node.func, Name): + return node.func.name return None - @staticmethod - def _get_value(node: ast.Attribute): - if isinstance(node.value, ast.Name): - return node.value.id + '.' + node.attr - - if isinstance(node.value, ast.Attribute): - value = AstHelper._get_value(node.value) - if not value: - return None - return value + '.' + node.attr - + @classmethod + def _get_attribute_value(cls, node: Attribute): + if isinstance(node.expr, Name): + return node.expr.name + '.' + node.attrname + if isinstance(node.expr, Attribute): + parent = cls._get_attribute_value(node.expr) + return node.attrname if parent is None else parent + '.' + node.attrname + if isinstance(node.expr, Call): + name = cls.get_full_function_name(node.expr) + return node.attrname if name is None else name + '.' + node.attrname + name = type(node.expr).__name__ + if name not in missing_handlers: + missing_handlers.add(name) + logger.debug(f"Missing handler for {name}") return None diff --git a/src/databricks/labs/ucx/source_code/linters/context.py b/src/databricks/labs/ucx/source_code/linters/context.py index 5ab0eb1088..9fbe3e8bd1 100644 --- a/src/databricks/labs/ucx/source_code/linters/context.py +++ b/src/databricks/labs/ucx/source_code/linters/context.py @@ -1,13 +1,14 @@ from databricks.sdk.service.workspace import Language from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex -from databricks.labs.ucx.source_code.base import CurrentSessionState, SequentialLinter, Fixer, Linter +from databricks.labs.ucx.source_code.base import Fixer, Linter, SequentialLinter, CurrentSessionState from databricks.labs.ucx.source_code.linters.dbfs import FromDbfsFolder, DBFSUsageLinter from databricks.labs.ucx.source_code.linters.imports import DbutilsLinter + from databricks.labs.ucx.source_code.linters.pyspark import SparkSql -from databricks.labs.ucx.source_code.queries import FromTable from databricks.labs.ucx.source_code.linters.spark_connect import SparkConnectLinter from databricks.labs.ucx.source_code.linters.table_creation import DBRv8d0Linter +from databricks.labs.ucx.source_code.queries import FromTable class LinterContext: diff --git a/src/databricks/labs/ucx/source_code/linters/dbfs.py b/src/databricks/labs/ucx/source_code/linters/dbfs.py index 374cb4952a..2686e44c89 100644 --- a/src/databricks/labs/ucx/source_code/linters/dbfs.py +++ b/src/databricks/labs/ucx/source_code/linters/dbfs.py @@ -1,13 +1,14 @@ -import ast from collections.abc import Iterable +from astroid import Call, Const # type: ignore import sqlglot from sqlglot.expressions import Table from databricks.labs.ucx.source_code.base import Advice, Linter, Advisory, Deprecation +from databricks.labs.ucx.source_code.linters.imports import Visitor, ASTLinter -class DetectDbfsVisitor(ast.NodeVisitor): +class DetectDbfsVisitor(Visitor): """ Visitor that detects file system paths in Python code and checks them against a list of known deprecated paths. @@ -18,44 +19,44 @@ def __init__(self): self._fs_prefixes = ["/dbfs/mnt", "dbfs:/", "/mnt/"] self._reported_locations = set() # Set to store reported locations - def visit_Call(self, node): + def visit_call(self, node: Call): for arg in node.args: - if isinstance(arg, (ast.Str, ast.Constant)) and isinstance(arg.s, str): - if any(arg.s.startswith(prefix) for prefix in self._fs_prefixes): + if isinstance(arg, Const) and isinstance(arg.value, str): + value = arg.value + if any(value.startswith(prefix) for prefix in self._fs_prefixes): self._advices.append( Deprecation( code='dbfs-usage', - message=f"Deprecated file system path in call to: {arg.s}", + message=f"Deprecated file system path in call to: {value}", start_line=arg.lineno, start_col=arg.col_offset, end_line=arg.lineno, - end_col=arg.col_offset + len(arg.s), + end_col=arg.col_offset + len(value), ) ) # Record the location of the reported constant, so we do not double report self._reported_locations.add((arg.lineno, arg.col_offset)) - self.generic_visit(node) - def visit_Constant(self, node): + def visit_const(self, node: Const): # Constant strings yield Advisories if isinstance(node.value, str): self._check_str_constant(node) - def _check_str_constant(self, node): + def _check_str_constant(self, node: Const): # Check if the location has been reported before if (node.lineno, node.col_offset) not in self._reported_locations: - if any(node.s.startswith(prefix) for prefix in self._fs_prefixes): + value = node.value + if any(value.startswith(prefix) for prefix in self._fs_prefixes): self._advices.append( Advisory( code='dbfs-usage', - message=f"Possible deprecated file system path: {node.s}", + message=f"Possible deprecated file system path: {value}", start_line=node.lineno, start_col=node.col_offset, end_line=node.lineno, - end_col=node.col_offset + len(node.s), + end_col=node.col_offset + len(value), ) ) - self.generic_visit(node) def get_advices(self) -> Iterable[Advice]: yield from self._advices @@ -76,9 +77,9 @@ def lint(self, code: str) -> Iterable[Advice]: """ Lints the code looking for file system paths that are deprecated """ - tree = ast.parse(code) + linter = ASTLinter.parse(code) visitor = DetectDbfsVisitor() - visitor.visit(tree) + visitor.visit(linter.root) yield from visitor.get_advices() diff --git a/src/databricks/labs/ucx/source_code/linters/imports.py b/src/databricks/labs/ucx/source_code/linters/imports.py index 845430fee4..5056ed5c04 100644 --- a/src/databricks/labs/ucx/source_code/linters/imports.py +++ b/src/databricks/labs/ucx/source_code/linters/imports.py @@ -1,20 +1,60 @@ from __future__ import annotations import abc -import ast import logging from collections.abc import Iterable, Callable from typing import TypeVar, Generic, cast +from astroid import ( # type: ignore + parse, + Attribute, + Call, + Const, + Import, + ImportFrom, + Module, + Name, + NodeNG, +) + from databricks.labs.ucx.source_code.base import Linter, Advice, Advisory logger = logging.getLogger(__name__) -class MatchingVisitor(ast.NodeVisitor): +class Visitor: + + def visit(self, node: NodeNG): + self._visit_specific(node) + for child in node.get_children(): + self.visit(child) + + def _visit_specific(self, node: NodeNG): + method_name = "visit_" + type(node).__name__.lower() + method_slot = getattr(self, method_name, None) + if callable(method_slot): + method_slot(node) + else: + self.visit_nodeng(node) + + def visit_nodeng(self, node: NodeNG): + pass + + +class TreeWalker: + + @classmethod + def walk(cls, node: NodeNG) -> Iterable[NodeNG]: + yield node + for child in node.get_children(): + yield from cls.walk(child) + + +class MatchingVisitor(Visitor): def __init__(self, node_type: type, match_nodes: list[tuple[str, type]]): - self._matched_nodes: list[ast.AST] = [] + super() + self._matched_nodes: list[NodeNG] = [] self._node_type = node_type self._match_nodes = match_nodes @@ -22,8 +62,8 @@ def __init__(self, node_type: type, match_nodes: list[tuple[str, type]]): def matched_nodes(self): return self._matched_nodes - def visit_Call(self, node: ast.Call): - if self._node_type is not ast.Call: + def visit_call(self, node: Call): + if self._node_type is not Call: return try: if self._matches(node.func, 0): @@ -31,29 +71,29 @@ def visit_Call(self, node: ast.Call): except NotImplementedError as e: logger.warning(f"Missing implementation: {e.args[0]}") - def visit_Import(self, node: ast.Import): - if self._node_type is not ast.Import: + def visit_import(self, node: Import): + if self._node_type is not Import: return self._matched_nodes.append(node) - def visit_ImportFrom(self, node: ast.ImportFrom): - if self._node_type is not ast.ImportFrom: + def visit_importfrom(self, node: ImportFrom): + if self._node_type is not ImportFrom: return self._matched_nodes.append(node) - def _matches(self, node: ast.AST, depth: int): + def _matches(self, node: NodeNG, depth: int): if depth >= len(self._match_nodes): return False - pair = self._match_nodes[depth] - if not isinstance(node, pair[1]): + name, match_node = self._match_nodes[depth] + if not isinstance(node, match_node): return False - next_node: ast.AST | None = None - if isinstance(node, ast.Attribute): - if node.attr != pair[0]: + next_node: NodeNG | None = None + if isinstance(node, Attribute): + if node.attrname != name: return False - next_node = node.value - elif isinstance(node, ast.Name): - if node.id != pair[0]: + next_node = node.expr + elif isinstance(node, Name): + if node.name != name: return False else: raise NotImplementedError(str(type(node))) @@ -65,7 +105,7 @@ def _matches(self, node: ast.AST, depth: int): class NodeBase(abc.ABC): - def __init__(self, node: ast.AST): + def __init__(self, node: NodeNG): self._node = node @property @@ -73,12 +113,12 @@ def node(self): return self._node def __repr__(self): - return f"<{self.__class__.__name__}: {ast.unparse(self._node)}>" + return f"<{self.__class__.__name__}: {repr(self._node)}>" class SysPathChange(NodeBase, abc.ABC): - def __init__(self, node: ast.AST, path: str, is_append: bool): + def __init__(self, node: NodeNG, path: str, is_append: bool): super().__init__(node) self._path = path self._is_append = is_append @@ -106,9 +146,10 @@ class RelativePath(SysPathChange): pass -class SysPathVisitor(ast.NodeVisitor): +class SysPathVisitor(Visitor): def __init__(self): + super() self._aliases: dict[str, str] = {} self._syspath_changes: list[SysPathChange] = [] @@ -116,58 +157,59 @@ def __init__(self): def syspath_changes(self): return self._syspath_changes - def visit_Import(self, node: ast.Import): - for alias in node.names: - if alias.name in {"sys", "os"}: - self._aliases[alias.name] = alias.asname or alias.name + def visit_import(self, node: Import): + for name, alias in node.names: + if alias is None or name not in {"sys", "os"}: + continue + self._aliases[name] = alias - def visit_ImportFrom(self, node: ast.ImportFrom): + def visit_importfrom(self, node: ImportFrom): interesting_aliases = [("sys", "path"), ("os", "path"), ("os.path", "abspath")] - interesting_alias = next((t for t in interesting_aliases if t[0] == node.module), None) + interesting_alias = next((t for t in interesting_aliases if t[0] == node.modname), None) if interesting_alias is None: return - for alias in node.names: - if alias.name == interesting_alias[1]: - self._aliases[f"{node.module}.{interesting_alias[1]}"] = alias.asname or alias.name + for name, alias in node.names: + if name == interesting_alias[1]: + self._aliases[f"{node.modname}.{interesting_alias[1]}"] = alias or name break - def visit_Call(self, node: ast.Call): - func = cast(ast.Attribute, node.func) + def visit_call(self, node: Call): + func = cast(Attribute, node.func) # check for 'sys.path.append' if not ( self._match_aliases(func, ["sys", "path", "append"]) or self._match_aliases(func, ["sys", "path", "insert"]) ): return - is_append = func.attr == "append" + is_append = func.attrname == "append" changed = node.args[0] if is_append else node.args[1] - if isinstance(changed, ast.Constant): + if isinstance(changed, Const): self._syspath_changes.append(AbsolutePath(node, changed.value, is_append)) - elif isinstance(changed, ast.Call): + elif isinstance(changed, Call): self._visit_relative_path(changed, is_append) - def _match_aliases(self, node: ast.AST, names: list[str]): - if isinstance(node, ast.Attribute): - if node.attr != names[-1]: + def _match_aliases(self, node: NodeNG, names: list[str]): + if isinstance(node, Attribute): + if node.attrname != names[-1]: return False if len(names) == 1: return True - return self._match_aliases(node.value, names[0 : len(names) - 1]) - if isinstance(node, ast.Name): + return self._match_aliases(node.expr, names[0 : len(names) - 1]) + if isinstance(node, Name): full_name = ".".join(names) alias = self._aliases.get(full_name, full_name) - return node.id == alias + return node.name == alias return False - def _visit_relative_path(self, node: ast.Call, is_append: bool): + def _visit_relative_path(self, node: NodeNG, is_append: bool): # check for 'os.path.abspath' if not self._match_aliases(node.func, ["os", "path", "abspath"]): return changed = node.args[0] - if isinstance(changed, ast.Constant): + if isinstance(changed, Const): self._syspath_changes.append(RelativePath(changed, changed.value, is_append)) -T = TypeVar("T", bound=ast.AST) +T = TypeVar("T", bound=NodeNG) # disclaimer this class is NOT thread-safe @@ -175,11 +217,15 @@ class ASTLinter(Generic[T]): @staticmethod def parse(code: str): - root = ast.parse(code) + root = parse(code) return ASTLinter(root) - def __init__(self, root: ast.AST): - self._root: ast.AST = root + def __init__(self, root: Module): + self._root: Module = root + + @property + def root(self): + return self._root def locate(self, node_type: type[T], match_nodes: list[tuple[str, type]]) -> list[T]: visitor = MatchingVisitor(node_type, match_nodes) @@ -191,61 +237,57 @@ def collect_sys_paths_changes(self): visitor.visit(self._root) return visitor.syspath_changes - def extract_callchain(self) -> ast.Call | None: - """If 'node' is an assignment or expression, extract its full call-chain (if it has one)""" - call = None - if isinstance(self._root, ast.Assign): - call = self._root.value - elif isinstance(self._root, ast.Expr): - call = self._root.value - if not isinstance(call, ast.Call): - call = None - return call - - def extract_call_by_name(self, name: str) -> ast.Call | None: + def first_statement(self): + return self._root.body[0] + + @classmethod + def extract_call_by_name(cls, call: Call, name: str) -> Call | None: """Given a call-chain, extract its sub-call by method name (if it has one)""" - assert isinstance(self._root, ast.Call) - node = self._root + assert isinstance(call, Call) + node = call while True: func = node.func - if not isinstance(func, ast.Attribute): + if not isinstance(func, Attribute): return None - if func.attr == name: + if func.attrname == name: return node - if not isinstance(func.value, ast.Call): + if not isinstance(func.expr, Call): return None - node = func.value + node = func.expr - def args_count(self) -> int: + @classmethod + def args_count(cls, node: Call) -> int: """Count the number of arguments (positionals + keywords)""" - assert isinstance(self._root, ast.Call) - return len(self._root.args) + len(self._root.keywords) + assert isinstance(node, Call) + return len(node.args) + len(node.keywords) + @classmethod def get_arg( - self, + cls, + node: Call, arg_index: int | None, arg_name: str | None, - ) -> ast.expr | None: + ) -> NodeNG | None: """Extract the call argument identified by an optional position or name (if it has one)""" - assert isinstance(self._root, ast.Call) - if arg_index is not None and len(self._root.args) > arg_index: - return self._root.args[arg_index] + assert isinstance(node, Call) + if arg_index is not None and len(node.args) > arg_index: + return node.args[arg_index] if arg_name is not None: - arg = [kw.value for kw in self._root.keywords if kw.arg == arg_name] + arg = [kw.value for kw in node.keywords if kw.arg == arg_name] if len(arg) == 1: return arg[0] return None - def is_none(self) -> bool: + @classmethod + def is_none(cls, node: NodeNG) -> bool: """Check if the given AST expression is the None constant""" - assert isinstance(self._root, ast.expr) - if not isinstance(self._root, ast.Constant): + if not isinstance(node, Const): return False - return self._root.value is None + return node.value is None def __repr__(self): truncate_after = 32 - code = ast.unparse(self._root) + code = repr(self._root) if len(code) > truncate_after: code = code[0:truncate_after] + "..." return f"" @@ -253,20 +295,21 @@ def __repr__(self): class ImportSource(NodeBase): - def __init__(self, node: ast.AST, name: str): + def __init__(self, node: NodeNG, name: str): super().__init__(node) self.name = name class NotebookRunCall(NodeBase): - def __init__(self, node: ast.Call): + def __init__(self, node: Call): super().__init__(node) - def get_constant_path(self) -> str | None: - path = DbutilsLinter.get_dbutils_notebook_run_path_arg(cast(ast.Call, self.node)) - if isinstance(path, ast.Constant): - return path.value.strip().strip("'").strip('"') + def get_notebook_path(self) -> str | None: + node = DbutilsLinter.get_dbutils_notebook_run_path_arg(cast(Call, self.node)) + inferred = next(node.infer(), None) + if isinstance(inferred, Const): + return inferred.value.strip().strip("'").strip('"') return None @@ -281,10 +324,10 @@ def lint(self, code: str) -> Iterable[Advice]: return [self._convert_dbutils_notebook_run_to_advice(node.node) for node in nodes] @classmethod - def _convert_dbutils_notebook_run_to_advice(cls, node: ast.AST) -> Advisory: - assert isinstance(node, ast.Call) + def _convert_dbutils_notebook_run_to_advice(cls, node: NodeNG) -> Advisory: + assert isinstance(node, Call) path = cls.get_dbutils_notebook_run_path_arg(node) - if isinstance(path, ast.Constant): + if isinstance(path, Const): return Advisory( 'dbutils-notebook-run-literal', "Call to 'dbutils.notebook.run' will be migrated automatically", @@ -303,7 +346,7 @@ def _convert_dbutils_notebook_run_to_advice(cls, node: ast.AST) -> Advisory: ) @staticmethod - def get_dbutils_notebook_run_path_arg(node: ast.Call): + def get_dbutils_notebook_run_path_arg(node: Call): if len(node.args) > 0: return node.args[0] arg = next(kw for kw in node.keywords if kw.arg == "path") @@ -311,32 +354,24 @@ def get_dbutils_notebook_run_path_arg(node: ast.Call): @staticmethod def list_dbutils_notebook_run_calls(linter: ASTLinter) -> list[NotebookRunCall]: - calls = linter.locate(ast.Call, [("run", ast.Attribute), ("notebook", ast.Attribute), ("dbutils", ast.Name)]) + calls = linter.locate(Call, [("run", Attribute), ("notebook", Attribute), ("dbutils", Name)]) return [NotebookRunCall(call) for call in calls] - @staticmethod - def list_import_sources(linter: ASTLinter, problem_type: P) -> tuple[list[ImportSource], list[P]]: + @classmethod + def list_import_sources(cls, linter: ASTLinter, problem_type: P) -> tuple[list[ImportSource], list[P]]: problems: list[P] = [] + sources: list[ImportSource] = [] try: # pylint: disable=too-many-try-statements - nodes = linter.locate(ast.Import, []) - sources = [ImportSource(node, alias.name) for node in nodes for alias in node.names] - nodes = linter.locate(ast.ImportFrom, []) - sources.extend(ImportSource(node, node.module) for node in nodes) - nodes = linter.locate(ast.Call, [("import_module", ast.Attribute), ("importlib", ast.Name)]) - nodes.extend(linter.locate(ast.Call, [("__import__", ast.Attribute), ("importlib", ast.Name)])) - for node in nodes: - if isinstance(node.args[0], ast.Constant): - sources.append(ImportSource(node, node.args[0].value)) - continue - problem = problem_type( - 'dependency-not-constant', - "Can't check dependency not provided as a constant", - start_line=node.lineno, - start_col=node.col_offset, - end_line=node.end_lineno or 0, - end_col=node.end_col_offset or 0, - ) - problems.append(problem) + nodes = linter.locate(Import, []) + for source in cls._make_sources_for_import_nodes(nodes): + sources.append(source) + nodes = linter.locate(ImportFrom, []) + for source in cls._make_sources_for_import_from_nodes(nodes): + sources.append(source) + nodes = linter.locate(Call, [("import_module", Attribute), ("importlib", Name)]) + nodes.extend(linter.locate(Call, [("__import__", Attribute), ("importlib", Name)])) + for source in cls._make_sources_for_import_call_nodes(nodes, problem_type, problems): + sources.append(source) return sources, problems except Exception as e: # pylint: disable=broad-except problem = problem_type('internal-error', f"While linter {linter} was checking imports: {e}") @@ -346,3 +381,32 @@ def list_import_sources(linter: ASTLinter, problem_type: P) -> tuple[list[Import @staticmethod def list_sys_path_changes(linter: ASTLinter) -> list[SysPathChange]: return linter.collect_sys_paths_changes() + + @classmethod + def _make_sources_for_import_nodes(cls, nodes: list[Import]) -> Iterable[ImportSource]: + for node in nodes: + for name, _ in node.names: + if name is not None: + yield ImportSource(node, name) + + @classmethod + def _make_sources_for_import_from_nodes(cls, nodes: list[ImportFrom]) -> Iterable[ImportSource]: + for node in nodes: + yield ImportSource(node, node.modname) + + @classmethod + def _make_sources_for_import_call_nodes(cls, nodes: list[Call], problem_type: P, problems: list[P]): + for node in nodes: + arg = node.args[0] + if isinstance(arg, Const): + yield ImportSource(node, arg.value) + continue + problem = problem_type( + 'dependency-not-constant', + "Can't check dependency not provided as a constant", + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno or 0, + end_col=node.end_col_offset or 0, + ) + problems.append(problem) diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index 910f9383e2..7f841d1df8 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -1,8 +1,8 @@ -import ast from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator from dataclasses import dataclass +from astroid import Attribute, Call, Const, NodeNG # type: ignore from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex from databricks.labs.ucx.source_code.base import ( Advice, @@ -11,6 +11,7 @@ Fixer, Linter, ) +from databricks.labs.ucx.source_code.linters.imports import ASTLinter, TreeWalker from databricks.labs.ucx.source_code.queries import FromTable from databricks.labs.ucx.source_code.linters.ast_helpers import AstHelper @@ -24,22 +25,18 @@ class Matcher(ABC): table_arg_name: str | None = None call_context: dict[str, set[str]] | None = None - def matches(self, node: ast.AST): - return ( - isinstance(node, ast.Call) - and isinstance(node.func, ast.Attribute) - and self._get_table_arg(node) is not None - ) + def matches(self, node: NodeNG): + return isinstance(node, Call) and isinstance(node.func, Attribute) and self._get_table_arg(node) is not None @abstractmethod - def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]: + def lint(self, from_table: FromTable, index: MigrationIndex, node: Call) -> Iterator[Advice]: """raises Advices by linting the code""" @abstractmethod - def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None: + def apply(self, from_table: FromTable, index: MigrationIndex, node: Call) -> None: """applies recommendations""" - def _get_table_arg(self, node: ast.Call): + def _get_table_arg(self, node: Call): node_argc = len(node.args) if self.min_args <= node_argc <= self.max_args and self.table_arg_index < node_argc: return node.args[self.table_arg_index] @@ -52,9 +49,9 @@ def _get_table_arg(self, node: ast.Call): return keyword.value return None - def _check_call_context(self, node: ast.Call) -> bool: - assert isinstance(node.func, ast.Attribute) # Avoid linter warning - func_name = node.func.attr + def _check_call_context(self, node: Call) -> bool: + assert isinstance(node.func, Attribute) # Avoid linter warning + func_name = node.func.attrname qualified_name = AstHelper.get_full_function_name(node) # Check if the call_context is None as that means all calls are checked @@ -71,9 +68,9 @@ def _check_call_context(self, node: ast.Call) -> bool: @dataclass class QueryMatcher(Matcher): - def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]: + def lint(self, from_table: FromTable, index: MigrationIndex, node: Call) -> Iterator[Advice]: table_arg = self._get_table_arg(node) - if isinstance(table_arg, ast.Constant): + if isinstance(table_arg, Const): for advice in from_table.lint(table_arg.value): yield advice.replace( start_line=node.lineno, @@ -91,9 +88,9 @@ def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> end_col=node.end_col_offset or 0, ) - def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None: + def apply(self, from_table: FromTable, index: MigrationIndex, node: Call) -> None: table_arg = self._get_table_arg(node) - assert isinstance(table_arg, ast.Constant) + assert isinstance(table_arg, Const) new_query = from_table.apply(table_arg.value) table_arg.value = new_query @@ -101,14 +98,14 @@ def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> @dataclass class TableNameMatcher(Matcher): - def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]: + def lint(self, from_table: FromTable, index: MigrationIndex, node: Call) -> Iterator[Advice]: table_arg = self._get_table_arg(node) - if not isinstance(table_arg, ast.Constant): - assert isinstance(node.func, ast.Attribute) # always true, avoids a pylint warning + if not isinstance(table_arg, Const): + assert isinstance(node.func, Attribute) # always true, avoids a pylint warning yield Advisory( code='table-migrate', - message=f"Can't migrate '{node.func.attr}' because its table name argument is not a constant", + message=f"Can't migrate '{node.func.attrname}' because its table name argument is not a constant", start_line=node.lineno, start_col=node.col_offset, end_line=node.end_lineno or 0, @@ -130,9 +127,9 @@ def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> end_col=node.end_col_offset or 0, ) - def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None: + def apply(self, from_table: FromTable, index: MigrationIndex, node: Call) -> None: table_arg = self._get_table_arg(node) - assert isinstance(table_arg, ast.Constant) + assert isinstance(table_arg, Const) dst = self._find_dest(index, table_arg.value, from_table.schema) if dst is not None: table_arg.value = dst.destination() @@ -149,21 +146,21 @@ def _find_dest(index: MigrationIndex, value: str, schema: str): @dataclass class ReturnValueMatcher(Matcher): - def matches(self, node: ast.AST): - return isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) + def matches(self, node: NodeNG): + return isinstance(node, Call) and isinstance(node.func, Attribute) - def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]: - assert isinstance(node.func, ast.Attribute) # always true, avoids a pylint warning + def lint(self, from_table: FromTable, index: MigrationIndex, node: Call) -> Iterator[Advice]: + assert isinstance(node.func, Attribute) # always true, avoids a pylint warning yield Advisory( code='table-migrate', - message=f"Call to '{node.func.attr}' will return a list of .. instead of .
.", + message=f"Call to '{node.func.attrname}' will return a list of ..
instead of .
.", start_line=node.lineno, start_col=node.col_offset, end_line=node.end_lineno or 0, end_col=node.end_col_offset or 0, ) - def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None: + def apply(self, from_table: FromTable, index: MigrationIndex, node: Call) -> None: # No transformations to apply return @@ -183,16 +180,12 @@ class DirectFilesystemAccessMatcher(Matcher): "file:/", } - def matches(self, node: ast.AST): - return ( - isinstance(node, ast.Call) - and isinstance(node.func, ast.Attribute) - and self._get_table_arg(node) is not None - ) + def matches(self, node: NodeNG): + return isinstance(node, Call) and isinstance(node.func, Attribute) and self._get_table_arg(node) is not None - def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> Iterator[Advice]: + def lint(self, from_table: FromTable, index: MigrationIndex, node: NodeNG) -> Iterator[Advice]: table_arg = self._get_table_arg(node) - if not isinstance(table_arg, ast.Constant): + if not isinstance(table_arg, Const): return if not table_arg.value: return @@ -218,7 +211,7 @@ def lint(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> end_col=node.end_col_offset or 0, ) - def apply(self, from_table: FromTable, index: MigrationIndex, node: ast.Call) -> None: + def apply(self, from_table: FromTable, index: MigrationIndex, node: Call) -> None: # No transformations to apply return @@ -335,31 +328,31 @@ def name(self) -> str: return self._from_table.name() def lint(self, code: str) -> Iterable[Advice]: - tree = ast.parse(code) - for node in ast.walk(tree): + linter = ASTLinter.parse(code) + for node in TreeWalker.walk(linter.root): matcher = self._find_matcher(node) if matcher is None: continue - assert isinstance(node, ast.Call) + assert isinstance(node, Call) yield from matcher.lint(self._from_table, self._index, node) def apply(self, code: str) -> str: - tree = ast.parse(code) + linter = ASTLinter.parse(code) # we won't be doing it like this in production, but for the sake of the example - for node in ast.walk(tree): + for node in TreeWalker.walk(linter.root): matcher = self._find_matcher(node) if matcher is None: continue - assert isinstance(node, ast.Call) + assert isinstance(node, Call) matcher.apply(self._from_table, self._index, node) - return ast.unparse(tree) + return linter.root.as_string() - def _find_matcher(self, node: ast.AST): - if not isinstance(node, ast.Call): + def _find_matcher(self, node: NodeNG): + if not isinstance(node, Call): return None - if not isinstance(node.func, ast.Attribute): + if not isinstance(node.func, Attribute): return None - matcher = self._spark_matchers.matchers.get(node.func.attr, None) + matcher = self._spark_matchers.matchers.get(node.func.attrname, None) if matcher is None: return None return matcher if matcher.matches(node) else None diff --git a/src/databricks/labs/ucx/source_code/linters/spark_connect.py b/src/databricks/labs/ucx/source_code/linters/spark_connect.py index 15047d8146..2f8f125a22 100644 --- a/src/databricks/labs/ucx/source_code/linters/spark_connect.py +++ b/src/databricks/labs/ucx/source_code/linters/spark_connect.py @@ -1,14 +1,15 @@ -import ast from abc import abstractmethod from collections.abc import Iterator from dataclasses import dataclass +from astroid import Attribute, Call, Name, NodeNG # type: ignore from databricks.labs.ucx.source_code.base import ( Advice, Failure, Linter, ) from databricks.labs.ucx.source_code.linters.ast_helpers import AstHelper +from databricks.labs.ucx.source_code.linters.imports import ASTLinter, TreeWalker @dataclass @@ -19,12 +20,12 @@ def _cluster_type_str(self) -> str: return 'UC Shared Clusters' if not self.is_serverless else 'Serverless Compute' @abstractmethod - def lint(self, node: ast.AST) -> Iterator[Advice]: + def lint(self, node: NodeNG) -> Iterator[Advice]: pass - def lint_tree(self, tree: ast.AST) -> Iterator[Advice]: + def lint_tree(self, tree: NodeNG) -> Iterator[Advice]: reported_locations = set() - for node in ast.walk(tree): + for node in TreeWalker.walk(tree): for advice in self.lint(node): loc = (advice.start_line, advice.start_col) if loc not in reported_locations: @@ -41,10 +42,10 @@ class JvmAccessMatcher(SharedClusterMatcher): "_jsparkSession", ] - def lint(self, node: ast.AST) -> Iterator[Advice]: - if not isinstance(node, ast.Attribute): + def lint(self, node: NodeNG) -> Iterator[Advice]: + if not isinstance(node, Attribute): return - if node.attr not in JvmAccessMatcher._FIELDS: + if node.attrname not in JvmAccessMatcher._FIELDS: return yield Failure( code='jvm-access-in-shared-clusters', @@ -76,16 +77,16 @@ class RDDApiMatcher(SharedClusterMatcher): "wholeTextFiles", ] - def lint(self, node: ast.AST) -> Iterator[Advice]: + def lint(self, node: NodeNG) -> Iterator[Advice]: yield from self._lint_sc(node) yield from self._lint_rdd_use(node) - def _lint_rdd_use(self, node: ast.AST) -> Iterator[Advice]: - if isinstance(node, ast.Attribute): - if node.attr == 'rdd': + def _lint_rdd_use(self, node: NodeNG) -> Iterator[Advice]: + if isinstance(node, Attribute): + if node.attrname == 'rdd': yield self._rdd_failure(node) return - if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and node.func.attr == 'mapPartitions': + if isinstance(node, Call) and isinstance(node.func, Attribute) and node.func.attrname == 'mapPartitions': yield Failure( code='rdd-in-shared-clusters', message=f'RDD APIs are not supported on {self._cluster_type_str()}. ' @@ -96,17 +97,17 @@ def _lint_rdd_use(self, node: ast.AST) -> Iterator[Advice]: end_col=node.end_col_offset or 0, ) - def _lint_sc(self, node: ast.AST) -> Iterator[Advice]: - if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute): + def _lint_sc(self, node: NodeNG) -> Iterator[Advice]: + if not isinstance(node, Call) or not isinstance(node.func, Attribute): return - if node.func.attr not in self._SC_METHODS: + if node.func.attrname not in self._SC_METHODS: return function_name = AstHelper.get_full_function_name(node) - if not function_name or not function_name.endswith(f"sc.{node.func.attr}"): + if not function_name or not function_name.endswith(f"sc.{node.func.attrname}"): return yield self._rdd_failure(node) - def _rdd_failure(self, node: ast.AST) -> Advice: + def _rdd_failure(self, node: NodeNG) -> Advice: return Failure( code='rdd-in-shared-clusters', message=f'RDD APIs are not supported on {self._cluster_type_str()}. Rewrite it using DataFrame API', @@ -121,22 +122,22 @@ class SparkSqlContextMatcher(SharedClusterMatcher): _ATTRIBUTES = ["sc", "sqlContext", "sparkContext"] _KNOWN_REPLACEMENTS = {"getConf": "conf", "_conf": "conf"} - def lint(self, node: ast.AST) -> Iterator[Advice]: - if not isinstance(node, ast.Attribute): + def lint(self, node: NodeNG) -> Iterator[Advice]: + if not isinstance(node, Attribute): return - if isinstance(node.value, ast.Name) and node.value.id in SparkSqlContextMatcher._ATTRIBUTES: - yield self._get_advice(node, node.value.id) + if isinstance(node.expr, Name) and node.expr.name in SparkSqlContextMatcher._ATTRIBUTES: + yield self._get_advice(node, node.expr.name) # sparkContext can be an attribute as in df.sparkContext.getConf() - if isinstance(node.value, ast.Attribute) and node.value.attr == 'sparkContext': - yield self._get_advice(node, node.value.attr) + if isinstance(node.expr, Attribute) and node.expr.attrname == 'sparkContext': + yield self._get_advice(node, node.expr.attrname) - def _get_advice(self, node: ast.Attribute, name: str) -> Advice: - if node.attr in SparkSqlContextMatcher._KNOWN_REPLACEMENTS: - replacement = SparkSqlContextMatcher._KNOWN_REPLACEMENTS[node.attr] + def _get_advice(self, node: Attribute, name: str) -> Advice: + if node.attrname in self._KNOWN_REPLACEMENTS: + replacement = self._KNOWN_REPLACEMENTS[node.attrname] return Failure( code='legacy-context-in-shared-clusters', - message=f'{name} and {node.attr} are not supported on {self._cluster_type_str()}. ' + message=f'{name} and {node.attrname} are not supported on {self._cluster_type_str()}. ' f'Rewrite it using spark.{replacement}', start_line=node.lineno, start_col=node.col_offset, @@ -154,14 +155,14 @@ def _get_advice(self, node: ast.Attribute, name: str) -> Advice: class LoggingMatcher(SharedClusterMatcher): - def lint(self, node: ast.AST) -> Iterator[Advice]: + def lint(self, node: NodeNG) -> Iterator[Advice]: yield from self._match_sc_set_log_level(node) yield from self._match_jvm_log(node) - def _match_sc_set_log_level(self, node: ast.AST) -> Iterator[Advice]: - if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute): + def _match_sc_set_log_level(self, node: NodeNG) -> Iterator[Advice]: + if not isinstance(node, Call) or not isinstance(node.func, Attribute): return - if node.func.attr != 'setLogLevel': + if node.func.attrname != 'setLogLevel': return function_name = AstHelper.get_full_function_name(node) if not function_name or not function_name.endswith('sc.setLogLevel'): @@ -177,8 +178,8 @@ def _match_sc_set_log_level(self, node: ast.AST) -> Iterator[Advice]: end_col=node.end_col_offset or 0, ) - def _match_jvm_log(self, node: ast.AST) -> Iterator[Advice]: - if not isinstance(node, ast.Attribute): + def _match_jvm_log(self, node: NodeNG) -> Iterator[Advice]: + if not isinstance(node, Attribute): return attribute_name = AstHelper.get_full_attribute_name(node) if attribute_name and attribute_name.endswith('org.apache.log4j'): @@ -203,6 +204,6 @@ def __init__(self, is_serverless: bool = False): ] def lint(self, code: str) -> Iterator[Advice]: - tree = ast.parse(code) + linter = ASTLinter.parse(code) for matcher in self._matchers: - yield from matcher.lint_tree(tree) + yield from matcher.lint_tree(linter.root) diff --git a/src/databricks/labs/ucx/source_code/linters/table_creation.py b/src/databricks/labs/ucx/source_code/linters/table_creation.py index b339f6bedc..d13191b839 100644 --- a/src/databricks/labs/ucx/source_code/linters/table_creation.py +++ b/src/databricks/labs/ucx/source_code/linters/table_creation.py @@ -1,15 +1,15 @@ from __future__ import annotations -import ast from collections.abc import Iterable, Iterator from dataclasses import dataclass +from astroid import Attribute, Call, NodeNG # type: ignore -from databricks.labs.ucx.source_code.linters.imports import ASTLinter from databricks.labs.ucx.source_code.base import ( Advice, Linter, ) +from databricks.labs.ucx.source_code.linters.imports import ASTLinter, TreeWalker @dataclass @@ -37,38 +37,36 @@ class NoFormatPythonMatcher: format_arg_index: int | None = None format_arg_name: str | None = None - def get_advice_span(self, node: ast.AST) -> Range | None: - # Check 1: retrieve full callchain: - callchain = ASTLinter(node).extract_callchain() - if callchain is None: + def get_advice_span(self, node: NodeNG) -> Range | None: + # Check 1: check Call: + if not isinstance(node, Call): return None # Check 2: check presence of the table-creating method call: - call = ASTLinter(callchain).extract_call_by_name(self.method_name) - if call is None: + if not isinstance(node.func, Attribute) or node.func.attrname != self.method_name: return None - call_args_count = ASTLinter(call).args_count() + call_args_count = ASTLinter.args_count(node) if call_args_count < self.min_args or call_args_count > self.max_args: return None # Check 3: check presence of the format specifier: # Option A: format specifier may be given as a direct parameter to the table-creating call # >>> df.saveToTable("c.db.table", format="csv") - format_arg = ASTLinter(call).get_arg(self.format_arg_index, self.format_arg_name) - if format_arg is not None and not ASTLinter(format_arg).is_none(): + format_arg = ASTLinter.get_arg(node, self.format_arg_index, self.format_arg_name) + if format_arg is not None and not ASTLinter.is_none(format_arg): # i.e., found an explicit "format" argument, and its value is not None. return None # Option B. format specifier may be a separate ".format(...)" call in this callchain # >>> df.format("csv").saveToTable("c.db.table") - format_call = ASTLinter(callchain).extract_call_by_name("format") + format_call = ASTLinter.extract_call_by_name(node, "format") if format_call is not None: # i.e., found an explicit ".format(...)" call in this chain. return None # Finally: matched the need for advice, so return the corresponding source range: return Range( - Position(call.lineno, call.col_offset), - Position(call.end_lineno or 0, call.end_col_offset or 0), + Position(node.lineno, node.col_offset), + Position(node.end_lineno or 0, node.end_col_offset or 0), ) @@ -78,7 +76,7 @@ class NoFormatPythonLinter: def __init__(self, matchers: list[NoFormatPythonMatcher]): self._matchers = matchers - def lint(self, node: ast.AST) -> Iterator[Advice]: + def lint(self, node: NodeNG) -> Iterator[Advice]: for matcher in self._matchers: span = matcher.get_advice_span(node) if span is not None: @@ -115,6 +113,6 @@ def lint(self, code: str) -> Iterable[Advice]: if self._skip_dbr: return - tree = ast.parse(code) - for node in ast.walk(tree): + linter = ASTLinter.parse(code) + for node in TreeWalker.walk(linter.root): yield from self._linter.lint(node) diff --git a/src/databricks/labs/ucx/source_code/python_libraries.py b/src/databricks/labs/ucx/source_code/python_libraries.py index 5b5fb9cc42..d1d1efd2c8 100644 --- a/src/databricks/labs/ucx/source_code/python_libraries.py +++ b/src/databricks/labs/ucx/source_code/python_libraries.py @@ -44,7 +44,7 @@ def _temporary_virtual_environment(self): # environment. If we don't have a virtual environment, create a temporary one. # simulate notebook-scoped virtual environment lib_install_folder = tempfile.mkdtemp(prefix='ucx-') - return Path(lib_install_folder) + return Path(lib_install_folder).resolve() def _install_library(self, path_lookup: PathLookup, library: Path) -> list[DependencyProblem]: """Pip install library and augment path look-up to resolve the library at import""" diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 7e1ae758ba..b06fbf1160 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -495,6 +495,7 @@ def config(self) -> WorkspaceConfig: ) def save_locations(self): + locations: list[ExternalLocation] = [] if self.workspace_client.config.is_azure: locations = [ExternalLocation("abfss://things@labsazurethings.dfs.core.windows.net/a", 1)] if self.workspace_client.config.is_aws: diff --git a/tests/integration/source_code/test_libraries.py b/tests/integration/source_code/test_libraries.py index 267ce7f864..0a51e0db9e 100644 --- a/tests/integration/source_code/test_libraries.py +++ b/tests/integration/source_code/test_libraries.py @@ -12,4 +12,4 @@ def test_loads_pip_library_from_notebook(simple_ctx): maybe = ctx.dependency_resolver.build_notebook_dependency_graph(Path('install_demo_wheel')) assert not maybe.problems - assert maybe.graph.all_relative_names() == {'install_demo_wheel.py'} + assert maybe.graph.all_relative_names() == {'install_demo_wheel.py', 'thingy/__init__.py'} diff --git a/tests/unit/source_code/linters/test_dbfs.py b/tests/unit/source_code/linters/test_dbfs.py index 1e469cf768..fc2f562fd2 100644 --- a/tests/unit/source_code/linters/test_dbfs.py +++ b/tests/unit/source_code/linters/test_dbfs.py @@ -8,9 +8,8 @@ class TestDetectDBFS: @pytest.mark.parametrize( "code, expected", [ - ('"/dbfs/mnt"', 1), - ('"not a file system path"', 0), - ('"/dbfs/mnt", "dbfs:/", "/mnt/"', 3), + ('SOME_CONSTANT = "not a file system path"', 0), + ('SOME_CONSTANT = ("/dbfs/mnt", "dbfs:/", "/mnt/")', 3), ('# "/dbfs/mnt"', 0), ('SOME_CONSTANT = "/dbfs/mnt"', 1), ('SOME_CONSTANT = "/dbfs/mnt"; load_data(SOME_CONSTANT)', 1), @@ -18,13 +17,11 @@ class TestDetectDBFS: ], ) def test_detects_dbfs_str_const_paths(self, code, expected): - finder = DBFSUsageLinter() - advices = finder.lint(code) - count = 0 + linter = DBFSUsageLinter() + advices = list(linter.lint(code)) for advice in advices: assert isinstance(advice, Advisory) - count += 1 - assert count == expected + assert len(advices) == expected @pytest.mark.parametrize( "code, expected", diff --git a/tests/unit/source_code/linters/test_pyspark.py b/tests/unit/source_code/linters/test_pyspark.py index 842cbe1de4..220e60271b 100644 --- a/tests/unit/source_code/linters/test_pyspark.py +++ b/tests/unit/source_code/linters/test_pyspark.py @@ -1,9 +1,10 @@ -import ast - import pytest +from astroid import Call, Const, Expr # type: ignore + from databricks.labs.ucx.source_code.base import Deprecation, CurrentSessionState -from databricks.labs.ucx.source_code.linters.pyspark import SparkSql, AstHelper, TableNameMatcher +from databricks.labs.ucx.source_code.linters.imports import ASTLinter +from databricks.labs.ucx.source_code.linters.pyspark import AstHelper, TableNameMatcher, SparkSql from databricks.labs.ucx.source_code.queries import FromTable @@ -95,7 +96,7 @@ def test_spark_table_return_value_apply(migration_index): do_stuff_with_table(table)""" fixed_code = sqf.apply(old_code) # no transformations to apply, only lint messages - assert fixed_code == old_code + assert fixed_code.rstrip() == old_code.rstrip() def test_spark_sql_fix(migration_index): @@ -109,7 +110,7 @@ def test_spark_sql_fix(migration_index): """ fixed_code = sqf.apply(old_code) assert ( - fixed_code + fixed_code.rstrip() == """spark.read.csv('s3://bucket/path') for i in range(10): result = spark.sql('SELECT * FROM brand.new.stuff').collect() @@ -548,53 +549,67 @@ def test_direct_cloud_access_reports_nothing(empty_index, fs_function): assert not advisories -def test_get_full_function_name(): +def test_get_full_function_name_for_member_function(): + linter = ASTLinter.parse("value.attr()") + node = linter.first_statement() + assert isinstance(node, Expr) + assert isinstance(node.value, Call) + assert AstHelper.get_full_function_name(node.value) == 'value.attr' + + +def test_get_full_function_name_for_member_member_function(): + linter = ASTLinter.parse("value1.value2.attr()") + node = linter.first_statement() + assert isinstance(node, Expr) + assert isinstance(node.value, Call) + assert AstHelper.get_full_function_name(node.value) == 'value1.value2.attr' - # Test when node.func is an instance of ast.Attribute - node = ast.Call(func=ast.Attribute(value=ast.Name(id='value'), attr='attr')) - # noinspection PyProtectedMember - assert AstHelper.get_full_function_name(node) == 'value.attr' - # Test when node.func is an instance of ast.Name - node = ast.Call(func=ast.Name(id='name')) - # noinspection PyProtectedMember - assert AstHelper.get_full_function_name(node) == 'name' +def test_get_full_function_name_for_chained_function(): + linter = ASTLinter.parse("value.attr1().attr2()") + node = linter.first_statement() + assert isinstance(node, Expr) + assert isinstance(node.value, Call) + assert AstHelper.get_full_function_name(node.value) == 'value.attr1.attr2' - # Test when node.func is neither ast.Attribute nor ast.Name - node = ast.Call(func=ast.Constant(value='constant')) - # noinspection PyProtectedMember - assert AstHelper.get_full_function_name(node) is None - # Test when next_node in _get_value is an instance of ast.Name - node = ast.Call(func=ast.Attribute(value=ast.Name(id='name'), attr='attr')) - # noinspection PyProtectedMember - assert AstHelper.get_full_function_name(node) == 'name.attr' +def test_get_full_function_name_for_global_function(): + linter = ASTLinter.parse("name()") + node = linter.first_statement() + assert isinstance(node, Expr) + assert isinstance(node.value, Call) + assert AstHelper.get_full_function_name(node.value) == 'name' - # Test when next_node in _get_value is an instance of ast.Attribute - node = ast.Call(func=ast.Attribute(value=ast.Attribute(value=ast.Name(id='value'), attr='attr'), attr='attr')) - # noinspection PyProtectedMember - assert AstHelper.get_full_function_name(node) == 'value.attr.attr' - # Test when next_node in _get_value is neither ast.Name nor ast.Attribute - node = ast.Call(func=ast.Attribute(value=ast.Constant(value='constant'), attr='attr')) - # noinspection PyProtectedMember - assert AstHelper.get_full_function_name(node) is None +def test_get_full_function_name_for_non_method(): + linter = ASTLinter.parse("not_a_function") + node = linter.first_statement() + assert isinstance(node, Expr) + assert AstHelper.get_full_function_name(node.value) is None -def test_apply_table_name_matcher(migration_index): +def test_apply_table_name_matcher_with_missing_constant(migration_index): from_table = FromTable(migration_index, CurrentSessionState('old')) matcher = TableNameMatcher('things', 1, 1, 0) - # Test when table_arg is an instance of ast.Constant but the destination does not exist in the index - node = ast.Call(args=[ast.Constant(value='some.things')]) - matcher.apply(from_table, migration_index, node) - table_constant = node.args[0] - assert isinstance(table_constant, ast.Constant) + linter = ASTLinter.parse("call('some.things')") + node = linter.first_statement() + assert isinstance(node, Expr) + assert isinstance(node.value, Call) + matcher.apply(from_table, migration_index, node.value) + table_constant = node.value.args[0] + assert isinstance(table_constant, Const) assert table_constant.value == 'some.things' - # Test when table_arg is an instance of ast.Constant and the destination exists in the index - node = ast.Call(args=[ast.Constant(value='old.things')]) - matcher.apply(from_table, migration_index, node) - table_constant = node.args[0] - assert isinstance(table_constant, ast.Constant) + +def test_apply_table_name_matcher_with_existing_constant(migration_index): + from_table = FromTable(migration_index, CurrentSessionState('old')) + matcher = TableNameMatcher('things', 1, 1, 0) + linter = ASTLinter.parse("call('old.things')") + node = linter.first_statement() + assert isinstance(node, Expr) + assert isinstance(node.value, Call) + matcher.apply(from_table, migration_index, node.value) + table_constant = node.value.args[0] + assert isinstance(table_constant, Const) assert table_constant.value == 'brand.new.stuff' diff --git a/tests/unit/source_code/linters/test_python_linter.py b/tests/unit/source_code/linters/test_python_linter.py index 457a2195ca..d58f9c3d45 100644 --- a/tests/unit/source_code/linters/test_python_linter.py +++ b/tests/unit/source_code/linters/test_python_linter.py @@ -1,10 +1,11 @@ from __future__ import annotations -import ast + import pytest +from astroid import Attribute, Call, Const, Expr # type: ignore from databricks.labs.ucx.source_code.graph import DependencyProblem -from databricks.labs.ucx.source_code.linters.imports import ASTLinter, DbutilsLinter +from databricks.labs.ucx.source_code.linters.imports import ASTLinter, DbutilsLinter, TreeWalker def test_linter_returns_empty_list_of_dbutils_notebook_run_calls(): @@ -25,7 +26,7 @@ def test_linter_returns_list_of_dbutils_notebook_run_calls(): def test_linter_returns_empty_list_of_imports(): linter = ASTLinter.parse('') - assert [] == DbutilsLinter.list_import_sources(linter, DependencyProblem)[0] + assert not DbutilsLinter.list_import_sources(linter, DependencyProblem)[0] def test_linter_returns_import(): @@ -135,96 +136,100 @@ def test_linter_returns_appended_relative_paths_with_os_path_abspath_alias(): assert "relative_path" in [p.path for p in appended] -def get_statement_node(stmt: str) -> ast.stmt: - node = ast.parse(stmt) - return node.body[0] - - -@pytest.mark.parametrize("stmt", ["o.m1().m2().m3()", "a = o.m1().m2().m3()"]) -def test_extract_callchain(migration_index, stmt): - node = get_statement_node(stmt) - act = ASTLinter(node).extract_callchain() - assert isinstance(act, ast.Call) - assert isinstance(act.func, ast.Attribute) - assert act.func.attr == "m3" - - -@pytest.mark.parametrize("stmt", ["a = 3", "[x+1 for x in xs]"]) -def test_extract_callchain_none(migration_index, stmt): - node = get_statement_node(stmt) - act = ASTLinter(node).extract_callchain() - assert act is None - - -def test_extract_call_by_name(migration_index): - callchain = get_statement_node("o.m1().m2().m3()").value - act = ASTLinter(callchain).extract_call_by_name("m2") - assert isinstance(act, ast.Call) - assert isinstance(act.func, ast.Attribute) - assert act.func.attr == "m2" +def test_extract_call_by_name(): + linter = ASTLinter.parse("o.m1().m2().m3()") + stmt = linter.first_statement() + assert isinstance(stmt, Expr) + act = ASTLinter.extract_call_by_name(stmt.value, "m2") + assert isinstance(act, Call) + assert isinstance(act.func, Attribute) + assert act.func.attrname == "m2" -def test_extract_call_by_name_none(migration_index): - callchain = get_statement_node("o.m1().m2().m3()").value - act = ASTLinter(callchain).extract_call_by_name("m5000") +def test_extract_call_by_name_none(): + linter = ASTLinter.parse("o.m1().m2().m3()") + stmt = linter.first_statement() + assert isinstance(stmt, Expr) + assert isinstance(stmt.value, Call) + act = ASTLinter.extract_call_by_name(stmt.value, "m5000") assert act is None @pytest.mark.parametrize( - "param", + "code, arg_index, arg_name, expected", [ - {"stmt": "o.m1()", "arg_index": 1, "arg_name": "second", "expected": None}, - {"stmt": "o.m1(3)", "arg_index": 1, "arg_name": "second", "expected": None}, - {"stmt": "o.m1(first=3)", "arg_index": 1, "arg_name": "second", "expected": None}, - {"stmt": "o.m1(4, 3)", "arg_index": None, "arg_name": None, "expected": None}, - {"stmt": "o.m1(4, 3)", "arg_index": None, "arg_name": "second", "expected": None}, - {"stmt": "o.m1(4, 3)", "arg_index": 1, "arg_name": "second", "expected": 3}, - {"stmt": "o.m1(4, 3)", "arg_index": 1, "arg_name": None, "expected": 3}, - {"stmt": "o.m1(first=4, second=3)", "arg_index": 1, "arg_name": "second", "expected": 3}, - {"stmt": "o.m1(second=3, first=4)", "arg_index": 1, "arg_name": "second", "expected": 3}, - {"stmt": "o.m1(second=3, first=4)", "arg_index": None, "arg_name": "second", "expected": 3}, - {"stmt": "o.m1(second=3)", "arg_index": 1, "arg_name": "second", "expected": 3}, - {"stmt": "o.m1(4, 3, 2)", "arg_index": 1, "arg_name": "second", "expected": 3}, + ("o.m1()", 1, "second", None), + ("o.m1(3)", 1, "second", None), + ("o.m1(first=3)", 1, "second", None), + ("o.m1(4, 3)", None, None, None), + ("o.m1(4, 3)", None, "second", None), + ("o.m1(4, 3)", 1, "second", 3), + ("o.m1(4, 3)", 1, None, 3), + ("o.m1(first=4, second=3)", 1, "second", 3), + ("o.m1(second=3, first=4)", 1, "second", 3), + ("o.m1(second=3, first=4)", None, "second", 3), + ("o.m1(second=3)", 1, "second", 3), + ("o.m1(4, 3, 2)", 1, "second", 3), ], ) -def test_get_arg(migration_index, param): - call = get_statement_node(param["stmt"]).value - act = ASTLinter(call).get_arg(param["arg_index"], param["arg_name"]) - if param["expected"] is None: +def test_linter_gets_arg(code, arg_index, arg_name, expected): + linter = ASTLinter.parse(code) + stmt = linter.first_statement() + assert isinstance(stmt, Expr) + assert isinstance(stmt.value, Call) + act = ASTLinter.get_arg(stmt.value, arg_index, arg_name) + if expected is None: assert act is None else: - assert isinstance(act, ast.Constant) - assert act.value == param["expected"] + assert isinstance(act, Const) + assert act.value == expected @pytest.mark.parametrize( - "param", + "code, expected", [ - {"stmt": "o.m1()", "expected": 0}, - {"stmt": "o.m1(3)", "expected": 1}, - {"stmt": "o.m1(first=3)", "expected": 1}, - {"stmt": "o.m1(3, 3)", "expected": 2}, - {"stmt": "o.m1(first=3, second=3)", "expected": 2}, - {"stmt": "o.m1(3, second=3)", "expected": 2}, - {"stmt": "o.m1(3, *b, **c, second=3)", "expected": 4}, + ("o.m1()", 0), + ("o.m1(3)", 1), + ("o.m1(first=3)", 1), + ("o.m1(3, 3)", 2), + ("o.m1(first=3, second=3)", 2), + ("o.m1(3, second=3)", 2), + ("o.m1(3, *b, **c, second=3)", 4), ], ) -def test_args_count(migration_index, param): - call = get_statement_node(param["stmt"]).value - act = ASTLinter(call).args_count() - assert param["expected"] == act +def test_args_count(code, expected): + linter = ASTLinter.parse(code) + stmt = linter.first_statement() + assert isinstance(stmt, Expr) + assert isinstance(stmt.value, Call) + act = ASTLinter.args_count(stmt.value) + assert act == expected @pytest.mark.parametrize( - "param", + "code, expected", [ - {"stmt": "a = x", "expected": False}, - {"stmt": "a = 3", "expected": False}, - {"stmt": "a = 'None'", "expected": False}, - {"stmt": "a = None", "expected": True}, + ( + """ +name = "xyz" +dbutils.notebook.run(name) +""", + "xyz", + ) ], ) -def test_is_none(migration_index, param): - val = get_statement_node(param["stmt"]).value - act = ASTLinter(val).is_none() - assert param["expected"] == act +def test_infers_string_variable_value(code, expected): + linter = ASTLinter.parse(code) + calls = DbutilsLinter.list_dbutils_notebook_run_calls(linter) + actual = list(call.get_notebook_path() for call in calls) + assert [expected] == actual + + +def test_tree_walker_walks_nodes_once(): + nodes = set() + count = 0 + linter = ASTLinter.parse("o.m1().m2().m3()") + for node in TreeWalker.walk(linter.root): + nodes.add(node) + count += 1 + assert len(nodes) == count diff --git a/tests/unit/source_code/linters/test_spark_connect.py b/tests/unit/source_code/linters/test_spark_connect.py index 907cd09de6..8d6a8bfe70 100644 --- a/tests/unit/source_code/linters/test_spark_connect.py +++ b/tests/unit/source_code/linters/test_spark_connect.py @@ -1,7 +1,7 @@ -import ast from itertools import chain from databricks.labs.ucx.source_code.base import Failure +from databricks.labs.ucx.source_code.linters.imports import TreeWalker, ASTLinter from databricks.labs.ucx.source_code.linters.spark_connect import LoggingMatcher, SparkConnectLinter @@ -11,8 +11,7 @@ def test_jvm_access_match_shared(): spark.range(10).collect() spark._jspark._jvm.com.my.custom.Name() """ - - assert [ + expected = [ Failure( code="jvm-access-in-shared-clusters", message='Cannot access Spark Driver JVM on UC Shared Clusters', @@ -21,7 +20,9 @@ def test_jvm_access_match_shared(): end_line=3, end_col=18, ), - ] == list(linter.lint(code)) + ] + actual = list(linter.lint(code)) + assert actual == expected def test_jvm_access_match_serverless(): @@ -31,7 +32,7 @@ def test_jvm_access_match_serverless(): spark._jspark._jvm.com.my.custom.Name() """ - assert [ + expected = [ Failure( code="jvm-access-in-shared-clusters", message='Cannot access Spark Driver JVM on Serverless Compute', @@ -40,7 +41,9 @@ def test_jvm_access_match_serverless(): end_line=3, end_col=18, ), - ] == list(linter.lint(code)) + ] + actual = list(linter.lint(code)) + assert actual == expected def test_rdd_context_match_shared(): @@ -49,7 +52,7 @@ def test_rdd_context_match_shared(): rdd1 = sc.parallelize([1, 2, 3]) rdd2 = spark.createDataFrame(sc.emptyRDD(), schema) """ - assert [ + expected = [ Failure( code="rdd-in-shared-clusters", message='RDD APIs are not supported on UC Shared Clusters. Rewrite it using DataFrame API', @@ -82,7 +85,9 @@ def test_rdd_context_match_shared(): end_line=3, end_col=40, ), - ] == list(linter.lint(code)) + ] + actual = list(linter.lint(code)) + assert actual == expected def test_rdd_context_match_serverless(): @@ -133,7 +138,7 @@ def test_rdd_map_partitions(): df = spark.createDataFrame([]) df.rdd.mapPartitions(myUdf) """ - assert [ + expected = [ Failure( code="rdd-in-shared-clusters", message='RDD APIs are not supported on UC Shared Clusters. Use mapInArrow() or Pandas UDFs instead', @@ -142,7 +147,9 @@ def test_rdd_map_partitions(): end_line=3, end_col=27, ), - ] == list(linter.lint(code)) + ] + actual = list(linter.lint(code)) + assert actual == expected def test_conf_shared(): @@ -163,7 +170,7 @@ def test_conf_shared(): def test_conf_serverless(): linter = SparkConnectLinter(is_serverless=True) code = """sc._conf().get('spark.my.conf')""" - assert [ + expected = [ Failure( code='legacy-context-in-shared-clusters', message='sc and _conf are not supported on Serverless Compute. Rewrite it using spark.conf', @@ -172,7 +179,9 @@ def test_conf_serverless(): end_line=1, end_col=8, ), - ] == list(linter.lint(code)) + ] + actual = list(linter.lint(code)) + assert actual == expected def test_logging_shared(): @@ -213,7 +222,7 @@ def test_logging_shared(): end_line=7, end_col=24, ), - ] == list(chain.from_iterable([logging_matcher.lint(node) for node in ast.walk(ast.parse(code))])) + ] == list(chain.from_iterable([logging_matcher.lint(node) for node in TreeWalker.walk(ASTLinter.parse(code).root)])) def test_logging_serverless(): @@ -242,7 +251,7 @@ def test_logging_serverless(): end_line=3, end_col=38, ), - ] == list(chain.from_iterable([logging_matcher.lint(node) for node in ast.walk(ast.parse(code))])) + ] == list(chain.from_iterable([logging_matcher.lint(node) for node in TreeWalker.walk(ASTLinter.parse(code).root)])) def test_valid_code(): diff --git a/tests/unit/source_code/linters/test_table_creation.py b/tests/unit/source_code/linters/test_table_creation.py index 34ee80558d..a90c6e9d31 100644 --- a/tests/unit/source_code/linters/test_table_creation.py +++ b/tests/unit/source_code/linters/test_table_creation.py @@ -120,4 +120,5 @@ def test_dbr_version_filter(migration_index, dbr_version): """Tests the DBR version cutoff filter""" old_code = get_code(False, 'spark.foo().bar().table("catalog.db.table").baz()') expected = [] if dbr_version["suppress"] else [get_advice(False, 'table', 18)] - assert expected == lint(old_code, dbr_version["version"]) + actual = lint(old_code, dbr_version["version"]) + assert actual == expected diff --git a/tests/unit/source_code/samples/root10.py b/tests/unit/source_code/samples/root10.py deleted file mode 100644 index 421b9a878f..0000000000 --- a/tests/unit/source_code/samples/root10.py +++ /dev/null @@ -1,3 +0,0 @@ -# Databricks notebook source -some_notebook = "./leaf3.py" -dbutils.notebook.run(some_notebook) diff --git a/tests/unit/source_code/test_dependencies.py b/tests/unit/source_code/test_dependencies.py index fdab6a7655..1e5352de0b 100644 --- a/tests/unit/source_code/test_dependencies.py +++ b/tests/unit/source_code/test_dependencies.py @@ -6,14 +6,14 @@ DependencyProblem, Dependency, ) +from databricks.labs.ucx.source_code.linters.files import FileLoader, ImportFileResolver from databricks.labs.ucx.source_code.notebooks.loaders import ( NotebookResolver, NotebookLoader, ) -from databricks.labs.ucx.source_code.linters.files import FileLoader, ImportFileResolver from databricks.labs.ucx.source_code.path_lookup import PathLookup -from databricks.labs.ucx.source_code.python_libraries import PythonLibraryResolver from databricks.labs.ucx.source_code.known import Whitelist +from databricks.labs.ucx.source_code.python_libraries import PythonLibraryResolver from tests.unit import ( locate_site_packages, ) @@ -94,25 +94,6 @@ def test_dependency_resolver_raises_problem_with_unfound_local_notebook_dependen ] -def test_dependency_resolver_raises_problem_with_non_constant_local_notebook_dependency(mock_path_lookup): - notebook_loader = NotebookLoader() - notebook_resolver = NotebookResolver(notebook_loader) - pip_resolver = PythonLibraryResolver(Whitelist()) - dependency_resolver = DependencyResolver(pip_resolver, notebook_resolver, [], mock_path_lookup) - maybe = dependency_resolver.build_notebook_dependency_graph(Path('root10.py')) - assert list(maybe.problems) == [ - DependencyProblem( - 'dependency-not-constant', - "Can't check dependency not provided as a constant", - Path('root10.py'), - 2, - 0, - 2, - 35, - ) - ] - - def test_dependency_resolver_raises_problem_with_invalid_run_cell(mock_path_lookup): notebook_loader = NotebookLoader() notebook_resolver = NotebookResolver(notebook_loader) @@ -147,27 +128,6 @@ def test_dependency_resolver_raises_problem_with_unresolved_import(mock_path_loo ] -def test_dependency_resolver_raises_problem_with_non_constant_notebook_argument(mock_path_lookup): - notebook_loader = NotebookLoader() - notebook_resolver = NotebookResolver(notebook_loader) - whitelist = Whitelist() - import_resolver = ImportFileResolver(FileLoader(), whitelist) - pip_resolver = PythonLibraryResolver(whitelist) - dependency_resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, mock_path_lookup) - maybe = dependency_resolver.build_local_file_dependency_graph(Path("run_notebooks.py")) - assert list(maybe.problems) == [ - DependencyProblem( - 'dependency-not-constant', - "Can't check dependency not provided as a constant", - Path("run_notebooks.py"), - 14, - 13, - 14, - 50, - ) - ] - - def test_dependency_resolver_visits_file_dependencies(mock_path_lookup): notebook_loader = NotebookLoader() notebook_resolver = NotebookResolver(notebook_loader) diff --git a/tests/unit/source_code/test_functional.py b/tests/unit/source_code/test_functional.py index f251a0cc28..a256b43850 100644 --- a/tests/unit/source_code/test_functional.py +++ b/tests/unit/source_code/test_functional.py @@ -50,7 +50,7 @@ def verify(self): actual_problems = sorted(list(self._lint()), key=lambda a: (a.start_line, a.start_col)) high_level_expected = [f'{p.code}:{p.message}' for p in expected_problems] high_level_actual = [f'{p.code}:{p.message}' for p in actual_problems] - assert high_level_expected == high_level_actual + assert high_level_actual == high_level_expected # TODO: match start/end lines/columns as well. At the moment notebook parsing has a bug that makes it impossible # TODO: output annotated file with comments for quick fixing diff --git a/tests/unit/source_code/test_notebook.py b/tests/unit/source_code/test_notebook.py index a1c52834c6..de3b00c309 100644 --- a/tests/unit/source_code/test_notebook.py +++ b/tests/unit/source_code/test_notebook.py @@ -7,13 +7,13 @@ from databricks.labs.ucx.source_code.base import Advisory from databricks.labs.ucx.source_code.graph import DependencyGraph, SourceContainer, DependencyResolver from databricks.labs.ucx.source_code.known import Whitelist +from databricks.labs.ucx.source_code.linters.imports import DbutilsLinter from databricks.labs.ucx.source_code.notebooks.sources import Notebook from databricks.labs.ucx.source_code.notebooks.loaders import ( NotebookResolver, NotebookLoader, ) from databricks.labs.ucx.source_code.python_libraries import PythonLibraryResolver -from databricks.labs.ucx.source_code.linters.imports import DbutilsLinter from tests.unit import _load_sources # fmt: off diff --git a/tests/unit/source_code/test_notebook_linter.py b/tests/unit/source_code/test_notebook_linter.py index 7ac234f6bb..fa6f7c3ee5 100644 --- a/tests/unit/source_code/test_notebook_linter.py +++ b/tests/unit/source_code/test_notebook_linter.py @@ -476,7 +476,15 @@ def test_notebook_linter_name(): [ Deprecation( code='table-migrate', - message='Table people is migrated to cata4.nondefault.newpeople ' 'in Unity Catalog', + message='Table people is migrated to cata4.nondefault.newpeople in Unity Catalog', + start_line=6, + start_col=8, + end_line=6, + end_col=29, + ), + Advice( + code='table-migrate', + message='The default format changed in Databricks Runtime 8.0, from Parquet to Delta', start_line=6, start_col=8, end_line=6, @@ -484,7 +492,15 @@ def test_notebook_linter_name(): ), Deprecation( code='table-migrate', - message='Table persons is migrated to cata4.newsomething.persons ' 'in Unity Catalog', + message='Table persons is migrated to cata4.newsomething.persons in Unity Catalog', + start_line=14, + start_col=8, + end_line=14, + end_col=30, + ), + Advice( + code='table-migrate', + message='The default format changed in Databricks Runtime 8.0, from Parquet to Delta', start_line=14, start_col=8, end_line=14, @@ -492,7 +508,15 @@ def test_notebook_linter_name(): ), Deprecation( code='table-migrate', - message='Table kittens is migrated to cata4.felines.toms in Unity ' 'Catalog', + message='Table kittens is migrated to cata4.felines.toms in Unity Catalog', + start_line=22, + start_col=8, + end_line=22, + end_col=30, + ), + Advice( + code='table-migrate', + message='The default format changed in Databricks Runtime 8.0, from Parquet to Delta', start_line=22, start_col=8, end_line=22, @@ -500,7 +524,7 @@ def test_notebook_linter_name(): ), Deprecation( code='table-migrate', - message='Table numbers is migrated to cata4.counting.numbers in ' 'Unity Catalog', + message='Table numbers is migrated to cata4.counting.numbers in Unity Catalog', start_line=26, start_col=0, end_line=26, @@ -508,7 +532,7 @@ def test_notebook_linter_name(): ), Advice( code='table-migrate', - message='The default format changed in Databricks Runtime 8.0, from ' 'Parquet to Delta', + message='The default format changed in Databricks Runtime 8.0, from Parquet to Delta', start_line=26, start_col=0, end_line=26,