Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions core/typecollection/typecollection.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ import (
pgtypes "github.com/dolthub/doltgresql/server/types"
)

// anonymousCompositePrefix is the prefix for anonymous composite type names. These types are not stored on
// disk, but instead are created dynamically as needed.
const anonymousCompositePrefix = "table("

// anonymousCompositeSuffix is the suffix for anonymous composite type names.
const anonymousCompositeSuffix = ")"

// TypeCollection is a collection of all types (both built-in and user defined).
type TypeCollection struct {
accessedMap map[id.Type]*pgtypes.DoltgresType
Expand Down Expand Up @@ -165,6 +172,11 @@ func (pgs *TypeCollection) GetType(ctx context.Context, name id.Type) (*pgtypes.
return nil, err
}
if h.IsEmpty() {
// If this is an anonymous composite type, create it dynamically
if isAnonymousCompositeType(name) {
return createAnonymousCompositeType(ctx, name)
}

// If it's not a built-in type or created type, then check if it's a composite table row type
sqlCtx, ok := ctx.(*sql.Context)
if !ok {
Expand All @@ -189,6 +201,34 @@ func (pgs *TypeCollection) GetType(ctx context.Context, name id.Type) (*pgtypes.
return pgt, nil
}

// isAnonymousCompositeType return true if |returnType| represents an anonymous composite return type
// for a function (i.e. the function was declared as "RETURNS TABLE(...)").
func isAnonymousCompositeType(returnType id.Type) bool {
typeName := returnType.TypeName()
return strings.HasPrefix(typeName, anonymousCompositePrefix) &&
strings.HasSuffix(typeName, anonymousCompositeSuffix)
}

// createAnonymousCompositeType creates a new DoltgresType for the anonymous composite return type for a function,
// as represented by |returnType|.
func createAnonymousCompositeType(ctx context.Context, returnType id.Type) (*pgtypes.DoltgresType, error) {
typeName := returnType.TypeName()
attributeTypes := typeName[len(anonymousCompositePrefix) : len(typeName)-len(anonymousCompositeSuffix)]
attributeTypesSlice := strings.Split(attributeTypes, ",")

attrs := make([]pgtypes.CompositeAttribute, len(attributeTypesSlice))
for i, attributeNameAndType := range attributeTypesSlice {
split := strings.Split(attributeNameAndType, ":")
if len(split) != 2 {
return nil, fmt.Errorf("unexpected anonymous composite type attribute syntax: %s", attributeNameAndType)
}

typeId := id.NewType("", split[1])
attrs[i] = pgtypes.NewCompositeAttribute(nil, id.Null, split[0], typeId, int16(i), "")
}
return pgtypes.NewCompositeType(ctx, id.Null, id.NullType, returnType, attrs), nil
}

// HasType checks if a type exists with given schema and type name.
func (pgs *TypeCollection) HasType(ctx context.Context, name id.Type) bool {
// We can check the built-in types first
Expand Down
8 changes: 4 additions & 4 deletions postgres/parser/parser/sql.y
Original file line number Diff line number Diff line change
Expand Up @@ -4313,11 +4313,11 @@ create_function_stmt:
}
| CREATE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS SETOF typename create_function_option_list
{
$$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), SetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $7.typeReference()}}, Options: $8.routineOptions()}
$$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), ReturnsSetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $7.typeReference()}}, Options: $8.routineOptions()}
}
| CREATE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS TABLE '(' opt_returns_table_col_def_list ')' create_function_option_list
{
$$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), RetType: $8.simpleColumnDefs(), Options: $10.routineOptions()}
$$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), ReturnsTable: true, RetType: $8.simpleColumnDefs(), Options: $10.routineOptions()}
}
| CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list create_function_option_list
{
Expand All @@ -4329,11 +4329,11 @@ create_function_stmt:
}
| CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS SETOF typename create_function_option_list
{
$$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), SetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $9.typeReference()}}, Options: $10.routineOptions()}
$$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), ReturnsSetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $9.typeReference()}}, Options: $10.routineOptions()}
}
| CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS TABLE '(' opt_returns_table_col_def_list ')' create_function_option_list
{
$$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), RetType: $10.simpleColumnDefs(), Options: $12.routineOptions()}
$$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), ReturnsTable: true, RetType: $10.simpleColumnDefs(), Options: $12.routineOptions()}
}

opt_returns_table_col_def_list:
Expand Down
17 changes: 9 additions & 8 deletions postgres/parser/sem/tree/create_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ var _ Statement = &CreateFunction{}

