diff --git a/enginetest/queries/update_queries.go b/enginetest/queries/update_queries.go index 9e0ea6b305..a53e046549 100644 --- a/enginetest/queries/update_queries.go +++ b/enginetest/queries/update_queries.go @@ -482,10 +482,8 @@ var UpdateScriptTests = []ScriptTest{ }, Assertions: []ScriptTestAssertion{ { - // TODO: Foreign key constraints are not honored for UDPATE ... JOIN statements - Skip: true, Query: "UPDATE orders o JOIN customers c ON o.customer_id = c.id SET o.customer_id = 123 where o.customer_id != 1;", - ExpectedErr: sql.ErrCheckConstraintViolated, + ExpectedErr: sql.ErrForeignKeyChildViolation, }, { Query: "SELECT * FROM orders;", @@ -510,16 +508,12 @@ var UpdateScriptTests = []ScriptTest{ }, Assertions: []ScriptTestAssertion{ { - // TODO: Foreign key constraints are not honored for UDPATE ... JOIN statements - Skip: true, Query: `UPDATE child1 c1 JOIN child2 c2 ON c1.id = 10 AND c2.id = 20 SET c1.p1_id = 999, c2.p2_id = 3;`, ExpectedErr: sql.ErrForeignKeyChildViolation, }, { - // TODO: Foreign key constraints are not honored for UDPATE ... JOIN statements - Skip: true, Query: `UPDATE child1 c1 JOIN child2 c2 ON c1.id = 10 AND c2.id = 20 SET c1.p1_id = 3, c2.p2_id = 999;`, diff --git a/sql/analyzer/apply_foreign_keys.go b/sql/analyzer/apply_foreign_keys.go index 166888c8f1..2cdf1b9bcb 100644 --- a/sql/analyzer/apply_foreign_keys.go +++ b/sql/analyzer/apply_foreign_keys.go @@ -122,32 +122,35 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f if plan.IsEmptyTable(n.Child) { return n, transform.SameTree, nil } - // TODO: UPDATE JOIN can update multiple tables. Because updatableJoinTable does not implement - // sql.ForeignKeyTable, we do not currenly support FK checks for UPDATE JOIN statements. - updateDest, err := plan.GetUpdatable(n.Child) - if err != nil { - return nil, transform.SameTree, err - } - fkTbl, ok := updateDest.(sql.ForeignKeyTable) - // If foreign keys aren't supported then we return - if !ok { - return n, transform.SameTree, nil + if n.IsJoin { + uj := n.Child.(*plan.UpdateJoin) + updateTargets := uj.UpdateTargets + fkHandlerMap := make(map[string]sql.Node, len(updateTargets)) + for tableName, updateTarget := range updateTargets { + fkHandlerMap[tableName] = updateTarget + fkHandler, err := + getForeignKeyHandlerFromUpdateTarget(ctx, a, updateTarget, cache, fkChain) + if err != nil { + return nil, transform.SameTree, err + } + if fkHandler == nil { + fkHandlerMap[tableName] = updateTarget + } else { + fkHandlerMap[tableName] = fkHandler + } + } + uj = plan.NewUpdateJoin(fkHandlerMap, uj.Child) + nn, err := n.WithChildren(uj) + return nn, transform.NewTree, err } - - fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false) + fkHandler, err := getForeignKeyHandlerFromUpdateTarget(ctx, a, n.Child, cache, fkChain) if err != nil { return nil, transform.SameTree, err } - if fkEditor == nil { + if fkHandler == nil { return n, transform.SameTree, nil } - nn, err := n.WithChildren(&plan.ForeignKeyHandler{ - Table: fkTbl, - Sch: updateDest.Schema(), - OriginalNode: n.Child, - Editor: fkEditor, - AllUpdaters: fkChain.GetUpdaters(), - }) + nn, err := n.WithChildren(fkHandler) return nn, transform.NewTree, err case *plan.DeleteFrom: if plan.IsEmptyTable(n.Child) { @@ -445,6 +448,36 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa return fkEditor, nil } +// getForeignKeyHandlerFromUpdateTarget creates a ForeignKeyHandler from a given update target Node. It is used for +// applying foreign key constraints to Update nodes +func getForeignKeyHandlerFromUpdateTarget(ctx *sql.Context, a *Analyzer, updateTarget sql.Node, + cache *foreignKeyCache, fkChain foreignKeyChain) (*plan.ForeignKeyHandler, error) { + updateDest, err := plan.GetUpdatable(updateTarget) + if err != nil { + return nil, err + } + fkTbl, ok := updateDest.(sql.ForeignKeyTable) + if !ok { + return nil, nil + } + + fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false) + if err != nil { + return nil, err + } + if fkEditor == nil { + return nil, nil + } + + return &plan.ForeignKeyHandler{ + Table: fkTbl, + Sch: updateDest.Schema(), + OriginalNode: updateTarget, + Editor: fkEditor, + AllUpdaters: fkChain.GetUpdaters(), + }, nil +} + // resolveSchemaDefaults resolves the default values for the schema of |table|. This is primarily needed for column // default value expressions, since those don't get resolved during the planbuilder phase and assignExecIndexes // doesn't traverse through the ForeignKeyEditors and referential actions to find all of them. In addition to resolving diff --git a/sql/analyzer/assign_update_join.go b/sql/analyzer/assign_update_join.go index 982fee825e..a8d842220f 100644 --- a/sql/analyzer/assign_update_join.go +++ b/sql/analyzer/assign_update_join.go @@ -34,12 +34,13 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope * return n, transform.SameTree, nil } - updaters, err := rowUpdatersByTable(ctx, us, jn) + n.IsJoin = true + updateTargets, err := getUpdateTargetsByTable(us, jn) if err != nil { return nil, transform.SameTree, err } - uj := plan.NewUpdateJoin(updaters, us) + uj := plan.NewUpdateJoin(updateTargets, us) ret, err := n.WithChildren(uj) if err != nil { return nil, transform.SameTree, err @@ -51,12 +52,12 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope * return n, transform.SameTree, nil } -// rowUpdatersByTable maps a set of tables to their RowUpdater objects. -func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) { +// getUpdateTargetsByTable maps a set of table names and aliases to their corresponding update target Node +func getUpdateTargetsByTable(node sql.Node, ij sql.Node) (map[string]sql.Node, error) { namesOfTableToBeUpdated := getTablesToBeUpdated(node) resolvedTables := getTablesByName(ij) - rowUpdatersByTable := make(map[string]sql.RowUpdater) + updateTargets := make(map[string]sql.Node) for tableToBeUpdated, _ := range namesOfTableToBeUpdated { resolvedTable, ok := resolvedTables[tableToBeUpdated] if !ok { @@ -76,10 +77,10 @@ func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[strin return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN") } - rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx) + updateTargets[tableToBeUpdated] = resolvedTable } - return rowUpdatersByTable, nil + return updateTargets, nil } // getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField. diff --git a/sql/plan/update_join.go b/sql/plan/update_join.go index d8da167fa8..da2ec1ca03 100644 --- a/sql/plan/update_join.go +++ b/sql/plan/update_join.go @@ -21,15 +21,15 @@ import ( ) type UpdateJoin struct { - Updaters map[string]sql.RowUpdater + UpdateTargets map[string]sql.Node UnaryNode } -// NewUpdateJoin returns an *UpdateJoin node. -func NewUpdateJoin(editorMap map[string]sql.RowUpdater, child sql.Node) *UpdateJoin { +// NewUpdateJoin returns a new *UpdateJoin node. +func NewUpdateJoin(updateTargets map[string]sql.Node, child sql.Node) *UpdateJoin { return &UpdateJoin{ - Updaters: editorMap, - UnaryNode: UnaryNode{Child: child}, + UpdateTargets: updateTargets, + UnaryNode: UnaryNode{Child: child}, } } @@ -54,14 +54,9 @@ func (u *UpdateJoin) DebugString() string { // GetUpdatable returns an updateJoinTable which implements sql.UpdatableTable. func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable { - // TODO: UpdateJoin can update multiple tables, but this interface only allows for a single table. - // Additionally, updatableJoinTable doesn't implement interfaces that other parts of the code - // expect, so UpdateJoins don't always work correctly. For example, because updatableJoinTable - // doesn't implement ForeignKeyTable, UpdateJoin statements don't enforce foreign key checks. - // We should revamp this function so that we can communicate multiple tables being updated. return &updatableJoinTable{ - updaters: u.Updaters, - joinNode: u.Child.(*UpdateSource).Child, + updateTargets: u.UpdateTargets, + joinNode: u.Child.(*UpdateSource).Child, } } @@ -71,7 +66,7 @@ func (u *UpdateJoin) WithChildren(children ...sql.Node) (sql.Node, error) { return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1) } - return NewUpdateJoin(u.Updaters, children[0]), nil + return NewUpdateJoin(u.UpdateTargets, children[0]), nil } func (u *UpdateJoin) IsReadOnly() bool { @@ -83,10 +78,26 @@ func (u *UpdateJoin) CollationCoercibility(ctx *sql.Context) (collation sql.Coll return sql.GetCoercibility(ctx, u.Child) } +func (u *UpdateJoin) GetUpdaters(ctx *sql.Context) (map[string]sql.RowUpdater, error) { + return getUpdaters(u.UpdateTargets, ctx) +} + +func getUpdaters(updateTargets map[string]sql.Node, ctx *sql.Context) (map[string]sql.RowUpdater, error) { + updaterMap := make(map[string]sql.RowUpdater) + for tableName, updateTarget := range updateTargets { + updatable, err := GetUpdatable(updateTarget) + if err != nil { + return nil, err + } + updaterMap[tableName] = updatable.Updater(ctx) + } + return updaterMap, nil +} + // updatableJoinTable manages the update of multiple tables. type updatableJoinTable struct { - updaters map[string]sql.RowUpdater - joinNode sql.Node + updateTargets map[string]sql.Node + joinNode sql.Node } var _ sql.UpdatableTable = (*updatableJoinTable)(nil) @@ -123,8 +134,9 @@ func (u *updatableJoinTable) Collation() sql.CollationID { // Updater implements the sql.UpdatableTable interface. func (u *updatableJoinTable) Updater(ctx *sql.Context) sql.RowUpdater { + updaters, _ := getUpdaters(u.updateTargets, ctx) return &updatableJoinUpdater{ - updaterMap: u.updaters, + updaterMap: updaters, schemaMap: RecreateTableSchemaFromJoinSchema(u.joinNode.Schema()), joinSchema: u.joinNode.Schema(), } diff --git a/sql/rowexec/dml.go b/sql/rowexec/dml.go index c2c779a362..e4347c70a8 100644 --- a/sql/rowexec/dml.go +++ b/sql/rowexec/dml.go @@ -416,10 +416,14 @@ func (b *BaseBuilder) buildUpdateJoin(ctx *sql.Context, n *plan.UpdateJoin, row return nil, err } + updaters, err := n.GetUpdaters(ctx) + if err != nil { + return nil, err + } return &updateJoinIter{ updateSourceIter: ji, joinSchema: n.Child.(*plan.UpdateSource).Child.Schema(), - updaters: n.Updaters, + updaters: updaters, caches: make(map[string]sql.KeyValueCache), disposals: make(map[string]sql.DisposeFunc), joinNode: n.Child.(*plan.UpdateSource).Child,