diff --git a/server/ast/expr.go b/server/ast/expr.go index 2899eccfaf..f4d8075877 100644 --- a/server/ast/expr.go +++ b/server/ast/expr.go @@ -607,7 +607,26 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) { // TODO: figure out if I can delete this return nil, errors.Errorf("this should probably be deleted (internal error, IndexedVar)") case *tree.IndirectionExpr: - return nil, errors.Errorf("subscripts are not yet supported") + childExpr, err := nodeExpr(ctx, node.Expr) + if err != nil { + return nil, err + } + + if len(node.Indirection) > 1 { + return nil, errors.Errorf("multi dimensional array subscripts are not yet supported") + } else if node.Indirection[0].Slice { + return nil, errors.Errorf("slice subscripts are not yet supported") + } + + indexExpr, err := nodeExpr(ctx, node.Indirection[0].Begin) + if err != nil { + return nil, err + } + + return vitess.InjectedExpr{ + Expression: &pgexprs.Subscript{}, + Children: vitess.Exprs{childExpr, indexExpr}, + }, nil case *tree.IsNotNullExpr: expr, err := nodeExpr(ctx, node.Expr) if err != nil { diff --git a/server/expression/subscript.go b/server/expression/subscript.go new file mode 100755 index 0000000000..08961018d8 --- /dev/null +++ b/server/expression/subscript.go @@ -0,0 +1,134 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/sql" + vitess "github.com/dolthub/vitess/go/vt/sqlparser" + + "github.com/dolthub/doltgresql/server/types" +) + +// Subscript represents a subscript expression, e.g. `a[1]`. +type Subscript struct { + Child sql.Expression + Index sql.Expression +} + +var _ vitess.Injectable = (*Subscript)(nil) +var _ sql.Expression = (*Subscript)(nil) + +// NewSubscript creates a new Subscript expression. +func NewSubscript(child, index sql.Expression) *Subscript { + return &Subscript{ + Child: child, + Index: index, + } +} + +// Resolved implements the sql.Expression interface. +func (s Subscript) Resolved() bool { + return s.Child.Resolved() && s.Index.Resolved() +} + +// String implements the sql.Expression interface. +func (s Subscript) String() string { + return fmt.Sprintf("%s[%s]", s.Child, s.Index) +} + +// Type implements the sql.Expression interface. +func (s Subscript) Type() sql.Type { + dt, ok := s.Child.Type().(*types.DoltgresType) + if !ok { + panic(fmt.Sprintf("unexpected type %T for subscript", s.Child.Type())) + } + return dt.ArrayBaseType() +} + +// IsNullable implements the sql.Expression interface. +func (s Subscript) IsNullable() bool { + return true +} + +// Eval implements the sql.Expression interface. +func (s Subscript) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + childVal, err := s.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + if childVal == nil { + return nil, nil + } + + indexVal, err := s.Index.Eval(ctx, row) + if err != nil { + return nil, err + } + if indexVal == nil { + return nil, nil + } + + switch child := childVal.(type) { + case []interface{}: + index, ok := indexVal.(int32) + if !ok { + converted, _, err := types.Int32.Convert(ctx, indexVal) + if err != nil { + return nil, err + } + index = converted.(int32) + } + + // subscripts are 1-based + if index < 1 || int(index) > len(child) { + return nil, nil + } + return child[index-1], nil + default: + return nil, fmt.Errorf("unsupported type %T for subscript", child) + } +} + +// Children implements the sql.Expression interface. +func (s Subscript) Children() []sql.Expression { + return []sql.Expression{s.Child, s.Index} +} + +// WithChildren implements the sql.Expression interface. +func (s Subscript) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 2 { + return nil, fmt.Errorf("expected 2 children, got %d", len(children)) + } + return NewSubscript(children[0], children[1]), nil +} + +// WithResolvedChildren implements the vitess.Injectable interface. +func (s Subscript) WithResolvedChildren(children []any) (any, error) { + if len(children) != 2 { + return nil, fmt.Errorf("expected 2 children, got %d", len(children)) + } + child, ok := children[0].(sql.Expression) + if !ok { + return nil, fmt.Errorf("expected child to be an expression but has type `%T`", children[0]) + } + index, ok := children[1].(sql.Expression) + if !ok { + return nil, fmt.Errorf("expected index to be an expression but has type `%T`", children[1]) + } + + return NewSubscript(child, index), nil +} diff --git a/testing/go/expressions_test.go b/testing/go/expressions_test.go index 894a08d59e..2e2f50add5 100644 --- a/testing/go/expressions_test.go +++ b/testing/go/expressions_test.go @@ -21,7 +21,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" ) -func TestExpressions(t *testing.T) { +func TestIn(t *testing.T) { RunScriptsWithoutNormalization(t, []ScriptTest{ anyTests("ANY"), anyTests("SOME"), @@ -292,3 +292,78 @@ func TestBinaryLogic(t *testing.T) { }, }) } + +// Note that our parser is more forgiving of array subscripts than the actual Postgres parser. +// We can handle this: SELECT ARRAY[1, 2, 3][1] +// But postgres requires: SELECT (ARRAY[1, 2, 3])[1] +func TestSubscript(t *testing.T) { + RunScripts(t, []ScriptTest{ + { + Name: "array literal", + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT ARRAY[1, 2, 3][1];`, + Expected: []sql.Row{{1}}, + }, + { + Query: `SELECT (ARRAY[1, 2, 3])[3];`, + Expected: []sql.Row{{3}}, + }, + { + Query: `SELECT (ARRAY[1, 2, 3])[1+1];`, + Expected: []sql.Row{{2}}, + }, + { + Query: `SELECT ARRAY[1, 2, 3][0];`, + Expected: []sql.Row{{nil}}, + }, + { + Query: `SELECT ARRAY[1, 2, 3][4];`, + Expected: []sql.Row{{nil}}, + }, + { + Query: `SELECT ARRAY[1, 2, 3][null];`, + Expected: []sql.Row{{nil}}, + }, + { + Query: `SELECT ARRAY['a', 'b', 'c'][2];`, + Expected: []sql.Row{{"b"}}, + }, + { + Query: `SELECT ARRAY[1, 2, 3][1:3];`, + ExpectedErr: "not yet supported", + }, + { + Query: `SELECT ARRAY[1, 2, 3]['abc'];`, + ExpectedErr: "integer: unhandled type: string", + }, + }, + }, + { + Name: "array column", + SetUpScript: []string{ + `CREATE TABLE test (id INT, arr INT[]);`, + `INSERT INTO test VALUES (1, ARRAY[1, 2, 3]), (2, ARRAY[4, 5, 6]);`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT arr[2] FROM test order by 1;`, + Expected: []sql.Row{{2}, {5}}, + }, + }, + }, + { + Name: "array subquery", + SetUpScript: []string{ + "CREATE TABLE test (id INT);", + "INSERT INTO test VALUES (1), (2), (3);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT (array(select id from test order by 1))[2]`, + Expected: []sql.Row{{2}}, + }, + }, + }, + }) +} diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index eef6284020..70aa685919 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -2310,5 +2310,39 @@ func TestSelectFromFunctions(t *testing.T) { }, }, }, + { + Name: "test select from dolt_ functions", + Skip: true, // need a way for single-row functions to declare a schema like table functions do, maybe just by modeling them as table functions in the first place + SetUpScript: []string{ + "CREATE TABLE test (pk INT primary key, v1 INT, v2 TEXT);", + "INSERT INTO test VALUES (1, 1, 'a'), (2, 2, 'b'), (3, 3, 'c'), (4, 4, 'd'), (5, 5, 'e');", + "call dolt_commit('-Am', 'first table');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: `select * from dolt_branch('newBranch')`, + Expected: []sql.Row{{0}}, + }, + { + Query: `select status from dolt_checkout('newBranch')`, + Expected: []sql.Row{{0}}, + }, + { + Query: `insert into test values (6, 6, 'f')`, + }, + { + Query: `select length(commit_hash) > 0 from (select commit_hash from dolt_commit('-Am', 'added f') as result)`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: "select dolt_checkout('main')", + Expected: []sql.Row{{0}}, + }, + { + Query: `select fast_forward, conflicts from dolt_merge('newBranch')`, + Expected: []sql.Row{{"t", 0}}, + }, + }, + }, }) }