Skip to content

Commit

Permalink
Allow mutating queue name in StatefulSet Webhook.
Browse files Browse the repository at this point in the history
  • Loading branch information
mbobrovskyi committed Jan 9, 2025
1 parent bf4657a commit fc36f93
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 133 deletions.
34 changes: 14 additions & 20 deletions pkg/controller/jobs/statefulset/statefulset_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ func (wh *Webhook) Default(ctx context.Context, obj runtime.Object) error {
ss.Spec.Template.Annotations = make(map[string]string, 1)
}
ss.Spec.Template.Annotations[pod.SuspendedByParentAnnotation] = FrameworkName
if ss.Spec.Template.Labels == nil {
ss.Spec.Template.Labels = make(map[string]string, 2)
}

queueName := jobframework.QueueNameForObject(ss.Object())
if queueName != "" {
if ss.Spec.Template.Labels == nil {
ss.Spec.Template.Labels = make(map[string]string, 2)
}
ss.Spec.Template.Labels[constants.QueueLabel] = queueName
ss.Spec.Template.Labels[pod.GroupNameLabel] = GetWorkloadName(ss.Name)
ss.Spec.Template.Annotations[pod.GroupTotalCountAnnotation] = fmt.Sprint(ptr.Deref(ss.Spec.Replicas, 1))
Expand Down Expand Up @@ -112,12 +113,9 @@ func (wh *Webhook) ValidateCreate(ctx context.Context, obj runtime.Object) (warn
}

var (
labelsPath = field.NewPath("metadata", "labels")
queueNameLabelPath = labelsPath.Key(constants.QueueLabel)
replicasPath = field.NewPath("spec", "replicas")
groupNameLabelPath = labelsPath.Key(pod.GroupNameLabel)
podSpecLabelPath = field.NewPath("spec", "template", "metadata", "labels")
podSpecQueueNameLabelPath = podSpecLabelPath.Key(constants.QueueLabel)
labelsPath = field.NewPath("metadata", "labels")
queueNameLabelPath = labelsPath.Key(constants.QueueLabel)
replicasPath = field.NewPath("spec", "replicas")
)

func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (warnings admission.Warnings, err error) {
Expand All @@ -130,17 +128,13 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob
oldQueueName := jobframework.QueueNameForObject(oldStatefulSet.Object())
newQueueName := jobframework.QueueNameForObject(newStatefulSet.Object())

allErrs := apivalidation.ValidateImmutableField(oldQueueName, newQueueName, queueNameLabelPath)
allErrs = append(allErrs, apivalidation.ValidateImmutableField(
newStatefulSet.Spec.Template.GetLabels()[constants.QueueLabel],
oldStatefulSet.Spec.Template.GetLabels()[constants.QueueLabel],
podSpecQueueNameLabelPath,
)...)
allErrs = append(allErrs, apivalidation.ValidateImmutableField(
newStatefulSet.GetLabels()[pod.GroupNameLabel],
oldStatefulSet.GetLabels()[pod.GroupNameLabel],
groupNameLabelPath,
)...)
allErrs := jobframework.ValidateQueueName(newStatefulSet.Object())

// Prevents updating the queue-name if at least one Pod is not suspended
// or if the queue-name has been deleted.
if oldStatefulSet.Status.ReadyReplicas > 0 || newQueueName == "" {
allErrs = append(allErrs, apivalidation.ValidateImmutableField(oldQueueName, newQueueName, queueNameLabelPath)...)
}

oldReplicas := ptr.Deref(oldStatefulSet.Spec.Replicas, 1)
newReplicas := ptr.Deref(newStatefulSet.Spec.Replicas, 1)
Expand Down
83 changes: 23 additions & 60 deletions pkg/controller/jobs/statefulset/statefulset_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,76 +244,39 @@ func TestValidateUpdate(t *testing.T) {
wantErr: nil,
},
"change in queue label": {
oldObj: &appsv1.StatefulSet{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue1",
},
},
},
newObj: &appsv1.StatefulSet{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue2",
},
},
},
wantErr: field.ErrorList{
&field.Error{
Type: field.ErrorTypeInvalid,
Field: queueNameLabelPath.String(),
},
}.ToAggregate(),
oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns").
Queue("test-queue").
Obj(),
newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns").
Queue("test-queue-new").
Obj(),
},
"change in pod template queue label": {
oldObj: &appsv1.StatefulSet{
Spec: appsv1.StatefulSetSpec{
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue1",
},
},
},
},
},
newObj: &appsv1.StatefulSet{
Spec: appsv1.StatefulSetSpec{
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue2",
},
},
},
},
},
"change in queue label (ReadyReplicas > 0)": {
oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns").
Queue("test-queue").
ReadyReplicas(1).
Obj(),
newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns").
Queue("test-queue-new").
ReadyReplicas(1).
Obj(),
wantErr: field.ErrorList{
&field.Error{
Type: field.ErrorTypeInvalid,
Field: podSpecQueueNameLabelPath.String(),
Field: queueNameLabelPath.String(),
},
}.ToAggregate(),
},
"change in group name label": {
oldObj: &appsv1.StatefulSet{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
pod.GroupNameLabel: "group1",
},
},
},
newObj: &appsv1.StatefulSet{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
pod.GroupNameLabel: "group2",
},
},
},
"delete queue name": {
oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns").
Queue("test-queue").
Obj(),
newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns").
Obj(),
wantErr: field.ErrorList{
&field.Error{
Type: field.ErrorTypeInvalid,
Field: groupNameLabelPath.String(),
Field: queueNameLabelPath.String(),
},
}.ToAggregate(),
},
Expand Down
5 changes: 5 additions & 0 deletions pkg/util/testingjobs/statefulset/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ func (ss *StatefulSetWrapper) Replicas(r int32) *StatefulSetWrapper {
return ss
}

