diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index e828b9bb9b..0212352d53 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -41,6 +41,21 @@ def _build_datediff(args: t.List) -> exp.Expression: ) +def _build_dateadd(args: t.List) -> exp.Expression: + expression = seq_get(args, 1) + + if len(args) == 2: + # DATE_ADD(startDate, numDays INTEGER) + # https://docs.databricks.com/en/sql/language-manual/functions/date_add.html + return exp.TsOrDsAdd( + this=seq_get(args, 0), expression=expression, unit=exp.Literal.string("DAY") + ) + + # DATE_ADD / DATEADD / TIMESTAMPADD(unit, value integer, expr) + # https://docs.databricks.com/en/sql/language-manual/functions/date_add3.html + return exp.TimestampAdd(this=seq_get(args, 2), expression=expression, unit=seq_get(args, 0)) + + def _normalize_partition(e: exp.Expression) -> exp.Expression: """Normalize the expressions in PARTITION BY (, , ...)""" if isinstance(e, str): @@ -50,6 +65,30 @@ def _normalize_partition(e: exp.Expression) -> exp.Expression: return e +def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str: + if not expression.unit or ( + isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY" + ): + # Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB + return self.func("DATE_ADD", expression.this, expression.expression) + + this = self.func( + "DATE_ADD", + unit_to_var(expression), + expression.expression, + expression.this, + ) + + if isinstance(expression, exp.TsOrDsAdd): + # The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not + # in other dialects + return_type = expression.return_type + if not return_type.is_type(exp.DataType.Type.TIMESTAMP, exp.DataType.Type.DATETIME): + this = f"CAST({this} AS {return_type})" + + return this + + class Spark(Spark2): class Tokenizer(Spark2.Tokenizer): RAW_STRINGS = [ @@ -62,6 +101,9 @@ class Parser(Spark2.Parser): FUNCTIONS = { **Spark2.Parser.FUNCTIONS, "ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue), + "DATE_ADD": _build_dateadd, + "DATEADD": _build_dateadd, + "TIMESTAMPADD": _build_dateadd, "DATEDIFF": _build_datediff, "TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"), "TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"), @@ -111,9 +153,8 @@ class Generator(Spark2.Generator): exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}", exp.StartsWith: rename_func("STARTSWITH"), - exp.TimestampAdd: lambda self, e: self.func( - "DATEADD", unit_to_var(e), e.expression, e.this - ), + exp.TsOrDsAdd: _dateadd_sql, + exp.TimestampAdd: _dateadd_sql, exp.TryCast: lambda self, e: ( self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e) ), diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index bfaf0091b5..888071de7e 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -619,9 +619,9 @@ def test_bigquery(self): 'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)', write={ "bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)", - "databricks": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", + "databricks": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", "mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)", - "spark": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", + "spark": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", }, ) self.validate_all( diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 844fe461d4..6e1830fb92 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -281,6 +281,9 @@ def test_redshift(self): "redshift": "SELECT DATEADD(MONTH, 18, '2008-02-28')", "snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMP))", "tsql": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS DATETIME2))", + "spark": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')", + "spark2": "SELECT ADD_MONTHS('2008-02-28', 18)", + "databricks": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')", }, ) self.validate_all( diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index ecc152f249..bff91bf8fd 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -563,6 +563,7 @@ def test_spark(self): "SELECT DATE_ADD(my_date_column, 1)", write={ "spark": "SELECT DATE_ADD(my_date_column, 1)", + "spark2": "SELECT DATE_ADD(my_date_column, 1)", "bigquery": "SELECT DATE_ADD(CAST(CAST(my_date_column AS DATETIME) AS DATE), INTERVAL 1 DAY)", }, ) @@ -675,6 +676,16 @@ def test_spark(self): "spark": "SELECT ARRAY_SORT(x)", }, ) + self.validate_all( + "SELECT DATE_ADD(MONTH, 20, col)", + read={ + "spark": "SELECT TIMESTAMPADD(MONTH, 20, col)", + }, + write={ + "spark": "SELECT DATE_ADD(MONTH, 20, col)", + "databricks": "SELECT DATE_ADD(MONTH, 20, col)", + }, + ) def test_bool_or(self): self.validate_all(