Skip to content

Commit 69bc3bd

Browse files
committed
add support for Func() and BoolFunc() #426
Add support for two features which landed in the 'flag' package from the standard library: Func() and BoolFunc() and their two pflag specific versions: FuncP() and BoolFuncP() fixes #426
1 parent 196624c commit 69bc3bd

File tree

4 files changed

+377
-0
lines changed

4 files changed

+377
-0
lines changed

bool_func.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package pflag
2+
3+
// -- func Value
4+
type boolfuncValue func(string) error
5+
6+
func (f boolfuncValue) Set(s string) error { return f(s) }
7+
8+
func (f boolfuncValue) Type() string { return "func" }
9+
10+
func (f boolfuncValue) String() string { return "" } // same behavior as stdlib 'flag' package
11+
12+
func (f boolfuncValue) IsBoolFlag() bool { return true }
13+
14+
// BoolFunc defines a func flag with specified name, callback function and usage string.
15+
//
16+
// The callback function will be called every time "--{name}" (or any form that matches the flag) is parsed
17+
// on the command line.
18+
func (f *FlagSet) BoolFunc(name string, usage string, fn func(string) error) {
19+
f.BoolFuncP(name, "", usage, fn)
20+
}
21+
22+
// BoolFuncP is like BoolFunc, but accepts a shorthand letter that can be used after a single dash.
23+
func (f *FlagSet) BoolFuncP(name, shorthand string, usage string, fn func(string) error) {
24+
var val Value = boolfuncValue(fn)
25+
flag := f.VarPF(val, name, shorthand, usage)
26+
flag.NoOptDefVal = "true"
27+
}
28+
29+
// BoolFunc defines a func flag with specified name, callback function and usage string.
30+
//
31+
// The callback function will be called every time "--{name}" (or any form that matches the flag) is parsed
32+
// on the command line.
33+
func BoolFunc(name string, usage string, fn func(string) error) {
34+
CommandLine.BoolFuncP(name, "", usage, fn)
35+
}
36+
37+
// BoolFuncP is like BoolFunc, but accepts a shorthand letter that can be used after a single dash.
38+
func BoolFuncP(name, shorthand string, fn func(string) error, usage string) {
39+
CommandLine.BoolFuncP(name, shorthand, usage, fn)
40+
}

