Skip to content

Commit 9c0147b

Browse files
committed
Add support for PostgreSQL-style shorthand casts
Standard SQL supports static method calls on types via the :: operator. While this syntax is generally incompatible with PostgreSQL shorthand cast syntax, there's a subset of that can be safely repurposed to support that functionality. The SQL specification defines static method calls as: <static method invocation> ::= <path-resolved user-defined type name> <double colon> <method name> [ <SQL argument list> ] where <path-resolved user-defined type name> translates to: <user-defined type name> ::= [ <schema name> <period> ] <qualified identifier> To support casts, we need to extend the rule to support arbitrary expressions as the target of the invocation. To disambiguate a static method call from a cast, we distinguish between type-producing expressions and expressions that produce regular values. For the latter, if the method matches the name of a well-known type, we treat it as a cast. Otherwise, the expression assumed to be a static method call and fail the evaluation with a "not yet supported" error. One limitation is that casts are only supported for simple types. Parametric types are not yet supported, but it wouldn't be too hard to add. Types whose name don't match the syntax of SQL function calls are not supported either and will be harder to support. That will require introducing type-producing expressions into the language and making types first-class expressions.
1 parent 2ee916f commit 9c0147b

File tree

10 files changed

+362
-50
lines changed

10 files changed

+362
-50
lines changed

core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -586,66 +586,70 @@ valueExpression
586586
;
587587

