Skip to content

Commit

Permalink
Add Command Execution backend which uses Command Execution API on a c…
Browse files Browse the repository at this point in the history
…luster (#95)
  • Loading branch information
nkvuong authored May 24, 2024
1 parent 867c2a8 commit 6db13bc
Show file tree
Hide file tree
Showing 2 changed files with 289 additions and 11 deletions.
64 changes: 53 additions & 11 deletions src/databricks/labs/lsql/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from types import UnionType
from typing import Any, ClassVar, Protocol, TypeVar

from databricks.labs.blueprint.commands import CommandExecutor
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import (
BadRequest,
Expand All @@ -16,6 +17,7 @@
PermissionDenied,
Unknown,
)
from databricks.sdk.service.compute import Language

from databricks.labs.lsql.core import Row, StatementExecutionExt

Expand Down Expand Up @@ -129,21 +131,20 @@ def _api_error_from_message(error_message: str) -> DatabricksError:
return Unknown(error_message)


class StatementExecutionBackend(SqlBackend):
def __init__(self, ws: WorkspaceClient, warehouse_id, *, max_records_per_batch: int = 1000, **kwargs):
self._sql = StatementExecutionExt(ws, warehouse_id=warehouse_id, **kwargs)
class ExecutionBackend(SqlBackend):
"""Abstract base class for Statement & Command Execution backends.
This class defines the save_table method that is used to save data to tables."""

def __init__(self, max_records_per_batch: int = 1000):
self._max_records_per_batch = max_records_per_batch
debug_truncate_bytes = ws.config.debug_truncate_bytes
# while unit-testing, this value will contain a mock
self._debug_truncate_bytes = debug_truncate_bytes if isinstance(debug_truncate_bytes, int) else 96

@abstractmethod
def execute(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> None:
logger.debug(f"[api][execute] {self._only_n_bytes(sql, self._debug_truncate_bytes)}")
self._sql.execute(sql, catalog=catalog, schema=schema)
raise NotImplementedError

def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> Iterator[Row]:
logger.debug(f"[api][fetch] {self._only_n_bytes(sql, self._debug_truncate_bytes)}")
return self._sql.fetch_all(sql, catalog=catalog, schema=schema)
@abstractmethod
def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> Iterator[Any]:
raise NotImplementedError

def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode="append"):
rows = self._filter_none_rows(rows, klass)
Expand Down Expand Up @@ -183,6 +184,47 @@ def _row_to_sql(row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ..
return ", ".join(data)


class StatementExecutionBackend(ExecutionBackend):
def __init__(self, ws: WorkspaceClient, warehouse_id, *, max_records_per_batch: int = 1000, **kwargs):
self._sql = StatementExecutionExt(ws, warehouse_id=warehouse_id, **kwargs)
debug_truncate_bytes = ws.config.debug_truncate_bytes
# while unit-testing, this value will contain a mock
self._debug_truncate_bytes = debug_truncate_bytes if isinstance(debug_truncate_bytes, int) else 96
super().__init__(max_records_per_batch)

def execute(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> None:
logger.debug(f"[api][execute] {self._only_n_bytes(sql, self._debug_truncate_bytes)}")
self._sql.execute(sql, catalog=catalog, schema=schema)

def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> Iterator[Row]:
logger.debug(f"[api][fetch] {self._only_n_bytes(sql, self._debug_truncate_bytes)}")
return self._sql.fetch_all(sql, catalog=catalog, schema=schema)


class CommandExecutionBackend(ExecutionBackend):
def __init__(self, ws: WorkspaceClient, cluster_id, *, max_records_per_batch: int = 1000):
self._sql = CommandExecutor(ws.clusters, ws.command_execution, lambda: cluster_id, language=Language.SQL)
debug_truncate_bytes = ws.config.debug_truncate_bytes
self._debug_truncate_bytes = debug_truncate_bytes if isinstance(debug_truncate_bytes, int) else 96
super().__init__(max_records_per_batch)

def execute(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> None:
logger.debug(f"[api][execute] {self._only_n_bytes(sql, self._debug_truncate_bytes)}")
if catalog:
self._sql.run(f"USE CATALOG {catalog}")
if schema:
self._sql.run(f"USE SCHEMA {schema}")
self._sql.run(sql)

def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> Iterator[Row]:
logger.debug(f"[api][fetch] {self._only_n_bytes(sql, self._debug_truncate_bytes)}")
if catalog:
self._sql.run(f"USE CATALOG {catalog}")
if schema:
self._sql.run(f"USE SCHEMA{schema}")
return self._sql.run(sql, result_as_json=True)


class _SparkBackend(SqlBackend):
def __init__(self, spark, debug_truncate_bytes):
self._spark = spark
Expand Down
236 changes: 236 additions & 0 deletions tests/unit/test_command_execution_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from dataclasses import dataclass
from unittest import mock
from unittest.mock import call, create_autospec

from databricks.sdk import WorkspaceClient
from databricks.sdk.service._internal import Wait
from databricks.sdk.service.compute import (
CommandStatus,
CommandStatusResponse,
ContextStatusResponse,
Language,
Results,
ResultType,
)

from databricks.labs.lsql.backends import CommandExecutionBackend


@dataclass
class Foo:
first: str
second: bool


@dataclass
class Baz:
first: str
second: str | None = None


@dataclass
class Bar:
first: str
second: bool
third: float


def test_command_context_backend_execute_happy():
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(data="success"), status=CommandStatus.FINISHED
)
)

ceb = CommandExecutionBackend(ws, "abc")

ceb.execute("CREATE TABLE foo")

ws.command_execution.execute.assert_called_with(
cluster_id="abc", language=Language.SQL, context_id="abc", command="CREATE TABLE foo"
)


def test_command_context_backend_with_overrides():
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(data="success"), status=CommandStatus.FINISHED
)
)

