Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions connection_properties.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,15 @@ var propertyDisableStatementCache = createConnectionProperty(
connectionstate.ContextStartup,
connectionstate.ConvertBool,
)
var propertyConnectTimeout = createConnectionProperty(
"connect_timeout",
"The amount of time to wait before timing out when creating a new connection.",
0,
false,
nil,
connectionstate.ContextStartup,
connectionstate.ConvertDuration,
)

// Generated read-only properties. These cannot be set by the user anywhere.
var propertyCommitTimestamp = createReadOnlyConnectionProperty(
Expand Down
2 changes: 1 addition & 1 deletion connectionstate/converters.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func parseTimestamp(re *regexp.Regexp, params string) (time.Time, error) {
func parseDuration(re *regexp.Regexp, value string) (time.Duration, error) {
matches := matchesToMap(re, value)
if matches["duration"] == "" && matches["number"] == "" && matches["null"] == "" {
return 0, spanner.ToSpannerError(status.Error(codes.InvalidArgument, fmt.Sprintf("No duration found: %v", value)))
return 0, spanner.ToSpannerError(status.Error(codes.InvalidArgument, fmt.Sprintf("No or invalid duration found: %v", value)))
}
if matches["duration"] != "" {
d, err := time.ParseDuration(matches["duration"])
Expand Down
11 changes: 11 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,17 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) {
c.connectorConfig.Project,
c.connectorConfig.Instance,
c.connectorConfig.Database)
if value, ok := c.initialPropertyValues[propertyConnectTimeout.Key()]; ok {
if timeout, err := value.GetValue(); err == nil {
if duration, ok := timeout.(time.Duration); ok {
var cancel context.CancelFunc
// This will set the actual timeout of the context to the lower of the
// current context timeout (if any) and the value from the connection property.
ctx, cancel = context.WithTimeout(ctx, duration)
defer cancel()
}
}
}

if err := c.increaseConnCount(ctx, databaseName, opts); err != nil {
return nil, err
Expand Down
41 changes: 41 additions & 0 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5555,6 +5555,47 @@ func TestReturnResultSetMetadataAndStats(t *testing.T) {
}
}

func TestConnectTimeout(t *testing.T) {
t.Parallel()

server, _, serverTeardown := setupMockedTestServerWithDialect(t, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL)
defer serverTeardown()
db, err := sql.Open(
"spanner",
fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true;connect_timeout=1ms", server.Address))
if err != nil {
t.Fatal(err)
}
defer silentClose(db)

// Make the ExecuteStreamingSql method a bit slow, so the query that is used to detect the dialect responds a bit slowly.
server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{MinimumExecutionTime: time.Millisecond * 10})

// Try to get/create a connection using a context without a deadline.
// This will cause the connect_timeout to be used.
c, err := db.Conn(context.Background())
if g, w := spanner.ErrCode(err), codes.DeadlineExceeded; g != w {
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
} else if c != nil {
_ = c.Close()
}
}

func TestInvalidConnectTimeout(t *testing.T) {
t.Parallel()

server, _, serverTeardown := setupMockedTestServerWithDialect(t, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL)
defer serverTeardown()
db, err := sql.Open(
"spanner",
fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true;connect_timeout='very long'", server.Address))
if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w {
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
} else if db != nil {
defer silentClose(db)
}
}

func numeric(v string) big.Rat {
res, _ := big.NewRat(1, 1).SetString(v)
return *res
Expand Down
Loading