Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 36 additions & 17 deletions core/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
// and may be refreshed at any point, including during the middle of a query. Callers should not assume that
// data stored in contextValues is persisted, and other types of data should not be added to contextValues.
type contextValues struct {
collection *sequences.Collection
seqs *sequences.Collection
types *typecollection.TypeCollection
funcs *functions.Collection
pgCatalogCache any
Expand Down Expand Up @@ -206,12 +206,17 @@ func GetFunctionsCollectionFromContext(ctx *sql.Context) (*functions.Collection,
if err != nil {
return nil, err
}
_, root, err := getRootFromContext(ctx)
if err != nil {
return nil, err
}
if cv.funcs == nil {
_, root, err := getRootFromContext(ctx)
cv.funcs, err = functions.LoadFunctions(ctx, root)
if err != nil {
return nil, err
}
cv.funcs, err = root.GetFunctions(ctx)
} else if cv.funcs.DiffersFrom(ctx, root) {
cv.funcs, err = functions.LoadFunctions(ctx, root)
if err != nil {
return nil, err
}
Expand All @@ -226,17 +231,17 @@ func GetSequencesCollectionFromContext(ctx *sql.Context) (*sequences.Collection,
if err != nil {
return nil, err
}
if cv.collection == nil {
if cv.seqs == nil {
_, root, err := getRootFromContext(ctx)
if err != nil {
return nil, err
}
cv.collection, err = root.GetSequences(ctx)
cv.seqs, err = sequences.LoadSequences(ctx, root)
if err != nil {
return nil, err
}
}
return cv.collection, nil
return cv.seqs, nil
}

// GetTypesCollectionFromContext returns the given type collection from the context.
Expand All @@ -251,7 +256,7 @@ func GetTypesCollectionFromContext(ctx *sql.Context) (*typecollection.TypeCollec
if err != nil {
return nil, err
}
cv.types, err = root.GetTypes(ctx)
cv.types, err = typecollection.LoadTypes(ctx, root)
if err != nil {
return nil, err
}
Expand All @@ -270,22 +275,36 @@ func CloseContextRootFinalizer(ctx *sql.Context) error {
if !ok {
return nil
}
if cv.collection == nil {
return nil
}
session, root, err := getRootFromContext(ctx)
if err != nil {
return err
}
newRoot, err := root.PutSequences(ctx, cv.collection)
if err != nil {
return err
newRoot := root
if cv.seqs != nil {
retRoot, err := cv.seqs.UpdateRoot(ctx, newRoot)
if err != nil {
return err
}
newRoot = retRoot.(*RootValue)
cv.seqs = nil
}
newRoot, err = newRoot.PutFunctions(ctx, cv.funcs)
if err != nil {
return err
if cv.funcs != nil && cv.funcs.DiffersFrom(ctx, root) {
retRoot, err := cv.funcs.UpdateRoot(ctx, newRoot)
if err != nil {
return err
}
newRoot = retRoot.(*RootValue)
cv.funcs = nil
}
if cv.types != nil {
retRoot, err := cv.types.UpdateRoot(ctx, newRoot)
if err != nil {
return err
}
newRoot = retRoot.(*RootValue)
cv.types = nil
}
if newRoot != nil {
if newRoot != root {
if err = session.SetWorkingRoot(ctx, ctx.GetCurrentDatabase(), newRoot); err != nil {
// TODO: We need a way to see if the session has a writeable working root
// (new interface method on session probably), and avoid setting it if so
Expand Down
Loading