Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions enginetest/queries/trigger_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,91 @@ var TriggerTests = []ScriptTest{
},
},
},
{
Name: "insert trigger with missing column default value",
SetUpScript: []string{
"CREATE TABLE t (i INT PRIMARY KEY, j INT NOT NULL);",
`
CREATE TRIGGER trig BEFORE INSERT ON t
FOR EACH ROW
BEGIN
SET new.j = 10;
END;`,
},
Assertions: []ScriptTestAssertion{
{
Query: "INSERT INTO t (i) VALUES (1);",
Expected: []sql.Row{
{types.OkResult{RowsAffected: 1}},
},
},
{
Query: "INSERT INTO t (i, j) VALUES (2, null);",
Expected: []sql.Row{
{types.OkResult{RowsAffected: 1}},
},
},
{
Query: "SELECT * FROM t;",
Expected: []sql.Row{
{1, 10},
{2, 10},
},
},
},
},
{
Name: "not null column with trigger that sets null should error",
SetUpScript: []string{
"CREATE TABLE t (i INT PRIMARY KEY, j INT NOT NULL);",
`
CREATE TRIGGER trig BEFORE INSERT ON t
FOR EACH ROW
BEGIN
SET new.j = null;
END;`,
},
Assertions: []ScriptTestAssertion{
{
Query: "INSERT INTO t (i) VALUES (1);",
ExpectedErr: sql.ErrInsertIntoNonNullableDefaultNullColumn,
},
{
Query: "INSERT INTO t (i, j) VALUES (1, 2);",
ExpectedErr: sql.ErrInsertIntoNonNullableProvidedNull,
},
},
},
{
Name: "not null column with before insert trigger should error",
SetUpScript: []string{
"CREATE TABLE t (i INT PRIMARY KEY, j INT NOT NULL);",
`
CREATE TRIGGER trig BEFORE INSERT ON t
FOR EACH ROW
BEGIN
SET new.i = 10 * new.i;
END;`,
},
Assertions: []ScriptTestAssertion{
{
Query: "INSERT INTO t (i) VALUES (1);",
ExpectedErr: sql.ErrInsertIntoNonNullableDefaultNullColumn,
},
{
Query: "INSERT INTO t (i, j) VALUES (1, 2);",
Expected: []sql.Row{
{types.NewOkResult(1)},
},
},
{
Query: "SELECT * FROM t;",
Expected: []sql.Row{
{10, 2},
},
},
},
},

// UPDATE triggers
{
Expand Down
30 changes: 21 additions & 9 deletions sql/analyzer/inserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,23 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc
}

// The schema of the destination node and the underlying table differ subtly in terms of defaults
project, firstGeneratedAutoIncRowIdx, err := wrapRowSource(ctx, source, insertable, insert.Destination.Schema(), columnNames)
var deferredDefaults sql.FastIntSet
project, firstGeneratedAutoIncRowIdx, deferredDefaults, err := wrapRowSource(
ctx,
source,
insertable,
insert.Destination.Schema(),
columnNames,
)
if err != nil {
return nil, transform.SameTree, err
}

return insert.WithSource(project).WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx), transform.NewTree, nil
return insert.WithSource(project).
WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx).
WithDeferredDefaults(deferredDefaults),
transform.NewTree,
nil
})
}

