diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6577ac3581..6ecfe9dd79 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -76,6 +76,8 @@ These are automatically included when releasing a new version of DoltgreSQL, so 3. `scripts`: Contains all of the non-build scripts and tools. 4. `server`: The interface between the `postgres` directory and the underlying [GMS](https://github.com/dolthub/go-mysql-server) and [Dolt](https://github.com/dolthub/dolt) backends. 1. `ast`: Specifically houses all transformations from the `postgres` AST to the [`vitess`](https://github.com/dolthub/vitess) AST. + 2. `functions`: Contains all implementations of [PostgreSQL functions](https://www.postgresql.org/docs/15/functions.html). + 3. `types`: Contains the implementations of all PostgreSQL types. 5. `testing`: Contains all integration tests, and all things related to testing. This will not contain _all_ tests within the repository, as functions within other directories may declare their own unit tests. 1. `bats`: Contains all of our [Bats](https://github.com/bats-core/bats-core) tests. diff --git a/go.mod b/go.mod index f396154577..b46c93b9c4 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,11 @@ require ( github.com/PuerkitoBio/goquery v1.8.1 github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a github.com/cockroachdb/errors v1.7.5 - github.com/dolthub/dolt/go v0.40.5-0.20240118214900-3cbb73cafa3c - github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231213233028-64c353bf920f - github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df + github.com/dolthub/dolt/go v0.40.5-0.20240207142355-cbe6b0ce7f01 + github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20240205125942-fc7c3429f29c + github.com/dolthub/go-mysql-server v0.17.1-0.20240207124505-c0f397a6aaca github.com/dolthub/sqllogictest/go v0.0.0-20240118211725-a52e3f5697e3 - github.com/dolthub/vitess v0.0.0-20240117231546-55b8c7b39462 + github.com/dolthub/vitess v0.0.0-20240207121055-c057d2347007 github.com/fatih/color v1.13.0 github.com/gogo/protobuf v1.3.2 github.com/golang/geo v0.0.0-20200730024412-e86565bf3f35 @@ -61,7 +61,7 @@ require ( github.com/dolthub/fslock v0.0.3 // indirect github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e // indirect github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 // indirect - github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 // indirect + github.com/dolthub/jsonpath v0.0.2-0.20240201003050-392940944c15 // indirect github.com/dolthub/maphash v0.0.0-20221220182448-74e1e1ea1577 // indirect github.com/dolthub/swiss v0.1.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect diff --git a/go.sum b/go.sum index fb03697043..fd90dc921b 100644 --- a/go.sum +++ b/go.sum @@ -214,30 +214,30 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/dolthub/dolt/go v0.40.5-0.20240118214900-3cbb73cafa3c h1:gEvvX3cUMEOW0UyIO3klXaohXzJmJ0ls0jZKAgTdWJE= -github.com/dolthub/dolt/go v0.40.5-0.20240118214900-3cbb73cafa3c/go.mod h1:n4qCXkCLlIFbR8PuXB0WG1JV5s8SZLK4sa/dEVx420o= -github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231213233028-64c353bf920f h1:f250FTgZ/OaCql9G6WJt46l9VOIBF1mI81hW9cnmBNM= -github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231213233028-64c353bf920f/go.mod h1:gHeHIDGU7em40EhFTliq62pExFcc1hxDTIZ9g5UqXYM= +github.com/dolthub/dolt/go v0.40.5-0.20240207142355-cbe6b0ce7f01 h1:r/1CAcbUpAHR0jG15YqMgoribkfOURNqH60M8lITN2M= +github.com/dolthub/dolt/go v0.40.5-0.20240207142355-cbe6b0ce7f01/go.mod h1:F/oS2i85PyQgoKG4ay4joe/iv2HR9njLbyYaqJ2FSyU= +github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20240205125942-fc7c3429f29c h1:AVAqyKKv6UVOcKr9anIe1VQItJEIcZENMtU0FF8bycM= +github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20240205125942-fc7c3429f29c/go.mod h1:gHeHIDGU7em40EhFTliq62pExFcc1hxDTIZ9g5UqXYM= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2/go.mod h1:mIEZOHnFx4ZMQeawhw9rhsj+0zwQj7adVsnBX7t+eKY= github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw1+y/N5SSCkma7FhAPw7KeGmD6c9PBZW9Y= github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168= -github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df h1:OmR6U3UvCMEguh1UaXCiK4qasA/tHH3+Ls2NRiEQfjU= -github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df/go.mod h1:hS8Snuzg+nyTDjv4NI9jiXQ2lJJOd3O0ylhVPQlHySw= +github.com/dolthub/go-mysql-server v0.17.1-0.20240207124505-c0f397a6aaca h1:tI3X4fIUTOT0N8n+GYkPNa384WlJoOBcztK5c5mBzjU= +github.com/dolthub/go-mysql-server v0.17.1-0.20240207124505-c0f397a6aaca/go.mod h1:ANK0a6tyjrZ2cOzDJT3nFsDp80xksI4UfeijFlvnjwE= github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514= github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto= -github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 h1:NfWmngMi1CYUWU4Ix8wM+USEhjc+mhPlT9JUR/anvbQ= -github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72/go.mod h1:ZWUdY4iszqRQ8OcoXClkxiAVAoWoK3cq0Hvv4ddGRuM= +github.com/dolthub/jsonpath v0.0.2-0.20240201003050-392940944c15 h1:sfTETOpsrNJPDn2KydiCtDgVu6Xopq8k3JP8PjFT22s= +github.com/dolthub/jsonpath v0.0.2-0.20240201003050-392940944c15/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/maphash v0.0.0-20221220182448-74e1e1ea1577 h1:SegEguMxToBn045KRHLIUlF2/jR7Y2qD6fF+3tdOfvI= github.com/dolthub/maphash v0.0.0-20221220182448-74e1e1ea1577/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4= github.com/dolthub/sqllogictest/go v0.0.0-20240118211725-a52e3f5697e3 h1:+eDpuEJ9t8aag943P27VS5PFNhp5l+6NIJ/Rc3b164o= github.com/dolthub/sqllogictest/go v0.0.0-20240118211725-a52e3f5697e3/go.mod h1:e/FIZVvT2IR53HBCAo41NjqgtEnjMJGKca3Y/dAmZaA= github.com/dolthub/swiss v0.1.0 h1:EaGQct3AqeP/MjASHLiH6i4TAmgbG/c4rA6a1bzCOPc= github.com/dolthub/swiss v0.1.0/go.mod h1:BeucyB08Vb1G9tumVN3Vp/pyY4AMUnr9p7Rz7wJ7kAQ= -github.com/dolthub/vitess v0.0.0-20240117231546-55b8c7b39462 h1:So1KO202cb047yWg5X27xRso6tkSYmU0Yu96JIVsaEU= -github.com/dolthub/vitess v0.0.0-20240117231546-55b8c7b39462/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= +github.com/dolthub/vitess v0.0.0-20240207121055-c057d2347007 h1:MvFoe0FnHhxQLyp4Ldw0HRj1yu83YErbtbr7XxhaIFk= +github.com/dolthub/vitess v0.0.0-20240207121055-c057d2347007/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= diff --git a/server/ast/column_table_def.go b/server/ast/column_table_def.go index be984d8dd9..8b2c75aa85 100644 --- a/server/ast/column_table_def.go +++ b/server/ast/column_table_def.go @@ -39,7 +39,7 @@ func nodeColumnTableDef(node *tree.ColumnTableDef) (_ *vitess.ColumnDefinition, return nil, fmt.Errorf("FAMILY is not yet supported") } - convertType, err := nodeResolvableTypeReference(node.Type) + convertType, resolvedType, err := nodeResolvableTypeReference(node.Type) if err != nil { return nil, err } @@ -105,6 +105,7 @@ func nodeColumnTableDef(node *tree.ColumnTableDef) (_ *vitess.ColumnDefinition, Name: vitess.NewColIdent(string(node.Name)), Type: vitess.ColumnType{ Type: convertType.Type, + ResolvedType: resolvedType, Null: isNull, NotNull: isNotNull, Autoincrement: vitess.BoolVal(node.IsSerial), diff --git a/server/ast/expr.go b/server/ast/expr.go index 8bf9485c12..681a22b2c3 100644 --- a/server/ast/expr.go +++ b/server/ast/expr.go @@ -19,10 +19,13 @@ import ( "go/constant" "strings" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" vitess "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" + "github.com/dolthub/doltgresql/postgres/parser/types" + pgtypes "github.com/dolthub/doltgresql/server/types" ) // nodeExprs handles tree.Exprs nodes. @@ -96,7 +99,26 @@ func nodeExpr(node tree.Expr) (vitess.Expr, error) { case *tree.AnnotateTypeExpr: return nil, fmt.Errorf("ANNOTATE_TYPE is not yet supported") case *tree.Array: - return nil, fmt.Errorf("arrays are not yet supported") + //TODO: right now, this only works with boolean array values for the sake of demonstration + var gmsExpr sql.Expression + if len(node.Exprs) == 0 { + if node.ResolvedType().Family() == types.ArrayFamily && node.ResolvedType().ArrayContents().Family() == types.BoolFamily { + gmsExpr = expression.NewLiteral([]bool{}, pgtypes.BoolArray) + } else { + return nil, fmt.Errorf("arrays are generally not yet supported") + } + } else { + vals := make([]bool, len(node.Exprs)) + for i, arrayExpr := range node.Exprs { + if arrayVal, ok := arrayExpr.(*tree.DBool); ok && arrayVal != nil { + vals[i] = bool(*arrayVal) + } else { + return nil, fmt.Errorf("array value is not yet supported") + } + } + gmsExpr = expression.NewLiteral(vals, pgtypes.BoolArray) + } + return vitess.InjectedExpr{Expression: gmsExpr}, nil case *tree.ArrayFlatten: return nil, fmt.Errorf("flattening arrays is not yet supported") case *tree.BinaryExpr: @@ -199,7 +221,7 @@ func nodeExpr(node tree.Expr) (vitess.Expr, error) { return nil, fmt.Errorf("unknown cast syntax") } - convertType, err := nodeResolvableTypeReference(node.Type) + convertType, _, err := nodeResolvableTypeReference(node.Type) if err != nil { return nil, err } diff --git a/server/ast/resolvable_type_reference.go b/server/ast/resolvable_type_reference.go index a9f7f82f08..b48a63fd2b 100755 --- a/server/ast/resolvable_type_reference.go +++ b/server/ast/resolvable_type_reference.go @@ -18,33 +18,44 @@ import ( "fmt" "strconv" + "github.com/dolthub/go-mysql-server/sql" vitess "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" "github.com/dolthub/doltgresql/postgres/parser/types" + pgtypes "github.com/dolthub/doltgresql/server/types" ) // nodeResolvableTypeReference handles tree.ResolvableTypeReference nodes. -func nodeResolvableTypeReference(typ tree.ResolvableTypeReference) (*vitess.ConvertType, error) { +func nodeResolvableTypeReference(typ tree.ResolvableTypeReference) (*vitess.ConvertType, sql.Type, error) { if typ == nil { - return nil, nil + return nil, nil, nil } var columnTypeName string var columnTypeLength *vitess.SQLVal var columnTypeScale *vitess.SQLVal + var resolvedType sql.Type switch columnType := typ.(type) { case *tree.ArrayTypeReference: - return nil, fmt.Errorf("array types are not yet supported") + return nil, nil, fmt.Errorf("the given array type is not yet supported") case *tree.OIDTypeReference: - return nil, fmt.Errorf("referencing types by their OID is not yet supported") + return nil, nil, fmt.Errorf("referencing types by their OID is not yet supported") case *tree.UnresolvedObjectName: - return nil, fmt.Errorf("type declaration format is not yet supported") + return nil, nil, fmt.Errorf("type declaration format is not yet supported") case *types.GeoMetadata: - return nil, fmt.Errorf("geometry types are not yet supported") + return nil, nil, fmt.Errorf("geometry types are not yet supported") case *types.T: columnTypeName = columnType.SQLStandardName() switch columnType.Family() { + case types.ArrayFamily: + if columnType.ArrayContents().Family() == types.BoolFamily { + resolvedType = pgtypes.BoolArray + } else { + return nil, nil, fmt.Errorf("the given array type is not yet supported") + } + case types.BoolFamily: + resolvedType = pgtypes.Bool case types.DecimalFamily: columnTypeName = "decimal" columnTypeLength = vitess.NewIntVal([]byte(strconv.Itoa(int(columnType.Precision())))) @@ -63,5 +74,5 @@ func nodeResolvableTypeReference(typ tree.ResolvableTypeReference) (*vitess.Conv Length: columnTypeLength, Scale: columnTypeScale, Charset: "", // TODO - }, nil + }, resolvedType, nil } diff --git a/server/types/bool.go b/server/types/bool.go new file mode 100644 index 0000000000..4b15a52673 --- /dev/null +++ b/server/types/bool.go @@ -0,0 +1,257 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "bytes" + "fmt" + "reflect" + "strings" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/shopspring/decimal" +) + +// Bool is the standard boolean. +var Bool = BoolType{} + +// BoolType is the extended type implementation of the PostgreSQL boolean. +type BoolType struct{} + +var _ types.ExtendedType = BoolType{} + +// CollationCoercibility implements the types.ExtendedType interface. +func (b BoolType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + return sql.Collation_binary, 5 +} + +// Compare implements the types.ExtendedType interface. +func (b BoolType) Compare(v1 any, v2 any) (int, error) { + if v1 == nil && v2 == nil { + return 0, nil + } else if v1 != nil && v2 == nil { + return 1, nil + } else if v1 == nil && v2 != nil { + return -1, nil + } + + ac, _, err := b.Convert(v1) + if err != nil { + return 0, err + } + bc, _, err := b.Convert(v2) + if err != nil { + return 0, err + } + + ab := ac.(bool) + bb := bc.(bool) + if ab == bb { + return 0, nil + } else if !ab { + return -1, nil + } else { + return 1, nil + } +} + +// Convert implements the types.ExtendedType interface. +func (b BoolType) Convert(val any) (any, sql.ConvertInRange, error) { + if val == nil { + return nil, sql.InRange, nil + } + + switch val := val.(type) { + case bool: + return val, sql.InRange, nil + case int: + return val != 0, sql.InRange, nil + case uint: + return val != 0, sql.InRange, nil + case int8: + return val != 0, sql.InRange, nil + case uint8: + return val != 0, sql.InRange, nil + case int16: + return val != 0, sql.InRange, nil + case uint16: + return val != 0, sql.InRange, nil + case int32: + return val != 0, sql.InRange, nil + case uint32: + return val != 0, sql.InRange, nil + case int64: + return val != 0, sql.InRange, nil + case uint64: + return val != 0, sql.InRange, nil + case float32: + return val != 0, sql.InRange, nil + case float64: + return val != 0, sql.InRange, nil + case decimal.NullDecimal: + if !val.Valid { + return nil, sql.InRange, nil + } + return b.Convert(val.Decimal) + case decimal.Decimal: + return !val.Equal(decimal.NewFromInt(0)), sql.InRange, nil + case string: + val = strings.TrimSpace(strings.ToLower(val)) + if val == "true" || val == "yes" || val == "on" || val == "1" { + return true, sql.InRange, nil + } else if val == "false" || val == "no" || val == "off" || val == "0" { + return false, sql.InRange, nil + } else { + return nil, sql.OutOfRange, fmt.Errorf("invalid string value for boolean: %q", val) + } + case []byte: + return b.Convert(string(val)) + default: + return nil, sql.OutOfRange, sql.ErrInvalidType.New(b) + } +} + +// Equals implements the types.ExtendedType interface. +func (b BoolType) Equals(otherType sql.Type) bool { + if otherExtendedType, ok := otherType.(types.ExtendedType); ok { + return bytes.Equal(MustSerializeType(b), MustSerializeType(otherExtendedType)) + } + return false +} + +// FormatSerializedValue implements the types.ExtendedType interface. +func (b BoolType) FormatSerializedValue(val []byte) (string, error) { + deserialized, err := b.DeserializeValue(val) + if err != nil { + return "", err + } + return b.FormatValue(deserialized) +} + +// FormatValue implements the types.ExtendedType interface. +func (b BoolType) FormatValue(val any) (string, error) { + if val == nil { + return "", nil + } + converted, _, err := b.Convert(val) + if err != nil { + return "", err + } + if converted.(bool) { + return "true", nil + } else { + return "false", nil + } +} + +// MaxSerializedWidth implements the types.ExtendedType interface. +func (b BoolType) MaxSerializedWidth() types.ExtendedTypeSerializedWidth { + return types.ExtendedTypeSerializedWidth_64K +} + +// MaxTextResponseByteLength implements the types.ExtendedType interface. +func (b BoolType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { + return 1 +} + +// Promote implements the types.ExtendedType interface. +func (b BoolType) Promote() sql.Type { + return b +} + +// SerializedCompare implements the types.ExtendedType interface. +func (b BoolType) SerializedCompare(v1 []byte, v2 []byte) (int, error) { + if len(v1) == 0 && len(v2) == 0 { + return 0, nil + } else if len(v1) > 0 && len(v2) == 0 { + return 1, nil + } else if len(v1) == 0 && len(v2) > 0 { + return -1, nil + } + + if v1[0] == v2[0] { + return 0, nil + } else if v1[0] == 0 { + return -1, nil + } else { + return 1, nil + } +} + +// SQL implements the types.ExtendedType interface. +func (b BoolType) SQL(ctx *sql.Context, dest []byte, v any) (sqltypes.Value, error) { + if v == nil { + return sqltypes.NULL, nil + } + value, _, err := b.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + var valBytes []byte + if value.(bool) { + //TODO: use Wireshark and check whether we're returning these strings or something else + valBytes = types.AppendAndSliceBytes(dest, []byte{'t'}) + } else { + valBytes = types.AppendAndSliceBytes(dest, []byte{'f'}) + } + return sqltypes.MakeTrusted(b.Type(), valBytes), nil +} + +// String implements the types.ExtendedType interface. +func (b BoolType) String() string { + return "boolean" +} + +// Type implements the types.ExtendedType interface. +func (b BoolType) Type() query.Type { + return sqltypes.Text +} + +// ValueType implements the types.ExtendedType interface. +func (b BoolType) ValueType() reflect.Type { + return reflect.TypeOf(bool(false)) +} + +// Zero implements the types.ExtendedType interface. +func (b BoolType) Zero() any { + return false +} + +// SerializeValue implements the types.ExtendedType interface. +func (b BoolType) SerializeValue(val any) ([]byte, error) { + if val == nil { + return nil, nil + } + converted, _, err := b.Convert(val) + if err != nil { + return nil, err + } + if converted.(bool) { + return []byte{1}, nil + } else { + return []byte{0}, nil + } +} + +// DeserializeValue implements the types.ExtendedType interface. +func (b BoolType) DeserializeValue(val []byte) (any, error) { + if len(val) == 0 { + return nil, nil + } + return val[0] != 0, nil +} diff --git a/server/types/bool_array.go b/server/types/bool_array.go new file mode 100644 index 0000000000..d709477d9b --- /dev/null +++ b/server/types/bool_array.go @@ -0,0 +1,265 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "bytes" + "math" + "reflect" + "strings" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" + + "github.com/dolthub/doltgresql/utils" +) + +// BoolArray is the standard boolean array. +var BoolArray = BoolArrayType{} + +// BoolArrayType is the extended type implementation of the PostgreSQL boolean. +type BoolArrayType struct{} + +var _ types.ExtendedType = BoolArrayType{} + +// CollationCoercibility implements the types.ExtendedType interface. +func (b BoolArrayType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + return sql.Collation_binary, 5 +} + +// Compare implements the types.ExtendedType interface. +func (b BoolArrayType) Compare(v1 any, v2 any) (int, error) { + if v1 == nil && v2 == nil { + return 0, nil + } else if v1 != nil && v2 == nil { + return 1, nil + } else if v1 == nil && v2 != nil { + return -1, nil + } + + ac, _, err := b.Convert(v1) + if err != nil { + return 0, err + } + bc, _, err := b.Convert(v2) + if err != nil { + return 0, err + } + + ab := ac.([]bool) + bb := bc.([]bool) + minLength := utils.Min(len(ab), len(bb)) + for i := 0; i < minLength; i++ { + if ab[i] == bb[i] { + continue + } else if !ab[i] { + return -1, nil + } else { + return 1, nil + } + } + if len(ab) == len(bb) { + return 0, nil + } else if len(ab) < len(bb) { + return -1, nil + } else { + return 1, nil + } +} + +// Convert implements the types.ExtendedType interface. +func (b BoolArrayType) Convert(val any) (any, sql.ConvertInRange, error) { + if val == nil { + return nil, sql.InRange, nil + } + + switch val := val.(type) { + case []bool: + return val, sql.InRange, nil + default: + return nil, sql.OutOfRange, sql.ErrInvalidType.New(b) + } +} + +// Equals implements the types.ExtendedType interface. +func (b BoolArrayType) Equals(otherType sql.Type) bool { + if otherExtendedType, ok := otherType.(types.ExtendedType); ok { + return bytes.Equal(MustSerializeType(b), MustSerializeType(otherExtendedType)) + } + return false +} + +// FormatSerializedValue implements the types.ExtendedType interface. +func (b BoolArrayType) FormatSerializedValue(val []byte) (string, error) { + deserialized, err := b.DeserializeValue(val) + if err != nil { + return "", err + } + return b.FormatValue(deserialized) +} + +// FormatValue implements the types.ExtendedType interface. +func (b BoolArrayType) FormatValue(val any) (string, error) { + if val == nil { + return "", nil + } + converted, _, err := b.Convert(val) + if err != nil { + return "", err + } + sb := strings.Builder{} + for i, v := range converted.([]bool) { + if i > 0 { + sb.WriteString(", ") + } + if v { + return "true", nil + } else { + return "false", nil + } + } + return sb.String(), nil +} + +// MaxSerializedWidth implements the types.ExtendedType interface. +func (b BoolArrayType) MaxSerializedWidth() types.ExtendedTypeSerializedWidth { + return types.ExtendedTypeSerializedWidth_Unbounded +} + +// MaxTextResponseByteLength implements the types.ExtendedType interface. +func (b BoolArrayType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { + return math.MaxUint32 +} + +// Promote implements the types.ExtendedType interface. +func (b BoolArrayType) Promote() sql.Type { + return b +} + +// SerializedCompare implements the types.ExtendedType interface. +func (b BoolArrayType) SerializedCompare(v1 []byte, v2 []byte) (int, error) { + if v1 == nil && v2 == nil { + return 0, nil + } else if v1 != nil && v2 == nil { + return 1, nil + } else if v1 == nil && v2 != nil { + return -1, nil + } + + minLength := utils.Min(len(v1), len(v2)) + for i := 0; i < minLength; i++ { + if v1[i] == v2[i] { + continue + } else if v1[i] == 0 { + return -1, nil + } else { + return 1, nil + } + } + if len(v1) == len(v2) { + return 0, nil + } else if len(v1) < len(v2) { + return -1, nil + } else { + return 1, nil + } +} + +// SQL implements the types.ExtendedType interface. +func (b BoolArrayType) SQL(ctx *sql.Context, dest []byte, v any) (sqltypes.Value, error) { + if v == nil { + return sqltypes.NULL, nil + } + valueAny, _, err := b.Convert(v) + if err != nil { + return sqltypes.Value{}, err + } + value := valueAny.([]bool) + if len(value) == 0 { + return sqltypes.MakeTrusted(b.Type(), types.AppendAndSliceBytes(dest, []byte{'{', '}'})), nil + } + valBytes := make([]byte, 2+len(value)+(len(value)-1)) // {t,f,t} | we're including the brackets and commas + valBytes[0] = '{' + valBytes[len(valBytes)-1] = '}' + valBytesIndex := 1 + for valueIndex := range value { + if valueIndex > 0 { + valBytes[valBytesIndex] = ',' + valBytesIndex++ + } + if value[valueIndex] { + valBytes[valBytesIndex] = 't' + } else { + valBytes[valBytesIndex] = 'f' + } + valBytesIndex++ + } + return sqltypes.MakeTrusted(b.Type(), types.AppendAndSliceBytes(dest, valBytes)), nil +} + +// String implements the types.ExtendedType interface. +func (b BoolArrayType) String() string { + return "boolean[]" +} + +// Type implements the types.ExtendedType interface. +func (b BoolArrayType) Type() query.Type { + return sqltypes.Text +} + +// ValueType implements the types.ExtendedType interface. +func (b BoolArrayType) ValueType() reflect.Type { + return reflect.TypeOf([]bool{}) +} + +// Zero implements the types.ExtendedType interface. +func (b BoolArrayType) Zero() any { + return []bool{} +} + +// SerializeValue implements the types.ExtendedType interface. +func (b BoolArrayType) SerializeValue(val any) ([]byte, error) { + if val == nil { + return nil, nil + } + convertedAny, _, err := b.Convert(val) + if err != nil { + return nil, err + } + converted := convertedAny.([]bool) + output := make([]byte, len(converted)) + for i := range converted { + if converted[i] { + output[i] = 1 + } else { + output[i] = 0 + } + } + return output, nil +} + +// DeserializeValue implements the types.ExtendedType interface. +func (b BoolArrayType) DeserializeValue(val []byte) (any, error) { + if val == nil { + return nil, nil + } + output := make([]bool, len(val)) + for i := range val { + output[i] = val[i] != 0 + } + return output, nil +} diff --git a/server/types/serialization.go b/server/types/serialization.go new file mode 100644 index 0000000000..fa8fc71ee9 --- /dev/null +++ b/server/types/serialization.go @@ -0,0 +1,72 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/sql/types" +) + +type serializationID byte + +const ( + serializationID_Bool serializationID = 1 + serializationID_BoolArray serializationID = 2 +) + +// init sets the serialization and deserialization functions. +func init() { + types.SetExtendedTypeSerializers(SerializeType, DeserializeType) +} + +// SerializeType is able to serialize the given extended type into a byte slice. All extended types will be defined +// by DoltgreSQL. +func SerializeType(extendedType types.ExtendedType) ([]byte, error) { + switch extendedType.(type) { + case BoolType: + return []byte{byte(serializationID_Bool)}, nil + case BoolArrayType: + return []byte{byte(serializationID_BoolArray)}, nil + default: + return nil, fmt.Errorf("unknown type to serialize") + } +} + +// MustSerializeType internally calls SerializeType and panics on error. In general, panics should only occur when a +// type has not yet had its serialization implemented yet. +func MustSerializeType(extendedType types.ExtendedType) []byte { + serializedType, err := SerializeType(extendedType) + if err != nil { + panic(err) + } + return serializedType +} + +// DeserializeType is able to deserialize the given serialized type into an appropriate extended type. All extended +// types will be defined by DoltgreSQL. +func DeserializeType(serializedType []byte) (types.ExtendedType, error) { + if len(serializedType) == 0 { + return nil, fmt.Errorf("cannot deserialize an empty type") + } + switch serializationID(serializedType[0]) { + case serializationID_Bool: + return Bool, nil + case serializationID_BoolArray: + return BoolArray, nil + default: + return nil, fmt.Errorf("unknown type to deserialize") + } +} diff --git a/testing/bats/types.bats b/testing/bats/types.bats new file mode 100644 index 0000000000..aaeb6a80b3 --- /dev/null +++ b/testing/bats/types.bats @@ -0,0 +1,40 @@ +#!/usr/bin/env bats +load $BATS_TEST_DIRNAME/setup/common.bash + +setup() { + setup_common + start_sql_server + +} + +teardown() { + teardown_common +} + +@test 'types: boolean type' { + query_server < maximum { + maximum = vals[i] + } + } + return maximum +} diff --git a/utils/min.go b/utils/min.go new file mode 100644 index 0000000000..6899ee4f3a --- /dev/null +++ b/utils/min.go @@ -0,0 +1,33 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "golang.org/x/exp/constraints" +) + +// Min returns the smallest value of the given parameters. +func Min[T constraints.Ordered](vals ...T) (minimum T) { + if len(vals) == 0 { + return minimum + } + minimum = vals[0] + for i := 1; i < len(vals); i++ { + if vals[i] < minimum { + minimum = vals[i] + } + } + return minimum +}