func (ss *StatefulSetWrapper) ReadyReplicas(r int32) *StatefulSetWrapper {
ss.Status.ReadyReplicas = r
return ss
}

func (ss *StatefulSetWrapper) PodTemplateSpecPodGroupNameLabel(
ownerName string, ownerUID types.UID, ownerGVK schema.GroupVersionKind,
) *StatefulSetWrapper {
Expand Down
81 changes: 81 additions & 0 deletions test/e2e/singlecluster/statefulset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package e2e

import (
"fmt"
"github.com/onsi/ginkgo/v2"
"github.com/onsi/gomega"
appsv1 "k8s.io/api/apps/v1"
Expand All @@ -25,6 +26,7 @@ import (
"k8s.io/apimachinery/pkg/types"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/kueue/pkg/controller/constants"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/controller/jobs/statefulset"
Expand Down Expand Up @@ -325,5 +327,84 @@ var _ = ginkgo.Describe("StatefulSet integration", func() {
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
})
})

ginkgo.It("should allow to change queue name if ReadyReplicas=0", func() {
statefulSet := statefulsettesting.MakeStatefulSet("sts", ns.Name).
Image(util.E2eTestSleepImage, []string{"10m"}).
Request(corev1.ResourceCPU, "100m").
Replicas(3).
Queue(fmt.Sprintf("%s-invalid", localQueueName)).
Obj()
wlLookupKey := types.NamespacedName{Name: statefulset.GetWorkloadName(statefulSet.Name), Namespace: ns.Name}

ginkgo.By("Create StatefulSet", func() {
gomega.Expect(k8sClient.Create(ctx, statefulSet)).To(gomega.Succeed())
})

ginkgo.By("Checking that replicas is not ready", func() {
gomega.Consistently(func(g gomega.Gomega) {
createdStatefulSet := &appsv1.StatefulSet{}
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed())
g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(0)))
}, util.ConsistentDuration, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Update queue name", func() {
gomega.Eventually(func(g gomega.Gomega) {
createdStatefulSet := &appsv1.StatefulSet{}
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed())
createdStatefulSet.Labels[constants.QueueLabel] = localQueueName
g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed())
}, util.Timeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Waiting for replicas is ready", func() {
gomega.Eventually(func(g gomega.Gomega) {
createdStatefulSet := &appsv1.StatefulSet{}
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed())
g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3)))
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
})

createdWorkload := &kueue.Workload{}
ginkgo.By("Check workload is created", func() {
gomega.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed())
})

ginkgo.By("Scale down replicas to zero", func() {
gomega.Eventually(func(g gomega.Gomega) {
createdStatefulSet := &appsv1.StatefulSet{}
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed())
createdStatefulSet.Spec.Replicas = ptr.To[int32](0)
g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed())
}, util.Timeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Wait for ReadyReplicas < 3", func() {
gomega.Eventually(func(g gomega.Gomega) {
createdStatefulSet := &appsv1.StatefulSet{}
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed())
g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.BeNumerically("<", 3))
g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed())
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Scale up replicas to zero - retry as it may not be possible immediately", func() {
gomega.Eventually(func(g gomega.Gomega) {
createdStatefulSet := &appsv1.StatefulSet{}
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed())
createdStatefulSet.Spec.Replicas = ptr.To[int32](3)
g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed())
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
})

