Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ringhash: more e2e tests from c-core #7334

Merged
merged 8 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 82 additions & 9 deletions internal/testutils/blocking_context_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,109 @@ package testutils
import (
"context"
"net"
"sync"

"google.golang.org/grpc/grpclog"
)

var logger = grpclog.Component("testutils")

// BlockingDialer is a dialer that waits for Resume() to be called before
// dialing.
type BlockingDialer struct {
dialer *net.Dialer
blockCh chan struct{}
// mu protects holds.
mu sync.Mutex
// holds maps network addresses to a list of holds for that address.
holds map[string][]*Hold
easwars marked this conversation as resolved.
Show resolved Hide resolved
}

// NewBlockingDialer returns a dialer that waits for Resume() to be called
// before dialing.
func NewBlockingDialer() *BlockingDialer {
return &BlockingDialer{
dialer: &net.Dialer{},
blockCh: make(chan struct{}),
holds: make(map[string][]*Hold),
}
}

// DialContext implements a context dialer for use with grpc.WithContextDialer
// dial option for a BlockingDialer.
func (d *BlockingDialer) DialContext(ctx context.Context, addr string) (net.Conn, error) {
d.mu.Lock()
holds := d.holds[addr]
if len(holds) == 0 {
// No hold for this addr.
d.mu.Unlock()
return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
}
hold := holds[0]
d.holds[addr] = holds[1:]
d.mu.Unlock()

logger.Infof("Hold %p: Intercepted connection attempt to addr %q", hold, addr)
close(hold.waitCh)
select {
case <-d.blockCh:
case err := <-hold.blockCh:
if err != nil {
return nil, err
}
return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
case <-ctx.Done():
logger.Infof("Hold %p: Connection attempt to addr %q timed out", hold, addr)
return nil, ctx.Err()
}
return d.dialer.DialContext(ctx, "tcp", addr)
}

