237 lines
5.1 KiB
Go
237 lines
5.1 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
//go:build windows && cgo
|
|
|
|
package controlclient
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"errors"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/tailscale/certstore"
|
|
)
|
|
|
|
const (
|
|
testRootCommonName = "testroot"
|
|
testRootSubject = "CN=testroot"
|
|
)
|
|
|
|
type testIdentity struct {
|
|
chain []*x509.Certificate
|
|
}
|
|
|
|
func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate {
|
|
return []*x509.Certificate{
|
|
{
|
|
NotBefore: notBefore,
|
|
NotAfter: notAfter,
|
|
PublicKeyAlgorithm: x509.RSA,
|
|
},
|
|
{
|
|
Subject: pkix.Name{
|
|
CommonName: rootCommonName,
|
|
},
|
|
PublicKeyAlgorithm: x509.RSA,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (t *testIdentity) Certificate() (*x509.Certificate, error) {
|
|
return t.chain[0], nil
|
|
}
|
|
|
|
func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) {
|
|
return t.chain, nil
|
|
}
|
|
|
|
func (t *testIdentity) Signer() (crypto.Signer, error) {
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
|
|
func (t *testIdentity) Delete() error {
|
|
return errors.New("not implemented")
|
|
}
|
|
|
|
func (t *testIdentity) Close() {}
|
|
|
|
func TestSelectIdentityFromSlice(t *testing.T) {
|
|
var times []time.Time
|
|
for _, ts := range []string{
|
|
"2000-01-01T00:00:00Z",
|
|
"2001-01-01T00:00:00Z",
|
|
"2002-01-01T00:00:00Z",
|
|
"2003-01-01T00:00:00Z",
|
|
} {
|
|
tm, err := time.Parse(time.RFC3339, ts)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
times = append(times, tm)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
subject string
|
|
ids []certstore.Identity
|
|
now time.Time
|
|
// wantIndex is an index into ids, or -1 for nil.
|
|
wantIndex int
|
|
}{
|
|
{
|
|
name: "single unexpired identity",
|
|
subject: testRootSubject,
|
|
ids: []certstore.Identity{
|
|
&testIdentity{
|
|
chain: makeChain(testRootCommonName, times[0], times[2]),
|
|
},
|
|
},
|
|
now: times[1],
|
|
wantIndex: 0,
|
|
},
|
|
{
|
|
name: "single expired identity",
|
|
subject: testRootSubject,
|
|
ids: []certstore.Identity{
|
|
&testIdentity{
|
|
chain: makeChain(testRootCommonName, times[0], times[1]),
|
|
},
|
|
},
|
|
now: times[2],
|
|
wantIndex: -1,
|
|
},
|
|
{
|
|
name: "unrelated ids",
|
|
subject: testRootSubject,
|
|
ids: []certstore.Identity{
|
|
&testIdentity{
|
|
chain: makeChain("something", times[0], times[2]),
|
|
},
|
|
&testIdentity{
|
|
chain: makeChain(testRootCommonName, times[0], times[2]),
|
|
},
|
|
&testIdentity{
|
|
chain: makeChain("else", times[0], times[2]),
|
|
},
|
|
},
|
|
now: times[1],
|
|
wantIndex: 1,
|
|
},
|
|
{
|
|
name: "expired with unrelated ids",
|
|
subject: testRootSubject,
|
|
ids: []certstore.Identity{
|
|
&testIdentity{
|
|
chain: makeChain("something", times[0], times[3]),
|
|
},
|
|
&testIdentity{
|
|
chain: makeChain(testRootCommonName, times[0], times[1]),
|
|
},
|
|
&testIdentity{
|
|
chain: makeChain("else", times[0], times[3]),
|
|
},
|
|
},
|
|
now: times[2],
|
|
wantIndex: -1,
|
|
},
|
|
{
|
|
name: "one expired",
|
|
subject: testRootSubject,
|
|
ids: []certstore.Identity{
|
|
&testIdentity{
|
|
chain: makeChain(testRootCommonName, times[0], times[1]),
|
|
},
|
|
&testIdentity{
|
|
chain: makeChain(testRootCommonName, times[1], times[3]),
|
|
},
|
|
},
|
|
now: times[2],
|
|
wantIndex: 1,
|
|
},
|
|
{
|
|
name: "two certs both unexpired",
|
|
subject: testRootSubject,
|
|
ids: []certstore.Identity{
|
|
&testIdentity{
|
|
chain: makeChain(testRootCommonName, times[0], times[3]),
|
|
},
|
|
&testIdentity{
|
|
chain: makeChain(testRootCommonName, times[1], times[3]),
|
|
},
|
|
},
|
|
now: times[2],
|
|
wantIndex: 1,
|
|
},
|
|
{
|
|
name: "two unexpired one expired",
|
|
subject: testRootSubject,
|
|
ids: []certstore.Identity{
|
|
&testIdentity{
|
|
chain: makeChain(testRootCommonName, times[0], times[3]),
|
|
},
|
|
&testIdentity{
|
|
chain: makeChain(testRootCommonName, times[1], times[3]),
|
|
},
|
|
&testIdentity{
|
|
chain: makeChain(testRootCommonName, times[0], times[1]),
|
|
},
|
|
},
|
|
now: times[2],
|
|
wantIndex: 1,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now)
|
|
|
|
if gotId == nil && gotChain != nil {
|
|
t.Error("id is nil: got non-nil chain, want nil chain")
|
|
return
|
|
}
|
|
if gotId != nil && gotChain == nil {
|
|
t.Error("id is not nil: got nil chain, want non-nil chain")
|
|
return
|
|
}
|
|
if tt.wantIndex == -1 {
|
|
if gotId != nil {
|
|
t.Error("got non-nil id, want nil id")
|
|
}
|
|
return
|
|
}
|
|
if gotId == nil {
|
|
t.Error("got nil id, want non-nil id")
|
|
return
|
|
}
|
|
if gotId != tt.ids[tt.wantIndex] {
|
|
found := -1
|
|
for i := range tt.ids {
|
|
if tt.ids[i] == gotId {
|
|
found = i
|
|
break
|
|
}
|
|
}
|
|
if found == -1 {
|
|
t.Errorf("got unknown id, want id at index %v", tt.wantIndex)
|
|
} else {
|
|
t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex)
|
|
}
|
|
}
|
|
|
|
tid, ok := tt.ids[tt.wantIndex].(*testIdentity)
|
|
if !ok {
|
|
t.Error("got non-testIdentity, want testIdentity")
|
|
return
|
|
}
|
|
|
|
if !reflect.DeepEqual(tid.chain, gotChain) {
|
|
t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex)
|
|
}
|
|
})
|
|
}
|
|
}
|