Skip to content

Commit

Permalink
Migrate Python linters from ast (standard library) to astroid pac…
Browse files Browse the repository at this point in the history
…kage (#1835)

## Changes
Migrate Python linters from ast to astroid
Implement minimal inference

### Linked issues
Progresses #1205 

### Functionality 

- [ ] added relevant user documentation
- [ ] added new CLI command
- [ ] modified existing command: `databricks labs ucx ...`
- [ ] added a new workflow
- [ ] modified existing workflow: `...`
- [ ] added a new table
- [ ] modified existing table: `...`

### Tests
- [ ] manually tested
- [x] added unit tests
- [ ] added integration tests
- [ ] verified on staging environment (screenshot attached)

---------

Co-authored-by: Eric Vergnaud <[email protected]>
  • Loading branch information
ericvergnaud and ericvergnaud authored Jun 6, 2024
1 parent 20474c3 commit f346f0a
Show file tree
Hide file tree
Showing 22 changed files with 526 additions and 449 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ dependencies = ["databricks-sdk>=0.27,<0.29",
"databricks-labs-lsql~=0.4.0",
"databricks-labs-blueprint>=0.6.0",
"PyYAML>=6.0.0,<7.0.0",
"sqlglot>=23.9,<24.2"]
"sqlglot>=23.9,<24.2",
"astroid>=3.2.2"]

