diff --git a/internal/backend/remote-state/http/backend.go b/internal/backend/remote-state/http/backend.go index dda03c06b409..0494e8dd9925 100644 --- a/internal/backend/remote-state/http/backend.go +++ b/internal/backend/remote-state/http/backend.go @@ -313,7 +313,13 @@ func (b *Backend) StateMgr(name string) (statemgr.Full, tfdiags.Diagnostics) { return nil, diags.Append(backend.ErrWorkspacesNotSupported) } - return &remote.State{Client: b.client}, diags + sm := &remote.State{Client: b.client} + + if err := sm.RefreshState(); err != nil { + return nil, diags.Append(err) + } + + return sm, diags } func (b *Backend) Workspaces() ([]string, tfdiags.Diagnostics) { diff --git a/internal/backend/remote-state/http/server_test.go b/internal/backend/remote-state/http/server_test.go index 2e406038bc3c..43caf607dea8 100644 --- a/internal/backend/remote-state/http/server_test.go +++ b/internal/backend/remote-state/http/server_test.go @@ -277,14 +277,10 @@ func TestMTLSServer_NoCertFails(t *testing.T) { } // Now get a state manager and check that it fails to refresh the state - sm, sDiags := b.StateMgr(backend.DefaultStateName) - if sDiags.HasErrors() { - t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, sDiags) - } - err = sm.RefreshState() - if nil == err { + _, sDiags := b.StateMgr(backend.DefaultStateName) + if !sDiags.HasErrors() { t.Error("expected error when refreshing state without a client cert") - } else if !strings.Contains(err.Error(), "remote error: tls: certificate required") { + } else if !strings.Contains(sDiags.Err().Error(), "remote error: tls: certificate required") { t.Errorf("expected the error to report missing tls credentials: %v", err) } } diff --git a/internal/backend/remote-state/http/test_backend.go b/internal/backend/remote-state/http/test_backend.go new file mode 100644 index 000000000000..bd264a80a7fb --- /dev/null +++ b/internal/backend/remote-state/http/test_backend.go @@ -0,0 +1,132 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 +package http + +import ( + "bytes" + "fmt" + "io" + "net/http" + "reflect" +) + +type TestRequestHandleFunc func(w http.ResponseWriter, r *http.Request) + +type TestHTTPBackend struct { + Data []byte + Locked bool + + methodFuncs map[string]TestRequestHandleFunc + methodCalls map[string]int +} + +func (h *TestHTTPBackend) Handle(w http.ResponseWriter, r *http.Request) { + h.countMethodCall(r.Method) + called := h.callMethod(r.Method, w, r) + if called { + return + } + + switch r.Method { + case "GET": + w.Write(h.Data) + case "PUT": + buf := new(bytes.Buffer) + if _, err := io.Copy(buf, r.Body); err != nil { + w.WriteHeader(500) + } + w.WriteHeader(201) + h.Data = buf.Bytes() + case "POST": + buf := new(bytes.Buffer) + if _, err := io.Copy(buf, r.Body); err != nil { + w.WriteHeader(500) + } + h.Data = buf.Bytes() + case "LOCK": + if h.Locked { + w.WriteHeader(423) + } else { + h.Locked = true + } + case "UNLOCK": + h.Locked = false + case "DELETE": + h.Data = nil + w.WriteHeader(200) + default: + w.WriteHeader(http.StatusNotImplemented) + w.Write([]byte(fmt.Sprintf("Unknown method: %s", r.Method))) + } +} + +func (h *TestHTTPBackend) countMethodCall(method string) { + if h.methodCalls == nil { + h.methodCalls = make(map[string]int) + } + if _, ok := h.methodCalls[method]; !ok { + h.methodCalls[method] = 0 + } + h.methodCalls[method]++ +} + +func (h *TestHTTPBackend) CallCount(method string) int { + if h.methodCalls == nil { + return 0 + } + callCount, ok := h.methodCalls[method] + if !ok { + return 0 + } + return callCount +} + +func (h *TestHTTPBackend) callMethod(method string, w http.ResponseWriter, r *http.Request) bool { + if h.methodFuncs == nil { + return false + } + f, ok := h.methodFuncs[method] + if ok { + f(w, r) + } + return ok +} + +func (h *TestHTTPBackend) SetMethodFunc(method string, impl TestRequestHandleFunc) { + if h.methodFuncs == nil { + h.methodFuncs = make(map[string]TestRequestHandleFunc) + } + h.methodFuncs[method] = impl +} + +// mod_dav-ish behavior +func (h *TestHTTPBackend) HandleWebDAV(w http.ResponseWriter, r *http.Request) { + h.countMethodCall(r.Method) + if f, ok := h.methodFuncs[r.Method]; ok { + f(w, r) + return + } + + switch r.Method { + case "GET": + w.Write(h.Data) + case "PUT": + buf := new(bytes.Buffer) + if _, err := io.Copy(buf, r.Body); err != nil { + w.WriteHeader(500) + } + if reflect.DeepEqual(h.Data, buf.Bytes()) { + h.Data = buf.Bytes() + w.WriteHeader(204) + } else { + h.Data = buf.Bytes() + w.WriteHeader(201) + } + case "DELETE": + h.Data = nil + w.WriteHeader(200) + default: + w.WriteHeader(http.StatusNotImplemented) + w.Write([]byte(fmt.Sprintf("Unknown method: %s", r.Method))) + } +} diff --git a/internal/command/init_test.go b/internal/command/init_test.go index 361702972b94..6c266b9b8ef9 100644 --- a/internal/command/init_test.go +++ b/internal/command/init_test.go @@ -7,6 +7,8 @@ import ( "fmt" "io/ioutil" "log" + "net/http" + "net/http/httptest" "os" "path/filepath" "regexp" @@ -20,6 +22,7 @@ import ( "github.com/zclconf/go-cty/cty" "github.com/hashicorp/terraform/internal/addrs" + httpBackend "github.com/hashicorp/terraform/internal/backend/remote-state/http" "github.com/hashicorp/terraform/internal/command/arguments" "github.com/hashicorp/terraform/internal/command/views" "github.com/hashicorp/terraform/internal/configs" @@ -481,6 +484,84 @@ func TestInit_backend(t *testing.T) { } } +// regression test for https://github.com/hashicorp/terraform/issues/38027 +func TestInit_backend_migration_stateMgr_error(t *testing.T) { + // Create a temporary working directory that is empty + td := t.TempDir() + t.Chdir(td) + + { + // create some state in (implied) local backend + outputCfg := `output "test" { value = "test" } +` + if err := os.WriteFile("output.tf", []byte(outputCfg), 0644); err != nil { + t.Fatalf("err: %s", err) + } + + ui := new(cli.MockUi) + applyView, done := testView(t) + applyCmd := &ApplyCommand{ + Meta: Meta{ + Ui: ui, + View: applyView, + }, + } + code := applyCmd.Run([]string{"-auto-approve"}) + testOut := done(t) + if code != 0 { + t.Fatalf("bad: \n%s", testOut.All()) + } + + if _, err := os.Stat(DefaultStateFilename); err != nil { + t.Fatalf("err: %s", err) + } + } + { + // attempt to migrate the state to a broken backend + testBackend := new(httpBackend.TestHTTPBackend) + testBackend.SetMethodFunc("GET", func(w http.ResponseWriter, r *http.Request) { + // simulate "broken backend" in the way described in #38027 + // i.e. access denied + w.WriteHeader(403) + }) + ts := httptest.NewServer(http.HandlerFunc(testBackend.Handle)) + t.Cleanup(ts.Close) + + backendCfg := fmt.Sprintf(`terraform { + backend "http" { + address = %q + } +} +`, ts.URL) + if err := os.WriteFile("backend.tf", []byte(backendCfg), 0644); err != nil { + t.Fatalf("err: %s", err) + } + + ui := new(cli.MockUi) + initView, done := testView(t) + initCmd := &InitCommand{ + Meta: Meta{ + Ui: ui, + View: initView, + }, + } + code := initCmd.Run([]string{"-migrate-state"}) + out := done(t) + if code == 0 { + t.Fatalf("expected migration to fail (gracefully): %s", out.Stdout()) + } + expectedErrMsg := "HTTP remote state endpoint invalid auth" + if !strings.Contains(out.Stderr(), expectedErrMsg) { + t.Fatalf("expected error %q, given: %s", expectedErrMsg, out.Stderr()) + } + + getCalled := testBackend.CallCount("GET") + if getCalled != 1 { + t.Fatalf("expected GET to be called exactly %d, called %d times", 1, getCalled) + } + } +} + func TestInit_backendUnset(t *testing.T) { // Create a temporary working directory that is empty td := t.TempDir()