Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions go/vt/vtgate/semantics/early_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error {
// this rewriting is done in the `up` phase, because we need the vindex hints to have been
// processed while collecting the tables.
return removeVindexHints(node)
case *sqlparser.Select:
// we are interested in queries with a HAVING or ORDER BY, that have a non-empty GROUP BY
if (node.Having == nil && len(node.OrderBy) == 0) ||
(len(node.GroupBy) == 0) {
return nil
}

return r.addMissingColumnsToGroupBy(node)
}
return nil
}
Expand Down Expand Up @@ -944,3 +952,105 @@ func (e *expanderState) storeExpandInfo(tbl TableInfo, tblName sqlparser.TableNa
}
e.expandedColumns[tblName] = append(e.expandedColumns[tblName], colName)
}

func (r *earlyRewriter) addMissingColumnsToGroupBy(sel *sqlparser.Select) error {
var columnsNeeded []*sqlparser.ColName
avoidAggr := func(cursor *sqlparser.Cursor) bool {
_, isAggr := cursor.Node().(sqlparser.AggrFunc)
return !isAggr
}
noteColumns := func(cursor *sqlparser.Cursor) bool {
col, isCol := cursor.Node().(*sqlparser.ColName)
if isCol {
columnsNeeded = append(columnsNeeded, col)
}
return true
}
if sel.Having != nil {
sqlparser.Rewrite(sel.Having, avoidAggr, noteColumns)
}
if len(sel.OrderBy) > 0 {
sqlparser.Rewrite(sel.OrderBy, avoidAggr, noteColumns)
}

if columnsNeeded == nil {
return nil
}

g := newGrouper(sel.GroupBy, r.binder.org, r.binder.tc)
for _, colName := range columnsNeeded {
_, dep, _ := r.binder.org.depsForExpr(colName)
err := g.addToGroupByIfNeeded(dep, colName)
if err != nil {
return err
}
}
sel.GroupBy = g.gb
return nil
}

func newGrouper(gb sqlparser.GroupBy, org originable, tc *tableCollector) *grouper {
return &grouper{
gb: gb,
org: org,
tc: tc,
}
}

type grouper struct {
gbCols map[TableSet]sqlparser.Columns
groupByPK map[TableSet]any
org originable
gb sqlparser.GroupBy
tc *tableCollector
}

func (g *grouper) getGroupByColumns() (map[TableSet]sqlparser.Columns, map[TableSet]any) {
if g.gbCols != nil {
return g.gbCols, g.groupByPK
}
g.gbCols = map[TableSet]sqlparser.Columns{}
g.groupByPK = map[TableSet]any{}
for _, expr := range g.gb {
col, isCol := expr.(*sqlparser.ColName)
if !isCol {
continue
}
_, deps, _ := g.org.depsForExpr(col)
g.gbCols[deps] = append(g.gbCols[deps], col.Name)
info, err := g.tc.tableInfoFor(deps)
if err != nil {
panic(err.Error())
}
pks := info.GetVindexTable().PrimaryKey
for _, pk := range pks {
if pk.Equal(col.Name) {
g.groupByPK[deps] = nil
}
}
}
return g.gbCols, g.groupByPK
}

func (g *grouper) addToGroupByIfNeeded(id TableSet, col *sqlparser.ColName) error {
gbCols, pkCols := g.getGroupByColumns()
cols := gbCols[id]

// first we check if this column is already present in the GROUP BY
for _, colName := range cols {
if colName.Equal(col.Name) {
return nil
}
}

// now we check if the primary key of this table is in the GROUP BY.
// If it is, adding another column from this table to the GB will not change the semantics
_, isThere := pkCols[id]
if isThere {
newCol := sqlparser.NewColNameWithQualifier(col.Name.String(), col.Qualifier)
g.gb = append(g.gb, newCol)
return nil
}

return nil
}
15 changes: 14 additions & 1 deletion go/vt/vtgate/semantics/early_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,12 @@ func TestOrderByGroupByLiteral(t *testing.T) {

func TestHavingAndOrderByColumnName(t *testing.T) {
schemaInfo := &FakeSI{
Tables: map[string]*vindexes.Table{},
Tables: map[string]*vindexes.Table{
"t1": {
Name: sqlparser.NewIdentifierCS("t1"),
PrimaryKey: sqlparser.Columns{sqlparser.NewIdentifierCI("id")},
},
},
}
cDB := "db"
tcases := []struct {
Expand Down Expand Up @@ -409,6 +414,14 @@ func TestHavingAndOrderByColumnName(t *testing.T) {
}, {
sql: "select id, id, count(*) from t1 order by id",
expSQL: "select id, id, count(*) from t1 order by id asc",
}, {
// we add any columns being used in the having clause that are
// functionally dependent on a column already in the group by clause
sql: "select t1.id, t1.foo-sum(t2.bar) as x from t1 join t2 group by t1.id having x > 5",
expSQL: "select t1.id, t1.foo - sum(t2.bar) as x from t1 join t2 group by t1.id, t1.foo having t1.foo - sum(t2.bar) > 5",
}, {
sql: "select t1.id, t1.foo-sum(t2.bar) as x from t1 join t2 having x > 5",
expSQL: "select t1.id, t1.foo - sum(t2.bar) as x from t1 join t2 having t1.foo - sum(t2.bar) > 5",
}}
for _, tcase := range tcases {
t.Run(tcase.sql, func(t *testing.T) {
Expand Down