Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 72 additions & 59 deletions go/vt/vttablet/tabletmanager/vreplication/controller_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,24 @@ package vreplication

import (
"fmt"
"strconv"

"vitess.io/vitess/go/vt/sqlparser"
)

// controllerPlan is the plan for vreplication control statements.
type controllerPlan struct {
opcode int
query string
// delCopySate is set for deletes.
delCopyState string
id int
opcode int

// numInserts is set for insertQuery.
numInserts int

// selector and applier are set for updateQuery and deleteQuery.
selector string
applier *sqlparser.ParsedQuery

// delCopyState is set of deletes.
delCopyState *sqlparser.ParsedQuery
}

const (
Expand All @@ -46,26 +52,31 @@ func buildControllerPlan(query string) (*controllerPlan, error) {
if err != nil {
return nil, err
}
var plan *controllerPlan
switch stmt := stmt.(type) {
case *sqlparser.Insert:
return buildInsertPlan(stmt)
plan, err = buildInsertPlan(stmt)
case *sqlparser.Update:
return buildUpdatePlan(stmt)
plan, err = buildUpdatePlan(stmt)
case *sqlparser.Delete:
return buildDeletePlan(stmt)
plan, err = buildDeletePlan(stmt)
case *sqlparser.Select:
return buildSelectPlan(stmt)
plan, err = buildSelectPlan(stmt)
default:
return nil, fmt.Errorf("unsupported construct: %s", sqlparser.String(stmt))
}
if err != nil {
return nil, err
}
plan.query = query
return plan, nil
}

func buildInsertPlan(ins *sqlparser.Insert) (*controllerPlan, error) {
switch sqlparser.String(ins.Table) {
case reshardingJournalTableName:
return &controllerPlan{
opcode: reshardingJournalQuery,
query: sqlparser.String(ins),
}, nil
case vreplicationTableName:
// no-op
Expand All @@ -88,15 +99,8 @@ func buildInsertPlan(ins *sqlparser.Insert) (*controllerPlan, error) {
if !ok {
return nil, fmt.Errorf("unsupported construct: %v", sqlparser.String(ins))
}
if len(rows) != 1 {
return nil, fmt.Errorf("unsupported construct: %v", sqlparser.String(ins))
}
row := rows[0]
idPos := 0
if len(ins.Columns) != 0 {
if len(ins.Columns) != len(row) {
return nil, fmt.Errorf("malformed statement: %v", sqlparser.String(ins))
}
idPos = -1
for i, col := range ins.Columns {
if col.EqualString("id") {
Expand All @@ -106,13 +110,18 @@ func buildInsertPlan(ins *sqlparser.Insert) (*controllerPlan, error) {
}
}
if idPos >= 0 {
if _, ok := row[idPos].(*sqlparser.NullVal); !ok {
return nil, fmt.Errorf("id should not have a value: %v", sqlparser.String(ins))
for _, row := range rows {
if idPos >= len(row) {
return nil, fmt.Errorf("malformed statement: %v", sqlparser.String(ins))
}
if _, ok := row[idPos].(*sqlparser.NullVal); !ok {
return nil, fmt.Errorf("id should not have a value: %v", sqlparser.String(ins))
}
}
}
return &controllerPlan{
opcode: insertQuery,
query: sqlparser.String(ins),
opcode: insertQuery,
numInserts: len(rows),
}, nil
}

Expand All @@ -121,7 +130,6 @@ func buildUpdatePlan(upd *sqlparser.Update) (*controllerPlan, error) {
case reshardingJournalTableName:
return &controllerPlan{
opcode: reshardingJournalQuery,
query: sqlparser.String(upd),
}, nil
case vreplicationTableName:
// no-op
Expand All @@ -137,15 +145,24 @@ func buildUpdatePlan(upd *sqlparser.Update) (*controllerPlan, error) {
}
}

id, err := extractID(upd.Where)
if err != nil {
return nil, err
buf1 := sqlparser.NewTrackedBuffer(nil)
buf1.Myprintf("select id from %s%v", vreplicationTableName, upd.Where)
upd.Where = &sqlparser.Where{
Type: sqlparser.WhereStr,
Expr: &sqlparser.ComparisonExpr{
Left: &sqlparser.ColName{Name: sqlparser.NewColIdent("id")},
Operator: sqlparser.InStr,
Right: sqlparser.ListArg("::ids"),
},
}

buf2 := sqlparser.NewTrackedBuffer(nil)
buf2.Myprintf("%v", upd)

return &controllerPlan{
opcode: updateQuery,
query: sqlparser.String(upd),
id: id,
opcode: updateQuery,
selector: buf1.String(),
applier: buf2.ParsedQuery(),
}, nil
}

Expand All @@ -154,7 +171,6 @@ func buildDeletePlan(del *sqlparser.Delete) (*controllerPlan, error) {
case reshardingJournalTableName:
return &controllerPlan{
opcode: reshardingJournalQuery,
query: sqlparser.String(del),
}, nil
case vreplicationTableName:
// no-op
Expand All @@ -171,49 +187,46 @@ func buildDeletePlan(del *sqlparser.Delete) (*controllerPlan, error) {
return nil, fmt.Errorf("unsupported construct: %v", sqlparser.String(del))
}

id, err := extractID(del.Where)
if err != nil {
return nil, err
buf1 := sqlparser.NewTrackedBuffer(nil)
buf1.Myprintf("select id from %s%v", vreplicationTableName, del.Where)
del.Where = &sqlparser.Where{
Type: sqlparser.WhereStr,
Expr: &sqlparser.ComparisonExpr{
Left: &sqlparser.ColName{Name: sqlparser.NewColIdent("id")},
Operator: sqlparser.InStr,
Right: sqlparser.ListArg("::ids"),
},
}

buf2 := sqlparser.NewTrackedBuffer(nil)
buf2.Myprintf("%v", del)

copyStateWhere := &sqlparser.Where{
Type: sqlparser.WhereStr,
Expr: &sqlparser.ComparisonExpr{
Left: &sqlparser.ColName{Name: sqlparser.NewColIdent("vrepl_id")},
Operator: sqlparser.InStr,
Right: sqlparser.ListArg("::ids"),
},
}
buf3 := sqlparser.NewTrackedBuffer(nil)
buf3.Myprintf("delete from %s%v", copyStateTableName, copyStateWhere)

return &controllerPlan{
opcode: deleteQuery,
query: sqlparser.String(del),
delCopyState: fmt.Sprintf("delete from %s where vrepl_id = %d", copySateTableName, id),
id: id,
selector: buf1.String(),
applier: buf2.ParsedQuery(),
delCopyState: buf3.ParsedQuery(),
}, nil
}

func buildSelectPlan(sel *sqlparser.Select) (*controllerPlan, error) {
switch sqlparser.String(sel.From) {
case vreplicationTableName, reshardingJournalTableName, copySateTableName:
case vreplicationTableName, reshardingJournalTableName, copyStateTableName:
return &controllerPlan{
opcode: selectQuery,
query: sqlparser.String(sel),
}, nil
default:
return nil, fmt.Errorf("invalid table name: %v", sqlparser.String(sel.From))
}
}

func extractID(where *sqlparser.Where) (int, error) {
if where == nil {
return 0, fmt.Errorf("invalid where clause:%v", sqlparser.String(where))
}
comp, ok := where.Expr.(*sqlparser.ComparisonExpr)
if !ok {
return 0, fmt.Errorf("invalid where clause:%v", sqlparser.String(where))
}
if sqlparser.String(comp.Left) != "id" {
return 0, fmt.Errorf("invalid where clause:%v", sqlparser.String(where))
}
if comp.Operator != sqlparser.EqualStr {
return 0, fmt.Errorf("invalid where clause:%v", sqlparser.String(where))
}

id, err := strconv.Atoi(sqlparser.String(comp.Right))
if err != nil {
return 0, fmt.Errorf("invalid where clause:%v", sqlparser.String(where))
}
return id, nil
}
Loading