Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions pkg/controllers/disruption/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,6 @@ func (c *Controller) executeCommand(ctx context.Context, m Method, cmd Command,
commandID := uuid.NewUUID()
log.FromContext(ctx).WithValues(append([]any{"command-id", string(commandID), "reason", strings.ToLower(string(m.Reason()))}, cmd.LogValues()...)...).Info("disrupting node(s)")

c.cluster.MarkForDeletion(lo.Map(cmd.candidates, func(c *Candidate, _ int) string { return c.ProviderID() })...)

// Cordon the old nodes before we launch the replacements to prevent new pods from scheduling to the old nodes
markedCandidates, markDisruptedErr := c.MarkDisrupted(ctx, m, cmd.candidates...)
// If we get a failure marking some nodes as disrupted, if we are launching replacements, we shouldn't continue
Expand All @@ -224,6 +222,14 @@ func (c *Controller) executeCommand(ctx context.Context, m Method, cmd Command,
return serrors.Wrap(fmt.Errorf("launching replacement nodeclaim, %w", err), "command-id", commandID)
}

// IMPORTANT
// We must MarkForDeletion AFTER we launch the replacements and not before
// The reason for this is to avoid producing double-launches
// If we MarkForDeletion before we create replacements, it's possible for the provisioner
// to recognize that it needs to launch capacity for terminating pods, causing us to launch
// capacity for these pods twice instead of just once
c.cluster.MarkForDeletion(lo.Map(cmd.candidates, func(c *Candidate, _ int) string { return c.ProviderID() })...)

// Nominate each node for scheduling and emit pod nomination events
// We emit all nominations before we exit the disruption loop as
// we want to ensure that nodes that are nominated are respected in the subsequent
Expand Down Expand Up @@ -261,6 +267,7 @@ func (c *Controller) createReplacementNodeClaims(ctx context.Context, m Method,
return nodeClaimNames, nil
}

// MarkDisrupted taints the node and adds the Disrupted condition to the NodeClaim for a candidate that is about to be disrupted
func (c *Controller) MarkDisrupted(ctx context.Context, m Method, candidates ...*Candidate) ([]*Candidate, error) {
errs := make([]error, len(candidates))
workqueue.ParallelizeUntil(ctx, len(candidates), len(candidates), func(i int) {
Expand Down
125 changes: 107 additions & 18 deletions pkg/controllers/disruption/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"sort"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -314,21 +315,18 @@ var _ = Describe("Simulate Scheduling", func() {
}

// Get a set of the node claim names so that it's easy to check if a new one is made
nodeClaimNames := lo.SliceToMap(nodeClaims, func(nc *v1.NodeClaim) (string, struct{}) {
return nc.Name, struct{}{}
})
nodeClaimNames := sets.New(lo.Map(nodeClaims, func(nc *v1.NodeClaim, _ int) string { return nc.Name })...)
ExpectSingletonReconciled(ctx, disruptionController)

// Expect a replace action
ExpectTaintedNodeCount(ctx, env.Client, 1)
ncs := ExpectNodeClaims(ctx, env.Client)
// which would create one more node claim
Expect(len(ncs)).To(Equal(11))
nc, new := lo.Find(ncs, func(nc *v1.NodeClaim) bool {
_, ok := nodeClaimNames[nc.Name]
return !ok
nc, ok := lo.Find(ncs, func(nc *v1.NodeClaim) bool {
return !nodeClaimNames.Has(nc.Name)
})
Expect(new).To(BeTrue())
Expect(ok).To(BeTrue())
// which needs to be deployed
ExpectNodeClaimDeployedAndStateUpdated(ctx, env.Client, cluster, cloudProvider, nc)
nodeClaimNames[nc.Name] = struct{}{}
Expand All @@ -337,11 +335,10 @@ var _ = Describe("Simulate Scheduling", func() {
// Another replacement disruption action
ncs = ExpectNodeClaims(ctx, env.Client)
Expect(len(ncs)).To(Equal(12))
nc, new = lo.Find(ncs, func(nc *v1.NodeClaim) bool {
_, ok := nodeClaimNames[nc.Name]
return !ok
nc, ok = lo.Find(ncs, func(nc *v1.NodeClaim) bool {
return !nodeClaimNames.Has(nc.Name)
})
Expect(new).To(BeTrue())
Expect(ok).To(BeTrue())
ExpectNodeClaimDeployedAndStateUpdated(ctx, env.Client, cluster, cloudProvider, nc)
nodeClaimNames[nc.Name] = struct{}{}

Expand All @@ -350,11 +347,10 @@ var _ = Describe("Simulate Scheduling", func() {
// One more replacement disruption action
ncs = ExpectNodeClaims(ctx, env.Client)
Expect(len(ncs)).To(Equal(13))
nc, new = lo.Find(ncs, func(nc *v1.NodeClaim) bool {
_, ok := nodeClaimNames[nc.Name]
return !ok
nc, ok = lo.Find(ncs, func(nc *v1.NodeClaim) bool {
return !nodeClaimNames.Has(nc.Name)
})
Expect(new).To(BeTrue())
Expect(ok).To(BeTrue())
ExpectNodeClaimDeployedAndStateUpdated(ctx, env.Client, cluster, cloudProvider, nc)
nodeClaimNames[nc.Name] = struct{}{}

Expand Down Expand Up @@ -460,6 +456,74 @@ var _ = Describe("Simulate Scheduling", func() {
Expect(nodeclaims[0].Name).ToNot(Equal(nodeClaim.Name))
Expect(nodes[0].Name).ToNot(Equal(node.Name))
})
It("should ensure that we do not duplicate capacity for disrupted nodes with provisioning", func() {
// We create a client that hangs Create() so that when we try to create replacements
// we give ourselves time to check that we wouldn't provision additional capacity before the replacements are made
hangCreateClient := newHangCreateClient(env.Client)
defer hangCreateClient.Stop()

p := provisioning.NewProvisioner(hangCreateClient, recorder, cloudProvider, cluster, fakeClock)
dc := disruption.NewController(fakeClock, env.Client, p, cloudProvider, recorder, cluster, queue)

nodeClaim, node := test.NodeClaimAndNode(v1.NodeClaim{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
v1.NodePoolLabelKey: nodePool.Name,
corev1.LabelInstanceTypeStable: mostExpensiveInstance.Name,
v1.CapacityTypeLabelKey: mostExpensiveOffering.Requirements.Get(v1.CapacityTypeLabelKey).Any(),
corev1.LabelTopologyZone: mostExpensiveOffering.Requirements.Get(corev1.LabelTopologyZone).Any(),
},
},
Status: v1.NodeClaimStatus{
ProviderID: test.RandomProviderID(),
Allocatable: map[corev1.ResourceName]resource.Quantity{
corev1.ResourceCPU: resource.MustParse("32"),
corev1.ResourcePods: resource.MustParse("100"),
},
},
})
nodeClaim.StatusConditions().SetTrue(v1.ConditionTypeDrifted)
labels := map[string]string{
"app": "test",
}
// create our RS so we can link a pod to it
rs := test.ReplicaSet()
ExpectApplied(ctx, env.Client, rs)
Expect(env.Client.Get(ctx, client.ObjectKeyFromObject(rs), rs)).To(Succeed())

pod := test.Pod(test.PodOptions{
ObjectMeta: metav1.ObjectMeta{Labels: labels,
OwnerReferences: []metav1.OwnerReference{
{
APIVersion: "apps/v1",
Kind: "ReplicaSet",
Name: rs.Name,
UID: rs.UID,
Controller: lo.ToPtr(true),
BlockOwnerDeletion: lo.ToPtr(true),
},
}}})

ExpectApplied(ctx, env.Client, rs, pod, nodeClaim, node, nodePool)

// bind the pods to the node
ExpectManualBinding(ctx, env.Client, pod, node)

// inform cluster state about nodes and nodeclaims
ExpectMakeNodesAndNodeClaimsInitializedAndStateUpdated(ctx, env.Client, nodeStateController, nodeClaimStateController, []*corev1.Node{node}, []*v1.NodeClaim{nodeClaim})

// Expect the disruption controller to attempt to create a replacement and hang creation when we try to create the replacement
go ExpectSingletonReconciled(ctx, dc)
Eventually(func(g Gomega) {
g.Expect(hangCreateClient.HasWaiter()).To(BeTrue())
}).Should(Succeed())

// If our code works correctly, the provisioner should not try to create a new NodeClaim since we shouldn't have marked
// our nodes for disruption until the new NodeClaims have been successfully launched
results, err := prov.Schedule(ctx)
Expect(err).ToNot(HaveOccurred())
Expect(results.NewNodeClaims).To(BeEmpty())
})
})

var _ = Describe("Disruption Taints", func() {
Expand Down Expand Up @@ -1745,7 +1809,7 @@ var _ = Describe("Candidate Filtering", func() {

Expect(cluster.Nodes()).To(HaveLen(1))
_, err := disruption.NewCandidate(ctx, env.Client, recorder, fakeClock, cluster.Nodes()[0], pdbLimits, nodePoolMap, nodePoolInstanceTypeMap, queue, disruption.GracefulDisruptionClass)
Expect(err).ToNot((HaveOccurred()))
Expect(err).ToNot(HaveOccurred())
})
It("should consider candidates that have an instance type that cannot be resolved", func() {
nodeClaim, node := test.NodeClaimAndNode(v1.NodeClaim{
Expand Down Expand Up @@ -2102,8 +2166,8 @@ func mostExpensiveInstanceWithZone(zone string) *cloudprovider.InstanceType {
}

//nolint:unparam
func fromInt(i int) *intstr.IntOrString {
v := intstr.FromInt(i)
func fromInt(i int32) *intstr.IntOrString {
v := intstr.FromInt32(i)
return &v
}

Expand Down Expand Up @@ -2212,3 +2276,28 @@ func NewTestingQueue(kubeClient client.Client, recorder events.Recorder, cluster
q.TypedRateLimitingInterface = test.NewTypedRateLimitingInterface[*orchestration.Command](workqueue.TypedQueueConfig[*orchestration.Command]{Name: "disruption.workqueue"})
return q
}

type hangCreateClient struct {
client.Client
hasWaiter atomic.Bool
stop chan struct{}
}

func newHangCreateClient(c client.Client) *hangCreateClient {
return &hangCreateClient{Client: c, stop: make(chan struct{})}
}

func (h *hangCreateClient) HasWaiter() bool {
return h.hasWaiter.Load()
}

func (h *hangCreateClient) Stop() {
close(h.stop)
}

func (h *hangCreateClient) Create(_ context.Context, _ client.Object, _ ...client.CreateOption) error {
h.hasWaiter.Store(true)
<-h.stop
h.hasWaiter.Store(false)
return nil
}
Loading