Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer values across notebook cells #1968

Merged
merged 7 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/databricks/labs/ucx/source_code/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from dataclasses import dataclass
from pathlib import Path

from astroid import NodeNG # type: ignore
from astroid import AstroidSyntaxError, NodeNG # type: ignore

from databricks.sdk.service import compute

from databricks.labs.ucx.source_code.linters.python_ast import Tree

# Code mapping between LSP, PyLint, and our own diagnostics:
# | LSP | PyLint | Our |
Expand Down Expand Up @@ -130,6 +131,16 @@ class Linter:
def lint(self, code: str) -> Iterable[Advice]: ...


class PythonLinter(Linter):

def lint(self, code: str) -> Iterable[Advice]:
tree = Tree.normalize_and_parse(code)
yield from self.lint_tree(tree)

@abstractmethod
def lint_tree(self, tree: Tree) -> Iterable[Advice]: ...


class Fixer:
@abstractmethod
def name(self) -> str: ...
Expand Down Expand Up @@ -170,3 +181,22 @@ def __init__(self, linters: list[Linter]):
def lint(self, code: str) -> Iterable[Advice]:
for linter in self._linters:
yield from linter.lint(code)


class PythonSequentialLinter(Linter):

def __init__(self, linters: list[PythonLinter]):
self._linters = linters
self._tree: Tree | None = None

def lint(self, code: str) -> Iterable[Advice]:
try:
tree = Tree.normalize_and_parse(code)
if self._tree is None:
self._tree = tree
else:
tree = self._tree.append_statements(tree)
for linter in self._linters:
yield from linter.lint_tree(tree)
except AstroidSyntaxError as e:
yield Failure('syntax-error', str(e), 0, 0, 0, 0)
10 changes: 8 additions & 2 deletions src/databricks/labs/ucx/source_code/linters/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from databricks.sdk.service.workspace import Language

from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
from databricks.labs.ucx.source_code.base import Fixer, Linter, SequentialLinter, CurrentSessionState
from databricks.labs.ucx.source_code.base import (
Fixer,
Linter,
SequentialLinter,
CurrentSessionState,
PythonSequentialLinter,
)
from databricks.labs.ucx.source_code.linters.dbfs import FromDbfsFolder, DBFSUsageLinter
from databricks.labs.ucx.source_code.linters.imports import DbutilsLinter

