Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 95 additions & 43 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2488,7 +2488,7 @@ func (c *Client) StreamSessionEvents(ctx context.Context, sessionID string, star
oneOf, err := stream.Recv()
if err != nil {
if err != io.EOF {
e <- trace.Wrap(trace.Wrap(err))
e <- trace.Wrap(err)
} else {
close(ch)
}
Expand All @@ -2498,7 +2498,7 @@ func (c *Client) StreamSessionEvents(ctx context.Context, sessionID string, star

event, err := events.FromOneOf(*oneOf)
if err != nil {
e <- trace.Wrap(trace.Wrap(err))
e <- trace.Wrap(err)
break outer
}

Expand Down Expand Up @@ -2593,7 +2593,7 @@ func (c *Client) StreamUnstructuredSessionEvents(ctx context.Context, sessionID
// on the client grpc side.
c.streamUnstructuredSessionEventsFallback(ctx, sessionID, startIndex, ch, e)
} else {
e <- trace.Wrap(trace.Wrap(err))
e <- trace.Wrap(err)
}
return ch, e
}
Expand All @@ -2616,7 +2616,7 @@ func (c *Client) StreamUnstructuredSessionEvents(ctx context.Context, sessionID
go c.streamUnstructuredSessionEventsFallback(ctx, sessionID, startIndex, ch, e)
return
}
e <- trace.Wrap(trace.Wrap(err))
e <- trace.Wrap(err)
} else {
close(ch)
}
Expand Down Expand Up @@ -2663,7 +2663,7 @@ func (c *Client) streamUnstructuredSessionEventsFallback(ctx context.Context, se
oneOf, err := stream.Recv()
if err != nil {
if err != io.EOF {
e <- trace.Wrap(trace.Wrap(err))
e <- trace.Wrap(err)
} else {
close(ch)
}
Expand All @@ -2673,7 +2673,7 @@ func (c *Client) streamUnstructuredSessionEventsFallback(ctx context.Context, se

event, err := events.FromOneOf(*oneOf)
if err != nil {
e <- trace.Wrap(trace.Wrap(err))
e <- trace.Wrap(err)
return
}

Expand Down Expand Up @@ -3753,52 +3753,41 @@ type ResourcePage[T types.ResourceWithLabels] struct {
NextKey string
}

// getResourceFromProtoPage extracts the resource from the PaginatedResource returned
// from the rpc ListUnifiedResources
func getResourceFromProtoPage(resource *proto.PaginatedResource) (types.ResourceWithLabels, error) {
var out types.ResourceWithLabels
// convertEnrichedResource extracts the resource and any enriched information from the
// PaginatedResource returned from the rpc ListUnifiedResources.
func convertEnrichedResource(resource *proto.PaginatedResource) (*types.EnrichedResource, error) {
if r := resource.GetNode(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r, Logins: resource.Logins}, nil
} else if r := resource.GetDatabaseServer(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetDatabaseService(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetAppServerOrSAMLIdPServiceProvider(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetWindowsDesktop(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetWindowsDesktopService(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetKubeCluster(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetKubernetesServer(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetUserGroup(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else if r := resource.GetAppServer(); r != nil {
out = r
return out, nil
return &types.EnrichedResource{ResourceWithLabels: r}, nil
} else {
return nil, trace.BadParameter("received unsupported resource %T", resource.Resource)
}
}

// ListUnifiedResourcePage is a helper for getting a single page of unified resources that match the provided request.
func ListUnifiedResourcePage(ctx context.Context, clt ListUnifiedResourcesClient, req *proto.ListUnifiedResourcesRequest) (ResourcePage[types.ResourceWithLabels], error) {
var out ResourcePage[types.ResourceWithLabels]
// GetUnifiedResourcePage is a helper for getting a single page of unified resources that match the provided request.
func GetUnifiedResourcePage(ctx context.Context, clt ListUnifiedResourcesClient, req *proto.ListUnifiedResourcesRequest) ([]*types.EnrichedResource, string, error) {
var out []*types.EnrichedResource

// Set the limit to the default size if one was not provided within
// an acceptable range.
if req.Limit == 0 || req.Limit > int32(defaults.DefaultChunkSize) {
if req.Limit <= 0 || req.Limit > int32(defaults.DefaultChunkSize) {
req.Limit = int32(defaults.DefaultChunkSize)
}

Expand All @@ -3810,24 +3799,87 @@ func ListUnifiedResourcePage(ctx context.Context, clt ListUnifiedResourcesClient
req.Limit /= 2
// This is an extremely unlikely scenario, but better to cover it anyways.
if req.Limit == 0 {
return out, trace.Wrap(trace.Wrap(err), "resource is too large to retrieve")
return nil, "", trace.Wrap(err, "resource is too large to retrieve")
}

continue
}

return out, trace.Wrap(err)
return nil, "", trace.Wrap(err)
}

for _, respResource := range resp.Resources {
resource, err := getResourceFromProtoPage(respResource)
resource, err := convertEnrichedResource(respResource)
if err != nil {
return out, trace.Wrap(err)
return nil, "", trace.Wrap(err)
}
out = append(out, resource)
}

return out, resp.NextKey, nil
}
}

// GetEnrichedResourcePage is a helper for getting a single page of enriched resources.
func GetEnrichedResourcePage(ctx context.Context, clt GetResourcesClient, req *proto.ListResourcesRequest) (ResourcePage[*types.EnrichedResource], error) {
var out ResourcePage[*types.EnrichedResource]

// Set the limit to the default size if one was not provided within
// an acceptable range.
if req.Limit <= 0 || req.Limit > int32(defaults.DefaultChunkSize) {
req.Limit = int32(defaults.DefaultChunkSize)
}

for {
resp, err := clt.GetResources(ctx, req)
if err != nil {
if trace.IsLimitExceeded(err) {
// Cut chunkSize in half if gRPC max message size is exceeded.
req.Limit /= 2
// This is an extremely unlikely scenario, but better to cover it anyways.
if req.Limit == 0 {
return out, trace.Wrap(err, "resource is too large to retrieve")
}

continue
}
out.Resources = append(out.Resources, resource)

return out, trace.Wrap(err)
}

for _, respResource := range resp.Resources {
var resource types.ResourceWithLabels
switch req.ResourceType {
case types.KindDatabaseServer:
resource = respResource.GetDatabaseServer()
case types.KindDatabaseService:
resource = respResource.GetDatabaseService()
case types.KindAppServer:
resource = respResource.GetAppServer()
case types.KindNode:
resource = respResource.GetNode()
case types.KindWindowsDesktop:
resource = respResource.GetWindowsDesktop()
case types.KindWindowsDesktopService:
resource = respResource.GetWindowsDesktopService()
case types.KindKubernetesCluster:
resource = respResource.GetKubeCluster()
case types.KindKubeServer:
resource = respResource.GetKubernetesServer()
case types.KindUserGroup:
resource = respResource.GetUserGroup()
case types.KindAppOrSAMLIdPServiceProvider:
resource = respResource.GetAppServerOrSAMLIdPServiceProvider()
default:
out.Resources = nil
return out, trace.NotImplemented("resource type %s does not support pagination", req.ResourceType)
}

out.Resources = append(out.Resources, &types.EnrichedResource{ResourceWithLabels: resource, Logins: respResource.Logins})
}

out.NextKey = resp.NextKey
out.Total = int(resp.TotalCount)

return out, nil
}
Expand All @@ -3839,7 +3891,7 @@ func GetResourcePage[T types.ResourceWithLabels](ctx context.Context, clt GetRes

// Set the limit to the default size if one was not provided within
// an acceptable range.
if req.Limit == 0 || req.Limit > int32(defaults.DefaultChunkSize) {
if req.Limit <= 0 || req.Limit > int32(defaults.DefaultChunkSize) {
req.Limit = int32(defaults.DefaultChunkSize)
}

Expand All @@ -3851,7 +3903,7 @@ func GetResourcePage[T types.ResourceWithLabels](ctx context.Context, clt GetRes
req.Limit /= 2
// This is an extremely unlikely scenario, but better to cover it anyways.
if req.Limit == 0 {
return out, trace.Wrap(trace.Wrap(err), "resource is too large to retrieve")
return out, trace.Wrap(err, "resource is too large to retrieve")
}

continue
Expand Down Expand Up @@ -3964,7 +4016,7 @@ func GetResourcesWithFilters(ctx context.Context, clt ListResourcesClient, req p
chunkSize = chunkSize / 2
// This is an extremely unlikely scenario, but better to cover it anyways.
if chunkSize == 0 {
return nil, trace.Wrap(trace.Wrap(err), "resource is too large to retrieve")
return nil, trace.Wrap(err, "resource is too large to retrieve")
}

continue
Expand Down Expand Up @@ -4015,7 +4067,7 @@ func GetKubernetesResourcesWithFilters(ctx context.Context, clt kubeproto.KubeSe
chunkSize = chunkSize / 2
// This is an extremely unlikely scenario, but better to cover it anyways.
if chunkSize == 0 {
return nil, trace.Wrap(trace.Wrap(err), "resource is too large to retrieve")
return nil, trace.Wrap(err, "resource is too large to retrieve")
}
continue
}
Expand Down
Loading