diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index d9bd4704d3..bc3f1ce42b 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -1202,11 +1202,11 @@ CREATE TABLE tab3 ( Assertions: []ScriptTestAssertion{ { Query: `SELECT IS_UUID(UUID())`, - Expected: []sql.Row{{int8(1)}}, + Expected: []sql.Row{{true}}, }, { Query: `SELECT IS_UUID(@uuid)`, - Expected: []sql.Row{{int8(1)}}, + Expected: []sql.Row{{true}}, }, { Query: `SELECT BIN_TO_UUID(UUID_TO_BIN(@uuid))`, @@ -1354,6 +1354,278 @@ CREATE TABLE tab3 ( }, }, }, + { + Name: "last_insert_uuid() behavior", + SetUpScript: []string{ + "create table varchar36 (pk varchar(36) primary key default (UUID()), i int);", + "create table char36 (pk char(36) primary key default (UUID()), i int);", + "create table varbinary16 (pk varbinary(16) primary key default (UUID_to_bin(UUID())), i int);", + "create table binary16 (pk binary(16) primary key default (UUID_to_bin(UUID())), i int);", + "create table binary16swap (pk binary(16) primary key default (UUID_to_bin(UUID(), true)), i int);", + "create table invalid (pk int primary key, c1 varchar(36) default (UUID()));", + "create table prepared (uuid char(36) default (UUID()), ai int auto_increment, c1 varchar(100), primary key (uuid, ai));", + }, + Assertions: []ScriptTestAssertion{ + // The initial value of last_insert_uuid() is an empty string + { + Query: "select last_insert_uuid()", + Expected: []sql.Row{{""}}, + }, + + // invalid table – UUID default is not a primary key, so last_insert_uuid() doesn't get udpated + { + Query: "insert into invalid values (1, DEFAULT);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "select last_insert_uuid()", + Expected: []sql.Row{{""}}, + }, + { + Query: "insert into invalid values (2, UUID());", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "select last_insert_uuid()", + Expected: []sql.Row{{""}}, + }, + + // varchar(36) test cases... + { + Query: "insert into varchar36 values (DEFAULT, 1);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select pk from varchar36 where i=1);", + Expected: []sql.Row{{true, true}}, + }, + { + Query: "insert into varchar36 values (UUID(), 2), (UUID(), 3);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 2}}}, + }, + { + // last_insert_uuid() reports the first UUID() generated in the last insert statement + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select pk from varchar36 where i=2);", + Expected: []sql.Row{{true, true}}, + }, + { + Query: "insert into varchar36 values ('notta-uuid', 4);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + // The previous insert didn't generate a UUID, so last_insert_uuid() doesn't get updated + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select pk from varchar36 where i=2);", + Expected: []sql.Row{{true, true}}, + }, + + // char(36) test cases... + { + Query: "insert into char36 values (DEFAULT, 1);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select pk from char36 where i=1);", + Expected: []sql.Row{{true, true}}, + }, + { + Query: "insert into char36 values (UUID(), 2), (UUID(), 3);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 2}}}, + }, + { + // last_insert_uuid() reports the first UUID() generated in the last insert statement + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select pk from char36 where i=2);", + Expected: []sql.Row{{true, true}}, + }, + { + Query: "insert into char36 values ('notta-uuid', 4);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + // The previous insert didn't generate a UUID, so last_insert_uuid() doesn't get updated + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select pk from char36 where i=2);", + Expected: []sql.Row{{true, true}}, + }, + { + Query: "insert into char36 (i) values (5);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select pk from char36 where i=5);", + Expected: []sql.Row{{true, true}}, + }, + + // varbinary(16) test cases... + { + Query: "insert into varbinary16 values (DEFAULT, 1);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select bin_to_uuid(pk) from varbinary16 where i=1);", + Expected: []sql.Row{{true, true}}, + }, + { + Query: "insert into varbinary16 values (UUID_to_bin(UUID()), 2), (UUID_to_bin(UUID()), 3);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 2}}}, + }, + { + // last_insert_uuid() reports the first UUID() generated in the last insert statement + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select bin_to_uuid(pk) from varbinary16 where i=2);", + Expected: []sql.Row{{true, true}}, + }, + { + Query: "insert into varbinary16 values ('notta-uuid', 4);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + // The previous insert didn't generate a UUID, so last_insert_uuid() doesn't get updated + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select bin_to_uuid(pk) from varbinary16 where i=2);", + Expected: []sql.Row{{true, true}}, + }, + + // binary(16) test cases... + { + Query: "insert into binary16 values (DEFAULT, 1);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select bin_to_uuid(pk) from binary16 where i=1);", + Expected: []sql.Row{{true, true}}, + }, + { + Query: "insert into binary16 values (UUID_to_bin(UUID()), 2), (UUID_to_bin(UUID()), 3);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 2}}}, + }, + { + // last_insert_uuid() reports the first UUID() generated in the last insert statement + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select bin_to_uuid(pk) from binary16 where i=2);", + Expected: []sql.Row{{true, true}}, + }, + { + Query: "insert into binary16 values ('notta-uuid', 4);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + // The previous insert didn't generate a UUID, so last_insert_uuid() doesn't get updated + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select bin_to_uuid(pk) from binary16 where i=2);", + Expected: []sql.Row{{true, true}}, + }, + { + Query: "insert into binary16 (i) values (5);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select bin_to_uuid(pk) from binary16 where i=5);", + Expected: []sql.Row{{true, true}}, + }, + + // binary(16) with UUID_to_bin swap test cases... + { + Query: "insert into binary16swap values (DEFAULT, 1);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select bin_to_uuid(pk, true) from binary16swap where i=1);", + Expected: []sql.Row{{true, true}}, + }, + { + Query: "insert into binary16swap values (UUID_to_bin(UUID(), true), 2), (UUID_to_bin(UUID(), true), 3);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 2}}}, + }, + { + // last_insert_uuid() reports the first UUID() generated in the last insert statement + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select bin_to_uuid(pk, true) from binary16swap where i=2);", + Expected: []sql.Row{{true, true}}, + }, + { + Query: "insert into binary16swap values ('notta-uuid', 4);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + // The previous insert didn't generate a UUID, so last_insert_uuid() doesn't get updated + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select bin_to_uuid(pk, true) from binary16swap where i=2);", + Expected: []sql.Row{{true, true}}, + }, + { + Query: "insert into binary16swap (i) values (5);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select bin_to_uuid(pk, true) from binary16swap where i=5);", + Expected: []sql.Row{{true, true}}, + }, + + // INSERT INTO ... SELECT ... Tests + { + // If we populate the UUID column (pk) with its implicit default, then it updates last_insert_uuid() + Query: "insert into varchar36 (i) select 42 from dual;", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1}}}, + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select pk from varchar36 where i=42);", + Expected: []sql.Row{{true, true}}, + }, + { + // If all values come from another table, the auto_uuid value shouldn't be generated, so last_insert_uuid() doesn't change + Query: "insert into varchar36 (pk, i) (select 'one', 101 from dual union all select 'two', 202);", + Expected: []sql.Row{{types.OkResult{RowsAffected: 2}}}, + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select pk from varchar36 where i=42);", + Expected: []sql.Row{{true, true}}, + }, + + // Prepared statements + { + // Test with an insert statement that implicit uses the UUID column default + Query: `prepare stmt1 from "insert into prepared (c1) values ('odd'), ('even')";`, + Expected: []sql.Row{{types.OkResult{Info: plan.PrepareInfo{}}}}, + }, + { + Query: "execute stmt1;", + Expected: []sql.Row{{types.OkResult{RowsAffected: 2, InsertID: 1}}}, + SkipResultCheckOnServerEngine: true, // Server engine returns []sql.Row{} + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select uuid from prepared where ai=1), last_insert_id();", + Expected: []sql.Row{{true, true, uint64(1)}}, + }, + { + // Executing the prepared statement a second time should refresh last_insert_uuid() + Query: "execute stmt1;", + Expected: []sql.Row{{types.OkResult{RowsAffected: 2, InsertID: 3}}}, + SkipResultCheckOnServerEngine: true, // Server engine returns []sql.Row{} + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select uuid from prepared where ai=3), last_insert_id();", + Expected: []sql.Row{{true, true, uint64(3)}}, + }, + + { + // Test with an insert statement that explicitly uses the UUID column default + Query: `prepare stmt2 from "insert into prepared (uuid, c1) values (DEFAULT, 'more'), (DEFAULT, 'less')";`, + Expected: []sql.Row{{types.OkResult{Info: plan.PrepareInfo{}}}}, + }, + { + Query: "execute stmt2;", + Expected: []sql.Row{{types.OkResult{RowsAffected: 2, InsertID: 5}}}, + SkipResultCheckOnServerEngine: true, // Server engine returns []sql.Row{} + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select uuid from prepared where ai=5), last_insert_id();", + Expected: []sql.Row{{true, true, uint64(5)}}, + }, + { + // Executing the prepared statement a second time should refresh last_insert_uuid() + Query: "execute stmt2;", + Expected: []sql.Row{{types.OkResult{RowsAffected: 2, InsertID: 7}}}, + SkipResultCheckOnServerEngine: true, // Server engine returns []sql.Row{} + }, + { + Query: "select is_uuid(last_insert_uuid()), last_insert_uuid() = (select uuid from prepared where ai=7), last_insert_id();", + Expected: []sql.Row{{true, true, uint64(7)}}, + }, + }, + }, { Name: "last_insert_id() behavior", SetUpScript: []string{ diff --git a/sql/analyzer/inserts.go b/sql/analyzer/inserts.go index 48275761c9..b48e97cf1a 100644 --- a/sql/analyzer/inserts.go +++ b/sql/analyzer/inserts.go @@ -18,8 +18,11 @@ import ( "fmt" "strings" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/expression/function" "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql/types" @@ -167,6 +170,30 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s } } + // Handle auto UUID columns + autoUuidCol, autoUuidColIdx := findAutoUuidColumn(ctx, schema) + if autoUuidCol != nil { + if columnDefaultValue, ok := projExprs[autoUuidColIdx].(*sql.ColumnDefaultValue); ok { + // If the auto UUID column is being populated through the projection (i.e. it's projecting a + // ColumnDefaultValue to create the UUID), then update the project to include the AutoUuid expression. + newExpr, identity, err := insertAutoUuidExpression(ctx, columnDefaultValue, autoUuidCol) + if err != nil { + return nil, false, err + } + if identity == transform.NewTree { + projExprs[autoUuidColIdx] = newExpr + } + } else { + // Otherwise, if the auto UUID column is not getting populated through the projection, then we + // need to look through the tuples to look for the first DEFAULT or UUID() expression and apply + // the AutoUuid expression to it. + err := wrapAutoUuidInValuesTuples(ctx, autoUuidCol, insertSource, columnNames) + if err != nil { + return nil, false, err + } + } + } + err := validateRowSource(insertSource, projExprs) if err != nil { return nil, false, err @@ -175,6 +202,114 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s return plan.NewProject(projExprs, insertSource), autoAutoIncrement, nil } +// insertAutoUuidExpression transforms the specified |expr| for |autoUuidCol| and inserts an AutoUuid +// expression above the UUID() function call, so that the auto generated UUID value can be captured and +// saved to the session's query info. +func insertAutoUuidExpression(ctx *sql.Context, expr sql.Expression, autoUuidCol *sql.Column) (sql.Expression, transform.TreeIdentity, error) { + return transform.Expr(expr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + switch e := e.(type) { + case *function.UUIDFunc: + return expression.NewAutoUuid(ctx, autoUuidCol, e), transform.NewTree, nil + default: + return e, transform.SameTree, nil + } + }) +} + +// findAutoUuidColumn searches the specified |schema| for a column that meets the requirements of an auto UUID +// column, and if found, returns the column, as well as its index in the schema. See isAutoUuidColumn() for the +// requirements on what is considered an auto UUID column. +func findAutoUuidColumn(_ *sql.Context, schema sql.Schema) (autoUuidCol *sql.Column, autoUuidColIdx int) { + for i, col := range schema { + if isAutoUuidColumn(col) { + return col, i + } + } + + return nil, -1 +} + +// wrapAutoUuidInValuesTuples searches the tuples in the |insertSource| (if it is a *plan.Values) for the first +// tuple using a DEFAULT() or a UUID() function expression for the |autoUuidCol|, and wraps the UUID() function +// in an AutoUuid expression so that the generated UUID value can be captured and saved to the session's query info. +// After finding a first occurrence, this function returns, since only the first generated UUID needs to be saved. +// The caller must provide the |columnNames| for the insertSource so that this function can identify the index +// in the value tuples for the auto UUID column. +func wrapAutoUuidInValuesTuples(ctx *sql.Context, autoUuidCol *sql.Column, insertSource sql.Node, columnNames []string) error { + values, ok := insertSource.(*plan.Values) + if !ok { + // If the insert source isn't value tuples, then we don't need to do anything + return nil + } + + // Search the column names in the Values tuples to find the right tuple index + autoUuidColTupleIdx := -1 + for i, columnName := range columnNames { + if strings.ToLower(autoUuidCol.Name) == strings.ToLower(columnName) { + autoUuidColTupleIdx = i + } + } + if autoUuidColTupleIdx == -1 { + return nil + } + + for _, tuple := range values.ExpressionTuples { + expr := tuple[autoUuidColTupleIdx] + if wrapper, ok := expr.(*expression.Wrapper); ok { + expr = wrapper.Unwrap() + } + + switch expr.(type) { + case *sql.ColumnDefaultValue, *function.UUIDFunc, *function.UUIDToBin: + // Only ColumnDefaultValue, UUIDFunc, and UUIDToBin are valid to use in an auto UUID column + newExpr, identity, err := insertAutoUuidExpression(ctx, expr, autoUuidCol) + if err != nil { + return err + } + if identity == transform.NewTree { + tuple[autoUuidColTupleIdx] = newExpr + return nil + } + } + } + + return nil +} + +// isAutoUuidColumn returns true if the specified |col| meets the requirements of an auto generated UUID column. To +// be an auto UUID column, the column must be part of the primary key (it may be a composite primary key), and the +// type must be either varchar(36), char(36), varbinary(16), or binary(16). It must have a default value set to +// populate a UUID, either through the UUID() function (for char and varchar columns) or the UUID_TO_BIN(UUID()) +// function (for binary and varbinary columns). +func isAutoUuidColumn(col *sql.Column) bool { + if col.PrimaryKey == false { + return false + } + + switch col.Type.Type() { + case sqltypes.Char, sqltypes.VarChar: + stringType := col.Type.(sql.StringType) + if stringType.MaxCharacterLength() != 36 || col.Default == nil { + return false + } + if _, ok := col.Default.Expr.(*function.UUIDFunc); ok { + return true + } + case sqltypes.Binary, sqltypes.VarBinary: + stringType := col.Type.(sql.StringType) + if stringType.MaxByteLength() != 16 || col.Default == nil { + return false + } + if uuidToBinFunc, ok := col.Default.Expr.(*function.UUIDToBin); ok { + if _, ok := uuidToBinFunc.Children()[0].(*function.UUIDFunc); ok { + return true + } + } + } + + return false +} + // validGeneratedColumnValue returns true if the column is a generated column and the source node is not a values node. // Explicit default values (`DEFAULT`) are the only valid values to specify for a generated column func validGeneratedColumnValue(idx int, source sql.Node) bool { @@ -204,6 +339,7 @@ func assertCompatibleSchemas(projExprs []sql.Expression, schema sql.Schema) erro switch e := expr.(type) { case *expression.Literal, *expression.AutoIncrement, + *expression.AutoUuid, *sql.ColumnDefaultValue: continue case *expression.GetField: diff --git a/sql/base_session.go b/sql/base_session.go index bfcbd26dff..467900adde 100644 --- a/sql/base_session.go +++ b/sql/base_session.go @@ -15,6 +15,7 @@ package sql import ( + "fmt" "strings" "sync" "sync/atomic" @@ -46,7 +47,7 @@ type BaseSession struct { warncnt uint16 locks map[string]bool queriedDb string - lastQueryInfo map[string]int64 + lastQueryInfo map[string]any tx Transaction ignoreAutocommit bool @@ -437,16 +438,38 @@ func (s *BaseSession) SetViewRegistry(reg *ViewRegistry) { s.viewReg = reg } -func (s *BaseSession) SetLastQueryInfo(key string, value int64) { +func (s *BaseSession) SetLastQueryInfoInt(key string, value int64) { s.mu.Lock() defer s.mu.Unlock() s.lastQueryInfo[key] = value } -func (s *BaseSession) GetLastQueryInfo(key string) int64 { +func (s *BaseSession) GetLastQueryInfoInt(key string) int64 { s.mu.RLock() defer s.mu.RUnlock() - return s.lastQueryInfo[key] + + value, ok := s.lastQueryInfo[key].(int64) + if !ok { + panic(fmt.Sprintf("last query info value stored for %s is not an int64 value, but a %T", key, s.lastQueryInfo[key])) + } + return value +} + +func (s *BaseSession) SetLastQueryInfoString(key string, value string) { + s.mu.Lock() + defer s.mu.Unlock() + s.lastQueryInfo[key] = value +} + +func (s *BaseSession) GetLastQueryInfoString(key string) string { + s.mu.RLock() + defer s.mu.RUnlock() + + value, ok := s.lastQueryInfo[key].(string) + if !ok { + panic(fmt.Sprintf("last query info value stored for %s is not a string value, but a %T", key, s.lastQueryInfo[key])) + } + return value } func (s *BaseSession) GetTransaction() Transaction { diff --git a/sql/expression/auto_uuid.go b/sql/expression/auto_uuid.go new file mode 100644 index 0000000000..669985b8b9 --- /dev/null +++ b/sql/expression/auto_uuid.go @@ -0,0 +1,106 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "fmt" + + "github.com/dolthub/vitess/go/sqltypes" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// AutoUuid is an expression that captures an automatically generated UUID value and stores it in the session for +// later retrieval. AutoUuid is intended to only be used directly on top of a UUID function. +type AutoUuid struct { + UnaryExpression + uuidCol *sql.Column + foundUuid bool +} + +var _ sql.Expression = (*AutoUuid)(nil) +var _ sql.CollationCoercible = (*AutoUuid)(nil) + +// NewAutoUuid creates a new AutoUuid expression. The |child| expression must be a function.UUIDFunc, but +// because of package import cycles, we can't enforce that directly here. +func NewAutoUuid(_ *sql.Context, col *sql.Column, child sql.Expression) *AutoUuid { + return &AutoUuid{ + UnaryExpression: UnaryExpression{Child: child}, + uuidCol: col, + } +} + +// IsNullable implements the Expression interface. +func (au *AutoUuid) IsNullable() bool { + return false +} + +// Type implements the Expression interface. +func (au *AutoUuid) Type() sql.Type { + return types.MustCreateString(sqltypes.Char, 36, sql.Collation_Default) +} + +// CollationCoercibility implements the interface sql.CollationCoercible. +func (au *AutoUuid) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + return sql.GetCoercibility(ctx, au.Child) +} + +// Eval implements the Expression interface. +func (au *AutoUuid) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + childResult, err := au.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + if !au.foundUuid { + uuidValue, ok := childResult.(string) + if !ok { + // This should never happen – AutoUuid should only ever be placed directly above a UUID function, + // so the result from eval'ing its child should *always* be a string. + return nil, fmt.Errorf("unexpected type for UUID value: %T", childResult) + } + + // TODO: Setting this here means that another call to last_insert_uuid() in the same statement could + // read this value too early. We should verify this isn't how MySQL behaves, and then could fix + // by setting a PENDING_LAST_INSERT_UUID value in the session query info, then moving it to + // LAST_INSERT_UUID in the session query info at the end of execution. + ctx.Session.SetLastQueryInfoString(sql.LastInsertUuid, uuidValue) + au.foundUuid = true + } + + return childResult, nil +} + +func (au *AutoUuid) String() string { + return fmt.Sprintf("AutoUuid(%s)", au.Child.String()) +} + +// WithChildren implements the Expression interface. +func (au *AutoUuid) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(au, len(children), 1) + } + return &AutoUuid{ + UnaryExpression: UnaryExpression{Child: children[0]}, + uuidCol: au.uuidCol, + foundUuid: au.foundUuid, + }, nil +} + +// Children implements the Expression interface. +func (au *AutoUuid) Children() []sql.Expression { + return []sql.Expression{au.Child} +} diff --git a/sql/expression/function/queryinfo.go b/sql/expression/function/queryinfo.go index 47f9e7cd2e..f53e80f836 100644 --- a/sql/expression/function/queryinfo.go +++ b/sql/expression/function/queryinfo.go @@ -1,8 +1,24 @@ +// Copyright 2021-2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package function import ( "fmt" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/types" @@ -54,7 +70,7 @@ func (r RowCount) IsNullable() bool { // Eval implements sql.Expression func (r RowCount) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return ctx.GetLastQueryInfo(sql.RowCount), nil + return ctx.GetLastQueryInfoInt(sql.RowCount), nil } // Children implements sql.Expression @@ -72,7 +88,69 @@ func (r RowCount) FunctionName() string { return "row_count" } +// LastInsertUuid implements the LAST_INSERT_UUID() function. This function is +// NOT a standard function in MySQL, but is a useful analogue to LAST_INSERT_ID() +// if customers are inserting UUIDs into a table. +type LastInsertUuid struct{} + +var _ sql.FunctionExpression = LastInsertUuid{} +var _ sql.CollationCoercible = LastInsertUuid{} + +func NewLastInsertUuid(children ...sql.Expression) (sql.Expression, error) { + if len(children) > 0 { + return nil, sql.ErrInvalidChildrenNumber.New(LastInsertUuid{}.String(), len(children), 0) + } + + return &LastInsertUuid{}, nil +} + +func (l LastInsertUuid) CollationCoercibility(_ *sql.Context) (collation sql.CollationID, coercibility byte) { + return sql.Collation_binary, 5 +} + +func (l LastInsertUuid) Resolved() bool { + return true +} + +func (l LastInsertUuid) String() string { + return fmt.Sprintf("%s()", l.FunctionName()) +} + +func (l LastInsertUuid) Type() sql.Type { + return types.MustCreateStringWithDefaults(sqltypes.VarChar, 36) +} + +func (l LastInsertUuid) IsNullable() bool { + return false +} + +func (l LastInsertUuid) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) { + lastInsertUuid := ctx.GetLastQueryInfoString(sql.LastInsertUuid) + result, _, err := l.Type().Convert(lastInsertUuid) + if err != nil { + return nil, err + } + return result, nil +} + +func (l LastInsertUuid) Children() []sql.Expression { + return nil +} + +func (l LastInsertUuid) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewLastInsertUuid(children...) +} + +func (l LastInsertUuid) FunctionName() string { + return "last_insert_uuid" +} + +func (l LastInsertUuid) Description() string { + return "returns the first value of the UUID() function from the last INSERT statement." +} + // LastInsertId implements the LAST_INSERT_ID() function +// https://dev.mysql.com/doc/refman/8.0/en/information-functions.html#function_last-insert-id type LastInsertId struct { expression.UnaryExpression } @@ -129,7 +207,7 @@ func (r LastInsertId) IsNullable() bool { func (r LastInsertId) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // With no arguments, just return the last insert id for this session if len(r.Children()) == 0 { - lastInsertId := ctx.GetLastQueryInfo(sql.LastInsertId) + lastInsertId := ctx.GetLastQueryInfoInt(sql.LastInsertId) unsigned, _, err := types.Uint64.Convert(lastInsertId) if err != nil { return nil, err @@ -147,7 +225,7 @@ func (r LastInsertId) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - ctx.SetLastQueryInfo(sql.LastInsertId, id.(int64)) + ctx.SetLastQueryInfoInt(sql.LastInsertId, id.(int64)) return id, nil } @@ -220,7 +298,7 @@ func (r FoundRows) IsNullable() bool { // Eval implements sql.Expression func (r FoundRows) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return ctx.GetLastQueryInfo(sql.FoundRows), nil + return ctx.GetLastQueryInfoInt(sql.FoundRows), nil } // Children implements sql.Expression diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index 0f98a7be6b..5420833f73 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -147,6 +147,7 @@ var BuiltIns = []sql.Function{ sql.FunctionN{Name: "lag", Fn: func(e ...sql.Expression) (sql.Expression, error) { return window.NewLag(e...) }}, sql.Function1{Name: "last", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewLast(e) }}, sql.FunctionN{Name: "last_insert_id", Fn: NewLastInsertId}, + sql.FunctionN{Name: "last_insert_uuid", Fn: NewLastInsertUuid}, sql.Function1{Name: "lcase", Fn: NewLower}, sql.FunctionN{Name: "lead", Fn: func(e ...sql.Expression) (sql.Expression, error) { return window.NewLead(e...) }}, sql.FunctionN{Name: "least", Fn: NewLeast}, diff --git a/sql/expression/function/uuid.go b/sql/expression/function/uuid.go index 915141b339..a59bff0eb0 100644 --- a/sql/expression/function/uuid.go +++ b/sql/expression/function/uuid.go @@ -54,15 +54,11 @@ import ( type UUIDFunc struct{} -func (u UUIDFunc) IsNonDeterministic() bool { - return true -} - var _ sql.FunctionExpression = &UUIDFunc{} var _ sql.CollationCoercible = &UUIDFunc{} func NewUUIDFunc() sql.Expression { - return UUIDFunc{} + return &UUIDFunc{} } // Description implements sql.FunctionExpression @@ -92,7 +88,7 @@ func (u UUIDFunc) WithChildren(children ...sql.Expression) (sql.Expression, erro return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 0) } - return UUIDFunc{}, nil + return &UUIDFunc{}, nil } func (u UUIDFunc) FunctionName() string { @@ -113,6 +109,10 @@ func (u UUIDFunc) IsNullable() bool { return false } +func (u UUIDFunc) IsNonDeterministic() bool { + return true +} + // IS_UUID(string_uuid) // // Returns 1 if the argument is a valid string-format UUID, 0 if the argument is not a valid UUID, and NULL if the @@ -129,7 +129,7 @@ var _ sql.FunctionExpression = &IsUUID{} var _ sql.CollationCoercible = &IsUUID{} func NewIsUUID(arg sql.Expression) sql.Expression { - return IsUUID{child: arg} + return &IsUUID{child: arg} } // FunctionName implements sql.FunctionExpression @@ -147,7 +147,7 @@ func (u IsUUID) String() string { } func (u IsUUID) Type() sql.Type { - return types.Int8 + return types.Boolean } // CollationCoercibility implements the interface sql.CollationCoercible. @@ -158,7 +158,7 @@ func (IsUUID) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID func (u IsUUID) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { str, err := u.child.Eval(ctx, row) if err != nil { - return 0, err + return nil, err } if str == nil { @@ -169,19 +169,19 @@ func (u IsUUID) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { case string: _, err := uuid.Parse(str) if err != nil { - return int8(0), nil + return false, nil } - return int8(1), nil + return true, nil case []byte: _, err := uuid.ParseBytes(str) if err != nil { - return int8(0), nil + return false, nil } - return int8(1), nil + return true, nil default: - return int8(0), nil + return false, nil } } @@ -190,7 +190,7 @@ func (u IsUUID) WithChildren(children ...sql.Expression) (sql.Expression, error) return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1) } - return IsUUID{child: children[0]}, nil + return &IsUUID{child: children[0]}, nil } func (u IsUUID) Resolved() bool { @@ -241,9 +241,9 @@ var _ sql.CollationCoercible = (*UUIDToBin)(nil) func NewUUIDToBin(args ...sql.Expression) (sql.Expression, error) { switch len(args) { case 1: - return UUIDToBin{inputUUID: args[0]}, nil + return &UUIDToBin{inputUUID: args[0]}, nil case 2: - return UUIDToBin{inputUUID: args[0], swapFlag: args[1]}, nil + return &UUIDToBin{inputUUID: args[0], swapFlag: args[1]}, nil default: return nil, sql.ErrInvalidArgumentNumber.New("UUID_TO_BIN", "1 or 2", len(args)) } @@ -271,7 +271,7 @@ func (UUIDToBin) CollationCoercibility(ctx *sql.Context) (collation sql.Collatio return sql.Collation_binary, 4 } -func (ub UUIDToBin) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { +func (ub *UUIDToBin) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { str, err := ub.inputUUID.Eval(ctx, row) if err != nil { return 0, err @@ -304,7 +304,7 @@ func (ub UUIDToBin) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if err != nil { return nil, err } - return string(bt), nil + return bt, nil } sf, err := ub.swapFlag.Eval(ctx, row) @@ -324,10 +324,10 @@ func (ub UUIDToBin) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - return string(bt), nil + return bt, nil } else if sf.(int8) == 1 { encoding := swapUUIDBytes(parsed) - return string(encoding), nil + return encoding, nil } else { return nil, fmt.Errorf("UUID_TO_BIN received invalid swap flag") } @@ -402,9 +402,9 @@ var _ sql.CollationCoercible = (*BinToUUID)(nil) func NewBinToUUID(args ...sql.Expression) (sql.Expression, error) { switch len(args) { case 1: - return BinToUUID{inputBinary: args[0]}, nil + return &BinToUUID{inputBinary: args[0]}, nil case 2: - return BinToUUID{inputBinary: args[0], swapFlag: args[1]}, nil + return &BinToUUID{inputBinary: args[0], swapFlag: args[1]}, nil default: return nil, sql.ErrInvalidArgumentNumber.New("BIN_TO_UUID", "1 or 2", len(args)) } diff --git a/sql/expression/function/uuid_test.go b/sql/expression/function/uuid_test.go index 7f9da8aa6b..bce11aea53 100644 --- a/sql/expression/function/uuid_test.go +++ b/sql/expression/function/uuid_test.go @@ -41,7 +41,7 @@ func TestUUID(t *testing.T) { // validate that generated uuid is legitimate for IsUUID val := NewIsUUID(uuidE) - require.Equal(t, int8(1), eval(t, val, sql.Row{nil})) + require.Equal(t, true, eval(t, val, sql.Row{nil})) // Use a UUID regex as a sanity check re2 := regexp.MustCompile(`\b[0-9a-f]{8}\b-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-\b[0-9a-f]{12}\b`) @@ -55,14 +55,14 @@ func TestIsUUID(t *testing.T) { value interface{} expected interface{} }{ - {"uuid form 1", types.LongText, "{12345678-1234-5678-1234-567812345678}", int8(1)}, - {"uuid form 2", types.LongText, "12345678123456781234567812345678", int8(1)}, - {"uuid form 3", types.LongText, "12345678-1234-5678-1234-567812345678", int8(1)}, + {"uuid form 1", types.LongText, "{12345678-1234-5678-1234-567812345678}", true}, + {"uuid form 2", types.LongText, "12345678123456781234567812345678", true}, + {"uuid form 3", types.LongText, "12345678-1234-5678-1234-567812345678", true}, {"NULL", types.Null, nil, nil}, - {"random int", types.Int8, 1, int8(0)}, - {"random bool", types.Boolean, false, int8(0)}, - {"random string", types.LongText, "12345678-dasd-fasdf8", int8(0)}, - {"swapped uuid", types.LongText, "5678-1234-12345678-1234-567812345678", int8(0)}, + {"random int", types.Int8, 1, false}, + {"random bool", types.Boolean, false, false}, + {"random string", types.LongText, "12345678-dasd-fasdf8", false}, + {"swapped uuid", types.LongText, "5678-1234-12345678-1234-567812345678", false}, } for _, tt := range testCases { diff --git a/sql/plan/process.go b/sql/plan/process.go index d5aaa2bb2e..94467f1320 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -380,9 +380,9 @@ func (i *trackedRowIter) Close(ctx *sql.Context) error { func (i *trackedRowIter) updateSessionVars(ctx *sql.Context) { switch i.QueryType { case QueryTypeSelect: - ctx.SetLastQueryInfo(sql.RowCount, -1) + ctx.SetLastQueryInfoInt(sql.RowCount, -1) case QueryTypeDdl: - ctx.SetLastQueryInfo(sql.RowCount, 0) + ctx.SetLastQueryInfoInt(sql.RowCount, 0) case QueryTypeUpdate: // This is handled by RowUpdateAccumulator default: @@ -390,7 +390,7 @@ func (i *trackedRowIter) updateSessionVars(ctx *sql.Context) { } if i.ShouldSetFoundRows { - ctx.SetLastQueryInfo(sql.FoundRows, i.numRows) + ctx.SetLastQueryInfoInt(sql.FoundRows, i.numRows) } } diff --git a/sql/rowexec/dml_iters.go b/sql/rowexec/dml_iters.go index 19c974e1e4..b1de809c9e 100644 --- a/sql/rowexec/dml_iters.go +++ b/sql/rowexec/dml_iters.go @@ -494,9 +494,9 @@ func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) { return nil, io.EOF } - oldLastInsertId := ctx.Session.GetLastQueryInfo(sql.LastInsertId) + oldLastInsertId := ctx.Session.GetLastQueryInfoInt(sql.LastInsertId) if oldLastInsertId != 0 { - ctx.Session.SetLastQueryInfo(sql.LastInsertId, -1) + ctx.Session.SetLastQueryInfoInt(sql.LastInsertId, -1) } // We close our child iterator before returning any results. In @@ -526,12 +526,12 @@ func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) { // UPDATE statements also set FoundRows to the number of rows that // matched the WHERE clause, same as a SELECT. if ma, ok := a.updateRowHandler.(matchingAccumulator); ok { - ctx.SetLastQueryInfo(sql.FoundRows, ma.RowsMatched()) + ctx.SetLastQueryInfoInt(sql.FoundRows, ma.RowsMatched()) } - newLastInsertId := ctx.Session.GetLastQueryInfo(sql.LastInsertId) + newLastInsertId := ctx.Session.GetLastQueryInfoInt(sql.LastInsertId) if newLastInsertId == -1 { - ctx.Session.SetLastQueryInfo(sql.LastInsertId, oldLastInsertId) + ctx.Session.SetLastQueryInfoInt(sql.LastInsertId, oldLastInsertId) } res := a.updateRowHandler.okResult() // TODO: Should add warnings here @@ -545,7 +545,7 @@ func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) { } // By definition, ROW_COUNT() is equal to RowsAffected. - ctx.SetLastQueryInfo(sql.RowCount, int64(res.RowsAffected)) + ctx.SetLastQueryInfoInt(sql.RowCount, int64(res.RowsAffected)) return sql.NewRow(res), nil } else if isIg { diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index 700bb0b4b2..f2ad586f1d 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -291,23 +291,20 @@ func (i *insertIter) updateLastInsertId(ctx *sql.Context, row sql.Row) { return } - autoIncVal := i.getAutoIncVal(row) - if i.hasAutoAutoIncValue { - ctx.SetLastQueryInfo(sql.LastInsertId, autoIncVal) + autoIncVal := i.getAutoIncVal(row) + ctx.SetLastQueryInfoInt(sql.LastInsertId, autoIncVal) i.lastInsertIdUpdated = true } } func (i *insertIter) getAutoIncVal(row sql.Row) int64 { - var autoIncVal int64 for i, expr := range i.insertExprs { if _, ok := expr.(*expression.AutoIncrement); ok { - autoIncVal = toInt64(row[i]) - break + return toInt64(row[i]) } } - return autoIncVal + return 0 } func (i *insertIter) ignoreOrClose(ctx *sql.Context, row sql.Row, err error) error { diff --git a/sql/rowexec/rel_iters.go b/sql/rowexec/rel_iters.go index c23330a426..6739d09735 100644 --- a/sql/rowexec/rel_iters.go +++ b/sql/rowexec/rel_iters.go @@ -72,7 +72,7 @@ func (i *topRowsIter) Close(ctx *sql.Context) error { i.topRows = nil if i.calcFoundRows { - ctx.SetLastQueryInfo(sql.FoundRows, i.numFoundRows) + ctx.SetLastQueryInfoInt(sql.FoundRows, i.numFoundRows) } return i.childIter.Close(ctx) @@ -934,7 +934,7 @@ func (li *limitIter) Close(ctx *sql.Context) error { } if li.calcFoundRows { - ctx.SetLastQueryInfo(sql.FoundRows, li.currentPos) + ctx.SetLastQueryInfoInt(sql.FoundRows, li.currentPos) } return nil } diff --git a/sql/session.go b/sql/session.go index 172f6fca05..78a4dbac4b 100644 --- a/sql/session.go +++ b/sql/session.go @@ -100,10 +100,14 @@ type Session interface { DelLock(lockName string) error // IterLocks iterates through all locks owned by this user IterLocks(cb func(name string) error) error - // SetLastQueryInfo sets session-level query info for the key given, applying to the query just executed. - SetLastQueryInfo(key string, value int64) - // GetLastQueryInfo returns the session-level query info for the key given, for the query most recently executed. - GetLastQueryInfo(key string) int64 + // SetLastQueryInfoInt sets session-level query info for the key given, applying to the query just executed. + SetLastQueryInfoInt(key string, value int64) + // GetLastQueryInfoInt returns the session-level query info for the key given, for the query most recently executed. + GetLastQueryInfoInt(key string) int64 + // SetLastQueryInfoString sets session-level query info as a string for the key given, applying to the query just executed. + SetLastQueryInfoString(key string, value string) + // GetLastQueryInfoString returns the session-level query info as a string for the key given, for the query most recently executed. + GetLastQueryInfoString(key string) string // GetTransaction returns the active transaction, if any GetTransaction() Transaction // SetTransaction sets the session's transaction @@ -199,9 +203,10 @@ type ( ) const ( - RowCount = "row_count" - FoundRows = "found_rows" - LastInsertId = "last_insert_id" + RowCount = "row_count" + FoundRows = "found_rows" + LastInsertId = "last_insert_id" + LastInsertUuid = "last_insert_uuid" ) // Session ID 0 used as invalid SessionID @@ -626,11 +631,12 @@ func (i *spanIter) Close(ctx *Context) error { return i.iter.Close(ctx) } -func defaultLastQueryInfo() map[string]int64 { - return map[string]int64{ - RowCount: 0, - FoundRows: 1, // this is kind of a hack -- it handles the case of `select found_rows()` before any select statement is issued - LastInsertId: 0, +func defaultLastQueryInfo() map[string]any { + return map[string]any{ + RowCount: int64(0), + FoundRows: int64(1), // this is kind of a hack -- it handles the case of `select found_rows()` before any select statement is issued + LastInsertId: int64(0), + LastInsertUuid: "", } }