types/opt: add BoolFlag for setting Bool value as a flag
Updates tailscale/corp#22578 Signed-off-by: Will Norris <will@tailscale.com>
This commit is contained in:
parent
8af50fa97c
commit
cccacff564
|
@ -105,3 +105,29 @@ func (b *Bool) UnmarshalJSON(j []byte) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BoolFlag is a wrapper for Bool that implements [flag.Value].
|
||||
type BoolFlag struct {
|
||||
*Bool
|
||||
}
|
||||
|
||||
// Set the value of b, using any value supported by [strconv.ParseBool].
|
||||
func (b *BoolFlag) Set(s string) error {
|
||||
v, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.Bool.Set(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns "true" or "false" if the value is set, or an empty string otherwise.
|
||||
func (b *BoolFlag) String() string {
|
||||
if b == nil || b.Bool == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := b.Bool.Get(); ok {
|
||||
return strconv.FormatBool(v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
|
|
@ -5,7 +5,9 @@ package opt
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -127,3 +129,38 @@ func TestUnmarshalAlloc(t *testing.T) {
|
|||
t.Errorf("got %v allocs, want 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolFlag(t *testing.T) {
|
||||
tests := []struct {
|
||||
arguments string
|
||||
wantParseError bool // expect flag.Parse to error
|
||||
want Bool
|
||||
}{
|
||||
{"", false, Bool("")},
|
||||
{"-test", true, Bool("")},
|
||||
{`-test=""`, true, Bool("")},
|
||||
{"-test invalid", true, Bool("")},
|
||||
|
||||
{"-test true", false, NewBool(true)},
|
||||
{"-test 1", false, NewBool(true)},
|
||||
|
||||
{"-test false", false, NewBool(false)},
|
||||
{"-test 0", false, NewBool(false)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
var got Bool
|
||||
fs := flag.NewFlagSet(t.Name(), flag.ContinueOnError)
|
||||
fs.Var(&BoolFlag{&got}, "test", "test flag")
|
||||
|
||||
arguments := strings.Split(tt.arguments, " ")
|
||||
err := fs.Parse(arguments)
|
||||
if (err != nil) != tt.wantParseError {
|
||||
t.Errorf("flag.Parse(%q) returned error %v, want %v", arguments, err, tt.wantParseError)
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("flag.Parse(%q) got %q, want %q", arguments, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue