diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index ed1f30c670e..f2c642f4151 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -152,6 +152,14 @@ func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error { // this rewriting is done in the `up` phase, because we need the vindex hints to have been // processed while collecting the tables. return removeVindexHints(node) + case *sqlparser.Select: + // we are interested in queries with a HAVING or ORDER BY, that have a non-empty GROUP BY + if (node.Having == nil && len(node.OrderBy) == 0) || + (len(node.GroupBy) == 0) { + return nil + } + + return r.addMissingColumnsToGroupBy(node) } return nil } @@ -944,3 +952,105 @@ func (e *expanderState) storeExpandInfo(tbl TableInfo, tblName sqlparser.TableNa } e.expandedColumns[tblName] = append(e.expandedColumns[tblName], colName) } + +func (r *earlyRewriter) addMissingColumnsToGroupBy(sel *sqlparser.Select) error { + var columnsNeeded []*sqlparser.ColName + avoidAggr := func(cursor *sqlparser.Cursor) bool { + _, isAggr := cursor.Node().(sqlparser.AggrFunc) + return !isAggr + } + noteColumns := func(cursor *sqlparser.Cursor) bool { + col, isCol := cursor.Node().(*sqlparser.ColName) + if isCol { + columnsNeeded = append(columnsNeeded, col) + } + return true + } + if sel.Having != nil { + sqlparser.Rewrite(sel.Having, avoidAggr, noteColumns) + } + if len(sel.OrderBy) > 0 { + sqlparser.Rewrite(sel.OrderBy, avoidAggr, noteColumns) + } + + if columnsNeeded == nil { + return nil + } + + g := newGrouper(sel.GroupBy, r.binder.org, r.binder.tc) + for _, colName := range columnsNeeded { + _, dep, _ := r.binder.org.depsForExpr(colName) + err := g.addToGroupByIfNeeded(dep, colName) + if err != nil { + return err + } + } + sel.GroupBy = g.gb + return nil +} + +func newGrouper(gb sqlparser.GroupBy, org originable, tc *tableCollector) *grouper { + return &grouper{ + gb: gb, + org: org, + tc: tc, + } +} + +type grouper struct { + gbCols map[TableSet]sqlparser.Columns + groupByPK map[TableSet]any + org originable + gb sqlparser.GroupBy + tc *tableCollector +} + +func (g *grouper) getGroupByColumns() (map[TableSet]sqlparser.Columns, map[TableSet]any) { + if g.gbCols != nil { + return g.gbCols, g.groupByPK + } + g.gbCols = map[TableSet]sqlparser.Columns{} + g.groupByPK = map[TableSet]any{} + for _, expr := range g.gb { + col, isCol := expr.(*sqlparser.ColName) + if !isCol { + continue + } + _, deps, _ := g.org.depsForExpr(col) + g.gbCols[deps] = append(g.gbCols[deps], col.Name) + info, err := g.tc.tableInfoFor(deps) + if err != nil { + panic(err.Error()) + } + pks := info.GetVindexTable().PrimaryKey + for _, pk := range pks { + if pk.Equal(col.Name) { + g.groupByPK[deps] = nil + } + } + } + return g.gbCols, g.groupByPK +} + +func (g *grouper) addToGroupByIfNeeded(id TableSet, col *sqlparser.ColName) error { + gbCols, pkCols := g.getGroupByColumns() + cols := gbCols[id] + + // first we check if this column is already present in the GROUP BY + for _, colName := range cols { + if colName.Equal(col.Name) { + return nil + } + } + + // now we check if the primary key of this table is in the GROUP BY. + // If it is, adding another column from this table to the GB will not change the semantics + _, isThere := pkCols[id] + if isThere { + newCol := sqlparser.NewColNameWithQualifier(col.Name.String(), col.Qualifier) + g.gb = append(g.gb, newCol) + return nil + } + + return nil +} diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index 3b7b30d5f39..0654835baf8 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -377,7 +377,12 @@ func TestOrderByGroupByLiteral(t *testing.T) { func TestHavingAndOrderByColumnName(t *testing.T) { schemaInfo := &FakeSI{ - Tables: map[string]*vindexes.Table{}, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + PrimaryKey: sqlparser.Columns{sqlparser.NewIdentifierCI("id")}, + }, + }, } cDB := "db" tcases := []struct { @@ -409,6 +414,14 @@ func TestHavingAndOrderByColumnName(t *testing.T) { }, { sql: "select id, id, count(*) from t1 order by id", expSQL: "select id, id, count(*) from t1 order by id asc", + }, { + // we add any columns being used in the having clause that are + // functionally dependent on a column already in the group by clause + sql: "select t1.id, t1.foo-sum(t2.bar) as x from t1 join t2 group by t1.id having x > 5", + expSQL: "select t1.id, t1.foo - sum(t2.bar) as x from t1 join t2 group by t1.id, t1.foo having t1.foo - sum(t2.bar) > 5", + }, { + sql: "select t1.id, t1.foo-sum(t2.bar) as x from t1 join t2 having x > 5", + expSQL: "select t1.id, t1.foo - sum(t2.bar) as x from t1 join t2 having t1.foo - sum(t2.bar) > 5", }} for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) {