tailscale/util/winutil/winrestore_windows.go

211 lines
6.2 KiB
Go

// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package winutil
import (
"errors"
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"golang.org/x/sys/windows"
"tailscale.com/paths"
"tailscale.com/util/winutil/vss"
)
// StopWalking is the error value that a WalkSnapshotsFunc should return when
// it successfully completes and no longer needs to examine any more snapshots.
var StopWalking error = errors.New("Stop walking")
// WalkSnapshotsFunc is the type of the function called by WalkSnapshotsForLegacyStateDir
// to visit each mapped VSS snapshot.
// The path argument is the path of the directory containing the Tailscale state.
// The props argument contains the snapshot properties of the current snapshot, and
// should be treated as read-only.
// The function may return StopWalking if further walking is no longer necessary.
// Otherwise it should return nil to proceed with the walk, or an error.
type WalkSnapshotsFunc func(path string, props vss.SnapshotProperties) error
// WalkSnapshotsForLegacyStateDir enumerates available snapshots from the
// Volume Shadow Copy service. For each snapshot originating from this computer's
// C: volume, the snapshot is mounted to a temporary location inside the
// Tailscaled state directory.
// If the mounted snapshot contains a path to a legacy state directory (located under
// C:\Windows\System32\config\systemprofile\AppData\Local), the fn argument is
// invoked with the fully-qualified path to the mounted state directory, as well
// as the properties of the snapshot itself.
// A mounted snapshot that does not contain a path to a legacy state directory is
// not considered to be an error, the snapshot is ignored, and the walk continues.
// If fn returns StopWalking, then the walk is terminated but is considered to
// have been successful and nil is returned.
// If fn returns a different error, then the walk is terminated and fn's error
// is wrapped and then returned to the caller.
func WalkSnapshotsForLegacyStateDir(fn WalkSnapshotsFunc) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
// Ideally COM would be initialized process-wide, but until we have that
// conversation this should be okay, especially given that this function will
// only be called when a migration is necessary.
err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED)
if err != nil {
return err
}
defer windows.CoUninitialize()
sysVol, err := getSystemVolumeName()
if err != nil {
return err
}
thisMachine, err := getFullyQualifiedComputerName()
if err != nil {
return err
}
// We'll map each snapshot to a subdir inside our tailscaled state dir
mountPt := filepath.Dir(paths.DefaultTailscaledStateFile())
vssSnapshotEnumerator, err := vss.NewSnapshotEnumerator()
if err != nil {
return err
}
defer vssSnapshotEnumerator.Close()
snapshots, err := vssSnapshotEnumerator.QuerySnapshots()
if err != nil {
return err
}
defer snapshots.Close()
for _, snap := range snapshots {
if !strings.EqualFold(snap.Obj.OriginalVolumeName.String(), sysVol) ||
!strings.EqualFold(snap.Obj.OriginatingMachine.String(), thisMachine) {
// These snapshots do not belong to our computer's C: volume, so we should skip them.
continue
}
mounted, err := mountSnapshotDevice(snap.Obj, mountPt)
if err != nil {
return fmt.Errorf("Mapping snapshot device %v: %w", snap.Obj.SnapshotDeviceObject.String(), err)
}
defer mounted.Close()
legacyStateDir, err := mounted.findLegacyStateDir()
if err != nil {
// Not all snapshots will necessarily contain the state dir, so this is not fatal
continue
}
err = fn(legacyStateDir, snap.Obj)
if errors.Is(err, StopWalking) {
return nil
}
if err != nil {
return fmt.Errorf("WalkSnapshotsFunc returned error %w", err)
}
}
return nil
}
func getSystemVolumeName() (string, error) {
// This is the exact length of a volume name, including nul terminator (per MSDN)
var volName [50]uint16
// Modern Windows always requires that the OS be installed on C:
mountPt, err := windows.UTF16PtrFromString("C:\\")
if err != nil {
return "", err
}
err = windows.GetVolumeNameForVolumeMountPoint(mountPt, &volName[0], uint32(len(volName)))
if err != nil {
return "", err
}
return windows.UTF16ToString(volName[:len(volName)-1]), nil
}
type mountedSnapshot string
func (snap *mountedSnapshot) Close() error {
os.Remove(string(*snap))
*snap = ""
return nil
}
func mountSnapshotDevice(snap vss.SnapshotProperties, mountPath string) (mountedSnapshot, error) {
fi, err := os.Stat(mountPath)
if err != nil {
return "", err
}
if !fi.IsDir() {
return "", os.ErrInvalid
}
devPath := snap.SnapshotDeviceObject.String()
linkPath := filepath.Join(mountPath, filepath.Base(devPath))
linkPathUTF16, err := windows.UTF16PtrFromString(linkPath)
if err != nil {
return "", err
}
// The target needs to end with a backslash or else the symlink won't resolve correctly
deviceUTF16, err := windows.UTF16PtrFromString(devPath + "\\")
if err != nil {
return "", err
}
err = windows.CreateSymbolicLink(linkPathUTF16, deviceUTF16, windows.SYMBOLIC_LINK_FLAG_DIRECTORY)
if err != nil {
return "", err
}
return mountedSnapshot(linkPath), nil
}
func (snap *mountedSnapshot) findLegacyStateDir() (string, error) {
legacyStateDir := filepath.Dir(paths.LegacyStateFilePath())
relPath, err := filepath.Rel("C:\\", legacyStateDir)
if err != nil {
return "", err
}
snapStateDir := filepath.Join(string(*snap), relPath)
fi, err := os.Stat(snapStateDir)
if err != nil {
return "", err
}
if !fi.IsDir() {
return "", os.ErrInvalid
}
return snapStateDir, nil
}
func getFullyQualifiedComputerName() (string, error) {
var desiredLen uint32
err := windows.GetComputerNameEx(windows.ComputerNamePhysicalDnsFullyQualified, nil, &desiredLen)
if !errors.Is(err, windows.ERROR_MORE_DATA) {
return "", err
}
buf := make([]uint16, desiredLen+1)
// Note: bufLen includes nul terminator on input, but excludes nul terminator as output
bufLen := uint32(len(buf))
err = windows.GetComputerNameEx(windows.ComputerNamePhysicalDnsFullyQualified, &buf[0], &bufLen)
if err != nil {
return "", err
}
return windows.UTF16ToString(buf[:bufLen]), nil
}