From 30c4ba319e2edf176a0997b654c91f1c7c098e78 Mon Sep 17 00:00:00 2001 From: Travis Bischel Date: Wed, 19 May 2021 06:22:28 +0000 Subject: [PATCH] sasl: add support for AWS_MSK_IAM This took a good bit to figure out from the Java source, but, I've tested this against a scrappy msk cluster and it now works. --- pkg/sasl/aws/aws.go | 240 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 pkg/sasl/aws/aws.go diff --git a/pkg/sasl/aws/aws.go b/pkg/sasl/aws/aws.go new file mode 100644 index 00000000..eae84b14 --- /dev/null +++ b/pkg/sasl/aws/aws.go @@ -0,0 +1,240 @@ +// Package aws provides AWS_MSK_IAM sasl authentication as specified in the +// Java source. +// +// The Java source can be found at https://github.com/aws/aws-msk-iam-auth. +package aws + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net" + "net/url" + "strings" + "time" + + "github.com/twmb/franz-go/pkg/sasl" +) + +// Auth contains an AWS AccessKey and SecretKey for authentication. +// +// This client may add fields to this struct in the future if Kafka adds more +// capabilities to MSK IAM. +type Auth struct { + // AccessKey is an AWS AccessKey. + AccessKey string + + // AccessKey is an AWS SecretKey. + SecretKey string + + _internal struct{} // require explicit field initalization +} + +// AsManagedStreamingIAMMechanism returns a sasl mechanism that will use 'a' as +// credentials for all sasl sessions. +// +// This is a shortcut for using the ManagedStreamingIAM function and is useful +// when you do not need to live-rotate credentials. +func (a Auth) AsManagedStreamingIAMMechanism() sasl.Mechanism { + return ManagedStreamingIAM(func(context.Context) (Auth, error) { + return a, nil + }) +} + +type mskiam func(context.Context) (Auth, error) + +// ManagedStreamingIAM returns a sasl mechanism that will call authFn whenever +// sasl authentication is needed. The returned Auth is used for a single +// session. +func ManagedStreamingIAM(authFn func(context.Context) (Auth, error)) sasl.Mechanism { + return mskiam(authFn) +} + +func (mskiam) Name() string { return "AWS_MSK_IAM" } + +func (fn mskiam) Authenticate(ctx context.Context, host string) (sasl.Session, []byte, error) { + auth, err := fn(ctx) + if err != nil { + return nil, nil, err + } + + challenge, err := challenge(auth, host) + if err != nil { + return nil, nil, err + } + + return new(session), challenge, nil +} + +type session struct{} + +func (session) Challenge(resp []byte) (bool, []byte, error) { + if len(resp) == 0 { + return false, nil, errors.New("empty challenge response: failed") + } + return true, nil, nil +} + +const service = "kafka-cluster" + +func challenge(auth Auth, host string) ([]byte, error) { + host, _, err := net.SplitHostPort(host) // we do not need the port + if err != nil { + return nil, err + } + region, err := identifyRegion(host) + if err != nil { + return nil, err + } + + var ( + timestamp = time.Now().UTC().Format("20060102T150405Z") + date = timestamp[:8] // 20060102 + scope = scope(date, region) + v = make(url.Values) + ) + + v.Set("Action", service+":Connect") + v.Set("X-Amz-Algorithm", "AWS4-HMAC-SHA256") + v.Set("X-Amz-Credential", auth.AccessKey+"/"+scope) + v.Set("X-Amz-Date", timestamp) + v.Set("X-Amz-Expires", "900") // 1 min + v.Set("X-Amz-SignedHeaders", "host") + + qps := strings.Replace(v.Encode(), "+", "%20", -1) + + canonicalRequest := task1(host, qps) + sts := task2(timestamp, scope, canonicalRequest) + signature := task3(auth.SecretKey, region, date, sts) + + v.Set("X-Amz-Signature", signature) // task4 + + // According to the Java source and manual testing, all values in our + // challenge map must be lowercased, and we MUST have host, and we MUST + // have version, and version MUST be 2020_10_22. + keyvals := make(map[string]string) + for key, values := range v { + keyvals[strings.ToLower(key)] = values[0] + } + keyvals["host"] = host + keyvals["version"] = "2020_10_22" + + marshaled, err := json.Marshal(keyvals) + if err != nil { + return nil, err + } + return marshaled, nil +} + +// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html +// "CredentialScope", Part 3 +func scope(date, region string) string { + return strings.Join([]string{date, region, service, "aws4_request"}, "/") +} + +// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html +func task1(host, qps string) []byte { + // We start with our defined method, "GET", and the defined empty path, + // "/". For query parameters, we have to escape +'s with %20, but we did + // that above when building our URL. + // + // HTTPRequestMethod + '\n' + + // CanonicalURI + '\n' + + // CanonicalQueryString + '\n' + + canon := make([]byte, 0, 200) + canon = append(canon, "GET\n"...) + canon = append(canon, "/\n"...) + canon = append(canon, qps...) + canon = append(canon, '\n') + + // We only sign one header, the host. Each signed header is followed by + // a newline, and then the canonical header block is followed itself by + // a newline. + // + // CanonicalHeaders + '\n' + + // SignedHeaders + '\n' + + canon = append(canon, "host:"...) + canon = append(canon, host...) + canon = append(canon, '\n') + canon = append(canon, '\n') + canon = append(canon, "host\n"...) + + // Finally, we add our empty body. + // + // HexEncode(Hash(RequestPayload)) + const emptyBody = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + return append(canon, emptyBody...) +} + +// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html +func task2(timestamp, scope string, canonicalRequest []byte) []byte { + toSign := make([]byte, 0, 512) + toSign = append(toSign, "AWS4-HMAC-SHA256\n"...) + toSign = append(toSign, timestamp...) + toSign = append(toSign, '\n') + toSign = append(toSign, scope...) + toSign = append(toSign, '\n') + canonHash := sha256.Sum256(canonicalRequest) + hexBuf := make([]byte, 64) // 32 bytes to 64 + hex.Encode(hexBuf[:], canonHash[:]) + toSign = append(toSign, hexBuf[:]...) + return toSign +} + +var aws4requestBytes = []byte("aws4_request") + +// https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html +func task3(secretKey, region, date string, sts []byte) string { + key := make([]byte, 0, 100) + key = append(key, "AWS4"...) + key = append(key, secretKey...) + + h := hmac.New(sha256.New, key) + h.Write([]byte(date)) // kDate + + key = h.Sum(key[:0]) + h = hmac.New(sha256.New, key) + h.Write([]byte(region)) // kRegion + + key = h.Sum(key[:0]) + h = hmac.New(sha256.New, key) + h.Write([]byte(service)) // kService + + key = h.Sum(key[:0]) + h = hmac.New(sha256.New, key) + h.Write(aws4requestBytes) // kSigning + + key = h.Sum(key[:0]) + h = hmac.New(sha256.New, key) + h.Write(sts) + + return hex.EncodeToString(h.Sum(key[:0])) +} + +// aws-java-sdk-core/src/main/resources/com/amazonaws/partitions/endpoints.json +var suffixes = []string{ + ".amazonaws.com", + ".amazonaws.com.cn", + ".c2s.ic.gov", + ".sc2s.sgov.gov", +} + +// aws-java-sdk-core/src/main/java/com/amazonaws/partitions/PartitionMetadataProvider.java +// tryGetRegionByEndpointDnsSuffix +func identifyRegion(host string) (string, error) { + for _, suffix := range suffixes { + if strings.HasSuffix(host, suffix) { + serviceRegion := strings.TrimSuffix(host, suffix) + regionDot := strings.LastIndexByte(serviceRegion, '.') + if regionDot == -1 { + break + } + return serviceRegion[regionDot+1:], nil + } + } + return "", fmt.Errorf("cannot determine the region in %+q", host) +}