// Resume unblocks the dialer. It panics if called more than once.
func (d *BlockingDialer) Resume() {
close(d.blockCh)
// Hold is a handle to a single connection attempt. It can be used to block,
// fail and succeed connection attempts.
type Hold struct {
// dialer is the dialer that created this hold.
dialer *BlockingDialer
// waitCh is closed when a connection attempt is received.
waitCh chan struct{}
// blockCh receives the value to return from DialContext for this connection
// attempt (nil on resume, an error on fail). It receives at most 1 value.
blockCh chan error
// addr is the address that this hold is for.
addr string
}

// Hold blocks the dialer when a connection attempt is made to the given addr.
// A hold is valid for exactly one connection attempt. Multiple holds for an
// addr can be added, and they will apply in the order that the connections are
// attempted.
func (d *BlockingDialer) Hold(addr string) *Hold {
d.mu.Lock()
defer d.mu.Unlock()

h := Hold{dialer: d, blockCh: make(chan error, 1), waitCh: make(chan struct{}), addr: addr}
d.holds[addr] = append(d.holds[addr], &h)
return &h
}

// Wait blocks until there is a connection attempt on this Hold, or the context
// expires. Return false if the context has expired, true otherwise.
func (h *Hold) Wait(ctx context.Context) bool {
logger.Infof("Hold %p: Waiting for a connection attempt to addr %q", h, h.addr)
select {
case <-ctx.Done():
return false
case <-h.waitCh:
return true
}
}

// Resume unblocks the dialer for the given addr. Either Resume or Fail must be
// called at most once on a hold. Otherwise, Resume panics.
func (h *Hold) Resume() {
logger.Infof("Hold %p: Resuming connection attempt to addr %q", h, h.addr)
h.blockCh <- nil
close(h.blockCh)
}

// Fail fails the connection attempt. Either Resume or Fail must be
// called at most once on a hold. Otherwise, Resume panics.
func (h *Hold) Fail(err error) {
logger.Infof("Hold %p: Failing connection attempt to addr %q", h, h.addr)
h.blockCh <- err
close(h.blockCh)
}
201 changes: 201 additions & 0 deletions internal/testutils/blocking_context_dialer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package testutils

import (
"context"
"errors"
"testing"
"time"
)

const (
testTimeout = 5 * time.Second
testShortTimeout = 10 * time.Millisecond
)

func (s) TestBlockingDialer_NoHold(t *testing.T) {
lis, err := LocalTCPListener()
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()

d := NewBlockingDialer()

// This should not block.
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
conn, err := d.DialContext(ctx, lis.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
conn.Close()
}

func (s) TestBlockingDialer_HoldWaitResume(t *testing.T) {
lis, err := LocalTCPListener()
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()

d := NewBlockingDialer()
h := d.Hold(lis.Addr().String())

done := make(chan struct{})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
go func() {
defer close(done)
conn, err := d.DialContext(ctx, lis.Addr().String())
if err != nil {
t.Errorf("BlockingDialer.DialContext() got error: %v, want success", err)
return
}
conn.Close()
}()

// This should block until the goroutine above is scheduled.
if !h.Wait(ctx) {
t.Fatalf("Timeout while waiting for a connection attempt to %q", h.addr)
}
select {
case <-done:
t.Fatalf("Expected dialer to be blocked.")
case <-time.After(testShortTimeout):
}

h.Resume() // Unblock the above goroutine.

select {
case <-done:
case <-ctx.Done():
t.Errorf("Timeout waiting for connection attempt to resume.")
}
}

func (s) TestBlockingDialer_HoldWaitFail(t *testing.T) {
lis, err := LocalTCPListener()
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()

d := NewBlockingDialer()
h := d.Hold(lis.Addr().String())

wantErr := errors.New("test error")

dialError := make(chan error)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
go func() {
_, err := d.DialContext(ctx, lis.Addr().String())
dialError <- err
}()

if !h.Wait(ctx) {
t.Fatalf("Timeout while waiting for a connection attempt to " + h.addr)
}
select {
case err = <-dialError:
t.Errorf("DialContext got unblocked with err %v. Want DialContext to still be blocked after Wait()", err)
case <-time.After(testShortTimeout):
}

h.Fail(wantErr)

select {
case err = <-dialError:
if !errors.Is(err, wantErr) {
t.Errorf("BlockingDialer.DialContext() after Fail(): got error %v, want %v", err, wantErr)
}
case <-ctx.Done():
t.Errorf("Timeout waiting for connection attempt to fail.")
}
}

func (s) TestBlockingDialer_ContextCanceled(t *testing.T) {
lis, err := LocalTCPListener()
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()

d := NewBlockingDialer()
h := d.Hold(lis.Addr().String())

dialErr := make(chan error)
testCtx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()

ctx, cancel := context.WithCancel(testCtx)
defer cancel()
go func() {
_, err := d.DialContext(ctx, lis.Addr().String())
dialErr <- err
}()
if !h.Wait(testCtx) {
t.Errorf("Timeout while waiting for a connection attempt to %q", h.addr)
}

cancel()

select {
case err = <-dialErr:
if !errors.Is(err, context.Canceled) {
t.Errorf("BlockingDialer.DialContext() after context cancel: got error %v, want %v", err, context.Canceled)
}
case <-testCtx.Done():
t.Errorf("Timeout while waiting for Wait to return.")
}

h.Resume() // noop, just make sure nothing bad happen.
}

func (s) TestBlockingDialer_CancelWait(t *testing.T) {
lis, err := LocalTCPListener()
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()

d := NewBlockingDialer()
h := d.Hold(lis.Addr().String())

testCtx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()

ctx, cancel := context.WithCancel(testCtx)
cancel()
done := make(chan struct{})
go func() {
if h.Wait(ctx) {
t.Errorf("Expected cancel to return false when context expires")
}
done <- struct{}{}
}()

select {
case <-done:
case <-testCtx.Done():
t.Errorf("Timeout while waiting for Wait to return.")
}
}
Loading