Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scheduler: NodeType does not need to be proto-generated #3840

Merged
merged 2 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,65 @@
package schedulerobjects
package internaltypes

import (
"github.com/segmentio/fasthash/fnv1a"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
v1 "k8s.io/api/core/v1"

koTaint "github.com/armadaproject/armada/internal/scheduler/kubernetesobjects/taint"
)

// NodeType represents a particular combination of taints and labels.
// The scheduler groups nodes by node type. When assigning pods to nodes,
// the scheduler only considers nodes with a NodeType for which the taints and labels match.
// Its fields should be immutable! Do not change these!
type NodeType struct {
// Unique identifier. Used for map lookup.
id uint64
// Kubernetes taints.
// To reduce the number of distinct node types,
// may contain only a subset of the taints of the node the node type is created from.
taints []v1.Taint
// Kubernetes labels.
// To reduce the number of distinct node types,
// may contain only a subset of the labels of the node the node type is created from.
labels map[string]string
// Well-known labels not set by this node type.
// Used to filter out nodes when looking for nodes for a pod
// that requires at least one well-known label to be set.
unsetIndexedLabels map[string]string
}

func (m *NodeType) GetId() uint64 {
return m.id
}

func (m *NodeType) GetTaints() []v1.Taint {
return koTaint.DeepCopyTaints(m.taints)
}

func (m *NodeType) FindMatchingUntoleratedTaint(tolerations ...[]v1.Toleration) (v1.Taint, bool) {
return koTaint.FindMatchingUntoleratedTaint(m.taints, tolerations...)
}

func (m *NodeType) GetLabels() map[string]string {
return deepCopyLabels(m.labels)
}

func (m *NodeType) GetLabelValue(key string) (string, bool) {
val, ok := m.labels[key]
return val, ok
}

func (m *NodeType) GetUnsetIndexedLabels() map[string]string {
return deepCopyLabels(m.unsetIndexedLabels)
}

func (m *NodeType) GetUnsetIndexedLabelValue(key string) (string, bool) {
val, ok := m.unsetIndexedLabels[key]
return val, ok
}

type (
taintsFilterFunc func(*v1.Taint) bool
labelsFilterFunc func(key, value string) bool
Expand Down Expand Up @@ -63,10 +116,10 @@ func NewNodeType(taints []v1.Taint, labels map[string]string, indexedTaints map[
}

return &NodeType{
Id: nodeTypeIdFromTaintsAndLabels(taints, labels, unsetIndexedLabels),
Taints: taints,
Labels: labels,
UnsetIndexedLabels: unsetIndexedLabels,
id: nodeTypeIdFromTaintsAndLabels(taints, labels, unsetIndexedLabels),
taints: taints,
labels: labels,
unsetIndexedLabels: unsetIndexedLabels,
}
}

Expand Down Expand Up @@ -139,15 +192,3 @@ func getFilteredLabels(labels map[string]string, inclusionFilter labelsFilterFun
}
return filteredLabels
}

func (nodeType *NodeType) DeepCopy() *NodeType {
if nodeType == nil {
return nil
}
return &NodeType{
Id: nodeType.Id,
Taints: slices.Clone(nodeType.Taints),
Labels: maps.Clone(nodeType.Labels),
UnsetIndexedLabels: maps.Clone(nodeType.UnsetIndexedLabels),
}
}
92 changes: 92 additions & 0 deletions internal/scheduler/internaltypes/node_type_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package internaltypes

import (
"testing"

"github.com/stretchr/testify/assert"
v1 "k8s.io/api/core/v1"
)

func TestNodeType_GetId(t *testing.T) {
nodeType := makeSut()

assert.True(t, nodeType.GetId() != 0)
}

func TestNodeType_GetTaints(t *testing.T) {
nodeType := makeSut()

assert.Equal(t,
[]v1.Taint{
{Key: "taint1", Value: "value1", Effect: v1.TaintEffectNoSchedule},
{Key: "taint2", Value: "value2", Effect: v1.TaintEffectNoSchedule},
},
nodeType.GetTaints(),
)
}