// CreateFunction represents a CREATE FUNCTION statement.
type CreateFunction struct {
Name *UnresolvedObjectName
Replace bool
Args RoutineArgs
SetOf bool
RetType []SimpleColumnDef
Options []RoutineOption
Name *UnresolvedObjectName
Replace bool
Args RoutineArgs
ReturnsSetOf bool
ReturnsTable bool
RetType []SimpleColumnDef
Options []RoutineOption
}

// Format implements the NodeFormatter interface.
Expand All @@ -47,9 +48,9 @@ func (node *CreateFunction) Format(ctx *FmtCtx) {
ctx.WriteString(" )")
}
if node.RetType != nil {
if len(node.RetType) == 1 && node.RetType[0].Name == "" {
if !node.ReturnsTable {
ctx.WriteString("RETURNS ")
if node.SetOf {
if node.ReturnsSetOf {
ctx.WriteString("SETOF ")
}
ctx.WriteString(node.RetType[0].Type.SQLString())
Expand Down
57 changes: 50 additions & 7 deletions server/ast/create_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
package ast

import (
"context"
"fmt"
"strings"

"github.com/cockroachdb/errors"
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/postgres/parser/parser"
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
"github.com/dolthub/doltgresql/postgres/parser/types"
Expand All @@ -38,27 +40,41 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
}
// Grab the general information that we'll need to create the function
tableName := node.Name.ToTableName()
retType := pgtypes.Void
if len(node.RetType) == 1 {
var retType *pgtypes.DoltgresType
if len(node.RetType) == 0 {
retType = pgtypes.Void
} else if !node.ReturnsTable { // Return types may specify "trigger", but this doesn't apply elsewhere
switch typ := node.RetType[0].Type.(type) {
case *types.T:
retType = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(typ.Name()))
default:
sqlString := strings.ToLower(typ.SQLString())
if sqlString == "trigger" {
case *tree.UnresolvedObjectName:
if typ.NumParts == 1 && typ.SQLString() == "trigger" {
retType = pgtypes.Trigger
} else {
retType = pgtypes.NewUnresolvedDoltgresType("", sqlString)
_, retType, err = nodeResolvableTypeReference(ctx, typ)
if err != nil {
return nil, err
}
}
default:
return nil, fmt.Errorf("unsupported ResolvableTypeReference type: %T", typ)
}
} else {
retType = createAnonymousCompositeType(node.RetType)
}

paramNames := make([]string, len(node.Args))
paramTypes := make([]*pgtypes.DoltgresType, len(node.Args))
for i, arg := range node.Args {
paramNames[i] = arg.Name.String()
switch argType := arg.Type.(type) {
case *types.T:
paramTypes[i] = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.Name()))
case *tree.UnresolvedObjectName:
_, paramTypes[i], err = nodeResolvableTypeReference(ctx, argType)
if err != nil {
return nil, err
}
default:
paramTypes[i] = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(argType.SQLString()))
}
Expand Down Expand Up @@ -121,11 +137,38 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
parsedBody,
sqlDef,
sqlDefParsed,
node.SetOf,
node.ReturnsSetOf,
),
}, nil
}

// createAnonymousCompositeType creates a new DoltgresType for the anonymous composite return
// type for a function, as represented by the |fieldTypes| that were specified in the function
// definition.
func createAnonymousCompositeType(fieldTypes []tree.SimpleColumnDef) *pgtypes.DoltgresType {
attrs := make([]pgtypes.CompositeAttribute, len(fieldTypes))
for i, fieldType := range fieldTypes {
attrs[i] = pgtypes.NewCompositeAttribute(nil, id.Null, fieldType.Name.String(),
id.NewType("", fieldType.Type.SQLString()), int16(i), "")
}

typeIdString := "table("
for i, attr := range attrs {
if i > 0 {
typeIdString += ","
}
typeIdString += attr.Name
typeIdString += ":"
typeIdString += attr.TypeID.TypeName()
}
typeIdString += ")"

// NOTE: there is no schema needed, since these types are anonymous and can't be directly referenced
typeId := id.NewType("", typeIdString)

return pgtypes.NewCompositeType(context.Background(), id.Null, id.NullType, typeId, attrs)
}

