Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loader: allow overwrite of URL hostname again #844

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions internal/loader/artifact_url.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"

"github.com/hashicorp/go-retryablehttp"
digestlib "github.com/opencontainers/go-digest"
Expand All @@ -32,6 +34,13 @@ import (
"helm.sh/helm/v3/pkg/chart/loader"
)

const (
// envSourceControllerLocalhost is the name of the environment variable
// used to override the hostname of the source-controller from which
// the chart is usually downloaded.
envSourceControllerLocalhost = "SOURCE_CONTROLLER_LOCALHOST"
)

var (
// ErrFileNotFound is an error type used to signal 404 HTTP status code responses.
ErrFileNotFound = errors.New("file not found")
Expand All @@ -45,6 +54,11 @@ var (
// digest before loading the chart. It returns the loaded chart.Chart, or an
// error. The error may be of type ErrIntegrity if the integrity check fails.
func SecureLoadChartFromURL(client *retryablehttp.Client, URL, digest string) (*chart.Chart, error) {
URL, err := overwriteHostname(URL, os.Getenv(envSourceControllerLocalhost))
if err != nil {
return nil, err
}

req, err := retryablehttp.NewRequest(http.MethodGet, URL, nil)
if err != nil {
return nil, err
Expand Down Expand Up @@ -94,3 +108,18 @@ func copyAndVerify(digest string, reader io.Reader, writer io.Writer) error {
}
return nil
}

// overwriteHostname overwrites the hostname of the given URL with the given
// hostname. If the hostname is empty, the URL is returned unmodified.
func overwriteHostname(URL, hostname string) (string, error) {
if hostname == "" {
return URL, nil
}

u, err := url.Parse(URL)
if err != nil {
return "", fmt.Errorf("failed to parse URL to overwrite hostname: %w", err)
}
u.Host = hostname
return u.String(), nil
}
55 changes: 55 additions & 0 deletions internal/loader/artifact_url_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"

"github.com/hashicorp/go-retryablehttp"
Expand Down Expand Up @@ -72,6 +73,19 @@ func TestSecureLoadChartFromURL(t *testing.T) {
g.Expect(got.Metadata.Version).To(Equal("0.1.0"))
})

t.Run("overwrites hostname", func(t *testing.T) {
g := NewWithT(t)

t.Setenv(envSourceControllerLocalhost, strings.TrimPrefix(server.URL, "http://"))
wrongHostnameURL := "http://invalid.com" + chartPath

got, err := SecureLoadChartFromURL(client, wrongHostnameURL, digest.String())
g.Expect(err).ToNot(HaveOccurred())
g.Expect(got).ToNot(BeNil())
g.Expect(got.Name()).To(Equal("chart"))
g.Expect(got.Metadata.Version).To(Equal("0.1.0"))
})

t.Run("error on chart data digest mismatch", func(t *testing.T) {
g := NewWithT(t)

Expand Down Expand Up @@ -162,3 +176,44 @@ func Test_copyAndVerify(t *testing.T) {
})
}
}

func Test_overwriteHostname(t *testing.T) {
tests := []struct {
name string
URL string
hostname string
want string
wantErr bool
}{
{
name: "overwrite hostname",
URL: "http://example.com",
hostname: "localhost",
want: "http://localhost",
},
{
name: "overwrite hostname with port",
URL: "http://example.com",
hostname: "localhost:9090",
want: "http://localhost:9090",
},
{
name: "no hostname",
URL: "http://example.com",
hostname: "",
want: "http://example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := overwriteHostname(tt.URL, tt.hostname)
if (err != nil) != tt.wantErr {
t.Errorf("overwriteHostname() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("overwriteHostname() got = %v, want %v", got, tt.want)
}
})
}
}