From ca8524bff8709421e39ee2215fea2a946f504b97 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 6 Sep 2023 13:06:34 +0200 Subject: [PATCH 1/4] refactor: move update logic to sql_builder.go Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast_funcs.go | 16 +++ .../planbuilder/operator_transformers.go | 73 ++++++------ .../planbuilder/operators/SQL_builder.go | 107 +++++++++++------- 3 files changed, 122 insertions(+), 74 deletions(-) diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 3d8d027f12a..c4d0bdd3199 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -1150,6 +1150,22 @@ func (node *Update) AddWhere(expr Expr) { } } +// AddWhere adds the boolean expression to the +// WHERE clause as an AND condition. +func (node *Delete) AddWhere(expr Expr) { + if node.Where == nil { + node.Where = &Where{ + Type: WhereClause, + Expr: expr, + } + return + } + node.Where.Expr = &AndExpr{ + Left: node.Where.Expr, + Right: expr, + } +} + // AddOrder adds an order by element func (node *Union) AddOrder(order *Order) { node.OrderBy = append(node.OrderBy, order) diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 283ee147ba4..469afc231fa 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -404,17 +404,28 @@ func transformRoutePlan(ctx *plancontext.PlanningContext, op *operators.Route) ( switch src := op.Source.(type) { case *operators.Insert: return transformInsertPlan(ctx, op, src) - case *operators.Update: - return transformUpdatePlan(ctx, op, src) case *operators.Delete: return transformDeletePlan(ctx, op, src) } - condition := getVindexPredicate(ctx, op) - sel, err := operators.ToSQL(ctx, op.Source) + stmt, dmlOp, err := operators.ToSQL(ctx, op.Source) if err != nil { return nil, err } - replaceSubQuery(ctx, sel) + + switch stmt := stmt.(type) { + case *sqlparser.Update: + replaceSubQuery(ctx, stmt) + return buildUpdateLogicalPlan(ctx, op, dmlOp, stmt) + case sqlparser.SelectStatement: + replaceSubQuery(ctx, stmt) + return buildRouteLogicalPlan(ctx, op, stmt) + default: + panic(fmt.Sprintf("dont know how to %T", stmt)) + } +} + +func buildRouteLogicalPlan(ctx *plancontext.PlanningContext, op *operators.Route, stmt sqlparser.SelectStatement) (logicalPlan, error) { + condition := getVindexPredicate(ctx, op) eroute, err := routeToEngineRoute(ctx, op) for _, order := range op.Ordering { typ, collation, _ := ctx.SemTable.TypeForExpr(order.AST) @@ -431,11 +442,35 @@ func transformRoutePlan(ctx *plancontext.PlanningContext, op *operators.Route) ( } return &route{ eroute: eroute, - Select: sel, + Select: stmt, tables: operators.TableID(op), condition: condition, }, nil +} + +func buildUpdateLogicalPlan(ctx *plancontext.PlanningContext, op *operators.Route, dmlOp ops.Operator, stmt sqlparser.Statement) (logicalPlan, error) { + upd := dmlOp.(*operators.Update) + rp := newRoutingParams(ctx, op.Routing.OpCode()) + err := op.Routing.UpdateRoutingParams(ctx, rp) + if err != nil { + return nil, err + } + edml := &engine.DML{ + Query: generateQuery(stmt), + TableNames: []string{upd.VTable.Name.String()}, + Vindexes: upd.VTable.ColumnVindexes, + OwnedVindexQuery: upd.OwnedVindexQuery, + RoutingParameters: rp, + } + + transformDMLPlan(upd.VTable, edml, op.Routing, len(upd.ChangedVindexValues) > 0) + + e := &engine.Update{ + ChangedVindexValues: upd.ChangedVindexValues, + DML: edml, + } + return &primitiveWrapper{prim: e}, nil } func transformInsertPlan(ctx *plancontext.PlanningContext, op *operators.Route, ins *operators.Insert) (i *insert, err error) { @@ -541,32 +576,6 @@ func dmlFormatter(buf *sqlparser.TrackedBuffer, node sqlparser.SQLNode) { node.Format(buf) } -func transformUpdatePlan(ctx *plancontext.PlanningContext, op *operators.Route, upd *operators.Update) (logicalPlan, error) { - ast := upd.AST - replaceSubQuery(ctx, ast) - rp := newRoutingParams(ctx, op.Routing.OpCode()) - err := op.Routing.UpdateRoutingParams(ctx, rp) - if err != nil { - return nil, err - } - edml := &engine.DML{ - Query: generateQuery(ast), - TableNames: []string{upd.VTable.Name.String()}, - Vindexes: upd.VTable.ColumnVindexes, - OwnedVindexQuery: upd.OwnedVindexQuery, - RoutingParameters: rp, - } - - transformDMLPlan(upd.VTable, edml, op.Routing, len(upd.ChangedVindexValues) > 0) - - e := &engine.Update{ - ChangedVindexValues: upd.ChangedVindexValues, - DML: edml, - } - - return &primitiveWrapper{prim: e}, nil -} - func transformDeletePlan(ctx *plancontext.PlanningContext, op *operators.Route, del *operators.Delete) (logicalPlan, error) { ast := del.AST replaceSubQuery(ctx, ast) diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 6ea6dd43ff0..65109c38be2 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -30,20 +30,25 @@ import ( type ( queryBuilder struct { - ctx *plancontext.PlanningContext - sel sqlparser.SelectStatement - tableNames []string + ctx *plancontext.PlanningContext + stmt sqlparser.Statement + tableNames []string + dmlOperator ops.Operator } ) -func ToSQL(ctx *plancontext.PlanningContext, op ops.Operator) (sqlparser.SelectStatement, error) { +func (qb *queryBuilder) asSelectStatement() sqlparser.SelectStatement { + return qb.stmt.(sqlparser.SelectStatement) +} + +func ToSQL(ctx *plancontext.PlanningContext, op ops.Operator) (sqlparser.Statement, ops.Operator, error) { q := &queryBuilder{ctx: ctx} err := buildQuery(op, q) if err != nil { - return nil, err + return nil, nil, err } q.sortTables() - return q.sel, nil + return q.stmt, q.dmlOperator, nil } func (qb *queryBuilder) addTable(db, tableName, alias string, tableID semantics.TableSet, hints sqlparser.IndexHints) { @@ -61,10 +66,10 @@ func (qb *queryBuilder) addTableExpr( hints sqlparser.IndexHints, columnAliases sqlparser.Columns, ) { - if qb.sel == nil { - qb.sel = &sqlparser.Select{} + if qb.stmt == nil { + qb.stmt = &sqlparser.Select{} } - sel := qb.sel.(*sqlparser.Select) + sel := qb.stmt.(*sqlparser.Select) elems := &sqlparser.AliasedTableExpr{ Expr: tblExpr, Partitions: nil, @@ -74,7 +79,7 @@ func (qb *queryBuilder) addTableExpr( } qb.ctx.SemTable.ReplaceTableSetFor(tableID, elems) sel.From = append(sel.From, elems) - qb.sel = sel + qb.stmt = sel qb.tableNames = append(qb.tableNames, tableName) } @@ -85,34 +90,43 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) { return } - sel := qb.sel.(*sqlparser.Select) _, isSubQuery := expr.(*sqlparser.ExtractedSubquery) var addPred func(sqlparser.Expr) - if sqlparser.ContainsAggregation(expr) && !isSubQuery { - addPred = sel.AddHaving - } else { - addPred = sel.AddWhere + switch stmt := qb.stmt.(type) { + case *sqlparser.Select: + if sqlparser.ContainsAggregation(expr) && !isSubQuery { + addPred = stmt.AddHaving + } else { + addPred = stmt.AddWhere + } + case *sqlparser.Update: + addPred = stmt.AddWhere + case *sqlparser.Delete: + addPred = stmt.AddWhere + default: + panic(fmt.Sprintf("cant add WHERE to %T", qb.stmt)) } + for _, exp := range sqlparser.SplitAndExpression(nil, expr) { addPred(exp) } } func (qb *queryBuilder) addGroupBy(original sqlparser.Expr) { - sel := qb.sel.(*sqlparser.Select) + sel := qb.stmt.(*sqlparser.Select) sel.GroupBy = append(sel.GroupBy, original) } func (qb *queryBuilder) addProjection(projection *sqlparser.AliasedExpr) error { - switch stmt := qb.sel.(type) { + switch stmt := qb.stmt.(type) { case *sqlparser.Select: stmt.SelectExprs = append(stmt.SelectExprs, projection) return nil case *sqlparser.Union: switch expr := projection.Expr.(type) { case *sqlparser.ColName: - return checkUnionColumnByName(expr, qb.sel) + return checkUnionColumnByName(expr, stmt) default: // if there is more than just column names, we'll just push the UNION // inside a derived table and then recurse into this method again @@ -121,13 +135,14 @@ func (qb *queryBuilder) addProjection(projection *sqlparser.AliasedExpr) error { } } - return vterrors.VT13001(fmt.Sprintf("unknown select statement type: %T", qb.sel)) + return vterrors.VT13001(fmt.Sprintf("unknown select statement type: %T", qb.stmt)) } func (qb *queryBuilder) pushUnionInsideDerived() { + selStmt := qb.asSelectStatement() dt := &sqlparser.DerivedTable{ Lateral: false, - Select: qb.sel, + Select: selStmt, } sel := &sqlparser.Select{ From: []sqlparser.TableExpr{&sqlparser.AliasedTableExpr{ @@ -135,8 +150,8 @@ func (qb *queryBuilder) pushUnionInsideDerived() { As: sqlparser.NewIdentifierCS("dt"), }}, } - sel.SelectExprs = unionSelects(sqlparser.GetFirstSelect(qb.sel).SelectExprs) - qb.sel = sel + sel.SelectExprs = unionSelects(sqlparser.GetFirstSelect(selStmt).SelectExprs) + qb.stmt = sel } func unionSelects(exprs sqlparser.SelectExprs) (selectExprs sqlparser.SelectExprs) { @@ -172,7 +187,7 @@ func checkUnionColumnByName(column *sqlparser.ColName, sel sqlparser.SelectState } func (qb *queryBuilder) clearProjections() { - sel, isSel := qb.sel.(*sqlparser.Select) + sel, isSel := qb.stmt.(*sqlparser.Select) if !isSel { return } @@ -180,16 +195,16 @@ func (qb *queryBuilder) clearProjections() { } func (qb *queryBuilder) unionWith(other *queryBuilder, distinct bool) { - qb.sel = &sqlparser.Union{ - Left: qb.sel, - Right: other.sel, + qb.stmt = &sqlparser.Union{ + Left: qb.asSelectStatement(), + Right: other.asSelectStatement(), Distinct: distinct, } } func (qb *queryBuilder) joinInnerWith(other *queryBuilder, onCondition sqlparser.Expr) { - sel := qb.sel.(*sqlparser.Select) - otherSel := other.sel.(*sqlparser.Select) + sel := qb.stmt.(*sqlparser.Select) + otherSel := other.stmt.(*sqlparser.Select) sel.From = append(sel.From, otherSel.From...) sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...) @@ -210,8 +225,8 @@ func (qb *queryBuilder) joinInnerWith(other *queryBuilder, onCondition sqlparser } func (qb *queryBuilder) joinOuterWith(other *queryBuilder, onCondition sqlparser.Expr) { - sel := qb.sel.(*sqlparser.Select) - otherSel := other.sel.(*sqlparser.Select) + sel := qb.stmt.(*sqlparser.Select) + otherSel := other.stmt.(*sqlparser.Select) var lhs sqlparser.TableExpr if len(sel.From) == 1 { lhs = sel.From[0] @@ -258,7 +273,7 @@ func (qb *queryBuilder) sortTables() { } sort.Sort(ts) return true, nil - }, qb.sel) + }, qb.stmt) } @@ -370,7 +385,10 @@ func buildQuery(op ops.Operator, qb *queryBuilder) error { if err != nil { return err } - qb.sel.MakeDistinct() + qb.asSelectStatement().MakeDistinct() + return nil + case *Update: + buildUpdate(op, qb) return nil default: return vterrors.VT13001(fmt.Sprintf("unknown operator to convert to SQL: %T", op)) @@ -378,6 +396,11 @@ func buildQuery(op ops.Operator, qb *queryBuilder) error { return nil } +func buildUpdate(op *Update, qb *queryBuilder) { + qb.stmt = op.AST + qb.dmlOperator = op +} + func buildAggregation(op *Aggregator, qb *queryBuilder) error { err := buildQuery(op.Source, qb) if err != nil { @@ -415,7 +438,7 @@ func buildOrdering(op *Ordering, qb *queryBuilder) error { } for _, order := range op.Order { - qb.sel.AddOrder(order.Inner) + qb.asSelectStatement().AddOrder(order.Inner) } return nil } @@ -425,7 +448,7 @@ func buildLimit(op *Limit, qb *queryBuilder) error { if err != nil { return err } - qb.sel.SetLimit(op.AST) + qb.asSelectStatement().SetLimit(op.AST) return nil } @@ -453,7 +476,7 @@ func buildProjection(op *Projection, qb *queryBuilder) error { return err } - _, isSel := qb.sel.(*sqlparser.Select) + _, isSel := qb.stmt.(*sqlparser.Select) if isSel { qb.clearProjections() @@ -468,10 +491,10 @@ func buildProjection(op *Projection, qb *queryBuilder) error { // if the projection is on derived table, we use the select we have // created above and transform it into a derived table if op.TableID != nil { - sel := qb.sel - qb.sel = nil + sel := qb.stmt + qb.stmt = nil qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{ - Select: sel, + Select: sel.(sqlparser.SelectStatement), }, nil, nil) } @@ -553,8 +576,8 @@ func buildDerived(op *Horizon, qb *queryBuilder) error { } sqlparser.RemoveKeyspace(op.Query) - stmt := qb.sel - qb.sel = nil + stmt := qb.stmt + qb.stmt = nil switch sel := stmt.(type) { case *sqlparser.Select: return buildDerivedSelect(op, qb, sel) @@ -610,7 +633,7 @@ func buildHorizon(op *Horizon, qb *queryBuilder) error { return err } - err = stripDownQuery(op.Query, qb.sel) + err = stripDownQuery(op.Query, qb.asSelectStatement()) if err != nil { return err } @@ -619,7 +642,7 @@ func buildHorizon(op *Horizon, qb *queryBuilder) error { removeKeyspaceFromSelectExpr(aliasedExpr) } return true, nil - }, qb.sel) + }, qb.stmt) return nil } From 583e05b339dc18702c3c48649d4cd7ebd9b11ca4 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 6 Sep 2023 13:31:13 +0200 Subject: [PATCH 2/4] refactor: move delete to sql_builder Signed-off-by: Andres Taylor --- .../planbuilder/operator_transformers.go | 75 +++++++++++-------- .../planbuilder/operators/SQL_builder.go | 7 ++ 2 files changed, 49 insertions(+), 33 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 469afc231fa..8fab3bf9ea3 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -404,21 +404,21 @@ func transformRoutePlan(ctx *plancontext.PlanningContext, op *operators.Route) ( switch src := op.Source.(type) { case *operators.Insert: return transformInsertPlan(ctx, op, src) - case *operators.Delete: - return transformDeletePlan(ctx, op, src) } stmt, dmlOp, err := operators.ToSQL(ctx, op.Source) if err != nil { return nil, err } + replaceSubQuery(ctx, stmt) + switch stmt := stmt.(type) { case *sqlparser.Update: - replaceSubQuery(ctx, stmt) return buildUpdateLogicalPlan(ctx, op, dmlOp, stmt) case sqlparser.SelectStatement: - replaceSubQuery(ctx, stmt) return buildRouteLogicalPlan(ctx, op, stmt) + case *sqlparser.Delete: + return buildDeleteLogicalPlan(ctx, op, dmlOp, stmt) default: panic(fmt.Sprintf("dont know how to %T", stmt)) } @@ -448,31 +448,6 @@ func buildRouteLogicalPlan(ctx *plancontext.PlanningContext, op *operators.Route }, nil } -func buildUpdateLogicalPlan(ctx *plancontext.PlanningContext, op *operators.Route, dmlOp ops.Operator, stmt sqlparser.Statement) (logicalPlan, error) { - upd := dmlOp.(*operators.Update) - rp := newRoutingParams(ctx, op.Routing.OpCode()) - err := op.Routing.UpdateRoutingParams(ctx, rp) - if err != nil { - return nil, err - } - edml := &engine.DML{ - Query: generateQuery(stmt), - TableNames: []string{upd.VTable.Name.String()}, - Vindexes: upd.VTable.ColumnVindexes, - OwnedVindexQuery: upd.OwnedVindexQuery, - RoutingParameters: rp, - } - - transformDMLPlan(upd.VTable, edml, op.Routing, len(upd.ChangedVindexValues) > 0) - - e := &engine.Update{ - ChangedVindexValues: upd.ChangedVindexValues, - DML: edml, - } - - return &primitiveWrapper{prim: e}, nil -} - func transformInsertPlan(ctx *plancontext.PlanningContext, op *operators.Route, ins *operators.Insert) (i *insert, err error) { eins := &engine.Insert{ Opcode: mapToInsertOpCode(op.Routing.OpCode(), ins.Input != nil), @@ -576,14 +551,48 @@ func dmlFormatter(buf *sqlparser.TrackedBuffer, node sqlparser.SQLNode) { node.Format(buf) } -func transformDeletePlan(ctx *plancontext.PlanningContext, op *operators.Route, del *operators.Delete) (logicalPlan, error) { - ast := del.AST - replaceSubQuery(ctx, ast) +func buildUpdateLogicalPlan( + ctx *plancontext.PlanningContext, + op *operators.Route, + dmlOp ops.Operator, + stmt *sqlparser.Update, +) (logicalPlan, error) { + upd := dmlOp.(*operators.Update) rp := newRoutingParams(ctx, op.Routing.OpCode()) err := op.Routing.UpdateRoutingParams(ctx, rp) if err != nil { return nil, err } + edml := &engine.DML{ + Query: generateQuery(stmt), + TableNames: []string{upd.VTable.Name.String()}, + Vindexes: upd.VTable.ColumnVindexes, + OwnedVindexQuery: upd.OwnedVindexQuery, + RoutingParameters: rp, + } + + transformDMLPlan(upd.VTable, edml, op.Routing, len(upd.ChangedVindexValues) > 0) + + e := &engine.Update{ + ChangedVindexValues: upd.ChangedVindexValues, + DML: edml, + } + + return &primitiveWrapper{prim: e}, nil +} + +func buildDeleteLogicalPlan( + ctx *plancontext.PlanningContext, + rb *operators.Route, + dmlOp ops.Operator, + ast *sqlparser.Delete, +) (logicalPlan, error) { + del := dmlOp.(*operators.Delete) + rp := newRoutingParams(ctx, rb.Routing.OpCode()) + err := rb.Routing.UpdateRoutingParams(ctx, rp) + if err != nil { + return nil, err + } edml := &engine.DML{ Query: generateQuery(ast), TableNames: []string{del.VTable.Name.String()}, @@ -592,7 +601,7 @@ func transformDeletePlan(ctx *plancontext.PlanningContext, op *operators.Route, RoutingParameters: rp, } - transformDMLPlan(del.VTable, edml, op.Routing, del.OwnedVindexQuery != "") + transformDMLPlan(del.VTable, edml, rb.Routing, del.OwnedVindexQuery != "") e := &engine.Delete{ DML: edml, diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 65109c38be2..a3db9f0326d 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -390,12 +390,19 @@ func buildQuery(op ops.Operator, qb *queryBuilder) error { case *Update: buildUpdate(op, qb) return nil + case *Delete: + buildDelete(op, qb) default: return vterrors.VT13001(fmt.Sprintf("unknown operator to convert to SQL: %T", op)) } return nil } +func buildDelete(op *Delete, qb *queryBuilder) { + qb.stmt = op.AST + qb.dmlOperator = op +} + func buildUpdate(op *Update, qb *queryBuilder) { qb.stmt = op.AST qb.dmlOperator = op From b5ec7e6583fb978e413900453054afe80ca51c2b Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 6 Sep 2023 13:48:07 +0200 Subject: [PATCH 3/4] refactor: move INSERT to sql_builder Signed-off-by: Andres Taylor --- .../planbuilder/operator_transformers.go | 28 +++++++++---------- .../planbuilder/operators/SQL_builder.go | 22 +++++++-------- go/vt/vtgate/planbuilder/operators/delete.go | 4 +++ go/vt/vtgate/planbuilder/operators/insert.go | 4 +++ go/vt/vtgate/planbuilder/operators/update.go | 4 +++ 5 files changed, 37 insertions(+), 25 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 8fab3bf9ea3..74f374dac97 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -401,10 +401,6 @@ func newRoutingParams(ctx *plancontext.PlanningContext, opCode engine.Opcode) *e } func transformRoutePlan(ctx *plancontext.PlanningContext, op *operators.Route) (logicalPlan, error) { - switch src := op.Source.(type) { - case *operators.Insert: - return transformInsertPlan(ctx, op, src) - } stmt, dmlOp, err := operators.ToSQL(ctx, op.Source) if err != nil { return nil, err @@ -413,12 +409,14 @@ func transformRoutePlan(ctx *plancontext.PlanningContext, op *operators.Route) ( replaceSubQuery(ctx, stmt) switch stmt := stmt.(type) { - case *sqlparser.Update: - return buildUpdateLogicalPlan(ctx, op, dmlOp, stmt) case sqlparser.SelectStatement: return buildRouteLogicalPlan(ctx, op, stmt) + case *sqlparser.Update: + return buildUpdateLogicalPlan(ctx, op, dmlOp, stmt) case *sqlparser.Delete: return buildDeleteLogicalPlan(ctx, op, dmlOp, stmt) + case *sqlparser.Insert: + return buildInsertLogicalPlan(ctx, op, dmlOp, stmt) default: panic(fmt.Sprintf("dont know how to %T", stmt)) } @@ -448,10 +446,11 @@ func buildRouteLogicalPlan(ctx *plancontext.PlanningContext, op *operators.Route }, nil } -func transformInsertPlan(ctx *plancontext.PlanningContext, op *operators.Route, ins *operators.Insert) (i *insert, err error) { +func buildInsertLogicalPlan(ctx *plancontext.PlanningContext, rb *operators.Route, op ops.Operator, stmt *sqlparser.Insert) (logicalPlan, error) { + ins := op.(*operators.Insert) eins := &engine.Insert{ - Opcode: mapToInsertOpCode(op.Routing.OpCode(), ins.Input != nil), - Keyspace: op.Routing.Keyspace(), + Opcode: mapToInsertOpCode(rb.Routing.OpCode(), ins.Input != nil), + Keyspace: rb.Routing.Keyspace(), TableName: ins.VTable.Name.String(), Ignore: ins.Ignore, ForceNonStreaming: ins.ForceNonStreaming, @@ -460,7 +459,7 @@ func transformInsertPlan(ctx *plancontext.PlanningContext, op *operators.Route, VindexValues: ins.VindexValues, VindexValueOffset: ins.VindexValueOffset, } - i = &insert{eInsert: eins} + lp := &insert{eInsert: eins} // we would need to generate the query on the fly. The only exception here is // when unsharded query with autoincrement for that there is no input operator. @@ -469,15 +468,16 @@ func transformInsertPlan(ctx *plancontext.PlanningContext, op *operators.Route, } if ins.Input == nil { - eins.Query = generateQuery(ins.AST) + eins.Query = generateQuery(stmt) } else { - i.source, err = transformToLogicalPlan(ctx, ins.Input) + newSrc, err := transformToLogicalPlan(ctx, ins.Input) if err != nil { - return + return nil, err } + lp.source = newSrc } - return + return lp, nil } func mapToInsertOpCode(code engine.Opcode, insertSelect bool) engine.InsertOpcode { diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index a3db9f0326d..15e1833703c 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -386,25 +386,25 @@ func buildQuery(op ops.Operator, qb *queryBuilder) error { return err } qb.asSelectStatement().MakeDistinct() - return nil case *Update: - buildUpdate(op, qb) - return nil + buildDML(op, qb) case *Delete: - buildDelete(op, qb) + buildDML(op, qb) + case *Insert: + buildDML(op, qb) default: return vterrors.VT13001(fmt.Sprintf("unknown operator to convert to SQL: %T", op)) } return nil } -func buildDelete(op *Delete, qb *queryBuilder) { - qb.stmt = op.AST - qb.dmlOperator = op +type OpWithAST interface { + ops.Operator + Statement() sqlparser.Statement } -func buildUpdate(op *Update, qb *queryBuilder) { - qb.stmt = op.AST +func buildDML(op OpWithAST, qb *queryBuilder) { + qb.stmt = op.Statement() qb.dmlOperator = op } @@ -498,10 +498,10 @@ func buildProjection(op *Projection, qb *queryBuilder) error { // if the projection is on derived table, we use the select we have // created above and transform it into a derived table if op.TableID != nil { - sel := qb.stmt + sel := qb.asSelectStatement() qb.stmt = nil qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{ - Select: sel.(sqlparser.SelectStatement), + Select: sel, }, nil, nil) } diff --git a/go/vt/vtgate/planbuilder/operators/delete.go b/go/vt/vtgate/planbuilder/operators/delete.go index c24ab9f5065..01b3ab11520 100644 --- a/go/vt/vtgate/planbuilder/operators/delete.go +++ b/go/vt/vtgate/planbuilder/operators/delete.go @@ -65,3 +65,7 @@ func (d *Delete) GetOrdering() ([]ops.OrderBy, error) { func (d *Delete) ShortDescription() string { return fmt.Sprintf("%s.%s %s", d.VTable.Keyspace.Name, d.VTable.Name.String(), sqlparser.String(d.AST.Where)) } + +func (d *Delete) Statement() sqlparser.Statement { + return d.AST +} diff --git a/go/vt/vtgate/planbuilder/operators/insert.go b/go/vt/vtgate/planbuilder/operators/insert.go index 3fc70ed8998..78ae6cc133e 100644 --- a/go/vt/vtgate/planbuilder/operators/insert.go +++ b/go/vt/vtgate/planbuilder/operators/insert.go @@ -117,3 +117,7 @@ func (i *Insert) Clone(inputs []ops.Operator) ops.Operator { func (i *Insert) TablesUsed() []string { return SingleQualifiedIdentifier(i.VTable.Keyspace, i.VTable.Name) } + +func (i *Insert) Statement() sqlparser.Statement { + return i.AST +} diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index 0627f07734e..f523643a84e 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -68,3 +68,7 @@ func (u *Update) TablesUsed() []string { func (u *Update) ShortDescription() string { return u.VTable.String() } + +func (u *Update) Statement() sqlparser.Statement { + return u.AST +} From ca307e8a860712a2cf4ec63fa4abdf30379577b7 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 8 Sep 2023 07:47:43 +0200 Subject: [PATCH 4/4] use error instead of panic Signed-off-by: Andres Taylor --- go/vt/vtgate/planbuilder/operator_transformers.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 74f374dac97..e17548c68b1 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -418,7 +418,7 @@ func transformRoutePlan(ctx *plancontext.PlanningContext, op *operators.Route) ( case *sqlparser.Insert: return buildInsertLogicalPlan(ctx, op, dmlOp, stmt) default: - panic(fmt.Sprintf("dont know how to %T", stmt)) + return nil, vterrors.VT13001(fmt.Sprintf("dont know how to %T", stmt)) } }