From 88fb8c3c281c9e4a78e91a5a7f387d03c445f7ae Mon Sep 17 00:00:00 2001 From: Nick Tobey Date: Mon, 26 Aug 2024 17:57:50 -0700 Subject: [PATCH] Handle virtual columns when adding a column to a table. --- sql/rowexec/ddl_iters.go | 69 ++++++++++++++++++++++++++-------------- sql/rowexec/rel_iters.go | 7 ++++ 2 files changed, 53 insertions(+), 23 deletions(-) diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index e7cf95a0a7..fbda175151 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -1438,53 +1438,76 @@ func addColumnToSchema(schema sql.Schema, column *sql.Column, order *sql.ColumnO newSch := make(sql.Schema, 0, len(schema)+1) projections := make([]sql.Expression, len(schema)+1) + newGetField := func(i int) sql.Expression { + col := schema[i] + if col.Virtual { + return col.Generated + } else { + return expression.NewGetField(i, col.Type, col.Name, col.Nullable) + } + } + if idx >= 0 { newSch = append(newSch, schema[:idx]...) newSch = append(newSch, column) newSch = append(newSch, schema[idx:]...) - for i := range schema[:idx] { - projections[i] = expression.NewGetField(i, schema[i].Type, schema[i].Name, schema[i].Nullable) + for i := 0; i < idx; i++ { + projections[i] = newGetField(i) } projections[idx] = plan.ColDefaultExpression{column} - for i := range schema[idx:] { - schIdx := i + idx - projections[schIdx+1] = expression.NewGetField(schIdx, schema[schIdx].Type, schema[schIdx].Name, schema[schIdx].Nullable) + for i := idx; i < len(schema); i++ { + projections[i+1] = newGetField(i) } } else { // new column at end newSch = append(newSch, schema...) newSch = append(newSch, column) - for i := range schema { - projections[i] = expression.NewGetField(i, schema[i].Type, schema[i].Name, schema[i].Nullable) + for i, _ := range schema { + projections[i] = newGetField(i) } projections[len(schema)] = plan.ColDefaultExpression{column} } - // Alter the new default if it refers to other columns. The column indexes computed during analysis refer to the - // column indexes in the new result schema, which is not what we want here: we want the positions in the old - // (current) schema, since that is what we'll be evaluating when we rewrite the table. + // Alter old default expressions if they refer to other columns. The column indexes computed during analysis refer to the + // column indexes in the old result schema, which is not what we want here: we want the positions in the new + // schema, since that is what we'll be evaluating when we rewrite the table. + var updateFieldRefs transform.ExprFunc = func(s sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + switch s := s.(type) { + case *expression.GetField: + idx := newSch.IndexOf(s.Name(), newSch[0].Source) + if idx < 0 { + return nil, transform.SameTree, sql.ErrTableColumnNotFound.New(schema[0].Source, s.Name()) + } + return s.WithIndex(idx), transform.NewTree, nil + default: + return s, transform.SameTree, nil + } + return s, transform.SameTree, nil + } for i := range projections { switch p := projections[i].(type) { + case *sql.ColumnDefaultValue: + newExpr, _, err := transform.Expr(p, updateFieldRefs) + if err != nil { + return nil, nil, err + } + projections[i] = newExpr + break case plan.ColDefaultExpression: if p.Column.Default != nil { - newExpr, _, err := transform.Expr(p.Column.Default.Expr, func(s sql.Expression) (sql.Expression, transform.TreeIdentity, error) { - switch s := s.(type) { - case *expression.GetField: - idx := schema.IndexOf(s.Name(), schema[0].Source) - if idx < 0 { - return nil, transform.SameTree, sql.ErrTableColumnNotFound.New(schema[0].Source, s.Name()) - } - return expression.NewGetFieldWithTable(idx, 0, s.Type(), s.Database(), s.Table(), s.Name(), s.IsNullable()), transform.NewTree, nil - default: - return s, transform.SameTree, nil - } - return s, transform.SameTree, nil - }) + newExpr, _, err := transform.Expr(p.Column.Default.Expr, updateFieldRefs) if err != nil { return nil, nil, err } p.Column.Default.Expr = newExpr projections[i] = p + } else if p.Column.Generated != nil { + newExpr, _, err := transform.Expr(p.Column.Generated.Expr, updateFieldRefs) + if err != nil { + return nil, nil, err + } + p.Column.Generated.Expr = newExpr + projections[i] = p } break } diff --git a/sql/rowexec/rel_iters.go b/sql/rowexec/rel_iters.go index 56a6eddddb..fdcb3f9fc9 100644 --- a/sql/rowexec/rel_iters.go +++ b/sql/rowexec/rel_iters.go @@ -565,6 +565,13 @@ func defaultValFromProjectExpr(e sql.Expression) (*sql.ColumnDefaultValue, bool) if defaultVal, ok := e.(*sql.ColumnDefaultValue); ok { return defaultVal, true } + if defaultExpr, ok := e.(plan.ColDefaultExpression); ok { + if defaultExpr.Column.Default != nil { + return defaultExpr.Column.Default, true + } else if defaultExpr.Column.Generated != nil { + return defaultExpr.Column.Generated, true + } + } return nil, false }