control/controlclient: select newer certificate
If multiple certificates match when selecting a certificate, use the one issued the most recently (as determined by the NotBefore timestamp). This also adds some tests for the function that performs that comparison. Updates tailscale/coral#6 Signed-off-by: Adrian Dewhurst <adrian@tailscale.com>
This commit is contained in:
parent
a80cef0c13
commit
adda2d2a51
|
@ -18,6 +18,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/certstore"
|
||||
"tailscale.com/tailcfg"
|
||||
|
@ -73,23 +74,46 @@ func isSubjectInChain(subject string, chain []*x509.Certificate) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func selectIdentityFromSlice(subject string, ids []certstore.Identity) (certstore.Identity, []*x509.Certificate) {
|
||||
func selectIdentityFromSlice(subject string, ids []certstore.Identity, now time.Time) (certstore.Identity, []*x509.Certificate) {
|
||||
var bestCandidate struct {
|
||||
id certstore.Identity
|
||||
chain []*x509.Certificate
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
chain, err := id.CertificateChain()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(chain) < 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
if !isSupportedCertificate(chain[0]) {
|
||||
continue
|
||||
}
|
||||
|
||||
if isSubjectInChain(subject, chain) {
|
||||
return id, chain
|
||||
if now.Before(chain[0].NotBefore) || now.After(chain[0].NotAfter) {
|
||||
// Certificate is not valid at this time
|
||||
continue
|
||||
}
|
||||
|
||||
if !isSubjectInChain(subject, chain) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Select the most recently issued certificate. If there is a tie, pick
|
||||
// one arbitrarily.
|
||||
if len(bestCandidate.chain) > 0 && bestCandidate.chain[0].NotBefore.After(chain[0].NotBefore) {
|
||||
continue
|
||||
}
|
||||
|
||||
bestCandidate.id = id
|
||||
bestCandidate.chain = chain
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return bestCandidate.id, bestCandidate.chain
|
||||
}
|
||||
|
||||
// findIdentity locates an identity from the Windows or Darwin certificate
|
||||
|
@ -105,7 +129,7 @@ func findIdentity(subject string, st certstore.Store) (certstore.Identity, []*x5
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
selected, chain := selectIdentityFromSlice(subject, ids)
|
||||
selected, chain := selectIdentityFromSlice(subject, ids, time.Now())
|
||||
|
||||
for _, id := range ids {
|
||||
if id != selected {
|
||||
|
|
|
@ -0,0 +1,238 @@
|
|||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build windows && cgo
|
||||
// +build windows,cgo
|
||||
|
||||
package controlclient
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/certstore"
|
||||
)
|
||||
|
||||
const (
|
||||
testRootCommonName = "testroot"
|
||||
testRootSubject = "CN=testroot"
|
||||
)
|
||||
|
||||
type testIdentity struct {
|
||||
chain []*x509.Certificate
|
||||
}
|
||||
|
||||
func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate {
|
||||
return []*x509.Certificate{
|
||||
{
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
PublicKeyAlgorithm: x509.RSA,
|
||||
},
|
||||
{
|
||||
Subject: pkix.Name{
|
||||
CommonName: rootCommonName,
|
||||
},
|
||||
PublicKeyAlgorithm: x509.RSA,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *testIdentity) Certificate() (*x509.Certificate, error) {
|
||||
return t.chain[0], nil
|
||||
}
|
||||
|
||||
func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) {
|
||||
return t.chain, nil
|
||||
}
|
||||
|
||||
func (t *testIdentity) Signer() (crypto.Signer, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (t *testIdentity) Delete() error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (t *testIdentity) Close() {}
|
||||
|
||||
func TestSelectIdentityFromSlice(t *testing.T) {
|
||||
var times []time.Time
|
||||
for _, ts := range []string{
|
||||
"2000-01-01T00:00:00Z",
|
||||
"2001-01-01T00:00:00Z",
|
||||
"2002-01-01T00:00:00Z",
|
||||
"2003-01-01T00:00:00Z",
|
||||
} {
|
||||
tm, err := time.Parse(time.RFC3339, ts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
times = append(times, tm)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
subject string
|
||||
ids []certstore.Identity
|
||||
now time.Time
|
||||
// wantIndex is an index into ids, or -1 for nil.
|
||||
wantIndex int
|
||||
}{
|
||||
{
|
||||
name: "single unexpired identity",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[2]),
|
||||
},
|
||||
},
|
||||
now: times[1],
|
||||
wantIndex: 0,
|
||||
},
|
||||
{
|
||||
name: "single expired identity",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[1]),
|
||||
},
|
||||
},
|
||||
now: times[2],
|
||||
wantIndex: -1,
|
||||
},
|
||||
{
|
||||
name: "unrelated ids",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain("something", times[0], times[2]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[2]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain("else", times[0], times[2]),
|
||||
},
|
||||
},
|
||||
now: times[1],
|
||||
wantIndex: 1,
|
||||
},
|
||||
{
|
||||
name: "expired with unrelated ids",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain("something", times[0], times[3]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[1]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain("else", times[0], times[3]),
|
||||
},
|
||||
},
|
||||
now: times[2],
|
||||
wantIndex: -1,
|
||||
},
|
||||
{
|
||||
name: "one expired",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[1]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[1], times[3]),
|
||||
},
|
||||
},
|
||||
now: times[2],
|
||||
wantIndex: 1,
|
||||
},
|
||||
{
|
||||
name: "two certs both unexpired",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[3]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[1], times[3]),
|
||||
},
|
||||
},
|
||||
now: times[2],
|
||||
wantIndex: 1,
|
||||
},
|
||||
{
|
||||
name: "two unexpired one expired",
|
||||
subject: testRootSubject,
|
||||
ids: []certstore.Identity{
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[3]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[1], times[3]),
|
||||
},
|
||||
&testIdentity{
|
||||
chain: makeChain(testRootCommonName, times[0], times[1]),
|
||||
},
|
||||
},
|
||||
now: times[2],
|
||||
wantIndex: 1,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now)
|
||||
|
||||
if gotId == nil && gotChain != nil {
|
||||
t.Error("id is nil: got non-nil chain, want nil chain")
|
||||
return
|
||||
}
|
||||
if gotId != nil && gotChain == nil {
|
||||
t.Error("id is not nil: got nil chain, want non-nil chain")
|
||||
return
|
||||
}
|
||||
if tt.wantIndex == -1 {
|
||||
if gotId != nil {
|
||||
t.Error("got non-nil id, want nil id")
|
||||
}
|
||||
return
|
||||
}
|
||||
if gotId == nil {
|
||||
t.Error("got nil id, want non-nil id")
|
||||
return
|
||||
}
|
||||
if gotId != tt.ids[tt.wantIndex] {
|
||||
found := -1
|
||||
for i := range tt.ids {
|
||||
if tt.ids[i] == gotId {
|
||||
found = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if found == -1 {
|
||||
t.Errorf("got unknown id, want id at index %v", tt.wantIndex)
|
||||
} else {
|
||||
t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex)
|
||||
}
|
||||
}
|
||||
|
||||
tid, ok := tt.ids[tt.wantIndex].(*testIdentity)
|
||||
if !ok {
|
||||
t.Error("got non-testIdentity, want testIdentity")
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tid.chain, gotChain) {
|
||||
t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue