Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added linting for DBFS usage #1341

Merged
merged 9 commits into from
Apr 10, 2024
74 changes: 74 additions & 0 deletions src/databricks/labs/ucx/source_code/dbfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import ast
from collections.abc import Iterable

from databricks.labs.ucx.source_code.base import Advice, Linter, Advisory, Deprecation


class DetectDbfsVisitor(ast.NodeVisitor):
"""
Visitor that detects file system paths in Python code and checks them
against a list of known deprecated paths.
"""

def __init__(self):
self._advices: list[Advice] = []
self._fs_prefixes = ["/dbfs/mnt", "dbfs:/", "/mnt/"]

def visit_Call(self, node): # pylint: disable=invalid-name
for arg in node.args:
if isinstance(arg, (ast.Str, ast.Constant)) and isinstance(arg.s, str):
if any(arg.s.startswith(prefix) for prefix in self._fs_prefixes):
self._advices.append(
Deprecation(
code='dbfs-usage',
message=f"Deprecated file system path in call to: {arg.s}",
start_line=arg.lineno,
start_col=arg.col_offset,
end_line=arg.lineno,
end_col=arg.col_offset + len(arg.s),
)
)
self.generic_visit(node)

def visit_Constant(self, node):
# Constant strings yield Advisories
if isinstance(node.value, str):
self._check_str_constant(node)

def _check_str_constant(self, node):
if any(node.s.startswith(prefix) for prefix in self._fs_prefixes):
self._advices.append(
Advisory(
code='dbfs-usage',
message=f"Possible deprecated file system path: {node.s}",
start_line=node.lineno,
start_col=node.col_offset,
end_line=node.lineno,
end_col=node.col_offset + len(node.s),
)
)
self.generic_visit(node)

def get_advices(self) -> Iterable[Advice]:
yield from self._advices


class DBFSUsageLinter(Linter):
def __init__(self):
pass

@staticmethod
def name() -> str:
"""
Returns the name of the linter, for reporting etc
"""
return 'dbfs-usage'

def lint(self, code: str) -> Iterable[Advice]:
"""
Lints the code looking for file system paths that are deprecated
"""
tree = ast.parse(code)
visitor = DetectDbfsVisitor()
visitor.visit(tree)
return visitor.get_advices()
3 changes: 2 additions & 1 deletion src/databricks/labs/ucx/source_code/languages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from databricks.labs.ucx.source_code.base import Fixer, Linter, SequentialLinter
from databricks.labs.ucx.source_code.pyspark import SparkSql
from databricks.labs.ucx.source_code.queries import FromTable
from databricks.labs.ucx.source_code.dbfs import DBFSUsageLinter


class Languages:
def __init__(self, index: MigrationIndex):
self._index = index
from_table = FromTable(index)
self._linters = {
Language.PYTHON: SequentialLinter([SparkSql(from_table, index)]),
Language.PYTHON: SequentialLinter([SparkSql(from_table, index), DBFSUsageLinter()]),
Language.SQL: SequentialLinter([from_table]),
}
self._fixers: dict[Language, list[Fixer]] = {
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/source_code/test_dbfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest

from databricks.labs.ucx.source_code.base import Deprecation, Advisory
from databricks.labs.ucx.source_code.dbfs import DBFSUsageLinter


class TestDetectDBFS:
@pytest.mark.parametrize(
"code, expected",
[
('"/dbfs/mnt"', 1),
('"not a file system path"', 0),
('"/dbfs/mnt", "dbfs:/", "/mnt/"', 3),
('# "/dbfs/mnt"', 0),
('SOME_CONSTANT = "/dbfs/mnt"', 1),
('SOME_CONSTANT = "/dbfs/mnt"; load_data(SOME_CONSTANT)', 1),
],
)
def test_detects_dbfs_str_const_paths(self, code, expected):
finder = DBFSUsageLinter()
advices = finder.lint(code)
count = 0
for advice in advices:
assert isinstance(advice, Advisory)
count += 1
assert count == expected

@pytest.mark.parametrize(
"code, expected",
[
("load_data('/dbfs/mnt/data')", 1),
("load_data('/data')", 0),
("load_data('/dbfs/mnt/data', '/data')", 1),
("# load_data('/dbfs/mnt/data', '/data')", 0),
('spark.read.parquet("/mnt/foo/bar")', 1),
('spark.read.parquet("dbfs:/mnt/foo/bar")', 1),
('spark.read.parquet("dbfs://mnt/foo/bar")', 1),
# Would need a stateful linter to detect this next one
('DBFS="dbfs:/mnt/foo/bar"; spark.read.parquet(DBFS)', 0),
],
)
def test_dbfs_usage_linter(self, code, expected):
linter = DBFSUsageLinter()
advices = linter.lint(code)
count = 0
for advice in advices:
if isinstance(advice, Deprecation):
count += 1
assert count == expected

def test_dbfs_name(self):
linter = DBFSUsageLinter()
assert linter.name() == "dbfs-usage"
Loading