// handleLanguageSQL handles parsing SQL definition strings in both CREATE FUNCTION and CREATE PROCEDURE.
func handleLanguageSQL(definition string, paramNames []string, paramTypes []*pgtypes.DoltgresType) (string, vitess.Statement, error) {
stmt, err := parser.ParseOne(definition)
Expand Down
8 changes: 2 additions & 6 deletions server/expression/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,7 @@ func (array *Array) Eval(ctx *sql.Context, row sql.Row) (any, error) {
// We always cast the element, as there may be parameter restrictions in place
castFunc := framework.GetImplicitCast(doltgresType, resultTyp)
if castFunc == nil {
if doltgresType.ID == pgtypes.Unknown.ID {
castFunc = framework.UnknownLiteralCast
} else {
return nil, errors.Errorf("cannot find cast function from %s to %s", doltgresType.String(), resultTyp.String())
}
return nil, errors.Errorf("cannot find cast function from %s to %s", doltgresType.String(), resultTyp.String())
}

values[i], err = castFunc(ctx, val, resultTyp)
Expand Down Expand Up @@ -175,7 +171,7 @@ func (array *Array) getTargetType(children ...sql.Expression) (*pgtypes.Doltgres
childrenTypes = append(childrenTypes, childType)
}
}
targetType, err := framework.FindCommonType(childrenTypes)
targetType, _, err := framework.FindCommonType(childrenTypes)
if err != nil {
return nil, errors.Errorf("ARRAY %s", err.Error())
}
Expand Down
8 changes: 2 additions & 6 deletions server/expression/assignment_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,8 @@ func (ac *AssignmentCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
}
castFunc := framework.GetAssignmentCast(ac.fromType, ac.toType)
if castFunc == nil {
if ac.fromType.ID == pgtypes.Unknown.ID {
castFunc = framework.UnknownLiteralCast
} else {
return nil, errors.Errorf("ASSIGNMENT_CAST: target is of type %s but expression is of type %s: %s",
ac.toType.String(), ac.fromType.String(), ac.expr.String())
}
return nil, errors.Errorf("ASSIGNMENT_CAST: target is of type %s but expression is of type %s: %s",
ac.toType.String(), ac.fromType.String(), ac.expr.String())
}
return castFunc(ctx, val, ac.toType)
}
Expand Down
52 changes: 2 additions & 50 deletions server/expression/explicit_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/dolthub/go-mysql-server/sql/expression"
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
)
Expand Down Expand Up @@ -97,55 +96,8 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
baseCastToType := checkForDomainType(c.castToType)
castFunction := framework.GetExplicitCast(fromType, baseCastToType)
if castFunction == nil {
if fromType.ID == pgtypes.Unknown.ID {
castFunction = framework.UnknownLiteralCast
} else if fromType.IsRecordType() && c.castToType.IsCompositeType() { // TODO: should this only be in explicit, or assignment and implicit too?
// Casting to a record type will always work for any composite type.
// TODO: is the above statement true for all cases?
// When casting to a composite type, then we must match the arity and have valid casts for every position.
if c.castToType.IsRecordType() {
castFunction = framework.IdentityCast
} else {
castFunction = func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) {
vals, ok := val.([]pgtypes.RecordValue)
if !ok {
// TODO: better error message
return nil, errors.New("casting input error from record type")
}
if len(targetType.CompositeAttrs) != len(vals) {
return nil, errors.Newf("cannot cast type %s to %s", "", targetType.Name())
}
typeCollection, err := core.GetTypesCollectionFromContext(ctx)
if err != nil {
return nil, err
}
outputVals := make([]pgtypes.RecordValue, len(vals))
for i := range vals {
valType, ok := vals[i].Type.(*pgtypes.DoltgresType)
if !ok {
// TODO: if this is a GMS type, then we should cast to a Doltgres type here
return nil, errors.New("cannot cast record containing GMS type")
}
outputVals[i].Type, err = typeCollection.GetType(ctx, targetType.CompositeAttrs[i].TypeID)
if err != nil {
return nil, err
}
innerExplicit := ExplicitCast{
sqlChild: NewUnsafeLiteral(vals[i].Value, valType),
castToType: outputVals[i].Type.(*pgtypes.DoltgresType),
}
outputVals[i].Value, err = innerExplicit.Eval(ctx, nil)
if err != nil {
return nil, err
}
}
return outputVals, nil
}
}
} else {
return nil, errors.Errorf("EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s",
fromType.String(), c.castToType.String(), c.sqlChild.String())
}
return nil, errors.Errorf("EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s",
fromType.String(), c.castToType.String(), c.sqlChild.String())
}
castResult, err := castFunction(ctx, val, c.castToType)
if err != nil {
Expand Down
Loading