Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion server/ast/create_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,19 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
// Grab the general information that we'll need to create the function
tableName := node.Name.ToTableName()
retType := pgtypes.Void
if len(node.RetType) == 1 {
if len(node.RetType) == 1 { // Return types may specify "trigger", but this doesn't apply elsewhere
switch typ := node.RetType[0].Type.(type) {
case *types.T:
retType = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(typ.Name()))
case *tree.UnresolvedObjectName:
if typ.NumParts == 1 && typ.SQLString() == "trigger" {
retType = pgtypes.Trigger
} else {
_, retType, err = nodeResolvableTypeReference(ctx, typ)
if err != nil {
return nil, err
}
}
default:
sqlString := strings.ToLower(typ.SQLString())
if sqlString == "trigger" {
Expand All @@ -59,6 +68,11 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
switch argType := arg.Type.(type) {
case *types.T:
paramTypes[i] = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.Name()))
case *tree.UnresolvedObjectName:
_, paramTypes[i], err = nodeResolvableTypeReference(ctx, argType)
if err != nil {
return nil, err
}
default:
paramTypes[i] = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.SQLString()))
}
Expand Down
8 changes: 2 additions & 6 deletions server/expression/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,7 @@ func (array *Array) Eval(ctx *sql.Context, row sql.Row) (any, error) {
// We always cast the element, as there may be parameter restrictions in place
castFunc := framework.GetImplicitCast(doltgresType, resultTyp)
if castFunc == nil {
if doltgresType.ID == pgtypes.Unknown.ID {
castFunc = framework.UnknownLiteralCast
} else {
return nil, errors.Errorf("cannot find cast function from %s to %s", doltgresType.String(), resultTyp.String())
}
return nil, errors.Errorf("cannot find cast function from %s to %s", doltgresType.String(), resultTyp.String())
}

values[i], err = castFunc(ctx, val, resultTyp)
Expand Down Expand Up @@ -175,7 +171,7 @@ func (array *Array) getTargetType(children ...sql.Expression) (*pgtypes.Doltgres
childrenTypes = append(childrenTypes, childType)
}
}
targetType, err := framework.FindCommonType(childrenTypes)
targetType, _, err := framework.FindCommonType(childrenTypes)
if err != nil {
return nil, errors.Errorf("ARRAY %s", err.Error())
}
Expand Down
8 changes: 2 additions & 6 deletions server/expression/assignment_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,8 @@ func (ac *AssignmentCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
}
castFunc := framework.GetAssignmentCast(ac.fromType, ac.toType)
if castFunc == nil {
if ac.fromType.ID == pgtypes.Unknown.ID {
castFunc = framework.UnknownLiteralCast
} else {
return nil, errors.Errorf("ASSIGNMENT_CAST: target is of type %s but expression is of type %s: %s",
ac.toType.String(), ac.fromType.String(), ac.expr.String())
}
return nil, errors.Errorf("ASSIGNMENT_CAST: target is of type %s but expression is of type %s: %s",
ac.toType.String(), ac.fromType.String(), ac.expr.String())
}
return castFunc(ctx, val, ac.toType)
}
Expand Down
52 changes: 2 additions & 50 deletions server/expression/explicit_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/dolthub/go-mysql-server/sql/expression"
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
)
Expand Down Expand Up @@ -97,55 +96,8 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
baseCastToType := checkForDomainType(c.castToType)
castFunction := framework.GetExplicitCast(fromType, baseCastToType)
if castFunction == nil {
if fromType.ID == pgtypes.Unknown.ID {
castFunction = framework.UnknownLiteralCast
} else if fromType.IsRecordType() && c.castToType.IsCompositeType() { // TODO: should this only be in explicit, or assignment and implicit too?
// Casting to a record type will always work for any composite type.
// TODO: is the above statement true for all cases?
// When casting to a composite type, then we must match the arity and have valid casts for every position.
if c.castToType.IsRecordType() {
castFunction = framework.IdentityCast
} else {
castFunction = func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
vals, ok := val.([]pgtypes.RecordValue)
if !ok {
// TODO: better error message
return nil, errors.New("casting input error from record type")
}
if len(targetType.CompositeAttrs) != len(vals) {
return nil, errors.Newf("cannot cast type %s to %s", "", targetType.Name())
}
typeCollection, err := core.GetTypesCollectionFromContext(ctx)
if err != nil {
return nil, err
}
outputVals := make([]pgtypes.RecordValue, len(vals))
for i := range vals {
valType, ok := vals[i].Type.(*pgtypes.DoltgresType)
if !ok {
// TODO: if this is a GMS type, then we should cast to a Doltgres type here
return nil, errors.New("cannot cast record containing GMS type")
}
outputVals[i].Type, err = typeCollection.GetType(ctx, targetType.CompositeAttrs[i].TypeID)
if err != nil {
return nil, err
}
innerExplicit := ExplicitCast{
sqlChild: NewUnsafeLiteral(vals[i].Value, valType),
castToType: outputVals[i].Type.(*pgtypes.DoltgresType),
}
outputVals[i].Value, err = innerExplicit.Eval(ctx, nil)
if err != nil {
return nil, err
}
}
return outputVals, nil
}
}
} else {
return nil, errors.Errorf("EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s",
fromType.String(), c.castToType.String(), c.sqlChild.String())
}
return nil, errors.Errorf("EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s",
fromType.String(), c.castToType.String(), c.sqlChild.String())
}
castResult, err := castFunction(ctx, val, c.castToType)
if err != nil {
Expand Down
77 changes: 77 additions & 0 deletions server/functions/framework/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/cockroachdb/errors"
"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/core/id"
pgtypes "github.com/dolthub/doltgresql/server/types"
)
Expand Down Expand Up @@ -134,6 +135,9 @@ func GetExplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp
if cast := getSizingOrIdentityCast(fromType, toType, true); cast != nil {
return cast
}
if recordCast := getRecordCast(fromType, toType, GetExplicitCast); recordCast != nil {
return recordCast
}
// All types have a built-in explicit cast from string types: https://www.postgresql.org/docs/15/sql-createcast.html
if fromType.TypCategory == pgtypes.TypeCategory_StringTypes {
return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
Expand All @@ -159,6 +163,10 @@ func GetExplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp
return targetType.IoInput(ctx, str)
}
}
// It is always valid to convert from the `unknown` type
if fromType.ID == pgtypes.Unknown.ID {
return UnknownLiteralCast
}
return nil
}

