Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(postgres): Support DIV() func for integer division #3602

Merged
merged 2 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Dialect,
JSON_EXTRACT_TYPE,
any_value_to_max_sql,
binary_from_function,
bool_xor_sql,
datestrtodate_sql,
build_formatted_time,
Expand Down Expand Up @@ -329,6 +330,7 @@ class Tokenizer(tokens.Tokenizer):
"REGTYPE": TokenType.OBJECT_IDENTIFIER,
"FLOAT": TokenType.DOUBLE,
}
KEYWORDS.pop("DIV")

SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
Expand All @@ -347,6 +349,9 @@ class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": build_timestamp_trunc,
"DIV": lambda args: exp.cast(
binary_from_function(exp.IntDiv)(args), exp.DataType.Type.DECIMAL
),
"GENERATE_SERIES": _build_generate_series,
"JSON_EXTRACT_PATH": build_json_extract_path(exp.JSONExtract),
"JSON_EXTRACT_PATH_TEXT": build_json_extract_path(exp.JSONExtractScalar),
Expand Down Expand Up @@ -494,6 +499,7 @@ class Generator(generator.Generator):
exp.DateSub: _date_add_sql("-"),
exp.Explode: rename_func("UNNEST"),
exp.GroupConcat: _string_agg_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONExtract: _json_extract_sql("JSON_EXTRACT_PATH", "->"),
exp.JSONExtractScalar: _json_extract_sql("JSON_EXTRACT_PATH_TEXT", "->>"),
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
Expand Down Expand Up @@ -621,3 +627,12 @@ def datatype_sql(self, expression: exp.DataType) -> str:
return f"{self.expressions(expression, flat=True)}[{values}]"
return "ARRAY"
return super().datatype_sql(expression)

def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
this = expression.this

# Postgres casts DIV() to decimal for transpilation but when roundtripping it's superfluous
if isinstance(this, exp.IntDiv) and expression.to.is_type(exp.DataType.Type.DECIMAL):
return self.sql(this)
georgesittas marked this conversation as resolved.
Show resolved Hide resolved

return super().cast_sql(expression, safe_prefix=safe_prefix)
12 changes: 12 additions & 0 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,18 @@ def test_postgres(self):
self.validate_identity("cast(a as FLOAT8)", "CAST(a AS DOUBLE PRECISION)")
self.validate_identity("cast(a as FLOAT4)", "CAST(a AS REAL)")

self.validate_all(
"1 / DIV(4, 2)",
read={
"postgres": "1 / DIV(4, 2)",
},
write={
"sqlite": "1 / CAST(CAST(CAST(4 AS REAL) / 2 AS INTEGER) AS REAL)",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CAST(4 AS REAL) here looks a bit weird, I don't think it should've been added. The root cause is that the IntDiv generator doesn't properly populate the typed and safe args, disregarding the dialect's semantics. It seems like that method was added a long time ago, so this is most likely like an oversight.

Copy link
Collaborator

@georgesittas georgesittas Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it should've been added

This is because Postgres uses typed division, so since we're generating a Div as part of IntDiv's generation, the typed division semantics should be respected. See for example how we don't add a cast to REAL here:

>>> import sqlglot
>>> sqlglot.transpile("1/2", "postgres", "sqlite")
['1 / 2']

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be best to address this in a followup PR after all. I played around a bit and it seems like we'll need to think more about how to parse / generate IntDiv.

"duckdb": "1 / CAST(4 // 2 AS DECIMAL)",
"bigquery": "1 / CAST(DIV(4, 2) AS NUMERIC)",
},
)

def test_ddl(self):
# Checks that user-defined types are parsed into DataType instead of Identifier
self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is(
Expand Down
Loading