diff --git a/atomicfile/atomicfile.go b/atomicfile/atomicfile.go index b63955a92..5c18e85a8 100644 --- a/atomicfile/atomicfile.go +++ b/atomicfile/atomicfile.go @@ -8,14 +8,20 @@ package atomicfile // import "tailscale.com/atomicfile" import ( + "fmt" "os" "path/filepath" "runtime" ) -// WriteFile writes data to filename+some suffix, then renames it -// into filename. The perm argument is ignored on Windows. +// WriteFile writes data to filename+some suffix, then renames it into filename. +// The perm argument is ignored on Windows. If the target filename already +// exists but is not a regular file, WriteFile returns an error. func WriteFile(filename string, data []byte, perm os.FileMode) (err error) { + fi, err := os.Stat(filename) + if err == nil && !fi.Mode().IsRegular() { + return fmt.Errorf("%s already exists and is not a regular file", filename) + } f, err := os.CreateTemp(filepath.Dir(filename), filepath.Base(filename)+".tmp") if err != nil { return err diff --git a/atomicfile/atomicfile_test.go b/atomicfile/atomicfile_test.go new file mode 100644 index 000000000..b52e79c2b --- /dev/null +++ b/atomicfile/atomicfile_test.go @@ -0,0 +1,38 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !windows + +package atomicfile + +import ( + "net" + "path/filepath" + "strings" + "testing" +) + +func TestDoesNotOverwriteIrregularFiles(t *testing.T) { + // Per tailscale/tailscale#7658 as one example, almost any imagined use of + // atomicfile.Write should likely not attempt to overwrite an irregular file + // such as a device node, socket, or named pipe. + + d := t.TempDir() + special := filepath.Join(d, "special") + + // The least troublesome thing to make that is not a file is a unix socket. + // Making a null device sadly requries root. + l, err := net.ListenUnix("unix", &net.UnixAddr{Name: special, Net: "unix"}) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + err = WriteFile(special, []byte("hello"), 0644) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "is not a regular file") { + t.Fatalf("unexpected error: %v", err) + } +}