Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Command Execution backend which uses Command Execution API on a cluster #95

Merged
merged 3 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 53 additions & 11 deletions src/databricks/labs/lsql/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from types import UnionType
from typing import Any, ClassVar, Protocol, TypeVar

from databricks.labs.blueprint.commands import CommandExecutor
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import (
BadRequest,
Expand All @@ -16,6 +17,7 @@
PermissionDenied,
Unknown,
)
from databricks.sdk.service.compute import Language

from databricks.labs.lsql.core import Row, StatementExecutionExt

Expand Down Expand Up @@ -129,21 +131,20 @@ def _api_error_from_message(error_message: str) -> DatabricksError:
return Unknown(error_message)


class StatementExecutionBackend(SqlBackend):
def __init__(self, ws: WorkspaceClient, warehouse_id, *, max_records_per_batch: int = 1000, **kwargs):
self._sql = StatementExecutionExt(ws, warehouse_id=warehouse_id, **kwargs)
class ExecutionBackend(SqlBackend):
"""Abstract base class for Statement & Command Execution backends.
This class defines the save_table method that is used to save data to tables."""

def __init__(self, max_records_per_batch: int = 1000):
self._max_records_per_batch = max_records_per_batch
debug_truncate_bytes = ws.config.debug_truncate_bytes
# while unit-testing, this value will contain a mock
self._debug_truncate_bytes = debug_truncate_bytes if isinstance(debug_truncate_bytes, int) else 96

@abstractmethod
def execute(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> None:
logger.debug(f"[api][execute] {self._only_n_bytes(sql, self._debug_truncate_bytes)}")
self._sql.execute(sql, catalog=catalog, schema=schema)
raise NotImplementedError

def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> Iterator[Row]:
logger.debug(f"[api][fetch] {self._only_n_bytes(sql, self._debug_truncate_bytes)}")
return self._sql.fetch_all(sql, catalog=catalog, schema=schema)
@abstractmethod
def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> Iterator[Any]:
raise NotImplementedError

def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode="append"):
rows = self._filter_none_rows(rows, klass)
Expand Down Expand Up @@ -183,6 +184,47 @@ def _row_to_sql(row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ..
return ", ".join(data)


class StatementExecutionBackend(ExecutionBackend):
def __init__(self, ws: WorkspaceClient, warehouse_id, *, max_records_per_batch: int = 1000, **kwargs):
self._sql = StatementExecutionExt(ws, warehouse_id=warehouse_id, **kwargs)
debug_truncate_bytes = ws.config.debug_truncate_bytes
# while unit-testing, this value will contain a mock
self._debug_truncate_bytes = debug_truncate_bytes if isinstance(debug_truncate_bytes, int) else 96
super().__init__(max_records_per_batch)

def execute(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> None:
logger.debug(f"[api][execute] {self._only_n_bytes(sql, self._debug_truncate_bytes)}")
self._sql.execute(sql, catalog=catalog, schema=schema)

def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> Iterator[Row]:
logger.debug(f"[api][fetch] {self._only_n_bytes(sql, self._debug_truncate_bytes)}")
return self._sql.fetch_all(sql, catalog=catalog, schema=schema)


class CommandExecutionBackend(ExecutionBackend):
def __init__(self, ws: WorkspaceClient, cluster_id, *, max_records_per_batch: int = 1000):
self._sql = CommandExecutor(ws.clusters, ws.command_execution, lambda: cluster_id, language=Language.SQL)
debug_truncate_bytes = ws.config.debug_truncate_bytes
self._debug_truncate_bytes = debug_truncate_bytes if isinstance(debug_truncate_bytes, int) else 96
super().__init__(max_records_per_batch)

def execute(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> None:
logger.debug(f"[api][execute] {self._only_n_bytes(sql, self._debug_truncate_bytes)}")
if catalog:
self._sql.run(f"USE CATALOG {catalog}")
if schema:
self._sql.run(f"USE SCHEMA {schema}")
self._sql.run(sql)

def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> Iterator[Row]:
logger.debug(f"[api][fetch] {self._only_n_bytes(sql, self._debug_truncate_bytes)}")
if catalog:
self._sql.run(f"USE CATALOG {catalog}")
if schema:
self._sql.run(f"USE SCHEMA{schema}")
return self._sql.run(sql, result_as_json=True)


class _SparkBackend(SqlBackend):
def __init__(self, spark, debug_truncate_bytes):
self._spark = spark
Expand Down
236 changes: 236 additions & 0 deletions tests/unit/test_command_execution_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from dataclasses import dataclass
from unittest import mock
from unittest.mock import call, create_autospec

from databricks.sdk import WorkspaceClient
from databricks.sdk.service._internal import Wait
from databricks.sdk.service.compute import (
CommandStatus,
CommandStatusResponse,
ContextStatusResponse,
Language,
Results,
ResultType,
)

from databricks.labs.lsql.backends import CommandExecutionBackend


@dataclass
class Foo:
first: str
second: bool


@dataclass
class Baz:
first: str
second: str | None = None


@dataclass
class Bar:
first: str
second: bool
third: float


def test_command_context_backend_execute_happy():
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(data="success"), status=CommandStatus.FINISHED
)
)

ceb = CommandExecutionBackend(ws, "abc")

ceb.execute("CREATE TABLE foo")

ws.command_execution.execute.assert_called_with(
cluster_id="abc", language=Language.SQL, context_id="abc", command="CREATE TABLE foo"
)


def test_command_context_backend_with_overrides():
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(data="success"), status=CommandStatus.FINISHED
)
)

