Skip to content

Commit

Permalink
Add affinity assignment for TAS
Browse files Browse the repository at this point in the history
Signed-off-by: kerthcet <[email protected]>
  • Loading branch information
kerthcet committed Jan 10, 2025
1 parent 4f61106 commit ba85310
Showing 1 changed file with 85 additions and 20 deletions.
105 changes: 85 additions & 20 deletions pkg/cache/tas_flavor_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ type domain struct {
levelValues []string
}

func (d *domain) in(ids []utiltas.TopologyDomainID) bool {
for _, id := range ids {
if d.id == id {
return true
}
}
return false
}

// leafDomain extends the domain with information for the lowest-level domain.
type leafDomain struct {
domain
Expand Down Expand Up @@ -103,6 +112,9 @@ type TASFlavorSnapshot struct {
// domainsPerLevel stores the static tree information
domainsPerLevel []domainByID

// affinityDomains stores the former assigned domains of the workload.
affinityDomains []*domain

// tolerations represents the list of tolerations defined for the resource flavor
tolerations []corev1.Toleration
}
Expand All @@ -123,6 +135,7 @@ func newTASFlavorSnapshot(log logr.Logger, topologyName kueue.TopologyReference,
domains: make(domainByID),
roots: make(domainByID),
domainsPerLevel: domainsPerLevel,
affinityDomains: make([]*domain, 0),
}
return snapshot
}
Expand Down Expand Up @@ -220,10 +233,11 @@ func (s *TASFlavorSnapshot) addUsage(domainID utiltas.TopologyDomainID, usage re
//
// Phase 2:
//
// a) select the domain at requested level with count >= requestedCount
// b) traverse the structure down level-by-level optimizing the number of used
// domains at each level
// c) build the assignment for the lowest level in the hierarchy
// a) pre-assign the affinityDomains to colocate the podsets
// b) select the domain at requested level with count >= requestedCount
// c) traverse the structure down level-by-level optimizing the number of used
// domains at each level
// d) build the assignment for the lowest level in the hierarchy
func (s *TASFlavorSnapshot) FindTopologyAssignment(
topologyRequest *kueue.PodSetTopologyRequest,
requests resources.Requests,
Expand All @@ -241,22 +255,61 @@ func (s *TASFlavorSnapshot) FindTopologyAssignment(
// phase 1 - determine the number of pods which can fit in each topology domain
s.fillInCounts(requests, append(podSetTolerations, s.tolerations...))

// phase 2a: determine the level at which the assignment is done along with
// phase 2a - pre-assign the affinityDomains to make sure podSets are colocated as much as possible.
var prefilledDomain *domain
var prefilledState int32
var parentDomainIDs []utiltas.TopologyDomainID
if len(s.affinityDomains) > 0 {
newDomains := s.sortedDomains(s.affinityDomains, nil)
// the only affinity domain may have quota left.
prefilledDomain = newDomains[0]
// once the prefilled domain has enough quota, no need to go through the rest of the algorithm.
if prefilledDomain.state >= count {
prefilledDomain.state = count
return s.buildAssignment(newDomains[0:1], requests), ""
}

prefilledState = prefilledDomain.state
// We only need to assign count-prefilledState pods then.
count -= prefilledState

// reset the states of domains which original contains the count of prefilledDomain.
loopDomain := prefilledDomain
for {
loopDomain.state -= prefilledState
if loopDomain.parent == nil {
break
}
loopDomain = loopDomain.parent
parentDomainIDs = append(parentDomainIDs, loopDomain.id)
}
}

// phase 2b: determine the level at which the assignment is done along with
// the domains which can accommodate all pods
fitLevelIdx, currFitDomain, reason := s.findLevelWithFitDomains(levelIdx, required, count)
fitLevelIdx, currFitDomains, reason := s.findLevelWithFitDomains(levelIdx, required, count, parentDomainIDs)
if len(reason) > 0 {
return nil, reason
}

// phase 2b: traverse the tree down level-by-level optimizing the number of
// phase 2c: traverse the tree down level-by-level optimizing the number of
// topology domains at each level
currFitDomain = s.updateCountsToMinimum(currFitDomain, count)
currFitDomains = s.updateDomainCountsToMinimum(currFitDomains, count)
for levelIdx := fitLevelIdx; levelIdx+1 < len(s.domainsPerLevel); levelIdx++ {
lowerFitDomains := s.lowerLevelDomains(currFitDomain)
sortedLowerDomains := s.sortedDomains(lowerFitDomains)
currFitDomain = s.updateCountsToMinimum(sortedLowerDomains, count)
lowerFitDomains := s.lowerLevelDomains(currFitDomains)
sortedLowerDomains := s.sortedDomains(lowerFitDomains, parentDomainIDs)
currFitDomains = s.updateDomainCountsToMinimum(sortedLowerDomains, count)
}
return s.buildAssignment(currFitDomain, requests), ""
if prefilledDomain != nil {
// Restore the state for assignment.
prefilledDomain.state = prefilledState
if prefilledState != 0 {
currFitDomains = append([]*domain{prefilledDomain}, currFitDomains...)
}
}

// phase 2d: build the assignment
return s.buildAssignment(currFitDomains, requests), ""
}

func (s *TASFlavorSnapshot) HasLevel(r *kueue.PodSetTopologyRequest) bool {
Expand Down Expand Up @@ -285,21 +338,24 @@ func levelKey(topologyRequest *kueue.PodSetTopologyRequest) *string {
return nil
}

func (s *TASFlavorSnapshot) findLevelWithFitDomains(levelIdx int, required bool, count int32) (int, []*domain, string) {
func (s *TASFlavorSnapshot) findLevelWithFitDomains(levelIdx int, required bool, count int32, affinityParentIDs []utiltas.TopologyDomainID) (int, []*domain, string) {
domains := s.domainsPerLevel[levelIdx]

if len(domains) == 0 {
return 0, nil, fmt.Sprintf("no topology domains at level: %s", s.levelKeys[levelIdx])
}
levelDomains := utilmaps.Values(domains)
sortedDomain := s.sortedDomains(levelDomains)
sortedDomain := s.sortedDomains(levelDomains, affinityParentIDs)
topDomain := sortedDomain[0]
if topDomain.state < count {
if required {
return 0, nil, s.notFitMessage(topDomain.state, count)
}
if levelIdx > 0 {
return s.findLevelWithFitDomains(levelIdx-1, required, count)
return s.findLevelWithFitDomains(levelIdx-1, required, count, affinityParentIDs)
}

// accumulate the domain states at the highest level which may cross domains.
lastIdx := 0
remainingCount := count - sortedDomain[lastIdx].state
for remainingCount > 0 && lastIdx < len(sortedDomain)-1 && sortedDomain[lastIdx].state > 0 {
Expand All @@ -314,7 +370,7 @@ func (s *TASFlavorSnapshot) findLevelWithFitDomains(levelIdx int, required bool,
return levelIdx, []*domain{topDomain}, ""
}

func (s *TASFlavorSnapshot) updateCountsToMinimum(domains []*domain, count int32) []*domain {
func (s *TASFlavorSnapshot) updateDomainCountsToMinimum(domains []*domain, count int32) []*domain {
result := make([]*domain, 0)
remainingCount := count
for _, domain := range domains {
Expand Down Expand Up @@ -349,15 +405,16 @@ func (s *TASFlavorSnapshot) buildTopologyAssignmentForLevels(domains []*domain,
usage[k] = v * int64(domain.state)
}
s.addUsage(domain.id, usage)
s.affinityDomains = append(s.affinityDomains, domain)
}
return assignment
}

func (s *TASFlavorSnapshot) buildAssignment(domains []*domain, singlePodRequest resources.Requests) *kueue.TopologyAssignment {
// lex sort domains by their levelValues instead of IDs, as leaves' IDs can only contain the hostname
slices.SortFunc(domains, func(a, b *domain) int {
return utilslices.OrderStringSlices(a.levelValues, b.levelValues)
})
// slices.SortFunc(domains, func(a, b *domain) int {
// return utilslices.OrderStringSlices(a.levelValues, b.levelValues)
// })
levelIdx := 0
// assign only hostname values if topology defines it
if s.isLowestLevelNode() {
Expand All @@ -374,10 +431,18 @@ func (s *TASFlavorSnapshot) lowerLevelDomains(domains []*domain) []*domain {
return result
}

func (s *TASFlavorSnapshot) sortedDomains(domains []*domain) []*domain {
func (s *TASFlavorSnapshot) sortedDomains(domains []*domain, affinityParentIDs []utiltas.TopologyDomainID) []*domain {
result := make([]*domain, len(domains))
copy(result, domains)
slices.SortFunc(result, func(a, b *domain) int {
// fall into the affinity domains first if not empty.
if a.in(affinityParentIDs) && a.state != 0 && !b.in(affinityParentIDs) {
return -1
}
if !a.in(affinityParentIDs) && b.in(affinityParentIDs) && b.state != 0 {
return 1
}

switch {
case a.state == b.state:
return utilslices.OrderStringSlices(a.levelValues, b.levelValues)
Expand Down

0 comments on commit ba85310

Please sign in to comment.