From 7f350a116db8d7ccf74b2710a18f06528b99b882 Mon Sep 17 00:00:00 2001 From: Clayton Coleman Date: Fri, 1 Feb 2019 22:42:18 -0800 Subject: [PATCH 1/2] Use a shared rate limiter for all config related actions RateLimiter will gate how fast the payload runs in parallel. Previously each unique client got its own rate limiter which prevented us from reasoning about total write load. --- pkg/start/start.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/start/start.go b/pkg/start/start.go index 4e9ec532d..a2f9dcfe7 100644 --- a/pkg/start/start.go +++ b/pkg/start/start.go @@ -26,6 +26,7 @@ import ( "k8s.io/client-go/tools/leaderelection" "k8s.io/client-go/tools/leaderelection/resourcelock" "k8s.io/client-go/tools/record" + "k8s.io/client-go/util/flowcontrol" clientset "github.com/openshift/client-go/config/clientset/versioned" informers "github.com/openshift/client-go/config/informers/externalversions" @@ -273,8 +274,7 @@ func newClientBuilder(kubeconfig string) (*ClientBuilder, error) { } func increaseQPS(config *rest.Config) { - config.QPS = 20 - config.Burst = 40 + config.RateLimiter = flowcontrol.NewTokenBucketRateLimiter(20, 40) } func useProtobuf(config *rest.Config) { From cb4e0375e6aeaab92d2f3dcda23fe859b8cabe0a Mon Sep 17 00:00:00 2001 From: Clayton Coleman Date: Thu, 17 Jan 2019 02:51:44 -0500 Subject: [PATCH 2/2] payload: Create a task graph that can split a payload into chunks The payload is viewed as a graph of chunks of tasks that must be executed in serial (roughly corresponding to the manifests defined by each operator). Using two verbs, Split -- which ensures that certain tasks cause a happens-before/happens-after for the entire graph -- and Parallelize -- which tries to run manifests from different operators at the same run level (0000_70_*) in parallel, we roughly halve the linear depth of the payload and ensure that later operators run in parallel while preserving ordering for the core operators. Remove requeing because it can change the order of tasks within a job, which may cause confusion about potential orderings of actions in the future. Future variants might more aggressively parallelize tasks. Refactor the sync worker to use the task graph with a default parallelism of 8. --- pkg/cvo/cvo.go | 2 +- pkg/cvo/cvo_scenarios_test.go | 6 +- pkg/cvo/cvo_test.go | 6 +- pkg/cvo/sync_test.go | 268 +------------- pkg/cvo/sync_worker.go | 204 ++++++----- pkg/payload/task_graph.go | 523 +++++++++++++++++++++++++++ pkg/payload/task_graph_test.go | 641 +++++++++++++++++++++++++++++++++ 7 files changed, 1280 insertions(+), 370 deletions(-) create mode 100644 pkg/payload/task_graph.go create mode 100644 pkg/payload/task_graph_test.go diff --git a/pkg/cvo/cvo.go b/pkg/cvo/cvo.go index e54f199bf..d51b3bdb2 100644 --- a/pkg/cvo/cvo.go +++ b/pkg/cvo/cvo.go @@ -218,7 +218,7 @@ func (optr *Operator) Run(workers int, stopCh <-chan struct{}) { // start the config sync loop, and have it notify the queue when new status is detected go runThrottledStatusNotifier(stopCh, optr.statusInterval, 2, optr.configSync.StatusCh(), func() { optr.queue.Add(optr.queueKey()) }) - go optr.configSync.Start(stopCh) + go optr.configSync.Start(8, stopCh) go wait.Until(func() { optr.worker(optr.queue, optr.sync) }, time.Second, stopCh) go wait.Until(func() { optr.worker(optr.availableUpdatesQueue, optr.availableUpdatesSync) }, time.Second, stopCh) diff --git a/pkg/cvo/cvo_scenarios_test.go b/pkg/cvo/cvo_scenarios_test.go index 282f8accd..682b19f4f 100644 --- a/pkg/cvo/cvo_scenarios_test.go +++ b/pkg/cvo/cvo_scenarios_test.go @@ -98,7 +98,7 @@ func TestCVO_StartupAndSync(t *testing.T) { defer close(stopCh) defer shutdownFn() worker := o.configSync.(*SyncWorker) - go worker.Start(stopCh) + go worker.Start(1, stopCh) // Step 1: Verify the CVO creates the initial Cluster Version object // @@ -385,7 +385,7 @@ func TestCVO_RestartAndReconcile(t *testing.T) { // Step 2: Start the sync worker and verify the sequence of events, and then verify // the status does not change // - go worker.Start(stopCh) + go worker.Start(1, stopCh) // verifyAllStatus(t, worker.StatusCh(), SyncWorkerStatus{ @@ -538,7 +538,7 @@ func TestCVO_ErrorDuringReconcile(t *testing.T) { // Step 2: Start the sync worker and verify the sequence of events // - go worker.Start(stopCh) + go worker.Start(1, stopCh) // verifyAllStatus(t, worker.StatusCh(), SyncWorkerStatus{ diff --git a/pkg/cvo/cvo_test.go b/pkg/cvo/cvo_test.go index a3e5bee58..b67ca5cb2 100644 --- a/pkg/cvo/cvo_test.go +++ b/pkg/cvo/cvo_test.go @@ -1304,9 +1304,9 @@ func TestOperator_sync(t *testing.T) { Actual: configv1.Update{Image: "image/image:v4.0.1", Version: "0.0.1-abc"}, }, optr: Operator{ - releaseImage: "image/image:v4.0.1", - namespace: "test", - name: "default", + releaseImage: "image/image:v4.0.1", + namespace: "test", + name: "default", defaultUpstreamServer: "http://localhost:8080/graph", availableUpdates: &availableUpdates{ Upstream: "", diff --git a/pkg/cvo/sync_test.go b/pkg/cvo/sync_test.go index 15e9296a9..c34f98127 100644 --- a/pkg/cvo/sync_test.go +++ b/pkg/cvo/sync_test.go @@ -10,7 +10,6 @@ import ( "github.com/davecgh/go-spew/spew" - apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -28,179 +27,6 @@ import ( "github.com/openshift/cluster-version-operator/pkg/payload" ) -func TestHasRequeueOnErrorAnnotation(t *testing.T) { - tests := []struct { - annos map[string]string - - exp bool - experrs []string - }{{ - annos: nil, - exp: false, - experrs: nil, - }, { - annos: map[string]string{"dummy": "dummy"}, - exp: false, - experrs: nil, - }, { - annos: map[string]string{requeueOnErrorAnnotationKey: "NoMatch"}, - exp: true, - experrs: []string{"NoMatch"}, - }, { - annos: map[string]string{requeueOnErrorAnnotationKey: "NoMatch,NotFound"}, - exp: true, - experrs: []string{"NoMatch", "NotFound"}, - }} - for idx, test := range tests { - t.Run(fmt.Sprintf("test#%d", idx), func(t *testing.T) { - got, goterrs := hasRequeueOnErrorAnnotation(test.annos) - if got != test.exp { - t.Fatalf("expected %v got %v", test.exp, got) - } - if !reflect.DeepEqual(goterrs, test.experrs) { - t.Fatalf("expected %v got %v", test.exp, got) - } - }) - } -} - -func TestShouldRequeueOnErr(t *testing.T) { - tests := []struct { - err error - manifest string - exp bool - }{{ - err: nil, - manifest: `{ - "apiVersion": "v1", - "kind": "ConfigMap" - }`, - - exp: false, - }, { - err: fmt.Errorf("random error"), - manifest: `{ - "apiVersion": "v1", - "kind": "ConfigMap" - }`, - - exp: false, - }, { - err: &meta.NoResourceMatchError{}, - manifest: `{ - "apiVersion": "v1", - "kind": "ConfigMap" - }`, - - exp: false, - }, { - err: &payload.UpdateError{Nested: &meta.NoResourceMatchError{}}, - manifest: `{ - "apiVersion": "v1", - "kind": "ConfigMap" - }`, - - exp: false, - }, { - err: &meta.NoResourceMatchError{}, - manifest: `{ - "apiVersion": "v1", - "kind": "ConfigMap", - "metadata": { - "annotations": { - "v1.cluster-version-operator.operators.openshift.io/requeue-on-error": "NoMatch" - } - } - }`, - - exp: true, - }, { - err: &payload.UpdateError{Nested: &meta.NoResourceMatchError{}}, - manifest: `{ - "apiVersion": "v1", - "kind": "ConfigMap", - "metadata": { - "annotations": { - "v1.cluster-version-operator.operators.openshift.io/requeue-on-error": "NoMatch" - } - } - }`, - - exp: true, - }, { - err: &meta.NoResourceMatchError{}, - manifest: `{ - "apiVersion": "v1", - "kind": "ConfigMap", - "metadata": { - "annotations": { - "v1.cluster-version-operator.operators.openshift.io/requeue-on-error": "NotFound" - } - } - }`, - - exp: false, - }, { - err: &payload.UpdateError{Nested: &meta.NoResourceMatchError{}}, - manifest: `{ - "apiVersion": "v1", - "kind": "ConfigMap", - "metadata": { - "annotations": { - "v1.cluster-version-operator.operators.openshift.io/requeue-on-error": "NotFound" - } - } - }`, - - exp: false, - }, { - err: apierrors.NewInternalError(fmt.Errorf("dummy")), - manifest: `{ - "apiVersion": "v1", - "kind": "ConfigMap", - "metadata": { - "annotations": { - "v1.cluster-version-operator.operators.openshift.io/requeue-on-error": "NoMatch" - } - } - }`, - - exp: false, - }, { - err: &payload.UpdateError{Nested: apierrors.NewInternalError(fmt.Errorf("dummy"))}, - manifest: `{ - "apiVersion": "v1", - "kind": "ConfigMap", - "metadata": { - "annotations": { - "v1.cluster-version-operator.operators.openshift.io/requeue-on-error": "NoMatch" - } - } - }`, - - exp: false, - }, { - err: &payload.UpdateError{Nested: &resourcebuilder.RetryLaterError{}}, - manifest: `{ - "apiVersion": "v1", - "kind": "ConfigMap" - }`, - - exp: true, - }} - for idx, test := range tests { - t.Run(fmt.Sprintf("test#%d", idx), func(t *testing.T) { - var manifest lib.Manifest - if err := json.Unmarshal([]byte(test.manifest), &manifest); err != nil { - t.Fatal(err) - } - if got := shouldRequeueOnErr(test.err, &manifest); got != test.exp { - t.Fatalf("expected %v got %v", test.exp, got) - } - }) - } -} - func Test_SyncWorker_apply(t *testing.T) { tests := []struct { manifests []string @@ -274,94 +100,6 @@ func Test_SyncWorker_apply(t *testing.T) { t.Fatalf("expected: %s got: %s", spew.Sdump(exp), spew.Sdump(got)) } }, - }, { - manifests: []string{ - `{ - "apiVersion": "test.cvo.io/v1", - "kind": "TestA", - "metadata": { - "namespace": "default", - "name": "testa", - "annotations": { - "v1.cluster-version-operator.operators.openshift.io/requeue-on-error": "NoMatch" - } - } - }`, - `{ - "apiVersion": "test.cvo.io/v1", - "kind": "TestB", - "metadata": { - "namespace": "default", - "name": "testb" - } - }`, - }, - reactors: map[action]error{ - newAction(schema.GroupVersionKind{"test.cvo.io", "v1", "TestA"}, "default", "testa"): &meta.NoResourceMatchError{}, - }, - wantErr: true, - check: func(t *testing.T, actions []action) { - if len(actions) != 7 { - spew.Dump(actions) - t.Fatalf("unexpected %d actions", len(actions)) - } - - if got, exp := actions[0], (newAction(schema.GroupVersionKind{"test.cvo.io", "v1", "TestA"}, "default", "testa")); !reflect.DeepEqual(got, exp) { - t.Fatalf("expected: %s got: %s", spew.Sdump(exp), spew.Sdump(got)) - } - if got, exp := actions[3], (newAction(schema.GroupVersionKind{"test.cvo.io", "v1", "TestB"}, "default", "testb")); !reflect.DeepEqual(got, exp) { - t.Fatalf("expected: %s got: %s", spew.Sdump(exp), spew.Sdump(got)) - } - if got, exp := actions[4], (newAction(schema.GroupVersionKind{"test.cvo.io", "v1", "TestA"}, "default", "testa")); !reflect.DeepEqual(got, exp) { - t.Fatalf("expected: %s got: %s", spew.Sdump(exp), spew.Sdump(got)) - } - }, - }, { - manifests: []string{ - `{ - "apiVersion": "test.cvo.io/v1", - "kind": "TestA", - "metadata": { - "namespace": "default", - "name": "testa", - "annotations": { - "v1.cluster-version-operator.operators.openshift.io/requeue-on-error": "NoMatch" - } - } - }`, - `{ - "apiVersion": "test.cvo.io/v1", - "kind": "TestB", - "metadata": { - "namespace": "default", - "name": "testb", - "annotations": { - "v1.cluster-version-operator.operators.openshift.io/requeue-on-error": "NoMatch" - } - } - }`, - }, - reactors: map[action]error{ - newAction(schema.GroupVersionKind{"test.cvo.io", "v1", "TestA"}, "default", "testa"): &meta.NoResourceMatchError{}, - newAction(schema.GroupVersionKind{"test.cvo.io", "v1", "TestB"}, "default", "testb"): &meta.NoResourceMatchError{}, - }, - wantErr: true, - check: func(t *testing.T, actions []action) { - if len(actions) != 9 { - spew.Dump(actions) - t.Fatalf("unexpected %d actions", len(actions)) - } - - if got, exp := actions[0], (newAction(schema.GroupVersionKind{"test.cvo.io", "v1", "TestA"}, "default", "testa")); !reflect.DeepEqual(got, exp) { - t.Fatalf("expected: %s got: %s", spew.Sdump(exp), spew.Sdump(got)) - } - if got, exp := actions[3], (newAction(schema.GroupVersionKind{"test.cvo.io", "v1", "TestB"}, "default", "testb")); !reflect.DeepEqual(got, exp) { - t.Fatalf("expected: %s got: %s", spew.Sdump(exp), spew.Sdump(got)) - } - if got, exp := actions[6], (newAction(schema.GroupVersionKind{"test.cvo.io", "v1", "TestA"}, "default", "testa")); !reflect.DeepEqual(got, exp) { - t.Fatalf("expected: %s got: %s", spew.Sdump(exp), spew.Sdump(got)) - } - }, }} for idx, test := range tests { t.Run(fmt.Sprintf("test#%d", idx), func(t *testing.T) { @@ -385,7 +123,7 @@ func Test_SyncWorker_apply(t *testing.T) { worker.backoff.Steps = 3 worker.builder = NewResourceBuilder(nil) ctx := context.Background() - worker.apply(ctx, up, &SyncWork{}, &statusWrapper{w: worker, previousStatus: worker.Status()}) + worker.apply(ctx, up, &SyncWork{}, 1, &statusWrapper{w: worker, previousStatus: worker.Status()}) test.check(t, r.actions) }) } @@ -549,7 +287,7 @@ func Test_SyncWorker_apply_generic(t *testing.T) { modifiers: test.modifiers, } ctx := context.Background() - err := worker.apply(ctx, up, &SyncWork{}, &statusWrapper{w: worker, previousStatus: worker.Status()}) + err := worker.apply(ctx, up, &SyncWork{}, 1, &statusWrapper{w: worker, previousStatus: worker.Status()}) if err != nil { t.Fatal(err) } @@ -613,7 +351,7 @@ func (r *fakeSyncRecorder) StatusCh() <-chan SyncWorkerStatus { return ch } -func (r *fakeSyncRecorder) Start(stopCh <-chan struct{}) {} +func (r *fakeSyncRecorder) Start(maxWorkers int, stopCh <-chan struct{}) {} func (r *fakeSyncRecorder) Update(generation int64, desired configv1.Update, overrides []configv1.ComponentOverride, reconciling bool) *SyncWorkerStatus { r.Updates = append(r.Updates, desired) diff --git a/pkg/cvo/sync_worker.go b/pkg/cvo/sync_worker.go index 41be44fdb..63a55a2ff 100644 --- a/pkg/cvo/sync_worker.go +++ b/pkg/cvo/sync_worker.go @@ -4,16 +4,13 @@ import ( "context" "fmt" "reflect" - "strings" "sync" "time" "github.com/golang/glog" - "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "golang.org/x/time/rate" - "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" utilruntime "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/apimachinery/pkg/util/wait" @@ -27,7 +24,7 @@ import ( // ConfigSyncWorker abstracts how the image is synchronized to the server. Introduced for testing. type ConfigSyncWorker interface { - Start(stopCh <-chan struct{}) + Start(maxWorkers int, stopCh <-chan struct{}) Update(generation int64, desired configv1.Update, overrides []configv1.ComponentOverride, reconciling bool) *SyncWorkerStatus StatusCh() <-chan SyncWorkerStatus } @@ -199,7 +196,7 @@ func (w *SyncWorker) Update(generation int64, desired configv1.Update, overrides // Start periodically invokes run, detecting whether content has changed. // It is edge-triggered when Update() is invoked and level-driven after the // syncOnce() has succeeded for a given input (we are said to be "reconciling"). -func (w *SyncWorker) Start(stopCh <-chan struct{}) { +func (w *SyncWorker) Start(maxWorkers int, stopCh <-chan struct{}) { glog.V(5).Infof("Starting sync worker") work := &SyncWork{} @@ -248,7 +245,7 @@ func (w *SyncWorker) Start(stopCh <-chan struct{}) { // so that we don't fail, then immediately start reporting an earlier status reporter := &statusWrapper{w: w, previousStatus: w.Status()} glog.V(5).Infof("Previous sync status: %#v", reporter.previousStatus) - return w.syncOnce(ctx, work, reporter) + return w.syncOnce(ctx, work, maxWorkers, reporter) }() if err != nil { // backoff wait @@ -268,6 +265,7 @@ func (w *SyncWorker) Start(stopCh <-chan struct{}) { } glog.V(5).Infof("Sync succeeded, reconciling") + work.Completed++ work.Reconciling = true next = time.After(w.minimumReconcileInterval) } @@ -384,7 +382,7 @@ func (w *SyncWorker) Status() *SyncWorkerStatus { // sync retrieves the image and applies it to the server, returning an error if // the update could not be completely applied. The status is updated as we progress. // Cancelling the context will abort the execution of the sync. -func (w *SyncWorker) syncOnce(ctx context.Context, work *SyncWork, reporter StatusReporter) error { +func (w *SyncWorker) syncOnce(ctx context.Context, work *SyncWork, maxWorkers int, reporter StatusReporter) error { glog.V(4).Infof("Running sync %s on generation %d", versionString(work.Desired), work.Generation) update := work.Desired @@ -407,21 +405,34 @@ func (w *SyncWorker) syncOnce(ctx context.Context, work *SyncWork, reporter Stat glog.V(4).Infof("Payload loaded from %s with hash %s", payloadUpdate.ReleaseImage, payloadUpdate.ManifestHash) } - return w.apply(ctx, w.payload, work, reporter) + return w.apply(ctx, w.payload, work, maxWorkers, reporter) } // apply updates the server with the contents of the provided image or returns an error. -// Cancelling the context will abort the execution of the sync. -func (w *SyncWorker) apply(ctx context.Context, payloadUpdate *payload.Update, work *SyncWork, reporter StatusReporter) error { +// Cancelling the context will abort the execution of the sync. Will be executed in parallel if +// maxWorkers is set greater than 1. +func (w *SyncWorker) apply(ctx context.Context, payloadUpdate *payload.Update, work *SyncWork, maxWorkers int, reporter StatusReporter) error { update := configv1.Update{ Version: payloadUpdate.ReleaseVersion, Image: payloadUpdate.ReleaseImage, } - // update each object + // encapsulate status reporting in a threadsafe updater version := payloadUpdate.ReleaseVersion total := len(payloadUpdate.Manifests) - done := 0 + cr := &consistentReporter{ + status: SyncWorkerStatus{ + Generation: work.Generation, + Reconciling: work.Reconciling, + VersionHash: payloadUpdate.ManifestHash, + Actual: update, + }, + completed: work.Completed, + version: version, + total: total, + reporter: reporter, + } + var tasks []*payload.Task for i := range payloadUpdate.Manifests { tasks = append(tasks, &payload.Task{ @@ -431,47 +442,42 @@ func (w *SyncWorker) apply(ctx context.Context, payloadUpdate *payload.Update, w Backoff: w.backoff, }) } + graph := payload.NewTaskGraph(tasks) + graph.Split(payload.SplitOnJobs) + graph.Parallelize(payload.ByNumberAndComponent) - for i := 0; i < len(tasks); i++ { - task := tasks[i] - setAppliedAndPending(version, total, done) - fraction := float32(i) / float32(len(tasks)) - - reporter.Report(SyncWorkerStatus{Generation: work.Generation, Fraction: fraction, Step: "ApplyResources", Reconciling: work.Reconciling, VersionHash: payloadUpdate.ManifestHash, Actual: update}) - - glog.V(4).Infof("Running sync for %s", task) - glog.V(5).Infof("Manifest: %s", string(task.Manifest.Raw)) - - if contextIsCancelled(ctx) { - err := fmt.Errorf("update was cancelled at %d/%d", i, len(tasks)) - reporter.Report(SyncWorkerStatus{Generation: work.Generation, Failure: err, Fraction: fraction, Step: "ApplyResources", Reconciling: work.Reconciling, VersionHash: payloadUpdate.ManifestHash, Actual: update}) - return err - } + // update each object + err := payload.RunGraph(ctx, graph, maxWorkers, func(ctx context.Context, tasks []*payload.Task) error { + for _, task := range tasks { + if contextIsCancelled(ctx) { + return cr.CancelError() + } + cr.Update() - ov, ok := getOverrideForManifest(work.Overrides, task.Manifest) - if ok && ov.Unmanaged { - glog.V(4).Infof("Skipping %s as unmanaged", task) - continue - } + glog.V(4).Infof("Running sync for %s", task) + glog.V(5).Infof("Manifest: %s", string(task.Manifest.Raw)) - if err := task.Run(version, w.builder); err != nil { - reporter.Report(SyncWorkerStatus{Generation: work.Generation, Failure: err, Fraction: fraction, Step: "ApplyResources", Reconciling: work.Reconciling, VersionHash: payloadUpdate.ManifestHash, Actual: update}) - cause := errors.Cause(err) - if task.Requeued == 0 && shouldRequeueOnErr(cause, task.Manifest) { - task.Requeued++ - tasks = append(tasks, task) + ov, ok := getOverrideForManifest(work.Overrides, task.Manifest) + if ok && ov.Unmanaged { + glog.V(4).Infof("Skipping %s as unmanaged", task) continue } - return err + + if err := task.Run(version, w.builder); err != nil { + return err + } + cr.Inc() + glog.V(4).Infof("Done syncing for %s", task) } - done++ - glog.V(4).Infof("Done syncing for %s", task) + return nil + }) + if err != nil { + cr.Error(err) + return err } - setAppliedAndPending(version, total, done) - work.Completed++ - reporter.Report(SyncWorkerStatus{Generation: work.Generation, Fraction: 1, Completed: work.Completed, Reconciling: true, VersionHash: payloadUpdate.ManifestHash, Actual: update}) - + // update the + cr.Complete() return nil } @@ -488,59 +494,61 @@ func init() { ) } -func setAppliedAndPending(version string, total, done int) { - metricPayload.WithLabelValues(version, "pending").Set(float64(total - done)) - metricPayload.WithLabelValues(version, "applied").Set(float64(done)) -} - -// This is used to map the know causes to their check. -var requeueOnErrorCauseToCheck = map[string]func(error) bool{ - requeueOnErrorCauseNoMatch: meta.IsNoMatchError, -} - -func shouldRequeueOnErr(err error, manifest *lib.Manifest) bool { - cause := errors.Cause(err) - if _, ok := cause.(*resourcebuilder.RetryLaterError); ok { - return true - } - - ok, errs := hasRequeueOnErrorAnnotation(manifest.Object().GetAnnotations()) - if !ok { - return false - } - should := false - for _, e := range errs { - if ef, ok := requeueOnErrorCauseToCheck[e]; ok { - if ef(cause) { - should = true - break - } - } - } - - return should -} - -const ( - // RequeueOnErrorAnnotationKey is key for annotation on a manifests object that instructs CVO to requeue on specific errors. - // The value is comma separated list of causes that forces requeue. - requeueOnErrorAnnotationKey = "v1.cluster-version-operator.operators.openshift.io/requeue-on-error" - - // RequeueOnErrorCauseNoMatch is used when no match is found for object in api. - // This maps to https://godoc.org/k8s.io/apimachinery/pkg/api/meta#NoKindMatchError and https://godoc.org/k8s.io/apimachinery/pkg/api/meta#NoResourceMatchError . - // https://godoc.org/k8s.io/apimachinery/pkg/api/meta#IsNoMatchError is used as a check. - requeueOnErrorCauseNoMatch = "NoMatch" -) - -func hasRequeueOnErrorAnnotation(annos map[string]string) (bool, []string) { - if annos == nil { - return false, nil - } - errs, ok := annos[requeueOnErrorAnnotationKey] - if !ok { - return false, nil - } - return ok, strings.Split(errs, ",") +// consistentReporter hides the details of calculating the status based on the progress +// of the graph runner. +type consistentReporter struct { + lock sync.Mutex + status SyncWorkerStatus + version string + completed int + total int + done int + reporter StatusReporter +} + +func (r *consistentReporter) Inc() { + r.lock.Lock() + defer r.lock.Unlock() + r.done++ +} + +func (r *consistentReporter) Update() { + r.lock.Lock() + defer r.lock.Unlock() + metricPayload.WithLabelValues(r.version, "pending").Set(float64(r.total - r.done)) + metricPayload.WithLabelValues(r.version, "applied").Set(float64(r.done)) + copied := r.status + copied.Step = "ApplyResources" + copied.Fraction = float32(r.done) / float32(r.total) + r.reporter.Report(copied) +} + +func (r *consistentReporter) Error(err error) { + r.lock.Lock() + defer r.lock.Unlock() + copied := r.status + copied.Step = "ApplyResources" + copied.Fraction = float32(r.done) / float32(r.total) + copied.Failure = err + r.reporter.Report(copied) +} + +func (r *consistentReporter) CancelError() error { + r.lock.Lock() + defer r.lock.Unlock() + return fmt.Errorf("update was cancelled at %d/%d", r.done, r.total) +} + +func (r *consistentReporter) Complete() { + r.lock.Lock() + defer r.lock.Unlock() + metricPayload.WithLabelValues(r.version, "pending").Set(float64(r.total)) + metricPayload.WithLabelValues(r.version, "applied").Set(float64(r.total)) + copied := r.status + copied.Completed = r.completed + 1 + copied.Reconciling = true + copied.Fraction = 1 + r.reporter.Report(copied) } // getOverrideForManifest returns the override and true when override exists for manifest. diff --git a/pkg/payload/task_graph.go b/pkg/payload/task_graph.go new file mode 100644 index 000000000..2738de017 --- /dev/null +++ b/pkg/payload/task_graph.go @@ -0,0 +1,523 @@ +package payload + +import ( + "context" + "fmt" + "regexp" + "sort" + "strconv" + "strings" + "sync" + + "github.com/golang/glog" + + "k8s.io/apimachinery/pkg/runtime/schema" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" +) + +// SplitOnJobs enforces the rule that any Job in the payload prevents reordering or parallelism (either before or after) +func SplitOnJobs(task *Task) bool { + return task.Manifest.GVK == schema.GroupVersionKind{Kind: "Job", Version: "v1", Group: "batch"} +} + +var reMatchPattern = regexp.MustCompile(`^0000_(\d+)_([a-zA-Z0-9]+(-[a-zA-Z0-9]+)*?)_`) + +const ( + groupNumber = 1 + groupComponent = 2 +) + +// ByNumberAndComponent creates parallelization for tasks whose original filenames are of the form +// 0000_NN_NAME_* - files that share 0000_NN_NAME_ are run in serial, but chunks of files that have +// the same 0000_NN but different NAME can be run in parallel. If the input is not sorted in an order +// such that 0000_NN_NAME elements are next to each other, the splitter will treat those as unsplittable +// elements. +func ByNumberAndComponent(tasks []*Task) [][]*TaskNode { + if len(tasks) <= 1 { + return nil + } + count := len(tasks) + matches := make([][]string, 0, count) + for i := 0; i < len(tasks); i++ { + matches = append(matches, reMatchPattern.FindStringSubmatch(tasks[i].Manifest.OriginalFilename)) + } + + var buckets [][]*TaskNode + var lastNode *TaskNode + for i := 0; i < count; { + matchBase := matches[i] + j := i + 1 + var groups []*TaskNode + for ; j < count; j++ { + matchNext := matches[j] + if matchBase == nil || matchNext == nil || matchBase[groupNumber] != matchNext[groupNumber] { + break + } + if matchBase[groupComponent] != matchNext[groupComponent] { + groups = append(groups, &TaskNode{Tasks: tasks[i:j]}) + i = j + } + matchBase = matchNext + } + if len(groups) > 0 { + groups = append(groups, &TaskNode{Tasks: tasks[i:j]}) + i = j + buckets = append(buckets, groups) + lastNode = nil + continue + } + if lastNode == nil { + lastNode = &TaskNode{Tasks: append([]*Task(nil), tasks[i:j]...)} + i = j + buckets = append(buckets, []*TaskNode{lastNode}) + continue + } + lastNode.Tasks = append(lastNode.Tasks, tasks[i:j]...) + i = j + } + return buckets +} + +type TaskNode struct { + In []int + Tasks []*Task + Out []int +} + +func (n TaskNode) String() string { + var arr []string + for _, t := range n.Tasks { + if len(t.Manifest.OriginalFilename) > 0 { + arr = append(arr, t.Manifest.OriginalFilename) + continue + } + arr = append(arr, t.Manifest.GVK.String()) + } + return "{Tasks: " + strings.Join(arr, ", ") + "}" +} + +func (n *TaskNode) replaceIn(index, with int) { + for i, from := range n.In { + if from == index { + n.In[i] = with + } + } +} + +func (n *TaskNode) replaceOut(index, with int) { + for i, to := range n.Out { + if to == index { + n.Out[i] = with + } + } +} + +func (n *TaskNode) appendOut(items ...int) { + for _, in := range items { + if !containsInt(n.Out, in) { + n.Out = append(n.Out, in) + } + } +} + +// TaskGraph provides methods for parallelizing a linear sequence +// of Tasks based on Split or Parallelize functions. +type TaskGraph struct { + Nodes []*TaskNode +} + +// NewTaskGraph creates a graph with a single node containing +// the supplied tasks. +func NewTaskGraph(tasks []*Task) *TaskGraph { + return &TaskGraph{ + Nodes: []*TaskNode{ + { + Tasks: tasks, + }, + }, + } +} + +func containsInt(arr []int, value int) bool { + for _, i := range arr { + if i == value { + return true + } + } + return false +} + +func (g *TaskGraph) replaceInOf(index, with int) { + node := g.Nodes[index] + in := node.In + for _, pos := range in { + g.Nodes[pos].replaceOut(index, with) + } +} + +func (g *TaskGraph) replaceOutOf(index, with int) { + node := g.Nodes[index] + out := node.Out + for _, pos := range out { + g.Nodes[pos].replaceIn(index, with) + } +} + +// Split breaks a graph node with a task that onFn returns true into +// one, two, or three separate nodes, preserving the order of tasks. +// E.g. a node with [a,b,c,d] where onFn returns true of b will result +// in a graph with [a] -> [b] -> [c,d]. +func (g *TaskGraph) Split(onFn func(task *Task) bool) { + for i := 0; i < len(g.Nodes); i++ { + node := g.Nodes[i] + tasks := node.Tasks + if len(tasks) <= 1 { + continue + } + for j, task := range tasks { + if !onFn(task) { + continue + } + + if j > 0 { + left := tasks[0:j] + next := len(g.Nodes) + nextNode := &TaskNode{ + In: node.In, + Tasks: left, + Out: []int{i}, + } + g.Nodes = append(g.Nodes, nextNode) + g.replaceInOf(i, next) + node.In = []int{next} + } + + if j < (len(tasks) - 1) { + right := tasks[j+1:] + next := len(g.Nodes) + nextNode := &TaskNode{ + In: []int{i}, + Tasks: right, + Out: node.Out, + } + g.Nodes = append(g.Nodes, nextNode) + g.replaceOutOf(i, next) + node.Out = []int{next} + } + + node.Tasks = tasks[j : j+1] + break + } + } +} + +// Parallelize takes the given breakFn and splits any TaskNode's tasks up +// into parallel groups. If breakFn returns an empty array or a single +// array item with a single task node, that is considered a no-op. +func (g *TaskGraph) Parallelize(breakFn func([]*Task) [][]*TaskNode) { + for i := 0; i < len(g.Nodes); i++ { + node := g.Nodes[i] + results := breakFn(node.Tasks) + if len(results) == 0 || (len(results) == 1 && len(results[0]) == 1) { + continue + } + node.Tasks = nil + out := node.Out + node.Out = nil + + // starting with the left anchor, create chains of nodes, + // and avoid M x N in/out connections by creating spacers + in := []int{i} + for _, inNodes := range results { + if len(inNodes) == 0 { + continue + } + singleIn, singleOut := len(in) == 1, len(inNodes) == 1 + + switch { + case singleIn && singleOut, singleIn, singleOut: + in = g.bulkAdd(inNodes, in) + default: + in = g.bulkAdd([]*TaskNode{{}}, in) + in = g.bulkAdd(inNodes, in) + } + } + + // make node the left anchor and nextNode the right anchor + if len(out) > 0 { + next := len(g.Nodes) + nextNode := &TaskNode{ + Tasks: nil, + Out: out, + } + g.Nodes = append(g.Nodes, nextNode) + for _, j := range out { + g.Nodes[j].replaceIn(i, next) + } + for _, j := range in { + g.Nodes[j].Out = []int{next} + nextNode.In = append(nextNode.In, j) + } + } + } +} + +func (g *TaskGraph) Roots() []int { + var roots []int + for i, n := range g.Nodes { + if len(n.In) > 0 { + continue + } + roots = append(roots, i) + } + return roots +} + +func (g *TaskGraph) Tree() string { + roots := g.Roots() + visited := make([]int, len(g.Nodes)) + stage := 0 + var out []string + var depth []int + for len(roots) > 0 { + depth = append(depth, 0) + for _, i := range roots { + visited[i] = 1 + if d := len(g.Nodes[i].Tasks); d > depth[len(depth)-1] { + depth[len(depth)-1] = d + } + out = append(out, fmt.Sprintf("%d: %d %s in=%v out=%v", stage, i, g.Nodes[i], g.Nodes[i].In, g.Nodes[i].Out)) + } + roots = roots[0:0] + for i, b := range visited { + if b == 1 || !covers(visited, g.Nodes[i].In) { + continue + } + roots = append(roots, i) + } + stage++ + } + for i, b := range visited { + if b == 1 { + continue + } + out = append(out, fmt.Sprintf("unreachable: %d %s in=%v out=%v", i, g.Nodes[i], g.Nodes[i].In, g.Nodes[i].Out)) + } + var totalDepth int + var levels []string + for _, d := range depth { + levels = append(levels, strconv.Itoa(d)) + totalDepth += d + } + out = append(out, fmt.Sprintf("summary: depth=%d, levels=%s", totalDepth, strings.Join(levels, ","))) + return strings.Join(out, "\n") +} + +func covers(all []int, some []int) bool { + for _, i := range some { + if all[i] == 0 { + return false + } + } + return true +} + +func (g *TaskGraph) bulkAdd(nodes []*TaskNode, inNodes []int) []int { + from := len(g.Nodes) + g.Nodes = append(g.Nodes, nodes...) + to := len(g.Nodes) + if len(inNodes) == 0 { + toNodes := make([]int, to-from) + for k := from; k < to; k++ { + toNodes[k-from] = k + } + return toNodes + } + + next := make([]int, to-from) + for k := from; k < to; k++ { + g.Nodes[k].In = append([]int(nil), inNodes...) + next[k-from] = k + } + for _, k := range inNodes { + g.Nodes[k].appendOut(next...) + } + return next +} + +type runTasks struct { + index int + tasks []*Task +} + +type taskStatus struct { + index int + success bool +} + +func RunGraph(ctx context.Context, graph *TaskGraph, maxParallelism int, fn func(ctx context.Context, tasks []*Task) error) error { + nestedCtx, cancelFn := context.WithCancel(ctx) + defer cancelFn() + + // This goroutine takes nodes from the graph as they are available (their prereq has completed) and + // sends them to workCh. It uses completeCh to know that a previously dispatched item is complete. + completeCh := make(chan taskStatus, maxParallelism) + defer close(completeCh) + + workCh := make(chan runTasks, maxParallelism) + go func() { + defer close(workCh) + + // visited tracks nodes we have not sent (0), are currently + // waiting for completion (1), or have completed (2,3) + const ( + nodeNotVisited int = iota + nodeWorking + nodeFailed + nodeComplete + ) + visited := make([]int, len(graph.Nodes)) + canVisit := func(node *TaskNode) bool { + for _, previous := range node.In { + switch visited[previous] { + case nodeFailed, nodeWorking, nodeNotVisited: + return false + } + } + return true + } + + remaining := len(graph.Nodes) + var inflight int + for { + found := 0 + + // walk the graph, filling the work queue + for i := 0; i < len(visited); i++ { + if visited[i] != nodeNotVisited { + continue + } + if canVisit(graph.Nodes[i]) { + select { + case workCh <- runTasks{index: i, tasks: graph.Nodes[i].Tasks}: + visited[i] = nodeWorking + found++ + inflight++ + default: + break + } + } + } + + // try to empty the done channel + for len(completeCh) > 0 { + finished := <-completeCh + if finished.success { + visited[finished.index] = nodeComplete + } else { + visited[finished.index] = nodeFailed + } + remaining-- + inflight-- + found++ + } + + if found > 0 { + continue + } + + // no more work to hand out + if remaining == 0 { + glog.V(4).Infof("Graph is complete") + return + } + + // we walked the entire graph, there are still nodes remaining, but we're not waiting + // for anything + if inflight == 0 && found == 0 { + glog.V(4).Infof("No more reachable nodes in graph, continue") + break + } + + // we did nothing this round, so we have to wait for more + finished, ok := <-completeCh + if !ok { + // we've been aborted + glog.V(4).Infof("Stopped graph walker due to cancel") + return + } + if finished.success { + visited[finished.index] = nodeComplete + } else { + visited[finished.index] = nodeFailed + } + remaining-- + inflight-- + } + + // take everything remaining and process in order + var unreachable []*Task + for i := 0; i < len(visited); i++ { + if visited[i] == nodeNotVisited && canVisit(graph.Nodes[i]) { + unreachable = append(unreachable, graph.Nodes[i].Tasks...) + } + } + if len(unreachable) > 0 { + sort.Slice(unreachable, func(i, j int) bool { + a, b := unreachable[i], unreachable[j] + return a.Index < b.Index + }) + workCh <- runTasks{index: -1, tasks: unreachable} + glog.V(4).Infof("Waiting for last tasks") + <-completeCh + } + glog.V(4).Infof("No more work") + }() + + errCh := make(chan error, maxParallelism) + wg := sync.WaitGroup{} + if maxParallelism < 1 { + maxParallelism = 1 + } + for i := 0; i < maxParallelism; i++ { + wg.Add(1) + go func(job int) { + defer utilruntime.HandleCrash() + defer wg.Done() + for { + select { + case <-nestedCtx.Done(): + glog.V(4).Infof("Canceled worker %d", job) + return + case runTask, ok := <-workCh: + if !ok { + glog.V(4).Infof("No more work for %d", job) + return + } + glog.V(4).Infof("Running %d on %d", runTask.index, job) + err := fn(nestedCtx, runTask.tasks) + completeCh <- taskStatus{index: runTask.index, success: err == nil} + if err != nil { + errCh <- err + } + } + } + }(i) + } + go func() { + glog.V(4).Infof("Waiting for workers to complete") + wg.Wait() + glog.V(4).Infof("Workers finished") + close(errCh) + }() + + var errs []error + for err := range errCh { + errs = append(errs, err) + } + glog.V(4).Infof("Result of work: %v", errs) + if len(errs) > 0 { + return errs[0] + } + return nil +} diff --git a/pkg/payload/task_graph_test.go b/pkg/payload/task_graph_test.go new file mode 100644 index 000000000..61d3fe9e1 --- /dev/null +++ b/pkg/payload/task_graph_test.go @@ -0,0 +1,641 @@ +package payload + +import ( + "context" + "fmt" + "math/rand" + "os" + "reflect" + "sort" + "strings" + "sync" + "testing" + "time" + + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/diff" + + "github.com/openshift/cluster-version-operator/lib" +) + +func Test_TaskGraph_Split(t *testing.T) { + var ( + pod = schema.GroupVersionKind{Kind: "Pod", Version: "v1"} + job = schema.GroupVersionKind{Kind: "Job", Version: "v1", Group: "batch"} + ) + tasks := func(gvks ...schema.GroupVersionKind) []*Task { + var arr []*Task + for _, gvk := range gvks { + arr = append(arr, &Task{Manifest: &lib.Manifest{GVK: gvk}}) + } + return arr + } + tests := []struct { + name string + nodes []*TaskNode + onFn func(task *Task) bool + expect []*TaskNode + }{ + { + nodes: []*TaskNode{}, + onFn: SplitOnJobs, + expect: []*TaskNode{}, + }, + { + nodes: []*TaskNode{ + {Tasks: tasks(pod)}, + }, + onFn: SplitOnJobs, + expect: []*TaskNode{ + {Tasks: tasks(pod)}, + }, + }, + { + name: "split right", + nodes: []*TaskNode{ + {Tasks: tasks(job, pod)}, + }, + onFn: SplitOnJobs, + expect: []*TaskNode{ + {Tasks: tasks(job), Out: []int{1}}, + {Tasks: tasks(pod), In: []int{0}}, + }, + }, + { + name: "split left", + nodes: []*TaskNode{ + {Tasks: tasks(pod, job)}, + }, + onFn: SplitOnJobs, + expect: []*TaskNode{ + {Tasks: tasks(job), In: []int{1}}, + {Tasks: tasks(pod), Out: []int{0}}, + }, + }, + { + name: "interior", + nodes: []*TaskNode{ + {Tasks: tasks(pod, pod, job, pod)}, + }, + onFn: SplitOnJobs, + expect: []*TaskNode{ + {Tasks: tasks(job), In: []int{1}, Out: []int{2}}, + {Tasks: tasks(pod, pod), Out: []int{0}}, + {In: []int{0}, Tasks: tasks(pod)}, + }, + }, + { + name: "interspersed", + nodes: []*TaskNode{ + {Tasks: tasks(pod, pod, job, pod, job, pod)}, + }, + onFn: SplitOnJobs, + expect: []*TaskNode{ + {Tasks: tasks(job), In: []int{1}, Out: []int{3}}, + {Tasks: tasks(pod, pod), Out: []int{0}}, + {Tasks: tasks(job), In: []int{3}, Out: []int{4}}, + {Tasks: tasks(pod), In: []int{0}, Out: []int{2}}, + {In: []int{2}, Tasks: tasks(pod)}, + }, + }, + { + name: "ends", + nodes: []*TaskNode{ + {Tasks: tasks(job, pod, pod, job)}, + }, + onFn: SplitOnJobs, + expect: []*TaskNode{ + {Tasks: tasks(job), Out: []int{2}}, + {Tasks: tasks(job), In: []int{2}}, + {Tasks: tasks(pod, pod), In: []int{0}, Out: []int{1}}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := &TaskGraph{ + Nodes: tt.nodes, + } + g.Split(tt.onFn) + if !reflect.DeepEqual(g.Nodes, tt.expect) { + t.Fatalf("unexpected:\n%s\n%s", (&TaskGraph{Nodes: tt.expect}).Tree(), g.Tree()) + } + }) + } +} + +func TestByNumberAndComponent(t *testing.T) { + tasks := func(names ...string) []*Task { + var arr []*Task + for _, name := range names { + arr = append(arr, &Task{Manifest: &lib.Manifest{OriginalFilename: name}}) + } + return arr + } + tests := []struct { + name string + tasks []*Task + want [][]*TaskNode + }{ + { + name: "empty tasks", + tasks: tasks(), + want: nil, + }, + { + name: "no grouping possible", + tasks: tasks("a"), + want: nil, + }, + { + name: "no recognizable groups", + tasks: tasks("a", "b", "c"), + want: [][]*TaskNode{ + { + &TaskNode{Tasks: tasks("a", "b", "c")}, + }, + }, + }, + { + name: "single grouped item", + tasks: tasks("0000_01_x-y-z_file1"), + want: nil, + }, + { + name: "multiple grouped items in single node", + tasks: tasks("0000_01_x-y-z_file1", "0000_01_x-y-z_file2"), + want: [][]*TaskNode{ + { + &TaskNode{Tasks: tasks("0000_01_x-y-z_file1", "0000_01_x-y-z_file2")}, + }, + }, + }, + { + tasks: tasks("a", "0000_01_x-y-z_file1", "c"), + want: [][]*TaskNode{ + { + &TaskNode{Tasks: tasks("a", "0000_01_x-y-z_file1", "c")}, + }, + }, + }, + { + tasks: tasks("0000_01_x-y-z_file1", "0000_01_x-y-z_file2"), + want: [][]*TaskNode{ + { + &TaskNode{Tasks: tasks("0000_01_x-y-z_file1", "0000_01_x-y-z_file2")}, + }, + }, + }, + { + tasks: tasks("0000_01_a-b-c_file1", "0000_01_x-y-z_file2"), + want: [][]*TaskNode{ + { + &TaskNode{Tasks: tasks("0000_01_a-b-c_file1")}, + &TaskNode{Tasks: tasks("0000_01_x-y-z_file2")}, + }, + }, + }, + { + tasks: tasks( + "0000_01_a-b-c_file1", + "0000_01_x-y-z_file1", + "0000_01_x-y-z_file2", + "a", + "0000_01_x-y-z_file2", + "0000_01_x-y-z_file3", + ), + want: [][]*TaskNode{ + { + &TaskNode{Tasks: tasks( + "0000_01_a-b-c_file1", + )}, + &TaskNode{Tasks: tasks( + "0000_01_x-y-z_file1", + "0000_01_x-y-z_file2", + )}, + }, + { + &TaskNode{Tasks: tasks( + "a", + "0000_01_x-y-z_file2", + "0000_01_x-y-z_file3", + )}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ByNumberAndComponent(tt.tasks); !reflect.DeepEqual(got, tt.want) { + t.Fatalf("%s", diff.ObjectReflectDiff(tt.want, got)) + } + }) + } +} + +func Test_TaskGraph_real(t *testing.T) { + path := os.Getenv("TEST_GRAPH_PATH") + if len(path) == 0 { + t.Skip("TEST_GRAPH_PATH unset") + } + p, err := LoadUpdate(path, "arbitrary/image:1") + if err != nil { + t.Fatal(err) + } + var tasks []*Task + for i := range p.Manifests { + tasks = append(tasks, &Task{ + Manifest: &p.Manifests[i], + }) + } + g := NewTaskGraph(tasks) + g.Split(SplitOnJobs) + g.Parallelize(ByNumberAndComponent) + t.Logf("\n%s", g.Tree()) + t.Logf("original depth: %d", len(tasks)) +} + +func Test_TaskGraph_example(t *testing.T) { + pod := func(name string) *Task { + return &Task{ + Manifest: &lib.Manifest{ + GVK: schema.GroupVersionKind{Kind: "Pod", Version: "v1"}, + OriginalFilename: name, + }, + } + } + job := func(name string) *Task { + return &Task{ + Manifest: &lib.Manifest{ + GVK: schema.GroupVersionKind{Kind: "Job", Version: "v1", Group: "batch"}, + OriginalFilename: name, + }, + } + } + tests := []struct { + name string + tasks []*Task + expect *TaskGraph + }{ + { + tasks: []*Task{pod("a"), job("0000_50_a_0")}, + expect: &TaskGraph{ + Nodes: []*TaskNode{ + {Tasks: []*Task{job("0000_50_a_0")}, In: []int{1}}, + {Tasks: []*Task{pod("a")}, Out: []int{0}}, + }, + }, + }, + { + tasks: []*Task{ + pod("a"), + job("0000_50_a_0"), + pod("0000_50_a_1"), + pod("0000_50_a_2"), + }, + expect: &TaskGraph{ + Nodes: []*TaskNode{ + {Tasks: []*Task{job("0000_50_a_0")}, In: []int{1}, Out: []int{2}}, + {Tasks: []*Task{pod("a")}, Out: []int{0}}, + {Tasks: []*Task{pod("0000_50_a_1"), pod("0000_50_a_2")}, In: []int{0}}, + }, + }, + }, + { + tasks: []*Task{ + job("a"), + pod("0000_50_a_0"), + pod("0000_50_b_0"), + pod("0000_50_b_1"), + }, + expect: &TaskGraph{ + Nodes: []*TaskNode{ + {Tasks: []*Task{job("a")}, Out: []int{1}}, + {In: []int{0}, Out: []int{2, 3}}, + {Tasks: []*Task{pod("0000_50_a_0")}, In: []int{1}}, + {Tasks: []*Task{pod("0000_50_b_0"), pod("0000_50_b_1")}, In: []int{1}}, + }, + }, + }, + { + tasks: []*Task{ + job("a"), + pod("0000_50_a_0"), + pod("0000_50_b_0"), + pod("0000_50_b_1"), + pod("0000_50_c_0"), + job("b"), + }, + expect: &TaskGraph{ + Nodes: []*TaskNode{ + {Tasks: []*Task{job("a")}, Out: []int{2}}, + {Tasks: []*Task{job("b")}, In: []int{6}}, + {In: []int{0}, Out: []int{3, 4, 5}}, + {Tasks: []*Task{pod("0000_50_a_0")}, In: []int{2}, Out: []int{6}}, + {Tasks: []*Task{pod("0000_50_b_0"), pod("0000_50_b_1")}, In: []int{2}, Out: []int{6}}, + {Tasks: []*Task{pod("0000_50_c_0")}, In: []int{2}, Out: []int{6}}, + {In: []int{3, 4, 5}, Out: []int{1}}, + }, + }, + }, + { + tasks: []*Task{ + pod("0000_07_a_0"), + pod("0000_08_a_0"), + pod("0000_09_a_0"), + pod("0000_09_a_1"), + pod("0000_09_b_0"), + pod("0000_09_b_1"), + pod("0000_10_a_0"), + pod("0000_10_a_1"), + pod("0000_11_a_0"), + pod("0000_11_a_1"), + }, + expect: &TaskGraph{ + Nodes: []*TaskNode{ + {Out: []int{1}}, + {Tasks: []*Task{pod("0000_07_a_0"), pod("0000_08_a_0")}, In: []int{0}, Out: []int{2, 3}}, + {Tasks: []*Task{pod("0000_09_a_0"), pod("0000_09_a_1")}, In: []int{1}, Out: []int{4}}, + {Tasks: []*Task{pod("0000_09_b_0"), pod("0000_09_b_1")}, In: []int{1}, Out: []int{4}}, + {Tasks: []*Task{pod("0000_10_a_0"), pod("0000_10_a_1"), pod("0000_11_a_0"), pod("0000_11_a_1")}, In: []int{2, 3}}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewTaskGraph(tt.tasks) + g.Split(SplitOnJobs) + g.Parallelize(ByNumberAndComponent) + if !reflect.DeepEqual(g, tt.expect) { + t.Fatalf("unexpected:\n%s\n---\n%s", tt.expect.Tree(), g.Tree()) + } + }) + } +} + +func Test_TaskGraph_bulkAdd(t *testing.T) { + tasks := func(names ...string) []*Task { + var arr []*Task + for _, name := range names { + arr = append(arr, &Task{Manifest: &lib.Manifest{OriginalFilename: name}}) + } + return arr + } + tests := []struct { + name string + nodes []*TaskNode + add []*TaskNode + in []int + want []int + expect []*TaskNode + }{ + { + nodes: []*TaskNode{ + {Tasks: tasks("a", "b")}, + }, + add: []*TaskNode{ + {Tasks: tasks("c")}, + {Tasks: tasks("d")}, + }, + in: []int{0}, + want: []int{1, 2}, + expect: []*TaskNode{ + {Tasks: tasks("a", "b"), Out: []int{1, 2}}, + {Tasks: tasks("c"), In: []int{0}}, + {Tasks: tasks("d"), In: []int{0}}, + }, + }, + { + nodes: []*TaskNode{ + {Tasks: tasks("a", "b"), Out: []int{1}}, + {Tasks: tasks("e")}, + }, + add: []*TaskNode{ + {Tasks: tasks("c")}, + {Tasks: tasks("d")}, + }, + in: []int{0}, + want: []int{2, 3}, + expect: []*TaskNode{ + {Tasks: tasks("a", "b"), Out: []int{1, 2, 3}}, + {Tasks: tasks("e")}, + {Tasks: tasks("c"), In: []int{0}}, + {Tasks: tasks("d"), In: []int{0}}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := &TaskGraph{ + Nodes: tt.nodes, + } + if got := g.bulkAdd(tt.add, tt.in); !reflect.DeepEqual(got, tt.want) { + t.Errorf("TaskGraph.bulkAdd() = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(tt.expect, g.Nodes) { + t.Errorf("unexpected:\n%s\n---\n%s", (&TaskGraph{Nodes: tt.expect}).Tree(), g.Tree()) + } + }) + } +} + +type safeSlice struct { + lock sync.Mutex + items []string +} + +func (s *safeSlice) Add(item string) { + s.lock.Lock() + defer s.lock.Unlock() + s.items = append(s.items, item) +} + +func TestRunGraph(t *testing.T) { + tasks := func(names ...string) []*Task { + var arr []*Task + for _, name := range names { + arr = append(arr, &Task{Manifest: &lib.Manifest{OriginalFilename: name}}) + } + return arr + } + tests := []struct { + name string + nodes []*TaskNode + parallel int + sleep time.Duration + errorOn func(t *testing.T, name string, ctx context.Context, cancelFn func()) error + + order []string + want []string + invariants func(t *testing.T, got []string) + wantErr string + }{ + { + nodes: []*TaskNode{ + {Tasks: tasks("a", "b")}, + }, + order: []string{"a", "b"}, + }, + { + nodes: []*TaskNode{ + {Tasks: tasks("c"), In: []int{3}}, + {Tasks: tasks("d", "e"), In: []int{3}}, + {Tasks: tasks("f"), In: []int{3}, Out: []int{4}}, + {Tasks: tasks("a", "b"), Out: []int{0, 1, 2}}, + {Tasks: tasks("g"), In: []int{2}}, + }, + want: []string{"a", "b", "c", "d", "e", "f", "g"}, + sleep: time.Millisecond, + parallel: 2, + invariants: func(t *testing.T, got []string) { + for i := 0; i < len(got)-1; i++ { + for j := i + 1; j < len(got); j++ { + a, b := got[i], got[j] + switch { + case a == "b" && b == "a": + t.Fatalf("%d and %d in: %v", i, j, got) + case a == "e" && b == "d": + t.Fatalf("%d and %d in: %v", i, j, got) + case a != "a" && b == "b": + t.Fatalf("%d and %d in: %v", i, j, got) + case a == "g" && (b == "f" || b == "a" || b == "b"): + t.Fatalf("%d and %d in: %v", i, j, got) + } + } + } + }, + }, + { + nodes: []*TaskNode{ + {Tasks: tasks("c"), In: []int{2}}, + {Tasks: tasks("d"), In: []int{2}, Out: []int{3}}, + {Tasks: tasks("a", "b"), Out: []int{0, 1}}, + {Tasks: tasks("e"), In: []int{1}}, + }, + sleep: time.Millisecond, + parallel: 2, + errorOn: func(t *testing.T, name string, ctx context.Context, cancelFn func()) error { + if name == "d" { + return fmt.Errorf("error A") + } + return nil + }, + want: []string{"a", "b", "c"}, + wantErr: "error A", + invariants: func(t *testing.T, got []string) { + for _, s := range got { + if s == "e" { + t.Fatalf("shouldn't have reached e") + } + } + }, + }, + { + nodes: []*TaskNode{ + {Tasks: tasks("c"), In: []int{2}}, + {Tasks: tasks("d"), In: []int{2}, Out: []int{3}}, + {Tasks: tasks("a", "b"), Out: []int{0, 1}}, + {Tasks: tasks("e"), In: []int{1}}, + }, + sleep: time.Millisecond, + parallel: 2, + errorOn: func(t *testing.T, name string, ctx context.Context, cancelFn func()) error { + if name == "d" { + cancelFn() + select { + case <-time.After(time.Second): + t.Fatalf("expected context") + case <-ctx.Done(): + t.Logf("got cancelled context") + return fmt.Errorf("cancelled") + } + return fmt.Errorf("error A") + } + return nil + }, + want: []string{"a", "b", "c"}, + wantErr: "cancelled", + invariants: func(t *testing.T, got []string) { + for _, s := range got { + if s == "e" { + t.Fatalf("shouldn't have reached e") + } + } + }, + }, + { + nodes: []*TaskNode{ + {Tasks: tasks("a"), Out: []int{1}}, + {Tasks: tasks("b"), In: []int{0}, Out: []int{2, 4, 8}}, + {Tasks: tasks("c1"), In: []int{1}, Out: []int{3}}, + {Tasks: tasks("c2"), In: []int{2}, Out: []int{7}}, + {Tasks: tasks("d1"), In: []int{1}, Out: []int{5}}, + {Tasks: tasks("d2"), In: []int{4}, Out: []int{6}}, + {Tasks: tasks("d3"), In: []int{5}, Out: []int{7}}, + {Tasks: tasks("e"), In: []int{3, 6}}, + {Tasks: tasks("f"), In: []int{1}}, + }, + sleep: time.Millisecond, + parallel: 2, + errorOn: func(t *testing.T, name string, ctx context.Context, cancelFn func()) error { + if name == "c1" { + return fmt.Errorf("error - c1") + } + if name == "f" { + return fmt.Errorf("error - f") + } + return nil + }, + want: []string{"a", "b", "d1", "d2", "d3"}, + wantErr: "error -", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := &TaskGraph{ + Nodes: tt.nodes, + } + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + var order safeSlice + err := RunGraph(ctx, g, tt.parallel, func(ctx context.Context, tasks []*Task) error { + for _, task := range tasks { + time.Sleep(tt.sleep * time.Duration(rand.Intn(4))) + if tt.errorOn != nil { + if err := tt.errorOn(t, task.Manifest.OriginalFilename, ctx, cancelFn); err != nil { + return err + } + } + order.Add(task.Manifest.OriginalFilename) + } + return nil + }) + if tt.order != nil { + if !reflect.DeepEqual(tt.order, order.items) { + t.Fatal(diff.ObjectReflectDiff(tt.order, order.items)) + } + } + if tt.invariants != nil { + tt.invariants(t, order.items) + } + if tt.want != nil { + sort.Strings(tt.want) + sort.Strings(order.items) + if !reflect.DeepEqual(tt.want, order.items) { + t.Fatal(diff.ObjectReflectDiff(tt.want, order.items)) + } + } + + if (err != nil) != (tt.wantErr != "") { + t.Fatalf("unexpected error: %v", err) + } + if err != nil { + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("unexpected error: %v", err) + } + return + } + }) + } +}