diff --git a/server/functions/framework/overloads.go b/server/functions/framework/overloads.go index 51cbbb7c4b..a5d65f248b 100644 --- a/server/functions/framework/overloads.go +++ b/server/functions/framework/overloads.go @@ -86,9 +86,21 @@ func (o *Overloads) overloadsForParams(numParams int) []Overload { // parameter count from the target parameter count to obtain the additional parameter count. firstValueAfterVariadic := variadicIndex + 1 + (numParams - len(params)) copy(extendedParams[firstValueAfterVariadic:], params[variadicIndex+1:]) - // ToArrayType immediately followed by BaseType is a way to get the base type without having to cast. - // For array types, ToArrayType causes them to return themselves. - variadicBaseType := overload.GetParameters()[variadicIndex].ToArrayType().ArrayBaseType() + paramType := overload.GetParameters()[variadicIndex] + + var variadicBaseType *pgtypes.DoltgresType + + // special case: anyArray takes any args, pass as is + if paramType == pgtypes.AnyArray { + for variadicParamIdx := 0; variadicParamIdx < 1+(numParams-len(params)); variadicParamIdx++ { + variadicBaseType = pgtypes.AnyElement + } + } else { + // ToArrayType immediately followed by BaseType is a way to get the base type without having to cast. + // For array types, ToArrayType causes them to return themselves. + variadicBaseType = paramType.ToArrayType().ArrayBaseType() + } + for variadicParamIdx := 0; variadicParamIdx < 1+(numParams-len(params)); variadicParamIdx++ { extendedParams[variadicParamIdx+variadicIndex] = variadicBaseType } diff --git a/server/functions/json.go b/server/functions/json.go index cd81f53e8c..b2774c2e1c 100644 --- a/server/functions/json.go +++ b/server/functions/json.go @@ -15,6 +15,7 @@ package functions import ( + "fmt" "unsafe" "github.com/dolthub/go-mysql-server/sql" @@ -30,6 +31,8 @@ func initJson() { framework.RegisterFunction(json_out) framework.RegisterFunction(json_recv) framework.RegisterFunction(json_send) + framework.RegisterFunction(json_build_array) + framework.RegisterFunction(json_build_object) } // json_in represents the PostgreSQL function of json type IO input. @@ -83,3 +86,50 @@ var json_send = framework.Function1{ return []byte(val.(string)), nil }, } + +// json_build_array represents the PostgreSQL function json_build_array. +var json_build_array = framework.Function1{ + Name: "json_build_array", + Return: pgtypes.Json, + Parameters: [1]*pgtypes.DoltgresType{pgtypes.AnyArray}, + Variadic: true, + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { + inputArray := val1.([]any) + json, err := json.Marshal(inputArray) + return string(json), err + }, +} + +// json_build_object represents the PostgreSQL function json_build_object. +var json_build_object = framework.Function1{ + Name: "json_build_object", + Return: pgtypes.Json, + Parameters: [1]*pgtypes.DoltgresType{pgtypes.AnyArray}, + Variadic: true, + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { + inputArray := val1.([]any) + if len(inputArray)%2 != 0 { + return nil, sql.ErrInvalidArgumentNumber.New("json_build_object", "even number of arguments", len(inputArray)) + } + jsonObject := make(map[string]any) + var key string + for i, e := range inputArray { + if i%2 == 0 { + var ok bool + key, ok = e.(string) + if !ok { + // TODO: This isn't correct for every type we might use as a value. To get better type info to transform + // every value into its string format, we need to pass detailed arg type info for the vararg params (the + // unused param in the function call). + key = fmt.Sprintf("%v", e) + } + } else { + jsonObject[key] = e + key = "" + } + } + + json, err := json.Marshal(jsonObject) + return string(json), err + }, +} diff --git a/server/types/array.go b/server/types/array.go index 3fa64fb4f9..c5f102a715 100644 --- a/server/types/array.go +++ b/server/types/array.go @@ -62,3 +62,10 @@ func CreateArrayTypeFromBaseType(baseType *DoltgresType) *DoltgresType { CompareFunc: toFuncID("btarraycmp", toInternal("anyarray"), toInternal("anyarray")), } } + +// LogicalArrayElementTypes is a map of array element types for particular array types where the logical type varies +// from the declared type, as needed. Some types that have a NULL element for pg_catalog compatibility have a logical +// type that we need during analysis for function calls. +var LogicalArrayElementTypes = map[id.Type]*DoltgresType{ + toInternal("anyarray"): AnyElement, +} diff --git a/server/types/type.go b/server/types/type.go index 31e9e0a3a2..8b952b08d6 100644 --- a/server/types/type.go +++ b/server/types/type.go @@ -109,10 +109,20 @@ func (t *DoltgresType) ArrayBaseType() *DoltgresType { if !t.IsArrayType() { return t } - elem, ok := IDToBuiltInDoltgresType[t.Elem] + + var elem *DoltgresType + var ok bool + + elem, ok = IDToBuiltInDoltgresType[t.Elem] if !ok { - panic(fmt.Sprintf("cannot get base type from: %s", t.Name())) + // Some array types have no declared element type for pg_catalog compatibilty, but still have a logical type + // we return for analysis + elem, ok = LogicalArrayElementTypes[t.ID] + if !ok { + panic(fmt.Sprintf("cannot get base type from: %s", t.Name())) + } } + newElem := *elem.WithAttTypMod(t.attTypMod) return &newElem } @@ -433,7 +443,8 @@ func (t *DoltgresType) IoOutput(ctx *sql.Context, val any) (string, error) { // IsArrayType returns true if the type is of 'array' category func (t *DoltgresType) IsArrayType() bool { - return t.TypCategory == TypeCategory_ArrayTypes && t.Elem != id.NullType + return (t.TypCategory == TypeCategory_ArrayTypes && t.Elem != id.NullType) || + (t.TypCategory == TypeCategory_PseudoTypes && t.ID.TypeName() == "anyarray") } // IsEmptyType returns true if the type is not valid. diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index e00bbb4f32..00c3e768b2 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -962,6 +962,51 @@ func TestSystemInformationFunctions(t *testing.T) { }) } +func TestJsonFunctions(t *testing.T) { + RunScripts(t, []ScriptTest{ + { + Name: "json_build_array", + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT json_build_array(1, 2, 3);`, + Cols: []string{"json_build_array"}, + Expected: []sql.Row{{`[1,2,3]`}}, + }, + { + Query: `SELECT json_build_array(1, '2', 3);`, + Cols: []string{"json_build_array"}, + Expected: []sql.Row{{`[1,"2",3]`}}, + }, + { + Query: `SELECT json_build_array();`, + Skip: true, // variadic functions can't handle 0 arguments right now + Cols: []string{"json_build_array"}, + Expected: []sql.Row{{`[]`}}, + }, + }, + }, + { + Name: "json_build_object", + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT json_build_object('a', 2, 'b', 4);`, + Cols: []string{"json_build_object"}, + Expected: []sql.Row{{`{"a":2,"b":4}`}}, + }, + { + Query: `SELECT json_build_object('a', 2, 'b');`, + ExpectedErr: "even number", + }, + { + Query: `SELECT json_build_object(1, 2, 'b', 3);`, + Cols: []string{"json_build_object"}, + Expected: []sql.Row{{`{"1":2,"b":3}`}}, + }, + }, + }, + }) +} + func TestArrayFunctions(t *testing.T) { RunScripts(t, []ScriptTest{ {