diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 037916162c..ce9912a155 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -265,6 +265,30 @@ def _versioned_anyvalue_sql(self: Postgres.Generator, expression: exp.AnyValue) return rename_func("ANY_VALUE")(self, expression) +def _to_decimal(self: Postgres.Generator, expression: exp.Expression) -> exp.Expression: + if not expression.type: + from sqlglot.optimizer.annotate_types import annotate_types + + annotate_types(expression, dialect=self.dialect) + + if expression.type and expression.type == exp.DataType.build("DOUBLE"): + return exp.cast(expression, to=exp.DataType.build("DECIMAL")) + return expression + + +def _round(self: Postgres.Generator, expression: exp.Round) -> str: + # ROUND(double precision, integer) is not permitted in Postgres + # so it's necessary to cast double precision to decimal before rounding. + + this = self.sql(expression, "this") + decimals = self.sql(expression, "decimals") + + if not decimals: + return self.func("ROUND", this) + + return self.func("ROUND", _to_decimal(self, expression.this), decimals) + + class Postgres(Dialect): INDEX_OFFSET = 1 TYPED_DIVISION = True @@ -613,6 +637,7 @@ class Generator(generator.Generator): exp.Rand: rename_func("RANDOM"), exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), + exp.Round: _round, exp.Select: transforms.preprocess( [ transforms.eliminate_semi_and_anti_joins, diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 54e65dd62f..80d0cf6b5d 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -1460,3 +1460,24 @@ def _validate_udt(sql: str): _validate_udt('CAST(5 AS "MySchema"."MyType")') _validate_udt('CAST(5 AS MySchema."MyType")') _validate_udt('CAST(5 AS "MyCatalog"."MySchema"."MyType")') + + def test_round(self): + self.validate_identity("ROUND(x)") + self.validate_identity("ROUND(CAST(x AS DOUBLE PRECISION))") + self.validate_identity("ROUND(CAST(x AS DECIMAL), 4)") + self.validate_identity("ROUND(CAST(x AS INT), 4)") + self.validate_identity( + "ROUND(x, y)", + "ROUND(CAST(x AS DECIMAL), y)", + ) + self.validate_all( + "ROUND(CAST(CAST(x AS DOUBLE PRECISION) AS DECIMAL), 4)", + read={ + "postgres": "ROUND(x::DOUBLE, 4)", + "hive": "ROUND(x::DOUBLE, 4)", + "bigquery": "ROUND(x::DOUBLE, 4)", + }, + ) + self.validate_all( + "ROUND(CAST(x AS DECIMAL(18, 3)), 4)", read={"duckdb": "ROUND(x::DECIMAL, 4)"} + )