Skip to content

Commit

Permalink
Allow mutating queue name in StatefulSet Webhook.
Browse files Browse the repository at this point in the history
  • Loading branch information
mbobrovskyi committed Jan 3, 2025
1 parent 62e94f7 commit 151aa53
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 131 deletions.
26 changes: 1 addition & 25 deletions pkg/controller/jobs/pod/pod_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package pod
import (
"cmp"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"slices"
Expand Down Expand Up @@ -563,29 +561,7 @@ func getRoleHash(p corev1.Pod) (string, error) {
if roleHash, ok := p.Annotations[RoleHashAnnotation]; ok {
return roleHash, nil
}

shape := map[string]interface{}{
"spec": map[string]interface{}{
"initContainers": containersShape(p.Spec.InitContainers),
"containers": containersShape(p.Spec.Containers),
"nodeSelector": p.Spec.NodeSelector,
"affinity": p.Spec.Affinity,
"tolerations": p.Spec.Tolerations,
"runtimeClassName": p.Spec.RuntimeClassName,
"priority": p.Spec.Priority,
"topologySpreadConstraints": p.Spec.TopologySpreadConstraints,
"overhead": p.Spec.Overhead,
"resourceClaims": p.Spec.ResourceClaims,
},
}

shapeJSON, err := json.Marshal(shape)
if err != nil {
return "", err
}

// Trim hash to 8 characters and return
return fmt.Sprintf("%x", sha256.Sum256(shapeJSON))[:8], nil
return utilpod.GenerateShape(p.Spec)
}

// Load loads all pods in the group
Expand Down
13 changes: 0 additions & 13 deletions pkg/controller/jobs/pod/pod_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,6 @@ func getPodOptions(integrationOpts map[string]any) (*configapi.PodIntegrationOpt

var _ admission.CustomDefaulter = &PodWebhook{}

func containersShape(containers []corev1.Container) (result []map[string]interface{}) {
for _, c := range containers {
result = append(result, map[string]interface{}{
"resources": map[string]interface{}{
"requests": c.Resources.Requests,
},
"ports": c.Ports,
})
}

return result
}

// addRoleHash calculates the role hash and adds it to the pod's annotations
func (p *Pod) addRoleHash() error {
if p.pod.Annotations == nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/jobs/statefulset/statefulset_reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (r *Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (reco
func (r *Reconciler) fetchAndFinalizePods(ctx context.Context, namespace, statefulSetName string) error {
podList := &corev1.PodList{}
if err := r.client.List(ctx, podList, client.InNamespace(namespace), client.MatchingLabels{
pod.GroupNameLabel: GetWorkloadName(statefulSetName),
StatefulSetNameLabel: statefulSetName,
}); err != nil {
return err
}
Expand Down
85 changes: 40 additions & 45 deletions pkg/controller/jobs/statefulset/statefulset_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ import (
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/controller/jobs/pod"
"sigs.k8s.io/kueue/pkg/queue"
utilpod "sigs.k8s.io/kueue/pkg/util/pod"
)

const (
StatefulSetNameLabel = "kueue.x-k8s.io/statefulset-name"
)

type Webhook struct {
Expand Down Expand Up @@ -71,27 +76,33 @@ func (wh *Webhook) Default(ctx context.Context, obj runtime.Object) error {

jobframework.ApplyDefaultLocalQueue(ss.Object(), wh.queues.DefaultLocalQueueExist)
suspend, err := jobframework.WorkloadShouldBeSuspended(ctx, ss.Object(), wh.client, wh.manageJobsWithoutQueueName, wh.managedJobsNamespaceSelector)
if err != nil || !suspend {
return err
}

queueName := jobframework.QueueNameForObject(ss.Object())
if queueName == "" {
return nil
}

if ss.Spec.Template.Labels == nil {
ss.Spec.Template.Labels = make(map[string]string, 3)
}
ss.Spec.Template.Labels[StatefulSetNameLabel] = ss.Name
ss.Spec.Template.Labels[constants.QueueLabel] = queueName
groupName, err := GetWorkloadName(obj.(*appsv1.StatefulSet))
if err != nil {
return err
}
if suspend {
if ss.Spec.Template.Annotations == nil {
ss.Spec.Template.Annotations = make(map[string]string, 1)
}
ss.Spec.Template.Annotations[pod.SuspendedByParentAnnotation] = FrameworkName
if ss.Spec.Template.Labels == nil {
ss.Spec.Template.Labels = make(map[string]string, 2)
}
queueName := jobframework.QueueNameForObject(ss.Object())
if queueName != "" {
ss.Spec.Template.Labels[constants.QueueLabel] = queueName
ss.Spec.Template.Labels[pod.GroupNameLabel] = GetWorkloadName(ss.Name)
ss.Spec.Template.Annotations[pod.GroupTotalCountAnnotation] = fmt.Sprint(ptr.Deref(ss.Spec.Replicas, 1))
ss.Spec.Template.Annotations[pod.GroupFastAdmissionAnnotation] = "true"
ss.Spec.Template.Annotations[pod.GroupServingAnnotation] = "true"
ss.Spec.Template.Annotations[kueuealpha.PodGroupPodIndexLabelAnnotation] = appsv1.PodIndexLabel
}
ss.Spec.Template.Labels[pod.GroupNameLabel] = groupName

if ss.Spec.Template.Annotations == nil {
ss.Spec.Template.Annotations = make(map[string]string, 4)
}
ss.Spec.Template.Annotations[pod.GroupTotalCountAnnotation] = fmt.Sprint(ptr.Deref(ss.Spec.Replicas, 1))
ss.Spec.Template.Annotations[pod.GroupFastAdmissionAnnotation] = "true"
ss.Spec.Template.Annotations[pod.GroupServingAnnotation] = "true"
ss.Spec.Template.Annotations[kueuealpha.PodGroupPodIndexLabelAnnotation] = appsv1.PodIndexLabel

return nil
}
Expand Down Expand Up @@ -130,33 +141,12 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob
oldQueueName := jobframework.QueueNameForObject(oldStatefulSet.Object())
newQueueName := jobframework.QueueNameForObject(newStatefulSet.Object())

allErrs := apivalidation.ValidateImmutableField(oldQueueName, newQueueName, queueNameLabelPath)
allErrs = append(allErrs, apivalidation.ValidateImmutableField(
newStatefulSet.Spec.Template.GetLabels()[constants.QueueLabel],
oldStatefulSet.Spec.Template.GetLabels()[constants.QueueLabel],
podSpecQueueNameLabelPath,
)...)
allErrs = append(allErrs, apivalidation.ValidateImmutableField(
newStatefulSet.GetLabels()[pod.GroupNameLabel],
oldStatefulSet.GetLabels()[pod.GroupNameLabel],
groupNameLabelPath,
)...)

oldReplicas := ptr.Deref(oldStatefulSet.Spec.Replicas, 1)
newReplicas := ptr.Deref(newStatefulSet.Spec.Replicas, 1)

// Allow only scale down to zero and scale up from zero.
// TODO(#3279): Support custom resizes later
if newReplicas != 0 && oldReplicas != 0 {
allErrs = append(allErrs, apivalidation.ValidateImmutableField(
newStatefulSet.Spec.Replicas,
oldStatefulSet.Spec.Replicas,
replicasPath,
)...)
}
allErrs := jobframework.ValidateQueueName(newStatefulSet.Object())

if oldReplicas == 0 && newReplicas > 0 && newStatefulSet.Status.Replicas > 0 {
allErrs = append(allErrs, field.Forbidden(replicasPath, "scaling down is still in progress"))
// Prevents updating the queue-name if at least one Pod is not suspended
// or if the queue-name has been deleted.
if oldStatefulSet.Status.ReadyReplicas > 0 || newQueueName == "" {
allErrs = append(allErrs, apivalidation.ValidateImmutableField(oldQueueName, newQueueName, queueNameLabelPath)...)
}

return warnings, allErrs.ToAggregate()
Expand All @@ -166,7 +156,12 @@ func (wh *Webhook) ValidateDelete(context.Context, runtime.Object) (warnings adm
return nil, nil
}

func GetWorkloadName(statefulSetName string) string {
func GetWorkloadName(sts *appsv1.StatefulSet) (string, error) {
shape, err := utilpod.GenerateShape(sts.Spec.Template.Spec)
if err != nil {
return "", err
}
ownerName := fmt.Sprintf("%s-%s", sts.Name, shape)
// Passing empty UID as it is not available before object creation
return jobframework.GetWorkloadNameForOwnerWithGVK(statefulSetName, "", gvk)
return jobframework.GetWorkloadNameForOwnerWithGVK(ownerName, "", gvk), nil
}
39 changes: 39 additions & 0 deletions pkg/util/pod/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package pod

import (
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"math"
Expand Down Expand Up @@ -104,3 +106,40 @@ func readUIntFromStringBelowBound(value string, bound int) (*int, error) {
}
return ptr.To(int(uintValue)), nil
}

func GenerateShape(podSpec corev1.PodSpec) (string, error) {
shape := map[string]interface{}{
"spec": map[string]interface{}{
"initContainers": containersShape(podSpec.InitContainers),
"containers": containersShape(podSpec.Containers),
"nodeSelector": podSpec.NodeSelector,
"affinity": podSpec.Affinity,
"tolerations": podSpec.Tolerations,
"runtimeClassName": podSpec.RuntimeClassName,
"priority": podSpec.Priority,
"topologySpreadConstraints": podSpec.TopologySpreadConstraints,
"overhead": podSpec.Overhead,
"resourceClaims": podSpec.ResourceClaims,
},
}

shapeJSON, err := json.Marshal(shape)
if err != nil {
return "", err
}

// Trim hash to 8 characters and return
return fmt.Sprintf("%x", sha256.Sum256(shapeJSON))[:8], nil
}

func containersShape(containers []corev1.Container) (result []map[string]interface{}) {
for _, c := range containers {
result = append(result, map[string]interface{}{
"resources": map[string]interface{}{
"requests": c.Resources.Requests,
},
"ports": c.Ports,
})
}
return result
}
Loading

0 comments on commit 151aa53

Please sign in to comment.