all: use new functions, add tests

This commit is contained in:
Eugene Burkov 2024-10-25 15:55:16 +03:00
parent dcbabaf4e3
commit 1dbc784982
10 changed files with 190 additions and 44 deletions

View File

@ -31,6 +31,7 @@ NOTE: Add new changes BELOW THIS COMMENT.
### Fixed ### Fixed
- Incorrect handling of sensitive files permissions on Windows ([#7314]).
- Repetitive statistics log messages ([#7338]). - Repetitive statistics log messages ([#7338]).
- Custom client cache ([#7250]). - Custom client cache ([#7250]).
- Missing runtime clients with information from the system hosts file on first - Missing runtime clients with information from the system hosts file on first

View File

@ -2,28 +2,34 @@ package aghos
import "io/fs" import "io/fs"
// TODO(e.burkov): Add platform-independent tests.
// Chmod is an extension for [os.Chmod] that properly handles Windows access // Chmod is an extension for [os.Chmod] that properly handles Windows access
// rights. // rights.
//
// TODO(e.burkov): !! use.
func Chmod(name string, perm fs.FileMode) (err error) { func Chmod(name string, perm fs.FileMode) (err error) {
return chmod(name, perm) return chmod(name, perm)
} }
// Mkdir is an extension for [os.Chmod] that properly handles Windows access // Mkdir is an extension for [os.Mkdir] that properly handles Windows access
// rights. // rights.
//
// TODO(e.burkov): !! use.
func Mkdir(name string, perm fs.FileMode) (err error) { func Mkdir(name string, perm fs.FileMode) (err error) {
return mkdir(name, perm) return mkdir(name, perm)
} }
// MkdirAll is an extension for [os.MkdirAll] that properly handles Windows
// access rights.
func MkdirAll(path string, perm fs.FileMode) (err error) {
return mkdirAll(path, perm)
}
// WriteFile is an extension for [os.WriteFile] that properly handles Windows
// access rights.
func WriteFile(filename string, data []byte, perm fs.FileMode) (err error) {
return writeFile(filename, data, perm)
}
// Stat is an extension for [os.Stat] that properly handles Windows access // Stat is an extension for [os.Stat] that properly handles Windows access
// rights. // rights.
//
// TODO(e.burkov): !! use.
func Stat(name string) (fi fs.FileInfo, err error) { func Stat(name string) (fi fs.FileInfo, err error) {
return stat(name) return stat(name)
} }
// TODO(e.burkov): !! add tests.

View File

@ -17,6 +17,16 @@ func mkdir(name string, perm fs.FileMode) (err error) {
return os.Mkdir(name, perm) return os.Mkdir(name, perm)
} }
// mkdirAll is a Unix implementation of [MkdirAll].
func mkdirAll(path string, perm fs.FileMode) (err error) {
return os.MkdirAll(path, perm)
}
// writeFile is a Unix implementation of [WriteFile].
func writeFile(filename string, data []byte, perm fs.FileMode) (err error) {
return os.WriteFile(filename, data, perm)
}
// stat is a Unix implementation of [Stat]. // stat is a Unix implementation of [Stat].
func stat(name string) (fi os.FileInfo, err error) { func stat(name string) (fi os.FileInfo, err error) {
return os.Stat(name) return os.Stat(name)

View File

@ -40,11 +40,11 @@ func stat(name string) (fi os.FileInfo, err error) {
const objectType windows.SE_OBJECT_TYPE = windows.SE_FILE_OBJECT const objectType windows.SE_OBJECT_TYPE = windows.SE_FILE_OBJECT
secInfo := windows.SECURITY_INFORMATION(0 | secInfo := windows.SECURITY_INFORMATION(
windows.OWNER_SECURITY_INFORMATION | windows.OWNER_SECURITY_INFORMATION |
windows.GROUP_SECURITY_INFORMATION | windows.GROUP_SECURITY_INFORMATION |
windows.DACL_SECURITY_INFORMATION | windows.DACL_SECURITY_INFORMATION |
windows.PROTECTED_DACL_SECURITY_INFORMATION, windows.PROTECTED_DACL_SECURITY_INFORMATION,
) )
sd, err := windows.GetNamedSecurityInfo(fi.Name(), objectType, secInfo) sd, err := windows.GetNamedSecurityInfo(fi.Name(), objectType, secInfo)
@ -97,7 +97,7 @@ func chmod(name string, perm fs.FileMode) (err error) {
const objectType windows.SE_OBJECT_TYPE = windows.SE_FILE_OBJECT const objectType windows.SE_OBJECT_TYPE = windows.SE_FILE_OBJECT
entries := make([]windows.EXPLICIT_ACCESS, 0, 3) entries := make([]windows.EXPLICIT_ACCESS, 0, 3)
creatorMask, groupMask, worldMask := modeToMasks(perm) creatorMask, groupMask, worldMask := permToMasks(perm)
sidMasks := container.KeyValues[windows.WELL_KNOWN_SID_TYPE, windows.ACCESS_MASK]{{ sidMasks := container.KeyValues[windows.WELL_KNOWN_SID_TYPE, windows.ACCESS_MASK]{{
Key: windows.WinCreatorOwnerSid, Key: windows.WinCreatorOwnerSid,
@ -175,6 +175,28 @@ func mkdir(name string, perm os.FileMode) (err error) {
return chmod(name, perm) return chmod(name, perm)
} }
// mkdirAll is a Windows implementation of [MkdirAll].
func mkdirAll(path string, perm os.FileMode) (err error) {
parent, _ := filepath.Split(path)
err = os.MkdirAll(parent, perm)
if err != nil {
return fmt.Errorf("creating parent directories: %w", err)
}
return mkdir(path, perm)
}
// writeFile is a Windows implementation of [WriteFile].
func writeFile(filename string, data []byte, perm os.FileMode) (err error) {
err = os.WriteFile(filename, data, perm)
if err != nil {
return fmt.Errorf("writing file: %w", err)
}
return chmod(filename, perm)
}
// newWellKnownTrustee returns a trustee for a well-known SID. // newWellKnownTrustee returns a trustee for a well-known SID.
func newWellKnownTrustee(stype windows.WELL_KNOWN_SID_TYPE) (t *windows.TRUSTEE, err error) { func newWellKnownTrustee(stype windows.WELL_KNOWN_SID_TYPE) (t *windows.TRUSTEE, err error) {
sid, err := windows.CreateWellKnownSid(stype) sid, err := windows.CreateWellKnownSid(stype)
@ -190,13 +212,13 @@ func newWellKnownTrustee(stype windows.WELL_KNOWN_SID_TYPE) (t *windows.TRUSTEE,
// Constants reflecting the UNIX permission bits. // Constants reflecting the UNIX permission bits.
const ( const (
ownerWrite = 0b010000000 ownerWrite = 0b010_000_000
groupWrite = 0b000100000 groupWrite = 0b000_100_000
worldWrite = 0b000000100 worldWrite = 0b000_000_100
ownerAll = 0b111000000 ownerAll = 0b111_000_000
groupAll = 0b000111000 groupAll = 0b000_111_000
worldAll = 0b000000111 worldAll = 0b000_000_111
) )
// Constants reflecting the number of bits to shift the UNIX permission bits to // Constants reflecting the number of bits to shift the UNIX permission bits to
@ -216,9 +238,9 @@ const (
deleteWorld = 15 deleteWorld = 15
) )
// modeToMasks converts a UNIX file mode to the corresponding Windows access // permToMasks converts a UNIX file mode permissions to the corresponding
// masks. // Windows access masks.
func modeToMasks(fm os.FileMode) (owner, group, world windows.ACCESS_MASK) { func permToMasks(fm os.FileMode) (owner, group, world windows.ACCESS_MASK) {
mask := windows.ACCESS_MASK(fm.Perm()) mask := windows.ACCESS_MASK(fm.Perm())
owner = ((mask & ownerAll) << genericOwner) | ((mask & ownerWrite) << deleteOwner) owner = ((mask & ownerAll) << genericOwner) | ((mask & ownerWrite) << deleteOwner)
@ -229,7 +251,7 @@ func modeToMasks(fm os.FileMode) (owner, group, world windows.ACCESS_MASK) {
} }
// masksToPerm converts Windows access masks to the corresponding UNIX file // masksToPerm converts Windows access masks to the corresponding UNIX file
// mode. // mode permission bits.
func masksToPerm(u, g, o windows.ACCESS_MASK) (perm os.FileMode) { func masksToPerm(u, g, o windows.ACCESS_MASK) (perm os.FileMode) {
perm |= os.FileMode(((u >> genericOwner) & ownerAll) | ((u >> deleteOwner) & ownerWrite)) perm |= os.FileMode(((u >> genericOwner) & ownerAll) | ((u >> deleteOwner) & ownerWrite))
perm |= os.FileMode(((g >> genericGroup) & groupAll) | ((g >> deleteGroup) & groupWrite)) perm |= os.FileMode(((g >> genericGroup) & groupAll) | ((g >> deleteGroup) & groupWrite))

View File

@ -0,0 +1,102 @@
//go:build windows
package aghos
import (
"io/fs"
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/sys/windows"
)
// Common test constants for the Windows access masks.
const (
winAccessWrite = windows.GENERIC_WRITE | windows.DELETE
winAccessFull = windows.GENERIC_READ | windows.GENERIC_EXECUTE | winAccessWrite
)
func TestPermToMasks(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
perm fs.FileMode
wantUser windows.ACCESS_MASK
wantGroup windows.ACCESS_MASK
wantOther windows.ACCESS_MASK
}{{
name: "all",
perm: 0b111_111_111,
wantUser: winAccessFull,
wantGroup: winAccessFull,
wantOther: winAccessFull,
}, {
name: "user_write",
perm: 0o010_000_000,
wantUser: winAccessWrite,
wantGroup: 0,
wantOther: 0,
}, {
name: "group_read",
perm: 0o000_010_000,
wantUser: 0,
wantGroup: windows.GENERIC_READ,
wantOther: 0,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
user, group, other := permToMasks(tc.perm)
assert.Equal(t, tc.wantUser, user)
assert.Equal(t, tc.wantGroup, group)
assert.Equal(t, tc.wantOther, other)
})
}
}
func TestMasksToPerm(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
user windows.ACCESS_MASK
group windows.ACCESS_MASK
other windows.ACCESS_MASK
wantPerm fs.FileMode
}{{
name: "all",
user: winAccessFull,
group: winAccessFull,
other: winAccessFull,
wantPerm: 0b111_111_111,
}, {
name: "user_write",
user: winAccessWrite,
group: 0,
other: 0,
wantPerm: 0o010_000_000,
}, {
name: "group_read",
user: 0,
group: windows.GENERIC_READ,
other: 0,
wantPerm: 0o000_010_000,
}, {
name: "no_access",
user: 0,
group: 0,
other: 0,
wantPerm: 0,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tc.wantPerm, masksToPerm(tc.user, tc.group, tc.other))
})
}
}

View File

@ -8,6 +8,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
) )
@ -62,7 +63,7 @@ func newPendingFile(filePath string, mode fs.FileMode) (f PendingFile, err error
return nil, fmt.Errorf("opening pending file: %w", err) return nil, fmt.Errorf("opening pending file: %w", err)
} }
err = file.Chmod(mode) err = aghos.Chmod(file.Name(), mode)
if err != nil { if err != nil {
return nil, fmt.Errorf("preparing pending file: %w", err) return nil, fmt.Errorf("preparing pending file: %w", err)
} }

View File

@ -1057,7 +1057,7 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
} }
} }
err = os.MkdirAll(filepath.Join(d.conf.DataDir, filterDir), aghos.DefaultPermDir) err = aghos.MkdirAll(filepath.Join(d.conf.DataDir, filterDir), aghos.DefaultPermDir)
if err != nil { if err != nil {
d.Close() d.Close()

View File

@ -643,7 +643,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
} }
dataDir := Context.getDataDir() dataDir := Context.getDataDir()
err = os.MkdirAll(dataDir, aghos.DefaultPermDir) err = aghos.MkdirAll(dataDir, aghos.DefaultPermDir)
fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dataDir)) fatalOnError(errors.Annotate(err, "creating DNS data dir at %s: %w", dataDir))
GLMode = opts.glinetMode GLMode = opts.glinetMode

View File

@ -14,7 +14,7 @@ import (
// //
// TODO(a.garipov): Consider ways to detect this better. // TODO(a.garipov): Consider ways to detect this better.
func NeedsMigration(confFilePath string) (ok bool) { func NeedsMigration(confFilePath string) (ok bool) {
s, err := os.Stat(confFilePath) s, err := aghos.Stat(confFilePath)
if err != nil { if err != nil {
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
// Likely a first run. Don't check. // Likely a first run. Don't check.
@ -70,7 +70,7 @@ func chmodFile(filePath string) {
// chmodPath changes the permissions of a single filesystem entity. The results // chmodPath changes the permissions of a single filesystem entity. The results
// are logged at the appropriate level. // are logged at the appropriate level.
func chmodPath(entPath, fileType string, fm fs.FileMode) { func chmodPath(entPath, fileType string, fm fs.FileMode) {
err := os.Chmod(entPath, fm) err := aghos.Chmod(entPath, fm)
if err == nil { if err == nil {
log.Info("permcheck: changed permissions for %s %q", fileType, entPath) log.Info("permcheck: changed permissions for %s %q", fileType, entPath)

View File

@ -264,7 +264,7 @@ func (u *Updater) check() (err error) {
// ignores the configuration file if firstRun is true. // ignores the configuration file if firstRun is true.
func (u *Updater) backup(firstRun bool) (err error) { func (u *Updater) backup(firstRun bool) (err error) {
log.Debug("updater: backing up current configuration") log.Debug("updater: backing up current configuration")
_ = os.Mkdir(u.backupDir, aghos.DefaultPermDir) _ = aghos.Mkdir(u.backupDir, aghos.DefaultPermDir)
if !firstRun { if !firstRun {
err = copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml")) err = copyFile(u.confName, filepath.Join(u.backupDir, "AdGuardHome.yaml"))
if err != nil { if err != nil {
@ -338,12 +338,12 @@ func (u *Updater) downloadPackageFile() (err error) {
return fmt.Errorf("io.ReadAll() failed: %w", err) return fmt.Errorf("io.ReadAll() failed: %w", err)
} }
_ = os.Mkdir(u.updateDir, aghos.DefaultPermDir) _ = aghos.Mkdir(u.updateDir, aghos.DefaultPermDir)
log.Debug("updater: saving package to file") log.Debug("updater: saving package to file")
err = os.WriteFile(u.packageName, body, aghos.DefaultPermFile) err = aghos.WriteFile(u.packageName, body, aghos.DefaultPermFile)
if err != nil { if err != nil {
return fmt.Errorf("os.WriteFile() failed: %w", err) return fmt.Errorf("writing package file: %w", err)
} }
return nil return nil
} }
@ -366,9 +366,9 @@ func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name st
return "", nil return "", nil
} }
err = os.Mkdir(outputName, os.FileMode(hdr.Mode&0o755)) err = aghos.Mkdir(outputName, os.FileMode(hdr.Mode&0o755))
if err != nil && !errors.Is(err, os.ErrExist) { if err != nil && !errors.Is(err, os.ErrExist) {
return "", fmt.Errorf("os.Mkdir(%q): %w", outputName, err) return "", fmt.Errorf("creating directory %q: %w", outputName, err)
} }
log.Debug("updater: created directory %q", outputName) log.Debug("updater: created directory %q", outputName)
@ -469,9 +469,9 @@ func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
return "", nil return "", nil
} }
err = os.Mkdir(outputName, fi.Mode()) err = aghos.Mkdir(outputName, fi.Mode())
if err != nil && !errors.Is(err, os.ErrExist) { if err != nil && !errors.Is(err, os.ErrExist) {
return "", fmt.Errorf("os.Mkdir(%q): %w", outputName, err) return "", fmt.Errorf("creating directory %q: %w", outputName, err)
} }
log.Debug("updater: created directory %q", outputName) log.Debug("updater: created directory %q", outputName)
@ -523,15 +523,19 @@ func zipFileUnpack(zipfile, outDir string) (files []string, err error) {
} }
// Copy file on disk // Copy file on disk
func copyFile(src, dst string) error { func copyFile(src, dst string) (err error) {
d, e := os.ReadFile(src) d, err := os.ReadFile(src)
if e != nil { if err != nil {
return e // Don't wrap the error, since it's informative enough as is.
return err
} }
e = os.WriteFile(dst, d, aghos.DefaultPermFile)
if e != nil { err = aghos.WriteFile(dst, d, aghos.DefaultPermFile)
return e if err != nil {
// Don't wrap the error, since it's informative enough as is.
return err
} }
return nil return nil
} }