Expand Down Expand Up @@ -117,8 +128,9 @@ func findColIdx(colName string, colNames []string) int {
// wrapRowSource returns a projection that wraps the original row source so that its schema matches the full schema of
// the underlying table in the same order. Also, returns an integer value that indicates when this row source will
// result in an automatically generated value for an auto_increment column.
func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, schema sql.Schema, columnNames []string) (sql.Node, int, error) {
func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, schema sql.Schema, columnNames []string) (sql.Node, int, sql.FastIntSet, error) {
projExprs := make([]sql.Expression, len(schema))
deferredDefaults := sql.NewFastIntSet()
firstGeneratedAutoIncRowIdx := -1

for i, col := range schema {
Expand All @@ -130,7 +142,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
defaultExpr = col.Generated
}
if !col.Nullable && defaultExpr == nil && !col.AutoIncrement {
return nil, -1, sql.ErrInsertIntoNonNullableDefaultNullColumn.New(col.Name)
deferredDefaults.Add(i)
}

var err error
Expand All @@ -151,7 +163,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
}
})
if err != nil {
return nil, -1, err
return nil, -1, sql.FastIntSet{}, err
}
projExprs[i] = def
} else {
Expand All @@ -163,7 +175,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
// wrap it in an AutoIncrement expression.
ai, err := expression.NewAutoIncrement(ctx, destTbl, projExprs[i])
if err != nil {
return nil, -1, err
return nil, -1, sql.FastIntSet{}, err
}
projExprs[i] = ai

Expand Down Expand Up @@ -206,7 +218,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
// ColumnDefaultValue to create the UUID), then update the project to include the AutoUuid expression.
newExpr, identity, err := insertAutoUuidExpression(ctx, columnDefaultValue, autoUuidCol)
if err != nil {
return nil, -1, err
return nil, -1, sql.FastIntSet{}, err
}
if identity == transform.NewTree {
projExprs[autoUuidColIdx] = newExpr
Expand All @@ -217,12 +229,12 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
// the AutoUuid expression to it.
err := wrapAutoUuidInValuesTuples(ctx, autoUuidCol, insertSource, columnNames)
if err != nil {
return nil, -1, err
return nil, -1, sql.FastIntSet{}, err
}
}
}

return plan.NewProject(projExprs, insertSource), firstGeneratedAutoIncRowIdx, nil
return plan.NewProject(projExprs, insertSource), firstGeneratedAutoIncRowIdx, deferredDefaults, nil
}

// isZero returns true if the specified literal value |lit| has a value equal to 0.
Expand Down
11 changes: 11 additions & 0 deletions sql/plan/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ type InsertInto struct {

// FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id.
FirstGeneratedAutoIncRowIdx int

// DeferredDefaults marks which columns in the destination schema are expected to have default values.
DeferredDefaults sql.FastIntSet
}

var _ sql.Databaser = (*InsertInto)(nil)
Expand Down Expand Up @@ -201,6 +204,14 @@ func (ii *InsertInto) WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx int) *Ins
return &np
}

// WithDeferredDefaults sets the flags for the insert destination columns, which mark which of the columns are expected
// to be filled with the DEFAULT or GENERATED value.
func (ii *InsertInto) WithDeferredDefaults(deferredDefaults sql.FastIntSet) *InsertInto {
np := *ii
np.DeferredDefaults = deferredDefaults
return &np
}

// String implements the fmt.Stringer interface.
func (ii *InsertInto) String() string {
pr := sql.NewTreePrinter()
Expand Down
1 change: 1 addition & 0 deletions sql/rowexec/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func (b *BaseBuilder) buildInsertInto(ctx *sql.Context, ii *plan.InsertInto, row
ctx: ctx,
ignore: ii.Ignore,
firstGeneratedAutoIncRowIdx: ii.FirstGeneratedAutoIncRowIdx,
deferredDefaults: ii.DeferredDefaults,
}

var ed sql.EditOpenerCloser
Expand Down
12 changes: 8 additions & 4 deletions sql/rowexec/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ type insertIter struct {
ignore bool

firstGeneratedAutoIncRowIdx int

deferredDefaults sql.FastIntSet
}

func getInsertExpressions(values sql.Node) []sql.Expression {
Expand Down Expand Up @@ -395,12 +397,14 @@ func (i *insertIter) validateNullability(ctx *sql.Context, dstSchema sql.Schema,
for count, col := range dstSchema {
if !col.Nullable && row[count] == nil {
// In the case of an IGNORE we set the nil value to a default and add a warning
if i.ignore {
row[count] = col.Type.Zero()
_ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil
} else {
if !i.ignore {
if i.deferredDefaults.Contains(count) {
return sql.ErrInsertIntoNonNullableDefaultNullColumn.New(col.Name)
}
return sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)
}
row[count] = col.Type.Zero()
_ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil
}
}
return nil
Expand Down