Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add MarkIfFlagPresentThenOthersRequired #2200

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions flag_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import (
)

const (
requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
oneRequiredAnnotation = "cobra_annotation_one_required"
mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
requiredAsGroup = "cobra_annotation_required_if_others_set"
oneRequired = "cobra_annotation_one_required"
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
ifPresentThenOthersRequired = "cobra_annotation_if_present_then_others_required"
)

// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
Expand Down Expand Up @@ -76,6 +77,25 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
}
}

// MarkIfFlagPresentThenOthersRequired marks the given flags so that if the first flag is set,
// all the other flags become required.
func (c *Command) MarkIfFlagPresentThenOthersRequired(flagNames ...string) {
if len(flagNames) < 2 {
panic("MarkIfFlagPresentThenRequired requires at least two flags")
}
c.mergePersistentFlags()
for _, v := range flagNames {
f := c.Flags().Lookup(v)
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in an if present then others required flag group", v))
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
if err := c.Flags().SetAnnotation(v, ifPresentThenOthersRequired, append(f.Annotations[ifPresentThenOthersRequired], strings.Join(flagNames, " "))); err != nil {
panic(err)
}
}
}

// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
// first error encountered.
func (c *Command) ValidateFlagGroups() error {
Expand All @@ -90,10 +110,12 @@ func (c *Command) ValidateFlagGroups() error {
groupStatus := map[string]map[string]bool{}
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
ifPresentThenOthersRequiredGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequired, ifPresentThenOthersRequiredGroupStatus)
})

if err := validateRequiredFlagGroups(groupStatus); err != nil {
Expand All @@ -105,6 +127,9 @@ func (c *Command) ValidateFlagGroups() error {
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
return err
}
if err := validateIfPresentThenRequiredFlagGroups(ifPresentThenOthersRequiredGroupStatus); err != nil {
return err
}
return nil
}

Expand Down Expand Up @@ -206,6 +231,38 @@ func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
return nil
}

func validateIfPresentThenRequiredFlagGroups(data map[string]map[string]bool) error {
for flagList, flagnameAndStatus := range data {
flags := strings.Split(flagList, " ")
primaryFlag := flags[0]
remainingFlags := flags[1:]

// Handle missing primary flag entry
if _, exists := flagnameAndStatus[primaryFlag]; !exists {
flagnameAndStatus[primaryFlag] = false
}

// Check if the primary flag is set
if flagnameAndStatus[primaryFlag] {
var unset []string
for _, flag := range remainingFlags {
if !flagnameAndStatus[flag] {
unset = append(unset, flag)
}
}

// If any dependent flags are unset, trigger an error
if len(unset) > 0 {
return fmt.Errorf(
"if the first flag in the group [%v] is set, all other flags must be set; the following flags are not set: %v",
flagList, unset,
)
}
}
}
return nil
}

func sortedKeys(m map[string]map[string]bool) []string {
keys := make([]string, len(m))
i := 0
Expand All @@ -221,6 +278,7 @@ func sortedKeys(m map[string]map[string]bool) []string {
// - when a flag in a group is present, other flags in the group will be marked required
// - when none of the flags in a one-required group are present, all flags in the group will be marked required
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
// - when the first flag in an if-present-then-required group is present, the second flag will be marked as required
// This allows the standard completion logic to behave appropriately for flag groups
func (c *Command) enforceFlagGroupsForCompletion() {
if c.DisableFlagParsing {
Expand All @@ -231,10 +289,12 @@ func (c *Command) enforceFlagGroupsForCompletion() {
groupStatus := map[string]map[string]bool{}
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
ifPresentThenRequiredGroupStatus := map[string]map[string]bool{}
c.Flags().VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, ifPresentThenOthersRequired, ifPresentThenRequiredGroupStatus)
})

// If a flag that is part of a group is present, we make all the other flags
Expand Down Expand Up @@ -287,4 +347,17 @@ func (c *Command) enforceFlagGroupsForCompletion() {
}
}
}

