types/opt: add BoolFlag for setting Bool value as a flag
Updates tailscale/corp#22578 Signed-off-by: Will Norris <will@tailscale.com>
This commit is contained in:
parent
8af50fa97c
commit
cccacff564
|
@ -105,3 +105,29 @@ func (b *Bool) UnmarshalJSON(j []byte) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BoolFlag is a wrapper for Bool that implements [flag.Value].
|
||||||
|
type BoolFlag struct {
|
||||||
|
*Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the value of b, using any value supported by [strconv.ParseBool].
|
||||||
|
func (b *BoolFlag) Set(s string) error {
|
||||||
|
v, err := strconv.ParseBool(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.Bool.Set(v)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns "true" or "false" if the value is set, or an empty string otherwise.
|
||||||
|
func (b *BoolFlag) String() string {
|
||||||
|
if b == nil || b.Bool == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if v, ok := b.Bool.Get(); ok {
|
||||||
|
return strconv.FormatBool(v)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
|
@ -5,7 +5,9 @@ package opt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -127,3 +129,38 @@ func TestUnmarshalAlloc(t *testing.T) {
|
||||||
t.Errorf("got %v allocs, want 0", n)
|
t.Errorf("got %v allocs, want 0", n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBoolFlag(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
arguments string
|
||||||
|
wantParseError bool // expect flag.Parse to error
|
||||||
|
want Bool
|
||||||
|
}{
|
||||||
|
{"", false, Bool("")},
|
||||||
|
{"-test", true, Bool("")},
|
||||||
|
{`-test=""`, true, Bool("")},
|
||||||
|
{"-test invalid", true, Bool("")},
|
||||||
|
|
||||||
|
{"-test true", false, NewBool(true)},
|
||||||
|
{"-test 1", false, NewBool(true)},
|
||||||
|
|
||||||
|
{"-test false", false, NewBool(false)},
|
||||||
|
{"-test 0", false, NewBool(false)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
var got Bool
|
||||||
|
fs := flag.NewFlagSet(t.Name(), flag.ContinueOnError)
|
||||||
|
fs.Var(&BoolFlag{&got}, "test", "test flag")
|
||||||
|
|
||||||
|
arguments := strings.Split(tt.arguments, " ")
|
||||||
|
err := fs.Parse(arguments)
|
||||||
|
if (err != nil) != tt.wantParseError {
|
||||||
|
t.Errorf("flag.Parse(%q) returned error %v, want %v", arguments, err, tt.wantParseError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("flag.Parse(%q) got %q, want %q", arguments, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue