Skip to content

Commit

Permalink
Integrate detection of notebook dependencies (databrickslabs#1338)
Browse files Browse the repository at this point in the history
## Changes
 - integrates dependency graph to NotebookMigrator

### Linked issues
databrickslabs#1204 
databrickslabs#1286 
databrickslabs#1326

---------

Co-authored-by: Cor <[email protected]>
  • Loading branch information
ericvergnaud and JCZuurmond authored Apr 10, 2024
1 parent fc1747f commit f704ac2
Show file tree
Hide file tree
Showing 8 changed files with 499 additions and 68 deletions.
161 changes: 124 additions & 37 deletions src/databricks/labs/ucx/source_code/notebook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations # for type hints

import ast
import logging
from abc import ABC, abstractmethod
from ast import parse as parse_python
from collections.abc import Callable
Expand All @@ -9,20 +11,38 @@
from sqlglot import parse as parse_sql
from databricks.sdk.service.workspace import Language

NOTEBOOK_HEADER = " Databricks notebook source"
CELL_SEPARATOR = " COMMAND ----------"
MAGIC_PREFIX = ' MAGIC'
LANGUAGE_PREFIX = ' %'
from databricks.labs.ucx.source_code.python_linter import ASTLinter, PythonLinter


logger = logging.getLogger(__name__)
# use a specific logger for sqlglot warnings so we can disable them selectively
sqlglot_logger = logging.getLogger(f"{__name__}.sqlglot")

NOTEBOOK_HEADER = "Databricks notebook source"
CELL_SEPARATOR = "COMMAND ----------"
MAGIC_PREFIX = 'MAGIC'
LANGUAGE_PREFIX = '%'
LANGUAGE_PI = 'LANGUAGE'
COMMENT_PI = 'COMMENT'


class Cell(ABC):

def __init__(self, source: str):
self._original_code = source
self._migrated_code = source

@property
def original_code(self):
return self._original_code

@property
def migrated_code(self):
return self._original_code # for now since we're not doing any migration yet
return self._migrated_code # for now since we're not doing any migration yet

@migrated_code.setter
def migrated_code(self, value: str):
self._migrated_code = value

@property
@abstractmethod
Expand All @@ -46,15 +66,20 @@ def language(self):

def is_runnable(self) -> bool:
try:
ast = parse_python(self._original_code)
return ast is not None
tree = parse_python(self._original_code)
return tree is not None
except SyntaxError:
return False
return True

def build_dependency_graph(self, parent: DependencyGraph):
# TODO https://github.com/databrickslabs/ucx/issues/1200
# TODO https://github.com/databrickslabs/ucx/issues/1202
pass
linter = ASTLinter.parse(self._original_code)
nodes = linter.locate(ast.Call, [("run", ast.Attribute), ("notebook", ast.Attribute), ("dbutils", ast.Name)])
for node in nodes:
assert isinstance(node, ast.Call)
path = PythonLinter.get_dbutils_notebook_run_path_arg(node)
if isinstance(path, ast.Constant):
parent.register_dependency(path.value.strip("'").strip('"'))


class RCell(Cell):
Expand Down Expand Up @@ -93,8 +118,9 @@ def is_runnable(self) -> bool:
try:
statements = parse_sql(self._original_code)
return len(statements) > 0
except SQLParseError:
return False
except SQLParseError as e:
sqlglot_logger.warning(f"Failed to parse SQL using 'sqlglot': {self._original_code}", exc_info=e)
return True

def build_dependency_graph(self, parent: DependencyGraph):
pass # not in scope
Expand Down Expand Up @@ -123,7 +149,7 @@ def is_runnable(self) -> bool:
return True # TODO

def build_dependency_graph(self, parent: DependencyGraph):
command = f'{LANGUAGE_PREFIX}{self.language.magic_name}'.strip()
command = f'{LANGUAGE_PREFIX}{self.language.magic_name}'
lines = self._original_code.split('\n')
for line in lines:
start = line.index(command)
Expand All @@ -133,22 +159,28 @@ def build_dependency_graph(self, parent: DependencyGraph):
return
raise ValueError("Missing notebook path in %run command")

def migrate_notebook_path(self):
pass


class CellLanguage(Enum):
# long magic_names must come first to avoid shorter ones being matched
PYTHON = Language.PYTHON, 'python', '#', PythonCell
SCALA = Language.SCALA, 'scala', '//', ScalaCell
SQL = Language.SQL, 'sql', '--', SQLCell
RUN = None, 'run', None, RunCell
MARKDOWN = None, 'md', None, MarkdownCell
R = Language.R, 'r', '#', RCell
PYTHON = Language.PYTHON, 'python', '#', True, PythonCell
SCALA = Language.SCALA, 'scala', '//', True, ScalaCell
SQL = Language.SQL, 'sql', '--', True, SQLCell
RUN = None, 'run', '', False, RunCell
# see https://spec.commonmark.org/0.31.2/#html-comment
MARKDOWN = None, 'md', "<!--->", False, MarkdownCell
R = Language.R, 'r', '#', True, RCell

def __init__(self, *args):
super().__init__()
self._language = args[0]
self._magic_name = args[1]
self._comment_prefix = args[2]
self._new_cell = args[3]
# PI stands for Processing Instruction
self._requires_isolated_pi = args[3]
self._new_cell = args[4]

@property
def language(self) -> Language:
Expand All @@ -162,6 +194,10 @@ def magic_name(self) -> str:
def comment_prefix(self) -> str:
return self._comment_prefix

@property
def requires_isolated_pi(self) -> str:
return self._requires_isolated_pi

@classmethod
def of_language(cls, language: Language) -> CellLanguage:
return next((cl for cl in CellLanguage if cl.language == language))
Expand All @@ -171,7 +207,7 @@ def of_magic_name(cls, magic_name: str) -> CellLanguage | None:
return next((cl for cl in CellLanguage if magic_name.startswith(cl.magic_name)), None)

def read_cell_language(self, lines: list[str]) -> CellLanguage | None:
magic_prefix = f'{self.comment_prefix}{MAGIC_PREFIX}'
magic_prefix = f'{self.comment_prefix} {MAGIC_PREFIX} '
magic_language_prefix = f'{magic_prefix}{LANGUAGE_PREFIX}'
for line in lines:
# if we find a non-comment then we're done
Expand All @@ -191,26 +227,28 @@ def new_cell(self, source: str) -> Cell:

def extract_cells(self, source: str) -> list[Cell] | None:
lines = source.split('\n')
header = f"{self.comment_prefix}{NOTEBOOK_HEADER}"
header = f"{self.comment_prefix} {NOTEBOOK_HEADER}"
if not lines[0].startswith(header):
raise ValueError("Not a Databricks notebook source!")

def make_cell(lines_: list[str]):
def make_cell(cell_lines: list[str]):
# trim leading blank lines
while len(lines_) > 0 and len(lines_[0]) == 0:
lines_.pop(0)
while len(cell_lines) > 0 and len(cell_lines[0]) == 0:
cell_lines.pop(0)
# trim trailing blank lines
while len(lines_) > 0 and len(lines_[-1]) == 0:
lines_.pop(-1)
cell_language = self.read_cell_language(lines_)
while len(cell_lines) > 0 and len(cell_lines[-1]) == 0:
cell_lines.pop(-1)
cell_language = self.read_cell_language(cell_lines)
if cell_language is None:
cell_language = self
cell_source = '\n'.join(lines_)
else:
self._remove_magic_wrapper(cell_lines, cell_language)
cell_source = '\n'.join(cell_lines)
return cell_language.new_cell(cell_source)

cells = []
cell_lines: list[str] = []
separator = f"{self.comment_prefix}{CELL_SEPARATOR}"
separator = f"{self.comment_prefix} {CELL_SEPARATOR}"
for i in range(1, len(lines)):
line = lines[i].strip()
if line.startswith(separator):
Expand All @@ -225,6 +263,39 @@ def make_cell(lines_: list[str]):

return cells

def _remove_magic_wrapper(self, lines: list[str], cell_language: CellLanguage):
prefix = f"{self.comment_prefix} {MAGIC_PREFIX} "
prefix_len = len(prefix)
for i, line in enumerate(lines):
if line.startswith(prefix):
line = line[prefix_len:]
if cell_language.requires_isolated_pi and line.startswith(LANGUAGE_PREFIX):
line = f"{cell_language.comment_prefix} {LANGUAGE_PI}"
lines[i] = line
continue
if line.startswith(self.comment_prefix):
line = f"{cell_language.comment_prefix} {COMMENT_PI}{line}"
lines[i] = line

def wrap_with_magic(self, code: str, cell_language: CellLanguage) -> str:
language_pi_prefix = f"{cell_language.comment_prefix} {LANGUAGE_PI}"
comment_pi_prefix = f"{cell_language.comment_prefix} {COMMENT_PI}"
comment_pi_prefix_len = len(comment_pi_prefix)
lines = code.split('\n')
for i, line in enumerate(lines):
if line.startswith(language_pi_prefix):
line = f"{self.comment_prefix} {MAGIC_PREFIX} {LANGUAGE_PREFIX}{cell_language.magic_name}"
lines[i] = line
continue
if line.startswith(comment_pi_prefix):
lines[i] = line[comment_pi_prefix_len:]
continue
line = f"{self.comment_prefix} {MAGIC_PREFIX} {line}"
lines[i] = line
if code.endswith('./'):
lines.append('\n')
return "\n".join(lines)


class DependencyGraph:

Expand All @@ -238,7 +309,7 @@ def __init__(self, path: str, parent: DependencyGraph | None, locator: Callable[
def path(self):
return self._path

def register_dependency(self, path: str) -> DependencyGraph:
def register_dependency(self, path: str) -> DependencyGraph | None:
# already registered ?
child_graph = self.locate_dependency(path)
if child_graph is not None:
Expand All @@ -248,6 +319,8 @@ def register_dependency(self, path: str) -> DependencyGraph:
child_graph = DependencyGraph(path, self, self._locator)
self._dependencies[path] = child_graph
notebook = self._locator(path)
if not notebook:
return None
notebook.build_dependency_graph(child_graph)
return child_graph

Expand Down Expand Up @@ -296,30 +369,44 @@ def visit(self, visit_node: Callable[[DependencyGraph], bool | None]) -> bool |
class Notebook:

@staticmethod
def parse(path: str, source: str, default_language: Language) -> Notebook | None:
def parse(path: str, source: str, default_language: Language) -> Notebook:
default_cell_language = CellLanguage.of_language(default_language)
cells = default_cell_language.extract_cells(source)
return None if cells is None else Notebook(path, default_language, cells, source.endswith('\n'))
if cells is None:
raise ValueError(f"Could not parse Notebook: {path}")
return Notebook(path, source, default_language, cells, source.endswith('\n'))

def __init__(self, path: str, language: Language, cells: list[Cell], ends_with_lf):
def __init__(self, path: str, source: str, language: Language, cells: list[Cell], ends_with_lf):
self._path = path
self._source = source
self._language = language
self._cells = cells
self._ends_with_lf = ends_with_lf

@property
def path(self) -> str:
return self._path

@property
def cells(self) -> list[Cell]:
return self._cells

@property
def original_code(self) -> str:
return self._source

def to_migrated_code(self):
default_language = CellLanguage.of_language(self._language)
header = f"{default_language.comment_prefix}{NOTEBOOK_HEADER}"
header = f"{default_language.comment_prefix} {NOTEBOOK_HEADER}"
sources = [header]
for i, cell in enumerate(self._cells):
sources.append(cell.migrated_code)
migrated_code = cell.migrated_code
if cell.language is not default_language:
migrated_code = default_language.wrap_with_magic(migrated_code, cell.language)
sources.append(migrated_code)
if i < len(self._cells) - 1:
sources.append('')
sources.append(f'{default_language.comment_prefix}{CELL_SEPARATOR}')
sources.append(f'{default_language.comment_prefix} {CELL_SEPARATOR}')
sources.append('')
if self._ends_with_lf:
sources.append('') # following join will append lf
Expand Down
70 changes: 58 additions & 12 deletions src/databricks/labs/ucx/source_code/notebook_migrator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.workspace import ExportFormat, ObjectInfo
from databricks.sdk.service.workspace import ExportFormat, ObjectInfo, ObjectType

from databricks.labs.ucx.source_code.languages import Languages
from databricks.labs.ucx.source_code.notebook import DependencyGraph, Notebook, RunCell


class NotebookMigrator:
Expand All @@ -14,19 +15,64 @@ def revert(self, object_info: ObjectInfo):
return False
with self._ws.workspace.download(object_info.path + ".bak", format=ExportFormat.SOURCE) as f:
code = f.read().decode("utf-8")
self._ws.workspace.upload(object_info.path, code.encode("utf-8"))
self._ws.workspace.upload(object_info.path, code.encode("utf-8"))
return True

def apply(self, object_info: ObjectInfo) -> bool:
if not object_info.language or not object_info.path:
return False
if not self._languages.is_supported(object_info.language):
if not object_info.path or not object_info.language or object_info.object_type is not ObjectType.NOTEBOOK:
return False
notebook = self._load_notebook(object_info)
return self._apply(notebook)

def build_dependency_graph(self, object_info: ObjectInfo) -> DependencyGraph:
if not object_info.path or not object_info.language or object_info.object_type is not ObjectType.NOTEBOOK:
raise ValueError("Not a valid Notebook")
notebook = self._load_notebook(object_info)
dependencies = DependencyGraph(object_info.path, None, self._load_notebook_from_path)
notebook.build_dependency_graph(dependencies)
return dependencies

def _apply(self, notebook: Notebook) -> bool:
changed = False
for cell in notebook.cells:
# %run is not a supported language, so this needs to come first
if isinstance(cell, RunCell):
# TODO data on what to change to ?
if cell.migrate_notebook_path():
changed = True
continue
if not self._languages.is_supported(cell.language.language):
continue
migrated_code = self._languages.apply_fixes(cell.language.language, cell.original_code)
if migrated_code != cell.original_code:
cell.migrated_code = migrated_code
changed = True
if changed:
self._ws.workspace.upload(notebook.path + ".bak", notebook.original_code.encode("utf-8"))
self._ws.workspace.upload(notebook.path, notebook.to_migrated_code().encode("utf-8"))
# TODO https://github.com/databrickslabs/ucx/issues/1327 store 'migrated' status
return changed

def _load_notebook_from_path(self, path: str) -> Notebook:
object_info = self._load_object(path)
if object_info.object_type is not ObjectType.NOTEBOOK:
raise ValueError(f"Not a Notebook: {path}")
return self._load_notebook(object_info)

def _load_object(self, path: str) -> ObjectInfo:
result = self._ws.workspace.list(path)
object_info = next((oi for oi in result), None)
if object_info is None:
raise ValueError(f"Could not locate object at '{path}'")
return object_info

def _load_notebook(self, object_info: ObjectInfo) -> Notebook:
assert object_info is not None and object_info.path is not None and object_info.language is not None
source = self._load_source(object_info)
return Notebook.parse(object_info.path, source, object_info.language)

def _load_source(self, object_info: ObjectInfo) -> str:
if not object_info.language or not object_info.path:
raise ValueError(f"Invalid ObjectInfo: {object_info}")
with self._ws.workspace.download(object_info.path, format=ExportFormat.SOURCE) as f:
original_code = f.read().decode("utf-8")
new_code = self._languages.apply_fixes(object_info.language, original_code)
if new_code == original_code:
return False
self._ws.workspace.upload(object_info.path + ".bak", original_code.encode("utf-8"))
self._ws.workspace.upload(object_info.path, new_code.encode("utf-8"))
return True
return f.read().decode("utf-8")
Loading

0 comments on commit f704ac2

Please sign in to comment.