diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9ab9f5127b..43d931a3c6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -238,6 +238,24 @@ There are exceptions, as some statements we do not yet support, and cannot suppo In these cases, we must add a `//TODO:` comment stating what is missing and why it isn't an error. This will at least allow us to track all such instances where we deviate from the expected behavior, which we can also document elsewhere for users of DoltgreSQL. +### `server/functions` + +The `functions` package contains the functions, along with an implementation to approximate the function overloading structure (and type coercion). + +The function overloading structure is defined in all files that have the `zinternal_` prefix. +Although not preferable, this was chosen as Go does not allow cyclical references between packages. +Rather than have half of the implementation in `functions`, and the other half in another package, the decision was made to include both in the `functions` package with the added prefix for distinction. + +There's an `init` function in `server/functions/zinternal_catalog.go` (this is included in `server/listener.go`) that removes any conflicting GMS function names, and replaces them with the PostgreSQL equivalents. +This means that the functions that we've added behave as expected, and for others to have _some_ sort of implementation rather than outright failing. +We will eventually remove all GMS functions once all PostgreSQL functions have been implemented. +The other internal files all contribute to the generation of functions, along with their proper handling. + +Each function (and all overloads) are contained in a single file. +Overloads are named according to their parameters, and prefixed by their target function name. +The set of overloads are then added to the `Catalog` within `server/functions/zinternal_catalog.go`. +To add a new function, it is as simple as creating the `Function`, adding the overloads, and adding it to the `Catalog`. + ### `testing/bats` All Bats tests must follow this general structure: diff --git a/server/ast/expr.go b/server/ast/expr.go index 49777c797c..b5867443f8 100644 --- a/server/ast/expr.go +++ b/server/ast/expr.go @@ -189,10 +189,8 @@ func nodeExpr(node tree.Expr) (vitess.Expr, error) { } switch node.SyntaxMode { - case tree.CastExplicit: - // only acceptable cast type - case tree.CastShort: - return nil, fmt.Errorf("TYPECAST is not yet supported") + case tree.CastExplicit, tree.CastShort: + // Both of these are acceptable case tree.CastPrepend: return nil, fmt.Errorf("typed literals are not yet supported") default: diff --git a/server/ast/resolvable_type_reference.go b/server/ast/resolvable_type_reference.go index 677e25497b..a9f7f82f08 100755 --- a/server/ast/resolvable_type_reference.go +++ b/server/ast/resolvable_type_reference.go @@ -46,6 +46,7 @@ func nodeResolvableTypeReference(typ tree.ResolvableTypeReference) (*vitess.Conv columnTypeName = columnType.SQLStandardName() switch columnType.Family() { case types.DecimalFamily: + columnTypeName = "decimal" columnTypeLength = vitess.NewIntVal([]byte(strconv.Itoa(int(columnType.Precision())))) columnTypeScale = vitess.NewIntVal([]byte(strconv.Itoa(int(columnType.Scale())))) case types.JsonFamily: diff --git a/server/functions/cbrt.go b/server/functions/cbrt.go new file mode 100644 index 0000000000..6558fd8363 --- /dev/null +++ b/server/functions/cbrt.go @@ -0,0 +1,37 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +import ( + "fmt" + "math" +) + +// cbrt represents the PostgreSQL function of the same name. +var cbrt = Function{ + Name: "cbrt", + Overloads: []interface{}{cbrt_float}, +} + +// cbrt_float is one of the overloads of cbrt. +func cbrt_float(num FloatType) (FloatType, error) { + if num.IsNull { + return FloatType{IsNull: true}, nil + } + if num.OriginalType == ParameterType_String { + return FloatType{}, fmt.Errorf("function cbrt(%s) does not exist", ParameterType_String.String()) + } + return FloatType{Value: math.Cbrt(num.Value)}, nil +} diff --git a/server/functions/gcd.go b/server/functions/gcd.go new file mode 100644 index 0000000000..fe7f5c8680 --- /dev/null +++ b/server/functions/gcd.go @@ -0,0 +1,44 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +import ( + "fmt" + + "github.com/dolthub/doltgresql/utils" +) + +// gcd represents the PostgreSQL function of the same name. +var gcd = Function{ + Name: "gcd", + Overloads: []interface{}{gcd_int_int}, +} + +// gcd_int_int is one of the overloads of gcd. +func gcd_int_int(num1 IntegerType, num2 IntegerType) (IntegerType, error) { + if num1.IsNull || num2.IsNull { + return IntegerType{IsNull: true}, nil + } + if num1.OriginalType == ParameterType_String || num2.OriginalType == ParameterType_String { + return IntegerType{}, fmt.Errorf("function gcd(%s, %s) does not exist", + num1.OriginalType.String(), num2.OriginalType.String()) + } + for num2.Value != 0 { + temp := num2.Value + num2.Value = num1.Value % num2.Value + num1.Value = temp + } + return IntegerType{Value: utils.Abs(num1.Value)}, nil +} diff --git a/server/functions/lcm.go b/server/functions/lcm.go new file mode 100644 index 0000000000..29919d0342 --- /dev/null +++ b/server/functions/lcm.go @@ -0,0 +1,46 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +import ( + "fmt" + + "github.com/dolthub/doltgresql/utils" +) + +// lcm represents the PostgreSQL function of the same name. +var lcm = Function{ + Name: "lcm", + Overloads: []interface{}{lcm1_int_int}, +} + +// lcm1 is one of the overloads of lcm. +func lcm1_int_int(num1 IntegerType, num2 IntegerType) (IntegerType, error) { + if num1.IsNull || num2.IsNull { + return IntegerType{IsNull: true}, nil + } + if num1.OriginalType == ParameterType_String || num2.OriginalType == ParameterType_String { + return IntegerType{}, fmt.Errorf("function lcm(%s, %s) does not exist", + num1.OriginalType.String(), num2.OriginalType.String()) + } + gcdResult, err := gcd_int_int(num1, num2) + if err != nil { + return IntegerType{}, err + } + if gcdResult.Value == 0 { + return IntegerType{Value: 0}, nil + } + return IntegerType{Value: utils.Abs((num1.Value * num2.Value) / gcdResult.Value)}, nil +} diff --git a/server/functions/round.go b/server/functions/round.go new file mode 100644 index 0000000000..1b441d808c --- /dev/null +++ b/server/functions/round.go @@ -0,0 +1,48 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +import "math" + +// round represents the PostgreSQL function of the same name. +var round = Function{ + Name: "round", + Overloads: []interface{}{round_num, round_float, round_num_dec}, +} + +// round1 is one of the overloads of round. +func round_num(num NumericType) (NumericType, error) { + if num.IsNull { + return NumericType{IsNull: true}, nil + } + return NumericType{Value: math.Round(num.Value)}, nil +} + +// round2 is one of the overloads of round. +func round_float(num FloatType) (FloatType, error) { + if num.IsNull { + return FloatType{IsNull: true}, nil + } + return FloatType{Value: math.RoundToEven(num.Value)}, nil +} + +// round3 is one of the overloads of round. +func round_num_dec(num NumericType, decimalPlaces IntegerType) (NumericType, error) { + if num.IsNull || decimalPlaces.IsNull { + return NumericType{IsNull: true}, nil + } + ratio := math.Pow10(int(decimalPlaces.Value)) + return NumericType{Value: math.Round(num.Value*ratio) / ratio}, nil +} diff --git a/server/functions/zinternal_catalog.go b/server/functions/zinternal_catalog.go new file mode 100644 index 0000000000..d6d6dac5c5 --- /dev/null +++ b/server/functions/zinternal_catalog.go @@ -0,0 +1,121 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +import ( + "fmt" + "reflect" + "strings" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression/function" +) + +// Function is a name, along with a collection of functions, that represent a single PostgreSQL function with all of its +// overloads. +type Function struct { + Name string + Overloads []any +} + +// Catalog contains all of the PostgreSQL functions. If a new function is added, make sure to add it to the catalog here. +var Catalog = []Function{ + cbrt, + gcd, + lcm, + round, +} + +// init handles the initialization of the catalog by overwriting the built-in GMS functions, since they do not apply to +// PostgreSQL (and functions of the same name often have different behavior). +func init() { + catalogMap := make(map[string]struct{}) + for _, f := range Catalog { + catalogMap[strings.ToLower(f.Name)] = struct{}{} + } + var newBuiltIns []sql.Function + for _, f := range function.BuiltIns { + if _, ok := catalogMap[strings.ToLower(f.FunctionName())]; !ok { + newBuiltIns = append(newBuiltIns, f) + } + } + function.BuiltIns = newBuiltIns + + allNames := make(map[string]struct{}) + for _, catalogItem := range Catalog { + funcName := strings.ToLower(catalogItem.Name) + if _, ok := allNames[funcName]; ok { + panic("duplicate name: " + catalogItem.Name) + } + allNames[funcName] = struct{}{} + + baseOverload := &OverloadDeduction{} + for _, functionOverload := range catalogItem.Overloads { + // For each function overload, we first need to ensure that it has an acceptable signature + funcVal := reflect.ValueOf(functionOverload) + if !funcVal.IsValid() || funcVal.IsNil() { + panic(fmt.Errorf("function `%s` has an invalid item", catalogItem.Name)) + } + if funcVal.Kind() != reflect.Func { + panic(fmt.Errorf("function `%s` has a non-function item", catalogItem.Name)) + } + if funcVal.Type().NumOut() != 2 { + panic(fmt.Errorf("function `%s` has an overload that does not return two values", catalogItem.Name)) + } + if funcVal.Type().Out(1) != reflect.TypeOf((*error)(nil)).Elem() { + panic(fmt.Errorf("function `%s` has an overload that does not return an error", catalogItem.Name)) + } + returnValType, returnSqlType, ok := ParameterTypeFromReflection(funcVal.Type().Out(0)) + if !ok { + panic(fmt.Errorf("function `%s` has an overload that returns as invalid type (`%s`)", + catalogItem.Name, funcVal.Type().Out(0).String())) + } + + // Loop through all of the parameters to ensure uniqueness, then store it + currentOverload := baseOverload + for i := 0; i < funcVal.Type().NumIn(); i++ { + paramValType, _, ok := ParameterTypeFromReflection(funcVal.Type().In(i)) + if !ok { + panic(fmt.Errorf("function `%s` has an overload with an invalid parameter type (`%s`)", + catalogItem.Name, funcVal.Type().In(i).String())) + } + nextOverload := currentOverload.Parameter[paramValType] + if nextOverload == nil { + nextOverload = &OverloadDeduction{} + currentOverload.Parameter[paramValType] = nextOverload + } + currentOverload = nextOverload + } + if currentOverload.Function.IsValid() && !currentOverload.Function.IsNil() { + panic(fmt.Errorf("function `%s` has duplicate overloads", catalogItem.Name)) + } + currentOverload.Function = funcVal + currentOverload.ReturnValType = returnValType + currentOverload.ReturnSqlType = returnSqlType + } + + // Store the compiled function into the engine's built-in functions + function.BuiltIns = append(function.BuiltIns, sql.FunctionN{ + Name: funcName, + Fn: func(params ...sql.Expression) (sql.Expression, error) { + return &CompiledFunction{ + Name: catalogItem.Name, + Parameters: params, + Functions: baseOverload, + }, nil + }, + }) + } +} diff --git a/server/functions/zinternal_compiled_function.go b/server/functions/zinternal_compiled_function.go new file mode 100644 index 0000000000..0a3cb0ddb2 --- /dev/null +++ b/server/functions/zinternal_compiled_function.go @@ -0,0 +1,273 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +import ( + "fmt" + "reflect" + "strings" + "time" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/shopspring/decimal" +) + +// CompiledFunction is an expression that represents a fully-analyzed PostgreSQL function. +type CompiledFunction struct { + Name string + Parameters []sql.Expression + Functions *OverloadDeduction +} + +var _ sql.FunctionExpression = (*CompiledFunction)(nil) + +// FunctionName implements the interface sql.Expression. +func (c *CompiledFunction) FunctionName() string { + return c.Name +} + +// Description implements the interface sql.Expression. +func (c *CompiledFunction) Description() string { + return fmt.Sprintf("The PostgreSQL function `%s`", c.Name) +} + +// Resolved implements the interface sql.Expression. +func (c *CompiledFunction) Resolved() bool { + for _, param := range c.Parameters { + if !param.Resolved() { + return false + } + } + return true +} + +// String implements the interface sql.Expression. +func (c *CompiledFunction) String() string { + sb := strings.Builder{} + sb.WriteString(c.Name + "(") + for i, param := range c.Parameters { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(param.String()) + } + sb.WriteString(")") + return sb.String() +} + +// OverloadString returns the name of the function represented by the given overload. +func (c *CompiledFunction) OverloadString(types []IntermediateParameter) string { + sb := strings.Builder{} + sb.WriteString(c.Name + "(") + for i, t := range types { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(t.CurrentType.String()) + } + sb.WriteString(")") + return sb.String() +} + +// Type implements the interface sql.Expression. +func (c *CompiledFunction) Type() sql.Type { + if resolvedFunction, _ := c.Functions.ResolveByType(c.possibleParameterTypes()); resolvedFunction != nil { + return resolvedFunction.ReturnSqlType + } + // We can't resolve to a function before evaluation in this case, so we'll return something arbitrary + return types.LongText +} + +// IsNullable implements the interface sql.Expression. +func (c *CompiledFunction) IsNullable() bool { + // We'll always return true, since it does not seem to have a truly negative impact if we return true for a function + // that will never return NULL, however there is a negative impact for returning false when a function does return + // NULL. + return true +} + +// Eval implements the interface sql.Expression. +func (c *CompiledFunction) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + // First we'll evaluate all of the parameters. + parameters, err := c.evalParameters(ctx, row) + if err != nil { + return nil, err + } + // Next we'll resolve the overload based on the parameters given. + overload, err := c.Functions.Resolve(parameters) + if err != nil { + return nil, err + } + // If we do not receive an overload, then the parameters given did not result in a valid match + if overload == nil { + return nil, fmt.Errorf("function %s does not exist", c.OverloadString(parameters)) + } + // Convert the intermediate parameters into their concrete types, then pass them to the function + concreteParameters := make([]reflect.Value, len(parameters)) + for i := range parameters { + concreteParameters[i] = parameters[i].ToValue() + } + result := overload.Function.Call(concreteParameters) + if !result[1].IsNil() { + return nil, result[1].Interface().(error) + } + // Unpack the resulting value, returning it to the caller + switch overload.ReturnValType { + case ParameterType_Integer: + resultVal := result[0].Interface().(IntegerType) + if resultVal.IsNull { + return nil, nil + } + return resultVal.Value, nil + case ParameterType_Float: + resultVal := result[0].Interface().(FloatType) + if resultVal.IsNull { + return nil, nil + } + return resultVal.Value, nil + case ParameterType_Numeric: + resultVal := result[0].Interface().(NumericType) + if resultVal.IsNull { + return nil, nil + } + return resultVal.Value, nil + case ParameterType_String: + resultVal := result[0].Interface().(StringType) + if resultVal.IsNull { + return nil, nil + } + return resultVal.Value, nil + case ParameterType_Timestamp: + resultVal := result[0].Interface().(TimestampType) + if resultVal.IsNull { + return nil, nil + } + return resultVal.Value, nil + default: + return nil, fmt.Errorf("unhandled parameter type in %T::Eval (%d)", c, overload.ReturnValType) + } +} + +// Children implements the interface sql.Expression. +func (c *CompiledFunction) Children() []sql.Expression { + return c.Parameters +} + +// WithChildren implements the interface sql.Expression. +func (c *CompiledFunction) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return &CompiledFunction{ + Name: c.Name, + Parameters: children, + Functions: c.Functions, + }, nil +} + +// evalParameters evaluates the parameters within an Eval call. +func (c *CompiledFunction) evalParameters(ctx *sql.Context, row sql.Row) ([]IntermediateParameter, error) { + parameters := make([]IntermediateParameter, len(c.Parameters)) + for i, param := range c.Parameters { + evaluatedParam, err := param.Eval(ctx, row) + if err != nil { + return nil, err + } + parameters[i].Source = c.determineSource(param) + switch evaluatedParam := evaluatedParam.(type) { + case int8: + parameters[i].Value = int64(evaluatedParam) + parameters[i].OriginalType = ParameterType_Integer + case int16: + parameters[i].Value = int64(evaluatedParam) + parameters[i].OriginalType = ParameterType_Integer + case int32: + parameters[i].Value = int64(evaluatedParam) + parameters[i].OriginalType = ParameterType_Integer + case int64: + parameters[i].Value = evaluatedParam + parameters[i].OriginalType = ParameterType_Integer + case float32: + parameters[i].Value = float64(evaluatedParam) + parameters[i].OriginalType = ParameterType_Float + case float64: + parameters[i].Value = evaluatedParam + parameters[i].OriginalType = ParameterType_Float + case decimal.Decimal: + //TODO: properly handle decimal types + asFloat, _ := evaluatedParam.Float64() + parameters[i].Value = asFloat + parameters[i].OriginalType = ParameterType_Numeric + case string: + parameters[i].Value = evaluatedParam + parameters[i].OriginalType = ParameterType_String + case time.Time: + parameters[i].Value = evaluatedParam + parameters[i].OriginalType = ParameterType_Timestamp + case nil: + parameters[i].IsNull = true + parameters[i].OriginalType = ParameterType_Null + default: + return nil, fmt.Errorf("PostgreSQL functions do not yet support parameters of type `%T`", evaluatedParam) + } + parameters[i].CurrentType = parameters[i].OriginalType + } + return parameters, nil +} + +// determineSource determines what the source is, based on the expression given. +func (c *CompiledFunction) determineSource(expr sql.Expression) Source { + switch expr := expr.(type) { + case *expression.Alias: + return c.determineSource(expr.Child) + case *expression.GetField: + return Source_Column + case *expression.Literal: + return Source_Constant + default: + return Source_Expression + } +} + +// possibleParameterTypes returns the parameter types of all of the expressions by guessing the return value from the +// type that each expression declares it will return. This is not guaranteed to be correct. +func (c *CompiledFunction) possibleParameterTypes() []ParameterType { + possibleParamTypes := make([]ParameterType, len(c.Parameters)) + for i, param := range c.Parameters { + switch param.Type().Type() { + case query.Type_INT8, query.Type_INT16, query.Type_INT24, query.Type_INT32, query.Type_INT64: + possibleParamTypes[i] = ParameterType_Integer + case query.Type_UINT8, query.Type_UINT16, query.Type_UINT24, query.Type_UINT32, query.Type_UINT64: + possibleParamTypes[i] = ParameterType_Integer + case query.Type_YEAR: + possibleParamTypes[i] = ParameterType_Integer + case query.Type_FLOAT32, query.Type_FLOAT64: + possibleParamTypes[i] = ParameterType_Float + case query.Type_DECIMAL: + //TODO: properly handle decimal types + possibleParamTypes[i] = ParameterType_Float + case query.Type_DATE, query.Type_DATETIME, query.Type_TIMESTAMP: + possibleParamTypes[i] = ParameterType_Timestamp + case query.Type_CHAR, query.Type_VARCHAR, query.Type_TEXT: + possibleParamTypes[i] = ParameterType_String + case query.Type_ENUM, query.Type_SET: + possibleParamTypes[i] = ParameterType_Integer + default: + // We'll just use NULL for now, since we've got incomplete coverage of PostgreSQL types anyway + possibleParamTypes[i] = ParameterType_Null + } + } + return possibleParamTypes +} diff --git a/server/functions/zinternal_overload_deduction.go b/server/functions/zinternal_overload_deduction.go new file mode 100644 index 0000000000..e3d6080868 --- /dev/null +++ b/server/functions/zinternal_overload_deduction.go @@ -0,0 +1,193 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +import ( + "fmt" + "reflect" + "strconv" + "time" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// OverloadDeduction handles resolving which function to call by iterating over the parameter expressions. This also +// handles casting between types if an exact function match is not found. +type OverloadDeduction struct { + Function reflect.Value + ReturnSqlType sql.Type + ReturnValType ParameterType + Parameter [ParameterType_Length]*OverloadDeduction +} + +// Resolve returns an overload that either matches the given parameters exactly, or is a viable match after casting. +// This will modify the parameter slice in-place. Returns a nil OverloadDeduction if a viable match is not found. +func (overload *OverloadDeduction) Resolve(parameters []IntermediateParameter) (*OverloadDeduction, error) { + parameterTypes := make([]ParameterType, len(parameters)) + for i := range parameters { + parameterTypes[i] = parameters[i].OriginalType + } + resultOverload, resultTypes := overload.ResolveByType(parameterTypes) + // If we receive a nil overload, then no valid overloads were found + if resultOverload == nil { + return nil, nil + } + // If any of the result types are different from their originals, then we need to cast them to their resulting types + // if it's possible. + for i, t := range resultTypes { + parameters[i].CurrentType = t + if parameters[i].OriginalType == t { + continue + } + + var err error + switch parameters[i].OriginalType { + case ParameterType_Null: + // Since nulls are typeless, we pretend that the current type was also the original type + parameters[i].OriginalType = t + switch t { + case ParameterType_Integer: + parameters[i].Value = int64(0) + case ParameterType_Float: + parameters[i].Value = float64(0) + case ParameterType_Numeric: + //TODO: properly handle decimal types + parameters[i].Value = float64(0) + case ParameterType_String: + parameters[i].Value = "" + case ParameterType_Timestamp: + parameters[i].Value = time.Time{} + default: + return nil, fmt.Errorf("invalid `%s` cast to `%s`", parameters[i].OriginalType.String(), t.String()) + } + case ParameterType_Integer: + switch t { + case ParameterType_Float: + parameters[i].Value = float64(parameters[i].Value.(int64)) + case ParameterType_Numeric: + //TODO: properly handle decimal types + parameters[i].Value = float64(parameters[i].Value.(int64)) + case ParameterType_String: + parameters[i].Value = strconv.FormatInt(parameters[i].Value.(int64), 10) + default: + return nil, fmt.Errorf("invalid `%s` cast to `%s`", parameters[i].OriginalType.String(), t.String()) + } + case ParameterType_Float: + switch t { + case ParameterType_Numeric: + //TODO: properly handle decimal types, this is a redundant assignment but serves as a reminder + parameters[i].Value = parameters[i].Value.(float64) + case ParameterType_String: + parameters[i].Value = strconv.FormatFloat(parameters[i].Value.(float64), 'f', -1, 64) + default: + return nil, fmt.Errorf("invalid `%s` cast to `%s`", parameters[i].OriginalType.String(), t.String()) + } + case ParameterType_Numeric: + switch t { + case ParameterType_String: + //TODO: properly handle decimal types + parameters[i].Value = strconv.FormatFloat(parameters[i].Value.(float64), 'f', -1, 64) + default: + return nil, fmt.Errorf("invalid `%s` cast to `%s`", parameters[i].OriginalType.String(), t.String()) + } + case ParameterType_String: + switch t { + case ParameterType_Integer: + parameters[i].Value, err = strconv.ParseInt(parameters[i].Value.(string), 10, 64) + if err != nil { + return nil, fmt.Errorf("cannot cast `%s` to type `%s`", parameters[i].Value.(string), t.String()) + } + // It looks like string constants are treated as native integer types, so we'll mimic this here + if parameters[i].Source == Source_Constant { + parameters[i].OriginalType = ParameterType_Integer + } + case ParameterType_Float: + parameters[i].Value, err = strconv.ParseFloat(parameters[i].Value.(string), 64) + if err != nil { + return nil, fmt.Errorf("cannot cast `%s` to type `%s`", parameters[i].Value.(string), t.String()) + } + // It looks like string constants are treated as native float types, so we'll mimic this here + if parameters[i].Source == Source_Constant { + parameters[i].OriginalType = ParameterType_Float + } + case ParameterType_Numeric: + //TODO: properly handle decimal types + parameters[i].Value, err = strconv.ParseFloat(parameters[i].Value.(string), 64) + if err != nil { + return nil, fmt.Errorf("cannot cast `%s` to type `%s`", parameters[i].Value.(string), t.String()) + } + // It looks like string constants are treated as native numeric types, so we'll mimic this here + if parameters[i].Source == Source_Constant { + parameters[i].OriginalType = ParameterType_Numeric + } + case ParameterType_Timestamp: + //TODO: properly handle timestamps + parameters[i].Value, _, err = types.Datetime.Convert(parameters[i].Value) + if err != nil { + return nil, fmt.Errorf("cannot cast `%s` to type `%s`", parameters[i].Value.(string), t.String()) + } + default: + return nil, fmt.Errorf("invalid `%s` cast to `%s`", parameters[i].OriginalType.String(), t.String()) + } + case ParameterType_Timestamp: + return nil, fmt.Errorf("invalid `%s` cast to `%s`", parameters[i].OriginalType.String(), t.String()) + default: + return nil, fmt.Errorf("unhandled parameter type in %T::Resolve", overload) + } + } + return resultOverload, nil +} + +// ResolveByType returns the best matching overload for the given types. The returned types represent the actual types +// used by the overload, which may differ from the calling types. It is up to the caller to cast the parameters to match +// the types expected by the returned overload. Returns a nil OverloadDeduction if a viable match is not found. +func (overload *OverloadDeduction) ResolveByType(originalTypes []ParameterType) (*OverloadDeduction, []ParameterType) { + resultTypes := make([]ParameterType, len(originalTypes)) + copy(resultTypes, originalTypes) + return overload.resolveByType(originalTypes, resultTypes), resultTypes +} + +// resolveByType is the recursive implementation of ResolveByType. +func (overload *OverloadDeduction) resolveByType(originalTypes []ParameterType, resultTypes []ParameterType) *OverloadDeduction { + if overload == nil { + return nil + } + if len(originalTypes) == 0 { + if overload.Function.IsValid() && !overload.Function.IsNil() { + return overload + } + return nil + } + + // Check if we're able to resolve the original type + t := originalTypes[0] + resultOverload := overload.Parameter[t].resolveByType(originalTypes[1:], resultTypes[1:]) + if resultOverload != nil { + resultTypes[0] = t + return resultOverload + } + + // We did not find a resolution for the original type, so we'll look through each cast + for _, cast := range t.PotentialCasts() { + resultOverload = overload.Parameter[cast].resolveByType(originalTypes[1:], resultTypes[1:]) + if resultOverload != nil { + resultTypes[0] = cast + return resultOverload + } + } + // We did not find any potential matches, so we'll return nil + return nil +} diff --git a/server/functions/zinternal_parameters.go b/server/functions/zinternal_parameters.go new file mode 100644 index 0000000000..2929d8fbed --- /dev/null +++ b/server/functions/zinternal_parameters.go @@ -0,0 +1,192 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +import ( + "fmt" + "reflect" + "time" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// ParameterType represents the type of a parameter. +type ParameterType uint8 + +const ( + ParameterType_Null ParameterType = iota // The parameter is a NULL value, and is therefore typeless + ParameterType_Integer // The parameter is an IntegerType type + ParameterType_Float // The parameter is a FloatType type + ParameterType_Numeric // The parameter is a NumericType type + ParameterType_String // The parameter is a StringType type + ParameterType_Timestamp // The parameter is a TimestampType type + + ParameterType_Length // The number of parameters. This should always be last in the enum declaration. +) + +// ptCasts contains an array of all potential casts for each parameter type +var ptCasts [ParameterType_Length][]ParameterType + +func init() { + ptCasts[ParameterType_Null] = []ParameterType{ParameterType_Integer, ParameterType_Float, ParameterType_Numeric, ParameterType_String, ParameterType_Timestamp} + ptCasts[ParameterType_Integer] = []ParameterType{ParameterType_Float, ParameterType_Numeric, ParameterType_String} + ptCasts[ParameterType_Float] = []ParameterType{ParameterType_Numeric, ParameterType_String} + ptCasts[ParameterType_Numeric] = []ParameterType{ParameterType_String} + ptCasts[ParameterType_String] = []ParameterType{ParameterType_Integer, ParameterType_Float, ParameterType_Numeric, ParameterType_Timestamp} + ptCasts[ParameterType_Timestamp] = []ParameterType{} +} + +// PotentialCasts returns all potential casts for the current type. For example, an IntegerType may be cast to a FloatType. +// Casts may be bidirectional, as a StringType may cast to an IntegerType, and an IntegerType may cast to a StringType. +func (pt ParameterType) PotentialCasts() []ParameterType { + return ptCasts[pt] +} + +// PotentialCasts returns all potential casts for the current type. For example, an IntegerType may be cast to a FloatType. +// Casts may be bidirectional, as a StringType may cast to an IntegerType, and an IntegerType may cast to a StringType. +func (pt ParameterType) String() string { + switch pt { + case ParameterType_Null: + return "null" + case ParameterType_Integer: + return "integer" + case ParameterType_Float: + return "double precision" + case ParameterType_Numeric: + return "numeric" + case ParameterType_String: + return "character varying" + case ParameterType_Timestamp: + return "timestamp" + default: + panic(fmt.Errorf("unhandled type in ParameterType::String (%d)", int(pt))) + } +} + +// ParameterTypeFromReflection returns the ParameterType and equivalent sql.Type from the given reflection type. If the +// given type does not match a ParameterType, then this returns false. +func ParameterTypeFromReflection(t reflect.Type) (ParameterType, sql.Type, bool) { + switch t { + case reflect.TypeOf(IntegerType{}): + return ParameterType_Integer, types.Int64, true + case reflect.TypeOf(FloatType{}): + return ParameterType_Float, types.Float64, true + case reflect.TypeOf(NumericType{}): + //TODO: properly handle decimal types + return ParameterType_Numeric, types.Float64, true + case reflect.TypeOf(StringType{}): + return ParameterType_String, types.LongText, true + case reflect.TypeOf(TimestampType{}): + return ParameterType_Timestamp, types.Datetime, true + default: + return ParameterType_Null, types.Null, false + } +} + +// IntermediateParameter is a parameter before it has been finalized. +type IntermediateParameter struct { + Value interface{} + IsNull bool + OriginalType ParameterType + CurrentType ParameterType + Source Source +} + +// IntegerType is an integer type (all integer types are upcast to int64). +type IntegerType struct { + Value int64 + IsNull bool + OriginalType ParameterType + Source Source +} + +// FloatType is a floating point type (float32 is upcast to float64). +type FloatType struct { + Value float64 + IsNull bool + OriginalType ParameterType + Source Source +} + +// NumericType is a decimal type (all integer and float types are upcast to decimal). +type NumericType struct { + Value float64 //TODO: should be decimal, but our engine support isn't quite there yet + IsNull bool + OriginalType ParameterType + Source Source +} + +// StringType is a string type. +type StringType struct { + Value string + IsNull bool + OriginalType ParameterType + Source Source +} + +// TimestampType is a timestamp type. +type TimestampType struct { + Value time.Time + IsNull bool + OriginalType ParameterType + Source Source +} + +// ToValue converts the intermediate parameter into a concrete parameter type (IntegerType, FloatType, etc.) and returns +// it as a reflect.Value, which may be passed to the matched function. +func (ip IntermediateParameter) ToValue() reflect.Value { + switch ip.CurrentType { + case ParameterType_Null: + panic(fmt.Errorf("a NULL parameter type was not erased before the call to %T::ToValue", ip)) + case ParameterType_Integer: + return reflect.ValueOf(IntegerType{ + Value: ip.Value.(int64), + IsNull: ip.IsNull, + OriginalType: ip.OriginalType, + Source: ip.Source, + }) + case ParameterType_Float: + return reflect.ValueOf(FloatType{ + Value: ip.Value.(float64), + IsNull: ip.IsNull, + OriginalType: ip.OriginalType, + Source: ip.Source, + }) + case ParameterType_Numeric: + return reflect.ValueOf(NumericType{ + Value: ip.Value.(float64), + IsNull: ip.IsNull, + OriginalType: ip.OriginalType, + Source: ip.Source, + }) + case ParameterType_String: + return reflect.ValueOf(StringType{ + Value: ip.Value.(string), + IsNull: ip.IsNull, + OriginalType: ip.OriginalType, + Source: ip.Source, + }) + case ParameterType_Timestamp: + return reflect.ValueOf(TimestampType{ + Value: ip.Value.(time.Time), + IsNull: ip.IsNull, + OriginalType: ip.OriginalType, + Source: ip.Source, + }) + default: + panic(fmt.Errorf("unhandled type in %T::ToValue", ip)) + } +} diff --git a/server/functions/zinternal_source.go b/server/functions/zinternal_source.go new file mode 100644 index 0000000000..24780677ff --- /dev/null +++ b/server/functions/zinternal_source.go @@ -0,0 +1,25 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +// Source defines what kind of expression generated the given value, as some functions are context-dependent, and we +// need to approximate the context. +type Source uint8 + +const ( + Source_Expression Source = iota // The source is some expression. This may change as more sources are added. + Source_Constant // The source is a constant value + Source_Column // The source is a column +) diff --git a/server/listener.go b/server/listener.go index 8ca3a15671..8f2d96adcf 100644 --- a/server/listener.go +++ b/server/listener.go @@ -35,6 +35,7 @@ import ( "github.com/dolthub/doltgresql/postgres/messages" "github.com/dolthub/doltgresql/postgres/parser/parser" "github.com/dolthub/doltgresql/server/ast" + _ "github.com/dolthub/doltgresql/server/functions" ) var ( diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go new file mode 100644 index 0000000000..03306f9a93 --- /dev/null +++ b/testing/go/functions_test.go @@ -0,0 +1,155 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package _go + +import ( + "testing" + + "github.com/dolthub/go-mysql-server/sql" +) + +// https://www.postgresql.org/docs/15/functions-math.html +func TestFunctionsMath(t *testing.T) { + RunScripts(t, []ScriptTest{ + { + Name: "cbrt", + SetUpScript: []string{ + `CREATE TABLE test (pk INT primary key, v1 INT, v2 FLOAT4, v3 FLOAT8, v4 VARCHAR(255));`, + `INSERT INTO test VALUES (1, -1, -2, -3, '-5'), (2, 7, 11, 13, '17'), (3, 19, -23, 29, '-31');`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT cbrt(v1), cbrt(v2), cbrt(v3) FROM test ORDER BY pk;`, + Skip: true, // Our values are slightly different + Expected: []sql.Row{ + {-1.0, -1.259921049894873, -1.4422495703074083}, + {1.9129311827723892, 2.2239800905693157, 2.3513346877207573}, + {2.668401648721945, -2.8438669798515654, 3.0723168256858475}, + }, + }, + { + Query: `SELECT round(cbrt(v1)::numeric, 10), round(cbrt(v2)::numeric, 10), round(cbrt(v3)::numeric, 10) FROM test ORDER BY pk;`, + Expected: []sql.Row{ + {-1.0000000000, -1.2599210499, -1.4422495703}, + {1.9129311828, 2.2239800906, 2.3513346877}, + {2.6684016487, -2.8438669799, 3.0723168257}, + }, + }, + { + Query: `SELECT cbrt(v4) FROM test ORDER BY pk;`, + ExpectedErr: true, + }, + { + Query: `SELECT cbrt('64');`, + Expected: []sql.Row{ + {4.0}, + }, + }, + { + Query: `SELECT round(cbrt('64'));`, + Expected: []sql.Row{ + {4.0}, + }, + }, + }, + }, + { + Name: "gcd", + SetUpScript: []string{ + `CREATE TABLE test (pk INT primary key, v1 INT4, v2 INT8, v3 FLOAT8, v4 VARCHAR(255));`, + `INSERT INTO test VALUES (1, -2, -4, -6, '-8'), (2, 10, 12, 14.14, '16.16'), (3, 18, -20, 22.22, '-24.24');`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT gcd(v1, 10), gcd(v2, 20) FROM test ORDER BY pk;`, + Expected: []sql.Row{ + {2, 4}, + {10, 4}, + {2, 20}, + }, + }, + { + Query: `SELECT gcd(v3, 10) FROM test ORDER BY pk;`, + ExpectedErr: true, + }, + { + Query: `SELECT gcd(v4, 10) FROM test ORDER BY pk;`, + ExpectedErr: true, + }, + { + Query: `SELECT gcd(36, '48');`, + Expected: []sql.Row{ + {12}, + }, + }, + { + Query: `SELECT gcd('36', 48);`, + Expected: []sql.Row{ + {12}, + }, + }, + { + Query: `SELECT gcd(1, 0), gcd(0, 1), gcd(0, 0);`, + Expected: []sql.Row{ + {1, 1, 0}, + }, + }, + }, + }, + { + Name: "lcm", + SetUpScript: []string{ + `CREATE TABLE test (pk INT primary key, v1 INT4, v2 INT8, v3 FLOAT8, v4 VARCHAR(255));`, + `INSERT INTO test VALUES (1, -2, -4, -6, '-8'), (2, 10, 12, 14.14, '16.16'), (3, 18, -20, 22.22, '-24.24');`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT lcm(v1, 10), lcm(v2, 20) FROM test ORDER BY pk;`, + Expected: []sql.Row{ + {10, 20}, + {10, 60}, + {90, 20}, + }, + }, + { + Query: `SELECT lcm(v3, 10) FROM test ORDER BY pk;`, + ExpectedErr: true, + }, + { + Query: `SELECT lcm(v4, 10) FROM test ORDER BY pk;`, + ExpectedErr: true, + }, + { + Query: `SELECT lcm(36, '48');`, + Expected: []sql.Row{ + {144}, + }, + }, + { + Query: `SELECT lcm('36', 48);`, + Expected: []sql.Row{ + {144}, + }, + }, + { + Query: `SELECT lcm(1, 0), lcm(0, 1), lcm(0, 0);`, + Expected: []sql.Row{ + {0, 0, 0}, + }, + }, + }, + }, + }) +} diff --git a/utils/abs.go b/utils/abs.go new file mode 100644 index 0000000000..69a2f22ec7 --- /dev/null +++ b/utils/abs.go @@ -0,0 +1,27 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "golang.org/x/exp/constraints" +) + +// Abs returns the absolute value of the given number. +func Abs[T constraints.Integer | constraints.Float](val T) T { + if val < 0 { + return -val + } + return val +}