Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 104 additions & 28 deletions scripts/fiximports/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ package main
import (
"bytes"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"regexp"
"runtime"
"strings"

"golang.org/x/sync/errgroup"
"golang.org/x/tools/imports"
"golang.org/x/xerrors"
)

var (
Expand All @@ -25,7 +27,18 @@ var (
consecutiveNewlinesRegex = regexp.MustCompile(`\n\s*\n`)
)

type fileContent struct {
path string
original []byte
current []byte
changed bool
}

func main() {
numWorkers := runtime.NumCPU()

// Collect all the filenames that we want to process
var files []string
if err := filepath.Walk(".", func(path string, info fs.FileInfo, err error) error {
switch {
case err != nil:
Expand All @@ -39,49 +52,112 @@ func main() {
!strings.HasSuffix(info.Name(), ".go"):
return nil
}
return fixGoImports(path)
files = append(files, path)
return nil
}); err != nil {
fmt.Printf("Error fixing go imports: %v\n", err)
_, _ = fmt.Fprintf(os.Stderr, "Error walking directory: %v\n", err)
os.Exit(1)
}
}

func fixGoImports(path string) error {
sourceFile, err := os.OpenFile(path, os.O_RDWR, 0666)
// Read all file contents in parallel
fileContents, err := readFilesParallel(files, numWorkers)
if err != nil {
return err
_, _ = fmt.Fprintf(os.Stderr, "Error reading files: %v\n", err)
os.Exit(1)
}
defer func() { _ = sourceFile.Close() }()

source, err := io.ReadAll(sourceFile)
if err != nil {
return err
}
formatted := collapseImportNewlines(source)
// Because we have multiple ways of separating imports, we have to imports.Process for each one
// but imports.LocalPrefix is a global, so we have to set it for each group and process files
Comment thread
rvagg marked this conversation as resolved.
// in parallel.
for _, prefix := range groupByPrefixes {
imports.LocalPrefix = prefix
formatted, err = imports.Process(path, formatted, nil)
if err != nil {
return err
if err := processFilesParallel(fileContents, numWorkers); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "Error processing files with prefix %s: %v\n", prefix, err)
os.Exit(1)
}
}
if !bytes.Equal(source, formatted) {
if err := replaceFileContent(sourceFile, formatted); err != nil {
return err
}

// Write modified files in parallel
if err := writeFilesParallel(fileContents, numWorkers); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "Error writing files: %v\n", err)
os.Exit(1)
}
return nil
}

func replaceFileContent(target *os.File, replacement []byte) error {
if _, err := target.Seek(0, io.SeekStart); err != nil {
return err
func readFilesParallel(files []string, numWorkers int) ([]*fileContent, error) {
fileContents := make([]*fileContent, len(files))

var g errgroup.Group
g.SetLimit(numWorkers)

for i, path := range files {
Comment thread
rvagg marked this conversation as resolved.
g.Go(func() error {
content, err := os.ReadFile(path)
if err != nil {
return xerrors.Errorf("reading %s: %w", path, err)
}

// Collapse is a cheap operation to do here
collapsed := collapseImportNewlines(content)
fileContents[i] = &fileContent{
path: path,
original: content,
current: collapsed,
changed: !bytes.Equal(content, collapsed),
}
return nil
})
}
written, err := target.Write(replacement)
if err != nil {
return err

Comment thread
rvagg marked this conversation as resolved.
if err := g.Wait(); err != nil {
return nil, err
}

return fileContents, nil
}

func processFilesParallel(fileContents []*fileContent, numWorkers int) error {
var g errgroup.Group
g.SetLimit(numWorkers)

for _, file := range fileContents {
Comment thread
rvagg marked this conversation as resolved.
Comment thread
rvagg marked this conversation as resolved.
if file == nil {
continue
}
g.Go(func() error {
formatted, err := imports.Process(file.path, file.current, nil)
if err != nil {
return xerrors.Errorf("processing %s: %w", file.path, err)
}

if !bytes.Equal(file.current, formatted) {
file.current = formatted
file.changed = true
}
return nil
})
}
return target.Truncate(int64(written))

return g.Wait()
}

func writeFilesParallel(fileContents []*fileContent, numWorkers int) error {
var g errgroup.Group
g.SetLimit(numWorkers)

for _, file := range fileContents {
Comment thread
rvagg marked this conversation as resolved.
Comment thread
rvagg marked this conversation as resolved.
if file == nil || !file.changed {
continue
}
g.Go(func() error {
if err := os.WriteFile(file.path, file.current, 0666); err != nil {
return xerrors.Errorf("writing %s: %w", file.path, err)
}
return nil
})
}

return g.Wait()
}

func collapseImportNewlines(content []byte) []byte {
Expand Down