bool_func_test.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package pflag
2+
3+
import (
4+
"errors"
5+
"flag"
6+
"io"
7+
"strings"
8+
"testing"
9+
)
10+
11+
func TestBoolFunc(t *testing.T) {
12+
var count int
13+
fn := func(_ string) error {
14+
count++
15+
return nil
16+
}
17+
18+
fset := NewFlagSet("test", ContinueOnError)
19+
fset.BoolFunc("func", "Callback function", fn)
20+
21+
err := fset.Parse([]string{"--func", "--func=1", "--func=false"})
22+
if err != nil {
23+
t.Fatal("expected no error; got", err)
24+
}
25+
26+
if count != 3 {
27+
t.Fatalf("expected 3 calls to the callback, got %d calls", count)
28+
}
29+
}
30+
31+
func TestBoolFuncP(t *testing.T) {
32+
var count int
33+
fn := func(_ string) error {
34+
count++
35+
return nil
36+
}
37+
38+
fset := NewFlagSet("test", ContinueOnError)
39+
fset.BoolFuncP("bfunc", "b", "Callback function", fn)
40+
41+
err := fset.Parse([]string{"--bfunc", "--bfunc=0", "--bfunc=false", "-b", "-b=0"})
42+
if err != nil {
43+
t.Fatal("expected no error; got", err)
44+
}
45+
46+
if count != 5 {
47+
t.Fatalf("expected 5 calls to the callback, got %d calls", count)
48+
}
49+
}
50+
51+
func TestBoolFuncCompat(t *testing.T) {
52+
// compare behavior with the stdlib 'flag' package
53+
type BoolFuncFlagSet interface {
54+
BoolFunc(name string, usage string, fn func(string) error)
55+
Parse([]string) error
56+
}
57+
58+
unitTestErr := errors.New("unit test error")
59+
runCase := func(f BoolFuncFlagSet, name string, args []string) (values []string, err error) {
60+
fn := func(s string) error {
61+
values = append(values, s)
62+
if s == "err" {
63+
return unitTestErr
64+
}
65+
return nil
66+
}
67+
f.BoolFunc(name, "Callback function", fn)
68+
69+
err = f.Parse(args)
70+
return values, err
71+
}
72+
73+
t.Run("regular parsing", func(t *testing.T) {
74+
flagName := "bflag"
75+
args := []string{"--bflag", "--bflag=false", "--bflag=1", "--bflag=bar", "--bflag="}
76+
77+
// It turns out that, even though the function is called "BoolFunc",
78+
// the stanard flag package does not try to parse the value assigned to
79+
// that cli flag as a boolean. The string provided on the command line is
80+
// passed as is to the callback.
81+
// e.g: with "--bflag=not_a_bool" on the command line, the FlagSet does not
82+
// generate an error stating "invalid boolean value", and `fn` will be called
83+
// with "not_a_bool" as an argument.
84+
85+
stdFSet := flag.NewFlagSet("std test", flag.ContinueOnError)
86+
stdValues, err := runCase(stdFSet, flagName, args)
87+
if err != nil {
88+
t.Fatalf("std flag: expected no error, got %v", err)
89+
}
90+
expected := []string{"true", "false", "1", "bar", ""}
91+
if !cmpLists(expected, stdValues) {
92+
t.Fatalf("std flag: expected %v, got %v", expected, stdValues)
93+
}
94+
95+
fset := NewFlagSet("pflag test", ContinueOnError)
96+
pflagValues, err := runCase(fset, flagName, args)
97+
if err != nil {
98+
t.Fatalf("pflag: expected no error, got %v", err)
99+
}
100+
if !cmpLists(stdValues, pflagValues) {
101+
t.Fatalf("pflag: expected %v, got %v", stdValues, pflagValues)
102+
}
103+
})
104+
105+
t.Run("error triggered by callback", func(t *testing.T) {
106+
flagName := "bflag"
107+
args := []string{"--bflag", "--bflag=err", "--bflag=after"}
108+
109+
// test behavior of standard flag.Fset with an error triggere by the callback:
110+
// (note: as can be seen in 'runCase()', if the callback sees "err" as a value
111+
// for the bool flag, it will return an error)
112+
stdFSet := flag.NewFlagSet("std test", flag.ContinueOnError)
113+
stdFSet.SetOutput(io.Discard) // suppress output
114+
115+
// run test case with standard flag.Fset
116+
stdValues, err := runCase(stdFSet, flagName, args)
117+
118+
// double check the standard behavior:
119+
// - .Parse() should return an error, which contains the error message
120+
if err == nil {
121+
t.Fatalf("std flag: expected an error triggered by callback, got no error instead")
122+
}
123+
if !strings.HasSuffix(err.Error(), unitTestErr.Error()) {
124+
t.Fatalf("std flag: expected unittest error, got unexpected error value: %T %v", err, err)
125+
}
126+
// - the function should have been called twice, with the first two values,
127+
// the final "=after" should not be recorded
128+
expected := []string{"true", "err"}
129+
if !cmpLists(expected, stdValues) {
130+
t.Fatalf("std flag: expected %v, got %v", expected, stdValues)
131+
}
132+
133+
// now run the test case on a pflag FlagSet:
134+
fset := NewFlagSet("pflag test", ContinueOnError)
135+
pflagValues, err := runCase(fset, flagName, args)
136+
137+
// check that there is a similar error (note: pflag will _wrap_ the error, while the stdlib
138+
// currently keeps the original message but creates a flat errors.Error)
139+
if !errors.Is(err, unitTestErr) {
140+
t.Fatalf("pflag: got unexpected error value: %T %v", err, err)
141+
}
142+
// the callback should be called the same number of times, with the same values:
143+
if !cmpLists(stdValues, pflagValues) {
144+
t.Fatalf("pflag: expected %v, got %v", stdValues, pflagValues)
145+
}
146+
})
147+
}

