Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalize python code before parsing #1918

Merged
merged 20 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/databricks/labs/ucx/source_code/linters/dbfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def lint(self, code: str) -> Iterable[Advice]:
"""
Lints the code looking for file system paths that are deprecated
"""
code = Tree.convert_magic_lines_to_magic_commands(code)
tree = Tree.parse(code)
tree = Tree.normalize_and_parse(code)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like it's addressing comments from other pr :)

visitor = DetectDbfsVisitor(self._session_state)
visitor.visit(tree.node)
yield from visitor.get_advices()
Expand Down
3 changes: 1 addition & 2 deletions src/databricks/labs/ucx/source_code/linters/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def __init__(self, session_state: CurrentSessionState):
self._session_state = session_state

def lint(self, code: str) -> Iterable[Advice]:
code = Tree.convert_magic_lines_to_magic_commands(code)
tree = Tree.parse(code)
tree = Tree.normalize_and_parse(code)
nodes = self.list_dbutils_notebook_run_calls(tree)
for node in nodes:
yield from self._raise_advice_if_unresolved(node.node, self._session_state)
Expand Down
3 changes: 1 addition & 2 deletions src/databricks/labs/ucx/source_code/linters/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,7 @@ def name(self) -> str:

def lint(self, code: str) -> Iterable[Advice]:
try:
code = Tree.convert_magic_lines_to_magic_commands(code)
tree = Tree.parse(code)
tree = Tree.normalize_and_parse(code)
except AstroidSyntaxError as e:
yield Failure('syntax-error', str(e), 0, 0, 0, 0)
return
Expand Down
45 changes: 42 additions & 3 deletions src/databricks/labs/ucx/source_code/linters/python_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from abc import ABC
import logging
import re
from collections.abc import Iterable, Iterator, Generator
from typing import Any, TypeVar

Expand All @@ -28,13 +29,51 @@ def parse(code: str):
return Tree(root)

@classmethod
def convert_magic_lines_to_magic_commands(cls, python_code: str):
def normalize_and_parse(cls, code: str):
code = cls.normalize(code)
root = parse(code)
return Tree(root)

@classmethod
def normalize(cls, code: str):
code = cls._normalize_indents(code)
code = cls._convert_magic_lines_to_magic_commands(code)
return code

@classmethod
def _normalize_indents(cls, python_code: str):
lines = python_code.split("\n")
for line in lines:
# skip leading ws and comments
if len(line.strip()) == 0 or line.startswith('#'):
continue
if not line.startswith(' '):
# first line of code is correctly indented
return python_code
# first line of code is indented when it shouldn't
prefix_count = len(line) - len(line.lstrip(' '))
prefix_str = ' ' * prefix_count
for i, line_to_fix in enumerate(lines):
if line_to_fix.startswith(prefix_str):
lines[i] = line_to_fix[prefix_count:]
return "\n".join(lines)
return python_code

@classmethod
def _convert_magic_lines_to_magic_commands(cls, python_code: str):
lines = python_code.split("\n")
magic_markers = {"%", "!"}
in_multi_line_comment = False
pattern = re.compile('"""')
for i, line in enumerate(lines):
if len(line) == 0 or line[0] not in magic_markers:
if len(line) == 0:
continue
if not in_multi_line_comment and line[0] in magic_markers:
lines[i] = f"magic_command({line.encode()!r})"
continue
lines[i] = f"magic_command({line.encode()!r})"
matches = re.findall(pattern, line)
if len(matches) & 1:
in_multi_line_comment = not in_multi_line_comment
return "\n".join(lines)

def __init__(self, node: NodeNG):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def __init__(self, is_serverless: bool = False):
]

