diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 5881908e63..f1164e18c8 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -4243,6 +4243,73 @@ func TestCurrentTimestamp(t *testing.T, harness Harness) { } } +func TestOnUpdateExprScripts(t *testing.T, harness Harness) { + harness.Setup(setup.MydbData) + for _, script := range queries.OnUpdateExprScripts { + if sh, ok := harness.(SkippingHarness); ok { + if sh.SkipQueryTest(script.Name) { + t.Run(script.Name, func(t *testing.T) { + t.Skip(script.Name) + }) + continue + } + } + e := mustNewEngine(t, harness) + ctx := NewContext(harness) + err := CreateNewConnectionForServerEngine(ctx, e) + require.NoError(t, err, nil) + + t.Run(script.Name, func(t *testing.T) { + for _, statement := range script.SetUpScript { + sql.RunWithNowFunc(func() time.Time { return queries.Jan1Noon }, func() error { + ctx.WithQuery(statement) + ctx.SetQueryTime(queries.Jan1Noon) + RunQueryWithContext(t, e, harness, ctx, statement) + return nil + }) + } + + assertions := script.Assertions + if len(assertions) == 0 { + assertions = []queries.ScriptTestAssertion{ + { + Query: script.Query, + Expected: script.Expected, + ExpectedErr: script.ExpectedErr, + ExpectedIndexes: script.ExpectedIndexes, + }, + } + } + + for _, assertion := range script.Assertions { + t.Run(assertion.Query, func(t *testing.T) { + if assertion.Skip { + t.Skip() + } + sql.RunWithNowFunc(func() time.Time { return queries.Dec15_1_30 }, func() error { + ctx.SetQueryTime(queries.Dec15_1_30) + if assertion.ExpectedErr != nil { + AssertErr(t, e, harness, assertion.Query, assertion.ExpectedErr) + } else if assertion.ExpectedErrStr != "" { + AssertErr(t, e, harness, assertion.Query, nil, assertion.ExpectedErrStr) + } else { + var expected = assertion.Expected + if IsServerEngine(e) && assertion.SkipResultCheckOnServerEngine { + // TODO: remove this check in the future + expected = nil + } + TestQueryWithContext(t, ctx, e, harness, assertion.Query, expected, assertion.ExpectedColumns, assertion.Bindings) + } + return nil + }) + }) + } + }) + + e.Close() + } +} + func TestAddDropPks(t *testing.T, harness Harness) { for _, tt := range queries.AddDropPrimaryKeyScripts { TestScript(t, harness, tt) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index f5ebf91bf6..031acea9f9 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -560,6 +560,10 @@ func TestUpdateErrors(t *testing.T) { enginetest.TestUpdateErrors(t, enginetest.NewMemoryHarness("default", 1, testNumPartitions, true, mergableIndexDriver)) } +func TestOnUpdateExprScripts(t *testing.T) { + enginetest.TestOnUpdateExprScripts(t, enginetest.NewMemoryHarness("default", 1, testNumPartitions, true, mergableIndexDriver)) +} + func TestSpatialUpdate(t *testing.T) { enginetest.TestSpatialUpdate(t, enginetest.NewMemoryHarness("default", 1, testNumPartitions, true, mergableIndexDriver)) } diff --git a/enginetest/queries/update_queries.go b/enginetest/queries/update_queries.go index 0c9e873f8f..046c326427 100644 --- a/enginetest/queries/update_queries.go +++ b/enginetest/queries/update_queries.go @@ -15,12 +15,13 @@ package queries import ( - "github.com/dolthub/vitess/go/mysql" - - "github.com/dolthub/go-mysql-server/sql/types" + "time" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/types" + + "github.com/dolthub/vitess/go/mysql" ) var UpdateTests = []WriteQueryTest{ @@ -760,3 +761,463 @@ var UpdateErrorScripts = []ScriptTest{ ExpectedErr: types.ErrLengthBeyondLimit, }, } + +var ZeroTime = time.Date(0000, time.January, 1, 0, 0, 0, 0, time.UTC) +var Jan1Noon = time.Date(2000, time.January, 1, 12, 0, 0, 0, time.UTC) +var Dec15_1_30 = time.Date(2023, time.December, 15, 1, 30, 0, 0, time.UTC) +var Oct2Midnight = time.Date(2020, time.October, 2, 0, 0, 0, 0, time.UTC) +var OnUpdateExprScripts = []ScriptTest{ + { + Name: "error cases", + SetUpScript: []string{ + "create table t (i int, ts timestamp);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "create table tt (i int, j int on update (5))", + ExpectedErrStr: "syntax error at position 42 near 'update'", + }, + { + Query: "create table tt (i int, j int on update current_timestamp)", + ExpectedErr: sql.ErrInvalidOnUpdate, + }, + { + Query: "create table tt (i int, d date on update current_timestamp)", + ExpectedErr: sql.ErrInvalidOnUpdate, + }, + { + Query: "alter table t modify column ts timestamp on update (5)", + ExpectedErrStr: "syntax error at position 53 near 'update'", + }, + { + Query: "alter table t modify column t int on update current_timestamp", + ExpectedErr: sql.ErrInvalidOnUpdate, + }, + { + Query: "alter table t modify column t date on update current_timestamp", + ExpectedErr: sql.ErrInvalidOnUpdate, + }, + }, + }, + { + Name: "basic case", + SetUpScript: []string{ + "create table t (i int, ts timestamp default 0 on update current_timestamp);", + "insert into t(i) values (1), (2), (3);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "show create table t", + Expected: []sql.Row{ + {"t", "CREATE TABLE `t` (\n" + + " `i` int,\n" + + " `ts` timestamp DEFAULT 0 ON UPDATE (CURRENT_TIMESTAMP())\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + SkipResultCheckOnServerEngine: true, + Query: "select * from t order by i;", + Expected: []sql.Row{ + {1, ZeroTime}, + {2, ZeroTime}, + {3, ZeroTime}, + }, + }, + { + Query: "update t set i = 10 where i = 1;", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}, + }, + }, + { + SkipResultCheckOnServerEngine: true, + Query: "select * from t order by i;", + Expected: []sql.Row{ + {2, ZeroTime}, + {3, ZeroTime}, + {10, Dec15_1_30}, + }, + }, + { + Query: "update t set i = 100", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 3, Info: plan.UpdateInfo{Matched: 3, Updated: 3}}}, + }, + }, + { + Query: "select * from t order by i;", + Expected: []sql.Row{ + {100, Dec15_1_30}, + {100, Dec15_1_30}, + {100, Dec15_1_30}, + }, + }, + { + // updating timestamp itself blocks on update + Query: "update t set ts = timestamp('2020-10-2')", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 3, Info: plan.UpdateInfo{Matched: 3, Updated: 3}}}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {100, Oct2Midnight}, + {100, Oct2Midnight}, + {100, Oct2Midnight}, + }, + }, + }, + }, + { + Name: "default time is current time", + SetUpScript: []string{ + "create table t (i int, ts timestamp default current_timestamp on update current_timestamp);", + "insert into t(i) values (1), (2), (3);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "show create table t", + Expected: []sql.Row{ + {"t", "CREATE TABLE `t` (\n" + + " `i` int,\n" + + " `ts` timestamp DEFAULT (CURRENT_TIMESTAMP()) ON UPDATE (CURRENT_TIMESTAMP())\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + Query: "select * from t order by i;", + Expected: []sql.Row{ + {1, Jan1Noon}, + {2, Jan1Noon}, + {3, Jan1Noon}, + }, + }, + { + Query: "update t set i = 10 where i = 1;", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}, + }, + }, + { + Query: "select * from t order by i;", + Expected: []sql.Row{ + {2, Jan1Noon}, + {3, Jan1Noon}, + {10, Dec15_1_30}, + }, + }, + { + Query: "update t set i = 100", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 3, Info: plan.UpdateInfo{Matched: 3, Updated: 3}}}, + }, + }, + { + Query: "select * from t order by i;", + Expected: []sql.Row{ + {100, Dec15_1_30}, + {100, Dec15_1_30}, + {100, Dec15_1_30}, + }, + }, + { + // updating timestamp itself blocks on update + Query: "update t set ts = timestamp('2020-10-2')", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 3, Info: plan.UpdateInfo{Matched: 3, Updated: 3}}}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {100, Oct2Midnight}, + {100, Oct2Midnight}, + {100, Oct2Midnight}, + }, + }, + }, + }, + { + Name: "alter table", + SetUpScript: []string{ + "create table t (i int, ts timestamp);", + "insert into t(i) values (1), (2), (3);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "show create table t", + Expected: []sql.Row{ + {"t", "CREATE TABLE `t` (\n" + + " `i` int,\n" + + " `ts` timestamp\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + Query: "alter table t modify column ts timestamp default 0 on update current_timestamp;", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "show create table t", + Expected: []sql.Row{ + {"t", "CREATE TABLE `t` (\n" + + " `i` int,\n" + + " `ts` timestamp DEFAULT 0 ON UPDATE (CURRENT_TIMESTAMP())\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + Query: "select * from t order by i;", + Expected: []sql.Row{ + {1, nil}, + {2, nil}, + {3, nil}, + }, + }, + { + Query: "update t set i = 10 where i = 1;", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}, + }, + }, + { + Query: "select * from t order by i;", + Expected: []sql.Row{ + {2, nil}, + {3, nil}, + {10, Dec15_1_30}, + }, + }, + { + Query: "update t set i = 100", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 3, Info: plan.UpdateInfo{Matched: 3, Updated: 3}}}, + }, + }, + { + Query: "select * from t order by i;", + Expected: []sql.Row{ + {100, Dec15_1_30}, + {100, Dec15_1_30}, + {100, Dec15_1_30}, + }, + }, + { + Query: "update t set ts = timestamp('2020-10-2')", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 3, Info: plan.UpdateInfo{Matched: 3, Updated: 3}}}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {100, Oct2Midnight}, + {100, Oct2Midnight}, + {100, Oct2Midnight}, + }, + }, + }, + }, + { + Name: "multiple columns case", + SetUpScript: []string{ + "create table t (i int primary key, ts timestamp default 0 on update current_timestamp, dt datetime default 0 on update current_timestamp);", + "insert into t(i) values (1), (2), (3);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "show create table t", + Expected: []sql.Row{ + {"t", "CREATE TABLE `t` (\n" + + " `i` int NOT NULL,\n" + + " `ts` timestamp DEFAULT 0 ON UPDATE (CURRENT_TIMESTAMP()),\n" + + " `dt` datetime DEFAULT 0 ON UPDATE (CURRENT_TIMESTAMP()),\n" + + " PRIMARY KEY (`i`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + SkipResultCheckOnServerEngine: true, + Query: "select * from t order by i;", + Expected: []sql.Row{ + {1, ZeroTime, ZeroTime}, + {2, ZeroTime, ZeroTime}, + {3, ZeroTime, ZeroTime}, + }, + }, + { + Query: "update t set i = 10 where i = 1;", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}, + }, + }, + { + SkipResultCheckOnServerEngine: true, + Query: "select * from t order by i;", + Expected: []sql.Row{ + {2, ZeroTime, ZeroTime}, + {3, ZeroTime, ZeroTime}, + {10, Dec15_1_30, Dec15_1_30}, + }, + }, + { + Query: "update t set ts = timestamp('2020-10-2') where i = 2", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}, + }, + }, + { + SkipResultCheckOnServerEngine: true, + Query: "select * from t order by i;", + Expected: []sql.Row{ + {2, Oct2Midnight, Dec15_1_30}, + {3, ZeroTime, ZeroTime}, + {10, Dec15_1_30, Dec15_1_30}, + }, + }, + }, + }, + { + // before update triggers that update the timestamp column block the on update + Name: "before update trigger", + SetUpScript: []string{ + "create table t (i int primary key, ts timestamp default 0 on update current_timestamp, dt datetime default 0 on update current_timestamp);", + "create trigger trig before update on t for each row set new.ts = timestamp('2020-10-2');", + "insert into t(i) values (1), (2), (3);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "update t set i = 10 where i = 1;", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}, + }, + }, + { + SkipResultCheckOnServerEngine: true, + Query: "select * from t order by i;", + Expected: []sql.Row{ + {2, ZeroTime, ZeroTime}, + {3, ZeroTime, ZeroTime}, + {10, Oct2Midnight, Dec15_1_30}, + }, + }, + }, + }, + { + // update triggers that update other tables do not block on update + Name: "after update trigger", + SetUpScript: []string{ + "create table a (i int primary key);", + "create table b (i int, ts timestamp default 0 on update current_timestamp, dt datetime default 0 on update current_timestamp);", + "create trigger trig after update on a for each row update b set i = i + 1;", + "insert into a values (0);", + "insert into b(i) values (0);", + }, + Assertions: []ScriptTestAssertion{ + { + SkipResultCheckOnServerEngine: true, + Query: "select * from b order by i;", + Expected: []sql.Row{ + {0, ZeroTime, ZeroTime}, + }, + }, + { + Query: "update a set i = 10;", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}, + }, + }, + { + Query: "select * from b order by i;", + Expected: []sql.Row{ + {1, Dec15_1_30, Dec15_1_30}, + }, + }, + }, + }, + { + Name: "insert triggers", + SetUpScript: []string{ + "create table t (i int primary key);", + "create table a (i int, ts timestamp default 0 on update current_timestamp, dt datetime default 0 on update current_timestamp);", + "create table b (i int, ts timestamp default 0 on update current_timestamp, dt datetime default 0 on update current_timestamp);", + "create trigger trigA after insert on t for each row update a set i = i + 1;", + "create trigger trigB before insert on t for each row update b set i = i + 1;", + "insert into a(i) values (0);", + "insert into b(i) values (0);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into t values (1);", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "select * from a order by i;", + Expected: []sql.Row{ + {1, Dec15_1_30, Dec15_1_30}, + }, + }, + { + Query: "select * from b order by i;", + Expected: []sql.Row{ + {1, Dec15_1_30, Dec15_1_30}, + }, + }, + }, + }, + { + // Foreign Key Cascade Update does NOT trigger on update on child table + Name: "foreign key tests", + SetUpScript: []string{ + "create table parent (i int primary key);", + "create table child (i int primary key, ts timestamp default 0 on update current_timestamp, foreign key (i) references parent(i) on update cascade);", + "insert into parent values (1);", + "insert into child(i) values (1);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "update parent set i = 10;", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}, + }, + }, + { + SkipResultCheckOnServerEngine: true, + Query: "select * from child;", + Expected: []sql.Row{ + {10, ZeroTime}, + }, + }, + }, + }, + { + Name: "stored procedure tests", + SetUpScript: []string{ + "create table t (i int, ts timestamp default 0 on update current_timestamp);", + "insert into t(i) values (0);", + "create procedure p() update t set i = i + 1;", + }, + Assertions: []ScriptTestAssertion{ + { + // call depends on stored procedure stmt for whether to use 'query' or 'exec' from go sql driver. + SkipResultCheckOnServerEngine: true, + Query: "call p();", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {1, Dec15_1_30}, + }, + }, + }, + }, +} diff --git a/memory/table.go b/memory/table.go index ac1c3334d5..3c95aaa501 100644 --- a/memory/table.go +++ b/memory/table.go @@ -181,6 +181,20 @@ func NewPartitionedTableWithCollation(db *BaseDatabase, name string, schema sql. unrDef := sql.NewUnresolvedColumnDefaultValue(defStr) cCopy.Generated = unrDef } + if cCopy.OnUpdate != nil { + newDef, _, _ := transform.Expr(cCopy.OnUpdate, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + switch e := e.(type) { + case *expression.GetField: + // strip table names + return expression.NewGetField(e.Index(), e.Type(), e.Name(), e.IsNullable()), transform.NewTree, nil + default: + } + return e, transform.SameTree, nil + }) + defStr := newDef.String() + unrDef := sql.NewUnresolvedColumnDefaultValue(defStr) + cCopy.OnUpdate = unrDef + } newSchema[i] = cCopy } diff --git a/sql/column.go b/sql/column.go index c7783f177b..e1887f8192 100644 --- a/sql/column.go +++ b/sql/column.go @@ -53,6 +53,8 @@ type Column struct { // Virtual column values will be provided for write operations, in case integrators need to use them to update // indexes, but must not be returned in rows from tables that include them. Virtual bool + // OnUpdate contains the on update value of the column or nil if it was not explicitly defined. + OnUpdate *ColumnDefaultValue } // Check ensures the value is correct for this column. diff --git a/sql/errors.go b/sql/errors.go index 1686df312e..c6dfa1bcc3 100644 --- a/sql/errors.go +++ b/sql/errors.go @@ -866,6 +866,8 @@ var ( // ErrGeneratedColumnWithDefault is returned when a column specifies both a default and a generated value ErrGeneratedColumnWithDefault = errors.NewKind("Incorrect usage of DEFAULT and generated column") + ErrInvalidOnUpdate = errors.NewKind("Invalid ON UPDATE clause for '%s' column") + ErrInsertIntoMismatchValueCount = errors.NewKind("number of values does not match number of columns provided") ErrInvalidTypeForLimit = errors.NewKind("invalid limit. expected %T, found %T") diff --git a/sql/planbuilder/ddl.go b/sql/planbuilder/ddl.go index 0b9f373be5..29a8e15a8d 100644 --- a/sql/planbuilder/ddl.go +++ b/sql/planbuilder/ddl.go @@ -1035,6 +1035,7 @@ func (b *Builder) tableSpecToSchema(inScope, outScope *scope, db sql.Database, t defaults := make([]ast.Expr, len(tableSpec.Columns)) generated := make([]ast.Expr, len(tableSpec.Columns)) + updates := make([]ast.Expr, len(tableSpec.Columns)) var schema sql.Schema for i, cd := range tableSpec.Columns { sqlType := cd.Type.SQLType() @@ -1046,6 +1047,7 @@ func (b *Builder) tableSpecToSchema(inScope, outScope *scope, db sql.Database, t } defaults[i] = cd.Type.Default generated[i] = cd.Type.GeneratedExpr + updates[i] = cd.Type.OnUpdate column := b.columnDefinitionToColumn(inScope, cd, tableSpec.Indexes) column.DatabaseSource = db.Name() @@ -1084,6 +1086,13 @@ func (b *Builder) tableSpecToSchema(inScope, outScope *scope, db sql.Database, t } } + for i, onUpdateExpr := range updates { + schema[i].OnUpdate = b.convertDefaultExpression(outScope, onUpdateExpr, schema[i].Type, schema[i].Nullable) + if schema[i].OnUpdate != nil && !(types.IsDatetimeType(schema[i].Type) || types.IsTimestampType(schema[i].Type)) { + b.handleErr(sql.ErrInvalidOnUpdate.New(schema[i].Name)) + } + } + return sql.NewPrimaryKeySchema(schema, getPkOrdinals(tableSpec)...), tableCollation } @@ -1233,6 +1242,7 @@ func (b *Builder) resolveSchemaDefaults(inScope *scope, schema sql.Schema) sql.S for _, col := range newSch { col.Default = b.resolveColumnDefaultExpression(inScope, col, col.Default) col.Generated = b.resolveColumnDefaultExpression(inScope, col, col.Generated) + col.OnUpdate = b.resolveColumnDefaultExpression(inScope, col, col.OnUpdate) } return newSch } diff --git a/sql/planbuilder/dml.go b/sql/planbuilder/dml.go index ccd11e1676..384605afc1 100644 --- a/sql/planbuilder/dml.go +++ b/sql/planbuilder/dml.go @@ -259,15 +259,23 @@ func (b *Builder) assignmentExprsToExpressions(inScope *scope, e ast.AssignmentE } } - // We need additional update expressions for any generated columns, since they won't be part of the update + // We need additional update expressions for any generated columns and on update expressions, since they won't be part of the update // expressions, but their value in the row must be updated before being passed to the integrator for storage. if len(tableSch) > 0 { tabId := inScope.tables[strings.ToLower(tableSch[0].Source)] for i, col := range tableSch { if col.Generated != nil { - colName := expression.NewGetFieldWithTable(i, int(tabId), col.Type, col.DatabaseSource, col.Source, col.Name, col.Nullable) + colGf := expression.NewGetFieldWithTable(i, int(tabId), col.Type, col.DatabaseSource, col.Source, col.Name, col.Nullable) generated := b.resolveColumnDefaultExpression(inScope, col, col.Generated) - updateExprs = append(updateExprs, expression.NewSetField(colName, assignColumnIndexes(generated, tableSch))) + updateExprs = append(updateExprs, expression.NewSetField(colGf, assignColumnIndexes(generated, tableSch))) + } + if col.OnUpdate != nil { + // don't add if column is already being updated + if !isColumnUpdated(col, updateExprs) { + colGf := expression.NewGetFieldWithTable(i, int(tabId), col.Type, col.DatabaseSource, col.Source, col.Name, col.Nullable) + onUpdate := b.resolveColumnDefaultExpression(inScope, col, col.OnUpdate) + updateExprs = append(updateExprs, expression.NewSetField(colGf, assignColumnIndexes(onUpdate, tableSch))) + } } } } @@ -275,6 +283,23 @@ func (b *Builder) assignmentExprsToExpressions(inScope *scope, e ast.AssignmentE return updateExprs } +func isColumnUpdated(col *sql.Column, updateExprs []sql.Expression) bool { + for _, expr := range updateExprs { + sf, ok := expr.(*expression.SetField) + if !ok { + continue + } + gf, ok := sf.Left.(*expression.GetField) + if !ok { + continue + } + if strings.EqualFold(gf.Name(), col.Name) { + return true + } + } + return false +} + func (b *Builder) buildOnDupUpdateExprs(combinedScope, destScope *scope, e ast.AssignmentExprs) []sql.Expression { b.insertActive = true defer func() { diff --git a/sql/rowexec/show_iters.go b/sql/rowexec/show_iters.go index 596edec3ea..07cea5a8da 100644 --- a/sql/rowexec/show_iters.go +++ b/sql/rowexec/show_iters.go @@ -379,12 +379,23 @@ func (i *showCreateTablesIter) produceCreateTableStatement(ctx *sql.Context, tab colDefaultStr = fmt.Sprintf("'%v'", v) } } + var onUpdateStr string + if col.OnUpdate != nil { + onUpdateStr = col.OnUpdate.String() + if onUpdateStr != "NULL" && col.OnUpdate.IsLiteral() && !types.IsTime(col.OnUpdate.Type()) && !types.IsText(col.OnUpdate.Type()) { + v, err := col.OnUpdate.Eval(ctx, nil) + if err != nil { + return "", err + } + onUpdateStr = fmt.Sprintf("'%v'", v) + } + } if col.PrimaryKey && len(pkSchema.Schema) == 0 { pkOrdinals = append(pkOrdinals, i) } - colStmts[i] = sql.GenerateCreateTableColumnDefinition(col, colDefaultStr, tableCollation) + colStmts[i] = sql.GenerateCreateTableColumnDefinition(col, colDefaultStr, onUpdateStr, tableCollation) } for _, i := range pkOrdinals { diff --git a/sql/sqlfmt.go b/sql/sqlfmt.go index 77d730cdca..5d9cd15735 100644 --- a/sql/sqlfmt.go +++ b/sql/sqlfmt.go @@ -38,7 +38,7 @@ func GenerateCreateTableStatement(tblName string, colStmts []string, tblCharsetN // GenerateCreateTableColumnDefinition returns column definition string for 'CREATE TABLE' statement for given column. // This part comes first in the 'CREATE TABLE' statement. -func GenerateCreateTableColumnDefinition(col *Column, colDefault string, tableCollation CollationID) string { +func GenerateCreateTableColumnDefinition(col *Column, colDefault, onUpdate string, tableCollation CollationID) string { var colTypeString string if collationType, ok := col.Type.(TypeWithCollation); ok { colTypeString = collationType.StringWithTableCollation(tableCollation) @@ -72,6 +72,10 @@ func GenerateCreateTableColumnDefinition(col *Column, colDefault string, tableCo stmt = fmt.Sprintf("%s DEFAULT %s", stmt, colDefault) } + if col.OnUpdate != nil { + stmt = fmt.Sprintf("%s ON UPDATE %s", stmt, onUpdate) + } + if col.Comment != "" { stmt = fmt.Sprintf("%s COMMENT '%s'", stmt, col.Comment) }