ceb = CommandExecutionBackend(ws, "abc")

ceb.execute("CREATE TABLE foo", catalog="foo", schema="bar")

ws.command_execution.execute.assert_has_calls(
[
call(cluster_id="abc", language=Language.SQL, context_id="abc", command="USE CATALOG foo"),
call(cluster_id="abc", language=Language.SQL, context_id="abc", command="USE SCHEMA bar"),
call(cluster_id="abc", language=Language.SQL, context_id="abc", command="CREATE TABLE foo"),
]
)


def test_command_context_backend_fetch_happy():
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(
data=[["1"], ["2"], ["3"]],
result_type=ResultType.TABLE,
schema=[{"name": "id", "type": '"int"', "metadata": "{}"}],
),
status=CommandStatus.FINISHED,
)
)

ceb = CommandExecutionBackend(ws, "abc")

result = list(ceb.fetch("SELECT id FROM range(3)"))

assert [["1"], ["2"], ["3"]] == result


def test_command_context_backend_save_table_overwrite_empty_table():
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(data="success"), status=CommandStatus.FINISHED
)
)

ceb = CommandExecutionBackend(ws, "abc")
ceb.save_table("a.b.c", [Baz("1")], Baz, mode="overwrite")

ws.command_execution.execute.assert_has_calls(
[
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second STRING) USING DELTA",
),
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="TRUNCATE TABLE a.b.c",
),
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="INSERT INTO a.b.c (first, second) VALUES ('1', NULL)",
),
]
)


def test_command_context_backend_save_table_empty_records():
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(data="success"), status=CommandStatus.FINISHED
)
)

ceb = CommandExecutionBackend(ws, "abc")

ceb.save_table("a.b.c", [], Bar)

ws.command_execution.execute.assert_called_with(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="CREATE TABLE IF NOT EXISTS a.b.c "
"(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA",
)


def test_command_context_backend_save_table_two_records():
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(data="success"), status=CommandStatus.FINISHED
)
)

ceb = CommandExecutionBackend(ws, "abc")

ceb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo)

ws.command_execution.execute.assert_has_calls(
[
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA",
),
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="INSERT INTO a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)",
),
]
)


def test_command_context_backend_save_table_in_batches_of_two(mocker):
ws = create_autospec(WorkspaceClient)
ws.command_execution.create.return_value = Wait[ContextStatusResponse](
waiter=lambda callback, timeout: ContextStatusResponse(id="abc")
)
ws.command_execution.execute.return_value = Wait[CommandStatusResponse](
waiter=lambda callback, timeout: CommandStatusResponse(
results=Results(data="success"), status=CommandStatus.FINISHED
)
)

ceb = CommandExecutionBackend(ws, "abc", max_records_per_batch=2)

ceb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo)

ws.command_execution.execute.assert_has_calls(
[
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA",
),
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="INSERT INTO a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)",
),
mock.call(
cluster_id="abc",
language=Language.SQL,
context_id="abc",
command="INSERT INTO a.b.c (first, second) VALUES ('ccc', TRUE)",
),
]
)
Loading