Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions enginetest/engine_only_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,10 +664,13 @@ func TestTableFunctions(t *testing.T) {
memory.NormalDistTable{})

engine := enginetest.NewEngineWithProvider(t, harness, testDatabaseProvider)
harness = harness.WithProvider(engine.Analyzer.Catalog.DbProvider)

engine.EngineAnalyzer().ExecBuilder = rowexec.DefaultBuilder

engine, err := enginetest.RunSetupScripts(harness.NewContext(), engine, setup.MydbData, true)
require.NoError(t, err)
_ = harness.NewSession()

for _, test := range queries.TableFunctionScriptTests {
enginetest.TestScriptWithEngine(t, engine, harness, test)
Expand Down
14 changes: 10 additions & 4 deletions enginetest/memory_harness.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type MemoryHarness struct {
parallelism int
numTablePartitions int
readonly bool
provider *memory.DbProvider
provider sql.DatabaseProvider
indexDriverInitializer IndexDriverInitializer
driver sql.IndexDriver
nativeIndexSupport bool
Expand Down Expand Up @@ -282,7 +282,7 @@ func (m *MemoryHarness) IndexDriver(dbs []sql.Database) sql.IndexDriver {
func (m *MemoryHarness) newDatabase(name string) sql.Database {
ctx := m.NewContext()

err := m.getProvider().CreateDatabase(ctx, name)
err := m.getProvider().(*memory.DbProvider).CreateDatabase(ctx, name)
if err != nil {
panic(err)
}
Expand All @@ -291,7 +291,13 @@ func (m *MemoryHarness) newDatabase(name string) sql.Database {
return db
}

func (m *MemoryHarness) getProvider() *memory.DbProvider {
func (m *MemoryHarness) WithProvider(provider sql.DatabaseProvider) *MemoryHarness {
ret := *m
ret.provider = provider
return &ret
}

func (m *MemoryHarness) getProvider() sql.DatabaseProvider {
m.mu.Lock()
defer m.mu.Unlock()

Expand All @@ -309,7 +315,7 @@ func (m *MemoryHarness) NewDatabaseProvider() sql.MutableDatabaseProvider {
}

func (m *MemoryHarness) Provider() *memory.DbProvider {
return m.getProvider()
return m.getProvider().(*memory.DbProvider)
}

func (m *MemoryHarness) NewDatabases(names ...string) []sql.Database {
Expand Down
10 changes: 10 additions & 0 deletions enginetest/queries/table_func_scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ import (
)

var TableFunctionScriptTests = []ScriptTest{
{
SetUpScript: []string{
"create database if not exists mydb",
"use mydb",
"create table xy (x int primary key, y int)",
"insert into xy values (0,1), (1,2), (2,3)",
},
Query: "select y from table_func('z',2) join xy t on y = z",
Expected: []sql.Row{{2}},
},
{
Query: "select * from sequence_table('y',2) seq1 where y in (select SEQ2.x from table_func('x', 1) seq2)",
Expected: []sql.Row{{1}},
Expand Down
127 changes: 75 additions & 52 deletions sql/analyzer/fix_exec_indexes.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,58 +413,7 @@ func (s *idxScope) finalizeSelf(n sql.Node) (sql.Node, error) {
nn.OnDupExprs = s.expressions
return nn.WithChecks(s.checks), nil
default:

// fill in column ids
switch n := n.(type) {
case sql.Projector:
for _, e := range n.ProjectedExprs() {
if ide, ok := e.(sql.IdExpression); ok {
s.ids = append(s.ids, ide.Id())
} else {
s.ids = append(s.ids, 0)
}
}
case *plan.ResolvedTable, *plan.IndexedTableAccess, *plan.SubqueryAlias, *plan.RecursiveTable, *plan.RecursiveCte, *plan.SetOp, *plan.ValueDerivedTable, *plan.JSONTable:
if rt, ok := n.(*plan.ResolvedTable); ok && plan.IsDualTable(rt.Table) {
s.ids = append(s.ids, 0)
break
}
cols := n.(plan.TableIdNode).Columns()
if tn, ok := n.(sql.TableNode); ok {
if pkt, ok := tn.UnderlyingTable().(sql.PrimaryKeyTable); ok && len(pkt.PrimaryKeySchema().Schema) != len(n.Schema()) {
firstcol, _ := cols.Next(1)
for _, c := range n.Schema() {
ord := pkt.PrimaryKeySchema().IndexOfColName(c.Name)
colId := firstcol + sql.ColumnId(ord)
s.ids = append(s.ids, colId)
}
break
}
}
cols.ForEach(func(col sql.ColumnId) {
s.ids = append(s.ids, col)
})

case *plan.TableCountLookup:
s.ids = append(s.ids, n.Id())
case *plan.JoinNode:
if n.Op.IsPartial() {
s.ids = append(s.ids, s.childScopes[0].ids...)
} else {
s.ids = append(s.ids, s.childScopes[0].ids...)
s.ids = append(s.ids, s.childScopes[1].ids...)
}
case *plan.ShowStatus:
for i := range n.Schema() {
s.ids = append(s.ids, sql.ColumnId(i+1))
}
case *plan.Concat:
s.ids = append(s.ids, s.childScopes[0].ids...)
default:
for _, cs := range s.childScopes {
s.ids = append(s.ids, cs.ids...)
}
}
s.ids = columnIdsForNode(n)

s.addSchema(n.Schema())
var err error
Expand Down Expand Up @@ -505,6 +454,80 @@ func (s *idxScope) finalizeSelf(n sql.Node) (sql.Node, error) {
}
}

// columnIdsForNode collects the column ids of a node's return schema.
// Projector nodes can return a subset of the full sql.PrimaryTableSchema.
// todo: pruning projections should update plan.TableIdNode .Columns()
// to avoid schema/column discontinuities.
func columnIdsForNode(n sql.Node) []sql.ColumnId {
var ret []sql.ColumnId
switch n := n.(type) {
case sql.Projector:
for _, e := range n.ProjectedExprs() {
if ide, ok := e.(sql.IdExpression); ok {
ret = append(ret, ide.Id())
} else {
ret = append(ret, 0)
}
}
case *plan.TableCountLookup:
ret = append(ret, n.Id())
case *plan.TableAlias:
// Table alias's child either exposes 1) child ids or 2) is custom
// table function. We currently do not update table columns in response
// to table pruning, so we need to manually distinguish these cases.
// todo: prune columns should update column ids and table alias ids
switch n.Child.(type) {
case sql.TableFunction:
// todo: table functions that implement sql.Projector are not going
// to work. Need to fix prune.
n.Columns().ForEach(func(col sql.ColumnId) {
ret = append(ret, col)
})
default:
ret = append(ret, columnIdsForNode(n.Child)...)
}
case plan.TableIdNode:
if rt, ok := n.(*plan.ResolvedTable); ok && plan.IsDualTable(rt.Table) {
ret = append(ret, 0)
break
}

cols := n.(plan.TableIdNode).Columns()
if tn, ok := n.(sql.TableNode); ok {
if pkt, ok := tn.UnderlyingTable().(sql.PrimaryKeyTable); ok && len(pkt.PrimaryKeySchema().Schema) != len(n.Schema()) {
firstcol, _ := cols.Next(1)
for _, c := range n.Schema() {
ord := pkt.PrimaryKeySchema().IndexOfColName(c.Name)
colId := firstcol + sql.ColumnId(ord)
ret = append(ret, colId)
}
break
}
}
cols.ForEach(func(col sql.ColumnId) {
ret = append(ret, col)
})
case *plan.JoinNode:
if n.Op.IsPartial() {
ret = append(ret, columnIdsForNode(n.Left())...)
} else {
ret = append(ret, columnIdsForNode(n.Left())...)
ret = append(ret, columnIdsForNode(n.Right())...)
}
case *plan.ShowStatus:
for i := range n.Schema() {
ret = append(ret, sql.ColumnId(i+1))
}
case *plan.Concat:
ret = append(ret, columnIdsForNode(n.Left())...)
default:
for _, c := range n.Children() {
ret = append(ret, columnIdsForNode(c)...)
}
}
return ret
}

func fixExprToScope(e sql.Expression, scopes ...*idxScope) sql.Expression {
newScope := &idxScope{}
for _, s := range scopes {
Expand Down