// If a flag that is marked as if-present-then-required is present, make other flags in the group required
for flagList, flagnameAndStatus := range ifPresentThenRequiredGroupStatus {
flags := strings.Split(flagList, " ")
primaryFlag := flags[0]
remainingFlags := flags[1:]

if flagnameAndStatus[primaryFlag] {
for _, fName := range remainingFlags {
_ = c.MarkFlagRequired(fName)
}
}
}
}
44 changes: 32 additions & 12 deletions flag_groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,25 @@ func TestValidateFlagGroups(t *testing.T) {

// Each test case uses a unique command from the function above.
testcases := []struct {
desc string
flagGroupsRequired []string
flagGroupsOneRequired []string
flagGroupsExclusive []string
subCmdFlagGroupsRequired []string
subCmdFlagGroupsOneRequired []string
subCmdFlagGroupsExclusive []string
args []string
expectErr string
desc string
flagGroupsRequired []string
flagGroupsOneRequired []string
flagGroupsExclusive []string
flagGroupsIfPresentThenRequired []string
subCmdFlagGroupsRequired []string
subCmdFlagGroupsOneRequired []string
subCmdFlagGroupsExclusive []string
subCmdFlagGroupsIfPresentThenRequired []string
args []string
expectErr string
}{
{
desc: "No flags no problem",
}, {
desc: "No flags no problem even with conflicting groups",
flagGroupsRequired: []string{"a b"},
flagGroupsExclusive: []string{"a b"},
desc: "No flags no problem even with conflicting groups",
flagGroupsRequired: []string{"a b"},
flagGroupsExclusive: []string{"a b"},
flagGroupsIfPresentThenRequired: []string{"a b"},
}, {
desc: "Required flag group not satisfied",
flagGroupsRequired: []string{"a b c"},
Expand All @@ -74,6 +77,11 @@ func TestValidateFlagGroups(t *testing.T) {
flagGroupsExclusive: []string{"a b c"},
args: []string{"--a=foo", "--b=foo"},
expectErr: "if any flags in the group [a b c] are set none of the others can be; [a b] were all set",
}, {
desc: "If present then others required flag group not satisfied",
flagGroupsIfPresentThenRequired: []string{"a b"},
args: []string{"--a=foo"},
expectErr: "if the first flag in the group [a b] is set, all other flags must be set; the following flags are not set: [b]",
}, {
desc: "Multiple required flag group not satisfied returns first error",
flagGroupsRequired: []string{"a b c", "a d"},
Expand All @@ -89,6 +97,12 @@ func TestValidateFlagGroups(t *testing.T) {
flagGroupsExclusive: []string{"a b c", "a d"},
args: []string{"--a=foo", "--c=foo", "--d=foo"},
expectErr: `if any flags in the group [a b c] are set none of the others can be; [a c] were all set`,
},
{
desc: "Multiple if present then others required flag group not satisfied returns first error",
flagGroupsIfPresentThenRequired: []string{"a b", "d e"},
args: []string{"--a=foo", "--f=foo"},
expectErr: `if the first flag in the group [a b] is set, all other flags must be set; the following flags are not set: [b]`,
}, {
desc: "Validation of required groups occurs on groups in sorted order",
flagGroupsRequired: []string{"a d", "a b", "a c"},
Expand Down Expand Up @@ -182,6 +196,12 @@ func TestValidateFlagGroups(t *testing.T) {
for _, flagGroup := range tc.subCmdFlagGroupsExclusive {
sub.MarkFlagsMutuallyExclusive(strings.Split(flagGroup, " ")...)
}
for _, flagGroup := range tc.flagGroupsIfPresentThenRequired {
c.MarkIfFlagPresentThenOthersRequired(strings.Split(flagGroup, " ")...)
}
for _, flagGroup := range tc.subCmdFlagGroupsIfPresentThenRequired {
sub.MarkIfFlagPresentThenOthersRequired(strings.Split(flagGroup, " ")...)
}
c.SetArgs(tc.args)
err := c.Execute()
switch {
Expand Down