diff --git a/superset/sql/dialects/pinot.py b/superset/sql/dialects/pinot.py index e8804b2ee8ae..05d32f004bfd 100644 --- a/superset/sql/dialects/pinot.py +++ b/superset/sql/dialects/pinot.py @@ -24,7 +24,9 @@ from __future__ import annotations +from sqlglot import exp from sqlglot.dialects.mysql import MySQL +from sqlglot.tokens import TokenType class Pinot(MySQL): @@ -41,3 +43,55 @@ class Tokenizer(MySQL.Tokenizer): QUOTES = ["'"] # Only single quotes for strings IDENTIFIERS = ['"', "`"] # Backticks and double quotes for identifiers STRING_ESCAPES = ["'", "\\"] # Remove double quote from string escapes + KEYWORDS = { + **MySQL.Tokenizer.KEYWORDS, + "STRING": TokenType.TEXT, + "LONG": TokenType.BIGINT, + "BYTES": TokenType.VARBINARY, + } + + class Generator(MySQL.Generator): + TYPE_MAPPING = { + **MySQL.Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "INT", + exp.DataType.Type.SMALLINT: "INT", + exp.DataType.Type.INT: "INT", + exp.DataType.Type.BIGINT: "LONG", + exp.DataType.Type.FLOAT: "FLOAT", + exp.DataType.Type.DOUBLE: "DOUBLE", + exp.DataType.Type.BOOLEAN: "BOOLEAN", + exp.DataType.Type.TIMESTAMP: "TIMESTAMP", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.VARCHAR: "STRING", + exp.DataType.Type.CHAR: "STRING", + exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.BINARY: "BYTES", + exp.DataType.Type.VARBINARY: "BYTES", + exp.DataType.Type.JSON: "JSON", + } + + # Override MySQL's CAST_MAPPING - don't convert integer or string types + CAST_MAPPING = { + exp.DataType.Type.LONGBLOB: exp.DataType.Type.VARBINARY, + exp.DataType.Type.MEDIUMBLOB: exp.DataType.Type.VARBINARY, + exp.DataType.Type.TINYBLOB: exp.DataType.Type.VARBINARY, + exp.DataType.Type.UBIGINT: "UNSIGNED", + } + + def datatype_sql(self, expression: exp.DataType) -> str: + # Don't use MySQL's VARCHAR size requirement logic + # Just use TYPE_MAPPING for all types + type_value = expression.this + type_sql = ( + self.TYPE_MAPPING.get(type_value, type_value.value) + if isinstance(type_value, exp.DataType.Type) + else type_value + ) + + interior = self.expressions(expression, flat=True) + nested = f"({interior})" if interior else "" + + if expression.this in self.UNSIGNED_TYPE_MAPPING: + return f"{type_sql} UNSIGNED{nested}" + + return f"{type_sql}{nested}" diff --git a/tests/unit_tests/sql/dialects/pinot_tests.py b/tests/unit_tests/sql/dialects/pinot_tests.py index 4d7eed7154e1..f1a6cfb7298e 100644 --- a/tests/unit_tests/sql/dialects/pinot_tests.py +++ b/tests/unit_tests/sql/dialects/pinot_tests.py @@ -346,3 +346,78 @@ def test_distinct() -> None: FROM "products" """.strip() ) + + +def test_cast_to_string() -> None: + """ + Test that CAST to STRING is preserved (not converted to CHAR). + """ + sql = "SELECT CAST(cohort_size AS STRING) FROM table" + ast = sqlglot.parse_one(sql, Pinot) + generated = Pinot().generate(expression=ast) + + assert "STRING" in generated + assert "CHAR" not in generated + + +def test_concat_with_cast_string() -> None: + """ + Test CONCAT with CAST to STRING - verifies the original issue is fixed. + """ + sql = """ +SELECT concat(a, cast(b AS string), ' - ') +FROM "default".c""" + ast = sqlglot.parse_one(sql, Pinot) + generated = Pinot().generate(expression=ast) + + # Verify STRING type is preserved (not converted to CHAR) + assert "STRING" in generated or "string" in generated.lower() + assert "CHAR" not in generated + + +@pytest.mark.parametrize( + "cast_type, expected_type", + [ + ("INT", "INT"), + ("TINYINT", "INT"), + ("SMALLINT", "INT"), + ("BIGINT", "LONG"), + ("LONG", "LONG"), + ("FLOAT", "FLOAT"), + ("DOUBLE", "DOUBLE"), + ("BOOLEAN", "BOOLEAN"), + ("TIMESTAMP", "TIMESTAMP"), + ("STRING", "STRING"), + ("VARCHAR", "STRING"), + ("CHAR", "STRING"), + ("TEXT", "STRING"), + ("BYTES", "BYTES"), + ("BINARY", "BYTES"), + ("VARBINARY", "BYTES"), + ("JSON", "JSON"), + ], +) +def test_type_mappings(cast_type: str, expected_type: str) -> None: + """ + Test that Pinot type mappings work correctly for all basic types. + """ + sql = f"SELECT CAST(col AS {cast_type}) FROM table" # noqa: S608 + ast = sqlglot.parse_one(sql, Pinot) + generated = Pinot().generate(expression=ast) + + assert expected_type in generated + + +def test_unsigned_type() -> None: + """ + Test that unsigned integer types are handled correctly. + Tests the UNSIGNED_TYPE_MAPPING path in datatype_sql method. + """ + from sqlglot import exp + + # Create a UBIGINT DataType which is in UNSIGNED_TYPE_MAPPING + dt = exp.DataType(this=exp.DataType.Type.UBIGINT) + result = Pinot.Generator().datatype_sql(dt) + + assert "UNSIGNED" in result + assert "BIGINT" in result