Expand All @@ -174,6 +182,10 @@ func GetAssignmentCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresT
if cast := getSizingOrIdentityCast(fromType, toType, false); cast != nil {
return cast
}
// We then check for a record to composite cast
if recordCast := getRecordCast(fromType, toType, GetAssignmentCast); recordCast != nil {
return recordCast
}
// All types have a built-in assignment cast to string types: https://www.postgresql.org/docs/15/sql-createcast.html
if toType.TypCategory == pgtypes.TypeCategory_StringTypes {
return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
Expand All @@ -187,6 +199,10 @@ func GetAssignmentCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresT
return targetType.IoInput(ctx, str)
}
}
// It is always valid to convert from the `unknown` type
if fromType.ID == pgtypes.Unknown.ID {
return UnknownLiteralCast
}
return nil
}

Expand All @@ -200,6 +216,14 @@ func GetImplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp
if cast := getSizingOrIdentityCast(fromType, toType, false); cast != nil {
return cast
}
// We then check for a record to composite cast
if recordCast := getRecordCast(fromType, toType, GetImplicitCast); recordCast != nil {
return recordCast
}
// It is always valid to convert from the `unknown` type
if fromType.ID == pgtypes.Unknown.ID {
return UnknownLiteralCast
}
return nil
}

Expand Down Expand Up @@ -312,6 +336,59 @@ func getSizingOrIdentityCast(fromType *pgtypes.DoltgresType, toType *pgtypes.Dol
return IdentityCast
}

// getRecordCast handles casting from a record type to a composite type (if applicable). Returns nil if not applicable.
func getRecordCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType, passthrough func(*pgtypes.DoltgresType, *pgtypes.DoltgresType) pgtypes.TypeCastFunction) pgtypes.TypeCastFunction {
// TODO: does casting to a record type always work for any composite type?
// https://www.postgresql.org/docs/15/sql-expressions.html#SQL-SYNTAX-ROW-CONSTRUCTORS seems to suggest so
// Also not sure if we should use the passthrough, or if we always default to implicit, assignment, or explicit
if fromType.IsRecordType() && toType.IsCompositeType() {
// When casting to a composite type, then we must match the arity and have valid casts for every position.
if toType.IsRecordType() {
return IdentityCast
} else {
return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
vals, ok := val.([]pgtypes.RecordValue)
if !ok {
return nil, errors.New("casting input error from record type")
}
if len(targetType.CompositeAttrs) != len(vals) {
// TODO: these should go in DETAIL depending on the size
// Input has too few columns.
// Input has too many columns.
return nil, errors.Newf("cannot cast type %s to %s", fromType.Name(), targetType.Name())
}
typeCollection, err := core.GetTypesCollectionFromContext(ctx)
if err != nil {
return nil, err
}
outputVals := make([]pgtypes.RecordValue, len(vals))
for i := range vals {
valType, ok := vals[i].Type.(*pgtypes.DoltgresType)
if !ok {
return nil, errors.New("cannot cast record containing GMS type")
}
outputType, err := typeCollection.GetType(ctx, targetType.CompositeAttrs[i].TypeID)
if err != nil {
return nil, err
}
outputVals[i].Type = outputType
positionCast := passthrough(valType, outputType)
if positionCast == nil {
// TODO: this should be the DETAIL, with the actual error being "cannot cast type <FROM_TYPE> to <TO_TYPE>"
return nil, errors.Newf("Cannot cast type %s to %s in column %d", valType.Name(), outputType.Name(), i+1)
}
outputVals[i].Value, err = positionCast(ctx, vals[i].Value, outputType)
if err != nil {
return nil, err
}
}
return outputVals, nil
}
}
}
return nil
}

