diff --git a/types/wgkey/key.go b/types/wgkey/key.go index d96363952..668ceee85 100644 --- a/types/wgkey/key.go +++ b/types/wgkey/key.go @@ -90,14 +90,12 @@ func (k *Key) IsZero() bool { return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 } -func (k *Key) MarshalJSON() ([]byte, error) { - if k == nil { - return []byte("null"), nil - } - // TODO(josharian): use encoding/hex instead? - buf := new(bytes.Buffer) - fmt.Fprintf(buf, `"%x"`, k[:]) - return buf.Bytes(), nil +func (k Key) MarshalJSON() ([]byte, error) { + buf := make([]byte, 2+len(k)*2) + buf[0] = '"' + hex.Encode(buf[1:], k[:]) + buf[len(buf)-1] = '"' + return buf, nil } func (k *Key) UnmarshalJSON(b []byte) error { diff --git a/types/wgkey/key_test.go b/types/wgkey/key_test.go index 9b8632a3b..ba3e4166f 100644 --- a/types/wgkey/key_test.go +++ b/types/wgkey/key_test.go @@ -6,6 +6,7 @@ package wgkey import ( "bytes" + "encoding/json" "testing" ) @@ -20,7 +21,7 @@ func TestKeyBasics(t *testing.T) { t.Fatal(err) } - t.Run("JSON round-trip", func(t *testing.T) { + t.Run("JSON round-trip (pointer)", func(t *testing.T) { // should preserve the keys k2 := new(Key) if err := k2.UnmarshalJSON(b); err != nil { @@ -55,6 +56,27 @@ func TestKeyBasics(t *testing.T) { t.Fatalf("base64-encoded keys match: %s, %s", b1, b2) } }) + + t.Run("JSON round-trip (value)", func(t *testing.T) { + type T struct { + K Key + } + v := T{K: *k1} + b, err := json.Marshal(v) + if err != nil { + t.Fatal(err) + } + var u T + if err := json.Unmarshal(b, &u); err != nil { + t.Fatal(err) + } + if !bytes.Equal(v.K[:], u.K[:]) { + t.Fatalf("v.K %v != u.K %v", v.K[:], u.K[:]) + } + if b1, b2 := v.K.String(), u.K.String(); b1 != b2 { + t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2) + } + }) } func TestPrivateKeyBasics(t *testing.T) { pri, err := NewPrivate() @@ -109,3 +131,28 @@ func TestPrivateKeyBasics(t *testing.T) { } }) } + +func TestMarshalJSONAllocs(t *testing.T) { + var k Key + f := testing.AllocsPerRun(100, func() { + k.MarshalJSON() + }) + n := int(f) + if n != 1 { + t.Fatalf("max one alloc per Key.MarshalJSON, got %d", n) + } +} + +var sink []byte + +func BenchmarkMarshalJSON(b *testing.B) { + b.ReportAllocs() + var k Key + for i := 0; i < b.N; i++ { + var err error + sink, err = k.MarshalJSON() + if err != nil { + b.Fatal(err) + } + } +}