Skip to content

feat: Add MERGE DDL #429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sqlalchemy_bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
TIME,
TIMESTAMP,
)
from .merge import Merge # noqa

__all__ = [
"ARRAY",
Expand All @@ -56,6 +57,7 @@
"FLOAT64",
"INT64",
"INTEGER",
"Merge",
"NUMERIC",
"RECORD",
"STRING",
Expand Down
231 changes: 231 additions & 0 deletions sqlalchemy_bigquery/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright 2021 Google LLC
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import textwrap
from copy import copy
from dataclasses import dataclass
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union

import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.expression import ClauseElement, ColumnElement
from sqlalchemy.sql.selectable import Subquery # noqa
from sqlalchemy.sql.selectable import Alias, Select, TableClause

MergeConditionType = Optional[ColumnElement[sa.Boolean]]


class Merge(ClauseElement):
Copy link

@pykenny pykenny Nov 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ClauseElement instance is not executable, so passing a Merge instance to Connection.execute() doesn't work.

updates = {"id": source.id, "name": source.name}
stmt = (
  Merge(target, source, on=target.id == source.id)
  .when_matched()
  .then_update(updates)
  .when_not_matched_by_target()
  .then_insert(updates)
)
connection.execute(stmt) # Ooops

How about inheriting UpdateBase instead?

def __init__(
self,
target: Union[Alias, TableClause],
source: Union[Alias, TableClause, Subquery, Select],
*,
on: ColumnElement[sa.Boolean],
):
super().__init__()
if not (
isinstance(target, TableClause)
or (
isinstance(target, Alias)
and isinstance(getattr(target, "element"), TableClause)
)
):
raise Exception(
"Parameter `target` must be a table, or an aliased table,"
f" instead received:\n{repr(target)}"
)
if not isinstance(target, (Alias, TableClause, Subquery, Select)):
raise Exception(
"Parameter `source` must be a table, alias, subquery, or selectable,"
f" instead received:\n{repr(target)}"
)

self.when = tuple()
self.target = target
self.source = source
self.on = on
self.when: Tuple[WhenBase] = tuple()

def when_matched(self, condition: MergeConditionType = None):
return MergeWhenMatched(self, condition)

def when_matched_not_matched_by_target(self, condition: MergeConditionType = None):
return MergeWhenNotMatchedByTarget(self, condition)

def when_matched_not_matched_by_source(self, condition: MergeConditionType = None):
return MergeWhenNotMatchedBySource(self, condition)
Comment on lines +57 to +61
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't they be when_not_matched_by_target and when_not_matched_by_source? 🤔


def _add_when(self, when: "WhenBase"):
cloned = copy(self)
cloned.when = tuple(cloned.when) + (when,)
return cloned


class ThenBase:
pass


@dataclass(frozen=True)
class ThenUpdate(ThenBase):
fields: Dict[str, ColumnElement[Any]]


@dataclass(frozen=True)
class ThenInsert(ThenBase):
fields: Dict[str, ColumnElement[Any]]


@dataclass(frozen=True)
class ThenDelete(ThenBase):
pass


class WhenBase:
then: ThenBase
condition: MergeConditionType


@dataclass(frozen=True)
class WhenMatched(WhenBase):
then: Union[ThenUpdate, ThenDelete]
condition: MergeConditionType = None


@dataclass(frozen=True)
class WhenNotMatchedByTarget(WhenBase):
then: ThenInsert
condition: MergeConditionType = None


@dataclass(frozen=True)
class WhenNotMatchedBySource(WhenBase):
then: Union[ThenUpdate, ThenDelete]
condition: MergeConditionType = None


class MergeThenUpdateDeleteBase:
merge_cls: ClassVar[Union[Type[WhenMatched], Type[WhenNotMatchedBySource]]]

def __init__(self, stmt: "Merge", condition: MergeConditionType = None) -> None:
super().__init__()
self.stmt = stmt
self.condition = condition

def then_update(self, fields: Dict[str, ColumnElement[Any]]):
return self.stmt._add_when(self.merge_cls(ThenUpdate(fields), self.condition))

def then_delete(self):
return self.stmt._add_when(self.merge_cls(ThenDelete(), self.condition))


class MergeWhenMatched(MergeThenUpdateDeleteBase):
merge_cls = WhenMatched


class MergeWhenNotMatchedBySource(MergeThenUpdateDeleteBase):
merge_cls = WhenNotMatchedBySource


class MergeWhenNotMatchedByTarget:
def __init__(self, stmt: "Merge", condition: MergeConditionType = None) -> None:
super().__init__()
self.stmt = stmt
self.condition = condition

def then_insert(self, fields: Dict[str, ColumnElement[Any]]):
return self.stmt._add_when(
WhenNotMatchedByTarget(ThenInsert(fields), self.condition)
)


@compiles(Merge)
def _compile_merge(self: Merge, compiler: SQLCompiler, **kwargs):
def _compile_select(value):
if isinstance(value, (TableClause, Alias, Subquery)):
select = sa.select([]).select_from(value)
code = compiler.process(select, **kwargs).split("FROM", 1)[1]
else:
code = compiler.process(value, **kwargs)
code = code.strip()
if "\n" not in code:
return " " + code
return "\n" + _indent(code, 2)

target = _compile_select(self.target)
source = _compile_select(self.source)
on = compiler.process(self.on, **kwargs)

code = f"MERGE\n INTO{target}"
code += f"\n USING{source}"
code += f"\n ON {on}"
for when in self.when:
code += "\n" + compiler.process(when, **kwargs)
return code.rstrip() + ";"


@compiles(WhenMatched)
def _compile_when_matched(self: WhenMatched, compiler: SQLCompiler, **kwargs):
return "WHEN MATCHED" + _compile_when(self, compiler, **kwargs)


@compiles(WhenNotMatchedByTarget)
def _compile_when_not_matched_by_target(
self: WhenNotMatchedByTarget, compiler: SQLCompiler, **kwargs
):
return "WHEN NOT MATCHED BY TARGET" + _compile_when(self, compiler, **kwargs)


@compiles(WhenNotMatchedBySource)
def _compile_when_not_matched_by_source(
self: WhenNotMatchedBySource, compiler: SQLCompiler, **kwargs
):
return "WHEN MATCHED NOT MATCHED BY SOURCE" + _compile_when(
Copy link

@pykenny pykenny Sep 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WHEN MATCHED NOT MATCHED BY SOURCE -> WHEN NOT MATCHED BY SOURCE

self, compiler, **kwargs
)


def _compile_when(when: WhenBase, compiler: SQLCompiler, **kwargs):
code = ""
if when.condition is not None:
code += "\n AND "
code += _indent(compiler.process(when.condition, **kwargs)).strip()
code += "\n THEN "
code += _indent(compiler.process(when.then, **kwargs)).strip()
return code


@compiles(ThenUpdate)
def _compile_then_update(self: ThenUpdate, compiler: SQLCompiler, **kwargs):
code = "UPDATE SET"
code += ",".join(
"\n " + _indent(f"{key} = {compiler.process(value, **kwargs)}", 2).strip()
for key, value in self.fields.items()
)
return code


@compiles(ThenInsert)
def _compile_then_insert(self: ThenInsert, compiler: SQLCompiler, **kwargs):
code = "INSERT ("
code += ",".join(f"\n {key}" for key in self.fields.keys())
code += "\n) VALUES ("
code += ",".join(
"\n " + _indent(compiler.process(value, **kwargs), 2).strip()
for value in self.fields.values()
)
code += "\n)"
return code


@compiles(ThenDelete)
def _compile_then_delete(self: ThenDelete, compiler: SQLCompiler, **kwargs):
return "DELETE"


def _indent(text: str, amount: int = 1) -> str:
return textwrap.indent(text, prefix=" " * amount)
107 changes: 107 additions & 0 deletions tests/unit/test_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2021 Google LLC
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import re

import pytest
import sqlalchemy as sa

from sqlalchemy_bigquery import BigQueryDialect
from sqlalchemy_bigquery.merge import Merge


def test_merge():
target = sa.table("dest", sa.Column("b", sa.TEXT)).alias("a")
source = (
sa.select((sa.column("a") * 2).label("b"))
.select_from(sa.table("world"))
.alias("b")
)

stmt = Merge(target=target, source=source, on=target.c.b == source.c.b)
stmt = stmt.when_matched(source.c.b > 10).then_update(
{"b": target.c.b * source.c.b * sa.literal(123)}
)
stmt = stmt.when_matched_not_matched_by_target().then_insert({"b": source.c.b})
stmt = stmt.when_matched_not_matched_by_target().then_insert({"b": source.c.b})
stmt = stmt.when_matched().then_delete()

stmt_compiled = stmt.compile(
dialect=BigQueryDialect(), compile_kwargs={"literal_binds": True}
)
expected_sql = """
MERGE
INTO `dest` AS `a`
USING
(SELECT `a` * 2 AS `b`
FROM `world`) AS `b`
ON `a`.`b` = `b`.`b`
WHEN MATCHED
AND `b`.`b` > 10
THEN UPDATE SET
b = `a`.`b` * `b`.`b` * 123
WHEN NOT MATCHED BY TARGET
THEN INSERT (
b
) VALUES (
`b`.`b`
)
WHEN NOT MATCHED BY TARGET
THEN INSERT (
b
) VALUES (
`b`.`b`
)
WHEN MATCHED
THEN DELETE;
"""
assert remove_ws(stmt_compiled) == remove_ws(expected_sql)


def test_bad_parameter():
target = sa.table("dest", sa.Column("b", sa.TEXT)).alias("a")
source = (
sa.select((sa.column("a") * 2).label("b"))
.select_from(sa.table("world"))
.alias("b")
)

with pytest.raises(TypeError):
# Maybe we can help the developer prevent this gotchya?
stmt = Merge(target=target, source=source, on=target.c.b == source.c.b)
stmt = stmt.when_matched_not_matched_by_target().then_delete() # type: ignore


def test_then_delete():
target = sa.table("dest", sa.Column("b", sa.TEXT)).alias("a")
source = (
sa.select((sa.column("a") * 2).label("b"))
.select_from(sa.table("world"))
.alias("b")
)

stmt = Merge(target=target, source=source, on=target.c.b == source.c.b)
stmt = stmt.when_matched_not_matched_by_source().then_delete()

stmt_compiled = stmt.compile(
dialect=BigQueryDialect(), compile_kwargs={"literal_binds": True}
)
print(stmt_compiled)
expected_sql = """
MERGE
INTO `dest` AS `a`
USING
(SELECT `a` * 2 AS `b`
FROM `world`) AS `b`
ON `a`.`b` = `b`.`b`
WHEN MATCHED NOT MATCHED BY SOURCE
THEN DELETE;
"""
assert remove_ws(stmt_compiled) == remove_ws(expected_sql)


def remove_ws(text: str) -> str:
return re.sub(r"\s+", " ", str(text)).strip()