ginkgo.By("Waiting for replicas is ready", func() {
gomega.Eventually(func(g gomega.Gomega) {
createdStatefulSet := &appsv1.StatefulSet{}
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed())
g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3)))
}, util.LongTimeout, util.Interval).Should(gomega.Succeed())
})
})
})
})
61 changes: 8 additions & 53 deletions test/integration/webhook/jobs/statefulset_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/discovery"
"sigs.k8s.io/controller-runtime/pkg/client"

Expand Down Expand Up @@ -73,24 +72,14 @@ var _ = ginkgo.Describe("StatefulSet Webhook", func() {
})

ginkgo.When("The queue-name label is set", func() {
var (
statefulset *appsv1.StatefulSet
lookupKey types.NamespacedName
)

ginkgo.BeforeEach(func() {
statefulset = testingstatefulset.MakeStatefulSet("statefulset-with-queue-name", ns.Name).
Queue("user-queue").
Obj()
lookupKey = client.ObjectKeyFromObject(statefulset)
})

ginkgo.It("Should inject queue name, pod group name to pod template labels, and pod group total count to pod template annotations", func() {
gomega.Expect(k8sClient.Create(ctx, statefulset)).Should(gomega.Succeed())
sts := testingstatefulset.MakeStatefulSet("sts", ns.Name).Queue("user-queue").Obj()

gomega.Expect(k8sClient.Create(ctx, sts)).Should(gomega.Succeed())

gomega.Eventually(func(g gomega.Gomega) {
createdStatefulSet := &appsv1.StatefulSet{}
g.Expect(k8sClient.Get(ctx, lookupKey, createdStatefulSet)).Should(gomega.Succeed())
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(sts), createdStatefulSet)).Should(gomega.Succeed())
g.Expect(createdStatefulSet.Spec.Template.Labels[constants.QueueLabel]).
To(
gomega.Equal("user-queue"),
Expand All @@ -107,52 +96,18 @@ var _ = ginkgo.Describe("StatefulSet Webhook", func() {
"Pod group total count should be injected to pod template annotations",
)
}, util.Timeout, util.Interval).Should(gomega.Succeed())

ginkgo.By("Updating the statefulset should not allow to change the queue name", func() {
statefulsetToUpdate := &appsv1.StatefulSet{}
gomega.Expect(k8sClient.Get(ctx, lookupKey, statefulsetToUpdate)).Should(gomega.Succeed())
statefulsetWrapper := &testingstatefulset.StatefulSetWrapper{
StatefulSet: *statefulsetToUpdate,
}
updatedStatefulSet := statefulsetWrapper.
Queue("another-queue").
PodTemplateSpecPodGroupNameLabel("another", "another", appsv1.SchemeGroupVersion.WithKind("StatefulSet")).
Obj()
gomega.Expect(k8sClient.Update(ctx, updatedStatefulSet)).To(gomega.HaveOccurred())
})
ginkgo.By("Updating the statefulset should not allow to change replicas", func() {
statefulsetToUpdate := &appsv1.StatefulSet{}
gomega.Expect(k8sClient.Get(ctx, lookupKey, statefulsetToUpdate)).Should(gomega.Succeed())
statefulsetWrapper := &testingstatefulset.StatefulSetWrapper{
StatefulSet: *statefulsetToUpdate,
}
updatedStatefulSet := statefulsetWrapper.
Replicas(5).
PodTemplateSpecPodGroupNameLabel("another", "another", appsv1.SchemeGroupVersion.WithKind("StatefulSet")).
Obj()
gomega.Expect(k8sClient.Update(ctx, updatedStatefulSet)).To(gomega.HaveOccurred())
})
})
})

ginkgo.When("The queue-name label is not set", func() {
var (
statefulset *appsv1.StatefulSet
lookupKey types.NamespacedName
)

ginkgo.BeforeEach(func() {
statefulset = testingstatefulset.MakeStatefulSet("statefulset-without-queue-name", ns.Name).
Obj()
lookupKey = client.ObjectKeyFromObject(statefulset)
})

ginkgo.It("Should not inject queue name to pod template labels", func() {
gomega.Expect(k8sClient.Create(ctx, statefulset)).Should(gomega.Succeed())
sts := testingstatefulset.MakeStatefulSet("sts", ns.Name).Obj()

gomega.Expect(k8sClient.Create(ctx, sts)).Should(gomega.Succeed())

gomega.Eventually(func(g gomega.Gomega) {
createdStatefulSet := &appsv1.StatefulSet{}
g.Expect(k8sClient.Get(ctx, lookupKey, createdStatefulSet)).Should(gomega.Succeed())
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(sts), createdStatefulSet)).Should(gomega.Succeed())
g.Expect(createdStatefulSet.Spec.Template.Labels[constants.QueueLabel]).
To(
gomega.BeEmpty(),
Expand Down

0 comments on commit fc36f93

Please sign in to comment.