Expand All @@ -18,7 +24,7 @@ def __init__(self, index: MigrationIndex, session_state: CurrentSessionState | N
from_table = FromTable(index, session_state=session_state)
dbfs_from_folder = FromDbfsFolder()
self._linters = {
Language.PYTHON: SequentialLinter(
Language.PYTHON: PythonSequentialLinter(
[
SparkSql(from_table, index, session_state),
DBFSUsageLinter(session_state),
Expand Down
12 changes: 6 additions & 6 deletions src/databricks/labs/ucx/source_code/linters/dbfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from sqlglot import Expression, parse as parse_sql, ParseError as SqlParseError
from sqlglot.expressions import Table

from databricks.labs.ucx.source_code.base import Advice, Linter, Deprecation, CurrentSessionState, Failure
from databricks.labs.ucx.source_code.linters.python_ast import Tree, TreeVisitor, InferredValue
from databricks.labs.ucx.source_code.base import Advice, Linter, Deprecation, CurrentSessionState, Failure, PythonLinter
from databricks.labs.ucx.source_code.linters.python_ast import Tree, TreeVisitor
from databricks.labs.ucx.source_code.linters.python_infer import InferredValue

logger = logging.getLogger(__name__)

Expand All @@ -29,7 +30,7 @@ def visit_call(self, node: Call):

def _visit_arg(self, arg: NodeNG):
try:
for inferred in Tree(arg).infer_values(self._session_state):
for inferred in InferredValue.infer_from_node(arg, self._session_state):
if not inferred.is_inferred():
logger.debug(f"Could not infer value of {arg.as_string()}")
continue
Expand Down Expand Up @@ -64,7 +65,7 @@ def get_advices(self) -> Iterable[Advice]:
yield from self._advices


class DBFSUsageLinter(Linter):
class DBFSUsageLinter(PythonLinter):

def __init__(self, session_state: CurrentSessionState):
self._session_state = session_state
Expand All @@ -76,11 +77,10 @@ def name() -> str:
"""
return 'dbfs-usage'

def lint(self, code: str) -> Iterable[Advice]:
def lint_tree(self, tree: Tree) -> Iterable[Advice]:
"""
Lints the code looking for file system paths that are deprecated
"""
tree = Tree.normalize_and_parse(code)
visitor = DetectDbfsVisitor(self._session_state)
visitor.visit(tree.node)
yield from visitor.get_advices()
Expand Down
14 changes: 7 additions & 7 deletions src/databricks/labs/ucx/source_code/linters/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
NodeNG,
)

from databricks.labs.ucx.source_code.base import Linter, Advice, Advisory, CurrentSessionState
from databricks.labs.ucx.source_code.linters.python_ast import Tree, NodeBase, TreeVisitor, InferredValue
from databricks.labs.ucx.source_code.base import Advice, Advisory, CurrentSessionState, PythonLinter
from databricks.labs.ucx.source_code.linters.python_ast import Tree, NodeBase, TreeVisitor
from databricks.labs.ucx.source_code.linters.python_infer import InferredValue

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,7 +91,7 @@ def get_notebook_paths(self, session_state: CurrentSessionState) -> tuple[bool,
"""
arg = DbutilsLinter.get_dbutils_notebook_run_path_arg(self.node)
try:
all_inferred = Tree(arg).infer_values(session_state)
all_inferred = InferredValue.infer_from_node(arg, session_state)
return self._get_notebook_paths(all_inferred)
except InferenceError:
logger.debug(f"Can't infer value(s) of {arg.as_string()}")
Expand All @@ -110,13 +111,12 @@ def _get_notebook_paths(cls, all_inferred: Iterable[InferredValue]) -> tuple[boo
return has_unresolved, paths


class DbutilsLinter(Linter):
class DbutilsLinter(PythonLinter):

def __init__(self, session_state: CurrentSessionState):
self._session_state = session_state

def lint(self, code: str) -> Iterable[Advice]:
tree = Tree.normalize_and_parse(code)
def lint_tree(self, tree: Tree) -> Iterable[Advice]:
nodes = self.list_dbutils_notebook_run_calls(tree)
for node in nodes:
yield from self._raise_advice_if_unresolved(node.node, self._session_state)
Expand Down Expand Up @@ -229,7 +229,7 @@ def visit_call(self, node: Call):
relative = True
changed = changed.args[0]
try:
for inferred in Tree(changed).infer_values(self._session_state):
for inferred in InferredValue.infer_from_node(changed, self._session_state):
self._visit_inferred(changed, inferred, relative, is_append)
except InferenceError:
self.sys_path_changes.append(UnresolvedPath(changed, changed.as_string(), is_append))
Expand Down
21 changes: 8 additions & 13 deletions src/databricks/labs/ucx/source_code/linters/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
from collections.abc import Iterable, Iterator
from dataclasses import dataclass

from astroid import Attribute, Call, Const, InferenceError, NodeNG, AstroidSyntaxError # type: ignore
from astroid import Attribute, Call, Const, InferenceError, NodeNG # type: ignore
from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
from databricks.labs.ucx.source_code.base import (
Advice,
Advisory,
Deprecation,
Fixer,
Linter,
Failure,
CurrentSessionState,
PythonLinter,
)
from databricks.labs.ucx.source_code.linters.python_infer import InferredValue
from databricks.labs.ucx.source_code.queries import FromTable
from databricks.labs.ucx.source_code.linters.python_ast import Tree, InferredValue
from databricks.labs.ucx.source_code.linters.python_ast import Tree


@dataclass
Expand Down Expand Up @@ -78,7 +78,7 @@ def lint(
table_arg = self._get_table_arg(node)
if table_arg:
try:
for inferred in Tree(table_arg).infer_values(self.session_state):
for inferred in InferredValue.infer_from_node(table_arg, self.session_state):
yield from self._lint_table_arg(from_table, node, inferred)
except InferenceError:
yield Advisory.from_node(
Expand Down Expand Up @@ -114,7 +114,7 @@ def lint(
) -> Iterator[Advice]:
table_arg = self._get_table_arg(node)
table_name = table_arg.as_string().strip("'").strip('"')
for inferred in Tree(table_arg).infer_values(session_state):
for inferred in InferredValue.infer_from_node(table_arg, session_state):
if not inferred.is_inferred():
yield Advisory.from_node(
code='table-migrate',
Expand Down Expand Up @@ -315,7 +315,7 @@ def matchers(self):
return self._matchers


class SparkSql(Linter, Fixer):
class SparkSql(PythonLinter, Fixer):

_spark_matchers = SparkMatchers()

Expand All @@ -328,12 +328,7 @@ def name(self) -> str:
# this is the same fixer, just in a different language context
return self._from_table.name()

def lint(self, code: str) -> Iterable[Advice]:
try:
tree = Tree.normalize_and_parse(code)
except AstroidSyntaxError as e:
yield Failure('syntax-error', str(e), 0, 0, 0, 0)
return
def lint_tree(self, tree: Tree) -> Iterable[Advice]:
for node in tree.walk():
matcher = self._find_matcher(node)
if matcher is None:
Expand Down
Loading
Loading