diff --git a/go/libraries/doltcore/sqle/dfunctions/init.go b/go/libraries/doltcore/sqle/dfunctions/init.go index c8cc1be1fcf..6fb05d32940 100644 --- a/go/libraries/doltcore/sqle/dfunctions/init.go +++ b/go/libraries/doltcore/sqle/dfunctions/init.go @@ -26,6 +26,7 @@ var DoltFunctions = []sql.Function{ sql.Function2{Name: HasAncestorFuncName, Fn: NewHasAncestor}, sql.Function1{Name: HashOfTableFuncName, Fn: NewHashOfTable}, sql.FunctionN{Name: HashOfDatabaseFuncName, Fn: NewHashOfDatabase}, + sql.Function1{Name: JoinCostFuncName, Fn: NewJoinCost}, } // DolthubApiFunctions are the DoltFunctions that get exposed to Dolthub Api. diff --git a/go/libraries/doltcore/sqle/dfunctions/join_cost.go b/go/libraries/doltcore/sqle/dfunctions/join_cost.go new file mode 100644 index 00000000000..6b4416acc64 --- /dev/null +++ b/go/libraries/doltcore/sqle/dfunctions/join_cost.go @@ -0,0 +1,131 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dfunctions + +import ( + "fmt" + "strings" + + gms "github.com/dolthub/go-mysql-server" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/planbuilder" + "github.com/dolthub/go-mysql-server/sql/types" + + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" +) + +const JoinCostFuncName = "dolt_join_cost" + +type JoinCost struct { + q sql.Expression +} + +var _ sql.FunctionExpression = (*JoinCost)(nil) +var _ sql.CollationCoercible = (*JoinCost)(nil) + +// NewJoinCost returns a new JoinCost expression. +func NewJoinCost(e sql.Expression) sql.Expression { + return &JoinCost{q: e} +} + +// FunctionName implements sql.FunctionExpression +func (c *JoinCost) FunctionName() string { + return "JoinCost" +} + +// Description implements sql.FunctionExpression +func (c *JoinCost) Description() string { + return "print the memo tree" +} + +// Type implements the Expression interface. +func (c *JoinCost) Type() sql.Type { return types.LongText } + +// CollationCoercibility implements the interface sql.CollationCoercible. +func (*JoinCost) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + return ctx.GetCollation(), 4 +} + +// IsNullable implements the Expression interface. +func (c *JoinCost) IsNullable() bool { + return false +} + +func (c *JoinCost) String() string { + return fmt.Sprintf("%s(%s)", c.FunctionName(), c.q) +} + +// Eval implements the Expression interface. +func (c *JoinCost) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + q, err := exprToStringLit(ctx, c.q) + if err != nil { + return "", err + } + + dSess := dsess.DSessFromSess(ctx.Session) + pro := dSess.Provider() + eng := gms.NewDefault(pro) + + binder := planbuilder.New(ctx, eng.Analyzer.Catalog, eng.EventScheduler, eng.Parser) + parsed, _, _, qFlags, err := binder.Parse(q, nil, false) + if err != nil { + return nil, err + } + scope := plan.Scope{} + _, err = eng.Analyzer.Analyze(ctx, parsed, &scope, qFlags) + if err != nil { + ctx.GetLogger().Debug("join cost error", err) + } + + ret := strings.Builder{} + sep := "" + for _, t := range scope.JoinTrees { + ret.WriteString(sep) + ret.WriteString(t) + sep = "\n" + } + return ret.String(), nil +} + +// Resolved implements the Expression interface. +func (c *JoinCost) Resolved() bool { + return true +} + +// Children implements the Expression interface. +func (c *JoinCost) Children() []sql.Expression { + return nil +} + +// WithChildren implements the Expression interface. +func (c *JoinCost) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 3) + } + return c, nil +} + +func exprToStringLit(ctx *sql.Context, e sql.Expression) (string, error) { + q, err := e.Eval(ctx, nil) + if err != nil { + return "", err + } + qStr, isStr := q.(string) + if !isStr { + return "", fmt.Errorf("query must be a string, not %T", q) + } + return strings.TrimSpace(qStr), nil +} diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go index 28b3a76bb00..601ab8ab0c3 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go @@ -735,6 +735,33 @@ var DoltScripts = []queries.ScriptTest{ }, }, }, + { + Name: "dolt_join_cost tests", + SetUpScript: []string{ + "create table xy (x int primary key, y varchar(10))", + "create table ab (a int primary key, b varchar(10))", + "create table cd (c int primary key, d varchar(10))", + "insert into xy values (0,'0'), (1,'1'), (2,'2')", + "insert into ab values (0,'0'), (1,'1'), (2,'2')", + "insert into cd values (0,'0'), (1,'1'), (2,'2')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select dolt_join_cost('select * from ab, cd, xy where a = c and b = d and y = d')", + Expected: []sql.Row{ + {`memo: +├── G1: (tablescan: ab 0.0)* +├── G2: (tablescan: cd 0.0)* +├── G3: (hashjoin 1 2 12.1) (hashjoin 2 1 12.1) (mergejoin 1 2 6.1)* (mergejoin 2 1 6.1)* (lookupjoin 1 2 9.9) (lookupjoin 2 1 9.9) (innerjoin 2 1 10.1) (innerjoin 1 2 10.1) +├── G4: (tablescan: xy 0.0)* +├── G5: (hashjoin 3 4 12.1) (hashjoin 1 7 12.1) (hashjoin 7 1 12.1) (hashjoin 2 6 12.1) (hashjoin 6 2 12.1) (hashjoin 4 3 12.1) (lookupjoin 7 1 9.9) (lookupjoin 6 2 9.9) (innerjoin 4 3 10.1)* (innerjoin 6 2 10.1) (innerjoin 2 6 10.1) (innerjoin 7 1 10.1) (innerjoin 1 7 10.1) (innerjoin 3 4 10.1)* +├── G6: (hashjoin 1 4 12.1) (hashjoin 4 1 12.1) (innerjoin 4 1 10.1)* (innerjoin 1 4 10.1)* +└── G7: (hashjoin 2 4 12.1) (hashjoin 4 2 12.1) (innerjoin 4 2 10.1)* (innerjoin 2 4 10.1)* +`}, + }, + }, + }, + }, { Name: "dolt_diff.from_commit test", SetUpScript: []string{