From 548b6b1bdb352375a3f187ee4d2c8b2b6c6dce1a Mon Sep 17 00:00:00 2001 From: Philip Laine Date: Fri, 12 Apr 2024 17:11:55 +0200 Subject: [PATCH] Replace http util reverese proxy with custom request forwarding --- CHANGELOG.md | 2 ++ pkg/registry/registry.go | 56 ++++++++++++++++++++++------------------ 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e51f33e1..124f638b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- [#436](https://github.com/spegel-org/spegel/pull/436) Replace http util reverese proxy with custom request forwarding. + ### Deprecated ### Removed diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index fee9bcb3..b8a36f39 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -6,7 +6,6 @@ import ( "io" "net" "net/http" - "net/http/httputil" "net/url" "path" "strconv" @@ -32,7 +31,7 @@ type Registry struct { throttler *throttle.Throttler ociClient oci.Client router routing.Router - transport http.RoundTripper + httpClient *http.Client localAddr string resolveRetries int resolveTimeout time.Duration @@ -61,7 +60,7 @@ func WithResolveTimeout(resolveTimeout time.Duration) Option { func WithTransport(transport http.RoundTripper) Option { return func(r *Registry) { - r.transport = transport + r.httpClient.Transport = transport } } @@ -87,6 +86,7 @@ func NewRegistry(ociClient oci.Client, router routing.Router, opts ...Option) *R r := &Registry{ ociClient: ociClient, router: router, + httpClient: &http.Client{}, resolveRetries: 3, resolveTimeout: 1 * time.Second, resolveLatestTag: true, @@ -184,10 +184,8 @@ func (r *Registry) registryHandler(rw mux.ResponseWriter, req *http.Request) { } } - // Request with mirror header are proxied. + // Requests without mirror header set will be mirrored if req.Header.Get(MirroredHeaderKey) != "true" { - // Set mirrored header in request to stop infinite loops - req.Header.Set(MirroredHeaderKey, "true") key := dgst.String() if key == "" { key = ref @@ -241,6 +239,7 @@ func (r *Registry) handleMirror(rw mux.ResponseWriter, req *http.Request, key st rw.WriteError(http.StatusInternalServerError, err) return } + // TODO: Refactor context cancel and mirror channel closing for { select { @@ -255,36 +254,43 @@ func (r *Registry) handleMirror(rw mux.ResponseWriter, req *http.Request, key st return } - // Modify response returns and error on non 200 status code and NOP error handler skips response writing. - // If proxy fails no response is written and it is tried again against a different mirror. - // If the response writer has been written to it means that the request was properly proxied. - succeeded := false scheme := "http" if req.TLS != nil { scheme = "https" } - u := &url.URL{ + u := url.URL{ Scheme: scheme, Host: ipAddr.String(), + Path: req.URL.Path, + // TODO: Should this error early if not set? + RawQuery: fmt.Sprintf("ns=%s", req.URL.Query().Get("ns")), } - proxy := httputil.NewSingleHostReverseProxy(u) - proxy.Transport = r.transport - proxy.ErrorHandler = func(_ http.ResponseWriter, _ *http.Request, err error) { - log.Error(err, "proxy failed attempting next") + forwardReq, err := http.NewRequestWithContext(req.Context(), req.Method, u.String(), nil) + if err != nil { + rw.WriteError(http.StatusInternalServerError, err) + return } - proxy.ModifyResponse = func(resp *http.Response) error { - if resp.StatusCode != http.StatusOK { - err := fmt.Errorf("expected mirror to respond with 200 OK but received: %s", resp.Status) - log.Error(err, "mirror failed attempting next") - return err - } - succeeded = true - return nil + forwardReq.Header.Add(MirroredHeaderKey, "true") + resp, err := r.httpClient.Do(forwardReq) + if err != nil { + log.Error(err, "mirror failed attempting next") + break } - proxy.ServeHTTP(rw, req) - if !succeeded { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + log.Error(fmt.Errorf("expected mirror to respond with 200 OK but received: %s", resp.Status), "mirror failed attempting next") break } + for k, v := range resp.Header { + for _, vv := range v { + rw.Header().Add(k, vv) + } + } + _, err = io.Copy(rw, resp.Body) + if err != nil { + rw.WriteError(http.StatusInternalServerError, err) + return + } log.V(5).Info("mirrored request", "url", u.String()) return }