From fc36f933dedad2a7596de81782e3db3a1c13762a Mon Sep 17 00:00:00 2001 From: Mykhailo Bobrovskyi Date: Wed, 13 Nov 2024 11:18:16 +0200 Subject: [PATCH] Allow mutating queue name in StatefulSet Webhook. --- .../jobs/statefulset/statefulset_webhook.go | 34 ++++---- .../statefulset/statefulset_webhook_test.go | 83 +++++-------------- pkg/util/testingjobs/statefulset/wrappers.go | 5 ++ test/e2e/singlecluster/statefulset_test.go | 81 ++++++++++++++++++ .../webhook/jobs/statefulset_webhook_test.go | 61 ++------------ 5 files changed, 131 insertions(+), 133 deletions(-) diff --git a/pkg/controller/jobs/statefulset/statefulset_webhook.go b/pkg/controller/jobs/statefulset/statefulset_webhook.go index b4cea35935..afdb8ae0f2 100644 --- a/pkg/controller/jobs/statefulset/statefulset_webhook.go +++ b/pkg/controller/jobs/statefulset/statefulset_webhook.go @@ -79,11 +79,12 @@ func (wh *Webhook) Default(ctx context.Context, obj runtime.Object) error { ss.Spec.Template.Annotations = make(map[string]string, 1) } ss.Spec.Template.Annotations[pod.SuspendedByParentAnnotation] = FrameworkName - if ss.Spec.Template.Labels == nil { - ss.Spec.Template.Labels = make(map[string]string, 2) - } + queueName := jobframework.QueueNameForObject(ss.Object()) if queueName != "" { + if ss.Spec.Template.Labels == nil { + ss.Spec.Template.Labels = make(map[string]string, 2) + } ss.Spec.Template.Labels[constants.QueueLabel] = queueName ss.Spec.Template.Labels[pod.GroupNameLabel] = GetWorkloadName(ss.Name) ss.Spec.Template.Annotations[pod.GroupTotalCountAnnotation] = fmt.Sprint(ptr.Deref(ss.Spec.Replicas, 1)) @@ -112,12 +113,9 @@ func (wh *Webhook) ValidateCreate(ctx context.Context, obj runtime.Object) (warn } var ( - labelsPath = field.NewPath("metadata", "labels") - queueNameLabelPath = labelsPath.Key(constants.QueueLabel) - replicasPath = field.NewPath("spec", "replicas") - groupNameLabelPath = labelsPath.Key(pod.GroupNameLabel) - podSpecLabelPath = field.NewPath("spec", "template", "metadata", "labels") - podSpecQueueNameLabelPath = podSpecLabelPath.Key(constants.QueueLabel) + labelsPath = field.NewPath("metadata", "labels") + queueNameLabelPath = labelsPath.Key(constants.QueueLabel) + replicasPath = field.NewPath("spec", "replicas") ) func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (warnings admission.Warnings, err error) { @@ -130,17 +128,13 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob oldQueueName := jobframework.QueueNameForObject(oldStatefulSet.Object()) newQueueName := jobframework.QueueNameForObject(newStatefulSet.Object()) - allErrs := apivalidation.ValidateImmutableField(oldQueueName, newQueueName, queueNameLabelPath) - allErrs = append(allErrs, apivalidation.ValidateImmutableField( - newStatefulSet.Spec.Template.GetLabels()[constants.QueueLabel], - oldStatefulSet.Spec.Template.GetLabels()[constants.QueueLabel], - podSpecQueueNameLabelPath, - )...) - allErrs = append(allErrs, apivalidation.ValidateImmutableField( - newStatefulSet.GetLabels()[pod.GroupNameLabel], - oldStatefulSet.GetLabels()[pod.GroupNameLabel], - groupNameLabelPath, - )...) + allErrs := jobframework.ValidateQueueName(newStatefulSet.Object()) + + // Prevents updating the queue-name if at least one Pod is not suspended + // or if the queue-name has been deleted. + if oldStatefulSet.Status.ReadyReplicas > 0 || newQueueName == "" { + allErrs = append(allErrs, apivalidation.ValidateImmutableField(oldQueueName, newQueueName, queueNameLabelPath)...) + } oldReplicas := ptr.Deref(oldStatefulSet.Spec.Replicas, 1) newReplicas := ptr.Deref(newStatefulSet.Spec.Replicas, 1) diff --git a/pkg/controller/jobs/statefulset/statefulset_webhook_test.go b/pkg/controller/jobs/statefulset/statefulset_webhook_test.go index e400623bb9..4fedfb22c2 100644 --- a/pkg/controller/jobs/statefulset/statefulset_webhook_test.go +++ b/pkg/controller/jobs/statefulset/statefulset_webhook_test.go @@ -244,76 +244,39 @@ func TestValidateUpdate(t *testing.T) { wantErr: nil, }, "change in queue label": { - oldObj: &appsv1.StatefulSet{ - ObjectMeta: metav1.ObjectMeta{ - Labels: map[string]string{ - constants.QueueLabel: "queue1", - }, - }, - }, - newObj: &appsv1.StatefulSet{ - ObjectMeta: metav1.ObjectMeta{ - Labels: map[string]string{ - constants.QueueLabel: "queue2", - }, - }, - }, - wantErr: field.ErrorList{ - &field.Error{ - Type: field.ErrorTypeInvalid, - Field: queueNameLabelPath.String(), - }, - }.ToAggregate(), + oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Queue("test-queue"). + Obj(), + newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Queue("test-queue-new"). + Obj(), }, - "change in pod template queue label": { - oldObj: &appsv1.StatefulSet{ - Spec: appsv1.StatefulSetSpec{ - Template: corev1.PodTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Labels: map[string]string{ - constants.QueueLabel: "queue1", - }, - }, - }, - }, - }, - newObj: &appsv1.StatefulSet{ - Spec: appsv1.StatefulSetSpec{ - Template: corev1.PodTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Labels: map[string]string{ - constants.QueueLabel: "queue2", - }, - }, - }, - }, - }, + "change in queue label (ReadyReplicas > 0)": { + oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Queue("test-queue"). + ReadyReplicas(1). + Obj(), + newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Queue("test-queue-new"). + ReadyReplicas(1). + Obj(), wantErr: field.ErrorList{ &field.Error{ Type: field.ErrorTypeInvalid, - Field: podSpecQueueNameLabelPath.String(), + Field: queueNameLabelPath.String(), }, }.ToAggregate(), }, - "change in group name label": { - oldObj: &appsv1.StatefulSet{ - ObjectMeta: metav1.ObjectMeta{ - Labels: map[string]string{ - pod.GroupNameLabel: "group1", - }, - }, - }, - newObj: &appsv1.StatefulSet{ - ObjectMeta: metav1.ObjectMeta{ - Labels: map[string]string{ - pod.GroupNameLabel: "group2", - }, - }, - }, + "delete queue name": { + oldObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Queue("test-queue"). + Obj(), + newObj: testingstatefulset.MakeStatefulSet("test-sts", "test-ns"). + Obj(), wantErr: field.ErrorList{ &field.Error{ Type: field.ErrorTypeInvalid, - Field: groupNameLabelPath.String(), + Field: queueNameLabelPath.String(), }, }.ToAggregate(), }, diff --git a/pkg/util/testingjobs/statefulset/wrappers.go b/pkg/util/testingjobs/statefulset/wrappers.go index fe0c4214fd..bede2faa88 100644 --- a/pkg/util/testingjobs/statefulset/wrappers.go +++ b/pkg/util/testingjobs/statefulset/wrappers.go @@ -134,6 +134,11 @@ func (ss *StatefulSetWrapper) Replicas(r int32) *StatefulSetWrapper { return ss } +func (ss *StatefulSetWrapper) ReadyReplicas(r int32) *StatefulSetWrapper { + ss.Status.ReadyReplicas = r + return ss +} + func (ss *StatefulSetWrapper) PodTemplateSpecPodGroupNameLabel( ownerName string, ownerUID types.UID, ownerGVK schema.GroupVersionKind, ) *StatefulSetWrapper { diff --git a/test/e2e/singlecluster/statefulset_test.go b/test/e2e/singlecluster/statefulset_test.go index 44697adc85..8105c8a30b 100644 --- a/test/e2e/singlecluster/statefulset_test.go +++ b/test/e2e/singlecluster/statefulset_test.go @@ -17,6 +17,7 @@ limitations under the License. package e2e import ( + "fmt" "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" appsv1 "k8s.io/api/apps/v1" @@ -25,6 +26,7 @@ import ( "k8s.io/apimachinery/pkg/types" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/kueue/pkg/controller/constants" kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" "sigs.k8s.io/kueue/pkg/controller/jobs/statefulset" @@ -325,5 +327,84 @@ var _ = ginkgo.Describe("StatefulSet integration", func() { }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) }) }) + + ginkgo.It("should allow to change queue name if ReadyReplicas=0", func() { + statefulSet := statefulsettesting.MakeStatefulSet("sts", ns.Name). + Image(util.E2eTestSleepImage, []string{"10m"}). + Request(corev1.ResourceCPU, "100m"). + Replicas(3). + Queue(fmt.Sprintf("%s-invalid", localQueueName)). + Obj() + wlLookupKey := types.NamespacedName{Name: statefulset.GetWorkloadName(statefulSet.Name), Namespace: ns.Name} + + ginkgo.By("Create StatefulSet", func() { + gomega.Expect(k8sClient.Create(ctx, statefulSet)).To(gomega.Succeed()) + }) + + ginkgo.By("Checking that replicas is not ready", func() { + gomega.Consistently(func(g gomega.Gomega) { + createdStatefulSet := &appsv1.StatefulSet{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(0))) + }, util.ConsistentDuration, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Update queue name", func() { + gomega.Eventually(func(g gomega.Gomega) { + createdStatefulSet := &appsv1.StatefulSet{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + createdStatefulSet.Labels[constants.QueueLabel] = localQueueName + g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Waiting for replicas is ready", func() { + gomega.Eventually(func(g gomega.Gomega) { + createdStatefulSet := &appsv1.StatefulSet{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3))) + }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) + }) + + createdWorkload := &kueue.Workload{} + ginkgo.By("Check workload is created", func() { + gomega.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).To(gomega.Succeed()) + }) + + ginkgo.By("Scale down replicas to zero", func() { + gomega.Eventually(func(g gomega.Gomega) { + createdStatefulSet := &appsv1.StatefulSet{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + createdStatefulSet.Spec.Replicas = ptr.To[int32](0) + g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Wait for ReadyReplicas < 3", func() { + gomega.Eventually(func(g gomega.Gomega) { + createdStatefulSet := &appsv1.StatefulSet{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.BeNumerically("<", 3)) + g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed()) + }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Scale up replicas to zero - retry as it may not be possible immediately", func() { + gomega.Eventually(func(g gomega.Gomega) { + createdStatefulSet := &appsv1.StatefulSet{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + createdStatefulSet.Spec.Replicas = ptr.To[int32](3) + g.Expect(k8sClient.Update(ctx, createdStatefulSet)).To(gomega.Succeed()) + }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Waiting for replicas is ready", func() { + gomega.Eventually(func(g gomega.Gomega) { + createdStatefulSet := &appsv1.StatefulSet{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(statefulSet), createdStatefulSet)).To(gomega.Succeed()) + g.Expect(createdStatefulSet.Status.ReadyReplicas).To(gomega.Equal(int32(3))) + }, util.LongTimeout, util.Interval).Should(gomega.Succeed()) + }) + }) }) }) diff --git a/test/integration/webhook/jobs/statefulset_webhook_test.go b/test/integration/webhook/jobs/statefulset_webhook_test.go index 5a3c9e4d2a..73d49ac2e3 100644 --- a/test/integration/webhook/jobs/statefulset_webhook_test.go +++ b/test/integration/webhook/jobs/statefulset_webhook_test.go @@ -22,7 +22,6 @@ import ( appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/discovery" "sigs.k8s.io/controller-runtime/pkg/client" @@ -73,24 +72,14 @@ var _ = ginkgo.Describe("StatefulSet Webhook", func() { }) ginkgo.When("The queue-name label is set", func() { - var ( - statefulset *appsv1.StatefulSet - lookupKey types.NamespacedName - ) - - ginkgo.BeforeEach(func() { - statefulset = testingstatefulset.MakeStatefulSet("statefulset-with-queue-name", ns.Name). - Queue("user-queue"). - Obj() - lookupKey = client.ObjectKeyFromObject(statefulset) - }) - ginkgo.It("Should inject queue name, pod group name to pod template labels, and pod group total count to pod template annotations", func() { - gomega.Expect(k8sClient.Create(ctx, statefulset)).Should(gomega.Succeed()) + sts := testingstatefulset.MakeStatefulSet("sts", ns.Name).Queue("user-queue").Obj() + + gomega.Expect(k8sClient.Create(ctx, sts)).Should(gomega.Succeed()) gomega.Eventually(func(g gomega.Gomega) { createdStatefulSet := &appsv1.StatefulSet{} - g.Expect(k8sClient.Get(ctx, lookupKey, createdStatefulSet)).Should(gomega.Succeed()) + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(sts), createdStatefulSet)).Should(gomega.Succeed()) g.Expect(createdStatefulSet.Spec.Template.Labels[constants.QueueLabel]). To( gomega.Equal("user-queue"), @@ -107,52 +96,18 @@ var _ = ginkgo.Describe("StatefulSet Webhook", func() { "Pod group total count should be injected to pod template annotations", ) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - - ginkgo.By("Updating the statefulset should not allow to change the queue name", func() { - statefulsetToUpdate := &appsv1.StatefulSet{} - gomega.Expect(k8sClient.Get(ctx, lookupKey, statefulsetToUpdate)).Should(gomega.Succeed()) - statefulsetWrapper := &testingstatefulset.StatefulSetWrapper{ - StatefulSet: *statefulsetToUpdate, - } - updatedStatefulSet := statefulsetWrapper. - Queue("another-queue"). - PodTemplateSpecPodGroupNameLabel("another", "another", appsv1.SchemeGroupVersion.WithKind("StatefulSet")). - Obj() - gomega.Expect(k8sClient.Update(ctx, updatedStatefulSet)).To(gomega.HaveOccurred()) - }) - ginkgo.By("Updating the statefulset should not allow to change replicas", func() { - statefulsetToUpdate := &appsv1.StatefulSet{} - gomega.Expect(k8sClient.Get(ctx, lookupKey, statefulsetToUpdate)).Should(gomega.Succeed()) - statefulsetWrapper := &testingstatefulset.StatefulSetWrapper{ - StatefulSet: *statefulsetToUpdate, - } - updatedStatefulSet := statefulsetWrapper. - Replicas(5). - PodTemplateSpecPodGroupNameLabel("another", "another", appsv1.SchemeGroupVersion.WithKind("StatefulSet")). - Obj() - gomega.Expect(k8sClient.Update(ctx, updatedStatefulSet)).To(gomega.HaveOccurred()) - }) }) }) ginkgo.When("The queue-name label is not set", func() { - var ( - statefulset *appsv1.StatefulSet - lookupKey types.NamespacedName - ) - - ginkgo.BeforeEach(func() { - statefulset = testingstatefulset.MakeStatefulSet("statefulset-without-queue-name", ns.Name). - Obj() - lookupKey = client.ObjectKeyFromObject(statefulset) - }) - ginkgo.It("Should not inject queue name to pod template labels", func() { - gomega.Expect(k8sClient.Create(ctx, statefulset)).Should(gomega.Succeed()) + sts := testingstatefulset.MakeStatefulSet("sts", ns.Name).Obj() + + gomega.Expect(k8sClient.Create(ctx, sts)).Should(gomega.Succeed()) gomega.Eventually(func(g gomega.Gomega) { createdStatefulSet := &appsv1.StatefulSet{} - g.Expect(k8sClient.Get(ctx, lookupKey, createdStatefulSet)).Should(gomega.Succeed()) + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(sts), createdStatefulSet)).Should(gomega.Succeed()) g.Expect(createdStatefulSet.Spec.Template.Labels[constants.QueueLabel]). To( gomega.BeEmpty(),