func.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package pflag
2+
3+
// -- func Value
4+
type funcValue func(string) error
5+
6+
func (f funcValue) Set(s string) error { return f(s) }
7+
8+
func (f funcValue) Type() string { return "func" }
9+
10+
func (f funcValue) String() string { return "" } // same behavior as stdlib 'flag' package
11+
12+
// Func defines a func flag with specified name, callback function and usage string.
13+
//
14+
// The callback function will be called every time "--{name}={value}" (or equivalent) is
15+
// parsed on the command line, with "{value}" as an argument.
16+
func (f *FlagSet) Func(name string, usage string, fn func(string) error) {
17+
f.FuncP(name, "", usage, fn)
18+
}
19+
20+
// FuncP is like Func, but accepts a shorthand letter that can be used after a single dash.
21+
func (f *FlagSet) FuncP(name string, shorthand string, usage string, fn func(string) error) {
22+
var val Value = funcValue(fn)
23+
f.VarP(val, name, shorthand, usage)
24+
}
25+
26+
// Func defines a func flag with specified name, callback function and usage string.
27+
//
28+
// The callback function will be called every time "--{name}={value}" (or equivalent) is
29+
// parsed on the command line, with "{value}" as an argument.
30+
func Func(name string, fn func(string) error, usage string) {
31+
CommandLine.FuncP(name, "", usage, fn)
32+
}
33+
34+
// FuncP is like Func, but accepts a shorthand letter that can be used after a single dash.
35+
func FuncP(name, shorthand string, fn func(string) error, usage string) {
36+
CommandLine.FuncP(name, shorthand, usage, fn)
37+
}

