Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -1202,11 +1202,11 @@ CREATE TABLE tab3 (
Assertions: []ScriptTestAssertion{
{
Query: `SELECT IS_UUID(UUID())`,
Expected: []sql.Row{{int8(1)}},
Expected: []sql.Row{{true}},
},
{
Query: `SELECT IS_UUID(@uuid)`,
Expected: []sql.Row{{int8(1)}},
Expected: []sql.Row{{true}},
},
{
Query: `SELECT BIN_TO_UUID(UUID_TO_BIN(@uuid))`,
Expand Down Expand Up @@ -1354,6 +1354,62 @@ CREATE TABLE tab3 (
},
},
},
{
Name: "@@last_generated_uuid behavior",
SetUpScript: []string{
"create table a (x int primary key auto_increment, y varchar(100) default (UUID()))",
},
Assertions: []ScriptTestAssertion{
{
Query: "select @@last_generated_uuid",
Expected: []sql.Row{{nil}},
},
{
Query: "set @first_uuid = UUID();",
Expected: []sql.Row{{}},
},
{
Query: "select is_uuid(@@last_generated_uuid), @first_uuid = @@last_generated_uuid",
Expected: []sql.Row{{true, true}},
},
{
// Test an insert with an explicit call to UUID()
Query: "insert into a (x,y) values (1, UUID())",
Expected: []sql.Row{{types.OkResult{RowsAffected: 1, InsertID: 1}}},
},
{
Query: "select @@last_generated_uuid IS NOT NULL, is_uuid(@@last_generated_uuid), @@last_generated_uuid = (select y from a where x=1);",
Expected: []sql.Row{{true, true, true}},
},
{
// When UUID() is used in a nested expression, it still updates @@last_generated_uuid
Query: "select concat('foo-', UUID(), '-bar');",
SkipResultsCheck: true,
},
{
Query: "select @@last_generated_uuid IS NOT NULL, is_uuid(@@last_generated_uuid), @@last_generated_uuid != (select y from a where x=1);",
Expected: []sql.Row{{true, true, true}},
},
{
// @@last_generated_uuid should hold the value of the last row that implicitly used UUID() via the column default
Query: "insert into a (x) values (3), (4)",
Expected: []sql.Row{{types.OkResult{RowsAffected: 2, InsertID: 3}}},
},
{
Query: "select @@last_generated_uuid IS NOT NULL, is_uuid(@@last_generated_uuid), @@last_generated_uuid = (select y from a where x=4);",
Expected: []sql.Row{{true, true, true}},
},
{
Query: "insert into a values (5, 'five'), (6, 'six')",
Expected: []sql.Row{{types.OkResult{RowsAffected: 2, InsertID: 5}}},
},
{
// The above query doesn't invoke the UUID explicitly (or implicitly through a default), so last_generated_uuid is unchanged
Query: "select @@last_generated_uuid IS NOT NULL, is_uuid(@@last_generated_uuid), @@last_generated_uuid = (select y from a where x=4);",
Expected: []sql.Row{{true, true, true}},
},
},
},
{
Name: "last_insert_id() behavior",
SetUpScript: []string{
Expand Down
32 changes: 20 additions & 12 deletions sql/expression/function/uuid.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ import (

type UUIDFunc struct{}

func (u UUIDFunc) IsNonDeterministic() bool {
return true
}

var _ sql.FunctionExpression = &UUIDFunc{}
var _ sql.CollationCoercible = &UUIDFunc{}

Expand All @@ -84,7 +80,14 @@ func (UUIDFunc) CollationCoercibility(ctx *sql.Context) (collation sql.Collation
}

func (u UUIDFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return uuid.New().String(), nil
uuid := uuid.New().String()

err := ctx.SetSessionVariable(ctx, "last_generated_uuid", uuid)
if err != nil {
return nil, err
}

return uuid, nil
}

func (u UUIDFunc) WithChildren(children ...sql.Expression) (sql.Expression, error) {
Expand Down Expand Up @@ -113,6 +116,11 @@ func (u UUIDFunc) IsNullable() bool {
return false
}

// IsNonDeterministic implements the sql.NonDeterministicExpression interface
func (u UUIDFunc) IsNonDeterministic() bool {
return true
}

// IS_UUID(string_uuid)
//
// Returns 1 if the argument is a valid string-format UUID, 0 if the argument is not a valid UUID, and NULL if the
Expand Down Expand Up @@ -147,7 +155,7 @@ func (u IsUUID) String() string {
}

func (u IsUUID) Type() sql.Type {
return types.Int8
return types.Boolean
}

// CollationCoercibility implements the interface sql.CollationCoercible.
Expand All @@ -158,7 +166,7 @@ func (IsUUID) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID
func (u IsUUID) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
str, err := u.child.Eval(ctx, row)
if err != nil {
return 0, err
return false, err
}

if str == nil {
Expand All @@ -169,19 +177,19 @@ func (u IsUUID) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
case string:
_, err := uuid.Parse(str)
if err != nil {
return int8(0), nil
return false, nil
}

return int8(1), nil
return true, nil
case []byte:
_, err := uuid.ParseBytes(str)
if err != nil {
return int8(0), nil
return false, nil
}

return int8(1), nil
return true, nil
default:
return int8(0), nil
return false, nil
}
}

Expand Down
26 changes: 18 additions & 8 deletions sql/expression/function/uuid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,26 @@ func TestUUID(t *testing.T) {
// Generate a UUID and validate that is a legitimate uuid
uuidE := NewUUIDFunc()

// Assert that @@last_generated_uuid is initialized to NULL
lastGeneratedUuid, err := ctx.GetSessionVariable(ctx, "last_generated_uuid")
require.NoError(t, err)
require.Nil(t, lastGeneratedUuid)

result, err := uuidE.Eval(ctx, sql.Row{nil})
require.NoError(t, err)

myUUID := result.(string)
_, err = uuid.Parse(myUUID)
require.NoError(t, err)

// Assert that @@last_generated_uuid has been set
lastGeneratedUuid, err = ctx.GetSessionVariable(ctx, "last_generated_uuid")
require.NoError(t, err)
require.Equal(t, myUUID, lastGeneratedUuid)

// validate that generated uuid is legitimate for IsUUID
val := NewIsUUID(uuidE)
require.Equal(t, int8(1), eval(t, val, sql.Row{nil}))
require.Equal(t, true, eval(t, val, sql.Row{nil}))

// Use a UUID regex as a sanity check
re2 := regexp.MustCompile(`\b[0-9a-f]{8}\b-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-\b[0-9a-f]{12}\b`)
Expand All @@ -55,14 +65,14 @@ func TestIsUUID(t *testing.T) {
value interface{}
expected interface{}
}{
{"uuid form 1", types.LongText, "{12345678-1234-5678-1234-567812345678}", int8(1)},
{"uuid form 2", types.LongText, "12345678123456781234567812345678", int8(1)},
{"uuid form 3", types.LongText, "12345678-1234-5678-1234-567812345678", int8(1)},
{"uuid form 1", types.LongText, "{12345678-1234-5678-1234-567812345678}", true},
{"uuid form 2", types.LongText, "12345678123456781234567812345678", true},
{"uuid form 3", types.LongText, "12345678-1234-5678-1234-567812345678", true},
{"NULL", types.Null, nil, nil},
{"random int", types.Int8, 1, int8(0)},
{"random bool", types.Boolean, false, int8(0)},
{"random string", types.LongText, "12345678-dasd-fasdf8", int8(0)},
{"swapped uuid", types.LongText, "5678-1234-12345678-1234-567812345678", int8(0)},
{"random int", types.Int8, 1, false},
{"random bool", types.Boolean, false, false},
{"random string", types.LongText, "12345678-dasd-fasdf8", false},
{"swapped uuid", types.LongText, "5678-1234-12345678-1234-567812345678", false},
}

for _, tt := range testCases {
Expand Down
9 changes: 9 additions & 0 deletions sql/variables/system_variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"sync"
"time"

"github.com/dolthub/vitess/go/sqltypes"
"github.com/google/uuid"
"github.com/sirupsen/logrus"

Expand Down Expand Up @@ -1152,6 +1153,14 @@ var systemVars = map[string]sql.SystemVariable{
Type: types.NewSystemIntType("large_page_size", -9223372036854775808, 9223372036854775807, false),
Default: int64(0),
},
"last_generated_uuid": {
Name: "last_generated_uuid",
Scope: sql.SystemVariableScope_Session,
Dynamic: true,
SetVarHintApplies: false,
Type: types.MustCreateStringWithDefaults(sqltypes.VarChar, 36),
Default: nil,
},
"last_insert_id": {
Name: "last_insert_id",
Scope: sql.SystemVariableScope_Session,
Expand Down