Skip to content
36 changes: 26 additions & 10 deletions go/vt/vtgate/planbuilder/operators/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ func buildChildUpdOpForSetNull(
updateExprs := ctx.SemTable.GetUpdateExpressionsForFk(fk.String(updatedTable))
compExpr := nullSafeNotInComparison(ctx,
updatedTable,
updateExprs, fk, updatedTable.GetTableName(), nonLiteralUpdateInfo, false /* appendQualifier */)
updateExprs, fk, updatedTable.GetTableName(), fk.Table.GetTableName(), nonLiteralUpdateInfo, false /* appendQualifier */)
if compExpr != nil {
childWhereExpr = &sqlparser.AndExpr{
Left: childWhereExpr,
Expand Down Expand Up @@ -812,12 +812,20 @@ func createFKVerifyOp(
// and Child.c2 is not null and not ((Child.c1) <=> (Child.c2 + 1))
// limit 1
func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updatedTable *vindexes.BaseTable, updStmt *sqlparser.Update, pFK vindexes.ParentFKInfo) Operator {
childTblExpr := updStmt.TableExprs[0].(*sqlparser.AliasedTableExpr)
// Alias the foreign key's parent table name
parentTblExpr := sqlparser.NewAliasedTableExpr(pFK.Table.GetTableName(), "parent")
parentTbl, err := parentTblExpr.TableName()
if err != nil {
panic(err)
}

// Alias the foreign key's child table name
childTblExpr := sqlparser.NewAliasedTableExpr(updatedTable.GetTableName(), "child")
childTbl, err := childTblExpr.TableName()
if err != nil {
panic(err)
}
parentTbl := pFK.Table.GetTableName()

var whereCond sqlparser.Expr
var joinCond sqlparser.Expr
var notEqualColNames sqlparser.ValTuple
Expand Down Expand Up @@ -894,7 +902,7 @@ func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, upda
sqlparser.NewJoinTableExpr(
childTblExpr,
sqlparser.LeftJoinType,
sqlparser.NewAliasedTableExpr(parentTbl, ""),
parentTblExpr,
sqlparser.NewJoinCondition(joinCond, nil)),
},
sqlparser.NewWhere(sqlparser.WhereClause, whereCond),
Expand Down Expand Up @@ -931,12 +939,20 @@ func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updat
if !ctx.VerifyAllFKs {
panic(vterrors.VT12002(updatedTable.String(), cFk.Table.String()))
}
parentTblExpr := updStmt.TableExprs[0].(*sqlparser.AliasedTableExpr)

parentTblExpr := sqlparser.NewAliasedTableExpr(updatedTable.GetTableName(), "parent")
parentTbl, err := parentTblExpr.TableName()
if err != nil {
panic(err)
}
childTbl := cFk.Table.GetTableName()

// Alias the foreign key's child table name
childTblExpr := sqlparser.NewAliasedTableExpr(cFk.Table.GetTableName(), "child")
childTbl, err := childTblExpr.TableName()
if err != nil {
panic(err)
}

var joinCond sqlparser.Expr
for idx := range cFk.ParentColumns {
joinExpr := &sqlparser.ComparisonExpr{
Expand Down Expand Up @@ -967,7 +983,7 @@ func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updat
// For example, if we are setting `update child cola = :v1 and colb = :v2`, then on the parent, the where condition would look something like this -
// `:v1 IS NULL OR :v2 IS NULL OR (cola, colb) NOT IN ((:v1,:v2))`
// So, if either of :v1 or :v2 is NULL, then the entire condition is true (which is the same as not having the condition when :v1 or :v2 is NULL).
compExpr := nullSafeNotInComparison(ctx, updatedTable, updStmt.Exprs, cFk, parentTbl, nil /* nonLiteralUpdateInfo */, true /* appendQualifier */)
compExpr := nullSafeNotInComparison(ctx, updatedTable, updStmt.Exprs, cFk, parentTbl, childTbl, nil /* nonLiteralUpdateInfo */, true /* appendQualifier */)
if compExpr != nil {
whereCond = sqlparser.AndExpressions(whereCond, compExpr)
}
Expand All @@ -978,7 +994,7 @@ func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updat
sqlparser.NewJoinTableExpr(
parentTblExpr,
sqlparser.NormalJoinType,
sqlparser.NewAliasedTableExpr(childTbl, ""),
childTblExpr,
sqlparser.NewJoinCondition(joinCond, nil)),
},
sqlparser.NewWhere(sqlparser.WhereClause, whereCond),
Expand All @@ -992,7 +1008,7 @@ func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updat
// `:v1 IS NULL OR :v2 IS NULL OR (cola, colb) NOT IN ((:v1,:v2))`
// So, if either of :v1 or :v2 is NULL, then the entire condition is true (which is the same as not having the condition when :v1 or :v2 is NULL)
// This expression is used in cascading SET NULLs and in verifying whether an update should be restricted.
func nullSafeNotInComparison(ctx *plancontext.PlanningContext, updatedTable *vindexes.BaseTable, updateExprs sqlparser.UpdateExprs, cFk vindexes.ChildFKInfo, parentTbl sqlparser.TableName, nonLiteralUpdateInfo []engine.NonLiteralUpdateInfo, appendQualifier bool) sqlparser.Expr {
func nullSafeNotInComparison(ctx *plancontext.PlanningContext, updatedTable *vindexes.BaseTable, updateExprs sqlparser.UpdateExprs, cFk vindexes.ChildFKInfo, parentTbl, childTbl sqlparser.TableName, nonLiteralUpdateInfo []engine.NonLiteralUpdateInfo, appendQualifier bool) sqlparser.Expr {
var valTuple sqlparser.ValTuple
var updateValues sqlparser.ValTuple
for idx, updateExpr := range updateExprs {
Expand All @@ -1007,7 +1023,7 @@ func nullSafeNotInComparison(ctx *plancontext.PlanningContext, updatedTable *vin
}
updateValues = append(updateValues, childUpdateExpr)
if appendQualifier {
valTuple = append(valTuple, sqlparser.NewColNameWithQualifier(cFk.ChildColumns[colIdx].String(), cFk.Table.GetTableName()))
valTuple = append(valTuple, sqlparser.NewColNameWithQualifier(cFk.ChildColumns[colIdx].String(), childTbl))
} else {
valTuple = append(valTuple, sqlparser.NewColName(cFk.ChildColumns[colIdx].String()))
}
Expand Down
26 changes: 15 additions & 11 deletions go/vt/vtgate/planbuilder/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,17 @@ func (s *planTestSuite) setFks(vschema *vindexes.VSchema) {
"multicol_tbl1", "multicol_tbl2", "tbl_auth", "tblrefDef", "tbl20"})
}
if vschema.Keyspaces["unsharded_fk_allow"] != nil {
// u_tbl2(col2) -> u_tbl1(col1) Cascade.
// u_tbl4(col41) -> u_tbl1(col14) Restrict.
// u_tbl9(col9) -> u_tbl1(col1) Cascade Null.
// u_tbl3(col2) -> u_tbl2(col2) Cascade Null.
// u_tbl4(col4) -> u_tbl3(col3) Restrict.
// u_tbl6(col6) -> u_tbl5(col5) Restrict.
// u_tbl8(col8) -> u_tbl9(col9) Null Null.
// u_tbl8(col8) -> u_tbl6(col6) Cascade Null.
// u_tbl4(col4) -> u_tbl7(col7) Cascade Cascade.
// u_tbl9(col9) -> u_tbl4(col4) Restrict Restrict.
// u_tbl2(col2) -> u_tbl1(col1) Cascade.
// u_tbl4(col41) -> u_tbl1(col14) Restrict.
// u_tbl9(col9) -> u_tbl1(col1) Cascade Null.
// u_tbl3(col2) -> u_tbl2(col2) Cascade Null.
// u_tbl4(col4) -> u_tbl3(col3) Restrict.
// u_tbl6(col6) -> u_tbl5(col5) Restrict.
// u_tbl8(col8) -> u_tbl9(col9) Null Null.
// u_tbl8(col8) -> u_tbl6(col6) Cascade Null.
// u_tbl4(col4) -> u_tbl7(col7) Cascade Cascade.
// u_tbl9(col9) -> u_tbl4(col4) Restrict Restrict.
// u_tbl12(parent_id) -> u_tbl12(id) Restrict Restrict.
// u_multicol_tbl2(cola, colb) -> u_multicol_tbl1(cola, colb) Null Null.
// u_multicol_tbl3(cola, colb) -> u_multicol_tbl2(cola, colb) Cascade Cascade.

Expand All @@ -236,7 +237,10 @@ func (s *planTestSuite) setFks(vschema *vindexes.VSchema) {
_ = vschema.AddUniqueKey("unsharded_fk_allow", "u_tbl9", []sqlparser.Expr{sqlparser.NewColName("bar"), sqlparser.NewColName("col9")})
_ = vschema.AddUniqueKey("unsharded_fk_allow", "u_tbl8", []sqlparser.Expr{sqlparser.NewColName("col8")})

s.addPKs(vschema, "unsharded_fk_allow", []string{"u_tbl1", "u_tbl2", "u_tbl3", "u_tbl4", "u_tbl5", "u_tbl6", "u_tbl7", "u_tbl8", "u_tbl9", "u_tbl10", "u_tbl11",
// FK from u_tbl12 that is self-referential.
_ = vschema.AddForeignKey("unsharded_fk_allow", "u_tbl12", createFkDefinition([]string{"parent_id"}, "u_tbl12", []string{"id"}, sqlparser.Restrict, sqlparser.Restrict))

s.addPKs(vschema, "unsharded_fk_allow", []string{"u_tbl1", "u_tbl2", "u_tbl3", "u_tbl4", "u_tbl5", "u_tbl6", "u_tbl7", "u_tbl8", "u_tbl9", "u_tbl10", "u_tbl11", "u_tbl12",
"u_multicol_tbl1", "u_multicol_tbl2", "u_multicol_tbl3"})
}
}
Expand Down
Loading
Loading