diff --git a/extension/asapauthextension/extension.go b/extension/asapauthextension/extension.go index 6f53324d37415..a335fc35b2d81 100644 --- a/extension/asapauthextension/extension.go +++ b/extension/asapauthextension/extension.go @@ -10,10 +10,31 @@ import ( "bitbucket.org/atlassian/go-asap/v2" "github.com/SermoDigital/jose/crypto" + "go.opentelemetry.io/collector/component" "go.opentelemetry.io/collector/extension/extensionauth" "google.golang.org/grpc/credentials" ) +var _ extensionauth.Client = (*asapAuthExtension)(nil) + +type asapAuthExtension struct { + component.StartFunc + component.ShutdownFunc + + provisioner asap.Provisioner + privateKey any +} + +// PerRPCCredentials returns extensionauth.Client. +func (e *asapAuthExtension) PerRPCCredentials() (credentials.PerRPCCredentials, error) { + return &perRPCAuth{provisioner: e.provisioner, privateKey: e.privateKey}, nil +} + +// RoundTripper implements extensionauth.Client. +func (e *asapAuthExtension) RoundTripper(base http.RoundTripper) (http.RoundTripper, error) { + return asap.NewTransportDecorator(e.provisioner, e.privateKey)(base), nil +} + func createASAPClientAuthenticator(cfg *Config) (extensionauth.Client, error) { pk, err := asap.NewPrivateKey([]byte(cfg.PrivateKey)) if err != nil { @@ -24,14 +45,10 @@ func createASAPClientAuthenticator(cfg *Config) (extensionauth.Client, error) { p := asap.NewCachingProvisioner(asap.NewProvisioner( cfg.KeyID, cfg.TTL, cfg.Issuer, cfg.Audience, crypto.SigningMethodRS256)) - return extensionauth.NewClient( - extensionauth.WithClientRoundTripper(func(base http.RoundTripper) (http.RoundTripper, error) { - return asap.NewTransportDecorator(p, pk)(base), nil - }), - extensionauth.WithClientPerRPCCredentials(func() (credentials.PerRPCCredentials, error) { - return &perRPCAuth{provisioner: p, privateKey: pk}, nil - }), - ) + return &asapAuthExtension{ + provisioner: p, + privateKey: pk, + }, nil } // perRPCAuth is a gRPC credentials.PerRPCCredentials implementation that returns an 'authorization' header. diff --git a/extension/basicauthextension/extension.go b/extension/basicauthextension/extension.go index 7d0a2b78a767c..4809a355b0c91 100644 --- a/extension/basicauthextension/extension.go +++ b/extension/basicauthextension/extension.go @@ -27,20 +27,8 @@ var ( errInvalidFormat = errors.New("invalid authorization format") ) -type basicAuth struct { - htpasswd *HtpasswdSettings - clientAuth *ClientAuthSettings - matchFunc func(username, password string) bool -} - -func newClientAuthExtension(cfg *Config) (extensionauth.Client, error) { - ba := basicAuth{ - clientAuth: cfg.ClientAuth, - } - return extensionauth.NewClient( - extensionauth.WithClientRoundTripper(ba.roundTripper), - extensionauth.WithClientPerRPCCredentials(ba.perRPCCredentials), - ) +func newClientAuthExtension(cfg *Config) extensionauth.Client { + return &basicAuthClient{clientAuth: cfg.ClientAuth} } func newServerAuthExtension(cfg *Config) (extensionauth.Server, error) { @@ -48,16 +36,20 @@ func newServerAuthExtension(cfg *Config) (extensionauth.Server, error) { return nil, errNoCredentialSource } - ba := basicAuth{ + return &basicAuthServer{ htpasswd: cfg.Htpasswd, - } - return extensionauth.NewServer( - extensionauth.WithServerStart(ba.serverStart), - extensionauth.WithServerAuthenticate(ba.authenticate), - ) + }, nil } -func (ba *basicAuth) serverStart(_ context.Context, _ component.Host) error { +var _ extensionauth.Server = (*basicAuthServer)(nil) + +type basicAuthServer struct { + htpasswd *HtpasswdSettings + matchFunc func(username, password string) bool + component.ShutdownFunc +} + +func (ba *basicAuthServer) Start(_ context.Context, _ component.Host) error { var rs []io.Reader if ba.htpasswd.File != "" { @@ -86,7 +78,7 @@ func (ba *basicAuth) serverStart(_ context.Context, _ component.Host) error { return nil } -func (ba *basicAuth) authenticate(ctx context.Context, headers map[string][]string) (context.Context, error) { +func (ba *basicAuthServer) Authenticate(ctx context.Context, headers map[string][]string) (context.Context, error) { auth := getAuthHeader(headers) if auth == "" { return ctx, errNoAuth @@ -209,7 +201,16 @@ func (b *basicAuthRoundTripper) RoundTrip(request *http.Request) (*http.Response return b.base.RoundTrip(newRequest) } -func (ba *basicAuth) roundTripper(base http.RoundTripper) (http.RoundTripper, error) { +var _ extensionauth.Client = (*basicAuthClient)(nil) + +type basicAuthClient struct { + component.StartFunc + component.ShutdownFunc + + clientAuth *ClientAuthSettings +} + +func (ba *basicAuthClient) RoundTripper(base http.RoundTripper) (http.RoundTripper, error) { if strings.Contains(ba.clientAuth.Username, ":") { return nil, errInvalidFormat } @@ -219,7 +220,7 @@ func (ba *basicAuth) roundTripper(base http.RoundTripper) (http.RoundTripper, er }, nil } -func (ba *basicAuth) perRPCCredentials() (creds.PerRPCCredentials, error) { +func (ba *basicAuthClient) PerRPCCredentials() (creds.PerRPCCredentials, error) { if strings.Contains(ba.clientAuth.Username, ":") { return nil, errInvalidFormat } diff --git a/extension/basicauthextension/extension_test.go b/extension/basicauthextension/extension_test.go index 40ff4d2e17dd2..2dab5cc5b8c75 100644 --- a/extension/basicauthextension/extension_test.go +++ b/extension/basicauthextension/extension_test.go @@ -226,13 +226,12 @@ func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) } func TestBasicAuth_ClientValid(t *testing.T) { - ext, err := newClientAuthExtension(&Config{ + ext := newClientAuthExtension(&Config{ ClientAuth: &ClientAuthSettings{ Username: "username", Password: "password", }, }) - require.NoError(t, err) require.NotNil(t, ext) require.NoError(t, ext.Start(context.Background(), componenttest.NewNopHost())) @@ -273,19 +272,18 @@ func TestBasicAuth_ClientValid(t *testing.T) { func TestBasicAuth_ClientInvalid(t *testing.T) { t.Run("invalid username format", func(t *testing.T) { - ext, err := newClientAuthExtension(&Config{ + ext := newClientAuthExtension(&Config{ ClientAuth: &ClientAuthSettings{ Username: "user:name", Password: "password", }, }) - require.NoError(t, err) require.NotNil(t, ext) require.NoError(t, ext.Start(context.Background(), componenttest.NewNopHost())) base := &mockRoundTripper{} - _, err = ext.RoundTripper(base) + _, err := ext.RoundTripper(base) assert.Error(t, err) _, err = ext.PerRPCCredentials() diff --git a/extension/basicauthextension/factory.go b/extension/basicauthextension/factory.go index 0059d0d299faa..7e2750f56bcd0 100644 --- a/extension/basicauthextension/factory.go +++ b/extension/basicauthextension/factory.go @@ -31,5 +31,5 @@ func createExtension(_ context.Context, _ extension.Settings, cfg component.Conf if cfg.(*Config).Htpasswd != nil { return newServerAuthExtension(cfg.(*Config)) } - return newClientAuthExtension(cfg.(*Config)) + return newClientAuthExtension(cfg.(*Config)), nil } diff --git a/extension/headerssetterextension/extension.go b/extension/headerssetterextension/extension.go index f2d910fa100ee..c5d856404093b 100644 --- a/extension/headerssetterextension/extension.go +++ b/extension/headerssetterextension/extension.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" + "go.opentelemetry.io/collector/component" "go.opentelemetry.io/collector/extension/extensionauth" "go.uber.org/zap" "google.golang.org/grpc/credentials" @@ -22,6 +23,28 @@ type Header struct { source source.Source } +var _ extensionauth.Client = (*headerSetterExtension)(nil) + +type headerSetterExtension struct { + component.StartFunc + component.ShutdownFunc + + headers []Header +} + +// PerRPCCredentials implements extensionauth.Client. +func (h *headerSetterExtension) PerRPCCredentials() (credentials.PerRPCCredentials, error) { + return &headersPerRPC{headers: h.headers}, nil +} + +// RoundTripper implements extensionauth.Client. +func (h *headerSetterExtension) RoundTripper(base http.RoundTripper) (http.RoundTripper, error) { + return &headersRoundTripper{ + base: base, + headers: h.headers, + }, nil +} + func newHeadersSetterExtension(cfg *Config, logger *zap.Logger) (extensionauth.Client, error) { if cfg == nil { return nil, errors.New("extension configuration is not provided") @@ -63,18 +86,7 @@ func newHeadersSetterExtension(cfg *Config, logger *zap.Logger) (extensionauth.C headers = append(headers, Header{action: a, source: s}) } - return extensionauth.NewClient( - extensionauth.WithClientRoundTripper( - func(base http.RoundTripper) (http.RoundTripper, error) { - return &headersRoundTripper{ - base: base, - headers: headers, - }, nil - }), - extensionauth.WithClientPerRPCCredentials(func() (credentials.PerRPCCredentials, error) { - return &headersPerRPC{headers: headers}, nil - }), - ) + return &headerSetterExtension{headers: headers}, nil } // headersPerRPC is a gRPC credentials.PerRPCCredentials implementation sets diff --git a/extension/oauth2clientauthextension/extension.go b/extension/oauth2clientauthextension/extension.go index 6e82d0554abdc..4e3f1a26fa3fa 100644 --- a/extension/oauth2clientauthextension/extension.go +++ b/extension/oauth2clientauthextension/extension.go @@ -8,6 +8,8 @@ import ( "fmt" "net/http" + "go.opentelemetry.io/collector/component" + "go.opentelemetry.io/collector/extension/extensionauth" "go.uber.org/multierr" "go.uber.org/zap" "golang.org/x/oauth2" @@ -16,9 +18,14 @@ import ( grpcOAuth "google.golang.org/grpc/credentials/oauth" ) +var _ extensionauth.Client = (*clientAuthenticator)(nil) + // clientAuthenticator provides implementation for providing client authentication using OAuth2 client credentials // workflow for both gRPC and HTTP clients. type clientAuthenticator struct { + component.StartFunc + component.ShutdownFunc + clientCredentials *clientCredentialsConfig logger *zap.Logger client *http.Client @@ -75,9 +82,9 @@ func (ewts errorWrappingTokenSource) Token() (*oauth2.Token, error) { return tok, nil } -// roundTripper returns oauth2.Transport, an http.RoundTripper that performs "client-credential" OAuth flow and +// RoundTripper returns oauth2.Transport, an http.RoundTripper that performs "client-credential" OAuth flow and // also auto refreshes OAuth tokens as needed. -func (o *clientAuthenticator) roundTripper(base http.RoundTripper) (http.RoundTripper, error) { +func (o *clientAuthenticator) RoundTripper(base http.RoundTripper) (http.RoundTripper, error) { ctx := context.WithValue(context.Background(), oauth2.HTTPClient, o.client) return &oauth2.Transport{ Source: errorWrappingTokenSource{ @@ -88,9 +95,9 @@ func (o *clientAuthenticator) roundTripper(base http.RoundTripper) (http.RoundTr }, nil } -// perRPCCredentials returns gRPC PerRPCCredentials that supports "client-credential" OAuth flow. The underneath +// PerRPCCredentials returns gRPC PerRPCCredentials that supports "client-credential" OAuth flow. The underneath // oauth2.clientcredentials.Config instance will manage tokens performing auto refresh as necessary. -func (o *clientAuthenticator) perRPCCredentials() (credentials.PerRPCCredentials, error) { +func (o *clientAuthenticator) PerRPCCredentials() (credentials.PerRPCCredentials, error) { ctx := context.WithValue(context.Background(), oauth2.HTTPClient, o.client) return grpcOAuth.TokenSource{ TokenSource: errorWrappingTokenSource{ diff --git a/extension/oauth2clientauthextension/extension_test.go b/extension/oauth2clientauthextension/extension_test.go index 280d9cdbef5b3..5e58120962f31 100644 --- a/extension/oauth2clientauthextension/extension_test.go +++ b/extension/oauth2clientauthextension/extension_test.go @@ -222,7 +222,7 @@ func TestRoundTripper(t *testing.T) { } assert.NotNil(t, oauth2Authenticator) - roundTripper, err := oauth2Authenticator.roundTripper(baseRoundTripper) + roundTripper, err := oauth2Authenticator.RoundTripper(baseRoundTripper) assert.NoError(t, err) // test roundTripper is an OAuth RoundTripper @@ -266,7 +266,7 @@ func TestOAuth2PerRPCCredentials(t *testing.T) { return } assert.NoError(t, err) - perRPCCredentials, err := oauth2Authenticator.perRPCCredentials() + perRPCCredentials, err := oauth2Authenticator.PerRPCCredentials() assert.NoError(t, err) // test perRPCCredentials is an grpc OAuthTokenSource _, ok := perRPCCredentials.(grpcOAuth.TokenSource) @@ -294,7 +294,7 @@ func TestFailContactingOAuth(t *testing.T) { require.NoError(t, err) // Test for gRPC connections - credential, err := oauth2Authenticator.perRPCCredentials() + credential, err := oauth2Authenticator.PerRPCCredentials() require.NoError(t, err) _, err = credential.GetRequestMetadata(context.Background()) @@ -303,7 +303,7 @@ func TestFailContactingOAuth(t *testing.T) { transport := http.DefaultTransport.(*http.Transport).Clone() baseRoundTripper := (http.RoundTripper)(transport) - roundTripper, err := oauth2Authenticator.roundTripper(baseRoundTripper) + roundTripper, err := oauth2Authenticator.RoundTripper(baseRoundTripper) require.NoError(t, err) client := &http.Client{ diff --git a/extension/oauth2clientauthextension/factory.go b/extension/oauth2clientauthextension/factory.go index 3f0b2f0b1d87c..5cea012007f88 100644 --- a/extension/oauth2clientauthextension/factory.go +++ b/extension/oauth2clientauthextension/factory.go @@ -9,7 +9,6 @@ import ( "go.opentelemetry.io/collector/component" "go.opentelemetry.io/collector/extension" - "go.opentelemetry.io/collector/extension/extensionauth" "github.com/open-telemetry/opentelemetry-collector-contrib/extension/oauth2clientauthextension/internal/metadata" ) @@ -31,13 +30,5 @@ func createDefaultConfig() component.Config { } func createExtension(_ context.Context, set extension.Settings, cfg component.Config) (extension.Extension, error) { - ca, err := newClientAuthenticator(cfg.(*Config), set.Logger) - if err != nil { - return nil, err - } - - return extensionauth.NewClient( - extensionauth.WithClientRoundTripper(ca.roundTripper), - extensionauth.WithClientPerRPCCredentials(ca.perRPCCredentials), - ) + return newClientAuthenticator(cfg.(*Config), set.Logger) } diff --git a/extension/oidcauthextension/extension.go b/extension/oidcauthextension/extension.go index 90e7a9a2f15c6..e9f9601583fea 100644 --- a/extension/oidcauthextension/extension.go +++ b/extension/oidcauthextension/extension.go @@ -24,6 +24,8 @@ import ( "go.uber.org/zap" ) +var _ extensionauth.Server = (*oidcExtension)(nil) + type oidcExtension struct { cfg *Config @@ -45,23 +47,18 @@ var ( errNotAuthenticated = errors.New("authentication didn't succeed") ) -func newExtension(cfg *Config, logger *zap.Logger) (extensionauth.Server, error) { +func newExtension(cfg *Config, logger *zap.Logger) extensionauth.Server { if cfg.Attribute == "" { cfg.Attribute = defaultAttribute } - oe := &oidcExtension{ + return &oidcExtension{ cfg: cfg, logger: logger, } - return extensionauth.NewServer( - extensionauth.WithServerStart(oe.start), - extensionauth.WithServerAuthenticate(oe.authenticate), - extensionauth.WithServerShutdown(oe.shutdown), - ) } -func (e *oidcExtension) start(ctx context.Context, _ component.Host) error { +func (e *oidcExtension) Start(ctx context.Context, _ component.Host) error { err := e.setProviderConfig(ctx, e.cfg) if err != nil { return fmt.Errorf("failed to get configuration from the auth server: %w", err) @@ -72,7 +69,7 @@ func (e *oidcExtension) start(ctx context.Context, _ component.Host) error { return nil } -func (e *oidcExtension) shutdown(context.Context) error { +func (e *oidcExtension) Shutdown(context.Context) error { if e.client != nil { e.client.CloseIdleConnections() } @@ -84,7 +81,7 @@ func (e *oidcExtension) shutdown(context.Context) error { } // authenticate checks whether the given context contains valid auth data. Successfully authenticated calls will always return a nil error and a context with the auth data. -func (e *oidcExtension) authenticate(ctx context.Context, headers map[string][]string) (context.Context, error) { +func (e *oidcExtension) Authenticate(ctx context.Context, headers map[string][]string) (context.Context, error) { var authHeaders []string for k, v := range headers { if strings.EqualFold(k, e.cfg.Attribute) { diff --git a/extension/oidcauthextension/extension_test.go b/extension/oidcauthextension/extension_test.go index bc9ee7441eb4d..58d5d89b0a34e 100644 --- a/extension/oidcauthextension/extension_test.go +++ b/extension/oidcauthextension/extension_test.go @@ -27,9 +27,7 @@ import ( func newTestExtension(t *testing.T, cfg *Config) extensionauth.Server { t.Helper() - ext, err := newExtension(cfg, zap.NewNop()) - require.NoError(t, err) - return ext + return newExtension(cfg, zap.NewNop()) } func TestOIDCAuthenticationSucceeded(t *testing.T) { diff --git a/extension/oidcauthextension/factory.go b/extension/oidcauthextension/factory.go index b76430965bc93..4c0f1010113f2 100644 --- a/extension/oidcauthextension/factory.go +++ b/extension/oidcauthextension/factory.go @@ -33,5 +33,5 @@ func createDefaultConfig() component.Config { } func createExtension(_ context.Context, set extension.Settings, cfg component.Config) (extension.Extension, error) { - return newExtension(cfg.(*Config), set.Logger) + return newExtension(cfg.(*Config), set.Logger), nil }