diff --git a/go.mod b/go.mod index b491bbc0b1..67e360dc57 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00 - github.com/dolthub/go-mysql-server v0.19.1-0.20250317234108-fa97159ff8ce + github.com/dolthub/go-mysql-server v0.19.1-0.20250318170829-6ad8521aefac github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a github.com/fatih/color v1.13.0 diff --git a/go.sum b/go.sum index 9d2db11df4..70cd23c1a4 100644 --- a/go.sum +++ b/go.sum @@ -266,8 +266,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00 h1:rh2ij2yTYKJWlX+c8XRg4H5OzqPewbU1lPK8pcfVmx8= github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= -github.com/dolthub/go-mysql-server v0.19.1-0.20250317234108-fa97159ff8ce h1:ICPJPB1YINTv9YX0DFSVG/qi5NajMWU5vK7pW8cd650= -github.com/dolthub/go-mysql-server v0.19.1-0.20250317234108-fa97159ff8ce/go.mod h1:yr+Vv47/YLOKMgiEY+QxHTlbIVpTuiVtkEZ5l+xruY4= +github.com/dolthub/go-mysql-server v0.19.1-0.20250318170829-6ad8521aefac h1:v9MYsGeu+aqq3aQR0MpY6zHHfwc9vkM7ysJRqmFcsgE= +github.com/dolthub/go-mysql-server v0.19.1-0.20250318170829-6ad8521aefac/go.mod h1:yr+Vv47/YLOKMgiEY+QxHTlbIVpTuiVtkEZ5l+xruY4= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= diff --git a/server/ast/insert.go b/server/ast/insert.go index e05b995119..fa7a6f8e47 100644 --- a/server/ast/insert.go +++ b/server/ast/insert.go @@ -24,15 +24,20 @@ import ( ) // nodeInsert handles *tree.Insert nodes. -func nodeInsert(ctx *Context, node *tree.Insert) (*vitess.Insert, error) { +func nodeInsert(ctx *Context, node *tree.Insert) (insert *vitess.Insert, err error) { if node == nil { return nil, nil } ctx.Auth().PushAuthType(auth.AuthType_INSERT) defer ctx.Auth().PopAuthType() - if _, ok := node.Returning.(*tree.NoReturningClause); !ok { - return nil, errors.Errorf("RETURNING is not yet supported") + var returningExprs vitess.SelectExprs + if returning, ok := node.Returning.(*tree.ReturningExprs); ok { + // TODO: PostgreSQL will apply any triggers before returning the value; need to test this. + returningExprs, err = nodeSelectExprs(ctx, tree.SelectExprs(*returning)) + if err != nil { + return nil, err + } } var ignore string var onDuplicate vitess.OnDup @@ -102,13 +107,14 @@ func nodeInsert(ctx *Context, node *tree.Insert) (*vitess.Insert, error) { } } return &vitess.Insert{ - Action: vitess.InsertStr, - Ignore: ignore, - Table: tableName, - With: with, - Columns: columns, - Rows: rows, - OnDup: onDuplicate, + Action: vitess.InsertStr, + Ignore: ignore, + Table: tableName, + Returning: returningExprs, + With: with, + Columns: columns, + Rows: rows, + OnDup: onDuplicate, Auth: vitess.AuthInformation{ AuthType: auth.AuthType_INSERT, TargetType: auth.AuthTargetType_TableIdentifiers, diff --git a/server/connection_handler.go b/server/connection_handler.go index ff6b844c47..bf17c90771 100644 --- a/server/connection_handler.go +++ b/server/connection_handler.go @@ -516,7 +516,7 @@ func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error { func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error { var fields []pgproto3.FieldDescription var bindvarTypes []uint32 - var tag string + var query ConvertedQuery h.waitForSync = true if message.ObjectType == 'S' { @@ -527,7 +527,7 @@ func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error { fields = preparedStatementData.ReturnFields bindvarTypes = preparedStatementData.BindVarTypes - tag = preparedStatementData.Query.StatementTag + query = preparedStatementData.Query } else { portalData, ok := h.portals[message.Name] if !ok { @@ -535,10 +535,10 @@ func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error { } fields = portalData.Fields - tag = portalData.Query.StatementTag + query = portalData.Query } - return h.sendDescribeResponse(fields, bindvarTypes, tag) + return h.sendDescribeResponse(fields, bindvarTypes, query) } // handleBind handles a bind message, returning any error that occurs @@ -615,7 +615,7 @@ func (h *ConnectionHandler) handleExecute(message *pgproto3.Execute) error { // |rowsAffected| gets altered by the callback below rowsAffected := int32(0) - callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, true) + callback := h.spoolRowsCallback(query, &rowsAffected, true) err = h.doltgresHandler.ComExecuteBound(context.Background(), h.mysqlConn, query.String, portalData.BoundPlan, callback) if err != nil { return err @@ -916,7 +916,7 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error { // |rowsAffected| gets altered by the callback below rowsAffected := int32(0) - callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, false) + callback := h.spoolRowsCallback(query, &rowsAffected, false) err := h.doltgresHandler.ComQuery(context.Background(), h.mysqlConn, query.String, query.AST, callback) if err != nil { if strings.HasPrefix(err.Error(), "syntax error at position") { @@ -930,9 +930,9 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error { // spoolRowsCallback returns a callback function that will send RowDescription message, // then a DataRow message for each row in the result set. -func (h *ConnectionHandler) spoolRowsCallback(tag string, rows *int32, isExecute bool) func(ctx *sql.Context, res *Result) error { +func (h *ConnectionHandler) spoolRowsCallback(query ConvertedQuery, rows *int32, isExecute bool) func(ctx *sql.Context, res *Result) error { // IsIUD returns whether the query is either an INSERT, UPDATE, or DELETE query. - isIUD := tag == "INSERT" || tag == "UPDATE" || tag == "DELETE" + isIUD := query.StatementTag == "INSERT" || query.StatementTag == "UPDATE" || query.StatementTag == "DELETE" return func(ctx *sql.Context, res *Result) error { sess := dsess.DSessFromSess(ctx.Session) for _, notice := range sess.Notices() { @@ -947,7 +947,7 @@ func (h *ConnectionHandler) spoolRowsCallback(tag string, rows *int32, isExecute } sess.ClearNotices() - if returnsRow(tag) { + if returnsRow(query) { // EXECUTE does not send RowDescription; instead it should be sent from DESCRIBE prior to it if !isExecute { if err := h.send(&pgproto3.RowDescription{ @@ -977,7 +977,7 @@ func (h *ConnectionHandler) spoolRowsCallback(tag string, rows *int32, isExecute } // sendDescribeResponse sends a response message for a Describe message -func (h *ConnectionHandler) sendDescribeResponse(fields []pgproto3.FieldDescription, types []uint32, tag string) error { +func (h *ConnectionHandler) sendDescribeResponse(fields []pgproto3.FieldDescription, types []uint32, query ConvertedQuery) error { // The prepared statement variant of the describe command returns the OIDs of the parameters. if types != nil { if err := h.send(&pgproto3.ParameterDescription{ @@ -987,7 +987,7 @@ func (h *ConnectionHandler) sendDescribeResponse(fields []pgproto3.FieldDescript } } - if returnsRow(tag) { + if returnsRow(query) { // Both variants finish with a row description. return h.send(&pgproto3.RowDescription{ Fields: fields, @@ -1182,10 +1182,24 @@ func (h *ConnectionHandler) send(message pgproto3.BackendMessage) error { } // returnsRow returns whether the query returns set of rows such as SELECT and FETCH statements. -func returnsRow(tag string) bool { - switch tag { +func returnsRow(query ConvertedQuery) bool { + switch query.StatementTag { case "SELECT", "SHOW", "FETCH", "EXPLAIN", "SHOW TABLES": return true + case "INSERT": + hasReturningClause := false + sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + switch node := node.(type) { + case *sqlparser.Insert: + if len(node.Returning) > 0 { + hasReturningClause = true + } + return false, nil + } + // this should be impossible, but just in case + return true, nil + }, query.AST) + return hasReturningClause default: return false } diff --git a/testing/generation/command_docs/output/framework_test.go b/testing/generation/command_docs/output/framework_test.go index ce55c18504..59d460d0a3 100644 --- a/testing/generation/command_docs/output/framework_test.go +++ b/testing/generation/command_docs/output/framework_test.go @@ -119,7 +119,7 @@ func RunTests(t *testing.T, tests []QueryParses) { }() if !test.ShouldConvert() { if err == nil && vitessAST != nil { - t.Fatal("Query now converts, please upgrade the type to `Converts`") + t.Fatalf("Query %s now converts, please upgrade the type to `Converts`", test.String()) } return } diff --git a/testing/generation/command_docs/output/insert_test.go b/testing/generation/command_docs/output/insert_test.go index 7e02351b92..1430eba30f 100644 --- a/testing/generation/command_docs/output/insert_test.go +++ b/testing/generation/command_docs/output/insert_test.go @@ -1424,7 +1424,7 @@ func TestInsert(t *testing.T) { Unimplemented("INSERT INTO table_name AS alias ( column_name , column_name ) VALUES ( expression ) , ( expression , expression ) ON CONFLICT ( index_column_name , index_column_name ) WHERE index_predicate DO UPDATE SET ( column_name ) = ROW ( DEFAULT , expression ) RETURNING colname"), Unimplemented("WITH RECURSIVE queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name ( column_name ) VALUES ( DEFAULT , DEFAULT ) , ( expression ) ON CONFLICT ( index_column_name opclass , ( index_expression ) COLLATE en_US opclass ) WHERE index_predicate DO UPDATE SET ( column_name , column_name ) = ROW ( DEFAULT , expression ) RETURNING colname"), Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name VALUES ( DEFAULT ) , ( DEFAULT ) ON CONFLICT ( ( index_expression ) COLLATE en_US , ( index_expression ) ) DO UPDATE SET ( column_name ) = ( expression , DEFAULT ) RETURNING colname"), - Parses("INSERT INTO table_name ( column_name , column_name ) VALUES ( DEFAULT , expression ) , ( expression , DEFAULT ) ON CONFLICT ( index_column_name , index_column_name ) DO UPDATE SET ( column_name ) = ( SELECT 1 ) RETURNING colname"), + Converts("INSERT INTO table_name ( column_name , column_name ) VALUES ( DEFAULT , expression ) , ( expression , DEFAULT ) ON CONFLICT ( index_column_name , index_column_name ) DO UPDATE SET ( column_name ) = ( SELECT 1 ) RETURNING colname"), Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name ( column_name ) VALUES ( DEFAULT , DEFAULT ) , ( DEFAULT , expression ) ON CONFLICT ( ( index_expression ) opclass , ( index_expression ) ) WHERE index_predicate DO UPDATE SET ( column_name ) = ( SELECT 1 ) RETURNING colname"), Unimplemented("WITH queryname AS ( select ) INSERT INTO table_name VALUES ( DEFAULT , DEFAULT ) , ( expression , expression ) ON CONFLICT ( index_column_name COLLATE en_US , ( index_expression ) COLLATE en_US opclass ) WHERE index_predicate DO UPDATE SET ( column_name ) = ( expression ) , column_name = expression RETURNING colname"), Unimplemented("WITH RECURSIVE queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name AS alias ( column_name ) VALUES ( expression , DEFAULT ) ON CONFLICT ( ( index_expression ) COLLATE en_US , ( index_expression ) opclass ) DO UPDATE SET ( column_name ) = ROW ( DEFAULT ) , column_name = expression RETURNING colname"), @@ -2817,7 +2817,7 @@ func TestInsert(t *testing.T) { Unimplemented("WITH RECURSIVE queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name AS alias ( column_name ) VALUES ( DEFAULT , expression ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( index_column_name COLLATE en_US , ( index_expression ) COLLATE en_US ) WHERE index_predicate DO UPDATE SET ( column_name ) = ROW ( expression , DEFAULT ) , ( column_name , column_name ) = ROW ( DEFAULT ) RETURNING colname AS output_name"), Unimplemented("WITH RECURSIVE queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name ( column_name ) VALUES ( expression ) ON CONFLICT ( ( index_expression ) opclass , index_column_name opclass ) DO UPDATE SET ( column_name ) = ROW ( DEFAULT , DEFAULT ) , ( column_name , column_name ) = ROW ( DEFAULT ) RETURNING colname AS output_name"), Unimplemented("INSERT INTO table_name ( column_name , column_name ) VALUES ( expression , expression ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( ( index_expression ) , index_column_name opclass ) WHERE index_predicate DO UPDATE SET ( column_name ) = ( SELECT 1 ) , ( column_name , column_name ) = ROW ( DEFAULT ) RETURNING colname AS output_name"), - Parses("INSERT INTO table_name ( column_name ) VALUES ( DEFAULT , expression ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( index_column_name , index_column_name ) DO UPDATE SET ( column_name , column_name ) = ( expression ) , ( column_name ) = ( expression , expression ) RETURNING colname AS output_name"), + Converts("INSERT INTO table_name ( column_name ) VALUES ( DEFAULT , expression ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( index_column_name , index_column_name ) DO UPDATE SET ( column_name , column_name ) = ( expression ) , ( column_name ) = ( expression , expression ) RETURNING colname AS output_name"), Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name ( column_name ) VALUES ( expression , expression ) ON CONFLICT ( index_column_name , index_column_name opclass ) WHERE index_predicate DO UPDATE SET ( column_name , column_name ) = ROW ( expression ) , ( column_name ) = ( expression , expression ) RETURNING colname AS output_name"), Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name AS alias ( column_name , column_name ) SELECT 1 ON CONFLICT ( index_column_name opclass , index_column_name COLLATE en_US ) DO UPDATE SET ( column_name ) = ( expression , expression ) , ( column_name ) = ( expression , expression ) RETURNING colname AS output_name"), Unimplemented("WITH RECURSIVE queryname AS ( select ) INSERT INTO table_name AS alias ( column_name , column_name ) VALUES ( expression ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( ( index_expression ) opclass , ( index_expression ) COLLATE en_US ) DO UPDATE SET ( column_name , column_name ) = ( expression , expression ) , ( column_name ) = ( expression , expression ) RETURNING colname AS output_name"), @@ -5325,7 +5325,7 @@ func TestInsert(t *testing.T) { Unimplemented("INSERT INTO table_name ( column_name , column_name ) VALUES ( DEFAULT , expression ) , ( expression , expression ) ON CONFLICT ( index_column_name COLLATE en_US opclass , index_column_name opclass ) DO UPDATE SET ( column_name , column_name ) = ( expression , expression ) RETURNING colname , colname output_name"), Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name ( column_name ) VALUES ( expression ) , ( expression ) ON CONFLICT ( index_column_name opclass , index_column_name ) WHERE index_predicate DO UPDATE SET ( column_name , column_name ) = ROW ( expression , expression ) RETURNING colname , colname output_name"), Unimplemented("WITH queryname AS ( select ) INSERT INTO table_name AS alias ( column_name , column_name ) VALUES ( DEFAULT , DEFAULT ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( ( index_expression ) opclass , ( index_expression ) opclass ) WHERE index_predicate DO UPDATE SET ( column_name , column_name ) = ROW ( expression , expression ) RETURNING colname , colname output_name"), - Parses("INSERT INTO table_name ( column_name , column_name ) VALUES ( expression , expression ) ON CONFLICT ( index_column_name , index_column_name ) DO UPDATE SET ( column_name ) = ( DEFAULT , expression ) RETURNING colname , colname output_name"), + Converts("INSERT INTO table_name ( column_name , column_name ) VALUES ( expression , expression ) ON CONFLICT ( index_column_name , index_column_name ) DO UPDATE SET ( column_name ) = ( DEFAULT , expression ) RETURNING colname , colname output_name"), Unimplemented("INSERT INTO table_name VALUES ( expression , DEFAULT ) , ( expression , DEFAULT ) ON CONFLICT ( index_column_name opclass , index_column_name COLLATE en_US ) DO UPDATE SET ( column_name ) = ROW ( DEFAULT , expression ) RETURNING colname , colname output_name"), Unimplemented("WITH queryname AS ( select ) INSERT INTO table_name VALUES ( DEFAULT ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( index_column_name COLLATE en_US ) DO UPDATE SET ( column_name , column_name ) = ROW ( DEFAULT , expression ) RETURNING colname , colname output_name"), Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name AS alias ( column_name , column_name ) VALUES ( expression , DEFAULT ) , ( DEFAULT , expression ) ON CONFLICT ( index_column_name COLLATE en_US opclass ) WHERE index_predicate DO UPDATE SET ( column_name ) = ( expression , DEFAULT ) RETURNING colname , colname output_name"), @@ -6043,7 +6043,7 @@ func TestInsert(t *testing.T) { Unimplemented("WITH queryname AS ( select ) INSERT INTO table_name ( column_name , column_name ) VALUES ( DEFAULT , DEFAULT ) , ( DEFAULT ) ON CONFLICT ( ( index_expression ) COLLATE en_US , index_column_name COLLATE en_US opclass ) WHERE index_predicate DO UPDATE SET ( column_name , column_name ) = ROW ( expression , expression ) , ( column_name , column_name ) = ( expression ) RETURNING colname output_name , colname output_name"), Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name AS alias VALUES ( expression , expression ) , ( expression , expression ) ON CONFLICT ( index_column_name COLLATE en_US , index_column_name COLLATE en_US ) DO UPDATE SET ( column_name , column_name ) = ( DEFAULT , expression ) , ( column_name , column_name ) = ( expression ) RETURNING colname output_name , colname output_name"), Unimplemented("WITH RECURSIVE queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name VALUES ( DEFAULT ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( index_column_name COLLATE en_US opclass , index_column_name COLLATE en_US opclass ) WHERE index_predicate DO UPDATE SET ( column_name , column_name ) = ( DEFAULT , expression ) , ( column_name , column_name ) = ( expression ) RETURNING colname output_name , colname output_name"), - Parses("INSERT INTO table_name VALUES ( expression , expression ) , ( expression , DEFAULT ) ON CONFLICT ( index_column_name ) DO UPDATE SET ( column_name , column_name ) = ( expression , DEFAULT ) , ( column_name , column_name ) = ( expression ) RETURNING colname output_name , colname output_name"), + Converts("INSERT INTO table_name VALUES ( expression , expression ) , ( expression , DEFAULT ) ON CONFLICT ( index_column_name ) DO UPDATE SET ( column_name , column_name ) = ( expression , DEFAULT ) , ( column_name , column_name ) = ( expression ) RETURNING colname output_name , colname output_name"), Unimplemented("WITH RECURSIVE queryname AS ( select ) INSERT INTO table_name ( column_name , column_name ) VALUES ( DEFAULT , DEFAULT ) ON CONFLICT ( ( index_expression ) opclass , index_column_name opclass ) WHERE index_predicate DO UPDATE SET column_name = DEFAULT , ( column_name ) = ROW ( expression ) RETURNING colname output_name , colname output_name"), Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name AS alias VALUES ( expression ) ON CONFLICT ( ( index_expression ) COLLATE en_US opclass , ( index_expression ) COLLATE en_US opclass ) WHERE index_predicate DO UPDATE SET column_name = DEFAULT , ( column_name ) = ROW ( expression ) RETURNING colname output_name , colname output_name"), Unimplemented("WITH queryname AS ( select ) INSERT INTO table_name ( column_name ) VALUES ( expression , expression ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( ( index_expression ) COLLATE en_US opclass , index_column_name COLLATE en_US ) DO UPDATE SET ( column_name ) = ( expression ) , ( column_name ) = ROW ( expression ) RETURNING colname output_name , colname output_name"), diff --git a/testing/go/enginetest/doltgres_engine_test.go b/testing/go/enginetest/doltgres_engine_test.go index f3ea2c5d3d..f5ef408598 100755 --- a/testing/go/enginetest/doltgres_engine_test.go +++ b/testing/go/enginetest/doltgres_engine_test.go @@ -170,40 +170,7 @@ func (dcv *doltCommitValidator) CommitHash(val interface{}) (bool, string) { func TestSingleScript(t *testing.T) { t.Skip() - var scripts = []queries.ScriptTest{ - { - Name: "strings vs decimals with trailing 0s in IN exprs", - SetUpScript: []string{ - "create table t (v varchar(100));", - "insert into t values ('0'), ('0.0'), ('123'), ('123.0');", - "create table t_idx (v varchar(100));", - "create index idx on t_idx(v);", - "insert into t_idx values ('0'), ('0.0'), ('123'), ('123.0');", - }, - Assertions: []queries.ScriptTestAssertion{ - { - Skip: true, - Query: "select * from t where (v in (0.0, 123));", - Expected: []sql.Row{ - {"0"}, - {"0.0"}, - {"123"}, - {"123.0"}, - }, - }, - { - Skip: true, - Query: "select * from t_idx where (v in (0.0, 123));", - Expected: []sql.Row{ - {"0"}, - {"0.0"}, - {"123"}, - {"123.0"}, - }, - }, - }, - }, - } + var scripts = []queries.ScriptTest{} for _, script := range scripts { func() { @@ -215,8 +182,8 @@ func TestSingleScript(t *testing.T) { if err != nil { panic(err) } - // engine.EngineAnalyzer().Debug = true - // engine.EngineAnalyzer().Verbose = true + engine.EngineAnalyzer().Debug = true + engine.EngineAnalyzer().Verbose = true enginetest.TestScriptWithEngine(t, engine, harness, script) }() diff --git a/testing/go/enginetest/doltgres_harness_test.go b/testing/go/enginetest/doltgres_harness_test.go index a003916b5d..313eca00f3 100644 --- a/testing/go/enginetest/doltgres_harness_test.go +++ b/testing/go/enginetest/doltgres_harness_test.go @@ -701,6 +701,8 @@ func getDmlResult(rows pgx.Rows) (sql.Row, bool) { switch true { case tag.Insert(): + // The engine tests are currently all MySQL based, which doesn't support the RETURNING clause. If we decide + // to support this in the future, we will have to do so here. return sql.NewRow(gmstypes.NewOkResult(int(tag.RowsAffected()))), true case tag.Update(): return sql.NewRow(gmstypes.NewOkResult(int(tag.RowsAffected()))), true diff --git a/testing/go/insert_test.go b/testing/go/insert_test.go index 18a933b357..5b9f02839d 100755 --- a/testing/go/insert_test.go +++ b/testing/go/insert_test.go @@ -185,5 +185,66 @@ func TestInsert(t *testing.T) { }, }, }, + { + Name: "insert returning", + SetUpScript: []string{ + "CREATE TABLE t (i serial, j INT)", + "CREATE TABLE u (u uuid DEFAULT 'ac1f3e2d-1e4b-4d3e-8b1f-2b7f1e7f0e3d', j INT)", + "CREATE TABLE s (v1 varchar DEFAULT 'hello', v2 varchar DEFAULT 'world')", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "INSERT INTO t (j) VALUES (5), (6), (7) RETURNING i", + Expected: []sql.Row{ + {1}, {2}, {3}, + }, + }, + { + Query: "INSERT INTO t (j) VALUES (5), (6), (7) RETURNING i+3", + Expected: []sql.Row{ + {7}, {8}, {9}, + }, + }, + { + Query: "INSERT INTO t (j) VALUES (5), (6), (7) RETURNING i+j, j-3*i", + Expected: []sql.Row{ + {12, -16}, {14, -18}, {16, -20}, + }, + }, + { + Query: "INSERT INTO u (j) VALUES (5), (6), (7) RETURNING u", + Expected: []sql.Row{ + {"ac1f3e2d-1e4b-4d3e-8b1f-2b7f1e7f0e3d"}, {"ac1f3e2d-1e4b-4d3e-8b1f-2b7f1e7f0e3d"}, {"ac1f3e2d-1e4b-4d3e-8b1f-2b7f1e7f0e3d"}, + }, + }, + { + Query: "INSERT INTO s (v2) VALUES (' a') RETURNING concat(v1, v2)", + Expected: []sql.Row{ + {"hello a"}, + }, + }, + { + Query: "INSERT INTO s (v1) VALUES ('sup ') RETURNING concat(v1, v2)", + Expected: []sql.Row{ + {"sup world"}, + }, + }, + { + Query: "INSERT INTO s (v2, v1) VALUES ('def', 'abc'), ('xyz', 'uvw') RETURNING concat(v1, v2), concat(v2, v1), 100", + Expected: []sql.Row{ + {"abcdef", "defabc", 100}, + {"uvwxyz", "xyzuvw", 100}, + }, + }, + { + Query: "INSERT INTO t (j) VALUES (5), (6), (7) RETURNING i, doesnotexist", + ExpectedErr: "could not be found", + }, + { + Query: "INSERT INTO t (j) VALUES (5), (6), (7) RETURNING i, doesnotexist(j)", + ExpectedErr: "function: 'doesnotexist' not found", + }, + }, + }, }) }