Skip to content

Commit

Permalink
https://github.com/kedacore/keda/issues/2214
Browse files Browse the repository at this point in the history
  • Loading branch information
Siva Guruvareddiar committed Dec 24, 2023
1 parent 531dd6d commit 1cdd3c9
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
73 changes: 73 additions & 0 deletions pkg/scalers/aws_sigv4.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package scalers

import (
"fmt"
"net/http"
"time"

"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
)

// SigV4Config configures signing requests with SigV4.
type SigV4Config struct {
Enabled bool `yaml:"enabled,omitempty"`
Region string `yaml:"region,omitempty"`
}

// Custom round tripper to sign requests
type roundTripper struct {
signer *v4.Signer
region string
}

func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {

// Sign request
rt.signer.Sign(req, nil, "aps", rt.region, time.Now())

// Create default transport
transport := &http.Transport{}

// Send signed request
return transport.RoundTrip(req)
}

// NewSigV4RoundTripper returns a new http.RoundTripper that will sign requests
// using Amazon's Signature Verification V4 signing procedure. The request will
// then be handed off to the next RoundTripper provided by next. If next is nil,
// http.DefaultTransport will be used.
//
// Credentials for signing are retrieving used the default AWS credential chain.
// If credentials could not be found, an error will be returned.
func NewSigV4RoundTripper(triggerMetadata map[string]string, next http.RoundTripper) (http.RoundTripper, error) {

if triggerMetadata == nil {
return nil, fmt.Errorf("trigger metadata cannot be nil")
}

awsRegion := triggerMetadata["awsRegion"]
if awsRegion == "" {
return nil, fmt.Errorf("awsRegion not configured in trigger metadata")
}

accessId := triggerMetadata["awsAccessId"]
if accessId == "" {
return nil, fmt.Errorf("accessId not configured in trigger metadata")
}

secretKey := triggerMetadata["awsSecretKey"]
if secretKey == "" {
return nil, fmt.Errorf("secretKey not configured in trigger metadata")
}

creds := credentials.NewStaticCredentials(accessId, secretKey, "")
signer := v4.NewSigner(creds)

rt := &roundTripper{
signer: signer,
region: awsRegion,
}

return rt, nil
}
34 changes: 34 additions & 0 deletions pkg/scalers/aws_sigv4_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package scalers

import (
"net/http"
"strings"
"testing"

"github.com/aws/aws-sdk-go/aws/credentials"
signer "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/stretchr/testify/require"
)

func TestSigV4RoundTripper(t *testing.T) {

rt := &roundTripper{
signer: signer.NewSigner(credentials.NewStaticCredentials(
"test-id",
"secret",
"token",
)),
region: "us-west-2",
}

cli := &http.Client{Transport: rt}

req, err := http.NewRequest(http.MethodGet, "https://aps-workspaces.us-west-2.amazonaws.com/workspaces/ws-38377ca8-8db3-4b58-812d-b65a81837bb8/api/v1/query?query=ho11y_total", strings.NewReader("Hello, world!"))
require.NoError(t, err)
r, err := cli.Do(req)
require.NotEmpty(t, r)
require.NoError(t, err)

require.NotNil(t, req)
require.NotEmpty(t, req.Header.Get("Authorization"))
}
11 changes: 11 additions & 0 deletions pkg/scalers/prometheus_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ func NewPrometheusScaler(config *ScalerConfig) (Scaler, error) {
httpClient.Transport = transport
}
} else {
awsTransport, err := NewSigV4RoundTripper(config.TriggerMetadata, httpClient.Transport)
if err != nil {
logger.V(1).Error(err, "failed to get AWS client HTTP transport ")
return nil, err
}

if err == nil && awsTransport != nil {
httpClient.Transport = awsTransport
}
// could be the case of azure managed prometheus. Try and get the round-tripper.
// If it's not the case of azure managed prometheus, we will get both transport and err as nil and proceed assuming no auth.
azureTransport, err := azure.TryAndGetAzureManagedPrometheusHTTPRoundTripper(logger, config.PodIdentity, config.TriggerMetadata)
Expand Down Expand Up @@ -306,6 +315,8 @@ func (s *prometheusScaler) ExecutePromQuery(ctx context.Context) (float64, error
req.Header.Set(s.metadata.prometheusAuth.CustomAuthHeader, s.metadata.prometheusAuth.CustomAuthValue)
}

//cli := &http.Client{Transport: s.httpClient.Transport}

r, err := s.httpClient.Do(req)
if err != nil {
return -1, err
Expand Down

0 comments on commit 1cdd3c9

Please sign in to comment.