ceb = CommandExecutionBackend(ws, "abc")

ceb.execute("CREATE TABLE foo", catalog="foo", schema="bar")

ws.command_execution.execute.assert_has_calls(
[
call(cluster_id="abc", language=Language.SQL, context_id="abc", command="USE CATALOG foo"),
call(cluster_id="abc", language=Language.SQL, context_id="abc", command="USE SCHEMA bar"),
call(cluster_id="abc", language=Language.SQL, context_id="abc", command="CREATE TABLE foo"),
]
)


def test_command_context_backend_fetch_happy():
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(
data=[["1"], ["2"], ["3"]],
result_type=ResultType.TABLE,
schema=[{"name": "id", "type": '"int"', "metadata": "{}"}],
),
status=CommandStatus.FINISHED,
)
)

ceb = CommandExecutionBackend(ws, "abc")

result = list(ceb.fetch("SELECT id FROM range(3)"))

assert [["1"], ["2"], ["3"]] == result


def test_command_context_backend_save_table_overwrite_empty_table():
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(data="success"), status=CommandStatus.FINISHED
)
)

ceb = CommandExecutionBackend(ws, "abc")
ceb.save_table("a.b.c", [Baz("1")], Baz, mode="overwrite")

ws.command_execution.execute.assert_has_calls(
[
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second STRING) USING DELTA",
),
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="TRUNCATE TABLE a.b.c",
),
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="INSERT INTO a.b.c (first, second) VALUES ('1', NULL)",
),
]
)


def test_command_context_backend_save_table_empty_records():
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(data="success"), status=CommandStatus.FINISHED
)
)

ceb = CommandExecutionBackend(ws, "abc")

ceb.save_table("a.b.c", [], Bar)

ws.command_execution.execute.assert_called_with(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="CREATE TABLE IF NOT EXISTS a.b.c "
"(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA",
)


def test_command_context_backend_save_table_two_records():
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(data="success"), status=CommandStatus.FINISHED
)
)

ceb = CommandExecutionBackend(ws, "abc")

ceb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo)

ws.command_execution.execute.assert_has_calls(
[
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA",
),
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="INSERT INTO a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)",
),
]
)


def test_command_context_backend_save_table_in_batches_of_two(mocker):
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(data="success"), status=CommandStatus.FINISHED
)
)

ceb = CommandExecutionBackend(ws, "abc", max_records_per_batch=2)

ceb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo)

ws.command_execution.execute.assert_has_calls(
[
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA",
),
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="INSERT INTO a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)",
),
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="INSERT INTO a.b.c (first, second) VALUES ('ccc', TRUE)",
),
]
)

0 comments on commit 6db13bc

Please sign in to comment.