Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2b62558
Loosened foreign key validation
zachmu Mar 4, 2025
dc944fe
Merge branch 'zachmu/text-primary-keys' into zachmu/foreign-key-text-…
zachmu Mar 6, 2025
2713aec
Added a comment
zachmu Mar 6, 2025
6d61577
new gms
zachmu Mar 6, 2025
509be11
updated go.sum
zachmu Mar 6, 2025
38c5b55
formatting
zachmu Mar 6, 2025
1d662c2
New test for type compatibility on foreign key
zachmu Mar 6, 2025
ecc9da4
checkpoint
zachmu Mar 7, 2025
03b3206
more tests
zachmu Mar 7, 2025
9f21414
Closer to correct semantics for foreign key opertions
zachmu Mar 8, 2025
145480f
More insert tests to rule out insert problems in foreign key tests
zachmu Mar 8, 2025
c683c29
Stub for foreign key type conversions
zachmu Mar 11, 2025
55024ab
bug fix for ignoring the constraint name in foreign key creation, mor…
zachmu Mar 11, 2025
7b072e5
Another bug fix for missing index / constraint name, more tests
zachmu Mar 11, 2025
8357d38
Bug fixes in test defns
zachmu Mar 12, 2025
3524892
Tests for out of bounds type conversions
zachmu Mar 12, 2025
291a495
new test, fixed typos
zachmu Mar 12, 2025
7996632
Added a comment about current logic
zachmu Mar 12, 2025
0a1eddc
formatting
zachmu Mar 12, 2025
9698ee9
Merge branch 'main' into zachmu/foreign-key-types
zachmu Mar 12, 2025
903b23b
new gms
zachmu Mar 12, 2025
20f5f87
PR feedback and bug fix for type conversion
zachmu Mar 12, 2025
5df4f2c
Bug fix for key name generation
zachmu Mar 12, 2025
1a7d4e1
Another name generation bug fix
zachmu Mar 12, 2025
6e9876b
new gms
zachmu Mar 12, 2025
e7c4b49
formatting
zachmu Mar 12, 2025
b7a1215
PR feedback
zachmu Mar 13, 2025
987b557
formatting
zachmu Mar 13, 2025
211d4d9
merge main, fix a couple tests
zachmu Mar 13, 2025
73dde67
new gms
zachmu Mar 13, 2025
96b1809
new gms
zachmu Mar 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ require (
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00
github.com/dolthub/go-mysql-server v0.19.1-0.20250311212537-909b08b2a5d3
github.com/dolthub/go-mysql-server v0.19.1-0.20250313005113-73b3865b4145
github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216
github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a
github.com/fatih/color v1.13.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00 h1:rh2ij2yTYKJWlX+c8XRg4H5OzqPewbU1lPK8pcfVmx8=
github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA=
github.com/dolthub/go-mysql-server v0.19.1-0.20250311212537-909b08b2a5d3 h1:2Ae4/qoJSIx/rtcyqm2uSyO1dCJGDFuyx7LKTuujb00=
github.com/dolthub/go-mysql-server v0.19.1-0.20250311212537-909b08b2a5d3/go.mod h1:yr+Vv47/YLOKMgiEY+QxHTlbIVpTuiVtkEZ5l+xruY4=
github.com/dolthub/go-mysql-server v0.19.1-0.20250313005113-73b3865b4145 h1:ye9o0LXu3IuBSp5GA45s3IATkhtEMEuqHvvjIBTm6eI=
github.com/dolthub/go-mysql-server v0.19.1-0.20250313005113-73b3865b4145/go.mod h1:yr+Vv47/YLOKMgiEY+QxHTlbIVpTuiVtkEZ5l+xruY4=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q=
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE=
Expand Down
74 changes: 74 additions & 0 deletions server/analyzer/foreign_key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package analyzer

import (
"strings"

"github.com/cockroachdb/errors"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"

"github.com/dolthub/doltgresql/server/functions/framework"
"github.com/dolthub/doltgresql/server/types"
)

// validateForeignKeyDefinition validates that the given foreign key definition is valid for creation
func validateForeignKeyDefinition(ctx *sql.Context, fkDef sql.ForeignKeyConstraint, cols map[string]*sql.Column, parentCols map[string]*sql.Column) error {
for i := range fkDef.Columns {
col := cols[strings.ToLower(fkDef.Columns[i])]
parentCol := parentCols[strings.ToLower(fkDef.ParentColumns[i])]
if !foreignKeyComparableTypes(col.Type, parentCol.Type) {
return errors.Errorf("Key columns %q and %q are of incompatible types: %s and %s", col.Name, parentCol.Name, col.Type.String(), parentCol.Type.String())
}
}
return nil
}

// foreignKeyComparableTypes returns whether the two given types are able to be used as parent/child columns in a
// foreign key.
func foreignKeyComparableTypes(from sql.Type, to sql.Type) bool {
dtFrom, ok := from.(*types.DoltgresType)
if !ok {
return false // should never be possible
}

dtTo, ok := to.(*types.DoltgresType)
if !ok {
return false // should never be possible
}

if dtFrom.Equals(dtTo) {
return true
}

fromLiteral := expression.NewLiteral(dtFrom.Zero(), from)
toLiteral := expression.NewLiteral(dtTo.Zero(), to)

// a foreign key between two different types is valid if there is an equality operator on the two types
// TODO: there are some subtleties in postgres not captured by this logic, e.g. a foreign key from double -> int
// is valid, but the reverse is not. This works fine, but is more permissive than postgres is.
eq := framework.GetBinaryFunction(framework.Operator_BinaryEqual).Compile("=", fromLiteral, toLiteral)
if eq == nil || eq.StashedError() != nil {
return false
}

// Additionally, we need to be able to convert freely between the two types in both directions, since we do this
// during the process of enforcing the constraints
forwardConversion := types.GetAssignmentCast(dtFrom, dtTo)
reverseConversion := types.GetAssignmentCast(dtTo, dtFrom)

return forwardConversion != nil && reverseConversion != nil
}
10 changes: 10 additions & 0 deletions server/analyzer/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package analyzer

import (
"github.com/dolthub/go-mysql-server/sql/analyzer"
"github.com/dolthub/go-mysql-server/sql/plan"
)

// IDs are basically arbitrary, we just need to ensure that they do not conflict with existing IDs
Expand Down Expand Up @@ -91,6 +92,15 @@ func Init() {
analyzer.Rule{Id: ruleId_AddDomainConstraintsToCasts, Apply: AddDomainConstraintsToCasts},
analyzer.Rule{Id: ruleId_ReplaceNode, Apply: ReplaceNode},
analyzer.Rule{Id: ruleId_InsertContextRootFinalizer, Apply: InsertContextRootFinalizer})

initEngine()
}

