Skip to content

Commit

Permalink
fix(redshift)!: Normalize time units in their full singular form (#3652)
Browse files Browse the repository at this point in the history
* fix(redshift): Normalize time units in their full singular form

* fix make style

* PR Feedback 1

* Move DATE_PART_MAPPING to Dialect

* Add EPOCH test case
  • Loading branch information
VaggelisD authored Jun 14, 2024
1 parent e8cab58 commit d331e56
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 94 deletions.
118 changes: 118 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ class Dialect(metaclass=_Dialect):
) SELECT c FROM y;
"""

COPY_PARAMS_ARE_CSV = True
"""
Whether COPY statement parameters are separated by comma or whitespace
"""

# --- Autofilled ---

tokenizer_class = Tokenizer
Expand Down Expand Up @@ -347,6 +352,100 @@ class Dialect(metaclass=_Dialect):
UNICODE_START: t.Optional[str] = None
UNICODE_END: t.Optional[str] = None

DATE_PART_MAPPING = {
"Y": "YEAR",
"YY": "YEAR",
"YYY": "YEAR",
"YYYY": "YEAR",
"YR": "YEAR",
"YEARS": "YEAR",
"YRS": "YEAR",
"MM": "MONTH",
"MON": "MONTH",
"MONS": "MONTH",
"MONTHS": "MONTH",
"D": "DAY",
"DD": "DAY",
"DAYS": "DAY",
"DAYOFMONTH": "DAY",
"DAY OF WEEK": "DAYOFWEEK",
"WEEKDAY": "DAYOFWEEK",
"DOW": "DAYOFWEEK",
"DW": "DAYOFWEEK",
"WEEKDAY_ISO": "DAYOFWEEKISO",
"DOW_ISO": "DAYOFWEEKISO",
"DW_ISO": "DAYOFWEEKISO",
"DAY OF YEAR": "DAYOFYEAR",
"DOY": "DAYOFYEAR",
"DY": "DAYOFYEAR",
"W": "WEEK",
"WK": "WEEK",
"WEEKOFYEAR": "WEEK",
"WOY": "WEEK",
"WY": "WEEK",
"WEEK_ISO": "WEEKISO",
"WEEKOFYEARISO": "WEEKISO",
"WEEKOFYEAR_ISO": "WEEKISO",
"Q": "QUARTER",
"QTR": "QUARTER",
"QTRS": "QUARTER",
"QUARTERS": "QUARTER",
"H": "HOUR",
"HH": "HOUR",
"HR": "HOUR",
"HOURS": "HOUR",
"HRS": "HOUR",
"M": "MINUTE",
"MI": "MINUTE",
"MIN": "MINUTE",
"MINUTES": "MINUTE",
"MINS": "MINUTE",
"S": "SECOND",
"SEC": "SECOND",
"SECONDS": "SECOND",
"SECS": "SECOND",
"MS": "MILLISECOND",
"MSEC": "MILLISECOND",
"MSECS": "MILLISECOND",
"MSECOND": "MILLISECOND",
"MSECONDS": "MILLISECOND",
"MILLISEC": "MILLISECOND",
"MILLISECS": "MILLISECOND",
"MILLISECON": "MILLISECOND",
"MILLISECONDS": "MILLISECOND",
"US": "MICROSECOND",
"USEC": "MICROSECOND",
"USECS": "MICROSECOND",
"MICROSEC": "MICROSECOND",
"MICROSECS": "MICROSECOND",
"USECOND": "MICROSECOND",
"USECONDS": "MICROSECOND",
"MICROSECONDS": "MICROSECOND",
"NS": "NANOSECOND",
"NSEC": "NANOSECOND",
"NANOSEC": "NANOSECOND",
"NSECOND": "NANOSECOND",
"NSECONDS": "NANOSECOND",
"NANOSECS": "NANOSECOND",
"EPOCH_SECOND": "EPOCH",
"EPOCH_SECONDS": "EPOCH",
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
"EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
"EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
"TZH": "TIMEZONE_HOUR",
"TZM": "TIMEZONE_MINUTE",
"DEC": "DECADE",
"DECS": "DECADE",
"DECADES": "DECADE",
"MIL": "MILLENIUM",
"MILS": "MILLENIUM",
"MILLENIA": "MILLENIUM",
"C": "CENTURY",
"CENT": "CENTURY",
"CENTS": "CENTURY",
"CENTURIES": "CENTURY",
}

@classmethod
def get_or_raise(cls, dialect: DialectType) -> Dialect:
"""
Expand Down Expand Up @@ -1062,6 +1161,25 @@ def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[
return exp.Var(this=default) if default else None


@t.overload
def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var:
pass


@t.overload
def map_date_part(
part: t.Optional[exp.Expression], dialect: DialectType = Dialect
) -> t.Optional[exp.Expression]:
pass


def map_date_part(part, dialect: DialectType = Dialect):
mapped = (
Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None
)
return exp.var(mapped) if mapped else part


def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
Expand Down
7 changes: 6 additions & 1 deletion sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
json_extract_segments,
no_tablesample_sql,
rename_func,
map_date_part,
)
from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
Expand All @@ -23,7 +24,11 @@

def _build_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
def _builder(args: t.List) -> E:
expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
expr = expr_type(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=map_date_part(seq_get(args, 0)),
)
if expr_type is exp.TsOrDsAdd:
expr.set("return_type", exp.DataType.build("TIMESTAMP"))

