Skip to content

Commit

Permalink
Infer values across notebook cells (#1968)
Browse files Browse the repository at this point in the history
## Changes
When linting python code, infer values using not only code from current
cell but also code from previous cells

### Linked issues
Progresses #1912
Progresses #1205 

### Functionality 
None

### Tests
- [x] manually tested
- [x] added unit tests

Resolved 60 out of 891 "cannot be computed" advices when running make
solacc

---------

Co-authored-by: Eric Vergnaud <[email protected]>
  • Loading branch information
ericvergnaud and ericvergnaud authored Jul 5, 2024
1 parent 6198a28 commit eae95ad
Show file tree
Hide file tree
Showing 15 changed files with 608 additions and 398 deletions.
32 changes: 31 additions & 1 deletion src/databricks/labs/ucx/source_code/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from dataclasses import dataclass
from pathlib import Path

from astroid import NodeNG # type: ignore
from astroid import AstroidSyntaxError, NodeNG # type: ignore

from databricks.sdk.service import compute

from databricks.labs.ucx.source_code.linters.python_ast import Tree

# Code mapping between LSP, PyLint, and our own diagnostics:
# | LSP | PyLint | Our |
Expand Down Expand Up @@ -130,6 +131,16 @@ class Linter:
def lint(self, code: str) -> Iterable[Advice]: ...


class PythonLinter(Linter):

def lint(self, code: str) -> Iterable[Advice]:
tree = Tree.normalize_and_parse(code)
yield from self.lint_tree(tree)

@abstractmethod
def lint_tree(self, tree: Tree) -> Iterable[Advice]: ...


class Fixer:
@abstractmethod
def name(self) -> str: ...
Expand Down Expand Up @@ -170,3 +181,22 @@ def __init__(self, linters: list[Linter]):
def lint(self, code: str) -> Iterable[Advice]:
for linter in self._linters:
yield from linter.lint(code)


class PythonSequentialLinter(Linter):

def __init__(self, linters: list[PythonLinter]):
self._linters = linters
self._tree: Tree | None = None

def lint(self, code: str) -> Iterable[Advice]:
try:
tree = Tree.normalize_and_parse(code)
if self._tree is None:
self._tree = tree
else:
tree = self._tree.append_statements(tree)
for linter in self._linters:
yield from linter.lint_tree(tree)
except AstroidSyntaxError as e:
yield Failure('syntax-error', str(e), 0, 0, 0, 0)
61 changes: 58 additions & 3 deletions src/databricks/labs/ucx/source_code/known.json
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,7 @@
"code": "dbfs-usage",
"message": "Deprecated file system path: dbfs:/"
},
{
{
"code": "table-migrate",
"message": "The default format changed in Databricks Runtime 8.0, from Parquet to Delta"
}
Expand Down Expand Up @@ -2572,6 +2572,14 @@
"dockerpycreds.utils": [],
"dockerpycreds.version": []
},
"docstring-to-markdown": {
"docstring_to_markdown": [],
"docstring_to_markdown._utils": [],
"docstring_to_markdown.cpython": [],
"docstring_to_markdown.google": [],
"docstring_to_markdown.plain": [],
"docstring_to_markdown.rst": []
},
"entrypoints": {
"entrypoints": []
},
Expand Down Expand Up @@ -21782,6 +21790,53 @@
"python-dateutil": {
"dateutil": []
},
"python-lsp-jsonrpc": {
"pylsp_jsonrpc": [],
"pylsp_jsonrpc._version": [],
"pylsp_jsonrpc.dispatchers": [],
"pylsp_jsonrpc.endpoint": [],
"pylsp_jsonrpc.exceptions": [],
"pylsp_jsonrpc.streams": []
},
"python-lsp-server": {
"pylsp": [],
"pylsp._utils": [],
"pylsp._version": [],
"pylsp.config": [],
"pylsp.config.config": [],
"pylsp.config.flake8_conf": [],
"pylsp.config.pycodestyle_conf": [],
"pylsp.config.source": [],
"pylsp.hookspecs": [],
"pylsp.lsp": [],
"pylsp.plugins": [],
"pylsp.plugins._resolvers": [],
"pylsp.plugins._rope_task_handle": [],
"pylsp.plugins.autopep8_format": [],
"pylsp.plugins.definition": [],
"pylsp.plugins.flake8_lint": [],
"pylsp.plugins.folding": [],
"pylsp.plugins.highlight": [],
"pylsp.plugins.hover": [],
"pylsp.plugins.jedi_completion": [],
"pylsp.plugins.jedi_rename": [],
"pylsp.plugins.mccabe_lint": [],
"pylsp.plugins.preload_imports": [],
"pylsp.plugins.pycodestyle_lint": [],
"pylsp.plugins.pydocstyle_lint": [],
"pylsp.plugins.pyflakes_lint": [],
"pylsp.plugins.pylint_lint": [],
"pylsp.plugins.references": [],
"pylsp.plugins.rope_autoimport": [],
"pylsp.plugins.rope_completion": [],
"pylsp.plugins.signature": [],
"pylsp.plugins.symbols": [],
"pylsp.plugins.yapf_format": [],
"pylsp.python_lsp": [],
"pylsp.text_edit": [],
"pylsp.uris": [],
"pylsp.workspace": []
},
"pytz": {
"pytz": []
},
Expand Down Expand Up @@ -25156,6 +25211,7 @@
"tzdata": {
"tzdata": []
},
"ujson": {},
"umap": {
"umap": [],
"umap.get": []
Expand Down Expand Up @@ -25957,5 +26013,4 @@
"zipp.compat.py310": [],
"zipp.glob": []
}
}

}
23 changes: 17 additions & 6 deletions src/databricks/labs/ucx/source_code/linters/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from typing import cast

