Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions internal/cache/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,13 @@ func (s *S3) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, err
}
}

// Get object
obj, err := s.client.GetObject(ctx, s.config.Bucket, objectName, minio.GetObjectOptions{})
// Download object using parallel range-GET for large objects.
reader, err := s.parallelGetReader(ctx, s.config.Bucket, objectName, objInfo.Size)
if err != nil {
return nil, nil, errors.Errorf("failed to get object: %w", err)
return nil, nil, err
}

return &s3Reader{obj: obj}, headers, nil
return reader, headers, nil
}

// refreshExpiration updates the Expires-At metadata on an S3 object using
Expand Down
116 changes: 116 additions & 0 deletions internal/cache/s3_parallel_get.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package cache

import (
"context"
"io"
"sync"

"github.com/alecthomas/errors"
"github.com/minio/minio-go/v7"
)

const (
// s3DownloadChunkSize is the size of each parallel range-GET request.
// 32 MiB matches the gradle-cache-tool's benchmarked default.
s3DownloadChunkSize = 32 << 20
// s3DownloadWorkers is the number of concurrent range-GET requests.
// Benchmarking showed no throughput difference from 4 to 128 workers
// (extraction IOPS is the bottleneck), so 8 keeps connection count low.
s3DownloadWorkers = 8
)

// parallelGetReader returns an io.ReadCloser that downloads the S3 object
// using parallel range-GET requests and reassembles chunks in order.
// For objects smaller than one chunk, it falls back to a single GetObject.
func (s *S3) parallelGetReader(ctx context.Context, bucket, objectName string, size int64) (io.ReadCloser, error) {
if size <= s3DownloadChunkSize {
// Small object: single stream.
obj, err := s.client.GetObject(ctx, bucket, objectName, minio.GetObjectOptions{})
if err != nil {
return nil, errors.Errorf("failed to get object: %w", err)
}
return &s3Reader{obj: obj}, nil
}

// Large object: parallel range requests reassembled in order via io.Pipe.
pr, pw := io.Pipe()
go func() {
pw.CloseWithError(s.parallelGet(ctx, bucket, objectName, size, pw))
}()
return pr, nil
}

// parallelGet downloads an S3 object in parallel chunks and writes them in
// order to w. Each worker downloads its chunk into memory so the TCP
// connection stays active at full speed. Peak memory: numWorkers × chunkSize.
func (s *S3) parallelGet(ctx context.Context, bucket, objectName string, size int64, w io.Writer) error {
numChunks := int((size + s3DownloadChunkSize - 1) / s3DownloadChunkSize)
numWorkers := min(s3DownloadWorkers, numChunks)

type chunkResult struct {
data []byte
err error
}

// One buffered channel per chunk so workers never block after reading.
results := make([]chan chunkResult, numChunks)
for i := range results {
results[i] = make(chan chunkResult, 1)
}

// Work queue of chunk indices.
work := make(chan int, numChunks)
for i := range numChunks {
work <- i
}
close(work)

var wg sync.WaitGroup
for range numWorkers {
wg.Go(func() {
for seq := range work {
start := int64(seq) * s3DownloadChunkSize
end := min(start+s3DownloadChunkSize-1, size-1)

opts := minio.GetObjectOptions{}
if err := opts.SetRange(start, end); err != nil {
results[seq] <- chunkResult{err: errors.Errorf("set range %d-%d: %w", start, end, err)}
continue
}

obj, err := s.client.GetObject(ctx, bucket, objectName, opts)
if err != nil {
results[seq] <- chunkResult{err: errors.Errorf("get range %d-%d: %w", start, end, err)}
continue
}

// Drain the body immediately so the TCP connection stays at
// full speed. All workers do this concurrently, saturating
// the available S3 bandwidth.
data, readErr := io.ReadAll(obj)
obj.Close() //nolint:errcheck,gosec
results[seq] <- chunkResult{data: data, err: readErr}
}
})
}

// Write chunks in order. Each receive blocks until that chunk's worker
// finishes, while other workers continue downloading concurrently.
var writeErr error
for _, ch := range results {
r := <-ch
if writeErr != nil {
continue // drain remaining channels so goroutines can exit
}
if r.err != nil {
writeErr = r.err
continue
}
if _, err := w.Write(r.data); err != nil {
writeErr = err
}
}

wg.Wait()
return writeErr
}
218 changes: 211 additions & 7 deletions internal/snapshot/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package snapshot

