Skip to content

Commit 19ec3ae

Browse files
brianpursleyhoshsadiq
authored andcommitted
fix: don't remove flag value that matches subcommand name
When the command searches args to find the arg matching a particular subcommand name, it needs to ignore flag values, as it is possible that the value for a flag might match the name of the sub command. This change improves argsMinusFirstX() to ignore flag values when it searches for the X to exclude from the result. Fixes spf13/cobra#1777 Merge spf13/cobra#1781
1 parent 803ace3 commit 19ec3ae

File tree

2 files changed

+121
-8
lines changed

2 files changed

+121
-8
lines changed

command.go

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -666,13 +666,37 @@ Loop:
666666

667667
// argsMinusFirstX removes only the first x from args. Otherwise, commands that look like
668668
// openshift admin policy add-role-to-user admin my-user, lose the admin argument (arg[4]).
669-
func argsMinusFirstX(args []string, x string) []string {
670-
for i, y := range args {
671-
if x == y {
672-
var ret []string
673-
ret = append(ret, args[:i]...)
674-
ret = append(ret, args[i+1:]...)
675-
return ret
669+
// Special care needs to be taken not to remove a flag value.
670+
func (c *Command) argsMinusFirstX(args []string, x string) []string {
671+
if len(args) == 0 {
672+
return args
673+
}
674+
c.mergePersistentFlags()
675+
flags := c.Flags()
676+
677+
Loop:
678+
for pos := 0; pos < len(args); pos++ {
679+
s := args[pos]
680+
switch {
681+
case s == "--":
682+
// -- means we have reached the end of the parseable args. Break out of the loop now.
683+
break Loop
684+
case strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && !isBoolFlag(s[2:], flags):
685+
fallthrough
686+
case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !isShortBoolFlag(s[1:], flags):
687+
// This is a flag without a default value, and an equal sign is not used. Increment pos in order to skip
688+
// over the next arg, because that is the value of this flag.
689+
pos++
690+
continue
691+
case !strings.HasPrefix(s, "-"):
692+
// This is not a flag or a flag value. Check to see if it matches what we're looking for, and if so,
693+
// return the args, excluding the one at this position.
694+
if s == x {
695+
ret := []string{}
696+
ret = append(ret, args[:pos]...)
697+
ret = append(ret, args[pos+1:]...)
698+
return ret
699+
}
676700
}
677701
}
678702
return args
@@ -697,7 +721,7 @@ func (c *Command) Find(args []string) (*Command, []string, error) {
697721

698722
cmd := c.findNext(nextSubCmd)
699723
if cmd != nil {
700-
return innerfind(cmd, argsMinusFirstX(innerArgs, nextSubCmd))
724+
return innerfind(cmd, c.argsMinusFirstX(innerArgs, nextSubCmd))
701725
}
702726
return c, innerArgs
703727
}

command_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2551,3 +2551,92 @@ Use "root child [command] --help" for more information about a command.
25512551
})
25522552
}
25532553
}
2554+
2555+
func TestFind(t *testing.T) {
2556+
var foo, bar string
2557+
root := &zulu.Command{
2558+
Use: "root",
2559+
}
2560+
root.PersistentFlags().StringVar(&foo, "foo", "", "", zflag.OptShorthand('f'))
2561+
root.PersistentFlags().StringVar(&bar, "bar", "something", "", zflag.OptShorthand('b'))
2562+
2563+
child := &zulu.Command{
2564+
Use: "child",
2565+
}
2566+
root.AddCommand(child)
2567+
2568+
testCases := []struct {
2569+
args []string
2570+
expectedFoundArgs []string
2571+
}{
2572+
{
2573+
[]string{"child"},
2574+
[]string{},
2575+
},
2576+
{
2577+
[]string{"child", "child"},
2578+
[]string{"child"},
2579+
},
2580+
{
2581+
[]string{"child", "foo", "child", "bar", "child", "baz", "child"},
2582+
[]string{"foo", "child", "bar", "child", "baz", "child"},
2583+
},
2584+
{
2585+
[]string{"-f", "child", "child"},
2586+
[]string{"-f", "child"},
2587+
},
2588+
{
2589+
[]string{"child", "-f", "child"},
2590+
[]string{"-f", "child"},
2591+
},
2592+
{
2593+
[]string{"-b", "child", "child"},
2594+
[]string{"-b", "child"},
2595+
},
2596+
{
2597+
[]string{"child", "-b", "child"},
2598+
[]string{"-b", "child"},
2599+
},
2600+
{
2601+
[]string{"child", "-b"},
2602+
[]string{"-b"},
2603+
},
2604+
{
2605+
[]string{"-b", "-f", "child", "child"},
2606+
[]string{"-b", "-f", "child"},
2607+
},
2608+
{
2609+
[]string{"-f", "child", "-b", "something", "child"},
2610+
[]string{"-f", "child", "-b", "something"},
2611+
},
2612+
{
2613+
[]string{"-f", "child", "child", "-b"},
2614+
[]string{"-f", "child", "-b"},
2615+
},
2616+
{
2617+
[]string{"-f=child", "-b=something", "child"},
2618+
[]string{"-f=child", "-b=something"},
2619+
},
2620+
{
2621+
[]string{"--foo", "child", "--bar", "something", "child"},
2622+
[]string{"--foo", "child", "--bar", "something"},
2623+
},
2624+
}
2625+
2626+
for _, tc := range testCases {
2627+
t.Run(fmt.Sprintf("%v", tc.args), func(t *testing.T) {
2628+
cmd, foundArgs, err := root.Find(tc.args)
2629+
if err != nil {
2630+
t.Fatal(err)
2631+
}
2632+
2633+
if cmd != child {
2634+
t.Fatal("Expected cmd to be child, but it was not")
2635+
}
2636+
2637+
if !reflect.DeepEqual(tc.expectedFoundArgs, foundArgs) {
2638+
t.Fatalf("Wrong args\nExpected: %v\nGot: %v", tc.expectedFoundArgs, foundArgs)
2639+
}
2640+
})
2641+
}
2642+
}

0 commit comments

Comments
 (0)