diff --git a/go.work.sum b/go.work.sum index 2b35677053d0..a45710d76d46 100644 --- a/go.work.sum +++ b/go.work.sum @@ -509,6 +509,7 @@ golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U= golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -551,6 +552,7 @@ golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg= golang.org/x/tools v0.16.0/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= +golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE= golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w= gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU= gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o= diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go index 8a16a26e0d1b..416ba81ac75b 100644 --- a/spanner/spansql/parser.go +++ b/spanner/spansql/parser.go @@ -90,7 +90,7 @@ func ParseDML(filename, s string) (*DML, error) { return dml, nil } -func parseStatements(stmts statements, filename string, s string) error { +func parseStatements(stmts statements, filename, s string) error { p := newParser(filename, s) stmts.setFilename(filename) @@ -4655,8 +4655,26 @@ func (p *parser) parseLit() (Expr, *parseError) { return BytesLiteral(tok.string), nil } - // Handle parenthesized expressions. + // Handle parenthesized expressions and scalar subqueries. + // NOTE: The opening "(" has already been consumed by p.next() above (line 4638). + // The parser is now positioned right after the "(", ready to parse the contents. if tok.value == "(" { + // Look ahead to see if this is a subquery like: (SELECT ...) + // p.sniff("SELECT") peeks at the next token without consuming it. + if p.sniff("SELECT") { + // Parse the subquery starting from the current position (after the "(") + q, err := p.parseQuery() + if err != nil { + return nil, err + } + if err := p.expect(")"); err != nil { + return nil, err + } + return ScalarSubquery{Query: q}, nil + } + + // Regular parenthesized expression like: (1 + 2) + // Parse the inner expression starting from the current position (after the "(") e, err := p.parseExpr() if err != nil { return nil, err diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go index 27cac419466d..87dd311fdbdb 100644 --- a/spanner/spansql/parser_test.go +++ b/spanner/spansql/parser_test.go @@ -365,6 +365,65 @@ func TestParseQuery(t *testing.T) { }, }, }, + // Scalar subqueries in SELECT list - ensures parseLit handles (SELECT ...) correctly + { + `SELECT (SELECT MAX(id) FROM t1), name FROM t2`, + Query{ + Select: Select{ + List: []Expr{ + ScalarSubquery{ + Query: Query{ + Select: Select{ + List: []Expr{Func{Name: "MAX", Args: []Expr{ID("id")}}}, + From: []SelectFrom{SelectFromTable{Table: "t1"}}, + }, + }, + }, + ID("name"), + }, + From: []SelectFrom{SelectFromTable{Table: "t2"}}, + }, + }, + }, + // Scalar subquery in WHERE clause + { + `SELECT * FROM users WHERE age > (SELECT AVG(age) FROM users)`, + Query{ + Select: Select{ + List: []Expr{Star}, + From: []SelectFrom{SelectFromTable{Table: "users"}}, + Where: ComparisonOp{ + Op: Gt, + LHS: ID("age"), + RHS: ScalarSubquery{ + Query: Query{ + Select: Select{ + List: []Expr{Func{Name: "AVG", Args: []Expr{ID("age")}}}, + From: []SelectFrom{SelectFromTable{Table: "users"}}, + }, + }, + }, + }, + }, + }, + }, + // Parenthesized expression in SELECT list - ensures parseLit handles (expr) correctly + { + `SELECT (1 + 2) * 3, name FROM t`, + Query{ + Select: Select{ + List: []Expr{ + ArithOp{ + LHS: Paren{Expr: ArithOp{LHS: IntegerLiteral(1), Op: Add, RHS: IntegerLiteral(2)}}, + Op: Mul, + RHS: IntegerLiteral(3), + }, + ID("name"), + }, + From: []SelectFrom{SelectFromTable{Table: "t"}}, + }, + }, + }, } for _, test := range tests { got, err := ParseQuery(test.in) @@ -634,6 +693,53 @@ func TestParseExpr(t *testing.T) { // Reserved keywords. {`TRUE AND FALSE`, LogicalOp{LHS: True, Op: And, RHS: False}}, {`NULL`, Null}, + + // Parenthesized expressions - test that parseLit correctly handles the opening paren. + // The opening "(" is consumed by p.next(), then parseExpr is called for the contents. + {`(1)`, Paren{Expr: IntegerLiteral(1)}}, + {`(1 + 2)`, Paren{Expr: ArithOp{LHS: IntegerLiteral(1), Op: Add, RHS: IntegerLiteral(2)}}}, + {`((1 + 2))`, Paren{Expr: Paren{Expr: ArithOp{LHS: IntegerLiteral(1), Op: Add, RHS: IntegerLiteral(2)}}}}, + {`(1 + 2) * 3`, ArithOp{LHS: Paren{Expr: ArithOp{LHS: IntegerLiteral(1), Op: Add, RHS: IntegerLiteral(2)}}, Op: Mul, RHS: IntegerLiteral(3)}}, + {`((1 + 2) * 3)`, Paren{Expr: ArithOp{LHS: Paren{Expr: ArithOp{LHS: IntegerLiteral(1), Op: Add, RHS: IntegerLiteral(2)}}, Op: Mul, RHS: IntegerLiteral(3)}}}, + {`(A AND B) OR C`, LogicalOp{LHS: Paren{Expr: LogicalOp{LHS: ID("A"), Op: And, RHS: ID("B")}}, Op: Or, RHS: ID("C")}}, + {`(TRUE)`, Paren{Expr: True}}, + {`(NULL)`, Paren{Expr: Null}}, + + // Scalar subqueries - test that parseLit correctly distinguishes (SELECT ...) from (expr). + // When p.sniff("SELECT") is true after consuming "(", parseQuery is called instead of parseExpr. + { + `(SELECT 1)`, + ScalarSubquery{ + Query: Query{ + Select: Select{ + List: []Expr{IntegerLiteral(1)}, + }, + }, + }, + }, + { + `(SELECT MAX(x) FROM t)`, + ScalarSubquery{ + Query: Query{ + Select: Select{ + List: []Expr{Func{Name: "MAX", Args: []Expr{ID("x")}}}, + From: []SelectFrom{SelectFromTable{Table: "t"}}, + }, + }, + }, + }, + { + `(SELECT COUNT(*) FROM users WHERE active = TRUE)`, + ScalarSubquery{ + Query: Query{ + Select: Select{ + List: []Expr{Func{Name: "COUNT", Args: []Expr{Star}}}, + From: []SelectFrom{SelectFromTable{Table: "users"}}, + Where: ComparisonOp{LHS: ID("active"), Op: Eq, RHS: True}, + }, + }, + }, + }, } for _, test := range tests { p := newParser("test-file", test.in) @@ -2415,6 +2521,95 @@ func TestParseDDL(t *testing.T) { }, }, }, + { + `CREATE OR REPLACE VIEW Transaction SQL SECURITY INVOKER AS SELECT + ID, Name, Amount, AccountID, PaymentID + FROM + Transactions as t + JOIN Accounts as acc ON t.AccountID = acc.ID + JOIN Payment as p ON t.PaymentID = p.ID + AND p.EventSequence = ( + SELECT MAX(p2.EventSequence) + FROM Payment as p2 + WHERE p2.ID = p.ID + )`, + &DDL{ + Filename: "filename", + List: []DDLStmt{ + &CreateView{ + Name: "Transaction", + OrReplace: true, + SecurityType: Invoker, + Query: Query{ + Select: Select{ + List: []Expr{ID("ID"), ID("Name"), ID("Amount"), ID("AccountID"), ID("PaymentID")}, + From: []SelectFrom{ + SelectFromJoin{ + Type: InnerJoin, + LHS: SelectFromJoin{ + Type: InnerJoin, + LHS: SelectFromTable{ + Table: "Transactions", + Alias: "t", + }, + RHS: SelectFromTable{ + Table: "Accounts", + Alias: "acc", + }, + On: ComparisonOp{ + LHS: PathExp{"t", "AccountID"}, + Op: Eq, + RHS: PathExp{"acc", "ID"}, + }, + }, + RHS: SelectFromTable{ + Table: "Payment", + Alias: "p", + }, + On: LogicalOp{ + LHS: ComparisonOp{ + LHS: PathExp{"t", "PaymentID"}, + Op: Eq, + RHS: PathExp{"p", "ID"}, + }, + Op: And, + RHS: ComparisonOp{ + LHS: PathExp{"p", "EventSequence"}, + Op: Eq, + RHS: ScalarSubquery{ + Query: Query{ + Select: Select{ + List: []Expr{ + Func{ + Name: "MAX", + Args: []Expr{PathExp{"p2", "EventSequence"}}, + }, + }, + From: []SelectFrom{ + SelectFromTable{ + Table: "Payment", + Alias: "p2", + }, + }, + Where: ComparisonOp{ + LHS: PathExp{"p2", "ID"}, + Op: Eq, + RHS: PathExp{"p", "ID"}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + Position: line(1), + }, + }, + }, + }, } for _, test := range tests { got, err := ParseDDL("filename", test.in) diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go index 79198d54dda9..bc528d42e6dd 100644 --- a/spanner/spansql/sql.go +++ b/spanner/spansql/sql.go @@ -1084,6 +1084,13 @@ func (eo ExistsOp) addSQL(sb *strings.Builder) { sb.WriteString(")") } +func (ss ScalarSubquery) SQL() string { return buildSQL(ss) } +func (ss ScalarSubquery) addSQL(sb *strings.Builder) { + sb.WriteString("(") + ss.Query.addSQL(sb) + sb.WriteString(")") +} + func (io InOp) SQL() string { return buildSQL(io) } func (io InOp) addSQL(sb *strings.Builder) { io.LHS.addSQL(sb) diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go index e66d3c110709..0f939712145f 100644 --- a/spanner/spansql/types.go +++ b/spanner/spansql/types.go @@ -815,6 +815,12 @@ type ExistsOp struct { func (ExistsOp) isBoolExpr() {} // usually func (ExistsOp) isExpr() {} +type ScalarSubquery struct { + Query Query +} + +func (ScalarSubquery) isExpr() {} + type InOp struct { LHS Expr Neg bool