diff --git a/bigtable/bigtable.go b/bigtable/bigtable.go index 3eb40f97fa3a..1915c401ca10 100644 --- a/bigtable/bigtable.go +++ b/bigtable/bigtable.go @@ -33,6 +33,7 @@ import ( btpb "cloud.google.com/go/bigtable/apiv2/bigtablepb" btopt "cloud.google.com/go/bigtable/internal/option" + btransport "cloud.google.com/go/bigtable/internal/transport" "cloud.google.com/go/internal/trace" gax "github.com/googleapis/gax-go/v2" "github.com/googleapis/gax-go/v2/apierror" @@ -55,6 +56,8 @@ const ( queryExpiredViolationType = "PREPARED_QUERY_EXPIRED" preparedQueryExpireEarlyDuration = time.Second methodNameReadRows = "ReadRows" + // Cannot extract extract d.GRPCConnPoolSize as DialSettings is in internal grpc pacakage + defaultBigtableConnPoolSize = 10 // For routing cookie cookiePrefix = "x-goog-cbt-cookie-" @@ -156,11 +159,6 @@ func NewClientWithConfig(ctx context.Context, project, instance string, config C // TODO(b/372244283): Remove after b/358175516 has been fixed o = append(o, internaloption.EnableAsyncRefreshDryRun(metricsTracerFactory.newAsyncRefreshErrHandler())) - connPool, err := gtransport.DialPool(ctx, o...) - if err != nil { - return nil, err - } - disableRetryInfo := false // If DISABLE_RETRY_INFO=1, library does not base retry decision and back off time on server returned RetryInfo value. @@ -172,6 +170,23 @@ func NewClientWithConfig(ctx context.Context, project, instance string, config C retryOption = clientOnlyRetryOption executeQueryRetryOption = clientOnlyExecuteQueryRetryOption } + + var connPool gtransport.ConnPool + var connPoolErr error + enableBigtableConnPool := btopt.EnableBigtableConnectionPool() + if enableBigtableConnPool { + connPool, connPoolErr = btransport.NewBigtableChannelPool(defaultBigtableConnPoolSize, btopt.BigtableLoadBalancingStrategy(), func() (*grpc.ClientConn, error) { + return gtransport.Dial(ctx, o...) + }) + } else { + // use to regular ConnPool + connPool, connPoolErr = gtransport.DialPool(ctx, o...) + } + + if connPoolErr != nil { + return nil, connPoolErr + } + return &Client{ connPool: connPool, client: btpb.NewBigtableClient(connPool), diff --git a/bigtable/internal/option/option.go b/bigtable/internal/option/option.go index d62c6ff18ae1..e65a86b63228 100644 --- a/bigtable/internal/option/option.go +++ b/bigtable/internal/option/option.go @@ -22,6 +22,7 @@ import ( "fmt" "os" "strconv" + "strings" "time" "cloud.google.com/go/bigtable/internal" @@ -34,6 +35,19 @@ import ( "google.golang.org/grpc/metadata" ) +const ( + // LoadBalancingStrategyEnvVar is the environment variable to control the gRPC load balancing strategy. + LoadBalancingStrategyEnvVar = "CBT_LOAD_BALANCING_STRATEGY" + // RoundRobinLBPolicy is the policy name for round-robin. + RoundRobinLBPolicy = "round_robin" + // LeastInFlightLBPolicy is the policy name for least in flight (custom). + LeastInFlightLBPolicy = "least_in_flight" + // PowerOfTwoLeastInFlightLBPolicy is the policy name for power of two least in flight (custom). + PowerOfTwoLeastInFlightLBPolicy = "power_of_two_least_in_flight" + // BigtableConnectionPoolEnvVar is the env var for enabling Bigtable Connection Pool. + BigtableConnectionPoolEnvVar = "CBT_BIGTABLE_CONN_POOL" +) + // mergeOutgoingMetadata returns a context populated by the existing outgoing // metadata merged with the provided mds. func mergeOutgoingMetadata(ctx context.Context, mds ...metadata.MD) context.Context { @@ -124,3 +138,66 @@ func ClientInterceptorOptions(stream []grpc.StreamClientInterceptor, unary []grp option.WithGRPCDialOption(grpc.WithChainUnaryInterceptor(unary...)), } } + +// LoadBalancingStrategy for connection pool. +type LoadBalancingStrategy int + +const ( + // RoundRobin is the round_robin gRPC load balancing policy. + RoundRobin LoadBalancingStrategy = iota + // LeastInFlight is the least_in_flight gRPC load balancing policy (custom). + LeastInFlight + // PowerOfTwoLeastInFlight is the power_of_two_least_in_flight gRPC load balancing policy (custom). + PowerOfTwoLeastInFlight +) + +// String returns the string representation of the LoadBalancingStrategy. +func (s LoadBalancingStrategy) String() string { + switch s { + case LeastInFlight: + return "least_in_flight" + case PowerOfTwoLeastInFlight: + return "power_of_two_least_in_flight" + case RoundRobin: + return "round_robin" + default: + return "round_robin" // Default + } +} + +// parseLoadBalancingStrategy parses the string from the environment variable +// into a LoadBalancingStrategy enum value. +func parseLoadBalancingStrategy(strategyStr string) LoadBalancingStrategy { + switch strings.ToUpper(strategyStr) { + case "LEAST_IN_FLIGHT": + return LeastInFlight + case "POWER_OF_TWO_LEAST_IN_FLIGHT": + return PowerOfTwoLeastInFlight + case "ROUND_ROBIN": + return RoundRobin + case "": + return RoundRobin // Default if env var is not set + default: + return RoundRobin // Default for unknown values + } +} + +// BigtableLoadBalancingStrategy returns the gRPC service config JSON string for the chosen policy. +func BigtableLoadBalancingStrategy() LoadBalancingStrategy { + strategyStr := os.Getenv(LoadBalancingStrategyEnvVar) + return parseLoadBalancingStrategy(strategyStr) +} + +// EnableBigtableConnectionPool uses new conn pool if envVar is set. +func EnableBigtableConnectionPool() bool { + bigtableConnPoolEnvVal := os.Getenv(BigtableConnectionPoolEnvVar) + if bigtableConnPoolEnvVal == "" { + return false + } + enableBigtableConnPool, err := strconv.ParseBool(bigtableConnPoolEnvVal) + if err != nil { + // just fail and use default conn pool + return false + } + return enableBigtableConnPool +} diff --git a/bigtable/internal/transport/connpool.go b/bigtable/internal/transport/connpool.go new file mode 100644 index 000000000000..3697d4880fd4 --- /dev/null +++ b/bigtable/internal/transport/connpool.go @@ -0,0 +1,290 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "math/rand" + "sync" + "sync/atomic" + + gtransport "google.golang.org/api/transport/grpc" + + btopt "cloud.google.com/go/bigtable/internal/option" + + "google.golang.org/grpc" +) + +var errNoConnections = fmt.Errorf("bigtable_connpool: no connections available in the pool") +var _ gtransport.ConnPool = &BigtableChannelPool{} + +// BigtableChannelPool implements ConnPool and routes requests to the connection +// pool according to load balancing strategy. +// +// To benefit from automatic load tracking, use the Invoke and NewStream methods +// directly on the BigtableChannelPool instance. +type BigtableChannelPool struct { + conns []*grpc.ClientConn + load []int64 // Tracks active requests per connection + + // Mutex is only used for selecting the least loaded connection. + // The load array itself is manipulated using atomic operations. + mu sync.Mutex + dial func() (*grpc.ClientConn, error) + strategy btopt.LoadBalancingStrategy + rrIndex uint64 // For round-robin selection + selectFunc func() (int, error) // Stored function for connection selection + +} + +// NewBigtableChannelPool creates a pool of connPoolSize and takes the dial func() +func NewBigtableChannelPool(connPoolSize int, strategy btopt.LoadBalancingStrategy, dial func() (*grpc.ClientConn, error)) (*BigtableChannelPool, error) { + if connPoolSize <= 0 { + return nil, fmt.Errorf("bigtable_connpool: connPoolSize must be positive") + } + + if dial == nil { + return nil, fmt.Errorf("bigtable_connpool: dial function cannot be nil") + } + pool := &BigtableChannelPool{ + dial: dial, + strategy: strategy, + rrIndex: 0, + } + + // Set the selection function based on the strategy + switch strategy { + case btopt.LeastInFlight: + pool.selectFunc = pool.selectLeastLoaded + case btopt.PowerOfTwoLeastInFlight: + pool.selectFunc = pool.selectLeastLoadedRandomOfTwo + default: // RoundRobin is the default + pool.selectFunc = pool.selectRoundRobin + } + + for i := 0; i < connPoolSize; i++ { + conn, err := dial() + if err != nil { + defer pool.Close() + return nil, err + } + pool.conns = append(pool.conns, conn) + pool.load = append(pool.load, 0) + + } + return pool, nil + +} + +// Num returns the number of connections in the pool. +func (p *BigtableChannelPool) Num() int { + return len(p.conns) +} + +// Close closes all connections in the pool. +func (p *BigtableChannelPool) Close() error { + var errs multiError + for _, conn := range p.conns { + if err := conn.Close(); err != nil { + errs = append(errs, err) + } + } + if len(errs) == 0 { + return nil + } + return errs +} + +// Invoke selects the least loaded connection and calls Invoke on it. +// This method provides automatic load tracking. +func (p *BigtableChannelPool) Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...grpc.CallOption) error { + index, err := p.selectFunc() + if err != nil { + return err + } + conn := p.conns[index] + + atomic.AddInt64(&p.load[index], 1) + defer atomic.AddInt64(&p.load[index], -1) + + return conn.Invoke(ctx, method, args, reply, opts...) +} + +// Conn provides connbased on selectfunc() +func (p *BigtableChannelPool) Conn() *grpc.ClientConn { + index, err := p.selectFunc() + if err != nil { + // no conn available + return nil + } + return p.conns[index] +} + +// NewStream selects the least loaded connection and calls NewStream on it. +// This method provides automatic load tracking via a wrapped stream. +func (p *BigtableChannelPool) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + index, err := p.selectFunc() + if err != nil { + return nil, err + } + conn := p.conns[index] + + atomic.AddInt64(&p.load[index], 1) + + stream, err := conn.NewStream(ctx, desc, method, opts...) + + if err != nil { + atomic.AddInt64(&p.load[index], -1) // Decrement if stream creation failed + return nil, err + } + + // Wrap the stream to decrement load when the stream finishes. + return &refCountedStream{ + ClientStream: stream, + pool: p, + connIndex: index, + once: sync.Once{}, + }, nil +} + +// selectLeastLoadedRandomOfTwo() returns the index of the connection via random of two +func (p *BigtableChannelPool) selectLeastLoadedRandomOfTwo() (int, error) { + numConns := p.Num() + if numConns == 0 { + return -1, errNoConnections + } + if numConns == 1 { + return 0, nil + } + + // Pick two distinct random indices + idx1 := rand.Intn(numConns) + idx2 := rand.Intn(numConns) + // Simple way to ensure they are different for small numConns. + // For very large numConns, the chance of collision is low, + // but a loop is safer. + for idx2 == idx1 { + idx2 = rand.Intn(numConns) + } + + load1 := atomic.LoadInt64(&p.load[idx1]) + load2 := atomic.LoadInt64(&p.load[idx2]) + + if load1 <= load2 { + return idx1, nil + } + return idx2, nil +} + +func (p *BigtableChannelPool) selectRoundRobin() (int, error) { + numConns := p.Num() + if numConns == 0 { + return -1, errNoConnections + } + if numConns == 1 { + return 0, nil + } + + // Atomically increment and get the next index + nextIndex := atomic.AddUint64(&p.rrIndex, 1) - 1 + return int(nextIndex % uint64(numConns)), nil +} + +// selectLeastLoaded returns the index of the connection with the minimum load. +func (p *BigtableChannelPool) selectLeastLoaded() (int, error) { + numConns := p.Num() + + if numConns == 0 { + return -1, errNoConnections + } + + p.mu.Lock() + defer p.mu.Unlock() + + minIndex := 0 + minLoad := atomic.LoadInt64(&p.load[0]) + + for i := 1; i < p.Num(); i++ { + currentLoad := atomic.LoadInt64(&p.load[i]) + if currentLoad < minLoad { + minLoad = currentLoad + minIndex = i + } + } + return minIndex, nil +} + +// refCountedStream wraps a grpc.ClientStream to decrement the load count when the stream is done. +// refCountedStream in this BigtableConnectionPool is to hook into the stream's lifecycle +// to decrement the load counter (s.pool.load[s.connIndex]) when the stream is no longer usable. +// This is primarily detected by errors occurring during SendMsg or RecvMsg (including io.EOF on RecvMsg). + +// Another option would have been to use grpc.OnFinish for streams is about the timing of when the load should be considered "finished". +// The grpc.OnFinish callback is executed only when the entire stream is fully closed and the final status is determined. +type refCountedStream struct { + grpc.ClientStream + pool *BigtableChannelPool + connIndex int + once sync.Once +} + +// SendMsg calls the embedded stream's SendMsg method. +func (s *refCountedStream) SendMsg(m interface{}) error { + err := s.ClientStream.SendMsg(m) + if err != nil { + s.decrementLoad() + } + return err +} + +// RecvMsg calls the embedded stream's RecvMsg method and decrements load on error. +func (s *refCountedStream) RecvMsg(m interface{}) error { + err := s.ClientStream.RecvMsg(m) + if err != nil { // io.EOF is also an error, indicating stream end. + s.decrementLoad() + } + return err +} + +// decrementLoad ensures the load count is decremented exactly once. +func (s *refCountedStream) decrementLoad() { + s.once.Do(func() { + atomic.AddInt64(&s.pool.load[s.connIndex], -1) + }) +} + +type multiError []error + +func (m multiError) Error() string { + s, n := "", 0 + for _, e := range m { + if e != nil { + if n == 0 { + s = e.Error() + } + n++ + } + } + switch n { + case 0: + return "(0 errors)" + case 1: + return s + case 2: + return s + " (and 1 other error)" + } + return fmt.Sprintf("%s (and %d other errors)", s, n-1) +} diff --git a/bigtable/internal/transport/connpool_test.go b/bigtable/internal/transport/connpool_test.go new file mode 100644 index 000000000000..4381752d115b --- /dev/null +++ b/bigtable/internal/transport/connpool_test.go @@ -0,0 +1,629 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + btopt "cloud.google.com/go/bigtable/internal/option" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/status" +) + +type fakeService struct { + testgrpc.UnimplementedBenchmarkServiceServer + mu sync.Mutex + callCount int + streamSema chan struct{} // To control stream lifetime + delay time.Duration // To simulate work + serverErr error // Error to return from server +} + +func (s *fakeService) UnaryCall(ctx context.Context, req *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + s.mu.Lock() + s.callCount++ + s.mu.Unlock() + if s.delay > 0 { + time.Sleep(s.delay) + } + if s.serverErr != nil { + return nil, s.serverErr + } + // Echo the payload + return &testpb.SimpleResponse{Payload: req.GetPayload()}, nil +} + +func (s *fakeService) StreamingCall(stream testpb.BenchmarkService_StreamingCallServer) error { + s.mu.Lock() + s.callCount++ + s.mu.Unlock() + + if s.serverErr != nil { + return s.serverErr + } + + if s.streamSema != nil { + <-s.streamSema // Wait until released + } + + for { + req, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + if err := stream.Send(&testpb.SimpleResponse{Payload: req.GetPayload()}); err != nil { + return err + } + } +} + +func (s *fakeService) getCallCount() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.callCount +} + +func setupTestServer(t *testing.T, service *fakeService) string { + t.Helper() + lis, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + + } + srv := grpc.NewServer() + testgrpc.RegisterBenchmarkServiceServer(srv, service) + go func() { + if err := srv.Serve(lis); err != nil { + // t.Logf is used here as t.Fatalf cannot be called in a separate goroutine + t.Logf("gRPC server error: %v", err) + } + }() + + t.Cleanup(func() { + srv.Stop() + lis.Close() + }) + + return lis.Addr().String() +} + +func TestSelectRoundRobin(t *testing.T) { + pool := &BigtableChannelPool{} + + // Test empty pool + idx, err := pool.selectRoundRobin() + if idx != -1 { + t.Errorf("selectRoundRobin on empty pool got index %d, want -1", idx) + } + if err == nil { + t.Errorf("selectRoundRobin on empty pool got nil error, want non-nil") + } + if err != errNoConnections { + t.Errorf("selectRoundRobin on empty pool got error %v, want %v", err, errNoConnections) + } + + // Test single connection pool + pool.conns = make([]*grpc.ClientConn, 1) + pool.load = make([]int64, 1) + idx, err = pool.selectRoundRobin() + if idx != 0 { + t.Errorf("selectRoundRobin on single conn pool got index %d, want 0", idx) + } + if err != nil { + t.Errorf("selectRoundRobin on single conn pool got error %v, want nil", err) + } + + // Test multiple connections + poolSize := 3 + pool.conns = make([]*grpc.ClientConn, poolSize) + pool.load = make([]int64, poolSize) + pool.rrIndex = 0 + + // Test wrapping around + for i := 0; i < poolSize*2; i++ { + expectedIdx := i % poolSize + idx, err = pool.selectRoundRobin() + if idx != expectedIdx { + t.Errorf("selectRoundRobin call %d got index %d, want %d", i+1, idx, expectedIdx) + } + if err != nil { + t.Errorf("selectRoundRobin call %d got error %v, want nil", i+1, err) + } + } +} + +func TestSelectLeastLoadedRandomOfTwo(t *testing.T) { + pool := &BigtableChannelPool{} + + // Test empty pool + idx, err := pool.selectLeastLoadedRandomOfTwo() + if idx != -1 { + t.Errorf("selectLeastLoadedRandomOfTwo on empty pool got index %d, want -1", idx) + } + if err == nil { + t.Errorf("selectLeastLoadedRandomOfTwo on empty pool got nil error, want non-nil") + } + if err != errNoConnections { + t.Errorf("selectLeastLoadedRandomOfTwo on empty pool got error %v, want %v", err, errNoConnections) + } + + // Test single connection pool + pool.conns = make([]*grpc.ClientConn, 1) + pool.load = make([]int64, 1) + idx, err = pool.selectLeastLoadedRandomOfTwo() + if idx != 0 { + t.Errorf("selectLeastLoadedRandomOfTwo on single conn pool got index %d, want 0", idx) + } + if err != nil { + t.Errorf("selectLeastLoadedRandomOfTwo on single conn pool got error %v, want nil", err) + } + + // Test multiple connections + pool.conns = make([]*grpc.ClientConn, 5) + pool.load = []int64{10, 2, 30, 4, 50} // Loads for indices 0, 1, 2, 3, 4 + + for i := 0; i < 100; i++ { // Run multiple times due to randomness + idx, err = pool.selectLeastLoadedRandomOfTwo() + if err != nil { + t.Fatalf("selectLeastLoadedRandomOfTwo got unexpected error: %v", err) + } + if idx < 0 || idx >= len(pool.conns) { + t.Fatalf("Selected index %d is out of bounds", idx) + } + } + + // Test case where loads are distinct + pool.load = []int64{5, 1, 10} + pool.conns = make([]*grpc.ClientConn, 3) + for i := 0; i < 100; i++ { + idx, err = pool.selectLeastLoadedRandomOfTwo() + if err != nil { + t.Errorf("selectLeastLoadedRandomOfTwo got unexpected error: %v", err) + continue + } + if idx < 0 || idx >= 3 { + t.Errorf("selectLeastLoadedRandomOfTwo got index %d, want index in [0, 2]", idx) + continue + } + } + + // Test with all equal loads + pool.load = []int64{5, 5, 5} + for i := 0; i < 100; i++ { + idx, err = pool.selectLeastLoadedRandomOfTwo() + if err != nil { + t.Errorf("selectLeastLoadedRandomOfTwo got unexpected error: %v", err) + continue + } + if idx < 0 || idx >= 3 { + t.Errorf("Index %d out of bounds", idx) + } + } +} + +func TestNewLeastLoadedChannelPool(t *testing.T) { + t.Run("SuccessfulCreation", func(t *testing.T) { + poolSize := 5 + fake := &fakeService{} + addr := setupTestServer(t, fake) + + dialFunc := func() (*grpc.ClientConn, error) { + return grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + pool, err := NewBigtableChannelPool(poolSize, btopt.LeastInFlight, dialFunc) + if err != nil { + t.Fatalf("NewBigtableChannelPool failed: %v", err) + } + defer pool.Close() + + if pool.Num() != poolSize { + t.Errorf("Pool size got %d, want %d", pool.Num(), poolSize) + } + for i, conn := range pool.conns { + if conn == nil { + t.Errorf("conn at index %d is nil", i) + } + } + }) + + t.Run("DialFailure", func(t *testing.T) { + poolSize := 3 + dialCount := 0 + dialFunc := func() (*grpc.ClientConn, error) { + dialCount++ + if dialCount > 1 { + return nil, errors.New("simulated dial error") + } + fake := &fakeService{} + addr := setupTestServer(t, fake) + return grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + _, err := NewBigtableChannelPool(poolSize, btopt.LeastInFlight, dialFunc) + if err == nil { + t.Errorf("NewBigtableChannelPool should have failed due to dial error, but got no error") + } + }) +} + +func TestPoolInvoke(t *testing.T) { + strategies := []btopt.LoadBalancingStrategy{ + btopt.LeastInFlight, + btopt.RoundRobin, + btopt.PowerOfTwoLeastInFlight, + } + + for _, strategy := range strategies { + t.Run(fmt.Sprintf("Strategy_%s", strategy), func(t *testing.T) { + poolSize := 3 + fake := &fakeService{} + addr := setupTestServer(t, fake) + dialFunc := func() (*grpc.ClientConn, error) { + return grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + pool, err := NewBigtableChannelPool(poolSize, strategy, dialFunc) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Close() + + req := &testpb.SimpleRequest{Payload: &testpb.Payload{Body: []byte("hello")}} + res := &testpb.SimpleResponse{} + if err := pool.Invoke(context.Background(), "/grpc.testing.BenchmarkService/UnaryCall", req, res); err != nil { + t.Errorf("Invoke failed: %v", err) + } + if string(res.GetPayload().GetBody()) != "hello" { + t.Errorf("Invoke response got %q, want %q", string(res.GetPayload().GetBody()), "hello") + } + if fake.getCallCount() != 1 { + t.Errorf("Server call count got %d, want 1", fake.getCallCount()) + } + + for i, load := range pool.load { + if load != 0 { + t.Errorf("Load at index %d is non-zero after Invoke: %d", i, load) + } + } + }) + } +} + +func TestPoolNewStream(t *testing.T) { + strategies := []btopt.LoadBalancingStrategy{ + btopt.LeastInFlight, + btopt.RoundRobin, + btopt.PowerOfTwoLeastInFlight, + } + + for _, strategy := range strategies { + t.Run(fmt.Sprintf("Strategy_%s", strategy), func(t *testing.T) { + poolSize := 2 + fake := &fakeService{} + addr := setupTestServer(t, fake) + dialFunc := func() (*grpc.ClientConn, error) { + return grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + pool, err := NewBigtableChannelPool(poolSize, strategy, dialFunc) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Close() + + ctx := context.Background() + stream, err := pool.NewStream(ctx, &grpc.StreamDesc{StreamName: "StreamingCall"}, "/grpc.testing.BenchmarkService/StreamingCall") + if err != nil { + t.Fatalf("NewStream failed: %v", err) + } + + loadSum := int64(0) + for _, l := range pool.load { + loadSum += l + } + if loadSum != 1 { + t.Errorf("Total load after NewStream got %d, want 1. Loads: %v", loadSum, pool.load) + } + + req := &testpb.SimpleRequest{Payload: &testpb.Payload{Body: []byte("msg1")}} + if err := stream.SendMsg(req); err != nil { + t.Fatalf("SendMsg failed: %v", err) + } + res := &testpb.SimpleResponse{} + if err := stream.RecvMsg(res); err != nil { + t.Fatalf("RecvMsg failed: %v", err) + } + if string(res.GetPayload().GetBody()) != "msg1" { + t.Errorf("RecvMsg got %q, want %q", string(res.GetPayload().GetBody()), "msg1") + } + + if err := stream.CloseSend(); err != nil { + t.Fatalf("CloseSend failed: %v", err) + } + + if err := stream.RecvMsg(res); err != io.EOF { + t.Errorf("Expected io.EOF after CloseSend, got %v", err) + } + + time.Sleep(10 * time.Millisecond) + loadSum = int64(0) + for i, l := range pool.load { + if l < 0 { + t.Errorf("Load at index %d went negative: %d", i, l) + } + loadSum += l + } + if loadSum != 0 { + t.Errorf("Total load after stream completion got %d, want 0. Loads: %v", loadSum, pool.load) + } + }) + } +} + +func TestSelectLeastLoaded(t *testing.T) { + pool := &BigtableChannelPool{} + + // Test empty pool + idx, err := pool.selectLeastLoaded() + if idx != -1 { + t.Errorf("selectLeastLoaded on empty pool got index %d, want -1", idx) + } + if err == nil { + t.Errorf("selectLeastLoaded on empty pool got nil error, want non-nil") + } + if err != errNoConnections { + t.Errorf("selectLeastLoaded on empty pool got error %v, want %v", err, errNoConnections) + } + + // Test single connection pool + pool.conns = make([]*grpc.ClientConn, 1) + pool.load = make([]int64, 1) + idx, err = pool.selectLeastLoaded() + if idx != 0 { + t.Errorf("selectLeastLoaded on single conn pool got index %d, want 0", idx) + } + if err != nil { + t.Errorf("selectLeastLoaded on single conn pool got error %v, want nil", err) + } + + // Test multiple connections + pool.conns = make([]*grpc.ClientConn, 5) + pool.load = []int64{3, 1, 4, 1, 5} + idx, err = pool.selectLeastLoaded() + if idx != 1 { + t.Errorf("selectLeastLoadedIterative got index %d, want 1 for loads %v", idx, pool.load) + } + if err != nil { + t.Errorf("selectLeastLoadedIterative got error %v, want nil for loads %v", err, pool.load) + } +} + +func TestPoolClose(t *testing.T) { + poolSize := 2 + fake := &fakeService{} + addr := setupTestServer(t, fake) + dialFunc := func() (*grpc.ClientConn, error) { + return grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + pool, err := NewBigtableChannelPool(poolSize, btopt.LeastInFlight, dialFunc) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + + if err := pool.Close(); err != nil { + t.Errorf("Close failed: %v", err) + } +} + +func TestMultipleStreamsSingleConn(t *testing.T) { + poolSize := 1 // Force all streams to use the same connection + fake := &fakeService{} + addr := setupTestServer(t, fake) + dialFunc := func() (*grpc.ClientConn, error) { + return grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + pool, err := NewBigtableChannelPool(poolSize, btopt.LeastInFlight, dialFunc) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Close() + + numStreams := 5 + streams := make([]grpc.ClientStream, numStreams) + ctx := context.Background() + + // Open streams and check load + for i := 0; i < numStreams; i++ { + stream, err := pool.NewStream(ctx, &grpc.StreamDesc{StreamName: "StreamingCall"}, "/grpc.testing.BenchmarkService/StreamingCall") + if err != nil { + t.Fatalf("NewStream %d failed: %v", i, err) + } + streams[i] = stream + expectedLoad := int64(i + 1) + if atomic.LoadInt64(&pool.load[0]) != expectedLoad { + t.Errorf("Load after opening stream %d is %d, want %d", i, atomic.LoadInt64(&pool.load[0]), expectedLoad) + } + } + + // Basic interaction with each stream + for i, stream := range streams { + msg := fmt.Sprintf("stream%d", i) + req := &testpb.SimpleRequest{Payload: &testpb.Payload{Body: []byte(msg)}} + if err := stream.SendMsg(req); err != nil { + t.Errorf("SendMsg on stream %d failed: %v", i, err) + } + res := &testpb.SimpleResponse{} + if err := stream.RecvMsg(res); err != nil { + t.Errorf("RecvMsg on stream %d failed: %v", i, err) + } + if string(res.GetPayload().GetBody()) != msg { + t.Errorf("RecvMsg on stream %d got %q, want %q", i, string(res.GetPayload().GetBody()), msg) + } + } + + if fake.getCallCount() != numStreams { + t.Errorf("Server call count got %d, want %d", fake.getCallCount(), numStreams) + } + + // Close streams and check load + for i, stream := range streams { + if err := stream.CloseSend(); err != nil { + t.Errorf("CloseSend on stream %d failed: %v", i, err) + } + // Drain the stream + for { + if err := stream.RecvMsg(&testpb.SimpleResponse{}); err != nil { + if err != io.EOF { + t.Errorf("RecvMsg on stream %d after close failed unexpectedly: %v", i, err) + } + break + } + } + time.Sleep(10 * time.Millisecond) // Allow decrement to propagate + + expectedLoad := int64(numStreams - 1 - i) + if atomic.LoadInt64(&pool.load[0]) != expectedLoad { + t.Errorf("Load after closing stream %d is %d, want %d", i, atomic.LoadInt64(&pool.load[0]), expectedLoad) + } + } + + if atomic.LoadInt64(&pool.load[0]) != 0 { + t.Errorf("Final load is %d, want 0", atomic.LoadInt64(&pool.load[0])) + } +} + +func TestCachingStreamDecrement(t *testing.T) { + poolSize := 1 + fake := &fakeService{} + addr := setupTestServer(t, fake) + dialFunc := func() (*grpc.ClientConn, error) { + return grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + pool, err := NewBigtableChannelPool(poolSize, btopt.LeastInFlight, dialFunc) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Close() + + t.Run("DecrementOnRecvError", func(t *testing.T) { + fake.serverErr = errors.New("stream recv error") + defer func() { fake.serverErr = nil }() + + ctx := context.Background() + stream, err := pool.NewStream(ctx, &grpc.StreamDesc{StreamName: "StreamingCall"}, "/grpc.testing.BenchmarkService/StreamingCall") + if err != nil { + t.Fatalf("NewStream failed: %v", err) + } + if atomic.LoadInt64(&pool.load[0]) != 1 { + t.Errorf("Load is %d, want 1 after NewStream", atomic.LoadInt64(&pool.load[0])) + } + + err = stream.RecvMsg(&testpb.SimpleResponse{}) + if err == nil { + t.Errorf("RecvMsg should have failed") + } + + time.Sleep(10 * time.Millisecond) + if atomic.LoadInt64(&pool.load[0]) != 0 { + t.Errorf("Load is %d, want 0 after RecvMsg error", atomic.LoadInt64(&pool.load[0])) + } + }) + + t.Run("DecrementOnSendError", func(t *testing.T) { + ctx := context.Background() + stream, err := pool.NewStream(ctx, &grpc.StreamDesc{StreamName: "StreamingCall"}, "/grpc.testing.BenchmarkService/StreamingCall") + if err != nil { + t.Fatalf("NewStream failed: %v", err) + } + if atomic.LoadInt64(&pool.load[0]) != 1 { + t.Errorf("Load is %d, want 1 after NewStream", atomic.LoadInt64(&pool.load[0])) + } + + // Close the sending side of the stream. + if err := stream.CloseSend(); err != nil { + t.Fatalf("CloseSend failed: %v", err) + } + + // Wait for the server to acknowledge the closure by receiving io.EOF. + for { + if err := stream.RecvMsg(&testpb.SimpleResponse{}); err != nil { + if err == io.EOF { + break // Normal stream end. + } + t.Fatalf("RecvMsg failed unexpectedly while draining: %v", err) + } + } + + // Any subsequent SendMsg call must return an error. + err = stream.SendMsg(&testpb.SimpleRequest{Payload: &testpb.Payload{Body: []byte("wont send")}}) + if err == nil { + t.Errorf("SendMsg should have failed after stream is closed (RecvMsg returned io.EOF)") + } else { + // Optionally check the error type. It's often related to a closed stream. + st, ok := status.FromError(err) + if ok { + t.Logf("SendMsg failed as expected with status: %v", st) + } else { + t.Logf("SendMsg failed as expected with error: %v", err) + } + } + + // The decrement should have occurred when SendMsg returned an error. + time.Sleep(10 * time.Millisecond) // Give a moment for the decrement to be visible. + if atomic.LoadInt64(&pool.load[0]) != 0 { + t.Errorf("Load is %d, want 0 after SendMsg error on closed stream", atomic.LoadInt64(&pool.load[0])) + } + }) + + t.Run("NoDecrementOnSuccessfulSend", func(t *testing.T) { + fake.streamSema = make(chan struct{}) + defer close(fake.streamSema) + + ctx := context.Background() + stream, err := pool.NewStream(ctx, &grpc.StreamDesc{StreamName: "StreamingCall"}, "/grpc.testing.BenchmarkService/StreamingCall") + if err != nil { + t.Fatalf("NewStream failed: %v", err) + } + if atomic.LoadInt64(&pool.load[0]) != 1 { + t.Errorf("Load is %d, want 1", atomic.LoadInt64(&pool.load[0])) + } + + if err := stream.SendMsg(&testpb.SimpleRequest{Payload: &testpb.Payload{Body: []byte("test")}}); err != nil { + t.Fatalf("SendMsg failed: %v", err) + } + if atomic.LoadInt64(&pool.load[0]) != 1 { + t.Errorf("Load is %d, want 1 after successful SendMsg", atomic.LoadInt64(&pool.load[0])) + } + }) +}