Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions server/functions/framework/overloads.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,21 @@ func (o *Overloads) overloadsForParams(numParams int) []Overload {
// parameter count from the target parameter count to obtain the additional parameter count.
firstValueAfterVariadic := variadicIndex + 1 + (numParams - len(params))
copy(extendedParams[firstValueAfterVariadic:], params[variadicIndex+1:])
// ToArrayType immediately followed by BaseType is a way to get the base type without having to cast.
// For array types, ToArrayType causes them to return themselves.
variadicBaseType := overload.GetParameters()[variadicIndex].ToArrayType().ArrayBaseType()
paramType := overload.GetParameters()[variadicIndex]

var variadicBaseType *pgtypes.DoltgresType

// special case: anyArray takes any args, pass as is
if paramType == pgtypes.AnyArray {
for variadicParamIdx := 0; variadicParamIdx < 1+(numParams-len(params)); variadicParamIdx++ {
variadicBaseType = pgtypes.AnyElement
}
} else {
// ToArrayType immediately followed by BaseType is a way to get the base type without having to cast.
// For array types, ToArrayType causes them to return themselves.
variadicBaseType = paramType.ToArrayType().ArrayBaseType()
}

for variadicParamIdx := 0; variadicParamIdx < 1+(numParams-len(params)); variadicParamIdx++ {
extendedParams[variadicParamIdx+variadicIndex] = variadicBaseType
}
Expand Down
48 changes: 48 additions & 0 deletions server/functions/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package functions

import (
"fmt"
"unsafe"

"github.com/dolthub/go-mysql-server/sql"
Expand All @@ -30,6 +31,8 @@ func initJson() {
framework.RegisterFunction(json_out)
framework.RegisterFunction(json_recv)
framework.RegisterFunction(json_send)
framework.RegisterFunction(json_build_array)
framework.RegisterFunction(json_build_object)
}

// json_in represents the PostgreSQL function of json type IO input.
Expand Down Expand Up @@ -83,3 +86,48 @@ var json_send = framework.Function1{
return []byte(val.(string)), nil
},
}

// json_build_array represents the PostgreSQL function json_build_array.
var json_build_array = framework.Function1{
Name: "json_build_array",
Return: pgtypes.Json,
Parameters: [1]*pgtypes.DoltgresType{pgtypes.AnyArray},
Variadic: true,
Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) {
inputArray := val1.([]any)
json, err := json.Marshal(inputArray)
return string(json), err
},
}

// json_build_object represents the PostgreSQL function json_build_object.
var json_build_object = framework.Function1{
Name: "json_build_object",
Return: pgtypes.Json,
Parameters: [1]*pgtypes.DoltgresType{pgtypes.AnyArray},
Variadic: true,
Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) {
inputArray := val1.([]any)
if len(inputArray)%2 != 0 {
return nil, sql.ErrInvalidArgumentNumber.New("json_build_object", "even number of arguments", len(inputArray))
}
jsonObject := make(map[string]any)
var key string
for _, e := range inputArray {
if key == "" {
var ok bool
key, ok = e.(string)
if !ok {
// TODO: not clear this is the correct approach for all values, may need special handling for some of them
key = fmt.Sprintf("%v", e)
Comment thread
zachmu marked this conversation as resolved.
Outdated
}
} else {
jsonObject[key] = e
key = ""
}
}

json, err := json.Marshal(jsonObject)
return string(json), err
},
}
2 changes: 1 addition & 1 deletion server/types/any_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ var AnyArray = &DoltgresType{
Delimiter: ",",
RelID: id.Null,
SubscriptFunc: toFuncID("-"),
Elem: id.NullType,
Elem: AnyElement.ID,
Array: id.NullType,
InputFunc: toFuncID("anyarray_in", toInternal("cstring")),
OutputFunc: toFuncID("anyarray_out", toInternal("anyarray")),
Expand Down
3 changes: 2 additions & 1 deletion server/types/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,8 @@ func (t *DoltgresType) IoOutput(ctx *sql.Context, val any) (string, error) {

// IsArrayType returns true if the type is of 'array' category
func (t *DoltgresType) IsArrayType() bool {
return t.TypCategory == TypeCategory_ArrayTypes && t.Elem != id.NullType
return (t.TypCategory == TypeCategory_ArrayTypes && t.Elem != id.NullType) ||
(t.TypCategory == TypeCategory_PseudoTypes && t.ID.TypeName() == "anyarray")
}

// IsEmptyType returns true if the type is not valid.
Expand Down
45 changes: 45 additions & 0 deletions testing/go/functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,51 @@ func TestSystemInformationFunctions(t *testing.T) {
})
}

func TestJsonFunctions(t *testing.T) {
RunScripts(t, []ScriptTest{
{
Name: "json_build_array",
Assertions: []ScriptTestAssertion{
{
Query: `SELECT json_build_array(1, 2, 3);`,
Cols: []string{"json_build_array"},
Expected: []sql.Row{{`[1,2,3]`}},
},
{
Query: `SELECT json_build_array(1, '2', 3);`,
Cols: []string{"json_build_array"},
Expected: []sql.Row{{`[1,"2",3]`}},
},
{
Query: `SELECT json_build_array();`,
Skip: true, // variadic functions can't handle 0 arguments right now
Cols: []string{"json_build_array"},
Expected: []sql.Row{{`[]`}},
},
},
},
{
Name: "json_build_object",
Assertions: []ScriptTestAssertion{
{
Query: `SELECT json_build_object('a', 2, 'b', 4);`,
Cols: []string{"json_build_object"},
Expected: []sql.Row{{`{"a":2,"b":4}`}},
},
{
Query: `SELECT json_build_object('a', 2, 'b');`,
ExpectedErr: "even number",
},
{
Query: `SELECT json_build_object(1, 2, 'b', 3);`,
Cols: []string{"json_build_object"},
Expected: []sql.Row{{`{"1":2,"b":3}`}},
},
},
},
})
}

func TestArrayFunctions(t *testing.T) {
RunScripts(t, []ScriptTest{
{
Expand Down