Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions superset/sql/dialects/pinot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@

from __future__ import annotations

from sqlglot import exp
from sqlglot.dialects.mysql import MySQL
from sqlglot.tokens import TokenType


class Pinot(MySQL):
Expand All @@ -41,3 +43,55 @@ class Tokenizer(MySQL.Tokenizer):
QUOTES = ["'"] # Only single quotes for strings
IDENTIFIERS = ['"', "`"] # Backticks and double quotes for identifiers
STRING_ESCAPES = ["'", "\\"] # Remove double quote from string escapes
KEYWORDS = {
**MySQL.Tokenizer.KEYWORDS,
"STRING": TokenType.TEXT,
"LONG": TokenType.BIGINT,
"BYTES": TokenType.VARBINARY,
}

class Generator(MySQL.Generator):
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "INT",
exp.DataType.Type.SMALLINT: "INT",
exp.DataType.Type.INT: "INT",
exp.DataType.Type.BIGINT: "LONG",
exp.DataType.Type.FLOAT: "FLOAT",
exp.DataType.Type.DOUBLE: "DOUBLE",
exp.DataType.Type.BOOLEAN: "BOOLEAN",
exp.DataType.Type.TIMESTAMP: "TIMESTAMP",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.VARCHAR: "STRING",
exp.DataType.Type.CHAR: "STRING",
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.BINARY: "BYTES",
exp.DataType.Type.VARBINARY: "BYTES",
exp.DataType.Type.JSON: "JSON",
}

# Override MySQL's CAST_MAPPING - don't convert integer or string types
CAST_MAPPING = {
exp.DataType.Type.LONGBLOB: exp.DataType.Type.VARBINARY,
exp.DataType.Type.MEDIUMBLOB: exp.DataType.Type.VARBINARY,
exp.DataType.Type.TINYBLOB: exp.DataType.Type.VARBINARY,
exp.DataType.Type.UBIGINT: "UNSIGNED",
}

def datatype_sql(self, expression: exp.DataType) -> str:
# Don't use MySQL's VARCHAR size requirement logic
# Just use TYPE_MAPPING for all types
type_value = expression.this
type_sql = (
self.TYPE_MAPPING.get(type_value, type_value.value)
if isinstance(type_value, exp.DataType.Type)
else type_value
)

interior = self.expressions(expression, flat=True)
nested = f"({interior})" if interior else ""

if expression.this in self.UNSIGNED_TYPE_MAPPING:
return f"{type_sql} UNSIGNED{nested}"

return f"{type_sql}{nested}"
75 changes: 75 additions & 0 deletions tests/unit_tests/sql/dialects/pinot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,78 @@ def test_distinct() -> None:
FROM "products"
""".strip()
)


def test_cast_to_string() -> None:
"""
Test that CAST to STRING is preserved (not converted to CHAR).
"""
sql = "SELECT CAST(cohort_size AS STRING) FROM table"
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast)

assert "STRING" in generated
assert "CHAR" not in generated


def test_concat_with_cast_string() -> None:
"""
Test CONCAT with CAST to STRING - verifies the original issue is fixed.
"""
sql = """
SELECT concat(a, cast(b AS string), ' - ')
FROM "default".c"""
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast)

# Verify STRING type is preserved (not converted to CHAR)
assert "STRING" in generated or "string" in generated.lower()
assert "CHAR" not in generated


@pytest.mark.parametrize(
"cast_type, expected_type",
[
("INT", "INT"),
("TINYINT", "INT"),
("SMALLINT", "INT"),
("BIGINT", "LONG"),
("LONG", "LONG"),
("FLOAT", "FLOAT"),
("DOUBLE", "DOUBLE"),
("BOOLEAN", "BOOLEAN"),
("TIMESTAMP", "TIMESTAMP"),
("STRING", "STRING"),
("VARCHAR", "STRING"),
("CHAR", "STRING"),
("TEXT", "STRING"),
("BYTES", "BYTES"),
("BINARY", "BYTES"),
("VARBINARY", "BYTES"),
("JSON", "JSON"),
],
)
def test_type_mappings(cast_type: str, expected_type: str) -> None:
"""
Test that Pinot type mappings work correctly for all basic types.
"""
sql = f"SELECT CAST(col AS {cast_type}) FROM table" # noqa: S608
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast)

assert expected_type in generated


def test_unsigned_type() -> None:
"""
Test that unsigned integer types are handled correctly.
Tests the UNSIGNED_TYPE_MAPPING path in datatype_sql method.
"""
from sqlglot import exp

# Create a UBIGINT DataType which is in UNSIGNED_TYPE_MAPPING
dt = exp.DataType(this=exp.DataType.Type.UBIGINT)
result = Pinot.Generator().datatype_sql(dt)

assert "UNSIGNED" in result
assert "BIGINT" in result
Loading