diff --git a/go/tools/builders/generate_test_main.go b/go/tools/builders/generate_test_main.go index 4cc23e0471..ff150249a9 100644 --- a/go/tools/builders/generate_test_main.go +++ b/go/tools/builders/generate_test_main.go @@ -58,6 +58,7 @@ import ( "log" "os" "fmt" + "strconv" "testing" "testing/internal/testdeps" @@ -72,7 +73,7 @@ import ( {{end}} ) -var tests = []testing.InternalTest{ +var allTests = []testing.InternalTest{ {{range .TestNames}} {"{{.}}", undertest.{{.}} }, {{end}} @@ -84,6 +85,24 @@ var benchmarks = []testing.InternalBenchmark{ {{end}} } +func testsInShard() []testing.InternalTest { + totalShards, err := strconv.Atoi(os.Getenv("TEST_TOTAL_SHARDS")) + if err != nil || totalShards <= 1 { + return allTests + } + shardIndex, err := strconv.Atoi(os.Getenv("TEST_SHARD_INDEX")) + if err != nil || shardIndex < 0 { + return allTests + } + tests := []testing.InternalTest{} + for i, t := range allTests { + if i % totalShards == shardIndex { + tests = append(tests, t) + } + } + return tests +} + func coverRegisterAll() testing.Cover { coverage := testing.Cover{ Mode: "set", @@ -144,7 +163,7 @@ func main() { testing.RegisterCover(coverage) } - m := testing.MainStart(testdeps.TestDeps{}, tests, benchmarks, nil) + m := testing.MainStart(testdeps.TestDeps{}, testsInShard(), benchmarks, nil) {{if not .HasTestMain}} os.Exit(m.Run()) {{else}}