func_test.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package pflag
2+
3+
import (
4+
"errors"
5+
"flag"
6+
"io"
7+
"strings"
8+
"testing"
9+
)
10+
11+
func cmpLists(a, b []string) bool {
12+
if len(a) != len(b) {
13+
return false
14+
}
15+
for i := range a {
16+
if a[i] != b[i] {
17+
return false
18+
}
19+
}
20+
return true
21+
}
22+
23+
func TestFunc(t *testing.T) {
24+
var values []string
25+
fn := func(s string) error {
26+
values = append(values, s)
27+
return nil
28+
}
29+
30+
fset := NewFlagSet("test", ContinueOnError)
31+
fset.Func("fnflag", "Callback function", fn)
32+
33+
err := fset.Parse([]string{"--fnflag=aa", "--fnflag", "bb"})
34+
if err != nil {
35+
t.Fatal("expected no error; got", err)
36+
}
37+
38+
expected := []string{"aa", "bb"}
39+
if !cmpLists(expected, values) {
40+
t.Fatalf("expected %v, got %v", expected, values)
41+
}
42+
}
43+
44+
func TestFuncP(t *testing.T) {
45+
var values []string
46+
fn := func(s string) error {
47+
values = append(values, s)
48+
return nil
49+
}
50+
51+
fset := NewFlagSet("test", ContinueOnError)
52+
fset.FuncP("fnflag", "f", "Callback function", fn)
53+
54+
err := fset.Parse([]string{"--fnflag=a", "--fnflag", "b", "-fc", "-f=d", "-f", "e"})
55+
if err != nil {
56+
t.Fatal("expected no error; got", err)
57+
}
58+
59+
expected := []string{"a", "b", "c", "d", "e"}
60+
if !cmpLists(expected, values) {
61+
t.Fatalf("expected %v, got %v", expected, values)
62+
}
63+
}
64+
65+
func TestFuncCompat(t *testing.T) {
66+
// compare behavior with the stdlib 'flag' package
67+
type FuncFlagSet interface {
68+
Func(name string, usage string, fn func(string) error)
69+
Parse([]string) error
70+
}
71+
72+
unitTestErr := errors.New("unit test error")
73+
runCase := func(f FuncFlagSet, name string, args []string) (values []string, err error) {
74+
fn := func(s string) error {
75+
values = append(values, s)
76+
if s == "err" {
77+
return unitTestErr
78+
}
79+
return nil
80+
}
81+
f.Func(name, "Callback function", fn)
82+
83+
err = f.Parse(args)
84+
return values, err
85+
}
86+
87+
t.Run("regular parsing", func(t *testing.T) {
88+
flagName := "fnflag"
89+
args := []string{"--fnflag=xx", "--fnflag", "yy", "--fnflag=zz"}
90+
91+
stdFSet := flag.NewFlagSet("std test", flag.ContinueOnError)
92+
stdValues, err := runCase(stdFSet, flagName, args)
93+
if err != nil {
94+
t.Fatalf("std flag: expected no error, got %v", err)
95+
}
96+
expected := []string{"xx", "yy", "zz"}
97+
if !cmpLists(expected, stdValues) {
98+
t.Fatalf("std flag: expected %v, got %v", expected, stdValues)
99+
}
100+
101+
fset := NewFlagSet("pflag test", ContinueOnError)
102+
pflagValues, err := runCase(fset, flagName, args)
103+
if err != nil {
104+
t.Fatalf("pflag: expected no error, got %v", err)
105+
}
106+
if !cmpLists(stdValues, pflagValues) {
107+
t.Fatalf("pflag: expected %v, got %v", stdValues, pflagValues)
108+
}
109+
})
110+
111+
t.Run("error triggered by callback", func(t *testing.T) {
112+
flagName := "fnflag"
113+
args := []string{"--fnflag", "before", "--fnflag", "err", "--fnflag", "after"}
114+
115+
// test behavior of standard flag.Fset with an error triggere by the callback:
116+
// (note: as can be seen in 'runCase()', if the callback sees "err" as a value
117+
// for the bool flag, it will return an error)
118+
stdFSet := flag.NewFlagSet("std test", flag.ContinueOnError)
119+
stdFSet.SetOutput(io.Discard) // suppress output
120+
121+
// run test case with standard flag.Fset
122+
stdValues, err := runCase(stdFSet, flagName, args)
123+
124+
// double check the standard behavior:
125+
// - .Parse() should return an error, which contains the error message
126+
if err == nil {
127+
t.Fatalf("std flag: expected an error triggered by callback, got no error instead")
128+
}
129+
if !strings.HasSuffix(err.Error(), unitTestErr.Error()) {
130+
t.Fatalf("std flag: expected unittest error, got unexpected error value: %T %v", err, err)
131+
}
132+
// - the function should have been called twice, with the first two values,
133+
// the final "=after" should not be recorded
134+
expected := []string{"before", "err"}
135+
if !cmpLists(expected, stdValues) {
136+
t.Fatalf("std flag: expected %v, got %v", expected, stdValues)
137+
}
138+
139+
// now run the test case on a pflag FlagSet:
140+
fset := NewFlagSet("pflag test", ContinueOnError)
141+
pflagValues, err := runCase(fset, flagName, args)
142+
143+
// check that there is a similar error (note: pflag will _wrap_ the error, while the stdlib
144+
// currently keeps the original message but creates a flat errors.Error)
145+
if !errors.Is(err, unitTestErr) {
146+
t.Fatalf("pflag: got unexpected error value: %T %v", err, err)
147+
}
148+
// the callback should be called the same number of times, with the same values:
149+
if !cmpLists(stdValues, pflagValues) {
150+
t.Fatalf("pflag: expected %v, got %v", stdValues, pflagValues)
151+
}
152+
})
153+
}

0 commit comments

Comments
 (0)