util/winutil: add AllocateContiguousBuffer and SetNTString helper funcs
AllocateContiguousBuffer is for allocating structs with trailing buffers containing additional data. It is to be used for various Windows structures containing pointers to data located immediately after the struct. SetNTString performs in-place setting of windows.NTString and windows.NTUnicodeString. Updates #12383 Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
parent
c3e2b7347b
commit
df86576989
|
@ -175,6 +175,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
|||
golang.org/x/crypto/nacl/box from tailscale.com/types/key
|
||||
golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box
|
||||
golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+
|
||||
W golang.org/x/exp/constraints from tailscale.com/util/winutil
|
||||
L golang.org/x/net/bpf from github.com/mdlayher/netlink+
|
||||
golang.org/x/net/dns/dnsmessage from net+
|
||||
golang.org/x/net/http/httpguts from net/http
|
||||
|
|
|
@ -182,7 +182,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
|||
golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box
|
||||
golang.org/x/crypto/pbkdf2 from software.sslmate.com/src/go-pkcs12
|
||||
golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+
|
||||
W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe
|
||||
W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+
|
||||
golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli
|
||||
golang.org/x/net/bpf from github.com/mdlayher/netlink+
|
||||
golang.org/x/net/dns/dnsmessage from net+
|
||||
|
|
|
@ -7,14 +7,17 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
@ -643,3 +646,141 @@ func LogonSessionID(token windows.Token) (logonSessionID windows.LUID, err error
|
|||
|
||||
return origin.originatingLogonSession, nil
|
||||
}
|
||||
|
||||
// BufUnit is a type constraint for buffers passed into AllocateContiguousBuffer.
|
||||
type BufUnit interface {
|
||||
byte | uint16
|
||||
}
|
||||
|
||||
// AllocateContiguousBuffer allocates memory to satisfy the Windows idiom where
|
||||
// some structs contain pointers that are expected to refer to memory within the
|
||||
// same buffer containing the struct itself. T is the type that contains
|
||||
// the pointers. values must contain the actual data that is to be copied
|
||||
// into the buffer after T. AllocateContiguousBuffer returns a pointer to the
|
||||
// struct, the total length of the buffer in bytes, and a slice containing
|
||||
// each value within the buffer. The caller may use slcs to populate any
|
||||
// pointers in t as needed. Each element of slcs corresponds to the element of
|
||||
// values in the same position.
|
||||
//
|
||||
// It is the responsibility of the caller to ensure that any values expected
|
||||
// to contain null-terminated strings are in fact null-terminated!
|
||||
//
|
||||
// AllocateContiguousBuffer panics if no values are passed in, as there are
|
||||
// better alternatives for allocating a struct in that case.
|
||||
func AllocateContiguousBuffer[T any, BU BufUnit](values ...[]BU) (t *T, tLenBytes uint32, slcs [][]BU) {
|
||||
if len(values) == 0 {
|
||||
panic("len(values) must be > 0")
|
||||
}
|
||||
|
||||
// Get the sizes of T and BU, then compute a preferred alignment for T.
|
||||
tT := reflect.TypeFor[T]()
|
||||
szT := tT.Size()
|
||||
szBU := int(unsafe.Sizeof(BU(0)))
|
||||
alignment := max(tT.Align(), szBU)
|
||||
|
||||
// Our buffers for values will start at the next szBU boundary.
|
||||
tLenBytes = alignUp(uint32(szT), szBU)
|
||||
firstValueOffset := tLenBytes
|
||||
|
||||
// Accumulate the length of each value into tLenBytes
|
||||
for _, v := range values {
|
||||
tLenBytes += uint32(len(v) * szBU)
|
||||
}
|
||||
|
||||
// Now that we know the final length, align up to our preferred boundary.
|
||||
tLenBytes = alignUp(tLenBytes, alignment)
|
||||
|
||||
// Allocate the buffer. We choose a type for the slice that is appropriate
|
||||
// for the desired alignment. Note that we do not have a strict requirement
|
||||
// that T contain pointer fields; we could just be appending more data
|
||||
// within the same buffer.
|
||||
bufLen := tLenBytes / uint32(alignment)
|
||||
var pt unsafe.Pointer
|
||||
switch alignment {
|
||||
case 1:
|
||||
pt = unsafe.Pointer(unsafe.SliceData(make([]byte, bufLen)))
|
||||
case 2:
|
||||
pt = unsafe.Pointer(unsafe.SliceData(make([]uint16, bufLen)))
|
||||
case 4:
|
||||
pt = unsafe.Pointer(unsafe.SliceData(make([]uint32, bufLen)))
|
||||
case 8:
|
||||
pt = unsafe.Pointer(unsafe.SliceData(make([]uint64, bufLen)))
|
||||
default:
|
||||
panic(fmt.Sprintf("bad alignment %d", alignment))
|
||||
}
|
||||
|
||||
t = (*T)(pt)
|
||||
slcs = make([][]BU, 0, len(values))
|
||||
|
||||
// Use the limits of the buffer area after t to construct a slice representing the remaining buffer.
|
||||
firstValuePtr := unsafe.Pointer(uintptr(pt) + uintptr(firstValueOffset))
|
||||
buf := unsafe.Slice((*BU)(firstValuePtr), (tLenBytes-firstValueOffset)/uint32(szBU))
|
||||
|
||||
// Copy each value into the buffer and record a slice describing each value's limits into slcs.
|
||||
var index int
|
||||
for _, v := range values {
|
||||
if len(v) == 0 {
|
||||
// We allow zero-length values; we simply append a nil slice.
|
||||
slcs = append(slcs, nil)
|
||||
continue
|
||||
}
|
||||
valueSlice := buf[index : index+len(v)]
|
||||
copy(valueSlice, v)
|
||||
slcs = append(slcs, valueSlice)
|
||||
index += len(v)
|
||||
}
|
||||
|
||||
return t, tLenBytes, slcs
|
||||
}
|
||||
|
||||
// alignment must be a power of 2
|
||||
func alignUp[V constraints.Integer](v V, alignment int) V {
|
||||
return v + ((-v) & (V(alignment) - 1))
|
||||
}
|
||||
|
||||
// NTStr is a type constraint requiring the type to be either a
|
||||
// windows.NTString or a windows.NTUnicodeString.
|
||||
type NTStr interface {
|
||||
windows.NTString | windows.NTUnicodeString
|
||||
}
|
||||
|
||||
// SetNTString sets the value of nts in-place to point to the string contained
|
||||
// within buf. A nul terminator is optional in buf.
|
||||
func SetNTString[NTS NTStr, BU BufUnit](nts *NTS, buf []BU) {
|
||||
isEmpty := len(buf) == 0
|
||||
codeUnitSize := uint16(unsafe.Sizeof(BU(0)))
|
||||
lenBytes := len(buf) * int(codeUnitSize)
|
||||
if lenBytes > math.MaxUint16 {
|
||||
panic("buffer length must fit into uint16")
|
||||
}
|
||||
lenBytes16 := uint16(lenBytes)
|
||||
|
||||
switch p := any(nts).(type) {
|
||||
case *windows.NTString:
|
||||
if isEmpty {
|
||||
*p = windows.NTString{}
|
||||
break
|
||||
}
|
||||
p.Buffer = unsafe.SliceData(any(buf).([]byte))
|
||||
p.MaximumLength = lenBytes16
|
||||
p.Length = lenBytes16
|
||||
// account for nul terminator when present
|
||||
if buf[len(buf)-1] == 0 {
|
||||
p.Length -= codeUnitSize
|
||||
}
|
||||
case *windows.NTUnicodeString:
|
||||
if isEmpty {
|
||||
*p = windows.NTUnicodeString{}
|
||||
break
|
||||
}
|
||||
p.Buffer = unsafe.SliceData(any(buf).([]uint16))
|
||||
p.MaximumLength = lenBytes16
|
||||
p.Length = lenBytes16
|
||||
// account for nul terminator when present
|
||||
if buf[len(buf)-1] == 0 {
|
||||
p.Length -= codeUnitSize
|
||||
}
|
||||
default:
|
||||
panic("unknown type")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,9 +4,13 @@
|
|||
package winutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
//lint:file-ignore U1000 Fields are unused but necessary for tests.
|
||||
|
||||
const (
|
||||
localSystemSID = "S-1-5-18"
|
||||
networkSID = "S-1-5-2"
|
||||
|
@ -28,3 +32,103 @@ func TestLookupPseudoUser(t *testing.T) {
|
|||
t.Errorf("LookupPseudoUser(%q) unexpectedly succeeded", networkSID)
|
||||
}
|
||||
}
|
||||
|
||||
type testType interface {
|
||||
byte | uint16 | uint32 | uint64
|
||||
}
|
||||
|
||||
type noPointers[T testType] struct {
|
||||
foo byte
|
||||
bar T
|
||||
baz bool
|
||||
}
|
||||
|
||||
type hasPointer struct {
|
||||
foo byte
|
||||
bar uint32
|
||||
s1 *struct{}
|
||||
baz byte
|
||||
}
|
||||
|
||||
func checkContiguousBuffer[T any, BU BufUnit](t *testing.T, extra []BU, pt *T, ptLen uint32, slcs [][]BU) {
|
||||
szBU := int(unsafe.Sizeof(BU(0)))
|
||||
expectedAlign := max(reflect.TypeFor[T]().Align(), szBU)
|
||||
// Check that pointer is aligned
|
||||
if rem := uintptr(unsafe.Pointer(pt)) % uintptr(expectedAlign); rem != 0 {
|
||||
t.Errorf("pointer alignment got %d, want 0", rem)
|
||||
}
|
||||
// Check that alloc length is aligned
|
||||
if rem := int(ptLen) % expectedAlign; rem != 0 {
|
||||
t.Errorf("allocation length alignment got %d, want 0", rem)
|
||||
}
|
||||
expectedLen := int(unsafe.Sizeof(*pt))
|
||||
expectedLen = alignUp(expectedLen, szBU)
|
||||
expectedLen += len(extra) * szBU
|
||||
expectedLen = alignUp(expectedLen, expectedAlign)
|
||||
if gotLen := int(ptLen); gotLen != expectedLen {
|
||||
t.Errorf("allocation length got %d, want %d", gotLen, expectedLen)
|
||||
}
|
||||
if l := len(slcs); l != 1 {
|
||||
t.Errorf("len(slcs) got %d, want 1", l)
|
||||
}
|
||||
if len(extra) == 0 && slcs[0] != nil {
|
||||
t.Error("slcs[0] got non-nil, want nil")
|
||||
}
|
||||
if len(extra) != len(slcs[0]) {
|
||||
t.Errorf("len(slcs[0]) got %d, want %d", len(slcs[0]), len(extra))
|
||||
} else if rem := uintptr(unsafe.Pointer(unsafe.SliceData(slcs[0]))) % uintptr(szBU); rem != 0 {
|
||||
t.Errorf("additional data alignment got %d, want 0", rem)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocateContiguousBuffer(t *testing.T) {
|
||||
t.Run("NoValues", testNoValues)
|
||||
t.Run("NoPointers", testNoPointers)
|
||||
t.Run("HasPointer", testHasPointer)
|
||||
}
|
||||
|
||||
func testNoValues(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("expected panic but didn't get one")
|
||||
}
|
||||
}()
|
||||
|
||||
AllocateContiguousBuffer[hasPointer, byte]()
|
||||
}
|
||||
|
||||
const maxTestBufLen = 8
|
||||
|
||||
func testNoPointers(t *testing.T) {
|
||||
buf8 := make([]byte, maxTestBufLen)
|
||||
buf16 := make([]uint16, maxTestBufLen)
|
||||
for i := range maxTestBufLen {
|
||||
s8, sl, slcs8 := AllocateContiguousBuffer[noPointers[byte]](buf8[:i])
|
||||
checkContiguousBuffer(t, buf8[:i], s8, sl, slcs8)
|
||||
s16, sl, slcs8 := AllocateContiguousBuffer[noPointers[uint16]](buf8[:i])
|
||||
checkContiguousBuffer(t, buf8[:i], s16, sl, slcs8)
|
||||
s32, sl, slcs8 := AllocateContiguousBuffer[noPointers[uint32]](buf8[:i])
|
||||
checkContiguousBuffer(t, buf8[:i], s32, sl, slcs8)
|
||||
s64, sl, slcs8 := AllocateContiguousBuffer[noPointers[uint64]](buf8[:i])
|
||||
checkContiguousBuffer(t, buf8[:i], s64, sl, slcs8)
|
||||
s8, sl, slcs16 := AllocateContiguousBuffer[noPointers[byte]](buf16[:i])
|
||||
checkContiguousBuffer(t, buf16[:i], s8, sl, slcs16)
|
||||
s16, sl, slcs16 = AllocateContiguousBuffer[noPointers[uint16]](buf16[:i])
|
||||
checkContiguousBuffer(t, buf16[:i], s16, sl, slcs16)
|
||||
s32, sl, slcs16 = AllocateContiguousBuffer[noPointers[uint32]](buf16[:i])
|
||||
checkContiguousBuffer(t, buf16[:i], s32, sl, slcs16)
|
||||
s64, sl, slcs16 = AllocateContiguousBuffer[noPointers[uint64]](buf16[:i])
|
||||
checkContiguousBuffer(t, buf16[:i], s64, sl, slcs16)
|
||||
}
|
||||
}
|
||||
|
||||
func testHasPointer(t *testing.T) {
|
||||
buf8 := make([]byte, maxTestBufLen)
|
||||
buf16 := make([]uint16, maxTestBufLen)
|
||||
for i := range maxTestBufLen {
|
||||
s, sl, slcs8 := AllocateContiguousBuffer[hasPointer](buf8[:i])
|
||||
checkContiguousBuffer(t, buf8[:i], s, sl, slcs8)
|
||||
s, sl, slcs16 := AllocateContiguousBuffer[hasPointer](buf16[:i])
|
||||
checkContiguousBuffer(t, buf16[:i], s, sl, slcs16)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue