From 35f49cca69f825593dae71e0ffa257f69c7839a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Chrz=C4=85szcz?= Date: Tue, 3 Jun 2025 11:30:38 +0000 Subject: [PATCH] Prepare code for two-level scheduling --- pkg/cache/tas_flavor_snapshot.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/pkg/cache/tas_flavor_snapshot.go b/pkg/cache/tas_flavor_snapshot.go index 4f0608e3c04..bfd0eec0bb5 100644 --- a/pkg/cache/tas_flavor_snapshot.go +++ b/pkg/cache/tas_flavor_snapshot.go @@ -541,14 +541,14 @@ func (s *TASFlavorSnapshot) findTopologyAssignment( podSetNodeSelectors := info.NodeSelector count := tasPodSetRequests.Count required := isRequired(tasPodSetRequests.PodSet.TopologyRequest) - key := s.levelKeyWithImpliedFallback(&tasPodSetRequests) + topologyKey := s.levelKeyWithImpliedFallback(&tasPodSetRequests) unconstrained := isUnconstrained(tasPodSetRequests.PodSet.TopologyRequest, &tasPodSetRequests) - if key == nil { + if topologyKey == nil { return nil, "topology level not specified" } - levelIdx, found := s.resolveLevelIdx(*key) + levelIdx, found := s.resolveLevelIdx(*topologyKey) if !found { - return nil, fmt.Sprintf("no requested topology level: %s", *key) + return nil, fmt.Sprintf("no requested topology level: %s", *topologyKey) } var selector labels.Selector if s.isLowestLevelNode() { @@ -680,7 +680,7 @@ func isUnconstrained(tr *kueue.PodSetTopologyRequest, tasRequests *TASPodSetRequ func findBestFitDomainIdx(domains []*domain, count int32) int { bestFitIdx := 0 for i, domain := range domains { - if domain.state >= count && domain.state != domains[bestFitIdx].state { + if domain.state >= count && domain.state < domains[bestFitIdx].state { // choose the first occurrence of fitting domains // to make it consecutive with other podSet's bestFitIdx = i @@ -689,7 +689,7 @@ func findBestFitDomainIdx(domains []*domain, count int32) int { return bestFitIdx } -func (s *TASFlavorSnapshot) findLevelWithFitDomains(levelIdx int, required bool, count int32, unconstrained bool) (int, []*domain, string) { +func (s *TASFlavorSnapshot) findLevelWithFitDomains(levelIdx int, required bool, podSetSize int32, unconstrained bool) (int, []*domain, string) { domains := s.domainsPerLevel[levelIdx] if len(domains) == 0 { return 0, nil, fmt.Sprintf("no topology domains at level: %s", s.levelKeys[levelIdx]) @@ -697,19 +697,19 @@ func (s *TASFlavorSnapshot) findLevelWithFitDomains(levelIdx int, required bool, levelDomains := slices.Collect(maps.Values(domains)) sortedDomain := s.sortedDomains(levelDomains, unconstrained) topDomain := sortedDomain[0] - if useBestFitAlgorithm(unconstrained) && topDomain.state >= count { + if useBestFitAlgorithm(unconstrained) && topDomain.state >= podSetSize { // optimize the potentially last domain - topDomain = sortedDomain[findBestFitDomainIdx(sortedDomain, count)] + topDomain = sortedDomain[findBestFitDomainIdx(sortedDomain, podSetSize)] } - if topDomain.state < count { + if topDomain.state < podSetSize { if required { - return 0, nil, s.notFitMessage(topDomain.state, count) + return 0, nil, s.notFitMessage(topDomain.state, podSetSize) } if levelIdx > 0 && !unconstrained { - return s.findLevelWithFitDomains(levelIdx-1, required, count, unconstrained) + return s.findLevelWithFitDomains(levelIdx-1, required, podSetSize, unconstrained) } results := []*domain{} - remainingCount := count + remainingCount := podSetSize for idx := 0; remainingCount > 0 && idx < len(sortedDomain) && sortedDomain[idx].state > 0; idx++ { offset := 0 if useBestFitAlgorithm(unconstrained) && sortedDomain[idx].state >= remainingCount { @@ -717,10 +717,10 @@ func (s *TASFlavorSnapshot) findLevelWithFitDomains(levelIdx int, required bool, offset = findBestFitDomainIdx(sortedDomain[idx:], remainingCount) } results = append(results, sortedDomain[idx+offset]) - remainingCount -= sortedDomain[idx].state + remainingCount -= sortedDomain[idx+offset].state } if remainingCount > 0 { - return 0, nil, s.notFitMessage(count-remainingCount, count) + return 0, nil, s.notFitMessage(podSetSize-remainingCount, podSetSize) } return levelIdx, results, "" }