Skip to content

Commit

Permalink
feat(internaloption): add EnableDirectPath internaloption (#732)
Browse files Browse the repository at this point in the history
We want to make bigtable attempt DirectPath by default, instead of checking the environment variable GOOGLE_CLOUD_ENABLE_DIRECT_PATH. Notice that even after this PR, the real datapath is still CFE since ACL is currently denied for all projects.
  • Loading branch information
mohanli-ml authored Nov 2, 2020
1 parent 4870c18 commit baf33b2
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 54 deletions.
1 change: 1 addition & 0 deletions internal/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type DialSettings struct {
CustomClaims map[string]interface{}
SkipValidation bool
ImpersonationConfig *impersonate.Config
EnableDirectPath bool

// Google API system parameters. For more information please read:
// https://cloud.google.com/apis/docs/system-parameters
Expand Down
15 changes: 15 additions & 0 deletions option/internaloption/internaloption.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,18 @@ type skipDialSettingsValidation struct{}
func (s skipDialSettingsValidation) Apply(settings *internal.DialSettings) {
settings.SkipValidation = true
}

// EnableDirectPath returns a ClientOption that overrides the default
// attempt to use DirectPath.
//
// It should only be used internally by generated clients.
// This is an EXPERIMENTAL API and may be changed or removed in the future.
func EnableDirectPath(dp bool) option.ClientOption {
return enableDirectPath(dp)
}

type enableDirectPath bool

func (e enableDirectPath) Apply(o *internal.DialSettings) {
o.EnableDirectPath = bool(e)
}
23 changes: 8 additions & 15 deletions transport/grpc/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"crypto/tls"
"errors"
"log"
"os"
"strings"

"go.opencensus.io/plugin/ocgrpc"
Expand Down Expand Up @@ -138,9 +137,7 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C
// * The endpoint is a host:port (or dns:///host:port).
// * Credentials are obtained via GCE metadata server, using the default
// service account.
// * Opted in via GOOGLE_CLOUD_ENABLE_DIRECT_PATH environment variable.
// For example, GOOGLE_CLOUD_ENABLE_DIRECT_PATH=spanner,pubsub
if isDirectPathEnabled(endpoint) && isTokenSourceDirectPathCompatible(creds.TokenSource) {
if o.EnableDirectPath && checkDirectPathEndPoint(endpoint) && isTokenSourceDirectPathCompatible(creds.TokenSource) {
if !strings.HasPrefix(endpoint, "dns:///") {
endpoint = "dns:///" + endpoint
}
Expand Down Expand Up @@ -189,7 +186,7 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C
// point when isDirectPathEnabled will default to true, we guard it by
// the Directpath env var for now once we can introspect user defined
// dialer (https://github.com/grpc/grpc-go/issues/2795).
if timeoutDialerOption != nil && isDirectPathEnabled(endpoint) {
if timeoutDialerOption != nil && o.EnableDirectPath && checkDirectPathEndPoint(endpoint) {
grpcOpts = append(grpcOpts, timeoutDialerOption)
}

Expand Down Expand Up @@ -250,8 +247,8 @@ func isTokenSourceDirectPathCompatible(ts oauth2.TokenSource) bool {
return true
}

func isDirectPathEnabled(endpoint string) bool {
// Only host:port is supported, not other schemes (e.g., "tcp://" or "unix://").
func checkDirectPathEndPoint(endpoint string) bool {
// Only [dns:///]host[:port] is supported, not other schemes (e.g., "tcp://" or "unix://").
// Also don't try direct path if the user has chosen an alternate name resolver
// (i.e., via ":///" prefix).
//
Expand All @@ -261,15 +258,11 @@ func isDirectPathEnabled(endpoint string) bool {
return false
}

// Only try direct path if the user has opted in via the environment variable.
directPathAPIs := strings.Split(os.Getenv("GOOGLE_CLOUD_ENABLE_DIRECT_PATH"), ",")
for _, api := range directPathAPIs {
// Ignore empty string since an empty env variable splits into [""]
if api != "" && strings.Contains(endpoint, api) {
return true
}
if endpoint == "" {
return false
}
return false

return true
}

func processAndValidateOpts(opts []option.ClientOption) (*internal.DialSettings, error) {
Expand Down
8 changes: 3 additions & 5 deletions transport/grpc/dial_socketopt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ import (
"errors"
"fmt"
"net"
"os"
"syscall"
"testing"
"time"

"golang.org/x/oauth2"
"golang.org/x/sys/unix"
"google.golang.org/api/option"
"google.golang.org/api/option/internaloption"
"google.golang.org/grpc"
)

Expand Down Expand Up @@ -90,9 +90,6 @@ func getTCPUserTimeout(conn net.Conn) (int, error) {

// Check that tcp timeout dialer overwrites user defined dialer.
func TestDialWithDirectPathEnabled(t *testing.T) {
os.Setenv("GOOGLE_CLOUD_ENABLE_DIRECT_PATH", "example,other")
defer os.Clearenv()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)

userDialer := grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
Expand All @@ -104,7 +101,8 @@ func TestDialWithDirectPathEnabled(t *testing.T) {
conn, err := Dial(ctx,
option.WithTokenSource(oauth2.StaticTokenSource(nil)), // No creds.
option.WithGRPCDialOption(userDialer),
option.WithEndpoint("example.google.com:443"))
option.WithEndpoint("example.google.com:443"),
internaloption.EnableDirectPath(true))
if err != nil {
t.Errorf("DialGRPC: error %v, want nil", err)
}
Expand Down
45 changes: 11 additions & 34 deletions transport/grpc/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"context"
"errors"
"net"
"os"
"testing"
"time"

Expand Down Expand Up @@ -55,62 +54,40 @@ func TestGRPCHook(t *testing.T) {
}
}

func TestIsDirectPathEnabled(t *testing.T) {
func TestCheckDirectPathEndPoint(t *testing.T) {
for _, testcase := range []struct {
name string
endpoint string
envVar string
want bool
}{
{
name: "matches",
endpoint: "some-api",
envVar: "some-api",
want: true,
},
{
name: "does not match",
endpoint: "some-api",
envVar: "some-other-api",
name: "empty endpoint are disallowed",
endpoint: "",
want: false,
},
{
name: "matches in list",
endpoint: "some-api-2",
envVar: "some-api-1,some-api-2,some-api-3",
name: "dns schemes are allowed",
endpoint: "dns:///foo",
want: true,
},
{
name: "empty env var",
endpoint: "",
envVar: "",
want: false,
},
{
name: "trailing comma",
endpoint: "",
envVar: "foo,bar,",
want: false,
name: "host without no prefix are allowed",
endpoint: "foo",
want: true,
},
{
name: "dns schemes are allowed",
endpoint: "dns:///foo",
envVar: "dns:///foo",
name: "host with port are allowed",
endpoint: "foo:1234",
want: true,
},
{
name: "non-dns schemes are disallowed",
endpoint: "https://foo",
envVar: "https://foo",
want: false,
},
} {
t.Run(testcase.name, func(t *testing.T) {
if err := os.Setenv("GOOGLE_CLOUD_ENABLE_DIRECT_PATH", testcase.envVar); err != nil {
t.Fatal(err)
}

if got := isDirectPathEnabled(testcase.endpoint); got != testcase.want {
if got := checkDirectPathEndPoint(testcase.endpoint); got != testcase.want {
t.Fatalf("got %v, want %v", got, testcase.want)
}
})
Expand Down

0 comments on commit baf33b2

Please sign in to comment.