From 584054b6765eb00a4f61cedb5162fbdeb3bcf980 Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Thu, 29 Feb 2024 13:11:53 -0800 Subject: [PATCH 1/2] Adding a new @@last_generated_uuid system variable that is updated whenever the UUID() function is executed --- enginetest/queries/script_queries.go | 60 +++++++++++++++++++++++++++- sql/expression/function/uuid.go | 32 +++++++++------ sql/expression/function/uuid_test.go | 26 ++++++++---- sql/variables/system_variables.go | 9 +++++ 4 files changed, 105 insertions(+), 22 deletions(-) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index d9bd4704d3..9d74bb2e4c 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -1202,11 +1202,11 @@ CREATE TABLE tab3 ( Assertions: []ScriptTestAssertion{ { Query: `SELECT IS_UUID(UUID())`, - Expected: []sql.Row{{int8(1)}}, + Expected: []sql.Row{{true}}, }, { Query: `SELECT IS_UUID(@uuid)`, - Expected: []sql.Row{{int8(1)}}, + Expected: []sql.Row{{true}}, }, { Query: `SELECT BIN_TO_UUID(UUID_TO_BIN(@uuid))`, @@ -1354,6 +1354,62 @@ CREATE TABLE tab3 ( }, }, }, + { + Name: "@@last_generated_uuid behavior", + SetUpScript: []string{ + "create table a (x int primary key auto_increment, y varchar(100) default (UUID()))", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select @@last_generated_uuid", + Expected: []sql.Row{{nil}}, + }, + { + Query: "set @first_uuid = UUID();", + Expected: []sql.Row{{}}, + }, + { + Query: "select is_uuid(@@last_generated_uuid), @first_uuid = @@last_generated_uuid", + Expected: []sql.Row{{true, true}}, + }, + { + // Test an insert with an explicit call to UUID() + Query: "insert into a (x,y) values (1, UUID())", + Expected: []sql.Row{{types.OkResult{RowsAffected: 1, InsertID: 1}}}, + }, + { + Query: "select @@last_generated_uuid IS NOT NULL, is_uuid(@@last_generated_uuid), @@last_generated_uuid = (select y from a where x=1);", + Expected: []sql.Row{{true, true, true}}, + }, + { + // When UUID() is used in a nested expression, it still updates @@last_generated_uuid + Query: "select concat('foo-', UUID(), '-bar');", + SkipResultsCheck: true, + }, + { + Query: "select @@last_generated_uuid IS NOT NULL, is_uuid(@@last_generated_uuid), @@last_generated_uuid != (select y from a where x=1);", + Expected: []sql.Row{{true, true, true}}, + }, + { + // @@last_generated_uuid should hold the value of the last row that implicitly used UUID() via the column default + Query: "insert into a (x) values (3), (4)", + Expected: []sql.Row{{types.OkResult{RowsAffected: 2, InsertID: 3}}}, + }, + { + Query: "select @@last_generated_uuid IS NOT NULL, is_uuid(@@last_generated_uuid), @@last_generated_uuid = (select y from a where x=4);", + Expected: []sql.Row{{true, true, true}}, + }, + { + Query: "insert into a values (5, 'five'), (6, 'six')", + Expected: []sql.Row{{types.OkResult{RowsAffected: 2, InsertID: 5}}}, + }, + { + // The above query doesn't invoke the UUID explicitly (or implicitly through a default), so last_generated_uuid is unchanged + Query: "select @@last_generated_uuid IS NOT NULL, is_uuid(@@last_generated_uuid), @@last_generated_uuid = (select y from a where x=4);", + Expected: []sql.Row{{true, true, true}}, + }, + }, + }, { Name: "last_insert_id() behavior", SetUpScript: []string{ diff --git a/sql/expression/function/uuid.go b/sql/expression/function/uuid.go index 915141b339..9c1377636c 100644 --- a/sql/expression/function/uuid.go +++ b/sql/expression/function/uuid.go @@ -54,10 +54,6 @@ import ( type UUIDFunc struct{} -func (u UUIDFunc) IsNonDeterministic() bool { - return true -} - var _ sql.FunctionExpression = &UUIDFunc{} var _ sql.CollationCoercible = &UUIDFunc{} @@ -84,7 +80,14 @@ func (UUIDFunc) CollationCoercibility(ctx *sql.Context) (collation sql.Collation } func (u UUIDFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - return uuid.New().String(), nil + uuid := uuid.New().String() + + err := ctx.SetSessionVariable(ctx, "last_generated_uuid", uuid) + if err != nil { + return nil, err + } + + return uuid, nil } func (u UUIDFunc) WithChildren(children ...sql.Expression) (sql.Expression, error) { @@ -113,6 +116,11 @@ func (u UUIDFunc) IsNullable() bool { return false } +// IsNonDeterministic implements the sql.NonDeterministicExpression interface +func (u UUIDFunc) IsNonDeterministic() bool { + return true +} + // IS_UUID(string_uuid) // // Returns 1 if the argument is a valid string-format UUID, 0 if the argument is not a valid UUID, and NULL if the @@ -147,7 +155,7 @@ func (u IsUUID) String() string { } func (u IsUUID) Type() sql.Type { - return types.Int8 + return types.Boolean } // CollationCoercibility implements the interface sql.CollationCoercible. @@ -158,7 +166,7 @@ func (IsUUID) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID func (u IsUUID) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { str, err := u.child.Eval(ctx, row) if err != nil { - return 0, err + return false, err } if str == nil { @@ -169,19 +177,19 @@ func (u IsUUID) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { case string: _, err := uuid.Parse(str) if err != nil { - return int8(0), nil + return false, nil } - return int8(1), nil + return true, nil case []byte: _, err := uuid.ParseBytes(str) if err != nil { - return int8(0), nil + return false, nil } - return int8(1), nil + return true, nil default: - return int8(0), nil + return false, nil } } diff --git a/sql/expression/function/uuid_test.go b/sql/expression/function/uuid_test.go index 7f9da8aa6b..7ff5d36942 100644 --- a/sql/expression/function/uuid_test.go +++ b/sql/expression/function/uuid_test.go @@ -32,6 +32,11 @@ func TestUUID(t *testing.T) { // Generate a UUID and validate that is a legitimate uuid uuidE := NewUUIDFunc() + // Assert that @@last_generated_uuid is initialized to NULL + lastGeneratedUuid, err := ctx.GetSessionVariable(ctx, "last_generated_uuid") + require.NoError(t, err) + require.Nil(t, lastGeneratedUuid) + result, err := uuidE.Eval(ctx, sql.Row{nil}) require.NoError(t, err) @@ -39,9 +44,14 @@ func TestUUID(t *testing.T) { _, err = uuid.Parse(myUUID) require.NoError(t, err) + // Assert that @@last_generated_uuid has been set + lastGeneratedUuid, err = ctx.GetSessionVariable(ctx, "last_generated_uuid") + require.NoError(t, err) + require.Equal(t, myUUID, lastGeneratedUuid) + // validate that generated uuid is legitimate for IsUUID val := NewIsUUID(uuidE) - require.Equal(t, int8(1), eval(t, val, sql.Row{nil})) + require.Equal(t, true, eval(t, val, sql.Row{nil})) // Use a UUID regex as a sanity check re2 := regexp.MustCompile(`\b[0-9a-f]{8}\b-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-\b[0-9a-f]{12}\b`) @@ -55,14 +65,14 @@ func TestIsUUID(t *testing.T) { value interface{} expected interface{} }{ - {"uuid form 1", types.LongText, "{12345678-1234-5678-1234-567812345678}", int8(1)}, - {"uuid form 2", types.LongText, "12345678123456781234567812345678", int8(1)}, - {"uuid form 3", types.LongText, "12345678-1234-5678-1234-567812345678", int8(1)}, + {"uuid form 1", types.LongText, "{12345678-1234-5678-1234-567812345678}", true}, + {"uuid form 2", types.LongText, "12345678123456781234567812345678", true}, + {"uuid form 3", types.LongText, "12345678-1234-5678-1234-567812345678", true}, {"NULL", types.Null, nil, nil}, - {"random int", types.Int8, 1, int8(0)}, - {"random bool", types.Boolean, false, int8(0)}, - {"random string", types.LongText, "12345678-dasd-fasdf8", int8(0)}, - {"swapped uuid", types.LongText, "5678-1234-12345678-1234-567812345678", int8(0)}, + {"random int", types.Int8, 1, false}, + {"random bool", types.Boolean, false, false}, + {"random string", types.LongText, "12345678-dasd-fasdf8", false}, + {"swapped uuid", types.LongText, "5678-1234-12345678-1234-567812345678", false}, } for _, tt := range testCases { diff --git a/sql/variables/system_variables.go b/sql/variables/system_variables.go index 64f368de83..95c5d41ec0 100644 --- a/sql/variables/system_variables.go +++ b/sql/variables/system_variables.go @@ -27,6 +27,7 @@ import ( gmstime "github.com/dolthub/go-mysql-server/internal/time" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/vitess/go/sqltypes" ) // TODO: Add from the following sources because MySQL likes to not have every variable on a single page: @@ -1152,6 +1153,14 @@ var systemVars = map[string]sql.SystemVariable{ Type: types.NewSystemIntType("large_page_size", -9223372036854775808, 9223372036854775807, false), Default: int64(0), }, + "last_generated_uuid": { + Name: "last_generated_uuid", + Scope: sql.SystemVariableScope_Session, + Dynamic: true, + SetVarHintApplies: false, + Type: types.MustCreateStringWithDefaults(sqltypes.VarChar, 36), + Default: nil, + }, "last_insert_id": { Name: "last_insert_id", Scope: sql.SystemVariableScope_Session, From 2c049ee44f931670255b9279b7b0e5022da8e673 Mon Sep 17 00:00:00 2001 From: fulghum Date: Thu, 29 Feb 2024 21:22:39 +0000 Subject: [PATCH 2/2] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/variables/system_variables.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/variables/system_variables.go b/sql/variables/system_variables.go index 95c5d41ec0..6dabacd53a 100644 --- a/sql/variables/system_variables.go +++ b/sql/variables/system_variables.go @@ -21,13 +21,13 @@ import ( "sync" "time" + "github.com/dolthub/vitess/go/sqltypes" "github.com/google/uuid" "github.com/sirupsen/logrus" gmstime "github.com/dolthub/go-mysql-server/internal/time" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" - "github.com/dolthub/vitess/go/sqltypes" ) // TODO: Add from the following sources because MySQL likes to not have every variable on a single page: