From 43eb24b79ae8e921df813389cbb9dc36aa87942f Mon Sep 17 00:00:00 2001 From: Daniel N <2color@users.noreply.github.com> Date: Thu, 12 Sep 2024 17:52:35 +0200 Subject: [PATCH] feat: add protocol filtering implements https://github.com/ipfs/specs/pull/484 --- routing/http/server/filters.go | 143 +++++++++++++ routing/http/server/filters_test.go | 309 ++++++++++++++++++++++++++++ routing/http/server/server.go | 84 +++++++- routing/http/server/server_test.go | 68 ++++-- 4 files changed, 581 insertions(+), 23 deletions(-) create mode 100644 routing/http/server/filters.go create mode 100644 routing/http/server/filters_test.go diff --git a/routing/http/server/filters.go b/routing/http/server/filters.go new file mode 100644 index 000000000..e9411a00a --- /dev/null +++ b/routing/http/server/filters.go @@ -0,0 +1,143 @@ +package server + +import ( + "reflect" + "slices" + "strings" + + "github.com/ipfs/boxo/routing/http/types" + "github.com/multiformats/go-multiaddr" +) + +// filters implements IPIP-0484 + +func parseFilter(param string) []string { + if param == "" { + return nil + } + return strings.Split(strings.ToLower(param), ",") +} + +func filterProviders(providers []types.Record, filterAddrs, filterProtocols []string) []types.Record { + if len(filterAddrs) == 0 && len(filterProtocols) == 0 { + return providers + } + + filtered := make([]types.Record, 0, len(providers)) + + for _, provider := range providers { + if schema := provider.GetSchema(); schema == types.SchemaPeer { + peer, ok := provider.(*types.PeerRecord) + if !ok { + logger.Errorw("problem casting find providers result", "Schema", provider.GetSchema(), "Type", reflect.TypeOf(provider).String()) + // if the type assertion fails, we exlude record from results + continue + } + + record := applyFilters(peer, filterAddrs, filterProtocols) + + if record != nil { + filtered = append(filtered, record) + } + + } else { + // Will we ever encounter the SchemaBitswap type? Evidence seems to suggest that no longer + logger.Errorw("encountered unknown provider schema", "Schema", provider.GetSchema(), "Type", reflect.TypeOf(provider).String()) + } + } + return filtered +} + +// Applies the filters. Returns nil if the provider does not pass the protocols filter +// The address filter is more complicated because it potentially modifies the Addrs slice. +func applyFilters(provider *types.PeerRecord, filterAddrs, filterProtocols []string) *types.PeerRecord { + if !applyProtocolFilter(provider.Protocols, filterProtocols) { + // If the provider doesn't match any of the passed protocols, the provider is omitted from the response. + return nil + } + + // return untouched if there's no filter or filterAddrsQuery contains "unknown" and provider has no addrs + if len(filterAddrs) == 0 || (len(provider.Addrs) == 0 && slices.Contains(filterAddrs, "unknown")) { + return provider + } + + filteredAddrs := applyAddrFilter(provider.Addrs, filterAddrs) + + // If filtering resulted in no addrs, omit the provider + if len(filteredAddrs) == 0 { + return nil + } + + provider.Addrs = filteredAddrs + return provider +} + +// If there are only negative filters, no addresses will be included in the result. The function will return an empty list. +// For an address to be included, it must pass all negative filters AND match at least one positive filter. +func applyAddrFilter(addrs []types.Multiaddr, filterAddrsQuery []string) []types.Multiaddr { + if len(filterAddrsQuery) == 0 { + return addrs + } + + filteredAddrs := make([]types.Multiaddr, 0, len(addrs)) + + for _, addr := range addrs { + protocols := addr.Protocols() + includeAddr := true + + // First, check all negative filters + for _, filter := range filterAddrsQuery { + if strings.HasPrefix(filter, "!") { + protocolStringFromFilter := strings.TrimPrefix(filter, "!") + protocolFromFilter := multiaddr.ProtocolWithName(protocolStringFromFilter) + if containsProtocol(protocols, protocolFromFilter) { + includeAddr = false + break + } + } + } + + // If the address passed all negative filters, check positive filters + if includeAddr { + for _, filter := range filterAddrsQuery { + if !strings.HasPrefix(filter, "!") { + protocolFromFilter := multiaddr.ProtocolWithName(filter) + if containsProtocol(protocols, protocolFromFilter) { + filteredAddrs = append(filteredAddrs, addr) + break + } + } + } + } + } + return filteredAddrs +} + +func containsProtocol(protos []multiaddr.Protocol, proto multiaddr.Protocol) bool { + for _, p := range protos { + if p.Code == proto.Code { + return true + } + } + return false +} + +func applyProtocolFilter(peerProtocols []string, filterProtocols []string) bool { + if len(filterProtocols) == 0 { + // If no filter is passed, do not filter + return true + } + + for _, filter := range filterProtocols { + filterProtocol := strings.TrimPrefix(filter, "!") + + if filterProtocol == "unknown" && len(peerProtocols) == 0 { + return true + } + + for _, peerProtocol := range peerProtocols { + return peerProtocol == filterProtocol + } + } + return false +} diff --git a/routing/http/server/filters_test.go b/routing/http/server/filters_test.go new file mode 100644 index 000000000..85f16380c --- /dev/null +++ b/routing/http/server/filters_test.go @@ -0,0 +1,309 @@ +package server + +import ( + "testing" + + "github.com/ipfs/boxo/routing/http/types" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplyAddrFilter(t *testing.T) { + // Create some test multiaddrs + addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr2, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/udp/4001/quic/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr3, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001/ws/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr4, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr5, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/udp/4001/quic-v1/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr6, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/certhash/uEiD9f05PrY82lovP4gOFonmY7sO0E7_jyovt9p2LEcAS-Q/certhash/uEiBtGJsNz-PcywwXOVzEYeQQloQiHMqDqdj18t2Fe4GTLQ/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr7, _ := multiaddr.NewMultiaddr("/dns4/ny5.bootstrap.libp2p.io/tcp/443/wss/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + addr8, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/udp/4001/quic-v1/webtransport/certhash/uEiAMrMcVWFNiqtSeRXZTwHTac4p9WcGh5hg8kVBzTC1JTA/certhash/uEiA4dfvbbbnBIYalhp1OpW1Bk-nuWIKSy21ol6vPea67Cw/p2p/QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt") + + addrs := []types.Multiaddr{ + {Multiaddr: addr1}, + {Multiaddr: addr2}, + {Multiaddr: addr3}, + {Multiaddr: addr4}, + {Multiaddr: addr5}, + {Multiaddr: addr6}, + {Multiaddr: addr7}, + {Multiaddr: addr8}, + } + + testCases := []struct { + name string + filterAddrs []string + expectedAddrs []types.Multiaddr + }{ + { + name: "No filter", + filterAddrs: []string{}, + expectedAddrs: addrs, + }, + { + name: "Filter TCP", + filterAddrs: []string{"tcp"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr1}, {Multiaddr: addr3}, {Multiaddr: addr4}, {Multiaddr: addr7}}, + }, + { + name: "Filter UDP", + filterAddrs: []string{"udp"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr2}, {Multiaddr: addr5}, {Multiaddr: addr6}, {Multiaddr: addr8}}, + }, + { + name: "Filter WebSocket", + filterAddrs: []string{"ws"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr3}}, + }, + { + name: "Exclude TCP", + filterAddrs: []string{"!tcp"}, + expectedAddrs: []types.Multiaddr{}, + }, + { + name: "Include WebTransport and exclude p2p-circuit", + filterAddrs: []string{"webtransport", "!p2p-circuit"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr8}}, + }, + { + name: "empty for unknown protocol nae", + filterAddrs: []string{"fakeproto"}, + expectedAddrs: []types.Multiaddr{}, + }, + { + name: "Include WebTransport but ignore unknown protocol name", + filterAddrs: []string{"webtransport", "fakeproto"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr6}, {Multiaddr: addr8}}, + }, + { + name: "Multiple filters", + filterAddrs: []string{"tcp", "ws"}, + expectedAddrs: []types.Multiaddr{{Multiaddr: addr1}, {Multiaddr: addr3}, {Multiaddr: addr4}, {Multiaddr: addr7}}, + }, + { + name: "Multiple negative filters", + filterAddrs: []string{"!tcp", "!ws"}, + expectedAddrs: []types.Multiaddr{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := applyAddrFilter(addrs, tc.filterAddrs) + assert.Equal(t, len(tc.expectedAddrs), len(result), "Unexpected number of addresses after filtering") + + // Check that each expected address is in the result + for _, expectedAddr := range tc.expectedAddrs { + found := false + for _, resultAddr := range result { + if expectedAddr.Multiaddr.Equal(resultAddr.Multiaddr) { + found = true + break + } + } + assert.True(t, found, "Expected address not found in test %s result: %s", tc.name, expectedAddr.Multiaddr) + } + + // Check that each result address is in the expected list + for _, resultAddr := range result { + found := false + for _, expectedAddr := range tc.expectedAddrs { + if resultAddr.Multiaddr.Equal(expectedAddr.Multiaddr) { + found = true + break + } + } + assert.True(t, found, "Unexpected address found in test %s result: %s", tc.name, resultAddr.Multiaddr) + } + }) + } +} + +func TestApplyProtocolFilter(t *testing.T) { + testCases := []struct { + name string + peerProtocols []string + filterProtocols []string + expected bool + }{ + { + name: "No filter", + peerProtocols: []string{"transport-bitswap", "transport-ipfs-gateway-http"}, + filterProtocols: []string{}, + expected: true, + }, + { + name: "Single matching protocol", + peerProtocols: []string{"transport-bitswap", "transport-ipfs-gateway-http"}, + filterProtocols: []string{"transport-bitswap"}, + expected: true, + }, + { + name: "Single non-matching protocol", + peerProtocols: []string{"transport-bitswap", "transport-ipfs-gateway-http"}, + filterProtocols: []string{"transport-graphsync-filecoinv1"}, + expected: false, + }, + { + name: "Multiple protocols, one match", + peerProtocols: []string{"transport-bitswap", "transport-ipfs-gateway-http"}, + filterProtocols: []string{"transport-graphsync-filecoinv1", "transport-ipfs-gateway-http"}, + expected: true, + }, + { + name: "Negative filter, no match", + peerProtocols: []string{"transport-bitswap", "transport-ipfs-gateway-http"}, + filterProtocols: []string{"!transport-graphsync-filecoinv1"}, + expected: true, + }, + { + name: "Negative filter, with match", + peerProtocols: []string{"transport-bitswap", "transport-ipfs-gateway-http"}, + filterProtocols: []string{"!transport-ipfs-gateway-http"}, + expected: false, + }, + { + name: "Mixed positive and negative filters, no match", + peerProtocols: []string{"transport-bitswap", "transport-ipfs-gateway-http"}, + filterProtocols: []string{"transport-graphsync-filecoinv1", "!transport-ipfs-gateway-http"}, + expected: false, + }, + { + name: "Unknown protocol for empty peer protocols", + peerProtocols: []string{}, + filterProtocols: []string{"unknown"}, + expected: true, + }, + { + // TODO: Does this case make sense? + name: "Unknown protocol for non-empty peer protocols", + peerProtocols: []string{"transport-bitswap"}, + filterProtocols: []string{"unknown"}, + expected: false, + }, + { + name: "Case insensitive match", + peerProtocols: []string{"TRANSPORT-BITSWAP", "Transport-IPFS-Gateway-HTTP"}, + filterProtocols: []string{"transport-bitswap"}, + expected: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := applyProtocolFilter(tc.peerProtocols, tc.filterProtocols) + assert.Equal(t, tc.expected, result, "Unexpected result for test case: %s", tc.name) + }) + } +} + +func TestApplyFilters(t *testing.T) { + pid, err := peer.Decode("12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn") + require.NoError(t, err) + + tests := []struct { + name string + provider *types.PeerRecord + filterAddrs []string + filterProtocols []string + expected *types.PeerRecord + }{ + { + name: "No filters", + provider: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + Protocols: []string{"transport-ipfs-gateway-http"}, + }, + filterAddrs: []string{}, + filterProtocols: []string{}, + expected: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + Protocols: []string{"transport-ipfs-gateway-http"}, + }, + }, + { + name: "Protocol filter", + provider: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001"), + mustMultiaddr(t, "/ip4/127.0.0.1/udp/4001/quic-v1"), + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001/ws"), + mustMultiaddr(t, "/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + Protocols: []string{"transport-ipfs-gateway-http"}, + }, + filterAddrs: []string{}, + filterProtocols: []string{"transport-ipfs-gateway-http", "transport-bitswap"}, + expected: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001"), + mustMultiaddr(t, "/ip4/127.0.0.1/udp/4001/quic-v1"), + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001/ws"), + mustMultiaddr(t, "/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + Protocols: []string{"transport-ipfs-gateway-http"}, + }, + }, + { + name: "Address filter", + provider: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001"), + mustMultiaddr(t, "/ip4/127.0.0.1/udp/4001/quic-v1"), + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001/ws"), + mustMultiaddr(t, "/ip4/127.0.0.1/udp/4001/webrtc-direct/certhash/uEiCZqN653gMqxrWNmYuNg7Emwb-wvtsuzGE3XD6rypViZA"), + mustMultiaddr(t, "/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + Protocols: []string{"transport-ipfs-gateway-http"}, + }, + filterAddrs: []string{"webtransport", "wss", "webrtc-direct", "!p2p-circuit"}, + filterProtocols: []string{"transport-ipfs-gateway-http", "transport-bitswap"}, + expected: &types.PeerRecord{ + ID: &pid, + Addrs: []types.Multiaddr{ + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001"), + mustMultiaddr(t, "/ip4/127.0.0.1/udp/4001/quic-v1"), + mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001/ws"), + mustMultiaddr(t, "/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit"), + mustMultiaddr(t, "/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"), + }, + Protocols: []string{"transport-ipfs-gateway-http"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := applyFilters(tt.provider, tt.filterAddrs, tt.filterProtocols) + assert.Equal(t, tt.expected, result) + }) + } +} + +func mustMultiaddr(t *testing.T, s string) types.Multiaddr { + addr, err := multiaddr.NewMultiaddr(s) + if err != nil { + t.Fatalf("Failed to create multiaddr: %v", err) + } + return types.Multiaddr{Multiaddr: addr} +} diff --git a/routing/http/server/server.go b/routing/http/server/server.go index 1e1a84770..20d7e6fec 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -194,6 +194,11 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { return } + // Parse query parameters + query := httpReq.URL.Query() + filterAddrs := parseFilter(query.Get("filter-addrs")) + filterProtocols := parseFilter(query.Get("filter-protocols")) + mediaType, err := s.detectResponseType(httpReq) if err != nil { writeErr(w, "FindProviders", http.StatusBadRequest, err) @@ -201,7 +206,7 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { } var ( - handlerFunc func(w http.ResponseWriter, provIter iter.ResultIter[types.Record]) + handlerFunc func(w http.ResponseWriter, provIter iter.ResultIter[types.Record], filterAddrs, filterProtocols []string) recordsLimit int ) @@ -224,10 +229,10 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) { } } - handlerFunc(w, provIter) + handlerFunc(w, provIter, filterAddrs, filterProtocols) } -func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record]) { +func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record], filterAddrs, filterProtocols []string) { defer provIter.Close() providers, err := iter.ReadAllResults(provIter) @@ -236,13 +241,78 @@ func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIt return } + filteredProviders := filterProviders(providers, filterAddrs, filterProtocols) + writeJSONResult(w, "FindProviders", jsontypes.ProvidersResponse{ - Providers: providers, + Providers: filteredProviders, }) } +func (s *server) findProvidersNDJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record], filterAddrs, filterProtocols []string) { + defer provIter.Close() + + w.Header().Set("Content-Type", mediaTypeNDJSON) + w.Header().Add("Vary", "Accept") + w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat)) + + hasResults := false + for provIter.Next() { + res := provIter.Val() + if res.Err != nil { + logger.Errorw("ndjson iterator error", "Error", res.Err) + return + } + + // handle filtering per record as we iterate + if len(filterAddrs) > 0 || len(filterProtocols) > 0 { + switch v := res.Val.(type) { + case *types.PeerRecord: + record := applyFilters(v, filterAddrs, filterProtocols) + if record == nil { + // if the record is nil, we skip it + continue + } + res.Val = record + default: + logger.Warn("unexpected type for res.Val, expected types.PeerRecord") + continue + } + } + + // don't use an encoder because we can't easily differentiate writer errors from encoding errors + b, err := drjson.MarshalJSONBytes(res.Val) + if err != nil { + logger.Errorw("ndjson marshal error", "Error", err) + return + } + + if !hasResults { + hasResults = true + // There's results, cache useful result for longer + setCacheControl(w, maxAgeWithResults, maxStale) + } + + _, err = w.Write(b) + if err != nil { + logger.Warn("ndjson write error", "Error", err) + return + } + + _, err = w.Write([]byte{'\n'}) + if err != nil { + logger.Warn("ndjson write error", "Error", err) + return + } -func (s *server) findProvidersNDJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record]) { - writeResultsIterNDJSON(w, provIter) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + + if !hasResults { + // There weren't results, cache for shorter and send 404 + setCacheControl(w, maxAgeWithoutResults, maxStale) + w.WriteHeader(http.StatusNotFound) + } } func (s *server) findPeers(w http.ResponseWriter, r *http.Request) { @@ -572,7 +642,7 @@ func logErr(method, msg string, err error) { logger.Infow(msg, "Method", method, "Error", err) } -func writeResultsIterNDJSON[T any](w http.ResponseWriter, resultIter iter.ResultIter[T]) { +func writeResultsIterNDJSON[T types.Record](w http.ResponseWriter, resultIter iter.ResultIter[T]) { defer resultIter.Close() w.Header().Set("Content-Type", mediaTypeNDJSON) diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index 3f4e7906a..5bdfedb51 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -22,6 +22,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/routing" b58 "github.com/mr-tron/base58/base58" + "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -93,6 +94,13 @@ func TestProviders(t *testing.T) { pid2Str := "12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz" cidStr := "bafkreifjjcie6lypi6ny7amxnfftagclbuxndqonfipmb64f2km2devei4" + addr1, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + addr2, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/udp/4001/quic-v1") + addr3, _ := multiaddr.NewMultiaddr("/ip4/127.0.0.1/tcp/4001/ws") + addr4, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit") + addr5, _ := multiaddr.NewMultiaddr("/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit") + addr6, _ := multiaddr.NewMultiaddr("/ip4/8.8.8.8/udp/4001/quic-v1/webtransport") + pid, err := peer.Decode(pidStr) require.NoError(t, err) pid2, err := peer.Decode(pid2Str) @@ -101,7 +109,7 @@ func TestProviders(t *testing.T) { cid, err := cid.Decode(cidStr) require.NoError(t, err) - runTest := func(t *testing.T, contentType string, empty bool, expectedStream bool, expectedBody string) { + runTest := func(t *testing.T, contentType string, filterAddrs, filterProtocols string, empty bool, expectedStream bool, expectedBody string) { t.Parallel() var results *iter.SliceIter[iter.Result[types.Record]] @@ -114,16 +122,22 @@ func TestProviders(t *testing.T) { Schema: types.SchemaPeer, ID: &pid, Protocols: []string{"transport-bitswap"}, + Addrs: []types.Multiaddr{ + {Multiaddr: addr1}, + {Multiaddr: addr2}, + {Multiaddr: addr3}, + {Multiaddr: addr4}, + {Multiaddr: addr5}, + {Multiaddr: addr6}, + }, + }}, + {Val: &types.PeerRecord{ + Schema: types.SchemaPeer, + ID: &pid2, + Protocols: []string{"transport-ipfs-gateway-http"}, Addrs: []types.Multiaddr{}, }}, - //lint:ignore SA1019 // ignore staticcheck - {Val: &types.BitswapRecord{ - //lint:ignore SA1019 // ignore staticcheck - Schema: types.SchemaBitswap, - ID: &pid2, - Protocol: "transport-bitswap", - Addrs: []types.Multiaddr{}, - }}}, + }, ) } @@ -136,7 +150,7 @@ func TestProviders(t *testing.T) { limit = DefaultStreamingRecordsLimit } router.On("FindProviders", mock.Anything, cid, limit).Return(results, nil) - urlStr := serverAddr + "/routing/v1/providers/" + cidStr + urlStr := serverAddr + "/routing/v1/providers/" + cidStr + "?filter-addrs=" + filterAddrs + "&filter-protocols=" + filterProtocols req, err := http.NewRequest(http.MethodGet, urlStr, nil) require.NoError(t, err) @@ -174,29 +188,51 @@ func TestProviders(t *testing.T) { } t.Run("JSON Response", func(t *testing.T) { - runTest(t, mediaTypeJSON, false, false, `{"Providers":[{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"},{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}]}`) + runTest(t, mediaTypeJSON, "", "", false, false, `{"Providers":[{"Addrs":["/ip4/127.0.0.1/tcp/4001","/ip4/127.0.0.1/udp/4001/quic-v1","/ip4/127.0.0.1/tcp/4001/ws","/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit","/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit","/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"},{"Addrs":["/ip4/127.0.0.1/tcp/4001","/ip4/127.0.0.1/udp/4001/quic-v1","/ip4/127.0.0.1/tcp/4001/ws"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Protocols":["transport-ipfs-gateway-http"],"Schema":"peer"}]}`) + }) + + t.Run("JSON Response with addr filtering including unknown", func(t *testing.T) { + runTest(t, mediaTypeJSON, "webtransport,!p2p-circuit,unknown", "", false, false, `{"Providers":[{"Addrs":["/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"},{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Protocols":["transport-ipfs-gateway-http"],"Schema":"peer"}]}`) + }) + + t.Run("JSON Response with addr filtering", func(t *testing.T) { + runTest(t, mediaTypeJSON, "webtransport,!p2p-circuit", "", false, false, `{"Providers":[{"Addrs":["/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}]}`) + }) + + t.Run("JSON Response with protocol and addr filtering", func(t *testing.T) { + runTest(t, mediaTypeJSON, "quic-v1", "transport-bitswap", false, false, + `{"Providers":[{"Addrs":["/ip4/127.0.0.1/udp/4001/quic-v1","/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit","/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}]}`) + }) + + t.Run("JSON Response with protocol filtering", func(t *testing.T) { + runTest(t, mediaTypeJSON, "", "transport-ipfs-gateway-http", false, false, + `{"Providers":[{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Protocols":["transport-ipfs-gateway-http"],"Schema":"peer"}]}`) }) t.Run("Empty JSON Response", func(t *testing.T) { - runTest(t, mediaTypeJSON, true, false, `{"Providers":null}`) + runTest(t, mediaTypeJSON, "", "", true, false, `{"Providers":null}`) }) t.Run("Wildcard Accept header defaults to JSON Response", func(t *testing.T) { accept := "text/html,*/*" - runTest(t, accept, true, false, `{"Providers":null}`) + runTest(t, accept, "", "", true, false, `{"Providers":null}`) }) t.Run("Missing Accept header defaults to JSON Response", func(t *testing.T) { accept := "" - runTest(t, accept, true, false, `{"Providers":null}`) + runTest(t, accept, "", "", true, false, `{"Providers":null}`) }) t.Run("NDJSON Response", func(t *testing.T) { - runTest(t, mediaTypeNDJSON, false, true, `{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}`+"\n"+`{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}`+"\n") + runTest(t, mediaTypeNDJSON, "", "", false, true, `{"Addrs":["/ip4/127.0.0.1/tcp/4001","/ip4/127.0.0.1/udp/4001/quic-v1","/ip4/127.0.0.1/tcp/4001/ws","/ip4/102.101.1.1/tcp/4001/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit","/ip4/102.101.1.1/udp/4001/quic-v1/webtransport/p2p/12D3KooWEjsGPUQJ4Ej3d1Jcg4VckWhFbhc6mkGunMm1faeSzZMu/p2p-circuit","/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}`+"\n"+`{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}`+"\n") + }) + + t.Run("NDJSON Response with addr filtering", func(t *testing.T) { + runTest(t, mediaTypeNDJSON, "webtransport,!p2p-circuit,unknown", "", false, true, `{"Addrs":["/ip4/8.8.8.8/udp/4001/quic-v1/webtransport"],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}`+"\n"+`{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz","Protocols":["transport-ipfs-gateway-http"],"Schema":"peer"}`+"\n") }) t.Run("Empty NDJSON Response", func(t *testing.T) { - runTest(t, mediaTypeNDJSON, true, true, "") + runTest(t, mediaTypeNDJSON, "", "", true, true, "") }) t.Run("404 when router returns routing.ErrNotFound", func(t *testing.T) {