Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ require (
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00
github.com/dolthub/go-mysql-server v0.19.1-0.20250317234108-fa97159ff8ce
github.com/dolthub/go-mysql-server v0.19.1-0.20250318170829-6ad8521aefac
github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216
github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a
github.com/fatih/color v1.13.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00 h1:rh2ij2yTYKJWlX+c8XRg4H5OzqPewbU1lPK8pcfVmx8=
github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA=
github.com/dolthub/go-mysql-server v0.19.1-0.20250317234108-fa97159ff8ce h1:ICPJPB1YINTv9YX0DFSVG/qi5NajMWU5vK7pW8cd650=
github.com/dolthub/go-mysql-server v0.19.1-0.20250317234108-fa97159ff8ce/go.mod h1:yr+Vv47/YLOKMgiEY+QxHTlbIVpTuiVtkEZ5l+xruY4=
github.com/dolthub/go-mysql-server v0.19.1-0.20250318170829-6ad8521aefac h1:v9MYsGeu+aqq3aQR0MpY6zHHfwc9vkM7ysJRqmFcsgE=
github.com/dolthub/go-mysql-server v0.19.1-0.20250318170829-6ad8521aefac/go.mod h1:yr+Vv47/YLOKMgiEY+QxHTlbIVpTuiVtkEZ5l+xruY4=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q=
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE=
Expand Down
26 changes: 16 additions & 10 deletions server/ast/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,20 @@ import (
)

// nodeInsert handles *tree.Insert nodes.
func nodeInsert(ctx *Context, node *tree.Insert) (*vitess.Insert, error) {
func nodeInsert(ctx *Context, node *tree.Insert) (insert *vitess.Insert, err error) {
if node == nil {
return nil, nil
}
ctx.Auth().PushAuthType(auth.AuthType_INSERT)
defer ctx.Auth().PopAuthType()

if _, ok := node.Returning.(*tree.NoReturningClause); !ok {
return nil, errors.Errorf("RETURNING is not yet supported")
var returningExprs vitess.SelectExprs
if returning, ok := node.Returning.(*tree.ReturningExprs); ok {
// TODO: PostgreSQL will apply any triggers before returning the value; need to test this.
returningExprs, err = nodeSelectExprs(ctx, tree.SelectExprs(*returning))
if err != nil {
return nil, err
}
}
var ignore string
var onDuplicate vitess.OnDup
Expand Down Expand Up @@ -102,13 +107,14 @@ func nodeInsert(ctx *Context, node *tree.Insert) (*vitess.Insert, error) {
}
}
return &vitess.Insert{
Action: vitess.InsertStr,
Ignore: ignore,
Table: tableName,
With: with,
Columns: columns,
Rows: rows,
OnDup: onDuplicate,
Action: vitess.InsertStr,
Ignore: ignore,
Table: tableName,
Returning: returningExprs,
With: with,
Columns: columns,
Rows: rows,
OnDup: onDuplicate,
Auth: vitess.AuthInformation{
AuthType: auth.AuthType_INSERT,
TargetType: auth.AuthTargetType_TableIdentifiers,
Expand Down
40 changes: 27 additions & 13 deletions server/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error {
func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error {
var fields []pgproto3.FieldDescription
var bindvarTypes []uint32
var tag string
var query ConvertedQuery

h.waitForSync = true
if message.ObjectType == 'S' {
Expand All @@ -527,18 +527,18 @@ func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error {

fields = preparedStatementData.ReturnFields
bindvarTypes = preparedStatementData.BindVarTypes
tag = preparedStatementData.Query.StatementTag
query = preparedStatementData.Query
} else {
portalData, ok := h.portals[message.Name]
if !ok {
return errors.Errorf("portal %s does not exist", message.Name)
}

fields = portalData.Fields
tag = portalData.Query.StatementTag
query = portalData.Query
}

return h.sendDescribeResponse(fields, bindvarTypes, tag)
return h.sendDescribeResponse(fields, bindvarTypes, query)
}

// handleBind handles a bind message, returning any error that occurs
Expand Down Expand Up @@ -615,7 +615,7 @@ func (h *ConnectionHandler) handleExecute(message *pgproto3.Execute) error {
// |rowsAffected| gets altered by the callback below
rowsAffected := int32(0)

callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, true)
callback := h.spoolRowsCallback(query, &rowsAffected, true)
err = h.doltgresHandler.ComExecuteBound(context.Background(), h.mysqlConn, query.String, portalData.BoundPlan, callback)
if err != nil {
return err
Expand Down Expand Up @@ -916,7 +916,7 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error {
// |rowsAffected| gets altered by the callback below
rowsAffected := int32(0)

callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, false)
callback := h.spoolRowsCallback(query, &rowsAffected, false)
err := h.doltgresHandler.ComQuery(context.Background(), h.mysqlConn, query.String, query.AST, callback)
if err != nil {
if strings.HasPrefix(err.Error(), "syntax error at position") {
Expand All @@ -930,9 +930,9 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error {

// spoolRowsCallback returns a callback function that will send RowDescription message,
// then a DataRow message for each row in the result set.
func (h *ConnectionHandler) spoolRowsCallback(tag string, rows *int32, isExecute bool) func(ctx *sql.Context, res *Result) error {
func (h *ConnectionHandler) spoolRowsCallback(query ConvertedQuery, rows *int32, isExecute bool) func(ctx *sql.Context, res *Result) error {
// IsIUD returns whether the query is either an INSERT, UPDATE, or DELETE query.
isIUD := tag == "INSERT" || tag == "UPDATE" || tag == "DELETE"
isIUD := query.StatementTag == "INSERT" || query.StatementTag == "UPDATE" || query.StatementTag == "DELETE"
return func(ctx *sql.Context, res *Result) error {
sess := dsess.DSessFromSess(ctx.Session)
for _, notice := range sess.Notices() {
Expand All @@ -947,7 +947,7 @@ func (h *ConnectionHandler) spoolRowsCallback(tag string, rows *int32, isExecute
}
sess.ClearNotices()

if returnsRow(tag) {
if returnsRow(query) {
// EXECUTE does not send RowDescription; instead it should be sent from DESCRIBE prior to it
if !isExecute {
if err := h.send(&pgproto3.RowDescription{
Expand Down Expand Up @@ -977,7 +977,7 @@ func (h *ConnectionHandler) spoolRowsCallback(tag string, rows *int32, isExecute
}

// sendDescribeResponse sends a response message for a Describe message
func (h *ConnectionHandler) sendDescribeResponse(fields []pgproto3.FieldDescription, types []uint32, tag string) error {
func (h *ConnectionHandler) sendDescribeResponse(fields []pgproto3.FieldDescription, types []uint32, query ConvertedQuery) error {
// The prepared statement variant of the describe command returns the OIDs of the parameters.
if types != nil {
if err := h.send(&pgproto3.ParameterDescription{
Expand All @@ -987,7 +987,7 @@ func (h *ConnectionHandler) sendDescribeResponse(fields []pgproto3.FieldDescript
}
}

if returnsRow(tag) {
if returnsRow(query) {
// Both variants finish with a row description.
return h.send(&pgproto3.RowDescription{
Fields: fields,
Expand Down Expand Up @@ -1182,10 +1182,24 @@ func (h *ConnectionHandler) send(message pgproto3.BackendMessage) error {
}

// returnsRow returns whether the query returns set of rows such as SELECT and FETCH statements.
func returnsRow(tag string) bool {
switch tag {
func returnsRow(query ConvertedQuery) bool {
switch query.StatementTag {
case "SELECT", "SHOW", "FETCH", "EXPLAIN", "SHOW TABLES":
return true
case "INSERT":
hasReturningClause := false
sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch node := node.(type) {
case *sqlparser.Insert:
if len(node.Returning) > 0 {
hasReturningClause = true
}
return false, nil
}
// this should be impossible, but just in case
return true, nil
}, query.AST)
return hasReturningClause
default:
return false
}
Expand Down
2 changes: 1 addition & 1 deletion testing/generation/command_docs/output/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func RunTests(t *testing.T, tests []QueryParses) {
}()
if !test.ShouldConvert() {
if err == nil && vitessAST != nil {
t.Fatal("Query now converts, please upgrade the type to `Converts`")
t.Fatalf("Query %s now converts, please upgrade the type to `Converts`", test.String())
}
return
}
Expand Down
8 changes: 4 additions & 4 deletions testing/generation/command_docs/output/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1424,7 +1424,7 @@ func TestInsert(t *testing.T) {
Unimplemented("INSERT INTO table_name AS alias ( column_name , column_name ) VALUES ( expression ) , ( expression , expression ) ON CONFLICT ( index_column_name , index_column_name ) WHERE index_predicate DO UPDATE SET ( column_name ) = ROW ( DEFAULT , expression ) RETURNING colname"),
Unimplemented("WITH RECURSIVE queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name ( column_name ) VALUES ( DEFAULT , DEFAULT ) , ( expression ) ON CONFLICT ( index_column_name opclass , ( index_expression ) COLLATE en_US opclass ) WHERE index_predicate DO UPDATE SET ( column_name , column_name ) = ROW ( DEFAULT , expression ) RETURNING colname"),
Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name VALUES ( DEFAULT ) , ( DEFAULT ) ON CONFLICT ( ( index_expression ) COLLATE en_US , ( index_expression ) ) DO UPDATE SET ( column_name ) = ( expression , DEFAULT ) RETURNING colname"),
Parses("INSERT INTO table_name ( column_name , column_name ) VALUES ( DEFAULT , expression ) , ( expression , DEFAULT ) ON CONFLICT ( index_column_name , index_column_name ) DO UPDATE SET ( column_name ) = ( SELECT 1 ) RETURNING colname"),
Converts("INSERT INTO table_name ( column_name , column_name ) VALUES ( DEFAULT , expression ) , ( expression , DEFAULT ) ON CONFLICT ( index_column_name , index_column_name ) DO UPDATE SET ( column_name ) = ( SELECT 1 ) RETURNING colname"),
Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name ( column_name ) VALUES ( DEFAULT , DEFAULT ) , ( DEFAULT , expression ) ON CONFLICT ( ( index_expression ) opclass , ( index_expression ) ) WHERE index_predicate DO UPDATE SET ( column_name ) = ( SELECT 1 ) RETURNING colname"),
Unimplemented("WITH queryname AS ( select ) INSERT INTO table_name VALUES ( DEFAULT , DEFAULT ) , ( expression , expression ) ON CONFLICT ( index_column_name COLLATE en_US , ( index_expression ) COLLATE en_US opclass ) WHERE index_predicate DO UPDATE SET ( column_name ) = ( expression ) , column_name = expression RETURNING colname"),
Unimplemented("WITH RECURSIVE queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name AS alias ( column_name ) VALUES ( expression , DEFAULT ) ON CONFLICT ( ( index_expression ) COLLATE en_US , ( index_expression ) opclass ) DO UPDATE SET ( column_name ) = ROW ( DEFAULT ) , column_name = expression RETURNING colname"),
Expand Down Expand Up @@ -2817,7 +2817,7 @@ func TestInsert(t *testing.T) {
Unimplemented("WITH RECURSIVE queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name AS alias ( column_name ) VALUES ( DEFAULT , expression ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( index_column_name COLLATE en_US , ( index_expression ) COLLATE en_US ) WHERE index_predicate DO UPDATE SET ( column_name ) = ROW ( expression , DEFAULT ) , ( column_name , column_name ) = ROW ( DEFAULT ) RETURNING colname AS output_name"),
Unimplemented("WITH RECURSIVE queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name ( column_name ) VALUES ( expression ) ON CONFLICT ( ( index_expression ) opclass , index_column_name opclass ) DO UPDATE SET ( column_name ) = ROW ( DEFAULT , DEFAULT ) , ( column_name , column_name ) = ROW ( DEFAULT ) RETURNING colname AS output_name"),
Unimplemented("INSERT INTO table_name ( column_name , column_name ) VALUES ( expression , expression ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( ( index_expression ) , index_column_name opclass ) WHERE index_predicate DO UPDATE SET ( column_name ) = ( SELECT 1 ) , ( column_name , column_name ) = ROW ( DEFAULT ) RETURNING colname AS output_name"),
Parses("INSERT INTO table_name ( column_name ) VALUES ( DEFAULT , expression ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( index_column_name , index_column_name ) DO UPDATE SET ( column_name , column_name ) = ( expression ) , ( column_name ) = ( expression , expression ) RETURNING colname AS output_name"),
Converts("INSERT INTO table_name ( column_name ) VALUES ( DEFAULT , expression ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( index_column_name , index_column_name ) DO UPDATE SET ( column_name , column_name ) = ( expression ) , ( column_name ) = ( expression , expression ) RETURNING colname AS output_name"),
Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name ( column_name ) VALUES ( expression , expression ) ON CONFLICT ( index_column_name , index_column_name opclass ) WHERE index_predicate DO UPDATE SET ( column_name , column_name ) = ROW ( expression ) , ( column_name ) = ( expression , expression ) RETURNING colname AS output_name"),
Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name AS alias ( column_name , column_name ) SELECT 1 ON CONFLICT ( index_column_name opclass , index_column_name COLLATE en_US ) DO UPDATE SET ( column_name ) = ( expression , expression ) , ( column_name ) = ( expression , expression ) RETURNING colname AS output_name"),
Unimplemented("WITH RECURSIVE queryname AS ( select ) INSERT INTO table_name AS alias ( column_name , column_name ) VALUES ( expression ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( ( index_expression ) opclass , ( index_expression ) COLLATE en_US ) DO UPDATE SET ( column_name , column_name ) = ( expression , expression ) , ( column_name ) = ( expression , expression ) RETURNING colname AS output_name"),
Expand Down Expand Up @@ -5325,7 +5325,7 @@ func TestInsert(t *testing.T) {
Unimplemented("INSERT INTO table_name ( column_name , column_name ) VALUES ( DEFAULT , expression ) , ( expression , expression ) ON CONFLICT ( index_column_name COLLATE en_US opclass , index_column_name opclass ) DO UPDATE SET ( column_name , column_name ) = ( expression , expression ) RETURNING colname , colname output_name"),
Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name ( column_name ) VALUES ( expression ) , ( expression ) ON CONFLICT ( index_column_name opclass , index_column_name ) WHERE index_predicate DO UPDATE SET ( column_name , column_name ) = ROW ( expression , expression ) RETURNING colname , colname output_name"),
Unimplemented("WITH queryname AS ( select ) INSERT INTO table_name AS alias ( column_name , column_name ) VALUES ( DEFAULT , DEFAULT ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( ( index_expression ) opclass , ( index_expression ) opclass ) WHERE index_predicate DO UPDATE SET ( column_name , column_name ) = ROW ( expression , expression ) RETURNING colname , colname output_name"),
Parses("INSERT INTO table_name ( column_name , column_name ) VALUES ( expression , expression ) ON CONFLICT ( index_column_name , index_column_name ) DO UPDATE SET ( column_name ) = ( DEFAULT , expression ) RETURNING colname , colname output_name"),
Converts("INSERT INTO table_name ( column_name , column_name ) VALUES ( expression , expression ) ON CONFLICT ( index_column_name , index_column_name ) DO UPDATE SET ( column_name ) = ( DEFAULT , expression ) RETURNING colname , colname output_name"),
Unimplemented("INSERT INTO table_name VALUES ( expression , DEFAULT ) , ( expression , DEFAULT ) ON CONFLICT ( index_column_name opclass , index_column_name COLLATE en_US ) DO UPDATE SET ( column_name ) = ROW ( DEFAULT , expression ) RETURNING colname , colname output_name"),
Unimplemented("WITH queryname AS ( select ) INSERT INTO table_name VALUES ( DEFAULT ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( index_column_name COLLATE en_US ) DO UPDATE SET ( column_name , column_name ) = ROW ( DEFAULT , expression ) RETURNING colname , colname output_name"),
Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name AS alias ( column_name , column_name ) VALUES ( expression , DEFAULT ) , ( DEFAULT , expression ) ON CONFLICT ( index_column_name COLLATE en_US opclass ) WHERE index_predicate DO UPDATE SET ( column_name ) = ( expression , DEFAULT ) RETURNING colname , colname output_name"),
Expand Down Expand Up @@ -6043,7 +6043,7 @@ func TestInsert(t *testing.T) {
Unimplemented("WITH queryname AS ( select ) INSERT INTO table_name ( column_name , column_name ) VALUES ( DEFAULT , DEFAULT ) , ( DEFAULT ) ON CONFLICT ( ( index_expression ) COLLATE en_US , index_column_name COLLATE en_US opclass ) WHERE index_predicate DO UPDATE SET ( column_name , column_name ) = ROW ( expression , expression ) , ( column_name , column_name ) = ( expression ) RETURNING colname output_name , colname output_name"),
Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name AS alias VALUES ( expression , expression ) , ( expression , expression ) ON CONFLICT ( index_column_name COLLATE en_US , index_column_name COLLATE en_US ) DO UPDATE SET ( column_name , column_name ) = ( DEFAULT , expression ) , ( column_name , column_name ) = ( expression ) RETURNING colname output_name , colname output_name"),
Unimplemented("WITH RECURSIVE queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name VALUES ( DEFAULT ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( index_column_name COLLATE en_US opclass , index_column_name COLLATE en_US opclass ) WHERE index_predicate DO UPDATE SET ( column_name , column_name ) = ( DEFAULT , expression ) , ( column_name , column_name ) = ( expression ) RETURNING colname output_name , colname output_name"),
Parses("INSERT INTO table_name VALUES ( expression , expression ) , ( expression , DEFAULT ) ON CONFLICT ( index_column_name ) DO UPDATE SET ( column_name , column_name ) = ( expression , DEFAULT ) , ( column_name , column_name ) = ( expression ) RETURNING colname output_name , colname output_name"),
Converts("INSERT INTO table_name VALUES ( expression , expression ) , ( expression , DEFAULT ) ON CONFLICT ( index_column_name ) DO UPDATE SET ( column_name , column_name ) = ( expression , DEFAULT ) , ( column_name , column_name ) = ( expression ) RETURNING colname output_name , colname output_name"),
Unimplemented("WITH RECURSIVE queryname AS ( select ) INSERT INTO table_name ( column_name , column_name ) VALUES ( DEFAULT , DEFAULT ) ON CONFLICT ( ( index_expression ) opclass , index_column_name opclass ) WHERE index_predicate DO UPDATE SET column_name = DEFAULT , ( column_name ) = ROW ( expression ) RETURNING colname output_name , colname output_name"),
Unimplemented("WITH queryname AS ( select ) , queryname AS ( select ) INSERT INTO table_name AS alias VALUES ( expression ) ON CONFLICT ( ( index_expression ) COLLATE en_US opclass , ( index_expression ) COLLATE en_US opclass ) WHERE index_predicate DO UPDATE SET column_name = DEFAULT , ( column_name ) = ROW ( expression ) RETURNING colname output_name , colname output_name"),
Unimplemented("WITH queryname AS ( select ) INSERT INTO table_name ( column_name ) VALUES ( expression , expression ) , ( DEFAULT , DEFAULT ) ON CONFLICT ( ( index_expression ) COLLATE en_US opclass , index_column_name COLLATE en_US ) DO UPDATE SET ( column_name ) = ( expression ) , ( column_name ) = ROW ( expression ) RETURNING colname output_name , colname output_name"),
Expand Down
Loading
Loading