Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thread context through to function creation #455

Merged
merged 4 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (e *Engine) QueryNodeWithBindings(
}

if len(bindings) > 0 {
parsed, err = plan.ApplyBindings(parsed, bindings)
parsed, err = plan.ApplyBindings(ctx, parsed, bindings)
if err != nil {
return nil, nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions enginetest/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ func TestAnalyzer(t *testing.T) {
db, err := engine.Catalog.Database("foo")
require.NoError(t, err)
greatest, err := function.NewGreatest(
sql.NewEmptyContext(),
expression.NewLiteral("abc123", sql.LongText),
expression.NewLiteral("cde456", sql.LongText),
)
Expand All @@ -223,6 +224,7 @@ func TestAnalyzer(t *testing.T) {
db, err := engine.Catalog.Database("foo")
require.NoError(t, err)
timestamp, err := function.NewTimestamp(
sql.NewEmptyContext(),
expression.NewLiteral("20200101:120000Z", sql.LongText),
)
require.NoError(t, err)
Expand Down
4 changes: 2 additions & 2 deletions enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -2950,7 +2950,7 @@ func (c customFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return int64(5), nil
}

func (c customFunc) WithChildren(children ...sql.Expression) (sql.Expression, error) {
func (c customFunc) WithChildren(ctx *sql.Context, children ...sql.Expression) (sql.Expression, error) {
return &customFunc{expression.UnaryExpression{children[0]}}, nil
}

Expand All @@ -2960,7 +2960,7 @@ func TestColumnDefaults(t *testing.T, harness Harness) {

err := e.Catalog.Register(sql.Function1{
Name: "customfunc",
Fn: func(e1 sql.Expression) sql.Expression {
Fn: func(ctx *sql.Context, e1 sql.Expression) sql.Expression {
return &customFunc{expression.UnaryExpression{e1}}
},
})
Expand Down
2 changes: 1 addition & 1 deletion memory/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ func (t *Table) addColumnToSchema(ctx *sql.Context, newCol *sql.Column, order *s
if i == newColIdx {
continue
}
newDefault, _ := expression.TransformUp(newSchCol.Default, func(expr sql.Expression) (sql.Expression, error) {
newDefault, _ := expression.TransformUp(ctx, newSchCol.Default, func(expr sql.Expression) (sql.Expression, error) {
if expr, ok := expr.(*expression.GetField); ok {
return expr.WithIndex(newSch.IndexOf(expr.Name(), t.name)), nil
}
Expand Down
16 changes: 8 additions & 8 deletions sql/analyzer/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ func flattenAggregationExpressions(ctx *sql.Context, a *Analyzer, n sql.Node, sc
return n, nil
}

return flattenedWindow(n.SelectExprs, n.Child)
return flattenedWindow(ctx, n.SelectExprs, n.Child)
case *plan.GroupBy:
if !hasHiddenAggregations(n.SelectedExprs) {
return n, nil
}

return flattenedGroupBy(n.SelectedExprs, n.GroupByExprs, n.Child)
return flattenedGroupBy(ctx, n.SelectedExprs, n.GroupByExprs, n.Child)
default:
return n, nil
}
})
}

func flattenedGroupBy(projection, grouping []sql.Expression, child sql.Node) (sql.Node, error) {
newProjection, newAggregates, err := replaceAggregatesWithGetFieldProjections(projection)
func flattenedGroupBy(ctx *sql.Context, projection, grouping []sql.Expression, child sql.Node) (sql.Node, error) {
newProjection, newAggregates, err := replaceAggregatesWithGetFieldProjections(ctx, projection)
if err != nil {
return nil, err
}
Expand All @@ -72,13 +72,13 @@ func flattenedGroupBy(projection, grouping []sql.Expression, child sql.Node) (sq
// new set of project expressions, and the new set of aggregations. The former always matches the size of the projection
// expressions passed in. The latter will have the size of the number of aggregate expressions contained in the input
// slice.
func replaceAggregatesWithGetFieldProjections(projection []sql.Expression) (projections, aggregations []sql.Expression, err error) {
func replaceAggregatesWithGetFieldProjections(ctx *sql.Context, projection []sql.Expression) (projections, aggregations []sql.Expression, err error) {
var newProjection = make([]sql.Expression, len(projection))
var newAggregates []sql.Expression

for i, p := range projection {
var transformed bool
e, err := expression.TransformUp(p, func(e sql.Expression) (sql.Expression, error) {
e, err := expression.TransformUp(ctx, p, func(e sql.Expression) (sql.Expression, error) {
switch e := e.(type) {
case sql.Aggregation, sql.WindowAggregation:
// continue on
Expand Down Expand Up @@ -110,8 +110,8 @@ func replaceAggregatesWithGetFieldProjections(projection []sql.Expression) (proj
return newProjection, newAggregates, nil
}

func flattenedWindow(projection []sql.Expression, child sql.Node) (sql.Node, error) {
newProjection, newAggregates, err := replaceAggregatesWithGetFieldProjections(projection)
func flattenedWindow(ctx *sql.Context, projection []sql.Expression, child sql.Node) (sql.Node, error) {
newProjection, newAggregates, err := replaceAggregatesWithGetFieldProjections(ctx, projection)
if err != nil {
return nil, err
}
Expand Down
9 changes: 9 additions & 0 deletions sql/analyzer/aggregations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

func TestFlattenAggregationExprs(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

table := memory.NewTable("foo", sql.Schema{
{Name: "a", Type: sql.Int64, Source: "foo"},
Expand All @@ -47,6 +48,7 @@ func TestFlattenAggregationExprs(t *testing.T) {
[]sql.Expression{
expression.NewArithmetic(
aggregation.NewSum(
ctx,
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
),
expression.NewLiteral(int64(1), sql.Int64),
Expand All @@ -70,6 +72,7 @@ func TestFlattenAggregationExprs(t *testing.T) {
plan.NewGroupBy(
[]sql.Expression{
aggregation.NewSum(
ctx,
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
),
},
Expand All @@ -87,6 +90,7 @@ func TestFlattenAggregationExprs(t *testing.T) {
expression.NewAlias("x",
expression.NewArithmetic(
aggregation.NewSum(
ctx,
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
),
expression.NewLiteral(int64(1), sql.Int64),
Expand All @@ -110,6 +114,7 @@ func TestFlattenAggregationExprs(t *testing.T) {
plan.NewGroupBy(
[]sql.Expression{
aggregation.NewSum(
ctx,
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
),
},
Expand All @@ -126,9 +131,11 @@ func TestFlattenAggregationExprs(t *testing.T) {
[]sql.Expression{
expression.NewArithmetic(
aggregation.NewSum(
ctx,
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
),
aggregation.NewCount(
ctx,
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
),
"/",
Expand All @@ -153,9 +160,11 @@ func TestFlattenAggregationExprs(t *testing.T) {
plan.NewGroupBy(
[]sql.Expression{
aggregation.NewSum(
ctx,
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
),
aggregation.NewCount(
ctx,
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
),
expression.NewGetFieldWithTable(1, sql.Int64, "foo", "b", false),
Expand Down
8 changes: 4 additions & 4 deletions sql/analyzer/aliases.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,11 @@ func aliasesDefinedInNode(n sql.Node) []string {
// normalizeExpressions returns the expressions given after normalizing them to replace table and expression aliases
// with their underlying names. This is necessary to match such expressions against those declared by implementors of
// various interfaces that declare expressions to handle, such as Index.Expressions(), FilteredTable, etc.
func normalizeExpressions(tableAliases TableAliases, expr ...sql.Expression) []sql.Expression {
func normalizeExpressions(ctx *sql.Context, tableAliases TableAliases, expr ...sql.Expression) []sql.Expression {
expressions := make([]sql.Expression, len(expr))

for i, e := range expr {
expressions[i] = normalizeExpression(tableAliases, e)
expressions[i] = normalizeExpression(ctx, tableAliases, e)
}

return expressions
Expand All @@ -178,9 +178,9 @@ func normalizeExpressions(tableAliases TableAliases, expr ...sql.Expression) []s
// normalizeExpression returns the expression given after normalizing it to replace table aliases with their underlying
// names. This is necessary to match such expressions against those declared by implementors of various interfaces that
// declare expressions to handle, such as Index.Expressions(), FilteredTable, etc.
func normalizeExpression(tableAliases TableAliases, e sql.Expression) sql.Expression {
func normalizeExpression(ctx *sql.Context, tableAliases TableAliases, e sql.Expression) sql.Expression {
// If the query has table aliases, use them to replace any table aliases in column expressions
normalized, _ := expression.TransformUp(e, func(e sql.Expression) (sql.Expression, error) {
normalized, _ := expression.TransformUp(ctx, e, func(e sql.Expression) (sql.Expression, error) {
if field, ok := e.(*expression.GetField); ok {
table := field.Table()
if rt, ok := tableAliases[table]; ok {
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/apply_indexes_for_subquery_comparisons.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func getIndexedInSubqueryFilter(ctx *sql.Context, a *Analyzer, left, right sql.E
return nil
}
defer indexes.releaseUsedIndexes()
idx := indexes.IndexByExpression(ctx, ctx.GetCurrentDatabase(), normalizeExpressions(tableAliases, gf)...)
idx := indexes.IndexByExpression(ctx, ctx.GetCurrentDatabase(), normalizeExpressions(ctx, tableAliases, gf)...)
if idx == nil {
return nil
}
Expand Down
8 changes: 4 additions & 4 deletions sql/analyzer/apply_indexes_from_outer_scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func getOuterScopeIndexes(
for table, idx := range indexes {
if exprsByTable[table] != nil {
// creating a key expression can fail in some cases, just skip this table
keyExpr := createIndexKeyExpr(idx, exprsByTable[table], tableAliases)
keyExpr := createIndexKeyExpr(ctx, idx, exprsByTable[table], tableAliases)
if keyExpr == nil {
continue
}
Expand All @@ -168,14 +168,14 @@ func getOuterScopeIndexes(
}

// createIndexKeyExpr returns a slice of expressions to be used when creating an index lookup key for the table given.
func createIndexKeyExpr(idx sql.Index, joinExprs []*joinColExpr, tableAliases TableAliases) []sql.Expression {
func createIndexKeyExpr(ctx *sql.Context, idx sql.Index, joinExprs []*joinColExpr, tableAliases TableAliases) []sql.Expression {

keyExprs := make([]sql.Expression, len(idx.Expressions()))

IndexExpressions:
for i, idxExpr := range idx.Expressions() {
for j := range joinExprs {
if idxExpr == normalizeExpression(tableAliases, joinExprs[j].colExpr).String() {
if idxExpr == normalizeExpression(ctx, tableAliases, joinExprs[j].colExpr).String() {
keyExprs[i] = joinExprs[j].comparand
continue IndexExpressions
}
Expand Down Expand Up @@ -234,7 +234,7 @@ func getSubqueryIndexes(
indexCols := exprsByTable[table]
if indexCols != nil {
idx := ia.IndexByExpression(ctx, ctx.GetCurrentDatabase(),
normalizeExpressions(tableAliases, extractComparands(indexCols)...)...)
normalizeExpressions(ctx, tableAliases, extractComparands(indexCols)...)...)
if idx != nil {
result[indexCols[0].comparandCol.Table()] = idx
}
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/check_constraints.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func loadChecks(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.No
transformedChecks := make(sql.CheckConstraints, len(checks))

for i, check := range checks {
newExpr, err := expression.TransformUp(check.Expr, func(e sql.Expression) (sql.Expression, error) {
newExpr, err := expression.TransformUp(ctx, check.Expr, func(e sql.Expression) (sql.Expression, error) {
if t, ok := e.(*expression.UnresolvedColumn); ok {
name := t.Name()
strings.Replace(name, "`", "", -1) // remove any preexisting backticks
Expand Down
5 changes: 3 additions & 2 deletions sql/analyzer/expand_stars_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
)

func TestExpandStars(t *testing.T) {
ctx := sql.NewEmptyContext()
f := getRule("expand_stars")

table := memory.NewTable("mytable", sql.Schema{
Expand Down Expand Up @@ -152,7 +153,7 @@ func TestExpandStars(t *testing.T) {
[]sql.Expression{
expression.NewStar(),
expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false),
mustExpr(window.NewRowNumber().(*window.RowNumber).WithWindow(
mustExpr(window.NewRowNumber(ctx).(*window.RowNumber).WithWindow(
sql.NewWindow(
[]sql.Expression{
expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false),
Expand All @@ -172,7 +173,7 @@ func TestExpandStars(t *testing.T) {
expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "a", false),
expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false),
expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false),
mustExpr(window.NewRowNumber().(*window.RowNumber).WithWindow(
mustExpr(window.NewRowNumber(ctx).(*window.RowNumber).WithWindow(
sql.NewWindow(
[]sql.Expression{
expression.NewGetFieldWithTable(1, sql.Int32, "mytable", "b", false),
Expand Down
12 changes: 6 additions & 6 deletions sql/analyzer/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,19 @@ func newFilterSet(filtersByTable filtersByTable, tableAliases TableAliases) *fil

// availableFiltersForTable returns the filters that are still available for the table given (not previously marked
// handled)
func (fs *filterSet) availableFiltersForTable(table string) []sql.Expression {
func (fs *filterSet) availableFiltersForTable(ctx *sql.Context, table string) []sql.Expression {
filters, ok := fs.filtersByTable[table]
if !ok {
return nil
}
return fs.subtractUsedIndexes(subtractExprSet(filters, fs.handledFilters))
return fs.subtractUsedIndexes(ctx, subtractExprSet(filters, fs.handledFilters))
}

// availableFilters returns the filters that are still available (not previously marked handled)
func (fs *filterSet) availableFilters() []sql.Expression {
func (fs *filterSet) availableFilters(ctx *sql.Context) []sql.Expression {
var available []sql.Expression
for _, es := range fs.filtersByTable {
available = append(available, fs.subtractUsedIndexes(subtractExprSet(es, fs.handledFilters))...)
available = append(available, fs.subtractUsedIndexes(ctx, subtractExprSet(es, fs.handledFilters))...)
}
return available
}
Expand Down Expand Up @@ -191,13 +191,13 @@ func subtractExprSet(all, toSubtract []sql.Expression) []sql.Expression {
}

// subtractUsedIndexes returns the filter expressions given with used indexes subtracted off.
func (fs *filterSet) subtractUsedIndexes(all []sql.Expression) []sql.Expression {
func (fs *filterSet) subtractUsedIndexes(ctx *sql.Context, all []sql.Expression) []sql.Expression {
var remainder []sql.Expression

// Careful: index expressions are always normalized (contain actual table names), whereas filter expressions can
// contain aliases for both expressions and table names. We want to normalize all expressions for comparison, but
// return the original expressions.
normalized := normalizeExpressions(fs.tableAliases, all...)
normalized := normalizeExpressions(ctx, fs.tableAliases, all...)

for i, e := range normalized {
var found bool
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/filters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func TestExprToTableFilters(t *testing.T) {
assert.Error(t, err)

_, err = exprToTableFilters(expression.NewAnd(
expression.NewEquals(lit(1), mustExpr(function.NewRand())),
expression.NewEquals(lit(1), mustExpr(function.NewRand(sql.NewEmptyContext()))),
expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "f", false),
))
assert.Error(t, err)
Expand Down
Loading