Expand Down
99 changes: 6 additions & 93 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
timestamptrunc_sql,
timestrtotime_sql,
var_map_sql,
map_date_part,
)
from sqlglot.helper import flatten, is_float, is_int, seq_get
from sqlglot.tokens import TokenType
Expand Down Expand Up @@ -75,7 +76,7 @@ def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:

def _build_datediff(args: t.List) -> exp.DateDiff:
return exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0))
this=seq_get(args, 2), expression=seq_get(args, 1), unit=map_date_part(seq_get(args, 0))
)


Expand All @@ -84,7 +85,7 @@ def _builder(args: t.List) -> E:
return expr_type(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=_map_date_part(seq_get(args, 0)),
unit=map_date_part(seq_get(args, 0)),
)

return _builder
Expand Down Expand Up @@ -143,97 +144,9 @@ def _parse(self: Snowflake.Parser) -> exp.Show:
return _parse


DATE_PART_MAPPING = {
"Y": "YEAR",
"YY": "YEAR",
"YYY": "YEAR",
"YYYY": "YEAR",
"YR": "YEAR",
"YEARS": "YEAR",
"YRS": "YEAR",
"MM": "MONTH",
"MON": "MONTH",
"MONS": "MONTH",
"MONTHS": "MONTH",
"D": "DAY",
"DD": "DAY",
"DAYS": "DAY",
"DAYOFMONTH": "DAY",
"WEEKDAY": "DAYOFWEEK",
"DOW": "DAYOFWEEK",
"DW": "DAYOFWEEK",
"WEEKDAY_ISO": "DAYOFWEEKISO",
"DOW_ISO": "DAYOFWEEKISO",
"DW_ISO": "DAYOFWEEKISO",
"YEARDAY": "DAYOFYEAR",
"DOY": "DAYOFYEAR",
"DY": "DAYOFYEAR",
"W": "WEEK",
"WK": "WEEK",
"WEEKOFYEAR": "WEEK",
"WOY": "WEEK",
"WY": "WEEK",
"WEEK_ISO": "WEEKISO",
"WEEKOFYEARISO": "WEEKISO",
"WEEKOFYEAR_ISO": "WEEKISO",
"Q": "QUARTER",
"QTR": "QUARTER",
"QTRS": "QUARTER",
"QUARTERS": "QUARTER",
"H": "HOUR",
"HH": "HOUR",
"HR": "HOUR",
"HOURS": "HOUR",
"HRS": "HOUR",
"M": "MINUTE",
"MI": "MINUTE",
"MIN": "MINUTE",
"MINUTES": "MINUTE",
"MINS": "MINUTE",
"S": "SECOND",
"SEC": "SECOND",
"SECONDS": "SECOND",
"SECS": "SECOND",
"MS": "MILLISECOND",
"MSEC": "MILLISECOND",
"MILLISECONDS": "MILLISECOND",
"US": "MICROSECOND",
"USEC": "MICROSECOND",
"MICROSECONDS": "MICROSECOND",
"NS": "NANOSECOND",
"NSEC": "NANOSECOND",
"NANOSEC": "NANOSECOND",
"NSECOND": "NANOSECOND",
"NSECONDS": "NANOSECOND",
"NANOSECS": "NANOSECOND",
"EPOCH": "EPOCH_SECOND",
"EPOCH_SECONDS": "EPOCH_SECOND",
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
"EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
"EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
"TZH": "TIMEZONE_HOUR",
"TZM": "TIMEZONE_MINUTE",
}


@t.overload
def _map_date_part(part: exp.Expression) -> exp.Var:
pass


@t.overload
def _map_date_part(part: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
pass


def _map_date_part(part):
mapped = DATE_PART_MAPPING.get(part.name.upper()) if part else None
return exp.var(mapped) if mapped else part


def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
trunc = date_trunc_to_time(args)
trunc.set("unit", _map_date_part(trunc.args["unit"]))
trunc.set("unit", map_date_part(trunc.args["unit"]))
return trunc


Expand Down Expand Up @@ -367,7 +280,7 @@ class Parser(parser.Parser):
),
"IFF": exp.If.from_arg_list,
"LAST_DAY": lambda args: exp.LastDay(
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
this=seq_get(args, 0), unit=map_date_part(seq_get(args, 1))
),
"LISTAGG": exp.GroupConcat.from_arg_list,
"MEDIAN": lambda args: exp.PercentileCont(
Expand Down Expand Up @@ -541,7 +454,7 @@ def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:

self._match(TokenType.COMMA)
expression = self._parse_bitwise()
this = _map_date_part(this)
this = map_date_part(this)
name = this.name.upper()

if name.startswith("EPOCH"):
Expand Down
14 changes: 14 additions & 0 deletions tests/dialects/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,12 @@ def test_redshift(self):
"postgres": "COALESCE(a, b, c, d)",
},
)

self.validate_identity(
"DATEDIFF(days, a, b)",
"DATEDIFF(DAY, a, b)",
)

self.validate_all(
"DATEDIFF('day', a, b)",
write={
Expand Down Expand Up @@ -300,6 +306,14 @@ def test_redshift(self):
},
)

self.validate_all(
"SELECT EXTRACT(EPOCH FROM CURRENT_DATE)",
write={
"snowflake": "SELECT DATE_PART(EPOCH, CURRENT_DATE)",
"redshift": "SELECT EXTRACT(EPOCH FROM CURRENT_DATE)",
},
)

def test_identity(self):
self.validate_identity("LISTAGG(DISTINCT foo, ', ')")
self.validate_identity("CREATE MATERIALIZED VIEW orders AUTO REFRESH YES AS SELECT 1")
Expand Down

0 comments on commit d331e56

Please sign in to comment.