Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion core/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ package core

import (
"github.com/cockroachdb/errors"

"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve"
"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/core/functions"
"github.com/dolthub/doltgresql/core/sequences"
"github.com/dolthub/doltgresql/core/typecollection"
)
Expand All @@ -30,6 +30,7 @@ import (
type contextValues struct {
collection *sequences.Collection
types *typecollection.TypeCollection
funcs *functions.Collection
pgCatalogCache any
}

Expand Down Expand Up @@ -63,6 +64,17 @@ func getRootFromContext(ctx *sql.Context) (*dsess.DoltSession, *RootValue, error
return session, state.WorkingRoot().(*RootValue), nil
}

// IsContextValid returns whether the context is valid for use with any of the functions in the package. If this is not
// false, then there's a high likelihood that the context is being used in a temporary scenario (such as setting up the
// database, etc.).
func IsContextValid(ctx *sql.Context) bool {
if ctx == nil {
return false
}
_, ok := ctx.Session.(*dsess.DoltSession)
return ok
}

// GetPgCatalogCache returns a cache of data for pg_catalog tables. This function should only be used by
// pg_catalog table handlers. The catalog cache instance stores generated pg_catalog table data so that
// it only has to generated table data once per query.
Expand Down Expand Up @@ -185,6 +197,26 @@ func GetSqlTableFromContext(ctx *sql.Context, databaseName string, tableName dol
return nil, nil
}

// GetFunctionsCollectionFromContext returns the functions collection from the given context. Will always return a
// collection if no error is returned.
func GetFunctionsCollectionFromContext(ctx *sql.Context) (*functions.Collection, error) {
cv, err := getContextValues(ctx)
if err != nil {
return nil, err
}
if cv.funcs == nil {
_, root, err := getRootFromContext(ctx)
if err != nil {
return nil, err
}
cv.funcs, err = root.GetFunctions(ctx)
if err != nil {
return nil, err
}
}
return cv.funcs, nil
}

// GetSequencesCollectionFromContext returns the given sequence collection from the context. Will always return a collection if
// no error is returned.
func GetSequencesCollectionFromContext(ctx *sql.Context) (*sequences.Collection, error) {
Expand Down Expand Up @@ -247,6 +279,10 @@ func CloseContextRootFinalizer(ctx *sql.Context) error {
if err != nil {
return err
}
newRoot, err = newRoot.PutFunctions(ctx, cv.funcs)
if err != nil {
return err
}
if newRoot != nil {
if err = session.SetWorkingRoot(ctx, ctx.GetCurrentDatabase(), newRoot); err != nil {
// TODO: We need a way to see if the session has a writeable working root
Expand Down
128 changes: 128 additions & 0 deletions core/functions/function.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import (
"maps"
"sync"

"github.com/cockroachdb/errors"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/server/plpgsql"
)

// Collection contains a collection of functions.
type Collection struct {
funcMap map[id.Function]*Function
overloadMap map[id.Function][]*Function
mutex *sync.Mutex
}

// Function represents a created function.
type Function struct {
ID id.Function
ReturnType id.Type
ParameterNames []string
ParameterTypes []id.Type
Variadic bool
IsNonDeterministic bool
Strict bool
Operations []plpgsql.InterpreterOperation
}

// GetFunction returns the function with the given ID. Returns nil if the function cannot be found.
func (pgf *Collection) GetFunction(funcID id.Function) *Function {
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

if f, ok := pgf.funcMap[funcID]; ok {
return f
}
return nil
}

// GetFunctionOverloads returns the overloads for the function matching the schema and the function name. The parameter
// types are ignored when searching for overloads.
func (pgf *Collection) GetFunctionOverloads(funcID id.Function) []*Function {
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

funcNameOnly := id.NewFunction(funcID.SchemaName(), funcID.FunctionName())
return pgf.overloadMap[funcNameOnly]
}

// HasFunction returns whether the function is present.
func (pgf *Collection) HasFunction(funcID id.Function) bool {
return pgf.GetFunction(funcID) != nil
}

// AddFunction adds a new function.
func (pgf *Collection) AddFunction(f *Function) error {
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

if _, ok := pgf.funcMap[f.ID]; ok {
return errors.Errorf(`function "%s" already exists with same argument types`, f.ID.FunctionName())
}
pgf.funcMap[f.ID] = f
funcNameOnly := id.NewFunction(f.ID.SchemaName(), f.ID.FunctionName())
pgf.overloadMap[funcNameOnly] = append(pgf.overloadMap[funcNameOnly], f)
return nil
}

// DropFunction drops an existing function.
func (pgf *Collection) DropFunction(funcID id.Function) error {
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

if _, ok := pgf.funcMap[funcID]; ok {
delete(pgf.funcMap, funcID)
funcNameOnly := id.NewFunction(funcID.SchemaName(), funcID.FunctionName())
for i, f := range pgf.overloadMap[funcNameOnly] {
if f.ID == funcID {
pgf.overloadMap[funcNameOnly] = append(pgf.overloadMap[funcNameOnly][:i], pgf.overloadMap[funcNameOnly][i+1:]...)
break
}
}
return nil
}
return errors.Errorf(`function %s does not exist`, funcID.FunctionName())
}

// IterateFunctions iterates over all functions in the collection.
func (pgf *Collection) IterateFunctions(callback func(f *Function) error) error {
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

for _, f := range pgf.funcMap {
if err := callback(f); err != nil {
return err
}
}
return nil
}

// Clone returns a new *Collection with the same contents as the original.
func (pgf *Collection) Clone() *Collection {
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

return &Collection{
funcMap: maps.Clone(pgf.funcMap),
overloadMap: maps.Clone(pgf.overloadMap),
mutex: &sync.Mutex{},
}
}
39 changes: 39 additions & 0 deletions core/functions/merge.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import (
"context"

"github.com/cockroachdb/errors"
)

// Merge handles merging functions on our root and their root.
func Merge(ctx context.Context, ourCollection, theirCollection, ancCollection *Collection) (*Collection, error) {
mergedCollection := ourCollection.Clone()
err := theirCollection.IterateFunctions(func(theirFunc *Function) error {
// If we don't have the sequence, then we simply add it
if !mergedCollection.HasFunction(theirFunc.ID) {
newFunc := *theirFunc
return mergedCollection.AddFunction(&newFunc)
}
// TODO: figure out a decent merge strategy
return errors.Errorf(`unable to merge "%s"`, theirFunc.ID.AsId().String())
})
if err != nil {
return nil, err
}
return mergedCollection, nil
}
120 changes: 120 additions & 0 deletions core/functions/serialization.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import (
"context"
"sync"

"github.com/cockroachdb/errors"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/server/plpgsql"
"github.com/dolthub/doltgresql/utils"
)

// Serialize returns the Collection as a byte slice. If the Collection is nil, then this returns a nil slice.
func (pgf *Collection) Serialize(ctx context.Context) ([]byte, error) {
if pgf == nil {
return nil, nil
}
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

// Write all of the functions to the writer
writer := utils.NewWriter(256)
writer.VariableUint(0) // Version
funcIDs := utils.GetMapKeysSorted(pgf.funcMap)
writer.VariableUint(uint64(len(funcIDs)))
for _, funcID := range funcIDs {
f := pgf.funcMap[funcID]
writer.Id(f.ID.AsId())
writer.Id(f.ReturnType.AsId())
writer.StringSlice(f.ParameterNames)
writer.IdTypeSlice(f.ParameterTypes)
writer.Bool(f.Variadic)
writer.Bool(f.IsNonDeterministic)
writer.Bool(f.Strict)
// Write the operations
writer.VariableUint(uint64(len(f.Operations)))
for _, op := range f.Operations {
writer.Uint16(uint16(op.OpCode))
writer.String(op.PrimaryData)
writer.StringSlice(op.SecondaryData)
writer.String(op.Target)
writer.Int32(int32(op.Index))
}
}

return writer.Data(), nil
}

// Deserialize returns the Collection that was serialized in the byte slice. Returns an empty Collection if data is nil
// or empty.
func Deserialize(ctx context.Context, data []byte) (*Collection, error) {
if len(data) == 0 {
return &Collection{
funcMap: make(map[id.Function]*Function),
overloadMap: make(map[id.Function][]*Function),
mutex: &sync.Mutex{},
}, nil
}
funcMap := make(map[id.Function]*Function)
overloadMap := make(map[id.Function][]*Function)
reader := utils.NewReader(data)
version := reader.VariableUint()
if version != 0 {
return nil, errors.Errorf("version %d of functions is not supported, please upgrade the server", version)
}

// Read from the reader
numOfFunctions := reader.VariableUint()
for i := uint64(0); i < numOfFunctions; i++ {
f := &Function{}
f.ID = id.Function(reader.Id())
f.ReturnType = id.Type(reader.Id())
f.ParameterNames = reader.StringSlice()
f.ParameterTypes = reader.IdTypeSlice()
f.Variadic = reader.Bool()
f.IsNonDeterministic = reader.Bool()
f.Strict = reader.Bool()
// Read the operations
opCount := reader.VariableUint()
f.Operations = make([]plpgsql.InterpreterOperation, opCount)
for opIdx := uint64(0); opIdx < opCount; opIdx++ {
op := plpgsql.InterpreterOperation{}
op.OpCode = plpgsql.OpCode(reader.Uint16())
op.PrimaryData = reader.String()
op.SecondaryData = reader.StringSlice()
op.Target = reader.String()
op.Index = int(reader.Int32())
f.Operations[opIdx] = op
}
// Add the function to each map
funcMap[f.ID] = f
funcNameOnly := id.NewFunction(f.ID.SchemaName(), f.ID.FunctionName())
overloadMap[funcNameOnly] = append(overloadMap[funcNameOnly], f)
}
if !reader.IsEmpty() {
return nil, errors.Errorf("extra data found while deserializing functions")
}

// Return the deserialized object
return &Collection{
funcMap: funcMap,
overloadMap: overloadMap,
mutex: &sync.Mutex{},
}, nil
}
2 changes: 2 additions & 0 deletions core/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/dolthub/dolt/go/store/types"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/server/plpgsql"
)

// Init initializes this package.
Expand All @@ -27,5 +28,6 @@ func Init() {
doltdb.NewRootValue = newRootValue
types.DoltgresRootValueHumanReadableStringAtIndentationLevel = rootValueHumanReadableStringAtIndentationLevel
types.DoltgresRootValueWalkAddrs = rootValueWalkAddrs
plpgsql.GetTypesCollectionFromContext = GetTypesCollectionFromContext
id.RegisterListener(sequenceIDListener{}, id.Section_Table)
}
Loading