Skip to content

Commit

Permalink
Feat(snowflake): add support for a a couple of missing clauses in PIV…
Browse files Browse the repository at this point in the history
…OT clause (#3867)

* Feat(snowflake): complete PIVOT parsing

* Fixups

* Remove unused arg

* Typo
  • Loading branch information
georgesittas committed Aug 2, 2024
1 parent 659b8bf commit 734f54b
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 5 deletions.
7 changes: 7 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3859,6 +3859,7 @@ class Pivot(Expression):
"group": False,
"columns": False,
"include_nulls": False,
"default_on_null": False,
}

@property
Expand Down Expand Up @@ -4536,6 +4537,12 @@ class PivotAlias(Alias):
pass


# Represents Snowflake's ANY [ ORDER BY ... ] syntax
# https://docs.snowflake.com/en/sql-reference/constructs/pivot
class PivotAny(Expression):
arg_types = {"this": False}


class Aliases(Expression):
arg_types = {"this": True, "expressions": True}

Expand Down
10 changes: 9 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class Generator(metaclass=_Generator):
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
exp.PivotAny: lambda self, e: f"ANY{self.sql(e, 'this')}",
exp.ProjectionPolicyColumnConstraint: lambda self,
e: f"PROJECTION POLICY {self.sql(e, 'this')}",
exp.RemoteWithConnectionModelProperty: lambda self,
Expand Down Expand Up @@ -1830,13 +1831,20 @@ def pivot_sql(self, expression: exp.Pivot) -> str:
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
direction = self.seg("UNPIVOT" if expression.unpivot else "PIVOT")

field = self.sql(expression, "field")
if field and isinstance(expression.args.get("field"), exp.PivotAny):
field = f"IN ({field})"

include_nulls = expression.args.get("include_nulls")
if include_nulls is not None:
nulls = " INCLUDE NULLS " if include_nulls else " EXCLUDE NULLS "
else:
nulls = ""
return f"{direction}{nulls}({expressions} FOR {field}){alias}"

default_on_null = self.sql(expression, "default_on_null")
default_on_null = f" DEFAULT ON NULL ({default_on_null})" if default_on_null else ""
return f"{direction}{nulls}({expressions} FOR {field}{default_on_null}){alias}"

def version_sql(self, expression: exp.Version) -> str:
this = f"FOR {expression.name}"
Expand Down
16 changes: 12 additions & 4 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3725,9 +3725,9 @@ def _parse_on() -> t.Optional[exp.Expression]:
exp.Pivot, this=this, expressions=expressions, using=using, group=group
)

def _parse_pivot_in(self) -> exp.In:
def _parse_pivot_in(self) -> exp.In | exp.PivotAny:
def _parse_aliased_expression() -> t.Optional[exp.Expression]:
this = self._parse_assignment()
this = self._parse_select_or_expression()

self._match(TokenType.ALIAS)
alias = self._parse_field()
Expand All @@ -3741,10 +3741,14 @@ def _parse_aliased_expression() -> t.Optional[exp.Expression]:
if not self._match_pair(TokenType.IN, TokenType.L_PAREN):
self.raise_error("Expecting IN (")

aliased_expressions = self._parse_csv(_parse_aliased_expression)
if self._match(TokenType.ANY):
expr: exp.PivotAny | exp.In = self.expression(exp.PivotAny, this=self._parse_order())
else:
aliased_expressions = self._parse_csv(_parse_aliased_expression)
expr = self.expression(exp.In, this=value, expressions=aliased_expressions)

self._match_r_paren()
return self.expression(exp.In, this=value, expressions=aliased_expressions)
return expr

def _parse_pivot(self) -> t.Optional[exp.Pivot]:
index = self._index
Expand Down Expand Up @@ -3781,6 +3785,9 @@ def _parse_pivot(self) -> t.Optional[exp.Pivot]:
self.raise_error("Expecting FOR")

field = self._parse_pivot_in()
default_on_null = self._match_text_seq("DEFAULT", "ON", "NULL") and self._parse_wrapped(
self._parse_bitwise
)

self._match_r_paren()

Expand All @@ -3790,6 +3797,7 @@ def _parse_pivot(self) -> t.Optional[exp.Pivot]:
field=field,
unpivot=unpivot,
include_nulls=include_nulls,
default_on_null=default_on_null,
)

if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False):
Expand Down
12 changes: 12 additions & 0 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ def test_snowflake(self):
self.validate_identity("ALTER TABLE a SWAP WITH b")
self.validate_identity("SELECT MATCH_CONDITION")
self.validate_identity("SELECT * REPLACE (CAST(col AS TEXT) AS scol) FROM t")
self.validate_identity(
"SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q3', '2023_Q4', '2024_Q1') DEFAULT ON NULL (0)) ORDER BY empid"
)
self.validate_identity(
"SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR quarter IN (SELECT DISTINCT quarter FROM ad_campaign_types_by_quarter WHERE television = TRUE ORDER BY quarter)) ORDER BY empid"
)
self.validate_identity(
"SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR IN (ANY ORDER BY quarter)) ORDER BY empid"
)
self.validate_identity(
"SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR IN (ANY)) ORDER BY empid"
)
self.validate_identity(
"MERGE INTO my_db AS ids USING (SELECT new_id FROM my_model WHERE NOT col IS NULL) AS new_ids ON ids.type = new_ids.type AND ids.source = new_ids.source WHEN NOT MATCHED THEN INSERT VALUES (new_ids.new_id)"
)
Expand Down

0 comments on commit 734f54b

Please sign in to comment.