diff --git a/go-controller/pkg/ovn/egressfirewall.go b/go-controller/pkg/ovn/egressfirewall.go index cc8e1a5a43..f8cfdc8ae0 100644 --- a/go-controller/pkg/ovn/egressfirewall.go +++ b/go-controller/pkg/ovn/egressfirewall.go @@ -12,9 +12,11 @@ import ( "github.com/ovn-org/ovn-kubernetes/go-controller/pkg/nbdb" addressset "github.com/ovn-org/ovn-kubernetes/go-controller/pkg/ovn/address_set" "github.com/ovn-org/ovn-kubernetes/go-controller/pkg/types" + "github.com/ovn-org/ovn-kubernetes/go-controller/pkg/util/batching" kapi "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/errors" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/client-go/util/retry" "k8s.io/klog/v2" utilnet "k8s.io/utils/net" @@ -26,6 +28,7 @@ const ( // egressFirewallACLExtIdKey external ID key for egress firewall ACLs egressFirewallACLExtIdKey = "egressFirewall" egressFirewallACLPriorityKey = "priority" + aclDeleteBatchSize = 1000 ) type egressFirewall struct { @@ -158,10 +161,19 @@ func (oc *Controller) syncEgressFirewall(egressFirewalls []interface{}) error { // delete acls from all switches, they reside on the port group now if len(egressFirewallACLs) != 0 { - err = libovsdbops.RemoveACLsFromLogicalSwitchesWithPredicate(oc.nbClient, func(item *nbdb.LogicalSwitch) bool { return true }, - egressFirewallACLs...) + err = batching.Batch[*nbdb.ACL](aclDeleteBatchSize, egressFirewallACLs, func(batchACLs []*nbdb.ACL) error { + // optimize the predicate to exclude switches that don't reference deleting acls. + aclsToDelete := sets.String{} + for _, acl := range batchACLs { + aclsToDelete.Insert(acl.UUID) + } + swWithACLsPred := func(sw *nbdb.LogicalSwitch) bool { + return aclsToDelete.HasAny(sw.ACLs...) + } + return libovsdbops.RemoveACLsFromLogicalSwitchesWithPredicate(oc.nbClient, swWithACLsPred, batchACLs...) + }) if err != nil { - return fmt.Errorf("failed to remove reject acl from node logical switches: %v", err) + return fmt.Errorf("failed to remove egress firewall acls from node logical switches: %v", err) } } diff --git a/go-controller/pkg/util/batching/batch.go b/go-controller/pkg/util/batching/batch.go new file mode 100644 index 0000000000..88932e9afc --- /dev/null +++ b/go-controller/pkg/util/batching/batch.go @@ -0,0 +1,23 @@ +package batching + +import "fmt" + +func Batch[T any](batchSize int, data []T, eachFn func([]T) error) error { + if batchSize < 1 { + return fmt.Errorf("batchSize should be > 0, got %d", batchSize) + } + start := 0 + dataLen := len(data) + for start < dataLen { + end := start + batchSize + if end > dataLen { + end = dataLen + } + err := eachFn(data[start:end]) + if err != nil { + return err + } + start = end + } + return nil +} diff --git a/go-controller/pkg/util/batching/batch_test.go b/go-controller/pkg/util/batching/batch_test.go new file mode 100644 index 0000000000..5de64f1f86 --- /dev/null +++ b/go-controller/pkg/util/batching/batch_test.go @@ -0,0 +1,81 @@ +package batching + +import ( + "fmt" + "github.com/onsi/ginkgo" + "github.com/onsi/gomega" + + "strings" + "testing" +) + +type batchTestData struct { + name string + batchSize int + data []int + result []int + expectErr string +} + +func TestBatch(t *testing.T) { + tt := []batchTestData{ + { + name: "batch size should be > 0", + batchSize: 0, + data: []int{1, 2, 3}, + expectErr: "batchSize should be > 0", + }, + { + name: "batchSize = 1", + batchSize: 1, + data: []int{1, 2, 3}, + }, + { + name: "batchSize > 1", + batchSize: 2, + data: []int{1, 2, 3}, + }, + { + name: "number of batches = 0", + batchSize: 2, + data: nil, + }, + { + name: "number of batches = 1", + batchSize: 2, + data: []int{1, 2}, + }, + { + name: "number of batches > 1", + batchSize: 2, + data: []int{1, 2, 3, 4}, + }, + { + name: "number of batches not int", + batchSize: 2, + data: []int{1, 2, 3, 4, 5}, + }, + } + + for _, tCase := range tt { + g := gomega.NewGomegaWithT(t) + ginkgo.By(tCase.name) + var result []int + batchNum := 0 + err := Batch[int](tCase.batchSize, tCase.data, func(l []int) error { + batchNum += 1 + result = append(result, l...) + return nil + }) + if err != nil { + if tCase.expectErr != "" && strings.Contains(err.Error(), tCase.expectErr) { + continue + } + t.Fatal(fmt.Sprintf("test %s failed: %v", tCase.name, err)) + } + // tCase.data/tCase.batchSize round up + expectedBatchNum := (len(tCase.data) + tCase.batchSize - 1) / tCase.batchSize + g.Expect(batchNum).To(gomega.Equal(expectedBatchNum)) + g.Expect(result).To(gomega.Equal(tCase.data)) + } +}