diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index d299bf9801..aa3f122a09 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -5167,6 +5167,59 @@ CREATE TABLE tab3 ( }, }, }, + { + Name: "update with left join with some missing rows", + SetUpScript: []string{ + `create table joinparent ( + id int not null auto_increment, + name varchar(128) not null, + archived int default 0 not null, + archived_at datetime null, + primary key (id) + );`, + `insert into joinparent (name) values + ('first'), + ('second'), + ('third'), + ('fourth'), + ('fifth');`, + `create index joinparent_archived on joinparent (archived, archived_at);`, + `create table joinchild ( + id int not null auto_increment, + name varchar(128) not null, + parent_id int not null, + archived int default 0 not null, + archived_at datetime null, + primary key (id), + constraint joinchild_parent unique (parent_id, id, archived));`, + `insert into joinchild (name, parent_id) values + ('first', 4), + ('second', 3), + ('third', 2);`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `update joinparent as jp + left join joinchild as jc on jc.parent_id = jp.id + set jp.archived = jp.id, jp.archived_at = now(), + jc.archived = jc.id, jc.archived_at = now() + where jp.id > 0 and jp.name != "never" + order by jp.name + limit 100`, + Expected: []sql.Row{{types.OkResult{RowsAffected: 8, Info: plan.UpdateInfo{Matched: 10, Updated: 8}}}}, + }, + // do without limit to use `plan.Sort` instead of `plan.TopN` + { + Query: `update joinparent as jp + left join joinchild as jc on jc.parent_id = jp.id + set jp.archived = 0, jp.archived_at = null, + jc.archived = 0, jc.archived_at = null + where jp.id > 0 and jp.name != "never" + order by jp.name`, + Expected: []sql.Row{{types.OkResult{RowsAffected: 8, Info: plan.UpdateInfo{Matched: 10, Updated: 8}}}}, + }, + }, + }, { Name: "count distinct decimals", SetUpScript: []string{ diff --git a/sql/rowexec/update.go b/sql/rowexec/update.go index fb79f10c71..dc494a116d 100644 --- a/sql/rowexec/update.go +++ b/sql/rowexec/update.go @@ -256,9 +256,50 @@ func (u *updateJoinIter) Next(ctx *sql.Context) (sql.Row, error) { } } +func toJoinNode(node sql.Node) *plan.JoinNode { + switch n := node.(type) { + case *plan.JoinNode: + return n + case *plan.TopN: + return toJoinNode(n.Child) + case *plan.Filter: + return toJoinNode(n.Child) + case *plan.Project: + return toJoinNode(n.Child) + case *plan.Limit: + return toJoinNode(n.Child) + case *plan.Offset: + return toJoinNode(n.Child) + case *plan.Sort: + return toJoinNode(n.Child) + case *plan.Distinct: + return toJoinNode(n.Child) + case *plan.Having: + return toJoinNode(n.Child) + case *plan.Window: + return toJoinNode(n.Child) + default: + return nil + } +} + +func isIndexedAccess(node sql.Node) bool { + switch n := node.(type) { + case *plan.Filter: + return isIndexedAccess(n.Child) + case *plan.TableAlias: + return isIndexedAccess(n.Child) + case *plan.JoinNode: + return isIndexedAccess(n.Left()) + case *plan.IndexedTableAccess: + return true + } + return false +} + func isRightOrLeftJoin(node sql.Node) bool { - jn, ok := node.(*plan.JoinNode) - if !ok { + jn := toJoinNode(node) + if jn == nil { return false } return jn.JoinType().IsLeftOuter() @@ -269,8 +310,8 @@ func isRightOrLeftJoin(node sql.Node) bool { // the left or right side of the join (given the direction). A row of all nils that does not pass condition 1 must not // be part of the update operation. This is follows the logic as established in the joinIter. func (u *updateJoinIter) shouldUpdateDirectionalJoin(ctx *sql.Context, joinRow, tableRow sql.Row) (bool, error) { - jn := u.joinNode.(*plan.JoinNode) - if !jn.JoinType().IsLeftOuter() { + jn := toJoinNode(u.joinNode) + if jn == nil || !jn.JoinType().IsLeftOuter() { return true, fmt.Errorf("expected left join") } @@ -279,7 +320,7 @@ func (u *updateJoinIter) shouldUpdateDirectionalJoin(ctx *sql.Context, joinRow, if err != nil { return true, err } - if v, ok := val.(bool); ok && v { + if v, ok := val.(bool); ok && v && !isIndexedAccess(jn) { return true, nil }