Skip to content

Commit

Permalink
feat(postgres): Support DIV() func for integer division (#3602)
Browse files Browse the repository at this point in the history
* feat(postgres): Support DIV() func for integer division

* Add type builder instead of is_type
  • Loading branch information
VaggelisD authored Jun 7, 2024
1 parent d6cfb41 commit 4b30b87
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
15 changes: 15 additions & 0 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Dialect,
JSON_EXTRACT_TYPE,
any_value_to_max_sql,
binary_from_function,
bool_xor_sql,
datestrtodate_sql,
build_formatted_time,
Expand Down Expand Up @@ -329,6 +330,7 @@ class Tokenizer(tokens.Tokenizer):
"REGTYPE": TokenType.OBJECT_IDENTIFIER,
"FLOAT": TokenType.DOUBLE,
}
KEYWORDS.pop("DIV")

SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
Expand All @@ -347,6 +349,9 @@ class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": build_timestamp_trunc,
"DIV": lambda args: exp.cast(
binary_from_function(exp.IntDiv)(args), exp.DataType.Type.DECIMAL
),
"GENERATE_SERIES": _build_generate_series,
"JSON_EXTRACT_PATH": build_json_extract_path(exp.JSONExtract),
"JSON_EXTRACT_PATH_TEXT": build_json_extract_path(exp.JSONExtractScalar),
Expand Down Expand Up @@ -494,6 +499,7 @@ class Generator(generator.Generator):
exp.DateSub: _date_add_sql("-"),
exp.Explode: rename_func("UNNEST"),
exp.GroupConcat: _string_agg_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONExtract: _json_extract_sql("JSON_EXTRACT_PATH", "->"),
exp.JSONExtractScalar: _json_extract_sql("JSON_EXTRACT_PATH_TEXT", "->>"),
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
Expand Down Expand Up @@ -621,3 +627,12 @@ def datatype_sql(self, expression: exp.DataType) -> str:
return f"{self.expressions(expression, flat=True)}[{values}]"
return "ARRAY"
return super().datatype_sql(expression)

def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
this = expression.this

# Postgres casts DIV() to decimal for transpilation but when roundtripping it's superfluous
if isinstance(this, exp.IntDiv) and expression.to == exp.DataType.build("decimal"):
return self.sql(this)

return super().cast_sql(expression, safe_prefix=safe_prefix)
22 changes: 22 additions & 0 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,28 @@ def test_postgres(self):
self.validate_identity("cast(a as FLOAT8)", "CAST(a AS DOUBLE PRECISION)")
self.validate_identity("cast(a as FLOAT4)", "CAST(a AS REAL)")

self.validate_all(
"1 / DIV(4, 2)",
read={
"postgres": "1 / DIV(4, 2)",
},
write={
"sqlite": "1 / CAST(CAST(CAST(4 AS REAL) / 2 AS INTEGER) AS REAL)",
"duckdb": "1 / CAST(4 // 2 AS DECIMAL)",
"bigquery": "1 / CAST(DIV(4, 2) AS NUMERIC)",
},
)
self.validate_all(
"CAST(DIV(4, 2) AS DECIMAL(5, 3))",
read={
"duckdb": "CAST(4 // 2 AS DECIMAL(5, 3))",
},
write={
"duckdb": "CAST(CAST(4 // 2 AS DECIMAL) AS DECIMAL(5, 3))",
"postgres": "CAST(DIV(4, 2) AS DECIMAL(5, 3))",
},
)

def test_ddl(self):
# Checks that user-defined types are parsed into DataType instead of Identifier
self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is(
Expand Down

0 comments on commit 4b30b87

Please sign in to comment.