-
-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This took a good bit to figure out from the Java source, but, I've tested this against a scrappy msk cluster and it now works.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
// Package aws provides AWS_MSK_IAM sasl authentication as specified in the | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong. |
||
// Java source. | ||
// | ||
// The Java source can be found at https://github.com/aws/aws-msk-iam-auth. | ||
package aws | ||
|
||
import ( | ||
"context" | ||
"crypto/hmac" | ||
"crypto/sha256" | ||
"encoding/hex" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"net" | ||
"net/url" | ||
"strings" | ||
"time" | ||
|
||
"github.com/twmb/franz-go/pkg/sasl" | ||
) | ||
|
||
// Auth contains an AWS AccessKey and SecretKey for authentication. | ||
// | ||
// This client may add fields to this struct in the future if Kafka adds more | ||
// capabilities to MSK IAM. | ||
type Auth struct { | ||
This comment has been minimized.
Sorry, something went wrong.
grsubramanian
|
||
// AccessKey is an AWS AccessKey. | ||
AccessKey string | ||
|
||
// AccessKey is an AWS SecretKey. | ||
SecretKey string | ||
|
||
_internal struct{} // require explicit field initalization | ||
} | ||
|
||
// AsManagedStreamingIAMMechanism returns a sasl mechanism that will use 'a' as | ||
// credentials for all sasl sessions. | ||
// | ||
// This is a shortcut for using the ManagedStreamingIAM function and is useful | ||
// when you do not need to live-rotate credentials. | ||
func (a Auth) AsManagedStreamingIAMMechanism() sasl.Mechanism { | ||
return ManagedStreamingIAM(func(context.Context) (Auth, error) { | ||
return a, nil | ||
}) | ||
} | ||
|
||
type mskiam func(context.Context) (Auth, error) | ||
|
||
// ManagedStreamingIAM returns a sasl mechanism that will call authFn whenever | ||
// sasl authentication is needed. The returned Auth is used for a single | ||
// session. | ||
func ManagedStreamingIAM(authFn func(context.Context) (Auth, error)) sasl.Mechanism { | ||
return mskiam(authFn) | ||
} | ||
|
||
func (mskiam) Name() string { return "AWS_MSK_IAM" } | ||
|
||
func (fn mskiam) Authenticate(ctx context.Context, host string) (sasl.Session, []byte, error) { | ||
auth, err := fn(ctx) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
challenge, err := challenge(auth, host) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
|
||
return new(session), challenge, nil | ||
} | ||
|
||
type session struct{} | ||
|
||
func (session) Challenge(resp []byte) (bool, []byte, error) { | ||
if len(resp) == 0 { | ||
return false, nil, errors.New("empty challenge response: failed") | ||
} | ||
return true, nil, nil | ||
} | ||
|
||
const service = "kafka-cluster" | ||
|
||
func challenge(auth Auth, host string) ([]byte, error) { | ||
host, _, err := net.SplitHostPort(host) // we do not need the port | ||
if err != nil { | ||
return nil, err | ||
} | ||
region, err := identifyRegion(host) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
var ( | ||
timestamp = time.Now().UTC().Format("20060102T150405Z") | ||
date = timestamp[:8] // 20060102 | ||
scope = scope(date, region) | ||
v = make(url.Values) | ||
) | ||
|
||
v.Set("Action", service+":Connect") | ||
v.Set("X-Amz-Algorithm", "AWS4-HMAC-SHA256") | ||
v.Set("X-Amz-Credential", auth.AccessKey+"/"+scope) | ||
v.Set("X-Amz-Date", timestamp) | ||
v.Set("X-Amz-Expires", "900") // 1 min | ||
This comment has been minimized.
Sorry, something went wrong. |
||
v.Set("X-Amz-SignedHeaders", "host") | ||
|
||
qps := strings.Replace(v.Encode(), "+", "%20", -1) | ||
|
||
canonicalRequest := task1(host, qps) | ||
sts := task2(timestamp, scope, canonicalRequest) | ||
signature := task3(auth.SecretKey, region, date, sts) | ||
|
||
v.Set("X-Amz-Signature", signature) // task4 | ||
This comment has been minimized.
Sorry, something went wrong.
grsubramanian
|
||
|
||
// According to the Java source and manual testing, all values in our | ||
// challenge map must be lowercased, and we MUST have host, and we MUST | ||
// have version, and version MUST be 2020_10_22. | ||
keyvals := make(map[string]string) | ||
for key, values := range v { | ||
keyvals[strings.ToLower(key)] = values[0] | ||
} | ||
keyvals["host"] = host | ||
keyvals["version"] = "2020_10_22" | ||
|
||
marshaled, err := json.Marshal(keyvals) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return marshaled, nil | ||
} | ||
|
||
// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html | ||
// "CredentialScope", Part 3 | ||
func scope(date, region string) string { | ||
return strings.Join([]string{date, region, service, "aws4_request"}, "/") | ||
} | ||
|
||
// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html | ||
func task1(host, qps string) []byte { | ||
// We start with our defined method, "GET", and the defined empty path, | ||
// "/". For query parameters, we have to escape +'s with %20, but we did | ||
// that above when building our URL. | ||
// | ||
// HTTPRequestMethod + '\n' + | ||
// CanonicalURI + '\n' + | ||
// CanonicalQueryString + '\n' + | ||
canon := make([]byte, 0, 200) | ||
canon = append(canon, "GET\n"...) | ||
canon = append(canon, "/\n"...) | ||
canon = append(canon, qps...) | ||
canon = append(canon, '\n') | ||
|
||
// We only sign one header, the host. Each signed header is followed by | ||
// a newline, and then the canonical header block is followed itself by | ||
// a newline. | ||
// | ||
// CanonicalHeaders + '\n' + | ||
// SignedHeaders + '\n' + | ||
canon = append(canon, "host:"...) | ||
canon = append(canon, host...) | ||
canon = append(canon, '\n') | ||
canon = append(canon, '\n') | ||
canon = append(canon, "host\n"...) | ||
|
||
// Finally, we add our empty body. | ||
// | ||
// HexEncode(Hash(RequestPayload)) | ||
const emptyBody = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" | ||
return append(canon, emptyBody...) | ||
} | ||
|
||
// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html | ||
func task2(timestamp, scope string, canonicalRequest []byte) []byte { | ||
toSign := make([]byte, 0, 512) | ||
toSign = append(toSign, "AWS4-HMAC-SHA256\n"...) | ||
toSign = append(toSign, timestamp...) | ||
toSign = append(toSign, '\n') | ||
toSign = append(toSign, scope...) | ||
toSign = append(toSign, '\n') | ||
canonHash := sha256.Sum256(canonicalRequest) | ||
hexBuf := make([]byte, 64) // 32 bytes to 64 | ||
hex.Encode(hexBuf[:], canonHash[:]) | ||
toSign = append(toSign, hexBuf[:]...) | ||
return toSign | ||
} | ||
|
||
var aws4requestBytes = []byte("aws4_request") | ||
|
||
// https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html | ||
func task3(secretKey, region, date string, sts []byte) string { | ||
This comment has been minimized.
Sorry, something went wrong.
grsubramanian
|
||
key := make([]byte, 0, 100) | ||
key = append(key, "AWS4"...) | ||
key = append(key, secretKey...) | ||
|
||
h := hmac.New(sha256.New, key) | ||
h.Write([]byte(date)) // kDate | ||
|
||
key = h.Sum(key[:0]) | ||
h = hmac.New(sha256.New, key) | ||
h.Write([]byte(region)) // kRegion | ||
|
||
key = h.Sum(key[:0]) | ||
h = hmac.New(sha256.New, key) | ||
h.Write([]byte(service)) // kService | ||
|
||
key = h.Sum(key[:0]) | ||
h = hmac.New(sha256.New, key) | ||
h.Write(aws4requestBytes) // kSigning | ||
|
||
key = h.Sum(key[:0]) | ||
h = hmac.New(sha256.New, key) | ||
h.Write(sts) | ||
|
||
return hex.EncodeToString(h.Sum(key[:0])) | ||
} | ||
|
||
// aws-java-sdk-core/src/main/resources/com/amazonaws/partitions/endpoints.json | ||
var suffixes = []string{ | ||
".amazonaws.com", | ||
".amazonaws.com.cn", | ||
".c2s.ic.gov", | ||
".sc2s.sgov.gov", | ||
} | ||
|
||
// aws-java-sdk-core/src/main/java/com/amazonaws/partitions/PartitionMetadataProvider.java | ||
// tryGetRegionByEndpointDnsSuffix | ||
func identifyRegion(host string) (string, error) { | ||
for _, suffix := range suffixes { | ||
if strings.HasSuffix(host, suffix) { | ||
serviceRegion := strings.TrimSuffix(host, suffix) | ||
regionDot := strings.LastIndexByte(serviceRegion, '.') | ||
if regionDot == -1 { | ||
break | ||
} | ||
return serviceRegion[regionDot+1:], nil | ||
} | ||
} | ||
return "", fmt.Errorf("cannot determine the region in %+q", host) | ||
} |
I think it would be useful to have an example of how customers can wire in the IAM credentials so that it is discoverable by this SASL mechanism implementation.