diff --git a/sql/analyzer/process_truncate.go b/sql/analyzer/process_truncate.go index dd1ec8eb38..5440b10129 100644 --- a/sql/analyzer/process_truncate.go +++ b/sql/analyzer/process_truncate.go @@ -33,7 +33,9 @@ func processTruncate(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.S switch n := node.(type) { case *plan.DeleteFrom: - if !n.Resolved() { + // If there are any returning expressions, then we can't convert to a Truncate operation, + // since we need to process all rows and return results. + if !n.Resolved() || len(n.Returning) > 0 { return n, transform.SameTree, nil } return deleteToTruncate(ctx, a, n) diff --git a/sql/plan/delete.go b/sql/plan/delete.go index aa62ffb880..0040b6e13d 100644 --- a/sql/plan/delete.go +++ b/sql/plan/delete.go @@ -20,6 +20,7 @@ import ( "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/transform" ) var ErrDeleteFromNotSupported = errors.NewKind("table doesn't support DELETE FROM") @@ -33,6 +34,10 @@ type DeleteFrom struct { explicitTargets []sql.Node RefsSingleRel bool IsProcNested bool + + // Returning is a list of expressions to return after the delete operation. This feature is not + // supported in MySQL's syntax, but is exposed through PostgreSQL's syntax. + Returning []sql.Expression } var _ sql.Databaseable = (*DeleteFrom)(nil) @@ -73,6 +78,24 @@ func (p *DeleteFrom) GetDeleteTargets() []sql.Node { } } +// Schema implements the sql.Node interface. +func (p *DeleteFrom) Schema() sql.Schema { + // Postgres allows the returned values of the delete statement to be controlled, so if returning + // expressions were specified, then we return a different schema. + if p.Returning != nil { + // We know that returning exprs are resolved here, because you can't call Schema() + // safely until Resolved() is true. + returningSchema := sql.Schema{} + for _, expr := range p.Returning { + returningSchema = append(returningSchema, transform.ExpressionToColumn(expr, "")) + } + + return returningSchema + } + + return p.Child.Schema() +} + // Resolved implements the sql.Resolvable interface. func (p *DeleteFrom) Resolved() bool { if p.Child.Resolved() == false { @@ -85,9 +108,31 @@ func (p *DeleteFrom) Resolved() bool { } } + for _, expr := range p.Returning { + if expr.Resolved() == false { + return false + } + } + return true } +// Expressions implements the sql.Expressioner interface. +func (p *DeleteFrom) Expressions() []sql.Expression { + return p.Returning +} + +// WithExpressions implements the sql.Expressioner interface. +func (p *DeleteFrom) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) { + if len(newExprs) != len(p.Returning) { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(newExprs), len(p.Returning)) + } + + copy := *p + copy.Returning = newExprs + return ©, nil +} + func (p *DeleteFrom) IsReadOnly() bool { return false } @@ -111,7 +156,9 @@ func (p *DeleteFrom) WithChildren(children ...sql.Node) (sql.Node, error) { return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) } - return NewDeleteFrom(children[0], p.explicitTargets), nil + deleteFrom := NewDeleteFrom(children[0], p.explicitTargets) + deleteFrom.Returning = p.Returning + return deleteFrom, nil } func GetDeletable(node sql.Node) (sql.DeletableTable, error) { diff --git a/sql/planbuilder/dml.go b/sql/planbuilder/dml.go index 1b5ea656f2..f49997e125 100644 --- a/sql/planbuilder/dml.go +++ b/sql/planbuilder/dml.go @@ -492,6 +492,11 @@ func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) { del.RefsSingleRel = !outScope.refsSubquery del.IsProcNested = b.ProcCtx().DbName != "" outScope.node = del + + if len(d.Returning) > 0 { + del.Returning = b.analyzeSelectList(outScope, outScope, d.Returning) + } + return } diff --git a/sql/rowexec/delete.go b/sql/rowexec/delete.go index 03bc341d3d..f70c10a803 100644 --- a/sql/rowexec/delete.go +++ b/sql/rowexec/delete.go @@ -64,10 +64,12 @@ func findSourcePosition(schema sql.Schema, name string) (uint, uint, error) { // but in more complex scenarios when there are columns contributed by outer scopes and for DELETE FROM JOIN statements // the child iterator will return a row that is composed of rows from multiple table sources. type deleteIter struct { - deleters []schemaPositionDeleter - schema sql.Schema - childIter sql.RowIter - closed bool + deleters []schemaPositionDeleter + schema sql.Schema + childIter sql.RowIter + closed bool + returnExprs []sql.Expression + returnSchema sql.Schema } func (d *deleteIter) Next(ctx *sql.Context) (sql.Row, error) { @@ -98,6 +100,18 @@ func (d *deleteIter) Next(ctx *sql.Context) (sql.Row, error) { } } + if len(d.returnExprs) > 0 { + var retExprRow sql.Row + for _, returnExpr := range d.returnExprs { + result, err := returnExpr.Eval(ctx, row) + if err != nil { + return nil, err + } + retExprRow = append(retExprRow, result) + } + return retExprRow, nil + } + return row, nil } @@ -125,14 +139,16 @@ func (d *deleteIter) Close(ctx *sql.Context) error { return nil } -func newDeleteIter(childIter sql.RowIter, schema sql.Schema, deleters ...schemaPositionDeleter) sql.RowIter { +func newDeleteIter(childIter sql.RowIter, schema sql.Schema, deleters []schemaPositionDeleter, returnExprs []sql.Expression, returnSchema sql.Schema) sql.RowIter { openerClosers := make([]sql.EditOpenerCloser, len(deleters)) for i, ds := range deleters { openerClosers[i] = ds.deleter } return plan.NewTableEditorIter(&deleteIter{ - deleters: deleters, - childIter: childIter, - schema: schema, + deleters: deleters, + childIter: childIter, + schema: schema, + returnExprs: returnExprs, + returnSchema: returnSchema, }, openerClosers...) } diff --git a/sql/rowexec/dml.go b/sql/rowexec/dml.go index e4347c70a8..412bc66b6d 100644 --- a/sql/rowexec/dml.go +++ b/sql/rowexec/dml.go @@ -149,7 +149,7 @@ func (b *BaseBuilder) buildDeleteFrom(ctx *sql.Context, n *plan.DeleteFrom, row } schemaPositionDeleters[i] = schemaPositionDeleter{deleter, int(start), int(end)} } - return newDeleteIter(iter, schema, schemaPositionDeleters...), nil + return newDeleteIter(iter, schema, schemaPositionDeleters, n.Returning, n.Schema()), nil } func (b *BaseBuilder) buildForeignKeyHandler(ctx *sql.Context, n *plan.ForeignKeyHandler, row sql.Row) (sql.RowIter, error) { diff --git a/sql/rowexec/dml_iters.go b/sql/rowexec/dml_iters.go index abc78bfbfa..583c7c2dcf 100644 --- a/sql/rowexec/dml_iters.go +++ b/sql/rowexec/dml_iters.go @@ -619,6 +619,10 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc if len(innerIter.returnExprs) > 0 { return innerIter, innerIter.returnSchema } + case *deleteIter: + if len(innerIter.returnExprs) > 0 { + return innerIter, innerIter.returnSchema + } } return defaultAccumulatorIter(ctx, iter)