Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions enginetest/queries/update_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,8 @@ var UpdateScriptTests = []ScriptTest{
},
Assertions: []ScriptTestAssertion{
{
// TODO: Foreign key constraints are not honored for UDPATE ... JOIN statements
Skip: true,
Query: "UPDATE orders o JOIN customers c ON o.customer_id = c.id SET o.customer_id = 123 where o.customer_id != 1;",
ExpectedErr: sql.ErrCheckConstraintViolated,
ExpectedErr: sql.ErrForeignKeyChildViolation,
},
{
Query: "SELECT * FROM orders;",
Expand All @@ -510,16 +508,12 @@ var UpdateScriptTests = []ScriptTest{
},
Assertions: []ScriptTestAssertion{
{
// TODO: Foreign key constraints are not honored for UDPATE ... JOIN statements
Skip: true,
Query: `UPDATE child1 c1
JOIN child2 c2 ON c1.id = 10 AND c2.id = 20
SET c1.p1_id = 999, c2.p2_id = 3;`,
ExpectedErr: sql.ErrForeignKeyChildViolation,
},
{
// TODO: Foreign key constraints are not honored for UDPATE ... JOIN statements
Skip: true,
Query: `UPDATE child1 c1
JOIN child2 c2 ON c1.id = 10 AND c2.id = 20
SET c1.p1_id = 3, c2.p2_id = 999;`,
Expand Down
73 changes: 53 additions & 20 deletions sql/analyzer/apply_foreign_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,32 +122,35 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
if plan.IsEmptyTable(n.Child) {
return n, transform.SameTree, nil
}
// TODO: UPDATE JOIN can update multiple tables. Because updatableJoinTable does not implement
// sql.ForeignKeyTable, we do not currenly support FK checks for UPDATE JOIN statements.
updateDest, err := plan.GetUpdatable(n.Child)
if err != nil {
return nil, transform.SameTree, err
}
fkTbl, ok := updateDest.(sql.ForeignKeyTable)
// If foreign keys aren't supported then we return
if !ok {
return n, transform.SameTree, nil
if n.IsJoin {
uj := n.Child.(*plan.UpdateJoin)
updateTargets := uj.UpdateTargets
fkHandlerMap := make(map[string]sql.Node, len(updateTargets))
for tableName, updateTarget := range updateTargets {
fkHandlerMap[tableName] = updateTarget
fkHandler, err :=
getForeignKeyHandlerFromUpdateTarget(ctx, a, updateTarget, cache, fkChain)
if err != nil {
return nil, transform.SameTree, err
}
if fkHandler == nil {
fkHandlerMap[tableName] = updateTarget
} else {
fkHandlerMap[tableName] = fkHandler
}
}
uj = plan.NewUpdateJoin(fkHandlerMap, uj.Child)
nn, err := n.WithChildren(uj)
return nn, transform.NewTree, err
}

fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false)
fkHandler, err := getForeignKeyHandlerFromUpdateTarget(ctx, a, n.Child, cache, fkChain)
if err != nil {
return nil, transform.SameTree, err
}
if fkEditor == nil {
if fkHandler == nil {
return n, transform.SameTree, nil
}
nn, err := n.WithChildren(&plan.ForeignKeyHandler{
Table: fkTbl,
Sch: updateDest.Schema(),
OriginalNode: n.Child,
Editor: fkEditor,
AllUpdaters: fkChain.GetUpdaters(),
})
nn, err := n.WithChildren(fkHandler)
return nn, transform.NewTree, err
case *plan.DeleteFrom:
if plan.IsEmptyTable(n.Child) {
Expand Down Expand Up @@ -445,6 +448,36 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
return fkEditor, nil
}

// getForeignKeyHandlerFromUpdateTarget creates a ForeignKeyHandler from a given update target Node. It is used for
// applying foreign key constraints to Update nodes
func getForeignKeyHandlerFromUpdateTarget(ctx *sql.Context, a *Analyzer, updateTarget sql.Node,
cache *foreignKeyCache, fkChain foreignKeyChain) (*plan.ForeignKeyHandler, error) {
updateDest, err := plan.GetUpdatable(updateTarget)
if err != nil {
return nil, err
}
fkTbl, ok := updateDest.(sql.ForeignKeyTable)
if !ok {
return nil, nil
}

fkEditor, err := getForeignKeyEditor(ctx, a, fkTbl, cache, fkChain, false)
if err != nil {
return nil, err
}
if fkEditor == nil {
return nil, nil
}

return &plan.ForeignKeyHandler{
Table: fkTbl,
Sch: updateDest.Schema(),
OriginalNode: updateTarget,
Editor: fkEditor,
AllUpdaters: fkChain.GetUpdaters(),
}, nil
}

// resolveSchemaDefaults resolves the default values for the schema of |table|. This is primarily needed for column
// default value expressions, since those don't get resolved during the planbuilder phase and assignExecIndexes
// doesn't traverse through the ForeignKeyEditors and referential actions to find all of them. In addition to resolving
Expand Down
15 changes: 8 additions & 7 deletions sql/analyzer/assign_update_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *
return n, transform.SameTree, nil
}

updaters, err := rowUpdatersByTable(ctx, us, jn)
n.IsJoin = true
updateTargets, err := getUpdateTargetsByTable(us, jn)
if err != nil {
return nil, transform.SameTree, err
}

uj := plan.NewUpdateJoin(updaters, us)
uj := plan.NewUpdateJoin(updateTargets, us)
ret, err := n.WithChildren(uj)
if err != nil {
return nil, transform.SameTree, err
Expand All @@ -51,12 +52,12 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *
return n, transform.SameTree, nil
}

// rowUpdatersByTable maps a set of tables to their RowUpdater objects.
func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) {
// getUpdateTargetsByTable maps a set of table names and aliases to their corresponding update target Node
func getUpdateTargetsByTable(node sql.Node, ij sql.Node) (map[string]sql.Node, error) {
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
resolvedTables := getTablesByName(ij)

rowUpdatersByTable := make(map[string]sql.RowUpdater)
updateTargets := make(map[string]sql.Node)
for tableToBeUpdated, _ := range namesOfTableToBeUpdated {
resolvedTable, ok := resolvedTables[tableToBeUpdated]
if !ok {
Expand All @@ -76,10 +77,10 @@ func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[strin
return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN")
}

rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx)
updateTargets[tableToBeUpdated] = resolvedTable
}

return rowUpdatersByTable, nil
return updateTargets, nil
}

// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
Expand Down
44 changes: 28 additions & 16 deletions sql/plan/update_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ import (
)

type UpdateJoin struct {
Updaters map[string]sql.RowUpdater
UpdateTargets map[string]sql.Node
Comment thread
angelamayxie marked this conversation as resolved.
UnaryNode
}

// NewUpdateJoin returns an *UpdateJoin node.
func NewUpdateJoin(editorMap map[string]sql.RowUpdater, child sql.Node) *UpdateJoin {
// NewUpdateJoin returns a new *UpdateJoin node.
func NewUpdateJoin(updateTargets map[string]sql.Node, child sql.Node) *UpdateJoin {
return &UpdateJoin{
Updaters: editorMap,
UnaryNode: UnaryNode{Child: child},
UpdateTargets: updateTargets,
UnaryNode: UnaryNode{Child: child},
}
}

Expand All @@ -54,14 +54,9 @@ func (u *UpdateJoin) DebugString() string {

// GetUpdatable returns an updateJoinTable which implements sql.UpdatableTable.
func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable {
// TODO: UpdateJoin can update multiple tables, but this interface only allows for a single table.
// Additionally, updatableJoinTable doesn't implement interfaces that other parts of the code
// expect, so UpdateJoins don't always work correctly. For example, because updatableJoinTable
// doesn't implement ForeignKeyTable, UpdateJoin statements don't enforce foreign key checks.
// We should revamp this function so that we can communicate multiple tables being updated.
return &updatableJoinTable{
updaters: u.Updaters,
joinNode: u.Child.(*UpdateSource).Child,
updateTargets: u.UpdateTargets,
joinNode: u.Child.(*UpdateSource).Child,
}
}

Expand All @@ -71,7 +66,7 @@ func (u *UpdateJoin) WithChildren(children ...sql.Node) (sql.Node, error) {
return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1)
}

return NewUpdateJoin(u.Updaters, children[0]), nil
return NewUpdateJoin(u.UpdateTargets, children[0]), nil
}

func (u *UpdateJoin) IsReadOnly() bool {
Expand All @@ -83,10 +78,26 @@ func (u *UpdateJoin) CollationCoercibility(ctx *sql.Context) (collation sql.Coll
return sql.GetCoercibility(ctx, u.Child)
}

func (u *UpdateJoin) GetUpdaters(ctx *sql.Context) (map[string]sql.RowUpdater, error) {
return getUpdaters(u.UpdateTargets, ctx)
}

func getUpdaters(updateTargets map[string]sql.Node, ctx *sql.Context) (map[string]sql.RowUpdater, error) {
updaterMap := make(map[string]sql.RowUpdater)
for tableName, updateTarget := range updateTargets {
updatable, err := GetUpdatable(updateTarget)
if err != nil {
return nil, err
}
updaterMap[tableName] = updatable.Updater(ctx)
}
return updaterMap, nil
}

// updatableJoinTable manages the update of multiple tables.
type updatableJoinTable struct {
updaters map[string]sql.RowUpdater
joinNode sql.Node
updateTargets map[string]sql.Node
joinNode sql.Node
}

var _ sql.UpdatableTable = (*updatableJoinTable)(nil)
Expand Down Expand Up @@ -123,8 +134,9 @@ func (u *updatableJoinTable) Collation() sql.CollationID {

// Updater implements the sql.UpdatableTable interface.
func (u *updatableJoinTable) Updater(ctx *sql.Context) sql.RowUpdater {
updaters, _ := getUpdaters(u.updateTargets, ctx)
return &updatableJoinUpdater{
updaterMap: u.updaters,
updaterMap: updaters,
schemaMap: RecreateTableSchemaFromJoinSchema(u.joinNode.Schema()),
joinSchema: u.joinNode.Schema(),
}
Expand Down
6 changes: 5 additions & 1 deletion sql/rowexec/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,10 +416,14 @@ func (b *BaseBuilder) buildUpdateJoin(ctx *sql.Context, n *plan.UpdateJoin, row
return nil, err
}

updaters, err := n.GetUpdaters(ctx)
if err != nil {
return nil, err
}
return &updateJoinIter{
updateSourceIter: ji,
joinSchema: n.Child.(*plan.UpdateSource).Child.Schema(),
updaters: n.Updaters,
updaters: updaters,
caches: make(map[string]sql.KeyValueCache),
disposals: make(map[string]sql.DisposeFunc),
joinNode: n.Child.(*plan.UpdateSource).Child,
Expand Down