diff --git a/enginetest/queries/derived_table_outer_scope_visibility_queries.go b/enginetest/queries/derived_table_outer_scope_visibility_queries.go index 2cb1624ff9..9e1ea67b18 100644 --- a/enginetest/queries/derived_table_outer_scope_visibility_queries.go +++ b/enginetest/queries/derived_table_outer_scope_visibility_queries.go @@ -125,7 +125,7 @@ var DerivedTableOuterScopeVisibilityQueries = []ScriptTest{ }, }, { - Name: "https://github.com/dolthub/go-mysql-server/issues/1282", + Name: "github.com/dolthub/go-mysql-server/issues/1282", SetUpScript: []string{ "CREATE TABLE `dcim_rackgroup` (`id` char(32) NOT NULL, `lft` int unsigned NOT NULL, `rght` int unsigned NOT NULL, `tree_id` int unsigned NOT NULL, `level` int unsigned NOT NULL, `parent_id` char(32), PRIMARY KEY (`id`), KEY `dcim_rackgroup_tree_id_9c2ad6f4` (`tree_id`), CONSTRAINT `dcim_rackgroup_parent_id_cc315105_fk_dcim_rackgroup_id` FOREIGN KEY (`parent_id`) REFERENCES `dcim_rackgroup` (`id`));", "CREATE TABLE `dcim_rack` (`id` char(32) NOT NULL, `group_id` char(32), PRIMARY KEY (`id`), KEY `dcim_rack_group_id_44e90ea9` (`group_id`), CONSTRAINT `dcim_rack_group_id_44e90ea9_fk_dcim_rackgroup_id` FOREIGN KEY (`group_id`) REFERENCES `dcim_rackgroup` (`id`));", diff --git a/enginetest/queries/join_queries.go b/enginetest/queries/join_queries.go index a15eaa3053..a94542a3a4 100644 --- a/enginetest/queries/join_queries.go +++ b/enginetest/queries/join_queries.go @@ -1641,4 +1641,83 @@ LATERAL ( }, }, }, + { + Name: "nested lateral joins", + SetUpScript: []string{ + "CREATE table ab (a int primary key, b int);", + "insert into ab values (0,3), (1,2), (2,1), (3,0);", + "create table three_pk (pk1 tinyint, pk2 tinyint, pk3 tinyint, col tinyint, primary key (pk1, pk2))", + "insert into three_pk values (0,0,0,100), (0,1,1,101), (1,0,1,110), (1,1,0,111)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: `select * from ab ab1 join lateral (select * from ab ab2 join lateral (select * from three_pk where pk1 = ab1.a and pk2 = ab2.a) inner1) inner2;`, + Expected: []sql.Row{ + {0, 3, 0, 3, 0, 0, 0, 100}, + {0, 3, 1, 2, 0, 1, 1, 101}, + {1, 2, 0, 3, 1, 0, 1, 110}, + {1, 2, 1, 2, 1, 1, 0, 111}, + }, + }, + { + Query: `select * from ab ab1 join lateral (select * from ab ab2 join lateral (select * from ab ab3 join lateral (select * from three_pk where pk1 = ab1.a and pk2 = ab2.a and pk3 = ab3.a) inner1) inner2) inner3;`, + Expected: []sql.Row{ + {0, 3, 0, 3, 0, 3, 0, 0, 0, 100}, + {0, 3, 1, 2, 1, 2, 0, 1, 1, 101}, + {1, 2, 0, 3, 1, 2, 1, 0, 1, 110}, + {1, 2, 1, 2, 0, 3, 1, 1, 0, 111}, + }, + }, + { + Query: `select * from ab ab1 where exists (select * from ab ab2 where exists (select * from three_pk where pk1 = ab1.a and pk2 = ab2.a));`, + Expected: []sql.Row{ + {0, 3}, + {1, 2}, + }, + }, + { + Query: `select * from ab ab1 where exists (select * from ab ab2 where exists (select * from ab ab3 where exists (select * from three_pk where pk1 = ab1.a and pk2 = ab2.a and pk3 = ab3.a)));`, + Expected: []sql.Row{ + {0, 3}, + {1, 2}, + }, + }, + }, + }, + { + Name: "non-lateral joins inside inside lateral join", + SetUpScript: []string{ + "CREATE table ab (a int primary key);", + "insert into ab values (0), (1), (2);", + "create table three_pk (pk1 tinyint, pk2 tinyint, pk3 tinyint, col tinyint, primary key (pk1, pk2))", + "insert into three_pk values (0,0,0,100), (0,1,1,101), (1,0,1,110), (1,1,0,111)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: ` +select * +from three_pk outer_table join lateral ( + select /*+ JOIN_ORDER(inner1, inner2, inner3) */ * + from three_pk inner1 join (three_pk inner2 join three_pk inner3 + on outer_table.pk2 = inner2.pk1 and outer_table.pk2 = inner2.pk2 + and outer_table.pk3 = inner3.pk1 and outer_table.pk3 = inner3.pk2) + on outer_table.pk1 = inner1.pk1 and outer_table.pk1 = inner1.pk2 +) inner_join;`, + Expected: []sql.Row{ + {0, 0, 0, 100, 0, 0, 0, 100, 0, 0, 0, 100, 0, 0, 0, 100}, + {0, 1, 1, 101, 0, 0, 0, 100, 1, 1, 0, 111, 1, 1, 0, 111}, + {1, 0, 1, 110, 1, 1, 0, 111, 0, 0, 0, 100, 1, 1, 0, 111}, + {1, 1, 0, 111, 1, 1, 0, 111, 1, 1, 0, 111, 0, 0, 0, 100}, + }, + }, + { + Query: `select ab1.a, a2 from ab ab1 join lateral (select ab2.a as a2, ab3.a as a3 from ab ab2 full outer join ab ab3 on ab2.a = ab1.a) inner1 where a3 is null;`, + Expected: []sql.Row{ + {0, 1}, {0, 2}, + {1, 0}, {1, 2}, + {2, 0}, {2, 1}, + }, + }, + }, + }, } diff --git a/enginetest/queries/query_plans.go b/enginetest/queries/query_plans.go index a6b99d970f..2e80df0da2 100644 --- a/enginetest/queries/query_plans.go +++ b/enginetest/queries/query_plans.go @@ -4758,6 +4758,202 @@ Select * from ( " └─ keys: ab.a, uv.u\n" + "", }, + { + Query: `select * from ab ab1 where exists (select * from ab ab2 join lateral (select * from two_pk where pk1 = ab1.a and pk2 = ab2.a) inner1)`, + Skip: true, + ExpectedPlan: "Project\n" + + " ├─ columns: [ab1.a:0!null, ab1.b:1]\n" + + " └─ LateralCrossJoin\n" + + " ├─ TableAlias(ab1)\n" + + " │ └─ ProcessTable\n" + + " │ └─ Table\n" + + " │ ├─ name: ab\n" + + " │ └─ columns: [a b]\n" + + " └─ Limit(1)\n" + + " └─ LateralCrossJoin\n" + + " ├─ TableAlias(ab2)\n" + + " │ └─ Table\n" + + " │ ├─ name: ab\n" + + " │ ├─ columns: [a b]\n" + + " │ ├─ colSet: (3,4)\n" + + " │ └─ tableId: 2\n" + + " └─ SubqueryAlias\n" + + " ├─ name: inner1\n" + + " ├─ outerVisibility: true\n" + + " ├─ isLateral: true\n" + + " ├─ cacheable: false\n" + + " ├─ colSet: (12-18)\n" + + " ├─ tableId: 4\n" + + " └─ Filter\n" + + " ├─ AND\n" + + " │ ├─ Eq\n" + + " │ │ ├─ two_pk.pk1:4!null\n" + + " │ │ └─ ab1.a:0!null\n" + + " │ └─ Eq\n" + + " │ ├─ two_pk.pk2:5!null\n" + + " │ └─ ab2.a:2!null\n" + + " └─ IndexedTableAccess(two_pk)\n" + + " ├─ index: [two_pk.pk1,two_pk.pk2]\n" + + " ├─ keys: [ab1.a:0!null ab2.a:2!null]\n" + + " ├─ colSet: (5-11)\n" + + " ├─ tableId: 3\n" + + " └─ Table\n" + + " ├─ name: two_pk\n" + + " └─ columns: [pk1 pk2 c1 c2 c3 c4 c5]\n" + + "", + ExpectedEstimates: "Project\n" + + " ├─ columns: [ab1.a, ab1.b]\n" + + " └─ LateralCrossJoin (estimated cost=126249.000 rows=156)\n" + + " ├─ TableAlias(ab1)\n" + + " │ └─ Table\n" + + " │ └─ name: ab\n" + + " └─ Limit(1)\n" + + " └─ LateralCrossJoin\n" + + " ├─ TableAlias(ab2)\n" + + " │ └─ Table\n" + + " │ └─ name: ab\n" + + " └─ SubqueryAlias\n" + + " ├─ name: inner1\n" + + " ├─ outerVisibility: true\n" + + " ├─ isLateral: true\n" + + " ├─ cacheable: false\n" + + " ├─ colSet: (12-18)\n" + + " ├─ tableId: 4\n" + + " └─ Filter\n" + + " ├─ ((two_pk.pk1 = ab1.a) AND (two_pk.pk2 = ab2.a))\n" + + " └─ IndexedTableAccess(two_pk)\n" + + " ├─ index: [two_pk.pk1,two_pk.pk2]\n" + + " ├─ columns: [pk1 pk2 c1 c2 c3 c4 c5]\n" + + " └─ keys: ab1.a, ab2.a\n" + + "", + ExpectedAnalysis: "Project\n" + + " ├─ columns: [ab1.a, ab1.b]\n" + + " └─ LateralCrossJoin (estimated cost=126249.000 rows=156) (actual rows=2 loops=1)\n" + + " ├─ TableAlias(ab1)\n" + + " │ └─ Table\n" + + " │ └─ name: ab\n" + + " └─ Limit(1)\n" + + " └─ LateralCrossJoin\n" + + " ├─ TableAlias(ab2)\n" + + " │ └─ Table\n" + + " │ └─ name: ab\n" + + " └─ SubqueryAlias\n" + + " ├─ name: inner1\n" + + " ├─ outerVisibility: true\n" + + " ├─ isLateral: true\n" + + " ├─ cacheable: false\n" + + " ├─ colSet: (12-18)\n" + + " ├─ tableId: 4\n" + + " └─ Filter\n" + + " ├─ ((two_pk.pk1 = ab1.a) AND (two_pk.pk2 = ab2.a))\n" + + " └─ IndexedTableAccess(two_pk)\n" + + " ├─ index: [two_pk.pk1,two_pk.pk2]\n" + + " ├─ columns: [pk1 pk2 c1 c2 c3 c4 c5]\n" + + " └─ keys: ab1.a, ab2.a\n" + + "", + }, + { + Query: `select * from ab ab1 join lateral (select * from ab ab2 join lateral (select * from two_pk where pk1 = ab1.a and pk2 = ab2.a) inner1) inner2;`, + Skip: true, + ExpectedPlan: "LateralCrossJoin\n" + + " ├─ TableAlias(ab1)\n" + + " │ └─ ProcessTable\n" + + " │ └─ Table\n" + + " │ ├─ name: ab\n" + + " │ └─ columns: [a b]\n" + + " └─ SubqueryAlias\n" + + " ├─ name: inner2\n" + + " ├─ outerVisibility: false\n" + + " ├─ isLateral: true\n" + + " ├─ cacheable: false\n" + + " ├─ colSet: (19-27)\n" + + " ├─ tableId: 5\n" + + " └─ LateralCrossJoin\n" + + " ├─ TableAlias(ab2)\n" + + " │ └─ Table\n" + + " │ ├─ name: ab\n" + + " │ ├─ columns: [a b]\n" + + " │ ├─ colSet: (3,4)\n" + + " │ └─ tableId: 2\n" + + " └─ SubqueryAlias\n" + + " ├─ name: inner1\n" + + " ├─ outerVisibility: false\n" + + " ├─ isLateral: true\n" + + " ├─ cacheable: false\n" + + " ├─ colSet: (12-18)\n" + + " ├─ tableId: 4\n" + + " └─ Filter\n" + + " ├─ AND\n" + + " │ ├─ Eq\n" + + " │ │ ├─ two_pk.pk1:4!null\n" + + " │ │ └─ ab1.a:0!null\n" + + " │ └─ Eq\n" + + " │ ├─ two_pk.pk2:5!null\n" + + " │ └─ ab2.a:2!null\n" + + " └─ IndexedTableAccess(two_pk)\n" + + " ├─ index: [two_pk.pk1,two_pk.pk2]\n" + + " ├─ columns: [pk1 pk2 c1 c2 c3 c4 c5]\n" + + " └─ keys: ab1.a, ab2.a\n" + + "", + ExpectedEstimates: "LateralCrossJoin (estimated cost=100999.000 rows=125)\n" + + " ├─ TableAlias(ab1)\n" + + " │ └─ Table\n" + + " │ └─ name: ab\n" + + " └─ SubqueryAlias\n" + + " ├─ name: inner2\n" + + " ├─ outerVisibility: false\n" + + " ├─ isLateral: true\n" + + " ├─ cacheable: false\n" + + " ├─ colSet: (19-27)\n" + + " ├─ tableId: 5\n" + + " └─ LateralCrossJoin (estimated cost=100999.000 rows=125)\n" + + " ├─ TableAlias(ab2)\n" + + " │ └─ Table\n" + + " │ └─ name: ab\n" + + " └─ SubqueryAlias\n" + + " ├─ name: inner1\n" + + " ├─ outerVisibility: false\n" + + " ├─ isLateral: true\n" + + " ├─ cacheable: false\n" + + " ├─ colSet: (12-18)\n" + + " ├─ tableId: 4\n" + + " └─ Filter\n" + + " ├─ ((two_pk.pk1 = ab1.a) AND (two_pk.pk2 = ab2.a))\n" + + " └─ IndexedTableAccess(two_pk)\n" + + " ├─ index: [two_pk.pk1,two_pk.pk2]\n" + + " ├─ columns: [pk1 pk2 c1 c2 c3 c4 c5]\n" + + " └─ keys: ab1.a, ab2.a\n" + + "", + ExpectedAnalysis: "LateralCrossJoin (estimated cost=100999.000 rows=125) (actual rows=4 loops=1)\n" + + " ├─ TableAlias(ab1)\n" + + " │ └─ Table\n" + + " │ └─ name: ab\n" + + " └─ SubqueryAlias\n" + + " ├─ name: inner2\n" + + " ├─ outerVisibility: false\n" + + " ├─ isLateral: true\n" + + " ├─ cacheable: false\n" + + " ├─ colSet: (19-27)\n" + + " ├─ tableId: 5\n" + + " └─ LateralCrossJoin (estimated cost=100999.000 rows=125)\n" + + " ├─ TableAlias(ab2)\n" + + " │ └─ Table\n" + + " │ └─ name: ab\n" + + " └─ SubqueryAlias\n" + + " ├─ name: inner1\n" + + " ├─ outerVisibility: false\n" + + " ├─ isLateral: true\n" + + " ├─ cacheable: false\n" + + " ├─ colSet: (12-18)\n" + + " ├─ tableId: 4\n" + + " └─ Filter\n" + + " ├─ ((two_pk.pk1 = ab1.a) AND (two_pk.pk2 = ab2.a))\n" + + " └─ IndexedTableAccess(two_pk)\n" + + " ├─ index: [two_pk.pk1,two_pk.pk2]\n" + + " ├─ columns: [pk1 pk2 c1 c2 c3 c4 c5]\n" + + " └─ keys: ab1.a, ab2.a\n" + + "", + }, { Query: `select * from ab s where exists (select * from ab where a = 1 or s.a = 1)`, ExpectedPlan: "SemiJoin\n" + diff --git a/sql/analyzer/apply_indexes_from_outer_scope.go b/sql/analyzer/apply_indexes_from_outer_scope.go index 0e15f7051a..085f6bb0a5 100644 --- a/sql/analyzer/apply_indexes_from_outer_scope.go +++ b/sql/analyzer/apply_indexes_from_outer_scope.go @@ -256,6 +256,11 @@ func getSubqueryIndexes( result := make(map[string]sql.Index) // For every predicate involving a table in the outer scope, see if there's an index lookup possible on its comparands // (the tables in this scope) + // TODO: As written, this only looks at a single outer scope. + // This prevents optimization when we need to get predicates from multiple outer scopes like: + // select * from ab ab1 where exists + // (select * from ab ab2 join lateral + // (select * from two_pk where pk1 = ab1.a and pk2 = ab2.a) inner1)` for _, scopeTable := range tablesInScope { indexCols := exprsByTable[scopeTable] if indexCols != nil { diff --git a/sql/analyzer/fix_exec_indexes.go b/sql/analyzer/fix_exec_indexes.go index 011d7c1f05..92f5a7115a 100644 --- a/sql/analyzer/fix_exec_indexes.go +++ b/sql/analyzer/fix_exec_indexes.go @@ -661,13 +661,6 @@ func (s *idxScope) finalizeSelf(n sql.Node) (sql.Node, error) { return n, nil } n = jn.WithScopeLen(scopeLen) - n, err = n.WithChildren( - plan.NewStripRowNode(jn.Left(), scopeLen), - plan.NewStripRowNode(jn.Right(), scopeLen), - ) - if err != nil { - return nil, err - } } return n, nil } diff --git a/sql/analyzer/resolve_subqueries.go b/sql/analyzer/resolve_subqueries.go index c84bf1f940..2a936eb19a 100644 --- a/sql/analyzer/resolve_subqueries.go +++ b/sql/analyzer/resolve_subqueries.go @@ -170,7 +170,7 @@ func finalizeSubqueriesHelper(ctx *sql.Context, a *Analyzer, node sql.Node, scop var subScope *plan.Scope = scope for _, joinParent := range joinParents { if sqa.OuterScopeVisibility && joinParent != nil { - if stripChild, ok := joinParent.Right().(*plan.StripRowNode); ok && stripChild.Child == sqa { + if joinParent.Right() == sqa { subScope = scope.NewScopeInJoin(joinParent.Children()[0]) subScope.SetLateralJoin(joinParent.Op.IsLateral()) } else { diff --git a/sql/analyzer/symbol_resolution.go b/sql/analyzer/symbol_resolution.go index 455cc170d0..089bc72c61 100644 --- a/sql/analyzer/symbol_resolution.go +++ b/sql/analyzer/symbol_resolution.go @@ -105,7 +105,7 @@ func pruneTables(ctx *sql.Context, a *Analyzer, n sql.Node, s *plan.Scope, sel R } case *plan.Filter, *plan.Distinct, *plan.GroupBy, *plan.Project, *plan.TableAlias, *plan.Window, *plan.Sort, *plan.Limit, *plan.RecursiveCte, - *plan.RecursiveTable, *plan.TopN, *plan.Offset, *plan.StripRowNode: + *plan.RecursiveTable, *plan.TopN, *plan.Offset: default: return n, transform.SameTree, nil } diff --git a/sql/memo/join_order_builder.go b/sql/memo/join_order_builder.go index 7d6d835f80..ad26853329 100644 --- a/sql/memo/join_order_builder.go +++ b/sql/memo/join_order_builder.go @@ -238,8 +238,6 @@ func (j *joinOrderBuilder) populateSubgraph(n sql.Node) (vertexSet, edgeSet, *Ex group = j.buildJoinLeaf(n) case sql.NameableNode: group = j.buildJoinLeaf(n.(plan.TableIdNode)) - case *plan.StripRowNode: - return j.populateSubgraph(n.Child) case *plan.CachedResults: return j.populateSubgraph(n.Child) default: diff --git a/sql/plan/scope.go b/sql/plan/scope.go index e9b550aebb..15efe669a3 100644 --- a/sql/plan/scope.go +++ b/sql/plan/scope.go @@ -99,18 +99,6 @@ func (s *Scope) NewScopeFromSubqueryExpression(node sql.Node, corr sql.ColSet) * // NewScopeFromSubqueryExpression returns a new subscope created from a subquery expression contained by the specified // node. func (s *Scope) NewScopeInJoin(node sql.Node) *Scope { - for { - var done bool - switch n := node.(type) { - case *StripRowNode: - node = n.Child - default: - done = true - } - if done { - break - } - } if s == nil { return &Scope{joinSiblings: []sql.Node{node}} } diff --git a/sql/plan/subquery.go b/sql/plan/subquery.go index 296ef0ed9f..85013683f4 100644 --- a/sql/plan/subquery.go +++ b/sql/plan/subquery.go @@ -68,49 +68,6 @@ var _ sql.NonDeterministicExpression = (*Subquery)(nil) var _ sql.ExpressionWithNodes = (*Subquery)(nil) var _ sql.CollationCoercible = (*Subquery)(nil) -type StripRowNode struct { - UnaryNode - NumCols int -} - -var _ sql.Node = (*StripRowNode)(nil) -var _ sql.CollationCoercible = (*StripRowNode)(nil) - -func NewStripRowNode(child sql.Node, numCols int) sql.Node { - return &StripRowNode{UnaryNode: UnaryNode{child}, NumCols: numCols} -} - -// Describe implements the sql.Describable interface -func (srn *StripRowNode) Describe(options sql.DescribeOptions) string { - return sql.Describe(srn.Child, options) -} - -// String implements the fmt.Stringer interface -func (srn *StripRowNode) String() string { - return srn.Child.String() -} - -func (srn *StripRowNode) IsReadOnly() bool { - return srn.Child.IsReadOnly() -} - -// DebugString implements the sql.DebugStringer interface -func (srn *StripRowNode) DebugString() string { - return sql.DebugString(srn.Child) -} - -func (srn *StripRowNode) WithChildren(children ...sql.Node) (sql.Node, error) { - if len(children) != 1 { - return nil, sql.ErrInvalidChildrenNumber.New(srn, len(children), 1) - } - return NewStripRowNode(children[0], srn.NumCols), nil -} - -// CollationCoercibility implements the interface sql.CollationCoercible. -func (srn *StripRowNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { - return sql.GetCoercibility(ctx, srn.Child) -} - // PrependNode wraps its child by prepending column values onto any result rows type PrependNode struct { UnaryNode diff --git a/sql/rowexec/builder_gen_test.go b/sql/rowexec/builder_gen_test.go index 1014ae6538..ef211e91b3 100644 --- a/sql/rowexec/builder_gen_test.go +++ b/sql/rowexec/builder_gen_test.go @@ -159,7 +159,6 @@ func TestGenBuilder(t *testing.T) { "SignalName": "*plan.SignalName", "Sort": "*plan.Sort", "TopN": "*plan.TopN", - "StripRowNode": "*plan.StripRowNode", "prependNode": "*plan.prependNode", "Max1Row": "*plan.Max1Row", "SubqueryAlias": "*plan.SubqueryAlias", diff --git a/sql/rowexec/join_iters.go b/sql/rowexec/join_iters.go index 7f987a978f..1c9daca744 100644 --- a/sql/rowexec/join_iters.go +++ b/sql/rowexec/join_iters.go @@ -1,4 +1,4 @@ -// Copyright 2020-2021 Dolthub, Inc. +// Copyright 2020-2025 Dolthub, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -29,153 +29,329 @@ import ( "github.com/dolthub/go-mysql-server/sql/transform" ) -// joinIter is an iterator that iterates over every row in the primary table and performs an index lookup in -// the secondary table for each value -type joinIter struct { - b sql.NodeExecBuilder - cond sql.Expression - primary sql.RowIter +// joinState is the common state for all join iterators. +// This type encapsulates accesses to the underlying iterators and handles things like managing outer scopes. +// Various join iters wrap a joinState value and handle behavior specific to that join type. +// +// The general usage pattern looks like this: +// +// while there are rows in the primary/left child iterator: +// advance the primary iterator and store the yielded rows, stripping columns that refer to outer scopes +// build a new child iterator for the right/secondary child, using the new values from the primary (and from outer scopes) as the parent row. +// while there are rows in the secondary/right child iterator: +// advance the secondary iterator and store the yielded rows, stripping columns that refer to outer scopes +// check whether the current state satisfied any join conditions +// potentially yield a new row containing values from the outer scope, and from both children +// +// All row iterators, including join iterators, currently obey the following invariant: +// - When constructed, they take a `parentRow` parameter containing values for all values defined outside of the node. +// This includes, in order: +// -- Values from outer and lateral scopes +// -- Values from parent join nodes +// - When yielding rows, the row contains, in order: +// -- Values from outer and lateral scopes +// -- Values defined by the node +// +// Yielding values defined in outer scopes is necessary because a parent node may need to use that value in an expression; +// prepending these values to iterator rows is how we expose them. +// Notably, join iterators do *not* yield values defined in parent nodes unless those values constitute an outer or lateral +// scope (such as in the case of lateralJoinIterator). This is important, because it allows a join iterator to not care +// whether or not its children are also join iterators: both join and non-join nodes yield values in the same format. +// +// Q: Why do we only copy the last rows returned by the child iterators? Why can't we just use the result of primaryRowIter.Next() as primaryRow? +// A: There is a subtle correctness issue if we do that, because the child could be a cached subquery. We cache subqueries if they don't reference +// any columns in their outer scope, but we still pass in those columns when building the iterator, and the iterator still returns values +// for those columns in its results. Thus for values corresponding to outer scopes, it is possible for the values returned by the child iterator +// to differ from the values in the join's parentRow, and the values returned by the iterator should be discarded. +// +// TODO: This is dangerous and there may be existing correctness bugs because of this. We should fix this by moving to +// an implementation where parent scope values are not returned by iterators at all. +type joinState struct { + builder sql.NodeExecBuilder + joinType plan.JoinType + + // scopeLen is the number of columns inherited from outer scopes. These are additional columns that are prepended + // to every child iterator, allowing child iterators to resolve references to these outer scopes. + scopeLen int + // parentLen is the number of columns inherited from parent nodes, including both outer scopes and parent nodes + // within the same scope (such as parent join nodes in a many-table-join.) This value is always greater than or + // equal to the value of |scopeLen| + parentLen int + // leftLen and rightLen are the number of columns in the left/primary and right/secondary child node schemas. + leftLen int + rightLen int + + // join nodes and their children obey the following invariants: + // - rows returned by primaryRowIter contain the outer scope rows, followed by the left child rows. + // Thus, they are always of length scopeLen + leftLen. + // - rows returned by secondaryRowIter contain the outer scope rows, followed by the right child rows. + // For non-lateral joins they are always of length scopeLen + rightLen. + // Lateral joins make this slightly more complicated. + + primaryRowIter sql.RowIter + secondaryRowIter sql.RowIter + + // primaryRow is the row that will get passed to the builder when building the secondary child iterator. + // It is always of length parentLen + leftLen + primaryRow sql.Row + + // fullRow is the row that will get passed to any join conditions. It is always of length rowSize (aka parentLen + leftLen + rightLen) + fullRow sql.Row + + // secondaryProvider is a node from which secondaryRowIter can be constructed. It is usually built once + // for each value pulled from primaryRowIter secondaryProvider sql.Node - secondary sql.RowIter - primaryRow sql.Row - rowSize int - scopeLen int - parentLen int - joinType plan.JoinType - loadPrimaryRow bool - foundMatch bool + + // cond is the join condition, if any + cond sql.Expression + + // foundMatch indicates whether the iterator has returned a result for the current primaryRow. It is + // needed for left outer joins and full outer joins. + foundMatch bool } -func newJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) { - var leftName, rightName string +// fullRowSize is the total number of columns visible in the join. It includes columns from the outer scope, +// columns from parent join nodes, and columns from both children. +func (i *joinState) fullRowSize() int { + return i.parentLen + i.leftLen + i.rightLen +} + +// resultRowSize is the size of the rows produced by the join iterator +func (i *joinState) resultRowSize() int { + return i.scopeLen + i.leftLen + i.rightLen +} + +// makeResultRow creates a new sql.Row computed from the most recently visited children. +func (i *joinState) makeResultRow() sql.Row { + resultRow := make(sql.Row, i.resultRowSize()) + copy(resultRow, i.fullRow[:i.scopeLen]) + copy(resultRow[i.scopeLen:], i.fullRow[i.parentLen:]) + return resultRow +} + +// scopeColumns returns the values defined in outer scopes that are visible to this join. +// It is a subset of parentColumns. +func (i *joinState) scopeColumns() sql.Row { + return i.fullRow[:i.scopeLen] +} + +// parentColumns returns the values defined in all parent nodes that are visible to this join. +// It is a superset of scopeColumns, but also includes parent nodes in the same scope, such as parent join nodes. +func (i *joinState) parentColumns() sql.Row { + return i.fullRow[:i.parentLen] +} + +// leftColumns returns the values most recently yielded from the primary/left child node. +func (i *joinState) leftColumns() sql.Row { + return i.fullRow[i.parentLen : i.parentLen+i.leftLen] +} + +// rightColumns returns the values most recently yielded from the secondary/right child node. +func (i *joinState) rightColumns() sql.Row { + return i.fullRow[i.parentLen+i.leftLen : i.parentLen+i.leftLen+i.rightLen] +} + +// makeLeftOuterNonMatchingResult returns a new sql.Row representing a row from an OUTER LEFT join where no match was made with the right child. +func (i *joinState) makeLeftOuterNonMatchingResult() sql.Row { + resultRow := make(sql.Row, i.resultRowSize()) + copy(resultRow, i.scopeColumns()) + copy(resultRow[i.scopeLen:], i.leftColumns()) + return resultRow +} + +// makeLeftOuterNonMatchingResult returns a new sql.Row representing a row from an OUTER RIGHT join where no match was made with the left child. +func (i *joinState) makeRightOuterNonMatchingResult() sql.Row { + resultRow := make(sql.Row, i.resultRowSize()) + copy(resultRow, i.scopeColumns()) + copy(resultRow[i.scopeLen+i.leftLen:], i.rightColumns()) + return resultRow +} + +// makeSemiJoinResult returns a new sql.Row representing a row from a SemiJoin or ExistsIter +func (i *joinState) makeSemiJoinResult() sql.Row { + resultRow := make(sql.Row, i.scopeLen+i.leftLen) + copy(resultRow, i.scopeColumns()) + copy(resultRow[i.scopeLen:], i.leftColumns()) + return resultRow +} + +func newJoinState(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, parentRow sql.Row, opName string) (joinState, trace.Span, error) { + var left, right string if leftTable, ok := j.Left().(sql.Nameable); ok { - leftName = leftTable.Name() + left = leftTable.Name() } else { - leftName = reflect.TypeOf(j.Left()).String() + left = reflect.TypeOf(j.Left()).String() } - if rightTable, ok := j.Right().(sql.Nameable); ok { - rightName = rightTable.Name() + right = rightTable.Name() } else { - rightName = reflect.TypeOf(j.Right()).String() + right = reflect.TypeOf(j.Right()).String() } - span, ctx := ctx.Span("plan.joinIter", trace.WithAttributes( - attribute.String("left", leftName), - attribute.String("right", rightName), + span, ctx := ctx.Span(opName, trace.WithAttributes( + attribute.String("left", left), + attribute.String("right", right), )) - l, err := b.Build(ctx, j.Left(), row) + parentLen := len(parentRow) + scopeLen := j.ScopeLen + leftLen := len(j.Left().Schema()) + rightLen := len(j.Right().Schema()) + + primaryRow := make(sql.Row, parentLen+leftLen) + copy(primaryRow, parentRow) + + resultRow := make(sql.Row, scopeLen+leftLen+rightLen) + copy(resultRow, parentRow[:scopeLen]) + + fullRow := make(sql.Row, parentLen+leftLen+rightLen) + copy(fullRow, parentRow[:parentLen]) + + primaryRowIter, err := b.Build(ctx, j.Left(), parentRow) if err != nil { span.End() - return nil, err + return joinState{}, nil, err } - parentLen := len(row) - - primaryRow := make(sql.Row, parentLen+len(j.Left().Schema())) - copy(primaryRow, row) - - return sql.NewSpanIter(span, &joinIter{ - b: b, + return joinState{ + builder: b, joinType: j.Op, - cond: j.Filter, - primary: l, - primaryRow: primaryRow, - loadPrimaryRow: true, + scopeLen: scopeLen, + parentLen: parentLen, + leftLen: leftLen, + rightLen: rightLen, + primaryRowIter: primaryRowIter, + primaryRow: primaryRow, + fullRow: fullRow, secondaryProvider: j.Right(), + secondaryRowIter: nil, - rowSize: parentLen + len(j.Left().Schema()) + len(j.Right().Schema()), - scopeLen: j.ScopeLen, - parentLen: parentLen, - }), nil + cond: j.Filter, + }, span, nil } -func (i *joinIter) loadPrimary(ctx *sql.Context) error { - if i.loadPrimaryRow { - r, err := i.primary.Next(ctx) +// loadPrimary advances the primary iterator and updates internal state. +func (i *joinState) loadPrimary(ctx *sql.Context) error { + childRow, err := i.primaryRowIter.Next(ctx) + if err != nil { + return err + } + i.foundMatch = false + // the child iter begins with rows from the outer scope; strip those away + rowsFromChild := childRow[len(childRow)-i.leftLen:] + copy(i.primaryRow[i.parentLen:], rowsFromChild) + copy(i.fullRow[i.parentLen:], rowsFromChild) + return nil +} + +// loadSecondary advances the secondary iterator and updates internal state. +// If the secondary iterator is exhausted, close and remove it. +func (i *joinState) loadSecondary(ctx *sql.Context) error { + childRow, err := i.secondaryRowIter.Next(ctx) + if err == io.EOF { + err = i.secondaryRowIter.Close(ctx) if err != nil { return err } - copy(i.primaryRow[i.parentLen:], r) - i.foundMatch = false - i.loadPrimaryRow = false + i.secondaryRowIter = nil + return io.EOF + } else if err != nil { + return err } + + // the child iter begins with rows from the outer scope; strip those away + rowsFromChild := childRow[len(childRow)-i.rightLen:] + copy(i.fullRow[i.parentLen+i.leftLen:], rowsFromChild) return nil } -func (i *joinIter) loadSecondary(ctx *sql.Context) (sql.Row, error) { - if i.secondary == nil { - rowIter, err := i.b.Build(ctx, i.secondaryProvider, i.primaryRow) - if err != nil { - return nil, err - } - if plan.IsEmptyIter(rowIter) { - return nil, plan.ErrEmptyCachedResult - } - i.secondary = rowIter +// resetSecondaryIter closes and removes the secondary iterator. +func (i *joinState) resetSecondaryIter(ctx *sql.Context) (err error) { + if i.secondaryRowIter != nil { + err = i.secondaryRowIter.Close(ctx) + i.secondaryRowIter = nil } + return err +} - secondaryRow, err := i.secondary.Next(ctx) - if err != nil { - if err == io.EOF { - err = i.secondary.Close(ctx) - i.secondary = nil - if err != nil { - return nil, err +// Close cleans up the iterator by recursively closing the children iterators. +func (i *joinState) Close(ctx *sql.Context) (err error) { + if i.primaryRowIter != nil { + if err = i.primaryRowIter.Close(ctx); err != nil { + if i.secondaryRowIter != nil { + _ = i.secondaryRowIter.Close(ctx) } - i.loadPrimaryRow = true - return nil, io.EOF + return err } + } + + if i.secondaryRowIter != nil { + err = i.secondaryRowIter.Close(ctx) + i.secondaryRowIter = nil + } + + return err +} + +// joinIter is an iterator that iterates over every row in the primary table and performs an index lookup in +// the secondary table for each value +type joinIter struct { + joinState +} + +func newJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) { + js, span, err := newJoinState(ctx, b, j, row, "plan.joinIter") + if err != nil { return nil, err } - return secondaryRow, nil + return sql.NewSpanIter(span, &joinIter{ + joinState: js, + }), nil } func (i *joinIter) Next(ctx *sql.Context) (sql.Row, error) { for { - if err := i.loadPrimary(ctx); err != nil { - return nil, err + if i.secondaryRowIter == nil { + if err := i.loadPrimary(ctx); err != nil { + return nil, err + } + + rowIter, err := i.builder.Build(ctx, i.secondaryProvider, i.primaryRow) + if err != nil { + return nil, err + } + if plan.IsEmptyIter(rowIter) { + if !i.foundMatch && i.joinType.IsLeftOuter() { + return i.makeLeftOuterNonMatchingResult(), nil + } + return nil, io.EOF + } + i.secondaryRowIter = rowIter } - primary := i.primaryRow - secondary, err := i.loadSecondary(ctx) + err := i.loadSecondary(ctx) if err != nil { if errors.Is(err, io.EOF) { if !i.foundMatch && i.joinType.IsLeftOuter() { - i.loadPrimaryRow = true - row := i.buildRow(primary, nil) - return i.removeParentRow(row), nil + return i.makeLeftOuterNonMatchingResult(), nil } continue } - if errors.Is(err, plan.ErrEmptyCachedResult) { - if !i.foundMatch && i.joinType.IsLeftOuter() { - i.loadPrimaryRow = true - row := i.buildRow(primary, nil) - return i.removeParentRow(row), nil - } - return nil, io.EOF - } return nil, err } - row := i.buildRow(primary, secondary) - res, err := sql.EvaluateCondition(ctx, i.cond, row) + res, err := sql.EvaluateCondition(ctx, i.cond, i.fullRow) if err != nil { return nil, err } if res == nil && i.joinType.IsExcludeNulls() { - err = i.secondary.Close(ctx) - i.secondary = nil - if err != nil { + if err := i.resetSecondaryIter(ctx); err != nil { return nil, err } - i.loadPrimaryRow = true continue } @@ -184,83 +360,24 @@ func (i *joinIter) Next(ctx *sql.Context) (sql.Row, error) { } i.foundMatch = true - return i.removeParentRow(row), nil - } -} - -func (i *joinIter) removeParentRow(r sql.Row) sql.Row { - copy(r[i.scopeLen:], r[i.parentLen:]) - r = r[:len(r)-i.parentLen+i.scopeLen] - return r -} - -// buildRow builds the result set row using the rows from the primary and secondary tables -func (i *joinIter) buildRow(primary, secondary sql.Row) sql.Row { - row := make(sql.Row, i.rowSize) - copy(row, primary) - copy(row[len(primary):], secondary) - return row -} - -func (i *joinIter) Close(ctx *sql.Context) (err error) { - if i.primary != nil { - if err = i.primary.Close(ctx); err != nil { - if i.secondary != nil { - _ = i.secondary.Close(ctx) - } - return err - } - } - - if i.secondary != nil { - err = i.secondary.Close(ctx) - i.secondary = nil + return i.makeResultRow(), nil } - - return err } func newExistsIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) { - leftIter, err := b.Build(ctx, j.Left(), row) + + js, span, err := newJoinState(ctx, b, j, row, "plan.existsIter") if err != nil { return nil, err } - parentLen := len(row) - - rowSize := parentLen + len(j.Left().Schema()) + len(j.Right().Schema()) - fullRow := make(sql.Row, rowSize) - copy(fullRow, row) - - primaryRow := make(sql.Row, parentLen+len(j.Left().Schema())) - copy(primaryRow, row) - - return &existsIter{ - b: b, - typ: j.Op, - primary: leftIter, - primaryRow: primaryRow, - fullRow: fullRow, - parentLen: parentLen, - secondaryProvider: j.Right(), - cond: j.Filter, - scopeLen: j.ScopeLen, - rowSize: rowSize, - }, nil + return sql.NewSpanIter(span, &existsIter{ + joinState: js, + }), nil } type existsIter struct { - b sql.NodeExecBuilder - cond sql.Expression - primary sql.RowIter - secondaryProvider sql.Node - primaryRow sql.Row - fullRow sql.Row - parentLen int - scopeLen int - rowSize int - typ plan.JoinType - rightIterNonEmpty bool + joinState } type existsState uint8 @@ -275,8 +392,6 @@ const ( ) func (i *existsIter) Next(ctx *sql.Context) (sql.Row, error) { - var right sql.Row - var rIter sql.RowIter var err error // the common sequence is: LOAD_LEFT -> LOAD_RIGHT -> COMPARE -> RET @@ -290,52 +405,43 @@ func (i *existsIter) Next(ctx *sql.Context) (sql.Row, error) { for { switch nextState { case esIncLeft: - r, err := i.primary.Next(ctx) - if err != nil { + if err := i.loadPrimary(ctx); err != nil { return nil, err } - copy(i.primaryRow[i.parentLen:], r) - rIter, err = i.b.Build(ctx, i.secondaryProvider, i.primaryRow) + i.secondaryRowIter, err = i.builder.Build(ctx, i.secondaryProvider, i.primaryRow) if err != nil { return nil, err } - if plan.IsEmptyIter(rIter) { + if plan.IsEmptyIter(i.secondaryRowIter) { nextState = esRightIterEOF } else { nextState = esIncRight } case esIncRight: - right, err = rIter.Next(ctx) + err := i.loadSecondary(ctx) if err != nil { - iterErr := rIter.Close(ctx) - if iterErr != nil { - return nil, fmt.Errorf("%w; error on close: %s", err, iterErr) - } if errors.Is(err, io.EOF) { nextState = esRightIterEOF } else { return nil, err } } else { - i.rightIterNonEmpty = true nextState = esCompare } case esRightIterEOF: - if i.typ.IsSemi() { + if i.joinType.IsSemi() { // reset iter, no match nextState = esIncLeft } else { nextState = esRet } case esCompare: - copy(i.fullRow[i.parentLen:], i.primaryRow[i.parentLen:]) - copy(i.fullRow[len(i.primaryRow):], right) res, err := sql.EvaluateCondition(ctx, i.cond, i.fullRow) if err != nil { return nil, err } - if res == nil && i.typ.IsExcludeNulls() { + if res == nil && i.joinType.IsExcludeNulls() { nextState = esRejectNull continue } @@ -343,11 +449,10 @@ func (i *existsIter) Next(ctx *sql.Context) (sql.Row, error) { if !sql.IsTrue(res) { nextState = esIncRight } else { - err = rIter.Close(ctx) - if err != nil { + if err = i.resetSecondaryIter(ctx); err != nil { return nil, err } - if i.typ.IsAnti() { + if i.joinType.IsAnti() { // reset iter, found match -> no return row nextState = esIncLeft } else { @@ -355,79 +460,40 @@ func (i *existsIter) Next(ctx *sql.Context) (sql.Row, error) { } } case esRejectNull: - if i.typ.IsAnti() { + if i.joinType.IsAnti() { nextState = esIncLeft } else { nextState = esIncRight } case esRet: - return i.removeParentRow(i.primaryRow.Copy()), nil + return i.makeSemiJoinResult(), nil default: return nil, fmt.Errorf("invalid exists join state") } } } -func (i *existsIter) removeParentRow(r sql.Row) sql.Row { - copy(r[i.scopeLen:], r[i.parentLen:]) - r = r[:len(r)-i.parentLen+i.scopeLen] - return r -} - -// buildRow builds the result set row using the rows from the primary and secondary tables -func (i *existsIter) buildRow(primary, secondary sql.Row) sql.Row { - row := make(sql.Row, i.rowSize) - copy(row, primary) - copy(row[len(primary):], secondary) - return row -} - -func (i *existsIter) Close(ctx *sql.Context) (err error) { - if i.primary != nil { - if err = i.primary.Close(ctx); err != nil { - return err - } - } - return err -} - func newFullJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) { - leftIter, err := b.Build(ctx, j.Left(), row) + js, span, err := newJoinState(ctx, b, j, row, "plan.fullJoinIter") if err != nil { return nil, err } - return &fullJoinIter{ + return sql.NewSpanIter(span, &fullJoinIter{ + joinState: js, parentRow: row, - l: leftIter, - rp: j.Right(), - cond: j.Filter, - scopeLen: j.ScopeLen, - rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()), seenLeft: make(map[uint64]struct{}), seenRight: make(map[uint64]struct{}), - leftLen: len(j.Left().Schema()), - rightLen: len(j.Right().Schema()), - b: b, - }, nil + }), nil } // fullJoinIter implements full join as a union of left and right join: // FJ(A,B) => U(LJ(A,B), RJ(A,B)). The current algorithm will have a // runtime and memory complexity O(m+n). type fullJoinIter struct { - rp sql.Node - b sql.NodeExecBuilder - r sql.RowIter - cond sql.Expression - l sql.RowIter + joinState seenLeft map[uint64]struct{} seenRight map[uint64]struct{} - leftRow sql.Row parentRow sql.Row - rowSize int - leftLen int - rightLen int - scopeLen int leftDone bool } @@ -436,93 +502,79 @@ func (i *fullJoinIter) Next(ctx *sql.Context) (sql.Row, error) { if i.leftDone { break } - if i.leftRow == nil { - r, err := i.l.Next(ctx) + if i.secondaryRowIter == nil { + err := i.loadPrimary(ctx) if errors.Is(err, io.EOF) { i.leftDone = true - i.l = nil - i.r = nil + i.primaryRowIter = nil continue - } - if err != nil { + } else if err != nil { return nil, err } - i.leftRow = r - } - - if i.r == nil { - iter, err := i.b.Build(ctx, i.rp, i.leftRow) + iter, err := i.builder.Build(ctx, i.secondaryProvider, i.primaryRow) if err != nil { return nil, err } - i.r = iter + i.secondaryRowIter = iter } - rightRow, err := i.r.Next(ctx) + err := i.loadSecondary(ctx) if err == io.EOF { - key, err := hash.HashOf(ctx, nil, i.leftRow) + key, err := hash.HashOf(ctx, nil, i.leftColumns()) if err != nil { return nil, err } if _, ok := i.seenLeft[key]; !ok { // (left, null) only if we haven't matched left - ret := i.buildRow(i.leftRow, make(sql.Row, i.rightLen)) - i.r = nil - i.leftRow = nil - return i.removeParentRow(ret), nil + ret := i.makeLeftOuterNonMatchingResult() + err := i.resetSecondaryIter(ctx) + return ret, err } - i.r = nil - i.leftRow = nil + i.secondaryRowIter = nil continue } if err != nil { return nil, err } - row := i.buildRow(i.leftRow, rightRow) - matches, err := sql.EvaluateCondition(ctx, i.cond, row) + matches, err := sql.EvaluateCondition(ctx, i.cond, i.fullRow) if err != nil { return nil, err } if !sql.IsTrue(matches) { continue } - rkey, err := hash.HashOf(ctx, nil, rightRow) + rkey, err := hash.HashOf(ctx, nil, i.rightColumns()) if err != nil { return nil, err } i.seenRight[rkey] = struct{}{} - lKey, err := hash.HashOf(ctx, nil, i.leftRow) + lKey, err := hash.HashOf(ctx, nil, i.leftColumns()) if err != nil { return nil, err } i.seenLeft[lKey] = struct{}{} - return i.removeParentRow(row), nil + return i.makeResultRow(), nil } for { - if i.r == nil { + if i.secondaryRowIter == nil { // Phase 2 of FULL OUTER JOIN: return unmatched right rows as (null, rightRow). // Use parentRow instead of leftRow since leftRow is nil when left side is empty. - iter, err := i.b.Build(ctx, i.rp, i.parentRow) + iter, err := i.builder.Build(ctx, i.secondaryProvider, i.parentRow) if err != nil { return nil, err } - i.r = iter + i.secondaryRowIter = iter } - rightRow, err := i.r.Next(ctx) - if errors.Is(err, io.EOF) { - err := i.r.Close(ctx) - if err != nil { - return nil, err - } - return nil, io.EOF + if err := i.loadSecondary(ctx); err != nil { + return nil, err } - key, err := hash.HashOf(ctx, nil, rightRow) + key, err := hash.HashOf(ctx, nil, i.rightColumns()) if err != nil { return nil, err } @@ -530,152 +582,50 @@ func (i *fullJoinIter) Next(ctx *sql.Context) (sql.Row, error) { continue } // (null, right) only if we haven't matched right - ret := make(sql.Row, i.rowSize) - copy(ret[i.leftLen:], rightRow) - return i.removeParentRow(ret), nil + return i.makeRightOuterNonMatchingResult(), nil } } -func (i *fullJoinIter) removeParentRow(r sql.Row) sql.Row { - copy(r[i.scopeLen:], r[len(i.parentRow):]) - r = r[:len(r)-len(i.parentRow)+i.scopeLen] - return r -} - -// buildRow builds the result set row using the rows from the primary and secondary tables -func (i *fullJoinIter) buildRow(primary, secondary sql.Row) sql.Row { - row := make(sql.Row, i.rowSize) - copy(row, i.parentRow) - copy(row[len(i.parentRow):], primary) - copy(row[len(i.parentRow)+len(primary):], secondary) - return row -} - -func (i *fullJoinIter) Close(ctx *sql.Context) (err error) { - if i.l != nil { - err = i.l.Close(ctx) - } - - if i.r != nil { - if err == nil { - err = i.r.Close(ctx) - } else { - i.r.Close(ctx) - } - } - - return err -} - type crossJoinIterator struct { - l sql.RowIter - r sql.RowIter - rp sql.Node - b sql.NodeExecBuilder - - primaryRow sql.Row - - rowSize int - scopeLen int - parentLen int + joinState } func newCrossJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) { - var left, right string - if leftTable, ok := j.Left().(sql.Nameable); ok { - left = leftTable.Name() - } else { - left = reflect.TypeOf(j.Left()).String() - } - - if rightTable, ok := j.Right().(sql.Nameable); ok { - right = rightTable.Name() - } else { - right = reflect.TypeOf(j.Right()).String() - } - - span, ctx := ctx.Span("plan.CrossJoin", trace.WithAttributes( - attribute.String("left", left), - attribute.String("right", right), - )) - - l, err := b.Build(ctx, j.Left(), row) + js, span, err := newJoinState(ctx, b, j, row, "plan.crossJoinIter") if err != nil { - span.End() return nil, err } - parentLen := len(row) - - primaryRow := make(sql.Row, parentLen+len(j.Left().Schema())) - copy(primaryRow, row) - return sql.NewSpanIter(span, &crossJoinIterator{ - b: b, - l: l, - rp: j.Right(), - - primaryRow: primaryRow, - - rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()), - scopeLen: j.ScopeLen, - parentLen: parentLen, + joinState: js, }), nil } func (i *crossJoinIterator) Next(ctx *sql.Context) (sql.Row, error) { for { - if i.r == nil { - r, err := i.l.Next(ctx) - if err != nil { + if i.secondaryRowIter == nil { + if err := i.loadPrimary(ctx); err != nil { return nil, err } - copy(i.primaryRow[i.parentLen:], r) - iter, err := i.b.Build(ctx, i.rp, i.primaryRow) + iter, err := i.builder.Build(ctx, i.secondaryProvider, i.primaryRow) if err != nil { return nil, err } - i.r = iter + i.secondaryRowIter = iter } - rightRow, err := i.r.Next(ctx) + err := i.loadSecondary(ctx) if err == io.EOF { - i.r = nil continue - } - - if err != nil { + } else if err != nil { return nil, err } - row := make(sql.Row, i.rowSize) - copy(row, i.primaryRow) - copy(row[len(i.primaryRow):], rightRow) - return i.removeParentRow(row), nil + return i.makeResultRow(), nil } } -func (i *crossJoinIterator) removeParentRow(r sql.Row) sql.Row { - copy(r[i.scopeLen:], r[i.parentLen:]) - r = r[:len(r)-i.parentLen+i.scopeLen] - return r -} - -func (i *crossJoinIterator) Close(ctx *sql.Context) (err error) { - if i.l != nil { - err = i.l.Close(ctx) - } - if i.r != nil { - if err == nil { - err = i.r.Close(ctx) - } else { - i.r.Close(ctx) - } - } - return err -} - // lateralJoinIterator is an iterator that performs a lateral join. // A LateralJoin is a join where the right side is a subquery that can reference the left side, like through a filter. // MySQL Docs: https://dev.mysql.com/doc/refman/8.0/en/lateral-derived-tables.html @@ -704,123 +654,38 @@ func (i *crossJoinIterator) Close(ctx *sql.Context) (err error) { // +---+---+ // cond is passed to the filter iter to be evaluated. type lateralJoinIterator struct { - primary sql.RowIter - secondary sql.RowIter - secondaryNode sql.Node - cond sql.Expression - b sql.NodeExecBuilder - // primaryRow contains the parent row concatenated with the current row from the primary child, - // and is used to build the secondary child iter. - primaryRow sql.Row - // secondaryRow contains the current row from the secondary child. - secondaryRow sql.Row - rowSize int - scopeLen int - parentLen int - jType plan.JoinType - foundMatch bool + joinState } func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, parentRow sql.Row) (sql.RowIter, error) { - var left, right string - if leftTable, ok := j.Left().(sql.Nameable); ok { - left = leftTable.Name() - } else { - left = reflect.TypeOf(j.Left()).String() - } - if rightTable, ok := j.Right().(sql.Nameable); ok { - right = rightTable.Name() - } else { - right = reflect.TypeOf(j.Right()).String() - } - - span, ctx := ctx.Span("plan.LateralJoin", trace.WithAttributes( - attribute.String("left", left), - attribute.String("right", right), - )) - l, err := b.Build(ctx, j.Left(), parentRow) + js, span, err := newJoinState(ctx, b, j, parentRow, "plan.lateralJoinIter") if err != nil { - span.End() return nil, err } - parentLen := len(parentRow) - - primaryRow := make(sql.Row, parentLen+len(j.Left().Schema())) - copy(primaryRow, parentRow) - return sql.NewSpanIter(span, &lateralJoinIterator{ - primaryRow: primaryRow, - parentLen: len(parentRow), - primary: l, - secondaryNode: j.Right(), - cond: j.Filter, - jType: j.Op, - rowSize: len(parentRow) + len(j.Left().Schema()) + len(j.Right().Schema()), - scopeLen: j.ScopeLen, - b: b, + joinState: js, }), nil } -func (i *lateralJoinIterator) loadPrimary(ctx *sql.Context) error { - lRow, err := i.primary.Next(ctx) - if err != nil { - return err - } - copy(i.primaryRow[i.parentLen:], lRow) - i.foundMatch = false - return nil -} - func (i *lateralJoinIterator) buildSecondary(ctx *sql.Context) error { - prepended, _, err := transform.Node(i.secondaryNode, plan.PrependRowInPlan(i.primaryRow, true)) + prepended, _, err := transform.Node(i.secondaryProvider, plan.PrependRowInPlan(i.primaryRow[i.parentLen:], true)) if err != nil { return err } - iter, err := i.b.Build(ctx, prepended, i.primaryRow) - if err != nil { - return err - } - i.secondary = iter - return nil -} - -func (i *lateralJoinIterator) loadSecondary(ctx *sql.Context) error { - sRow, err := i.secondary.Next(ctx) + iter, err := i.builder.Build(ctx, prepended, i.primaryRow) if err != nil { return err } - i.secondaryRow = sRow[len(i.primaryRow):] + i.secondaryRowIter = iter return nil } -func (i *lateralJoinIterator) buildRow(primaryRow, secondaryRow sql.Row) sql.Row { - row := make(sql.Row, i.rowSize) - copy(row, primaryRow) - copy(row[len(primaryRow):], secondaryRow) - return row -} - -func (i *lateralJoinIterator) removeParentRow(r sql.Row) sql.Row { - copy(r[i.scopeLen:], r[i.parentLen:]) - r = r[:len(r)-i.parentLen+i.scopeLen] - return r -} - -func (i *lateralJoinIterator) reset(ctx *sql.Context) (err error) { - if i.secondary != nil { - err = i.secondary.Close(ctx) - i.secondary = nil - } - i.secondaryRow = nil - return -} - func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) { for { // secondary being nil means we've exhausted all secondary rows for the current primary. - if i.secondary == nil { + if i.secondaryRowIter == nil { if err := i.loadPrimary(ctx); err != nil { return nil, err } @@ -830,21 +695,22 @@ func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) { } if err := i.loadSecondary(ctx); err != nil { if errors.Is(err, io.EOF) { - if !i.foundMatch && i.jType == plan.JoinTypeLateralLeft { - res := i.buildRow(i.primaryRow, nil) - if resetErr := i.reset(ctx); resetErr != nil { + if !i.foundMatch && i.joinType == plan.JoinTypeLateralLeft { + res := make(sql.Row, i.fullRowSize()) + copy(res, i.primaryRow) + if resetErr := i.resetSecondaryIter(ctx); resetErr != nil { return nil, resetErr } - return i.removeParentRow(res), nil + return res, nil } - if resetErr := i.reset(ctx); resetErr != nil { + if resetErr := i.resetSecondaryIter(ctx); resetErr != nil { return nil, resetErr } continue } return nil, err } - row := i.buildRow(i.primaryRow, i.secondaryRow) + row := i.fullRow if i.cond != nil { if res, err := sql.EvaluateCondition(ctx, i.cond, row); err != nil { return nil, err @@ -854,23 +720,6 @@ func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) { } i.foundMatch = true - return i.removeParentRow(row), nil - } -} - -func (i *lateralJoinIterator) Close(ctx *sql.Context) error { - var pErr, sErr error - if i.primary != nil { - pErr = i.primary.Close(ctx) + return row.Copy(), nil } - if i.secondary != nil { - sErr = i.secondary.Close(ctx) - } - if pErr != nil { - return pErr - } - if sErr != nil { - return sErr - } - return nil } diff --git a/sql/rowexec/merge_join.go b/sql/rowexec/merge_join.go index 53e99ac46e..3f0d791029 100644 --- a/sql/rowexec/merge_join.go +++ b/sql/rowexec/merge_join.go @@ -376,7 +376,7 @@ func (i *mergeJoinIter) incMatch(ctx *sql.Context) error { if !i.leftDone { // rightBuf has already been validated, we don't need compare - copySubslice(i.fullRow, i.rightBuf[i.bufI], i.scopeLen+i.parentLen+i.leftRowLen) + copy(i.fullRow[i.scopeLen+i.parentLen+i.leftRowLen:], i.rightBuf[i.bufI]) i.bufI++ return nil } @@ -476,8 +476,11 @@ func (i *mergeJoinIter) peekMatch(ctx *sql.Context, iter sql.RowIter) (bool, sql return false, nil, err } + // strip outer scope rows from peek + peek = peek[i.scopeLen:] + // check if lookahead valid - copySubslice(i.fullRow, peek, off) + copy(i.fullRow[off:], peek) res, err := i.cmp.Compare(ctx, i.fullRow) if expression.ErrNilOperand.Is(err) { // revert change to output row if no match @@ -487,7 +490,7 @@ func (i *mergeJoinIter) peekMatch(ctx *sql.Context, iter sql.RowIter) (bool, sql } if res != 0 { // revert change to output row if no match - copySubslice(i.fullRow, restore, off) + copy(i.fullRow[off:], restore) } return res == 0, peek, nil } @@ -499,9 +502,7 @@ func (i *mergeJoinIter) exhausted() bool { // copySubslice copies |src| into |dst| starting at index |off| func copySubslice(dst, src sql.Row, off int) { - for i, v := range src { - dst[off+i] = v - } + copy(dst[off:], src) } // incLeft updates |i.fullRow|'s left row @@ -520,12 +521,11 @@ func (i *mergeJoinIter) incLeft(ctx *sql.Context) error { } else if err != nil { return err } + // strip outer scope rows from row + row = row[i.scopeLen:] } - off := i.scopeLen + i.parentLen - for j, v := range row { - i.fullRow[off+j] = v - } + copy(i.fullRow[i.scopeLen+i.parentLen:], row) return nil } @@ -545,12 +545,11 @@ func (i *mergeJoinIter) incRight(ctx *sql.Context) error { } else if err != nil { return err } + // strip outer scope rows from row + row = row[i.scopeLen:] } - off := i.scopeLen + i.parentLen + i.leftRowLen - for j, v := range row { - i.fullRow[off+j] = v - } + copy(i.fullRow[i.scopeLen+i.parentLen+i.leftRowLen:], row) return nil } diff --git a/sql/rowexec/node_builder.gen.go b/sql/rowexec/node_builder.gen.go index de22144b89..8277816f17 100644 --- a/sql/rowexec/node_builder.gen.go +++ b/sql/rowexec/node_builder.gen.go @@ -326,8 +326,6 @@ func (b *BaseBuilder) buildNodeExecNoAnalyze(ctx *sql.Context, n sql.Node, row s return b.buildLoadData(ctx, n, row) case *plan.ShowCharset: return b.buildShowCharset(ctx, n, row) - case *plan.StripRowNode: - return b.buildStripRowNode(ctx, n, row) case *plan.DropConstraint: return b.buildDropConstraint(ctx, n, row) case *plan.FlushPrivileges: diff --git a/sql/rowexec/other.go b/sql/rowexec/other.go index 5eacc932d3..e715835873 100644 --- a/sql/rowexec/other.go +++ b/sql/rowexec/other.go @@ -24,18 +24,6 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" ) -func (b *BaseBuilder) buildStripRowNode(ctx *sql.Context, n *plan.StripRowNode, row sql.Row) (sql.RowIter, error) { - childIter, err := b.buildNodeExec(ctx, n.Child, row) - if err != nil { - return nil, err - } - - return &stripRowIter{ - childIter, - n.NumCols, - }, nil -} - func (b *BaseBuilder) buildConcat(ctx *sql.Context, n *plan.Concat, row sql.Row) (sql.RowIter, error) { span, ctx := ctx.Span("plan.Concat") li, err := b.buildNodeExec(ctx, n.Left(), row) diff --git a/sql/rowexec/other_iters.go b/sql/rowexec/other_iters.go index 0df2369196..543b237d4c 100644 --- a/sql/rowexec/other_iters.go +++ b/sql/rowexec/other_iters.go @@ -368,20 +368,3 @@ func (ci *concatIter) Close(ctx *sql.Context) error { return nil } } - -type stripRowIter struct { - sql.RowIter - numCols int -} - -func (sri *stripRowIter) Next(ctx *sql.Context) (sql.Row, error) { - r, err := sri.RowIter.Next(ctx) - if err != nil { - return nil, err - } - return r[sri.numCols:], nil -} - -func (sri *stripRowIter) Close(ctx *sql.Context) error { - return sri.RowIter.Close(ctx) -}