def lint(self, code: str) -> Iterator[Advice]:
code = Tree.convert_magic_lines_to_magic_commands(code)
tree = Tree.parse(code)
tree = Tree.normalize_and_parse(code)
for matcher in self._matchers:
yield from matcher.lint_tree(tree.node)
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def __init__(self, dbr_version: tuple[int, int] | None):
def lint(self, code: str) -> Iterable[Advice]:
if self._skip_dbr:
return
code = Tree.convert_magic_lines_to_magic_commands(code)
tree = Tree.parse(code)
tree = Tree.normalize_and_parse(code)
for node in tree.walk():
yield from self._linter.lint(node)
3 changes: 1 addition & 2 deletions src/databricks/labs/ucx/source_code/notebooks/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,7 @@ def build_graph_from_python_source(self, python_code: str) -> list[DependencyPro
"""
problems: list[DependencyProblem] = []
try:
python_code = Tree.convert_magic_lines_to_magic_commands(python_code)
tree = Tree.parse(python_code)
tree = Tree.normalize_and_parse(python_code)
except Exception as e: # pylint: disable=broad-except
problems.append(DependencyProblem('parse-error', f"Could not parse Python code: {e}"))
return problems
Expand Down
14 changes: 12 additions & 2 deletions src/databricks/labs/ucx/source_code/notebooks/sources.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import codecs
import locale
import os
from collections.abc import Iterable
from functools import cached_property
from pathlib import Path
Expand Down Expand Up @@ -167,8 +169,16 @@ def __init__(self, ctx: LinterContext, path: Path, content: str | None = None):

@cached_property
def _source_code(self) -> str:
encoding = locale.getpreferredencoding(False)
return self._path.read_text(encoding) if self._content is None else self._content
return self._path.read_text(self._guess_encoding()) if self._content is None else self._content

def _guess_encoding(self):
path = self._path.as_posix()
count = min(32, os.path.getsize(path))
with open(path, 'rb') as _file:
raw = _file.read(count)
if raw.startswith(codecs.BOM_UTF8):
return 'utf-8-sig'
return locale.getpreferredencoding(False)

def _file_language(self):
return SUPPORTED_EXTENSION_LANGUAGES.get(self._path.suffix.lower())
Expand Down
39 changes: 38 additions & 1 deletion tests/unit/source_code/linters/test_python_ast.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from astroid import Assign, Attribute, Call, Const, Expr # type: ignore
from astroid import Assign, AstroidSyntaxError, Attribute, Call, Const, Expr # type: ignore

from databricks.labs.ucx.source_code.base import CurrentSessionState
from databricks.labs.ucx.source_code.linters.python_ast import Tree
Expand Down Expand Up @@ -254,3 +254,40 @@ def test_infers_externally_defined_value_set():
values = list(tree.infer_values(state))
strings = list(value.as_string() for value in values)
assert strings == ["my-value"]


def test_parses_incorrectly_indented_code():
source = """# DBTITLE 1,Get Sales Data for Analysis
sales = (
spark
.table('retail_sales')
.join( # limit data to CY 2021 and 2022
spark.table('date').select('dateKey','date','year').filter('year between 2021 and 2022'),
on='dateKey'
)
.join( # get product fields needed for analysis
spark.table('product').select('productKey','brandValue','packSizeValueUS'),
on='productKey'
)
.join( # get brand fields needed for analysis
spark.table('brand_name_mapping').select('brandValue','brandName'),
on='brandValue'
)
)
"""
# ensure it would fail if not normalized
with pytest.raises(AstroidSyntaxError):
Tree.parse(source)
Tree.normalize_and_parse(source)
assert True


def test_ignores_magic_marker_in_multiline_comment():
source = """message_unformatted = u\"""
%s is only supported in Python %s and above.\"""
name="name"
version="version"
formatted=message_unformatted % (name, version)
"""
Tree.normalize_and_parse(source)
assert True
20 changes: 3 additions & 17 deletions tests/unit/source_code/notebooks/test_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from databricks.labs.ucx.source_code.graph import Dependency, DependencyGraph, DependencyResolver, DependencyProblem
from databricks.labs.ucx.source_code.linters.files import FileLoader, ImportFileResolver
from databricks.labs.ucx.source_code.linters.python_ast import Tree
from databricks.labs.ucx.source_code.notebooks.cells import CellLanguage, PipCell, PythonCell, PipMagic, MagicCommand
from databricks.labs.ucx.source_code.notebooks.cells import CellLanguage, PipCell, PipMagic, MagicCommand
from databricks.labs.ucx.source_code.notebooks.loaders import (
NotebookResolver,
NotebookLoader,
Expand Down Expand Up @@ -167,19 +167,6 @@ def test_pip_cell_build_dependency_graph_handles_multiline_code():
graph.register_library.assert_called_once_with("databricks")


def test_parses_python_cell_with_magic_commands(simple_dependency_resolver, mock_path_lookup):
code = """
a = 'something'
%pip install databricks
b = 'else'
"""
cell = PythonCell(code, original_offset=1)
dependency = Dependency(FileLoader(), Path(""))
graph = DependencyGraph(dependency, None, simple_dependency_resolver, mock_path_lookup, CurrentSessionState())
problems = cell.build_dependency_graph(graph)
assert not problems


@pytest.mark.parametrize(
"code,split",
[
Expand All @@ -205,16 +192,15 @@ def test_parses_python_cell_with_magic_commands(simple_dependency_resolver, mock
),
],
)
def test_pip_command_split(code, split):
def test_pip_magic_split(code, split):
assert PipMagic._split(code) == split # pylint: disable=protected-access


def test_unsupported_magic_raises_problem(simple_dependency_resolver, mock_path_lookup):
source = """
%unsupported stuff '"%#@!
"""
converted = Tree.convert_magic_lines_to_magic_commands(source)
tree = Tree.parse(converted)
tree = Tree.normalize_and_parse(source)
commands, _ = MagicCommand.extract_from_tree(tree, DependencyProblem.from_node)
dependency = Dependency(FileLoader(), Path(""))
graph = DependencyGraph(dependency, None, simple_dependency_resolver, mock_path_lookup, CurrentSessionState())
Expand Down
Loading