diff --git a/go.mod b/go.mod index f4f51f63c5..0d4bb1c2ab 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20250916051405-78a38d478790 github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20260108222406-f8a2587c4954 + github.com/dolthub/vitess v0.0.0-20260109110924-205efc8530f1 github.com/go-sql-driver/mysql v1.9.3 github.com/gocraft/dbr/v2 v2.7.2 github.com/google/uuid v1.3.0 diff --git a/go.sum b/go.sum index b95d2364c1..b99ed09ff6 100644 --- a/go.sum +++ b/go.sum @@ -18,8 +18,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20260108222406-f8a2587c4954 h1:VN2ZjnYPyxcAN/XetvcdumFbI2Ad/Gb47Qwdo9REY3A= -github.com/dolthub/vitess v0.0.0-20260108222406-f8a2587c4954/go.mod h1:FLWqdXsAeeBQyFwDjmBVu0GnbjI2MKeRf3tRVdJEKlI= +github.com/dolthub/vitess v0.0.0-20260109110924-205efc8530f1 h1:souetbYNBRHrt9y990VGD1jkzCIQ0jC+gxMdFOEjL+g= +github.com/dolthub/vitess v0.0.0-20260109110924-205efc8530f1/go.mod h1:FLWqdXsAeeBQyFwDjmBVu0GnbjI2MKeRf3tRVdJEKlI= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= diff --git a/sql/errors.go b/sql/errors.go index 806192c544..0be1672f38 100644 --- a/sql/errors.go +++ b/sql/errors.go @@ -964,6 +964,9 @@ var ( // ErrOnlyFDAndRBREventsAllowedInBinlogStatement is returned when an unsupported event type is used in a BINLOG statement. ErrOnlyFDAndRBREventsAllowedInBinlogStatement = errors.NewKind("Only Format_description_log_event and row events are allowed in BINLOG statements (but %s was provided)") + + // ErrDistinctOnMatchOrderBy is returned when DISTINCT ON does not match the initial ORDER BY expressions + ErrDistinctOnMatchOrderBy = errors.NewKind("SELECT DISTINCT ON expressions must match initial ORDER BY expressions") ) // CastSQLError returns a *mysql.SQLError with the error code and in some cases, also a SQL state, populated for the diff --git a/sql/iters/rel_iters.go b/sql/iters/rel_iters.go index 24c581447d..1e259a6cea 100644 --- a/sql/iters/rel_iters.go +++ b/sql/iters/rel_iters.go @@ -551,14 +551,21 @@ func (i *sortIter) computeSortedRows(ctx *sql.Context) error { // result sets. type distinctIter struct { childIter sql.RowIter + hasher DistinctHasher seen sql.KeyValueCache DisposeFunc sql.DisposeFunc } -func NewDistinctIter(ctx *sql.Context, child sql.RowIter) *distinctIter { +// DistinctHasher handles hashing for distinctIter +type DistinctHasher interface { + HashOf(ctx *sql.Context, row sql.Row) (uint64, error) +} + +func NewDistinctIter(ctx *sql.Context, child sql.RowIter, hasher DistinctHasher) *distinctIter { cache, dispose := ctx.Memory.NewHistoryCache() return &distinctIter{ childIter: child, + hasher: hasher, seen: cache, DisposeFunc: dispose, } @@ -574,7 +581,7 @@ func (di *distinctIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, err } - hash, err := hash.HashOf(ctx, nil, row) + hash, err := di.hasher.HashOf(ctx, row) if err != nil { return nil, err } diff --git a/sql/memo/exec_builder.go b/sql/memo/exec_builder.go index 67addda622..2ef263b22d 100644 --- a/sql/memo/exec_builder.go +++ b/sql/memo/exec_builder.go @@ -22,7 +22,7 @@ func (b *ExecBuilder) buildRel(r RelExpr, children ...sql.Node) (sql.Node, error } // TODO: distinctOp doesn't seem to be propagated through all the time - return b.wrapInDistinct(n, r.Distinct()) + return b.wrapInDistinct(n, r.Distinct(), r.DistinctOn()) } func (b *ExecBuilder) buildInnerJoin(j *InnerJoin, children ...sql.Node) (sql.Node, error) { @@ -73,7 +73,7 @@ func (b *ExecBuilder) buildRangeHeap(sr *RangeHeap, children ...sql.Node) (ret s switch n := children[0].(type) { case *plan.Distinct: ret, err = b.buildRangeHeap(sr, n.Child) - ret = plan.NewDistinct(ret) + ret = plan.NewDistinct(ret, n.DistinctOn()...) case *plan.OrderedDistinct: ret, err = b.buildRangeHeap(sr, n.Child) ret = plan.NewOrderedDistinct(ret) @@ -233,7 +233,7 @@ func (b *ExecBuilder) buildIndexScan(i *IndexScan, children ...sql.Node) (sql.No ret = i.Table case *plan.Distinct: ret, err = b.buildIndexScan(i, n.Child) - ret = plan.NewDistinct(ret) + ret = plan.NewDistinct(ret, n.DistinctOn()...) case *plan.OrderedDistinct: ret, err = b.buildIndexScan(i, n.Child) ret = plan.NewOrderedDistinct(ret) @@ -358,7 +358,7 @@ func (b *ExecBuilder) buildProject(r *Project, children ...sql.Node) (sql.Node, } func (b *ExecBuilder) buildDistinct(r *Distinct, children ...sql.Node) (sql.Node, error) { - return plan.NewDistinct(children[0]), nil + return plan.NewDistinct(children[0], r.distinctOn...), nil } func (b *ExecBuilder) buildFilter(r *Filter, children ...sql.Node) (sql.Node, error) { @@ -366,10 +366,10 @@ func (b *ExecBuilder) buildFilter(r *Filter, children ...sql.Node) (sql.Node, er return ret, nil } -func (b *ExecBuilder) wrapInDistinct(n sql.Node, d distinctOp) (sql.Node, error) { +func (b *ExecBuilder) wrapInDistinct(n sql.Node, d distinctOp, distinctOn []sql.Expression) (sql.Node, error) { switch d { case HashDistinctOp: - return plan.NewDistinct(n), nil + return plan.NewDistinct(n, distinctOn...), nil case SortedDistinctOp: return plan.NewOrderedDistinct(n), nil case NoDistinctOp: diff --git a/sql/memo/join_order_builder.go b/sql/memo/join_order_builder.go index f52f6b62aa..34726ef4a5 100644 --- a/sql/memo/join_order_builder.go +++ b/sql/memo/join_order_builder.go @@ -226,6 +226,7 @@ func (j *joinOrderBuilder) populateSubgraph(n sql.Node) (vertexSet, edgeSet, *Ex case *plan.Distinct: _, _, group = j.populateSubgraph(n.Child) group.RelProps.Distinct = HashDistinctOp + group.RelProps.DistinctOn = n.Expressions() case *plan.Max1Row: return j.buildMax1Row(n) case *plan.JoinNode: diff --git a/sql/memo/memo.go b/sql/memo/memo.go index 93f621ac7a..e496ee67a6 100644 --- a/sql/memo/memo.go +++ b/sql/memo/memo.go @@ -469,11 +469,11 @@ func (m *Memo) optimizeMemoGroup(grp *ExprGroup) error { } if grp.RelProps.Distinct.IsHash() { - if sortedInputs(n) { + if sortedInputs(n) && len(grp.RelProps.DistinctOn) == 0 { n.SetDistinct(SortedDistinctOp) m.Tracer.Log("Plan %s: using sorted distinct", n) } else { - n.SetDistinct(HashDistinctOp) + n.SetDistinct(HashDistinctOp, grp.RelProps.DistinctOn...) d := &Distinct{Child: grp} relCost += float64(m.statsForRel(m.Ctx, d).RowCount()) m.Tracer.Log("Plan %s: using hash distinct", n) @@ -752,7 +752,8 @@ type RelExpr interface { SetCost(c float64) Cost() float64 Distinct() distinctOp - SetDistinct(distinctOp) + DistinctOn() []sql.Expression + SetDistinct(distinctOp, ...sql.Expression) } type relBase struct { @@ -764,6 +765,8 @@ type relBase struct { c float64 // d indicates a RelExpr should be checked for distinctness d distinctOp + // distinctOn, when not empty, indicates the expressions that should be used for distinctness (otherwise it's the projections) + distinctOn []sql.Expression } // relKey is a quick identifier for avoiding duplicate work on the same @@ -810,8 +813,13 @@ func (r *relBase) Distinct() distinctOp { return r.d } -func (r *relBase) SetDistinct(d distinctOp) { +func (r *relBase) DistinctOn() []sql.Expression { + return r.distinctOn +} + +func (r *relBase) SetDistinct(d distinctOp, on ...sql.Expression) { r.d = d + r.distinctOn = on } func (r *relBase) Group() *ExprGroup { diff --git a/sql/memo/rel_props.go b/sql/memo/rel_props.go index 95d46e1475..c6c18fe1e5 100644 --- a/sql/memo/rel_props.go +++ b/sql/memo/rel_props.go @@ -41,6 +41,7 @@ type relProps struct { tableNodes []plan.TableIdNode sort sql.SortFields Distinct distinctOp + DistinctOn []sql.Expression } func newRelProps(rel RelExpr) *relProps { diff --git a/sql/plan/distinct.go b/sql/plan/distinct.go index 05a0f2e6a9..2c0b643053 100644 --- a/sql/plan/distinct.go +++ b/sql/plan/distinct.go @@ -16,20 +16,24 @@ package plan import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/hash" ) // Distinct is a node that ensures all rows that come from it are unique. type Distinct struct { UnaryNode + distinctOn []sql.Expression // If these have len() > 0, then these are used instead of the projection (Doltgres only) } var _ sql.Node = (*Distinct)(nil) +var _ sql.Expressioner = (*Distinct)(nil) var _ sql.CollationCoercible = (*Distinct)(nil) // NewDistinct creates a new Distinct node. -func NewDistinct(child sql.Node) *Distinct { +func NewDistinct(child sql.Node, distinctOn ...sql.Expression) *Distinct { return &Distinct{ - UnaryNode: UnaryNode{Child: child}, + UnaryNode: UnaryNode{Child: child}, + distinctOn: distinctOn, } } @@ -44,7 +48,22 @@ func (d *Distinct) WithChildren(children ...sql.Node) (sql.Node, error) { return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 1) } - return NewDistinct(children[0]), nil + return NewDistinct(children[0], d.distinctOn...), nil +} + +// DistinctOn returns any DISTINCT ON expressions. +func (d *Distinct) DistinctOn() []sql.Expression { + return d.distinctOn +} + +// Expressions implements the interface sql.Expressioner. +func (d *Distinct) Expressions() []sql.Expression { + return d.DistinctOn() +} + +// WithExpressions implements the interface sql.Expressioner. +func (d *Distinct) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + return NewDistinct(d.Child, exprs...), nil } func (d *Distinct) IsReadOnly() bool { @@ -80,6 +99,39 @@ func (d Distinct) DebugString() string { return p.String() } +// Hasher returns a new DistinctHasher created from this Distinct node. +func (d *Distinct) Hasher() DistinctHasher { + var hashingRow sql.Row + if len(d.distinctOn) > 0 { + hashingRow = make(sql.Row, len(d.distinctOn)) + } + return DistinctHasher{ + distinctOn: d.distinctOn, + hashingRow: hashingRow, + } +} + +// DistinctHasher handles hashing for Distinct nodes, taking any non-expression projections into account. +type DistinctHasher struct { + distinctOn []sql.Expression + hashingRow sql.Row +} + +// HashOf handles the hashing of the given row, taking any expressions into account. +func (dh DistinctHasher) HashOf(ctx *sql.Context, row sql.Row) (uint64, error) { + if len(dh.distinctOn) > 0 { + var err error + for i, expr := range dh.distinctOn { + dh.hashingRow[i], err = expr.Eval(ctx, row) + if err != nil { + return 0, err + } + } + return hash.HashOf(ctx, nil, dh.hashingRow) + } + return hash.HashOf(ctx, nil, row) +} + // OrderedDistinct is a Distinct node optimized for sorted row sets. // It's 2 orders of magnitude faster and uses 2 orders of magnitude less memory. type OrderedDistinct struct { diff --git a/sql/planbuilder/factory.go b/sql/planbuilder/factory.go index 21f059c520..1bedb7df03 100644 --- a/sql/planbuilder/factory.go +++ b/sql/planbuilder/factory.go @@ -188,18 +188,45 @@ func (f *factory) buildTableAlias(name string, child sql.Node) (plan.TableIdNode } // buildDistinct will wrap the child node in a distinct node depending on the Sort nodes and Projections there. -// if the sort fields are a subset of the projection fields +// If the sort fields are a subset of the projection fields: // -// sort(project(table)) -> sort(distinct(project(table))) +// project(sort(table)) -> sort(distinct(project(table))) // -// else +// otherwise, it is: // -// sort(project(table)) -> distinct(sort(project(table))) -func (f *factory) buildDistinct(child sql.Node, refsSubquery bool) (sql.Node, error) { - if proj, isProj := child.(*plan.Project); isProj { +// project(sort(table)) -> distinct(sort(project(table))) +// +// With DISTINCT ON columns, we may use columns that are not referenced by the projection, and therefore it must be +// pushed under it. This also means that the sort node has to be pushed under distinct as well. +func (f *factory) buildDistinct(child sql.Node, refsSubquery bool, distinctOn []sql.Expression) (sql.Node, error) { + if len(distinctOn) > 0 { + if proj, isProj := child.(*plan.Project); isProj { + if sort, isSort := proj.Child.(*plan.Sort); isSort { + dMap := make(map[string]struct{}) + for _, expr := range distinctOn { + dMap[strings.ToLower(expr.String())] = struct{}{} + } + minMatching := min(len(distinctOn), len(sort.SortFields)) + for i := 0; i < minMatching; i++ { + if _, ok := dMap[strings.ToLower(sort.SortFields[i].Column.String())]; !ok { + return nil, sql.ErrDistinctOnMatchOrderBy.New() + } + } + } + distinct := plan.NewDistinct(proj.Child, distinctOn...) + proj.Child = distinct + return proj, nil + } + } else if proj, isProj := child.(*plan.Project); isProj { // TODO: if projection columns are just primary key, distinct is no-op // TODO: distinct literals are just one row if sort, isSort := proj.Child.(*plan.Sort); isSort { + if len(distinctOn) > 0 { + sortMap := make(map[string]struct{}) + for _, p := range proj.Projections { + sortMap[strings.ToLower(p.String())] = struct{}{} + } + } projMap := make(map[string]struct{}) for _, p := range proj.Projections { projMap[strings.ToLower(p.String())] = struct{}{} @@ -218,12 +245,12 @@ func (f *factory) buildDistinct(child sql.Node, refsSubquery bool) (sql.Node, er if err != nil { return nil, err } - sort.Child = plan.NewDistinct(proj) + sort.Child = plan.NewDistinct(proj, distinctOn...) return sort, nil } } } - return plan.NewDistinct(child), nil + return plan.NewDistinct(child, distinctOn...), nil } func (f *factory) buildSort(child sql.Node, exprs []sql.SortField, deps sql.ColSet, subquery bool) (sql.Node, error) { diff --git a/sql/planbuilder/select.go b/sql/planbuilder/select.go index c2d168fba6..4e9274c726 100644 --- a/sql/planbuilder/select.go +++ b/sql/planbuilder/select.go @@ -107,7 +107,7 @@ func (b *Builder) buildSelect(inScope *scope, s *ast.Select) (outScope *scope) { b.buildProjection(outScope, projScope) outScope = projScope - if err := b.buildDistinct(outScope, s.QueryOpts.Distinct); err != nil { + if err := b.buildDistinct(outScope, s.QueryOpts.Distinct, s.QueryOpts.DistinctOn); err != nil { b.handleErr(err) } @@ -201,12 +201,16 @@ func (b *Builder) typeCoerceLiteral(e sql.Expression) sql.Expression { // buildDistinct creates a new plan.Distinct node if the query has a DISTINCT option. // If the query has both DISTINCT and ALL, an error is returned. -func (b *Builder) buildDistinct(inScope *scope, distinct bool) error { +func (b *Builder) buildDistinct(inScope *scope, distinct bool, distinctOn ast.Exprs) error { if !distinct { return nil } + distinctOnExprs := make([]sql.Expression, len(distinctOn)) + for i := range distinctOn { + distinctOnExprs[i] = b.buildScalar(inScope, distinctOn[i]) + } var err error - inScope.node, err = b.f.buildDistinct(inScope.node, inScope.refsSubquery) + inScope.node, err = b.f.buildDistinct(inScope.node, inScope.refsSubquery, distinctOnExprs) return err } diff --git a/sql/rowexec/rel.go b/sql/rowexec/rel.go index a25f534de1..a3cb621996 100644 --- a/sql/rowexec/rel.go +++ b/sql/rowexec/rel.go @@ -772,7 +772,7 @@ func (b *BaseBuilder) buildDistinct(ctx *sql.Context, n *plan.Distinct, row sql. return nil, err } - return sql.NewSpanIter(span, iters.NewDistinctIter(ctx, it)), nil + return sql.NewSpanIter(span, iters.NewDistinctIter(ctx, it, n.Hasher())), nil } func (b *BaseBuilder) buildIndexedTableAccess(ctx *sql.Context, n *plan.IndexedTableAccess, row sql.Row) (sql.RowIter, error) { @@ -841,11 +841,11 @@ func (b *BaseBuilder) buildSetOp(ctx *sql.Context, s *plan.SetOp, row sql.Row) ( return nil, err } if s.Distinct { - dIter := iters.NewDistinctIter(ctx, iter) + dIter := iters.NewDistinctIter(ctx, iter, plan.DistinctHasher{}) s.AddDispose(dIter.DisposeFunc) iter = dIter - dIter2 := iters.NewDistinctIter(ctx, iter2) + dIter2 := iters.NewDistinctIter(ctx, iter2, plan.DistinctHasher{}) s.AddDispose(dIter2.DisposeFunc) iter2 = dIter2 } @@ -856,7 +856,7 @@ func (b *BaseBuilder) buildSetOp(ctx *sql.Context, s *plan.SetOp, row sql.Row) ( } if s.Distinct && s.SetOpType != plan.ExceptType { - dIter := iters.NewDistinctIter(ctx, iter) + dIter := iters.NewDistinctIter(ctx, iter, plan.DistinctHasher{}) s.AddDispose(dIter.DisposeFunc) iter = dIter }