// IdentityCast returns the input value.
func IdentityCast(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
return val, nil
Expand Down
50 changes: 30 additions & 20 deletions server/functions/framework/common_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,52 +20,62 @@ import (
pgtypes "github.com/dolthub/doltgresql/server/types"
)

// FindCommonType returns the common type that given types can convert to.
// FindCommonType returns the common type that given types can convert to. Returns false if no implicit casts are needed
// to resolve the given types as the returned common type.
// https://www.postgresql.org/docs/15/typeconv-union-case.html
func FindCommonType(types []*pgtypes.DoltgresType) (*pgtypes.DoltgresType, error) {
var candidateType = pgtypes.Unknown
var fail = false
func FindCommonType(types []*pgtypes.DoltgresType) (_ *pgtypes.DoltgresType, requiresCasts bool, err error) {
candidateType := pgtypes.Unknown
differentTypes := false
for _, typ := range types {
if typ.ID == candidateType.ID {
continue
} else if candidateType.ID == pgtypes.Unknown.ID {
candidateType = typ
} else {
candidateType = pgtypes.Unknown
fail = true
differentTypes = true
}
}
if !fail {
if !differentTypes {
if candidateType.ID == pgtypes.Unknown.ID {
return pgtypes.Text, nil
// We require implicit casts from `unknown` to `text`
return pgtypes.Text, true, nil
}
return candidateType, nil
return candidateType, false, nil
}
// We have different types if we've made it this far, so we're guaranteed to require implicit casts
requiresCasts = true
for _, typ := range types {
if candidateType.ID == pgtypes.Unknown.ID {
candidateType = typ
}
if typ.ID != pgtypes.Unknown.ID && candidateType.TypCategory != typ.TypCategory {
return nil, errors.Errorf("types %s and %s cannot be matched", candidateType.String(), typ.String())
return nil, false, errors.Errorf("types %s and %s cannot be matched", candidateType.String(), typ.String())
}
}

var preferredTypeFound = false
// Attempt to find the most general type (or the preferred type in the type category)
for _, typ := range types {
if typ.ID == pgtypes.Unknown.ID {
if typ.ID == pgtypes.Unknown.ID || typ.ID == candidateType.ID {
continue
} else if GetImplicitCast(typ, candidateType) != nil {
// typ can convert to the candidate type, so the candidate type is at least as general
continue
} else if GetImplicitCast(candidateType, typ) == nil {
return nil, errors.Errorf("cannot find implicit cast function from %s to %s", candidateType.String(), typ.String())
} else if !preferredTypeFound {
} else if GetImplicitCast(candidateType, typ) != nil {
// the candidate type can convert to typ, but not vice versa, so typ is likely more general
candidateType = typ
if candidateType.IsPreferred {
candidateType = typ
preferredTypeFound = true
// We stop considering more types once we've found a preferred type
break
}
} else {
return nil, errors.Errorf("found another preferred candidate type")
}
}
return candidateType, nil
// Verify that all types have an implicit conversion to the candidate type
for _, typ := range types {
if typ.ID == pgtypes.Unknown.ID || typ.ID == candidateType.ID {
continue
} else if GetImplicitCast(typ, candidateType) == nil {
return nil, false, errors.Errorf("cannot find implicit cast function from %s to %s", candidateType.String(), typ.String())
}
}
return candidateType, requiresCasts, nil
}
8 changes: 2 additions & 6 deletions server/functions/framework/compiled_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,12 +582,8 @@ func (c *CompiledFunction) typeCompatibleOverloads(fnOverloads []Overload, argTy
polymorphicTargets = append(polymorphicTargets, argTypes[i])
} else {
if overloadCasts[i] = GetImplicitCast(argTypes[i], paramType); overloadCasts[i] == nil {
if argTypes[i].ID == pgtypes.Unknown.ID {
overloadCasts[i] = UnknownLiteralCast
} else {
isConvertible = false
break
}
isConvertible = false
break
}
}
}
Expand Down
Loading