Skip to content

Commit

Permalink
feat(spark, databricks)!: Support for DATE_ADD functions (#3609)
Browse files Browse the repository at this point in the history
* feat(spark, databricks): Support for DATE_ADD functions

* PR Feedback 1

* PR Feedback 2
  • Loading branch information
VaggelisD authored Jun 7, 2024
1 parent 664ae5c commit d6cfb41
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 5 deletions.
47 changes: 44 additions & 3 deletions sqlglot/dialects/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,21 @@ def _build_datediff(args: t.List) -> exp.Expression:
)


def _build_dateadd(args: t.List) -> exp.Expression:
expression = seq_get(args, 1)

if len(args) == 2:
# DATE_ADD(startDate, numDays INTEGER)
# https://docs.databricks.com/en/sql/language-manual/functions/date_add.html
return exp.TsOrDsAdd(
this=seq_get(args, 0), expression=expression, unit=exp.Literal.string("DAY")
)

# DATE_ADD / DATEADD / TIMESTAMPADD(unit, value integer, expr)
# https://docs.databricks.com/en/sql/language-manual/functions/date_add3.html
return exp.TimestampAdd(this=seq_get(args, 2), expression=expression, unit=seq_get(args, 0))


def _normalize_partition(e: exp.Expression) -> exp.Expression:
"""Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
if isinstance(e, str):
Expand All @@ -50,6 +65,30 @@ def _normalize_partition(e: exp.Expression) -> exp.Expression:
return e


def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str:
if not expression.unit or (
isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY"
):
# Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB
return self.func("DATE_ADD", expression.this, expression.expression)

this = self.func(
"DATE_ADD",
unit_to_var(expression),
expression.expression,
expression.this,
)

if isinstance(expression, exp.TsOrDsAdd):
# The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not
# in other dialects
return_type = expression.return_type
if not return_type.is_type(exp.DataType.Type.TIMESTAMP, exp.DataType.Type.DATETIME):
this = f"CAST({this} AS {return_type})"

return this


class Spark(Spark2):
class Tokenizer(Spark2.Tokenizer):
RAW_STRINGS = [
Expand All @@ -62,6 +101,9 @@ class Parser(Spark2.Parser):
FUNCTIONS = {
**Spark2.Parser.FUNCTIONS,
"ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
"DATE_ADD": _build_dateadd,
"DATEADD": _build_dateadd,
"TIMESTAMPADD": _build_dateadd,
"DATEDIFF": _build_datediff,
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
"TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
Expand Down Expand Up @@ -111,9 +153,8 @@ class Generator(Spark2.Generator):
exp.PartitionedByProperty: lambda self,
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", unit_to_var(e), e.expression, e.this
),
exp.TsOrDsAdd: _dateadd_sql,
exp.TimestampAdd: _dateadd_sql,
exp.TryCast: lambda self, e: (
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
),
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,9 +619,9 @@ def test_bigquery(self):
'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)',
write={
"bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)",
"databricks": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
"databricks": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
"mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)",
"spark": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
"spark": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
},
)
self.validate_all(
Expand Down
3 changes: 3 additions & 0 deletions tests/dialects/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@ def test_redshift(self):
"redshift": "SELECT DATEADD(MONTH, 18, '2008-02-28')",
"snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMP))",
"tsql": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS DATETIME2))",
"spark": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
"spark2": "SELECT ADD_MONTHS('2008-02-28', 18)",
"databricks": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
},
)
self.validate_all(
Expand Down
11 changes: 11 additions & 0 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def test_spark(self):
"SELECT DATE_ADD(my_date_column, 1)",
write={
"spark": "SELECT DATE_ADD(my_date_column, 1)",
"spark2": "SELECT DATE_ADD(my_date_column, 1)",
"bigquery": "SELECT DATE_ADD(CAST(CAST(my_date_column AS DATETIME) AS DATE), INTERVAL 1 DAY)",
},
)
Expand Down Expand Up @@ -675,6 +676,16 @@ def test_spark(self):
"spark": "SELECT ARRAY_SORT(x)",
},
)
self.validate_all(
"SELECT DATE_ADD(MONTH, 20, col)",
read={
"spark": "SELECT TIMESTAMPADD(MONTH, 20, col)",
},
write={
"spark": "SELECT DATE_ADD(MONTH, 20, col)",
"databricks": "SELECT DATE_ADD(MONTH, 20, col)",
},
)

def test_bool_or(self):
self.validate_all(
Expand Down

0 comments on commit d6cfb41

Please sign in to comment.