211 lines
6.2 KiB
Go
211 lines
6.2 KiB
Go
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package winutil
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"runtime"
|
||
|
"strings"
|
||
|
|
||
|
"golang.org/x/sys/windows"
|
||
|
|
||
|
"tailscale.com/paths"
|
||
|
"tailscale.com/util/winutil/vss"
|
||
|
)
|
||
|
|
||
|
// StopWalking is the error value that a WalkSnapshotsFunc should return when
|
||
|
// it successfully completes and no longer needs to examine any more snapshots.
|
||
|
var StopWalking error = errors.New("Stop walking")
|
||
|
|
||
|
// WalkSnapshotsFunc is the type of the function called by WalkSnapshotsForLegacyStateDir
|
||
|
// to visit each mapped VSS snapshot.
|
||
|
// The path argument is the path of the directory containing the Tailscale state.
|
||
|
// The props argument contains the snapshot properties of the current snapshot, and
|
||
|
// should be treated as read-only.
|
||
|
// The function may return StopWalking if further walking is no longer necessary.
|
||
|
// Otherwise it should return nil to proceed with the walk, or an error.
|
||
|
type WalkSnapshotsFunc func(path string, props vss.SnapshotProperties) error
|
||
|
|
||
|
// WalkSnapshotsForLegacyStateDir enumerates available snapshots from the
|
||
|
// Volume Shadow Copy service. For each snapshot originating from this computer's
|
||
|
// C: volume, the snapshot is mounted to a temporary location inside the
|
||
|
// Tailscaled state directory.
|
||
|
// If the mounted snapshot contains a path to a legacy state directory (located under
|
||
|
// C:\Windows\System32\config\systemprofile\AppData\Local), the fn argument is
|
||
|
// invoked with the fully-qualified path to the mounted state directory, as well
|
||
|
// as the properties of the snapshot itself.
|
||
|
// A mounted snapshot that does not contain a path to a legacy state directory is
|
||
|
// not considered to be an error, the snapshot is ignored, and the walk continues.
|
||
|
// If fn returns StopWalking, then the walk is terminated but is considered to
|
||
|
// have been successful and nil is returned.
|
||
|
// If fn returns a different error, then the walk is terminated and fn's error
|
||
|
// is wrapped and then returned to the caller.
|
||
|
func WalkSnapshotsForLegacyStateDir(fn WalkSnapshotsFunc) error {
|
||
|
runtime.LockOSThread()
|
||
|
defer runtime.UnlockOSThread()
|
||
|
|
||
|
// Ideally COM would be initialized process-wide, but until we have that
|
||
|
// conversation this should be okay, especially given that this function will
|
||
|
// only be called when a migration is necessary.
|
||
|
err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer windows.CoUninitialize()
|
||
|
|
||
|
sysVol, err := getSystemVolumeName()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
thisMachine, err := getFullyQualifiedComputerName()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// We'll map each snapshot to a subdir inside our tailscaled state dir
|
||
|
mountPt := filepath.Dir(paths.DefaultTailscaledStateFile())
|
||
|
|
||
|
vssSnapshotEnumerator, err := vss.NewSnapshotEnumerator()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer vssSnapshotEnumerator.Close()
|
||
|
|
||
|
snapshots, err := vssSnapshotEnumerator.QuerySnapshots()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer snapshots.Close()
|
||
|
|
||
|
for _, snap := range snapshots {
|
||
|
if !strings.EqualFold(snap.Obj.OriginalVolumeName.String(), sysVol) ||
|
||
|
!strings.EqualFold(snap.Obj.OriginatingMachine.String(), thisMachine) {
|
||
|
// These snapshots do not belong to our computer's C: volume, so we should skip them.
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
mounted, err := mountSnapshotDevice(snap.Obj, mountPt)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("Mapping snapshot device %v: %w", snap.Obj.SnapshotDeviceObject.String(), err)
|
||
|
}
|
||
|
defer mounted.Close()
|
||
|
|
||
|
legacyStateDir, err := mounted.findLegacyStateDir()
|
||
|
if err != nil {
|
||
|
// Not all snapshots will necessarily contain the state dir, so this is not fatal
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
err = fn(legacyStateDir, snap.Obj)
|
||
|
if errors.Is(err, StopWalking) {
|
||
|
return nil
|
||
|
}
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("WalkSnapshotsFunc returned error %w", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func getSystemVolumeName() (string, error) {
|
||
|
// This is the exact length of a volume name, including nul terminator (per MSDN)
|
||
|
var volName [50]uint16
|
||
|
|
||
|
// Modern Windows always requires that the OS be installed on C:
|
||
|
mountPt, err := windows.UTF16PtrFromString("C:\\")
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
err = windows.GetVolumeNameForVolumeMountPoint(mountPt, &volName[0], uint32(len(volName)))
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
return windows.UTF16ToString(volName[:len(volName)-1]), nil
|
||
|
}
|
||
|
|
||
|
type mountedSnapshot string
|
||
|
|
||
|
func (snap *mountedSnapshot) Close() error {
|
||
|
os.Remove(string(*snap))
|
||
|
*snap = ""
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func mountSnapshotDevice(snap vss.SnapshotProperties, mountPath string) (mountedSnapshot, error) {
|
||
|
fi, err := os.Stat(mountPath)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
if !fi.IsDir() {
|
||
|
return "", os.ErrInvalid
|
||
|
}
|
||
|
|
||
|
devPath := snap.SnapshotDeviceObject.String()
|
||
|
linkPath := filepath.Join(mountPath, filepath.Base(devPath))
|
||
|
|
||
|
linkPathUTF16, err := windows.UTF16PtrFromString(linkPath)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
// The target needs to end with a backslash or else the symlink won't resolve correctly
|
||
|
deviceUTF16, err := windows.UTF16PtrFromString(devPath + "\\")
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
err = windows.CreateSymbolicLink(linkPathUTF16, deviceUTF16, windows.SYMBOLIC_LINK_FLAG_DIRECTORY)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
return mountedSnapshot(linkPath), nil
|
||
|
}
|
||
|
|
||
|
func (snap *mountedSnapshot) findLegacyStateDir() (string, error) {
|
||
|
legacyStateDir := filepath.Dir(paths.LegacyStateFilePath())
|
||
|
relPath, err := filepath.Rel("C:\\", legacyStateDir)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
snapStateDir := filepath.Join(string(*snap), relPath)
|
||
|
fi, err := os.Stat(snapStateDir)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
if !fi.IsDir() {
|
||
|
return "", os.ErrInvalid
|
||
|
}
|
||
|
|
||
|
return snapStateDir, nil
|
||
|
}
|
||
|
|
||
|
func getFullyQualifiedComputerName() (string, error) {
|
||
|
var desiredLen uint32
|
||
|
err := windows.GetComputerNameEx(windows.ComputerNamePhysicalDnsFullyQualified, nil, &desiredLen)
|
||
|
if !errors.Is(err, windows.ERROR_MORE_DATA) {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
buf := make([]uint16, desiredLen+1)
|
||
|
|
||
|
// Note: bufLen includes nul terminator on input, but excludes nul terminator as output
|
||
|
bufLen := uint32(len(buf))
|
||
|
err = windows.GetComputerNameEx(windows.ComputerNamePhysicalDnsFullyQualified, &buf[0], &bufLen)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
return windows.UTF16ToString(buf[:bufLen]), nil
|
||
|
}
|