diff --git a/pkg/cvo/cvo.go b/pkg/cvo/cvo.go index 1d9310958b..da4d7097df 100644 --- a/pkg/cvo/cvo.go +++ b/pkg/cvo/cvo.go @@ -923,6 +923,7 @@ func hasReachedLevel(cv *configv1.ClusterVersion, desired configv1.Update) bool func (optr *Operator) defaultPreconditionChecks() precondition.List { return []precondition.Precondition{ + preconditioncv.NewRollback(optr.currentVersion), preconditioncv.NewUpgradeable(optr.cvLister), preconditioncv.NewRecentEtcdBackup(optr.cvLister, optr.coLister), preconditioncv.NewRecommendedUpdate(optr.cvLister), diff --git a/pkg/payload/precondition/clusterversion/rollback.go b/pkg/payload/precondition/clusterversion/rollback.go new file mode 100644 index 0000000000..ca7ea46c9a --- /dev/null +++ b/pkg/payload/precondition/clusterversion/rollback.go @@ -0,0 +1,65 @@ +package clusterversion + +import ( + "context" + "fmt" + + "github.com/blang/semver/v4" + configv1 "github.com/openshift/api/config/v1" + + precondition "github.com/openshift/cluster-version-operator/pkg/payload/precondition" +) + +// currentRelease is a callback for returning the version that is currently being reconciled. +type currentRelease func() configv1.Release + +// Rollback blocks rollbacks from the version that is currently being reconciled. +type Rollback struct { + currentRelease +} + +// NewRollback returns a new Rollback precondition check. +func NewRollback(fn currentRelease) *Rollback { + return &Rollback{ + currentRelease: fn, + } +} + +// Name returns Name for the precondition. +func (pf *Rollback) Name() string { return "ClusterVersionRollback" } + +// Run runs the Rollback precondition, blocking rollbacks from the +// version that is currently being reconciled. It returns a +// PreconditionError when possible. +func (p *Rollback) Run(ctx context.Context, releaseContext precondition.ReleaseContext) error { + currentRelease := p.currentRelease() + currentVersion, err := semver.Parse(currentRelease.Version) + if err != nil { + return &precondition.Error{ + Nested: err, + Reason: "InvalidCurrentVersion", + Message: err.Error(), + Name: p.Name(), + } + } + + targetVersion, err := semver.Parse(releaseContext.DesiredVersion) + if err != nil { + return &precondition.Error{ + Nested: err, + Reason: "InvalidDesiredVersion", + Message: err.Error(), + Name: p.Name(), + } + } + + if targetVersion.LT(currentVersion) { + return &precondition.Error{ + Reason: "LowDesiredVersion", + Message: fmt.Sprintf("%s is less than the current target %s, but rollbacks and downgrades are not recommended", targetVersion, currentVersion), + Name: p.Name(), + } + } + + return nil +} diff --git a/pkg/payload/precondition/clusterversion/rollback_test.go b/pkg/payload/precondition/clusterversion/rollback_test.go new file mode 100644 index 0000000000..1fe106784b --- /dev/null +++ b/pkg/payload/precondition/clusterversion/rollback_test.go @@ -0,0 +1,64 @@ +package clusterversion + +import ( + "context" + "testing" + + configv1 "github.com/openshift/api/config/v1" + "github.com/openshift/cluster-version-operator/pkg/payload/precondition" +) + +func TestRollbackRun(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + currVersion string + desiredVersion string + expected string + }{ + { + name: "update", + currVersion: "1.0.0", + desiredVersion: "1.0.1", + expected: "", + }, + { + name: "no change", + currVersion: "1.0.0", + desiredVersion: "1.0.0", + expected: "", + }, + { + name: "rollback", + currVersion: "1.0.1", + desiredVersion: "1.0.0", + expected: "1.0.0 is less than the current target 1.0.1, but rollbacks and downgrades are not recommended", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + instance := NewRollback(func() configv1.Release { + return configv1.Release{ + Version: tc.currVersion, + } + }) + + err := instance.Run(ctx, precondition.ReleaseContext{ + DesiredVersion: tc.desiredVersion, + }) + switch { + case err != nil && len(tc.expected) == 0: + t.Error(err) + case err != nil && err.Error() == tc.expected: + case err != nil && err.Error() != tc.expected: + t.Error(err) + case err == nil && len(tc.expected) == 0: + case err == nil && len(tc.expected) != 0: + t.Error(err) + } + + }) + } +}