drive: rewrite Location headers

This ensures that MOVE, LOCK and any other verbs that use the Location
header work correctly.

Fixes #11758

Signed-off-by: Percy Wegmann <percy@tailscale.com>
This commit is contained in:
Percy Wegmann 2024-04-18 12:11:20 -05:00 committed by Percy Wegmann
parent c24f2eee34
commit 787f8c08ec
4 changed files with 52 additions and 8 deletions

View File

@ -100,7 +100,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
pathComponents := shared.CleanAndSplit(r.URL.Path) pathComponents := shared.CleanAndSplit(r.URL.Path)
if len(pathComponents) >= mpl { if len(pathComponents) >= mpl {
h.delegate(pathComponents[mpl-1:], w, r) h.delegate(mpl, pathComponents[mpl-1:], w, r)
return return
} }
h.handle(w, r) h.handle(w, r)
@ -129,24 +129,41 @@ func (h *Handler) handle(w http.ResponseWriter, r *http.Request) {
} }
// delegate sends the request to the Child WebDAV server. // delegate sends the request to the Child WebDAV server.
func (h *Handler) delegate(pathComponents []string, w http.ResponseWriter, r *http.Request) string { func (h *Handler) delegate(mpl int, pathComponents []string, w http.ResponseWriter, r *http.Request) {
dest := r.Header.Get("Destination")
if dest != "" {
// Rewrite destination header
destURL, err := url.Parse(dest)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
destinationComponents := shared.CleanAndSplit(destURL.Path)
if len(destinationComponents) < mpl || destinationComponents[mpl-1] != pathComponents[0] {
http.Error(w, "Destination across shares is not supported", http.StatusBadRequest)
return
}
updatedDest := shared.JoinEscaped(destinationComponents[mpl:]...)
r.Header.Set("Destination", updatedDest)
}
childName := pathComponents[0] childName := pathComponents[0]
child := h.GetChild(childName) child := h.GetChild(childName)
if child == nil { if child == nil {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
return childName return
} }
u, err := url.Parse(child.BaseURL) u, err := url.Parse(child.BaseURL)
if err != nil { if err != nil {
h.logf("warning: parse base URL %s failed: %s", child.BaseURL, err) h.logf("warning: parse base URL %s failed: %s", child.BaseURL, err)
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return childName return
} }
u.Path = path.Join(u.Path, shared.Join(pathComponents[1:]...)) u.Path = path.Join(u.Path, shared.Join(pathComponents[1:]...))
r.URL = u r.URL = u
r.Host = u.Host r.Host = u.Host
child.rp.ServeHTTP(w, r) child.rp.ServeHTTP(w, r)
return childName
} }
// SetChildren replaces the entire existing set of children with the given // SetChildren replaces the entire existing set of children with the given

View File

@ -37,7 +37,7 @@ func (h *Handler) handlePROPFIND(w http.ResponseWriter, r *http.Request) {
bw := &bufferingResponseWriter{ResponseWriter: w} bw := &bufferingResponseWriter{ResponseWriter: w}
mpl := h.maxPathLength(r) mpl := h.maxPathLength(r)
h.delegate(pathComponents[mpl-1:], bw, r) h.delegate(mpl, pathComponents[mpl-1:], bw, r)
// Fixup paths to add the requested path as a prefix. // Fixup paths to add the requested path as a prefix.
pathPrefix := shared.Join(pathComponents[0:mpl]...) pathPrefix := shared.Join(pathComponents[0:mpl]...)

View File

@ -33,6 +33,7 @@ const (
share11 = `sha re$%11` share11 = `sha re$%11`
share12 = `_sha re$%12` share12 = `_sha re$%12`
file111 = `fi le$%111.txt` file111 = `fi le$%111.txt`
file112 = `file112.txt`
) )
func init() { func init() {
@ -81,9 +82,13 @@ func TestFileManipulation(t *testing.T) {
s.checkFileStatus(remote1, share11, file111) s.checkFileStatus(remote1, share11, file111)
s.checkFileContents(remote1, share11, file111) s.checkFileContents(remote1, share11, file111)
s.renameFile("renaming file across shares should fail", remote1, share11, file111, share12, file112, false)
s.renameFile("renaming file in same share should succeed", remote1, share11, file111, share11, file112, true)
s.checkFileContents(remote1, share11, file112)
s.addShare(remote1, share12, drive.PermissionReadOnly) s.addShare(remote1, share12, drive.PermissionReadOnly)
s.writeFile("writing file to read-only remote should fail", remote1, share12, file111, "hello world", false) s.writeFile("writing file to read-only remote should fail", remote1, share12, file111, "hello world", false)
s.writeFile("writing file to non-existent remote should fail", "non-existent", share11, file111, "hello world", false) s.writeFile("writing file to non-existent remote should fail", "non-existent", share11, file111, "hello world", false)
s.writeFile("writing file to non-existent share should fail", remote1, "non-existent", file111, "hello world", false) s.writeFile("writing file to non-existent share should fail", remote1, "non-existent", file111, "hello world", false)
} }
@ -241,7 +246,18 @@ func (s *system) writeFile(label, remoteName, shareName, name, contents string,
if expectSuccess && err != nil { if expectSuccess && err != nil {
s.t.Fatalf("%v: expected success writing file %q, but got error %v", label, path, err) s.t.Fatalf("%v: expected success writing file %q, but got error %v", label, path, err)
} else if !expectSuccess && err == nil { } else if !expectSuccess && err == nil {
s.t.Fatalf("%v: expected error writing file %q", label, path) s.t.Fatalf("%v: expected error writing file %q, but got no error", label, path)
}
}
func (s *system) renameFile(label, remoteName, fromShare, fromFile, toShare, toFile string, expectSuccess bool) {
fromPath := pathTo(remoteName, fromShare, fromFile)
toPath := pathTo(remoteName, toShare, toFile)
err := s.client.Rename(fromPath, toPath, true)
if expectSuccess && err != nil {
s.t.Fatalf("%v: expected success moving file %q to %q, but got error %v", label, fromPath, toPath, err)
} else if !expectSuccess && err == nil {
s.t.Fatalf("%v: expected error moving file %q to %q, but got no error", label, fromPath, toPath)
} }
} }

View File

@ -4,6 +4,7 @@
package shared package shared
import ( import (
"net/url"
"path" "path"
"strings" "strings"
) )
@ -35,6 +36,16 @@ func Join(parts ...string) string {
return path.Join(fullParts...) return path.Join(fullParts...)
} }
// JoinEscaped is like Join but path escapes each part.
func JoinEscaped(parts ...string) string {
fullParts := make([]string, 0, len(parts))
fullParts = append(fullParts, sepString)
for _, part := range parts {
fullParts = append(fullParts, url.PathEscape(part))
}
return path.Join(fullParts...)
}
// IsRoot determines whether a given path p is the root path, defined as either // IsRoot determines whether a given path p is the root path, defined as either
// empty or "/". // empty or "/".
func IsRoot(p string) bool { func IsRoot(p string) bool {