588588
primaryExpression
589-
: literal #literals
590-
| QUESTION_MARK #parameter
591-
| POSITION '(' valueExpression IN valueExpression ')' #position
592-
| '(' expression (',' expression)+ ')' #rowConstructor
593-
| ROW '(' expression (',' expression)* ')' #rowConstructor
589+
: literal #literals
590+
| QUESTION_MARK #parameter
591+
| POSITION '(' valueExpression IN valueExpression ')' #position
592+
| '(' expression (',' expression)+ ')' #rowConstructor
593+
| ROW '(' expression (',' expression)* ')' #rowConstructor
594594
| name=LISTAGG '(' setQuantifier? expression (',' string)?
595595
(ON OVERFLOW listAggOverflowBehavior)? ')'
596596
(WITHIN GROUP '(' orderBy ')')
597-
filter? over? #listagg
597+
filter? over? #listagg
598598
| processingMode? qualifiedName '(' (label=identifier '.')? ASTERISK ')'
599-
filter? over? #functionCall
599+
filter? over? #functionCall
600600
| processingMode? qualifiedName '(' (setQuantifier? expression (',' expression)*)?
601-
orderBy? ')' filter? (nullTreatment? over)? #functionCall
602-
| identifier over #measure
603-
| identifier '->' expression #lambda
604-
| '(' (identifier (',' identifier)*)? ')' '->' expression #lambda
605-
| '(' query ')' #subqueryExpression
601+
orderBy? ')' filter? (nullTreatment? over)? #functionCall
602+
| identifier over #measure
603+
| identifier '->' expression #lambda
604+
| '(' (identifier (',' identifier)*)? ')' '->' expression #lambda
605+
| '(' query ')' #subqueryExpression
606606
// This is an extension to ANSI SQL, which considers EXISTS to be a <boolean expression>
607-
| EXISTS '(' query ')' #exists
608-
| CASE operand=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
609-
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
610-
| CAST '(' expression AS type ')' #cast
611-
| TRY_CAST '(' expression AS type ')' #cast
612-
| ARRAY '[' (expression (',' expression)*)? ']' #arrayConstructor
613-
| '[' (expression (',' expression)*)? ']' #arrayConstructor
614-
| value=primaryExpression '[' index=valueExpression ']' #subscript
615-
| identifier #columnReference
616-
| base=primaryExpression '.' fieldName=identifier #dereference
617-
| name=CURRENT_DATE #currentDate
618-
| name=CURRENT_TIME ('(' precision=INTEGER_VALUE ')')? #currentTime
619-
| name=CURRENT_TIMESTAMP ('(' precision=INTEGER_VALUE ')')? #currentTimestamp
620-
| name=LOCALTIME ('(' precision=INTEGER_VALUE ')')? #localTime
621-
| name=LOCALTIMESTAMP ('(' precision=INTEGER_VALUE ')')? #localTimestamp
622-
| name=CURRENT_USER #currentUser
623-
| name=CURRENT_CATALOG #currentCatalog
624-
| name=CURRENT_SCHEMA #currentSchema
625-
| name=CURRENT_PATH #currentPath
626-
| TRIM '(' (trimsSpecification? trimChar=valueExpression? FROM)?
627-
trimSource=valueExpression ')' #trim
628-
| TRIM '(' trimSource=valueExpression ',' trimChar=valueExpression ')' #trim
629-
| SUBSTRING '(' valueExpression FROM valueExpression (FOR valueExpression)? ')' #substring
630-
| NORMALIZE '(' valueExpression (',' normalForm)? ')' #normalize
631-
| EXTRACT '(' identifier FROM valueExpression ')' #extract
632-
| '(' expression ')' #parenthesizedExpression
633-
| GROUPING '(' (qualifiedName (',' qualifiedName)*)? ')' #groupingOperation
634-
| JSON_EXISTS '(' jsonPathInvocation (jsonExistsErrorBehavior ON ERROR)? ')' #jsonExists
607+
| EXISTS '(' query ')' #exists
608+
| CASE operand=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
609+
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
610+
| CAST '(' expression AS type ')' #cast
611+
| TRY_CAST '(' expression AS type ')' #cast
612+
// the target is a primaryExpression to support PostgreSQL-style casts
613+
// of the form <complex expression>::<type>, which are syntactically ambiguous with
614+
// static method calls defined by the SQL spec (and we reserve it for future use)
615+
| primaryExpression DOUBLE_COLON identifier ('(' (expression (',' expression)*)? ')')? #staticMethodCall
616+
| ARRAY '[' (expression (',' expression)*)? ']' #arrayConstructor
617+
| '[' (expression (',' expression)*)? ']' #arrayConstructor
618+
| value=primaryExpression '[' index=valueExpression ']' #subscript
619+
| identifier #columnReference
620+
| base=primaryExpression '.' fieldName=identifier #dereference
621+
| name=CURRENT_DATE #currentDate
622+
| name=CURRENT_TIME ('(' precision=INTEGER_VALUE ')')? #currentTime
623+
| name=CURRENT_TIMESTAMP ('(' precision=INTEGER_VALUE ')')? #currentTimestamp
624+
| name=LOCALTIME ('(' precision=INTEGER_VALUE ')')? #localTime
625+
| name=LOCALTIMESTAMP ('(' precision=INTEGER_VALUE ')')? #localTimestamp
626+
| name=CURRENT_USER #currentUser
627+
| name=CURRENT_CATALOG #currentCatalog
628+
| name=CURRENT_SCHEMA #currentSchema
629+
| name=CURRENT_PATH #currentPath
630+
| TRIM '(' (trimsSpecification? trimChar=valueExpression? FROM)?
631+
trimSource=valueExpression ')' #trim
632+
| TRIM '(' trimSource=valueExpression ',' trimChar=valueExpression ')' #trim
633+
| SUBSTRING '(' valueExpression FROM valueExpression (FOR valueExpression)? ')' #substring
634+
| NORMALIZE '(' valueExpression (',' normalForm)? ')' #normalize
635+
| EXTRACT '(' identifier FROM valueExpression ')' #extract
636+
| '(' expression ')' #parenthesizedExpression
637+
| GROUPING '(' (qualifiedName (',' qualifiedName)*)? ')' #groupingOperation
638+
| JSON_EXISTS '(' jsonPathInvocation (jsonExistsErrorBehavior ON ERROR)? ')' #jsonExists
635639
| JSON_VALUE '('
636640
jsonPathInvocation
637641
(RETURNING type)?
638642
(emptyBehavior=jsonValueBehavior ON EMPTY)?
639643
(errorBehavior=jsonValueBehavior ON ERROR)?
640-
')' #jsonValue
644+
')' #jsonValue
641645
| JSON_QUERY '('
642646
jsonPathInvocation
643647
(RETURNING type (FORMAT jsonRepresentation)?)?
644648
(jsonQueryWrapperBehavior WRAPPER)?
645649
((KEEP | OMIT) QUOTES (ON SCALAR TEXT_STRING)?)?
646650
(emptyBehavior=jsonQueryBehavior ON EMPTY)?
647651
(errorBehavior=jsonQueryBehavior ON ERROR)?
648-
')' #jsonQuery
652+
')' #jsonQuery
649653
| JSON_OBJECT '('
650654
(
651655
jsonObjectMember (',' jsonObjectMember)*
@@ -1123,6 +1127,7 @@ DISTINCT: 'DISTINCT';
11231127
DISTRIBUTED: 'DISTRIBUTED';
11241128
DO: 'DO';
11251129
DOUBLE: 'DOUBLE';
1130+
DOUBLE_COLON: '::';
11261131
DROP: 'DROP';
11271132
ELSE: 'ELSE';
11281133
EMPTY: 'EMPTY';

core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@
145145
import io.trino.sql.tree.SkipTo;
146146
import io.trino.sql.tree.SortItem;
147147
import io.trino.sql.tree.SortItem.Ordering;
148+
import io.trino.sql.tree.StaticMethodCall;
148149
import io.trino.sql.tree.StringLiteral;
149150
import io.trino.sql.tree.SubqueryExpression;
150151
import io.trino.sql.tree.SubscriptExpression;
@@ -1704,6 +1705,43 @@ private void analyzeFrameRangeOffset(Expression offsetValue, FrameBound.Type bou
17041705
frameBoundCalculations.put(NodeRef.of(offsetValue), function);
17051706
}
17061707

1708+
@Override
1709+
protected Type visitStaticMethodCall(StaticMethodCall node, Context context)
1710+
{
1711+
// PostgreSQL-style casts are syntactically ambiguous with static method calls. So, static method call semantics take precendence.
1712+
// A static method call is characterized by the target being an expression whose type is "type". This not yet supported
1713+
// as a first-class concept, so we fake it by analyzing the expression normally. If the analysis succeeds, we treat it as
1714+
// the target of a cast.
1715+
1716+
// Trino allows resolving column names that match type names, so we need to check explicitly
1717+
// if this is a type reference in the context of a static method call
1718+
if (node.getTarget() instanceof Identifier target) {
1719+
try {
1720+
plannerContext.getTypeManager().fromSqlType(target.getValue());
1721+
throw semanticException(NOT_SUPPORTED, node, "Static method calls are not supported");
1722+
}
1723+
catch (TypeNotFoundException typeException) {
1724+
// since the type is not found, this must be a normal value-producing expression. Treat it as a candidate for
1725+
// resolving the PostgreSQL-style cast, as explained above.
1726+
}
1727+
}
1728+
1729+
if (!node.getArguments().isEmpty()) {
1730+
throw semanticException(NOT_SUPPORTED, node, "Static method calls are not supported");
1731+
}
1732+
1733+
process(node.getTarget(), context);
1734+
1735+
// assume it's a PostgreSQL-style cast unless result type is not a known type
1736+
try {
1737+
Type type = plannerContext.getTypeManager().fromSqlType(node.getMethod().getValue());
1738+
return setExpressionType(node, type);
1739+
}
1740+
catch (Exception e) {
1741+
throw semanticException(NOT_SUPPORTED, node, "Static method calls are not supported");
1742+
}
1743+
}
1744+
17071745
@Override
17081746
protected Type visitWindowOperation(WindowOperation node, Context context)
17091747
{

core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
import io.trino.sql.tree.Row;
110110
import io.trino.sql.tree.SearchedCaseExpression;
111111
import io.trino.sql.tree.SimpleCaseExpression;
112+
import io.trino.sql.tree.StaticMethodCall;
112113
import io.trino.sql.tree.StringLiteral;
113114
import io.trino.sql.tree.SubscriptExpression;
114115
import io.trino.sql.tree.Trim;
@@ -316,6 +317,7 @@ private io.trino.sql.ir.Expression translate(Expression expr, boolean isRoot)
316317
case io.trino.sql.tree.FieldReference expression -> translate(expression);
317318
case Identifier expression -> translate(expression);
318319
case FunctionCall expression -> translate(expression);
320+
case StaticMethodCall expression -> translate(expression);
319321
case DereferenceExpression expression -> translate(expression);
320322
case Array expression -> translate(expression);
321323
case CurrentCatalog expression -> translate(expression);
@@ -663,6 +665,14 @@ private io.trino.sql.ir.Expression translate(FunctionCall expression)
663665
.collect(toImmutableList()));
664666
}
665667

