diff --git a/refactor/importgraph/graph_test.go b/refactor/importgraph/graph_test.go index 2ab54e2ab0d..75263839a24 100644 --- a/refactor/importgraph/graph_test.go +++ b/refactor/importgraph/graph_test.go @@ -10,7 +10,9 @@ package importgraph_test import ( + "fmt" "go/build" + "os" "sort" "strings" "testing" @@ -30,10 +32,40 @@ func TestBuild(t *testing.T) { var gopath string for _, env := range exported.Config.Env { - if !strings.HasPrefix(env, "GOPATH=") { + eq := strings.Index(env, "=") + if eq == 0 { + // We sometimes see keys with a single leading "=" in the environment on Windows. + // TODO(#49886): What is the correct way to parse them in general? + eq = strings.Index(env[1:], "=") + 1 + } + if eq < 0 { + t.Fatalf("invalid variable in exported.Config.Env: %q", env) + } + k := env[:eq] + v := env[eq+1:] + if k == "GOPATH" { + gopath = v + } + + if os.Getenv(k) == v { continue } - gopath = strings.TrimPrefix(env, "GOPATH=") + defer func(prev string, prevOK bool) { + if !prevOK { + if err := os.Unsetenv(k); err != nil { + t.Fatal(err) + } + } else { + if err := os.Setenv(k, prev); err != nil { + t.Fatal(err) + } + } + }(os.LookupEnv(k)) + + if err := os.Setenv(k, v); err != nil { + t.Fatal(err) + } + t.Logf("%s=%s", k, v) } if gopath == "" { t.Fatal("Failed to fish GOPATH out of env: ", exported.Config.Env) @@ -41,45 +73,97 @@ func TestBuild(t *testing.T) { var buildContext = build.Default buildContext.GOPATH = gopath + buildContext.Dir = exported.Config.Dir + + forward, reverse, errs := importgraph.Build(&buildContext) + for path, err := range errs { + t.Errorf("%s: %s", path, err) + } + if t.Failed() { + return + } + + // Log the complete graph before the errors, so that the errors are near the + // end of the log (where we expect them to be). + nodePrinted := map[string]bool{} + printNode := func(direction string, from string) { + key := fmt.Sprintf("%s[%q]", direction, from) + if nodePrinted[key] { + return + } + nodePrinted[key] = true + + var g importgraph.Graph + switch direction { + case "forward": + g = forward + case "reverse": + g = reverse + default: + t.Helper() + t.Fatalf("bad direction: %q", direction) + } + + t.Log(key) + var pkgs []string + for pkg := range g[from] { + pkgs = append(pkgs, pkg) + } + sort.Strings(pkgs) + for _, pkg := range pkgs { + t.Logf("\t%s", pkg) + } + } - forward, reverse, errors := importgraph.Build(&buildContext) + if testing.Verbose() { + printNode("forward", this) + printNode("reverse", this) + } // Test direct edges. // We throw in crypto/hmac to prove that external test files // (such as this one) are inspected. for _, p := range []string{"go/build", "testing", "crypto/hmac"} { if !forward[this][p] { - t.Errorf("forward[importgraph][%s] not found", p) + printNode("forward", this) + t.Errorf("forward[%q][%q] not found", this, p) } if !reverse[p][this] { - t.Errorf("reverse[%s][importgraph] not found", p) + printNode("reverse", p) + t.Errorf("reverse[%q][%q] not found", p, this) } } // Test non-existent direct edges for _, p := range []string{"errors", "reflect"} { if forward[this][p] { - t.Errorf("unexpected: forward[importgraph][%s] found", p) + printNode("forward", this) + t.Errorf("unexpected: forward[%q][%q] found", this, p) } if reverse[p][this] { - t.Errorf("unexpected: reverse[%s][importgraph] found", p) + printNode("reverse", p) + t.Errorf("unexpected: reverse[%q][%q] found", p, this) } } // Test Search is reflexive. if !forward.Search(this)[this] { + printNode("forward", this) t.Errorf("irreflexive: forward.Search(importgraph)[importgraph] not found") } if !reverse.Search(this)[this] { + printNode("reverse", this) t.Errorf("irrefexive: reverse.Search(importgraph)[importgraph] not found") } // Test Search is transitive. (There is no direct edge to these packages.) for _, p := range []string{"errors", "reflect", "unsafe"} { if !forward.Search(this)[p] { + printNode("forward", this) t.Errorf("intransitive: forward.Search(importgraph)[%s] not found", p) } if !reverse.Search(p)[this] { + printNode("reverse", p) t.Errorf("intransitive: reverse.Search(%s)[importgraph] not found", p) } } @@ -95,26 +179,10 @@ func TestBuild(t *testing.T) { !forward.Search("io")["fmt"] || !reverse.Search("fmt")["io"] || !reverse.Search("io")["fmt"] { + printNode("forward", "fmt") + printNode("forward", "io") + printNode("reverse", "fmt") + printNode("reverse", "io") t.Errorf("fmt and io are not mutually reachable despite being in the same SCC") } - - // debugging - if false { - for path, err := range errors { - t.Logf("%s: %s", path, err) - } - printSorted := func(direction string, g importgraph.Graph, start string) { - t.Log(direction) - var pkgs []string - for pkg := range g.Search(start) { - pkgs = append(pkgs, pkg) - } - sort.Strings(pkgs) - for _, pkg := range pkgs { - t.Logf("\t%s", pkg) - } - } - printSorted("forward", forward, this) - printSorted("reverse", reverse, this) - } }