From 81949733e88f9dfaece5e5a9e38bc3170593780c Mon Sep 17 00:00:00 2001 From: Cesar Ghali Date: Thu, 22 Feb 2018 10:00:00 -0800 Subject: [PATCH] Address comments --- credentials/alts/alts.go | 78 ++++++++++++++++++++++------------ credentials/alts/utils.go | 15 ++++--- credentials/alts/utils_test.go | 6 +-- 3 files changed, 62 insertions(+), 37 deletions(-) diff --git a/credentials/alts/alts.go b/credentials/alts/alts.go index c479ef1e53c1..a58816565a84 100644 --- a/credentials/alts/alts.go +++ b/credentials/alts/alts.go @@ -20,6 +20,7 @@ // encapsulates all the state needed by a client to authenticate with a server // using ALTS and make various assertions, e.g., about the client's identity, // role, or whether it is authorized to make a particular call. +// This package is experimental. package alts import ( @@ -27,6 +28,7 @@ import ( "flag" "fmt" "net" + "sync" "time" "golang.org/x/net/context" @@ -51,6 +53,15 @@ const ( var ( enableUntrustedALTS = flag.Bool("enable_untrusted_alts", false, "Enables ALTS in untrusted mode. Enabling this mode is risky since we cannot ensure that the application is running on GCP with a trusted handshaker service.") + once sync.Once + maxRPCVersion = &altspb.RpcProtocolVersions_Version{ + Major: protocolVersionMaxMajor, + Minor: protocolVersionMaxMinor, + } + minRPCVersion = &altspb.RpcProtocolVersions_Version{ + Major: protocolVersionMinMajor, + Minor: protocolVersionMinMinor, + } // ErrUntrustedPlatform is returned from ClientHandshake and // ServerHandshake is running on a platform where the trustworthiness of // the handshaker service is not guaranteed. @@ -58,7 +69,10 @@ var ( ) // AuthInfo exposes security information from the ALTS handshake to the -// application. +// application. This interface is to be implemented by ALTS. Users should not +// need a brand new implementation of this interface. For situations like +// testing, any new implementation should embed this interface. This allows +// ALTS to add new methods to this interface. type AuthInfo interface { // ApplicationProtocol returns application protocol negotiated for the // ALTS connection. @@ -77,8 +91,8 @@ type AuthInfo interface { PeerRPCVersions() *altspb.RpcProtocolVersions } -// altsTC is the credentials required for authenticating a connection using Google -// Transport Security. It implements credentials.TransportCredentials interface. +// altsTC is the credentials required for authenticating a connection using ALTS. +// It implements credentials.TransportCredentials interface. type altsTC struct { info *credentials.ProtocolInfo hsAddr string @@ -97,6 +111,11 @@ func NewServerALTS() credentials.TransportCredentials { } func newALTS(side core.Side, accounts []string) credentials.TransportCredentials { + // Make sure flags are parsed before accessing enableUntrustedALTS. + once.Do(func() { + flag.Parse() + vmOnGCP = isRunningOnGCP() + }) if *enableUntrustedALTS { grpclog.Warning("untrusted ALTS mode is enabled and we cannot guarantee the trustworthiness of the ALTS handshaker service.") } @@ -124,19 +143,29 @@ func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.C } // Do not close hsConn since it is shared with other handshakes. + // Possible context leak: + // The cancel function for the child context we create will only be + // called a non-nil error is returned. + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer func() { + if err != nil { + cancel() + } + }() + opts := handshaker.DefaultClientHandshakerOptions() opts.TargetServiceAccounts = g.accounts opts.RPCVersions = &altspb.RpcProtocolVersions{ - MaxRpcVersion: &altspb.RpcProtocolVersions_Version{ - Major: protocolVersionMaxMajor, - Minor: protocolVersionMaxMinor, - }, - MinRpcVersion: &altspb.RpcProtocolVersions_Version{ - Major: protocolVersionMinMajor, - Minor: protocolVersionMinMinor, - }, + MaxRpcVersion: maxRPCVersion, + MinRpcVersion: minRPCVersion, } chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts) + defer func() { + if err != nil { + chs.Close() + } + }() if err != nil { return nil, nil, err } @@ -171,16 +200,15 @@ func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.Au defer cancel() opts := handshaker.DefaultServerHandshakerOptions() opts.RPCVersions = &altspb.RpcProtocolVersions{ - MaxRpcVersion: &altspb.RpcProtocolVersions_Version{ - Major: protocolVersionMaxMajor, - Minor: protocolVersionMaxMinor, - }, - MinRpcVersion: &altspb.RpcProtocolVersions_Version{ - Major: protocolVersionMinMajor, - Minor: protocolVersionMinMinor, - }, + MaxRpcVersion: maxRPCVersion, + MinRpcVersion: minRPCVersion, } shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts) + defer func() { + if err != nil { + shs.Close() + } + }() if err != nil { return nil, nil, err } @@ -218,15 +246,11 @@ func (g *altsTC) OverrideServerName(serverNameOverride string) error { // compareRPCVersion returns 0 if v1 == v2, 1 if v1 > v2 and -1 if v1 < v2. func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int { switch { - case v1.GetMajor() > v2.GetMajor(): - fallthrough - case v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor(): + case v1.GetMajor() > v2.GetMajor(), + v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor(): return 1 - } - switch { - case v1.GetMajor() < v2.GetMajor(): - fallthrough - case v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor(): + case v1.GetMajor() < v2.GetMajor(), + v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor(): return -1 } return 0 diff --git a/credentials/alts/utils.go b/credentials/alts/utils.go index 15e3c24f8760..cd5be2e6a654 100644 --- a/credentials/alts/utils.go +++ b/credentials/alts/utils.go @@ -49,8 +49,8 @@ func (k platformError) Error() string { var ( // The following two variables will be reassigned in tests. - runningOS = runtime.GOOS - readerFunc = func() (io.Reader, error) { + runningOS = runtime.GOOS + manufacturerReader = func() (io.Reader, error) { switch runningOS { case "linux": return os.Open(linuxProductNameFile) @@ -72,10 +72,10 @@ var ( return nil, errors.New("cannot determine the machine's manufacturer") default: - panic(platformError(runningOS)) + return nil, platformError(runningOS) } } - vmOnGCP = isRunningOnGCP() + vmOnGCP bool ) // isRunningOnGCP checks whether the local system, without doing a network request is @@ -83,7 +83,7 @@ var ( func isRunningOnGCP() bool { manufacturer, err := readManufacturer() if err != nil { - log.Fatal(err) + log.Fatalf("failure to read manufacturer information: %v", err) } name := string(manufacturer) switch runningOS { @@ -96,12 +96,13 @@ func isRunningOnGCP() bool { name = strings.Replace(name, "\r", "", -1) return name == "Google" default: - panic(platformError(runningOS)) + log.Fatal(platformError(runningOS)) } + return false } func readManufacturer() ([]byte, error) { - reader, err := readerFunc() + reader, err := manufacturerReader() if err != nil { return nil, err } diff --git a/credentials/alts/utils_test.go b/credentials/alts/utils_test.go index 7aa57cc477a3..32c5e1bf4d04 100644 --- a/credentials/alts/utils_test.go +++ b/credentials/alts/utils_test.go @@ -51,16 +51,16 @@ func TestIsRunningOnGCP(t *testing.T) { func setup(testOS string, testReader io.Reader) func() { tmpOS := runningOS - tmpReader := readerFunc + tmpReader := manufacturerReader // Set test OS and reader function. runningOS = testOS - readerFunc = func() (io.Reader, error) { + manufacturerReader = func() (io.Reader, error) { return testReader, nil } return func() { runningOS = tmpOS - readerFunc = tmpReader + manufacturerReader = tmpReader } }