Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -1993,6 +1993,221 @@ CREATE TABLE tab3 (
},
},
},
{
Name: "last_insert_id(default) behavior",
SetUpScript: []string{
"create table t (pk int primary key auto_increment, i int default 0)",
},
Assertions: []ScriptTestAssertion{
{
Query: "insert into t(pk) values (default);",
Expected: []sql.Row{{types.OkResult{RowsAffected: 1, InsertID: 1}}},
},
{
Query: "select last_insert_id()",
Expected: []sql.Row{
{uint64(1)},
},
},
{
Query: "select * from t",
Expected: []sql.Row{
{1, 0},
},
},

{
Query: "insert into t(pk) values (default), (default), (default), (default), (default);",
Expected: []sql.Row{{types.OkResult{RowsAffected: 5, InsertID: 2}}},
},
{
Query: "select last_insert_id()",
Expected: []sql.Row{
{uint64(2)},
},
},
{
Query: "select * from t",
Expected: []sql.Row{
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{5, 0},
{6, 0},
},
},

{
Query: "insert into t(pk) values (10), (default);",
Expected: []sql.Row{{types.OkResult{RowsAffected: 2, InsertID: 10}}},
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(minor) should InsertID be 11? I couldn't see where MySQL was returning this field, but it seems like it would be the same as last_insert_id? (all I saw MySQL return was the RowsAffected)

Not a big deal, and I don't think we need to change it here, since it isn't really user facing, but figured I'd mention it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It definitely seems like it should be 11.
Fixing it breaks a bunch of tests, and I'm not sure if they are wrong or not.
Will look further into this in a future PR.

},
{
Query: "select last_insert_id()",
Expected: []sql.Row{
{uint64(11)},
},
},
{
Query: "select * from t",
Expected: []sql.Row{
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{5, 0},
{6, 0},
{10, 0},
{11, 0},
},
},

{
Query: "insert into t(pk) values (20), (default), (default);",
Expected: []sql.Row{{types.OkResult{RowsAffected: 3, InsertID: 20}}},
},
{
Query: "select last_insert_id()",
Expected: []sql.Row{
{uint64(21)},
},
},
{
Query: "select * from t",
Expected: []sql.Row{
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{5, 0},
{6, 0},
{10, 0},
{11, 0},
{20, 0},
{21, 0},
{22, 0},
},
},

{
Query: "insert into t(i) values (100);",
Expected: []sql.Row{{types.OkResult{RowsAffected: 1, InsertID: 23}}},
},
{
Query: "select last_insert_id()",
Expected: []sql.Row{
{uint64(23)},
},
},
{
Query: "select * from t",
Expected: []sql.Row{
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{5, 0},
{6, 0},
{10, 0},
{11, 0},
{20, 0},
{21, 0},
{22, 0},
{23, 100},
},
},

{
Query: "insert into t(i, pk) values (200, default);",
Expected: []sql.Row{{types.OkResult{RowsAffected: 1, InsertID: 24}}},
},
{
Query: "select last_insert_id()",
Expected: []sql.Row{
{uint64(24)},
},
},
{
Query: "select * from t",
Expected: []sql.Row{
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{5, 0},
{6, 0},
{10, 0},
{11, 0},
{20, 0},
{21, 0},
{22, 0},
{23, 100},
{24, 200},
},
},

{
Query: "insert into t(pk) values (null);",
Expected: []sql.Row{{types.OkResult{RowsAffected: 1, InsertID: 25}}},
},
{
Query: "select last_insert_id()",
Expected: []sql.Row{
{uint64(25)},
},
},
{
Query: "select * from t",
Expected: []sql.Row{
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{5, 0},
{6, 0},
{10, 0},
{11, 0},
{20, 0},
{21, 0},
{22, 0},
{23, 100},
{24, 200},
{25, 0},
},
},

{
Query: "insert into t values ();",
Expected: []sql.Row{{types.OkResult{RowsAffected: 1, InsertID: 26}}},
},
{
Query: "select last_insert_id()",
Expected: []sql.Row{
{uint64(26)},
},
},
{
Query: "select * from t",
Expected: []sql.Row{
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{5, 0},
{6, 0},
{10, 0},
{11, 0},
{20, 0},
{21, 0},
{22, 0},
{23, 100},
{24, 200},
{25, 0},
{26, 0},
},
},
},
},
{
Name: "row_count() behavior",
SetUpScript: []string{
Expand Down
94 changes: 61 additions & 33 deletions sql/analyzer/inserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/dolthub/go-mysql-server/sql/expression/function"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/transform"
"github.com/dolthub/go-mysql-server/sql/types"
)

func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
Expand Down Expand Up @@ -82,12 +83,12 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc
}

// The schema of the destination node and the underlying table differ subtly in terms of defaults
project, autoAutoIncrement, err := wrapRowSource(ctx, source, insertable, insert.Destination.Schema(), columnNames)
project, firstGeneratedAutoIncRowIdx, err := wrapRowSource(ctx, source, insertable, insert.Destination.Schema(), columnNames)
if err != nil {
return nil, transform.SameTree, err
}

return insert.WithSource(project).WithUnspecifiedAutoIncrement(autoAutoIncrement), transform.NewTree, nil
return insert.WithSource(project).WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx), transform.NewTree, nil
})
}

Expand All @@ -106,42 +107,43 @@ func existsNonZeroValueCount(values sql.Node) bool {
return false
}

