Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(redshift)!: Normalize time units in their full singular form #3652

Merged
merged 5 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ class Dialect(metaclass=_Dialect):
) SELECT c FROM y;
"""

COPY_PARAMS_ARE_CSV = True
"""
Whether COPY statement parameters are separated by comma or whitespace
"""

# --- Autofilled ---

tokenizer_class = Tokenizer
Expand Down Expand Up @@ -347,6 +352,100 @@ class Dialect(metaclass=_Dialect):
UNICODE_START: t.Optional[str] = None
UNICODE_END: t.Optional[str] = None

DATE_PART_MAPPING = {
"Y": "YEAR",
"YY": "YEAR",
"YYY": "YEAR",
"YYYY": "YEAR",
"YR": "YEAR",
"YEARS": "YEAR",
"YRS": "YEAR",
"MM": "MONTH",
"MON": "MONTH",
"MONS": "MONTH",
"MONTHS": "MONTH",
"D": "DAY",
"DD": "DAY",
"DAYS": "DAY",
"DAYOFMONTH": "DAY",
"DAY OF WEEK": "DAYOFWEEK",
"WEEKDAY": "DAYOFWEEK",
"DOW": "DAYOFWEEK",
"DW": "DAYOFWEEK",
"WEEKDAY_ISO": "DAYOFWEEKISO",
"DOW_ISO": "DAYOFWEEKISO",
"DW_ISO": "DAYOFWEEKISO",
"DAY OF YEAR": "DAYOFYEAR",
"DOY": "DAYOFYEAR",
"DY": "DAYOFYEAR",
"W": "WEEK",
"WK": "WEEK",
"WEEKOFYEAR": "WEEK",
"WOY": "WEEK",
"WY": "WEEK",
"WEEK_ISO": "WEEKISO",
"WEEKOFYEARISO": "WEEKISO",
"WEEKOFYEAR_ISO": "WEEKISO",
"Q": "QUARTER",
"QTR": "QUARTER",
"QTRS": "QUARTER",
"QUARTERS": "QUARTER",
"H": "HOUR",
"HH": "HOUR",
"HR": "HOUR",
"HOURS": "HOUR",
"HRS": "HOUR",
"M": "MINUTE",
"MI": "MINUTE",
"MIN": "MINUTE",
"MINUTES": "MINUTE",
"MINS": "MINUTE",
"S": "SECOND",
"SEC": "SECOND",
"SECONDS": "SECOND",
"SECS": "SECOND",
"MS": "MILLISECOND",
"MSEC": "MILLISECOND",
"MSECS": "MILLISECOND",
"MSECOND": "MILLISECOND",
"MSECONDS": "MILLISECOND",
"MILLISEC": "MILLISECOND",
"MILLISECS": "MILLISECOND",
"MILLISECON": "MILLISECOND",
"MILLISECONDS": "MILLISECOND",
"US": "MICROSECOND",
"USEC": "MICROSECOND",
"USECS": "MICROSECOND",
"MICROSEC": "MICROSECOND",
"MICROSECS": "MICROSECOND",
"USECOND": "MICROSECOND",
"USECONDS": "MICROSECOND",
"MICROSECONDS": "MICROSECOND",
"NS": "NANOSECOND",
"NSEC": "NANOSECOND",
"NANOSEC": "NANOSECOND",
"NSECOND": "NANOSECOND",
"NSECONDS": "NANOSECOND",
"NANOSECS": "NANOSECOND",
"EPOCH_SECOND": "EPOCH",
"EPOCH_SECONDS": "EPOCH",
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
"EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
"EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
"TZH": "TIMEZONE_HOUR",
"TZM": "TIMEZONE_MINUTE",
"DEC": "DECADE",
"DECS": "DECADE",
"DECADES": "DECADE",
"MIL": "MILLENIUM",
"MILS": "MILLENIUM",
"MILLENIA": "MILLENIUM",
"C": "CENTURY",
"CENT": "CENTURY",
"CENTS": "CENTURY",
"CENTURIES": "CENTURY",
}

@classmethod
def get_or_raise(cls, dialect: DialectType) -> Dialect:
"""
Expand Down Expand Up @@ -1062,6 +1161,25 @@ def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[
return exp.Var(this=default) if default else None


@t.overload
def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var:
pass


@t.overload
def map_date_part(
part: t.Optional[exp.Expression], dialect: DialectType = Dialect
) -> t.Optional[exp.Expression]:
pass


def map_date_part(part, dialect: DialectType = Dialect):
mapped = (
Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None
)
return exp.var(mapped) if mapped else part


def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
Expand Down
7 changes: 6 additions & 1 deletion sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
json_extract_segments,
no_tablesample_sql,
rename_func,
map_date_part,
)
from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
Expand All @@ -23,7 +24,11 @@

def _build_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
def _builder(args: t.List) -> E:
expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
expr = expr_type(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=map_date_part(seq_get(args, 0)),
)
if expr_type is exp.TsOrDsAdd:
expr.set("return_type", exp.DataType.build("TIMESTAMP"))

Expand Down
99 changes: 6 additions & 93 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
timestamptrunc_sql,
timestrtotime_sql,
var_map_sql,
map_date_part,
)
from sqlglot.helper import flatten, is_float, is_int, seq_get
from sqlglot.tokens import TokenType
Expand Down Expand Up @@ -75,7 +76,7 @@ def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:

