Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inferred values when linting sys path - saved #1863

Closed
51 changes: 38 additions & 13 deletions src/databricks/labs/ucx/source_code/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import abc
import logging
from dataclasses import dataclass
from pathlib import Path
from collections.abc import Callable
Expand All @@ -13,10 +14,13 @@
ImportSource,
NotebookRunCall,
SysPathChange,
UnresolvedPath,
)
from databricks.labs.ucx.source_code.linters.python_ast import Tree, NodeBase
from databricks.labs.ucx.source_code.path_lookup import PathLookup

logger = logging.Logger(__file__)


class DependencyGraph:

Expand Down Expand Up @@ -182,21 +186,42 @@ def build_graph_from_python_source(self, python_code: str) -> list[DependencyPro

def _process_node(self, base_node: NodeBase):
if isinstance(base_node, SysPathChange):
self._mutate_path_lookup(base_node)
if isinstance(base_node, NotebookRunCall):
strpath = base_node.get_notebook_path()
if strpath is None:
yield DependencyProblem('dependency-not-constant', "Can't check dependency not provided as a constant")
else:
yield from self.register_notebook(Path(strpath))
if isinstance(base_node, ImportSource):
prefix = ""
if isinstance(base_node.node, ImportFrom) and base_node.node.level is not None:
prefix = "." * base_node.node.level
name = base_node.name or ""
yield from self.register_import(prefix + name)
yield from self._mutate_path_lookup(base_node)
elif isinstance(base_node, NotebookRunCall):
yield from self._register_notebook(base_node)
elif isinstance(base_node, ImportSource):
yield from self._register_import(base_node)
else:
logger.error(f"Can't process {NodeBase.__name__} of type {type(base_node).__name__}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.error(f"Can't process {NodeBase.__name__} of type {type(base_node).__name__}")
logger.warning(f"Can't process {NodeBase.__name__} of type {type(base_node).__name__}")

Error logs fail the workflow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def _register_import(self, base_node: ImportSource):
prefix = ""
if isinstance(base_node.node, ImportFrom) and base_node.node.level is not None:
prefix = "." * base_node.node.level
name = base_node.name or ""
yield from self.register_import(prefix + name)

def _register_notebook(self, base_node: NotebookRunCall):
paths = base_node.get_notebook_paths()
asserted = False
for path in paths:
if isinstance(path, str):
yield from self.register_notebook(Path(path))
continue
if not asserted:
asserted = True
yield DependencyProblem(
'dependency-cannot-compute',
f"Can't check dependency from {base_node.node.as_string()} because the expression cannot be computed",
)

def _mutate_path_lookup(self, change: SysPathChange):
if isinstance(change, UnresolvedPath):
yield DependencyProblem(
'sys-path-cannot-compute',
f"Can't update sys.path from {change.node.as_string()} because the expression cannot be computed",
)
return
path = Path(change.path)
if not path.is_absolute():
path = self._path_lookup.cwd / path
Expand Down
75 changes: 51 additions & 24 deletions src/databricks/labs/ucx/source_code/linters/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
Attribute,
Call,
Const,
InferenceError,
Import,
ImportFrom,
Name,
NodeNG,
Uninferable,
)

from databricks.labs.ucx.source_code.base import Linter, Advice, Advisory
Expand Down Expand Up @@ -83,12 +85,23 @@ class NotebookRunCall(NodeBase):
def __init__(self, node: Call):
super().__init__(node)

def get_notebook_path(self) -> str | None:
node = DbutilsLinter.get_dbutils_notebook_run_path_arg(cast(Call, self.node))
inferred = next(node.infer(), None)
if isinstance(inferred, Const):
return inferred.value.strip().strip("'").strip('"')
return None
def get_notebook_paths(self) -> list[str | None]:
node = DbutilsLinter.get_dbutils_notebook_run_path_arg(self.node)
try:
return self._get_notebook_paths(node.infer())
except InferenceError:
logger.debug(f"Can't infer value(s) of {node.as_string()}")
return [None]

@classmethod
def _get_notebook_paths(cls, nodes: Iterable[NodeNG]) -> list[str | None]:
paths: list[str | None] = []
for node in nodes:
if isinstance(node, Const):
paths.append(node.as_string().strip("'").strip('"'))
continue
paths.append(None)
return paths


T = TypeVar("T", bound=Callable)
Expand All @@ -104,19 +117,20 @@ def lint(self, code: str) -> Iterable[Advice]:
@classmethod
def _convert_dbutils_notebook_run_to_advice(cls, node: NodeNG) -> Advisory:
assert isinstance(node, Call)
path = cls.get_dbutils_notebook_run_path_arg(node)
if isinstance(path, Const):
call = NotebookRunCall(cast(Call, node))
paths = call.get_notebook_paths()
if None in paths:
return Advisory(
'dbutils-notebook-run-literal',
"Call to 'dbutils.notebook.run' will be migrated automatically",
'dbutils-notebook-run-dynamic',
"Path for 'dbutils.notebook.run' is too complex and requires adjusting the notebook path(s)",
node.lineno,
node.col_offset,
node.end_lineno or 0,
node.end_col_offset or 0,
)
return Advisory(
'dbutils-notebook-run-dynamic',
"Path for 'dbutils.notebook.run' is not a constant and requires adjusting the notebook path",
'dbutils-notebook-run-literal',
"Call to 'dbutils.notebook.run' will be migrated automatically",
nfx marked this conversation as resolved.
Show resolved Hide resolved
node.lineno,
node.col_offset,
node.end_lineno or 0,
Expand Down Expand Up @@ -172,6 +186,11 @@ class RelativePath(SysPathChange):
pass


class UnresolvedPath(SysPathChange):
# path added to sys.path that cannot be inferred
pass


class SysPathChangesVisitor(TreeVisitor):

def __init__(self):
Expand Down Expand Up @@ -204,10 +223,26 @@ def visit_call(self, node: Call):
return
is_append = func.attrname == "append"
changed = node.args[0] if is_append else node.args[1]
if isinstance(changed, Const):
self.sys_path_changes.append(AbsolutePath(node, changed.value, is_append))
elif isinstance(changed, Call):
self._visit_relative_path(changed, is_append)
relative = False
if isinstance(changed, Call):
if not self._match_aliases(changed.func, ["os", "path", "abspath"]):
return
relative = True
changed = changed.args[0]
try:
all_inferred = changed.inferred()
for inferred in all_inferred:
self._visit_inferred(changed, inferred, relative, is_append)
except InferenceError:
self.sys_path_changes.append(UnresolvedPath(changed, changed.as_string(), is_append))

def _visit_inferred(self, changed: NodeNG, inferred: NodeNG, is_relative: bool, is_append: bool):
if inferred is Uninferable or not isinstance(inferred, Const):
self.sys_path_changes.append(UnresolvedPath(changed, changed.as_string(), is_append))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.sys_path_changes.append(UnresolvedPath(changed, changed.as_string(), is_append))
self.sys_path_changes.append(UnresolvedPath(changed, changed.as_string(), is_append))
return

Can we explicitly return, so that branches won't fall through accidentally?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if is_relative:
self.sys_path_changes.append(RelativePath(changed, inferred.value, is_append))
else:
self.sys_path_changes.append(AbsolutePath(changed, inferred.value, is_append))

def _match_aliases(self, node: NodeNG, names: list[str]):
if isinstance(node, Attribute):
Expand All @@ -221,11 +256,3 @@ def _match_aliases(self, node: NodeNG, names: list[str]):
alias = self._aliases.get(full_name, full_name)
return node.name == alias
return False

def _visit_relative_path(self, node: Call, is_append: bool):
# check for 'os.path.abspath'
if not self._match_aliases(node.func, ["os", "path", "abspath"]):
return
changed = node.args[0]
if isinstance(changed, Const):
self.sys_path_changes.append(RelativePath(changed, changed.value, is_append))
15 changes: 9 additions & 6 deletions src/databricks/labs/ucx/source_code/notebooks/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SourceContainer,
)
from databricks.labs.ucx.source_code.notebooks.cells import CellLanguage
from databricks.labs.ucx.source_code.notebooks.sources import Notebook
from databricks.labs.ucx.source_code.notebooks.sources import Notebook, SUPPORTED_EXTENSION_LANGUAGES
from databricks.labs.ucx.source_code.path_lookup import PathLookup

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,17 +55,20 @@ def load_dependency(self, path_lookup: PathLookup, dependency: Dependency) -> So
except NotFound:
logger.warning(f"Could not read notebook from workspace: {absolute_path}")
return None
language = self._detect_language(content)
language = self.detect_language(absolute_path, content)
if not language:
logger.warning(f"Could not detect language for {absolute_path}")
return None
return Notebook.parse(absolute_path, content, language)

@staticmethod
def _detect_language(content: str):
for language in CellLanguage:
if content.startswith(language.file_magic_header):
return language.language
def detect_language(path: Path, content: str):
language = SUPPORTED_EXTENSION_LANGUAGES.get(path.suffix, None)
if language:
return language
for cell_language in CellLanguage:
if content.startswith(cell_language.file_magic_header):
return cell_language.language
return None

@staticmethod
Expand Down
123 changes: 123 additions & 0 deletions tests/unit/source_code/linters/test_python_ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import functools
import operator

import pytest
from astroid import Attribute, Call, Const, Expr # type: ignore

from databricks.labs.ucx.source_code.linters.imports import DbutilsLinter
from databricks.labs.ucx.source_code.linters.python_ast import Tree


def test_extract_call_by_name():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m2")
assert isinstance(act, Call)
assert isinstance(act.func, Attribute)
assert act.func.attrname == "m2"


def test_extract_call_by_name_none():
tree = Tree.parse("o.m1().m2().m3()")
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.extract_call_by_name(stmt.value, "m5000")
assert act is None


@pytest.mark.parametrize(
"code, arg_index, arg_name, expected",
[
("o.m1()", 1, "second", None),
("o.m1(3)", 1, "second", None),
("o.m1(first=3)", 1, "second", None),
("o.m1(4, 3)", None, None, None),
("o.m1(4, 3)", None, "second", None),
("o.m1(4, 3)", 1, "second", 3),
("o.m1(4, 3)", 1, None, 3),
("o.m1(first=4, second=3)", 1, "second", 3),
("o.m1(second=3, first=4)", 1, "second", 3),
("o.m1(second=3, first=4)", None, "second", 3),
("o.m1(second=3)", 1, "second", 3),
("o.m1(4, 3, 2)", 1, "second", 3),
],
)
def test_linter_gets_arg(code, arg_index, arg_name, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.get_arg(stmt.value, arg_index, arg_name)
if expected is None:
assert act is None
else:
assert isinstance(act, Const)
assert act.value == expected


@pytest.mark.parametrize(
"code, expected",
[
("o.m1()", 0),
("o.m1(3)", 1),
("o.m1(first=3)", 1),
("o.m1(3, 3)", 2),
("o.m1(first=3, second=3)", 2),
("o.m1(3, second=3)", 2),
("o.m1(3, *b, **c, second=3)", 4),
],
)
def test_args_count(code, expected):
tree = Tree.parse(code)
stmt = tree.first_statement()
assert isinstance(stmt, Expr)
assert isinstance(stmt.value, Call)
act = Tree.args_count(stmt.value)
assert act == expected


def test_tree_walks_nodes_once():
nodes = set()
count = 0
tree = Tree.parse("o.m1().m2().m3()")
for node in tree.walk():
nodes.add(node)
count += 1
assert len(nodes) == count


@pytest.mark.parametrize(
"code, expected",
[
(
"""
name = "xyz"
dbutils.notebook.run(name)
""",
["xyz"],
),
(
"""
name = "xyz" + "-" + "abc"
dbutils.notebook.run(name)
""",
["xyz-abc"],
),
(
"""
names = ["abc", "xyz"]
for name in names:
dbutils.notebook.run(name)
""",
["abc", "xyz"],
),
],
)
def test_infers_dbutils_notebook_run_dynamic_value(code, expected):
tree = Tree.parse(code)
calls = DbutilsLinter.list_dbutils_notebook_run_calls(tree)
actual = functools.reduce(operator.iconcat, list(call.get_notebook_paths() for call in calls), [])
assert expected == actual
Loading