Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions core/table_to_dolt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package core

import (
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/plan"
)

// SQLNodeToDoltTable takes a sql.Node and returns a *sqle.DoltTable if either the node is a Dolt table, or it is a
// wrapper or container that holds a Dolt table. Returns nil if a Dolt table could not be found. If the node is not a
// sql.Table, then this will return nil.
func SQLNodeToDoltTable(n sql.Node) *sqle.DoltTable {
tbl, ok := n.(sql.Table)
if !ok {
return nil
}
return SQLTableToDoltTable(tbl)
}

// SQLTableToDoltTable takes a sql.Table and returns a *sqle.DoltTable if either the table is a Dolt table, or it is a
// wrapper or container that holds a Dolt table. Returns nil if a Dolt table could not be found.
func SQLTableToDoltTable(tbl sql.Table) *sqle.DoltTable {
switch t := tbl.(type) {
case *plan.ResolvedTable:
return SQLTableToDoltTable(t.Table)
case *plan.ProcessTable:
return SQLTableToDoltTable(t.Table)
case *plan.IndexedTableAccess:
return SQLTableToDoltTable(t.Table)
case *plan.ProcedureResolvedTable:
return SQLTableToDoltTable(t.ResolvedTable.Table)
case *sqle.WritableIndexedDoltTable:
return t.WritableDoltTable.DoltTable
case *sqle.IndexedDoltTable:
return t.DoltTable
case *sqle.AlterableDoltTable:
return t.WritableDoltTable.DoltTable
case *sqle.WritableDoltTable:
return t.DoltTable
case *sqle.DoltTable:
return t
default:
if wrapper, ok := tbl.(sql.TableWrapper); ok {
return SQLTableToDoltTable(wrapper.Underlying())
}
return nil
}
}
2 changes: 1 addition & 1 deletion server/ast/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func nodeCall(ctx *Context, node *tree.Call) (vitess.Statement, error) {
if node.Procedure.WindowDef != nil {
return nil, errors.Errorf("procedure window definitions are not yet supported")
}
if node.Procedure.AggType != tree.GeneralAgg {
if node.Procedure.AggType == tree.OrderedSetAgg {
return nil, errors.Errorf("procedure aggregation is not yet supported")
}
if len(node.Procedure.OrderBy) > 0 {
Expand Down
2 changes: 2 additions & 0 deletions server/ast/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ func Convert(postgresStmt parser.Statement) (vitess.Statement, error) {
return nodeDropFunction(ctx, stmt)
case *tree.DropIndex:
return nodeDropIndex(ctx, stmt)
case *tree.DropProcedure:
return nodeDropProcedure(ctx, stmt)
case *tree.DropRole:
return nodeDropRole(ctx, stmt)
case *tree.DropSchema:
Expand Down
47 changes: 47 additions & 0 deletions server/ast/drop_procedure.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ast

import (
"fmt"

vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
pgnodes "github.com/dolthub/doltgresql/server/node"
)

// nodeDropProcedure handles *tree.DropProcedure nodes.
func nodeDropProcedure(_ *Context, node *tree.DropProcedure) (vitess.Statement, error) {
if node == nil {
return nil, nil
}

if node.DropBehavior == tree.DropCascade {
return nil, fmt.Errorf("DROP PROCEDURE with CASCADE is not supported yet")
}

if len(node.Procedures) == 0 {
return nil, fmt.Errorf("no function name specified for DROP PROCEDURE")
}

return vitess.InjectedStatement{
Statement: pgnodes.NewDropProcedure(
node.IfExists,
node.Procedures,
node.DropBehavior == tree.DropCascade),
Children: nil,
}, nil
}
83 changes: 68 additions & 15 deletions server/functions/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
package functions

import (
"fmt"

"github.com/cockroachdb/errors"
"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/utils"

"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
)
Expand All @@ -40,7 +43,7 @@ var record_in = framework.Function3{
Parameters: [3]*pgtypes.DoltgresType{pgtypes.Cstring, pgtypes.Oid, pgtypes.Int32},
Strict: true,
Callable: func(ctx *sql.Context, _ [4]*pgtypes.DoltgresType, val1, val2, val3 any) (any, error) {
return nil, fmt.Errorf("record_in not implemented")
return nil, errors.Errorf("record_in not implemented")
},
}

Expand All @@ -53,24 +56,66 @@ var record_out = framework.Function1{
Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) {
values, ok := val.([]pgtypes.RecordValue)
if !ok {
return nil, fmt.Errorf("expected []RecordValue, but got %T", val)
return nil, errors.Errorf("expected []RecordValue, but got %T", val)
}
return pgtypes.RecordToString(ctx, values)
},
}

// record_recv represents the PostgreSQL function of record type IO receive.
// record_recv represents the PostgreSQL function of record type IO receive. The input of this function is expected to
// be the output of record_send.
var record_recv = framework.Function3{
Name: "record_recv",
Return: pgtypes.Record,
Parameters: [3]*pgtypes.DoltgresType{pgtypes.Internal, pgtypes.Oid, pgtypes.Int32},
Strict: true,
Callable: func(ctx *sql.Context, _ [4]*pgtypes.DoltgresType, val1, val2, val3 any) (any, error) {
return nil, fmt.Errorf("record_recv not implemented")
data, ok := val1.([]byte)
if !ok {
return nil, errors.Errorf("expected []byte, but got `%T`", val1)
}
typeColl, err := core.GetTypesCollectionFromContext(ctx)
if err != nil {
return nil, err
}
reader := utils.NewReader(data)
version := reader.Byte()
switch version {
case 0:
valuesLen := reader.VariableUint()
values := make([]pgtypes.RecordValue, valuesLen)
for i := uint64(0); i < valuesLen; i++ {
typeId := id.Type(reader.Id())
valueData := reader.ByteSlice()
dgtype, err := typeColl.GetType(ctx, typeId)
if err != nil {
return nil, err
}
if dgtype == nil {
return nil, errors.Errorf("record_recv encountered type `%s.%s` which could not be found",
typeId.SchemaName(), typeId.TypeName())
}
value, err := dgtype.DeserializeValue(ctx, valueData)
if err != nil {
return nil, err
}
values[i] = pgtypes.RecordValue{
Value: value,
Type: dgtype,
}
}
if reader.RemainingBytes() > 0 {
return nil, errors.New("record_recv encountered extra data during deserialization")
}
return values, nil
default:
return nil, errors.Errorf("version %d of record serialization is not supported, please upgrade the server", version)
}
},
}

// record_send represents the PostgreSQL function of record type IO send.
// record_send represents the PostgreSQL function of record type IO send. The output of this function is expected to
// be the input of record_recv.
var record_send = framework.Function1{
Name: "record_send",
Return: pgtypes.Bytea,
Expand All @@ -79,16 +124,24 @@ var record_send = framework.Function1{
Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) {
values, ok := val.([]pgtypes.RecordValue)
if !ok {
return nil, fmt.Errorf("expected []RecordValue, but got %T", val)
return nil, errors.Errorf("expected []RecordValue, but got %T", val)
}
// TODO: converting from a string back to the record doesn't work as we lose type information, so we need to
// figure out how to retain this information
output, err := pgtypes.RecordToString(ctx, values)
if err != nil {
return nil, err
writer := utils.NewWriter(uint64(16 * len(values)))
writer.Byte(0) // Version
writer.VariableUint(uint64(len(values)))
for _, value := range values {
dgtype, ok := value.Type.(*pgtypes.DoltgresType)
if !ok {
return nil, errors.Errorf("record_send only supports Doltgres types, but received `%T`", value.Type)
}
valBytes, err := dgtype.SerializeValue(ctx, value.Value)
if err != nil {
return nil, err
}
writer.Id(dgtype.ID.AsId())
writer.ByteSlice(valBytes)
}

return []byte(output.(string)), nil
return writer.Data(), nil
},
}

Expand Down
Loading
Loading