Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions integration/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ import (
"github.com/gravitational/teleport/integration/helpers"
"github.com/gravitational/teleport/integration/kube"
"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/testauthority"
libclient "github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/defaults"
Expand Down Expand Up @@ -1522,11 +1521,11 @@ func TestALPNSNIProxyGRPCInsecure(t *testing.T) {
suite := newSuite(t,
withRootClusterConfig(rootClusterStandardConfig(t), func(config *servicecfg.Config) {
config.Auth.BootstrapResources = []types.Resource{provisionToken}
config.Auth.ServerOptions = []auth.ServerOption{auth.WithHTTPClientForAWSSTS(fakeSTSClient{
config.Auth.HTTPClientForAWSSTS = fakeSTSClient{
accountID: nodeAccount,
arn: nodeRoleARN,
credentials: nodeCredentials,
})}
}
}),
withLeafClusterConfig(leafClusterStandardConfig(t)),
)
Expand Down
48 changes: 20 additions & 28 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,24 +281,25 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {

closeCtx, cancelFunc := context.WithCancel(context.TODO())
as := Server{
bk: cfg.Backend,
clock: cfg.Clock,
limiter: limiter,
Authority: cfg.Authority,
AuthServiceName: cfg.AuthServiceName,
ServerID: cfg.HostUUID,
githubClients: make(map[string]*githubClient),
cancelFunc: cancelFunc,
closeCtx: closeCtx,
emitter: cfg.Emitter,
streamer: cfg.Streamer,
Unstable: local.NewUnstableService(cfg.Backend, cfg.AssertionReplayService),
Services: services,
Cache: services,
keyStore: keyStore,
traceClient: cfg.TraceClient,
fips: cfg.FIPS,
loadAllCAs: cfg.LoadAllCAs,
bk: cfg.Backend,
clock: cfg.Clock,
limiter: limiter,
Authority: cfg.Authority,
AuthServiceName: cfg.AuthServiceName,
ServerID: cfg.HostUUID,
githubClients: make(map[string]*githubClient),
cancelFunc: cancelFunc,
closeCtx: closeCtx,
emitter: cfg.Emitter,
streamer: cfg.Streamer,
Unstable: local.NewUnstableService(cfg.Backend, cfg.AssertionReplayService),
Services: services,
Cache: services,
keyStore: keyStore,
traceClient: cfg.TraceClient,
fips: cfg.FIPS,
loadAllCAs: cfg.LoadAllCAs,
httpClientForAWSSTS: cfg.HTTPClientForAWSSTS,
}
as.inventory = inventory.NewController(&as, services, inventory.WithAuthServerID(cfg.HostUUID))
for _, o := range opts {
Expand Down Expand Up @@ -595,7 +596,7 @@ type Server struct {

// httpClientForAWSSTS overwrites the default HTTP client used for making
// STS requests.
httpClientForAWSSTS stsClient
httpClientForAWSSTS utils.HTTPDoClient
}

// SetSAMLService registers svc as the SAMLService that provides the SAML
Expand Down Expand Up @@ -5411,12 +5412,3 @@ func DefaultDNSNamesForRole(role types.SystemRole) []string {
}
return nil
}

// WithHTTPClientForAWSSTS is a ServerOption that overwrites default HTTP
// client used for STS requests.
func WithHTTPClientForAWSSTS(client stsClient) ServerOption {
return func(s *Server) error {
s.httpClientForAWSSTS = client
return nil
}
}
4 changes: 4 additions & 0 deletions lib/auth/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ type InitConfig struct {
// Clock is the clock instance auth uses. Typically you'd only want to set
// this during testing.
Clock clockwork.Clock

// HTTPClientForAWSSTS overwrites the default HTTP client used for making
// STS requests. Used in test.
HTTPClientForAWSSTS utils.HTTPDoClient
}

// Init instantiates and configures an instance of AuthServer
Expand Down
6 changes: 1 addition & 5 deletions lib/auth/join_iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,9 @@ type stsIdentityResponse struct {
GetCallerIdentityResponse getCallerIdentityResponse `json:"GetCallerIdentityResponse"`
}

type stsClient interface {
Do(*http.Request) (*http.Response, error)
}

// executeSTSIdentityRequest sends the sts:GetCallerIdentity HTTP request to the
// AWS API, parses the response, and returns the awsIdentity
func executeSTSIdentityRequest(ctx context.Context, client stsClient, req *http.Request) (*awsIdentity, error) {
func executeSTSIdentityRequest(ctx context.Context, client utils.HTTPDoClient, req *http.Request) (*awsIdentity, error) {
if client == nil {
client = http.DefaultClient
}
Expand Down
3 changes: 2 additions & 1 deletion lib/auth/join_iam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/utils"
)

func responseFromAWSIdentity(id awsIdentity) string {
Expand Down Expand Up @@ -134,7 +135,7 @@ func TestAuth_RegisterUsingIAMMethod(t *testing.T) {
tokenName string
requestTokenName string
tokenSpec types.ProvisionTokenSpecV2
stsClient stsClient
stsClient utils.HTTPDoClient
iamRegisterOptions []iamRegisterOption
challengeResponseOptions []challengeResponseOption
challengeResponseErr error
Expand Down
5 changes: 3 additions & 2 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1636,7 +1636,8 @@ func (process *TeleportProcess) initAuthService() error {
FIPS: cfg.FIPS,
LoadAllCAs: cfg.Auth.LoadAllCAs,
Clock: cfg.Clock,
}, append(cfg.Auth.ServerOptions, func(as *auth.Server) error {
HTTPClientForAWSSTS: cfg.Auth.HTTPClientForAWSSTS,
}, func(as *auth.Server) error {
if !process.Config.CachePolicy.Enabled {
return nil
}
Expand All @@ -1654,7 +1655,7 @@ func (process *TeleportProcess) initAuthService() error {
as.Cache = cache

return nil
})...)
})
if err != nil {
return trace.Wrap(err)
}
Expand Down
6 changes: 3 additions & 3 deletions lib/service/servicecfg/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"github.com/jonboulle/clockwork"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/keystore"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/limiter"
Expand Down Expand Up @@ -103,8 +102,9 @@ type AuthConfig struct {
// this during testing.
Clock clockwork.Clock

// ServerOptions is a list of auth.Init options used in test.
ServerOptions []auth.ServerOption
// HTTPClientForAWSSTS overwrites the default HTTP client used for making
// STS requests. Used in test.
HTTPClientForAWSSTS utils.HTTPDoClient
}

// HostedPluginsConfig configures the hosted plugin runtime.
Expand Down
5 changes: 5 additions & 0 deletions lib/utils/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,8 @@ func GetAnyHeader(header http.Header, keys ...string) string {
}
return ""
}

// HTTPDoClient is an interface that defines the Do function of http.Client.
type HTTPDoClient interface {
Do(req *http.Request) (*http.Response, error)
}