From 785d17ec7bb3b78099b686a51296f10c9069afa3 Mon Sep 17 00:00:00 2001 From: Will Scott Date: Wed, 23 Oct 2024 08:25:11 +0900 Subject: [PATCH] Add ndjson semantics to delegated routing endpoint --- delegated_translator.go | 159 ++++++++++++++++++++----------- find_ndjson.go | 205 ++++++++++++++++++++++++++++++++++++++++ server.go | 2 +- 3 files changed, 310 insertions(+), 56 deletions(-) diff --git a/delegated_translator.go b/delegated_translator.go index 28bc9cd..7e914f8 100644 --- a/delegated_translator.go +++ b/delegated_translator.go @@ -22,9 +22,10 @@ const ( ) type findFunc func(ctx context.Context, method, source string, req *url.URL, encrypted bool) (int, []byte) +type findStreamFunc func(ctx context.Context, method string, req *url.URL, encrypted bool) (int, chan model.ProviderResult) -func NewDelegatedTranslator(backend findFunc) (http.Handler, error) { - finder := delegatedTranslator{backend} +func NewDelegatedTranslator(backend findFunc, streamingBackend findStreamFunc) (http.Handler, error) { + finder := delegatedTranslator{backend, streamingBackend} m := http.NewServeMux() m.HandleFunc("/providers", finder.provide) m.HandleFunc("/encrypted/providers", finder.provide) @@ -34,7 +35,8 @@ func NewDelegatedTranslator(backend findFunc) (http.Handler, error) { } type delegatedTranslator struct { - be findFunc + be findFunc + sbe findStreamFunc } func (dt *delegatedTranslator) provide(w http.ResponseWriter, r *http.Request) { @@ -81,6 +83,48 @@ func (dt *delegatedTranslator) find(w http.ResponseWriter, r *http.Request, encr // Translate URL by mapping `/providers/{CID}` to `/cid/{CID}`. uri := r.URL.JoinPath("../../cid", cidUrlParam) + + acc, err := getAccepts(r) + if err != nil { + http.Error(w, "invalid Accept header", http.StatusBadRequest) + return + } + + switch { + case acc.ndjson: + rcode, respChan := dt.sbe(r.Context(), findMethodDelegated, uri, encrypted) + if rcode != http.StatusOK { + http.Error(w, "", rcode) + return + } + out := &drResp{} + hasWritten := false + encoder := json.NewEncoder(w) + + for rcrd := range respChan { + if !hasWritten { + w.Header().Set("Content-Type", mediaTypeNDJson) + w.Header().Set("Connection", "Keep-Alive") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(200) + hasWritten = true + } + prov := drProvFromResult(rcrd) + // if new + if out.append(prov) { + if err := encoder.Encode(prov); err != nil { + return + } + } + } + if len(out.seenProviders) == 0 { + // no response. + w.WriteHeader(http.StatusNotFound) + } + return + default: + } + rcode, resp := dt.be(r.Context(), http.MethodGet, findMethodDelegated, uri, encrypted) if rcode != http.StatusOK { http.Error(w, "", rcode) @@ -105,63 +149,14 @@ func (dt *delegatedTranslator) find(w http.ResponseWriter, r *http.Request, encr res := parsed.MultihashResults[0] - out := drResp{} + out := &drResp{} // Records returned from IPNI via Delegated Routing don't have ContextID in them. Becuase of that, // some records that are valid from the IPNI point of view might look like duplicates from the Delegated Routing point of view. // To make the Delegated Routing output nicer, deduplicate identical records. - uniqueProviders := map[uint32]struct{}{} - appendIfUnique := func(drp *drProvider) { - capacity := len(drp.ID) + len(drp.Schema) - for _, proto := range drp.Protocols { - capacity += len(proto) - } - for _, meta := range drp.Metadata { - capacity += len(meta) - } - drpb := make([]byte, 0, capacity) - drpb = append(drpb, []byte(drp.ID)...) - for _, proto := range drp.Protocols { - drpb = append(drpb, []byte(proto)...) - } - drpb = append(drpb, []byte(drp.Schema)...) - for _, meta := range drp.Metadata { - drpb = append(drpb, meta...) - } - key := crc32.ChecksumIEEE(drpb) - if _, ok := uniqueProviders[key]; ok { - return - } - uniqueProviders[key] = struct{}{} - out.Providers = append(out.Providers, *drp) - } for _, p := range res.ProviderResults { - md := metadata.Default.New() - err := md.UnmarshalBinary(p.Metadata) - if err != nil { - appendIfUnique(&drProvider{ - Schema: peerSchema, - ID: p.Provider.ID, - Addrs: p.Provider.Addrs, - }) - } else { - provider := &drProvider{ - Schema: peerSchema, - ID: p.Provider.ID, - Addrs: p.Provider.Addrs, - Metadata: make(map[string][]byte), - } - - for _, proto := range md.Protocols() { - pl := md.Get(proto) - plb, _ := pl.MarshalBinary() - provider.Protocols = append(provider.Protocols, proto.String()) - provider.Metadata[proto.String()] = plb - } - - appendIfUnique(provider) - } + out.append(drProvFromResult(p)) } outBytes, err := json.Marshal(out) @@ -174,7 +169,34 @@ func (dt *delegatedTranslator) find(w http.ResponseWriter, r *http.Request, encr } type drResp struct { - Providers []drProvider + Providers []drProvider + seenProviders map[uint32]struct{} +} + +func (dr *drResp) append(drp *drProvider) bool { + capacity := len(drp.ID) + len(drp.Schema) + for _, proto := range drp.Protocols { + capacity += len(proto) + } + for _, meta := range drp.Metadata { + capacity += len(meta) + } + drpb := make([]byte, 0, capacity) + drpb = append(drpb, []byte(drp.ID)...) + for _, proto := range drp.Protocols { + drpb = append(drpb, []byte(proto)...) + } + drpb = append(drpb, []byte(drp.Schema)...) + for _, meta := range drp.Metadata { + drpb = append(drpb, meta...) + } + key := crc32.ChecksumIEEE(drpb) + if _, ok := dr.seenProviders[key]; ok { + return false + } + dr.seenProviders[key] = struct{}{} + dr.Providers = append(dr.Providers, *drp) + return true } type drProvider struct { @@ -185,6 +207,33 @@ type drProvider struct { Metadata map[string][]byte } +func drProvFromResult(p model.ProviderResult) *drProvider { + md := metadata.Default.New() + err := md.UnmarshalBinary(p.Metadata) + if err != nil { + return &drProvider{ + Schema: peerSchema, + ID: p.Provider.ID, + Addrs: p.Provider.Addrs, + } + } else { + provider := &drProvider{ + Schema: peerSchema, + ID: p.Provider.ID, + Addrs: p.Provider.Addrs, + Metadata: make(map[string][]byte), + } + + for _, proto := range md.Protocols() { + pl := md.Get(proto) + plb, _ := pl.MarshalBinary() + provider.Protocols = append(provider.Protocols, proto.String()) + provider.Metadata[proto.String()] = plb + } + return provider + } +} + func (dp drProvider) MarshalJSON() ([]byte, error) { m := map[string]interface{}{} if dp.Metadata != nil { diff --git a/find_ndjson.go b/find_ndjson.go index feff893..6f4e586 100644 --- a/find_ndjson.go +++ b/find_ndjson.go @@ -388,3 +388,208 @@ LOOP: latencyTags = append(latencyTags, tag.Insert(metrics.FoundCaskade, yesno(foundCaskade))) latencyTags = append(latencyTags, tag.Insert(metrics.FoundRegular, yesno(foundRegular))) } + +func (s *server) doFindStreaming(ctx context.Context, method string, req *url.URL, encrypted bool) (int, chan model.ProviderResult) { + start := time.Now() + latencyTags := []tag.Mutator{tag.Insert(metrics.Method, http.MethodGet)} + loadTags := []tag.Mutator{tag.Insert(metrics.Method, method)} + defer func() { + _ = stats.RecordWithOptions(context.Background(), + stats.WithTags(latencyTags...), + stats.WithMeasurements(metrics.FindLatency.M(float64(time.Since(start).Milliseconds())))) + _ = stats.RecordWithOptions(context.Background(), + stats.WithTags(loadTags...), + stats.WithMeasurements(metrics.FindLoad.M(1))) + }() + + maxWait := config.Server.ResultStreamMaxWait + + sg := &scatterGather[Backend, any]{ + backends: s.backends, + maxWait: maxWait, + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + type resultWithBackend struct { + rslt *encryptedOrPlainResult + bknd Backend + } + + resultsChan := make(chan *resultWithBackend, 1) + var count int32 + if err := sg.scatter(ctx, func(cctx context.Context, b Backend) (*any, error) { + // forward double hashed requests to double hashed backends only and regular requests to regular backends + _, isDhBackend := b.(dhBackend) + _, isProvidersBackend := b.(providersBackend) + if (encrypted != isDhBackend) || isProvidersBackend { + return nil, nil + } + + // Copy the URL from original request and override host/schema to point + // to the server. + endpoint := *req + endpoint.Host = b.URL().Host + endpoint.Scheme = b.URL().Scheme + log := log.With("backend", endpoint.Host) + + req, err := http.NewRequestWithContext(cctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + log.Warnw("Failed to construct backend query", "err", err) + return nil, err + } + req.Header.Set("X-Forwarded-Host", req.Host) + req.Header.Set("Accept", mediaTypeNDJson) + + if !b.Matches(req) { + return nil, nil + } + + resp, err := s.Client.Do(req) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Debugw("Backend query ended", "err", err) + } else { + log.Warnw("Failed to query backend", "err", err) + } + return nil, err + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + case http.StatusNotFound: + io.Copy(io.Discard, resp.Body) + atomic.AddInt32(&count, 1) + return nil, nil + default: + bb, _ := io.ReadAll(resp.Body) + body := string(bb) + log := log.With("status", resp.StatusCode, "body", body) + log.Warn("Request processing was not successful") + err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.URL().Host) + if resp.StatusCode < http.StatusInternalServerError { + err = circuitbreaker.MarkAsSuccess(err) + } + return nil, err + } + + scanner := bufio.NewScanner(resp.Body) + for { + select { + case <-cctx.Done(): + return nil, nil + default: + if scanner.Scan() { + var result encryptedOrPlainResult + line := scanner.Bytes() + if len(line) == 0 { + continue + } + atomic.AddInt32(&count, 1) + if err := json.Unmarshal(line, &result); err != nil { + return nil, circuitbreaker.MarkAsSuccess(err) + } + // Sanity check the results in case backends don't respect accept media types; + // see: https://github.com/ipni/storetheindex/issues/1209 + if len(result.EncryptedValueKey) == 0 && (result.Provider.ID == "" || len(result.Provider.Addrs) == 0) { + continue + } + + select { + case <-cctx.Done(): + return nil, nil + case resultsChan <- &resultWithBackend{rslt: &result, bknd: b}: + } + continue + } + if err := scanner.Err(); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Debugw("Reading backend response ended", "err", err) + } else { + log.Warnw("Failed to read backend response", "err", err) + } + + return nil, circuitbreaker.MarkAsSuccess(err) + } + return nil, nil + } + } + }); err != nil { + log.Errorw("Failed to scatter HTTP find request", "err", err) + return http.StatusInternalServerError, nil + } + + out := make(chan model.ProviderResult) + + // Results chan is done when gathering is finished. + // Do this in a separate goroutine to avoid potentially closing results chan twice. + go func() { + for { + select { + case <-ctx.Done(): + return + case _, ok := <-sg.gather(ctx): + if !ok { + close(resultsChan) + return + } + } + } + }() + + go func() { + defer close(out) + + results := newResultSet() + var rs resultStats + var foundCaskade, foundRegular bool + LOOP: + for { + select { + case <-ctx.Done(): + break LOOP + case rwb, ok := <-resultsChan: + if !ok { + break LOOP + } + result := rwb.rslt + absent := results.putIfAbsent(result) + if !absent { + continue + } + + rs.observeResult(result) + + _, isCaskade := rwb.bknd.(caskadeBackend) + foundCaskade = foundCaskade || isCaskade + foundRegular = foundRegular || !isCaskade + + out <- result.ProviderResult + } + } + _ = stats.RecordWithOptions(context.Background(), + stats.WithMeasurements(metrics.FindBackends.M(float64(atomic.LoadInt32(&count))))) + + if len(results) == 0 { + latencyTags = append(latencyTags, tag.Insert(metrics.Found, "no")) + return + } + + rs.reportMetrics(method) + + latencyTags = append(latencyTags, tag.Insert(metrics.Found, "yes")) + yesno := func(yn bool) string { + if yn { + return "yes" + } + return "no" + } + + latencyTags = append(latencyTags, tag.Insert(metrics.FoundCaskade, yesno(foundCaskade))) + latencyTags = append(latencyTags, tag.Insert(metrics.FoundRegular, yesno(foundRegular))) + }() + + return 200, out +} diff --git a/server.go b/server.go index 9c8f403..f55f4f0 100644 --- a/server.go +++ b/server.go @@ -246,7 +246,7 @@ func (s *server) Serve() chan error { mux.HandleFunc("/health", s.health) ec := make(chan error) - delegated, err := NewDelegatedTranslator(s.doFind) + delegated, err := NewDelegatedTranslator(s.doFind, s.doFindStreaming) if err != nil { ec <- err close(ec)