Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -5167,6 +5167,59 @@ CREATE TABLE tab3 (
},
},
},
{
Name: "update with left join with some missing rows",
SetUpScript: []string{
`create table joinparent (
id int not null auto_increment,
name varchar(128) not null,
archived int default 0 not null,
archived_at datetime null,
primary key (id)
);`,
`insert into joinparent (name) values
('first'),
('second'),
('third'),
('fourth'),
('fifth');`,
`create index joinparent_archived on joinparent (archived, archived_at);`,
`create table joinchild (
id int not null auto_increment,
name varchar(128) not null,
parent_id int not null,
archived int default 0 not null,
archived_at datetime null,
primary key (id),
constraint joinchild_parent unique (parent_id, id, archived));`,
`insert into joinchild (name, parent_id) values
('first', 4),
('second', 3),
('third', 2);`,
},
Assertions: []ScriptTestAssertion{
{
Query: `update joinparent as jp
left join joinchild as jc on jc.parent_id = jp.id
set jp.archived = jp.id, jp.archived_at = now(),
jc.archived = jc.id, jc.archived_at = now()
where jp.id > 0 and jp.name != "never"
order by jp.name
limit 100`,
Expected: []sql.Row{{types.OkResult{RowsAffected: 8, Info: plan.UpdateInfo{Matched: 10, Updated: 8}}}},
},
// do without limit to use `plan.Sort` instead of `plan.TopN`
{
Query: `update joinparent as jp
left join joinchild as jc on jc.parent_id = jp.id
set jp.archived = 0, jp.archived_at = null,
jc.archived = 0, jc.archived_at = null
where jp.id > 0 and jp.name != "never"
order by jp.name`,
Expected: []sql.Row{{types.OkResult{RowsAffected: 8, Info: plan.UpdateInfo{Matched: 10, Updated: 8}}}},
},
},
},
{
Name: "count distinct decimals",
SetUpScript: []string{
Expand Down
51 changes: 46 additions & 5 deletions sql/rowexec/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,50 @@ func (u *updateJoinIter) Next(ctx *sql.Context) (sql.Row, error) {
}
}

func toJoinNode(node sql.Node) *plan.JoinNode {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use switch n := node.(type) instead of having to recast every time
Example:

func toJoinNode(node sql.Node) *plan.JoinNode {
	switch n := node.(type) {
	case *plan.JoinNode:
		return n
	case *plan.Filter:
		return toJoinNode(n.Child)
	case *plan.Project:
		return toJoinNode(n.Child)
	default:
		return nil
	}
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, this is what I've changed it to with all the other node types mentioned above added. I will hopefully be able to push the updates sometime this week.

switch n := node.(type) {
case *plan.JoinNode:
return n
case *plan.TopN:
return toJoinNode(n.Child)
case *plan.Filter:
return toJoinNode(n.Child)
case *plan.Project:
return toJoinNode(n.Child)
case *plan.Limit:
return toJoinNode(n.Child)
case *plan.Offset:
return toJoinNode(n.Child)
case *plan.Sort:
return toJoinNode(n.Child)
case *plan.Distinct:
return toJoinNode(n.Child)
case *plan.Having:
return toJoinNode(n.Child)
case *plan.Window:
return toJoinNode(n.Child)
default:
return nil
}
}

func isIndexedAccess(node sql.Node) bool {
switch n := node.(type) {
case *plan.Filter:
return isIndexedAccess(n.Child)
case *plan.TableAlias:
return isIndexedAccess(n.Child)
case *plan.JoinNode:
return isIndexedAccess(n.Left())
case *plan.IndexedTableAccess:
return true
}
return false
}

func isRightOrLeftJoin(node sql.Node) bool {
jn, ok := node.(*plan.JoinNode)
if !ok {
jn := toJoinNode(node)
if jn == nil {
return false
}
return jn.JoinType().IsLeftOuter()
Expand All @@ -269,8 +310,8 @@ func isRightOrLeftJoin(node sql.Node) bool {
// the left or right side of the join (given the direction). A row of all nils that does not pass condition 1 must not
// be part of the update operation. This is follows the logic as established in the joinIter.
func (u *updateJoinIter) shouldUpdateDirectionalJoin(ctx *sql.Context, joinRow, tableRow sql.Row) (bool, error) {
jn := u.joinNode.(*plan.JoinNode)
if !jn.JoinType().IsLeftOuter() {
jn := toJoinNode(u.joinNode)
if jn == nil || !jn.JoinType().IsLeftOuter() {
return true, fmt.Errorf("expected left join")
}

Expand All @@ -279,7 +320,7 @@ func (u *updateJoinIter) shouldUpdateDirectionalJoin(ctx *sql.Context, joinRow,
if err != nil {
return true, err
}
if v, ok := val.(bool); ok && v {
if v, ok := val.(bool); ok && v && !isIndexedAccess(jn) {
return true, nil
}

Expand Down