Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sql/analyzer/process_truncate.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ func processTruncate(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.S

switch n := node.(type) {
case *plan.DeleteFrom:
if !n.Resolved() {
// If there are any returning expressions, then we can't convert to a Truncate operation,
// since we need to process all rows and return results.
if !n.Resolved() || len(n.Returning) > 0 {
return n, transform.SameTree, nil
}
return deleteToTruncate(ctx, a, n)
Expand Down
49 changes: 48 additions & 1 deletion sql/plan/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"gopkg.in/src-d/go-errors.v1"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/transform"
)

var ErrDeleteFromNotSupported = errors.NewKind("table doesn't support DELETE FROM")
Expand All @@ -33,6 +34,10 @@ type DeleteFrom struct {
explicitTargets []sql.Node
RefsSingleRel bool
IsProcNested bool

// Returning is a list of expressions to return after the delete operation. This feature is not
// supported in MySQL's syntax, but is exposed through PostgreSQL's syntax.
Returning []sql.Expression
}

var _ sql.Databaseable = (*DeleteFrom)(nil)
Expand Down Expand Up @@ -73,6 +78,24 @@ func (p *DeleteFrom) GetDeleteTargets() []sql.Node {
}
}

// Schema implements the sql.Node interface.
func (p *DeleteFrom) Schema() sql.Schema {
// Postgres allows the returned values of the delete statement to be controlled, so if returning
// expressions were specified, then we return a different schema.
if p.Returning != nil {
// We know that returning exprs are resolved here, because you can't call Schema()
// safely until Resolved() is true.
returningSchema := sql.Schema{}
for _, expr := range p.Returning {
returningSchema = append(returningSchema, transform.ExpressionToColumn(expr, ""))
}

return returningSchema
}

return p.Child.Schema()
}

// Resolved implements the sql.Resolvable interface.
func (p *DeleteFrom) Resolved() bool {
if p.Child.Resolved() == false {
Expand All @@ -85,9 +108,31 @@ func (p *DeleteFrom) Resolved() bool {
}
}

for _, expr := range p.Returning {
if expr.Resolved() == false {
return false
}
}

return true
}

// Expressions implements the sql.Expressioner interface.
func (p *DeleteFrom) Expressions() []sql.Expression {
return p.Returning
}

// WithExpressions implements the sql.Expressioner interface.
func (p *DeleteFrom) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) {
if len(newExprs) != len(p.Returning) {
return nil, sql.ErrInvalidChildrenNumber.New(p, len(newExprs), len(p.Returning))
}

copy := *p
copy.Returning = newExprs
return &copy, nil
}

func (p *DeleteFrom) IsReadOnly() bool {
return false
}
Expand All @@ -111,7 +156,9 @@ func (p *DeleteFrom) WithChildren(children ...sql.Node) (sql.Node, error) {
return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1)
}

return NewDeleteFrom(children[0], p.explicitTargets), nil
deleteFrom := NewDeleteFrom(children[0], p.explicitTargets)
deleteFrom.Returning = p.Returning
return deleteFrom, nil
}

func GetDeletable(node sql.Node) (sql.DeletableTable, error) {
Expand Down
5 changes: 5 additions & 0 deletions sql/planbuilder/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,11 @@ func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) {
del.RefsSingleRel = !outScope.refsSubquery
del.IsProcNested = b.ProcCtx().DbName != ""
outScope.node = del

if len(d.Returning) > 0 {
del.Returning = b.analyzeSelectList(outScope, outScope, d.Returning)
}

return
}

Expand Down
32 changes: 24 additions & 8 deletions sql/rowexec/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ func findSourcePosition(schema sql.Schema, name string) (uint, uint, error) {
// but in more complex scenarios when there are columns contributed by outer scopes and for DELETE FROM JOIN statements
// the child iterator will return a row that is composed of rows from multiple table sources.
type deleteIter struct {
deleters []schemaPositionDeleter
schema sql.Schema
childIter sql.RowIter
closed bool
deleters []schemaPositionDeleter
schema sql.Schema
childIter sql.RowIter
closed bool
returnExprs []sql.Expression
returnSchema sql.Schema
}

func (d *deleteIter) Next(ctx *sql.Context) (sql.Row, error) {
Expand Down Expand Up @@ -98,6 +100,18 @@ func (d *deleteIter) Next(ctx *sql.Context) (sql.Row, error) {
}
}

if len(d.returnExprs) > 0 {
var retExprRow sql.Row
for _, returnExpr := range d.returnExprs {
result, err := returnExpr.Eval(ctx, row)
if err != nil {
return nil, err
}
retExprRow = append(retExprRow, result)
}
return retExprRow, nil
}

return row, nil
}

Expand Down Expand Up @@ -125,14 +139,16 @@ func (d *deleteIter) Close(ctx *sql.Context) error {
return nil
}

func newDeleteIter(childIter sql.RowIter, schema sql.Schema, deleters ...schemaPositionDeleter) sql.RowIter {
func newDeleteIter(childIter sql.RowIter, schema sql.Schema, deleters []schemaPositionDeleter, returnExprs []sql.Expression, returnSchema sql.Schema) sql.RowIter {
openerClosers := make([]sql.EditOpenerCloser, len(deleters))
for i, ds := range deleters {
openerClosers[i] = ds.deleter
}
return plan.NewTableEditorIter(&deleteIter{
deleters: deleters,
childIter: childIter,
schema: schema,
deleters: deleters,
childIter: childIter,
schema: schema,
returnExprs: returnExprs,
returnSchema: returnSchema,
}, openerClosers...)
}
2 changes: 1 addition & 1 deletion sql/rowexec/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (b *BaseBuilder) buildDeleteFrom(ctx *sql.Context, n *plan.DeleteFrom, row
}
schemaPositionDeleters[i] = schemaPositionDeleter{deleter, int(start), int(end)}
}
return newDeleteIter(iter, schema, schemaPositionDeleters...), nil
return newDeleteIter(iter, schema, schemaPositionDeleters, n.Returning, n.Schema()), nil
}

func (b *BaseBuilder) buildForeignKeyHandler(ctx *sql.Context, n *plan.ForeignKeyHandler, row sql.Row) (sql.RowIter, error) {
Expand Down
4 changes: 4 additions & 0 deletions sql/rowexec/dml_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,10 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc
if len(innerIter.returnExprs) > 0 {
return innerIter, innerIter.returnSchema
}
case *deleteIter:
if len(innerIter.returnExprs) > 0 {
return innerIter, innerIter.returnSchema
}
}

return defaultAccumulatorIter(ctx, iter)
Expand Down
Loading