Skip to content

Commit

Permalink
feat: add $RYUK_CONNECTION_TIMEOUT to configure connection timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
hhsnopek committed Nov 29, 2022
1 parent 8f512d3 commit 1ca5330
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 11 deletions.
67 changes: 59 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"log"
"net"
"net/url"
"os"
"os/signal"
"strings"
"sync"
Expand All @@ -20,20 +21,70 @@ import (
"gopkg.in/matryer/try.v1"
)

const (
connectionTimeoutEnv string = "RYUK_CONNECTION_TIMEOUT"
)

var (
port = flag.Int("p", 8080, "Port to bind at")
initialConnectTimeout = 1 * time.Minute
reconnectionTimeout = 10 * time.Second
port int
connectionTimeout time.Duration
reconnectionTimeout time.Duration
)

type config struct {
Port int
ConnectionTimeout time.Duration
ReconnectionTimeout time.Duration
}

// newConfig parses command line flags and returns a parsed config. config.timeout
// can be set by environment variable, RYUK_CONNECTION_TIMEOUT. If an error occurs
// while parsing RYUK_CONNECTION_TIMEOUT the error is returned.
func newConfig(args []string) (*config, error) {
cfg := config{
ConnectionTimeout: 60 * time.Second,
ReconnectionTimeout: 10 * time.Second,
}

fs := flag.NewFlagSet("ryuk", flag.ExitOnError)
fs.SetOutput(os.Stdout)

fs.IntVar(&cfg.Port, "p", 8080, "Port to bind at")

err := fs.Parse(args)
if err != nil {
return nil, err
}

if timeout, ok := os.LookupEnv(connectionTimeoutEnv); ok {
parsedTimeout, err := time.ParseDuration(timeout)
if err != nil {
return nil, fmt.Errorf("failed to parse \"%s\": %s", connectionTimeoutEnv, err)
}

cfg.ConnectionTimeout = parsedTimeout
}

return &cfg, nil
}

func main() {
flag.Parse()
cfg, err := newConfig(os.Args[1:])
if err != nil {
panic(err)
}

cli, err := client.NewClientWithOpts()
port = cfg.Port
connectionTimeout = cfg.ConnectionTimeout
reconnectionTimeout = cfg.ReconnectionTimeout

cli, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
panic(err)
}

cli.NegotiateAPIVersion(context.Background())

pingCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

Expand Down Expand Up @@ -62,9 +113,9 @@ func main() {
}

func processRequests(deathNote *sync.Map, connectionAccepted chan<- net.Addr, connectionLost chan<- net.Addr) {
log.Printf("Starting on port %d...", *port)
log.Printf("Starting on port %d...", port)

ln, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -137,7 +188,7 @@ func waitForPruneCondition(ctx context.Context, connectionAccepted <-chan net.Ad
}

select {
case <-time.After(initialConnectTimeout):
case <-time.After(connectionTimeout):
panic("Timed out waiting for the first connection")
case addr := <-connectionAccepted:
handleConnectionAccepted(addr)
Expand Down
38 changes: 35 additions & 3 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"github.com/docker/docker/client"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
testcontainers "github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go"
)

var addr = &net.TCPAddr{
Expand All @@ -28,12 +28,16 @@ var addr = &net.TCPAddr{
Zone: "",
}

var testConnectionTimeout time.Duration = 5 * time.Second

func init() {
initialConnectTimeout = 5 * time.Second
reconnectionTimeout = 1 * time.Second
}

func TestReconnectionTimeout(t *testing.T) {
// reset connectionTimeout
connectionTimeout = testConnectionTimeout

acc := make(chan net.Addr)
lost := make(chan net.Addr)

Expand All @@ -56,6 +60,9 @@ func TestReconnectionTimeout(t *testing.T) {
}

func TestInitialTimeout(t *testing.T) {
// reset connectionTimeout
connectionTimeout = testConnectionTimeout

acc := make(chan net.Addr)
lost := make(chan net.Addr)

Expand All @@ -80,7 +87,7 @@ func TestInitialTimeout(t *testing.T) {
}

func TestPrune(t *testing.T) {
cli, err := client.NewClientWithOpts()
cli, err := client.NewClientWithOpts(client.FromEnv)
if err == nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
Expand All @@ -91,6 +98,7 @@ func TestPrune(t *testing.T) {
if err != nil {
t.Fatal(err)
}
cli.NegotiateAPIVersion(context.Background())

maxLength := 25

Expand Down Expand Up @@ -274,3 +282,27 @@ func TestPrune(t *testing.T) {
assert.Equal(t, maxLength, di)
})
}

func Test_newConfig(t *testing.T) {
t.Run("should return an error when failing to parse the environment variable", func(t *testing.T) {
t.Setenv(connectionTimeoutEnv, "bad_value")

config, err := newConfig([]string{})
require.NotNil(t, err)
require.Nil(t, config)
})

t.Run("should set connectionTimeout with the environment variable", func(t *testing.T) {
t.Setenv(connectionTimeoutEnv, "10s")

config, err := newConfig([]string{})
require.Nil(t, err)
assert.Equal(t, 10*time.Second, config.ConnectionTimeout)
})

t.Run("should set port", func(t *testing.T) {
config, err := newConfig([]string{"-p", "3000"})
require.Nil(t, err)
assert.Equal(t, 3000, config.Port)
})
}

0 comments on commit 1ca5330

Please sign in to comment.