all: fix bugs

This commit is contained in:
Eugene Burkov 2024-10-25 18:30:52 +03:00
parent a2309f812a
commit a22b0d265e
6 changed files with 96 additions and 17 deletions

View File

@ -5,6 +5,8 @@ package aghos
import ( import (
"io/fs" "io/fs"
"os" "os"
"github.com/google/renameio/v2/maybe"
) )
// chmod is a Unix implementation of [Chmod]. // chmod is a Unix implementation of [Chmod].
@ -24,7 +26,7 @@ func mkdirAll(path string, perm fs.FileMode) (err error) {
// writeFile is a Unix implementation of [WriteFile]. // writeFile is a Unix implementation of [WriteFile].
func writeFile(filename string, data []byte, perm fs.FileMode) (err error) { func writeFile(filename string, data []byte, perm fs.FileMode) (err error) {
return os.WriteFile(filename, data, perm) return maybe.WriteFile(filename, data, perm)
} }
// openFile is a Unix implementation of [OpenFile]. // openFile is a Unix implementation of [OpenFile].

View File

@ -86,9 +86,11 @@ func stat(name string) (fi os.FileInfo, err error) {
} }
} }
mode := masksToPerm(ownerMask, groupMask, otherMask) | (fi.Mode().Perm() & ^fs.ModePerm)
return &fileInfo{ return &fileInfo{
FileInfo: fi, FileInfo: fi,
mode: masksToPerm(ownerMask, groupMask, otherMask), mode: mode,
}, nil }, nil
} }
@ -96,8 +98,13 @@ func stat(name string) (fi os.FileInfo, err error) {
func chmod(name string, perm fs.FileMode) (err error) { func chmod(name string, perm fs.FileMode) (err error) {
const objectType windows.SE_OBJECT_TYPE = windows.SE_FILE_OBJECT const objectType windows.SE_OBJECT_TYPE = windows.SE_FILE_OBJECT
fi, err := os.Stat(name)
if err != nil {
return fmt.Errorf("getting file info: %w", err)
}
entries := make([]windows.EXPLICIT_ACCESS, 0, 3) entries := make([]windows.EXPLICIT_ACCESS, 0, 3)
creatorMask, groupMask, worldMask := permToMasks(perm) creatorMask, groupMask, worldMask := permToMasks(perm, fi.IsDir())
sidMasks := container.KeyValues[windows.WELL_KNOWN_SID_TYPE, windows.ACCESS_MASK]{{ sidMasks := container.KeyValues[windows.WELL_KNOWN_SID_TYPE, windows.ACCESS_MASK]{{
Key: windows.WinCreatorOwnerSid, Key: windows.WinCreatorOwnerSid,
@ -120,6 +127,8 @@ func chmod(name string, perm fs.FileMode) (err error) {
trustee, err = newWellKnownTrustee(sidMask.Key) trustee, err = newWellKnownTrustee(sidMask.Key)
if err != nil { if err != nil {
errs = append(errs, err) errs = append(errs, err)
continue
} }
entries = append(entries, windows.EXPLICIT_ACCESS{ entries = append(entries, windows.EXPLICIT_ACCESS{
@ -184,24 +193,35 @@ func mkdirAll(path string, perm os.FileMode) (err error) {
return fmt.Errorf("creating parent directories: %w", err) return fmt.Errorf("creating parent directories: %w", err)
} }
return mkdir(path, perm) err = mkdir(path, perm)
if errors.Is(err, os.ErrExist) {
return nil
}
return err
} }
// writeFile is a Windows implementation of [WriteFile]. // writeFile is a Windows implementation of [WriteFile].
func writeFile(filename string, data []byte, perm os.FileMode) (err error) { func writeFile(filename string, data []byte, perm os.FileMode) (err error) {
err = os.WriteFile(filename, data, perm) file, err := openFile(filename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, perm)
if err != nil { if err != nil {
return fmt.Errorf("writing file: %w", err) return fmt.Errorf("opening file: %w", err)
}
defer func() { err = errors.WithDeferred(err, file.Close()) }()
_, err = file.Write(data)
if err != nil {
return fmt.Errorf("writing data: %w", err)
} }
return chmod(filename, perm) return nil
} }
// openFile is a Windows implementation of [OpenFile]. // openFile is a Windows implementation of [OpenFile].
func openFile(name string, flag int, perm os.FileMode) (file *os.File, err error) { func openFile(name string, flag int, perm os.FileMode) (file *os.File, err error) {
// Only change permissions if the file not yet exists, but should be // Only change permissions if the file not yet exists, but should be
// created. // created.
if flag&os.O_CREATE != 0 { if flag&os.O_CREATE == 0 {
return os.OpenFile(name, flag, perm) return os.OpenFile(name, flag, perm)
} }
@ -236,6 +256,10 @@ const (
groupWrite = 0b000_010_000 groupWrite = 0b000_010_000
worldWrite = 0b000_000_010 worldWrite = 0b000_000_010
ownerRead = 0b100_000_000
groupRead = 0b000_100_000
worldRead = 0b000_000_100
ownerAll = 0b111_000_000 ownerAll = 0b111_000_000
groupAll = 0b000_111_000 groupAll = 0b000_111_000
worldAll = 0b000_000_111 worldAll = 0b000_000_111
@ -251,22 +275,57 @@ const (
) )
// Constants reflecting the number of bits to shift the UNIX write permission // Constants reflecting the number of bits to shift the UNIX write permission
// bits to convert them to the delete access rights used by Windows. // bits to convert them to the access rights used by Windows.
const ( const (
deleteOwner = 9 deleteOwner = 9
deleteGroup = 12 deleteGroup = 12
deleteWorld = 15 deleteWorld = 15
listDirOwner = 7
listDirGroup = 10
listDirWorld = 13
traverseOwner = 2
traverseGroup = 5
traverseWorld = 8
writeEAOwner = 2
writeEAGroup = 1
writeEAWorld = 4
deleteChildOwner = 1
deleteChildGroup = 4
deleteChildWorld = 7
) )
// permToMasks converts a UNIX file mode permissions to the corresponding // permToMasks converts a UNIX file mode permissions to the corresponding
// Windows access masks. // Windows access masks. The [isDir] argument is used to set specific access
func permToMasks(fm os.FileMode) (owner, group, world windows.ACCESS_MASK) { // bits for directories.
func permToMasks(fm os.FileMode, isDir bool) (owner, group, world windows.ACCESS_MASK) {
mask := windows.ACCESS_MASK(fm.Perm()) mask := windows.ACCESS_MASK(fm.Perm())
owner = ((mask & ownerAll) << genericOwner) | ((mask & ownerWrite) << deleteOwner) owner = ((mask & ownerAll) << genericOwner) | ((mask & ownerWrite) << deleteOwner)
group = ((mask & groupAll) << genericGroup) | ((mask & groupWrite) << deleteGroup) group = ((mask & groupAll) << genericGroup) | ((mask & groupWrite) << deleteGroup)
world = ((mask & worldAll) << genericWorld) | ((mask & worldWrite) << deleteWorld) world = ((mask & worldAll) << genericWorld) | ((mask & worldWrite) << deleteWorld)
if isDir {
owner |= (mask & ownerRead) << listDirOwner
group |= (mask & groupRead) << listDirGroup
world |= (mask & worldRead) << listDirWorld
owner |= (mask & ownerRead) << traverseOwner
group |= (mask & groupRead) << traverseGroup
world |= (mask & worldRead) << traverseWorld
owner |= (mask & ownerWrite) << deleteChildOwner
group |= (mask & groupWrite) << deleteChildGroup
world |= (mask & worldWrite) << deleteChildWorld
owner |= (mask & ownerWrite) >> writeEAOwner
group |= (mask & groupWrite) << writeEAGroup
world |= (mask & worldWrite) << writeEAWorld
}
return owner, group, world return owner, group, world
} }

View File

@ -14,6 +14,8 @@ import (
const ( const (
winAccessWrite = windows.GENERIC_WRITE | windows.DELETE winAccessWrite = windows.GENERIC_WRITE | windows.DELETE
winAccessFull = windows.GENERIC_READ | windows.GENERIC_EXECUTE | winAccessWrite winAccessFull = windows.GENERIC_READ | windows.GENERIC_EXECUTE | winAccessWrite
winAccessDirRead = windows.GENERIC_READ | windows.FILE_LIST_DIRECTORY | windows.FILE_TRAVERSE
) )
func TestPermToMasks(t *testing.T) { func TestPermToMasks(t *testing.T) {
@ -25,31 +27,49 @@ func TestPermToMasks(t *testing.T) {
wantUser windows.ACCESS_MASK wantUser windows.ACCESS_MASK
wantGroup windows.ACCESS_MASK wantGroup windows.ACCESS_MASK
wantOther windows.ACCESS_MASK wantOther windows.ACCESS_MASK
isDir bool
}{{ }{{
name: "all", name: "all",
perm: 0b111_111_111, perm: 0b111_111_111,
wantUser: winAccessFull, wantUser: winAccessFull,
wantGroup: winAccessFull, wantGroup: winAccessFull,
wantOther: winAccessFull, wantOther: winAccessFull,
isDir: false,
}, { }, {
name: "user_write", name: "user_write",
perm: 0o010_000_000, perm: 0o010_000_000,
wantUser: winAccessWrite, wantUser: winAccessWrite,
wantGroup: 0, wantGroup: 0,
wantOther: 0, wantOther: 0,
isDir: false,
}, { }, {
name: "group_read", name: "group_read",
perm: 0o000_010_000, perm: 0o000_010_000,
wantUser: 0, wantUser: 0,
wantGroup: windows.GENERIC_READ, wantGroup: windows.GENERIC_READ,
wantOther: 0, wantOther: 0,
isDir: false,
}, {
name: "all_dir",
perm: 0b111_111_111,
wantUser: winAccessFull | winAccessDirRead,
wantGroup: winAccessFull | winAccessDirRead,
wantOther: winAccessFull | winAccessDirRead,
isDir: true,
}, {
name: "user_write_dir",
perm: 0o010_000_000,
wantUser: winAccessWrite,
wantGroup: 0,
wantOther: 0,
isDir: true,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
user, group, other := permToMasks(tc.perm) user, group, other := permToMasks(tc.perm, tc.isDir)
assert.Equal(t, tc.wantUser, user) assert.Equal(t, tc.wantUser, user)
assert.Equal(t, tc.wantGroup, group) assert.Equal(t, tc.wantGroup, group)
assert.Equal(t, tc.wantOther, other) assert.Equal(t, tc.wantOther, other)

View File

@ -708,7 +708,7 @@ func (c *configuration) write() (err error) {
return fmt.Errorf("generating config file: %w", err) return fmt.Errorf("generating config file: %w", err)
} }
err = maybe.WriteFile(confPath, buf.Bytes(), aghos.DefaultPermFile) err = aghos.WriteFile(confPath, buf.Bytes(), aghos.DefaultPermFile)
if err != nil { if err != nil {
return fmt.Errorf("writing config file: %w", err) return fmt.Errorf("writing config file: %w", err)
} }

View File

@ -9,7 +9,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr" "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/osutil" "github.com/AdguardTeam/golibs/osutil"
"github.com/google/renameio/v2/maybe"
) )
// signalHandler processes incoming signals and shuts services down. // signalHandler processes incoming signals and shuts services down.
@ -142,7 +141,7 @@ func (h *signalHandler) writePID() {
data = strconv.AppendInt(data, int64(os.Getpid()), 10) data = strconv.AppendInt(data, int64(os.Getpid()), 10)
data = append(data, '\n') data = append(data, '\n')
err := maybe.WriteFile(h.pidFile, data, 0o644) err := aghos.WriteFile(h.pidFile, data, 0o644)
if err != nil { if err != nil {
log.Error("sighdlr: writing pidfile: %s", err) log.Error("sighdlr: writing pidfile: %s", err)

View File

@ -21,7 +21,6 @@ import (
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"github.com/google/renameio/v2/maybe"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@ -183,7 +182,7 @@ func (m *Manager) write() (err error) {
return fmt.Errorf("encoding: %w", err) return fmt.Errorf("encoding: %w", err)
} }
err = maybe.WriteFile(m.fileName, b, aghos.DefaultPermFile) err = aghos.WriteFile(m.fileName, b, aghos.DefaultPermFile)
if err != nil { if err != nil {
return fmt.Errorf("writing: %w", err) return fmt.Errorf("writing: %w", err)
} }