diff --git a/core/typecollection/typecollection.go b/core/typecollection/typecollection.go index e5bdc899ac..c185bf30f8 100644 --- a/core/typecollection/typecollection.go +++ b/core/typecollection/typecollection.go @@ -32,6 +32,13 @@ import ( pgtypes "github.com/dolthub/doltgresql/server/types" ) +// anonymousCompositePrefix is the prefix for anonymous composite type names. These types are not stored on +// disk, but instead are created dynamically as needed. +const anonymousCompositePrefix = "table(" + +// anonymousCompositeSuffix is the suffix for anonymous composite type names. +const anonymousCompositeSuffix = ")" + // TypeCollection is a collection of all types (both built-in and user defined). type TypeCollection struct { accessedMap map[id.Type]*pgtypes.DoltgresType @@ -165,6 +172,11 @@ func (pgs *TypeCollection) GetType(ctx context.Context, name id.Type) (*pgtypes. return nil, err } if h.IsEmpty() { + // If this is an anonymous composite type, create it dynamically + if isAnonymousCompositeType(name) { + return createAnonymousCompositeType(ctx, name) + } + // If it's not a built-in type or created type, then check if it's a composite table row type sqlCtx, ok := ctx.(*sql.Context) if !ok { @@ -189,6 +201,34 @@ func (pgs *TypeCollection) GetType(ctx context.Context, name id.Type) (*pgtypes. return pgt, nil } +// isAnonymousCompositeType return true if |returnType| represents an anonymous composite return type +// for a function (i.e. the function was declared as "RETURNS TABLE(...)"). +func isAnonymousCompositeType(returnType id.Type) bool { + typeName := returnType.TypeName() + return strings.HasPrefix(typeName, anonymousCompositePrefix) && + strings.HasSuffix(typeName, anonymousCompositeSuffix) +} + +// createAnonymousCompositeType creates a new DoltgresType for the anonymous composite return type for a function, +// as represented by |returnType|. +func createAnonymousCompositeType(ctx context.Context, returnType id.Type) (*pgtypes.DoltgresType, error) { + typeName := returnType.TypeName() + attributeTypes := typeName[len(anonymousCompositePrefix) : len(typeName)-len(anonymousCompositeSuffix)] + attributeTypesSlice := strings.Split(attributeTypes, ",") + + attrs := make([]pgtypes.CompositeAttribute, len(attributeTypesSlice)) + for i, attributeNameAndType := range attributeTypesSlice { + split := strings.Split(attributeNameAndType, ":") + if len(split) != 2 { + return nil, fmt.Errorf("unexpected anonymous composite type attribute syntax: %s", attributeNameAndType) + } + + typeId := id.NewType("", split[1]) + attrs[i] = pgtypes.NewCompositeAttribute(nil, id.Null, split[0], typeId, int16(i), "") + } + return pgtypes.NewCompositeType(ctx, id.Null, id.NullType, returnType, attrs), nil +} + // HasType checks if a type exists with given schema and type name. func (pgs *TypeCollection) HasType(ctx context.Context, name id.Type) bool { // We can check the built-in types first diff --git a/postgres/parser/parser/sql.y b/postgres/parser/parser/sql.y index f7f6cc8fe2..9698f94145 100644 --- a/postgres/parser/parser/sql.y +++ b/postgres/parser/parser/sql.y @@ -4313,11 +4313,11 @@ create_function_stmt: } | CREATE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS SETOF typename create_function_option_list { - $$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), SetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $7.typeReference()}}, Options: $8.routineOptions()} + $$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), ReturnsSetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $7.typeReference()}}, Options: $8.routineOptions()} } | CREATE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS TABLE '(' opt_returns_table_col_def_list ')' create_function_option_list { - $$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), RetType: $8.simpleColumnDefs(), Options: $10.routineOptions()} + $$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), ReturnsTable: true, RetType: $8.simpleColumnDefs(), Options: $10.routineOptions()} } | CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list create_function_option_list { @@ -4329,11 +4329,11 @@ create_function_stmt: } | CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS SETOF typename create_function_option_list { - $$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), SetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $9.typeReference()}}, Options: $10.routineOptions()} + $$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), ReturnsSetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $9.typeReference()}}, Options: $10.routineOptions()} } | CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS TABLE '(' opt_returns_table_col_def_list ')' create_function_option_list { - $$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), RetType: $10.simpleColumnDefs(), Options: $12.routineOptions()} + $$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), ReturnsTable: true, RetType: $10.simpleColumnDefs(), Options: $12.routineOptions()} } opt_returns_table_col_def_list: diff --git a/postgres/parser/sem/tree/create_function.go b/postgres/parser/sem/tree/create_function.go index 5742fcab52..9b9a1f1d31 100644 --- a/postgres/parser/sem/tree/create_function.go +++ b/postgres/parser/sem/tree/create_function.go @@ -25,12 +25,13 @@ var _ Statement = &CreateFunction{} // CreateFunction represents a CREATE FUNCTION statement. type CreateFunction struct { - Name *UnresolvedObjectName - Replace bool - Args RoutineArgs - SetOf bool - RetType []SimpleColumnDef - Options []RoutineOption + Name *UnresolvedObjectName + Replace bool + Args RoutineArgs + ReturnsSetOf bool + ReturnsTable bool + RetType []SimpleColumnDef + Options []RoutineOption } // Format implements the NodeFormatter interface. @@ -47,9 +48,9 @@ func (node *CreateFunction) Format(ctx *FmtCtx) { ctx.WriteString(" )") } if node.RetType != nil { - if len(node.RetType) == 1 && node.RetType[0].Name == "" { + if !node.ReturnsTable { ctx.WriteString("RETURNS ") - if node.SetOf { + if node.ReturnsSetOf { ctx.WriteString("SETOF ") } ctx.WriteString(node.RetType[0].Type.SQLString()) diff --git a/server/ast/create_function.go b/server/ast/create_function.go index 5dc4069d1e..2839cf2443 100644 --- a/server/ast/create_function.go +++ b/server/ast/create_function.go @@ -15,12 +15,14 @@ package ast import ( + "context" "fmt" "strings" "github.com/cockroachdb/errors" vitess "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/parser" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" "github.com/dolthub/doltgresql/postgres/parser/types" @@ -38,20 +40,29 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme } // Grab the general information that we'll need to create the function tableName := node.Name.ToTableName() - retType := pgtypes.Void - if len(node.RetType) == 1 { + var retType *pgtypes.DoltgresType + if len(node.RetType) == 0 { + retType = pgtypes.Void + } else if !node.ReturnsTable { // Return types may specify "trigger", but this doesn't apply elsewhere switch typ := node.RetType[0].Type.(type) { case *types.T: retType = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(typ.Name())) - default: - sqlString := strings.ToLower(typ.SQLString()) - if sqlString == "trigger" { + case *tree.UnresolvedObjectName: + if typ.NumParts == 1 && typ.SQLString() == "trigger" { retType = pgtypes.Trigger } else { - retType = pgtypes.NewUnresolvedDoltgresType("", sqlString) + _, retType, err = nodeResolvableTypeReference(ctx, typ) + if err != nil { + return nil, err + } } + default: + return nil, fmt.Errorf("unsupported ResolvableTypeReference type: %T", typ) } + } else { + retType = createAnonymousCompositeType(node.RetType) } + paramNames := make([]string, len(node.Args)) paramTypes := make([]*pgtypes.DoltgresType, len(node.Args)) for i, arg := range node.Args { @@ -59,6 +70,11 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme switch argType := arg.Type.(type) { case *types.T: paramTypes[i] = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.Name())) + case *tree.UnresolvedObjectName: + _, paramTypes[i], err = nodeResolvableTypeReference(ctx, argType) + if err != nil { + return nil, err + } default: paramTypes[i] = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.SQLString())) } @@ -121,11 +137,38 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme parsedBody, sqlDef, sqlDefParsed, - node.SetOf, + node.ReturnsSetOf, ), }, nil } +// createAnonymousCompositeType creates a new DoltgresType for the anonymous composite return +// type for a function, as represented by the |fieldTypes| that were specified in the function +// definition. +func createAnonymousCompositeType(fieldTypes []tree.SimpleColumnDef) *pgtypes.DoltgresType { + attrs := make([]pgtypes.CompositeAttribute, len(fieldTypes)) + for i, fieldType := range fieldTypes { + attrs[i] = pgtypes.NewCompositeAttribute(nil, id.Null, fieldType.Name.String(), + id.NewType("", fieldType.Type.SQLString()), int16(i), "") + } + + typeIdString := "table(" + for i, attr := range attrs { + if i > 0 { + typeIdString += "," + } + typeIdString += attr.Name + typeIdString += ":" + typeIdString += attr.TypeID.TypeName() + } + typeIdString += ")" + + // NOTE: there is no schema needed, since these types are anonymous and can't be directly referenced + typeId := id.NewType("", typeIdString) + + return pgtypes.NewCompositeType(context.Background(), id.Null, id.NullType, typeId, attrs) +} + // handleLanguageSQL handles parsing SQL definition strings in both CREATE FUNCTION and CREATE PROCEDURE. func handleLanguageSQL(definition string, paramNames []string, paramTypes []*pgtypes.DoltgresType) (string, vitess.Statement, error) { stmt, err := parser.ParseOne(definition) diff --git a/server/expression/array.go b/server/expression/array.go index ea53d68f4e..0a91fe70bc 100644 --- a/server/expression/array.go +++ b/server/expression/array.go @@ -82,11 +82,7 @@ func (array *Array) Eval(ctx *sql.Context, row sql.Row) (any, error) { // We always cast the element, as there may be parameter restrictions in place castFunc := framework.GetImplicitCast(doltgresType, resultTyp) if castFunc == nil { - if doltgresType.ID == pgtypes.Unknown.ID { - castFunc = framework.UnknownLiteralCast - } else { - return nil, errors.Errorf("cannot find cast function from %s to %s", doltgresType.String(), resultTyp.String()) - } + return nil, errors.Errorf("cannot find cast function from %s to %s", doltgresType.String(), resultTyp.String()) } values[i], err = castFunc(ctx, val, resultTyp) @@ -175,7 +171,7 @@ func (array *Array) getTargetType(children ...sql.Expression) (*pgtypes.Doltgres childrenTypes = append(childrenTypes, childType) } } - targetType, err := framework.FindCommonType(childrenTypes) + targetType, _, err := framework.FindCommonType(childrenTypes) if err != nil { return nil, errors.Errorf("ARRAY %s", err.Error()) } diff --git a/server/expression/assignment_cast.go b/server/expression/assignment_cast.go index 832316f88a..e897192f43 100644 --- a/server/expression/assignment_cast.go +++ b/server/expression/assignment_cast.go @@ -56,12 +56,8 @@ func (ac *AssignmentCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { } castFunc := framework.GetAssignmentCast(ac.fromType, ac.toType) if castFunc == nil { - if ac.fromType.ID == pgtypes.Unknown.ID { - castFunc = framework.UnknownLiteralCast - } else { - return nil, errors.Errorf("ASSIGNMENT_CAST: target is of type %s but expression is of type %s: %s", - ac.toType.String(), ac.fromType.String(), ac.expr.String()) - } + return nil, errors.Errorf("ASSIGNMENT_CAST: target is of type %s but expression is of type %s: %s", + ac.toType.String(), ac.fromType.String(), ac.expr.String()) } return castFunc(ctx, val, ac.toType) } diff --git a/server/expression/explicit_cast.go b/server/expression/explicit_cast.go index 9b4ac18502..36186cf18a 100644 --- a/server/expression/explicit_cast.go +++ b/server/expression/explicit_cast.go @@ -23,7 +23,6 @@ import ( "github.com/dolthub/go-mysql-server/sql/expression" vitess "github.com/dolthub/vitess/go/vt/sqlparser" - "github.com/dolthub/doltgresql/core" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -97,55 +96,8 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { baseCastToType := checkForDomainType(c.castToType) castFunction := framework.GetExplicitCast(fromType, baseCastToType) if castFunction == nil { - if fromType.ID == pgtypes.Unknown.ID { - castFunction = framework.UnknownLiteralCast - } else if fromType.IsRecordType() && c.castToType.IsCompositeType() { // TODO: should this only be in explicit, or assignment and implicit too? - // Casting to a record type will always work for any composite type. - // TODO: is the above statement true for all cases? - // When casting to a composite type, then we must match the arity and have valid casts for every position. - if c.castToType.IsRecordType() { - castFunction = framework.IdentityCast - } else { - castFunction = func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - vals, ok := val.([]pgtypes.RecordValue) - if !ok { - // TODO: better error message - return nil, errors.New("casting input error from record type") - } - if len(targetType.CompositeAttrs) != len(vals) { - return nil, errors.Newf("cannot cast type %s to %s", "", targetType.Name()) - } - typeCollection, err := core.GetTypesCollectionFromContext(ctx) - if err != nil { - return nil, err - } - outputVals := make([]pgtypes.RecordValue, len(vals)) - for i := range vals { - valType, ok := vals[i].Type.(*pgtypes.DoltgresType) - if !ok { - // TODO: if this is a GMS type, then we should cast to a Doltgres type here - return nil, errors.New("cannot cast record containing GMS type") - } - outputVals[i].Type, err = typeCollection.GetType(ctx, targetType.CompositeAttrs[i].TypeID) - if err != nil { - return nil, err - } - innerExplicit := ExplicitCast{ - sqlChild: NewUnsafeLiteral(vals[i].Value, valType), - castToType: outputVals[i].Type.(*pgtypes.DoltgresType), - } - outputVals[i].Value, err = innerExplicit.Eval(ctx, nil) - if err != nil { - return nil, err - } - } - return outputVals, nil - } - } - } else { - return nil, errors.Errorf("EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s", - fromType.String(), c.castToType.String(), c.sqlChild.String()) - } + return nil, errors.Errorf("EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s", + fromType.String(), c.castToType.String(), c.sqlChild.String()) } castResult, err := castFunction(ctx, val, c.castToType) if err != nil { diff --git a/server/functions/framework/cast.go b/server/functions/framework/cast.go index 2cf13cd1b5..1150ab931f 100644 --- a/server/functions/framework/cast.go +++ b/server/functions/framework/cast.go @@ -20,6 +20,7 @@ import ( "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core" "github.com/dolthub/doltgresql/core/id" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -134,6 +135,9 @@ func GetExplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp if cast := getSizingOrIdentityCast(fromType, toType, true); cast != nil { return cast } + if recordCast := getRecordCast(fromType, toType, GetExplicitCast); recordCast != nil { + return recordCast + } // All types have a built-in explicit cast from string types: https://www.postgresql.org/docs/15/sql-createcast.html if fromType.TypCategory == pgtypes.TypeCategory_StringTypes { return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { @@ -159,6 +163,10 @@ func GetExplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp return targetType.IoInput(ctx, str) } } + // It is always valid to convert from the `unknown` type + if fromType.ID == pgtypes.Unknown.ID { + return UnknownLiteralCast + } return nil } @@ -174,6 +182,10 @@ func GetAssignmentCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresT if cast := getSizingOrIdentityCast(fromType, toType, false); cast != nil { return cast } + // We then check for a record to composite cast + if recordCast := getRecordCast(fromType, toType, GetAssignmentCast); recordCast != nil { + return recordCast + } // All types have a built-in assignment cast to string types: https://www.postgresql.org/docs/15/sql-createcast.html if toType.TypCategory == pgtypes.TypeCategory_StringTypes { return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { @@ -187,6 +199,10 @@ func GetAssignmentCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresT return targetType.IoInput(ctx, str) } } + // It is always valid to convert from the `unknown` type + if fromType.ID == pgtypes.Unknown.ID { + return UnknownLiteralCast + } return nil } @@ -200,6 +216,14 @@ func GetImplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp if cast := getSizingOrIdentityCast(fromType, toType, false); cast != nil { return cast } + // We then check for a record to composite cast + if recordCast := getRecordCast(fromType, toType, GetImplicitCast); recordCast != nil { + return recordCast + } + // It is always valid to convert from the `unknown` type + if fromType.ID == pgtypes.Unknown.ID { + return UnknownLiteralCast + } return nil } @@ -312,6 +336,59 @@ func getSizingOrIdentityCast(fromType *pgtypes.DoltgresType, toType *pgtypes.Dol return IdentityCast } +// getRecordCast handles casting from a record type to a composite type (if applicable). Returns nil if not applicable. +func getRecordCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType, passthrough func(*pgtypes.DoltgresType, *pgtypes.DoltgresType) pgtypes.TypeCastFunction) pgtypes.TypeCastFunction { + // TODO: does casting to a record type always work for any composite type? + // https://www.postgresql.org/docs/15/sql-expressions.html#SQL-SYNTAX-ROW-CONSTRUCTORS seems to suggest so + // Also not sure if we should use the passthrough, or if we always default to implicit, assignment, or explicit + if fromType.IsRecordType() && toType.IsCompositeType() { + // When casting to a composite type, then we must match the arity and have valid casts for every position. + if toType.IsRecordType() { + return IdentityCast + } else { + return func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + vals, ok := val.([]pgtypes.RecordValue) + if !ok { + return nil, errors.New("casting input error from record type") + } + if len(targetType.CompositeAttrs) != len(vals) { + // TODO: these should go in DETAIL depending on the size + // Input has too few columns. + // Input has too many columns. + return nil, errors.Newf("cannot cast type %s to %s", fromType.Name(), targetType.Name()) + } + typeCollection, err := core.GetTypesCollectionFromContext(ctx) + if err != nil { + return nil, err + } + outputVals := make([]pgtypes.RecordValue, len(vals)) + for i := range vals { + valType, ok := vals[i].Type.(*pgtypes.DoltgresType) + if !ok { + return nil, errors.New("cannot cast record containing GMS type") + } + outputType, err := typeCollection.GetType(ctx, targetType.CompositeAttrs[i].TypeID) + if err != nil { + return nil, err + } + outputVals[i].Type = outputType + positionCast := passthrough(valType, outputType) + if positionCast == nil { + // TODO: this should be the DETAIL, with the actual error being "cannot cast type to " + return nil, errors.Newf("Cannot cast type %s to %s in column %d", valType.Name(), outputType.Name(), i+1) + } + outputVals[i].Value, err = positionCast(ctx, vals[i].Value, outputType) + if err != nil { + return nil, err + } + } + return outputVals, nil + } + } + } + return nil +} + // IdentityCast returns the input value. func IdentityCast(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { return val, nil diff --git a/server/functions/framework/common_type.go b/server/functions/framework/common_type.go index f06f22f90a..fc69c9a693 100644 --- a/server/functions/framework/common_type.go +++ b/server/functions/framework/common_type.go @@ -20,11 +20,12 @@ import ( pgtypes "github.com/dolthub/doltgresql/server/types" ) -// FindCommonType returns the common type that given types can convert to. +// FindCommonType returns the common type that given types can convert to. Returns false if no implicit casts are needed +// to resolve the given types as the returned common type. // https://www.postgresql.org/docs/15/typeconv-union-case.html -func FindCommonType(types []*pgtypes.DoltgresType) (*pgtypes.DoltgresType, error) { - var candidateType = pgtypes.Unknown - var fail = false +func FindCommonType(types []*pgtypes.DoltgresType) (_ *pgtypes.DoltgresType, requiresCasts bool, err error) { + candidateType := pgtypes.Unknown + differentTypes := false for _, typ := range types { if typ.ID == candidateType.ID { continue @@ -32,40 +33,49 @@ func FindCommonType(types []*pgtypes.DoltgresType) (*pgtypes.DoltgresType, error candidateType = typ } else { candidateType = pgtypes.Unknown - fail = true + differentTypes = true } } - if !fail { + if !differentTypes { if candidateType.ID == pgtypes.Unknown.ID { - return pgtypes.Text, nil + // We require implicit casts from `unknown` to `text` + return pgtypes.Text, true, nil } - return candidateType, nil + return candidateType, false, nil } + // We have different types if we've made it this far, so we're guaranteed to require implicit casts + requiresCasts = true for _, typ := range types { if candidateType.ID == pgtypes.Unknown.ID { candidateType = typ } if typ.ID != pgtypes.Unknown.ID && candidateType.TypCategory != typ.TypCategory { - return nil, errors.Errorf("types %s and %s cannot be matched", candidateType.String(), typ.String()) + return nil, false, errors.Errorf("types %s and %s cannot be matched", candidateType.String(), typ.String()) } } - - var preferredTypeFound = false + // Attempt to find the most general type (or the preferred type in the type category) for _, typ := range types { - if typ.ID == pgtypes.Unknown.ID { + if typ.ID == pgtypes.Unknown.ID || typ.ID == candidateType.ID { continue } else if GetImplicitCast(typ, candidateType) != nil { + // typ can convert to the candidate type, so the candidate type is at least as general continue - } else if GetImplicitCast(candidateType, typ) == nil { - return nil, errors.Errorf("cannot find implicit cast function from %s to %s", candidateType.String(), typ.String()) - } else if !preferredTypeFound { + } else if GetImplicitCast(candidateType, typ) != nil { + // the candidate type can convert to typ, but not vice versa, so typ is likely more general + candidateType = typ if candidateType.IsPreferred { - candidateType = typ - preferredTypeFound = true + // We stop considering more types once we've found a preferred type + break } - } else { - return nil, errors.Errorf("found another preferred candidate type") } } - return candidateType, nil + // Verify that all types have an implicit conversion to the candidate type + for _, typ := range types { + if typ.ID == pgtypes.Unknown.ID || typ.ID == candidateType.ID { + continue + } else if GetImplicitCast(typ, candidateType) == nil { + return nil, false, errors.Errorf("cannot find implicit cast function from %s to %s", candidateType.String(), typ.String()) + } + } + return candidateType, requiresCasts, nil } diff --git a/server/functions/framework/compiled_function.go b/server/functions/framework/compiled_function.go index 19e84f8c5c..7311c43a90 100644 --- a/server/functions/framework/compiled_function.go +++ b/server/functions/framework/compiled_function.go @@ -582,12 +582,8 @@ func (c *CompiledFunction) typeCompatibleOverloads(fnOverloads []Overload, argTy polymorphicTargets = append(polymorphicTargets, argTypes[i]) } else { if overloadCasts[i] = GetImplicitCast(argTypes[i], paramType); overloadCasts[i] == nil { - if argTypes[i].ID == pgtypes.Unknown.ID { - overloadCasts[i] = UnknownLiteralCast - } else { - isConvertible = false - break - } + isConvertible = false + break } } } diff --git a/server/functions/framework/provider.go b/server/functions/framework/provider.go index bf7143047e..e4dc388b6e 100644 --- a/server/functions/framework/provider.go +++ b/server/functions/framework/provider.go @@ -17,9 +17,8 @@ package framework import ( "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/doltgresql/core/extensions" - "github.com/dolthub/doltgresql/core" + "github.com/dolthub/doltgresql/core/extensions" "github.com/dolthub/doltgresql/core/id" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -71,6 +70,7 @@ func (fp *FunctionProvider) Function(ctx *sql.Context, name string) (sql.Functio if err != nil || returnType == nil { return nil, false } + paramTypes := make([]*pgtypes.DoltgresType, len(overload.ParameterTypes)) for i, paramType := range overload.ParameterTypes { paramTypes[i], err = typesCollection.GetType(ctx, paramType) diff --git a/server/plpgsql/json_convert.go b/server/plpgsql/json_convert.go index 6775a1545d..a735f4289f 100644 --- a/server/plpgsql/json_convert.go +++ b/server/plpgsql/json_convert.go @@ -27,22 +27,39 @@ func jsonConvert(jsonBlock plpgSQL_block) (Block, error) { TriggerOld: jsonBlock.OldVariableNumber, Label: jsonBlock.Action.StmtBlock.Label, } + lowestRecordNumber := int32(2147483647) + // We do a first loop to determine the offset for the first record + for _, v := range jsonBlock.Datums { + switch { + case v.Record != nil: + if v.Record.DatumNumber < lowestRecordNumber { + lowestRecordNumber = v.Record.DatumNumber + } + } + } + offset := int32(0) - lowestRecordNumber + // Then we do a second loop that actually adds all of the datums to the block for _, v := range jsonBlock.Datums { switch { case v.Record != nil: // TODO: support normal record types - if int(v.Record.DatumNumber) > len(block.Records) { + datumNumber := v.Record.DatumNumber + offset + if int(datumNumber) >= len(block.Records) { oldRecords := block.Records - block.Records = make([]Record, v.Record.DatumNumber) + block.Records = make([]Record, datumNumber+1) copy(block.Records, oldRecords) } - block.Records[v.Record.DatumNumber-1].Name = v.Record.RefName + + if v.Record.DatumNumber > 0 { + block.Records[datumNumber].Name = v.Record.RefName + } case v.RecordField != nil: - if int(v.RecordField.RecordParentNumber) > len(block.Records) { + recordParentNumber := v.RecordField.RecordParentNumber + offset + if int(recordParentNumber) >= len(block.Records) { return Block{}, errors.New("invalid record parent number") } - block.Records[v.RecordField.RecordParentNumber-1].Fields = append( - block.Records[v.RecordField.RecordParentNumber-1].Fields, v.RecordField.FieldName) + block.Records[recordParentNumber].Fields = append( + block.Records[recordParentNumber].Fields, v.RecordField.FieldName) case v.Row != nil: if len(v.Row.Fields) != 1 { return Block{}, errors.New("record types are not yet supported") diff --git a/server/types/composite.go b/server/types/composite.go index ff67b8db0f..1a3abf6e5c 100644 --- a/server/types/composite.go +++ b/server/types/composite.go @@ -15,13 +15,15 @@ package types import ( + "context" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/doltgresql/core/id" ) // NewCompositeType creates new instance of composite DoltgresType. -func NewCompositeType(ctx *sql.Context, relID id.Id, arrayID, typeID id.Type, attrs []CompositeAttribute) *DoltgresType { +func NewCompositeType(_ context.Context, relID id.Id, arrayID, typeID id.Type, attrs []CompositeAttribute) *DoltgresType { return &DoltgresType{ ID: typeID, TypLength: -1, diff --git a/testing/go/create_function_plpgsql_test.go b/testing/go/create_function_plpgsql_test.go index cc948346d5..d0633b5cfc 100644 --- a/testing/go/create_function_plpgsql_test.go +++ b/testing/go/create_function_plpgsql_test.go @@ -467,6 +467,40 @@ $$ LANGUAGE plpgsql;`}, }, }, }, + { + Name: "RETURNS SETOF with type from other schema", + SetUpScript: []string{ + `CREATE SCHEMA sch1;`, + `CREATE TYPE sch1.user_summary AS ( + user_id integer, + username text, + is_active boolean);`, + `CREATE OR REPLACE FUNCTION func2() RETURNS SETOF sch1.user_summary + LANGUAGE plpgsql + AS $$ + BEGIN + RETURN QUERY SELECT 1, 'username', true; + RETURN QUERY SELECT 2, 'another', false; + END; + $$;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT func2();", + Expected: []sql.Row{ + {"(1,username,t)"}, + {"(2,another,f)"}, + }, + }, + { + Query: "SELECT func2(), func2();", + Expected: []sql.Row{ + {"(1,username,t)", "(1,username,t)"}, + {"(2,another,f)", "(2,another,f)"}, + }, + }, + }, + }, { Name: "RETURNS SETOF with param", SetUpScript: []string{ @@ -500,6 +534,126 @@ $$ LANGUAGE plpgsql;`}, }, }, }, + { + Name: "RETURNS TABLE", + SetUpScript: []string{ + `CREATE FUNCTION func2() RETURNS TABLE(user_id integer, username text, is_active boolean) + LANGUAGE plpgsql + AS $$ + BEGIN + RETURN QUERY SELECT 1, 'username', true; + RETURN QUERY SELECT 2, 'another', false; + END; + $$;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT func2();", + Expected: []sql.Row{ + {"(1,username,t)"}, + {"(2,another,f)"}, + }, + }, + { + Query: "SELECT func2(), func2();", + Expected: []sql.Row{ + {"(1,username,t)", "(1,username,t)"}, + {"(2,another,f)", "(2,another,f)"}, + }, + }, + }, + }, + { + Name: "RETURNS TABLE with single field", + SetUpScript: []string{ + `CREATE FUNCTION func2() RETURNS TABLE(username text) + LANGUAGE plpgsql + AS $$ + BEGIN + RETURN QUERY SELECT 'username1'; + RETURN QUERY SELECT 'username2'; + END; + $$;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT func2();", + Expected: []sql.Row{ + {"(username1)"}, + {"(username2)"}, + }, + }, + { + Query: "SELECT func2(), func2();", + Expected: []sql.Row{ + {"(username1)", "(username1)"}, + {"(username2)", "(username2)"}, + }, + }, + }, + }, + { + Name: "RETURNS TABLE with types from other schema", + SetUpScript: []string{ + `CREATE SCHEMA sch1;`, + `CREATE TYPE sch1.mytype AS ( + user_id integer, + username text);`, + `CREATE FUNCTION func2() RETURNS TABLE(foo sch1.mytype) + LANGUAGE plpgsql + AS $$ + BEGIN + RETURN QUERY SELECT 1, 'username1'; + RETURN QUERY SELECT 2, 'username2'; + END; + $$;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT func2();", + Expected: []sql.Row{ + {"(1,username1)"}, + {"(2,username2)"}, + }, + }, + { + Query: "SELECT func2(), func2();", + Expected: []sql.Row{ + {"(1,username1)", "(1,username1)"}, + {"(2,username2)", "(2,username2)"}, + }, + }, + }, + }, + { + Name: "RETURNS TABLE with param", + SetUpScript: []string{ + `CREATE OR REPLACE FUNCTION func3(user_id integer) RETURNS TABLE(user_id integer, username text, is_active boolean) + LANGUAGE plpgsql + AS $$ + BEGIN + RETURN QUERY SELECT user_id, 'username', true; + RETURN QUERY SELECT user_id, 'another', false; + END; + $$;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT func3(111);", + Expected: []sql.Row{ + {"(111,username,t)"}, + {"(111,another,f)"}, + }, + }, + { + Query: "SELECT func3(111), func3(222);", + Expected: []sql.Row{ + {"(111,username,t)", "(222,username,t)"}, + {"(111,another,f)", "(222,another,f)"}, + }, + }, + }, + }, { Name: "RETURNS SETOF with composite param", SetUpScript: []string{ @@ -1108,5 +1262,119 @@ $$;`, }, }, }, + { + Name: "AlexTransit_venderctl import dump", + SetUpScript: []string{ + `CREATE TYPE public.tax_job_state AS ENUM ( + 'sched', + 'busy', + 'final', + 'help' +);`, + `CREATE TABLE public.catalog ( + vmid integer NOT NULL, + code text NOT NULL, + name text NOT NULL +);`, + `CREATE SEQUENCE public.tax_job_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1;`, + `CREATE TABLE public.tax_job ( + id bigint NOT NULL, + state public.tax_job_state NOT NULL, + created timestamp with time zone NOT NULL, + modified timestamp with time zone NOT NULL, + scheduled timestamp with time zone, + worker text, + processor text, + ext_id text, + data jsonb, + gross integer, + notes text[], + ops jsonb +);`, + `CREATE TABLE public.trans ( + vmid integer NOT NULL, + vmtime timestamp with time zone, + received timestamp with time zone NOT NULL, + menu_code text NOT NULL, + options integer[], + price integer NOT NULL, + method integer NOT NULL, + tax_job_id bigint, + executer bigint, + exeputer_type integer, + executer_str text +);`, + `ALTER TABLE ONLY public.tax_job ALTER COLUMN id SET DEFAULT nextval('public.tax_job_id_seq'::regclass);`, + `INSERT INTO public.trans VALUES (1, '2023-04-05 06:07:08', '2023-05-06 07:08:09', 'test', ARRAY[5,7], 44, 1, NULL, 1, 1, '');`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `CREATE FUNCTION public.tax_job_trans(t public.trans) RETURNS public.tax_job + LANGUAGE plpgsql + AS ' + # print_strict_params ON +DECLARE + tjd jsonb; + ops jsonb; + tj tax_job; + name text; +BEGIN + -- lock trans row + PERFORM + 1 + FROM + trans + WHERE (vmid, vmtime) = (t.vmid, + t.vmtime) +LIMIT 1 +FOR UPDATE; + -- if trans already has tax_job assigned, just return it + IF t.tax_job_id IS NOT NULL THEN + SELECT + * INTO STRICT tj + FROM + tax_job + WHERE + id = t.tax_job_id; + RETURN tj; + END IF; + -- op code to human friendly name via catalog + SELECT + catalog.name INTO name + FROM + catalog + WHERE (vmid, code) = (t.vmid, + t.menu_code); + IF NOT found THEN + name := ''#'' || t.menu_code; + END IF; + ops := jsonb_build_array (jsonb_build_object(''vmid'', t.vmid, ''time'', t.vmtime, ''name'', name, ''code'', t.menu_code, ''amount'', 1, ''price'', t.price, ''method'', t.method)); + INSERT INTO tax_job (state, created, modified, scheduled, processor, ops, gross) + VALUES (''sched'', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, ''ru2019'', ops, t.price) + RETURNING + * INTO STRICT tj; + UPDATE + trans + SET + tax_job_id = tj.id + WHERE (vmid, vmtime) = (t.vmid, + t.vmtime); + RETURN tj; +END; +';`, + Expected: []sql.Row{}, + }, + { + Query: `SELECT public.tax_job_trans(trans.*) FROM public.trans;`, + Skip: true, // TODO: implement table.* syntax + Expected: []sql.Row{{`(1,sched,"2026-01-23 14:06:32.794817+00","2026-01-23 14:06:32.794817+00","2026-01-23 14:06:32.794817+00",,ru2019,,,44,,"[{""code"": ""test"", ""name"": ""#test"", ""time"": ""2023-04-05T06:07:08+00:00"", ""vmid"": 1, ""price"": 44, ""amount"": 1, ""method"": 1}]")`}}, + }, + }, + }, }) } diff --git a/testing/go/issues_test.go b/testing/go/issues_test.go index f1a9f12ce2..ae80326f3d 100644 --- a/testing/go/issues_test.go +++ b/testing/go/issues_test.go @@ -196,5 +196,38 @@ limit 1`, }, }, }, + { + Name: "Issue #2197 Part 1", + SetUpScript: []string{ + `CREATE TABLE t1 (a INT, b VARCHAR(3));`, + `CREATE TABLE t2(id SERIAL, t1 t1);`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `INSERT INTO t2(t1) VALUES (ROW(1, 'abc'));`, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM t2;`, + Expected: []sql.Row{{1, "(1,abc)"}}, + }, + { + Query: `INSERT INTO t2(t1) VALUES (ROW('a', 'def'));`, + ExpectedErr: "invalid input syntax for type", + }, + { + Query: `INSERT INTO t2(t1) VALUES (ROW(true, 'def'));`, + ExpectedErr: "Cannot cast type", + }, + { + Query: `INSERT INTO t2(t1) VALUES (ROW(2, 'def', 'ghi'));`, + ExpectedErr: "cannot cast type", + }, + { + Query: `INSERT INTO t2(t1) VALUES (ROW(2));`, + ExpectedErr: "cannot cast type", + }, + }, + }, }) }