Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions superset/sql/dialects/pinot.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,36 @@ class Generator(MySQL.Generator):
e.args.get("start"),
e.args.get("length"),
),
exp.StrPosition: lambda self, e: self.func(
"STRPOS",
e.this,
e.args.get("substr"),
e.args.get("position"),
),
exp.StartsWith: lambda self, e: self.func(
"STARTSWITH",
e.this,
e.args.get("expression"),
),
exp.Chr: lambda self, e: self.func(
"CHR",
*e.args.get("expressions", []),
),
exp.Mod: lambda self, e: self.func(
"MOD",
e.this,
e.args.get("expression"),
),
exp.ArrayAgg: lambda self, e: self.func(
"ARRAY_AGG",
e.this,
),
exp.JSONExtractScalar: lambda self, e: self.func(
"JSON_EXTRACT_SCALAR",
e.this,
e.args.get("expression"),
e.args.get("variant"),
),
}
# Remove DATE_TRUNC transformation - Pinot supports standard SQL DATE_TRUNC
TRANSFORMS.pop(exp.DateTrunc, None)
Expand Down
267 changes: 267 additions & 0 deletions tests/unit_tests/sql/dialects/pinot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,270 @@ def test_substr_cross_dialect_generation() -> None:
mysql_output = parsed.sql(dialect="mysql")
assert "SUBSTRING(" in mysql_output
assert pinot_output != mysql_output # They should be different


