Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add MarkIfFlagPresentThenOthersRequired #2200

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions flag_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import (
)

const (
requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
oneRequiredAnnotation = "cobra_annotation_one_required"
mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
oneRequiredAnnotation = "cobra_annotation_one_required"
mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
ifPresentThenOthersRequiredAnnotation = "cobra_annotation_if_present_then_others_required"
Comment on lines +26 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you rename them in a way they share a common prefix, not suffix

These constants are defined at the package level, so IDE will suggest them in any method when coding in that package

I would like to see something like this

Suggested change
requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
oneRequiredAnnotation = "cobra_annotation_one_required"
mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
ifPresentThenOthersRequiredAnnotation = "cobra_annotation_if_present_then_others_required"
annotationGroupRequired = "cobra_annotation_required_if_others_set"
annotationRequiredOne = "cobra_annotation_one_required"
annotationMutuallyExclusive = "cobra_annotation_mutually_exclusive"
annotationGroupDependent = "cobra_annotation_if_present_then_others_required"

The names of the constants are just suggestions to give you an idea

)

// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
Expand Down Expand Up @@ -76,6 +77,25 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
}
}

// MarkIfFlagPresentThenOthersRequired marks the given flags so that if the first flag is set,
// all the other flags become required.
func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) {
if len(flagNames) < 2 {
panic("MarkIfFlagPresentThenRequired requires at least two flags")
}
c.mergePersistentFlags()
for _, v := range flagNames {
f := c.Flags().Lookup(v)
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in an if present then others required flag group", v))
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this tested? I mean covered by a test that will ensure it will stay like this if there is a future refactoring

if err := c.Flags().SetAnnotation(v, ifPresentThenOthersRequiredAnnotation, append(f.Annotations[ifPresentThenOthersRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {
panic(err)
}
}
}

// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
// first error encountered.
func (c *Command) ValidateFlagGroups() error {
Expand All @@ -90,10 +110,12 @@ func (c *Command) ValidateFlagGroups() error {
groupStatus := map[string]map[string]bool{}
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
ifPresentThenOthersRequiredGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequiredAnnotation, ifPresentThenOthersRequiredGroupStatus)
})

if err := validateRequiredFlagGroups(groupStatus); err != nil {
Expand All @@ -105,6 +127,9 @@ func (c *Command) ValidateFlagGroups() error {
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
return err
}
if err := validateIfPresentThenRequiredFlagGroups(ifPresentThenOthersRequiredGroupStatus); err != nil {
return err
}
return nil
}

Expand Down Expand Up @@ -206,6 +231,38 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
return nil
}

func validateIfPresentThenRequiredFlagGroups(data map[string]map[string]bool) error {
for flagList, flagnameAndStatus := range data {
flags := strings.Split(flagList, " ")
primaryFlag := flags[0]
remainingFlags := flags[1:]

// Handle missing primary flag entry
if _, exists := flagnameAndStatus[primaryFlag]; !exists {
flagnameAndStatus[primaryFlag] = false
}

// Check if the primary flag is set
if flagnameAndStatus[primaryFlag] {
var unset []string
for _, flag := range remainingFlags {
if !flagnameAndStatus[flag] {
unset = append(unset, flag)
}
}

// If any dependent flags are unset, trigger an error
if len(unset) > 0 {
return fmt.Errorf(
"%v is set, the following flags must be provided: %v",
primaryFlag, unset,
)
}
}
}
return nil
}

func sortedKeys(m map[string]map[string]bool) []string {
keys := make([]string, len(m))
i := 0
Expand All @@ -221,6 +278,7 @@ func sortedKeys(m map[string]map[string]bool) []string {
// - when a flag in a group is present, other flags in the group will be marked required
// - when none of the flags in a one-required group are present, all flags in the group will be marked required
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
// - when the first flag in an if-present-then-required group is present, the other flags will be marked as required
// This allows the standard completion logic to behave appropriately for flag groups
func (c *Command) enforceFlagGroupsForCompletion() {
if c.DisableFlagParsing {
Expand All @@ -231,10 +289,12 @@ func (c *Command) enforceFlagGroupsForCompletion() {
groupStatus := map[string]map[string]bool{}
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
ifPresentThenRequiredGroupStatus := map[string]map[string]bool{}
c.Flags().VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequiredAnnotation, ifPresentThenRequiredGroupStatus)
})

// If a flag that is part of a group is present, we make all the other flags
Expand Down Expand Up @@ -287,4 +347,17 @@ func (c *Command) enforceFlagGroupsForCompletion() {
}
}
}

// If a flag that is marked as if-present-then-required is present, make other flags in the group required
for flagList, flagnameAndStatus := range ifPresentThenRequiredGroupStatus {
flags := strings.Split(flagList, " ")
primaryFlag := flags[0]
remainingFlags := flags[1:]

if flagnameAndStatus[primaryFlag] {
for _, fName := range remainingFlags {
_ = c.MarkFlagRequired(fName)
}
}
}
}
44 changes: 32 additions & 12 deletions flag_groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,25 @@ func TestValidateFlagGroups(t *testing.T) {

// Each test case uses a unique command from the function above.
testcases := []struct {
desc string
flagGroupsRequired []string
flagGroupsOneRequired []string
flagGroupsExclusive []string
subCmdFlagGroupsRequired []string
subCmdFlagGroupsOneRequired []string
subCmdFlagGroupsExclusive []string
args []string
expectErr string
desc string
flagGroupsRequired []string
flagGroupsOneRequired []string
flagGroupsExclusive []string
flagGroupsIfPresentThenRequired []string
subCmdFlagGroupsRequired []string
subCmdFlagGroupsOneRequired []string
subCmdFlagGroupsExclusive []string
subCmdFlagGroupsIfPresentThenRequired []string
args []string
expectErr string
}{
{
desc: "No flags no problem",
}, {
desc: "No flags no problem even with conflicting groups",
flagGroupsRequired: []string{"a b"},
flagGroupsExclusive: []string{"a b"},
desc: "No flags no problem even with conflicting groups",
flagGroupsRequired: []string{"a b"},
flagGroupsExclusive: []string{"a b"},
flagGroupsIfPresentThenRequired: []string{"a b", "b a"},
}, {
desc: "Required flag group not satisfied",
flagGroupsRequired: []string{"a b c"},
Expand All @@ -74,6 +77,11 @@ func TestValidateFlagGroups(t *testing.T) {
flagGroupsExclusive: []string{"a b c"},
args: []string{"--a=foo", "--b=foo"},
expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set",
}, {
desc: "If present then others required flag group not satisfied",
flagGroupsIfPresentThenRequired: []string{"a b"},
args: []string{"--a=foo"},
expectErr: "a is set, the following flags must be provided: [b]",
}, {
desc: "Multiple required flag group not satisfied returns first error",
flagGroupsRequired: []string{"a b c", "a d"},
Expand All @@ -89,6 +97,12 @@ func TestValidateFlagGroups(t *testing.T) {
flagGroupsExclusive: []string{"a b c", "a d"},
args: []string{"--a=foo", "--c=foo", "--d=foo"},
expectErr: `if any flags in the group [a b c] are set none of the others can be; [a c] were all set`,
},
{
desc: "Multiple if present then others required flag group not satisfied returns first error",
flagGroupsIfPresentThenRequired: []string{"a b", "d e"},
args: []string{"--a=foo", "--f=foo"},
expectErr: `a is set, the following flags must be provided: [b]`,
}, {
desc: "Validation of required groups occurs on groups in sorted order",
flagGroupsRequired: []string{"a d", "a b", "a c"},
Expand Down Expand Up @@ -182,6 +196,12 @@ func TestValidateFlagGroups(t *testing.T) {
for _, flagGroup := range tc.subCmdFlagGroupsExclusive {
sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...)
}
for _, flagGroup := range tc.flagGroupsIfPresentThenRequired {
c.MarkIfFlagPresentThenOthersRequired(strings.Split(flagGroup, " ")...)
}
for _, flagGroup := range tc.subCmdFlagGroupsIfPresentThenRequired {
sub.MarkIfFlagPresentThenOthersRequired(strings.Split(flagGroup, " ")...)
}
c.SetArgs(tc.args)
err := c.Execute()
switch {
Expand Down
Loading