diff --git a/src/databricks/labs/ucx/source_code/base.py b/src/databricks/labs/ucx/source_code/base.py index bc79231f58..ccd1a1fa7d 100644 --- a/src/databricks/labs/ucx/source_code/base.py +++ b/src/databricks/labs/ucx/source_code/base.py @@ -6,10 +6,11 @@ from dataclasses import dataclass from pathlib import Path -from astroid import NodeNG # type: ignore +from astroid import AstroidSyntaxError, NodeNG # type: ignore from databricks.sdk.service import compute +from databricks.labs.ucx.source_code.linters.python_ast import Tree # Code mapping between LSP, PyLint, and our own diagnostics: # | LSP | PyLint | Our | @@ -130,6 +131,16 @@ class Linter: def lint(self, code: str) -> Iterable[Advice]: ... +class PythonLinter(Linter): + + def lint(self, code: str) -> Iterable[Advice]: + tree = Tree.normalize_and_parse(code) + yield from self.lint_tree(tree) + + @abstractmethod + def lint_tree(self, tree: Tree) -> Iterable[Advice]: ... + + class Fixer: @abstractmethod def name(self) -> str: ... @@ -170,3 +181,22 @@ def __init__(self, linters: list[Linter]): def lint(self, code: str) -> Iterable[Advice]: for linter in self._linters: yield from linter.lint(code) + + +class PythonSequentialLinter(Linter): + + def __init__(self, linters: list[PythonLinter]): + self._linters = linters + self._tree: Tree | None = None + + def lint(self, code: str) -> Iterable[Advice]: + try: + tree = Tree.normalize_and_parse(code) + if self._tree is None: + self._tree = tree + else: + tree = self._tree.append_statements(tree) + for linter in self._linters: + yield from linter.lint_tree(tree) + except AstroidSyntaxError as e: + yield Failure('syntax-error', str(e), 0, 0, 0, 0) diff --git a/src/databricks/labs/ucx/source_code/known.json b/src/databricks/labs/ucx/source_code/known.json index 6a66138446..61e5b151f6 100644 --- a/src/databricks/labs/ucx/source_code/known.json +++ b/src/databricks/labs/ucx/source_code/known.json @@ -1265,7 +1265,7 @@ "code": "dbfs-usage", "message": "Deprecated file system path: dbfs:/" }, - { + { "code": "table-migrate", "message": "The default format changed in Databricks Runtime 8.0, from Parquet to Delta" } @@ -2572,6 +2572,14 @@ "dockerpycreds.utils": [], "dockerpycreds.version": [] }, + "docstring-to-markdown": { + "docstring_to_markdown": [], + "docstring_to_markdown._utils": [], + "docstring_to_markdown.cpython": [], + "docstring_to_markdown.google": [], + "docstring_to_markdown.plain": [], + "docstring_to_markdown.rst": [] + }, "entrypoints": { "entrypoints": [] }, @@ -21782,6 +21790,53 @@ "python-dateutil": { "dateutil": [] }, + "python-lsp-jsonrpc": { + "pylsp_jsonrpc": [], + "pylsp_jsonrpc._version": [], + "pylsp_jsonrpc.dispatchers": [], + "pylsp_jsonrpc.endpoint": [], + "pylsp_jsonrpc.exceptions": [], + "pylsp_jsonrpc.streams": [] + }, + "python-lsp-server": { + "pylsp": [], + "pylsp._utils": [], + "pylsp._version": [], + "pylsp.config": [], + "pylsp.config.config": [], + "pylsp.config.flake8_conf": [], + "pylsp.config.pycodestyle_conf": [], + "pylsp.config.source": [], + "pylsp.hookspecs": [], + "pylsp.lsp": [], + "pylsp.plugins": [], + "pylsp.plugins._resolvers": [], + "pylsp.plugins._rope_task_handle": [], + "pylsp.plugins.autopep8_format": [], + "pylsp.plugins.definition": [], + "pylsp.plugins.flake8_lint": [], + "pylsp.plugins.folding": [], + "pylsp.plugins.highlight": [], + "pylsp.plugins.hover": [], + "pylsp.plugins.jedi_completion": [], + "pylsp.plugins.jedi_rename": [], + "pylsp.plugins.mccabe_lint": [], + "pylsp.plugins.preload_imports": [], + "pylsp.plugins.pycodestyle_lint": [], + "pylsp.plugins.pydocstyle_lint": [], + "pylsp.plugins.pyflakes_lint": [], + "pylsp.plugins.pylint_lint": [], + "pylsp.plugins.references": [], + "pylsp.plugins.rope_autoimport": [], + "pylsp.plugins.rope_completion": [], + "pylsp.plugins.signature": [], + "pylsp.plugins.symbols": [], + "pylsp.plugins.yapf_format": [], + "pylsp.python_lsp": [], + "pylsp.text_edit": [], + "pylsp.uris": [], + "pylsp.workspace": [] + }, "pytz": { "pytz": [] }, @@ -25156,6 +25211,7 @@ "tzdata": { "tzdata": [] }, + "ujson": {}, "umap": { "umap": [], "umap.get": [] @@ -25957,5 +26013,4 @@ "zipp.compat.py310": [], "zipp.glob": [] } -} - +} \ No newline at end of file diff --git a/src/databricks/labs/ucx/source_code/linters/context.py b/src/databricks/labs/ucx/source_code/linters/context.py index d36776e46e..fad60107f9 100644 --- a/src/databricks/labs/ucx/source_code/linters/context.py +++ b/src/databricks/labs/ucx/source_code/linters/context.py @@ -1,7 +1,16 @@ +from typing import cast + from databricks.sdk.service.workspace import Language from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex -from databricks.labs.ucx.source_code.base import Fixer, Linter, SequentialLinter, CurrentSessionState +from databricks.labs.ucx.source_code.base import ( + Fixer, + Linter, + SequentialLinter, + CurrentSessionState, + PythonSequentialLinter, + PythonLinter, +) from databricks.labs.ucx.source_code.linters.dbfs import FromDbfsFolder, DBFSUsageLinter from databricks.labs.ucx.source_code.linters.imports import DbutilsLinter @@ -16,7 +25,7 @@ def __init__(self, index: MigrationIndex | None = None, session_state: CurrentSe self._index = index session_state = CurrentSessionState() if not session_state else session_state - python_linters: list[Linter] = [] + python_linters: list[PythonLinter] = [] python_fixers: list[Fixer] = [] sql_linters: list[Linter] = [] @@ -38,9 +47,9 @@ def __init__(self, index: MigrationIndex | None = None, session_state: CurrentSe ] sql_linters.append(FromDbfsFolder()) - self._linters = { - Language.PYTHON: SequentialLinter(python_linters), - Language.SQL: SequentialLinter(sql_linters), + self._linters: dict[Language, list[Linter] | list[PythonLinter]] = { + Language.PYTHON: python_linters, + Language.SQL: sql_linters, } self._fixers: dict[Language, list[Fixer]] = { Language.PYTHON: python_fixers, @@ -53,7 +62,9 @@ def is_supported(self, language: Language) -> bool: def linter(self, language: Language) -> Linter: if language not in self._linters: raise ValueError(f"Unsupported language: {language}") - return self._linters[language] + if language is Language.PYTHON: + return PythonSequentialLinter(cast(list[PythonLinter], self._linters[language])) + return SequentialLinter(cast(list[Linter], self._linters[language])) def fixer(self, language: Language, diagnostic_code: str) -> Fixer | None: if language not in self._fixers: diff --git a/src/databricks/labs/ucx/source_code/linters/dbfs.py b/src/databricks/labs/ucx/source_code/linters/dbfs.py index b33c25698a..e5c7989793 100644 --- a/src/databricks/labs/ucx/source_code/linters/dbfs.py +++ b/src/databricks/labs/ucx/source_code/linters/dbfs.py @@ -5,8 +5,9 @@ from sqlglot import Expression, parse as parse_sql, ParseError as SqlParseError from sqlglot.expressions import Table -from databricks.labs.ucx.source_code.base import Advice, Linter, Deprecation, CurrentSessionState, Failure -from databricks.labs.ucx.source_code.linters.python_ast import Tree, TreeVisitor, InferredValue +from databricks.labs.ucx.source_code.base import Advice, Linter, Deprecation, CurrentSessionState, Failure, PythonLinter +from databricks.labs.ucx.source_code.linters.python_ast import Tree, TreeVisitor +from databricks.labs.ucx.source_code.linters.python_infer import InferredValue logger = logging.getLogger(__name__) @@ -29,7 +30,7 @@ def visit_call(self, node: Call): def _visit_arg(self, arg: NodeNG): try: - for inferred in Tree(arg).infer_values(self._session_state): + for inferred in InferredValue.infer_from_node(arg, self._session_state): if not inferred.is_inferred(): logger.debug(f"Could not infer value of {arg.as_string()}") continue @@ -64,7 +65,7 @@ def get_advices(self) -> Iterable[Advice]: yield from self._advices -class DBFSUsageLinter(Linter): +class DBFSUsageLinter(PythonLinter): def __init__(self, session_state: CurrentSessionState): self._session_state = session_state @@ -76,11 +77,10 @@ def name() -> str: """ return 'dbfs-usage' - def lint(self, code: str) -> Iterable[Advice]: + def lint_tree(self, tree: Tree) -> Iterable[Advice]: """ Lints the code looking for file system paths that are deprecated """ - tree = Tree.normalize_and_parse(code) visitor = DetectDbfsVisitor(self._session_state) visitor.visit(tree.node) yield from visitor.get_advices() diff --git a/src/databricks/labs/ucx/source_code/linters/imports.py b/src/databricks/labs/ucx/source_code/linters/imports.py index 5e408e2176..1ef187fad7 100644 --- a/src/databricks/labs/ucx/source_code/linters/imports.py +++ b/src/databricks/labs/ucx/source_code/linters/imports.py @@ -16,8 +16,9 @@ NodeNG, ) -from databricks.labs.ucx.source_code.base import Linter, Advice, Advisory, CurrentSessionState -from databricks.labs.ucx.source_code.linters.python_ast import Tree, NodeBase, TreeVisitor, InferredValue +from databricks.labs.ucx.source_code.base import Advice, Advisory, CurrentSessionState, PythonLinter +from databricks.labs.ucx.source_code.linters.python_ast import Tree, NodeBase, TreeVisitor +from databricks.labs.ucx.source_code.linters.python_infer import InferredValue logger = logging.getLogger(__name__) @@ -90,7 +91,7 @@ def get_notebook_paths(self, session_state: CurrentSessionState) -> tuple[bool, """ arg = DbutilsLinter.get_dbutils_notebook_run_path_arg(self.node) try: - all_inferred = Tree(arg).infer_values(session_state) + all_inferred = InferredValue.infer_from_node(arg, session_state) return self._get_notebook_paths(all_inferred) except InferenceError: logger.debug(f"Can't infer value(s) of {arg.as_string()}") @@ -110,13 +111,12 @@ def _get_notebook_paths(cls, all_inferred: Iterable[InferredValue]) -> tuple[boo return has_unresolved, paths -class DbutilsLinter(Linter): +class DbutilsLinter(PythonLinter): def __init__(self, session_state: CurrentSessionState): self._session_state = session_state - def lint(self, code: str) -> Iterable[Advice]: - tree = Tree.normalize_and_parse(code) + def lint_tree(self, tree: Tree) -> Iterable[Advice]: nodes = self.list_dbutils_notebook_run_calls(tree) for node in nodes: yield from self._raise_advice_if_unresolved(node.node, self._session_state) @@ -229,7 +229,7 @@ def visit_call(self, node: Call): relative = True changed = changed.args[0] try: - for inferred in Tree(changed).infer_values(self._session_state): + for inferred in InferredValue.infer_from_node(changed, self._session_state): self._visit_inferred(changed, inferred, relative, is_append) except InferenceError: self.sys_path_changes.append(UnresolvedPath(changed, changed.as_string(), is_append)) diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index d78267ebcb..b56bdc544d 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -2,19 +2,19 @@ from collections.abc import Iterable, Iterator from dataclasses import dataclass -from astroid import Attribute, Call, Const, InferenceError, NodeNG, AstroidSyntaxError # type: ignore +from astroid import Attribute, Call, Const, InferenceError, NodeNG # type: ignore from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex from databricks.labs.ucx.source_code.base import ( Advice, Advisory, Deprecation, Fixer, - Linter, - Failure, CurrentSessionState, + PythonLinter, ) +from databricks.labs.ucx.source_code.linters.python_infer import InferredValue from databricks.labs.ucx.source_code.queries import FromTable -from databricks.labs.ucx.source_code.linters.python_ast import Tree, InferredValue +from databricks.labs.ucx.source_code.linters.python_ast import Tree @dataclass @@ -78,7 +78,7 @@ def lint( table_arg = self._get_table_arg(node) if table_arg: try: - for inferred in Tree(table_arg).infer_values(self.session_state): + for inferred in InferredValue.infer_from_node(table_arg, self.session_state): yield from self._lint_table_arg(from_table, node, inferred) except InferenceError: yield Advisory.from_node( @@ -114,7 +114,7 @@ def lint( ) -> Iterator[Advice]: table_arg = self._get_table_arg(node) table_name = table_arg.as_string().strip("'").strip('"') - for inferred in Tree(table_arg).infer_values(session_state): + for inferred in InferredValue.infer_from_node(table_arg, session_state): if not inferred.is_inferred(): yield Advisory.from_node( code='table-migrate-cannot-compute-value', @@ -315,7 +315,7 @@ def matchers(self): return self._matchers -class SparkSql(Linter, Fixer): +class SparkSql(PythonLinter, Fixer): _spark_matchers = SparkMatchers() @@ -328,12 +328,7 @@ def name(self) -> str: # this is the same fixer, just in a different language context return self._from_table.name() - def lint(self, code: str) -> Iterable[Advice]: - try: - tree = Tree.normalize_and_parse(code) - except AstroidSyntaxError as e: - yield Failure('syntax-error', str(e), 0, 0, 0, 0) - return + def lint_tree(self, tree: Tree) -> Iterable[Advice]: for node in tree.walk(): matcher = self._find_matcher(node) if matcher is None: diff --git a/src/databricks/labs/ucx/source_code/linters/python_ast.py b/src/databricks/labs/ucx/source_code/linters/python_ast.py index 3a9ca4553f..97959c3687 100644 --- a/src/databricks/labs/ucx/source_code/linters/python_ast.py +++ b/src/databricks/labs/ucx/source_code/linters/python_ast.py @@ -3,15 +3,21 @@ from abc import ABC import logging import re -from collections.abc import Iterable, Iterator, Generator -from typing import Any, TypeVar - -from astroid import Assign, Attribute, Call, Const, decorators, Dict, FormattedValue, Import, ImportFrom, JoinedStr, Module, Name, NodeNG, parse, Uninferable # type: ignore -from astroid.context import InferenceContext, InferenceResult, CallContext # type: ignore -from astroid.typing import InferenceErrorInfo # type: ignore -from astroid.exceptions import InferenceError # type: ignore - -from databricks.labs.ucx.source_code.base import CurrentSessionState +from collections.abc import Iterable +from typing import TypeVar, cast + +from astroid import ( # type: ignore + Assign, + Attribute, + Call, + Const, + Import, + ImportFrom, + Module, + Name, + NodeNG, + parse, +) logger = logging.getLogger(__name__) @@ -192,191 +198,20 @@ def _get_attribute_value(cls, node: Attribute): logger.debug(f"Missing handler for {name}") return None - def infer_values(self, state: CurrentSessionState | None = None) -> Iterable[InferredValue]: - self._contextualize(state) - for inferred_atoms in self._infer_values(): - yield InferredValue(inferred_atoms) - - def _contextualize(self, state: CurrentSessionState | None): - if state is None or state.named_parameters is None or len(state.named_parameters) == 0: - return - self._contextualize_dbutils_widgets_get(state) - self._contextualize_dbutils_widgets_get_all(state) - - def _contextualize_dbutils_widgets_get(self, state: CurrentSessionState): - calls = Tree(self.root).locate(Call, [("get", Attribute), ("widgets", Attribute), ("dbutils", Name)]) - for call in calls: - call.func = _DbUtilsWidgetsGetCall(state, call) - - def _contextualize_dbutils_widgets_get_all(self, state: CurrentSessionState): - calls = Tree(self.root).locate(Call, [("getAll", Attribute), ("widgets", Attribute), ("dbutils", Name)]) - for call in calls: - call.func = _DbUtilsWidgetsGetAllCall(state, call) - - def _infer_values(self) -> Iterator[Iterable[NodeNG]]: - # deal with node types that don't implement 'inferred()' - if self._node is Uninferable or isinstance(self._node, Const): - yield [self._node] - elif isinstance(self._node, JoinedStr): - yield from self._infer_values_from_joined_string() - elif isinstance(self._node, FormattedValue): - yield from _LocalTree(self._node.value).do_infer_values() - else: - yield from self._infer_internal() - - def _infer_internal(self): - try: - for inferred in self._node.inferred(): - # work around infinite recursion of empty lists - if inferred == self._node: - continue - yield from _LocalTree(inferred).do_infer_values() - except InferenceError as e: - logger.debug(f"When inferring {self._node}", exc_info=e) - yield [Uninferable] - - def _infer_values_from_joined_string(self) -> Iterator[Iterable[NodeNG]]: - assert isinstance(self._node, JoinedStr) - yield from self._infer_values_from_joined_values(self._node.values) - - @classmethod - def _infer_values_from_joined_values(cls, nodes: list[NodeNG]) -> Iterator[Iterable[NodeNG]]: - if len(nodes) == 1: - yield from _LocalTree(nodes[0]).do_infer_values() - return - for firsts in _LocalTree(nodes[0]).do_infer_values(): - for remains in cls._infer_values_from_joined_values(nodes[1:]): - yield list(firsts) + list(remains) - - -class _LocalTree(Tree): - """class that avoids pylint W0212 protected-access warning""" - - def do_infer_values(self): - return self._infer_values() - - -class _DbUtilsWidgetsGetCall(NodeNG): - - def __init__(self, session_state: CurrentSessionState, node: NodeNG): - super().__init__( - lineno=node.lineno, - col_offset=node.col_offset, - end_lineno=node.end_lineno, - end_col_offset=node.end_col_offset, - parent=node.parent, - ) - self._session_state = session_state - - @decorators.raise_if_nothing_inferred - def _infer( - self, context: InferenceContext | None = None, **kwargs: Any - ) -> Generator[InferenceResult, None, InferenceErrorInfo | None]: - yield self - return InferenceErrorInfo(node=self, context=context) - - def infer_call_result(self, context: InferenceContext | None = None, **_): # caller needs unused kwargs - call_context = getattr(context, "callcontext", None) - if not isinstance(call_context, CallContext): - yield Uninferable - return - arg = call_context.args[0] - for inferred in Tree(arg).infer_values(self._session_state): - if not inferred.is_inferred(): - yield Uninferable - continue - name = inferred.as_string() - named_parameters = self._session_state.named_parameters - if not named_parameters or name not in named_parameters: - yield Uninferable - continue - value = named_parameters[name] - yield Const( - value, - lineno=self.lineno, - col_offset=self.col_offset, - end_lineno=self.end_lineno, - end_col_offset=self.end_col_offset, - parent=self, - ) - - -class _DbUtilsWidgetsGetAllCall(NodeNG): - - def __init__(self, session_state: CurrentSessionState, node: NodeNG): - super().__init__( - lineno=node.lineno, - col_offset=node.col_offset, - end_lineno=node.end_lineno, - end_col_offset=node.end_col_offset, - parent=node.parent, - ) - self._session_state = session_state - - @decorators.raise_if_nothing_inferred - def _infer( - self, context: InferenceContext | None = None, **kwargs: Any - ) -> Generator[InferenceResult, None, InferenceErrorInfo | None]: - yield self - return InferenceErrorInfo(node=self, context=context) - - def infer_call_result(self, **_): # caller needs unused kwargs - named_parameters = self._session_state.named_parameters - if not named_parameters: - yield Uninferable - return - items = self._populate_items(named_parameters) - result = Dict( - lineno=self.lineno, - col_offset=self.col_offset, - end_lineno=self.end_lineno, - end_col_offset=self.end_col_offset, - parent=self, - ) - result.postinit(items) - yield result - - def _populate_items(self, values: dict[str, str]): - items: list[tuple[InferenceResult, InferenceResult]] = [] - for key, value in values.items(): - item_key = Const( - key, - lineno=self.lineno, - col_offset=self.col_offset, - end_lineno=self.end_lineno, - end_col_offset=self.end_col_offset, - parent=self, - ) - item_value = Const( - value, - lineno=self.lineno, - col_offset=self.col_offset, - end_lineno=self.end_lineno, - end_col_offset=self.end_col_offset, - parent=self, - ) - items.append((item_key, item_value)) - return items - - -class InferredValue: - """Represents 1 or more nodes that together represent the value. - The list of nodes typically holds one Const element, but for f-strings it - can hold multiple ones, including Uninferable nodes.""" - - def __init__(self, atoms: Iterable[NodeNG]): - self._atoms = list(atoms) - - @property - def nodes(self): - return self._atoms - - def is_inferred(self): - return all(atom is not Uninferable for atom in self._atoms) - - def as_string(self): - strings = [str(const.value) for const in filter(lambda atom: isinstance(atom, Const), self._atoms)] - return "".join(strings) + def append_statements(self, tree: Tree) -> Tree: + if not isinstance(tree.node, Module): + raise NotImplementedError(f"Can't append statements from {type(tree.node).__name__}") + tree_module: Module = cast(Module, tree.node) + if not isinstance(self.node, Module): + raise NotImplementedError(f"Can't append statements to {type(self.node).__name__}") + self_module: Module = cast(Module, self.node) + for stmt in tree_module.body: + stmt.parent = self_module + self_module.body.append(stmt) + for name, value in tree_module.globals.items(): + self_module.globals[name] = value + # the following may seem strange but it's actually ok to use the original module as tree root + return tree class TreeVisitor: diff --git a/src/databricks/labs/ucx/source_code/linters/python_infer.py b/src/databricks/labs/ucx/source_code/linters/python_infer.py new file mode 100644 index 0000000000..073ab362a6 --- /dev/null +++ b/src/databricks/labs/ucx/source_code/linters/python_infer.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +import logging +from collections.abc import Iterable, Iterator, Generator +from typing import Any + +from astroid import ( # type: ignore + Attribute, + Call, + Const, + decorators, + Dict, + FormattedValue, + JoinedStr, + Name, + NodeNG, + Uninferable, +) +from astroid.context import InferenceContext, InferenceResult, CallContext # type: ignore +from astroid.typing import InferenceErrorInfo # type: ignore +from astroid.exceptions import InferenceError # type: ignore + +from databricks.labs.ucx.source_code.base import CurrentSessionState +from databricks.labs.ucx.source_code.linters.python_ast import Tree + +logger = logging.getLogger(__name__) + + +class InferredValue: + """Represents 1 or more nodes that together represent the value. + The list of nodes typically holds one Const element, but for f-strings it + can hold multiple ones, including Uninferable nodes.""" + + @classmethod + def infer_from_node(cls, node: NodeNG, state: CurrentSessionState | None = None) -> Iterable[InferredValue]: + cls._contextualize(node, state) + for inferred_atoms in cls._infer_values(node): + yield InferredValue(inferred_atoms) + + @classmethod + def _contextualize(cls, node: NodeNG, state: CurrentSessionState | None): + if state is None or state.named_parameters is None or len(state.named_parameters) == 0: + return + cls._contextualize_dbutils_widgets_get(node, state) + cls._contextualize_dbutils_widgets_get_all(node, state) + + @classmethod + def _contextualize_dbutils_widgets_get(cls, node: NodeNG, state: CurrentSessionState): + root = Tree(node).root + calls = Tree(root).locate(Call, [("get", Attribute), ("widgets", Attribute), ("dbutils", Name)]) + for call in calls: + call.func = _DbUtilsWidgetsGetCall(state, call) + + @classmethod + def _contextualize_dbutils_widgets_get_all(cls, node: NodeNG, state: CurrentSessionState): + root = Tree(node).root + calls = Tree(root).locate(Call, [("getAll", Attribute), ("widgets", Attribute), ("dbutils", Name)]) + for call in calls: + call.func = _DbUtilsWidgetsGetAllCall(state, call) + + @classmethod + def _infer_values(cls, node: NodeNG) -> Iterator[Iterable[NodeNG]]: + # deal with node types that don't implement 'inferred()' + if node is Uninferable or isinstance(node, Const): + yield [node] + elif isinstance(node, JoinedStr): + yield from cls._infer_values_from_joined_string(node) + elif isinstance(node, FormattedValue): + yield from _LocalInferredValue.do_infer_values(node.value) + else: + yield from cls._infer_internal(node) + + @classmethod + def _infer_internal(cls, node: NodeNG): + try: + for inferred in node.inferred(): + # work around infinite recursion of empty lists + if inferred == node: + continue + yield from _LocalInferredValue.do_infer_values(inferred) + except InferenceError as e: + logger.debug(f"When inferring {node}", exc_info=e) + yield [Uninferable] + + @classmethod + def _infer_values_from_joined_string(cls, node: NodeNG) -> Iterator[Iterable[NodeNG]]: + assert isinstance(node, JoinedStr) + yield from cls._infer_values_from_joined_values(node.values) + + @classmethod + def _infer_values_from_joined_values(cls, nodes: list[NodeNG]) -> Iterator[Iterable[NodeNG]]: + if len(nodes) == 1: + yield from _LocalInferredValue.do_infer_values(nodes[0]) + return + for firsts in _LocalInferredValue.do_infer_values(nodes[0]): + for remains in cls._infer_values_from_joined_values(nodes[1:]): + yield list(firsts) + list(remains) + + def __init__(self, atoms: Iterable[NodeNG]): + self._atoms = list(atoms) + + @property + def nodes(self): + return self._atoms + + def is_inferred(self): + return all(atom is not Uninferable for atom in self._atoms) + + def as_string(self): + strings = [str(const.value) for const in filter(lambda atom: isinstance(atom, Const), self._atoms)] + return "".join(strings) + + +class _DbUtilsWidgetsGetCall(NodeNG): + + def __init__(self, session_state: CurrentSessionState, node: NodeNG): + super().__init__( + lineno=node.lineno, + col_offset=node.col_offset, + end_lineno=node.end_lineno, + end_col_offset=node.end_col_offset, + parent=node.parent, + ) + self._session_state = session_state + + @decorators.raise_if_nothing_inferred + def _infer( + self, context: InferenceContext | None = None, **kwargs: Any + ) -> Generator[InferenceResult, None, InferenceErrorInfo | None]: + yield self + return InferenceErrorInfo(node=self, context=context) + + def infer_call_result(self, context: InferenceContext | None = None, **_): # caller needs unused kwargs + call_context = getattr(context, "callcontext", None) + if not isinstance(call_context, CallContext): + yield Uninferable + return + arg = call_context.args[0] + for inferred in InferredValue.infer_from_node(arg, self._session_state): + if not inferred.is_inferred(): + yield Uninferable + continue + name = inferred.as_string() + named_parameters = self._session_state.named_parameters + if not named_parameters or name not in named_parameters: + yield Uninferable + continue + value = named_parameters[name] + yield Const( + value, + lineno=self.lineno, + col_offset=self.col_offset, + end_lineno=self.end_lineno, + end_col_offset=self.end_col_offset, + parent=self, + ) + + +class _LocalInferredValue(InferredValue): + + @classmethod + def do_infer_values(cls, node: NodeNG): + yield from cls._infer_values(node) + + +class _DbUtilsWidgetsGetAllCall(NodeNG): + + def __init__(self, session_state: CurrentSessionState, node: NodeNG): + super().__init__( + lineno=node.lineno, + col_offset=node.col_offset, + end_lineno=node.end_lineno, + end_col_offset=node.end_col_offset, + parent=node.parent, + ) + self._session_state = session_state + + @decorators.raise_if_nothing_inferred + def _infer( + self, context: InferenceContext | None = None, **kwargs: Any + ) -> Generator[InferenceResult, None, InferenceErrorInfo | None]: + yield self + return InferenceErrorInfo(node=self, context=context) + + def infer_call_result(self, **_): # caller needs unused kwargs + named_parameters = self._session_state.named_parameters + if not named_parameters: + yield Uninferable + return + items = self._populate_items(named_parameters) + result = Dict( + lineno=self.lineno, + col_offset=self.col_offset, + end_lineno=self.end_lineno, + end_col_offset=self.end_col_offset, + parent=self, + ) + result.postinit(items) + yield result + + def _populate_items(self, values: dict[str, str]): + items: list[tuple[InferenceResult, InferenceResult]] = [] + for key, value in values.items(): + item_key = Const( + key, + lineno=self.lineno, + col_offset=self.col_offset, + end_lineno=self.end_lineno, + end_col_offset=self.end_col_offset, + parent=self, + ) + item_value = Const( + value, + lineno=self.lineno, + col_offset=self.col_offset, + end_lineno=self.end_lineno, + end_col_offset=self.end_col_offset, + parent=self, + ) + items.append((item_key, item_value)) + return items diff --git a/src/databricks/labs/ucx/source_code/linters/spark_connect.py b/src/databricks/labs/ucx/source_code/linters/spark_connect.py index 6b3b2f0d28..962057d2c8 100644 --- a/src/databricks/labs/ucx/source_code/linters/spark_connect.py +++ b/src/databricks/labs/ucx/source_code/linters/spark_connect.py @@ -6,7 +6,7 @@ from databricks.labs.ucx.source_code.base import ( Advice, Failure, - Linter, + PythonLinter, ) from databricks.labs.ucx.source_code.linters.python_ast import Tree @@ -172,7 +172,7 @@ def _match_jvm_log(self, node: NodeNG) -> Iterator[Advice]: ) -class SparkConnectLinter(Linter): +class SparkConnectLinter(PythonLinter): def __init__(self, is_serverless: bool = False): self._matchers = [ JvmAccessMatcher(is_serverless=is_serverless), @@ -181,7 +181,6 @@ def __init__(self, is_serverless: bool = False): LoggingMatcher(is_serverless=is_serverless), ] - def lint(self, code: str) -> Iterator[Advice]: - tree = Tree.normalize_and_parse(code) + def lint_tree(self, tree: Tree) -> Iterator[Advice]: for matcher in self._matchers: yield from matcher.lint_tree(tree.node) diff --git a/src/databricks/labs/ucx/source_code/linters/table_creation.py b/src/databricks/labs/ucx/source_code/linters/table_creation.py index 5d2b27044a..147c72e30c 100644 --- a/src/databricks/labs/ucx/source_code/linters/table_creation.py +++ b/src/databricks/labs/ucx/source_code/linters/table_creation.py @@ -7,7 +7,7 @@ from databricks.labs.ucx.source_code.base import ( Advice, - Linter, + PythonLinter, ) from databricks.labs.ucx.source_code.linters.python_ast import Tree @@ -92,7 +92,7 @@ def lint(self, node: NodeNG) -> Iterator[Advice]: ) -class DBRv8d0Linter(Linter): +class DBRv8d0Linter(PythonLinter): """Performs Python linting for backwards incompatible changes in DBR version 8.0. Specifically, it yields advice for table-creation with implicit format. """ @@ -111,9 +111,8 @@ def __init__(self, dbr_version: tuple[int, int] | None): ] ) - def lint(self, code: str) -> Iterable[Advice]: + def lint_tree(self, tree: Tree) -> Iterable[Advice]: if self._skip_dbr: return - tree = Tree.normalize_and_parse(code) for node in tree.walk(): yield from self._linter.lint(node) diff --git a/src/databricks/labs/ucx/source_code/notebooks/sources.py b/src/databricks/labs/ucx/source_code/notebooks/sources.py index 42177876af..1da4afaf73 100644 --- a/src/databricks/labs/ucx/source_code/notebooks/sources.py +++ b/src/databricks/labs/ucx/source_code/notebooks/sources.py @@ -9,7 +9,7 @@ from databricks.sdk.service.workspace import Language from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex -from databricks.labs.ucx.source_code.base import Advice, Failure +from databricks.labs.ucx.source_code.base import Advice, Failure, Linter from databricks.labs.ucx.source_code.graph import SourceContainer, DependencyGraph, DependencyProblem from databricks.labs.ucx.source_code.linters.context import LinterContext @@ -87,6 +87,8 @@ class NotebookLinter: def __init__(self, langs: LinterContext, notebook: Notebook): self._languages: LinterContext = langs self._notebook: Notebook = notebook + # reuse Python linter, which accumulates statements for improved inference + self._python_linter = langs.linter(Language.PYTHON) @classmethod def from_source(cls, index: MigrationIndex, source: str, default_language: Language) -> 'NotebookLinter': @@ -99,13 +101,18 @@ def lint(self) -> Iterable[Advice]: for cell in self._notebook.cells: if not self._languages.is_supported(cell.language.language): continue - linter = self._languages.linter(cell.language.language) + linter = self._linter(cell.language.language) for advice in linter.lint(cell.original_code): yield advice.replace( start_line=advice.start_line + cell.original_offset, end_line=advice.end_line + cell.original_offset, ) + def _linter(self, language: Language) -> Linter: + if language is Language.PYTHON: + return self._python_linter + return self._languages.linter(language) + @staticmethod def name() -> str: return "notebook-linter" diff --git a/tests/unit/source_code/linters/test_python_ast.py b/tests/unit/source_code/linters/test_python_ast.py index 51e0f7ed78..47569c7321 100644 --- a/tests/unit/source_code/linters/test_python_ast.py +++ b/tests/unit/source_code/linters/test_python_ast.py @@ -1,8 +1,8 @@ import pytest from astroid import Assign, AstroidSyntaxError, Attribute, Call, Const, Expr # type: ignore -from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.linters.python_ast import Tree +from databricks.labs.ucx.source_code.linters.python_infer import InferredValue def test_extracts_root(): @@ -99,163 +99,6 @@ def test_tree_walks_nodes_once(): assert len(nodes) == count -def test_infers_empty_list(): - tree = Tree.parse("a=[]") - nodes = tree.locate(Assign, []) - tree = Tree(nodes[0].value) - values = list(tree.infer_values()) - assert not values - - -def test_infers_empty_tuple(): - tree = Tree.parse("a=tuple()") - nodes = tree.locate(Assign, []) - tree = Tree(nodes[0].value) - values = list(tree.infer_values()) - assert not values - - -def test_infers_empty_set(): - tree = Tree.parse("a={}") - nodes = tree.locate(Assign, []) - tree = Tree(nodes[0].value) - values = list(tree.infer_values()) - assert not values - - -def test_infers_fstring_value(): - source = """ -value = "abc" -fstring = f"Hello {value}!" -""" - tree = Tree.parse(source) - nodes = tree.locate(Assign, []) - tree = Tree(nodes[1].value) # value of fstring = ... - values = list(tree.infer_values()) - assert all(value.is_inferred() for value in values) - strings = list(value.as_string() for value in values) - assert strings == ["Hello abc!"] - - -def test_infers_string_format_value(): - source = """ -value = "abc" -fstring = "Hello {0}!".format(value) -""" - tree = Tree.parse(source) - nodes = tree.locate(Assign, []) - tree = Tree(nodes[1].value) # value of fstring = ... - values = list(tree.infer_values()) - assert all(value.is_inferred() for value in values) - strings = list(value.as_string() for value in values) - assert strings == ["Hello abc!"] - - -def test_infers_fstring_values(): - source = """ -values_1 = ["abc", "def"] -for value1 in values_1: - values_2 = ["ghi", "jkl"] - for value2 in values_2: - fstring = f"Hello {value1}, {value2}!" -""" - tree = Tree.parse(source) - nodes = tree.locate(Assign, []) - tree = Tree(nodes[2].value) # value of fstring = ... - values = list(tree.infer_values()) - assert all(value.is_inferred() for value in values) - strings = list(value.as_string() for value in values) - assert strings == ["Hello abc, ghi!", "Hello abc, jkl!", "Hello def, ghi!", "Hello def, jkl!"] - - -def test_fails_to_infer_cascading_fstring_values(): - # The purpose of this test is to detect a change in astroid support for f-strings - source = """ -value1 = "John" -value2 = f"Hello {value1}" -value3 = f"{value2}, how are you today?" -""" - tree = Tree.parse(source) - nodes = tree.locate(Assign, []) - tree = Tree(nodes[2].value) # value of value3 = ... - values = list(tree.infer_values()) - # for now, we simply check failure to infer! - assert any(not value.is_inferred() for value in values) - # the expected value would be ["Hello John, how are you today?"] - - -def test_infers_externally_defined_value(): - state = CurrentSessionState() - state.named_parameters = {"my-widget": "my-value"} - source = """ -name = "my-widget" -value = dbutils.widgets.get(name) -""" - tree = Tree.parse(source) - nodes = tree.locate(Assign, []) - tree = Tree(nodes[1].value) # value of value = ... - values = list(tree.infer_values(state)) - strings = list(value.as_string() for value in values) - assert strings == ["my-value"] - - -def test_infers_externally_defined_values(): - state = CurrentSessionState() - state.named_parameters = {"my-widget-1": "my-value-1", "my-widget-2": "my-value-2"} - source = """ -for name in ["my-widget-1", "my-widget-2"]: - value = dbutils.widgets.get(name) -""" - tree = Tree.parse(source) - nodes = tree.locate(Assign, []) - tree = Tree(nodes[0].value) # value of value = ... - values = list(tree.infer_values(state)) - strings = list(value.as_string() for value in values) - assert strings == ["my-value-1", "my-value-2"] - - -def test_fails_to_infer_missing_externally_defined_value(): - state = CurrentSessionState() - state.named_parameters = {"my-widget-1": "my-value-1", "my-widget-2": "my-value-2"} - source = """ -name = "my-widget" -value = dbutils.widgets.get(name) -""" - tree = Tree.parse(source) - nodes = tree.locate(Assign, []) - tree = Tree(nodes[1].value) # value of value = ... - values = tree.infer_values(state) - assert all(not value.is_inferred() for value in values) - - -def test_survives_absence_of_externally_defined_values(): - source = """ - name = "my-widget" - value = dbutils.widgets.get(name) - """ - tree = Tree.parse(source) - nodes = tree.locate(Assign, []) - tree = Tree(nodes[1].value) # value of value = ... - values = tree.infer_values(CurrentSessionState()) - assert all(not value.is_inferred() for value in values) - - -def test_infers_externally_defined_value_set(): - state = CurrentSessionState() - state.named_parameters = {"my-widget": "my-value"} - source = """ -values = dbutils.widgets.getAll() -name = "my-widget" -value = values[name] -""" - tree = Tree.parse(source) - nodes = tree.locate(Assign, []) - tree = Tree(nodes[2].value) # value of value = ... - values = list(tree.infer_values(state)) - strings = list(value.as_string() for value in values) - assert strings == ["my-value"] - - def test_parses_incorrectly_indented_code(): source = """# DBTITLE 1,Get Sales Data for Analysis sales = ( @@ -291,3 +134,16 @@ def test_ignores_magic_marker_in_multiline_comment(): """ Tree.normalize_and_parse(source) assert True + + +def test_appends_statements(): + source_1 = "a = 'John'" + tree_1 = Tree.normalize_and_parse(source_1) + source_2 = 'b = f"Hello {a}!"' + tree_2 = Tree.normalize_and_parse(source_2) + tree_3 = tree_1.append_statements(tree_2) + nodes = tree_3.locate(Assign, []) + tree = Tree(nodes[0].value) # tree_3 only contains tree_2 statements + values = list(InferredValue.infer_from_node(tree.node)) + strings = list(value.as_string() for value in values) + assert strings == ["Hello John!"] diff --git a/tests/unit/source_code/linters/test_python_infer.py b/tests/unit/source_code/linters/test_python_infer.py new file mode 100644 index 0000000000..f838cb7154 --- /dev/null +++ b/tests/unit/source_code/linters/test_python_infer.py @@ -0,0 +1,176 @@ +from astroid import Assign # type: ignore + +from databricks.labs.ucx.source_code.base import CurrentSessionState +from databricks.labs.ucx.source_code.linters.python_ast import Tree +from databricks.labs.ucx.source_code.linters.python_infer import InferredValue + + +def test_infers_empty_list(): + tree = Tree.parse("a=[]") + nodes = tree.locate(Assign, []) + tree = Tree(nodes[0].value) + values = list(InferredValue.infer_from_node(tree.node)) + assert not values + + +def test_infers_empty_tuple(): + tree = Tree.parse("a=tuple()") + nodes = tree.locate(Assign, []) + tree = Tree(nodes[0].value) + values = list(InferredValue.infer_from_node(tree.node)) + assert not values + + +def test_infers_empty_set(): + tree = Tree.parse("a={}") + nodes = tree.locate(Assign, []) + tree = Tree(nodes[0].value) + values = list(InferredValue.infer_from_node(tree.node)) + assert not values + + +def test_infers_fstring_value(): + source = """ +value = "abc" +fstring = f"Hello {value}!" +""" + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[1].value) # value of fstring = ... + values = list(InferredValue.infer_from_node(tree.node)) + assert all(value.is_inferred() for value in values) + strings = list(value.as_string() for value in values) + assert strings == ["Hello abc!"] + + +def test_infers_fstring_dict_value(): + source = """ +value = { "abc": 123 } +fstring = f"Hello {value['abc']}!" +""" + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[1].value) # value of fstring = ... + values = list(InferredValue.infer_from_node(tree.node)) + assert all(value.is_inferred() for value in values) + strings = list(value.as_string() for value in values) + assert strings == ["Hello 123!"] + + +def test_infers_string_format_value(): + source = """ +value = "abc" +fstring = "Hello {0}!".format(value) +""" + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[1].value) # value of fstring = ... + values = list(InferredValue.infer_from_node(tree.node)) + assert all(value.is_inferred() for value in values) + strings = list(value.as_string() for value in values) + assert strings == ["Hello abc!"] + + +def test_infers_fstring_values(): + source = """ +values_1 = ["abc", "def"] +for value1 in values_1: + values_2 = ["ghi", "jkl"] + for value2 in values_2: + fstring = f"Hello {value1}, {value2}!" +""" + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[2].value) # value of fstring = ... + values = list(InferredValue.infer_from_node(tree.node)) + assert all(value.is_inferred() for value in values) + strings = list(value.as_string() for value in values) + assert strings == ["Hello abc, ghi!", "Hello abc, jkl!", "Hello def, ghi!", "Hello def, jkl!"] + + +def test_fails_to_infer_cascading_fstring_values(): + # The purpose of this test is to detect a change in astroid support for f-strings + source = """ +value1 = "John" +value2 = f"Hello {value1}" +value3 = f"{value2}, how are you today?" +""" + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[2].value) # value of value3 = ... + values = list(InferredValue.infer_from_node(tree.node)) + # for now, we simply check failure to infer! + assert any(not value.is_inferred() for value in values) + # the expected value would be ["Hello John, how are you today?"] + + +def test_infers_externally_defined_value(): + state = CurrentSessionState() + state.named_parameters = {"my-widget": "my-value"} + source = """ +name = "my-widget" +value = dbutils.widgets.get(name) +""" + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[1].value) # value of value = ... + values = list(InferredValue.infer_from_node(tree.node, state)) + strings = list(value.as_string() for value in values) + assert strings == ["my-value"] + + +def test_infers_externally_defined_values(): + state = CurrentSessionState() + state.named_parameters = {"my-widget-1": "my-value-1", "my-widget-2": "my-value-2"} + source = """ +for name in ["my-widget-1", "my-widget-2"]: + value = dbutils.widgets.get(name) +""" + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[0].value) # value of value = ... + values = list(InferredValue.infer_from_node(tree.node, state)) + strings = list(value.as_string() for value in values) + assert strings == ["my-value-1", "my-value-2"] + + +def test_fails_to_infer_missing_externally_defined_value(): + state = CurrentSessionState() + state.named_parameters = {"my-widget-1": "my-value-1", "my-widget-2": "my-value-2"} + source = """ +name = "my-widget" +value = dbutils.widgets.get(name) +""" + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[1].value) # value of value = ... + values = InferredValue.infer_from_node(tree.node, state) + assert all(not value.is_inferred() for value in values) + + +def test_survives_absence_of_externally_defined_values(): + source = """ + name = "my-widget" + value = dbutils.widgets.get(name) + """ + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[1].value) # value of value = ... + values = InferredValue.infer_from_node(tree.node, CurrentSessionState()) + assert all(not value.is_inferred() for value in values) + + +def test_infers_externally_defined_value_set(): + state = CurrentSessionState() + state.named_parameters = {"my-widget": "my-value"} + source = """ +values = dbutils.widgets.getAll() +name = "my-widget" +value = values[name] +""" + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[2].value) # value of value = ... + values = list(InferredValue.infer_from_node(tree.node, state)) + strings = list(value.as_string() for value in values) + assert strings == ["my-value"] diff --git a/tests/unit/source_code/samples/values_across_cells.py b/tests/unit/source_code/samples/values_across_cells.py new file mode 100644 index 0000000000..fe7366eeb5 --- /dev/null +++ b/tests/unit/source_code/samples/values_across_cells.py @@ -0,0 +1,6 @@ +# Databricks notebook source +a = 12 + +# COMMAND ---------- + +spark.table(f"{a}") diff --git a/tests/unit/source_code/test_notebook_linter.py b/tests/unit/source_code/test_notebook_linter.py index d507f03b38..18b4956cc6 100644 --- a/tests/unit/source_code/test_notebook_linter.py +++ b/tests/unit/source_code/test_notebook_linter.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest from databricks.sdk.service.workspace import Language @@ -547,3 +549,21 @@ def test_notebook_linter_tracks_use(extended_test_index, lang, source, expected) assert linter is not None advices = list(linter.lint()) assert advices == expected + + +def test_computes_values_across_cells(extended_test_index, mock_path_lookup): + path = mock_path_lookup.resolve(Path("values_across_cells.py")) + source = path.read_text() + linter = NotebookLinter.from_source(extended_test_index, source, Language.PYTHON) + advices = list(linter.lint()) + expected = [ + Advice( + code='table-migrate', + message='The default format changed in Databricks Runtime 8.0, from Parquet to Delta', + start_line=5, + start_col=0, + end_line=5, + end_col=19, + ) + ] + assert advices == expected