-
Notifications
You must be signed in to change notification settings - Fork 5
perf: parallel tar extraction with 64 worker goroutines #218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| package cache | ||
|
|
||
| import ( | ||
| "context" | ||
| "io" | ||
| "sync" | ||
|
|
||
| "github.com/alecthomas/errors" | ||
| "github.com/minio/minio-go/v7" | ||
| ) | ||
|
|
||
| const ( | ||
| // s3DownloadChunkSize is the size of each parallel range-GET request. | ||
| // 32 MiB matches the gradle-cache-tool's benchmarked default. | ||
| s3DownloadChunkSize = 32 << 20 | ||
| // s3DownloadWorkers is the number of concurrent range-GET requests. | ||
| // Benchmarking showed no throughput difference from 4 to 128 workers | ||
| // (extraction IOPS is the bottleneck), so 8 keeps connection count low. | ||
| s3DownloadWorkers = 8 | ||
| ) | ||
|
|
||
| // parallelGetReader returns an io.ReadCloser that downloads the S3 object | ||
| // using parallel range-GET requests and reassembles chunks in order. | ||
| // For objects smaller than one chunk, it falls back to a single GetObject. | ||
| func (s *S3) parallelGetReader(ctx context.Context, bucket, objectName string, size int64) (io.ReadCloser, error) { | ||
| if size <= s3DownloadChunkSize { | ||
| // Small object: single stream. | ||
| obj, err := s.client.GetObject(ctx, bucket, objectName, minio.GetObjectOptions{}) | ||
| if err != nil { | ||
| return nil, errors.Errorf("failed to get object: %w", err) | ||
| } | ||
| return &s3Reader{obj: obj}, nil | ||
| } | ||
|
|
||
| // Large object: parallel range requests reassembled in order via io.Pipe. | ||
| pr, pw := io.Pipe() | ||
| go func() { | ||
| pw.CloseWithError(s.parallelGet(ctx, bucket, objectName, size, pw)) | ||
| }() | ||
| return pr, nil | ||
| } | ||
|
|
||
| // parallelGet downloads an S3 object in parallel chunks and writes them in | ||
| // order to w. Each worker downloads its chunk into memory so the TCP | ||
| // connection stays active at full speed. Peak memory: numWorkers × chunkSize. | ||
| func (s *S3) parallelGet(ctx context.Context, bucket, objectName string, size int64, w io.Writer) error { | ||
| numChunks := int((size + s3DownloadChunkSize - 1) / s3DownloadChunkSize) | ||
| numWorkers := min(s3DownloadWorkers, numChunks) | ||
|
|
||
| type chunkResult struct { | ||
| data []byte | ||
| err error | ||
| } | ||
|
|
||
| // One buffered channel per chunk so workers never block after reading. | ||
| results := make([]chan chunkResult, numChunks) | ||
| for i := range results { | ||
| results[i] = make(chan chunkResult, 1) | ||
| } | ||
|
|
||
| // Work queue of chunk indices. | ||
| work := make(chan int, numChunks) | ||
| for i := range numChunks { | ||
| work <- i | ||
| } | ||
| close(work) | ||
|
|
||
| var wg sync.WaitGroup | ||
| for range numWorkers { | ||
| wg.Go(func() { | ||
| for seq := range work { | ||
| start := int64(seq) * s3DownloadChunkSize | ||
| end := min(start+s3DownloadChunkSize-1, size-1) | ||
|
|
||
| opts := minio.GetObjectOptions{} | ||
| if err := opts.SetRange(start, end); err != nil { | ||
| results[seq] <- chunkResult{err: errors.Errorf("set range %d-%d: %w", start, end, err)} | ||
| continue | ||
| } | ||
|
|
||
| obj, err := s.client.GetObject(ctx, bucket, objectName, opts) | ||
| if err != nil { | ||
| results[seq] <- chunkResult{err: errors.Errorf("get range %d-%d: %w", start, end, err)} | ||
| continue | ||
| } | ||
|
|
||
| // Drain the body immediately so the TCP connection stays at | ||
| // full speed. All workers do this concurrently, saturating | ||
| // the available S3 bandwidth. | ||
| data, readErr := io.ReadAll(obj) | ||
| obj.Close() //nolint:errcheck,gosec | ||
| results[seq] <- chunkResult{data: data, err: readErr} | ||
| } | ||
| }) | ||
| } | ||
|
|
||
| // Write chunks in order. Each receive blocks until that chunk's worker | ||
| // finishes, while other workers continue downloading concurrently. | ||
| var writeErr error | ||
| for _, ch := range results { | ||
| r := <-ch | ||
| if writeErr != nil { | ||
| continue // drain remaining channels so goroutines can exit | ||
| } | ||
| if r.err != nil { | ||
| writeErr = r.err | ||
| continue | ||
| } | ||
| if _, err := w.Write(r.data); err != nil { | ||
| writeErr = err | ||
| } | ||
| } | ||
|
|
||
| wg.Wait() | ||
| return writeErr | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| package snapshot | ||
|
|
||
| import ( | ||
| "archive/tar" | ||
| "bufio" | ||
| "bytes" | ||
| "context" | ||
|
|
@@ -12,6 +13,8 @@ import ( | |
| "os/exec" | ||
| "path/filepath" | ||
| "runtime" | ||
| "strings" | ||
| "sync" | ||
| "time" | ||
|
|
||
| "github.com/alecthomas/errors" | ||
|
|
@@ -187,9 +190,32 @@ func Restore(ctx context.Context, remote cache.Cache, key cache.Key, directory s | |
| return Extract(ctx, rc, directory, threads) | ||
| } | ||
|
|
||
| const ( | ||
| // extractWorkers is the number of goroutines writing files concurrently | ||
| // during parallel tar extraction. Hides per-file open/write/close syscall | ||
| // latency so the tar-stream reader (and download pipeline behind it) is not | ||
| // stalled waiting for individual file writes to complete. | ||
| // Benchmarked on r8id.metal-48xlarge (NVMe, 96 cores) with a 334K-file | ||
| // bundle: 64 workers = 6.27s, 128 = 6.84s (extra GC pressure outweighs | ||
| // any I/O concurrency gain). | ||
| extractWorkers = 64 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can't be static for several reasons:
This will need to be threaded through a config value. |
||
| // maxParallelFileSize is the largest file that will be buffered in memory | ||
| // and dispatched to the worker pool. Files larger than this are written | ||
| // inline in the main goroutine to keep peak memory bounded. | ||
| // At 4 MiB, 99.97% of Gradle cache entries go through the parallel path. | ||
| maxParallelFileSize = 4 << 20 // 4 MiB | ||
| ) | ||
|
|
||
| // Extract decompresses a zstd+tar stream into directory, preserving all file | ||
| // permissions, ownership, and symlinks. threads controls zstd parallelism; | ||
| // 0 uses all available CPU cores. | ||
| // | ||
| // The single-threaded bottleneck on restore is writing files to disk. Even | ||
| // though tar entries must be read sequentially (the format has no index), the | ||
| // actual file writes are independent. The extractor dispatches each entry | ||
| // (buffered in memory, ≤4 MiB) to one of 64 worker goroutines that write | ||
| // concurrently. This hides the per-file syscall latency (~20µs × N files) | ||
| // behind parallelism. | ||
| func Extract(ctx context.Context, r io.Reader, directory string, threads int) error { | ||
| if threads <= 0 { | ||
| threads = runtime.NumCPU() | ||
|
|
@@ -212,15 +238,193 @@ func Extract(ctx context.Context, r io.Reader, directory string, threads int) er | |
| } | ||
| defer dec.Close() | ||
|
|
||
| tarCmd := exec.CommandContext(ctx, "tar", "-xpf", "-", "-C", directory) | ||
| tarCmd.Stdin = dec | ||
| return extractTarParallel(ctx, dec, directory) | ||
| } | ||
|
|
||
| var tarStderr bytes.Buffer | ||
| tarCmd.Stderr = &tarStderr | ||
| type writeJob struct { | ||
| target string | ||
| mode os.FileMode | ||
| data []byte | ||
| } | ||
|
|
||
| // safePath validates that name is a relative path that stays within dir when | ||
| // joined. It rejects absolute paths and parent traversals (".."). Returns the | ||
| // resolved path under dir. | ||
| func safePath(dir, name string) (string, error) { | ||
| clean := filepath.Clean(name) | ||
| if filepath.IsAbs(clean) { | ||
| return "", errors.Errorf("path %q is absolute", name) | ||
| } | ||
| if clean == ".." || strings.HasPrefix(clean, ".."+string(os.PathSeparator)) { | ||
| return "", errors.Errorf("path %q escapes destination directory", name) | ||
| } | ||
| joined := filepath.Join(dir, clean) | ||
| if !strings.HasPrefix(joined, dir+string(os.PathSeparator)) && joined != dir { | ||
| return "", errors.Errorf("path %q resolves outside destination directory", name) | ||
| } | ||
| return joined, nil | ||
| } | ||
|
|
||
| // extractTarParallel reads a tar stream and writes files using a pool of | ||
| // goroutines. The main goroutine reads tar entries and buffers small file | ||
| // contents; workers write those files to disk concurrently. Large files are | ||
| // written inline to keep memory use bounded. | ||
| func extractTarParallel(ctx context.Context, r io.Reader, dir string) error { | ||
| // Resolve dir to absolute so containment checks are reliable. | ||
| var err error | ||
| dir, err = filepath.Abs(dir) | ||
| if err != nil { | ||
| return errors.Wrap(err, "resolve destination directory") | ||
| } | ||
|
|
||
| jobs := make(chan writeJob, extractWorkers*2) | ||
|
|
||
| var ( | ||
| wg sync.WaitGroup | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change this to an errgroup.WithContext(ctx) so that a) cancellation kills the errgroup and b) any error kills the errgroup. |
||
| writeErrOnce sync.Once | ||
| writeErr error | ||
| ) | ||
|
|
||
| for range extractWorkers { | ||
| wg.Go(func() { | ||
| for job := range jobs { | ||
| f, err := os.OpenFile(job.target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, job.mode) | ||
| if err != nil { | ||
| writeErrOnce.Do(func() { writeErr = errors.Errorf("open %s: %w", filepath.Base(job.target), err) }) | ||
| continue | ||
| } | ||
| if _, err := f.Write(job.data); err != nil { | ||
| f.Close() //nolint:errcheck,gosec | ||
| writeErrOnce.Do(func() { writeErr = errors.Errorf("write %s: %w", filepath.Base(job.target), err) }) | ||
| continue | ||
| } | ||
| if err := f.Close(); err != nil { | ||
| writeErrOnce.Do(func() { writeErr = errors.Errorf("close %s: %w", filepath.Base(job.target), err) }) | ||
| } | ||
| } | ||
| }) | ||
| } | ||
|
|
||
| copyBuf := make([]byte, 1<<20) // reused only for inline large-file writes | ||
|
|
||
| // createdDirs is accessed only by the main goroutine, so no mutex needed. | ||
| createdDirs := make(map[string]struct{}) | ||
| ensureDir := func(d string, mode os.FileMode) error { | ||
| if _, ok := createdDirs[d]; ok { | ||
| return nil | ||
| } | ||
| if err := os.MkdirAll(d, mode); err != nil { //nolint:gosec // path is validated by caller | ||
| return errors.Wrap(err, "mkdir") | ||
| } | ||
| createdDirs[d] = struct{}{} | ||
| return nil | ||
| } | ||
|
|
||
| if err := tarCmd.Run(); err != nil { | ||
| return errors.Errorf("tar failed: %w: %s", err, tarStderr.String()) | ||
| tr := tar.NewReader(r) | ||
| var readErr error | ||
| loop: | ||
| for { | ||
| if err := ctx.Err(); err != nil { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pull this loop body out into a separate function. This will let the function just return normal errors and simplify the code significantly. |
||
| readErr = errors.Wrap(err, "context cancelled") | ||
| break | ||
| } | ||
|
|
||
| hdr, err := tr.Next() | ||
| if errors.Is(err, io.EOF) { | ||
| break | ||
| } | ||
| if err != nil { | ||
| readErr = errors.Wrap(err, "read tar entry") | ||
| break | ||
| } | ||
|
|
||
| target, err := safePath(dir, hdr.Name) | ||
| if err != nil { | ||
| readErr = errors.Errorf("unsafe tar entry %q: %w", hdr.Name, err) | ||
| break | ||
| } | ||
|
|
||
| switch hdr.Typeflag { | ||
| case tar.TypeDir: | ||
| if err := ensureDir(target, hdr.FileInfo().Mode()); err != nil { | ||
| readErr = errors.Errorf("mkdir %s: %w", hdr.Name, err) | ||
| break loop | ||
| } | ||
|
|
||
| case tar.TypeLink: | ||
| if err := ensureDir(filepath.Dir(target), 0o755); err != nil { | ||
| readErr = errors.Errorf("mkdir for hardlink %s: %w", hdr.Name, err) | ||
| break loop | ||
| } | ||
| linkTarget, err := safePath(dir, hdr.Linkname) | ||
| if err != nil { | ||
| readErr = errors.Errorf("unsafe hardlink target %q: %w", hdr.Linkname, err) | ||
| break loop | ||
| } | ||
| if err := os.Link(linkTarget, target); err != nil { | ||
| readErr = errors.Errorf("hardlink %s → %s: %w", hdr.Name, hdr.Linkname, err) | ||
| break loop | ||
| } | ||
|
|
||
| case tar.TypeReg: | ||
| if err := ensureDir(filepath.Dir(target), 0o755); err != nil { | ||
| readErr = errors.Errorf("mkdir %s: %w", hdr.Name, err) | ||
| break loop | ||
| } | ||
|
|
||
| if hdr.Size <= maxParallelFileSize { | ||
| // Buffer in memory and dispatch to worker pool so the main | ||
| // goroutine can continue reading the tar stream immediately. | ||
| buf := make([]byte, hdr.Size) | ||
| if _, err := io.ReadFull(tr, buf); err != nil { | ||
| readErr = errors.Errorf("read %s: %w", hdr.Name, err) | ||
| break loop | ||
| } | ||
| // Propagate worker errors early. | ||
| if writeErr != nil { | ||
| readErr = writeErr | ||
| break loop | ||
| } | ||
| jobs <- writeJob{target: target, mode: hdr.FileInfo().Mode(), data: buf} | ||
| } else { | ||
| // Large file: write inline to keep memory bounded. | ||
| f, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, hdr.FileInfo().Mode()) //nolint:gosec // path traversal guarded above | ||
| if err != nil { | ||
| readErr = errors.Errorf("open %s: %w", hdr.Name, err) | ||
| break loop | ||
| } | ||
| if _, err := io.CopyBuffer(f, io.LimitReader(tr, hdr.Size), copyBuf); err != nil { | ||
| f.Close() //nolint:errcheck,gosec | ||
| readErr = errors.Errorf("write %s: %w", hdr.Name, err) | ||
| break loop | ||
| } | ||
| if err := f.Close(); err != nil { | ||
| readErr = errors.Errorf("close %s: %w", hdr.Name, err) | ||
| break loop | ||
| } | ||
| } | ||
|
|
||
| case tar.TypeSymlink: | ||
| if err := ensureDir(filepath.Dir(target), 0o755); err != nil { | ||
| readErr = errors.Errorf("mkdir for symlink %s: %w", hdr.Name, err) | ||
| break loop | ||
| } | ||
| if _, err := safePath(dir, hdr.Linkname); err != nil { | ||
| readErr = errors.Errorf("unsafe symlink target %q: %w", hdr.Linkname, err) | ||
| break loop | ||
| } | ||
| if err := os.Symlink(hdr.Linkname, target); err != nil { | ||
| readErr = errors.Errorf("symlink %s → %s: %w", hdr.Name, hdr.Linkname, err) | ||
| break loop | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return nil | ||
| close(jobs) | ||
| wg.Wait() | ||
|
|
||
| if readErr != nil { | ||
| return readErr | ||
| } | ||
| return writeErr | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't include pre-parallel benchmark results, what were the original numbers?