func findColIdx(colName string, colNames []string) int {
for i, name := range colNames {
if strings.EqualFold(name, colName) {
return i
}
}
return -1
}

// wrapRowSource returns a projection that wraps the original row source so that its schema matches the full schema of
// the underlying table, in the same order. Also returns a boolean value that indicates whether this row source will
// the underlying table in the same order. Also, returns an integer value that indicates when this row source will
// result in an automatically generated value for an auto_increment column.
func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, schema sql.Schema, columnNames []string) (sql.Node, bool, error) {
func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, schema sql.Schema, columnNames []string) (sql.Node, int, error) {
projExprs := make([]sql.Expression, len(schema))
autoAutoIncrement := false

for i, f := range schema {
columnExplicitlySpecified := false
for j, col := range columnNames {
if strings.EqualFold(f.Name, col) {
projExprs[i] = expression.NewGetField(j, f.Type, f.Name, f.Nullable)
columnExplicitlySpecified = true
break
}
}
firstGeneratedAutoIncRowIdx := -1

if !columnExplicitlySpecified {
defaultExpr := f.Default
for i, col := range schema {
colIdx := findColIdx(col.Name, columnNames)
// if column was not explicitly specified, try to substitute with default or generated value
if colIdx == -1 {
defaultExpr := col.Default
if defaultExpr == nil {
defaultExpr = f.Generated
defaultExpr = col.Generated
}

if !f.Nullable && defaultExpr == nil && !f.AutoIncrement {
return nil, false, sql.ErrInsertIntoNonNullableDefaultNullColumn.New(f.Name)
if !col.Nullable && defaultExpr == nil && !col.AutoIncrement {
return nil, -1, sql.ErrInsertIntoNonNullableDefaultNullColumn.New(col.Name)
}
var err error

colIdx := make(map[string]int)
var err error
colNameToIdx := make(map[string]int)
for i, c := range schema {
colIdx[fmt.Sprintf("%s.%s", strings.ToLower(c.Source), strings.ToLower(c.Name))] = i
colNameToIdx[fmt.Sprintf("%s.%s", strings.ToLower(c.Source), strings.ToLower(c.Name))] = i
}
def, _, err := transform.Expr(defaultExpr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
switch e := e.(type) {
case *expression.GetField:
idx, ok := colIdx[strings.ToLower(e.WithTable(destTbl.Name()).String())]
idx, ok := colNameToIdx[strings.ToLower(e.WithTable(destTbl.Name()).String())]
if !ok {
return nil, transform.SameTree, fmt.Errorf("field not found: %s", e.String())
}
Expand All @@ -151,20 +153,46 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
}
})
if err != nil {
return nil, false, err
return nil, -1, err
}
projExprs[i] = def
} else {
projExprs[i] = expression.NewGetField(colIdx, col.Type, col.Name, col.Nullable)
}

if f.AutoIncrement {
if col.AutoIncrement {
// Regardless of whether the column was explicitly specified, if it is an auto increment column, we need to
// wrap it in an AutoIncrement expression.
ai, err := expression.NewAutoIncrement(ctx, destTbl, projExprs[i])
if err != nil {
return nil, false, err
return nil, -1, err
}
projExprs[i] = ai

if !columnExplicitlySpecified {
autoAutoIncrement = true
if colIdx == -1 {
// Auto increment column was not specified explicitly, so we should increment last_insert_id immediately
firstGeneratedAutoIncRowIdx = 0
} else {
// Additionally, the first NULL, DEFAULT, or empty value is what the last_insert_id should be set to.
switch src := insertSource.(type) {
case *plan.Values:
for ii, tup := range src.ExpressionTuples {
expr := tup[colIdx]
if unwrap, ok := expr.(*expression.Wrapper); ok {
expr = unwrap.Unwrap()
}
if _, isDef := expr.(*sql.ColumnDefaultValue); isDef {
firstGeneratedAutoIncRowIdx = ii
break
}
if lit, isLit := expr.(*expression.Literal); isLit {
if types.Null.Equals(lit.Type()) {
firstGeneratedAutoIncRowIdx = ii
break
}
}
}
}
}
}
}
Expand All @@ -177,7 +205,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
// ColumnDefaultValue to create the UUID), then update the project to include the AutoUuid expression.
newExpr, identity, err := insertAutoUuidExpression(ctx, columnDefaultValue, autoUuidCol)
if err != nil {
return nil, false, err
return nil, -1, err
}
if identity == transform.NewTree {
projExprs[autoUuidColIdx] = newExpr
Expand All @@ -188,12 +216,12 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
// the AutoUuid expression to it.
err := wrapAutoUuidInValuesTuples(ctx, autoUuidCol, insertSource, columnNames)
if err != nil {
return nil, false, err
return nil, -1, err
}
}
}

return plan.NewProject(projExprs, insertSource), autoAutoIncrement, nil
return plan.NewProject(projExprs, insertSource), firstGeneratedAutoIncRowIdx, nil
}

// insertAutoUuidExpression transforms the specified |expr| for |autoUuidCol| and inserts an AutoUuid
Expand Down
1 change: 0 additions & 1 deletion sql/base_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,6 @@ func (s *BaseSession) SetLastQueryInfoInt(key string, value int64) {
}

func (s *BaseSession) GetLastQueryInfoInt(key string) int64 {

value, ok := s.lastQueryInfo[key].Load().(int64)
if !ok {
panic(fmt.Sprintf("last query info value stored for %s is not an int64 value, but a %T", key, s.lastQueryInfo[key]))
Expand Down
Loading