from databricks.sdk.service.workspace import Language

from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
from databricks.labs.ucx.source_code.base import Fixer, Linter, SequentialLinter, CurrentSessionState
from databricks.labs.ucx.source_code.base import (
Fixer,
Linter,
SequentialLinter,
CurrentSessionState,
PythonSequentialLinter,
PythonLinter,
)
from databricks.labs.ucx.source_code.linters.dbfs import FromDbfsFolder, DBFSUsageLinter
from databricks.labs.ucx.source_code.linters.imports import DbutilsLinter

Expand All @@ -16,7 +25,7 @@ def __init__(self, index: MigrationIndex | None = None, session_state: CurrentSe
self._index = index
session_state = CurrentSessionState() if not session_state else session_state

python_linters: list[Linter] = []
python_linters: list[PythonLinter] = []
python_fixers: list[Fixer] = []

sql_linters: list[Linter] = []
Expand All @@ -38,9 +47,9 @@ def __init__(self, index: MigrationIndex | None = None, session_state: CurrentSe
]
sql_linters.append(FromDbfsFolder())

self._linters = {
Language.PYTHON: SequentialLinter(python_linters),
Language.SQL: SequentialLinter(sql_linters),
self._linters: dict[Language, list[Linter] | list[PythonLinter]] = {
Language.PYTHON: python_linters,
Language.SQL: sql_linters,
}
self._fixers: dict[Language, list[Fixer]] = {
Language.PYTHON: python_fixers,
Expand All @@ -53,7 +62,9 @@ def is_supported(self, language: Language) -> bool:
def linter(self, language: Language) -> Linter:
if language not in self._linters:
raise ValueError(f"Unsupported language: {language}")
return self._linters[language]
if language is Language.PYTHON:
return PythonSequentialLinter(cast(list[PythonLinter], self._linters[language]))
return SequentialLinter(cast(list[Linter], self._linters[language]))

def fixer(self, language: Language, diagnostic_code: str) -> Fixer | None:
if language not in self._fixers:
Expand Down
12 changes: 6 additions & 6 deletions src/databricks/labs/ucx/source_code/linters/dbfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from sqlglot import Expression, parse as parse_sql, ParseError as SqlParseError
from sqlglot.expressions import Table

from databricks.labs.ucx.source_code.base import Advice, Linter, Deprecation, CurrentSessionState, Failure
from databricks.labs.ucx.source_code.linters.python_ast import Tree, TreeVisitor, InferredValue
from databricks.labs.ucx.source_code.base import Advice, Linter, Deprecation, CurrentSessionState, Failure, PythonLinter
from databricks.labs.ucx.source_code.linters.python_ast import Tree, TreeVisitor
from databricks.labs.ucx.source_code.linters.python_infer import InferredValue

logger = logging.getLogger(__name__)

Expand All @@ -29,7 +30,7 @@ def visit_call(self, node: Call):

def _visit_arg(self, arg: NodeNG):
try:
for inferred in Tree(arg).infer_values(self._session_state):
for inferred in InferredValue.infer_from_node(arg, self._session_state):
if not inferred.is_inferred():
logger.debug(f"Could not infer value of {arg.as_string()}")
continue
Expand Down Expand Up @@ -64,7 +65,7 @@ def get_advices(self) -> Iterable[Advice]:
yield from self._advices


class DBFSUsageLinter(Linter):
class DBFSUsageLinter(PythonLinter):

def __init__(self, session_state: CurrentSessionState):
self._session_state = session_state
Expand All @@ -76,11 +77,10 @@ def name() -> str:
"""
return 'dbfs-usage'

def lint(self, code: str) -> Iterable[Advice]:
def lint_tree(self, tree: Tree) -> Iterable[Advice]:
"""
Lints the code looking for file system paths that are deprecated
"""
tree = Tree.normalize_and_parse(code)
visitor = DetectDbfsVisitor(self._session_state)
visitor.visit(tree.node)
yield from visitor.get_advices()
Expand Down
14 changes: 7 additions & 7 deletions src/databricks/labs/ucx/source_code/linters/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
NodeNG,
)

from databricks.labs.ucx.source_code.base import Linter, Advice, Advisory, CurrentSessionState
from databricks.labs.ucx.source_code.linters.python_ast import Tree, NodeBase, TreeVisitor, InferredValue
from databricks.labs.ucx.source_code.base import Advice, Advisory, CurrentSessionState, PythonLinter
from databricks.labs.ucx.source_code.linters.python_ast import Tree, NodeBase, TreeVisitor
from databricks.labs.ucx.source_code.linters.python_infer import InferredValue

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,7 +91,7 @@ def get_notebook_paths(self, session_state: CurrentSessionState) -> tuple[bool,
"""
arg = DbutilsLinter.get_dbutils_notebook_run_path_arg(self.node)
try:
all_inferred = Tree(arg).infer_values(session_state)
all_inferred = InferredValue.infer_from_node(arg, session_state)
return self._get_notebook_paths(all_inferred)
except InferenceError:
logger.debug(f"Can't infer value(s) of {arg.as_string()}")
Expand All @@ -110,13 +111,12 @@ def _get_notebook_paths(cls, all_inferred: Iterable[InferredValue]) -> tuple[boo
return has_unresolved, paths


class DbutilsLinter(Linter):
class DbutilsLinter(PythonLinter):

def __init__(self, session_state: CurrentSessionState):
self._session_state = session_state

def lint(self, code: str) -> Iterable[Advice]:
tree = Tree.normalize_and_parse(code)
def lint_tree(self, tree: Tree) -> Iterable[Advice]:
nodes = self.list_dbutils_notebook_run_calls(tree)
for node in nodes:
yield from self._raise_advice_if_unresolved(node.node, self._session_state)
Expand Down Expand Up @@ -229,7 +229,7 @@ def visit_call(self, node: Call):
relative = True
changed = changed.args[0]
try:
for inferred in Tree(changed).infer_values(self._session_state):
for inferred in InferredValue.infer_from_node(changed, self._session_state):
self._visit_inferred(changed, inferred, relative, is_append)
except InferenceError:
self.sys_path_changes.append(UnresolvedPath(changed, changed.as_string(), is_append))
Expand Down
21 changes: 8 additions & 13 deletions src/databricks/labs/ucx/source_code/linters/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
from collections.abc import Iterable, Iterator
from dataclasses import dataclass

from astroid import Attribute, Call, Const, InferenceError, NodeNG, AstroidSyntaxError # type: ignore
from astroid import Attribute, Call, Const, InferenceError, NodeNG # type: ignore
from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
from databricks.labs.ucx.source_code.base import (
Advice,
Advisory,
Deprecation,
Fixer,
Linter,
Failure,
CurrentSessionState,
PythonLinter,
)
from databricks.labs.ucx.source_code.linters.python_infer import InferredValue
from databricks.labs.ucx.source_code.queries import FromTable
from databricks.labs.ucx.source_code.linters.python_ast import Tree, InferredValue
from databricks.labs.ucx.source_code.linters.python_ast import Tree


@dataclass
Expand Down Expand Up @@ -78,7 +78,7 @@ def lint(
table_arg = self._get_table_arg(node)
if table_arg:
try:
for inferred in Tree(table_arg).infer_values(self.session_state):
for inferred in InferredValue.infer_from_node(table_arg, self.session_state):
yield from self._lint_table_arg(from_table, node, inferred)
except InferenceError:
yield Advisory.from_node(
Expand Down Expand Up @@ -114,7 +114,7 @@ def lint(
) -> Iterator[Advice]:
table_arg = self._get_table_arg(node)
table_name = table_arg.as_string().strip("'").strip('"')
for inferred in Tree(table_arg).infer_values(session_state):
for inferred in InferredValue.infer_from_node(table_arg, session_state):
if not inferred.is_inferred():
yield Advisory.from_node(
code='table-migrate-cannot-compute-value',
Expand Down Expand Up @@ -315,7 +315,7 @@ def matchers(self):
return self._matchers


class SparkSql(Linter, Fixer):
class SparkSql(PythonLinter, Fixer):

_spark_matchers = SparkMatchers()

Expand All @@ -328,12 +328,7 @@ def name(self) -> str:
# this is the same fixer, just in a different language context
return self._from_table.name()

def lint(self, code: str) -> Iterable[Advice]:
try:
tree = Tree.normalize_and_parse(code)
except AstroidSyntaxError as e:
yield Failure('syntax-error', str(e), 0, 0, 0, 0)
return
def lint_tree(self, tree: Tree) -> Iterable[Advice]:
for node in tree.walk():
matcher = self._find_matcher(node)
if matcher is None:
Expand Down
Loading

0 comments on commit eae95ad

Please sign in to comment.