From 801692ce74b310473fdf86a8034f3d2614ca8b81 Mon Sep 17 00:00:00 2001 From: Daylon Wilkins Date: Tue, 13 Feb 2024 06:19:13 -0800 Subject: [PATCH] Added UUID type --- postgres/parser/uuid/codec.go | 13 ++ server/ast/resolvable_type_reference.go | 2 + server/types/serialization.go | 5 + server/types/uuid.go | 193 ++++++++++++++++++++++++ testing/go/types_test.go | 3 +- 5 files changed, 214 insertions(+), 2 deletions(-) create mode 100644 server/types/uuid.go diff --git a/postgres/parser/uuid/codec.go b/postgres/parser/uuid/codec.go index f01334bd7b..152153c40a 100644 --- a/postgres/parser/uuid/codec.go +++ b/postgres/parser/uuid/codec.go @@ -122,6 +122,19 @@ func (u *UUID) UnmarshalText(text []byte) error { case 41, 45: return u.decodeURN(text) default: + // Postgres allows for an extended dash placement scheme. We'll eventually support those explicitly, but for + // now, we'll just "support" them all by removing the dashes so that the UUID is in a dash-less state. + shortenedText := bytes.ReplaceAll(text, []byte{'-'}, nil) + switch len(shortenedText) { + case 32: + return u.decodeHashLike(shortenedText) + case 34, 38: + return u.decodeBraced(shortenedText) + case 36: + return u.decodeCanonical(shortenedText) + case 41, 45: + return u.decodeURN(shortenedText) + } return fmt.Errorf("uuid: incorrect UUID length: %s", text) } } diff --git a/server/ast/resolvable_type_reference.go b/server/ast/resolvable_type_reference.go index b48a63fd2b..8caf3c51f3 100755 --- a/server/ast/resolvable_type_reference.go +++ b/server/ast/resolvable_type_reference.go @@ -66,6 +66,8 @@ func nodeResolvableTypeReference(typ tree.ResolvableTypeReference) (*vitess.Conv columnTypeLength = vitess.NewIntVal([]byte(strconv.Itoa(int(columnType.Width())))) case types.TimestampFamily: columnTypeName = columnType.Name() + case types.UuidFamily: + resolvedType = pgtypes.Uuid } } diff --git a/server/types/serialization.go b/server/types/serialization.go index fa8fc71ee9..a007b87557 100644 --- a/server/types/serialization.go +++ b/server/types/serialization.go @@ -25,6 +25,7 @@ type serializationID byte const ( serializationID_Bool serializationID = 1 serializationID_BoolArray serializationID = 2 + serializationID_Uuid serializationID = 3 ) // init sets the serialization and deserialization functions. @@ -40,6 +41,8 @@ func SerializeType(extendedType types.ExtendedType) ([]byte, error) { return []byte{byte(serializationID_Bool)}, nil case BoolArrayType: return []byte{byte(serializationID_BoolArray)}, nil + case UuidType: + return []byte{byte(serializationID_Uuid)}, nil default: return nil, fmt.Errorf("unknown type to serialize") } @@ -66,6 +69,8 @@ func DeserializeType(serializedType []byte) (types.ExtendedType, error) { return Bool, nil case serializationID_BoolArray: return BoolArray, nil + case serializationID_Uuid: + return Uuid, nil default: return nil, fmt.Errorf("unknown type to deserialize") } diff --git a/server/types/uuid.go b/server/types/uuid.go new file mode 100644 index 0000000000..f5fd13809b --- /dev/null +++ b/server/types/uuid.go @@ -0,0 +1,193 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "bytes" + "reflect" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" + + "github.com/dolthub/doltgresql/postgres/parser/uuid" +) + +// Uuid is the UUID type. +var Uuid = UuidType{} + +// UuidType is the extended type implementation of the PostgreSQL UUID. +type UuidType struct{} + +var _ types.ExtendedType = UuidType{} + +// CollationCoercibility implements the types.ExtendedType interface. +func (b UuidType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + return sql.Collation_binary, 5 +} + +// Compare implements the types.ExtendedType interface. +func (b UuidType) Compare(v1 any, v2 any) (int, error) { + if v1 == nil && v2 == nil { + return 0, nil + } else if v1 != nil && v2 == nil { + return 1, nil + } else if v1 == nil && v2 != nil { + return -1, nil + } + + ac, _, err := b.Convert(v1) + if err != nil { + return 0, err + } + bc, _, err := b.Convert(v2) + if err != nil { + return 0, err + } + + ab := ac.(uuid.UUID) + bb := bc.(uuid.UUID) + return bytes.Compare(ab.GetBytesMut(), bb.GetBytesMut()), nil +} + +// Convert implements the types.ExtendedType interface. +func (b UuidType) Convert(val any) (any, sql.ConvertInRange, error) { + if val == nil { + return nil, sql.InRange, nil + } + + switch val := val.(type) { + case string: + uuidVal, err := uuid.FromString(val) + if err != nil { + return nil, sql.OutOfRange, err + } + return uuidVal, sql.InRange, nil + case uuid.UUID: + return val, sql.InRange, nil + default: + return nil, sql.OutOfRange, sql.ErrInvalidType.New(b) + } +} + +// Equals implements the types.ExtendedType interface. +func (b UuidType) Equals(otherType sql.Type) bool { + if otherExtendedType, ok := otherType.(types.ExtendedType); ok { + return bytes.Equal(MustSerializeType(b), MustSerializeType(otherExtendedType)) + } + return false +} + +// FormatSerializedValue implements the types.ExtendedType interface. +func (b UuidType) FormatSerializedValue(val []byte) (string, error) { + deserialized, err := b.DeserializeValue(val) + if err != nil { + return "", err + } + return b.FormatValue(deserialized) +} + +// FormatValue implements the types.ExtendedType interface. +func (b UuidType) FormatValue(val any) (string, error) { + if val == nil { + return "", nil + } + converted, _, err := b.Convert(val) + if err != nil { + return "", err + } + return converted.(uuid.UUID).String(), nil +} + +// MaxSerializedWidth implements the types.ExtendedType interface. +func (b UuidType) MaxSerializedWidth() types.ExtendedTypeSerializedWidth { + return types.ExtendedTypeSerializedWidth_64K +} + +// MaxTextResponseByteLength implements the types.ExtendedType interface. +func (b UuidType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { + return 16 +} + +// Promote implements the types.ExtendedType interface. +func (b UuidType) Promote() sql.Type { + return b +} + +// SerializedCompare implements the types.ExtendedType interface. +func (b UuidType) SerializedCompare(v1 []byte, v2 []byte) (int, error) { + if len(v1) == 0 && len(v2) == 0 { + return 0, nil + } else if len(v1) > 0 && len(v2) == 0 { + return 1, nil + } else if len(v1) == 0 && len(v2) > 0 { + return -1, nil + } + + return bytes.Compare(v1, v2), nil +} + +// SQL implements the types.ExtendedType interface. +func (b UuidType) SQL(ctx *sql.Context, dest []byte, v any) (sqltypes.Value, error) { + if v == nil { + return sqltypes.NULL, nil + } + value, _, err := b.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + return sqltypes.MakeTrusted(b.Type(), types.AppendAndSliceBytes(dest, []byte(value.(uuid.UUID).String()))), nil +} + +// String implements the types.ExtendedType interface. +func (b UuidType) String() string { + return "uuid" +} + +// Type implements the types.ExtendedType interface. +func (b UuidType) Type() query.Type { + return sqltypes.Text +} + +// ValueType implements the types.ExtendedType interface. +func (b UuidType) ValueType() reflect.Type { + return reflect.TypeOf(uuid.UUID{}) +} + +// Zero implements the types.ExtendedType interface. +func (b UuidType) Zero() any { + return uuid.UUID{} +} + +// SerializeValue implements the types.ExtendedType interface. +func (b UuidType) SerializeValue(val any) ([]byte, error) { + if val == nil { + return nil, nil + } + converted, _, err := b.Convert(val) + if err != nil { + return nil, err + } + return converted.(uuid.UUID).GetBytes(), nil +} + +// DeserializeValue implements the types.ExtendedType interface. +func (b UuidType) DeserializeValue(val []byte) (any, error) { + if len(val) == 0 { + return nil, nil + } + return uuid.FromBytes(val) +} diff --git a/testing/go/types_test.go b/testing/go/types_test.go index 0ad17fbe03..2a93ac7801 100644 --- a/testing/go/types_test.go +++ b/testing/go/types_test.go @@ -760,10 +760,9 @@ var typesTests = []ScriptTest{ }, { Name: "Uuid type", - Skip: true, SetUpScript: []string{ "CREATE TABLE t_uuid (id INTEGER primary key, v1 UUID);", - "INSERT INTO t_uuid VALUES (1, 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'), (2, 'f47ac10b-58cc-4372-a567-0e02b2c3d479');", + "INSERT INTO t_uuid VALUES (1, 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'), (2, 'f47ac10b58cc4372a567-0e02b2c3d479');", }, Assertions: []ScriptTestAssertion{ {