diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index e948ba953..34addb451 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -3550,12 +3550,13 @@ func (b *LocalBackend) initPeerAPIListener() { ps := &peerAPIServer{ b: b, - taildrop: &taildrop.Handler{ + taildrop: &taildrop.Manager{ Logf: b.logf, Clock: b.clock, Dir: fileRoot, DirectFileMode: b.directFileRoot != "", AvoidFinalRename: !b.directFileDoFinalRename, + SendFileNotify: b.sendFileNotify, }, } if dm, ok := b.sys.DNSManager.GetOK(); ok { diff --git a/ipn/ipnlocal/peerapi.go b/ipn/ipnlocal/peerapi.go index f9c194fc8..65e2a8622 100644 --- a/ipn/ipnlocal/peerapi.go +++ b/ipn/ipnlocal/peerapi.go @@ -15,6 +15,7 @@ import ( "net" "net/http" "net/netip" + "net/url" "os" "runtime" "slices" @@ -53,7 +54,7 @@ type peerAPIServer struct { b *LocalBackend resolver *resolver.Resolver - taildrop *taildrop.Handler + taildrop *taildrop.Manager } var ( @@ -634,11 +635,45 @@ func (h *peerAPIHandler) handlePeerPut(w http.ResponseWriter, r *http.Request) { http.Error(w, "file sharing not enabled by Tailscale admin", http.StatusForbidden) return } + if r.Method != "PUT" { + http.Error(w, "expected method PUT", http.StatusMethodNotAllowed) + return + } + rawPath := r.URL.EscapedPath() + suffix, ok := strings.CutPrefix(rawPath, "/v0/put/") + if !ok { + http.Error(w, "misconfigured internals", http.StatusInternalServerError) + return + } + if suffix == "" { + http.Error(w, "empty filename", http.StatusBadRequest) + return + } + if strings.Contains(suffix, "/") { + http.Error(w, "directories not supported", http.StatusBadRequest) + return + } + baseName, err := url.PathUnescape(suffix) + if err != nil { + http.Error(w, "bad path encoding", http.StatusBadRequest) + return + } t0 := h.ps.b.clock.Now() - n, ok := h.ps.taildrop.HandlePut(w, r) - if ok { + // TODO(rhea,joetsai): Set the client ID and starting offset. + n, err := h.ps.taildrop.PutFile("", baseName, r.Body, 0, r.ContentLength) + switch err { + case nil: d := h.ps.b.clock.Since(t0).Round(time.Second / 10) h.logf("got put of %s in %v from %v/%v", approxSize(n), d, h.remoteAddr.Addr(), h.peerNode.ComputedName) + io.WriteString(w, "{}\n") + case taildrop.ErrNoTaildrop: + http.Error(w, err.Error(), http.StatusForbidden) + case taildrop.ErrInvalidFileName: + http.Error(w, err.Error(), http.StatusBadRequest) + case taildrop.ErrFileExists: + http.Error(w, err.Error(), http.StatusConflict) + default: + http.Error(w, err.Error(), http.StatusInternalServerError) } } diff --git a/ipn/ipnlocal/peerapi_test.go b/ipn/ipnlocal/peerapi_test.go index 7996dc62e..48a6a1d3b 100644 --- a/ipn/ipnlocal/peerapi_test.go +++ b/ipn/ipnlocal/peerapi_test.go @@ -17,7 +17,9 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "go4.org/netipx" + "tailscale.com/client/tailscale/apitype" "tailscale.com/ipn" "tailscale.com/ipn/store/mem" "tailscale.com/tailcfg" @@ -191,7 +193,7 @@ func TestHandlePeerAPI(t *testing.T) { capSharing: true, reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, checks: checks( - httpStatus(http.StatusInternalServerError), + httpStatus(http.StatusForbidden), bodyContains("Taildrop disabled; no storage directory"), ), }, @@ -248,7 +250,7 @@ func TestHandlePeerAPI(t *testing.T) { reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo.partial", nil)}, checks: checks( httpStatus(400), - bodyContains("bad filename"), + bodyContains("invalid filename"), ), }, { @@ -258,7 +260,7 @@ func TestHandlePeerAPI(t *testing.T) { reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo.deleted", nil)}, checks: checks( httpStatus(400), - bodyContains("bad filename"), + bodyContains("invalid filename"), ), }, { @@ -268,7 +270,7 @@ func TestHandlePeerAPI(t *testing.T) { reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/.", nil)}, checks: checks( httpStatus(400), - bodyContains("bad filename"), + bodyContains("invalid filename"), ), }, { @@ -298,7 +300,7 @@ func TestHandlePeerAPI(t *testing.T) { reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("."), nil)}, checks: checks( httpStatus(400), - bodyContains("bad filename"), + bodyContains("invalid filename"), ), }, { @@ -308,7 +310,7 @@ func TestHandlePeerAPI(t *testing.T) { reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("/"), nil)}, checks: checks( httpStatus(400), - bodyContains("bad filename"), + bodyContains("invalid filename"), ), }, { @@ -318,7 +320,7 @@ func TestHandlePeerAPI(t *testing.T) { reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("\\"), nil)}, checks: checks( httpStatus(400), - bodyContains("bad filename"), + bodyContains("invalid filename"), ), }, { @@ -328,7 +330,7 @@ func TestHandlePeerAPI(t *testing.T) { reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll(".."), nil)}, checks: checks( httpStatus(400), - bodyContains("bad filename"), + bodyContains("invalid filename"), ), }, { @@ -338,7 +340,7 @@ func TestHandlePeerAPI(t *testing.T) { reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("foo/../../../../../etc/passwd"), nil)}, checks: checks( httpStatus(400), - bodyContains("bad filename"), + bodyContains("invalid filename"), ), }, { @@ -370,7 +372,7 @@ func TestHandlePeerAPI(t *testing.T) { reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+(hexAll("😜")[:3]), nil)}, checks: checks( httpStatus(400), - bodyContains("bad filename"), + bodyContains("invalid filename"), ), }, { @@ -380,7 +382,7 @@ func TestHandlePeerAPI(t *testing.T) { reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/%00", nil)}, checks: checks( httpStatus(400), - bodyContains("bad filename"), + bodyContains("invalid filename"), ), }, { @@ -390,7 +392,7 @@ func TestHandlePeerAPI(t *testing.T) { reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/%01", nil)}, checks: checks( httpStatus(400), - bodyContains("bad filename"), + bodyContains("invalid filename"), ), }, { @@ -400,7 +402,7 @@ func TestHandlePeerAPI(t *testing.T) { reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("nul:"), nil)}, checks: checks( httpStatus(400), - bodyContains("bad filename"), + bodyContains("invalid filename"), ), }, { @@ -410,7 +412,7 @@ func TestHandlePeerAPI(t *testing.T) { reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll(" foo "), nil)}, checks: checks( httpStatus(400), - bodyContains("bad filename"), + bodyContains("invalid filename"), ), }, { @@ -441,23 +443,69 @@ func TestHandlePeerAPI(t *testing.T) { ), }, { - name: "bad_duplicate_zero_length", + name: "duplicate_zero_length", isSelf: true, capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil), httptest.NewRequest("PUT", "/v0/put/foo", nil)}, + reqs: []*http.Request{ + httptest.NewRequest("PUT", "/v0/put/foo", nil), + httptest.NewRequest("PUT", "/v0/put/foo", nil), + }, checks: checks( - httpStatus(409), - bodyContains("file exists"), + httpStatus(200), + func(t *testing.T, env *peerAPITestEnv) { + got, err := env.ph.ps.taildrop.WaitingFiles() + if err != nil { + t.Fatalf("WaitingFiles error: %v", err) + } + want := []apitype.WaitingFile{{Name: "foo", Size: 0}} + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) + } + }, ), }, { - name: "bad_duplicate_non_zero_length_content_length", + name: "duplicate_non_zero_length_content_length", isSelf: true, capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents")), httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents"))}, + reqs: []*http.Request{ + httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents")), + httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents")), + }, checks: checks( - httpStatus(409), - bodyContains("file exists"), + httpStatus(200), + func(t *testing.T, env *peerAPITestEnv) { + got, err := env.ph.ps.taildrop.WaitingFiles() + if err != nil { + t.Fatalf("WaitingFiles error: %v", err) + } + want := []apitype.WaitingFile{{Name: "foo", Size: 8}} + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) + } + }, + ), + }, + { + name: "duplicate_different_files", + isSelf: true, + capSharing: true, + reqs: []*http.Request{ + httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("fizz")), + httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("buzz")), + }, + checks: checks( + httpStatus(200), + func(t *testing.T, env *peerAPITestEnv) { + got, err := env.ph.ps.taildrop.WaitingFiles() + if err != nil { + t.Fatalf("WaitingFiles error: %v", err) + } + want := []apitype.WaitingFile{{Name: "foo", Size: 4}, {Name: "foo (1)", Size: 4}} + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) + } + }, ), }, } @@ -492,7 +540,7 @@ func TestHandlePeerAPI(t *testing.T) { if !tt.omitRoot { rootDir = t.TempDir() if e.ph.ps.taildrop == nil { - e.ph.ps.taildrop = &taildrop.Handler{ + e.ph.ps.taildrop = &taildrop.Manager{ Logf: e.logBuf.Logf, Clock: &tstest.Clock{}, } @@ -536,7 +584,7 @@ func TestFileDeleteRace(t *testing.T) { capFileSharing: true, clock: &tstest.Clock{}, }, - taildrop: &taildrop.Handler{ + taildrop: &taildrop.Manager{ Logf: t.Logf, Clock: &tstest.Clock{}, Dir: dir, diff --git a/taildrop/retrieve.go b/taildrop/retrieve.go index 7a773c950..01ab59468 100644 --- a/taildrop/retrieve.go +++ b/taildrop/retrieve.go @@ -21,11 +21,11 @@ import ( // HasFilesWaiting reports whether any files are buffered in [Handler.Dir]. // This always returns false when [Handler.DirectFileMode] is false. -func (s *Handler) HasFilesWaiting() bool { - if s == nil || s.Dir == "" || s.DirectFileMode { +func (m *Manager) HasFilesWaiting() bool { + if m == nil || m.Dir == "" || m.DirectFileMode { return false } - if s.knownEmpty.Load() { + if m.knownEmpty.Load() { // Optimization: this is usually empty, so avoid opening // the directory and checking. We can't cache the actual // has-files-or-not values as the macOS/iOS client might @@ -33,7 +33,7 @@ func (s *Handler) HasFilesWaiting() bool { // keep this negative cache. return false } - f, err := os.Open(s.Dir) + f, err := os.Open(m.Dir) if err != nil { return false } @@ -51,22 +51,22 @@ func (s *Handler) HasFilesWaiting() bool { // as the OS may return "foo.jpg.deleted" before "foo.jpg" // and we don't want to delete the ".deleted" file before // enumerating to the "foo.jpg" file. - defer tryDeleteAgain(filepath.Join(s.Dir, name)) + defer tryDeleteAgain(filepath.Join(m.Dir, name)) continue } if de.Type().IsRegular() { - _, err := os.Stat(filepath.Join(s.Dir, name+deletedSuffix)) + _, err := os.Stat(filepath.Join(m.Dir, name+deletedSuffix)) if os.IsNotExist(err) { return true } if err == nil { - tryDeleteAgain(filepath.Join(s.Dir, name)) + tryDeleteAgain(filepath.Join(m.Dir, name)) continue } } } if err == io.EOF { - s.knownEmpty.Store(true) + m.knownEmpty.Store(true) } if err != nil { break @@ -78,17 +78,14 @@ func (s *Handler) HasFilesWaiting() bool { // WaitingFiles returns the list of files that have been sent by a // peer that are waiting in [Handler.Dir]. // This always returns nil when [Handler.DirectFileMode] is false. -func (s *Handler) WaitingFiles() (ret []apitype.WaitingFile, err error) { - if s == nil { - return nil, errNilHandler +func (m *Manager) WaitingFiles() (ret []apitype.WaitingFile, err error) { + if m == nil || m.Dir == "" { + return nil, ErrNoTaildrop } - if s.Dir == "" { - return nil, errNoTaildrop - } - if s.DirectFileMode { + if m.DirectFileMode { return nil, nil } - f, err := os.Open(s.Dir) + f, err := os.Open(m.Dir) if err != nil { return nil, err } @@ -140,7 +137,7 @@ func (s *Handler) WaitingFiles() (ret []apitype.WaitingFile, err error) { // Maybe Windows is done virus scanning the file we tried // to delete a long time ago and will let us delete it now. for name := range deleted { - tryDeleteAgain(filepath.Join(s.Dir, name)) + tryDeleteAgain(filepath.Join(m.Dir, name)) } } sort.Slice(ret, func(i, j int) bool { return ret[i].Name < ret[j].Name }) @@ -163,23 +160,20 @@ func tryDeleteAgain(fullPath string) { // DeleteFile deletes a file of the given baseName from [Handler.Dir]. // This method is only allowed when [Handler.DirectFileMode] is false. -func (s *Handler) DeleteFile(baseName string) error { - if s == nil { - return errNilHandler +func (m *Manager) DeleteFile(baseName string) error { + if m == nil || m.Dir == "" { + return ErrNoTaildrop } - if s.Dir == "" { - return errNoTaildrop - } - if s.DirectFileMode { + if m.DirectFileMode { return errors.New("deletes not allowed in direct mode") } - path, ok := s.diskPath(baseName) + path, ok := m.joinDir(baseName) if !ok { return errors.New("bad filename") } var bo *backoff.Backoff - logf := s.Logf - t0 := s.Clock.Now() + logf := m.Logf + t0 := m.Clock.Now() for { err := os.Remove(path) if err != nil && !os.IsNotExist(err) { @@ -198,7 +192,7 @@ func (s *Handler) DeleteFile(baseName string) error { if bo == nil { bo = backoff.NewBackoff("delete-retry", logf, 1*time.Second) } - if s.Clock.Since(t0) < 5*time.Second { + if m.Clock.Since(t0) < 5*time.Second { bo.BackOff(context.Background(), err) continue } @@ -223,17 +217,14 @@ func touchFile(path string) error { // OpenFile opens a file of the given baseName from [Handler.Dir]. // This method is only allowed when [Handler.DirectFileMode] is false. -func (s *Handler) OpenFile(baseName string) (rc io.ReadCloser, size int64, err error) { - if s == nil { - return nil, 0, errNilHandler +func (m *Manager) OpenFile(baseName string) (rc io.ReadCloser, size int64, err error) { + if m == nil || m.Dir == "" { + return nil, 0, ErrNoTaildrop } - if s.Dir == "" { - return nil, 0, errNoTaildrop - } - if s.DirectFileMode { + if m.DirectFileMode { return nil, 0, errors.New("opens not allowed in direct mode") } - path, ok := s.diskPath(baseName) + path, ok := m.joinDir(baseName) if !ok { return nil, 0, errors.New("bad filename") } diff --git a/taildrop/send.go b/taildrop/send.go index bee750d88..8bb2e8715 100644 --- a/taildrop/send.go +++ b/taildrop/send.go @@ -4,11 +4,10 @@ package taildrop import ( + "crypto/sha256" + "errors" "io" - "net/http" - "net/url" "os" - "strings" "sync" "time" @@ -17,10 +16,14 @@ import ( "tailscale.com/version/distro" ) +type incomingFileKey struct { + id ClientID + name string // e.g., "foo.jpeg" +} + type incomingFile struct { clock tstime.Clock - name string // "foo.jpg" started time.Time size int64 // or -1 if unknown; never 0 w io.Writer // underlying writer @@ -33,13 +36,6 @@ type incomingFile struct { lastNotify time.Time } -func (f *incomingFile) markAndNotifyDone() { - f.mu.Lock() - f.done = true - f.mu.Unlock() - f.sendFileNotify() -} - func (f *incomingFile) Write(p []byte) (n int, err error) { n, err = f.w.Write(p) @@ -62,123 +58,197 @@ func (f *incomingFile) Write(p []byte) (n int, err error) { return n, err } -// HandlePut receives a file. -// It handles an HTTP PUT request to the "/v0/put/{filename}" endpoint, -// where {filename} is a base filename. -// It returns the number of bytes received and whether it was received successfully. -func (h *Handler) HandlePut(w http.ResponseWriter, r *http.Request) (finalSize int64, success bool) { - if !envknob.CanTaildrop() { - http.Error(w, "Taildrop disabled on device", http.StatusForbidden) - return finalSize, success +// PutFile stores a file into [Manager.Dir] from a given client id. +// The baseName must be a base filename without any slashes. +// The length is the expected length of content to read from r, +// it may be negative to indicate that it is unknown. +// +// If there is a failure reading from r, then the partial file is not deleted +// for some period of time. The [Manager.PartialFiles] and [Manager.HashPartialFile] +// methods may be used to list all partial files and to compute the hash for a +// specific partial file. This allows the client to determine whether to resume +// a partial file. While resuming, PutFile may be called again with a non-zero +// offset to specify where to resume receiving data at. +func (m *Manager) PutFile(id ClientID, baseName string, r io.Reader, offset, length int64) (int64, error) { + switch { + case m == nil || m.Dir == "": + return 0, ErrNoTaildrop + case !envknob.CanTaildrop(): + return 0, ErrNoTaildrop + case distro.Get() == distro.Unraid && !m.DirectFileMode: + return 0, ErrNotAccessible } - if r.Method != "PUT" { - http.Error(w, "expected method PUT", http.StatusMethodNotAllowed) - return finalSize, success - } - if h == nil || h.Dir == "" { - http.Error(w, errNoTaildrop.Error(), http.StatusInternalServerError) - return finalSize, success - } - if distro.Get() == distro.Unraid && !h.DirectFileMode { - http.Error(w, "Taildrop folder not configured or accessible", http.StatusInternalServerError) - return finalSize, success - } - rawPath := r.URL.EscapedPath() - suffix, ok := strings.CutPrefix(rawPath, "/v0/put/") + dstPath, ok := m.joinDir(baseName) if !ok { - http.Error(w, "misconfigured internals", http.StatusInternalServerError) - return finalSize, success - } - if suffix == "" { - http.Error(w, "empty filename", http.StatusBadRequest) - return finalSize, success - } - if strings.Contains(suffix, "/") { - http.Error(w, "directories not supported", http.StatusBadRequest) - return finalSize, success - } - baseName, err := url.PathUnescape(suffix) - if err != nil { - http.Error(w, "bad path encoding", http.StatusBadRequest) - return finalSize, success - } - dstFile, ok := h.diskPath(baseName) - if !ok { - http.Error(w, "bad filename", http.StatusBadRequest) - return finalSize, success - } - // TODO(bradfitz): prevent same filename being sent by two peers at once - - // prevent same filename being sent twice - if _, err := os.Stat(dstFile); err == nil { - http.Error(w, "file exists", http.StatusConflict) - return finalSize, success + return 0, ErrInvalidFileName } - partialFile := dstFile + partialSuffix - f, err := os.Create(partialFile) - if err != nil { - h.Logf("put Create error: %v", redactErr(err)) - http.Error(w, err.Error(), http.StatusInternalServerError) - return finalSize, success + redactAndLogError := func(action string, err error) error { + err = redactErr(err) + m.Logf("put %v error: %v", action, err) + return err } - defer func() { - if !success { - os.Remove(partialFile) - } - }() - var inFile *incomingFile - sendFileNotify := h.SendFileNotify + + avoidPartialRename := m.DirectFileMode && m.AvoidFinalRename + if avoidPartialRename { + // Users using AvoidFinalRename are depending on the exact filename + // of the partial files. So avoid injecting the id into it. + id = "" + } + + // Check whether there is an in-progress transfer for the file. + sendFileNotify := m.SendFileNotify if sendFileNotify == nil { sendFileNotify = func() {} // avoid nil panics below } - if r.ContentLength != 0 { - inFile = &incomingFile{ - clock: h.Clock, - name: baseName, - started: h.Clock.Now(), - size: r.ContentLength, - w: f, + partialPath := dstPath + id.partialSuffix() + inFileKey := incomingFileKey{id, baseName} + inFile, loaded := m.incomingFiles.LoadOrInit(inFileKey, func() *incomingFile { + inFile := &incomingFile{ + clock: m.Clock, + started: m.Clock.Now(), + size: length, sendFileNotify: sendFileNotify, } - if h.DirectFileMode { - inFile.partialPath = partialFile + if m.DirectFileMode { + inFile.partialPath = partialPath } - h.incomingFiles.Store(inFile, struct{}{}) - defer h.incomingFiles.Delete(inFile) - n, err := io.Copy(inFile, r.Body) + return inFile + }) + if loaded { + return 0, ErrFileExists + } + defer m.incomingFiles.Delete(inFileKey) + + // Create (if not already) the partial file with read-write permissions. + f, err := os.OpenFile(partialPath, os.O_CREATE|os.O_RDWR, 0666) + if err != nil { + return 0, redactAndLogError("Create", err) + } + defer func() { + f.Close() // best-effort to cleanup dangling file handles if err != nil { - err = redactErr(err) - f.Close() - h.Logf("put Copy error: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return finalSize, success + if avoidPartialRename { + os.Remove(partialPath) // best-effort + return + } + + // TODO: We need to delete partialPath eventually. + // However, this must be done after some period of time. } - finalSize = n - } - if err := redactErr(f.Close()); err != nil { - h.Logf("put Close error: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return finalSize, success - } - if h.DirectFileMode && h.AvoidFinalRename { - if inFile != nil { // non-zero length; TODO: notify even for zero length - inFile.markAndNotifyDone() + }() + inFile.w = f + + // A positive offset implies that we are resuming an existing file. + // Seek to the appropriate offset and truncate the file. + if offset != 0 { + currLength, err := f.Seek(0, io.SeekEnd) + if err != nil { + return 0, redactAndLogError("Seek", err) } - } else { - if err := os.Rename(partialFile, dstFile); err != nil { - err = redactErr(err) - h.Logf("put final rename: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return finalSize, success + if offset < 0 || offset > currLength { + return 0, redactAndLogError("Seek", err) + } + if _, err := f.Seek(offset, io.SeekStart); err != nil { + return 0, redactAndLogError("Seek", err) + } + if err := f.Truncate(offset); err != nil { + return 0, redactAndLogError("Truncate", err) } } - // TODO: set modtime - // TODO: some real response - success = true - io.WriteString(w, "{}\n") - h.knownEmpty.Store(false) + // Copy the contents of the file. + copyLength, err := io.Copy(inFile, r) + if err != nil { + return 0, redactAndLogError("Copy", err) + } + if length >= 0 && copyLength != length { + return 0, redactAndLogError("Copy", errors.New("copied an unexpected number of bytes")) + } + if err := f.Close(); err != nil { + return 0, redactAndLogError("Close", err) + } + fileLength := offset + copyLength + + // Return early for avoidPartialRename since users of AvoidFinalRename + // are depending on the exact naming of partial files. + if avoidPartialRename { + inFile.mu.Lock() + inFile.done = true + inFile.mu.Unlock() + m.knownEmpty.Store(false) + sendFileNotify() + return fileLength, nil + } + + // File has been successfully received, rename the partial file + // to the final destination filename. If a file of that name already exists, + // then try multiple times with variations of the filename. + computePartialSum := sync.OnceValues(func() ([sha256.Size]byte, error) { + return sha256File(partialPath) + }) + maxRetries := 10 + for ; maxRetries > 0; maxRetries-- { + // Atomically rename the partial file as the destination file if it doesn't exist. + // Otherwise, it returns the length of the current destination file. + // The operation is atomic. + dstLength, err := func() (int64, error) { + m.renameMu.Lock() + defer m.renameMu.Unlock() + switch fi, err := os.Stat(dstPath); { + case os.IsNotExist(err): + return -1, os.Rename(partialPath, dstPath) + case err != nil: + return -1, err + default: + return fi.Size(), nil + } + }() + if err != nil { + return 0, redactAndLogError("Rename", err) + } + if dstLength < 0 { + break // we successfully renamed; so stop + } + + // Avoid the final rename if a destination file has the same contents. + if dstLength == fileLength { + partialSum, err := computePartialSum() + if err != nil { + return 0, redactAndLogError("Rename", err) + } + dstSum, err := sha256File(dstPath) + if err != nil { + return 0, redactAndLogError("Rename", err) + } + if dstSum == partialSum { + if err := os.Remove(partialPath); err != nil { + return 0, redactAndLogError("Remove", err) + } + break // we successfully found a content match; so stop + } + } + + // Choose a new destination filename and try again. + dstPath = NextFilename(dstPath) + } + if maxRetries <= 0 { + return 0, errors.New("too many retries trying to rename partial file") + } + m.knownEmpty.Store(false) sendFileNotify() - return finalSize, success + return fileLength, nil +} + +func sha256File(file string) (out [sha256.Size]byte, err error) { + h := sha256.New() + f, err := os.Open(file) + if err != nil { + return out, err + } + defer f.Close() + if _, err := io.Copy(h, f); err != nil { + return out, err + } + return [sha256.Size]byte(h.Sum(nil)), nil } diff --git a/taildrop/taildrop.go b/taildrop/taildrop.go index b482655cd..bc2b3f6ff 100644 --- a/taildrop/taildrop.go +++ b/taildrop/taildrop.go @@ -15,8 +15,10 @@ import ( "os" "path" "path/filepath" + "regexp" "strconv" "strings" + "sync" "sync/atomic" "unicode" "unicode/utf8" @@ -28,8 +30,20 @@ import ( "tailscale.com/util/multierr" ) -// Handler manages the state for receiving and managing taildropped files. -type Handler struct { +// ClientID is an opaque identifier for file resumption. +// A client can only list and resume partial files for its own ID. +// It must contain any filesystem specific characters (e.g., slashes). +type ClientID string // e.g., "n12345CNTRL" + +func (id ClientID) partialSuffix() string { + if id == "" { + return partialSuffix + } + return "." + string(id) + partialSuffix // e.g., ".n12345CNTRL.partial" +} + +// Manager manages the state for receiving and managing taildropped files. +type Manager struct { Logf logger.Logf Clock tstime.Clock @@ -54,6 +68,11 @@ type Handler struct { // AvoidFinalRename specifies whether in DirectFileMode // we should avoid renaming "foo.jpg.partial" to "foo.jpg" after reception. + // + // TODO(joetsai,rhea): Delete this. This is currently depended upon + // in the Apple platforms since it violates the abstraction layer + // and directly assumes how taildrop represents partial files. + // Right now, file resumption does not work on Apple. AvoidFinalRename bool // SendFileNotify is called periodically while a file is actively @@ -64,12 +83,17 @@ type Handler struct { knownEmpty atomic.Bool - incomingFiles syncs.Map[*incomingFile, struct{}] + incomingFiles syncs.Map[incomingFileKey, *incomingFile] + + // renameMu is used to protect os.Rename calls so that they are atomic. + renameMu sync.Mutex } var ( - errNilHandler = errors.New("handler unavailable; not listening") - errNoTaildrop = errors.New("Taildrop disabled; no storage directory") + ErrNoTaildrop = errors.New("Taildrop disabled; no storage directory") + ErrInvalidFileName = errors.New("invalid filename") + ErrFileExists = errors.New("file already exists") + ErrNotAccessible = errors.New("Taildrop folder not configured or accessible") ) const ( @@ -107,7 +131,7 @@ func validFilenameRune(r rune) bool { return unicode.IsPrint(r) } -func (s *Handler) diskPath(baseName string) (fullPath string, ok bool) { +func (m *Manager) joinDir(baseName string) (fullPath string, ok bool) { if !utf8.ValidString(baseName) { return "", false } @@ -133,19 +157,20 @@ func (s *Handler) diskPath(baseName string) (fullPath string, ok bool) { if !filepath.IsLocal(baseName) { return "", false } - return filepath.Join(s.Dir, baseName), true + return filepath.Join(m.Dir, baseName), true } -func (s *Handler) IncomingFiles() []ipn.PartialFile { +// IncomingFiles returns a list of active incoming files. +func (m *Manager) IncomingFiles() []ipn.PartialFile { // Make sure we always set n.IncomingFiles non-nil so it gets encoded // in JSON to clients. They distinguish between empty and non-nil // to know whether a Notify should be able about files. files := make([]ipn.PartialFile, 0) - s.incomingFiles.Range(func(f *incomingFile, _ struct{}) bool { + m.incomingFiles.Range(func(k incomingFileKey, f *incomingFile) bool { f.mu.Lock() defer f.mu.Unlock() files = append(files, ipn.PartialFile{ - Name: f.name, + Name: k.name, Started: f.started, DeclaredSize: f.size, Received: f.copied, @@ -220,3 +245,26 @@ func redactErr(root error) error { } return &redactedErr{msg: s, inner: root} } + +var ( + rxExtensionSuffix = regexp.MustCompile(`(\.[a-zA-Z0-9]{0,3}[a-zA-Z][a-zA-Z0-9]{0,3})*$`) + rxNumberSuffix = regexp.MustCompile(` \([0-9]+\)`) +) + +// NextFilename returns the next filename in a sequence. +// It is used for construction a new filename if there is a conflict. +// +// For example, "Foo.jpg" becomes "Foo (1).jpg" and +// "Foo (1).jpg" becomes "Foo (2).jpg". +func NextFilename(name string) string { + ext := rxExtensionSuffix.FindString(strings.TrimPrefix(name, ".")) + name = strings.TrimSuffix(name, ext) + var n uint64 + if rxNumberSuffix.MatchString(name) { + i := strings.LastIndex(name, " (") + if n, _ = strconv.ParseUint(name[i+len("( "):len(name)-len(")")], 10, 64); n > 0 { + name = name[:i] + } + } + return name + " (" + strconv.FormatUint(n+1, 10) + ")" + ext +} diff --git a/taildrop/taildrop_test.go b/taildrop/taildrop_test.go index 29c88e8c8..969ce3fe5 100644 --- a/taildrop/taildrop_test.go +++ b/taildrop/taildrop_test.go @@ -16,7 +16,7 @@ import ( // Tests "foo.jpg.deleted" marks (for Windows). func TestDeletedMarkers(t *testing.T) { dir := t.TempDir() - h := &Handler{Dir: dir} + h := &Manager{Dir: dir} nothingWaiting := func() { t.Helper() @@ -153,3 +153,32 @@ func TestRedactErr(t *testing.T) { }) } } + +func TestNextFilename(t *testing.T) { + tests := []struct { + in string + want string + want2 string + }{ + {"foo", "foo (1)", "foo (2)"}, + {"foo(1)", "foo(1) (1)", "foo(1) (2)"}, + {"foo.tar", "foo (1).tar", "foo (2).tar"}, + {"foo.tar.gz", "foo (1).tar.gz", "foo (2).tar.gz"}, + {".bashrc", ".bashrc (1)", ".bashrc (2)"}, + {"fizz buzz.torrent", "fizz buzz (1).torrent", "fizz buzz (2).torrent"}, + {"rawr 2023.12.15.txt", "rawr 2023.12.15 (1).txt", "rawr 2023.12.15 (2).txt"}, + {"IMG_7934.JPEG", "IMG_7934 (1).JPEG", "IMG_7934 (2).JPEG"}, + {"my song.mp3", "my song (1).mp3", "my song (2).mp3"}, + {"archive.7z", "archive (1).7z", "archive (2).7z"}, + {"foo/bar/fizz", "foo/bar/fizz (1)", "foo/bar/fizz (2)"}, + } + + for _, tt := range tests { + if got := NextFilename(tt.in); got != tt.want { + t.Errorf("NextFilename(%q) = %q, want %q", tt.in, got, tt.want) + } + if got2 := NextFilename(tt.want); got2 != tt.want2 { + t.Errorf("NextFilename(%q) = %q, want %q", tt.want, got2, tt.want2) + } + } +}