def _build_datediff(args: t.List) -> exp.DateDiff:
return exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0))
this=seq_get(args, 2), expression=seq_get(args, 1), unit=map_date_part(seq_get(args, 0))
)


Expand All @@ -84,7 +85,7 @@ def _builder(args: t.List) -> E:
return expr_type(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=_map_date_part(seq_get(args, 0)),
unit=map_date_part(seq_get(args, 0)),
)

return _builder
Expand Down Expand Up @@ -143,97 +144,9 @@ def _parse(self: Snowflake.Parser) -> exp.Show:
return _parse


DATE_PART_MAPPING = {
"Y": "YEAR",
"YY": "YEAR",
"YYY": "YEAR",
"YYYY": "YEAR",
"YR": "YEAR",
"YEARS": "YEAR",
"YRS": "YEAR",
"MM": "MONTH",
"MON": "MONTH",
"MONS": "MONTH",
"MONTHS": "MONTH",
"D": "DAY",
"DD": "DAY",
"DAYS": "DAY",
"DAYOFMONTH": "DAY",
"WEEKDAY": "DAYOFWEEK",
"DOW": "DAYOFWEEK",
"DW": "DAYOFWEEK",
"WEEKDAY_ISO": "DAYOFWEEKISO",
"DOW_ISO": "DAYOFWEEKISO",
"DW_ISO": "DAYOFWEEKISO",
"YEARDAY": "DAYOFYEAR",
"DOY": "DAYOFYEAR",
"DY": "DAYOFYEAR",
"W": "WEEK",
"WK": "WEEK",
"WEEKOFYEAR": "WEEK",
"WOY": "WEEK",
"WY": "WEEK",
"WEEK_ISO": "WEEKISO",
"WEEKOFYEARISO": "WEEKISO",
"WEEKOFYEAR_ISO": "WEEKISO",
"Q": "QUARTER",
"QTR": "QUARTER",
"QTRS": "QUARTER",
"QUARTERS": "QUARTER",
"H": "HOUR",
"HH": "HOUR",
"HR": "HOUR",
"HOURS": "HOUR",
"HRS": "HOUR",
"M": "MINUTE",
"MI": "MINUTE",
"MIN": "MINUTE",
"MINUTES": "MINUTE",
"MINS": "MINUTE",
"S": "SECOND",
"SEC": "SECOND",
"SECONDS": "SECOND",
"SECS": "SECOND",
"MS": "MILLISECOND",
"MSEC": "MILLISECOND",
"MILLISECONDS": "MILLISECOND",
"US": "MICROSECOND",
"USEC": "MICROSECOND",
"MICROSECONDS": "MICROSECOND",
"NS": "NANOSECOND",
"NSEC": "NANOSECOND",
"NANOSEC": "NANOSECOND",
"NSECOND": "NANOSECOND",
"NSECONDS": "NANOSECOND",
"NANOSECS": "NANOSECOND",
"EPOCH": "EPOCH_SECOND",
"EPOCH_SECONDS": "EPOCH_SECOND",
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
"EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
"EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
"TZH": "TIMEZONE_HOUR",
"TZM": "TIMEZONE_MINUTE",
}


@t.overload
def _map_date_part(part: exp.Expression) -> exp.Var:
pass


@t.overload
def _map_date_part(part: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
pass


def _map_date_part(part):
mapped = DATE_PART_MAPPING.get(part.name.upper()) if part else None
return exp.var(mapped) if mapped else part


def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
trunc = date_trunc_to_time(args)
trunc.set("unit", _map_date_part(trunc.args["unit"]))
trunc.set("unit", map_date_part(trunc.args["unit"]))
return trunc


Expand Down Expand Up @@ -367,7 +280,7 @@ class Parser(parser.Parser):
),
"IFF": exp.If.from_arg_list,
"LAST_DAY": lambda args: exp.LastDay(
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
this=seq_get(args, 0), unit=map_date_part(seq_get(args, 1))
),
"LISTAGG": exp.GroupConcat.from_arg_list,
"MEDIAN": lambda args: exp.PercentileCont(
Expand Down Expand Up @@ -541,7 +454,7 @@ def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:

self._match(TokenType.COMMA)
expression = self._parse_bitwise()
this = _map_date_part(this)
this = map_date_part(this)
name = this.name.upper()

if name.startswith("EPOCH"):
Expand Down
14 changes: 14 additions & 0 deletions tests/dialects/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,12 @@ def test_redshift(self):
"postgres": "COALESCE(a, b, c, d)",
},
)

self.validate_identity(
"DATEDIFF(days, a, b)",
"DATEDIFF(DAY, a, b)",
)

self.validate_all(
"DATEDIFF('day', a, b)",
write={
Expand Down Expand Up @@ -300,6 +306,14 @@ def test_redshift(self):
},
)

self.validate_all(
"SELECT EXTRACT(EPOCH FROM CURRENT_DATE)",
write={
"snowflake": "SELECT DATE_PART(EPOCH, CURRENT_DATE)",
"redshift": "SELECT EXTRACT(EPOCH FROM CURRENT_DATE)",
},
)

def test_identity(self):
self.validate_identity("LISTAGG(DISTINCT foo, ', ')")
self.validate_identity("CREATE MATERIALIZED VIEW orders AUTO REFRESH YES AS SELECT 1")
Expand Down
Loading