163 lines
4.6 KiB
Go
163 lines
4.6 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package taildrop
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"os"
|
|
"slices"
|
|
"strings"
|
|
)
|
|
|
|
var (
|
|
blockSize = int64(64 << 10)
|
|
hashAlgorithm = "sha256"
|
|
)
|
|
|
|
// BlockChecksum represents the checksum for a single block.
|
|
type BlockChecksum struct {
|
|
Checksum Checksum `json:"checksum"`
|
|
Algorithm string `json:"algo"` // always "sha256" for now
|
|
Size int64 `json:"size"` // always (64<<10) for now
|
|
}
|
|
|
|
// Checksum is an opaque checksum that is comparable.
|
|
type Checksum struct{ cs [sha256.Size]byte }
|
|
|
|
func hash(b []byte) Checksum {
|
|
return Checksum{sha256.Sum256(b)}
|
|
}
|
|
func (cs Checksum) String() string {
|
|
return hex.EncodeToString(cs.cs[:])
|
|
}
|
|
func (cs Checksum) AppendText(b []byte) ([]byte, error) {
|
|
return hexAppendEncode(b, cs.cs[:]), nil
|
|
}
|
|
func (cs Checksum) MarshalText() ([]byte, error) {
|
|
return hexAppendEncode(nil, cs.cs[:]), nil
|
|
}
|
|
func (cs *Checksum) UnmarshalText(b []byte) error {
|
|
if len(b) != 2*len(cs.cs) {
|
|
return fmt.Errorf("invalid hex length: %d", len(b))
|
|
}
|
|
_, err := hex.Decode(cs.cs[:], b)
|
|
return err
|
|
}
|
|
|
|
// TODO(https://go.dev/issue/53693): Use hex.AppendEncode instead.
|
|
func hexAppendEncode(dst, src []byte) []byte {
|
|
n := hex.EncodedLen(len(src))
|
|
dst = slices.Grow(dst, n)
|
|
hex.Encode(dst[len(dst):][:n], src)
|
|
return dst[:len(dst)+n]
|
|
}
|
|
|
|
// PartialFiles returns a list of partial files in [Handler.Dir]
|
|
// that were sent (or is actively being sent) by the provided id.
|
|
func (m *Manager) PartialFiles(id ClientID) (ret []string, err error) {
|
|
if m == nil || m.opts.Dir == "" {
|
|
return nil, ErrNoTaildrop
|
|
}
|
|
|
|
suffix := id.partialSuffix()
|
|
if err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool {
|
|
if name := de.Name(); strings.HasSuffix(name, suffix) {
|
|
ret = append(ret, name)
|
|
}
|
|
return true
|
|
}); err != nil {
|
|
return ret, redactError(err)
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
// HashPartialFile returns a function that hashes the next block in the file,
|
|
// starting from the beginning of the file.
|
|
// It returns (BlockChecksum{}, io.EOF) when the stream is complete.
|
|
// It is the caller's responsibility to call close.
|
|
func (m *Manager) HashPartialFile(id ClientID, baseName string) (next func() (BlockChecksum, error), close func() error, err error) {
|
|
if m == nil || m.opts.Dir == "" {
|
|
return nil, nil, ErrNoTaildrop
|
|
}
|
|
noopNext := func() (BlockChecksum, error) { return BlockChecksum{}, io.EOF }
|
|
noopClose := func() error { return nil }
|
|
|
|
dstFile, err := joinDir(m.opts.Dir, baseName)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
f, err := os.Open(dstFile + id.partialSuffix())
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return noopNext, noopClose, nil
|
|
}
|
|
return nil, nil, redactError(err)
|
|
}
|
|
|
|
b := make([]byte, blockSize) // TODO: Pool this?
|
|
next = func() (BlockChecksum, error) {
|
|
switch n, err := io.ReadFull(f, b); {
|
|
case err != nil && err != io.EOF && err != io.ErrUnexpectedEOF:
|
|
return BlockChecksum{}, redactError(err)
|
|
case n == 0:
|
|
return BlockChecksum{}, io.EOF
|
|
default:
|
|
return BlockChecksum{hash(b[:n]), hashAlgorithm, int64(n)}, nil
|
|
}
|
|
}
|
|
close = f.Close
|
|
return next, close, nil
|
|
}
|
|
|
|
// ResumeReader reads and discards the leading content of r
|
|
// that matches the content based on the checksums that exist.
|
|
// It returns the number of bytes consumed,
|
|
// and returns an [io.Reader] representing the remaining content.
|
|
func ResumeReader(r io.Reader, hashNext func() (BlockChecksum, error)) (int64, io.Reader, error) {
|
|
if hashNext == nil {
|
|
return 0, r, nil
|
|
}
|
|
|
|
var offset int64
|
|
b := make([]byte, 0, blockSize)
|
|
for {
|
|
// Obtain the next block checksum from the remote peer.
|
|
cs, err := hashNext()
|
|
switch {
|
|
case err == io.EOF:
|
|
return offset, io.MultiReader(bytes.NewReader(b), r), nil
|
|
case err != nil:
|
|
return offset, io.MultiReader(bytes.NewReader(b), r), err
|
|
case cs.Algorithm != hashAlgorithm || cs.Size < 0 || cs.Size > blockSize:
|
|
return offset, io.MultiReader(bytes.NewReader(b), r), fmt.Errorf("invalid block size or hashing algorithm")
|
|
}
|
|
|
|
// Read the contents of the next block.
|
|
n, err := io.ReadFull(r, b[:cs.Size])
|
|
b = b[:n]
|
|
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
|
err = nil
|
|
}
|
|
if len(b) == 0 || err != nil {
|
|
// This should not occur in practice.
|
|
// It implies that an error occurred reading r,
|
|
// or that the partial file on the remote side is fully complete.
|
|
return offset, io.MultiReader(bytes.NewReader(b), r), err
|
|
}
|
|
|
|
// Compare the local and remote block checksums.
|
|
// If it mismatches, then resume from this point.
|
|
if cs.Checksum != hash(b) {
|
|
return offset, io.MultiReader(bytes.NewReader(b), r), nil
|
|
}
|
|
offset += int64(len(b))
|
|
b = b[:0]
|
|
}
|
|
}
|