diff --git a/scripts/fiximports/main.go b/scripts/fiximports/main.go index 427975855bc..3a5a188814f 100644 --- a/scripts/fiximports/main.go +++ b/scripts/fiximports/main.go @@ -3,14 +3,16 @@ package main import ( "bytes" "fmt" - "io" "io/fs" "os" "path/filepath" "regexp" + "runtime" "strings" + "golang.org/x/sync/errgroup" "golang.org/x/tools/imports" + "golang.org/x/xerrors" ) var ( @@ -25,7 +27,18 @@ var ( consecutiveNewlinesRegex = regexp.MustCompile(`\n\s*\n`) ) +type fileContent struct { + path string + original []byte + current []byte + changed bool +} + func main() { + numWorkers := runtime.NumCPU() + + // Collect all the filenames that we want to process + var files []string if err := filepath.Walk(".", func(path string, info fs.FileInfo, err error) error { switch { case err != nil: @@ -39,49 +52,112 @@ func main() { !strings.HasSuffix(info.Name(), ".go"): return nil } - return fixGoImports(path) + files = append(files, path) + return nil }); err != nil { - fmt.Printf("Error fixing go imports: %v\n", err) + _, _ = fmt.Fprintf(os.Stderr, "Error walking directory: %v\n", err) os.Exit(1) } -} -func fixGoImports(path string) error { - sourceFile, err := os.OpenFile(path, os.O_RDWR, 0666) + // Read all file contents in parallel + fileContents, err := readFilesParallel(files, numWorkers) if err != nil { - return err + _, _ = fmt.Fprintf(os.Stderr, "Error reading files: %v\n", err) + os.Exit(1) } - defer func() { _ = sourceFile.Close() }() - source, err := io.ReadAll(sourceFile) - if err != nil { - return err - } - formatted := collapseImportNewlines(source) + // Because we have multiple ways of separating imports, we have to imports.Process for each one + // but imports.LocalPrefix is a global, so we have to set it for each group and process files + // in parallel. for _, prefix := range groupByPrefixes { imports.LocalPrefix = prefix - formatted, err = imports.Process(path, formatted, nil) - if err != nil { - return err + if err := processFilesParallel(fileContents, numWorkers); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "Error processing files with prefix %s: %v\n", prefix, err) + os.Exit(1) } } - if !bytes.Equal(source, formatted) { - if err := replaceFileContent(sourceFile, formatted); err != nil { - return err - } + + // Write modified files in parallel + if err := writeFilesParallel(fileContents, numWorkers); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "Error writing files: %v\n", err) + os.Exit(1) } - return nil } -func replaceFileContent(target *os.File, replacement []byte) error { - if _, err := target.Seek(0, io.SeekStart); err != nil { - return err +func readFilesParallel(files []string, numWorkers int) ([]*fileContent, error) { + fileContents := make([]*fileContent, len(files)) + + var g errgroup.Group + g.SetLimit(numWorkers) + + for i, path := range files { + g.Go(func() error { + content, err := os.ReadFile(path) + if err != nil { + return xerrors.Errorf("reading %s: %w", path, err) + } + + // Collapse is a cheap operation to do here + collapsed := collapseImportNewlines(content) + fileContents[i] = &fileContent{ + path: path, + original: content, + current: collapsed, + changed: !bytes.Equal(content, collapsed), + } + return nil + }) } - written, err := target.Write(replacement) - if err != nil { - return err + + if err := g.Wait(); err != nil { + return nil, err + } + + return fileContents, nil +} + +func processFilesParallel(fileContents []*fileContent, numWorkers int) error { + var g errgroup.Group + g.SetLimit(numWorkers) + + for _, file := range fileContents { + if file == nil { + continue + } + g.Go(func() error { + formatted, err := imports.Process(file.path, file.current, nil) + if err != nil { + return xerrors.Errorf("processing %s: %w", file.path, err) + } + + if !bytes.Equal(file.current, formatted) { + file.current = formatted + file.changed = true + } + return nil + }) } - return target.Truncate(int64(written)) + + return g.Wait() +} + +func writeFilesParallel(fileContents []*fileContent, numWorkers int) error { + var g errgroup.Group + g.SetLimit(numWorkers) + + for _, file := range fileContents { + if file == nil || !file.changed { + continue + } + g.Go(func() error { + if err := os.WriteFile(file.path, file.current, 0666); err != nil { + return xerrors.Errorf("writing %s: %w", file.path, err) + } + return nil + }) + } + + return g.Wait() } func collapseImportNewlines(content []byte) []byte {