Skip to content

Commit

Permalink
Handle virtual columns when adding a column to a table.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicktobey committed Aug 27, 2024
1 parent eeebabe commit 88fb8c3
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 23 deletions.
69 changes: 46 additions & 23 deletions sql/rowexec/ddl_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -1438,53 +1438,76 @@ func addColumnToSchema(schema sql.Schema, column *sql.Column, order *sql.ColumnO
newSch := make(sql.Schema, 0, len(schema)+1)
projections := make([]sql.Expression, len(schema)+1)

newGetField := func(i int) sql.Expression {
col := schema[i]
if col.Virtual {
return col.Generated
} else {
return expression.NewGetField(i, col.Type, col.Name, col.Nullable)
}
}

if idx >= 0 {
newSch = append(newSch, schema[:idx]...)
newSch = append(newSch, column)
newSch = append(newSch, schema[idx:]...)

for i := range schema[:idx] {
projections[i] = expression.NewGetField(i, schema[i].Type, schema[i].Name, schema[i].Nullable)
for i := 0; i < idx; i++ {
projections[i] = newGetField(i)
}
projections[idx] = plan.ColDefaultExpression{column}
for i := range schema[idx:] {
schIdx := i + idx
projections[schIdx+1] = expression.NewGetField(schIdx, schema[schIdx].Type, schema[schIdx].Name, schema[schIdx].Nullable)
for i := idx; i < len(schema); i++ {
projections[i+1] = newGetField(i)
}
} else { // new column at end
newSch = append(newSch, schema...)
newSch = append(newSch, column)
for i := range schema {
projections[i] = expression.NewGetField(i, schema[i].Type, schema[i].Name, schema[i].Nullable)
for i, _ := range schema {
projections[i] = newGetField(i)
}
projections[len(schema)] = plan.ColDefaultExpression{column}
}

// Alter the new default if it refers to other columns. The column indexes computed during analysis refer to the
// column indexes in the new result schema, which is not what we want here: we want the positions in the old
// (current) schema, since that is what we'll be evaluating when we rewrite the table.
// Alter old default expressions if they refer to other columns. The column indexes computed during analysis refer to the
// column indexes in the old result schema, which is not what we want here: we want the positions in the new
// schema, since that is what we'll be evaluating when we rewrite the table.
var updateFieldRefs transform.ExprFunc = func(s sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
switch s := s.(type) {
case *expression.GetField:
idx := newSch.IndexOf(s.Name(), newSch[0].Source)
if idx < 0 {
return nil, transform.SameTree, sql.ErrTableColumnNotFound.New(schema[0].Source, s.Name())
}
return s.WithIndex(idx), transform.NewTree, nil
default:
return s, transform.SameTree, nil
}
return s, transform.SameTree, nil
}
for i := range projections {
switch p := projections[i].(type) {
case *sql.ColumnDefaultValue:
newExpr, _, err := transform.Expr(p, updateFieldRefs)
if err != nil {
return nil, nil, err
}
projections[i] = newExpr
break
case plan.ColDefaultExpression:
if p.Column.Default != nil {
newExpr, _, err := transform.Expr(p.Column.Default.Expr, func(s sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
switch s := s.(type) {
case *expression.GetField:
idx := schema.IndexOf(s.Name(), schema[0].Source)
if idx < 0 {
return nil, transform.SameTree, sql.ErrTableColumnNotFound.New(schema[0].Source, s.Name())
}
return expression.NewGetFieldWithTable(idx, 0, s.Type(), s.Database(), s.Table(), s.Name(), s.IsNullable()), transform.NewTree, nil
default:
return s, transform.SameTree, nil
}
return s, transform.SameTree, nil
})
newExpr, _, err := transform.Expr(p.Column.Default.Expr, updateFieldRefs)
if err != nil {
return nil, nil, err
}
p.Column.Default.Expr = newExpr
projections[i] = p
} else if p.Column.Generated != nil {
newExpr, _, err := transform.Expr(p.Column.Generated.Expr, updateFieldRefs)
if err != nil {
return nil, nil, err
}
p.Column.Generated.Expr = newExpr
projections[i] = p
}
break
}
Expand Down
7 changes: 7 additions & 0 deletions sql/rowexec/rel_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,13 @@ func defaultValFromProjectExpr(e sql.Expression) (*sql.ColumnDefaultValue, bool)
if defaultVal, ok := e.(*sql.ColumnDefaultValue); ok {
return defaultVal, true
}
if defaultExpr, ok := e.(plan.ColDefaultExpression); ok {
if defaultExpr.Column.Default != nil {
return defaultExpr.Column.Default, true
} else if defaultExpr.Column.Generated != nil {
return defaultExpr.Column.Generated, true
}
}

return nil, false
}
Expand Down

0 comments on commit 88fb8c3

Please sign in to comment.