Skip to content

Serialize schema names in foreign key constraints #8461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 17, 2024
6 changes: 2 additions & 4 deletions go/libraries/doltcore/diff/table_deltas.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ func GetTableDeltas(ctx context.Context, fromRoot, toRoot doltdb.RootValue) (del
func getFkParentSchs(ctx context.Context, root doltdb.RootValue, fks ...doltdb.ForeignKey) (map[doltdb.TableName]schema.Schema, error) {
schs := make(map[doltdb.TableName]schema.Schema)
for _, toFk := range fks {
// TODO: schema
toRefTable, _, ok, err := doltdb.GetTableInsensitive(ctx, root, doltdb.TableName{Name: toFk.ReferencedTableName})
toRefTable, _, ok, err := doltdb.GetTableInsensitive(ctx, root, toFk.ReferencedTableName)
if err != nil {
return nil, err
}
Expand All @@ -228,8 +227,7 @@ func getFkParentSchs(ctx context.Context, root doltdb.RootValue, fks ...doltdb.F
if err != nil {
return nil, err
}
// TODO: schema name
schs[doltdb.TableName{Name: toFk.ReferencedTableName}] = toRefSch
schs[toFk.ReferencedTableName] = toRefSch
}
return schs, nil
}
Expand Down
42 changes: 22 additions & 20 deletions go/libraries/doltcore/doltdb/foreign_key_coll.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ const (
// ForeignKey is the complete, internal representation of a Foreign Key.
type ForeignKey struct {
Name string `noms:"name" json:"name"`
TableName string `noms:"tbl_name" json:"tbl_name"`
TableName TableName `noms:"tbl_name" json:"tbl_name"`
TableIndex string `noms:"tbl_index" json:"tbl_index"`
TableColumns []uint64 `noms:"tbl_cols" json:"tbl_cols"`
ReferencedTableName string `noms:"ref_tbl_name" json:"ref_tbl_name"`
ReferencedTableName TableName `noms:"ref_tbl_name" json:"ref_tbl_name"`
ReferencedTableIndex string `noms:"ref_tbl_index" json:"ref_tbl_index"`
ReferencedTableColumns []uint64 `noms:"ref_tbl_cols" json:"ref_tbl_cols"`
OnUpdate ForeignKeyReferentialAction `noms:"on_update" json:"on_update"`
Expand Down Expand Up @@ -167,7 +167,7 @@ func (fk ForeignKey) Equals(other ForeignKey, fkSchemasByName, otherSchemasByNam
}
for i, tag := range resolvedFK.TableColumns {
unresolvedColName := unresolvedFK.UnresolvedFKDetails.TableColumns[i]
resolvedSch, ok := resolvedSchemasByName[TableName{Name: resolvedFK.TableName}]
resolvedSch, ok := resolvedSchemasByName[resolvedFK.TableName]
if !ok {
return false
}
Expand All @@ -186,7 +186,7 @@ func (fk ForeignKey) Equals(other ForeignKey, fkSchemasByName, otherSchemasByNam
}
for i, tag := range resolvedFK.ReferencedTableColumns {
unresolvedColName := unresolvedFK.UnresolvedFKDetails.ReferencedTableColumns[i]
resolvedSch, ok := resolvedSchemasByName[TableName{Name: unresolvedFK.ReferencedTableName}]
resolvedSch, ok := resolvedSchemasByName[unresolvedFK.ReferencedTableName]
if !ok {
return false
}
Expand Down Expand Up @@ -246,6 +246,8 @@ func (fk ForeignKey) HashOf() (hash.Hash, error) {
_, err = bb.Write(t)
case uint64:
err = binary.Write(&bb, binary.LittleEndian, t)
case TableName:
_, err = bb.Write([]byte(t.String()))
default:
return hash.Hash{}, fmt.Errorf("unsupported type %T", t)
}
Expand Down Expand Up @@ -278,7 +280,7 @@ func CombinedHash(fks []ForeignKey) (hash.Hash, error) {

// IsSelfReferential returns whether the table declaring the foreign key is also referenced by the foreign key.
func (fk ForeignKey) IsSelfReferential() bool {
return strings.EqualFold(fk.TableName, fk.ReferencedTableName)
return fk.TableName.EqualFold(fk.ReferencedTableName)
}

// IsResolved returns whether the foreign key has been resolved.
Expand Down Expand Up @@ -543,13 +545,13 @@ OuterLoopResolved:
len(fk.ReferencedTableColumns) != len(existingFk.UnresolvedFKDetails.ReferencedTableColumns) {
continue
}
// TODO: schema name
tblSch, ok := allSchemas[TableName{Name: existingFk.TableName}]

tblSch, ok := allSchemas[existingFk.TableName]
if !ok {
continue
}
// TODO: schema name
refTblSch, ok := allSchemas[TableName{Name: existingFk.ReferencedTableName}]

refTblSch, ok := allSchemas[existingFk.ReferencedTableName]
if !ok {
continue
}
Expand Down Expand Up @@ -590,10 +592,10 @@ func (fkc *ForeignKeyCollection) Iter(cb func(fk ForeignKey) (stop bool, err err
// it will be present in both declaresFk and referencedByFk. Each array is sorted by name ascending.
func (fkc *ForeignKeyCollection) KeysForTable(tableName TableName) (declaredFk, referencedByFk []ForeignKey) {
for _, foreignKey := range fkc.foreignKeys {
if strings.EqualFold(foreignKey.TableName, tableName.Name) {
if foreignKey.TableName.EqualFold(tableName) {
declaredFk = append(declaredFk, foreignKey)
}
if strings.EqualFold(foreignKey.ReferencedTableName, tableName.Name) {
if foreignKey.ReferencedTableName.EqualFold(tableName) {
referencedByFk = append(referencedByFk, foreignKey)
}
}
Expand Down Expand Up @@ -643,9 +645,9 @@ func (fkc *ForeignKeyCollection) RemoveKeyByName(foreignKeyName string) bool {
func (fkc *ForeignKeyCollection) RemoveTables(ctx context.Context, tables ...TableName) error {
outgoing := NewTableNameSet(tables)
for _, fk := range fkc.foreignKeys {
// TODO: schema names
dropChild := outgoing.Contains(TableName{Name: fk.TableName})
dropParent := outgoing.Contains(TableName{Name: fk.ReferencedTableName})

dropChild := outgoing.Contains(fk.TableName)
dropParent := outgoing.Contains(fk.ReferencedTableName)
if dropParent && !dropChild {
return fmt.Errorf("unable to remove `%s` since it is referenced from table `%s`", fk.ReferencedTableName, fk.TableName)
}
Expand All @@ -666,8 +668,8 @@ func (fkc *ForeignKeyCollection) RemoveTables(ctx context.Context, tables ...Tab
func (fkc *ForeignKeyCollection) RemoveAndUnresolveTables(ctx context.Context, root RootValue, tables ...TableName) error {
outgoing := NewTableNameSet(tables)
for _, fk := range fkc.foreignKeys {
dropChild := outgoing.Contains(TableName{Name: fk.TableName})
dropParent := outgoing.Contains(TableName{Name: fk.ReferencedTableName})
dropChild := outgoing.Contains(fk.TableName)
dropParent := outgoing.Contains(fk.ReferencedTableName)
if dropParent && !dropChild {
if !fk.IsResolved() {
continue
Expand All @@ -682,7 +684,7 @@ func (fkc *ForeignKeyCollection) RemoveAndUnresolveTables(ctx context.Context, r
fk.UnresolvedFKDetails.TableColumns = make([]string, len(fk.TableColumns))
fk.UnresolvedFKDetails.ReferencedTableColumns = make([]string, len(fk.ReferencedTableColumns))

tbl, ok, err := root.GetTable(ctx, TableName{Name: fk.TableName})
tbl, ok, err := root.GetTable(ctx, fk.TableName)
if err != nil {
return err
}
Expand All @@ -703,7 +705,7 @@ func (fkc *ForeignKeyCollection) RemoveAndUnresolveTables(ctx context.Context, r
fk.UnresolvedFKDetails.TableColumns[i] = col.Name
}

refTbl, ok, err := root.GetTable(ctx, TableName{Name: fk.ReferencedTableName})
refTbl, ok, err := root.GetTable(ctx, fk.ReferencedTableName)
if err != nil {
return err
}
Expand Down Expand Up @@ -747,8 +749,8 @@ func (fkc *ForeignKeyCollection) RemoveAndUnresolveTables(ctx context.Context, r
}

// Tables returns the set of all tables that either declare a foreign key or are referenced by a foreign key.
func (fkc *ForeignKeyCollection) Tables() map[string]struct{} {
tables := make(map[string]struct{})
func (fkc *ForeignKeyCollection) Tables() map[TableName]struct{} {
tables := make(map[TableName]struct{})
for _, fk := range fkc.foreignKeys {
tables[fk.TableName] = struct{}{}
tables[fk.ReferencedTableName] = struct{}{}
Expand Down
18 changes: 14 additions & 4 deletions go/libraries/doltcore/doltdb/foreign_key_serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,22 @@ func deserializeFlatbufferForeignKeys(msg types.SerialMessage) (*ForeignKeyColle
}
}

tableName, ok := decodeTableNameFromSerialization(string(fk.ChildTableName()))
if !ok {
return nil, fmt.Errorf("could not decode table name: %s", string(fk.ChildTableName()))
}

parentTableName, ok := decodeTableNameFromSerialization(string(fk.ParentTableName()))
if !ok {
return nil, fmt.Errorf("could not decode table name: %s", string(fk.ParentTableName()))
}

err := collection.AddKeys(ForeignKey{
Name: string(fk.Name()),
TableName: string(fk.ChildTableName()),
TableName: tableName,
TableIndex: string(fk.ChildTableIndex()),
TableColumns: childCols,
ReferencedTableName: string(fk.ParentTableName()),
ReferencedTableName: parentTableName,
ReferencedTableIndex: string(fk.ParentTableIndex()),
ReferencedTableColumns: parentCols,
OnUpdate: ForeignKeyReferentialAction(fk.OnUpdate()),
Expand Down Expand Up @@ -181,9 +191,9 @@ func serializeFlatbufferForeignKeys(fkc *ForeignKeyCollection) types.SerialMessa
}
parentCols = serializeUint64Vector(b, fk.ReferencedTableColumns)
childCols = serializeUint64Vector(b, fk.TableColumns)
parentTable = b.CreateString(fk.ReferencedTableName)
parentTable = b.CreateString(encodeTableNameForSerialization(fk.ReferencedTableName))
parentIndex = b.CreateString(fk.ReferencedTableIndex)
childTable = b.CreateString(fk.TableName)
childTable = b.CreateString(encodeTableNameForSerialization(fk.TableName))
childIndex = b.CreateString(fk.TableIndex)
foreignKeyName = b.CreateString(fk.Name)

Expand Down
11 changes: 7 additions & 4 deletions go/libraries/doltcore/doltdb/root_val.go
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,10 @@ func (tn TableName) String() string {
return tn.Schema + "." + tn.Name
}

func (tn TableName) EqualFold(o TableName) bool {
return strings.EqualFold(tn.Name, o.Name) && strings.EqualFold(tn.Schema, o.Schema)
}

// ToTableNames is a migration helper function that converts a slice of table names to a slice of TableName structs.
func ToTableNames(names []string, schemaName string) []TableName {
tbls := make([]TableName, len(names))
Expand Down Expand Up @@ -1070,14 +1074,13 @@ func ValidateForeignKeysOnSchemas(ctx context.Context, root RootValue) (RootValu
return nil, err
}

// TODO: schema name
allTablesSlice, err := root.GetTableNames(ctx, DefaultSchemaName)
allTablesSlice, err := UnionTableNames(ctx, root)
if err != nil {
return nil, err
}
allTablesSet := make(map[string]schema.Schema)
allTablesSet := make(map[TableName]schema.Schema)
for _, tableName := range allTablesSlice {
tbl, ok, err := root.GetTable(ctx, TableName{Name: tableName})
tbl, ok, err := root.GetTable(ctx, tableName)
if err != nil {
return nil, err
}
Expand Down
35 changes: 29 additions & 6 deletions go/libraries/doltcore/doltdb/root_val_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bytes"
"context"
"fmt"
"strings"

flatbuffers "github.com/dolthub/flatbuffers/v23/go"

Expand Down Expand Up @@ -358,7 +359,7 @@ func (r fbRvStorage) EditTablesMap(ctx context.Context, vrw types.ValueReadWrite
if err != nil {
return nil, err
}
newaddr, err := am.Get(ctx, encodeTableNameForAddressMap(e.name))
newaddr, err := am.Get(ctx, encodeTableNameForSerialization(e.name))
if err != nil {
return nil, err
}
Expand All @@ -372,18 +373,18 @@ func (r fbRvStorage) EditTablesMap(ctx context.Context, vrw types.ValueReadWrite
if err != nil {
return nil, err
}
err = ae.Update(ctx, encodeTableNameForAddressMap(e.name), oldaddr)
err = ae.Update(ctx, encodeTableNameForSerialization(e.name), oldaddr)
if err != nil {
return nil, err
}
} else {
if e.ref == nil {
err := ae.Delete(ctx, encodeTableNameForAddressMap(e.name))
err := ae.Delete(ctx, encodeTableNameForSerialization(e.name))
if err != nil {
return nil, err
}
} else {
err := ae.Update(ctx, encodeTableNameForAddressMap(e.name), e.ref.TargetHash())
err := ae.Update(ctx, encodeTableNameForSerialization(e.name), e.ref.TargetHash())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -454,13 +455,35 @@ func serializeDatabaseSchemas(b *flatbuffers.Builder, dbSchemas []schema.Databas
return b.EndVector(len(offsets))
}

func encodeTableNameForAddressMap(name TableName) string {
// encodeTableNameForSerialization encodes a table name for serialization. Table names with no schema are encoded as
// just the bare table name. Table names with schemas are encoded by surrounding the schema name with null bytes and
// appending the table name.
func encodeTableNameForSerialization(name TableName) string {
if name.Schema == "" {
return name.Name
}
return fmt.Sprintf("\000%s\000%s", name.Schema, name.Name)
}

// decodeTableNameFromSerialization decodes a table name from a serialized string. See notes on serialization in
// |encodeTableNameForSerialization|
func decodeTableNameFromSerialization(encodedName string) (TableName, bool) {
if encodedName[0] != 0 {
return TableName{Name: encodedName}, true
} else if len(encodedName) >= 4 { // 2 null bytes plus at least one char for schema and table name
schemaEnd := strings.LastIndexByte(encodedName, 0)
return TableName{
Schema: encodedName[1:schemaEnd],
Name: encodedName[schemaEnd+1:],
}, true
}

// invalid encoding
return TableName{}, false
}

// decodeTableNameForAddressMap decodes a table name from an address map key, expecting a particular schema name. See
// notes on serialization in |encodeTableNameForSerialization|
func decodeTableNameForAddressMap(encodedName, schemaName string) (string, bool) {
if schemaName == "" && encodedName[0] != 0 {
return encodedName, true
Expand All @@ -478,7 +501,7 @@ type fbTableMap struct {
}

func (m fbTableMap) Get(ctx context.Context, name string) (hash.Hash, error) {
return m.AddressMap.Get(ctx, encodeTableNameForAddressMap(TableName{Name: name, Schema: m.schemaName}))
return m.AddressMap.Get(ctx, encodeTableNameForSerialization(TableName{Name: name, Schema: m.schemaName}))
}

func (m fbTableMap) Iter(ctx context.Context, cb func(string, hash.Hash) (bool, error)) error {
Expand Down
31 changes: 0 additions & 31 deletions go/libraries/doltcore/merge/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,37 +483,6 @@ func getConstraintViolationStats(ctx context.Context, root doltdb.RootValue, tbl
return nil
}

// MayHaveConstraintViolations returns whether the given roots may have constraint violations. For example, a fast
// forward merge that does not involve any tables with foreign key constraints or check constraints will not be able
// to generate constraint violations. Unique key constraint violations would be caught during the generation of the
// merged root, therefore it is not a factor for this function.
func MayHaveConstraintViolations(ctx context.Context, ancestor, merged doltdb.RootValue) (bool, error) {
ancTables, err := doltdb.MapTableHashes(ctx, ancestor)
if err != nil {
return false, err
}
mergedTables, err := doltdb.MapTableHashes(ctx, merged)
if err != nil {
return false, err
}
fkColl, err := merged.GetForeignKeyCollection(ctx)
if err != nil {
return false, err
}
tablesInFks := fkColl.Tables()
for tblName := range tablesInFks {
if ancHash, ok := ancTables[doltdb.TableName{Name: tblName}]; !ok {
// If a table used in a foreign key is new then it's treated as a change
return true, nil
} else if mergedHash, ok := mergedTables[doltdb.TableName{Name: tblName}]; !ok {
return false, fmt.Errorf("foreign key uses table '%s' but no hash can be found for this table", tblName)
} else if !ancHash.Equal(mergedHash) {
return true, nil
}
}
return false, nil
}

type ArtifactStatus struct {
SchemaConflictsTables []string
DataConflictTables []string
Expand Down
4 changes: 2 additions & 2 deletions go/libraries/doltcore/merge/merge_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ func fkCollSetDifference(
func pruneInvalidForeignKeys(ctx context.Context, fkColl *doltdb.ForeignKeyCollection, mergedRoot doltdb.RootValue) (pruned *doltdb.ForeignKeyCollection, err error) {
pruned, _ = doltdb.NewForeignKeyCollection()
err = fkColl.Iter(func(fk doltdb.ForeignKey) (stop bool, err error) {
parentTbl, ok, err := mergedRoot.GetTable(ctx, doltdb.TableName{Name: fk.ReferencedTableName})
parentTbl, ok, err := mergedRoot.GetTable(ctx, fk.ReferencedTableName)
if err != nil || !ok {
return false, err
}
Expand All @@ -1028,7 +1028,7 @@ func pruneInvalidForeignKeys(ctx context.Context, fkColl *doltdb.ForeignKeyColle
}
}

childTbl, ok, err := mergedRoot.GetTable(ctx, doltdb.TableName{Name: fk.TableName})
childTbl, ok, err := mergedRoot.GetTable(ctx, fk.TableName)
if err != nil || !ok {
return false, err
}
Expand Down
Loading
Loading