[project.entry-points.databricks]
runtime = "databricks.labs.ucx.runtime:main"
Expand All @@ -65,7 +66,7 @@ dependencies = [
"black~=24.3.0",
"coverage[toml]~=7.4.4",
"mypy~=1.9.0",
"pylint~=3.1.0",
"pylint~=3.2.2",
"pylint-pytest==2.0.0a0",
"databricks-labs-pylint~=0.4.0",
"pytest~=8.1.0",
Expand Down
6 changes: 3 additions & 3 deletions src/databricks/labs/ucx/source_code/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from databricks.labs.ucx.source_code.linters.imports import (
ASTLinter,
DbutilsLinter,
SysPathChange,
NotebookRunCall,
ImportSource,
NodeBase,
NotebookRunCall,
SysPathChange,
)
from databricks.labs.ucx.source_code.path_lookup import PathLookup

Expand Down Expand Up @@ -186,7 +186,7 @@ def _process_node(self, base_node: NodeBase):
if isinstance(base_node, SysPathChange):
self._mutate_path_lookup(base_node)
if isinstance(base_node, NotebookRunCall):
strpath = base_node.get_constant_path()
strpath = base_node.get_notebook_path()
if strpath is None:
yield DependencyProblem('dependency-not-constant', "Can't check dependency not provided as a constant")
else:
Expand Down
53 changes: 31 additions & 22 deletions src/databricks/labs/ucx/source_code/linters/ast_helpers.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,39 @@
import ast
import logging

from astroid import Attribute, Call, Name # type: ignore

class AstHelper:
@staticmethod
def get_full_attribute_name(node: ast.Attribute) -> str:
return AstHelper._get_value(node)
logger = logging.getLogger(__file__)

@staticmethod
def get_full_function_name(node: ast.Call) -> str | None:
if isinstance(node.func, ast.Attribute):
return AstHelper._get_value(node.func)
missing_handlers: set[str] = set()

if isinstance(node.func, ast.Name):
return node.func.id

class AstHelper:
@classmethod
def get_full_attribute_name(cls, node: Attribute) -> str:
return cls._get_attribute_value(node)

@classmethod
def get_full_function_name(cls, node: Call) -> str | None:
if not isinstance(node, Call):
return None
if isinstance(node.func, Attribute):
return cls._get_attribute_value(node.func)
if isinstance(node.func, Name):
return node.func.name
return None

@staticmethod
def _get_value(node: ast.Attribute):
if isinstance(node.value, ast.Name):
return node.value.id + '.' + node.attr

if isinstance(node.value, ast.Attribute):
value = AstHelper._get_value(node.value)
if not value:
return None
return value + '.' + node.attr

@classmethod
def _get_attribute_value(cls, node: Attribute):
if isinstance(node.expr, Name):
return node.expr.name + '.' + node.attrname
if isinstance(node.expr, Attribute):
parent = cls._get_attribute_value(node.expr)
return node.attrname if parent is None else parent + '.' + node.attrname
if isinstance(node.expr, Call):
name = cls.get_full_function_name(node.expr)
return node.attrname if name is None else name + '.' + node.attrname
name = type(node.expr).__name__
if name not in missing_handlers:
missing_handlers.add(name)
logger.debug(f"Missing handler for {name}")
return None
5 changes: 3 additions & 2 deletions src/databricks/labs/ucx/source_code/linters/context.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from databricks.sdk.service.workspace import Language

from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
from databricks.labs.ucx.source_code.base import CurrentSessionState, SequentialLinter, Fixer, Linter
from databricks.labs.ucx.source_code.base import Fixer, Linter, SequentialLinter, CurrentSessionState
from databricks.labs.ucx.source_code.linters.dbfs import FromDbfsFolder, DBFSUsageLinter
from databricks.labs.ucx.source_code.linters.imports import DbutilsLinter

from databricks.labs.ucx.source_code.linters.pyspark import SparkSql
from databricks.labs.ucx.source_code.queries import FromTable
from databricks.labs.ucx.source_code.linters.spark_connect import SparkConnectLinter
from databricks.labs.ucx.source_code.linters.table_creation import DBRv8d0Linter
from databricks.labs.ucx.source_code.queries import FromTable


class LinterContext:
Expand Down
33 changes: 17 additions & 16 deletions src/databricks/labs/ucx/source_code/linters/dbfs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import ast
from collections.abc import Iterable

from astroid import Call, Const # type: ignore
import sqlglot
from sqlglot.expressions import Table

from databricks.labs.ucx.source_code.base import Advice, Linter, Advisory, Deprecation
from databricks.labs.ucx.source_code.linters.imports import Visitor, ASTLinter


class DetectDbfsVisitor(ast.NodeVisitor):
class DetectDbfsVisitor(Visitor):
"""
Visitor that detects file system paths in Python code and checks them
against a list of known deprecated paths.
Expand All @@ -18,44 +19,44 @@ def __init__(self):
self._fs_prefixes = ["/dbfs/mnt", "dbfs:/", "/mnt/"]
self._reported_locations = set() # Set to store reported locations

def visit_Call(self, node):
def visit_call(self, node: Call):
for arg in node.args:
if isinstance(arg, (ast.Str, ast.Constant)) and isinstance(arg.s, str):
if any(arg.s.startswith(prefix) for prefix in self._fs_prefixes):
if isinstance(arg, Const) and isinstance(arg.value, str):
value = arg.value
if any(value.startswith(prefix) for prefix in self._fs_prefixes):
self._advices.append(
Deprecation(
code='dbfs-usage',
message=f"Deprecated file system path in call to: {arg.s}",
message=f"Deprecated file system path in call to: {value}",
start_line=arg.lineno,
start_col=arg.col_offset,
end_line=arg.lineno,
end_col=arg.col_offset + len(arg.s),
end_col=arg.col_offset + len(value),
)
)
# Record the location of the reported constant, so we do not double report
self._reported_locations.add((arg.lineno, arg.col_offset))
self.generic_visit(node)

def visit_Constant(self, node):
def visit_const(self, node: Const):
# Constant strings yield Advisories
if isinstance(node.value, str):
self._check_str_constant(node)

def _check_str_constant(self, node):
def _check_str_constant(self, node: Const):
# Check if the location has been reported before
if (node.lineno, node.col_offset) not in self._reported_locations:
if any(node.s.startswith(prefix) for prefix in self._fs_prefixes):
value = node.value
if any(value.startswith(prefix) for prefix in self._fs_prefixes):
self._advices.append(
Advisory(
code='dbfs-usage',
message=f"Possible deprecated file system path: {node.s}",
message=f"Possible deprecated file system path: {value}",
start_line=node.lineno,
start_col=node.col_offset,
end_line=node.lineno,
end_col=node.col_offset + len(node.s),
end_col=node.col_offset + len(value),
)
)
self.generic_visit(node)

def get_advices(self) -> Iterable[Advice]:
yield from self._advices
Expand All @@ -76,9 +77,9 @@ def lint(self, code: str) -> Iterable[Advice]:
"""
Lints the code looking for file system paths that are deprecated
"""
tree = ast.parse(code)
linter = ASTLinter.parse(code)
visitor = DetectDbfsVisitor()
visitor.visit(tree)
visitor.visit(linter.root)
yield from visitor.get_advices()


Expand Down
Loading

0 comments on commit f346f0a

Please sign in to comment.