668+
private io.trino.sql.ir.Expression translate(StaticMethodCall expression)
669+
{
670+
// Currently, only PostgreSQL-style cast shorthand expressions are supported
671+
return new io.trino.sql.ir.Cast(
672+
translateExpression(expression.getTarget()),
673+
analysis.getType(expression));
674+
}
675+
666676
private io.trino.sql.ir.Expression translate(DereferenceExpression expression)
667677
{
668678
if (analysis.isColumnReference(expression)) {
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.operator.scalar;
15+
16+
import io.trino.spi.type.DoubleType;
17+
import io.trino.spi.type.VarcharType;
18+
import io.trino.sql.query.QueryAssertions;
19+
import org.junit.jupiter.api.AfterAll;
20+
import org.junit.jupiter.api.BeforeAll;
21+
import org.junit.jupiter.api.Test;
22+
import org.junit.jupiter.api.TestInstance;
23+
import org.junit.jupiter.api.parallel.Execution;
24+
25+
import static org.assertj.core.api.Assertions.assertThat;
26+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
27+
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
28+
import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT;
29+
30+
@TestInstance(PER_CLASS)
31+
@Execution(CONCURRENT)
32+
public class TestStaticMethodCall
33+
{
34+
private QueryAssertions assertions;
35+
36+
@BeforeAll
37+
public void init()
38+
{
39+
assertions = new QueryAssertions();
40+
}
41+
42+
@AfterAll
43+
public void teardown()
44+
{
45+
assertions.close();
46+
assertions = null;
47+
}
48+
49+
@Test
50+
void testPostgreSqlStyleCast()
51+
{
52+
assertThat(assertions.expression("1::double"))
53+
.hasType(DoubleType.DOUBLE)
54+
.isEqualTo(1.0);
55+
56+
assertThat(assertions.expression("1::varchar"))
57+
.hasType(VarcharType.VARCHAR)
58+
.isEqualTo("1");
59+
60+
assertThatThrownBy(() -> assertions.expression("1::varchar(100)").evaluate())
61+
.hasMessage("line 1:13: Static method calls are not supported");
62+
63+
assertThat(assertions.expression("(a + b)::double")
64+
.binding("a", "1")
65+
.binding("b", "2"))
66+
.hasType(DoubleType.DOUBLE)
67+
.isEqualTo(3.0);
68+
69+
assertThatThrownBy(() -> assertions.expression("1::decimal(3, 2)").evaluate())
70+
.hasMessage("line 1:13: Static method calls are not supported");
71+
}
72+
73+
@Test
74+
void testCall()
75+
{
76+
assertThatThrownBy(() -> assertions.expression("1::double(2)").evaluate())
77+
.hasMessage("line 1:13: Static method calls are not supported");
78+
79+
assertThatThrownBy(() -> assertions.expression("1::foo").evaluate())
80+
.hasMessage("line 1:13: Static method calls are not supported");
81+
82+
assertThatThrownBy(() -> assertions.expression("integer::foo").evaluate())
83+
.hasMessage("line 1:19: Static method calls are not supported");
84+
85+
assertThatThrownBy(() -> assertions.expression("integer::foo(1, 2)").evaluate())
86+
.hasMessage("line 1:19: Static method calls are not supported");
87+
88+
assertThat(assertions.query("SELECT bigint::real FROM (VALUES 1) AS t(bigint)"))
89+
.failure()
90+
.hasMessage("line 1:14: Static method calls are not supported");
91+
}
92+
}

core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
import io.trino.sql.tree.SimpleGroupBy;
9292
import io.trino.sql.tree.SkipTo;
9393
import io.trino.sql.tree.SortItem;
94+
import io.trino.sql.tree.StaticMethodCall;
9495
import io.trino.sql.tree.StringLiteral;
9596
import io.trino.sql.tree.SubqueryExpression;
9697
import io.trino.sql.tree.SubscriptExpression;
@@ -467,6 +468,24 @@ protected String visitFunctionCall(FunctionCall node, Void context)
467468
return builder.toString();
468469
}
469470

471+
@Override
472+
protected String visitStaticMethodCall(StaticMethodCall node, Void context)
473+
{
474+
StringBuilder builder = new StringBuilder();
475+
476+
builder.append(process(node.getTarget(), context))
477+
.append("::")
478+
.append(process(node.getMethod(), context));
479+
480+
if (!node.getArguments().isEmpty()) {
481+
builder.append('(')
482+
.append(joinExpressions(node.getArguments()))
483+
.append(')');
484+
}
485+
486+
return builder.toString();
487+
}
488+
470489
@Override
471490
protected String visitWindowOperation(WindowOperation node, Void context)
472491
{

core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@
278278
import io.trino.sql.tree.SortItem;
279279
import io.trino.sql.tree.StartTransaction;
280280
import io.trino.sql.tree.Statement;
281+
import io.trino.sql.tree.StaticMethodCall;
281282
import io.trino.sql.tree.StringLiteral;
282283
import io.trino.sql.tree.SubqueryExpression;
283284
import io.trino.sql.tree.SubscriptExpression;
@@ -3104,6 +3105,16 @@ else if (processingMode.FINAL() != null) {
31043105
arguments);
31053106
}
31063107

3108+
@Override
3109+
public Node visitStaticMethodCall(SqlBaseParser.StaticMethodCallContext context)
3110+
{
3111+
return new StaticMethodCall(
3112+
getLocation(context.DOUBLE_COLON()),
3113+
(Expression) visit(context.primaryExpression()),
3114+
(Identifier) visit(context.identifier()),
3115+
visit(context.expression(), Expression.class));
3116+
}
3117+
31073118
@Override
31083119
public Node visitMeasure(SqlBaseParser.MeasureContext context)
31093120
{

core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,11 @@ protected R visitFunctionCall(FunctionCall node, C context)
322322
return visitExpression(node, context);
323323
}
324324

325+
protected R visitStaticMethodCall(StaticMethodCall node, C context)
326+
{
327+
return visitExpression(node, context);
328+
}
329+
325330
protected R visitProcessingMode(ProcessingMode node, C context)
326331
{
327332
return visitNode(node, context);

0 commit comments

Comments
 (0)