@pytest.mark.parametrize(
"function_name,sample_args",
[
# Math functions
("ABS", "-5"),
("CEIL", "3.14"),
("FLOOR", "3.14"),
("EXP", "2"),
("LN", "10"),
("SQRT", "16"),
("ROUNDDECIMAL", "3.14159, 2"),
("ADD", "1, 2, 3"),
("SUB", "10, 3"),
("MULT", "5, 4"),
("MOD", "10, 3"),
# String functions
("UPPER", "'hello'"),
("LOWER", "'HELLO'"),
("REVERSE", "'hello'"),
("SUBSTR", "'hello', 0, 3"),
("CONCAT", "'hello', ' ', 'world'"),
("TRIM", "' hello '"),
("LTRIM", "' hello'"),
("RTRIM", "'hello '"),
("LENGTH", "'hello'"),
("STRPOS", "'hello', 'l', 1"),
("STARTSWITH", "'hello', 'he'"),
("REPLACE", "'hello', 'l', 'r'"),
("RPAD", "'hello', 10, 'x'"),
("LPAD", "'hello', 10, 'x'"),
("CODEPOINT", "'A'"),
("CHR", "65"),
("regexpExtract", "'foo123bar', '[0-9]+'"),
("regexpReplace", "'hello', 'l', 'r'"),
("remove", "'hello', 'l'"),
("urlEncoding", "'hello world'"),
("urlDecoding", "'hello%20world'"),
("fromBase64", "'aGVsbG8='"),
("toUtf8", "'hello'"),
("isSubnetOf", "'192.168.1.1', '192.168.0.0/16'"),
# DateTime functions
("DATETRUNC", "'day', timestamp_col"),
("DATETIMECONVERT", "dt_col, '1:HOURS:EPOCH', '1:DAYS:EPOCH', '1:DAYS'"),
("TIMECONVERT", "timestamp_col, 'MILLISECONDS', 'SECONDS'"),
("NOW", ""),
("AGO", "'P1D'"),
("YEAR", "timestamp_col"),
("QUARTER", "timestamp_col"),
("MONTH", "timestamp_col"),
("WEEK", "timestamp_col"),
("DAY", "timestamp_col"),
("HOUR", "timestamp_col"),
("MINUTE", "timestamp_col"),
("SECOND", "timestamp_col"),
("MILLISECOND", "timestamp_col"),
("DAYOFWEEK", "timestamp_col"),
("DAYOFYEAR", "timestamp_col"),
("YEAROFWEEK", "timestamp_col"),
("toEpochSeconds", "timestamp_col"),
("toEpochMinutes", "timestamp_col"),
("toEpochHours", "timestamp_col"),
("toEpochDays", "timestamp_col"),
("fromEpochSeconds", "1234567890"),
("fromEpochMinutes", "20576131"),
("fromEpochHours", "342935"),
("fromEpochDays", "14288"),
("toDateTime", "timestamp_col, 'yyyy-MM-dd'"),
("fromDateTime", "'2024-01-01', 'yyyy-MM-dd'"),
("timezoneHour", "timestamp_col"),
("timezoneMinute", "timestamp_col"),
("DATE_ADD", "'day', 7, NOW()"),
("DATE_SUB", "'day', 7, NOW()"),
("TIMESTAMPADD", "'day', 7, timestamp_col"),
("TIMESTAMPDIFF", "'day', timestamp1, timestamp2"),
("dateTrunc", "'day', timestamp_col"),
("dateDiff", "'day', timestamp1, timestamp2"),
("dateAdd", "'day', 7, timestamp_col"),
("dateBin", "'day', timestamp_col, NOW()"),
("toIso8601", "timestamp_col"),
("fromIso8601", "'2024-01-01T00:00:00Z'"),
# Aggregation functions
("COUNT", "*"),
("SUM", "amount"),
("AVG", "value"),
("MIN", "value"),
("MAX", "value"),
("DISTINCTCOUNT", "user_id"),
("DISTINCTCOUNTBITMAP", "user_id"),
("DISTINCTCOUNTHLL", "user_id"),
("DISTINCTCOUNTRAWHLL", "user_id"),
("DISTINCTCOUNTHLLPLUS", "user_id"),
("DISTINCTCOUNTRAWHLLPLUS", "user_id"),
("DISTINCTCOUNTSMARTHLL", "user_id"),
("DISTINCTCOUNTCPCSKETCH", "user_id"),
("DISTINCTCOUNTRAWCPCSKETCH", "user_id"),
("DISTINCTCOUNTTHETASKETCH", "user_id"),
("DISTINCTCOUNTRAWTHETASKETCH", "user_id"),
("DISTINCTCOUNTTUPLESKETCH", "user_id"),
("DISTINCTCOUNTRAWINTEGERSUMTUPLESKETCH", "user_id"),
("DISTINCTCOUNTULL", "user_id"),
("DISTINCTCOUNTRAWULL", "user_id"),
("SEGMENTPARTITIONEDDISTINCTCOUNT", "user_id"),
("SUMVALUESINTEGERSUMTUPLESKETCH", "value"),
("PERCENTILE", "value, 95"),
("PERCENTILEEST", "value, 95"),
("PERCENTILETDIGEST", "value, 95"),
("PERCENTILESMARTTDIGEST", "value, 95"),
("PERCENTILEKLL", "value, 95"),
("PERCENTILEKLLRAW", "value, 95"),
("HISTOGRAM", "value, 10"),
("MODE", "category"),
("MINMAXRANGE", "value"),
("SUMPRECISION", "value, 10"),
("ARG_MIN", "value, id"),
("ARG_MAX", "value, id"),
("COVAR_POP", "x, y"),
("COVAR_SAMP", "x, y"),
("LASTWITHTIME", "value, timestamp_col, 'LONG'"),
("FIRSTWITHTIME", "value, timestamp_col, 'LONG'"),
("ARRAY_AGG", "value"),
# Multi-value functions
("COUNTMV", "tags"),
("MAXMV", "scores"),
("MINMV", "scores"),
("SUMMV", "scores"),
("AVGMV", "scores"),
("MINMAXRANGEMV", "scores"),
("PERCENTILEMV", "scores, 95"),
("PERCENTILEESTMV", "scores, 95"),
("PERCENTILETDIGESTMV", "scores, 95"),
("PERCENTILEKLLMV", "scores, 95"),
("DISTINCTCOUNTMV", "tags"),
("DISTINCTCOUNTBITMAPMV", "tags"),
("DISTINCTCOUNTHLLMV", "tags"),
("DISTINCTCOUNTRAWHLLMV", "tags"),
("DISTINCTCOUNTHLLPLUSMV", "tags"),
("DISTINCTCOUNTRAWHLLPLUSMV", "tags"),
("ARRAYLENGTH", "array_col"),
("MAP_VALUE", "map_col, 'key'"),
("VALUEIN", "value, 'val1', 'val2'"),
# JSON functions
("JSONEXTRACTSCALAR", "json_col, '$.name', 'STRING'"),
("JSONEXTRACTKEY", "json_col, '$.data'"),
("JSONFORMAT", "json_col"),
("JSONPATH", "json_col, '$.name'"),
("JSONPATHLONG", "json_col, '$.id'"),
("JSONPATHDOUBLE", "json_col, '$.price'"),
("JSONPATHSTRING", "json_col, '$.name'"),
("JSONPATHARRAY", "json_col, '$.items'"),
("JSONPATHARRAYDEFAULTEMPTY", "json_col, '$.items'"),
("TOJSONMAPSTR", "map_col"),
("JSON_MATCH", "json_col, '\"$.name\"=''value'''"),
("JSON_EXTRACT_SCALAR", "json_col, '$.name', 'STRING'"),
# Array functions
("arrayReverseInt", "int_array"),
("arrayReverseString", "string_array"),
("arraySortInt", "int_array"),
("arraySortString", "string_array"),
("arrayIndexOfInt", "int_array, 5"),
("arrayIndexOfString", "string_array, 'value'"),
("arrayContainsInt", "int_array, 5"),
("arrayContainsString", "string_array, 'value'"),
("arraySliceInt", "int_array, 0, 3"),
("arraySliceString", "string_array, 0, 3"),
("arrayDistinctInt", "int_array"),
("arrayDistinctString", "string_array"),
("arrayRemoveInt", "int_array, 5"),
("arrayRemoveString", "string_array, 'value'"),
("arrayUnionInt", "int_array1, int_array2"),
("arrayUnionString", "string_array1, string_array2"),
("arrayConcatInt", "int_array1, int_array2"),
("arrayConcatString", "string_array1, string_array2"),
("arrayElementAtInt", "int_array, 0"),
("arrayElementAtString", "string_array, 0"),
("arraySumInt", "int_array"),
("arrayValueConstructor", "1, 2, 3"),
("arrayToString", "array_col, ','"),
# Geospatial functions
("ST_DISTANCE", "point1, point2"),
("ST_CONTAINS", "polygon, point"),
("ST_AREA", "polygon"),
("ST_GEOMFROMTEXT", "'POINT(1 2)'"),
("ST_GEOMFROMWKB", "wkb_col"),
("ST_GEOGFROMWKB", "wkb_col"),
("ST_GEOGFROMTEXT", "'POINT(1 2)'"),
("ST_POINT", "1.0, 2.0"),
("ST_POLYGON", "'POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'"),
("ST_ASBINARY", "geom_col"),
("ST_ASTEXT", "geom_col"),
("ST_GEOMETRYTYPE", "geom_col"),
("ST_EQUALS", "geom1, geom2"),
("ST_WITHIN", "geom1, geom2"),
("ST_UNION", "geom1, geom2"),
("ST_GEOMFROMGEOJSON", '\'{"type":"Point","coordinates":[1,2]}\''),
("ST_GEOGFROMGEOJSON", '\'{"type":"Point","coordinates":[1,2]}\''),
("ST_ASGEOJSON", "geom_col"),
("toSphericalGeography", "geom_col"),
("toGeometry", "geog_col"),
# Binary/Hash functions
("SHA", "'hello'"),
("SHA256", "'hello'"),
("SHA512", "'hello'"),
("SHA224", "'hello'"),
("MD5", "'hello'"),
("MD2", "'hello'"),
("toBase64", "'hello'"),
("fromUtf8", "bytes_col"),
("MurmurHash2", "'hello'"),
("MurmurHash3Bit32", "'hello'"),
# Window functions
("ROW_NUMBER", ""),
("RANK", ""),
("DENSE_RANK", ""),
# Funnel analysis
("FunnelMaxStep", "event_col, 'step1', 'step2', 'step3'"),
("FunnelMatchStep", "event_col, 'step1', 'step2', 'step3'"),
("FunnelCompleteCount", "event_col, 'step1', 'step2', 'step3'"),
# Text search
("TEXT_MATCH", "text_col, 'search query'"),
# Vector functions
("VECTOR_SIMILARITY", "vector1, vector2"),
("l2_distance", "vector1, vector2"),
# Lookup
("LOOKUP", "'lookupTable', 'lookupColumn', 'keyColumn', keyValue"),
# URL functions
("urlProtocol", "'https://example.com/path'"),
("urlDomain", "'https://example.com/path'"),
("urlPath", "'https://example.com/path'"),
("urlPort", "'https://example.com:8080/path'"),
("urlEncode", "'hello world'"),
("urlDecode", "'hello%20world'"),
# Conditional
("COALESCE", "val1, val2, 'default'"),
("NULLIF", "val1, val2"),
("GREATEST", "1, 2, 3"),
("LEAST", "1, 2, 3"),
# Other
("REGEXP_LIKE", "'hello', 'h.*'"),
("GROOVY", "'{return arg0 + arg1}', col1, col2"),
],
)
def test_pinot_function_names_preserved(function_name: str, sample_args: str) -> None:
"""
Test that Pinot function names are preserved during parse/generate roundtrip.

This ensures that when we parse Pinot SQL and generate it back, the function
names remain unchanged. This is critical for maintaining compatibility with
Pinot's function library.
"""
# Special handling for window functions
if function_name in ["ROW_NUMBER", "RANK", "DENSE_RANK"]:
sql = f"SELECT {function_name}() OVER (ORDER BY col) FROM table" # noqa: S608
else:
sql = f"SELECT {function_name}({sample_args}) FROM table" # noqa: S608

# Parse with Pinot dialect
parsed = sqlglot.parse_one(sql, Pinot)

# Generate back to Pinot
generated = parsed.sql(dialect=Pinot)

# The function name should be preserved (case-insensitive check)
assert function_name.upper() in generated.upper(), (
f"Function {function_name} not preserved in output: {generated}"
)
Loading