From 7b7eb1dd4a1429531d0a2cdfacbfb811fd530c13 Mon Sep 17 00:00:00 2001 From: lorneli Date: Mon, 6 Nov 2017 22:16:09 +0800 Subject: [PATCH] http: support QueryUnescape in HTTPPool.ServeHTTP HTTPPool will inverse escaped path if query parameter contains "escaped=true". Add keys with char to be escaped in test. --- http.go | 36 +++++++++++++++++++++++++++++++----- http_test.go | 8 +++++--- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/http.go b/http.go index f37467a7..76083072 100644 --- a/http.go +++ b/http.go @@ -138,19 +138,45 @@ func (p *HTTPPool) PickPeer(key string) (ProtoGetter, bool) { return nil, false } -func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Parse request. +func (p *HTTPPool) parseRequest(r *http.Request) (string, string, bool) { if !strings.HasPrefix(r.URL.Path, p.opts.BasePath) { panic("HTTPPool serving unexpected path: " + r.URL.Path) } parts := strings.SplitN(r.URL.Path[len(p.opts.BasePath):], "/", 2) if len(parts) != 2 { - http.Error(w, "bad request", http.StatusBadRequest) - return + return "", "", false } groupName := parts[0] key := parts[1] + queries, err := url.ParseQuery(r.URL.RawQuery) + if err != nil { + // Still accept groupName and key in path. + return groupName, key, true + } + + var uerr error + if queries.Get("escaped") == "true" { + groupName, uerr = url.QueryUnescape(groupName) + if uerr != nil { + return "", "", false + } + key, uerr = url.QueryUnescape(key) + if uerr != nil { + return "", "", false + } + } + + return groupName, key, true +} + +func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { + groupName, key, ok := p.parseRequest(r) + if !ok { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + // Fetch the value for this group/key. group := GetGroup(groupName) if group == nil { @@ -191,7 +217,7 @@ var bufferPool = sync.Pool{ func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse) error { u := fmt.Sprintf( - "%v%v/%v", + "%v%v/%v?escaped=true", h.baseURL, url.QueryEscape(in.GetGroup()), url.QueryEscape(in.GetKey()), diff --git a/http_test.go b/http_test.go index b42edd7f..977c9d5f 100644 --- a/http_test.go +++ b/http_test.go @@ -103,9 +103,11 @@ func TestHTTPPool(t *testing.T) { } func testKeys(n int) (keys []string) { - keys = make([]string, n) - for i := range keys { - keys[i] = strconv.Itoa(i) + keys = make([]string, 0) + for i := 0; i < n; i++ { + keys = append(keys, strconv.Itoa(i)) + // Keys with char to be escaped + keys = append(keys, " "+strconv.Itoa(i)) } return }