diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index 9e64e75ea0d..2f9fd2f80ac 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -1185,6 +1185,32 @@ func TestExecutorDDL(t *testing.T) { } } +func TestExecutorDDLFk(t *testing.T) { + executor, _, _, sbc := createExecutorEnv() + + mName := "TestExecutorDDLFk" + stmts := []string{ + "create table t1(id bigint primary key, foreign key (id) references t2(id))", + "alter table t2 add foreign key (id) references t1(id) on delete cascade", + } + + for _, stmt := range stmts { + for _, fkMode := range []string{"allow", "disallow"} { + t.Run(stmt+fkMode, func(t *testing.T) { + sbc.ExecCount.Set(0) + *foreignKeyMode = fkMode + _, err := executor.Execute(ctx, mName, NewSafeSession(&vtgatepb.Session{TargetString: KsTestUnsharded}), stmt, nil) + if fkMode == "allow" { + require.NoError(t, err) + require.EqualValues(t, 1, sbc.ExecCount.Get()) + } else { + require.EqualError(t, err, "foreign key constraint is not allowed") + } + }) + } + } +} + func TestExecutorAlterVSchemaKeyspace(t *testing.T) { *vschemaacl.AuthorizedDDLUsers = "%" defer func() { diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index af3bcc188ff..23b6a968c02 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -61,6 +61,9 @@ type ContextVSchema interface { // This will let the user know that they are using something // that could become a problem if they move to a sharded keyspace WarnUnshardedOnly(format string, params ...interface{}) + + // ForeignKeyMode returns the foreign_key flag value + ForeignKeyMode() string } // PlannerVersion is an alias here to make the code more readable diff --git a/go/vt/vtgate/planbuilder/ddl.go b/go/vt/vtgate/planbuilder/ddl.go index be3c6bb4b74..04e6ddfcd21 100644 --- a/go/vt/vtgate/planbuilder/ddl.go +++ b/go/vt/vtgate/planbuilder/ddl.go @@ -1,13 +1,13 @@ package planbuilder import ( - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vtgate/vindexes" - "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/vindexes" + + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) // Error messages for CreateView queries @@ -17,6 +17,33 @@ const ( DifferentDestinations string = "Tables or Views specified in the query do not belong to the same destination" ) +type fkStrategy int + +const ( + fkAllow fkStrategy = iota + fkDisallow +) + +var fkStrategyMap = map[string]fkStrategy{ + "allow": fkAllow, + "disallow": fkDisallow, +} + +type fkContraint struct { + found bool +} + +func (fk *fkContraint) FkWalk(node sqlparser.SQLNode) (kontinue bool, err error) { + switch node.(type) { + case *sqlparser.CreateTable, *sqlparser.AlterTable, + *sqlparser.TableSpec, *sqlparser.AddConstraintDefinition, *sqlparser.ConstraintDefinition: + return true, nil + case *sqlparser.ForeignKeyDefinition: + fk.found = true + } + return false, nil +} + // buildGeneralDDLPlan builds a general DDL plan, which can be either normal DDL or online DDL. // The two behave compeltely differently, and have two very different primitives. // We want to be able to dynamically choose between normal/online plans according to Session settings. @@ -55,44 +82,40 @@ func buildDDLPlans(sql string, ddlStatement sqlparser.DDLStatement, reservedVars switch ddl := ddlStatement.(type) { case *sqlparser.AlterTable, *sqlparser.TruncateTable: - // For Alter Table and other statements, the table must already exist - // We should find the target of the query from this tables location - destination, keyspace, err = findTableDestinationAndKeyspace(vschema, ddlStatement) + err = checkFKError(vschema, ddlStatement) if err != nil { return nil, nil, err } + // For Alter Table and other statements, the table must already exist + // We should find the target of the query from this tables location + destination, keyspace, err = findTableDestinationAndKeyspace(vschema, ddlStatement) case *sqlparser.CreateView: destination, keyspace, err = buildCreateView(vschema, ddl, reservedVars) - if err != nil { - return nil, nil, err - } case *sqlparser.AlterView: destination, keyspace, err = buildAlterView(vschema, ddl, reservedVars) + case *sqlparser.CreateTable: + err = checkFKError(vschema, ddlStatement) if err != nil { return nil, nil, err } - case *sqlparser.CreateTable: destination, keyspace, _, err = vschema.TargetDestination(ddlStatement.GetTable().Qualifier.String()) - // Remove the keyspace name as the database name might be different. - ddlStatement.SetTable("", ddlStatement.GetTable().Name.String()) if err != nil { return nil, nil, err } + // Remove the keyspace name as the database name might be different. + ddlStatement.SetTable("", ddlStatement.GetTable().Name.String()) case *sqlparser.DropView, *sqlparser.DropTable: destination, keyspace, err = buildDropViewOrTable(vschema, ddlStatement) - if err != nil { - return nil, nil, err - } case *sqlparser.RenameTable: destination, keyspace, err = buildRenameTable(vschema, ddl) - if err != nil { - return nil, nil, err - } - default: return nil, nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unexpected ddl statement type: %T", ddlStatement) } + if err != nil { + return nil, nil, err + } + if destination == nil { destination = key.DestinationAllShards{} } @@ -116,6 +139,17 @@ func buildDDLPlans(sql string, ddlStatement sqlparser.DDLStatement, reservedVars }, nil } +func checkFKError(vschema ContextVSchema, ddlStatement sqlparser.DDLStatement) error { + if fkStrategyMap[vschema.ForeignKeyMode()] == fkDisallow { + fk := &fkContraint{} + _ = sqlparser.Walk(fk.FkWalk, ddlStatement) + if fk.found { + return vterrors.Errorf(vtrpcpb.Code_ABORTED, "foreign key constraint is not allowed") + } + } + return nil +} + func findTableDestinationAndKeyspace(vschema ContextVSchema, ddlStatement sqlparser.DDLStatement) (key.Destination, *vindexes.Keyspace, error) { var table *vindexes.Table var destination key.Destination diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index 1f719d5cc20..ea8c3568aa5 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -312,6 +312,10 @@ type vschemaWrapper struct { version PlannerVersion } +func (vw *vschemaWrapper) ForeignKeyMode() string { + return "allow" +} + func (vw *vschemaWrapper) AllKeyspace() ([]*vindexes.Keyspace, error) { if vw.keyspace == nil { return nil, errors.New("keyspace not available") diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index 78a8eb40704..f60af6dca2b 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -753,6 +753,14 @@ func (vc *vcursorImpl) WarnUnshardedOnly(format string, params ...interface{}) { } } +// ForeignKey implements the VCursor interface +func (vc *vcursorImpl) ForeignKeyMode() string { + if foreignKeyMode == nil { + return "" + } + return strings.ToLower(*foreignKeyMode) +} + // ParseDestinationTarget parses destination target string and sets default keyspace if possible. func parseDestinationTarget(targetString string, vschema *vindexes.VSchema) (string, topodatapb.TabletType, key.Destination, error) { destKeyspace, destTabletType, dest, err := topoprotopb.ParseDestination(targetString, defaultTabletType) diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index 37c4650cc9a..33e18832d6b 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -82,6 +82,8 @@ var ( // lockHeartbeatTime is used to set the next heartbeat time. lockHeartbeatTime = flag.Duration("lock_heartbeat_time", 5*time.Second, "If there is lock function used. This will keep the lock connection active by using this heartbeat") warnShardedOnly = flag.Bool("warn_sharded_only", false, "If any features that are only available in unsharded mode are used, query execution warnings will be added to the session") + + foreignKeyMode = flag.String("foreign_key_mode", "allow", "This is to provide how to handle foreign key constraint in create/alter table. Valid values are: allow, disallow") ) func getTxMode() vtgatepb.TransactionMode {