diff --git a/src/databricks/labs/ucx/source_code/notebook.py b/src/databricks/labs/ucx/source_code/notebook.py index ab81da2f21..7173d6cce6 100644 --- a/src/databricks/labs/ucx/source_code/notebook.py +++ b/src/databricks/labs/ucx/source_code/notebook.py @@ -1,5 +1,7 @@ from __future__ import annotations # for type hints +import ast +import logging from abc import ABC, abstractmethod from ast import parse as parse_python from collections.abc import Callable @@ -9,20 +11,38 @@ from sqlglot import parse as parse_sql from databricks.sdk.service.workspace import Language -NOTEBOOK_HEADER = " Databricks notebook source" -CELL_SEPARATOR = " COMMAND ----------" -MAGIC_PREFIX = ' MAGIC' -LANGUAGE_PREFIX = ' %' +from databricks.labs.ucx.source_code.python_linter import ASTLinter, PythonLinter + + +logger = logging.getLogger(__name__) +# use a specific logger for sqlglot warnings so we can disable them selectively +sqlglot_logger = logging.getLogger(f"{__name__}.sqlglot") + +NOTEBOOK_HEADER = "Databricks notebook source" +CELL_SEPARATOR = "COMMAND ----------" +MAGIC_PREFIX = 'MAGIC' +LANGUAGE_PREFIX = '%' +LANGUAGE_PI = 'LANGUAGE' +COMMENT_PI = 'COMMENT' class Cell(ABC): def __init__(self, source: str): self._original_code = source + self._migrated_code = source + + @property + def original_code(self): + return self._original_code @property def migrated_code(self): - return self._original_code # for now since we're not doing any migration yet + return self._migrated_code # for now since we're not doing any migration yet + + @migrated_code.setter + def migrated_code(self, value: str): + self._migrated_code = value @property @abstractmethod @@ -46,15 +66,20 @@ def language(self): def is_runnable(self) -> bool: try: - ast = parse_python(self._original_code) - return ast is not None + tree = parse_python(self._original_code) + return tree is not None except SyntaxError: - return False + return True def build_dependency_graph(self, parent: DependencyGraph): - # TODO https://github.com/databrickslabs/ucx/issues/1200 # TODO https://github.com/databrickslabs/ucx/issues/1202 - pass + linter = ASTLinter.parse(self._original_code) + nodes = linter.locate(ast.Call, [("run", ast.Attribute), ("notebook", ast.Attribute), ("dbutils", ast.Name)]) + for node in nodes: + assert isinstance(node, ast.Call) + path = PythonLinter.get_dbutils_notebook_run_path_arg(node) + if isinstance(path, ast.Constant): + parent.register_dependency(path.value.strip("'").strip('"')) class RCell(Cell): @@ -93,8 +118,9 @@ def is_runnable(self) -> bool: try: statements = parse_sql(self._original_code) return len(statements) > 0 - except SQLParseError: - return False + except SQLParseError as e: + sqlglot_logger.warning(f"Failed to parse SQL using 'sqlglot': {self._original_code}", exc_info=e) + return True def build_dependency_graph(self, parent: DependencyGraph): pass # not in scope @@ -123,7 +149,7 @@ def is_runnable(self) -> bool: return True # TODO def build_dependency_graph(self, parent: DependencyGraph): - command = f'{LANGUAGE_PREFIX}{self.language.magic_name}'.strip() + command = f'{LANGUAGE_PREFIX}{self.language.magic_name}' lines = self._original_code.split('\n') for line in lines: start = line.index(command) @@ -133,22 +159,28 @@ def build_dependency_graph(self, parent: DependencyGraph): return raise ValueError("Missing notebook path in %run command") + def migrate_notebook_path(self): + pass + class CellLanguage(Enum): # long magic_names must come first to avoid shorter ones being matched - PYTHON = Language.PYTHON, 'python', '#', PythonCell - SCALA = Language.SCALA, 'scala', '//', ScalaCell - SQL = Language.SQL, 'sql', '--', SQLCell - RUN = None, 'run', None, RunCell - MARKDOWN = None, 'md', None, MarkdownCell - R = Language.R, 'r', '#', RCell + PYTHON = Language.PYTHON, 'python', '#', True, PythonCell + SCALA = Language.SCALA, 'scala', '//', True, ScalaCell + SQL = Language.SQL, 'sql', '--', True, SQLCell + RUN = None, 'run', '', False, RunCell + # see https://spec.commonmark.org/0.31.2/#html-comment + MARKDOWN = None, 'md', "", False, MarkdownCell + R = Language.R, 'r', '#', True, RCell def __init__(self, *args): super().__init__() self._language = args[0] self._magic_name = args[1] self._comment_prefix = args[2] - self._new_cell = args[3] + # PI stands for Processing Instruction + self._requires_isolated_pi = args[3] + self._new_cell = args[4] @property def language(self) -> Language: @@ -162,6 +194,10 @@ def magic_name(self) -> str: def comment_prefix(self) -> str: return self._comment_prefix + @property + def requires_isolated_pi(self) -> str: + return self._requires_isolated_pi + @classmethod def of_language(cls, language: Language) -> CellLanguage: return next((cl for cl in CellLanguage if cl.language == language)) @@ -171,7 +207,7 @@ def of_magic_name(cls, magic_name: str) -> CellLanguage | None: return next((cl for cl in CellLanguage if magic_name.startswith(cl.magic_name)), None) def read_cell_language(self, lines: list[str]) -> CellLanguage | None: - magic_prefix = f'{self.comment_prefix}{MAGIC_PREFIX}' + magic_prefix = f'{self.comment_prefix} {MAGIC_PREFIX} ' magic_language_prefix = f'{magic_prefix}{LANGUAGE_PREFIX}' for line in lines: # if we find a non-comment then we're done @@ -191,26 +227,28 @@ def new_cell(self, source: str) -> Cell: def extract_cells(self, source: str) -> list[Cell] | None: lines = source.split('\n') - header = f"{self.comment_prefix}{NOTEBOOK_HEADER}" + header = f"{self.comment_prefix} {NOTEBOOK_HEADER}" if not lines[0].startswith(header): raise ValueError("Not a Databricks notebook source!") - def make_cell(lines_: list[str]): + def make_cell(cell_lines: list[str]): # trim leading blank lines - while len(lines_) > 0 and len(lines_[0]) == 0: - lines_.pop(0) + while len(cell_lines) > 0 and len(cell_lines[0]) == 0: + cell_lines.pop(0) # trim trailing blank lines - while len(lines_) > 0 and len(lines_[-1]) == 0: - lines_.pop(-1) - cell_language = self.read_cell_language(lines_) + while len(cell_lines) > 0 and len(cell_lines[-1]) == 0: + cell_lines.pop(-1) + cell_language = self.read_cell_language(cell_lines) if cell_language is None: cell_language = self - cell_source = '\n'.join(lines_) + else: + self._remove_magic_wrapper(cell_lines, cell_language) + cell_source = '\n'.join(cell_lines) return cell_language.new_cell(cell_source) cells = [] cell_lines: list[str] = [] - separator = f"{self.comment_prefix}{CELL_SEPARATOR}" + separator = f"{self.comment_prefix} {CELL_SEPARATOR}" for i in range(1, len(lines)): line = lines[i].strip() if line.startswith(separator): @@ -225,6 +263,39 @@ def make_cell(lines_: list[str]): return cells + def _remove_magic_wrapper(self, lines: list[str], cell_language: CellLanguage): + prefix = f"{self.comment_prefix} {MAGIC_PREFIX} " + prefix_len = len(prefix) + for i, line in enumerate(lines): + if line.startswith(prefix): + line = line[prefix_len:] + if cell_language.requires_isolated_pi and line.startswith(LANGUAGE_PREFIX): + line = f"{cell_language.comment_prefix} {LANGUAGE_PI}" + lines[i] = line + continue + if line.startswith(self.comment_prefix): + line = f"{cell_language.comment_prefix} {COMMENT_PI}{line}" + lines[i] = line + + def wrap_with_magic(self, code: str, cell_language: CellLanguage) -> str: + language_pi_prefix = f"{cell_language.comment_prefix} {LANGUAGE_PI}" + comment_pi_prefix = f"{cell_language.comment_prefix} {COMMENT_PI}" + comment_pi_prefix_len = len(comment_pi_prefix) + lines = code.split('\n') + for i, line in enumerate(lines): + if line.startswith(language_pi_prefix): + line = f"{self.comment_prefix} {MAGIC_PREFIX} {LANGUAGE_PREFIX}{cell_language.magic_name}" + lines[i] = line + continue + if line.startswith(comment_pi_prefix): + lines[i] = line[comment_pi_prefix_len:] + continue + line = f"{self.comment_prefix} {MAGIC_PREFIX} {line}" + lines[i] = line + if code.endswith('./'): + lines.append('\n') + return "\n".join(lines) + class DependencyGraph: @@ -238,7 +309,7 @@ def __init__(self, path: str, parent: DependencyGraph | None, locator: Callable[ def path(self): return self._path - def register_dependency(self, path: str) -> DependencyGraph: + def register_dependency(self, path: str) -> DependencyGraph | None: # already registered ? child_graph = self.locate_dependency(path) if child_graph is not None: @@ -248,6 +319,8 @@ def register_dependency(self, path: str) -> DependencyGraph: child_graph = DependencyGraph(path, self, self._locator) self._dependencies[path] = child_graph notebook = self._locator(path) + if not notebook: + return None notebook.build_dependency_graph(child_graph) return child_graph @@ -296,30 +369,44 @@ def visit(self, visit_node: Callable[[DependencyGraph], bool | None]) -> bool | class Notebook: @staticmethod - def parse(path: str, source: str, default_language: Language) -> Notebook | None: + def parse(path: str, source: str, default_language: Language) -> Notebook: default_cell_language = CellLanguage.of_language(default_language) cells = default_cell_language.extract_cells(source) - return None if cells is None else Notebook(path, default_language, cells, source.endswith('\n')) + if cells is None: + raise ValueError(f"Could not parse Notebook: {path}") + return Notebook(path, source, default_language, cells, source.endswith('\n')) - def __init__(self, path: str, language: Language, cells: list[Cell], ends_with_lf): + def __init__(self, path: str, source: str, language: Language, cells: list[Cell], ends_with_lf): self._path = path + self._source = source self._language = language self._cells = cells self._ends_with_lf = ends_with_lf + @property + def path(self) -> str: + return self._path + @property def cells(self) -> list[Cell]: return self._cells + @property + def original_code(self) -> str: + return self._source + def to_migrated_code(self): default_language = CellLanguage.of_language(self._language) - header = f"{default_language.comment_prefix}{NOTEBOOK_HEADER}" + header = f"{default_language.comment_prefix} {NOTEBOOK_HEADER}" sources = [header] for i, cell in enumerate(self._cells): - sources.append(cell.migrated_code) + migrated_code = cell.migrated_code + if cell.language is not default_language: + migrated_code = default_language.wrap_with_magic(migrated_code, cell.language) + sources.append(migrated_code) if i < len(self._cells) - 1: sources.append('') - sources.append(f'{default_language.comment_prefix}{CELL_SEPARATOR}') + sources.append(f'{default_language.comment_prefix} {CELL_SEPARATOR}') sources.append('') if self._ends_with_lf: sources.append('') # following join will append lf diff --git a/src/databricks/labs/ucx/source_code/notebook_migrator.py b/src/databricks/labs/ucx/source_code/notebook_migrator.py index b58c6024b9..4f765c0bb5 100644 --- a/src/databricks/labs/ucx/source_code/notebook_migrator.py +++ b/src/databricks/labs/ucx/source_code/notebook_migrator.py @@ -1,7 +1,8 @@ from databricks.sdk import WorkspaceClient -from databricks.sdk.service.workspace import ExportFormat, ObjectInfo +from databricks.sdk.service.workspace import ExportFormat, ObjectInfo, ObjectType from databricks.labs.ucx.source_code.languages import Languages +from databricks.labs.ucx.source_code.notebook import DependencyGraph, Notebook, RunCell class NotebookMigrator: @@ -14,19 +15,64 @@ def revert(self, object_info: ObjectInfo): return False with self._ws.workspace.download(object_info.path + ".bak", format=ExportFormat.SOURCE) as f: code = f.read().decode("utf-8") - self._ws.workspace.upload(object_info.path, code.encode("utf-8")) + self._ws.workspace.upload(object_info.path, code.encode("utf-8")) return True def apply(self, object_info: ObjectInfo) -> bool: - if not object_info.language or not object_info.path: - return False - if not self._languages.is_supported(object_info.language): + if not object_info.path or not object_info.language or object_info.object_type is not ObjectType.NOTEBOOK: return False + notebook = self._load_notebook(object_info) + return self._apply(notebook) + + def build_dependency_graph(self, object_info: ObjectInfo) -> DependencyGraph: + if not object_info.path or not object_info.language or object_info.object_type is not ObjectType.NOTEBOOK: + raise ValueError("Not a valid Notebook") + notebook = self._load_notebook(object_info) + dependencies = DependencyGraph(object_info.path, None, self._load_notebook_from_path) + notebook.build_dependency_graph(dependencies) + return dependencies + + def _apply(self, notebook: Notebook) -> bool: + changed = False + for cell in notebook.cells: + # %run is not a supported language, so this needs to come first + if isinstance(cell, RunCell): + # TODO data on what to change to ? + if cell.migrate_notebook_path(): + changed = True + continue + if not self._languages.is_supported(cell.language.language): + continue + migrated_code = self._languages.apply_fixes(cell.language.language, cell.original_code) + if migrated_code != cell.original_code: + cell.migrated_code = migrated_code + changed = True + if changed: + self._ws.workspace.upload(notebook.path + ".bak", notebook.original_code.encode("utf-8")) + self._ws.workspace.upload(notebook.path, notebook.to_migrated_code().encode("utf-8")) + # TODO https://github.com/databrickslabs/ucx/issues/1327 store 'migrated' status + return changed + + def _load_notebook_from_path(self, path: str) -> Notebook: + object_info = self._load_object(path) + if object_info.object_type is not ObjectType.NOTEBOOK: + raise ValueError(f"Not a Notebook: {path}") + return self._load_notebook(object_info) + + def _load_object(self, path: str) -> ObjectInfo: + result = self._ws.workspace.list(path) + object_info = next((oi for oi in result), None) + if object_info is None: + raise ValueError(f"Could not locate object at '{path}'") + return object_info + + def _load_notebook(self, object_info: ObjectInfo) -> Notebook: + assert object_info is not None and object_info.path is not None and object_info.language is not None + source = self._load_source(object_info) + return Notebook.parse(object_info.path, source, object_info.language) + + def _load_source(self, object_info: ObjectInfo) -> str: + if not object_info.language or not object_info.path: + raise ValueError(f"Invalid ObjectInfo: {object_info}") with self._ws.workspace.download(object_info.path, format=ExportFormat.SOURCE) as f: - original_code = f.read().decode("utf-8") - new_code = self._languages.apply_fixes(object_info.language, original_code) - if new_code == original_code: - return False - self._ws.workspace.upload(object_info.path + ".bak", original_code.encode("utf-8")) - self._ws.workspace.upload(object_info.path, new_code.encode("utf-8")) - return True + return f.read().decode("utf-8") diff --git a/src/databricks/labs/ucx/source_code/python_linter.py b/src/databricks/labs/ucx/source_code/python_linter.py new file mode 100644 index 0000000000..628252b176 --- /dev/null +++ b/src/databricks/labs/ucx/source_code/python_linter.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import ast +import logging +from collections.abc import Iterable + +from databricks.labs.ucx.source_code.base import Linter, Advice, Advisory + + +logger = logging.getLogger(__name__) + + +class MatchingVisitor(ast.NodeVisitor): + + def __init__(self, node_type: type, match_nodes: list[tuple[str, type]]): + self._matched_nodes: list[ast.AST] = [] + self._node_type = node_type + self._match_nodes = match_nodes + + @property + def matched_nodes(self): + return self._matched_nodes + + # visit_Call follows NodeVisitor requirements, which clash with python naming conventions + # pylint: disable=invalid-name + def visit_Call(self, node: ast.Call): + if self._node_type is not ast.Call: + return + try: + if self._matches(node.func, 0): + self._matched_nodes.append(node) + except NotImplementedError as e: + logger.warning(f"Missing implementation: {e.args[0]}") + + def _matches(self, node: ast.AST, depth: int): + if depth >= len(self._match_nodes): + return False + pair = self._match_nodes[depth] + if not isinstance(node, pair[1]): + return False + next_node: ast.AST | None = None + if isinstance(node, ast.Attribute): + if node.attr != pair[0]: + return False + next_node = node.value + elif isinstance(node, ast.Name): + if node.id != pair[0]: + return False + else: + raise NotImplementedError(str(type(node))) + if next_node is None: + # is this the last node to match ? + return len(self._match_nodes) - 1 == depth + return self._matches(next_node, depth + 1) + + +# disclaimer this class is NOT thread-safe +class ASTLinter: + + @staticmethod + def parse(code: str): + root = ast.parse(code) + return ASTLinter(root) + + def __init__(self, root: ast.AST): + self._root: ast.AST = root + + def locate(self, node_type: type, match_nodes: list[tuple[str, type]]) -> list[ast.AST]: + visitor = MatchingVisitor(node_type, match_nodes) + visitor.visit(self._root) + return visitor.matched_nodes + + +class PythonLinter(Linter): + + def lint(self, code: str) -> Iterable[Advice]: + linter = ASTLinter.parse(code) + nodes = linter.locate(ast.Call, [("run", ast.Attribute), ("notebook", ast.Attribute), ("dbutils", ast.Name)]) + return [self._convert_dbutils_notebook_run_to_advice(node) for node in nodes] + + @classmethod + def _convert_dbutils_notebook_run_to_advice(cls, node: ast.AST) -> Advisory: + assert isinstance(node, ast.Call) + path = cls.get_dbutils_notebook_run_path_arg(node) + if isinstance(path, ast.Constant): + return Advisory( + 'dbutils-notebook-run-literal', + "Call to 'dbutils.notebook.run' will be migrated automatically", + node.lineno, + node.col_offset, + node.end_lineno or 0, + node.end_col_offset or 0, + ) + return Advisory( + 'dbutils-notebook-run-dynamic', + "Path for 'dbutils.notebook.run' is not a constant and requires adjusting the notebook path", + node.lineno, + node.col_offset, + node.end_lineno or 0, + node.end_col_offset or 0, + ) + + @staticmethod + def get_dbutils_notebook_run_path_arg(node: ast.Call): + if len(node.args) > 0: + return node.args[0] + arg = next(kw for kw in node.keywords if kw.arg == "path") + return arg.value if arg is not None else None diff --git a/tests/unit/source_code/notebooks/leaf3.py.txt b/tests/unit/source_code/notebooks/leaf3.py.txt new file mode 100644 index 0000000000..39f7b60069 --- /dev/null +++ b/tests/unit/source_code/notebooks/leaf3.py.txt @@ -0,0 +1,12 @@ +# Databricks notebook source +# time horizon +days = 300 + +# volatility +sigma = 0.04 + +# drift (average growth rate) +mu = 0.05 + +# initial starting price +start_price = 10 diff --git a/tests/unit/source_code/notebooks/root4.py.txt b/tests/unit/source_code/notebooks/root4.py.txt new file mode 100644 index 0000000000..bb004e10b9 --- /dev/null +++ b/tests/unit/source_code/notebooks/root4.py.txt @@ -0,0 +1,2 @@ +# Databricks notebook source +dbutils.notebook.run("./leaf3.py.txt") diff --git a/tests/unit/source_code/notebooks/run_notebooks.py.txt b/tests/unit/source_code/notebooks/run_notebooks.py.txt new file mode 100644 index 0000000000..3bd59fde10 --- /dev/null +++ b/tests/unit/source_code/notebooks/run_notebooks.py.txt @@ -0,0 +1,18 @@ +import datetime +# Updated List of Notebooks +notebooks_list = [ + '/Production/data_solutions/accounts/companyA_10011111/de/companyA_de_report', + '/Production/data_solutions/accounts/companyB_10022222/de/companyB_de_report', + '/Production/data_solutions/accounts/companyC_10033333/de/companyC_de_report', + '/Production/data_solutions/accounts/companyD_10044444/de/companyD_de_report', + ] +# Execution: +for notebook in notebooks_list: + try: + start_time = datetime.datetime.now() + print("Running the report of " + str(notebook).split('/')[len(str(notebook).split('/'))-1]) + status = dbutils.notebook.run(notebook,100000) + end_time = datetime.datetime.now() + print("Finished, time taken: " + str(start_time-end_time)) + except: + print("The notebook {0} failed to run".format(notebook)) diff --git a/tests/unit/source_code/test_notebook.py b/tests/unit/source_code/test_notebook.py index d1ae01758b..f3ef36df0e 100644 --- a/tests/unit/source_code/test_notebook.py +++ b/tests/unit/source_code/test_notebook.py @@ -3,10 +3,14 @@ import pytest from databricks.sdk.service.workspace import Language +from databricks.labs.ucx.source_code.base import Advisory from databricks.labs.ucx.source_code.notebook import Notebook, DependencyGraph +from databricks.labs.ucx.source_code.python_linter import PythonLinter from tests.unit import _load_sources # fmt: off +# the following samples are real samples from https://github.com/databricks-industry-solutions +# please keep them untouched, we want our unit tests to run against genuinely representative data PYTHON_NOTEBOOK_SAMPLE = ( "00_var_context.py.txt", Language.PYTHON, @@ -82,7 +86,6 @@ def test_notebook_rebuilds_same_code(source: tuple[str, Language, list[str]]): assert actual_purified == expected_purified -@pytest.mark.skip("for now") @pytest.mark.parametrize( "source", [ @@ -179,3 +182,70 @@ def test_notebook_builds_cyclical_dependency_graph(): notebook.build_dependency_graph(graph) actual = {path[2:] if path.startswith('./') else path for path in graph.paths} assert actual == set(paths) + + +def test_notebook_builds_python_dependency_graph(): + paths = ["root4.py.txt", "leaf3.py.txt"] + sources: list[str] = _load_sources(Notebook, *paths) + languages = [Language.PYTHON] * len(paths) + locator = notebook_locator(paths, sources, languages) + notebook = locator(paths[0]) + graph = DependencyGraph(paths[0], None, locator) + notebook.build_dependency_graph(graph) + actual = {path[2:] if path.startswith('./') else path for path in graph.paths} + assert actual == set(paths) + + +def test_detects_manual_migration_in_dbutils_notebook_run_in_python_code_(): + sources: list[str] = _load_sources(Notebook, "run_notebooks.py.txt") + linter = PythonLinter() + advices = list(linter.lint(sources[0])) + assert [ + Advisory( + code='dbutils-notebook-run-dynamic', + message="Path for 'dbutils.notebook.run' is not a constant and requires adjusting the notebook path", + start_line=14, + start_col=13, + end_line=14, + end_col=50, + ) + ] == advices + + +def test_detects_automatic_migration_in_dbutils_notebook_run_in_python_code_(): + sources: list[str] = _load_sources(Notebook, "root4.py.txt") + linter = PythonLinter() + advices = list(linter.lint(sources[0])) + assert [ + Advisory( + code='dbutils-notebook-run-literal', + message="Call to 'dbutils.notebook.run' will be migrated automatically", + start_line=2, + start_col=0, + end_line=2, + end_col=38, + ) + ] == advices + + +def test_detects_multiple_calls_to_dbutils_notebook_run_in_python_code_(): + source = """ +import stuff +do_something_with_stuff(stuff) +stuff2 = dbutils.notebook.run("where is notebook 1?") +stuff3 = dbutils.notebook.run("where is notebook 2?") +""" + linter = PythonLinter() + advices = list(linter.lint(source)) + assert len(advices) == 2 + + +def test_does_not_detect_partial_call_to_dbutils_notebook_run_in_python_code_(): + source = """ +import stuff +do_something_with_stuff(stuff) +stuff2 = notebook.run("where is notebook 1?") +""" + linter = PythonLinter() + advices = list(linter.lint(source)) + assert len(advices) == 0 diff --git a/tests/unit/source_code/test_notebook_migrator.py b/tests/unit/source_code/test_notebook_migrator.py index c3e50e934f..32112fc5d1 100644 --- a/tests/unit/source_code/test_notebook_migrator.py +++ b/tests/unit/source_code/test_notebook_migrator.py @@ -1,52 +1,140 @@ +from typing import BinaryIO from unittest.mock import create_autospec +import pytest from databricks.sdk import WorkspaceClient -from databricks.sdk.service.workspace import ExportFormat, Language, ObjectInfo +from databricks.sdk.service.workspace import ExportFormat, Language, ObjectInfo, ObjectType +from databricks.labs.ucx.hive_metastore.table_migrate import MigrationIndex from databricks.labs.ucx.source_code.languages import Languages +from databricks.labs.ucx.source_code.notebook import Notebook from databricks.labs.ucx.source_code.notebook_migrator import NotebookMigrator +from tests.unit import _load_sources + + +def test_apply_invalid_object_fails(): + ws = create_autospec(WorkspaceClient) + languages = create_autospec(Languages) + migrator = NotebookMigrator(ws, languages) + object_info = ObjectInfo(language=Language.PYTHON) + assert not migrator.apply(object_info) + + +def test_revert_invalid_object_fails(): + ws = create_autospec(WorkspaceClient) + languages = create_autospec(Languages) + migrator = NotebookMigrator(ws, languages) + object_info = ObjectInfo(language=Language.PYTHON) + assert not migrator.revert(object_info) def test_revert_restores_original_code(): ws = create_autospec(WorkspaceClient) ws.workspace.download.return_value.__enter__.return_value.read.return_value = b'original_code' languages = create_autospec(Languages) - notebooks = NotebookMigrator(ws, languages) + migrator = NotebookMigrator(ws, languages) object_info = ObjectInfo(path='path', language=Language.PYTHON) - notebooks.revert(object_info) + migrator.revert(object_info) ws.workspace.download.assert_called_with('path.bak', format=ExportFormat.SOURCE) ws.workspace.upload.assert_called_with('path', b'original_code') def test_apply_returns_false_when_language_not_supported(): + notebook_code = """# Databricks notebook source +# MAGIC %r +# // original code +""" ws = create_autospec(WorkspaceClient) + ws.workspace.download.return_value.__enter__.return_value.read.return_value = notebook_code.encode("utf-8") languages = create_autospec(Languages) languages.is_supported.return_value = False - notebooks = NotebookMigrator(ws, languages) - object_info = ObjectInfo(path='path', language=Language.R) - result = notebooks.apply(object_info) + migrator = NotebookMigrator(ws, languages) + object_info = ObjectInfo(path='path', language=Language.R, object_type=ObjectType.NOTEBOOK) + result = migrator.apply(object_info) assert not result def test_apply_returns_false_when_no_fixes_applied(): + notebook_code = """# Databricks notebook source +# original code +""" ws = create_autospec(WorkspaceClient) - ws.workspace.download.return_value.__enter__.return_value.read.return_value = b'original_code' + ws.workspace.download.return_value.__enter__.return_value.read.return_value = notebook_code.encode("utf-8") languages = create_autospec(Languages) languages.is_supported.return_value = True - languages.apply_fixes.return_value = 'original_code' - notebooks = NotebookMigrator(ws, languages) - object_info = ObjectInfo(path='path', language=Language.PYTHON) - assert not notebooks.apply(object_info) + languages.apply_fixes.return_value = "# original code" # cell code + migrator = NotebookMigrator(ws, languages) + object_info = ObjectInfo(path='path', language=Language.PYTHON, object_type=ObjectType.NOTEBOOK) + assert not migrator.apply(object_info) def test_apply_returns_true_and_changes_code_when_fixes_applied(): + original_code = """# Databricks notebook source +# original code +""" + migrated_cell_code = '# migrated code' + migrated_code = """# Databricks notebook source +# migrated code +""" ws = create_autospec(WorkspaceClient) - ws.workspace.download.return_value.__enter__.return_value.read.return_value = b'original_code' + ws.workspace.download.return_value.__enter__.return_value.read.return_value = original_code.encode("utf-8") languages = create_autospec(Languages) languages.is_supported.return_value = True - languages.apply_fixes.return_value = 'new_code' - notebooks = NotebookMigrator(ws, languages) - object_info = ObjectInfo(path='path', language=Language.PYTHON) - assert notebooks.apply(object_info) - ws.workspace.upload.assert_any_call('path.bak', 'original_code'.encode("utf-8")) - ws.workspace.upload.assert_any_call('path', 'new_code'.encode("utf-8")) + languages.apply_fixes.return_value = migrated_cell_code + migrator = NotebookMigrator(ws, languages) + object_info = ObjectInfo(path='path', language=Language.PYTHON, object_type=ObjectType.NOTEBOOK) + assert migrator.apply(object_info) + ws.workspace.upload.assert_any_call('path.bak', original_code.encode("utf-8")) + ws.workspace.upload.assert_any_call('path', migrated_code.encode("utf-8")) + + +def test_build_dependency_graph_visits_dependencies(): + paths = ["root3.run.py.txt", "root1.run.py.txt", "leaf1.py.txt", "leaf2.py.txt"] + sources: dict[str, str] = dict(zip(paths, _load_sources(Notebook, *paths))) + visited: dict[str, bool] = {} + + # can't remove **kwargs because it receives format=xxx + # pylint: disable=unused-argument + def download_side_effect(*args, **kwargs): + filename = args[0] + if filename.startswith('./'): + filename = filename[2:] + visited[filename] = True + result = create_autospec(BinaryIO) + result.__enter__.return_value.read.return_value = sources[filename].encode("utf-8") + return result + + def list_side_effect(*args): + path = args[0] + return [ObjectInfo(path=path, language=Language.PYTHON, object_type=ObjectType.NOTEBOOK)] + + ws = create_autospec(WorkspaceClient) + ws.workspace.download.side_effect = download_side_effect + ws.workspace.list.side_effect = list_side_effect + migrator = NotebookMigrator(ws, Languages(create_autospec(MigrationIndex))) + object_info = ObjectInfo(path="root3.run.py.txt", language=Language.PYTHON, object_type=ObjectType.NOTEBOOK) + migrator.build_dependency_graph(object_info) + assert len(visited) == len(paths) + + +def test_build_dependency_graph_fails_with_unfound_dependency(): + paths = ["root1.run.py.txt", "leaf1.py.txt", "leaf2.py.txt"] + sources: dict[str, str] = dict(zip(paths, _load_sources(Notebook, *paths))) + + # can't remove **kwargs because it receives format=xxx + # pylint: disable=unused-argument + def download_side_effect(*args, **kwargs): + filename = args[0] + if filename.startswith('./'): + filename = filename[2:] + result = create_autospec(BinaryIO) + result.__enter__.return_value.read.return_value = sources[filename].encode("utf-8") + return result + + ws = create_autospec(WorkspaceClient) + ws.workspace.download.side_effect = download_side_effect + ws.workspace.list.return_value = [] + migrator = NotebookMigrator(ws, Languages(create_autospec(MigrationIndex))) + object_info = ObjectInfo(path="root1.run.py.txt", language=Language.PYTHON, object_type=ObjectType.NOTEBOOK) + with pytest.raises(ValueError): + migrator.build_dependency_graph(object_info)