diff --git a/cmd/spire-server/cli/agent/agent_posix_test.go b/cmd/spire-server/cli/agent/agent_posix_test.go index a1e6568bc6f..43ac27f7ade 100644 --- a/cmd/spire-server/cli/agent/agent_posix_test.go +++ b/cmd/spire-server/cli/agent/agent_posix_test.go @@ -14,6 +14,14 @@ var ( Path to the SPIRE Server API socket (default "/tmp/spire-server/private/api.sock") ` listUsage = `Usage of agent list: + -attestationType string + Filter by attestation type, like join_token or x509pop. + -banned value + Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all. + -canReattest value + Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all. + -expiresBefore string + Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07") -matchSelectorsOn string The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -output value @@ -40,8 +48,20 @@ var ( The SPIFFE ID of the agent to evict (agent identity) ` countUsage = `Usage of agent count: + -attestationType string + Filter by attestation type, like join_token or x509pop. + -banned value + Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all. + -canReattest value + Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all. + -expiresBefore string + Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07") + -matchSelectorsOn string + The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -output value Desired output format (pretty, json); default: pretty. + -selector value + A colon-delimited type:value selector. Can be used more than once -socketPath string Path to the SPIRE Server API socket (default "/tmp/spire-server/private/api.sock") ` diff --git a/cmd/spire-server/cli/agent/agent_test.go b/cmd/spire-server/cli/agent/agent_test.go index bf63ba0fd37..6110a88bfd6 100644 --- a/cmd/spire-server/cli/agent/agent_test.go +++ b/cmd/spire-server/cli/agent/agent_test.go @@ -225,6 +225,12 @@ func TestCount(t *testing.T) { expectedReturnCode: 1, expectedStderr: common.AddrError, }, + { + name: "Count by expiresBefore: month out of range", + args: []string{"-expiresBefore", "2001-13-05"}, + expectedReturnCode: 1, + expectedStderr: "Error: date is not valid: parsing time \"2001-13-05\": month out of range\n", + }, } { for _, format := range availableFormats { t.Run(fmt.Sprintf("%s using %s format", tt.name, format), func(t *testing.T) { @@ -389,6 +395,45 @@ func TestList(t *testing.T) { expectedStdoutPretty: "Found 1 attested agent:\n\nSPIFFE ID : spiffe://example.org/spire/agent/agent1", expectedStdoutJSON: `{"agents":[{"id":{"trust_domain":"example.org","path":"/spire/agent/agent1"},"attestation_type":"","x509svid_serial_number":"","x509svid_expires_at":"0","selectors":[],"banned":false,"can_reattest":true}],"next_page_token":""}`, }, + { + name: "by expiresBefore", + args: []string{"-expiresBefore", "2000-01-01 15:04:05 -0700 -07"}, + expectReq: &agentv1.ListAgentsRequest{ + Filter: &agentv1.ListAgentsRequest_Filter{ + ByExpiresBefore: "2000-01-01 15:04:05 -0700 -07", + }, + PageSize: 1000, + }, + existentAgents: testAgents, + expectedStdoutPretty: "Found 1 attested agent:\n\nSPIFFE ID : spiffe://example.org/spire/agent/agent1", + expectedStdoutJSON: `{"agents":[{"id":{"trust_domain":"example.org","path":"/spire/agent/agent1"},"attestation_type":"","x509svid_serial_number":"","x509svid_expires_at":"0","selectors":[],"banned":false,"can_reattest":true}],"next_page_token":""}`, + }, + { + name: "by banned", + args: []string{"-banned", "true"}, + expectReq: &agentv1.ListAgentsRequest{ + Filter: &agentv1.ListAgentsRequest_Filter{ + ByBanned: wrapperspb.Bool(true), + }, + PageSize: 1000, + }, + existentAgents: testAgentsWithBanned, + expectedStdoutPretty: "Found 1 attested agent:\n\nSPIFFE ID : spiffe://example.org/spire/agent/banned", + expectedStdoutJSON: `{"agents":[{"id":{"trust_domain":"example.org","path":"/spire/agent/banned"},"attestation_type":"","x509svid_serial_number":"","x509svid_expires_at":"0","selectors":[],"banned":true,"can_reattest":false}],"next_page_token":""}`, + }, + { + name: "by canReattest", + args: []string{"-canReattest", "true"}, + expectReq: &agentv1.ListAgentsRequest{ + Filter: &agentv1.ListAgentsRequest_Filter{ + ByCanReattest: wrapperspb.Bool(true), + }, + PageSize: 1000, + }, + existentAgents: testAgents, + expectedStdoutPretty: "Found 1 attested agent:\n\nSPIFFE ID : spiffe://example.org/spire/agent/agent1", + expectedStdoutJSON: `{"agents":[{"id":{"trust_domain":"example.org","path":"/spire/agent/agent1"},"attestation_type":"","x509svid_serial_number":"","x509svid_expires_at":"0","selectors":[],"banned":false,"can_reattest":true}],"next_page_token":""}`, + }, { name: "List by selectors: Invalid matcher", args: []string{"-selector", "foo:bar", "-selector", "bar:baz", "-matchSelectorsOn", "NO-MATCHER"}, @@ -407,6 +452,12 @@ func TestList(t *testing.T) { expectedReturnCode: 1, expectedStderr: common.AddrError, }, + { + name: "List by expiresBefore: month out of range", + args: []string{"-expiresBefore", "2001-13-05"}, + expectedReturnCode: 1, + expectedStderr: "Error: date is not valid: parsing time \"2001-13-05\": month out of range\n", + }, } { for _, format := range availableFormats { t.Run(fmt.Sprintf("%s using %s format", tt.name, format), func(t *testing.T) { diff --git a/cmd/spire-server/cli/agent/agent_windows_test.go b/cmd/spire-server/cli/agent/agent_windows_test.go index 965ab5c3a76..7b98b75005a 100644 --- a/cmd/spire-server/cli/agent/agent_windows_test.go +++ b/cmd/spire-server/cli/agent/agent_windows_test.go @@ -14,6 +14,14 @@ var ( Desired output format (pretty, json); default: pretty. ` listUsage = `Usage of agent list: + -attestationType string + Filter by attestation type, like join_token or x509pop. + -banned value + Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all. + -canReattest value + Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all. + -expiresBefore string + Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07") -matchSelectorsOn string The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -namedPipeName string @@ -40,10 +48,22 @@ var ( The SPIFFE ID of the agent to evict (agent identity) ` countUsage = `Usage of agent count: + -attestationType string + Filter by attestation type, like join_token or x509pop. + -banned value + Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all. + -canReattest value + Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all. + -expiresBefore string + Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07") + -matchSelectorsOn string + The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -namedPipeName string Pipe name of the SPIRE Server API named pipe (default "\\spire-server\\private\\api") -output value Desired output format (pretty, json); default: pretty. + -selector value + A colon-delimited type:value selector. Can be used more than once ` showUsage = `Usage of agent show: -namedPipeName string diff --git a/cmd/spire-server/cli/agent/count.go b/cmd/spire-server/cli/agent/count.go index 4b46f77f0d8..ef59d0de7bc 100644 --- a/cmd/spire-server/cli/agent/count.go +++ b/cmd/spire-server/cli/agent/count.go @@ -5,16 +5,39 @@ import ( "errors" "flag" "fmt" + "time" "github.com/mitchellh/cli" agentv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/agent/v1" + "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/cmd/spire-server/util" commoncli "github.com/spiffe/spire/pkg/common/cli" "github.com/spiffe/spire/pkg/common/cliprinter" + "google.golang.org/protobuf/types/known/wrapperspb" ) type countCommand struct { - env *commoncli.Env + // Type and value are delimited by a colon (:) + // ex. "unix:uid:1000" or "spiffe_id:spiffe://example.org/foo" + selectors commoncli.StringsFlag + + // Match used when filtering by selectors + matchSelectorsOn string + + // Filters agents to those that are banned. + banned commoncli.BoolFlag + + // Filters agents by those expires before. + expiresBefore string + + // Filters agents to those matching the attestation type. + attestationType string + + // Filters agents that can re-attest. + canReattest commoncli.BoolFlag + + env *commoncli.Env + printer cliprinter.Printer } @@ -39,8 +62,61 @@ func (*countCommand) Synopsis() string { // Run counts attested agents func (c *countCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient util.ServerClient) error { + filter := &agentv1.CountAgentsRequest_Filter{} + if len(c.selectors) > 0 { + matchBehavior, err := parseToSelectorMatch(c.matchSelectorsOn) + if err != nil { + return err + } + + selectors := make([]*types.Selector, len(c.selectors)) + for i, sel := range c.selectors { + selector, err := util.ParseSelector(sel) + if err != nil { + return fmt.Errorf("error parsing selector %q: %w", sel, err) + } + selectors[i] = selector + } + filter.BySelectorMatch = &types.SelectorMatch{ + Selectors: selectors, + Match: matchBehavior, + } + } + + if c.expiresBefore != "" { + // Parse the time string into a time.Time object + _, err := time.Parse("2006-01-02 15:04:05 -0700 -07", c.expiresBefore) + if err != nil { + return fmt.Errorf("date is not valid: %w", err) + } + filter.ByExpiresBefore = c.expiresBefore + } + + if c.attestationType != "" { + filter.ByAttestationType = c.attestationType + } + + // 0: all, 1: can't reattest, 2: can reattest + if c.canReattest == 1 { + filter.ByCanReattest = wrapperspb.Bool(false) + } + if c.canReattest == 2 { + filter.ByCanReattest = wrapperspb.Bool(true) + } + + // 0: all, 1: no-banned, 2: banned + if c.banned == 1 { + filter.ByBanned = wrapperspb.Bool(false) + } + if c.banned == 2 { + filter.ByBanned = wrapperspb.Bool(true) + } + agentClient := serverClient.NewAgentClient() - countResponse, err := agentClient.CountAgents(ctx, &agentv1.CountAgentsRequest{}) + + countResponse, err := agentClient.CountAgents(ctx, &agentv1.CountAgentsRequest{ + Filter: filter, + }) if err != nil { return err } @@ -49,6 +125,12 @@ func (c *countCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient u } func (c *countCommand) AppendFlags(fs *flag.FlagSet) { + fs.Var(&c.selectors, "selector", "A colon-delimited type:value selector. Can be used more than once") + fs.StringVar(&c.attestationType, "attestationType", "", "Filter by attestation type, like join_token or x509pop.") + fs.Var(&c.canReattest, "canReattest", "Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all.") + fs.Var(&c.banned, "banned", "Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all.") + fs.StringVar(&c.expiresBefore, "expiresBefore", "", "Filter by expiration time (format: \"2006-01-02 15:04:05 -0700 -07\")") + fs.StringVar(&c.matchSelectorsOn, "matchSelectorsOn", "superset", "The match mode used when filtering by selectors. Options: exact, any, superset and subset") cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, prettyPrintCount) } diff --git a/cmd/spire-server/cli/agent/list.go b/cmd/spire-server/cli/agent/list.go index 80b66989ec3..8062294c434 100644 --- a/cmd/spire-server/cli/agent/list.go +++ b/cmd/spire-server/cli/agent/list.go @@ -14,16 +14,32 @@ import ( commoncli "github.com/spiffe/spire/pkg/common/cli" "github.com/spiffe/spire/pkg/common/cliprinter" "github.com/spiffe/spire/pkg/common/idutil" + "google.golang.org/protobuf/types/known/wrapperspb" ) type listCommand struct { - env *commoncli.Env // Type and value are delimited by a colon (:) // ex. "unix:uid:1000" or "spiffe_id:spiffe://example.org/foo" selectors commoncli.StringsFlag - // Match used when filtering agents by selectors + + // Match used when filtering by selectors matchSelectorsOn string - printer cliprinter.Printer + + // Filters agents to those that are banned. + banned commoncli.BoolFlag + + // Filters agents by those expires before. + expiresBefore string + + // Filters agents to those matching the attestation type. + attestationType string + + // Filters agents that can re-attest. + canReattest commoncli.BoolFlag + + env *commoncli.Env + + printer cliprinter.Printer } // NewListCommand creates a new "list" subcommand for "agent" command. @@ -68,6 +84,35 @@ func (c *listCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient ut } } + if c.expiresBefore != "" { + // Parse the time string into a time.Time object + _, err := time.Parse("2006-01-02 15:04:05 -0700 -07", c.expiresBefore) + if err != nil { + return fmt.Errorf("date is not valid: %w", err) + } + filter.ByExpiresBefore = c.expiresBefore + } + + if c.attestationType != "" { + filter.ByAttestationType = c.attestationType + } + + // 0: all, 1: can't reattest, 2: can reattest + if c.canReattest == 1 { + filter.ByCanReattest = wrapperspb.Bool(false) + } + if c.canReattest == 2 { + filter.ByCanReattest = wrapperspb.Bool(true) + } + + // 0: all, 1: no-banned, 2: banned + if c.banned == 1 { + filter.ByBanned = wrapperspb.Bool(false) + } + if c.banned == 2 { + filter.ByBanned = wrapperspb.Bool(true) + } + agentClient := serverClient.NewAgentClient() pageToken := "" @@ -91,8 +136,12 @@ func (c *listCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient ut } func (c *listCommand) AppendFlags(fs *flag.FlagSet) { - fs.StringVar(&c.matchSelectorsOn, "matchSelectorsOn", "superset", "The match mode used when filtering by selectors. Options: exact, any, superset and subset") fs.Var(&c.selectors, "selector", "A colon-delimited type:value selector. Can be used more than once") + fs.StringVar(&c.attestationType, "attestationType", "", "Filter by attestation type, like join_token or x509pop.") + fs.Var(&c.canReattest, "canReattest", "Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all.") + fs.Var(&c.banned, "banned", "Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all.") + fs.StringVar(&c.expiresBefore, "expiresBefore", "", "Filter by expiration time (format: \"2006-01-02 15:04:05 -0700 -07\")") + fs.StringVar(&c.matchSelectorsOn, "matchSelectorsOn", "superset", "The match mode used when filtering by selectors. Options: exact, any, superset and subset") cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, prettyPrintAgents) } diff --git a/cmd/spire-server/cli/entry/count.go b/cmd/spire-server/cli/entry/count.go index c095e76899b..ce3f8153f66 100644 --- a/cmd/spire-server/cli/entry/count.go +++ b/cmd/spire-server/cli/entry/count.go @@ -7,12 +7,39 @@ import ( "github.com/mitchellh/cli" entryv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/entry/v1" + "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/cmd/spire-server/util" commoncli "github.com/spiffe/spire/pkg/common/cli" "github.com/spiffe/spire/pkg/common/cliprinter" + "google.golang.org/protobuf/types/known/wrapperspb" ) type countCommand struct { + // Type and value are delimited by a colon (:) + // ex. "unix:uid:1000" or "spiffe_id:spiffe://example.org/foo" + selectors StringsFlag + + // Workload parent spiffeID + parentID string + + // Workload spiffeID + spiffeID string + + // Entry hint + hint string + + // List of SPIFFE IDs of trust domains the registration entry is federated with + federatesWith StringsFlag + + // Whether or not the entry is for a downstream SPIRE server + downstream bool + + // Match used when filtering by federates with + matchFederatesWithOn string + + // Match used when filtering by selectors + matchSelectorsOn string + printer cliprinter.Printer env *commoncli.Env } @@ -39,7 +66,66 @@ func (*countCommand) Synopsis() string { // Run counts attested entries func (c *countCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient util.ServerClient) error { entryClient := serverClient.NewEntryClient() - countResponse, err := entryClient.CountEntries(ctx, &entryv1.CountEntriesRequest{}) + + filter := &entryv1.CountEntriesRequest_Filter{} + if c.parentID != "" { + id, err := idStringToProto(c.parentID) + if err != nil { + return fmt.Errorf("error parsing parent ID %q: %w", c.parentID, err) + } + filter.ByParentId = id + } + + if c.spiffeID != "" { + id, err := idStringToProto(c.spiffeID) + if err != nil { + return fmt.Errorf("error parsing SPIFFE ID %q: %w", c.spiffeID, err) + } + filter.BySpiffeId = id + } + + if len(c.selectors) != 0 { + matchSelectorBehavior, err := parseToSelectorMatch(c.matchSelectorsOn) + if err != nil { + return err + } + + selectors := make([]*types.Selector, len(c.selectors)) + for i, sel := range c.selectors { + selector, err := util.ParseSelector(sel) + if err != nil { + return fmt.Errorf("error parsing selectors: %w", err) + } + selectors[i] = selector + } + filter.BySelectors = &types.SelectorMatch{ + Selectors: selectors, + Match: matchSelectorBehavior, + } + } + + filter.ByDownstream = wrapperspb.Bool(c.downstream) + + if len(c.federatesWith) > 0 { + matchFederatesWithBehavior, err := parseToFederatesWithMatch(c.matchFederatesWithOn) + if err != nil { + return err + } + + filter.ByFederatesWith = &types.FederatesWithMatch{ + TrustDomains: c.federatesWith, + Match: matchFederatesWithBehavior, + } + } + + if c.hint != "" { + filter.ByHint = wrapperspb.String(c.hint) + } + + countResponse, err := entryClient.CountEntries(ctx, &entryv1.CountEntriesRequest{ + Filter: filter, + }) + if err != nil { return err } @@ -48,6 +134,15 @@ func (c *countCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient u } func (c *countCommand) AppendFlags(fs *flag.FlagSet) { + fs.StringVar(&c.parentID, "parentID", "", "The Parent ID of the records to count") + fs.StringVar(&c.spiffeID, "spiffeID", "", "The SPIFFE ID of the records to count") + fs.BoolVar(&c.downstream, "downstream", false, "A boolean value that, when set, indicates that the entry describes a downstream SPIRE server") + fs.Var(&c.selectors, "selector", "A colon-delimited type:value selector. Can be used more than once") + fs.Var(&c.federatesWith, "federatesWith", "SPIFFE ID of a trust domain an entry is federate with. Can be used more than once") + fs.StringVar(&c.matchFederatesWithOn, "matchFederatesWithOn", "superset", "The match mode used when filtering by federates with. Options: exact, any, superset and subset") + fs.StringVar(&c.matchSelectorsOn, "matchSelectorsOn", "superset", "The match mode used when filtering by selectors. Options: exact, any, superset and subset") + fs.StringVar(&c.hint, "hint", "", "The Hint of the records to count (optional)") + cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, c.prettyPrintCount) } diff --git a/cmd/spire-server/cli/entry/count_test.go b/cmd/spire-server/cli/entry/count_test.go index cfff9ca6f70..e162bd4a96a 100644 --- a/cmd/spire-server/cli/entry/count_test.go +++ b/cmd/spire-server/cli/entry/count_test.go @@ -5,9 +5,11 @@ import ( "testing" entryv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/entry/v1" + "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" ) func TestCountHelp(t *testing.T) { @@ -31,12 +33,262 @@ func TestCount(t *testing.T) { for _, tt := range []struct { name string args []string + expCountReq *entryv1.CountEntriesRequest fakeCountResp *entryv1.CountEntriesResponse serverErr error expOutPretty string expOutJSON string expErr string }{ + { + name: "Count all entries (empty filter)", + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp4, + expOutPretty: "4 registration entries", + expOutJSON: `{"count":4}`, + }, + { + name: "Count by parentID", + args: []string{"-parentID", "spiffe://example.org/father"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByParentId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/father"}, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp2, + expOutPretty: "2 registration entries", + expOutJSON: `{"count":2}`, + }, + { + name: "Count by parent ID using invalid ID", + args: []string{"-parentID", "invalid-id"}, + expErr: "Error: error parsing parent ID \"invalid-id\": scheme is missing or invalid\n", + }, + { + name: "Count by SPIFFE ID", + args: []string{"-spiffeID", "spiffe://example.org/daughter"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/daughter"}, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp2, + expOutPretty: "2 registration entries", + expOutJSON: `{"count":2}`, + }, + { + name: "Count by SPIFFE ID using invalid ID", + args: []string{"-spiffeID", "invalid-id"}, + expErr: "Error: error parsing SPIFFE ID \"invalid-id\": scheme is missing or invalid\n", + }, + { + name: "Count by selectors: default matcher", + args: []string{"-selector", "foo:bar", "-selector", "bar:baz"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySelectors: &types.SelectorMatch{ + Selectors: []*types.Selector{ + {Type: "foo", Value: "bar"}, + {Type: "bar", Value: "baz"}, + }, + Match: types.SelectorMatch_MATCH_SUPERSET, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by selectors: exact matcher", + args: []string{"-selector", "foo:bar", "-selector", "bar:baz", "-matchSelectorsOn", "exact"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySelectors: &types.SelectorMatch{ + Selectors: []*types.Selector{ + {Type: "foo", Value: "bar"}, + {Type: "bar", Value: "baz"}, + }, + Match: types.SelectorMatch_MATCH_EXACT, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by selectors: superset matcher", + args: []string{"-selector", "foo:bar", "-selector", "bar:baz", "-matchSelectorsOn", "superset"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySelectors: &types.SelectorMatch{ + Selectors: []*types.Selector{ + {Type: "foo", Value: "bar"}, + {Type: "bar", Value: "baz"}, + }, + Match: types.SelectorMatch_MATCH_SUPERSET, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by selectors: subset matcher", + args: []string{"-selector", "foo:bar", "-selector", "bar:baz", "-matchSelectorsOn", "subset"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySelectors: &types.SelectorMatch{ + Selectors: []*types.Selector{ + {Type: "foo", Value: "bar"}, + {Type: "bar", Value: "baz"}, + }, + Match: types.SelectorMatch_MATCH_SUBSET, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by selectors: Any matcher", + args: []string{"-selector", "foo:bar", "-selector", "bar:baz", "-matchSelectorsOn", "any"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySelectors: &types.SelectorMatch{ + Selectors: []*types.Selector{ + {Type: "foo", Value: "bar"}, + {Type: "bar", Value: "baz"}, + }, + Match: types.SelectorMatch_MATCH_ANY, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by selectors: Invalid matcher", + args: []string{"-selector", "foo:bar", "-selector", "bar:baz", "-matchSelectorsOn", "NO-MATCHER"}, + expErr: "Error: match behavior \"NO-MATCHER\" unknown\n", + }, + { + name: "Count by selector using invalid selector", + args: []string{"-selector", "invalid-selector"}, + expErr: "Error: error parsing selectors: selector \"invalid-selector\" must be formatted as type:value\n", + }, + { + name: "Server error", + args: []string{"-spiffeID", "spiffe://example.org/daughter"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/daughter"}, + ByDownstream: wrapperspb.Bool(false), + }, + }, + serverErr: status.Error(codes.Internal, "internal server error"), + expErr: "Error: rpc error: code = Internal desc = internal server error\n", + }, + { + name: "Count by Federates With: default matcher", + args: []string{"-federatesWith", "spiffe://domain.test"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByFederatesWith: &types.FederatesWithMatch{ + TrustDomains: []string{"spiffe://domain.test"}, + Match: types.FederatesWithMatch_MATCH_SUPERSET, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by Federates With: exact matcher", + args: []string{"-federatesWith", "spiffe://domain.test", "-matchFederatesWithOn", "exact"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByFederatesWith: &types.FederatesWithMatch{ + TrustDomains: []string{"spiffe://domain.test"}, + Match: types.FederatesWithMatch_MATCH_EXACT, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by Federates With: Any matcher", + args: []string{"-federatesWith", "spiffe://domain.test", "-matchFederatesWithOn", "any"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByFederatesWith: &types.FederatesWithMatch{ + TrustDomains: []string{"spiffe://domain.test"}, + Match: types.FederatesWithMatch_MATCH_ANY, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by Federates With: superset matcher", + args: []string{"-federatesWith", "spiffe://domain.test", "-matchFederatesWithOn", "superset"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByFederatesWith: &types.FederatesWithMatch{ + TrustDomains: []string{"spiffe://domain.test"}, + Match: types.FederatesWithMatch_MATCH_SUPERSET, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by Federates With: subset matcher", + args: []string{"-federatesWith", "spiffe://domain.test", "-matchFederatesWithOn", "subset"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByFederatesWith: &types.FederatesWithMatch{ + TrustDomains: []string{"spiffe://domain.test"}, + Match: types.FederatesWithMatch_MATCH_SUBSET, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by Federates With: Invalid matcher", + args: []string{"-federatesWith", "spiffe://domain.test", "-matchFederatesWithOn", "NO-MATCHER"}, + expErr: "Error: match behavior \"NO-MATCHER\" unknown\n", + }, { name: "4 entries", fakeCountResp: fakeResp4, @@ -73,13 +325,16 @@ func TestCount(t *testing.T) { test.server.err = tt.serverErr test.server.countEntriesResp = tt.fakeCountResp - rc := test.client.Run(test.args(tt.args...)) + args := tt.args + args = append(args, "-output", format) + + rc := test.client.Run(test.args(args...)) if tt.expErr != "" { require.Equal(t, 1, rc) require.Equal(t, tt.expErr, test.stderr.String()) return } - requireOutputBasedOnFormat(t, test.stdout.String(), format, tt.expOutPretty, tt.expOutJSON) + requireOutputBasedOnFormat(t, format, test.stdout.String(), tt.expOutPretty, tt.expOutJSON) require.Equal(t, 0, rc) }) } diff --git a/cmd/spire-server/cli/entry/show.go b/cmd/spire-server/cli/entry/show.go index 5cbfdf10e89..94f55038540 100644 --- a/cmd/spire-server/cli/entry/show.go +++ b/cmd/spire-server/cli/entry/show.go @@ -175,6 +175,8 @@ func (c *showCommand) fetchEntries(ctx context.Context, client entryv1.EntryClie filter.ByHint = wrapperspb.String(c.hint) } + filter.ByDownstream = wrapperspb.Bool(c.downstream) + pageToken := "" for { diff --git a/cmd/spire-server/cli/entry/show_test.go b/cmd/spire-server/cli/entry/show_test.go index 418889cb3f6..d2fee9b3250 100644 --- a/cmd/spire-server/cli/entry/show_test.go +++ b/cmd/spire-server/cli/entry/show_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" ) func TestShowHelp(t *testing.T) { @@ -61,7 +62,9 @@ func TestShow(t *testing.T) { name: "List all entries (empty filter)", expListReq: &entryv1.ListEntriesRequest{ PageSize: listEntriesRequestPageSize, - Filter: &entryv1.ListEntriesRequest_Filter{}, + Filter: &entryv1.ListEntriesRequest_Filter{ + ByDownstream: wrapperspb.Bool(false), + }, }, fakeListResp: fakeRespAll, expOutPretty: fmt.Sprintf("Found 4 entries\n%s%s%s%s", @@ -103,7 +106,8 @@ func TestShow(t *testing.T) { expListReq: &entryv1.ListEntriesRequest{ PageSize: listEntriesRequestPageSize, Filter: &entryv1.ListEntriesRequest_Filter{ - ByParentId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/father"}, + ByParentId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/father"}, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespFather, @@ -124,7 +128,8 @@ func TestShow(t *testing.T) { expListReq: &entryv1.ListEntriesRequest{ PageSize: listEntriesRequestPageSize, Filter: &entryv1.ListEntriesRequest_Filter{ - BySpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/daughter"}, + BySpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/daughter"}, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespDaughter, @@ -152,6 +157,7 @@ func TestShow(t *testing.T) { }, Match: types.SelectorMatch_MATCH_SUPERSET, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespFatherDaughter, @@ -173,6 +179,7 @@ func TestShow(t *testing.T) { }, Match: types.SelectorMatch_MATCH_EXACT, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespFatherDaughter, @@ -194,6 +201,7 @@ func TestShow(t *testing.T) { }, Match: types.SelectorMatch_MATCH_SUPERSET, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespFatherDaughter, @@ -215,6 +223,7 @@ func TestShow(t *testing.T) { }, Match: types.SelectorMatch_MATCH_SUBSET, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespFatherDaughter, @@ -236,6 +245,7 @@ func TestShow(t *testing.T) { }, Match: types.SelectorMatch_MATCH_ANY, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespFatherDaughter, @@ -260,7 +270,8 @@ func TestShow(t *testing.T) { expListReq: &entryv1.ListEntriesRequest{ PageSize: listEntriesRequestPageSize, Filter: &entryv1.ListEntriesRequest_Filter{ - BySpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/daughter"}, + BySpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/daughter"}, + ByDownstream: wrapperspb.Bool(false), }, }, serverErr: status.Error(codes.Internal, "internal server error"), @@ -276,6 +287,7 @@ func TestShow(t *testing.T) { TrustDomains: []string{"spiffe://domain.test"}, Match: types.FederatesWithMatch_MATCH_SUPERSET, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespMotherDaughter, @@ -294,6 +306,7 @@ func TestShow(t *testing.T) { TrustDomains: []string{"spiffe://domain.test"}, Match: types.FederatesWithMatch_MATCH_EXACT, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespMotherDaughter, @@ -312,6 +325,7 @@ func TestShow(t *testing.T) { TrustDomains: []string{"spiffe://domain.test"}, Match: types.FederatesWithMatch_MATCH_ANY, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespMotherDaughter, @@ -330,6 +344,7 @@ func TestShow(t *testing.T) { TrustDomains: []string{"spiffe://domain.test"}, Match: types.FederatesWithMatch_MATCH_SUPERSET, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespMotherDaughter, @@ -348,6 +363,7 @@ func TestShow(t *testing.T) { TrustDomains: []string{"spiffe://domain.test"}, Match: types.FederatesWithMatch_MATCH_SUBSET, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespMotherDaughter, diff --git a/cmd/spire-server/cli/entry/util_posix_test.go b/cmd/spire-server/cli/entry/util_posix_test.go index d3825c12984..7b04cb3f96d 100644 --- a/cmd/spire-server/cli/entry/util_posix_test.go +++ b/cmd/spire-server/cli/entry/util_posix_test.go @@ -112,9 +112,25 @@ const ( Path to the SPIRE Server API socket (default "/tmp/spire-server/private/api.sock") ` countUsage = `Usage of entry count: + -downstream + A boolean value that, when set, indicates that the entry describes a downstream SPIRE server + -federatesWith value + SPIFFE ID of a trust domain an entry is federate with. Can be used more than once + -hint string + The Hint of the records to count (optional) + -matchFederatesWithOn string + The match mode used when filtering by federates with. Options: exact, any, superset and subset (default "superset") + -matchSelectorsOn string + The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -output value Desired output format (pretty, json); default: pretty. + -parentID string + The Parent ID of the records to count + -selector value + A colon-delimited type:value selector. Can be used more than once -socketPath string Path to the SPIRE Server API socket (default "/tmp/spire-server/private/api.sock") + -spiffeID string + The SPIFFE ID of the records to count ` ) diff --git a/cmd/spire-server/cli/entry/util_windows_test.go b/cmd/spire-server/cli/entry/util_windows_test.go index 75e1d1929b0..18f5c88af42 100644 --- a/cmd/spire-server/cli/entry/util_windows_test.go +++ b/cmd/spire-server/cli/entry/util_windows_test.go @@ -112,9 +112,25 @@ const ( Desired output format (pretty, json); default: pretty. ` countUsage = `Usage of entry count: + -downstream + A boolean value that, when set, indicates that the entry describes a downstream SPIRE server + -federatesWith value + SPIFFE ID of a trust domain an entry is federate with. Can be used more than once + -hint string + The Hint of the records to count (optional) + -matchFederatesWithOn string + The match mode used when filtering by federates with. Options: exact, any, superset and subset (default "superset") + -matchSelectorsOn string + The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -namedPipeName string Pipe name of the SPIRE Server API named pipe (default "\\spire-server\\private\\api") -output value Desired output format (pretty, json); default: pretty. + -parentID string + The Parent ID of the records to count + -selector value + A colon-delimited type:value selector. Can be used more than once + -spiffeID string + The SPIFFE ID of the records to count ` ) diff --git a/doc/spire_server.md b/doc/spire_server.md index 9931f5f3a41..5b716f4d265 100644 --- a/doc/spire_server.md +++ b/doc/spire_server.md @@ -356,9 +356,14 @@ Updates registration entries. Displays the total number of registration entries. -| Command | Action | Default | -|:--------------|:------------------------------------|:-----------------------------------| -| `-socketPath` | Path to the SPIRE Server API socket | /tmp/spire-server/private/api.sock | +| Command | Action | Default | +|:-----------------|:-------------------------------------------------------------------------------------------------|:-----------------------------------| +| `-downstream` | A boolean value that, when set, indicates that the entry describes a downstream SPIRE server | | +| `-federatesWith` | SPIFFE ID of a trust domain an entry is federate with. Can be used more than once | | +| `-parentID` | The Parent ID of the records to count. | | +| `-selector` | A colon-delimited type:value selector. Can be used more than once to specify multiple selectors. | | +| `-socketPath` | Path to the SPIRE Server API socket | /tmp/spire-server/private/api.sock | +| `-spiffeID` | The SPIFFE ID of the records to count. | | ### `spire-server entry delete` @@ -512,7 +517,11 @@ Displays the total number of attested nodes. | Command | Action | Default | |:--------------|:------------------------------------|:-----------------------------------| -| `-socketPath` | Path to the SPIRE Server API socket | /tmp/spire-server/private/api.sock | +| `-selector` | A colon-delimited type:value selector. Can be used more than once to specify multiple selectors. | | +| `-canReattest` | Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all | | +| `-banned` | Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all | | +| `-expiresBefore` | Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07") | | +| `-spiffeID` | The SPIFFE ID of the records to count. | | ### `spire-server agent evict` @@ -529,7 +538,13 @@ Displays attested nodes. | Command | Action | Default | |:--------------|:------------------------------------|:-----------------------------------| -| `-socketPath` | Path to the SPIRE Server API socket | /tmp/spire-server/private/api.sock | +| Command | Action | Default | +|:--------------|:------------------------------------|:-----------------------------------| +| `-selector` | A colon-delimited type:value selector. Can be used more than once to specify multiple selectors. | | +| `-canReattest` | Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all | | +| `-banned` | Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all | | +| `-expiresBefore` | Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07")| | +| `-attestationType` | Filters agents to those matching the attestation type, like join_token or x509pop. | | ### `spire-server agent show` diff --git a/go.mod b/go.mod index 9dc5a8c9eb1..0723cc178a7 100644 --- a/go.mod +++ b/go.mod @@ -68,7 +68,7 @@ require ( github.com/sigstore/sigstore v1.8.2 github.com/sirupsen/logrus v1.9.3 github.com/spiffe/go-spiffe/v2 v2.1.7 - github.com/spiffe/spire-api-sdk v1.2.5-0.20240222231036-08f5a1ab98c6 + github.com/spiffe/spire-api-sdk v1.2.5-0.20240301205221-967353a5c821 github.com/spiffe/spire-plugin-sdk v1.4.4-0.20230721151831-bf67dde4721d github.com/stretchr/testify v1.9.0 github.com/uber-go/tally/v4 v4.1.12 diff --git a/go.sum b/go.sum index 9b12f374343..07a63f03552 100644 --- a/go.sum +++ b/go.sum @@ -1391,8 +1391,8 @@ github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMV github.com/spiffe/go-spiffe/v2 v2.1.6/go.mod h1:eVDqm9xFvyqao6C+eQensb9ZPkyNEeaUbqbBpOhBnNk= github.com/spiffe/go-spiffe/v2 v2.1.7 h1:VUkM1yIyg/x8X7u1uXqSRVRCdMdfRIEdFBzpqoeASGk= github.com/spiffe/go-spiffe/v2 v2.1.7/go.mod h1:QJDGdhXllxjxvd5B+2XnhhXB/+rC8gr+lNrtOryiWeE= -github.com/spiffe/spire-api-sdk v1.2.5-0.20240222231036-08f5a1ab98c6 h1:gCctMhffEF4KcrLP85qQwOeQoHCMMYlDL1HR0fEZ+sE= -github.com/spiffe/spire-api-sdk v1.2.5-0.20240222231036-08f5a1ab98c6/go.mod h1:4uuhFlN6KBWjACRP3xXwrOTNnvaLp1zJs8Lribtr4fI= +github.com/spiffe/spire-api-sdk v1.2.5-0.20240301205221-967353a5c821 h1:ws5/mYxmiZtw/67nymx5hnSJo8Kx2Q1UkQqiSt8TU74= +github.com/spiffe/spire-api-sdk v1.2.5-0.20240301205221-967353a5c821/go.mod h1:4uuhFlN6KBWjACRP3xXwrOTNnvaLp1zJs8Lribtr4fI= github.com/spiffe/spire-plugin-sdk v1.4.4-0.20230721151831-bf67dde4721d h1:LCRQGU6vOqKLfRrG+GJQrwMwDILcAddAEIf4/1PaSVc= github.com/spiffe/spire-plugin-sdk v1.4.4-0.20230721151831-bf67dde4721d/go.mod h1:GA6o2PVLwyJdevT6KKt5ZXCY/ziAPna13y/seGk49Ik= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/pkg/common/cli/flags.go b/pkg/common/cli/flags.go index 0b336e7f466..9770f821a2f 100644 --- a/pkg/common/cli/flags.go +++ b/pkg/common/cli/flags.go @@ -45,3 +45,29 @@ func (s *StringsFlag) Set(val string) error { *s = append(*s, val) return nil } + +// BoolFlag is used to define 3 possible states: true, false, or all. +// Take care that false=1, and true=2 +type BoolFlag int + +const BoolFlagAll = 0 +const BoolFlagFalse = 1 +const BoolFlagTrue = 2 + +func (b *BoolFlag) String() string { + return "" +} + +func (b *BoolFlag) Set(val string) error { + if val == "false" { + *b = BoolFlagFalse + return nil + } + if val == "true" { + *b = BoolFlagTrue + return nil + } + // if the value received isn't true or false, it will set the default value + *b = BoolFlagAll + return nil +} diff --git a/pkg/common/telemetry/server/datastore/wrapper.go b/pkg/common/telemetry/server/datastore/wrapper.go index 645b8501a21..d97ce5bf881 100644 --- a/pkg/common/telemetry/server/datastore/wrapper.go +++ b/pkg/common/telemetry/server/datastore/wrapper.go @@ -186,10 +186,10 @@ func (w metricsWrapper) ListRegistrationEntriesEvents(ctx context.Context, req * return w.ds.ListRegistrationEntriesEvents(ctx, req) } -func (w metricsWrapper) CountAttestedNodes(ctx context.Context) (_ int32, err error) { +func (w metricsWrapper) CountAttestedNodes(ctx context.Context, req *datastore.CountAttestedNodesRequest) (_ int32, err error) { callCounter := StartCountNodeCall(w.m) defer callCounter.Done(&err) - return w.ds.CountAttestedNodes(ctx) + return w.ds.CountAttestedNodes(ctx, req) } func (w metricsWrapper) CountBundles(ctx context.Context) (_ int32, err error) { @@ -198,10 +198,10 @@ func (w metricsWrapper) CountBundles(ctx context.Context) (_ int32, err error) { return w.ds.CountBundles(ctx) } -func (w metricsWrapper) CountRegistrationEntries(ctx context.Context) (_ int32, err error) { +func (w metricsWrapper) CountRegistrationEntries(ctx context.Context, req *datastore.CountRegistrationEntriesRequest) (_ int32, err error) { callCounter := StartCountRegistrationCall(w.m) defer callCounter.Done(&err) - return w.ds.CountRegistrationEntries(ctx) + return w.ds.CountRegistrationEntries(ctx, req) } func (w metricsWrapper) PruneAttestedNodesEvents(ctx context.Context, olderThan time.Duration) (err error) { diff --git a/pkg/common/telemetry/server/datastore/wrapper_test.go b/pkg/common/telemetry/server/datastore/wrapper_test.go index d085ee34d4c..8c7d7fcc24e 100644 --- a/pkg/common/telemetry/server/datastore/wrapper_test.go +++ b/pkg/common/telemetry/server/datastore/wrapper_test.go @@ -318,7 +318,7 @@ func (ds *fakeDataStore) AppendBundle(context.Context, *common.Bundle) (*common. return &common.Bundle{}, ds.err } -func (ds *fakeDataStore) CountAttestedNodes(context.Context) (int32, error) { +func (ds *fakeDataStore) CountAttestedNodes(context.Context, *datastore.CountAttestedNodesRequest) (int32, error) { return 0, ds.err } @@ -326,7 +326,7 @@ func (ds *fakeDataStore) CountBundles(context.Context) (int32, error) { return 0, ds.err } -func (ds *fakeDataStore) CountRegistrationEntries(context.Context) (int32, error) { +func (ds *fakeDataStore) CountRegistrationEntries(context.Context, *datastore.CountRegistrationEntriesRequest) (int32, error) { return 0, ds.err } diff --git a/pkg/server/api/agent/v1/service.go b/pkg/server/api/agent/v1/service.go index 0cc4b6629e8..b9889c056bf 100644 --- a/pkg/server/api/agent/v1/service.go +++ b/pkg/server/api/agent/v1/service.go @@ -69,8 +69,44 @@ func RegisterService(s grpc.ServiceRegistrar, service *Service) { } // CountAgents returns the total number of agents. -func (s *Service) CountAgents(ctx context.Context, _ *agentv1.CountAgentsRequest) (*agentv1.CountAgentsResponse, error) { - count, err := s.ds.CountAttestedNodes(ctx) +func (s *Service) CountAgents(ctx context.Context, req *agentv1.CountAgentsRequest) (*agentv1.CountAgentsResponse, error) { + log := rpccontext.Logger(ctx) + + countReq := &datastore.CountAttestedNodesRequest{} + + // Parse proto filter into datastore request + if req.Filter != nil { + filter := req.Filter + rpccontext.AddRPCAuditFields(ctx, fieldsFromCountAgentsRequest(filter)) + + if filter.ByBanned != nil { + countReq.ByBanned = &req.Filter.ByBanned.Value + } + if filter.ByCanReattest != nil { + countReq.ByCanReattest = &req.Filter.ByCanReattest.Value + } + + if filter.ByAttestationType != "" { + countReq.ByAttestationType = filter.ByAttestationType + } + + if filter.ByExpiresBefore != "" { + countReq.ByExpiresBefore, _ = time.Parse("2006-01-02 15:04:05 -0700 -07", filter.ByExpiresBefore) + } + + if filter.BySelectorMatch != nil { + selectors, err := api.SelectorsFromProto(filter.BySelectorMatch.Selectors) + if err != nil { + return nil, api.MakeErr(log, codes.InvalidArgument, "failed to parse selectors", err) + } + countReq.BySelectorMatch = &datastore.BySelectors{ + Match: datastore.MatchBehavior(filter.BySelectorMatch.Match), + Selectors: selectors, + } + } + } + + count, err := s.ds.CountAttestedNodes(ctx, countReq) if err != nil { log := rpccontext.Logger(ctx) return nil, api.MakeErr(log, codes.Internal, "failed to count agents", err) @@ -92,20 +128,22 @@ func (s *Service) ListAgents(ctx context.Context, req *agentv1.ListAgentsRequest // Parse proto filter into datastore request if req.Filter != nil { filter := req.Filter - rpccontext.AddRPCAuditFields(ctx, fieldsFromFilterRequest(filter)) + rpccontext.AddRPCAuditFields(ctx, fieldsFromListAgentsRequest(filter)) - var byBanned *bool if filter.ByBanned != nil { - byBanned = &filter.ByBanned.Value + listReq.ByBanned = &req.Filter.ByBanned.Value } - var byCanReattest *bool if filter.ByCanReattest != nil { - byCanReattest = &filter.ByCanReattest.Value + listReq.ByCanReattest = &req.Filter.ByCanReattest.Value + } + + if filter.ByAttestationType != "" { + listReq.ByAttestationType = filter.ByAttestationType } - listReq.ByAttestationType = filter.ByAttestationType - listReq.ByBanned = byBanned - listReq.ByCanReattest = byCanReattest + if filter.ByExpiresBefore != "" { + listReq.ByExpiresBefore, _ = time.Parse("2006-01-02 15:04:05 -0700 -07", filter.ByExpiresBefore) + } if filter.BySelectorMatch != nil { selectors, err := api.SelectorsFromProto(filter.BySelectorMatch.Selectors) @@ -686,7 +724,30 @@ func getAttestAgentResponse(spiffeID spiffeid.ID, certificates []*x509.Certifica } } -func fieldsFromFilterRequest(filter *agentv1.ListAgentsRequest_Filter) logrus.Fields { +func fieldsFromListAgentsRequest(filter *agentv1.ListAgentsRequest_Filter) logrus.Fields { + fields := logrus.Fields{} + + if filter.ByAttestationType != "" { + fields[telemetry.NodeAttestorType] = filter.ByAttestationType + } + + if filter.ByBanned != nil { + fields[telemetry.ByBanned] = filter.ByBanned.Value + } + + if filter.ByCanReattest != nil { + fields[telemetry.ByCanReattest] = filter.ByCanReattest.Value + } + + if filter.BySelectorMatch != nil { + fields[telemetry.BySelectorMatch] = filter.BySelectorMatch.Match.String() + fields[telemetry.BySelectors] = api.SelectorFieldFromProto(filter.BySelectorMatch.Selectors) + } + + return fields +} + +func fieldsFromCountAgentsRequest(filter *agentv1.CountAgentsRequest_Filter) logrus.Fields { fields := logrus.Fields{} if filter.ByAttestationType != "" { diff --git a/pkg/server/api/debug/v1/service.go b/pkg/server/api/debug/v1/service.go index c1ab073b190..163be708d37 100644 --- a/pkg/server/api/debug/v1/service.go +++ b/pkg/server/api/debug/v1/service.go @@ -78,12 +78,11 @@ func (s *Service) GetInfo(ctx context.Context, _ *debugv1.GetInfoRequest) (*debu // Update cache when expired or does not exists if s.getInfoResp.ts.IsZero() || s.clock.Now().Sub(s.getInfoResp.ts) >= cacheExpiry { - nodes, err := s.ds.CountAttestedNodes(ctx) + nodes, err := s.ds.CountAttestedNodes(ctx, &datastore.CountAttestedNodesRequest{}) if err != nil { return nil, api.MakeErr(log, codes.Internal, "failed to count agents", err) } - - entries, err := s.ds.CountRegistrationEntries(ctx) + entries, err := s.ds.CountRegistrationEntries(ctx, &datastore.CountRegistrationEntriesRequest{}) if err != nil { return nil, api.MakeErr(log, codes.Internal, "failed to count entries", err) } diff --git a/pkg/server/api/entry/v1/service.go b/pkg/server/api/entry/v1/service.go index 2ff683af8c0..180d66c2413 100644 --- a/pkg/server/api/entry/v1/service.go +++ b/pkg/server/api/entry/v1/service.go @@ -61,8 +61,70 @@ func RegisterService(s grpc.ServiceRegistrar, service *Service) { } // CountEntries returns the total number of entries. -func (s *Service) CountEntries(ctx context.Context, _ *entryv1.CountEntriesRequest) (*entryv1.CountEntriesResponse, error) { - count, err := s.ds.CountRegistrationEntries(ctx) +func (s *Service) CountEntries(ctx context.Context, req *entryv1.CountEntriesRequest) (*entryv1.CountEntriesResponse, error) { + log := rpccontext.Logger(ctx) + countReq := &datastore.CountRegistrationEntriesRequest{} + + if req.Filter != nil { + rpccontext.AddRPCAuditFields(ctx, fieldsFromCountEntryFilter(ctx, s.td, req.Filter)) + if req.Filter.ByHint != nil { + countReq.ByHint = req.Filter.ByHint.GetValue() + } + + if req.Filter.ByParentId != nil { + parentID, err := api.TrustDomainMemberIDFromProto(ctx, s.td, req.Filter.ByParentId) + if err != nil { + return nil, api.MakeErr(log, codes.InvalidArgument, "malformed parent ID filter", err) + } + countReq.ByParentID = parentID.String() + } + + if req.Filter.BySpiffeId != nil { + spiffeID, err := api.TrustDomainWorkloadIDFromProto(ctx, s.td, req.Filter.BySpiffeId) + if err != nil { + return nil, api.MakeErr(log, codes.InvalidArgument, "malformed SPIFFE ID filter", err) + } + countReq.BySpiffeID = spiffeID.String() + } + + if req.Filter.BySelectors != nil { + dsSelectors, err := api.SelectorsFromProto(req.Filter.BySelectors.Selectors) + if err != nil { + return nil, api.MakeErr(log, codes.InvalidArgument, "malformed selectors filter", err) + } + if len(dsSelectors) == 0 { + return nil, api.MakeErr(log, codes.InvalidArgument, "malformed selectors filter", errors.New("empty selector set")) + } + countReq.BySelectors = &datastore.BySelectors{ + Match: datastore.MatchBehavior(req.Filter.BySelectors.Match), + Selectors: dsSelectors, + } + } + + if req.Filter.ByFederatesWith != nil { + trustDomains := make([]string, 0, len(req.Filter.ByFederatesWith.TrustDomains)) + for _, tdStr := range req.Filter.ByFederatesWith.TrustDomains { + td, err := spiffeid.TrustDomainFromString(tdStr) + if err != nil { + return nil, api.MakeErr(log, codes.InvalidArgument, "malformed federates with filter", err) + } + trustDomains = append(trustDomains, td.IDString()) + } + if len(trustDomains) == 0 { + return nil, api.MakeErr(log, codes.InvalidArgument, "malformed federates with filter", errors.New("empty trust domain set")) + } + countReq.ByFederatesWith = &datastore.ByFederatesWith{ + Match: datastore.MatchBehavior(req.Filter.ByFederatesWith.Match), + TrustDomains: trustDomains, + } + } + + if req.Filter.ByDownstream != nil { + countReq.ByDownstream = &req.Filter.ByDownstream.Value + } + } + + count, err := s.ds.CountRegistrationEntries(ctx, countReq) if err != nil { log := rpccontext.Logger(ctx) return nil, api.MakeErr(log, codes.Internal, "failed to count entries", err) @@ -139,6 +201,10 @@ func (s *Service) ListEntries(ctx context.Context, req *entryv1.ListEntriesReque TrustDomains: trustDomains, } } + + if req.Filter.ByDownstream != nil { + listReq.ByDownstream = &req.Filter.ByDownstream.Value + } } dsResp, err := s.ds.ListRegistrationEntries(ctx, listReq) @@ -725,6 +791,46 @@ func fieldsFromListEntryFilter(ctx context.Context, td spiffeid.TrustDomain, fil fields[telemetry.FederatesWith] = strings.Join(filter.ByFederatesWith.TrustDomains, ",") } + if filter.ByDownstream != nil { + fields[telemetry.Downstream] = &filter.ByDownstream.Value + } + + return fields +} + +func fieldsFromCountEntryFilter(ctx context.Context, td spiffeid.TrustDomain, filter *entryv1.CountEntriesRequest_Filter) logrus.Fields { + fields := logrus.Fields{} + + if filter.ByHint != nil { + fields[telemetry.Hint] = filter.ByHint.Value + } + + if filter.ByParentId != nil { + if parentID, err := api.TrustDomainMemberIDFromProto(ctx, td, filter.ByParentId); err == nil { + fields[telemetry.ParentID] = parentID.String() + } + } + + if filter.BySpiffeId != nil { + if id, err := api.TrustDomainWorkloadIDFromProto(ctx, td, filter.BySpiffeId); err == nil { + fields[telemetry.SPIFFEID] = id.String() + } + } + + if filter.BySelectors != nil { + fields[telemetry.BySelectorMatch] = filter.BySelectors.Match.String() + fields[telemetry.BySelectors] = api.SelectorFieldFromProto(filter.BySelectors.Selectors) + } + + if filter.ByFederatesWith != nil { + fields[telemetry.FederatesWithMatch] = filter.ByFederatesWith.Match.String() + fields[telemetry.FederatesWith] = strings.Join(filter.ByFederatesWith.TrustDomains, ",") + } + + if filter.ByDownstream != nil { + fields[telemetry.Downstream] = &filter.ByDownstream.Value + } + return fields } diff --git a/pkg/server/datastore/datastore.go b/pkg/server/datastore/datastore.go index db96a1976a0..4b81e85cf52 100644 --- a/pkg/server/datastore/datastore.go +++ b/pkg/server/datastore/datastore.go @@ -31,7 +31,7 @@ type DataStore interface { RevokeJWTKey(ctx context.Context, trustDomainID string, authorityID string) (*common.PublicKey, error) // Entries - CountRegistrationEntries(context.Context) (int32, error) + CountRegistrationEntries(context.Context, *CountRegistrationEntriesRequest) (int32, error) CreateRegistrationEntry(context.Context, *common.RegistrationEntry) (*common.RegistrationEntry, error) CreateOrReturnRegistrationEntry(context.Context, *common.RegistrationEntry) (*common.RegistrationEntry, bool, error) DeleteRegistrationEntry(ctx context.Context, entryID string) (*common.RegistrationEntry, error) @@ -46,7 +46,7 @@ type DataStore interface { GetLatestRegistrationEntryEventID(ctx context.Context) (uint, error) // Nodes - CountAttestedNodes(context.Context) (int32, error) + CountAttestedNodes(context.Context, *CountAttestedNodesRequest) (int32, error) CreateAttestedNode(context.Context, *common.AttestedNode) (*common.AttestedNode, error) DeleteAttestedNode(ctx context.Context, spiffeID string) (*common.AttestedNode, error) FetchAttestedNode(ctx context.Context, spiffeID string) (*common.AttestedNode, error) @@ -206,6 +206,7 @@ type ListRegistrationEntriesRequest struct { Pagination *Pagination ByFederatesWith *ByFederatesWith ByHint string + ByDownstream *bool } type CAJournal struct { @@ -242,6 +243,25 @@ type ListFederationRelationshipsResponse struct { Pagination *Pagination } +type CountAttestedNodesRequest struct { + ByAttestationType string + ByBanned *bool + ByExpiresBefore time.Time + BySelectorMatch *BySelectors + FetchSelectors bool + ByCanReattest *bool +} + +type CountRegistrationEntriesRequest struct { + DataConsistency DataConsistency + ByParentID string + BySelectors *BySelectors + BySpiffeID string + ByFederatesWith *ByFederatesWith + ByHint string + ByDownstream *bool +} + type BundleEndpointType string const ( diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index 5cb6ef3b46d..096a59717dc 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -310,7 +310,11 @@ func (ds *Plugin) FetchAttestedNode(ctx context.Context, spiffeID string) (attes } // CountAttestedNodes counts all attested nodes -func (ds *Plugin) CountAttestedNodes(ctx context.Context) (count int32, err error) { +func (ds *Plugin) CountAttestedNodes(ctx context.Context, req *datastore.CountAttestedNodesRequest) (count int32, err error) { + if countAttestedNodesHasFilters(req) { + resp, err := countAttestedNodesWithFilters(ctx, ds.db, ds.log, req) + return resp, err + } if err = ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { count, err = countAttestedNodes(tx) return err @@ -474,15 +478,14 @@ func (ds *Plugin) FetchRegistrationEntry(ctx context.Context, } // CountRegistrationEntries counts all registrations (pagination available) -func (ds *Plugin) CountRegistrationEntries(ctx context.Context) (count int32, err error) { - if err = ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { - count, err = countRegistrationEntries(tx) - return err - }); err != nil { - return 0, err +func (ds *Plugin) CountRegistrationEntries(ctx context.Context, req *datastore.CountRegistrationEntriesRequest) (count int32, err error) { + var actDb = ds.db + if req.DataConsistency == datastore.TolerateStale && ds.roDb != nil { + actDb = ds.roDb } - return count, nil + resp, err := countRegistrationEntries(ctx, actDb, ds.log, req) + return resp, err } // ListRegistrationEntries lists all registrations (pagination available) @@ -1550,6 +1553,16 @@ func countAttestedNodes(tx *gorm.DB) (int32, error) { return int32(count), nil } +func countAttestedNodesHasFilters(req *datastore.CountAttestedNodesRequest) bool { + if req.ByAttestationType != "" || req.ByBanned != nil || !req.ByExpiresBefore.IsZero() { + return true + } + if req.BySelectorMatch != nil || !req.FetchSelectors || req.ByCanReattest != nil { + return true + } + return false +} + func listAttestedNodes(ctx context.Context, db *sqlDB, log logrus.FieldLogger, req *datastore.ListAttestedNodesRequest) (*datastore.ListAttestedNodesResponse, error) { if req.Pagination != nil && req.Pagination.PageSize == 0 { return nil, status.Error(codes.InvalidArgument, "cannot paginate with pagesize = 0") @@ -1600,6 +1613,48 @@ func listAttestedNodes(ctx context.Context, db *sqlDB, log logrus.FieldLogger, r } } +func countAttestedNodesWithFilters(ctx context.Context, db *sqlDB, _ logrus.FieldLogger, req *datastore.CountAttestedNodesRequest) (int32, error) { + if req.BySelectorMatch != nil && len(req.BySelectorMatch.Selectors) == 0 { + return -1, status.Error(codes.InvalidArgument, "cannot list by empty selectors set") + } + + var val int32 + listReq := &datastore.ListAttestedNodesRequest{ + ByAttestationType: req.ByAttestationType, + ByBanned: req.ByBanned, + ByExpiresBefore: req.ByExpiresBefore, + BySelectorMatch: req.BySelectorMatch, + FetchSelectors: req.FetchSelectors, + ByCanReattest: req.ByCanReattest, + Pagination: &datastore.Pagination{ + Token: "", + PageSize: 1000, + }, + } + for { + resp, err := listAttestedNodesOnce(ctx, db, listReq) + if err != nil { + return -1, err + } + + if len(resp.Nodes) == 0 { + return val, nil + } + + if req.BySelectorMatch != nil { + switch req.BySelectorMatch.Match { + case datastore.Exact, datastore.Subset: + resp.Nodes = filterNodesBySelectorSet(resp.Nodes, req.BySelectorMatch.Selectors) + default: + } + } + + val += int32(len(resp.Nodes)) + + listReq.Pagination = resp.Pagination + } +} + func createAttestedNodeEvent(tx *gorm.DB, spiffeID string) error { newAttestedNodeEvent := AttestedNodeEvent{ SpiffeID: spiffeID, @@ -1740,7 +1795,6 @@ func listAttestedNodesOnce(ctx context.Context, db *sqlDB, req *datastore.ListAt resp.Pagination.Token = strconv.FormatUint(lastEID, 10) } } - return resp, nil } @@ -1798,7 +1852,6 @@ func buildListAttestedNodesQueryCTE(req *datastore.ListAttestedNodesRequest, dbT builder.WriteString("\t\tAND data_type = ?\n") args = append(args, req.ByAttestationType) } - // Filter by banned, an Attestation Node is banned when serial number is empty. // This filter allows 3 outputs: // - nil: returns all @@ -1811,8 +1864,11 @@ func buildListAttestedNodesQueryCTE(req *datastore.ListAttestedNodesRequest, dbT builder.WriteString("\t\tAND serial_number <> ''\n") } } - - // Filter by CanReattest. This is similar to ByBanned + // Filter by canReattest, + // This filter allows 3 outputs: + // - nil: returns all + // - true: returns nodes with canReattest=true + // - false: returns nodes with canReattest=false if req.ByCanReattest != nil { if *req.ByCanReattest { builder.WriteString("\t\tAND can_reattest = true\n") @@ -1960,7 +2016,6 @@ SELECT } builder.WriteString("\n) ORDER BY id ASC\n") - return builder.String(), args, nil } @@ -2654,15 +2709,6 @@ ORDER BY selector_id, dns_name_id return query, []any{entryID}, nil } -func countRegistrationEntries(tx *gorm.DB) (int32, error) { - var count int - if err := tx.Model(&RegisteredEntry{}).Count(&count).Error; err != nil { - return 0, sqlError.Wrap(err) - } - - return int32(count), nil -} - func listRegistrationEntries(ctx context.Context, db *sqlDB, log logrus.FieldLogger, req *datastore.ListRegistrationEntriesRequest) (*datastore.ListRegistrationEntriesResponse, error) { if req.Pagination != nil && req.Pagination.PageSize == 0 { return nil, status.Error(codes.InvalidArgument, "cannot paginate with pagesize = 0") @@ -2757,7 +2803,6 @@ func listRegistrationEntriesOnce(ctx context.Context, db queryContext, databaseT return nil, sqlError.Wrap(err) } defer rows.Close() - var entries []*common.RegistrationEntry if req.Pagination != nil { entries = make([]*common.RegistrationEntry, 0, req.Pagination.PageSize) @@ -2840,8 +2885,12 @@ func buildListRegistrationEntriesQuery(dbType string, supportsCTE bool, req *dat func buildListRegistrationEntriesQuerySQLite3(req *datastore.ListRegistrationEntriesRequest) (string, []any, error) { builder := new(strings.Builder) - filtered, args, err := appendListRegistrationEntriesFilterQuery("\nWITH listing AS (\n", builder, SQLite, req) + var downstream = false + if req.ByDownstream != nil { + downstream = *req.ByDownstream + } + if err != nil { return "", nil, err } @@ -2873,9 +2922,17 @@ SELECT FROM registered_entries `) + if filtered { builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } builder.WriteString(` UNION @@ -2924,6 +2981,11 @@ func buildListRegistrationEntriesQueryPostgreSQL(req *datastore.ListRegistration builder := new(strings.Builder) filtered, args, err := appendListRegistrationEntriesFilterQuery("\nWITH listing AS (\n", builder, PostgreSQL, req) + var downstream = false + if req.ByDownstream != nil { + downstream = *req.ByDownstream + } + if err != nil { return "", nil, err } @@ -2958,6 +3020,13 @@ FROM if filtered { builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } builder.WriteString(` UNION ALL @@ -3051,6 +3120,11 @@ LEFT JOIN `) filtered, args, err := appendListRegistrationEntriesFilterQuery("WHERE E.id IN (\n", builder, MySQL, req) + var downstream = false + if req.ByDownstream != nil { + downstream = *req.ByDownstream + } + if err != nil { return "", nil, err } @@ -3058,7 +3132,13 @@ LEFT JOIN if filtered { builder.WriteString(")") } - + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } builder.WriteString("\nORDER BY e_id, selector_id, dns_name_id\n;") return builder.String(), args, nil @@ -3068,6 +3148,11 @@ func buildListRegistrationEntriesQueryMySQLCTE(req *datastore.ListRegistrationEn builder := new(strings.Builder) filtered, args, err := appendListRegistrationEntriesFilterQuery("\nWITH listing AS (\n", builder, MySQL, req) + var downstream = false + if req.ByDownstream != nil { + downstream = *req.ByDownstream + } + if err != nil { return "", nil, err } @@ -3102,6 +3187,13 @@ FROM if filtered { builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } builder.WriteString(` UNION @@ -3146,6 +3238,52 @@ ORDER BY e_id, selector_id, dns_name_id return builder.String(), args, nil } +// Count Registration Entries +func countRegistrationEntries(ctx context.Context, db *sqlDB, _ logrus.FieldLogger, req *datastore.CountRegistrationEntriesRequest) (int32, error) { + if req.BySelectors != nil && len(req.BySelectors.Selectors) == 0 { + return 0, status.Error(codes.InvalidArgument, "cannot list by empty selector set") + } + + var val int32 + listReq := &datastore.ListRegistrationEntriesRequest{ + DataConsistency: req.DataConsistency, + ByParentID: req.ByParentID, + BySelectors: req.BySelectors, + BySpiffeID: req.BySpiffeID, + ByFederatesWith: req.ByFederatesWith, + ByHint: req.ByHint, + ByDownstream: req.ByDownstream, + Pagination: &datastore.Pagination{ + Token: "", + PageSize: 1000, + }, + } + + for { + resp, err := listRegistrationEntriesOnce(ctx, db.raw, db.databaseType, db.supportsCTE, listReq) + + if err != nil { + return -1, err + } + + if len(resp.Entries) == 0 { + return val, nil + } + + if req.BySelectors != nil { + switch req.BySelectors.Match { + case datastore.Exact, datastore.Subset: + resp.Entries = filterEntriesBySelectorSet(resp.Entries, req.BySelectors.Selectors) + default: + } + } + + val += int32(len(resp.Entries)) + + listReq.Pagination = resp.Pagination + } +} + type idFilterNode struct { idColumn string diff --git a/pkg/server/datastore/sqlstore/sqlstore_test.go b/pkg/server/datastore/sqlstore/sqlstore_test.go index 482a2938809..a600dc3c36d 100644 --- a/pkg/server/datastore/sqlstore/sqlstore_test.go +++ b/pkg/server/datastore/sqlstore/sqlstore_test.go @@ -507,7 +507,7 @@ func (s *PluginSuite) TestCountBundles() { func (s *PluginSuite) TestCountAttestedNodes() { // Count empty attested nodes - count, err := s.ds.CountAttestedNodes(ctx) + count, err := s.ds.CountAttestedNodes(ctx, &datastore.CountAttestedNodesRequest{}) s.Require().NoError(err) s.Require().Equal(int32(0), count) @@ -531,14 +531,14 @@ func (s *PluginSuite) TestCountAttestedNodes() { s.Require().NoError(err) // Count all - count, err = s.ds.CountAttestedNodes(ctx) + count, err = s.ds.CountAttestedNodes(ctx, &datastore.CountAttestedNodesRequest{}) s.Require().NoError(err) s.Require().Equal(int32(2), count) } func (s *PluginSuite) TestCountRegistrationEntries() { // Count empty registration entries - count, err := s.ds.CountRegistrationEntries(ctx) + count, err := s.ds.CountRegistrationEntries(ctx, &datastore.CountRegistrationEntriesRequest{}) s.Require().NoError(err) s.Require().Equal(int32(0), count) @@ -560,7 +560,7 @@ func (s *PluginSuite) TestCountRegistrationEntries() { s.Require().NoError(err) // Count all - count, err = s.ds.CountRegistrationEntries(ctx) + count, err = s.ds.CountRegistrationEntries(ctx, &datastore.CountRegistrationEntriesRequest{}) s.Require().NoError(err) s.Require().Equal(int32(2), count) } diff --git a/test/fakes/fakedatastore/fakedatastore.go b/test/fakes/fakedatastore/fakedatastore.go index b0f8d89440e..404958983ec 100644 --- a/test/fakes/fakedatastore/fakedatastore.go +++ b/test/fakes/fakedatastore/fakedatastore.go @@ -121,11 +121,11 @@ func (s *DataStore) PruneBundle(ctx context.Context, trustDomainID string, expir return s.ds.PruneBundle(ctx, trustDomainID, expiresBefore) } -func (s *DataStore) CountAttestedNodes(ctx context.Context) (int32, error) { +func (s *DataStore) CountAttestedNodes(ctx context.Context, req *datastore.CountAttestedNodesRequest) (int32, error) { if err := s.getNextError(); err != nil { return 0, err } - return s.ds.CountAttestedNodes(ctx) + return s.ds.CountAttestedNodes(ctx, req) } func (s *DataStore) CreateAttestedNode(ctx context.Context, node *common.AttestedNode) (*common.AttestedNode, error) { @@ -238,11 +238,11 @@ func (s *DataStore) GetNodeSelectors(ctx context.Context, spiffeID string, dataC return selectors, err } -func (s *DataStore) CountRegistrationEntries(ctx context.Context) (int32, error) { +func (s *DataStore) CountRegistrationEntries(ctx context.Context, req *datastore.CountRegistrationEntriesRequest) (int32, error) { if err := s.getNextError(); err != nil { return 0, err } - return s.ds.CountRegistrationEntries(ctx) + return s.ds.CountRegistrationEntries(ctx, req) } func (s *DataStore) CreateRegistrationEntry(ctx context.Context, entry *common.RegistrationEntry) (*common.RegistrationEntry, error) {