diff --git a/enginetest/queries/trigger_queries.go b/enginetest/queries/trigger_queries.go index ff66cff884..118aa9b41e 100644 --- a/enginetest/queries/trigger_queries.go +++ b/enginetest/queries/trigger_queries.go @@ -506,6 +506,91 @@ var TriggerTests = []ScriptTest{ }, }, }, + { + Name: "insert trigger with missing column default value", + SetUpScript: []string{ + "CREATE TABLE t (i INT PRIMARY KEY, j INT NOT NULL);", + ` +CREATE TRIGGER trig BEFORE INSERT ON t +FOR EACH ROW +BEGIN + SET new.j = 10; +END;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "INSERT INTO t (i) VALUES (1);", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "INSERT INTO t (i, j) VALUES (2, null);", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 1}}, + }, + }, + { + Query: "SELECT * FROM t;", + Expected: []sql.Row{ + {1, 10}, + {2, 10}, + }, + }, + }, + }, + { + Name: "not null column with trigger that sets null should error", + SetUpScript: []string{ + "CREATE TABLE t (i INT PRIMARY KEY, j INT NOT NULL);", + ` +CREATE TRIGGER trig BEFORE INSERT ON t +FOR EACH ROW +BEGIN + SET new.j = null; +END;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "INSERT INTO t (i) VALUES (1);", + ExpectedErr: sql.ErrInsertIntoNonNullableDefaultNullColumn, + }, + { + Query: "INSERT INTO t (i, j) VALUES (1, 2);", + ExpectedErr: sql.ErrInsertIntoNonNullableProvidedNull, + }, + }, + }, + { + Name: "not null column with before insert trigger should error", + SetUpScript: []string{ + "CREATE TABLE t (i INT PRIMARY KEY, j INT NOT NULL);", + ` +CREATE TRIGGER trig BEFORE INSERT ON t +FOR EACH ROW +BEGIN + SET new.i = 10 * new.i; +END;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "INSERT INTO t (i) VALUES (1);", + ExpectedErr: sql.ErrInsertIntoNonNullableDefaultNullColumn, + }, + { + Query: "INSERT INTO t (i, j) VALUES (1, 2);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "SELECT * FROM t;", + Expected: []sql.Row{ + {10, 2}, + }, + }, + }, + }, // UPDATE triggers { diff --git a/sql/analyzer/inserts.go b/sql/analyzer/inserts.go index 3c108bcd57..503db27608 100644 --- a/sql/analyzer/inserts.go +++ b/sql/analyzer/inserts.go @@ -81,12 +81,23 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc } // The schema of the destination node and the underlying table differ subtly in terms of defaults - project, firstGeneratedAutoIncRowIdx, err := wrapRowSource(ctx, source, insertable, insert.Destination.Schema(), columnNames) + var deferredDefaults sql.FastIntSet + project, firstGeneratedAutoIncRowIdx, deferredDefaults, err := wrapRowSource( + ctx, + source, + insertable, + insert.Destination.Schema(), + columnNames, + ) if err != nil { return nil, transform.SameTree, err } - return insert.WithSource(project).WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx), transform.NewTree, nil + return insert.WithSource(project). + WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx). + WithDeferredDefaults(deferredDefaults), + transform.NewTree, + nil }) } @@ -117,8 +128,9 @@ func findColIdx(colName string, colNames []string) int { // wrapRowSource returns a projection that wraps the original row source so that its schema matches the full schema of // the underlying table in the same order. Also, returns an integer value that indicates when this row source will // result in an automatically generated value for an auto_increment column. -func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, schema sql.Schema, columnNames []string) (sql.Node, int, error) { +func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, schema sql.Schema, columnNames []string) (sql.Node, int, sql.FastIntSet, error) { projExprs := make([]sql.Expression, len(schema)) + deferredDefaults := sql.NewFastIntSet() firstGeneratedAutoIncRowIdx := -1 for i, col := range schema { @@ -130,7 +142,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s defaultExpr = col.Generated } if !col.Nullable && defaultExpr == nil && !col.AutoIncrement { - return nil, -1, sql.ErrInsertIntoNonNullableDefaultNullColumn.New(col.Name) + deferredDefaults.Add(i) } var err error @@ -151,7 +163,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s } }) if err != nil { - return nil, -1, err + return nil, -1, sql.FastIntSet{}, err } projExprs[i] = def } else { @@ -163,7 +175,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s // wrap it in an AutoIncrement expression. ai, err := expression.NewAutoIncrement(ctx, destTbl, projExprs[i]) if err != nil { - return nil, -1, err + return nil, -1, sql.FastIntSet{}, err } projExprs[i] = ai @@ -206,7 +218,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s // ColumnDefaultValue to create the UUID), then update the project to include the AutoUuid expression. newExpr, identity, err := insertAutoUuidExpression(ctx, columnDefaultValue, autoUuidCol) if err != nil { - return nil, -1, err + return nil, -1, sql.FastIntSet{}, err } if identity == transform.NewTree { projExprs[autoUuidColIdx] = newExpr @@ -217,12 +229,12 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s // the AutoUuid expression to it. err := wrapAutoUuidInValuesTuples(ctx, autoUuidCol, insertSource, columnNames) if err != nil { - return nil, -1, err + return nil, -1, sql.FastIntSet{}, err } } } - return plan.NewProject(projExprs, insertSource), firstGeneratedAutoIncRowIdx, nil + return plan.NewProject(projExprs, insertSource), firstGeneratedAutoIncRowIdx, deferredDefaults, nil } // isZero returns true if the specified literal value |lit| has a value equal to 0. diff --git a/sql/plan/insert.go b/sql/plan/insert.go index c52acf843d..293d5483e2 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -72,6 +72,9 @@ type InsertInto struct { // FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id. FirstGeneratedAutoIncRowIdx int + + // DeferredDefaults marks which columns in the destination schema are expected to have default values. + DeferredDefaults sql.FastIntSet } var _ sql.Databaser = (*InsertInto)(nil) @@ -201,6 +204,14 @@ func (ii *InsertInto) WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx int) *Ins return &np } +// WithDeferredDefaults sets the flags for the insert destination columns, which mark which of the columns are expected +// to be filled with the DEFAULT or GENERATED value. +func (ii *InsertInto) WithDeferredDefaults(deferredDefaults sql.FastIntSet) *InsertInto { + np := *ii + np.DeferredDefaults = deferredDefaults + return &np +} + // String implements the fmt.Stringer interface. func (ii *InsertInto) String() string { pr := sql.NewTreePrinter() diff --git a/sql/rowexec/dml.go b/sql/rowexec/dml.go index 944183d16c..99b148a435 100644 --- a/sql/rowexec/dml.go +++ b/sql/rowexec/dml.go @@ -90,6 +90,7 @@ func (b *BaseBuilder) buildInsertInto(ctx *sql.Context, ii *plan.InsertInto, row ctx: ctx, ignore: ii.Ignore, firstGeneratedAutoIncRowIdx: ii.FirstGeneratedAutoIncRowIdx, + deferredDefaults: ii.DeferredDefaults, } var ed sql.EditOpenerCloser diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index 573c5fc7f7..a23f67f1f5 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -45,6 +45,8 @@ type insertIter struct { ignore bool firstGeneratedAutoIncRowIdx int + + deferredDefaults sql.FastIntSet } func getInsertExpressions(values sql.Node) []sql.Expression { @@ -395,12 +397,14 @@ func (i *insertIter) validateNullability(ctx *sql.Context, dstSchema sql.Schema, for count, col := range dstSchema { if !col.Nullable && row[count] == nil { // In the case of an IGNORE we set the nil value to a default and add a warning - if i.ignore { - row[count] = col.Type.Zero() - _ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil - } else { + if !i.ignore { + if i.deferredDefaults.Contains(count) { + return sql.ErrInsertIntoNonNullableDefaultNullColumn.New(col.Name) + } return sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name) } + row[count] = col.Type.Zero() + _ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil } } return nil