From 74095ebefc274889a062f51830491649dd2bb743 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 12 Jun 2024 11:15:33 +0200 Subject: [PATCH 01/23] rename tree.root to tree.node --- .../labs/ucx/source_code/linters/dbfs.py | 2 +- .../labs/ucx/source_code/linters/imports.py | 4 +- .../labs/ucx/source_code/linters/pyspark.py | 2 +- .../ucx/source_code/linters/python_ast.py | 43 +++++++++++-------- .../ucx/source_code/linters/spark_connect.py | 2 +- .../source_code/linters/test_python_ast.py | 11 +++-- 6 files changed, 36 insertions(+), 28 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/dbfs.py b/src/databricks/labs/ucx/source_code/linters/dbfs.py index e34ce3c189..4348067603 100644 --- a/src/databricks/labs/ucx/source_code/linters/dbfs.py +++ b/src/databricks/labs/ucx/source_code/linters/dbfs.py @@ -80,7 +80,7 @@ def lint(self, code: str) -> Iterable[Advice]: """ tree = Tree.parse(code) visitor = DetectDbfsVisitor() - visitor.visit(tree.root) + visitor.visit(tree.node) yield from visitor.get_advices() diff --git a/src/databricks/labs/ucx/source_code/linters/imports.py b/src/databricks/labs/ucx/source_code/linters/imports.py index a9a7e60c90..20d52d2f65 100644 --- a/src/databricks/labs/ucx/source_code/linters/imports.py +++ b/src/databricks/labs/ucx/source_code/linters/imports.py @@ -44,7 +44,7 @@ def extract_from_tree(cls, tree: Tree, problem_factory: ProblemFactory) -> tuple sources.append(source) return sources, problems except Exception as e: # pylint: disable=broad-except - problem = problem_factory('internal-error', f"While checking imports: {e}", tree.root) + problem = problem_factory('internal-error', f"While checking imports: {e}", tree.node) problems.append(problem) return [], problems @@ -148,7 +148,7 @@ class SysPathChange(NodeBase, abc.ABC): @staticmethod def extract_from_tree(tree: Tree) -> list[SysPathChange]: visitor = SysPathChangesVisitor() - visitor.visit(tree.root) + visitor.visit(tree.node) return visitor.sys_path_changes def __init__(self, node: NodeNG, path: str, is_append: bool): diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index 5820d8ffb3..0f7143ce00 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -339,7 +339,7 @@ def apply(self, code: str) -> str: continue assert isinstance(node, Call) matcher.apply(self._from_table, self._index, node) - return tree.root.as_string() + return tree.node.as_string() def _find_matcher(self, node: NodeNG): if not isinstance(node, Call): diff --git a/src/databricks/labs/ucx/source_code/linters/python_ast.py b/src/databricks/labs/ucx/source_code/linters/python_ast.py index 5bffe50acc..64f9a6f8ba 100644 --- a/src/databricks/labs/ucx/source_code/linters/python_ast.py +++ b/src/databricks/labs/ucx/source_code/linters/python_ast.py @@ -22,15 +22,15 @@ def parse(code: str): root = parse(code) return Tree(root) - def __init__(self, root: NodeNG): - self._root: NodeNG = root + def __init__(self, node: NodeNG): + self._node: NodeNG = node @property - def root(self): - return self._root + def node(self): + return self._node def walk(self) -> Iterable[NodeNG]: - yield from self._walk(self._root) + yield from self._walk(self._node) @classmethod def _walk(cls, node: NodeNG) -> Iterable[NodeNG]: @@ -40,12 +40,12 @@ def _walk(cls, node: NodeNG) -> Iterable[NodeNG]: def locate(self, node_type: type[T], match_nodes: list[tuple[str, type]]) -> list[T]: visitor = MatchingVisitor(node_type, match_nodes) - visitor.visit(self._root) + visitor.visit(self._node) return visitor.matched_nodes def first_statement(self): - if isinstance(self._root, Module): - return self._root.body[0] + if isinstance(self._node, Module): + return self._node.body[0] return None @classmethod @@ -95,7 +95,7 @@ def is_none(cls, node: NodeNG) -> bool: def __repr__(self): truncate_after = 32 - code = repr(self._root) + code = repr(self._node) if len(code) > truncate_after: code = code[0:truncate_after] + "..." return f"" @@ -130,28 +130,33 @@ def _get_attribute_value(cls, node: Attribute): logger.debug(f"Missing handler for {name}") return None - def infer_values(self) -> Iterable[InferredValue]: + def infer_values(self): # , ctx: RuntimeContext | None = None) -> Iterable[InferredValue]: + # if ctx is not None: + # self.contextualize(ctx) for inferred_atoms in self._infer_values(): yield InferredValue(inferred_atoms) + # def contextualize(self, ctx: RuntimeContext): + # pass # parent = self.parent + def _infer_values(self) -> Iterator[Iterable[NodeNG]]: # deal with node types that don't implement 'inferred()' - if self._root is Uninferable or isinstance(self._root, Const): - yield [self._root] - elif isinstance(self._root, JoinedStr): + if self._node is Uninferable or isinstance(self._node, Const): + yield [self._node] + elif isinstance(self._node, JoinedStr): yield from self._infer_values_from_joined_string() - elif isinstance(self._root, FormattedValue): - yield from _LocalTree(self._root.value).do_infer_values() + elif isinstance(self._node, FormattedValue): + yield from _LocalTree(self._node.value).do_infer_values() else: - for inferred in self._root.inferred(): + for inferred in self._node.inferred(): # work around infinite recursion of empty lists - if inferred == self._root: + if inferred == self._node: continue yield from _LocalTree(inferred).do_infer_values() def _infer_values_from_joined_string(self) -> Iterator[Iterable[NodeNG]]: - assert isinstance(self._root, JoinedStr) - yield from self._infer_values_from_joined_values(self._root.values) + assert isinstance(self._node, JoinedStr) + yield from self._infer_values_from_joined_values(self._node.values) @classmethod def _infer_values_from_joined_values(cls, nodes: list[NodeNG]) -> Iterator[Iterable[NodeNG]]: diff --git a/src/databricks/labs/ucx/source_code/linters/spark_connect.py b/src/databricks/labs/ucx/source_code/linters/spark_connect.py index 064e72d8b3..35f5230a7a 100644 --- a/src/databricks/labs/ucx/source_code/linters/spark_connect.py +++ b/src/databricks/labs/ucx/source_code/linters/spark_connect.py @@ -184,4 +184,4 @@ def __init__(self, is_serverless: bool = False): def lint(self, code: str) -> Iterator[Advice]: tree = Tree.parse(code) for matcher in self._matchers: - yield from matcher.lint_tree(tree.root) + yield from matcher.lint_tree(tree.node) diff --git a/tests/unit/source_code/linters/test_python_ast.py b/tests/unit/source_code/linters/test_python_ast.py index 71c0c80209..841456035e 100644 --- a/tests/unit/source_code/linters/test_python_ast.py +++ b/tests/unit/source_code/linters/test_python_ast.py @@ -1,6 +1,7 @@ import pytest from astroid import Assign, Attribute, Call, Const, Expr # type: ignore +from databricks.labs.ucx.contexts.application import GlobalContext from databricks.labs.ucx.source_code.linters.python_ast import Tree @@ -143,10 +144,10 @@ def test_infers_fstring_values(): def test_fails_to_infer_cascading_fstring_values(): # The purpose of this test s to detect a change in astroid support for f-strings source = """ - value1 = "John" - value2 = f"Hello {value1}" - value3 = f"{value2}, how are you today?" - """ +value1 = "John" +value2 = f"Hello {value1}" +value3 = f"{value2}, how are you today?" +""" tree = Tree.parse(source) nodes = tree.locate(Assign, []) tree = Tree(nodes[2].value) # value of value3 = ... @@ -154,3 +155,5 @@ def test_fails_to_infer_cascading_fstring_values(): # for now, we simply check failure to infer! assert any(not value.is_inferred() for value in values) # the expected value would be ["Hello John, how are you today?"] + + From 62cb196fc999e58f65ecaa1891758db3b354df03 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 12 Jun 2024 11:36:57 +0200 Subject: [PATCH 02/23] add root property to tree --- src/databricks/labs/ucx/source_code/linters/python_ast.py | 7 +++++++ tests/unit/source_code/linters/test_python_ast.py | 8 ++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/python_ast.py b/src/databricks/labs/ucx/source_code/linters/python_ast.py index 64f9a6f8ba..f64ecbf591 100644 --- a/src/databricks/labs/ucx/source_code/linters/python_ast.py +++ b/src/databricks/labs/ucx/source_code/linters/python_ast.py @@ -29,6 +29,13 @@ def __init__(self, node: NodeNG): def node(self): return self._node + @property + def root(self): + node = self._node + while node.parent: + node = node.parent + return node + def walk(self) -> Iterable[NodeNG]: yield from self._walk(self._node) diff --git a/tests/unit/source_code/linters/test_python_ast.py b/tests/unit/source_code/linters/test_python_ast.py index 841456035e..004feeeb9d 100644 --- a/tests/unit/source_code/linters/test_python_ast.py +++ b/tests/unit/source_code/linters/test_python_ast.py @@ -1,10 +1,15 @@ import pytest from astroid import Assign, Attribute, Call, Const, Expr # type: ignore -from databricks.labs.ucx.contexts.application import GlobalContext from databricks.labs.ucx.source_code.linters.python_ast import Tree +def test_extracts_root(): + tree = Tree.parse("o.m1().m2().m3()") + stmt = tree.first_statement() + root = Tree(stmt).root + assert root == tree.node + def test_extract_call_by_name(): tree = Tree.parse("o.m1().m2().m3()") stmt = tree.first_statement() @@ -156,4 +161,3 @@ def test_fails_to_infer_cascading_fstring_values(): assert any(not value.is_inferred() for value in values) # the expected value would be ["Hello John, how are you today?"] - From 36d6dbb38af2bf04eeea15eb49ad098f9cf4dafd Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 12 Jun 2024 15:31:32 +0200 Subject: [PATCH 03/23] successful unit tests --- .../labs/ucx/contexts/application.py | 17 +---- src/databricks/labs/ucx/contexts/base.py | 18 +++++ .../ucx/source_code/linters/python_ast.py | 70 ++++++++++++++++--- .../source_code/linters/test_python_ast.py | 46 +++++++++++- 4 files changed, 125 insertions(+), 26 deletions(-) create mode 100644 src/databricks/labs/ucx/contexts/base.py diff --git a/src/databricks/labs/ucx/contexts/application.py b/src/databricks/labs/ucx/contexts/application.py index d21c096a72..bbf585dcb5 100644 --- a/src/databricks/labs/ucx/contexts/application.py +++ b/src/databricks/labs/ucx/contexts/application.py @@ -10,6 +10,7 @@ from databricks.labs.blueprint.wheels import ProductInfo, WheelsV2 from databricks.labs.lsql.backends import SqlBackend +from databricks.labs.ucx.contexts.base import BaseContext from databricks.labs.ucx.recon.data_comparator import StandardDataComparator from databricks.labs.ucx.recon.data_profiler import StandardDataProfiler from databricks.labs.ucx.recon.metadata_retriever import DatabricksTableMetadataRetriever @@ -69,17 +70,7 @@ logger = logging.getLogger(__name__) -class GlobalContext(abc.ABC): - def __init__(self, named_parameters: dict[str, str] | None = None): - if not named_parameters: - named_parameters = {} - self._named_parameters = named_parameters - - def replace(self, **kwargs): - """Replace cached properties for unit testing purposes.""" - for key, value in kwargs.items(): - self.__dict__[key] = value - return self +class GlobalContext(BaseContext, abc.ABC): @cached_property def workspace_client(self) -> WorkspaceClient: @@ -93,10 +84,6 @@ def sql_backend(self) -> SqlBackend: def account_client(self) -> AccountClient: raise ValueError("Account client not set") - @cached_property - def named_parameters(self) -> dict[str, str]: - return self._named_parameters - @cached_property def product_info(self): return ProductInfo.from_class(WorkspaceConfig) diff --git a/src/databricks/labs/ucx/contexts/base.py b/src/databricks/labs/ucx/contexts/base.py new file mode 100644 index 0000000000..04dc3e34b9 --- /dev/null +++ b/src/databricks/labs/ucx/contexts/base.py @@ -0,0 +1,18 @@ +from functools import cached_property + + +class BaseContext: + def __init__(self, named_parameters: dict[str, str] | None = None): + if not named_parameters: + named_parameters = {} + self._named_parameters = named_parameters + + def replace(self, **kwargs): + """Replace cached properties for unit testing purposes.""" + for key, value in kwargs.items(): + self.__dict__[key] = value + return self + + @cached_property + def named_parameters(self) -> dict[str, str]: + return self._named_parameters diff --git a/src/databricks/labs/ucx/source_code/linters/python_ast.py b/src/databricks/labs/ucx/source_code/linters/python_ast.py index f64ecbf591..9f43af496f 100644 --- a/src/databricks/labs/ucx/source_code/linters/python_ast.py +++ b/src/databricks/labs/ucx/source_code/linters/python_ast.py @@ -1,11 +1,15 @@ from __future__ import annotations -import abc +from abc import ABC import logging -from collections.abc import Iterable, Iterator -from typing import TypeVar +from collections.abc import Iterable, Iterator, Generator +from typing import Any, TypeVar -from astroid import Assign, Attribute, Call, Const, FormattedValue, Import, ImportFrom, JoinedStr, Module, Name, NodeNG, parse, Uninferable # type: ignore +from astroid import Assign, Attribute, Call, Const, decorators, FormattedValue, Import, ImportFrom, JoinedStr, Module, Name, NodeNG, parse, Uninferable # type: ignore +from astroid.context import InferenceContext, InferenceResult, CallContext # type: ignore +from astroid.typing import InferenceErrorInfo # type: ignore + +from databricks.labs.ucx.contexts.base import BaseContext logger = logging.getLogger(__name__) @@ -137,14 +141,16 @@ def _get_attribute_value(cls, node: Attribute): logger.debug(f"Missing handler for {name}") return None - def infer_values(self): # , ctx: RuntimeContext | None = None) -> Iterable[InferredValue]: - # if ctx is not None: - # self.contextualize(ctx) + def infer_values(self, ctx: BaseContext | None = None) -> Iterable[InferredValue]: + if ctx is not None: + self.contextualize(ctx) for inferred_atoms in self._infer_values(): yield InferredValue(inferred_atoms) - # def contextualize(self, ctx: RuntimeContext): - # pass # parent = self.parent + def contextualize(self, ctx: BaseContext): + calls = self.locate(Call, [("get", Attribute), ("widgets", Attribute), ("dbutils", Name)]) + for call in calls: + call.func = _ContextualCall(ctx, call) def _infer_values(self) -> Iterator[Iterable[NodeNG]]: # deal with node types that don't implement 'inferred()' @@ -182,6 +188,50 @@ def do_infer_values(self): return self._infer_values() +class _ContextualCall(NodeNG): + + def __init__(self, ctx: BaseContext, node: NodeNG): + super().__init__( + lineno=node.lineno, + col_offset=node.col_offset, + end_lineno=node.end_lineno, + end_col_offset=node.end_col_offset, + parent=node.parent, + ) + self._ctx = ctx + + @decorators.raise_if_nothing_inferred + def _infer( + self, context: InferenceContext | None = None, **kwargs: Any + ) -> Generator[InferenceResult, None, InferenceErrorInfo | None]: + yield self + return InferenceErrorInfo(node=self, context=context) + + def infer_call_result(self, context: InferenceContext | None = None, **_): # caller needs unused kwargs + call_context = getattr(context, "callcontext", None) + if not isinstance(call_context, CallContext): + yield Uninferable + return + arg = call_context.args[0] + for inferred in Tree(arg).infer_values(): + if not inferred.is_inferred(): + yield Uninferable + continue + name = inferred.as_string() + if name not in self._ctx.named_parameters: + yield Uninferable + continue + value = self._ctx.named_parameters[name] + yield Const( + value, + lineno=self.lineno, + col_offset=self.col_offset, + end_lineno=self.end_lineno, + end_col_offset=self.end_col_offset, + parent=self, + ) + + class InferredValue: """Represents 1 or more nodes that together represent the value. The list of nodes typically holds one Const element, but for f-strings it @@ -278,7 +328,7 @@ def _matches(self, node: NodeNG, depth: int): return self._matches(next_node, depth + 1) -class NodeBase(abc.ABC): +class NodeBase(ABC): def __init__(self, node: NodeNG): self._node = node diff --git a/tests/unit/source_code/linters/test_python_ast.py b/tests/unit/source_code/linters/test_python_ast.py index 004feeeb9d..374b6b8ddd 100644 --- a/tests/unit/source_code/linters/test_python_ast.py +++ b/tests/unit/source_code/linters/test_python_ast.py @@ -1,6 +1,7 @@ import pytest from astroid import Assign, Attribute, Call, Const, Expr # type: ignore +from databricks.labs.ucx.contexts.application import GlobalContext from databricks.labs.ucx.source_code.linters.python_ast import Tree @@ -10,6 +11,7 @@ def test_extracts_root(): root = Tree(stmt).root assert root == tree.node + def test_extract_call_by_name(): tree = Tree.parse("o.m1().m2().m3()") stmt = tree.first_statement() @@ -147,7 +149,7 @@ def test_infers_fstring_values(): def test_fails_to_infer_cascading_fstring_values(): - # The purpose of this test s to detect a change in astroid support for f-strings + # The purpose of this test is to detect a change in astroid support for f-strings source = """ value1 = "John" value2 = f"Hello {value1}" @@ -161,3 +163,45 @@ def test_fails_to_infer_cascading_fstring_values(): assert any(not value.is_inferred() for value in values) # the expected value would be ["Hello John, how are you today?"] + +def test_infers_externally_defined_value(): + ctx = GlobalContext() + ctx.named_parameters["my-widget"] = "my-value" + source = """ +name = "my-widget" +value = dbutils.widgets.get(name) +""" + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[1].value) # value of value = ... + values = list(tree.infer_values(ctx)) + strings = list(value.as_string() for value in values) + assert strings == ["my-value"] + + +def test_infers_externally_defined_values(): + ctx = GlobalContext() + ctx.named_parameters["my-widget-1"] = "my-value-1" + ctx.named_parameters["my-widget-2"] = "my-value-2" + source = """ +for name in ["my-widget-1", "my-widget-2"]: + value = dbutils.widgets.get(name) +""" + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[0].value) # value of value = ... + values = list(tree.infer_values(ctx)) + strings = list(value.as_string() for value in values) + assert strings == ["my-value-1", "my-value-2"] + + +def test_fails_to_infer_missing_externally_defined_value(): + source = """ +name = "my-widget" +value = dbutils.widgets.get(name) +""" + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[1].value) # value of value = ... + values = tree.infer_values(GlobalContext()) + assert all(not value.is_inferred() for value in values) From b3b51409debced627a43ea05d4bd86da47f2352d Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 12 Jun 2024 15:48:58 +0200 Subject: [PATCH 04/23] use named parameters from sate, not from context --- .../labs/ucx/contexts/application.py | 17 +++++++++++++-- .../ucx/source_code/linters/python_ast.py | 21 ++++++++++--------- .../source_code/linters/test_python_ast.py | 17 +++++++-------- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/src/databricks/labs/ucx/contexts/application.py b/src/databricks/labs/ucx/contexts/application.py index bbf585dcb5..d21c096a72 100644 --- a/src/databricks/labs/ucx/contexts/application.py +++ b/src/databricks/labs/ucx/contexts/application.py @@ -10,7 +10,6 @@ from databricks.labs.blueprint.wheels import ProductInfo, WheelsV2 from databricks.labs.lsql.backends import SqlBackend -from databricks.labs.ucx.contexts.base import BaseContext from databricks.labs.ucx.recon.data_comparator import StandardDataComparator from databricks.labs.ucx.recon.data_profiler import StandardDataProfiler from databricks.labs.ucx.recon.metadata_retriever import DatabricksTableMetadataRetriever @@ -70,7 +69,17 @@ logger = logging.getLogger(__name__) -class GlobalContext(BaseContext, abc.ABC): +class GlobalContext(abc.ABC): + def __init__(self, named_parameters: dict[str, str] | None = None): + if not named_parameters: + named_parameters = {} + self._named_parameters = named_parameters + + def replace(self, **kwargs): + """Replace cached properties for unit testing purposes.""" + for key, value in kwargs.items(): + self.__dict__[key] = value + return self @cached_property def workspace_client(self) -> WorkspaceClient: @@ -84,6 +93,10 @@ def sql_backend(self) -> SqlBackend: def account_client(self) -> AccountClient: raise ValueError("Account client not set") + @cached_property + def named_parameters(self) -> dict[str, str]: + return self._named_parameters + @cached_property def product_info(self): return ProductInfo.from_class(WorkspaceConfig) diff --git a/src/databricks/labs/ucx/source_code/linters/python_ast.py b/src/databricks/labs/ucx/source_code/linters/python_ast.py index 9f43af496f..d6836411a4 100644 --- a/src/databricks/labs/ucx/source_code/linters/python_ast.py +++ b/src/databricks/labs/ucx/source_code/linters/python_ast.py @@ -9,7 +9,7 @@ from astroid.context import InferenceContext, InferenceResult, CallContext # type: ignore from astroid.typing import InferenceErrorInfo # type: ignore -from databricks.labs.ucx.contexts.base import BaseContext +from databricks.labs.ucx.source_code.base import CurrentSessionState logger = logging.getLogger(__name__) @@ -141,16 +141,16 @@ def _get_attribute_value(cls, node: Attribute): logger.debug(f"Missing handler for {name}") return None - def infer_values(self, ctx: BaseContext | None = None) -> Iterable[InferredValue]: - if ctx is not None: - self.contextualize(ctx) + def infer_values(self, state: CurrentSessionState | None = None) -> Iterable[InferredValue]: + if state is not None: + self.contextualize(state) for inferred_atoms in self._infer_values(): yield InferredValue(inferred_atoms) - def contextualize(self, ctx: BaseContext): + def contextualize(self, state: CurrentSessionState): calls = self.locate(Call, [("get", Attribute), ("widgets", Attribute), ("dbutils", Name)]) for call in calls: - call.func = _ContextualCall(ctx, call) + call.func = _ContextualCall(state, call) def _infer_values(self) -> Iterator[Iterable[NodeNG]]: # deal with node types that don't implement 'inferred()' @@ -190,7 +190,7 @@ def do_infer_values(self): class _ContextualCall(NodeNG): - def __init__(self, ctx: BaseContext, node: NodeNG): + def __init__(self, state: CurrentSessionState, node: NodeNG): super().__init__( lineno=node.lineno, col_offset=node.col_offset, @@ -198,7 +198,7 @@ def __init__(self, ctx: BaseContext, node: NodeNG): end_col_offset=node.end_col_offset, parent=node.parent, ) - self._ctx = ctx + self._state = state @decorators.raise_if_nothing_inferred def _infer( @@ -218,10 +218,11 @@ def infer_call_result(self, context: InferenceContext | None = None, **_): # ca yield Uninferable continue name = inferred.as_string() - if name not in self._ctx.named_parameters: + named_parameters = self._state.named_parameters + if not named_parameters or name not in named_parameters: yield Uninferable continue - value = self._ctx.named_parameters[name] + value = named_parameters[name] yield Const( value, lineno=self.lineno, diff --git a/tests/unit/source_code/linters/test_python_ast.py b/tests/unit/source_code/linters/test_python_ast.py index 374b6b8ddd..ab19c1a784 100644 --- a/tests/unit/source_code/linters/test_python_ast.py +++ b/tests/unit/source_code/linters/test_python_ast.py @@ -1,7 +1,7 @@ import pytest from astroid import Assign, Attribute, Call, Const, Expr # type: ignore -from databricks.labs.ucx.contexts.application import GlobalContext +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.linters.python_ast import Tree @@ -165,8 +165,8 @@ def test_fails_to_infer_cascading_fstring_values(): def test_infers_externally_defined_value(): - ctx = GlobalContext() - ctx.named_parameters["my-widget"] = "my-value" + state = CurrentSessionState() + state.named_parameters = {"my-widget": "my-value"} source = """ name = "my-widget" value = dbutils.widgets.get(name) @@ -174,15 +174,14 @@ def test_infers_externally_defined_value(): tree = Tree.parse(source) nodes = tree.locate(Assign, []) tree = Tree(nodes[1].value) # value of value = ... - values = list(tree.infer_values(ctx)) + values = list(tree.infer_values(state)) strings = list(value.as_string() for value in values) assert strings == ["my-value"] def test_infers_externally_defined_values(): - ctx = GlobalContext() - ctx.named_parameters["my-widget-1"] = "my-value-1" - ctx.named_parameters["my-widget-2"] = "my-value-2" + state = CurrentSessionState() + state.named_parameters = {"my-widget-1": "my-value-1", "my-widget-2": "my-value-2"} source = """ for name in ["my-widget-1", "my-widget-2"]: value = dbutils.widgets.get(name) @@ -190,7 +189,7 @@ def test_infers_externally_defined_values(): tree = Tree.parse(source) nodes = tree.locate(Assign, []) tree = Tree(nodes[0].value) # value of value = ... - values = list(tree.infer_values(ctx)) + values = list(tree.infer_values(state)) strings = list(value.as_string() for value in values) assert strings == ["my-value-1", "my-value-2"] @@ -203,5 +202,5 @@ def test_fails_to_infer_missing_externally_defined_value(): tree = Tree.parse(source) nodes = tree.locate(Assign, []) tree = Tree(nodes[1].value) # value of value = ... - values = tree.infer_values(GlobalContext()) + values = tree.infer_values(CurrentSessionState()) assert all(not value.is_inferred() for value in values) From 0f081be6ecfaa74c13ebc53ccda057971a1277b5 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 12 Jun 2024 16:57:23 +0200 Subject: [PATCH 05/23] integrate session state to linters --- .../labs/ucx/contexts/application.py | 9 +++++++- .../labs/ucx/contexts/workspace_cli.py | 1 + src/databricks/labs/ucx/source_code/graph.py | 14 ++++++++----- src/databricks/labs/ucx/source_code/jobs.py | 14 ++++++------- .../labs/ucx/source_code/linters/context.py | 2 +- .../labs/ucx/source_code/linters/dbfs.py | 13 ++++++------ .../labs/ucx/source_code/linters/files.py | 6 ++++-- .../labs/ucx/source_code/linters/imports.py | 11 +++++----- .../labs/ucx/source_code/linters/pyspark.py | 4 +++- .../ucx/source_code/linters/python_ast.py | 8 +++---- tests/integration/source_code/test_jobs.py | 3 ++- tests/unit/source_code/conftest.py | 5 ++++- tests/unit/source_code/linters/test_dbfs.py | 8 +++---- tests/unit/source_code/linters/test_files.py | 18 +++++++++++++--- .../linters/test_python_imports.py | 19 +++++++++-------- .../unit/source_code/notebooks/test_cells.py | 13 ++++++++---- tests/unit/source_code/test_dependencies.py | 16 +++++++++----- tests/unit/source_code/test_graph.py | 9 ++++++-- tests/unit/source_code/test_jobs.py | 4 +++- tests/unit/source_code/test_notebook.py | 21 +++++++++++-------- .../test_path_lookup_simulation.py | 12 +++++++---- tests/unit/source_code/test_s3fs.py | 9 ++++++-- 22 files changed, 142 insertions(+), 77 deletions(-) diff --git a/src/databricks/labs/ucx/contexts/application.py b/src/databricks/labs/ucx/contexts/application.py index d21c096a72..1e66ce4f8b 100644 --- a/src/databricks/labs/ucx/contexts/application.py +++ b/src/databricks/labs/ucx/contexts/application.py @@ -15,6 +15,7 @@ from databricks.labs.ucx.recon.metadata_retriever import DatabricksTableMetadataRetriever from databricks.labs.ucx.recon.migration_recon import MigrationRecon from databricks.labs.ucx.recon.schema_comparator import StandardSchemaComparator +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.python_libraries import PythonLibraryResolver from databricks.sdk import AccountClient, WorkspaceClient, core from databricks.sdk.errors import ResourceDoesNotExist @@ -375,6 +376,10 @@ def path_lookup(self): # TODO find a solution to enable a different cwd per job/task (maybe it's not necessary or possible?) return PathLookup.from_sys_path(Path.cwd()) + @cached_property + def session_state(self): + return CurrentSessionState() + @cached_property def file_loader(self): return FileLoader() @@ -393,7 +398,9 @@ def file_resolver(self): @cached_property def dependency_resolver(self): - return DependencyResolver(self.pip_resolver, self.notebook_resolver, self.file_resolver, self.path_lookup) + return DependencyResolver( + self.pip_resolver, self.notebook_resolver, self.file_resolver, self.path_lookup, self.session_state + ) @cached_property def workflow_linter(self): diff --git a/src/databricks/labs/ucx/contexts/workspace_cli.py b/src/databricks/labs/ucx/contexts/workspace_cli.py index 3c63223d8c..fc24eb2f9d 100644 --- a/src/databricks/labs/ucx/contexts/workspace_cli.py +++ b/src/databricks/labs/ucx/contexts/workspace_cli.py @@ -194,6 +194,7 @@ def local_code_linter(self): self.file_loader, self.folder_loader, self.path_lookup, + self.session_state, self.dependency_resolver, self.linter_context_factory, ) diff --git a/src/databricks/labs/ucx/source_code/graph.py b/src/databricks/labs/ucx/source_code/graph.py index f5f3691bcd..75401b3cfb 100644 --- a/src/databricks/labs/ucx/source_code/graph.py +++ b/src/databricks/labs/ucx/source_code/graph.py @@ -8,7 +8,7 @@ from astroid import ImportFrom, NodeNG # type: ignore -from databricks.labs.ucx.source_code.base import Advisory +from databricks.labs.ucx.source_code.base import Advisory, CurrentSessionState from databricks.labs.ucx.source_code.linters.imports import ( DbutilsLinter, ImportSource, @@ -30,11 +30,13 @@ def __init__( parent: DependencyGraph | None, resolver: DependencyResolver, path_lookup: PathLookup, + session_state: CurrentSessionState, ): self._dependency = dependency self._parent = parent self._resolver = resolver self._path_lookup = path_lookup.change_directory(dependency.path.parent) + self._session_state = session_state self._dependencies: dict[Dependency, DependencyGraph] = {} @property @@ -78,7 +80,7 @@ def register_dependency(self, dependency: Dependency) -> MaybeGraph: self._dependencies[dependency] = maybe.graph return maybe # nay, create the child graph and populate it - child_graph = DependencyGraph(dependency, self, self._resolver, self._path_lookup) + child_graph = DependencyGraph(dependency, self, self._resolver, self._path_lookup, self._session_state) self._dependencies[dependency] = child_graph container = dependency.load(self.path_lookup) # TODO: Return either (child) graph OR problems @@ -177,7 +179,7 @@ def build_graph_from_python_source(self, python_code: str) -> list[DependencyPro except Exception as e: # pylint: disable=broad-except problems.append(DependencyProblem('parse-error', f"Could not parse Python code: {e}")) return problems - syspath_changes = SysPathChange.extract_from_tree(tree) + syspath_changes = SysPathChange.extract_from_tree(self._session_state, tree) run_calls = DbutilsLinter.list_dbutils_notebook_run_calls(tree) import_sources: list[ImportSource] import_problems: list[DependencyProblem] @@ -342,11 +344,13 @@ def __init__( notebook_resolver: BaseNotebookResolver, import_resolver: BaseImportResolver, path_lookup: PathLookup, + session_state: CurrentSessionState, ): self._library_resolver = library_resolver self._notebook_resolver = notebook_resolver self._import_resolver = import_resolver self._path_lookup = path_lookup + self._session_state = session_state def resolve_notebook(self, path_lookup: PathLookup, path: Path) -> MaybeDependency: return self._notebook_resolver.resolve_notebook(path_lookup, path) @@ -367,7 +371,7 @@ def build_local_file_dependency_graph(self, path: Path) -> MaybeGraph: maybe = resolver.resolve_local_file(self._path_lookup, path) if not maybe.dependency: return MaybeGraph(None, self._make_relative_paths(maybe.problems, path)) - graph = DependencyGraph(maybe.dependency, None, self, self._path_lookup) + graph = DependencyGraph(maybe.dependency, None, self, self._path_lookup, self._session_state) container = maybe.dependency.load(graph.path_lookup) if container is None: problem = DependencyProblem('cannot-load-file', f"Could not load file {path}") @@ -389,7 +393,7 @@ def build_notebook_dependency_graph(self, path: Path) -> MaybeGraph: maybe = self._notebook_resolver.resolve_notebook(self._path_lookup, path) if not maybe.dependency: return MaybeGraph(None, self._make_relative_paths(maybe.problems, path)) - graph = DependencyGraph(maybe.dependency, None, self, self._path_lookup) + graph = DependencyGraph(maybe.dependency, None, self, self._path_lookup, self._session_state) container = maybe.dependency.load(graph.path_lookup) if container is None: problem = DependencyProblem('cannot-load-notebook', f"Could not load notebook {path}") diff --git a/src/databricks/labs/ucx/source_code/jobs.py b/src/databricks/labs/ucx/source_code/jobs.py index b8ce3eb6b5..8f88642e82 100644 --- a/src/databricks/labs/ucx/source_code/jobs.py +++ b/src/databricks/labs/ucx/source_code/jobs.py @@ -343,21 +343,21 @@ def _lint_job(self, job: jobs.Job) -> list[JobProblem]: def _lint_task(self, task: jobs.Task, job: jobs.Job): dependency: Dependency = WorkflowTask(self._ws, task, job) - graph = DependencyGraph(dependency, None, self._resolver, self._path_lookup) + # we can load it without further preparation since the WorkflowTask is merely a wrapper container = dependency.load(self._path_lookup) - assert container is not None # because we just created it assert isinstance(container, WorkflowTaskContainer) + session_state = CurrentSessionState( + data_security_mode=container.data_security_mode, + named_parameters=container.named_parameters, + spark_conf=container.spark_conf, + ) + graph = DependencyGraph(dependency, None, self._resolver, self._path_lookup, session_state) problems = container.build_dependency_graph(graph) if problems: for problem in problems: source_path = self._UNKNOWN if problem.is_path_missing() else problem.source_path yield source_path, problem return - session_state = CurrentSessionState( - data_security_mode=container.data_security_mode, - named_parameters=container.named_parameters, - spark_conf=container.spark_conf, - ) ctx = LinterContext(self._migration_index, session_state) for dependency in graph.all_dependencies: logger.info(f'Linting {task.task_key} dependency: {dependency}') diff --git a/src/databricks/labs/ucx/source_code/linters/context.py b/src/databricks/labs/ucx/source_code/linters/context.py index 9fbe3e8bd1..5660b877cf 100644 --- a/src/databricks/labs/ucx/source_code/linters/context.py +++ b/src/databricks/labs/ucx/source_code/linters/context.py @@ -21,7 +21,7 @@ def __init__(self, index: MigrationIndex, session_state: CurrentSessionState | N Language.PYTHON: SequentialLinter( [ SparkSql(from_table, index), - DBFSUsageLinter(), + DBFSUsageLinter(session_state), DBRv8d0Linter(dbr_version=None), SparkConnectLinter(is_serverless=False), DbutilsLinter(), diff --git a/src/databricks/labs/ucx/source_code/linters/dbfs.py b/src/databricks/labs/ucx/source_code/linters/dbfs.py index 4348067603..fcf8aa5000 100644 --- a/src/databricks/labs/ucx/source_code/linters/dbfs.py +++ b/src/databricks/labs/ucx/source_code/linters/dbfs.py @@ -5,7 +5,7 @@ import sqlglot from sqlglot.expressions import Table -from databricks.labs.ucx.source_code.base import Advice, Linter, Deprecation +from databricks.labs.ucx.source_code.base import Advice, Linter, Deprecation, CurrentSessionState from databricks.labs.ucx.source_code.linters.python_ast import Tree, TreeVisitor, InferredValue logger = logging.getLogger(__name__) @@ -17,7 +17,8 @@ class DetectDbfsVisitor(TreeVisitor): against a list of known deprecated paths. """ - def __init__(self) -> None: + def __init__(self, session_state: CurrentSessionState) -> None: + self._session_state = session_state self._advices: list[Advice] = [] self._fs_prefixes = ["/dbfs/mnt", "dbfs:/", "/mnt/"] self._reported_locations: set[tuple[int, int]] = set() # Set to store reported locations; astroid coordinates! @@ -28,7 +29,7 @@ def visit_call(self, node: Call): def _visit_arg(self, arg: NodeNG): try: - for inferred in Tree(arg).infer_values(): + for inferred in Tree(arg).infer_values(self._session_state): if not inferred.is_inferred(): logger.debug(f"Could not infer value of {arg.as_string()}") continue @@ -64,8 +65,8 @@ def get_advices(self) -> Iterable[Advice]: class DBFSUsageLinter(Linter): - def __init__(self): - pass + def __init__(self, session_state: CurrentSessionState): + self._session_state = session_state @staticmethod def name() -> str: @@ -79,7 +80,7 @@ def lint(self, code: str) -> Iterable[Advice]: Lints the code looking for file system paths that are deprecated """ tree = Tree.parse(code) - visitor = DetectDbfsVisitor() + visitor = DetectDbfsVisitor(self._session_state) visitor.visit(tree.node) yield from visitor.get_advices() diff --git a/src/databricks/labs/ucx/source_code/linters/files.py b/src/databricks/labs/ucx/source_code/linters/files.py index 71c07a580b..cd8fc05411 100644 --- a/src/databricks/labs/ucx/source_code/linters/files.py +++ b/src/databricks/labs/ucx/source_code/linters/files.py @@ -6,7 +6,7 @@ import sys from typing import TextIO -from databricks.labs.ucx.source_code.base import LocatedAdvice +from databricks.labs.ucx.source_code.base import LocatedAdvice, CurrentSessionState from databricks.labs.ucx.source_code.notebooks.sources import FileLinter, SUPPORTED_EXTENSION_LANGUAGES from databricks.labs.ucx.source_code.path_lookup import PathLookup from databricks.labs.ucx.source_code.known import Whitelist @@ -88,12 +88,14 @@ def __init__( file_loader: FileLoader, folder_loader: FolderLoader, path_lookup: PathLookup, + session_state: CurrentSessionState, dependency_resolver: DependencyResolver, languages_factory: Callable[[], LinterContext], ) -> None: self._file_loader = file_loader self._folder_loader = folder_loader self._path_lookup = path_lookup + self._session_state = session_state self._dependency_resolver = dependency_resolver self._extensions = {".py": Language.PYTHON, ".sql": Language.SQL} self._new_linter_context = languages_factory @@ -116,7 +118,7 @@ def lint(self, prompts: Prompts, path: Path | None, stdout: TextIO = sys.stdout) def lint_path(self, path: Path) -> Iterable[LocatedAdvice]: loader = self._folder_loader if path.is_dir() else self._file_loader dependency = Dependency(loader, path) - graph = DependencyGraph(dependency, None, self._dependency_resolver, self._path_lookup) + graph = DependencyGraph(dependency, None, self._dependency_resolver, self._path_lookup, self._session_state) container = dependency.load(self._path_lookup) assert container is not None # because we just created it problems = container.build_dependency_graph(graph) diff --git a/src/databricks/labs/ucx/source_code/linters/imports.py b/src/databricks/labs/ucx/source_code/linters/imports.py index 20d52d2f65..23658ed809 100644 --- a/src/databricks/labs/ucx/source_code/linters/imports.py +++ b/src/databricks/labs/ucx/source_code/linters/imports.py @@ -16,7 +16,7 @@ NodeNG, ) -from databricks.labs.ucx.source_code.base import Linter, Advice, Advisory +from databricks.labs.ucx.source_code.base import Linter, Advice, Advisory, CurrentSessionState from databricks.labs.ucx.source_code.linters.python_ast import Tree, NodeBase, TreeVisitor, InferredValue logger = logging.getLogger(__name__) @@ -146,8 +146,8 @@ def list_dbutils_notebook_run_calls(tree: Tree) -> list[NotebookRunCall]: class SysPathChange(NodeBase, abc.ABC): @staticmethod - def extract_from_tree(tree: Tree) -> list[SysPathChange]: - visitor = SysPathChangesVisitor() + def extract_from_tree(session_state: CurrentSessionState, tree: Tree) -> list[SysPathChange]: + visitor = SysPathChangesVisitor(session_state) visitor.visit(tree.node) return visitor.sys_path_changes @@ -186,8 +186,9 @@ class UnresolvedPath(SysPathChange): class SysPathChangesVisitor(TreeVisitor): - def __init__(self) -> None: + def __init__(self, session_state: CurrentSessionState) -> None: super() + self._session_state = session_state self._aliases: dict[str, str] = {} self.sys_path_changes: list[SysPathChange] = [] @@ -223,7 +224,7 @@ def visit_call(self, node: Call): relative = True changed = changed.args[0] try: - for inferred in Tree(changed).infer_values(): + for inferred in Tree(changed).infer_values(self._session_state): self._visit_inferred(changed, inferred, relative, is_append) except InferenceError: self.sys_path_changes.append(UnresolvedPath(changed, changed.as_string(), is_append)) diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index 0f7143ce00..b5441a41ab 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -11,6 +11,7 @@ Fixer, Linter, Failure, + CurrentSessionState, ) from databricks.labs.ucx.source_code.queries import FromTable from databricks.labs.ucx.source_code.linters.python_ast import Tree, InferredValue @@ -24,6 +25,7 @@ class Matcher(ABC): table_arg_index: int table_arg_name: str | None = None call_context: dict[str, set[str]] | None = None + session_state: CurrentSessionState | None = None def matches(self, node: NodeNG): return isinstance(node, Call) and isinstance(node.func, Attribute) and self._get_table_arg(node) is not None @@ -72,7 +74,7 @@ def lint(self, from_table: FromTable, index: MigrationIndex, node: Call) -> Iter table_arg = self._get_table_arg(node) if table_arg: try: - for inferred in Tree(table_arg).infer_values(): + for inferred in Tree(table_arg).infer_values(self.session_state): yield from self._lint_table_arg(from_table, node, inferred) except InferenceError: yield Advisory.from_node( diff --git a/src/databricks/labs/ucx/source_code/linters/python_ast.py b/src/databricks/labs/ucx/source_code/linters/python_ast.py index d6836411a4..35b8c99991 100644 --- a/src/databricks/labs/ucx/source_code/linters/python_ast.py +++ b/src/databricks/labs/ucx/source_code/linters/python_ast.py @@ -190,7 +190,7 @@ def do_infer_values(self): class _ContextualCall(NodeNG): - def __init__(self, state: CurrentSessionState, node: NodeNG): + def __init__(self, session_state: CurrentSessionState, node: NodeNG): super().__init__( lineno=node.lineno, col_offset=node.col_offset, @@ -198,7 +198,7 @@ def __init__(self, state: CurrentSessionState, node: NodeNG): end_col_offset=node.end_col_offset, parent=node.parent, ) - self._state = state + self._session_state = session_state @decorators.raise_if_nothing_inferred def _infer( @@ -213,12 +213,12 @@ def infer_call_result(self, context: InferenceContext | None = None, **_): # ca yield Uninferable return arg = call_context.args[0] - for inferred in Tree(arg).infer_values(): + for inferred in Tree(arg).infer_values(self._session_state): if not inferred.is_inferred(): yield Uninferable continue name = inferred.as_string() - named_parameters = self._state.named_parameters + named_parameters = self._session_state.named_parameters if not named_parameters or name not in named_parameters: yield Uninferable continue diff --git a/tests/integration/source_code/test_jobs.py b/tests/integration/source_code/test_jobs.py index d8533da51e..839e3d0152 100644 --- a/tests/integration/source_code/test_jobs.py +++ b/tests/integration/source_code/test_jobs.py @@ -178,7 +178,7 @@ def test_workflow_linter_lints_job_with_import_pypi_library( def test_lint_local_code(simple_ctx): # no need to connect - linter_context = LinterContext(MigrationIndex([])) + linter_context = LinterContext(MigrationIndex([]), simple_ctx.session_state) light_ctx = simple_ctx ucx_path = Path(__file__).parent.parent.parent.parent path_to_scan = Path(ucx_path, "src") @@ -187,6 +187,7 @@ def test_lint_local_code(simple_ctx): light_ctx.file_loader, light_ctx.folder_loader, light_ctx.path_lookup, + light_ctx.session_state, light_ctx.dependency_resolver, lambda: linter_context, ) diff --git a/tests/unit/source_code/conftest.py b/tests/unit/source_code/conftest.py index 1ad3c4ac69..97ff5cb5ec 100644 --- a/tests/unit/source_code/conftest.py +++ b/tests/unit/source_code/conftest.py @@ -4,6 +4,7 @@ MigrationStatus, ) from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.graph import DependencyResolver from databricks.labs.ucx.source_code.known import Whitelist from databricks.labs.ucx.source_code.linters.files import ImportFileResolver, FileLoader @@ -53,4 +54,6 @@ def simple_dependency_resolver(mock_path_lookup): library_resolver = PythonLibraryResolver(whitelist) notebook_resolver = NotebookResolver(NotebookLoader()) import_resolver = ImportFileResolver(FileLoader(), whitelist) - return DependencyResolver(library_resolver, notebook_resolver, import_resolver, mock_path_lookup) + return DependencyResolver( + library_resolver, notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() + ) diff --git a/tests/unit/source_code/linters/test_dbfs.py b/tests/unit/source_code/linters/test_dbfs.py index 88c8b8f805..9ce795eee7 100644 --- a/tests/unit/source_code/linters/test_dbfs.py +++ b/tests/unit/source_code/linters/test_dbfs.py @@ -1,6 +1,6 @@ import pytest -from databricks.labs.ucx.source_code.base import Deprecation, Advisory +from databricks.labs.ucx.source_code.base import Deprecation, Advisory, CurrentSessionState from databricks.labs.ucx.source_code.linters.dbfs import DBFSUsageLinter, FromDbfsFolder @@ -17,7 +17,7 @@ class TestDetectDBFS: ], ) def test_detects_dbfs_paths(self, code, expected): - linter = DBFSUsageLinter() + linter = DBFSUsageLinter(CurrentSessionState()) advices = list(linter.lint(code)) for advice in advices: assert isinstance(advice, Advisory) @@ -46,7 +46,7 @@ def test_detects_dbfs_paths(self, code, expected): ], ) def test_dbfs_usage_linter(self, code, expected): - linter = DBFSUsageLinter() + linter = DBFSUsageLinter(CurrentSessionState()) advices = linter.lint(code) count = 0 for advice in advices: @@ -55,7 +55,7 @@ def test_dbfs_usage_linter(self, code, expected): assert count == expected def test_dbfs_name(self): - linter = DBFSUsageLinter() + linter = DBFSUsageLinter(CurrentSessionState()) assert linter.name() == "dbfs-usage" diff --git a/tests/unit/source_code/linters/test_files.py b/tests/unit/source_code/linters/test_files.py index 881929859f..1b77587a15 100644 --- a/tests/unit/source_code/linters/test_files.py +++ b/tests/unit/source_code/linters/test_files.py @@ -4,6 +4,8 @@ import pytest from databricks.labs.blueprint.tui import MockPrompts + +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.graph import DependencyResolver, SourceContainer from databricks.labs.ucx.source_code.notebooks.loaders import NotebookResolver, NotebookLoader from databricks.labs.ucx.source_code.python_libraries import PythonLibraryResolver @@ -95,16 +97,18 @@ def test_linter_walks_directory(mock_path_lookup, migration_index): folder_loader = FolderLoader(file_loader) whitelist = Whitelist() pip_resolver = PythonLibraryResolver(whitelist) + session_state = CurrentSessionState() resolver = DependencyResolver( pip_resolver, NotebookResolver(NotebookLoader()), ImportFileResolver(file_loader, whitelist), mock_path_lookup, + session_state, ) path = Path(Path(__file__).parent, "../samples", "simulate-sys-path") prompts = MockPrompts({"Which file or directory do you want to lint ?": path.as_posix()}) linter = LocalCodeLinter( - file_loader, folder_loader, mock_path_lookup, resolver, lambda: LinterContext(migration_index) + file_loader, folder_loader, mock_path_lookup, session_state, resolver, lambda: LinterContext(migration_index) ) advices = linter.lint(prompts, None) assert not advices @@ -149,10 +153,18 @@ def test_known_issues(path: Path, migration_index): file_loader = FileLoader() folder_loader = FolderLoader(file_loader) path_lookup = PathLookup.from_sys_path(Path.cwd()) + session_state = CurrentSessionState() whitelist = Whitelist() notebook_resolver = NotebookResolver(NotebookLoader()) import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, path_lookup) - linter = LocalCodeLinter(file_loader, folder_loader, path_lookup, resolver, lambda: LinterContext(migration_index)) + resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, path_lookup, session_state) + linter = LocalCodeLinter( + file_loader, + folder_loader, + path_lookup, + session_state, + resolver, + lambda: LinterContext(migration_index, session_state), + ) linter.lint(MockPrompts({}), path) diff --git a/tests/unit/source_code/linters/test_python_imports.py b/tests/unit/source_code/linters/test_python_imports.py index e73e4a25e5..8bd054423a 100644 --- a/tests/unit/source_code/linters/test_python_imports.py +++ b/tests/unit/source_code/linters/test_python_imports.py @@ -2,6 +2,7 @@ import pytest +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.graph import DependencyProblem from databricks.labs.ucx.source_code.linters.imports import DbutilsLinter, ImportSource, SysPathChange @@ -56,7 +57,7 @@ def test_linter_returns_appended_absolute_paths(): sys.path.append("absolute_path_2") """ tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(tree) + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) assert ["absolute_path_1", "absolute_path_2"] == [p.path for p in appended] @@ -67,7 +68,7 @@ def test_linter_returns_appended_absolute_paths_with_sys_alias(): stuff.path.append("absolute_path_2") """ tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(tree) + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) assert ["absolute_path_1", "absolute_path_2"] == [p.path for p in appended] @@ -77,7 +78,7 @@ def test_linter_returns_appended_absolute_paths_with_sys_path_alias(): stuff.append("absolute_path") """ tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(tree) + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) assert "absolute_path" in [p.path for p in appended] @@ -88,7 +89,7 @@ def test_linter_returns_appended_relative_paths(): sys.path.append(os.path.abspath("relative_path")) """ tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(tree) + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) assert "relative_path" in [p.path for p in appended] @@ -99,7 +100,7 @@ def test_linter_returns_appended_relative_paths_with_os_alias(): sys.path.append(stuff.path.abspath("relative_path")) """ tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(tree) + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) assert "relative_path" in [p.path for p in appended] @@ -110,7 +111,7 @@ def test_linter_returns_appended_relative_paths_with_os_path_alias(): sys.path.append(stuff.abspath("relative_path")) """ tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(tree) + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) assert "relative_path" in [p.path for p in appended] @@ -121,7 +122,7 @@ def test_linter_returns_appended_relative_paths_with_os_path_abspath_import(): sys.path.append(abspath("relative_path")) """ tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(tree) + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) assert "relative_path" in [p.path for p in appended] @@ -132,7 +133,7 @@ def test_linter_returns_appended_relative_paths_with_os_path_abspath_alias(): sys.path.append(stuff("relative_path")) """ tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(tree) + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) assert "relative_path" in [p.path for p in appended] @@ -143,7 +144,7 @@ def test_linter_returns_inferred_paths(): sys.path.append(path) """ tree = Tree.parse(code) - appended = SysPathChange.extract_from_tree(tree) + appended = SysPathChange.extract_from_tree(CurrentSessionState(), tree) assert ["absolute_path_1"] == [p.path for p in appended] diff --git a/tests/unit/source_code/notebooks/test_cells.py b/tests/unit/source_code/notebooks/test_cells.py index 89e2805dbd..56c8835168 100644 --- a/tests/unit/source_code/notebooks/test_cells.py +++ b/tests/unit/source_code/notebooks/test_cells.py @@ -4,6 +4,7 @@ import pytest +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.graph import Dependency, DependencyGraph, DependencyResolver from databricks.labs.ucx.source_code.linters.files import FileLoader, ImportFileResolver from databricks.labs.ucx.source_code.notebooks.cells import CellLanguage, PipCell @@ -117,8 +118,10 @@ def test_pip_cell_build_dependency_graph_reports_unknown_library(mock_path_looku notebook_loader = NotebookLoader() notebook_resolver = NotebookResolver(notebook_loader) pip_resolver = PythonLibraryResolver(Whitelist()) - dependency_resolver = DependencyResolver(pip_resolver, notebook_resolver, [], mock_path_lookup) - graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup) + dependency_resolver = DependencyResolver( + pip_resolver, notebook_resolver, [], mock_path_lookup, CurrentSessionState() + ) + graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup, CurrentSessionState()) code = "%pip install unknown-library-name" cell = PipCell(code, original_offset=1) @@ -138,8 +141,10 @@ def test_pip_cell_build_dependency_graph_resolves_installed_library(mock_path_lo file_loader = FileLoader() pip_resolver = PythonLibraryResolver(whitelist) import_resolver = ImportFileResolver(file_loader, whitelist) - dependency_resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, mock_path_lookup) - graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup) + dependency_resolver = DependencyResolver( + pip_resolver, notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() + ) + graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup, CurrentSessionState()) whl = Path(__file__).parent / '../samples/distribution/dist/thingy-0.0.1-py2.py3-none-any.whl' diff --git a/tests/unit/source_code/test_dependencies.py b/tests/unit/source_code/test_dependencies.py index 66f27711b7..1250ff243b 100644 --- a/tests/unit/source_code/test_dependencies.py +++ b/tests/unit/source_code/test_dependencies.py @@ -1,7 +1,7 @@ from pathlib import Path from unittest.mock import create_autospec - +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.graph import ( SourceContainer, DependencyResolver, @@ -121,7 +121,9 @@ def test_dependency_resolver_terminates_at_known_libraries(empty_index, mock_not file_loader = FileLoader() import_resolver = ImportFileResolver(file_loader, Whitelist()) library_resolver = PythonLibraryResolver(Whitelist()) - resolver = DependencyResolver(library_resolver, mock_notebook_resolver, import_resolver, lookup) + resolver = DependencyResolver( + library_resolver, mock_notebook_resolver, import_resolver, lookup, CurrentSessionState() + ) maybe = resolver.build_local_file_dependency_graph(Path("import-site-package.py")) assert not maybe.failed graph = maybe.graph @@ -154,7 +156,9 @@ def load_dependency(self, path_lookup: PathLookup, dependency: Dependency) -> So whitelist = Whitelist() import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - resolver = DependencyResolver(pip_resolver, mock_notebook_resolver, import_resolver, mock_path_lookup) + resolver = DependencyResolver( + pip_resolver, mock_notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() + ) maybe = resolver.build_local_file_dependency_graph(Path("import-sub-site-package.py")) assert list(maybe.problems) == [ DependencyProblem( @@ -172,7 +176,7 @@ def load_dependency(self, path_lookup: PathLookup, dependency: Dependency) -> So notebook_loader = FailingNotebookLoader() notebook_resolver = NotebookResolver(notebook_loader) pip_resolver = PythonLibraryResolver(Whitelist()) - resolver = DependencyResolver(pip_resolver, notebook_resolver, [], mock_path_lookup) + resolver = DependencyResolver(pip_resolver, notebook_resolver, [], mock_path_lookup, CurrentSessionState()) maybe = resolver.build_notebook_dependency_graph(Path("root5.py")) assert list(maybe.problems) == [ DependencyProblem('cannot-load-notebook', 'Could not load notebook root5.py', Path('')) @@ -183,7 +187,9 @@ def test_dependency_resolver_raises_problem_with_missing_file_loader(mock_notebo library_resolver = PythonLibraryResolver(Whitelist()) import_resolver = create_autospec(BaseImportResolver) import_resolver.resolve_import.return_value = None - resolver = DependencyResolver(library_resolver, mock_notebook_resolver, import_resolver, mock_path_lookup) + resolver = DependencyResolver( + library_resolver, mock_notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() + ) maybe = resolver.build_local_file_dependency_graph(Path("import-sub-site-package.py")) assert list(maybe.problems) == [ DependencyProblem('missing-file-resolver', 'Missing resolver for local files', Path('')) diff --git a/tests/unit/source_code/test_graph.py b/tests/unit/source_code/test_graph.py index 583b54d5b5..c06c51bb54 100644 --- a/tests/unit/source_code/test_graph.py +++ b/tests/unit/source_code/test_graph.py @@ -1,5 +1,6 @@ from pathlib import Path +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.linters.files import FileLoader, ImportFileResolver, FolderLoader from databricks.labs.ucx.source_code.graph import Dependency, DependencyGraph, DependencyResolver from databricks.labs.ucx.source_code.notebooks.loaders import NotebookResolver, NotebookLoader @@ -11,13 +12,15 @@ def test_dependency_graph_registers_library(mock_path_lookup): dependency = Dependency(FileLoader(), Path("test")) file_loader = FileLoader() whitelist = Whitelist() + session_state = CurrentSessionState() dependency_resolver = DependencyResolver( PythonLibraryResolver(whitelist), NotebookResolver(NotebookLoader()), ImportFileResolver(file_loader, whitelist), mock_path_lookup, + session_state, ) - graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup) + graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup, session_state) problems = graph.register_library("demo-egg") # installs pkgdir @@ -29,14 +32,16 @@ def test_folder_loads_content(mock_path_lookup): path = Path(Path(__file__).parent, "samples") file_loader = FileLoader() whitelist = Whitelist() + session_state = CurrentSessionState() dependency_resolver = DependencyResolver( PythonLibraryResolver(whitelist), NotebookResolver(NotebookLoader()), ImportFileResolver(file_loader, whitelist), mock_path_lookup, + session_state, ) dependency = Dependency(FolderLoader(file_loader), path) - graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup) + graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup, session_state) container = dependency.load(mock_path_lookup) container.build_dependency_graph(graph) assert len(graph.all_paths) > 1 diff --git a/tests/unit/source_code/test_jobs.py b/tests/unit/source_code/test_jobs.py index ec1ee7cb7e..3a5b46e2d6 100644 --- a/tests/unit/source_code/test_jobs.py +++ b/tests/unit/source_code/test_jobs.py @@ -7,6 +7,7 @@ from databricks.sdk.service.jobs import Job from databricks.sdk.service.pipelines import NotebookLibrary, GetPipelineResponse, PipelineLibrary, FileLibrary +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.python_libraries import PythonLibraryResolver from databricks.labs.ucx.source_code.known import Whitelist from databricks.sdk import WorkspaceClient @@ -38,6 +39,7 @@ def dependency_resolver(mock_path_lookup) -> DependencyResolver: NotebookResolver(NotebookLoader()), ImportFileResolver(file_loader, whitelist), mock_path_lookup, + CurrentSessionState(), ) return resolver @@ -45,7 +47,7 @@ def dependency_resolver(mock_path_lookup) -> DependencyResolver: @pytest.fixture def graph(mock_path_lookup, dependency_resolver) -> DependencyGraph: dependency = Dependency(FileLoader(), Path("test")) - dependency_graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup) + dependency_graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup, CurrentSessionState()) return dependency_graph diff --git a/tests/unit/source_code/test_notebook.py b/tests/unit/source_code/test_notebook.py index e110afeb6b..8c1272f3b4 100644 --- a/tests/unit/source_code/test_notebook.py +++ b/tests/unit/source_code/test_notebook.py @@ -4,6 +4,7 @@ import pytest from databricks.sdk.service.workspace import Language, ObjectType, ObjectInfo +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.graph import DependencyGraph, SourceContainer, DependencyResolver from databricks.labs.ucx.source_code.known import Whitelist from databricks.labs.ucx.source_code.linters.files import ImportFileResolver, FileLoader @@ -133,14 +134,16 @@ def dependency_resolver(mock_path_lookup) -> DependencyResolver: notebook_resolver = NotebookResolver(notebook_loader) library_resolver = PythonLibraryResolver(Whitelist()) import_resolver = ImportFileResolver(FileLoader(), Whitelist()) - return DependencyResolver(library_resolver, notebook_resolver, import_resolver, mock_path_lookup) + return DependencyResolver( + library_resolver, notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() + ) def test_notebook_builds_leaf_dependency_graph(mock_path_lookup) -> None: resolver = dependency_resolver(mock_path_lookup) maybe = resolver.resolve_notebook(mock_path_lookup, Path("leaf1.py")) assert maybe.dependency is not None - graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup) + graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup, CurrentSessionState()) container = maybe.dependency.load(mock_path_lookup) assert container is not None problems = container.build_dependency_graph(graph) @@ -158,7 +161,7 @@ def test_notebook_builds_depth1_dependency_graph(mock_path_lookup) -> None: resolver = dependency_resolver(mock_path_lookup) maybe = resolver.resolve_notebook(mock_path_lookup, Path(paths[0])) assert maybe.dependency is not None - graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup) + graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup, CurrentSessionState()) container = maybe.dependency.load(mock_path_lookup) assert container is not None problems = container.build_dependency_graph(graph) @@ -171,7 +174,7 @@ def test_notebook_builds_depth2_dependency_graph(mock_path_lookup) -> None: resolver = dependency_resolver(mock_path_lookup) maybe = resolver.resolve_notebook(mock_path_lookup, Path(paths[0])) assert maybe.dependency is not None - graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup) + graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup, CurrentSessionState()) container = maybe.dependency.load(mock_path_lookup) assert container is not None problems = container.build_dependency_graph(graph) @@ -184,7 +187,7 @@ def test_notebook_builds_dependency_graph_avoiding_duplicates(mock_path_lookup) resolver = dependency_resolver(mock_path_lookup) maybe = resolver.resolve_notebook(mock_path_lookup, Path(paths[0])) assert maybe.dependency is not None - graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup) + graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup, CurrentSessionState()) container = maybe.dependency.load(mock_path_lookup) assert container is not None problems = container.build_dependency_graph(graph) @@ -198,7 +201,7 @@ def test_notebook_builds_cyclical_dependency_graph(mock_path_lookup) -> None: resolver = dependency_resolver(mock_path_lookup) maybe = resolver.resolve_notebook(mock_path_lookup, Path(paths[0])) assert maybe.dependency is not None - graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup) + graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup, CurrentSessionState()) container = maybe.dependency.load(mock_path_lookup) assert container is not None problems = container.build_dependency_graph(graph) @@ -211,7 +214,7 @@ def test_notebook_builds_python_dependency_graph(mock_path_lookup) -> None: resolver = dependency_resolver(mock_path_lookup) maybe = resolver.resolve_notebook(mock_path_lookup, Path(paths[0])) assert maybe.dependency is not None - graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup) + graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup, CurrentSessionState()) container = maybe.dependency.load(mock_path_lookup) assert container is not None problems = container.build_dependency_graph(graph) @@ -224,7 +227,7 @@ def test_notebook_builds_python_dependency_graph_with_loop(mock_path_lookup) -> resolver = dependency_resolver(mock_path_lookup) maybe = resolver.resolve_notebook(mock_path_lookup, Path(path)) assert maybe.dependency is not None - graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup) + graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup, CurrentSessionState()) container = maybe.dependency.load(mock_path_lookup) assert container is not None container.build_dependency_graph(graph) @@ -237,7 +240,7 @@ def test_notebook_builds_python_dependency_graph_with_fstring_loop(mock_path_loo resolver = dependency_resolver(mock_path_lookup) maybe = resolver.resolve_notebook(mock_path_lookup, Path(path)) assert maybe.dependency is not None - graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup) + graph = DependencyGraph(maybe.dependency, None, resolver, mock_path_lookup, CurrentSessionState()) container = maybe.dependency.load(mock_path_lookup) assert container is not None container.build_dependency_graph(graph) diff --git a/tests/unit/source_code/test_path_lookup_simulation.py b/tests/unit/source_code/test_path_lookup_simulation.py index 38c89cbd76..f93d1e5bb1 100644 --- a/tests/unit/source_code/test_path_lookup_simulation.py +++ b/tests/unit/source_code/test_path_lookup_simulation.py @@ -3,6 +3,8 @@ from unittest.mock import create_autospec import pytest + +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.linters.files import ImportFileResolver, FileLoader from databricks.labs.ucx.source_code.path_lookup import PathLookup from databricks.labs.ucx.source_code.graph import SourceContainer, DependencyResolver @@ -46,7 +48,9 @@ def test_locates_notebooks(source: list[str], expected: int, mock_path_lookup): whitelist = Whitelist() import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - dependency_resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, mock_path_lookup) + dependency_resolver = DependencyResolver( + pip_resolver, notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() + ) maybe = dependency_resolver.build_notebook_dependency_graph(notebook_path) assert not maybe.problems assert maybe.graph is not None @@ -72,7 +76,7 @@ def test_locates_files(source: list[str], expected: int): notebook_resolver = NotebookResolver(notebook_loader) import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, lookup) + resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, lookup, CurrentSessionState()) maybe = resolver.build_local_file_dependency_graph(file_path) assert not maybe.problems assert maybe.graph is not None @@ -111,7 +115,7 @@ def test_locates_notebooks_with_absolute_path(): whitelist = Whitelist() import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, lookup) + resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, lookup, CurrentSessionState()) maybe = resolver.build_notebook_dependency_graph(parent_file_path) assert not maybe.problems assert maybe.graph is not None @@ -150,7 +154,7 @@ def func(): file_loader = FileLoader() import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, lookup) + resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, lookup, CurrentSessionState()) maybe = resolver.build_notebook_dependency_graph(parent_file_path) assert not maybe.problems assert maybe.graph is not None diff --git a/tests/unit/source_code/test_s3fs.py b/tests/unit/source_code/test_s3fs.py index 9f4650dcbc..9f02b50305 100644 --- a/tests/unit/source_code/test_s3fs.py +++ b/tests/unit/source_code/test_s3fs.py @@ -2,6 +2,7 @@ import pytest +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.graph import ( DependencyResolver, DependencyProblem, @@ -121,7 +122,9 @@ def test_detect_s3fs_import(empty_index, source: str, expected: list[DependencyP notebook_resolver = NotebookResolver(notebook_loader) import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - dependency_resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, mock_path_lookup) + dependency_resolver = DependencyResolver( + pip_resolver, notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() + ) maybe = dependency_resolver.build_local_file_dependency_graph(sample) assert maybe.problems == [_.replace(source_path=sample) for _ in expected] @@ -151,7 +154,9 @@ def test_detect_s3fs_import_in_dependencies( whitelist = Whitelist() import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - dependency_resolver = DependencyResolver(pip_resolver, mock_notebook_resolver, import_resolver, mock_path_lookup) + dependency_resolver = DependencyResolver( + pip_resolver, mock_notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() + ) sample = mock_path_lookup.cwd / "root9.py" maybe = dependency_resolver.build_local_file_dependency_graph(sample) assert maybe.problems == expected From 84e13a95ab255ccc1ae6b80df1e6f35c21e5d3f1 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 12 Jun 2024 17:06:41 +0200 Subject: [PATCH 06/23] drop obsolete file --- src/databricks/labs/ucx/contexts/base.py | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 src/databricks/labs/ucx/contexts/base.py diff --git a/src/databricks/labs/ucx/contexts/base.py b/src/databricks/labs/ucx/contexts/base.py deleted file mode 100644 index 04dc3e34b9..0000000000 --- a/src/databricks/labs/ucx/contexts/base.py +++ /dev/null @@ -1,18 +0,0 @@ -from functools import cached_property - - -class BaseContext: - def __init__(self, named_parameters: dict[str, str] | None = None): - if not named_parameters: - named_parameters = {} - self._named_parameters = named_parameters - - def replace(self, **kwargs): - """Replace cached properties for unit testing purposes.""" - for key, value in kwargs.items(): - self.__dict__[key] = value - return self - - @cached_property - def named_parameters(self) -> dict[str, str]: - return self._named_parameters From d0117eb933391b8b8465013c48c743bd78415a25 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 12 Jun 2024 19:03:27 +0200 Subject: [PATCH 07/23] integrate session state to linters and addd functional tests --- .../labs/ucx/contexts/workspace_cli.py | 1 - src/databricks/labs/ucx/source_code/base.py | 6 ++-- src/databricks/labs/ucx/source_code/graph.py | 2 +- src/databricks/labs/ucx/source_code/jobs.py | 12 +++---- src/databricks/labs/ucx/source_code/known.py | 3 +- .../labs/ucx/source_code/linters/context.py | 6 ++-- .../labs/ucx/source_code/linters/dbfs.py | 11 +++--- .../labs/ucx/source_code/linters/files.py | 36 +++++++++++-------- .../labs/ucx/source_code/linters/imports.py | 28 ++++++++------- .../labs/ucx/source_code/linters/pyspark.py | 32 +++++++++++------ .../ucx/source_code/linters/python_ast.py | 3 +- .../ucx/source_code/linters/spark_connect.py | 3 +- .../ucx/source_code/linters/table_creation.py | 3 +- src/databricks/labs/ucx/source_code/lsp.py | 3 +- .../ucx/source_code/notebooks/migrator.py | 8 ++--- .../labs/ucx/source_code/notebooks/sources.py | 20 +++++------ .../labs/ucx/source_code/queries.py | 2 +- tests/integration/source_code/solacc.py | 4 ++- tests/integration/source_code/test_jobs.py | 3 +- tests/unit/source_code/linters/test_dbfs.py | 14 ++++---- tests/unit/source_code/linters/test_files.py | 7 ++-- .../unit/source_code/linters/test_pyspark.py | 30 +++++++++------- .../linters/test_python_imports.py | 2 +- .../source_code/linters/test_spark_connect.py | 18 +++++----- .../linters/test_table_creation.py | 4 +-- .../source_code/notebooks/test_sources.py | 9 ++--- .../source_code/samples/functional/widgets.py | 5 +++ tests/unit/source_code/test_functional.py | 12 ++++--- tests/unit/source_code/test_notebook.py | 7 ++-- .../unit/source_code/test_notebook_linter.py | 6 ++-- tests/unit/source_code/test_queries.py | 6 ++-- 31 files changed, 172 insertions(+), 134 deletions(-) create mode 100644 tests/unit/source_code/samples/functional/widgets.py diff --git a/src/databricks/labs/ucx/contexts/workspace_cli.py b/src/databricks/labs/ucx/contexts/workspace_cli.py index fc24eb2f9d..3c63223d8c 100644 --- a/src/databricks/labs/ucx/contexts/workspace_cli.py +++ b/src/databricks/labs/ucx/contexts/workspace_cli.py @@ -194,7 +194,6 @@ def local_code_linter(self): self.file_loader, self.folder_loader, self.path_lookup, - self.session_state, self.dependency_resolver, self.linter_context_factory, ) diff --git a/src/databricks/labs/ucx/source_code/base.py b/src/databricks/labs/ucx/source_code/base.py index 77c1133e53..ca2cd015ea 100644 --- a/src/databricks/labs/ucx/source_code/base.py +++ b/src/databricks/labs/ucx/source_code/base.py @@ -127,7 +127,7 @@ class Convention(Advice): class Linter: @abstractmethod - def lint(self, code: str) -> Iterable[Advice]: ... + def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: ... class Fixer: @@ -167,6 +167,6 @@ class SequentialLinter(Linter): def __init__(self, linters: list[Linter]): self._linters = linters - def lint(self, code: str) -> Iterable[Advice]: + def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: for linter in self._linters: - yield from linter.lint(code) + yield from linter.lint(code, session_state) diff --git a/src/databricks/labs/ucx/source_code/graph.py b/src/databricks/labs/ucx/source_code/graph.py index 75401b3cfb..821a24ec3d 100644 --- a/src/databricks/labs/ucx/source_code/graph.py +++ b/src/databricks/labs/ucx/source_code/graph.py @@ -217,7 +217,7 @@ def _register_import(self, base_node: ImportSource): yield from self.register_import(prefix + name) def _register_notebook(self, base_node: NotebookRunCall): - has_unresolved, paths = base_node.get_notebook_paths() + has_unresolved, paths = base_node.get_notebook_paths(self._session_state) if has_unresolved: yield DependencyProblem( 'dependency-cannot-compute', diff --git a/src/databricks/labs/ucx/source_code/jobs.py b/src/databricks/labs/ucx/source_code/jobs.py index 8f88642e82..cea6eb6707 100644 --- a/src/databricks/labs/ucx/source_code/jobs.py +++ b/src/databricks/labs/ucx/source_code/jobs.py @@ -365,16 +365,16 @@ def _lint_task(self, task: jobs.Task, job: jobs.Job): if not container: continue if isinstance(container, Notebook): - yield from self._lint_notebook(container, ctx) + yield from self._lint_notebook(container, ctx, session_state) if isinstance(container, LocalFile): - yield from self._lint_file(container, ctx) + yield from self._lint_file(container, ctx, session_state) - def _lint_file(self, file: LocalFile, ctx: LinterContext): + def _lint_file(self, file: LocalFile, ctx: LinterContext, session_state: CurrentSessionState): linter = FileLinter(ctx, file.path) - for advice in linter.lint(): + for advice in linter.lint(session_state): yield file.path, advice - def _lint_notebook(self, notebook: Notebook, ctx: LinterContext): + def _lint_notebook(self, notebook: Notebook, ctx: LinterContext, session_state: CurrentSessionState): linter = NotebookLinter(ctx, notebook) - for advice in linter.lint(): + for advice in linter.lint(session_state): yield notebook.path, advice diff --git a/src/databricks/labs/ucx/source_code/known.py b/src/databricks/labs/ucx/source_code/known.py index c5cb10e8c8..ec4e464213 100644 --- a/src/databricks/labs/ucx/source_code/known.py +++ b/src/databricks/labs/ucx/source_code/known.py @@ -14,6 +14,7 @@ from databricks.labs.blueprint.entrypoint import get_logger from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.graph import DependencyProblem from databricks.labs.ucx.source_code.linters.context import LinterContext from databricks.labs.ucx.source_code.notebooks.sources import FileLinter @@ -149,7 +150,7 @@ def _analyze_file(cls, known_distributions, library_root, dist_info, module_path ctx = LinterContext(empty_index) linter = FileLinter(ctx, module_path) known_problems = set() - for problem in linter.lint(): + for problem in linter.lint(CurrentSessionState()): known_problems.add(KnownProblem(problem.code, problem.message)) problems = [_.as_dict() for _ in sorted(known_problems)] known_distributions[dist_info.name][module_ref] = problems diff --git a/src/databricks/labs/ucx/source_code/linters/context.py b/src/databricks/labs/ucx/source_code/linters/context.py index 5660b877cf..85019af5da 100644 --- a/src/databricks/labs/ucx/source_code/linters/context.py +++ b/src/databricks/labs/ucx/source_code/linters/context.py @@ -21,7 +21,7 @@ def __init__(self, index: MigrationIndex, session_state: CurrentSessionState | N Language.PYTHON: SequentialLinter( [ SparkSql(from_table, index), - DBFSUsageLinter(session_state), + DBFSUsageLinter(), DBRv8d0Linter(dbr_version=None), SparkConnectLinter(is_serverless=False), DbutilsLinter(), @@ -50,9 +50,9 @@ def fixer(self, language: Language, diagnostic_code: str) -> Fixer | None: return fixer return None - def apply_fixes(self, language: Language, code: str) -> str: + def apply_fixes(self, language: Language, code: str, session_state: CurrentSessionState) -> str: linter = self.linter(language) - for advice in linter.lint(code): + for advice in linter.lint(code, session_state): fixer = self.fixer(language, advice.code) if fixer: code = fixer.apply(code) diff --git a/src/databricks/labs/ucx/source_code/linters/dbfs.py b/src/databricks/labs/ucx/source_code/linters/dbfs.py index fcf8aa5000..3ee63761e5 100644 --- a/src/databricks/labs/ucx/source_code/linters/dbfs.py +++ b/src/databricks/labs/ucx/source_code/linters/dbfs.py @@ -54,7 +54,7 @@ def _check_str_constant(self, source_node, inferred: InferredValue): def _already_reported(self, source_node: NodeNG, inferred: InferredValue): all_nodes = [source_node] - all_nodes.extend(inferred.nodes()) + all_nodes.extend(inferred.nodes) reported = any((node.lineno, node.col_offset) in self._reported_locations for node in all_nodes) for node in all_nodes: self._reported_locations.add((node.lineno, node.col_offset)) @@ -65,9 +65,6 @@ def get_advices(self) -> Iterable[Advice]: class DBFSUsageLinter(Linter): - def __init__(self, session_state: CurrentSessionState): - self._session_state = session_state - @staticmethod def name() -> str: """ @@ -75,12 +72,12 @@ def name() -> str: """ return 'dbfs-usage' - def lint(self, code: str) -> Iterable[Advice]: + def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: """ Lints the code looking for file system paths that are deprecated """ tree = Tree.parse(code) - visitor = DetectDbfsVisitor(self._session_state) + visitor = DetectDbfsVisitor(session_state) visitor.visit(tree.node) yield from visitor.get_advices() @@ -93,7 +90,7 @@ def __init__(self): def name() -> str: return 'dbfs-query' - def lint(self, code: str) -> Iterable[Advice]: + def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: for statement in sqlglot.parse(code, read='databricks'): if not statement: continue diff --git a/src/databricks/labs/ucx/source_code/linters/files.py b/src/databricks/labs/ucx/source_code/linters/files.py index cd8fc05411..5fca181a12 100644 --- a/src/databricks/labs/ucx/source_code/linters/files.py +++ b/src/databricks/labs/ucx/source_code/linters/files.py @@ -88,19 +88,23 @@ def __init__( file_loader: FileLoader, folder_loader: FolderLoader, path_lookup: PathLookup, - session_state: CurrentSessionState, dependency_resolver: DependencyResolver, languages_factory: Callable[[], LinterContext], ) -> None: self._file_loader = file_loader self._folder_loader = folder_loader self._path_lookup = path_lookup - self._session_state = session_state self._dependency_resolver = dependency_resolver self._extensions = {".py": Language.PYTHON, ".sql": Language.SQL} self._new_linter_context = languages_factory - def lint(self, prompts: Prompts, path: Path | None, stdout: TextIO = sys.stdout) -> list[LocatedAdvice]: + def lint( + self, + prompts: Prompts, + path: Path | None, + session_state: CurrentSessionState | None = None, + stdout: TextIO = sys.stdout, + ) -> list[LocatedAdvice]: """Lint local code files looking for problems in notebooks and python files.""" if path is None: response = prompts.question( @@ -109,16 +113,18 @@ def lint(self, prompts: Prompts, path: Path | None, stdout: TextIO = sys.stdout) validate=lambda p_: Path(p_).exists(), ) path = Path(response) - located_advices = list(self.lint_path(path)) + if session_state is None: + session_state = CurrentSessionState() + located_advices = list(self.lint_path(path, session_state)) for located in located_advices: message = located.message_relative_to(path) stdout.write(f"{message}\n") return located_advices - def lint_path(self, path: Path) -> Iterable[LocatedAdvice]: + def lint_path(self, path: Path, session_state: CurrentSessionState) -> Iterable[LocatedAdvice]: loader = self._folder_loader if path.is_dir() else self._file_loader dependency = Dependency(loader, path) - graph = DependencyGraph(dependency, None, self._dependency_resolver, self._path_lookup, self._session_state) + graph = DependencyGraph(dependency, None, self._dependency_resolver, self._path_lookup, session_state) container = dependency.load(self._path_lookup) assert container is not None # because we just created it problems = container.build_dependency_graph(graph) @@ -126,14 +132,14 @@ def lint_path(self, path: Path) -> Iterable[LocatedAdvice]: problem_path = Path('UNKNOWN') if problem.is_path_missing() else problem.source_path.absolute() yield problem.as_advisory().for_path(problem_path) for child_path in graph.all_paths: - yield from self._lint_one(child_path) + yield from self._lint_one(child_path, session_state) - def _lint_one(self, path: Path) -> Iterable[LocatedAdvice]: + def _lint_one(self, path: Path, session_state: CurrentSessionState) -> Iterable[LocatedAdvice]: if path.is_dir(): return [] ctx = self._new_linter_context() linter = FileLinter(ctx, path) - return [advice.for_path(path) for advice in linter.lint()] + return [advice.for_path(path) for advice in linter.lint(session_state)] class LocalFileMigrator: @@ -143,14 +149,16 @@ def __init__(self, languages_factory: Callable[[], LinterContext]): self._extensions = {".py": Language.PYTHON, ".sql": Language.SQL} self._languages_factory = languages_factory - def apply(self, path: Path) -> bool: + def apply(self, path: Path, session_state: CurrentSessionState | None = None) -> bool: + if session_state is None: + session_state = CurrentSessionState() if path.is_dir(): for child_path in path.iterdir(): - self.apply(child_path) + self.apply(child_path, session_state) return True - return self._apply_file_fix(path) + return self._apply_file_fix(path, session_state) - def _apply_file_fix(self, path): + def _apply_file_fix(self, path: Path, session_state: CurrentSessionState): """ The fix method reads a file, lints it, applies fixes, and writes the fixed code back to the file. """ @@ -175,7 +183,7 @@ def _apply_file_fix(self, path): return False applied = False # Lint the code and apply fixes - for advice in linter.lint(code): + for advice in linter.lint(code, session_state): logger.info(f"Found: {advice}") fixer = languages.fixer(language, advice.code) if not fixer: diff --git a/src/databricks/labs/ucx/source_code/linters/imports.py b/src/databricks/labs/ucx/source_code/linters/imports.py index 23658ed809..17b48ce39c 100644 --- a/src/databricks/labs/ucx/source_code/linters/imports.py +++ b/src/databricks/labs/ucx/source_code/linters/imports.py @@ -82,45 +82,47 @@ class NotebookRunCall(NodeBase): def __init__(self, node: Call): super().__init__(node) - def get_notebook_paths(self) -> tuple[bool, list[str]]: + def get_notebook_paths(self, session_state: CurrentSessionState) -> tuple[bool, list[str]]: """we return multiple paths because astroid can infer them in scenarios such as: paths = ["p1", "p2"] for path in paths: dbutils.notebook.run(path) """ - node = DbutilsLinter.get_dbutils_notebook_run_path_arg(self.node) + arg = DbutilsLinter.get_dbutils_notebook_run_path_arg(self.node) try: - return self._get_notebook_paths(node.infer()) + all_inferred = Tree(arg).infer_values(session_state) + return self._get_notebook_paths(all_inferred) except InferenceError: - logger.debug(f"Can't infer value(s) of {node.as_string()}") + logger.debug(f"Can't infer value(s) of {arg.as_string()}") return True, [] @classmethod - def _get_notebook_paths(cls, nodes: Iterable[NodeNG]) -> tuple[bool, list[str]]: + def _get_notebook_paths(cls, all_inferred: Iterable[InferredValue]) -> tuple[bool, list[str]]: has_unresolved = False paths: list[str] = [] - for node in nodes: - if isinstance(node, Const): - paths.append(node.as_string().strip("'").strip('"')) + for inferred in all_inferred: + if inferred.is_inferred(): + paths.append(inferred.as_string().strip("'").strip('"')) continue - logger.debug(f"Can't compute {type(node).__name__}") + typenames = [type(node).__name__ for node in inferred.nodes] + logger.debug(f"Can't compute nodes [{','.join(typenames)}]") has_unresolved = True return has_unresolved, paths class DbutilsLinter(Linter): - def lint(self, code: str) -> Iterable[Advice]: + def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: tree = Tree.parse(code) nodes = self.list_dbutils_notebook_run_calls(tree) for node in nodes: - yield from self._raise_advice_if_unresolved(node.node) + yield from self._raise_advice_if_unresolved(node.node, session_state) @classmethod - def _raise_advice_if_unresolved(cls, node: NodeNG) -> Iterable[Advice]: + def _raise_advice_if_unresolved(cls, node: NodeNG, session_state: CurrentSessionState) -> Iterable[Advice]: assert isinstance(node, Call) call = NotebookRunCall(cast(Call, node)) - has_unresolved, _ = call.get_notebook_paths() + has_unresolved, _ = call.get_notebook_paths(session_state) if has_unresolved: yield from [ Advisory.from_node( diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index b5441a41ab..f7ba908413 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -31,7 +31,9 @@ def matches(self, node: NodeNG): return isinstance(node, Call) and isinstance(node.func, Attribute) and self._get_table_arg(node) is not None @abstractmethod - def lint(self, from_table: FromTable, index: MigrationIndex, node: Call) -> Iterator[Advice]: + def lint( + self, from_table: FromTable, index: MigrationIndex, session_state: CurrentSessionState, node: Call + ) -> Iterator[Advice]: """raises Advices by linting the code""" @abstractmethod @@ -70,12 +72,14 @@ def _check_call_context(self, node: Call) -> bool: @dataclass class QueryMatcher(Matcher): - def lint(self, from_table: FromTable, index: MigrationIndex, node: Call) -> Iterator[Advice]: + def lint( + self, from_table: FromTable, index: MigrationIndex, session_state: CurrentSessionState, node: Call + ) -> Iterator[Advice]: table_arg = self._get_table_arg(node) if table_arg: try: for inferred in Tree(table_arg).infer_values(self.session_state): - yield from self._lint_table_arg(from_table, node, inferred) + yield from self._lint_table_arg(from_table, node, inferred, session_state) except InferenceError: yield Advisory.from_node( code='table-migrate', @@ -84,9 +88,11 @@ def lint(self, from_table: FromTable, index: MigrationIndex, node: Call) -> Iter ) @classmethod - def _lint_table_arg(cls, from_table: FromTable, call_node: NodeNG, inferred: InferredValue): + def _lint_table_arg( + cls, from_table: FromTable, call_node: NodeNG, inferred: InferredValue, session_state: CurrentSessionState + ): if inferred.is_inferred(): - for advice in from_table.lint(inferred.as_string()): + for advice in from_table.lint(inferred.as_string(), session_state): yield advice.replace_from_node(call_node) else: yield Advisory.from_node( @@ -105,7 +111,9 @@ def apply(self, from_table: FromTable, index: MigrationIndex, node: Call) -> Non @dataclass class TableNameMatcher(Matcher): - def lint(self, from_table: FromTable, index: MigrationIndex, node: Call) -> Iterator[Advice]: + def lint( + self, from_table: FromTable, index: MigrationIndex, session_state: CurrentSessionState, node: Call + ) -> Iterator[Advice]: table_arg = self._get_table_arg(node) if not isinstance(table_arg, Const): @@ -150,7 +158,9 @@ class ReturnValueMatcher(Matcher): def matches(self, node: NodeNG): return isinstance(node, Call) and isinstance(node.func, Attribute) - def lint(self, from_table: FromTable, index: MigrationIndex, node: Call) -> Iterator[Advice]: + def lint( + self, from_table: FromTable, index: MigrationIndex, session_state: CurrentSessionState, node: Call + ) -> Iterator[Advice]: assert isinstance(node.func, Attribute) # always true, avoids a pylint warning yield Advisory.from_node( code='table-migrate', @@ -181,7 +191,9 @@ class DirectFilesystemAccessMatcher(Matcher): def matches(self, node: NodeNG): return isinstance(node, Call) and isinstance(node.func, Attribute) and self._get_table_arg(node) is not None - def lint(self, from_table: FromTable, index: MigrationIndex, node: NodeNG) -> Iterator[Advice]: + def lint( + self, from_table: FromTable, index: MigrationIndex, session_state: CurrentSessionState, node: NodeNG + ) -> Iterator[Advice]: table_arg = self._get_table_arg(node) if not isinstance(table_arg, Const): return @@ -319,7 +331,7 @@ def name(self) -> str: # this is the same fixer, just in a different language context return self._from_table.name() - def lint(self, code: str) -> Iterable[Advice]: + def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: try: tree = Tree.parse(code) except AstroidSyntaxError as e: @@ -330,7 +342,7 @@ def lint(self, code: str) -> Iterable[Advice]: if matcher is None: continue assert isinstance(node, Call) - yield from matcher.lint(self._from_table, self._index, node) + yield from matcher.lint(self._from_table, self._index, session_state, node) def apply(self, code: str) -> str: tree = Tree.parse(code) diff --git a/src/databricks/labs/ucx/source_code/linters/python_ast.py b/src/databricks/labs/ucx/source_code/linters/python_ast.py index 35b8c99991..efb74aa8cc 100644 --- a/src/databricks/labs/ucx/source_code/linters/python_ast.py +++ b/src/databricks/labs/ucx/source_code/linters/python_ast.py @@ -148,7 +148,7 @@ def infer_values(self, state: CurrentSessionState | None = None) -> Iterable[Inf yield InferredValue(inferred_atoms) def contextualize(self, state: CurrentSessionState): - calls = self.locate(Call, [("get", Attribute), ("widgets", Attribute), ("dbutils", Name)]) + calls = Tree(self.root).locate(Call, [("get", Attribute), ("widgets", Attribute), ("dbutils", Name)]) for call in calls: call.func = _ContextualCall(state, call) @@ -241,6 +241,7 @@ class InferredValue: def __init__(self, atoms: Iterable[NodeNG]): self._atoms = list(atoms) + @property def nodes(self): return self._atoms diff --git a/src/databricks/labs/ucx/source_code/linters/spark_connect.py b/src/databricks/labs/ucx/source_code/linters/spark_connect.py index 35f5230a7a..d344e76fce 100644 --- a/src/databricks/labs/ucx/source_code/linters/spark_connect.py +++ b/src/databricks/labs/ucx/source_code/linters/spark_connect.py @@ -7,6 +7,7 @@ Advice, Failure, Linter, + CurrentSessionState, ) from databricks.labs.ucx.source_code.linters.python_ast import Tree @@ -181,7 +182,7 @@ def __init__(self, is_serverless: bool = False): LoggingMatcher(is_serverless=is_serverless), ] - def lint(self, code: str) -> Iterator[Advice]: + def lint(self, code: str, session_state: CurrentSessionState) -> Iterator[Advice]: tree = Tree.parse(code) for matcher in self._matchers: yield from matcher.lint_tree(tree.node) diff --git a/src/databricks/labs/ucx/source_code/linters/table_creation.py b/src/databricks/labs/ucx/source_code/linters/table_creation.py index 95f00903f8..7e104195a1 100644 --- a/src/databricks/labs/ucx/source_code/linters/table_creation.py +++ b/src/databricks/labs/ucx/source_code/linters/table_creation.py @@ -8,6 +8,7 @@ from databricks.labs.ucx.source_code.base import ( Advice, Linter, + CurrentSessionState, ) from databricks.labs.ucx.source_code.linters.python_ast import Tree @@ -111,7 +112,7 @@ def __init__(self, dbr_version: tuple[int, int] | None): ] ) - def lint(self, code: str) -> Iterable[Advice]: + def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: if self._skip_dbr: return diff --git a/src/databricks/labs/ucx/source_code/lsp.py b/src/databricks/labs/ucx/source_code/lsp.py index 05bd3a7559..df19dece39 100644 --- a/src/databricks/labs/ucx/source_code/lsp.py +++ b/src/databricks/labs/ucx/source_code/lsp.py @@ -22,6 +22,7 @@ Convention, Deprecation, Failure, + CurrentSessionState, ) from databricks.labs.ucx.source_code.linters.context import LinterContext @@ -243,7 +244,7 @@ def _read(self, file_uri: str): def lint(self, file_uri: str): code, language = self._read(file_uri) analyser = self._languages.linter(language) - diagnostics = [Diagnostic.from_advice(_) for _ in analyser.lint(code)] + diagnostics = [Diagnostic.from_advice(_) for _ in analyser.lint(code, CurrentSessionState())] return AnalyseResponse(diagnostics) def quickfix(self, file_uri: str, code_range: Range, diagnostic_code: str): diff --git a/src/databricks/labs/ucx/source_code/notebooks/migrator.py b/src/databricks/labs/ucx/source_code/notebooks/migrator.py index 49e614a75c..a56e8acc4a 100644 --- a/src/databricks/labs/ucx/source_code/notebooks/migrator.py +++ b/src/databricks/labs/ucx/source_code/notebooks/migrator.py @@ -2,7 +2,7 @@ from pathlib import Path - +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.graph import Dependency from databricks.labs.ucx.source_code.linters.context import LinterContext from databricks.labs.ucx.source_code.notebooks.cells import RunCell @@ -30,9 +30,9 @@ def apply(self, path: Path) -> bool: lookup = PathLookup.from_sys_path(Path.cwd()) container = dependency.load(lookup) assert isinstance(container, Notebook) - return self._apply(container) + return self._apply(container, CurrentSessionState()) - def _apply(self, notebook: Notebook) -> bool: + def _apply(self, notebook: Notebook, session_state) -> bool: changed = False for cell in notebook.cells: # %run is not a supported language, so this needs to come first @@ -43,7 +43,7 @@ def _apply(self, notebook: Notebook) -> bool: continue if not self._languages.is_supported(cell.language.language): continue - migrated_code = self._languages.apply_fixes(cell.language.language, cell.original_code) + migrated_code = self._languages.apply_fixes(cell.language.language, cell.original_code, session_state) if migrated_code != cell.original_code: cell.migrated_code = migrated_code changed = True diff --git a/src/databricks/labs/ucx/source_code/notebooks/sources.py b/src/databricks/labs/ucx/source_code/notebooks/sources.py index 1e8fcf4f76..9f6e050f4f 100644 --- a/src/databricks/labs/ucx/source_code/notebooks/sources.py +++ b/src/databricks/labs/ucx/source_code/notebooks/sources.py @@ -8,7 +8,7 @@ from databricks.sdk.service.workspace import Language from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex -from databricks.labs.ucx.source_code.base import Advice, Failure +from databricks.labs.ucx.source_code.base import Advice, Failure, CurrentSessionState from databricks.labs.ucx.source_code.graph import SourceContainer, DependencyGraph, DependencyProblem from databricks.labs.ucx.source_code.linters.context import LinterContext @@ -94,12 +94,12 @@ def from_source(cls, index: MigrationIndex, source: str, default_language: Langu assert notebook is not None return cls(ctx, notebook) - def lint(self) -> Iterable[Advice]: + def lint(self, session_state: CurrentSessionState) -> Iterable[Advice]: for cell in self._notebook.cells: if not self._languages.is_supported(cell.language.language): continue linter = self._languages.linter(cell.language.language) - for advice in linter.lint(cell.original_code): + for advice in linter.lint(cell.original_code, session_state): yield advice.replace( start_line=advice.start_line + cell.original_offset, end_line=advice.end_line + cell.original_offset, @@ -179,7 +179,7 @@ def _is_notebook(self): return False return self._source_code.startswith(CellLanguage.of_language(language).file_magic_header) - def lint(self) -> Iterable[Advice]: + def lint(self, session_state: CurrentSessionState) -> Iterable[Advice]: encoding = locale.getpreferredencoding(False) try: is_notebook = self._is_notebook() @@ -189,11 +189,11 @@ def lint(self) -> Iterable[Advice]: return if is_notebook: - yield from self._lint_notebook() + yield from self._lint_notebook(session_state) else: - yield from self._lint_file() + yield from self._lint_file(session_state) - def _lint_file(self): + def _lint_file(self, session_state: CurrentSessionState): language = self._file_language() if not language: suffix = self._path.suffix.lower() @@ -206,13 +206,13 @@ def _lint_file(self): else: try: linter = self._ctx.linter(language) - yield from linter.lint(self._source_code) + yield from linter.lint(self._source_code, session_state) except ValueError as err: yield Failure( "unsupported-content", f"Error while parsing content of {self._path.as_posix()}: {err}", 0, 0, 1, 1 ) - def _lint_notebook(self): + def _lint_notebook(self, session_state: CurrentSessionState): notebook = Notebook.parse(self._path, self._source_code, self._file_language()) notebook_linter = NotebookLinter(self._ctx, notebook) - yield from notebook_linter.lint() + yield from notebook_linter.lint(session_state) diff --git a/src/databricks/labs/ucx/source_code/queries.py b/src/databricks/labs/ucx/source_code/queries.py index b2b0b9cdd1..cd223bd43f 100644 --- a/src/databricks/labs/ucx/source_code/queries.py +++ b/src/databricks/labs/ucx/source_code/queries.py @@ -42,7 +42,7 @@ def name(self) -> str: def schema(self): return self._session_state.schema - def lint(self, code: str) -> Iterable[Advice]: + def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: for statement in sqlglot.parse(code, read='databricks'): if not statement: continue diff --git a/tests/integration/source_code/solacc.py b/tests/integration/source_code/solacc.py index 20ead2a203..30637e0538 100644 --- a/tests/integration/source_code/solacc.py +++ b/tests/integration/source_code/solacc.py @@ -9,6 +9,7 @@ from databricks.labs.ucx.contexts.workspace_cli import LocalCheckoutContext from databricks.labs.ucx.framework.utils import run_command from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.linters.context import LinterContext logger = logging.getLogger("verify-accelerators") @@ -50,10 +51,11 @@ def lint_all(): ctx = LocalCheckoutContext(ws).replace(linter_context_factory=lambda: LinterContext(MigrationIndex([]))) parseable = 0 missing_imports = 0 + session_state = CurrentSessionState() all_files = list(dist.glob('**/*.py')) for file in all_files: try: - for located_advice in ctx.local_code_linter.lint_path(file): + for located_advice in ctx.local_code_linter.lint_path(file, session_state): if located_advice.advice.code == 'import-not-found': missing_imports += 1 message = located_advice.message_relative_to(dist.parent, default=file) diff --git a/tests/integration/source_code/test_jobs.py b/tests/integration/source_code/test_jobs.py index 839e3d0152..fc994ce916 100644 --- a/tests/integration/source_code/test_jobs.py +++ b/tests/integration/source_code/test_jobs.py @@ -187,11 +187,10 @@ def test_lint_local_code(simple_ctx): light_ctx.file_loader, light_ctx.folder_loader, light_ctx.path_lookup, - light_ctx.session_state, light_ctx.dependency_resolver, lambda: linter_context, ) - problems = linter.lint(Prompts(), path_to_scan, StringIO()) + problems = linter.lint(Prompts(), path_to_scan, light_ctx.session_state, StringIO()) assert len(problems) > 0 diff --git a/tests/unit/source_code/linters/test_dbfs.py b/tests/unit/source_code/linters/test_dbfs.py index 9ce795eee7..651ca48652 100644 --- a/tests/unit/source_code/linters/test_dbfs.py +++ b/tests/unit/source_code/linters/test_dbfs.py @@ -17,8 +17,8 @@ class TestDetectDBFS: ], ) def test_detects_dbfs_paths(self, code, expected): - linter = DBFSUsageLinter(CurrentSessionState()) - advices = list(linter.lint(code)) + linter = DBFSUsageLinter() + advices = list(linter.lint(code, CurrentSessionState())) for advice in advices: assert isinstance(advice, Advisory) assert len(advices) == expected @@ -46,8 +46,8 @@ def test_detects_dbfs_paths(self, code, expected): ], ) def test_dbfs_usage_linter(self, code, expected): - linter = DBFSUsageLinter(CurrentSessionState()) - advices = linter.lint(code) + linter = DBFSUsageLinter() + advices = linter.lint(code, CurrentSessionState()) count = 0 for advice in advices: if isinstance(advice, Deprecation): @@ -55,7 +55,7 @@ def test_dbfs_usage_linter(self, code, expected): assert count == expected def test_dbfs_name(self): - linter = DBFSUsageLinter(CurrentSessionState()) + linter = DBFSUsageLinter() assert linter.name() == "dbfs-usage" @@ -73,7 +73,7 @@ def test_dbfs_name(self): ) def test_non_dbfs_trigger_nothing(query): ftf = FromDbfsFolder() - assert not list(ftf.lint(query)) + assert not list(ftf.lint(query, CurrentSessionState())) @pytest.mark.parametrize( @@ -100,7 +100,7 @@ def test_dbfs_tables_trigger_messages_param(query: str, table: str): end_line=0, end_col=1024, ), - ] == list(ftf.lint(query)) + ] == list(ftf.lint(query, CurrentSessionState())) def test_dbfs_queries_name(): diff --git a/tests/unit/source_code/linters/test_files.py b/tests/unit/source_code/linters/test_files.py index 1b77587a15..3d2c4e096e 100644 --- a/tests/unit/source_code/linters/test_files.py +++ b/tests/unit/source_code/linters/test_files.py @@ -108,9 +108,9 @@ def test_linter_walks_directory(mock_path_lookup, migration_index): path = Path(Path(__file__).parent, "../samples", "simulate-sys-path") prompts = MockPrompts({"Which file or directory do you want to lint ?": path.as_posix()}) linter = LocalCodeLinter( - file_loader, folder_loader, mock_path_lookup, session_state, resolver, lambda: LinterContext(migration_index) + file_loader, folder_loader, mock_path_lookup, resolver, lambda: LinterContext(migration_index) ) - advices = linter.lint(prompts, None) + advices = linter.lint(prompts, None, session_state) assert not advices @@ -163,8 +163,7 @@ def test_known_issues(path: Path, migration_index): file_loader, folder_loader, path_lookup, - session_state, resolver, lambda: LinterContext(migration_index, session_state), ) - linter.lint(MockPrompts({}), path) + linter.lint(MockPrompts({}), path, session_state) diff --git a/tests/unit/source_code/linters/test_pyspark.py b/tests/unit/source_code/linters/test_pyspark.py index e5555ab059..7d045150b8 100644 --- a/tests/unit/source_code/linters/test_pyspark.py +++ b/tests/unit/source_code/linters/test_pyspark.py @@ -9,14 +9,16 @@ def test_spark_no_sql(empty_index): - ftf = FromTable(empty_index, CurrentSessionState()) + session_state = CurrentSessionState() + ftf = FromTable(empty_index, session_state) sqf = SparkSql(ftf, empty_index) - assert not list(sqf.lint("print(1)")) + assert not list(sqf.lint("print(1)", session_state)) def test_spark_sql_no_match(empty_index): - ftf = FromTable(empty_index, CurrentSessionState()) + session_state = CurrentSessionState() + ftf = FromTable(empty_index, session_state) sqf = SparkSql(ftf, empty_index) old_code = """ @@ -25,11 +27,12 @@ def test_spark_sql_no_match(empty_index): print(len(result)) """ - assert not list(sqf.lint(old_code)) + assert not list(sqf.lint(old_code, session_state)) def test_spark_sql_match(migration_index): - ftf = FromTable(migration_index, CurrentSessionState()) + session_state = CurrentSessionState() + ftf = FromTable(migration_index, session_state) sqf = SparkSql(ftf, migration_index) old_code = """ @@ -38,7 +41,7 @@ def test_spark_sql_match(migration_index): result = spark.sql("SELECT * FROM old.things").collect() print(len(result)) """ - assert list(sqf.lint(old_code)) == [ + assert list(sqf.lint(old_code, session_state)) == [ Deprecation( code='direct-filesystem-access', message='The use of direct filesystem references is deprecated: s3://bucket/path', @@ -59,7 +62,8 @@ def test_spark_sql_match(migration_index): def test_spark_sql_match_named(migration_index): - ftf = FromTable(migration_index, CurrentSessionState()) + session_state = CurrentSessionState() + ftf = FromTable(migration_index, session_state) sqf = SparkSql(ftf, migration_index) old_code = """ @@ -68,7 +72,7 @@ def test_spark_sql_match_named(migration_index): result = spark.sql(args=[1], sqlQuery = "SELECT * FROM old.things").collect() print(len(result)) """ - assert list(sqf.lint(old_code)) == [ + assert list(sqf.lint(old_code, session_state)) == [ Deprecation( code='direct-filesystem-access', message='The use of direct filesystem references is deprecated: ' 's3://bucket/path', @@ -522,9 +526,10 @@ def test_spark_sql_fix(migration_index): ], ) def test_spark_cloud_direct_access(empty_index, code, expected): - ftf = FromTable(empty_index, CurrentSessionState()) + session_state = CurrentSessionState() + ftf = FromTable(empty_index, session_state) sqf = SparkSql(ftf, empty_index) - advisories = list(sqf.lint(code)) + advisories = list(sqf.lint(code, session_state)) assert advisories == expected @@ -541,11 +546,12 @@ def test_spark_cloud_direct_access(empty_index, code, expected): @pytest.mark.parametrize("fs_function", FS_FUNCTIONS) def test_direct_cloud_access_reports_nothing(empty_index, fs_function): - ftf = FromTable(empty_index, CurrentSessionState()) + session_state = CurrentSessionState() + ftf = FromTable(empty_index, session_state) sqf = SparkSql(ftf, empty_index) # ls function calls have to be from dbutils.fs, or we ignore them code = f"""spark.{fs_function}("/bucket/path")""" - advisories = list(sqf.lint(code)) + advisories = list(sqf.lint(code, session_state)) assert not advisories diff --git a/tests/unit/source_code/linters/test_python_imports.py b/tests/unit/source_code/linters/test_python_imports.py index 8bd054423a..2ef072994d 100644 --- a/tests/unit/source_code/linters/test_python_imports.py +++ b/tests/unit/source_code/linters/test_python_imports.py @@ -188,6 +188,6 @@ def test_infers_dbutils_notebook_run_dynamic_value(code, expected): calls = DbutilsLinter.list_dbutils_notebook_run_calls(tree) all_paths: list[str] = [] for call in calls: - _, paths = call.get_notebook_paths() + _, paths = call.get_notebook_paths(CurrentSessionState()) all_paths.extend(paths) assert all_paths == expected diff --git a/tests/unit/source_code/linters/test_spark_connect.py b/tests/unit/source_code/linters/test_spark_connect.py index d33a64117f..3401f4bbc2 100644 --- a/tests/unit/source_code/linters/test_spark_connect.py +++ b/tests/unit/source_code/linters/test_spark_connect.py @@ -1,6 +1,6 @@ from itertools import chain -from databricks.labs.ucx.source_code.base import Failure +from databricks.labs.ucx.source_code.base import Failure, CurrentSessionState from databricks.labs.ucx.source_code.linters.python_ast import Tree from databricks.labs.ucx.source_code.linters.spark_connect import LoggingMatcher, SparkConnectLinter @@ -21,7 +21,7 @@ def test_jvm_access_match_shared(): end_col=18, ), ] - actual = list(linter.lint(code)) + actual = list(linter.lint(code, CurrentSessionState())) assert actual == expected @@ -42,7 +42,7 @@ def test_jvm_access_match_serverless(): end_col=18, ), ] - actual = list(linter.lint(code)) + actual = list(linter.lint(code, CurrentSessionState())) assert actual == expected @@ -86,7 +86,7 @@ def test_rdd_context_match_shared(): end_col=40, ), ] - actual = list(linter.lint(code)) + actual = list(linter.lint(code, CurrentSessionState())) assert actual == expected @@ -129,7 +129,7 @@ def test_rdd_context_match_serverless(): end_line=2, end_col=40, ), - ] == list(linter.lint(code)) + ] == list(linter.lint(code, CurrentSessionState())) def test_rdd_map_partitions(): @@ -148,7 +148,7 @@ def test_rdd_map_partitions(): end_col=27, ), ] - actual = list(linter.lint(code)) + actual = list(linter.lint(code, CurrentSessionState())) assert actual == expected @@ -164,7 +164,7 @@ def test_conf_shared(): end_line=0, end_col=23, ), - ] == list(linter.lint(code)) + ] == list(linter.lint(code, CurrentSessionState())) def test_conf_serverless(): @@ -180,7 +180,7 @@ def test_conf_serverless(): end_col=8, ), ] - actual = list(linter.lint(code)) + actual = list(linter.lint(code, CurrentSessionState())) assert actual == expected @@ -260,4 +260,4 @@ def test_valid_code(): df = spark.range(10) df.collect() """ - assert not list(linter.lint(code)) + assert not list(linter.lint(code, CurrentSessionState())) diff --git a/tests/unit/source_code/linters/test_table_creation.py b/tests/unit/source_code/linters/test_table_creation.py index e1b1c3756d..0d014f3c48 100644 --- a/tests/unit/source_code/linters/test_table_creation.py +++ b/tests/unit/source_code/linters/test_table_creation.py @@ -2,7 +2,7 @@ import pytest -from databricks.labs.ucx.source_code.base import Advice +from databricks.labs.ucx.source_code.base import Advice, CurrentSessionState from databricks.labs.ucx.source_code.linters.table_creation import DBRv8d0Linter @@ -60,7 +60,7 @@ def lint( dbr_version: tuple[int, int] | None = (7, 9), ) -> list[Advice]: """Invoke linting for the given dbr version""" - return list(DBRv8d0Linter(dbr_version).lint(code)) + return list(DBRv8d0Linter(dbr_version).lint(code, CurrentSessionState())) @pytest.mark.parametrize("method_name", METHOD_NAMES) diff --git a/tests/unit/source_code/notebooks/test_sources.py b/tests/unit/source_code/notebooks/test_sources.py index 9cd77b0c06..65afe8fb32 100644 --- a/tests/unit/source_code/notebooks/test_sources.py +++ b/tests/unit/source_code/notebooks/test_sources.py @@ -3,6 +3,7 @@ import pytest +from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.linters.context import LinterContext from databricks.labs.ucx.source_code.notebooks.sources import FileLinter @@ -10,14 +11,14 @@ @pytest.mark.parametrize("path, content", [("xyz.py", "a = 3"), ("xyz.sql", "select * from dual")]) def test_file_linter_lints_supported_language(path, content, migration_index): linter = FileLinter(LinterContext(migration_index), Path(path), content) - advices = list(linter.lint()) + advices = list(linter.lint(CurrentSessionState())) assert not advices @pytest.mark.parametrize("path", ["xyz.scala", "xyz.r", "xyz.sh"]) def test_file_linter_lints_not_yet_supported_language(path, migration_index): linter = FileLinter(LinterContext(migration_index), Path(path), "") - advices = list(linter.lint()) + advices = list(linter.lint(CurrentSessionState())) assert [advice.code for advice in advices] == ["unsupported-language"] @@ -42,7 +43,7 @@ def test_file_linter_lints_not_yet_supported_language(path, migration_index): ) def test_file_linter_lints_ignorable_language(path, migration_index): linter = FileLinter(LinterContext(migration_index), Path(path), "") - advices = list(linter.lint()) + advices = list(linter.lint(CurrentSessionState())) assert not advices @@ -51,7 +52,7 @@ def test_file_linter_lints_non_ascii_encoded_file(migration_index): non_ascii_encoded_file = Path(__file__).parent.parent / "samples" / "nonascii.py" linter = FileLinter(LinterContext(migration_index), non_ascii_encoded_file) - advices = list(linter.lint()) + advices = list(linter.lint(CurrentSessionState())) assert len(advices) == 1 assert advices[0].code == "unsupported-file-encoding" diff --git a/tests/unit/source_code/samples/functional/widgets.py b/tests/unit/source_code/samples/functional/widgets.py new file mode 100644 index 0000000000..6897cd6398 --- /dev/null +++ b/tests/unit/source_code/samples/functional/widgets.py @@ -0,0 +1,5 @@ +path = dbutils.widgets.get("my-widget") +dbutils.notebook.run(path) +path = dbutils.widgets.get("no-widget") +# ucx[dbutils-notebook-run-dynamic:+1:0:+1:26] Path for 'dbutils.notebook.run' cannot be computed and requires adjusting the notebook path(s) +dbutils.notebook.run(path) diff --git a/tests/unit/source_code/test_functional.py b/tests/unit/source_code/test_functional.py index 30b7c4be70..707694e7dd 100644 --- a/tests/unit/source_code/test_functional.py +++ b/tests/unit/source_code/test_functional.py @@ -10,7 +10,7 @@ import pytest from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex, MigrationStatus -from databricks.labs.ucx.source_code.base import Advice +from databricks.labs.ucx.source_code.base import Advice, CurrentSessionState from databricks.labs.ucx.source_code.linters.context import LinterContext from databricks.labs.ucx.source_code.notebooks.sources import FileLinter @@ -75,9 +75,9 @@ def __init__(self, path: Path) -> None: def verify(self) -> None: expected_problems = list(self._expected_problems()) - actual_advice = list(self._lint()) + actual_advices = list(self._lint()) # Convert the actual problems to the same type as our expected problems for easier comparison. - actual_problems = [Expectation.from_advice(advice) for advice in actual_advice] + actual_problems = [Expectation.from_advice(advice) for advice in actual_advices] # Fail the test if the comments don't match reality. expected_but_missing = sorted(set(expected_problems).difference(actual_problems)) @@ -102,9 +102,11 @@ def _lint(self) -> Iterable[Advice]: MigrationStatus('other', 'matters', dst_catalog='some', dst_schema='certain', dst_table='issues'), ] ) - ctx = LinterContext(migration_index) + session_state = CurrentSessionState() + session_state.named_parameters = {"my-widget": "my-path.py"} + ctx = LinterContext(migration_index, session_state) linter = FileLinter(ctx, self.path) - return linter.lint() + return linter.lint(session_state) def _expected_problems(self) -> Generator[Expectation, None, None]: with self.path.open('rb') as f: diff --git a/tests/unit/source_code/test_notebook.py b/tests/unit/source_code/test_notebook.py index 8c1272f3b4..852082c97e 100644 --- a/tests/unit/source_code/test_notebook.py +++ b/tests/unit/source_code/test_notebook.py @@ -275,10 +275,11 @@ def test_does_not_detect_partial_call_to_dbutils_notebook_run_in_python_code_() def test_raises_advice_when_dbutils_notebook_run_is_too_complex() -> None: source = """ -name = "xyz" -dbutils.notebook.run(f"Hey {name}") +name1 = "John" +name2 = f"{name1}" +dbutils.notebook.run(f"Hey {name2}") """ linter = DbutilsLinter() - advices = list(linter.lint(source)) + advices = list(linter.lint(source, CurrentSessionState())) assert len(advices) == 1 assert advices[0].code == "dbutils-notebook-run-dynamic" diff --git a/tests/unit/source_code/test_notebook_linter.py b/tests/unit/source_code/test_notebook_linter.py index d507f03b38..eceecb1567 100644 --- a/tests/unit/source_code/test_notebook_linter.py +++ b/tests/unit/source_code/test_notebook_linter.py @@ -2,7 +2,7 @@ from databricks.sdk.service.workspace import Language from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex -from databricks.labs.ucx.source_code.base import Deprecation, Advice +from databricks.labs.ucx.source_code.base import Deprecation, Advice, CurrentSessionState from databricks.labs.ucx.source_code.notebooks.sources import NotebookLinter index = MigrationIndex([]) @@ -317,7 +317,7 @@ def test_notebook_linter(lang, source, expected): # over multiple lines. linter = NotebookLinter.from_source(index, source, lang) assert linter is not None - gathered = list(linter.lint()) + gathered = list(linter.lint(CurrentSessionState())) assert gathered == expected @@ -545,5 +545,5 @@ def test_notebook_linter_name(): def test_notebook_linter_tracks_use(extended_test_index, lang, source, expected): linter = NotebookLinter.from_source(extended_test_index, source, lang) assert linter is not None - advices = list(linter.lint()) + advices = list(linter.lint(CurrentSessionState())) assert advices == expected diff --git a/tests/unit/source_code/test_queries.py b/tests/unit/source_code/test_queries.py index 0316170516..d327eb53e8 100644 --- a/tests/unit/source_code/test_queries.py +++ b/tests/unit/source_code/test_queries.py @@ -7,7 +7,7 @@ def test_not_migrated_tables_trigger_nothing(empty_index): old_query = "SELECT * FROM old.things LEFT JOIN hive_metastore.other.matters USING (x) WHERE state > 1 LIMIT 10" - assert not list(ftf.lint(old_query)) + assert not list(ftf.lint(old_query, CurrentSessionState())) def test_migrated_tables_trigger_messages(migration_index): @@ -32,7 +32,7 @@ def test_migrated_tables_trigger_messages(migration_index): end_line=0, end_col=1024, ), - ] == list(ftf.lint(old_query)) + ] == list(ftf.lint(old_query, CurrentSessionState())) def test_fully_migrated_queries_match(migration_index): @@ -61,7 +61,7 @@ def test_use_database_change(migration_index): USE newcatalog; SELECT * FROM things LEFT JOIN hive_metastore.other.matters USING (x) WHERE state > 1 LIMIT 10""" - _ = list(ftf.lint(query)) + _ = list(ftf.lint(query, CurrentSessionState())) assert ftf.schema == "newcatalog" From cd993e71b2e092bab1e562cd4bdf407bf1746223 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 12 Jun 2024 19:13:45 +0200 Subject: [PATCH 08/23] improve test coverage --- src/databricks/labs/ucx/source_code/linters/python_ast.py | 1 - tests/unit/source_code/linters/test_python_ast.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/labs/ucx/source_code/linters/python_ast.py b/src/databricks/labs/ucx/source_code/linters/python_ast.py index efb74aa8cc..eb0d08c1ea 100644 --- a/src/databricks/labs/ucx/source_code/linters/python_ast.py +++ b/src/databricks/labs/ucx/source_code/linters/python_ast.py @@ -57,7 +57,6 @@ def locate(self, node_type: type[T], match_nodes: list[tuple[str, type]]) -> lis def first_statement(self): if isinstance(self._node, Module): return self._node.body[0] - return None @classmethod def extract_call_by_name(cls, call: Call, name: str) -> Call | None: diff --git a/tests/unit/source_code/linters/test_python_ast.py b/tests/unit/source_code/linters/test_python_ast.py index ab19c1a784..1536c2b7bb 100644 --- a/tests/unit/source_code/linters/test_python_ast.py +++ b/tests/unit/source_code/linters/test_python_ast.py @@ -10,6 +10,7 @@ def test_extracts_root(): stmt = tree.first_statement() root = Tree(stmt).root assert root == tree.node + assert repr(tree) # for test coverqge def test_extract_call_by_name(): From efb7188250d5c6692961b85cbe4694039fda226d Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 12 Jun 2024 19:21:10 +0200 Subject: [PATCH 09/23] improve test coverage nd formatting --- src/databricks/labs/ucx/source_code/linters/python_ast.py | 4 +++- tests/unit/source_code/linters/test_python_ast.py | 7 ++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/python_ast.py b/src/databricks/labs/ucx/source_code/linters/python_ast.py index eb0d08c1ea..857d18753d 100644 --- a/src/databricks/labs/ucx/source_code/linters/python_ast.py +++ b/src/databricks/labs/ucx/source_code/linters/python_ast.py @@ -56,7 +56,9 @@ def locate(self, node_type: type[T], match_nodes: list[tuple[str, type]]) -> lis def first_statement(self): if isinstance(self._node, Module): - return self._node.body[0] + if len(self._node.body) > 0: + return self._node.body[0] + return None @classmethod def extract_call_by_name(cls, call: Call, name: str) -> Call | None: diff --git a/tests/unit/source_code/linters/test_python_ast.py b/tests/unit/source_code/linters/test_python_ast.py index 1536c2b7bb..4dd4809c93 100644 --- a/tests/unit/source_code/linters/test_python_ast.py +++ b/tests/unit/source_code/linters/test_python_ast.py @@ -10,7 +10,12 @@ def test_extracts_root(): stmt = tree.first_statement() root = Tree(stmt).root assert root == tree.node - assert repr(tree) # for test coverqge + assert repr(tree) # for test coverage + + +def test_no_first_statement(): + tree = Tree.parse("") + assert not tree.first_statement() def test_extract_call_by_name(): From 1c7aa67a5bbb59549b4cded9c23672408f13d5d4 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 12 Jun 2024 19:41:06 +0200 Subject: [PATCH 10/23] improve test coverage --- tests/unit/source_code/linters/test_files.py | 25 ++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/tests/unit/source_code/linters/test_files.py b/tests/unit/source_code/linters/test_files.py index 3d2c4e096e..5b08c3406c 100644 --- a/tests/unit/source_code/linters/test_files.py +++ b/tests/unit/source_code/linters/test_files.py @@ -8,6 +8,7 @@ from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.graph import DependencyResolver, SourceContainer from databricks.labs.ucx.source_code.notebooks.loaders import NotebookResolver, NotebookLoader +from databricks.labs.ucx.source_code.notebooks.migrator import NotebookMigrator from databricks.labs.ucx.source_code.python_libraries import PythonLibraryResolver from databricks.labs.ucx.source_code.known import Whitelist @@ -27,14 +28,21 @@ from tests.unit import locate_site_packages, _samples_path -def test_migrator_fix_ignores_unsupported_extensions(): +def test_notebook_migrator_ignores_unsupported_extensions(): + languages = LinterContext(MigrationIndex([])) + migrator = NotebookMigrator(languages) + path = Path('unsupported.ext') + assert not migrator.apply(path) + + +def test_file_migrator_fix_ignores_unsupported_extensions(): languages = LinterContext(MigrationIndex([])) migrator = LocalFileMigrator(lambda: languages) path = Path('unsupported.ext') assert not migrator.apply(path) -def test_migrator_fix_ignores_unsupported_language(): +def test_file_migrator_fix_ignores_unsupported_language(): languages = LinterContext(MigrationIndex([])) migrator = LocalFileMigrator(lambda: languages) migrator._extensions[".py"] = None # pylint: disable=protected-access @@ -42,14 +50,14 @@ def test_migrator_fix_ignores_unsupported_language(): assert not migrator.apply(path) -def test_migrator_fix_reads_supported_extensions(migration_index): +def test_file_migrator_fix_reads_supported_extensions(migration_index): languages = LinterContext(migration_index) migrator = LocalFileMigrator(lambda: languages) path = Path(__file__) assert not migrator.apply(path) -def test_migrator_supported_language_no_diagnostics(): +def test_file_migrator_supported_language_no_diagnostics(): languages = create_autospec(LinterContext) languages.linter(Language.PYTHON).lint.return_value = [] migrator = LocalFileMigrator(lambda: languages) @@ -58,6 +66,15 @@ def test_migrator_supported_language_no_diagnostics(): languages.fixer.assert_not_called() +def test_notebook_migrator_supported_language_no_diagnostics(simple_dependency_resolver, mock_path_lookup): + paths = ["root1.run.py"] + resolver = simple_dependency_resolver + maybe = resolver.resolve_notebook(mock_path_lookup, Path(paths[0])) + languages = LinterContext(MigrationIndex([])) + migrator = NotebookMigrator(languages) + assert not migrator.apply(maybe.dependency.path) + + def test_migrator_supported_language_no_fixer(): languages = create_autospec(LinterContext) languages.linter(Language.PYTHON).lint.return_value = [Mock(code='some-code')] From 521c4160e0ce9a7a8d67b4a4cbcc8296b9df9d01 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Wed, 12 Jun 2024 19:45:28 +0200 Subject: [PATCH 11/23] simplify --- tests/unit/source_code/linters/test_files.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/unit/source_code/linters/test_files.py b/tests/unit/source_code/linters/test_files.py index 5b08c3406c..f5aa3dc589 100644 --- a/tests/unit/source_code/linters/test_files.py +++ b/tests/unit/source_code/linters/test_files.py @@ -66,13 +66,11 @@ def test_file_migrator_supported_language_no_diagnostics(): languages.fixer.assert_not_called() -def test_notebook_migrator_supported_language_no_diagnostics(simple_dependency_resolver, mock_path_lookup): - paths = ["root1.run.py"] - resolver = simple_dependency_resolver - maybe = resolver.resolve_notebook(mock_path_lookup, Path(paths[0])) +def test_notebook_migrator_supported_language_no_diagnostics(mock_path_lookup): languages = LinterContext(MigrationIndex([])) migrator = NotebookMigrator(languages) - assert not migrator.apply(maybe.dependency.path) + path = mock_path_lookup.resolve(Path("root1.run.py")) + assert not migrator.apply(path) def test_migrator_supported_language_no_fixer(): From fdbdbedc7be0442400e06efa68dc1613bf1ef5e8 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 13 Jun 2024 08:19:18 +0200 Subject: [PATCH 12/23] refactor API --- .../labs/ucx/contexts/application.py | 9 +--- .../labs/ucx/contexts/workspace_cli.py | 9 ++-- src/databricks/labs/ucx/source_code/base.py | 6 +-- src/databricks/labs/ucx/source_code/graph.py | 10 ++-- src/databricks/labs/ucx/source_code/jobs.py | 12 ++--- src/databricks/labs/ucx/source_code/known.py | 3 +- .../labs/ucx/source_code/linters/context.py | 12 ++--- .../labs/ucx/source_code/linters/dbfs.py | 10 ++-- .../labs/ucx/source_code/linters/files.py | 29 +++++------ .../labs/ucx/source_code/linters/imports.py | 7 ++- .../labs/ucx/source_code/linters/pyspark.py | 15 +++--- .../ucx/source_code/linters/python_ast.py | 11 ++-- .../ucx/source_code/linters/spark_connect.py | 3 +- .../ucx/source_code/linters/table_creation.py | 3 +- src/databricks/labs/ucx/source_code/lsp.py | 3 +- .../ucx/source_code/notebooks/migrator.py | 7 ++- .../labs/ucx/source_code/notebooks/sources.py | 20 ++++---- .../labs/ucx/source_code/queries.py | 2 +- tests/integration/source_code/solacc.py | 4 +- tests/integration/source_code/test_jobs.py | 3 +- tests/unit/source_code/conftest.py | 5 +- tests/unit/source_code/linters/test_dbfs.py | 14 +++--- tests/unit/source_code/linters/test_files.py | 10 ++-- .../unit/source_code/linters/test_pyspark.py | 34 +++++++------ .../source_code/linters/test_spark_connect.py | 18 +++---- .../linters/test_table_creation.py | 4 +- .../unit/source_code/notebooks/test_cells.py | 8 +-- .../source_code/notebooks/test_sources.py | 9 ++-- tests/unit/source_code/test_dependencies.py | 50 ++++++++----------- tests/unit/source_code/test_functional.py | 2 +- tests/unit/source_code/test_graph.py | 2 - tests/unit/source_code/test_jobs.py | 1 - tests/unit/source_code/test_notebook.py | 12 ++--- .../unit/source_code/test_notebook_linter.py | 6 +-- .../test_path_lookup_simulation.py | 18 +++---- tests/unit/source_code/test_queries.py | 6 +-- tests/unit/source_code/test_s3fs.py | 12 ++--- 37 files changed, 179 insertions(+), 210 deletions(-) diff --git a/src/databricks/labs/ucx/contexts/application.py b/src/databricks/labs/ucx/contexts/application.py index 1e66ce4f8b..d21c096a72 100644 --- a/src/databricks/labs/ucx/contexts/application.py +++ b/src/databricks/labs/ucx/contexts/application.py @@ -15,7 +15,6 @@ from databricks.labs.ucx.recon.metadata_retriever import DatabricksTableMetadataRetriever from databricks.labs.ucx.recon.migration_recon import MigrationRecon from databricks.labs.ucx.recon.schema_comparator import StandardSchemaComparator -from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.python_libraries import PythonLibraryResolver from databricks.sdk import AccountClient, WorkspaceClient, core from databricks.sdk.errors import ResourceDoesNotExist @@ -376,10 +375,6 @@ def path_lookup(self): # TODO find a solution to enable a different cwd per job/task (maybe it's not necessary or possible?) return PathLookup.from_sys_path(Path.cwd()) - @cached_property - def session_state(self): - return CurrentSessionState() - @cached_property def file_loader(self): return FileLoader() @@ -398,9 +393,7 @@ def file_resolver(self): @cached_property def dependency_resolver(self): - return DependencyResolver( - self.pip_resolver, self.notebook_resolver, self.file_resolver, self.path_lookup, self.session_state - ) + return DependencyResolver(self.pip_resolver, self.notebook_resolver, self.file_resolver, self.path_lookup) @cached_property def workflow_linter(self): diff --git a/src/databricks/labs/ucx/contexts/workspace_cli.py b/src/databricks/labs/ucx/contexts/workspace_cli.py index 3c63223d8c..937aacbbf6 100644 --- a/src/databricks/labs/ucx/contexts/workspace_cli.py +++ b/src/databricks/labs/ucx/contexts/workspace_cli.py @@ -179,9 +179,10 @@ class LocalCheckoutContext(WorkspaceContext): """Local context extends Workspace context to provide extra properties for running local operations.""" - def linter_context_factory(self): + def linter_context_factory(self, session_state: CurrentSessionState | None = None): index = self.tables_migrator.index() - session_state = CurrentSessionState() + if session_state is None: + session_state = CurrentSessionState() return LinterContext(index, session_state) @cached_property @@ -190,10 +191,12 @@ def local_file_migrator(self): @cached_property def local_code_linter(self): + session_state = CurrentSessionState() return LocalCodeLinter( self.file_loader, self.folder_loader, self.path_lookup, + session_state, self.dependency_resolver, - self.linter_context_factory, + lambda: self.linter_context_factory(session_state), ) diff --git a/src/databricks/labs/ucx/source_code/base.py b/src/databricks/labs/ucx/source_code/base.py index ca2cd015ea..77c1133e53 100644 --- a/src/databricks/labs/ucx/source_code/base.py +++ b/src/databricks/labs/ucx/source_code/base.py @@ -127,7 +127,7 @@ class Convention(Advice): class Linter: @abstractmethod - def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: ... + def lint(self, code: str) -> Iterable[Advice]: ... class Fixer: @@ -167,6 +167,6 @@ class SequentialLinter(Linter): def __init__(self, linters: list[Linter]): self._linters = linters - def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: + def lint(self, code: str) -> Iterable[Advice]: for linter in self._linters: - yield from linter.lint(code, session_state) + yield from linter.lint(code) diff --git a/src/databricks/labs/ucx/source_code/graph.py b/src/databricks/labs/ucx/source_code/graph.py index 821a24ec3d..9fadc42c32 100644 --- a/src/databricks/labs/ucx/source_code/graph.py +++ b/src/databricks/labs/ucx/source_code/graph.py @@ -344,13 +344,11 @@ def __init__( notebook_resolver: BaseNotebookResolver, import_resolver: BaseImportResolver, path_lookup: PathLookup, - session_state: CurrentSessionState, ): self._library_resolver = library_resolver self._notebook_resolver = notebook_resolver self._import_resolver = import_resolver self._path_lookup = path_lookup - self._session_state = session_state def resolve_notebook(self, path_lookup: PathLookup, path: Path) -> MaybeDependency: return self._notebook_resolver.resolve_notebook(path_lookup, path) @@ -361,7 +359,7 @@ def resolve_import(self, path_lookup: PathLookup, name: str) -> MaybeDependency: def register_library(self, path_lookup: PathLookup, *libraries: str) -> list[DependencyProblem]: return self._library_resolver.register_library(path_lookup, *libraries) - def build_local_file_dependency_graph(self, path: Path) -> MaybeGraph: + def build_local_file_dependency_graph(self, path: Path, session_state: CurrentSessionState) -> MaybeGraph: """Builds a dependency graph starting from a file. This method is mainly intended for testing purposes. In case of problems, the paths in the problems will be relative to the starting path lookup.""" resolver = self._local_file_resolver @@ -371,7 +369,7 @@ def build_local_file_dependency_graph(self, path: Path) -> MaybeGraph: maybe = resolver.resolve_local_file(self._path_lookup, path) if not maybe.dependency: return MaybeGraph(None, self._make_relative_paths(maybe.problems, path)) - graph = DependencyGraph(maybe.dependency, None, self, self._path_lookup, self._session_state) + graph = DependencyGraph(maybe.dependency, None, self, self._path_lookup, session_state) container = maybe.dependency.load(graph.path_lookup) if container is None: problem = DependencyProblem('cannot-load-file', f"Could not load file {path}") @@ -387,13 +385,13 @@ def _local_file_resolver(self) -> BaseFileResolver | None: return self._import_resolver return None - def build_notebook_dependency_graph(self, path: Path) -> MaybeGraph: + def build_notebook_dependency_graph(self, path: Path, session_state: CurrentSessionState) -> MaybeGraph: """Builds a dependency graph starting from a notebook. This method is mainly intended for testing purposes. In case of problems, the paths in the problems will be relative to the starting path lookup.""" maybe = self._notebook_resolver.resolve_notebook(self._path_lookup, path) if not maybe.dependency: return MaybeGraph(None, self._make_relative_paths(maybe.problems, path)) - graph = DependencyGraph(maybe.dependency, None, self, self._path_lookup, self._session_state) + graph = DependencyGraph(maybe.dependency, None, self, self._path_lookup, session_state) container = maybe.dependency.load(graph.path_lookup) if container is None: problem = DependencyProblem('cannot-load-notebook', f"Could not load notebook {path}") diff --git a/src/databricks/labs/ucx/source_code/jobs.py b/src/databricks/labs/ucx/source_code/jobs.py index cea6eb6707..8f88642e82 100644 --- a/src/databricks/labs/ucx/source_code/jobs.py +++ b/src/databricks/labs/ucx/source_code/jobs.py @@ -365,16 +365,16 @@ def _lint_task(self, task: jobs.Task, job: jobs.Job): if not container: continue if isinstance(container, Notebook): - yield from self._lint_notebook(container, ctx, session_state) + yield from self._lint_notebook(container, ctx) if isinstance(container, LocalFile): - yield from self._lint_file(container, ctx, session_state) + yield from self._lint_file(container, ctx) - def _lint_file(self, file: LocalFile, ctx: LinterContext, session_state: CurrentSessionState): + def _lint_file(self, file: LocalFile, ctx: LinterContext): linter = FileLinter(ctx, file.path) - for advice in linter.lint(session_state): + for advice in linter.lint(): yield file.path, advice - def _lint_notebook(self, notebook: Notebook, ctx: LinterContext, session_state: CurrentSessionState): + def _lint_notebook(self, notebook: Notebook, ctx: LinterContext): linter = NotebookLinter(ctx, notebook) - for advice in linter.lint(session_state): + for advice in linter.lint(): yield notebook.path, advice diff --git a/src/databricks/labs/ucx/source_code/known.py b/src/databricks/labs/ucx/source_code/known.py index ec4e464213..c5cb10e8c8 100644 --- a/src/databricks/labs/ucx/source_code/known.py +++ b/src/databricks/labs/ucx/source_code/known.py @@ -14,7 +14,6 @@ from databricks.labs.blueprint.entrypoint import get_logger from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex -from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.graph import DependencyProblem from databricks.labs.ucx.source_code.linters.context import LinterContext from databricks.labs.ucx.source_code.notebooks.sources import FileLinter @@ -150,7 +149,7 @@ def _analyze_file(cls, known_distributions, library_root, dist_info, module_path ctx = LinterContext(empty_index) linter = FileLinter(ctx, module_path) known_problems = set() - for problem in linter.lint(CurrentSessionState()): + for problem in linter.lint(): known_problems.add(KnownProblem(problem.code, problem.message)) problems = [_.as_dict() for _ in sorted(known_problems)] known_distributions[dist_info.name][module_ref] = problems diff --git a/src/databricks/labs/ucx/source_code/linters/context.py b/src/databricks/labs/ucx/source_code/linters/context.py index 85019af5da..2ce502e62a 100644 --- a/src/databricks/labs/ucx/source_code/linters/context.py +++ b/src/databricks/labs/ucx/source_code/linters/context.py @@ -20,17 +20,17 @@ def __init__(self, index: MigrationIndex, session_state: CurrentSessionState | N self._linters = { Language.PYTHON: SequentialLinter( [ - SparkSql(from_table, index), - DBFSUsageLinter(), + SparkSql(from_table, index, session_state), + DBFSUsageLinter(session_state), DBRv8d0Linter(dbr_version=None), SparkConnectLinter(is_serverless=False), - DbutilsLinter(), + DbutilsLinter(session_state), ] ), Language.SQL: SequentialLinter([from_table, dbfs_from_folder]), } self._fixers: dict[Language, list[Fixer]] = { - Language.PYTHON: [SparkSql(from_table, index)], + Language.PYTHON: [SparkSql(from_table, index, session_state)], Language.SQL: [from_table], } @@ -50,9 +50,9 @@ def fixer(self, language: Language, diagnostic_code: str) -> Fixer | None: return fixer return None - def apply_fixes(self, language: Language, code: str, session_state: CurrentSessionState) -> str: + def apply_fixes(self, language: Language, code: str) -> str: linter = self.linter(language) - for advice in linter.lint(code, session_state): + for advice in linter.lint(code): fixer = self.fixer(language, advice.code) if fixer: code = fixer.apply(code) diff --git a/src/databricks/labs/ucx/source_code/linters/dbfs.py b/src/databricks/labs/ucx/source_code/linters/dbfs.py index 3ee63761e5..21c7f012c0 100644 --- a/src/databricks/labs/ucx/source_code/linters/dbfs.py +++ b/src/databricks/labs/ucx/source_code/linters/dbfs.py @@ -65,6 +65,10 @@ def get_advices(self) -> Iterable[Advice]: class DBFSUsageLinter(Linter): + + def __init__(self, session_state: CurrentSessionState): + self._session_state = session_state + @staticmethod def name() -> str: """ @@ -72,12 +76,12 @@ def name() -> str: """ return 'dbfs-usage' - def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: + def lint(self, code: str) -> Iterable[Advice]: """ Lints the code looking for file system paths that are deprecated """ tree = Tree.parse(code) - visitor = DetectDbfsVisitor(session_state) + visitor = DetectDbfsVisitor(self._session_state) visitor.visit(tree.node) yield from visitor.get_advices() @@ -90,7 +94,7 @@ def __init__(self): def name() -> str: return 'dbfs-query' - def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: + def lint(self, code: str) -> Iterable[Advice]: for statement in sqlglot.parse(code, read='databricks'): if not statement: continue diff --git a/src/databricks/labs/ucx/source_code/linters/files.py b/src/databricks/labs/ucx/source_code/linters/files.py index 5fca181a12..da05636dbe 100644 --- a/src/databricks/labs/ucx/source_code/linters/files.py +++ b/src/databricks/labs/ucx/source_code/linters/files.py @@ -88,12 +88,14 @@ def __init__( file_loader: FileLoader, folder_loader: FolderLoader, path_lookup: PathLookup, + session_state: CurrentSessionState, dependency_resolver: DependencyResolver, languages_factory: Callable[[], LinterContext], ) -> None: self._file_loader = file_loader self._folder_loader = folder_loader self._path_lookup = path_lookup + self._session_state = session_state self._dependency_resolver = dependency_resolver self._extensions = {".py": Language.PYTHON, ".sql": Language.SQL} self._new_linter_context = languages_factory @@ -102,7 +104,6 @@ def lint( self, prompts: Prompts, path: Path | None, - session_state: CurrentSessionState | None = None, stdout: TextIO = sys.stdout, ) -> list[LocatedAdvice]: """Lint local code files looking for problems in notebooks and python files.""" @@ -113,18 +114,16 @@ def lint( validate=lambda p_: Path(p_).exists(), ) path = Path(response) - if session_state is None: - session_state = CurrentSessionState() - located_advices = list(self.lint_path(path, session_state)) + located_advices = list(self.lint_path(path)) for located in located_advices: message = located.message_relative_to(path) stdout.write(f"{message}\n") return located_advices - def lint_path(self, path: Path, session_state: CurrentSessionState) -> Iterable[LocatedAdvice]: + def lint_path(self, path: Path) -> Iterable[LocatedAdvice]: loader = self._folder_loader if path.is_dir() else self._file_loader dependency = Dependency(loader, path) - graph = DependencyGraph(dependency, None, self._dependency_resolver, self._path_lookup, session_state) + graph = DependencyGraph(dependency, None, self._dependency_resolver, self._path_lookup, self._session_state) container = dependency.load(self._path_lookup) assert container is not None # because we just created it problems = container.build_dependency_graph(graph) @@ -132,14 +131,14 @@ def lint_path(self, path: Path, session_state: CurrentSessionState) -> Iterable[ problem_path = Path('UNKNOWN') if problem.is_path_missing() else problem.source_path.absolute() yield problem.as_advisory().for_path(problem_path) for child_path in graph.all_paths: - yield from self._lint_one(child_path, session_state) + yield from self._lint_one(child_path) - def _lint_one(self, path: Path, session_state: CurrentSessionState) -> Iterable[LocatedAdvice]: + def _lint_one(self, path: Path) -> Iterable[LocatedAdvice]: if path.is_dir(): return [] ctx = self._new_linter_context() linter = FileLinter(ctx, path) - return [advice.for_path(path) for advice in linter.lint(session_state)] + return [advice.for_path(path) for advice in linter.lint()] class LocalFileMigrator: @@ -149,16 +148,14 @@ def __init__(self, languages_factory: Callable[[], LinterContext]): self._extensions = {".py": Language.PYTHON, ".sql": Language.SQL} self._languages_factory = languages_factory - def apply(self, path: Path, session_state: CurrentSessionState | None = None) -> bool: - if session_state is None: - session_state = CurrentSessionState() + def apply(self, path: Path) -> bool: if path.is_dir(): for child_path in path.iterdir(): - self.apply(child_path, session_state) + self.apply(child_path) return True - return self._apply_file_fix(path, session_state) + return self._apply_file_fix(path) - def _apply_file_fix(self, path: Path, session_state: CurrentSessionState): + def _apply_file_fix(self, path: Path): """ The fix method reads a file, lints it, applies fixes, and writes the fixed code back to the file. """ @@ -183,7 +180,7 @@ def _apply_file_fix(self, path: Path, session_state: CurrentSessionState): return False applied = False # Lint the code and apply fixes - for advice in linter.lint(code, session_state): + for advice in linter.lint(code): logger.info(f"Found: {advice}") fixer = languages.fixer(language, advice.code) if not fixer: diff --git a/src/databricks/labs/ucx/source_code/linters/imports.py b/src/databricks/labs/ucx/source_code/linters/imports.py index 17b48ce39c..0dc2addfd0 100644 --- a/src/databricks/labs/ucx/source_code/linters/imports.py +++ b/src/databricks/labs/ucx/source_code/linters/imports.py @@ -112,11 +112,14 @@ def _get_notebook_paths(cls, all_inferred: Iterable[InferredValue]) -> tuple[boo class DbutilsLinter(Linter): - def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: + def __init__(self, session_state: CurrentSessionState): + self._session_state = session_state + + def lint(self, code: str) -> Iterable[Advice]: tree = Tree.parse(code) nodes = self.list_dbutils_notebook_run_calls(tree) for node in nodes: - yield from self._raise_advice_if_unresolved(node.node, session_state) + yield from self._raise_advice_if_unresolved(node.node, self._session_state) @classmethod def _raise_advice_if_unresolved(cls, node: NodeNG, session_state: CurrentSessionState) -> Iterable[Advice]: diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index f7ba908413..12201fd488 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -79,7 +79,7 @@ def lint( if table_arg: try: for inferred in Tree(table_arg).infer_values(self.session_state): - yield from self._lint_table_arg(from_table, node, inferred, session_state) + yield from self._lint_table_arg(from_table, node, inferred) except InferenceError: yield Advisory.from_node( code='table-migrate', @@ -88,11 +88,9 @@ def lint( ) @classmethod - def _lint_table_arg( - cls, from_table: FromTable, call_node: NodeNG, inferred: InferredValue, session_state: CurrentSessionState - ): + def _lint_table_arg(cls, from_table: FromTable, call_node: NodeNG, inferred: InferredValue): if inferred.is_inferred(): - for advice in from_table.lint(inferred.as_string(), session_state): + for advice in from_table.lint(inferred.as_string()): yield advice.replace_from_node(call_node) else: yield Advisory.from_node( @@ -323,15 +321,16 @@ class SparkSql(Linter, Fixer): _spark_matchers = SparkMatchers() - def __init__(self, from_table: FromTable, index: MigrationIndex): + def __init__(self, from_table: FromTable, index: MigrationIndex, session_state): self._from_table = from_table self._index = index + self._session_state = session_state def name(self) -> str: # this is the same fixer, just in a different language context return self._from_table.name() - def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: + def lint(self, code: str) -> Iterable[Advice]: try: tree = Tree.parse(code) except AstroidSyntaxError as e: @@ -342,7 +341,7 @@ def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice if matcher is None: continue assert isinstance(node, Call) - yield from matcher.lint(self._from_table, self._index, session_state, node) + yield from matcher.lint(self._from_table, self._index, self._session_state, node) def apply(self, code: str) -> str: tree = Tree.parse(code) diff --git a/src/databricks/labs/ucx/source_code/linters/python_ast.py b/src/databricks/labs/ucx/source_code/linters/python_ast.py index 857d18753d..6584de3d42 100644 --- a/src/databricks/labs/ucx/source_code/linters/python_ast.py +++ b/src/databricks/labs/ucx/source_code/linters/python_ast.py @@ -143,15 +143,16 @@ def _get_attribute_value(cls, node: Attribute): return None def infer_values(self, state: CurrentSessionState | None = None) -> Iterable[InferredValue]: - if state is not None: - self.contextualize(state) + self._contextualize(state) for inferred_atoms in self._infer_values(): yield InferredValue(inferred_atoms) - def contextualize(self, state: CurrentSessionState): + def _contextualize(self, state: CurrentSessionState | None): + if state is None or state.named_parameters is None or len(state.named_parameters) == 0: + return calls = Tree(self.root).locate(Call, [("get", Attribute), ("widgets", Attribute), ("dbutils", Name)]) for call in calls: - call.func = _ContextualCall(state, call) + call.func = _GetWidgetValueCall(state, call) def _infer_values(self) -> Iterator[Iterable[NodeNG]]: # deal with node types that don't implement 'inferred()' @@ -189,7 +190,7 @@ def do_infer_values(self): return self._infer_values() -class _ContextualCall(NodeNG): +class _GetWidgetValueCall(NodeNG): def __init__(self, session_state: CurrentSessionState, node: NodeNG): super().__init__( diff --git a/src/databricks/labs/ucx/source_code/linters/spark_connect.py b/src/databricks/labs/ucx/source_code/linters/spark_connect.py index d344e76fce..35f5230a7a 100644 --- a/src/databricks/labs/ucx/source_code/linters/spark_connect.py +++ b/src/databricks/labs/ucx/source_code/linters/spark_connect.py @@ -7,7 +7,6 @@ Advice, Failure, Linter, - CurrentSessionState, ) from databricks.labs.ucx.source_code.linters.python_ast import Tree @@ -182,7 +181,7 @@ def __init__(self, is_serverless: bool = False): LoggingMatcher(is_serverless=is_serverless), ] - def lint(self, code: str, session_state: CurrentSessionState) -> Iterator[Advice]: + def lint(self, code: str) -> Iterator[Advice]: tree = Tree.parse(code) for matcher in self._matchers: yield from matcher.lint_tree(tree.node) diff --git a/src/databricks/labs/ucx/source_code/linters/table_creation.py b/src/databricks/labs/ucx/source_code/linters/table_creation.py index 7e104195a1..95f00903f8 100644 --- a/src/databricks/labs/ucx/source_code/linters/table_creation.py +++ b/src/databricks/labs/ucx/source_code/linters/table_creation.py @@ -8,7 +8,6 @@ from databricks.labs.ucx.source_code.base import ( Advice, Linter, - CurrentSessionState, ) from databricks.labs.ucx.source_code.linters.python_ast import Tree @@ -112,7 +111,7 @@ def __init__(self, dbr_version: tuple[int, int] | None): ] ) - def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: + def lint(self, code: str) -> Iterable[Advice]: if self._skip_dbr: return diff --git a/src/databricks/labs/ucx/source_code/lsp.py b/src/databricks/labs/ucx/source_code/lsp.py index df19dece39..05bd3a7559 100644 --- a/src/databricks/labs/ucx/source_code/lsp.py +++ b/src/databricks/labs/ucx/source_code/lsp.py @@ -22,7 +22,6 @@ Convention, Deprecation, Failure, - CurrentSessionState, ) from databricks.labs.ucx.source_code.linters.context import LinterContext @@ -244,7 +243,7 @@ def _read(self, file_uri: str): def lint(self, file_uri: str): code, language = self._read(file_uri) analyser = self._languages.linter(language) - diagnostics = [Diagnostic.from_advice(_) for _ in analyser.lint(code, CurrentSessionState())] + diagnostics = [Diagnostic.from_advice(_) for _ in analyser.lint(code)] return AnalyseResponse(diagnostics) def quickfix(self, file_uri: str, code_range: Range, diagnostic_code: str): diff --git a/src/databricks/labs/ucx/source_code/notebooks/migrator.py b/src/databricks/labs/ucx/source_code/notebooks/migrator.py index a56e8acc4a..bdffb86c22 100644 --- a/src/databricks/labs/ucx/source_code/notebooks/migrator.py +++ b/src/databricks/labs/ucx/source_code/notebooks/migrator.py @@ -2,7 +2,6 @@ from pathlib import Path -from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.graph import Dependency from databricks.labs.ucx.source_code.linters.context import LinterContext from databricks.labs.ucx.source_code.notebooks.cells import RunCell @@ -30,9 +29,9 @@ def apply(self, path: Path) -> bool: lookup = PathLookup.from_sys_path(Path.cwd()) container = dependency.load(lookup) assert isinstance(container, Notebook) - return self._apply(container, CurrentSessionState()) + return self._apply(container) - def _apply(self, notebook: Notebook, session_state) -> bool: + def _apply(self, notebook: Notebook) -> bool: changed = False for cell in notebook.cells: # %run is not a supported language, so this needs to come first @@ -43,7 +42,7 @@ def _apply(self, notebook: Notebook, session_state) -> bool: continue if not self._languages.is_supported(cell.language.language): continue - migrated_code = self._languages.apply_fixes(cell.language.language, cell.original_code, session_state) + migrated_code = self._languages.apply_fixes(cell.language.language, cell.original_code) if migrated_code != cell.original_code: cell.migrated_code = migrated_code changed = True diff --git a/src/databricks/labs/ucx/source_code/notebooks/sources.py b/src/databricks/labs/ucx/source_code/notebooks/sources.py index 9f6e050f4f..1e8fcf4f76 100644 --- a/src/databricks/labs/ucx/source_code/notebooks/sources.py +++ b/src/databricks/labs/ucx/source_code/notebooks/sources.py @@ -8,7 +8,7 @@ from databricks.sdk.service.workspace import Language from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex -from databricks.labs.ucx.source_code.base import Advice, Failure, CurrentSessionState +from databricks.labs.ucx.source_code.base import Advice, Failure from databricks.labs.ucx.source_code.graph import SourceContainer, DependencyGraph, DependencyProblem from databricks.labs.ucx.source_code.linters.context import LinterContext @@ -94,12 +94,12 @@ def from_source(cls, index: MigrationIndex, source: str, default_language: Langu assert notebook is not None return cls(ctx, notebook) - def lint(self, session_state: CurrentSessionState) -> Iterable[Advice]: + def lint(self) -> Iterable[Advice]: for cell in self._notebook.cells: if not self._languages.is_supported(cell.language.language): continue linter = self._languages.linter(cell.language.language) - for advice in linter.lint(cell.original_code, session_state): + for advice in linter.lint(cell.original_code): yield advice.replace( start_line=advice.start_line + cell.original_offset, end_line=advice.end_line + cell.original_offset, @@ -179,7 +179,7 @@ def _is_notebook(self): return False return self._source_code.startswith(CellLanguage.of_language(language).file_magic_header) - def lint(self, session_state: CurrentSessionState) -> Iterable[Advice]: + def lint(self) -> Iterable[Advice]: encoding = locale.getpreferredencoding(False) try: is_notebook = self._is_notebook() @@ -189,11 +189,11 @@ def lint(self, session_state: CurrentSessionState) -> Iterable[Advice]: return if is_notebook: - yield from self._lint_notebook(session_state) + yield from self._lint_notebook() else: - yield from self._lint_file(session_state) + yield from self._lint_file() - def _lint_file(self, session_state: CurrentSessionState): + def _lint_file(self): language = self._file_language() if not language: suffix = self._path.suffix.lower() @@ -206,13 +206,13 @@ def _lint_file(self, session_state: CurrentSessionState): else: try: linter = self._ctx.linter(language) - yield from linter.lint(self._source_code, session_state) + yield from linter.lint(self._source_code) except ValueError as err: yield Failure( "unsupported-content", f"Error while parsing content of {self._path.as_posix()}: {err}", 0, 0, 1, 1 ) - def _lint_notebook(self, session_state: CurrentSessionState): + def _lint_notebook(self): notebook = Notebook.parse(self._path, self._source_code, self._file_language()) notebook_linter = NotebookLinter(self._ctx, notebook) - yield from notebook_linter.lint(session_state) + yield from notebook_linter.lint() diff --git a/src/databricks/labs/ucx/source_code/queries.py b/src/databricks/labs/ucx/source_code/queries.py index cd223bd43f..b2b0b9cdd1 100644 --- a/src/databricks/labs/ucx/source_code/queries.py +++ b/src/databricks/labs/ucx/source_code/queries.py @@ -42,7 +42,7 @@ def name(self) -> str: def schema(self): return self._session_state.schema - def lint(self, code: str, session_state: CurrentSessionState) -> Iterable[Advice]: + def lint(self, code: str) -> Iterable[Advice]: for statement in sqlglot.parse(code, read='databricks'): if not statement: continue diff --git a/tests/integration/source_code/solacc.py b/tests/integration/source_code/solacc.py index 30637e0538..20ead2a203 100644 --- a/tests/integration/source_code/solacc.py +++ b/tests/integration/source_code/solacc.py @@ -9,7 +9,6 @@ from databricks.labs.ucx.contexts.workspace_cli import LocalCheckoutContext from databricks.labs.ucx.framework.utils import run_command from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex -from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.linters.context import LinterContext logger = logging.getLogger("verify-accelerators") @@ -51,11 +50,10 @@ def lint_all(): ctx = LocalCheckoutContext(ws).replace(linter_context_factory=lambda: LinterContext(MigrationIndex([]))) parseable = 0 missing_imports = 0 - session_state = CurrentSessionState() all_files = list(dist.glob('**/*.py')) for file in all_files: try: - for located_advice in ctx.local_code_linter.lint_path(file, session_state): + for located_advice in ctx.local_code_linter.lint_path(file): if located_advice.advice.code == 'import-not-found': missing_imports += 1 message = located_advice.message_relative_to(dist.parent, default=file) diff --git a/tests/integration/source_code/test_jobs.py b/tests/integration/source_code/test_jobs.py index fc994ce916..839e3d0152 100644 --- a/tests/integration/source_code/test_jobs.py +++ b/tests/integration/source_code/test_jobs.py @@ -187,10 +187,11 @@ def test_lint_local_code(simple_ctx): light_ctx.file_loader, light_ctx.folder_loader, light_ctx.path_lookup, + light_ctx.session_state, light_ctx.dependency_resolver, lambda: linter_context, ) - problems = linter.lint(Prompts(), path_to_scan, light_ctx.session_state, StringIO()) + problems = linter.lint(Prompts(), path_to_scan, StringIO()) assert len(problems) > 0 diff --git a/tests/unit/source_code/conftest.py b/tests/unit/source_code/conftest.py index 97ff5cb5ec..1ad3c4ac69 100644 --- a/tests/unit/source_code/conftest.py +++ b/tests/unit/source_code/conftest.py @@ -4,7 +4,6 @@ MigrationStatus, ) from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex -from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.graph import DependencyResolver from databricks.labs.ucx.source_code.known import Whitelist from databricks.labs.ucx.source_code.linters.files import ImportFileResolver, FileLoader @@ -54,6 +53,4 @@ def simple_dependency_resolver(mock_path_lookup): library_resolver = PythonLibraryResolver(whitelist) notebook_resolver = NotebookResolver(NotebookLoader()) import_resolver = ImportFileResolver(FileLoader(), whitelist) - return DependencyResolver( - library_resolver, notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() - ) + return DependencyResolver(library_resolver, notebook_resolver, import_resolver, mock_path_lookup) diff --git a/tests/unit/source_code/linters/test_dbfs.py b/tests/unit/source_code/linters/test_dbfs.py index 651ca48652..9ce795eee7 100644 --- a/tests/unit/source_code/linters/test_dbfs.py +++ b/tests/unit/source_code/linters/test_dbfs.py @@ -17,8 +17,8 @@ class TestDetectDBFS: ], ) def test_detects_dbfs_paths(self, code, expected): - linter = DBFSUsageLinter() - advices = list(linter.lint(code, CurrentSessionState())) + linter = DBFSUsageLinter(CurrentSessionState()) + advices = list(linter.lint(code)) for advice in advices: assert isinstance(advice, Advisory) assert len(advices) == expected @@ -46,8 +46,8 @@ def test_detects_dbfs_paths(self, code, expected): ], ) def test_dbfs_usage_linter(self, code, expected): - linter = DBFSUsageLinter() - advices = linter.lint(code, CurrentSessionState()) + linter = DBFSUsageLinter(CurrentSessionState()) + advices = linter.lint(code) count = 0 for advice in advices: if isinstance(advice, Deprecation): @@ -55,7 +55,7 @@ def test_dbfs_usage_linter(self, code, expected): assert count == expected def test_dbfs_name(self): - linter = DBFSUsageLinter() + linter = DBFSUsageLinter(CurrentSessionState()) assert linter.name() == "dbfs-usage" @@ -73,7 +73,7 @@ def test_dbfs_name(self): ) def test_non_dbfs_trigger_nothing(query): ftf = FromDbfsFolder() - assert not list(ftf.lint(query, CurrentSessionState())) + assert not list(ftf.lint(query)) @pytest.mark.parametrize( @@ -100,7 +100,7 @@ def test_dbfs_tables_trigger_messages_param(query: str, table: str): end_line=0, end_col=1024, ), - ] == list(ftf.lint(query, CurrentSessionState())) + ] == list(ftf.lint(query)) def test_dbfs_queries_name(): diff --git a/tests/unit/source_code/linters/test_files.py b/tests/unit/source_code/linters/test_files.py index f5aa3dc589..fe7b28c8c0 100644 --- a/tests/unit/source_code/linters/test_files.py +++ b/tests/unit/source_code/linters/test_files.py @@ -118,14 +118,13 @@ def test_linter_walks_directory(mock_path_lookup, migration_index): NotebookResolver(NotebookLoader()), ImportFileResolver(file_loader, whitelist), mock_path_lookup, - session_state, ) path = Path(Path(__file__).parent, "../samples", "simulate-sys-path") prompts = MockPrompts({"Which file or directory do you want to lint ?": path.as_posix()}) linter = LocalCodeLinter( - file_loader, folder_loader, mock_path_lookup, resolver, lambda: LinterContext(migration_index) + file_loader, folder_loader, mock_path_lookup, session_state, resolver, lambda: LinterContext(migration_index) ) - advices = linter.lint(prompts, None, session_state) + advices = linter.lint(prompts, None) assert not advices @@ -173,12 +172,13 @@ def test_known_issues(path: Path, migration_index): notebook_resolver = NotebookResolver(NotebookLoader()) import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, path_lookup, session_state) + resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, path_lookup) linter = LocalCodeLinter( file_loader, folder_loader, path_lookup, + session_state, resolver, lambda: LinterContext(migration_index, session_state), ) - linter.lint(MockPrompts({}), path, session_state) + linter.lint(MockPrompts({}), path) diff --git a/tests/unit/source_code/linters/test_pyspark.py b/tests/unit/source_code/linters/test_pyspark.py index 7d045150b8..6f5bfdd6b3 100644 --- a/tests/unit/source_code/linters/test_pyspark.py +++ b/tests/unit/source_code/linters/test_pyspark.py @@ -11,15 +11,15 @@ def test_spark_no_sql(empty_index): session_state = CurrentSessionState() ftf = FromTable(empty_index, session_state) - sqf = SparkSql(ftf, empty_index) + sqf = SparkSql(ftf, empty_index, session_state) - assert not list(sqf.lint("print(1)", session_state)) + assert not list(sqf.lint("print(1)")) def test_spark_sql_no_match(empty_index): session_state = CurrentSessionState() ftf = FromTable(empty_index, session_state) - sqf = SparkSql(ftf, empty_index) + sqf = SparkSql(ftf, empty_index, session_state) old_code = """ for i in range(10): @@ -27,13 +27,13 @@ def test_spark_sql_no_match(empty_index): print(len(result)) """ - assert not list(sqf.lint(old_code, session_state)) + assert not list(sqf.lint(old_code)) def test_spark_sql_match(migration_index): session_state = CurrentSessionState() ftf = FromTable(migration_index, session_state) - sqf = SparkSql(ftf, migration_index) + sqf = SparkSql(ftf, migration_index, session_state) old_code = """ spark.read.csv("s3://bucket/path") @@ -41,7 +41,7 @@ def test_spark_sql_match(migration_index): result = spark.sql("SELECT * FROM old.things").collect() print(len(result)) """ - assert list(sqf.lint(old_code, session_state)) == [ + assert list(sqf.lint(old_code)) == [ Deprecation( code='direct-filesystem-access', message='The use of direct filesystem references is deprecated: s3://bucket/path', @@ -64,7 +64,7 @@ def test_spark_sql_match(migration_index): def test_spark_sql_match_named(migration_index): session_state = CurrentSessionState() ftf = FromTable(migration_index, session_state) - sqf = SparkSql(ftf, migration_index) + sqf = SparkSql(ftf, migration_index, session_state) old_code = """ spark.read.csv("s3://bucket/path") @@ -72,7 +72,7 @@ def test_spark_sql_match_named(migration_index): result = spark.sql(args=[1], sqlQuery = "SELECT * FROM old.things").collect() print(len(result)) """ - assert list(sqf.lint(old_code, session_state)) == [ + assert list(sqf.lint(old_code)) == [ Deprecation( code='direct-filesystem-access', message='The use of direct filesystem references is deprecated: ' 's3://bucket/path', @@ -93,8 +93,9 @@ def test_spark_sql_match_named(migration_index): def test_spark_table_return_value_apply(migration_index): - ftf = FromTable(migration_index, CurrentSessionState()) - sqf = SparkSql(ftf, migration_index) + session_state = CurrentSessionState() + ftf = FromTable(migration_index, session_state) + sqf = SparkSql(ftf, migration_index, session_state) old_code = """spark.read.csv('s3://bucket/path') for table in spark.catalog.listTables(): do_stuff_with_table(table)""" @@ -104,8 +105,9 @@ def test_spark_table_return_value_apply(migration_index): def test_spark_sql_fix(migration_index): - ftf = FromTable(migration_index, CurrentSessionState()) - sqf = SparkSql(ftf, migration_index) + session_state = CurrentSessionState() + ftf = FromTable(migration_index, session_state) + sqf = SparkSql(ftf, migration_index, session_state) old_code = """spark.read.csv("s3://bucket/path") for i in range(10): @@ -528,8 +530,8 @@ def test_spark_sql_fix(migration_index): def test_spark_cloud_direct_access(empty_index, code, expected): session_state = CurrentSessionState() ftf = FromTable(empty_index, session_state) - sqf = SparkSql(ftf, empty_index) - advisories = list(sqf.lint(code, session_state)) + sqf = SparkSql(ftf, empty_index, session_state) + advisories = list(sqf.lint(code)) assert advisories == expected @@ -548,10 +550,10 @@ def test_spark_cloud_direct_access(empty_index, code, expected): def test_direct_cloud_access_reports_nothing(empty_index, fs_function): session_state = CurrentSessionState() ftf = FromTable(empty_index, session_state) - sqf = SparkSql(ftf, empty_index) + sqf = SparkSql(ftf, empty_index, session_state) # ls function calls have to be from dbutils.fs, or we ignore them code = f"""spark.{fs_function}("/bucket/path")""" - advisories = list(sqf.lint(code, session_state)) + advisories = list(sqf.lint(code)) assert not advisories diff --git a/tests/unit/source_code/linters/test_spark_connect.py b/tests/unit/source_code/linters/test_spark_connect.py index 3401f4bbc2..d33a64117f 100644 --- a/tests/unit/source_code/linters/test_spark_connect.py +++ b/tests/unit/source_code/linters/test_spark_connect.py @@ -1,6 +1,6 @@ from itertools import chain -from databricks.labs.ucx.source_code.base import Failure, CurrentSessionState +from databricks.labs.ucx.source_code.base import Failure from databricks.labs.ucx.source_code.linters.python_ast import Tree from databricks.labs.ucx.source_code.linters.spark_connect import LoggingMatcher, SparkConnectLinter @@ -21,7 +21,7 @@ def test_jvm_access_match_shared(): end_col=18, ), ] - actual = list(linter.lint(code, CurrentSessionState())) + actual = list(linter.lint(code)) assert actual == expected @@ -42,7 +42,7 @@ def test_jvm_access_match_serverless(): end_col=18, ), ] - actual = list(linter.lint(code, CurrentSessionState())) + actual = list(linter.lint(code)) assert actual == expected @@ -86,7 +86,7 @@ def test_rdd_context_match_shared(): end_col=40, ), ] - actual = list(linter.lint(code, CurrentSessionState())) + actual = list(linter.lint(code)) assert actual == expected @@ -129,7 +129,7 @@ def test_rdd_context_match_serverless(): end_line=2, end_col=40, ), - ] == list(linter.lint(code, CurrentSessionState())) + ] == list(linter.lint(code)) def test_rdd_map_partitions(): @@ -148,7 +148,7 @@ def test_rdd_map_partitions(): end_col=27, ), ] - actual = list(linter.lint(code, CurrentSessionState())) + actual = list(linter.lint(code)) assert actual == expected @@ -164,7 +164,7 @@ def test_conf_shared(): end_line=0, end_col=23, ), - ] == list(linter.lint(code, CurrentSessionState())) + ] == list(linter.lint(code)) def test_conf_serverless(): @@ -180,7 +180,7 @@ def test_conf_serverless(): end_col=8, ), ] - actual = list(linter.lint(code, CurrentSessionState())) + actual = list(linter.lint(code)) assert actual == expected @@ -260,4 +260,4 @@ def test_valid_code(): df = spark.range(10) df.collect() """ - assert not list(linter.lint(code, CurrentSessionState())) + assert not list(linter.lint(code)) diff --git a/tests/unit/source_code/linters/test_table_creation.py b/tests/unit/source_code/linters/test_table_creation.py index 0d014f3c48..e1b1c3756d 100644 --- a/tests/unit/source_code/linters/test_table_creation.py +++ b/tests/unit/source_code/linters/test_table_creation.py @@ -2,7 +2,7 @@ import pytest -from databricks.labs.ucx.source_code.base import Advice, CurrentSessionState +from databricks.labs.ucx.source_code.base import Advice from databricks.labs.ucx.source_code.linters.table_creation import DBRv8d0Linter @@ -60,7 +60,7 @@ def lint( dbr_version: tuple[int, int] | None = (7, 9), ) -> list[Advice]: """Invoke linting for the given dbr version""" - return list(DBRv8d0Linter(dbr_version).lint(code, CurrentSessionState())) + return list(DBRv8d0Linter(dbr_version).lint(code)) @pytest.mark.parametrize("method_name", METHOD_NAMES) diff --git a/tests/unit/source_code/notebooks/test_cells.py b/tests/unit/source_code/notebooks/test_cells.py index 56c8835168..65f53109c3 100644 --- a/tests/unit/source_code/notebooks/test_cells.py +++ b/tests/unit/source_code/notebooks/test_cells.py @@ -118,9 +118,7 @@ def test_pip_cell_build_dependency_graph_reports_unknown_library(mock_path_looku notebook_loader = NotebookLoader() notebook_resolver = NotebookResolver(notebook_loader) pip_resolver = PythonLibraryResolver(Whitelist()) - dependency_resolver = DependencyResolver( - pip_resolver, notebook_resolver, [], mock_path_lookup, CurrentSessionState() - ) + dependency_resolver = DependencyResolver(pip_resolver, notebook_resolver, [], mock_path_lookup) graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup, CurrentSessionState()) code = "%pip install unknown-library-name" @@ -141,9 +139,7 @@ def test_pip_cell_build_dependency_graph_resolves_installed_library(mock_path_lo file_loader = FileLoader() pip_resolver = PythonLibraryResolver(whitelist) import_resolver = ImportFileResolver(file_loader, whitelist) - dependency_resolver = DependencyResolver( - pip_resolver, notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() - ) + dependency_resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, mock_path_lookup) graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup, CurrentSessionState()) whl = Path(__file__).parent / '../samples/distribution/dist/thingy-0.0.1-py2.py3-none-any.whl' diff --git a/tests/unit/source_code/notebooks/test_sources.py b/tests/unit/source_code/notebooks/test_sources.py index 65afe8fb32..9cd77b0c06 100644 --- a/tests/unit/source_code/notebooks/test_sources.py +++ b/tests/unit/source_code/notebooks/test_sources.py @@ -3,7 +3,6 @@ import pytest -from databricks.labs.ucx.source_code.base import CurrentSessionState from databricks.labs.ucx.source_code.linters.context import LinterContext from databricks.labs.ucx.source_code.notebooks.sources import FileLinter @@ -11,14 +10,14 @@ @pytest.mark.parametrize("path, content", [("xyz.py", "a = 3"), ("xyz.sql", "select * from dual")]) def test_file_linter_lints_supported_language(path, content, migration_index): linter = FileLinter(LinterContext(migration_index), Path(path), content) - advices = list(linter.lint(CurrentSessionState())) + advices = list(linter.lint()) assert not advices @pytest.mark.parametrize("path", ["xyz.scala", "xyz.r", "xyz.sh"]) def test_file_linter_lints_not_yet_supported_language(path, migration_index): linter = FileLinter(LinterContext(migration_index), Path(path), "") - advices = list(linter.lint(CurrentSessionState())) + advices = list(linter.lint()) assert [advice.code for advice in advices] == ["unsupported-language"] @@ -43,7 +42,7 @@ def test_file_linter_lints_not_yet_supported_language(path, migration_index): ) def test_file_linter_lints_ignorable_language(path, migration_index): linter = FileLinter(LinterContext(migration_index), Path(path), "") - advices = list(linter.lint(CurrentSessionState())) + advices = list(linter.lint()) assert not advices @@ -52,7 +51,7 @@ def test_file_linter_lints_non_ascii_encoded_file(migration_index): non_ascii_encoded_file = Path(__file__).parent.parent / "samples" / "nonascii.py" linter = FileLinter(LinterContext(migration_index), non_ascii_encoded_file) - advices = list(linter.lint(CurrentSessionState())) + advices = list(linter.lint()) assert len(advices) == 1 assert advices[0].code == "unsupported-file-encoding" diff --git a/tests/unit/source_code/test_dependencies.py b/tests/unit/source_code/test_dependencies.py index 1250ff243b..c7fc821247 100644 --- a/tests/unit/source_code/test_dependencies.py +++ b/tests/unit/source_code/test_dependencies.py @@ -29,25 +29,25 @@ def test_dependency_resolver_repr(simple_dependency_resolver): def test_dependency_resolver_visits_workspace_notebook_dependencies(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_notebook_dependency_graph(Path("root3.run.py")) + maybe = simple_dependency_resolver.build_notebook_dependency_graph(Path("root3.run.py"), CurrentSessionState()) assert not maybe.failed assert maybe.graph.all_relative_names() == {"root3.run.py", "root1.run.py", "leaf1.py", "leaf2.py"} def test_dependency_resolver_visits_local_notebook_dependencies(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_notebook_dependency_graph(Path("root4.py")) + maybe = simple_dependency_resolver.build_notebook_dependency_graph(Path("root4.py"), CurrentSessionState()) assert not maybe.failed assert maybe.graph.all_relative_names() == {"root4.py", "leaf3.py"} def test_dependency_resolver_visits_workspace_file_dependencies(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path('./root8.py')) + maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path('./root8.py'), CurrentSessionState()) assert not maybe.failed assert maybe.graph.all_relative_names() == {'leaf1.py', 'leaf2.py', 'root8.py'} def test_dependency_resolver_raises_problem_with_unfound_workspace_notebook_dependency(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_notebook_dependency_graph(Path("root1-no-leaf.run.py")) + maybe = simple_dependency_resolver.build_notebook_dependency_graph(Path("root1-no-leaf.run.py"), CurrentSessionState()) assert list(maybe.problems) == [ DependencyProblem( 'notebook-not-found', @@ -62,7 +62,7 @@ def test_dependency_resolver_raises_problem_with_unfound_workspace_notebook_depe def test_dependency_resolver_raises_problem_with_unfound_local_notebook_dependency(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_notebook_dependency_graph(Path("root4-no-leaf.py")) + maybe = simple_dependency_resolver.build_notebook_dependency_graph(Path("root4-no-leaf.py"), CurrentSessionState()) assert list(maybe.problems) == [ DependencyProblem( 'notebook-not-found', 'Notebook not found: __NO_LEAF__', Path('root4-no-leaf.py'), 1, 0, 1, 37 @@ -71,33 +71,33 @@ def test_dependency_resolver_raises_problem_with_unfound_local_notebook_dependen def test_dependency_resolver_raises_problem_with_invalid_run_cell(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_notebook_dependency_graph(Path('leaf6.py')) + maybe = simple_dependency_resolver.build_notebook_dependency_graph(Path('leaf6.py'), CurrentSessionState()) assert list(maybe.problems) == [ DependencyProblem('invalid-run-cell', 'Missing notebook path in %run command', Path('leaf6.py'), 5, 0, 5, 4) ] def test_dependency_resolver_visits_recursive_file_dependencies(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("root6.py")) + maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("root6.py"), CurrentSessionState()) assert not maybe.failed assert maybe.graph.all_relative_names() == {"root6.py", "root5.py", "leaf4.py"} def test_dependency_resolver_raises_problem_with_unresolved_import(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path('root7.py')) + maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path('root7.py'), CurrentSessionState()) assert list(maybe.problems) == [ DependencyProblem('import-not-found', 'Could not locate import: some_library', Path("root7.py"), 0, 0, 0, 19) ] def test_dependency_resolver_visits_file_dependencies(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("root5.py")) + maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("root5.py"), CurrentSessionState()) assert not maybe.failed assert maybe.graph.all_relative_names() == {"root5.py", "leaf4.py"} def test_dependency_resolver_skips_builtin_dependencies(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("python_builtins.py")) + maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("python_builtins.py"), CurrentSessionState()) assert not maybe.failed graph = maybe.graph maybe = graph.locate_dependency(Path("os")) @@ -107,7 +107,7 @@ def test_dependency_resolver_skips_builtin_dependencies(simple_dependency_resolv def test_dependency_resolver_ignores_known_dependencies(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("python_builtins.py")) + maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("python_builtins.py"), CurrentSessionState()) assert maybe.graph graph = maybe.graph maybe_graph = graph.locate_dependency(Path("databricks")) @@ -121,10 +121,8 @@ def test_dependency_resolver_terminates_at_known_libraries(empty_index, mock_not file_loader = FileLoader() import_resolver = ImportFileResolver(file_loader, Whitelist()) library_resolver = PythonLibraryResolver(Whitelist()) - resolver = DependencyResolver( - library_resolver, mock_notebook_resolver, import_resolver, lookup, CurrentSessionState() - ) - maybe = resolver.build_local_file_dependency_graph(Path("import-site-package.py")) + resolver = DependencyResolver(library_resolver, mock_notebook_resolver, import_resolver, lookup) + maybe = resolver.build_local_file_dependency_graph(Path("import-site-package.py"), CurrentSessionState()) assert not maybe.failed graph = maybe.graph maybe = graph.locate_dependency(Path(site_packages_path, "certifi", "core.py")) @@ -133,14 +131,14 @@ def test_dependency_resolver_terminates_at_known_libraries(empty_index, mock_not def test_dependency_resolver_raises_problem_with_unfound_root_file(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("non-existing.py")) + maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("non-existing.py"), CurrentSessionState()) assert list(maybe.problems) == [ DependencyProblem('file-not-found', 'File not found: non-existing.py', Path("non-existing.py")) ] def test_dependency_resolver_raises_problem_with_unfound_root_notebook(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_notebook_dependency_graph(Path("unknown_notebook")) + maybe = simple_dependency_resolver.build_notebook_dependency_graph(Path("unknown_notebook"), CurrentSessionState()) assert list(maybe.problems) == [ DependencyProblem('notebook-not-found', 'Notebook not found: unknown_notebook', Path("unknown_notebook")) ] @@ -156,10 +154,8 @@ def load_dependency(self, path_lookup: PathLookup, dependency: Dependency) -> So whitelist = Whitelist() import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - resolver = DependencyResolver( - pip_resolver, mock_notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() - ) - maybe = resolver.build_local_file_dependency_graph(Path("import-sub-site-package.py")) + resolver = DependencyResolver(pip_resolver, mock_notebook_resolver, import_resolver, mock_path_lookup) + maybe = resolver.build_local_file_dependency_graph(Path("import-sub-site-package.py"), CurrentSessionState()) assert list(maybe.problems) == [ DependencyProblem( 'cannot-load-file', 'Could not load file import-sub-site-package.py', Path('') @@ -176,8 +172,8 @@ def load_dependency(self, path_lookup: PathLookup, dependency: Dependency) -> So notebook_loader = FailingNotebookLoader() notebook_resolver = NotebookResolver(notebook_loader) pip_resolver = PythonLibraryResolver(Whitelist()) - resolver = DependencyResolver(pip_resolver, notebook_resolver, [], mock_path_lookup, CurrentSessionState()) - maybe = resolver.build_notebook_dependency_graph(Path("root5.py")) + resolver = DependencyResolver(pip_resolver, notebook_resolver, [], mock_path_lookup) + maybe = resolver.build_notebook_dependency_graph(Path("root5.py"), CurrentSessionState()) assert list(maybe.problems) == [ DependencyProblem('cannot-load-notebook', 'Could not load notebook root5.py', Path('')) ] @@ -187,17 +183,15 @@ def test_dependency_resolver_raises_problem_with_missing_file_loader(mock_notebo library_resolver = PythonLibraryResolver(Whitelist()) import_resolver = create_autospec(BaseImportResolver) import_resolver.resolve_import.return_value = None - resolver = DependencyResolver( - library_resolver, mock_notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() - ) - maybe = resolver.build_local_file_dependency_graph(Path("import-sub-site-package.py")) + resolver = DependencyResolver(library_resolver, mock_notebook_resolver, import_resolver, mock_path_lookup) + maybe = resolver.build_local_file_dependency_graph(Path("import-sub-site-package.py"), CurrentSessionState()) assert list(maybe.problems) == [ DependencyProblem('missing-file-resolver', 'Missing resolver for local files', Path('')) ] def test_dependency_resolver_raises_problem_for_non_inferable_sys_path(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("sys-path-with-fstring.py")) + maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("sys-path-with-fstring.py"), CurrentSessionState()) assert list(maybe.problems) == [ DependencyProblem( code='sys-path-cannot-compute', diff --git a/tests/unit/source_code/test_functional.py b/tests/unit/source_code/test_functional.py index 707694e7dd..a81ca8be45 100644 --- a/tests/unit/source_code/test_functional.py +++ b/tests/unit/source_code/test_functional.py @@ -106,7 +106,7 @@ def _lint(self) -> Iterable[Advice]: session_state.named_parameters = {"my-widget": "my-path.py"} ctx = LinterContext(migration_index, session_state) linter = FileLinter(ctx, self.path) - return linter.lint(session_state) + return linter.lint() def _expected_problems(self) -> Generator[Expectation, None, None]: with self.path.open('rb') as f: diff --git a/tests/unit/source_code/test_graph.py b/tests/unit/source_code/test_graph.py index c06c51bb54..a1234b6887 100644 --- a/tests/unit/source_code/test_graph.py +++ b/tests/unit/source_code/test_graph.py @@ -18,7 +18,6 @@ def test_dependency_graph_registers_library(mock_path_lookup): NotebookResolver(NotebookLoader()), ImportFileResolver(file_loader, whitelist), mock_path_lookup, - session_state, ) graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup, session_state) @@ -38,7 +37,6 @@ def test_folder_loads_content(mock_path_lookup): NotebookResolver(NotebookLoader()), ImportFileResolver(file_loader, whitelist), mock_path_lookup, - session_state, ) dependency = Dependency(FolderLoader(file_loader), path) graph = DependencyGraph(dependency, None, dependency_resolver, mock_path_lookup, session_state) diff --git a/tests/unit/source_code/test_jobs.py b/tests/unit/source_code/test_jobs.py index 3a5b46e2d6..1912e43757 100644 --- a/tests/unit/source_code/test_jobs.py +++ b/tests/unit/source_code/test_jobs.py @@ -39,7 +39,6 @@ def dependency_resolver(mock_path_lookup) -> DependencyResolver: NotebookResolver(NotebookLoader()), ImportFileResolver(file_loader, whitelist), mock_path_lookup, - CurrentSessionState(), ) return resolver diff --git a/tests/unit/source_code/test_notebook.py b/tests/unit/source_code/test_notebook.py index 852082c97e..d1a410c942 100644 --- a/tests/unit/source_code/test_notebook.py +++ b/tests/unit/source_code/test_notebook.py @@ -134,9 +134,7 @@ def dependency_resolver(mock_path_lookup) -> DependencyResolver: notebook_resolver = NotebookResolver(notebook_loader) library_resolver = PythonLibraryResolver(Whitelist()) import_resolver = ImportFileResolver(FileLoader(), Whitelist()) - return DependencyResolver( - library_resolver, notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() - ) + return DependencyResolver(library_resolver, notebook_resolver, import_resolver, mock_path_lookup) def test_notebook_builds_leaf_dependency_graph(mock_path_lookup) -> None: @@ -255,7 +253,7 @@ def test_detects_multiple_calls_to_dbutils_notebook_run_in_python_code() -> None stuff2 = dbutils.notebook.run("where is notebook 1?") stuff3 = dbutils.notebook.run("where is notebook 2?") """ - linter = DbutilsLinter() + linter = DbutilsLinter(CurrentSessionState()) tree = Tree.parse(source) nodes = linter.list_dbutils_notebook_run_calls(tree) assert len(nodes) == 2 @@ -267,7 +265,7 @@ def test_does_not_detect_partial_call_to_dbutils_notebook_run_in_python_code_() do_something_with_stuff(stuff) stuff2 = notebook.run("where is notebook 1?") """ - linter = DbutilsLinter() + linter = DbutilsLinter(CurrentSessionState()) tree = Tree.parse(source) nodes = linter.list_dbutils_notebook_run_calls(tree) assert len(nodes) == 0 @@ -279,7 +277,7 @@ def test_raises_advice_when_dbutils_notebook_run_is_too_complex() -> None: name2 = f"{name1}" dbutils.notebook.run(f"Hey {name2}") """ - linter = DbutilsLinter() - advices = list(linter.lint(source, CurrentSessionState())) + linter = DbutilsLinter(CurrentSessionState()) + advices = list(linter.lint(source)) assert len(advices) == 1 assert advices[0].code == "dbutils-notebook-run-dynamic" diff --git a/tests/unit/source_code/test_notebook_linter.py b/tests/unit/source_code/test_notebook_linter.py index eceecb1567..d507f03b38 100644 --- a/tests/unit/source_code/test_notebook_linter.py +++ b/tests/unit/source_code/test_notebook_linter.py @@ -2,7 +2,7 @@ from databricks.sdk.service.workspace import Language from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex -from databricks.labs.ucx.source_code.base import Deprecation, Advice, CurrentSessionState +from databricks.labs.ucx.source_code.base import Deprecation, Advice from databricks.labs.ucx.source_code.notebooks.sources import NotebookLinter index = MigrationIndex([]) @@ -317,7 +317,7 @@ def test_notebook_linter(lang, source, expected): # over multiple lines. linter = NotebookLinter.from_source(index, source, lang) assert linter is not None - gathered = list(linter.lint(CurrentSessionState())) + gathered = list(linter.lint()) assert gathered == expected @@ -545,5 +545,5 @@ def test_notebook_linter_name(): def test_notebook_linter_tracks_use(extended_test_index, lang, source, expected): linter = NotebookLinter.from_source(extended_test_index, source, lang) assert linter is not None - advices = list(linter.lint(CurrentSessionState())) + advices = list(linter.lint()) assert advices == expected diff --git a/tests/unit/source_code/test_path_lookup_simulation.py b/tests/unit/source_code/test_path_lookup_simulation.py index f93d1e5bb1..b655b2c643 100644 --- a/tests/unit/source_code/test_path_lookup_simulation.py +++ b/tests/unit/source_code/test_path_lookup_simulation.py @@ -48,10 +48,8 @@ def test_locates_notebooks(source: list[str], expected: int, mock_path_lookup): whitelist = Whitelist() import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - dependency_resolver = DependencyResolver( - pip_resolver, notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() - ) - maybe = dependency_resolver.build_notebook_dependency_graph(notebook_path) + dependency_resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, mock_path_lookup) + maybe = dependency_resolver.build_notebook_dependency_graph(notebook_path, CurrentSessionState()) assert not maybe.problems assert maybe.graph is not None assert len(maybe.graph.all_paths) == expected @@ -76,8 +74,8 @@ def test_locates_files(source: list[str], expected: int): notebook_resolver = NotebookResolver(notebook_loader) import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, lookup, CurrentSessionState()) - maybe = resolver.build_local_file_dependency_graph(file_path) + resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, lookup) + maybe = resolver.build_local_file_dependency_graph(file_path, CurrentSessionState()) assert not maybe.problems assert maybe.graph is not None assert len(maybe.graph.all_dependencies) == expected @@ -115,8 +113,8 @@ def test_locates_notebooks_with_absolute_path(): whitelist = Whitelist() import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, lookup, CurrentSessionState()) - maybe = resolver.build_notebook_dependency_graph(parent_file_path) + resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, lookup) + maybe = resolver.build_notebook_dependency_graph(parent_file_path, CurrentSessionState()) assert not maybe.problems assert maybe.graph is not None assert len(maybe.graph.all_paths) == 2 @@ -154,8 +152,8 @@ def func(): file_loader = FileLoader() import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, lookup, CurrentSessionState()) - maybe = resolver.build_notebook_dependency_graph(parent_file_path) + resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, lookup) + maybe = resolver.build_notebook_dependency_graph(parent_file_path, CurrentSessionState()) assert not maybe.problems assert maybe.graph is not None assert maybe.graph.all_relative_names() == {"some_file.py", "import_file.py"} diff --git a/tests/unit/source_code/test_queries.py b/tests/unit/source_code/test_queries.py index d327eb53e8..0316170516 100644 --- a/tests/unit/source_code/test_queries.py +++ b/tests/unit/source_code/test_queries.py @@ -7,7 +7,7 @@ def test_not_migrated_tables_trigger_nothing(empty_index): old_query = "SELECT * FROM old.things LEFT JOIN hive_metastore.other.matters USING (x) WHERE state > 1 LIMIT 10" - assert not list(ftf.lint(old_query, CurrentSessionState())) + assert not list(ftf.lint(old_query)) def test_migrated_tables_trigger_messages(migration_index): @@ -32,7 +32,7 @@ def test_migrated_tables_trigger_messages(migration_index): end_line=0, end_col=1024, ), - ] == list(ftf.lint(old_query, CurrentSessionState())) + ] == list(ftf.lint(old_query)) def test_fully_migrated_queries_match(migration_index): @@ -61,7 +61,7 @@ def test_use_database_change(migration_index): USE newcatalog; SELECT * FROM things LEFT JOIN hive_metastore.other.matters USING (x) WHERE state > 1 LIMIT 10""" - _ = list(ftf.lint(query, CurrentSessionState())) + _ = list(ftf.lint(query)) assert ftf.schema == "newcatalog" diff --git a/tests/unit/source_code/test_s3fs.py b/tests/unit/source_code/test_s3fs.py index 9f02b50305..f2bfaf0734 100644 --- a/tests/unit/source_code/test_s3fs.py +++ b/tests/unit/source_code/test_s3fs.py @@ -122,10 +122,8 @@ def test_detect_s3fs_import(empty_index, source: str, expected: list[DependencyP notebook_resolver = NotebookResolver(notebook_loader) import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - dependency_resolver = DependencyResolver( - pip_resolver, notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() - ) - maybe = dependency_resolver.build_local_file_dependency_graph(sample) + dependency_resolver = DependencyResolver(pip_resolver, notebook_resolver, import_resolver, mock_path_lookup) + maybe = dependency_resolver.build_local_file_dependency_graph(sample, CurrentSessionState()) assert maybe.problems == [_.replace(source_path=sample) for _ in expected] @@ -154,9 +152,7 @@ def test_detect_s3fs_import_in_dependencies( whitelist = Whitelist() import_resolver = ImportFileResolver(file_loader, whitelist) pip_resolver = PythonLibraryResolver(whitelist) - dependency_resolver = DependencyResolver( - pip_resolver, mock_notebook_resolver, import_resolver, mock_path_lookup, CurrentSessionState() - ) + dependency_resolver = DependencyResolver(pip_resolver, mock_notebook_resolver, import_resolver, mock_path_lookup) sample = mock_path_lookup.cwd / "root9.py" - maybe = dependency_resolver.build_local_file_dependency_graph(sample) + maybe = dependency_resolver.build_local_file_dependency_graph(sample, CurrentSessionState()) assert maybe.problems == expected From 43cf134b3e9d924092def14d27ca034d57fcdeb1 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 13 Jun 2024 08:26:44 +0200 Subject: [PATCH 13/23] formatting --- tests/unit/source_code/test_dependencies.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/unit/source_code/test_dependencies.py b/tests/unit/source_code/test_dependencies.py index c7fc821247..621e8d22e7 100644 --- a/tests/unit/source_code/test_dependencies.py +++ b/tests/unit/source_code/test_dependencies.py @@ -47,7 +47,9 @@ def test_dependency_resolver_visits_workspace_file_dependencies(simple_dependenc def test_dependency_resolver_raises_problem_with_unfound_workspace_notebook_dependency(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_notebook_dependency_graph(Path("root1-no-leaf.run.py"), CurrentSessionState()) + maybe = simple_dependency_resolver.build_notebook_dependency_graph( + Path("root1-no-leaf.run.py"), CurrentSessionState() + ) assert list(maybe.problems) == [ DependencyProblem( 'notebook-not-found', @@ -97,7 +99,9 @@ def test_dependency_resolver_visits_file_dependencies(simple_dependency_resolver def test_dependency_resolver_skips_builtin_dependencies(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("python_builtins.py"), CurrentSessionState()) + maybe = simple_dependency_resolver.build_local_file_dependency_graph( + Path("python_builtins.py"), CurrentSessionState() + ) assert not maybe.failed graph = maybe.graph maybe = graph.locate_dependency(Path("os")) @@ -107,7 +111,9 @@ def test_dependency_resolver_skips_builtin_dependencies(simple_dependency_resolv def test_dependency_resolver_ignores_known_dependencies(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("python_builtins.py"), CurrentSessionState()) + maybe = simple_dependency_resolver.build_local_file_dependency_graph( + Path("python_builtins.py"), CurrentSessionState() + ) assert maybe.graph graph = maybe.graph maybe_graph = graph.locate_dependency(Path("databricks")) @@ -191,7 +197,9 @@ def test_dependency_resolver_raises_problem_with_missing_file_loader(mock_notebo def test_dependency_resolver_raises_problem_for_non_inferable_sys_path(simple_dependency_resolver): - maybe = simple_dependency_resolver.build_local_file_dependency_graph(Path("sys-path-with-fstring.py"), CurrentSessionState()) + maybe = simple_dependency_resolver.build_local_file_dependency_graph( + Path("sys-path-with-fstring.py"), CurrentSessionState() + ) assert list(maybe.problems) == [ DependencyProblem( code='sys-path-cannot-compute', From 8b6a0e878196eda94dc760319660bbfea350b1a2 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 13 Jun 2024 08:46:45 +0200 Subject: [PATCH 14/23] fix failing test --- .../labs/ucx/source_code/linters/python_ast.py | 8 ++++++++ tests/unit/source_code/linters/test_python_ast.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/databricks/labs/ucx/source_code/linters/python_ast.py b/src/databricks/labs/ucx/source_code/linters/python_ast.py index 6584de3d42..9b11880a5e 100644 --- a/src/databricks/labs/ucx/source_code/linters/python_ast.py +++ b/src/databricks/labs/ucx/source_code/linters/python_ast.py @@ -8,6 +8,7 @@ from astroid import Assign, Attribute, Call, Const, decorators, FormattedValue, Import, ImportFrom, JoinedStr, Module, Name, NodeNG, parse, Uninferable # type: ignore from astroid.context import InferenceContext, InferenceResult, CallContext # type: ignore from astroid.typing import InferenceErrorInfo # type: ignore +from astroid.exceptions import InferenceError # type: ignore from databricks.labs.ucx.source_code.base import CurrentSessionState @@ -163,11 +164,18 @@ def _infer_values(self) -> Iterator[Iterable[NodeNG]]: elif isinstance(self._node, FormattedValue): yield from _LocalTree(self._node.value).do_infer_values() else: + yield from self._infer_internal() + + def _infer_internal(self): + try: for inferred in self._node.inferred(): # work around infinite recursion of empty lists if inferred == self._node: continue yield from _LocalTree(inferred).do_infer_values() + except InferenceError as e: + logger.debug(f"When inferring {self._node}", exc_info=e) + yield [Uninferable] def _infer_values_from_joined_string(self) -> Iterator[Iterable[NodeNG]]: assert isinstance(self._node, JoinedStr) diff --git a/tests/unit/source_code/linters/test_python_ast.py b/tests/unit/source_code/linters/test_python_ast.py index 4dd4809c93..caf83face1 100644 --- a/tests/unit/source_code/linters/test_python_ast.py +++ b/tests/unit/source_code/linters/test_python_ast.py @@ -201,10 +201,24 @@ def test_infers_externally_defined_values(): def test_fails_to_infer_missing_externally_defined_value(): + state = CurrentSessionState() + state.named_parameters = {"my-widget-1": "my-value-1", "my-widget-2": "my-value-2"} source = """ name = "my-widget" value = dbutils.widgets.get(name) """ + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[1].value) # value of value = ... + values = tree.infer_values(state) + assert all(not value.is_inferred() for value in values) + + +def test_survives_absence_of_externally_defined_values(): + source = """ + name = "my-widget" + value = dbutils.widgets.get(name) + """ tree = Tree.parse(source) nodes = tree.locate(Assign, []) tree = Tree(nodes[1].value) # value of value = ... From 4759117443205fa592e631d83651a3bcf6ce9bbf Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 13 Jun 2024 08:58:35 +0200 Subject: [PATCH 15/23] fix failing tests --- tests/integration/source_code/test_libraries.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/integration/source_code/test_libraries.py b/tests/integration/source_code/test_libraries.py index 9eb1a45cd1..9cf7e44f66 100644 --- a/tests/integration/source_code/test_libraries.py +++ b/tests/integration/source_code/test_libraries.py @@ -7,6 +7,7 @@ import pytest +from databricks.labs.ucx.source_code.base import CurrentSessionState from tests.unit.conftest import MockPathLookup @@ -22,7 +23,7 @@ def test_build_notebook_dependency_graphs_installs_wheel_with_pip_cell_in_notebook(simple_ctx, notebook): ctx = simple_ctx.replace(path_lookup=MockPathLookup()) - maybe = ctx.dependency_resolver.build_notebook_dependency_graph(Path(notebook)) + maybe = ctx.dependency_resolver.build_notebook_dependency_graph(Path(notebook), CurrentSessionState()) assert not maybe.problems assert maybe.graph.all_relative_names() == {f"{notebook}.py", "thingy/__init__.py"} @@ -30,13 +31,17 @@ def test_build_notebook_dependency_graphs_installs_wheel_with_pip_cell_in_notebo def test_build_notebook_dependency_graphs_installs_pytest_from_index_url(simple_ctx): ctx = simple_ctx.replace(path_lookup=MockPathLookup()) - maybe = ctx.dependency_resolver.build_notebook_dependency_graph(Path("pip_install_pytest_with_index_url")) + maybe = ctx.dependency_resolver.build_notebook_dependency_graph( + Path("pip_install_pytest_with_index_url"), CurrentSessionState() + ) assert not maybe.problems def test_build_notebook_dependency_graphs_installs_pypi_packages(simple_ctx): ctx = simple_ctx.replace(path_lookup=MockPathLookup()) - maybe = ctx.dependency_resolver.build_notebook_dependency_graph(Path("pip_install_multiple_packages")) + maybe = ctx.dependency_resolver.build_notebook_dependency_graph( + Path("pip_install_multiple_packages"), CurrentSessionState() + ) assert not maybe.problems assert maybe.graph.path_lookup.resolve(Path("splink")) assert maybe.graph.path_lookup.resolve(Path("mlflow")) @@ -48,7 +53,7 @@ def test_build_notebook_dependency_graphs_installs_pypi_packages(simple_ctx): def test_build_notebook_dependency_graphs_fails_installing_when_spaces(simple_ctx, notebook): ctx = simple_ctx.replace(path_lookup=MockPathLookup()) - maybe = ctx.dependency_resolver.build_notebook_dependency_graph(Path(notebook)) + maybe = ctx.dependency_resolver.build_notebook_dependency_graph(Path(notebook), CurrentSessionState()) assert not maybe.problems assert maybe.graph.all_relative_names() == {f"{notebook}.py", "thingy/__init__.py"} From 3ff9ba63cf58bd0d6825fabcf227f2323d9534a3 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 13 Jun 2024 12:17:48 +0200 Subject: [PATCH 16/23] add support for dbutils.widgets.getAll --- .../ucx/source_code/linters/python_ast.py | 73 ++++++++++++++++++- .../source_code/linters/test_python_ast.py | 16 ++++ .../source_code/samples/functional/widgets.py | 6 ++ 3 files changed, 92 insertions(+), 3 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/python_ast.py b/src/databricks/labs/ucx/source_code/linters/python_ast.py index 9b11880a5e..40766e6290 100644 --- a/src/databricks/labs/ucx/source_code/linters/python_ast.py +++ b/src/databricks/labs/ucx/source_code/linters/python_ast.py @@ -5,7 +5,7 @@ from collections.abc import Iterable, Iterator, Generator from typing import Any, TypeVar -from astroid import Assign, Attribute, Call, Const, decorators, FormattedValue, Import, ImportFrom, JoinedStr, Module, Name, NodeNG, parse, Uninferable # type: ignore +from astroid import Assign, Attribute, Call, Const, decorators, Dict, FormattedValue, Import, ImportFrom, JoinedStr, Module, Name, NodeNG, parse, Uninferable # type: ignore from astroid.context import InferenceContext, InferenceResult, CallContext # type: ignore from astroid.typing import InferenceErrorInfo # type: ignore from astroid.exceptions import InferenceError # type: ignore @@ -151,9 +151,18 @@ def infer_values(self, state: CurrentSessionState | None = None) -> Iterable[Inf def _contextualize(self, state: CurrentSessionState | None): if state is None or state.named_parameters is None or len(state.named_parameters) == 0: return + self._contextualize_dbutils_widgets_get(state) + self._contextualize_dbutils_widgets_get_all(state) + + def _contextualize_dbutils_widgets_get(self, state: CurrentSessionState): calls = Tree(self.root).locate(Call, [("get", Attribute), ("widgets", Attribute), ("dbutils", Name)]) for call in calls: - call.func = _GetWidgetValueCall(state, call) + call.func = _DbUtilsWidgetsGetCall(state, call) + + def _contextualize_dbutils_widgets_get_all(self, state: CurrentSessionState): + calls = Tree(self.root).locate(Call, [("getAll", Attribute), ("widgets", Attribute), ("dbutils", Name)]) + for call in calls: + call.func = _DbUtilsWidgetsGetAllCall(state, call) def _infer_values(self) -> Iterator[Iterable[NodeNG]]: # deal with node types that don't implement 'inferred()' @@ -198,7 +207,7 @@ def do_infer_values(self): return self._infer_values() -class _GetWidgetValueCall(NodeNG): +class _DbUtilsWidgetsGetCall(NodeNG): def __init__(self, session_state: CurrentSessionState, node: NodeNG): super().__init__( @@ -243,6 +252,64 @@ def infer_call_result(self, context: InferenceContext | None = None, **_): # ca ) +class _DbUtilsWidgetsGetAllCall(NodeNG): + + def __init__(self, session_state: CurrentSessionState, node: NodeNG): + super().__init__( + lineno=node.lineno, + col_offset=node.col_offset, + end_lineno=node.end_lineno, + end_col_offset=node.end_col_offset, + parent=node.parent, + ) + self._session_state = session_state + + @decorators.raise_if_nothing_inferred + def _infer( + self, context: InferenceContext | None = None, **kwargs: Any + ) -> Generator[InferenceResult, None, InferenceErrorInfo | None]: + yield self + return InferenceErrorInfo(node=self, context=context) + + def infer_call_result(self, **_): # caller needs unused kwargs + named_parameters = self._session_state.named_parameters + if not named_parameters: + yield Uninferable + return + items = self._populate_items(named_parameters) + result = Dict( + lineno=self.lineno, + col_offset=self.col_offset, + end_lineno=self.end_lineno, + end_col_offset=self.end_col_offset, + parent=self, + ) + result.postinit(items) + yield result + + def _populate_items(self, values: dict[str, str]): + items: list[tuple[InferenceResult, InferenceResult]] = [] + for key, value in values.items(): + item_key = Const( + key, + lineno=self.lineno, + col_offset=self.col_offset, + end_lineno=self.end_lineno, + end_col_offset=self.end_col_offset, + parent=self, + ) + item_value = Const( + value, + lineno=self.lineno, + col_offset=self.col_offset, + end_lineno=self.end_lineno, + end_col_offset=self.end_col_offset, + parent=self, + ) + items.append((item_key, item_value)) + return items + + class InferredValue: """Represents 1 or more nodes that together represent the value. The list of nodes typically holds one Const element, but for f-strings it diff --git a/tests/unit/source_code/linters/test_python_ast.py b/tests/unit/source_code/linters/test_python_ast.py index caf83face1..b25bd744de 100644 --- a/tests/unit/source_code/linters/test_python_ast.py +++ b/tests/unit/source_code/linters/test_python_ast.py @@ -224,3 +224,19 @@ def test_survives_absence_of_externally_defined_values(): tree = Tree(nodes[1].value) # value of value = ... values = tree.infer_values(CurrentSessionState()) assert all(not value.is_inferred() for value in values) + + +def test_infers_externally_defined_value_set(): + state = CurrentSessionState() + state.named_parameters = {"my-widget": "my-value"} + source = """ +values = dbutils.widgets.getAll() +name = "my-widget" +value = values[name] +""" + tree = Tree.parse(source) + nodes = tree.locate(Assign, []) + tree = Tree(nodes[2].value) # value of value = ... + values = list(tree.infer_values(state)) + strings = list(value.as_string() for value in values) + assert strings == ["my-value"] diff --git a/tests/unit/source_code/samples/functional/widgets.py b/tests/unit/source_code/samples/functional/widgets.py index 6897cd6398..fe6e20f754 100644 --- a/tests/unit/source_code/samples/functional/widgets.py +++ b/tests/unit/source_code/samples/functional/widgets.py @@ -3,3 +3,9 @@ path = dbutils.widgets.get("no-widget") # ucx[dbutils-notebook-run-dynamic:+1:0:+1:26] Path for 'dbutils.notebook.run' cannot be computed and requires adjusting the notebook path(s) dbutils.notebook.run(path) +values = dbutils.widgets.getAll() +path = values["my-widget"] +dbutils.notebook.run(path) +path = values["no-widget"] +# ucx[dbutils-notebook-run-dynamic:+1:0:+1:26] Path for 'dbutils.notebook.run' cannot be computed and requires adjusting the notebook path(s) +dbutils.notebook.run(path) From 939b42f544313f624c90860fa850dec37b8d1aa8 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 13 Jun 2024 14:02:05 +0200 Subject: [PATCH 17/23] fix crasher when running solacc --- src/databricks/labs/ucx/contexts/workspace_cli.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/labs/ucx/contexts/workspace_cli.py b/src/databricks/labs/ucx/contexts/workspace_cli.py index 937aacbbf6..6631410451 100644 --- a/src/databricks/labs/ucx/contexts/workspace_cli.py +++ b/src/databricks/labs/ucx/contexts/workspace_cli.py @@ -179,7 +179,7 @@ class LocalCheckoutContext(WorkspaceContext): """Local context extends Workspace context to provide extra properties for running local operations.""" - def linter_context_factory(self, session_state: CurrentSessionState | None = None): + def _linter_context_factory(self, session_state: CurrentSessionState | None = None): index = self.tables_migrator.index() if session_state is None: session_state = CurrentSessionState() @@ -187,7 +187,7 @@ def linter_context_factory(self, session_state: CurrentSessionState | None = Non @cached_property def local_file_migrator(self): - return LocalFileMigrator(self.linter_context_factory) + return LocalFileMigrator(lambda: self._linter_context_factory(CurrentSessionState())) @cached_property def local_code_linter(self): @@ -198,5 +198,5 @@ def local_code_linter(self): self.path_lookup, session_state, self.dependency_resolver, - lambda: self.linter_context_factory(session_state), + lambda: self._linter_context_factory(session_state), ) From 6e63ee2bccd40a79448d75ec90f47153e930a80f Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 13 Jun 2024 14:20:54 +0200 Subject: [PATCH 18/23] fix the fix --- src/databricks/labs/ucx/contexts/workspace_cli.py | 6 +++--- tests/integration/source_code/solacc.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/databricks/labs/ucx/contexts/workspace_cli.py b/src/databricks/labs/ucx/contexts/workspace_cli.py index 6631410451..db726573d1 100644 --- a/src/databricks/labs/ucx/contexts/workspace_cli.py +++ b/src/databricks/labs/ucx/contexts/workspace_cli.py @@ -179,7 +179,7 @@ class LocalCheckoutContext(WorkspaceContext): """Local context extends Workspace context to provide extra properties for running local operations.""" - def _linter_context_factory(self, session_state: CurrentSessionState | None = None): + def linter_context_factory(self, session_state: CurrentSessionState | None = None): index = self.tables_migrator.index() if session_state is None: session_state = CurrentSessionState() @@ -187,7 +187,7 @@ def _linter_context_factory(self, session_state: CurrentSessionState | None = No @cached_property def local_file_migrator(self): - return LocalFileMigrator(lambda: self._linter_context_factory(CurrentSessionState())) + return LocalFileMigrator(lambda: self.linter_context_factory(CurrentSessionState())) @cached_property def local_code_linter(self): @@ -198,5 +198,5 @@ def local_code_linter(self): self.path_lookup, session_state, self.dependency_resolver, - lambda: self._linter_context_factory(session_state), + lambda: self.linter_context_factory(session_state), ) diff --git a/tests/integration/source_code/solacc.py b/tests/integration/source_code/solacc.py index 20ead2a203..8998d4a3b9 100644 --- a/tests/integration/source_code/solacc.py +++ b/tests/integration/source_code/solacc.py @@ -47,7 +47,7 @@ def clone_all(): def lint_all(): # pylint: disable=too-many-nested-blocks ws = WorkspaceClient(host='...', token='...') - ctx = LocalCheckoutContext(ws).replace(linter_context_factory=lambda: LinterContext(MigrationIndex([]))) + ctx = LocalCheckoutContext(ws).replace(linter_context_factory=lambda session_state: LinterContext(MigrationIndex([]), session_state)) parseable = 0 missing_imports = 0 all_files = list(dist.glob('**/*.py')) @@ -61,7 +61,7 @@ def lint_all(): parseable += 1 except Exception as e: # pylint: disable=broad-except # here we're most likely catching astroid & sqlglot errors - logger.error(f"Error during parsing of {file}: {e}".replace("\n", " ")) + logger.error(f"Error during parsing of {file}: {e}".replace("\n", " "), exc_info=e) parseable_pct = int(parseable / len(all_files) * 100) logger.info(f"Parseable: {parseable_pct}% ({parseable}/{len(all_files)}), missing imports: {missing_imports}") if parseable_pct < 100: From b59cbd34f1128d0b534988565ede8cd6757aef61 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 13 Jun 2024 14:52:22 +0200 Subject: [PATCH 19/23] fix issue when linting sql statement 'CREATE SCHEMA xxx' --- src/databricks/labs/ucx/source_code/queries.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/databricks/labs/ucx/source_code/queries.py b/src/databricks/labs/ucx/source_code/queries.py index b2b0b9cdd1..d465368889 100644 --- a/src/databricks/labs/ucx/source_code/queries.py +++ b/src/databricks/labs/ucx/source_code/queries.py @@ -2,7 +2,7 @@ import logging import sqlglot -from sqlglot.expressions import Table, Expression, Use +from sqlglot.expressions import Table, Expression, Use, Create from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex from databricks.labs.ucx.source_code.base import Advice, Deprecation, Fixer, Linter, CurrentSessionState @@ -52,6 +52,11 @@ def lint(self, code: str) -> Iterable[Advice]: # the schema as the table name. self._session_state.schema = table.name continue + if isinstance(statement, Create) and statement.kind == "SCHEMA": + # Sqlglot captures the schema name in the Create statement as a Table, with + # the schema as the db name. + self._session_state.schema = table.db + continue # we only migrate tables in the hive_metastore catalog if self._catalog(table) != 'hive_metastore': From 2f0e0282b574688816880e2fb38a6f51d64e6385 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 13 Jun 2024 15:08:17 +0200 Subject: [PATCH 20/23] Infer table names from f-strings and improve Advice messages --- .../labs/ucx/source_code/linters/pyspark.py | 32 +++++++++---------- .../unit/source_code/linters/test_pyspark.py | 11 +++++++ .../catalog/spark-catalog-cache-table.py | 4 +-- .../spark-catalog-create-external-table.py | 4 +-- .../catalog/spark-catalog-create-table.py | 4 +-- .../catalog/spark-catalog-get-table.py | 4 +-- .../catalog/spark-catalog-is-cached.py | 4 +-- .../catalog/spark-catalog-list-columns.py | 4 +-- .../catalog/spark-catalog-list-tables.py | 4 +-- .../spark-catalog-recover-partitions.py | 4 +-- .../catalog/spark-catalog-refresh-table.py | 4 +-- .../catalog/spark-catalog-uncache-table.py | 4 +-- .../pyspark/dataframe-write-insert-into.py | 4 +-- .../pyspark/dataframe-write-save-as-table.py | 4 +-- .../samples/functional/pyspark/spark-table.py | 4 +-- 15 files changed, 52 insertions(+), 43 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/linters/pyspark.py b/src/databricks/labs/ucx/source_code/linters/pyspark.py index 12201fd488..c29df6786e 100644 --- a/src/databricks/labs/ucx/source_code/linters/pyspark.py +++ b/src/databricks/labs/ucx/source_code/linters/pyspark.py @@ -113,26 +113,24 @@ def lint( self, from_table: FromTable, index: MigrationIndex, session_state: CurrentSessionState, node: Call ) -> Iterator[Advice]: table_arg = self._get_table_arg(node) - - if not isinstance(table_arg, Const): - assert isinstance(node.func, Attribute) # always true, avoids a pylint warning - yield Advisory.from_node( + table_name = table_arg.as_string().strip("'").strip('"') + for inferred in Tree(table_arg).infer_values(session_state): + if not inferred.is_inferred(): + yield Advisory.from_node( + code='table-migrate', + message=f"Can't migrate '{node.as_string()}' because its table name argument cannot be computed", + node=node, + ) + continue + dst = self._find_dest(index, inferred.as_string(), from_table.schema) + if dst is None: + continue + yield Deprecation.from_node( code='table-migrate', - message=f"Can't migrate '{node.func.attrname}' because its table name argument is not a constant", + message=f"Table {table_name} is migrated to {dst.destination()} in Unity Catalog", + # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 node=node, ) - return - - dst = self._find_dest(index, table_arg.value, from_table.schema) - if dst is None: - return - - yield Deprecation.from_node( - code='table-migrate', - message=f"Table {table_arg.value} is migrated to {dst.destination()} in Unity Catalog", - # SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159 - node=node, - ) def apply(self, from_table: FromTable, index: MigrationIndex, node: Call) -> None: table_arg = self._get_table_arg(node) diff --git a/tests/unit/source_code/linters/test_pyspark.py b/tests/unit/source_code/linters/test_pyspark.py index 6f5bfdd6b3..6e5ceec468 100644 --- a/tests/unit/source_code/linters/test_pyspark.py +++ b/tests/unit/source_code/linters/test_pyspark.py @@ -16,6 +16,17 @@ def test_spark_no_sql(empty_index): assert not list(sqf.lint("print(1)")) +def test_spark_dynamic_sql(empty_index): + source = """ +schema="some_schema" +df4.write.saveAsTable(f"{schema}.member_measure") +""" + session_state = CurrentSessionState() + ftf = FromTable(empty_index, session_state) + sqf = SparkSql(ftf, empty_index, session_state) + assert not list(sqf.lint(source)) + + def test_spark_sql_no_match(empty_index): session_state = CurrentSessionState() ftf = FromTable(empty_index, session_state) diff --git a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-cache-table.py b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-cache-table.py index f4d544db71..22708aa45e 100644 --- a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-cache-table.py +++ b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-cache-table.py @@ -13,9 +13,9 @@ spark.catalog.cacheTable(storageLevel=None, tableName="old.things") ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated table. -# ucx[table-migrate:+1:0:+1:30] Can't migrate 'cacheTable' because its table name argument is not a constant +# ucx[table-migrate:+1:0:+1:30] Can't migrate 'spark.catalog.cacheTable(name)' because its table name argument cannot be computed spark.catalog.cacheTable(name) -# ucx[table-migrate:+1:0:+1:40] Can't migrate 'cacheTable' because its table name argument is not a constant +# ucx[table-migrate:+1:0:+1:40] Can't migrate 'spark.catalog.cacheTable(f'boop{stuff}')' because its table name argument cannot be computed spark.catalog.cacheTable(f"boop{stuff}") ## Some trivial references to the method or table in unrelated contexts that should not trigger warnigns. diff --git a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-create-external-table.py b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-create-external-table.py index cacf7afb31..d57ada90f5 100644 --- a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-create-external-table.py +++ b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-create-external-table.py @@ -27,10 +27,10 @@ do_stuff_with(df) ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated table. - # ucx[table-migrate:+1:9:+1:48] Can't migrate 'createExternalTable' because its table name argument is not a constant + # ucx[table-migrate:+1:9:+1:48] Can't migrate 'spark.catalog.createExternalTable(name)' because its table name argument cannot be computed df = spark.catalog.createExternalTable(name) do_stuff_with(df) - # ucx[table-migrate:+1:9:+1:58] Can't migrate 'createExternalTable' because its table name argument is not a constant + # ucx[table-migrate:+1:9:+1:58] Can't migrate 'spark.catalog.createExternalTable(f'boop{stuff}')' because its table name argument cannot be computed df = spark.catalog.createExternalTable(f"boop{stuff}") do_stuff_with(df) diff --git a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-create-table.py b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-create-table.py index b25affaa9b..16fda6951f 100644 --- a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-create-table.py +++ b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-create-table.py @@ -27,10 +27,10 @@ do_stuff_with(df) ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated table. - # ucx[table-migrate:+1:9:+1:40] Can't migrate 'createTable' because its table name argument is not a constant + # ucx[table-migrate:+1:9:+1:40] Can't migrate 'spark.catalog.createTable(name)' because its table name argument cannot be computed df = spark.catalog.createTable(name) do_stuff_with(df) - # ucx[table-migrate:+1:9:+1:50] Can't migrate 'createTable' because its table name argument is not a constant + # ucx[table-migrate:+1:9:+1:50] Can't migrate 'spark.catalog.createTable(f'boop{stuff}')' because its table name argument cannot be computed df = spark.catalog.createTable(f"boop{stuff}") do_stuff_with(df) diff --git a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-get-table.py b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-get-table.py index cbc576cbb8..ce552204fd 100644 --- a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-get-table.py +++ b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-get-table.py @@ -16,10 +16,10 @@ do_stuff_with(table) ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated table. - # ucx[table-migrate:+1:12:+1:40] Can't migrate 'getTable' because its table name argument is not a constant + # ucx[table-migrate:+1:12:+1:40] Can't migrate 'spark.catalog.getTable(name)' because its table name argument cannot be computed table = spark.catalog.getTable(name) do_stuff_with(table) - # ucx[table-migrate:+1:12:+1:50] Can't migrate 'getTable' because its table name argument is not a constant + # ucx[table-migrate:+1:12:+1:50] Can't migrate 'spark.catalog.getTable(f'boop{stuff}')' because its table name argument cannot be computed table = spark.catalog.getTable(f"boop{stuff}") do_stuff_with(table) diff --git a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-is-cached.py b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-is-cached.py index 1a4273d2e4..92f870961c 100644 --- a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-is-cached.py +++ b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-is-cached.py @@ -13,9 +13,9 @@ cached_previously = spark.catalog.isCached("old.things", "extra-argument") ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated table. - # ucx[table-migrate:+1:24:+1:52] Can't migrate 'isCached' because its table name argument is not a constant + # ucx[table-migrate:+1:24:+1:52] Can't migrate 'spark.catalog.isCached(name)' because its table name argument cannot be computed cached_previously = spark.catalog.isCached(name) - # ucx[table-migrate:+1:24:+1:62] Can't migrate 'isCached' because its table name argument is not a constant + # ucx[table-migrate:+1:24:+1:62] Can't migrate 'spark.catalog.isCached(f'boop{stuff}')' because its table name argument cannot be computed cached_previously = spark.catalog.isCached(f"boop{stuff}") ## Some trivial references to the method or table in unrelated contexts that should not trigger warnigns. diff --git a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-list-columns.py b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-list-columns.py index 4f669852b7..11d5152545 100644 --- a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-list-columns.py +++ b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-list-columns.py @@ -23,9 +23,9 @@ columns = spark.catalog.listColumns(dbName="old", name="things") ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated table. - # ucx[table-migrate:+1:14:+1:45] Can't migrate 'listColumns' because its table name argument is not a constant + # ucx[table-migrate:+1:14:+1:45] Can't migrate 'spark.catalog.listColumns(name)' because its table name argument cannot be computed columns = spark.catalog.listColumns(name) - # ucx[table-migrate:+1:14:+1:55] Can't migrate 'listColumns' because its table name argument is not a constant + # ucx[table-migrate:+1:14:+1:55] Can't migrate 'spark.catalog.listColumns(f'boop{stuff}')' because its table name argument cannot be computed columns = spark.catalog.listColumns(f"boop{stuff}") ## Some trivial references to the method or table in unrelated contexts that should not trigger warnigns. diff --git a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-list-tables.py b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-list-tables.py index d1242ceaca..715a4f74bd 100644 --- a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-list-tables.py +++ b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-list-tables.py @@ -33,12 +33,12 @@ ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated database. # ucx[table-migrate:+3:13:+3:43] Call to 'listTables' will return a list of .. instead of .
. ## TODO: The following isn't yet implemented: -## ucx[table-migrate:+1:13:+1:0] Can't migrate 'listTables' because its database name argument is not a constant +## ucx[table-migrate:+1:13:+1:0] Can't migrate 'listTables' because its database name argument cannot be computed for table in spark.catalog.listTables(name): do_stuff_with_table(table) # ucx[table-migrate:+3:13:+3:53] Call to 'listTables' will return a list of ..
instead of .
. ## TODO: The following isn't yet implemented: -## ucx[table-migrate:+1:13:+1:0] Can't migrate 'listTables' because its database name argument is not a constant +## ucx[table-migrate:+1:13:+1:0] Can't migrate 'listTables' because its database name argument cannot be computed for table in spark.catalog.listTables(f"boop{stuff}"): do_stuff_with_table(table) diff --git a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-recover-partitions.py b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-recover-partitions.py index bef212fe56..5089597d2d 100644 --- a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-recover-partitions.py +++ b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-recover-partitions.py @@ -13,9 +13,9 @@ spark.catalog.recoverPartitions("old.things", "extra-argument") ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated table. - # ucx[table-migrate:+1:4:+1:41] Can't migrate 'recoverPartitions' because its table name argument is not a constant + # ucx[table-migrate:+1:4:+1:41] Can't migrate 'spark.catalog.recoverPartitions(name)' because its table name argument cannot be computed spark.catalog.recoverPartitions(name) - # ucx[table-migrate:+1:4:+1:51] Can't migrate 'recoverPartitions' because its table name argument is not a constant + # ucx[table-migrate:+1:4:+1:51] Can't migrate 'spark.catalog.recoverPartitions(f'boop{stuff}')' because its table name argument cannot be computed spark.catalog.recoverPartitions(f"boop{stuff}") ## Some trivial references to the method or table in unrelated contexts that should not trigger warnigns. diff --git a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-refresh-table.py b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-refresh-table.py index 9974aa9bac..8a568cd44f 100644 --- a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-refresh-table.py +++ b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-refresh-table.py @@ -13,9 +13,9 @@ spark.catalog.refreshTable("old.things", "extra-argument") ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated table. - # ucx[table-migrate:+1:4:+1:36] Can't migrate 'refreshTable' because its table name argument is not a constant + # ucx[table-migrate:+1:4:+1:36] Can't migrate 'spark.catalog.refreshTable(name)' because its table name argument cannot be computed spark.catalog.refreshTable(name) - # ucx[table-migrate:+1:4:+1:46] Can't migrate 'refreshTable' because its table name argument is not a constant + # ucx[table-migrate:+1:4:+1:46] Can't migrate 'spark.catalog.refreshTable(f'boop{stuff}')' because its table name argument cannot be computed spark.catalog.refreshTable(f"boop{stuff}") ## Some trivial references to the method or table in unrelated contexts that should not trigger warnigns. diff --git a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-uncache-table.py b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-uncache-table.py index ac3cc41ad7..4d65ef1c81 100644 --- a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-uncache-table.py +++ b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-uncache-table.py @@ -13,9 +13,9 @@ spark.catalog.uncacheTable("old.things", "extra-argument") ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated table. - # ucx[table-migrate:+1:4:+1:36] Can't migrate 'uncacheTable' because its table name argument is not a constant + # ucx[table-migrate:+1:4:+1:36] Can't migrate 'spark.catalog.uncacheTable(name)' because its table name argument cannot be computed spark.catalog.uncacheTable(name) - # ucx[table-migrate:+1:4:+1:46] Can't migrate 'uncacheTable' because its table name argument is not a constant + # ucx[table-migrate:+1:4:+1:46] Can't migrate 'spark.catalog.uncacheTable(f'boop{stuff}')' because its table name argument cannot be computed spark.catalog.uncacheTable(f"boop{stuff}") ## Some trivial references to the method or table in unrelated contexts that should not trigger warnigns. diff --git a/tests/unit/source_code/samples/functional/pyspark/dataframe-write-insert-into.py b/tests/unit/source_code/samples/functional/pyspark/dataframe-write-insert-into.py index edf8ce75e7..6cca2c37b5 100644 --- a/tests/unit/source_code/samples/functional/pyspark/dataframe-write-insert-into.py +++ b/tests/unit/source_code/samples/functional/pyspark/dataframe-write-insert-into.py @@ -17,9 +17,9 @@ df.write.insertInto(overwrite=None, tableName="old.things") ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated table. - # ucx[table-migrate:+1:4:+1:29] Can't migrate 'insertInto' because its table name argument is not a constant + # ucx[table-migrate:+1:4:+1:29] Can't migrate 'df.write.insertInto(name)' because its table name argument cannot be computed df.write.insertInto(name) - # ucx[table-migrate:+1:4:+1:39] Can't migrate 'insertInto' because its table name argument is not a constant + # ucx[table-migrate:+1:4:+1:39] Can't migrate 'df.write.insertInto(f'boop{stuff}')' because its table name argument cannot be computed df.write.insertInto(f"boop{stuff}") ## Some trivial references to the method or table in unrelated contexts that should not trigger warnigns. diff --git a/tests/unit/source_code/samples/functional/pyspark/dataframe-write-save-as-table.py b/tests/unit/source_code/samples/functional/pyspark/dataframe-write-save-as-table.py index 7ff0411ced..4a26db53c4 100644 --- a/tests/unit/source_code/samples/functional/pyspark/dataframe-write-save-as-table.py +++ b/tests/unit/source_code/samples/functional/pyspark/dataframe-write-save-as-table.py @@ -17,9 +17,9 @@ df.write.saveAsTable(format="xyz", name="old.things") ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated table. - # ucx[table-migrate:+1:4:+1:46] Can't migrate 'saveAsTable' because its table name argument is not a constant + # ucx[table-migrate:+1:4:+1:46] Can't migrate 'df.write.format('delta').saveAsTable(name)' because its table name argument cannot be computed df.write.format("delta").saveAsTable(name) - # ucx[table-migrate:+1:4:+1:56] Can't migrate 'saveAsTable' because its table name argument is not a constant + # ucx[table-migrate:+1:4:+1:56] Can't migrate 'df.write.format('delta').saveAsTable(f'boop{stuff}')' because its table name argument cannot be computed df.write.format("delta").saveAsTable(f"boop{stuff}") ## Some trivial references to the method or table in unrelated contexts that should not trigger warnigns. diff --git a/tests/unit/source_code/samples/functional/pyspark/spark-table.py b/tests/unit/source_code/samples/functional/pyspark/spark-table.py index 3f8e204894..b4d96d52b0 100644 --- a/tests/unit/source_code/samples/functional/pyspark/spark-table.py +++ b/tests/unit/source_code/samples/functional/pyspark/spark-table.py @@ -20,12 +20,12 @@ do_stuff_with(df) ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated table. - # ucx[table-migrate:+3:9:+3:26] Can't migrate 'table' because its table name argument is not a constant + # ucx[table-migrate:+3:9:+3:26] Can't migrate 'spark.table(name)' because its table name argument cannot be computed # TODO: Fix false positive: # ucx[table-migrate:+1:9:+1:26] The default format changed in Databricks Runtime 8.0, from Parquet to Delta df = spark.table(name) do_stuff_with(df) - # ucx[table-migrate:+3:9:+3:36] Can't migrate 'table' because its table name argument is not a constant + # ucx[table-migrate:+3:9:+3:36] Can't migrate 'spark.table(f'boop{stuff}')' because its table name argument cannot be computed # TODO: Fix false positive: # ucx[table-migrate:+1:9:+1:36] The default format changed in Databricks Runtime 8.0, from Parquet to Delta df = spark.table(f"boop{stuff}") From 4602ac12e7b85793bad25a9c023869d67772893b Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 13 Jun 2024 15:08:53 +0200 Subject: [PATCH 21/23] Infer table names from f-strings and improve Advice messages --- .../functional/pyspark/catalog/spark-catalog-table-exists.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-table-exists.py b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-table-exists.py index c357492c84..56ec5b4c00 100644 --- a/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-table-exists.py +++ b/tests/unit/source_code/samples/functional/pyspark/catalog/spark-catalog-table-exists.py @@ -26,10 +26,10 @@ pass ## Some calls that use a variable whose value is unknown: they could potentially reference a migrated table. - # ucx[table-migrate:+1:7:+1:38] Can't migrate 'tableExists' because its table name argument is not a constant + # ucx[table-migrate:+1:7:+1:38] Can't migrate 'spark.catalog.tableExists(name)' because its table name argument cannot be computed if spark.catalog.tableExists(name): pass - # ucx[table-migrate:+1:7:+1:48] Can't migrate 'tableExists' because its table name argument is not a constant + # ucx[table-migrate:+1:7:+1:48] Can't migrate 'spark.catalog.tableExists(f'boot{stuff}')' because its table name argument cannot be computed if spark.catalog.tableExists(f"boot{stuff}"): pass From f8f9b214d97ff9cf0b307740c8b916bb1927513d Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 13 Jun 2024 15:29:57 +0200 Subject: [PATCH 22/23] formatting --- tests/integration/source_code/solacc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integration/source_code/solacc.py b/tests/integration/source_code/solacc.py index 8998d4a3b9..98ca0d33a9 100644 --- a/tests/integration/source_code/solacc.py +++ b/tests/integration/source_code/solacc.py @@ -47,7 +47,9 @@ def clone_all(): def lint_all(): # pylint: disable=too-many-nested-blocks ws = WorkspaceClient(host='...', token='...') - ctx = LocalCheckoutContext(ws).replace(linter_context_factory=lambda session_state: LinterContext(MigrationIndex([]), session_state)) + ctx = LocalCheckoutContext(ws).replace( + linter_context_factory=lambda session_state: LinterContext(MigrationIndex([]), session_state) + ) parseable = 0 missing_imports = 0 all_files = list(dist.glob('**/*.py')) From aa63ece8cfbf33ca58dffccc795366d8854c8d20 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Fri, 14 Jun 2024 09:48:05 +0200 Subject: [PATCH 23/23] increase test coverage --- tests/unit/source_code/test_queries.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/unit/source_code/test_queries.py b/tests/unit/source_code/test_queries.py index 0316170516..0e5fc628d5 100644 --- a/tests/unit/source_code/test_queries.py +++ b/tests/unit/source_code/test_queries.py @@ -77,3 +77,11 @@ def test_use_database_stops_migration(migration_index): ) transformed_query = ftf.apply(old_query) assert transformed_query == new_query + + +def test_parses_create_schema(migration_index): + query = "CREATE SCHEMA xyz" + session_state = CurrentSessionState(schema="old") + ftf = FromTable(migration_index, session_state=session_state) + advices = ftf.lint(query) + assert not list(advices)