func initEngine() {
// This technically takes place at execution time rather than as part of analysis, but we don't have a better
// place to put it. Our foreign key validation logic is different from MySQL's, and since it's not an analyzer rule
// we can't swap out a rule like the rest of the logic in this packge, we have to do a function swap.
plan.ValidateForeignKeyDefinition = validateForeignKeyDefinition
}

// insertAnalyzerRules inserts the given rule(s) before or after the given analyzer.RuleId, returning an updated slice.
Expand Down
13 changes: 12 additions & 1 deletion server/ast/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ func nodeAlterTableAddConstraint(
IfExists: ifExists,
TableSpec: &vitess.TableSpec{
Constraints: []*vitess.ConstraintDefinition{
{Details: foreignKeyDefinition},
{
Name: bareIdentifier(constraintDef.Name),
Details: foreignKeyDefinition,
},
},
},
}, nil
Expand All @@ -192,6 +195,14 @@ func nodeAlterTableAddConstraint(
}
}

// bareIdentifier returns the string representation of a name without any quoting
// (quoted is the default Name.String() behavior)
func bareIdentifier(id tree.Name) string {
ctx := tree.NewFmtCtx(tree.FmtBareIdentifiers)
id.Format(ctx)
return ctx.CloseAndGetString()
}

// nodeAlterTableAddColumn converts a tree.AlterTableAddColumn instance into an equivalent vitess.DDL instance.
func nodeAlterTableAddColumn(ctx *Context, node *tree.AlterTableAddColumn, tableName vitess.TableName, ifExists bool) (*vitess.DDL, error) {
if node.IfNotExists {
Expand Down
1 change: 1 addition & 0 deletions server/ast/constraint_table_def.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ func nodeUniqueConstraintTableDef(
Table: tableName,
IfExists: ifExists,
IndexSpec: &vitess.IndexSpec{
ToName: vitess.NewColIdent(bareIdentifier(node.Name)),
Action: "create",
Type: indexType,
Columns: columns,
Expand Down
11 changes: 11 additions & 0 deletions server/cast/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

package cast

import (
"github.com/dolthub/doltgresql/server/functions/framework"
"github.com/dolthub/doltgresql/server/types"
)

// Init initializes all casts in this package.
func Init() {
initBool()
Expand All @@ -40,4 +45,10 @@ func Init() {
initTimestampTZ()
initTimeTZ()
initVarChar()

// This is a hack to get around import cycles. The types package needs these references for type conversions in
// some contexts
types.GetImplicitCast = framework.GetImplicitCast
types.GetAssignmentCast = framework.GetAssignmentCast
types.GetExplicitCast = framework.GetExplicitCast
}
28 changes: 12 additions & 16 deletions server/functions/framework/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,22 @@ import (

// TODO: Right now, all casts are global. We should decide how to handle this in the presence of branches, sessions, etc.

// TypeCastFunction is a function that takes a value of a particular kind of type, and returns it as another kind of type.
// The targetType given should match the "To" type used to obtain the cast.
type TypeCastFunction func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error)

// getCastFunction is used to recursively call the cast function for when the inner logic sees that it has two array
// types. This sidesteps providing
type getCastFunction func(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) TypeCastFunction
type getCastFunction func(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) pgtypes.TypeCastFunction

// TypeCast is used to cast from one type to another.
type TypeCast struct {
FromType *pgtypes.DoltgresType
ToType *pgtypes.DoltgresType
Function TypeCastFunction
Function pgtypes.TypeCastFunction
}

// explicitTypeCastMutex is used to lock the explicit type cast map and array when writing.
var explicitTypeCastMutex = &sync.RWMutex{}

// explicitTypeCastsMap is a map that maps: from -> to -> function.
var explicitTypeCastsMap = map[id.Type]map[id.Type]TypeCastFunction{}
var explicitTypeCastsMap = map[id.Type]map[id.Type]pgtypes.TypeCastFunction{}

// explicitTypeCastsArray is a slice that holds all registered explicit casts from the given type.
var explicitTypeCastsArray = map[id.Type][]*pgtypes.DoltgresType{}
Expand All @@ -54,7 +50,7 @@ var explicitTypeCastsArray = map[id.Type][]*pgtypes.DoltgresType{}
var assignmentTypeCastMutex = &sync.RWMutex{}

// assignmentTypeCastsMap is a map that maps: from -> to -> function.
var assignmentTypeCastsMap = map[id.Type]map[id.Type]TypeCastFunction{}
var assignmentTypeCastsMap = map[id.Type]map[id.Type]pgtypes.TypeCastFunction{}

// assignmentTypeCastsArray is a slice that holds all registered assignment casts from the given type.
var assignmentTypeCastsArray = map[id.Type][]*pgtypes.DoltgresType{}
Expand All @@ -63,7 +59,7 @@ var assignmentTypeCastsArray = map[id.Type][]*pgtypes.DoltgresType{}
var implicitTypeCastMutex = &sync.RWMutex{}

// implicitTypeCastsMap is a map that maps: from -> to -> function.
var implicitTypeCastsMap = map[id.Type]map[id.Type]TypeCastFunction{}
var implicitTypeCastsMap = map[id.Type]map[id.Type]pgtypes.TypeCastFunction{}

// implicitTypeCastsArray is a slice that holds all registered implicit casts from the given type.
var implicitTypeCastsArray = map[id.Type][]*pgtypes.DoltgresType{}
Expand Down Expand Up @@ -126,7 +122,7 @@ func GetPotentialImplicitCasts(fromType id.Type) []*pgtypes.DoltgresType {

// GetExplicitCast returns the explicit type cast function that will cast the "from" type to the "to" type. Returns nil
// if such a cast is not valid.
func GetExplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) TypeCastFunction {
func GetExplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) pgtypes.TypeCastFunction {
if tcf := getCast(explicitTypeCastMutex, explicitTypeCastsMap, fromType, toType, GetExplicitCast); tcf != nil {
return tcf
} else if tcf = getCast(assignmentTypeCastMutex, assignmentTypeCastsMap, fromType, toType, GetExplicitCast); tcf != nil {
Expand Down Expand Up @@ -170,7 +166,7 @@ func GetExplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp

// GetAssignmentCast returns the assignment type cast function that will cast the "from" type to the "to" type. Returns
// nil if such a cast is not valid.
func GetAssignmentCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) TypeCastFunction {
func GetAssignmentCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) pgtypes.TypeCastFunction {
if tcf := getCast(assignmentTypeCastMutex, assignmentTypeCastsMap, fromType, toType, GetAssignmentCast); tcf != nil {
return tcf
} else if tcf = getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetAssignmentCast); tcf != nil {
Expand Down Expand Up @@ -199,7 +195,7 @@ func GetAssignmentCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresT

// GetImplicitCast returns the implicit type cast function that will cast the "from" type to the "to" type. Returns nil
// if such a cast is not valid.
func GetImplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) TypeCastFunction {
func GetImplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType) pgtypes.TypeCastFunction {
if tcf := getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetImplicitCast); tcf != nil {
return tcf
}
Expand All @@ -213,14 +209,14 @@ func GetImplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp

// addTypeCast registers the given type cast.
func addTypeCast(mutex *sync.RWMutex,
castMap map[id.Type]map[id.Type]TypeCastFunction,
castMap map[id.Type]map[id.Type]pgtypes.TypeCastFunction,
castArray map[id.Type][]*pgtypes.DoltgresType, cast TypeCast) error {
mutex.Lock()
defer mutex.Unlock()

toMap, ok := castMap[cast.FromType.ID]
if !ok {
toMap = map[id.Type]TypeCastFunction{}
toMap = map[id.Type]pgtypes.TypeCastFunction{}
castMap[cast.FromType.ID] = toMap
castArray[cast.FromType.ID] = nil
}
Expand All @@ -244,8 +240,8 @@ func getPotentialCasts(mutex *sync.RWMutex, castArray map[id.Type][]*pgtypes.Dol
// getCast returns the type cast function that will cast the "from" type to the "to" type. Returns nil if such a cast is
// not valid.
func getCast(mutex *sync.RWMutex,
castMap map[id.Type]map[id.Type]TypeCastFunction,
fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType, outerFunc getCastFunction) TypeCastFunction {
castMap map[id.Type]map[id.Type]pgtypes.TypeCastFunction,
fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresType, outerFunc getCastFunction) pgtypes.TypeCastFunction {
mutex.RLock()
defer mutex.RUnlock()

Expand Down
4 changes: 2 additions & 2 deletions server/functions/framework/compiled_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ func (c *CompiledFunction) resolveOperator(argTypes []*pgtypes.DoltgresType, ove
rightUnknownType := argTypes[1].ID == pgtypes.Unknown.ID
if (leftUnknownType && !rightUnknownType) || (!leftUnknownType && rightUnknownType) {
var typ *pgtypes.DoltgresType
casts := []TypeCastFunction{identityCast, identityCast}
casts := []pgtypes.TypeCastFunction{identityCast, identityCast}
if leftUnknownType {
casts[0] = UnknownLiteralCast
typ = argTypes[1]
Expand Down Expand Up @@ -484,7 +484,7 @@ func (c *CompiledFunction) typeCompatibleOverloads(fnOverloads []Overload, argTy
var compatible []overloadMatch
for _, overload := range fnOverloads {
isConvertible := true
overloadCasts := make([]TypeCastFunction, len(argTypes))
overloadCasts := make([]pgtypes.TypeCastFunction, len(argTypes))
// Polymorphic parameters must be gathered so that we can later verify that they all have matching base types
var polymorphicParameters []*pgtypes.DoltgresType
var polymorphicTargets []*pgtypes.DoltgresType
Expand Down
2 changes: 1 addition & 1 deletion server/functions/framework/overloads.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (o *Overload) coalesceVariadicValues(returnValues []any) []any {
// as the type cast functions required to convert every argument to its appropriate parameter type
type overloadMatch struct {
params Overload
casts []TypeCastFunction
casts []pgtypes.TypeCastFunction
}

// Valid returns whether this overload is valid (has a callable function)
Expand Down
31 changes: 31 additions & 0 deletions server/types/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,33 @@ func (t *DoltgresType) Convert(v interface{}) (interface{}, sql.ConvertInRange,
return nil, sql.OutOfRange, ErrUnhandledType.New(t.String(), v)
}

// GetImplicitCast is a reference to the implicit cast logic in the functions/framework package, which we can't use
// here due to import cycles
var GetImplicitCast func(fromType *DoltgresType, toType *DoltgresType) TypeCastFunction

// GetAssignmentCast is a reference to the assignment cast logic in the functions/framework package, which we can't use
// here due to import cycles
var GetAssignmentCast func(fromType *DoltgresType, toType *DoltgresType) TypeCastFunction

// GetExplicitCast is a reference to the explicit cast logic in the functions/framework package, which we can't use
// here due to import cycles
var GetExplicitCast func(fromType *DoltgresType, toType *DoltgresType) TypeCastFunction

// ConvertToType implements the types.ExtendedType interface.
func (t *DoltgresType) ConvertToType(ctx *sql.Context, typ types.ExtendedType, val any) (any, error) {
dt, ok := typ.(*DoltgresType)
if !ok {
return nil, errors.Errorf("expected DoltgresType, got %T", typ)
}

castFn := GetAssignmentCast(dt, t)
if castFn == nil {
return nil, errors.Errorf("no assignment cast from %s to %s", dt.Name(), t.Name())
}

return castFn(ctx, val, t)
}

// DomainUnderlyingBaseType returns an underlying base type of this domain type.
// It can be a nested domain type, so it recursively searches for a valid base type.
func (t *DoltgresType) DomainUnderlyingBaseType() *DoltgresType {
Expand Down Expand Up @@ -851,3 +878,7 @@ func (t *DoltgresType) DeserializeValue(ctx context.Context, val []byte) (any, e
return globalFunctionRegistry.GetFunction(t.ReceiveFunc).CallVariadic(nil, val)
}
}

// TypeCastFunction is a function that takes a value of a particular kind of type, and returns it as another kind of type.
// The targetType given should match the "To" type used to obtain the cast.
type TypeCastFunction func(ctx *sql.Context, val any, targetType *DoltgresType) (any, error)
Loading
Loading