Skip to content

Elimination of duplication through common logic #2293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions flag_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,31 +83,40 @@ func (c *Command) ValidateFlagGroups() error {
return nil
}

flags := c.Flags()

// groupStatus format is the list of flags as a unique ID,
// then a map of each flag name and whether it is set or not.
groupStatus := map[string]map[string]bool{}
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
})
required, oneRequired, mutuallyExclusive := c.getFlagGroupStatuses()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually when I see a function returning 3 values, I tend to recommend returning a struct with 3 fields. It often helps because more values can be returned easily, also you don't always need all values depending on where you call the function


if err := validateRequiredFlagGroups(groupStatus); err != nil {
if err := validateRequiredFlagGroups(required); err != nil {
return err
}
if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
if err := validateOneRequiredFlagGroups(oneRequired); err != nil {
return err
}
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
if err := validateExclusiveFlagGroups(mutuallyExclusive); err != nil {
return err
}
return nil
}

// getFlagGroupStatuses collects the status of all flags belonging to any flag group.
func (c *Command) getFlagGroupStatuses() (
required map[string]map[string]bool,
oneRequired map[string]map[string]bool,
mutuallyExclusive map[string]map[string]bool,
) {
flags := c.Flags()
required = map[string]map[string]bool{}
oneRequired = map[string]map[string]bool{}
mutuallyExclusive = map[string]map[string]bool{}

flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, required)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequired)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusive)
})

return required, oneRequired, mutuallyExclusive
}

func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
for _, fname := range flagnames {
f := fs.Lookup(fname)
Expand Down Expand Up @@ -227,19 +236,11 @@ func (c *Command) enforceFlagGroupsForCompletion() {
return
}

flags := c.Flags()
groupStatus := map[string]map[string]bool{}
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
c.Flags().VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
})
required, oneRequired, mutuallyExclusive := c.getFlagGroupStatuses()

// If a flag that is part of a group is present, we make all the other flags
// of that group required so that the shell completion suggests them automatically
for flagList, flagnameAndStatus := range groupStatus {
for flagList, flagnameAndStatus := range required {
for _, isSet := range flagnameAndStatus {
if isSet {
// One of the flags of the group is set, mark the other ones as required
Expand All @@ -252,7 +253,7 @@ func (c *Command) enforceFlagGroupsForCompletion() {

// If none of the flags of a one-required group are present, we make all the flags
// of that group required so that the shell completion suggests them automatically
for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
for flagList, flagnameAndStatus := range oneRequired {
isSet := false

for _, isSet = range flagnameAndStatus {
Expand All @@ -272,7 +273,7 @@ func (c *Command) enforceFlagGroupsForCompletion() {

// If a flag that is mutually exclusive to others is present, we hide the other
// flags of that group so the shell completion does not suggest them
for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
for flagList, flagnameAndStatus := range mutuallyExclusive {
for flagName, isSet := range flagnameAndStatus {
if isSet {
// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
Expand Down