func TestNodeType_FindMatchingUntoleratedTaint(t *testing.T) {
nodeType := makeSut()
taint, ok := nodeType.FindMatchingUntoleratedTaint([]v1.Toleration{{Key: "taint1", Operator: v1.TolerationOpExists, Effect: v1.TaintEffectNoSchedule}})

assert.True(t, ok)
assert.Equal(t,
v1.Taint{Key: "taint2", Value: "value2", Effect: v1.TaintEffectNoSchedule},
taint)
}

func TestNodeTypeLabels(t *testing.T) {
nodeType := makeSut()

assert.Equal(t,
map[string]string{
"label1": "value1",
"label2": "value2",
},
nodeType.GetLabels(),
)

val1, ok1 := nodeType.GetLabelValue("label1")
assert.Equal(t, val1, "value1")
assert.True(t, ok1)

val2, ok2 := nodeType.GetLabelValue("not-there")
assert.Equal(t, val2, "")
assert.False(t, ok2)

assert.Equal(t,
map[string]string{
"label3": "",
},
nodeType.GetUnsetIndexedLabels(),
)

val3, ok3 := nodeType.GetUnsetIndexedLabelValue("label3")
assert.Equal(t, val3, "")
assert.True(t, ok3)

val4, ok4 := nodeType.GetUnsetIndexedLabelValue("not-there")
assert.Equal(t, val4, "")
assert.False(t, ok4)
}