import (
"archive/tar"
"bufio"
"bytes"
"context"
Expand All @@ -12,6 +13,8 @@ import (
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
"time"

"github.com/alecthomas/errors"
Expand Down Expand Up @@ -187,9 +190,32 @@ func Restore(ctx context.Context, remote cache.Cache, key cache.Key, directory s
return Extract(ctx, rc, directory, threads)
}

const (
// extractWorkers is the number of goroutines writing files concurrently
// during parallel tar extraction. Hides per-file open/write/close syscall
// latency so the tar-stream reader (and download pipeline behind it) is not
// stalled waiting for individual file writes to complete.
// Benchmarked on r8id.metal-48xlarge (NVMe, 96 cores) with a 334K-file
// bundle: 64 workers = 6.27s, 128 = 6.84s (extra GC pressure outweighs
// any I/O concurrency gain).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't include pre-parallel benchmark results, what were the original numbers?

extractWorkers = 64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't be static for several reasons:

  • not all machines have this many cores
  • there are quite a few other processes running in parallel, so consuming all cores will result in k8s throttling

This will need to be threaded through a config value.

// maxParallelFileSize is the largest file that will be buffered in memory
// and dispatched to the worker pool. Files larger than this are written
// inline in the main goroutine to keep peak memory bounded.
// At 4 MiB, 99.97% of Gradle cache entries go through the parallel path.
maxParallelFileSize = 4 << 20 // 4 MiB
)

// Extract decompresses a zstd+tar stream into directory, preserving all file
// permissions, ownership, and symlinks. threads controls zstd parallelism;
// 0 uses all available CPU cores.
//
// The single-threaded bottleneck on restore is writing files to disk. Even
// though tar entries must be read sequentially (the format has no index), the
// actual file writes are independent. The extractor dispatches each entry
// (buffered in memory, ≤4 MiB) to one of 64 worker goroutines that write
// concurrently. This hides the per-file syscall latency (~20µs × N files)
// behind parallelism.
func Extract(ctx context.Context, r io.Reader, directory string, threads int) error {
if threads <= 0 {
threads = runtime.NumCPU()
Expand All @@ -212,15 +238,193 @@ func Extract(ctx context.Context, r io.Reader, directory string, threads int) er
}
defer dec.Close()

tarCmd := exec.CommandContext(ctx, "tar", "-xpf", "-", "-C", directory)
tarCmd.Stdin = dec
return extractTarParallel(ctx, dec, directory)
}

var tarStderr bytes.Buffer
tarCmd.Stderr = &tarStderr
type writeJob struct {
target string
mode os.FileMode
data []byte
}

// safePath validates that name is a relative path that stays within dir when
// joined. It rejects absolute paths and parent traversals (".."). Returns the
// resolved path under dir.
func safePath(dir, name string) (string, error) {
clean := filepath.Clean(name)
if filepath.IsAbs(clean) {
return "", errors.Errorf("path %q is absolute", name)
}
if clean == ".." || strings.HasPrefix(clean, ".."+string(os.PathSeparator)) {
return "", errors.Errorf("path %q escapes destination directory", name)
}
joined := filepath.Join(dir, clean)
if !strings.HasPrefix(joined, dir+string(os.PathSeparator)) && joined != dir {
return "", errors.Errorf("path %q resolves outside destination directory", name)
}
return joined, nil
}

// extractTarParallel reads a tar stream and writes files using a pool of
// goroutines. The main goroutine reads tar entries and buffers small file
// contents; workers write those files to disk concurrently. Large files are
// written inline to keep memory use bounded.
func extractTarParallel(ctx context.Context, r io.Reader, dir string) error {
// Resolve dir to absolute so containment checks are reliable.
var err error
dir, err = filepath.Abs(dir)
if err != nil {
return errors.Wrap(err, "resolve destination directory")
}

jobs := make(chan writeJob, extractWorkers*2)

var (
wg sync.WaitGroup
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change this to an errgroup.WithContext(ctx) so that a) cancellation kills the errgroup and b) any error kills the errgroup.

writeErrOnce sync.Once
writeErr error
)

for range extractWorkers {
wg.Go(func() {
for job := range jobs {
f, err := os.OpenFile(job.target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, job.mode)
if err != nil {
writeErrOnce.Do(func() { writeErr = errors.Errorf("open %s: %w", filepath.Base(job.target), err) })
continue
}
if _, err := f.Write(job.data); err != nil {
f.Close() //nolint:errcheck,gosec
writeErrOnce.Do(func() { writeErr = errors.Errorf("write %s: %w", filepath.Base(job.target), err) })
continue
}
if err := f.Close(); err != nil {
writeErrOnce.Do(func() { writeErr = errors.Errorf("close %s: %w", filepath.Base(job.target), err) })
}
}
})
}

copyBuf := make([]byte, 1<<20) // reused only for inline large-file writes

// createdDirs is accessed only by the main goroutine, so no mutex needed.
createdDirs := make(map[string]struct{})
ensureDir := func(d string, mode os.FileMode) error {
if _, ok := createdDirs[d]; ok {
return nil
}
if err := os.MkdirAll(d, mode); err != nil { //nolint:gosec // path is validated by caller
return errors.Wrap(err, "mkdir")
}
createdDirs[d] = struct{}{}
return nil
}

if err := tarCmd.Run(); err != nil {
return errors.Errorf("tar failed: %w: %s", err, tarStderr.String())
tr := tar.NewReader(r)
var readErr error
loop:
for {
if err := ctx.Err(); err != nil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull this loop body out into a separate function. This will let the function just return normal errors and simplify the code significantly.

readErr = errors.Wrap(err, "context cancelled")
break
}

hdr, err := tr.Next()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
readErr = errors.Wrap(err, "read tar entry")
break
}

target, err := safePath(dir, hdr.Name)
if err != nil {
readErr = errors.Errorf("unsafe tar entry %q: %w", hdr.Name, err)
break
}

switch hdr.Typeflag {
case tar.TypeDir:
if err := ensureDir(target, hdr.FileInfo().Mode()); err != nil {
readErr = errors.Errorf("mkdir %s: %w", hdr.Name, err)
break loop
}

case tar.TypeLink:
if err := ensureDir(filepath.Dir(target), 0o755); err != nil {
readErr = errors.Errorf("mkdir for hardlink %s: %w", hdr.Name, err)
break loop
}
linkTarget, err := safePath(dir, hdr.Linkname)
if err != nil {
readErr = errors.Errorf("unsafe hardlink target %q: %w", hdr.Linkname, err)
break loop
}
if err := os.Link(linkTarget, target); err != nil {
readErr = errors.Errorf("hardlink %s → %s: %w", hdr.Name, hdr.Linkname, err)
break loop
}

case tar.TypeReg:
if err := ensureDir(filepath.Dir(target), 0o755); err != nil {
readErr = errors.Errorf("mkdir %s: %w", hdr.Name, err)
break loop
}

if hdr.Size <= maxParallelFileSize {
// Buffer in memory and dispatch to worker pool so the main
// goroutine can continue reading the tar stream immediately.
buf := make([]byte, hdr.Size)
if _, err := io.ReadFull(tr, buf); err != nil {
readErr = errors.Errorf("read %s: %w", hdr.Name, err)
break loop
}
// Propagate worker errors early.
if writeErr != nil {
readErr = writeErr
break loop
}
jobs <- writeJob{target: target, mode: hdr.FileInfo().Mode(), data: buf}
} else {
// Large file: write inline to keep memory bounded.
f, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, hdr.FileInfo().Mode()) //nolint:gosec // path traversal guarded above
if err != nil {
readErr = errors.Errorf("open %s: %w", hdr.Name, err)
break loop
}
if _, err := io.CopyBuffer(f, io.LimitReader(tr, hdr.Size), copyBuf); err != nil {
f.Close() //nolint:errcheck,gosec
readErr = errors.Errorf("write %s: %w", hdr.Name, err)
break loop
}
if err := f.Close(); err != nil {
readErr = errors.Errorf("close %s: %w", hdr.Name, err)
break loop
}
}

case tar.TypeSymlink:
if err := ensureDir(filepath.Dir(target), 0o755); err != nil {
readErr = errors.Errorf("mkdir for symlink %s: %w", hdr.Name, err)
break loop
}
if _, err := safePath(dir, hdr.Linkname); err != nil {
readErr = errors.Errorf("unsafe symlink target %q: %w", hdr.Linkname, err)
break loop
}
if err := os.Symlink(hdr.Linkname, target); err != nil {
readErr = errors.Errorf("symlink %s → %s: %w", hdr.Name, hdr.Linkname, err)
break loop
}
}
}

return nil
close(jobs)
wg.Wait()

if readErr != nil {
return readErr
}
return writeErr
}