func makeSut() *NodeType {
taints := []v1.Taint{
{Key: "taint1", Value: "value1", Effect: v1.TaintEffectNoSchedule},
{Key: "not-indexed-taint", Value: "not-indexed-taint-value", Effect: v1.TaintEffectNoSchedule},
{Key: "taint2", Value: "value2", Effect: v1.TaintEffectNoSchedule},
}

labels := map[string]string{
"label1": "value1",
"label2": "value2",
"not-indexed-label;": "not-indexed-label-value",
}

return NewNodeType(
taints,
labels,
map[string]interface{}{"taint1": true, "taint2": true, "taint3": true},
map[string]interface{}{"label1": true, "label2": true, "label3": true},
)
}
20 changes: 10 additions & 10 deletions internal/scheduler/nodedb/nodedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (nodeDb *NodeDb) create(node *schedulerobjects.Node) (*internaltypes.Node,

totalResources := node.TotalResources

nodeType := schedulerobjects.NewNodeType(
nodeType := internaltypes.NewNodeType(
taints,
labels,
nodeDb.indexedTaints,
Expand Down Expand Up @@ -77,14 +77,14 @@ func (nodeDb *NodeDb) create(node *schedulerobjects.Node) (*internaltypes.Node,
}
index := uint64(nodeDb.numNodes)
nodeDb.numNodes++
nodeDb.numNodesByNodeType[nodeType.Id]++
nodeDb.numNodesByNodeType[nodeType.GetId()]++
nodeDb.totalResources.Add(totalResources)
nodeDb.nodeTypes[nodeType.Id] = nodeType
nodeDb.nodeTypes[nodeType.GetId()] = nodeType
nodeDb.mu.Unlock()

return internaltypes.CreateNode(
node.Id,
nodeType.Id,
nodeType.GetId(),
index,
node.Executor,
node.Name,
Expand Down Expand Up @@ -193,7 +193,7 @@ type NodeDb struct {
totalResources schedulerobjects.ResourceList
// Set of node types. Populated automatically as nodes are inserted.
// Node types are not cleaned up if all nodes of that type are removed from the NodeDb.
nodeTypes map[uint64]*schedulerobjects.NodeType
nodeTypes map[uint64]*internaltypes.NodeType

wellKnownNodeTypes map[string]*configuration.WellKnownNodeType

Expand Down Expand Up @@ -267,7 +267,7 @@ func NewNodeDb(
indexedTaints: mapFromSlice(indexedTaints),
indexedNodeLabels: mapFromSlice(indexedNodeLabels),
indexedNodeLabelValues: indexedNodeLabelValues,
nodeTypes: make(map[uint64]*schedulerobjects.NodeType),
nodeTypes: make(map[uint64]*internaltypes.NodeType),
wellKnownNodeTypes: make(map[string]*configuration.WellKnownNodeType),
numNodesByNodeType: make(map[uint64]int),
totalResources: schedulerobjects.ResourceList{Resources: make(map[string]resource.Quantity)},
Expand Down Expand Up @@ -353,7 +353,7 @@ func (nodeDb *NodeDb) String() string {
} else {
fmt.Fprint(w, "Node types:\n")
for _, nodeType := range nodeDb.nodeTypes {
fmt.Fprintf(w, " %d\n", nodeType.Id)
fmt.Fprintf(w, " %d\n", nodeType.GetId())
}
}
w.Flush()
Expand Down Expand Up @@ -1069,12 +1069,12 @@ func (nodeDb *NodeDb) NodeTypesMatchingJob(jctx *schedulercontext.JobSchedulingC
for _, nodeType := range nodeDb.nodeTypes {
matches, reason := NodeTypeJobRequirementsMet(nodeType, jctx)
if matches {
matchingNodeTypeIds = append(matchingNodeTypeIds, nodeType.Id)
matchingNodeTypeIds = append(matchingNodeTypeIds, nodeType.GetId())
} else if reason != nil {
s := nodeDb.stringFromPodRequirementsNotMetReason(reason)
numExcludedNodesByReason[s] += nodeDb.numNodesByNodeType[nodeType.Id]
numExcludedNodesByReason[s] += nodeDb.numNodesByNodeType[nodeType.GetId()]
} else {
numExcludedNodesByReason[PodRequirementsNotMetReasonUnknown] += nodeDb.numNodesByNodeType[nodeType.Id]
numExcludedNodesByReason[PodRequirementsNotMetReasonUnknown] += nodeDb.numNodesByNodeType[nodeType.GetId()]
}
}
return matchingNodeTypeIds, numExcludedNodesByReason, nil
Expand Down
4 changes: 2 additions & 2 deletions internal/scheduler/nodedb/nodeiteration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -957,11 +957,11 @@ func gpuNodeTypeLabelToNodeTypeId(nodeTypeLabel string) uint64 {
}

func labelsToNodeTypeId(labels map[string]string) uint64 {
nodeType := schedulerobjects.NewNodeType(
nodeType := internaltypes.NewNodeType(
[]v1.Taint{},
labels,
mapFromSlice(testfixtures.TestIndexedTaints),
mapFromSlice(testfixtures.TestIndexedNodeLabels),
)
return nodeType.Id
return nodeType.GetId()
}
26 changes: 9 additions & 17 deletions internal/scheduler/nodedb/nodematching.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (

schedulercontext "github.com/armadaproject/armada/internal/scheduler/context"
"github.com/armadaproject/armada/internal/scheduler/internaltypes"
koTaint "github.com/armadaproject/armada/internal/scheduler/kubernetesobjects/taint"
"github.com/armadaproject/armada/internal/scheduler/schedulerobjects"
)

const (
Expand Down Expand Up @@ -126,24 +124,18 @@ func (err *InsufficientResources) String() string {
// NodeTypeJobRequirementsMet determines whether a pod can be scheduled on nodes of this NodeType.
// If the requirements are not met, it returns the reason for why.
// If the requirements can't be parsed, an error is returned.
func NodeTypeJobRequirementsMet(nodeType *schedulerobjects.NodeType, jctx *schedulercontext.JobSchedulingContext) (bool, PodRequirementsNotMetReason) {
matches, reason := TolerationRequirementsMet(nodeType.GetTaints(), jctx.AdditionalTolerations, jctx.PodRequirements.GetTolerations())
func NodeTypeJobRequirementsMet(nodeType *internaltypes.NodeType, jctx *schedulercontext.JobSchedulingContext) (bool, PodRequirementsNotMetReason) {
matches, reason := TolerationRequirementsMet(nodeType, jctx.AdditionalTolerations, jctx.PodRequirements.GetTolerations())
if !matches {
return matches, reason
}

nodeTypeLabels := nodeType.GetLabels()
nodeTypeLabelGetter := func(key string) (string, bool) {
val, ok := nodeTypeLabels[key]
return val, ok
}

matches, reason = NodeSelectorRequirementsMet(nodeTypeLabelGetter, nodeType.GetUnsetIndexedLabels(), jctx.AdditionalNodeSelectors)
matches, reason = NodeSelectorRequirementsMet(nodeType.GetLabelValue, nodeType.GetUnsetIndexedLabelValue, jctx.AdditionalNodeSelectors)
if !matches {
return matches, reason
}

return NodeSelectorRequirementsMet(nodeTypeLabelGetter, nodeType.GetUnsetIndexedLabels(), jctx.PodRequirements.GetNodeSelector())
return NodeSelectorRequirementsMet(nodeType.GetLabelValue, nodeType.GetUnsetIndexedLabelValue, jctx.PodRequirements.GetNodeSelector())
}

// JobRequirementsMet determines whether a job can be scheduled onto this node.
Expand Down Expand Up @@ -202,8 +194,8 @@ func DynamicJobRequirementsMet(allocatableResources internaltypes.ResourceList,
return matches, reason
}

func TolerationRequirementsMet(taints []v1.Taint, tolerations ...[]v1.Toleration) (bool, PodRequirementsNotMetReason) {
untoleratedTaint, hasUntoleratedTaint := koTaint.FindMatchingUntoleratedTaint(taints, tolerations...)
func TolerationRequirementsMet(nodeType *internaltypes.NodeType, tolerations ...[]v1.Toleration) (bool, PodRequirementsNotMetReason) {
untoleratedTaint, hasUntoleratedTaint := nodeType.FindMatchingUntoleratedTaint(tolerations...)
if hasUntoleratedTaint {
return false, &UntoleratedTaint{Taint: untoleratedTaint}
}
Expand All @@ -218,7 +210,7 @@ func NodeTolerationRequirementsMet(node *internaltypes.Node, tolerations ...[]v1
return true, nil
}

func NodeSelectorRequirementsMet(nodeLabelGetter func(string) (string, bool), unsetIndexedLabels, nodeSelector map[string]string) (bool, PodRequirementsNotMetReason) {
func NodeSelectorRequirementsMet(nodeLabelGetter func(string) (string, bool), unsetIndexedLabelGetter func(string) (string, bool), nodeSelector map[string]string) (bool, PodRequirementsNotMetReason) {
for label, podValue := range nodeSelector {
// If the label value differs between nodeLabels and the pod, always return false.
if nodeValue, ok := nodeLabelGetter(label); ok {
Expand All @@ -233,8 +225,8 @@ func NodeSelectorRequirementsMet(nodeLabelGetter func(string) (string, bool), un
// If unsetIndexedLabels is provided, return false only if this label is explicitly marked as not set.
//
// If unsetIndexedLabels is not provided, we assume that nodeLabels contains all labels and return false.
if unsetIndexedLabels != nil {
if _, ok := unsetIndexedLabels[label]; ok {
if unsetIndexedLabelGetter != nil {
if _, ok := unsetIndexedLabelGetter(label); ok {
return false, &MissingLabel{Label: label}
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion internal/scheduler/nodedb/nodematching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ func TestNodeTypeSchedulingRequirementsMet(t *testing.T) {
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
nodeType := schedulerobjects.NewNodeType(
nodeType := internaltypes.NewNodeType(
tc.Taints,
tc.Labels,
tc.IndexedTaints,
Expand Down
Loading