Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 38 additions & 89 deletions modules/cockroachdb/cockroachdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,60 +174,50 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom
password: defaultPassword,
},
}
req := testcontainers.GenericContainerRequest{
ContainerRequest: testcontainers.ContainerRequest{
Image: img,
ExposedPorts: []string{
defaultSQLPort,
defaultAdminPort,
},
Env: map[string]string{
"COCKROACH_DATABASE": defaultDatabase,
"COCKROACH_USER": defaultUser,
"COCKROACH_PASSWORD": defaultPassword,
},
Files: []testcontainers.ContainerFile{{
Reader: newDefaultsReader(clusterDefaults),
ContainerFilePath: clusterDefaultsContainerFile,
FileMode: 0o644,
}},
Cmd: []string{
"start-single-node",
memStorageFlag + defaultStoreSize,
},
WaitingFor: wait.ForAll(
wait.ForFile(cockroachDir+"/init_success"),
wait.ForHTTP("/health").WithPort(defaultAdminPort),
wait.ForTLSCert(
certsDir+"/client."+defaultUser+".crt",
certsDir+"/client."+defaultUser+".key",
).WithRootCAs(fileCACert).WithServerName("127.0.0.1"),
wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port nat.Port) string {
connStr, err := ctr.connString(host, port)
if err != nil {
panic(err)
}
return connStr
}),
),
},
Started: true,
}

for _, opt := range opts {
if err := opt.Customize(&req); err != nil {
return nil, fmt.Errorf("customize request: %w", err)
}
moduleOpts := []testcontainers.ContainerCustomizer{
testcontainers.WithCmd(
"start-single-node",
memStorageFlag+defaultStoreSize,
),
testcontainers.WithExposedPorts(defaultSQLPort, defaultAdminPort),
testcontainers.WithEnv(map[string]string{
"COCKROACH_DATABASE": defaultDatabase,
"COCKROACH_USER": defaultUser,
"COCKROACH_PASSWORD": defaultPassword,
}),
testcontainers.WithFiles(testcontainers.ContainerFile{
Reader: newDefaultsReader(clusterDefaults),
ContainerFilePath: clusterDefaultsContainerFile,
FileMode: 0o644,
}),
testcontainers.WithWaitStrategy(wait.ForAll(
wait.ForFile(cockroachDir+"/init_success"),
wait.ForHTTP("/health").WithPort(defaultAdminPort),
wait.ForTLSCert(
certsDir+"/client."+defaultUser+".crt",
certsDir+"/client."+defaultUser+".key",
).WithRootCAs(fileCACert).WithServerName("127.0.0.1"),
wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port nat.Port) string {
connStr, err := ctr.connString(host, port)
if err != nil {
panic(err)
}
return connStr
}),
)),
}

if err := ctr.configure(&req); err != nil {
return nil, fmt.Errorf("set options: %w", err)
}
moduleOpts = append(moduleOpts, opts...)

// configure the wait strategy after all the options have been applied
// It extracts the TLS strategy from the wait strategy and sets it on the container.
moduleOpts = append(moduleOpts, ctr.configure())

var err error
ctr.Container, err = testcontainers.GenericContainer(ctx, req)
ctr.Container, err = testcontainers.Run(ctx, img, moduleOpts...)
if err != nil {
return ctr, fmt.Errorf("generic container: %w", err)
return ctr, fmt.Errorf("run cockroachdb: %w", err)
}

return ctr, nil
Expand Down Expand Up @@ -278,44 +268,3 @@ func (c *CockroachDBContainer) connConfig(host string, port nat.Port) (*pgx.Conn

return cfg, nil
}

// configure sets the CockroachDBContainer options from the given request and updates the request
// wait strategies to match the options.
func (c *CockroachDBContainer) configure(req *testcontainers.GenericContainerRequest) error {
c.database = req.Env[envDatabase]
c.user = req.Env[envUser]
c.password = req.Env[envPassword]

var insecure bool
for _, arg := range req.Cmd {
if arg == insecureFlag {
insecure = true
break
}
}

// Walk the wait strategies to find the TLS strategy and either remove it or
// update the client certificate files to match the user and configure the
// container to use the TLS strategy.
if err := wait.Walk(&req.WaitingFor, func(strategy wait.Strategy) error {
if cert, ok := strategy.(*wait.TLSStrategy); ok {
if insecure {
// If insecure mode is enabled, the certificate strategy is removed.
return errors.Join(wait.ErrVisitRemove, wait.ErrVisitStop)
}

// Update the client certificate files to match the user which may have changed.
cert.WithCert(certsDir+"/client."+c.user+".crt", certsDir+"/client."+c.user+".key")

c.tlsStrategy = cert

// Stop the walk as the certificate strategy has been found.
return wait.ErrVisitStop
}
return nil
}); err != nil {
return fmt.Errorf("walk strategies: %w", err)
}

return nil
}
47 changes: 47 additions & 0 deletions modules/cockroachdb/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package cockroachdb

import (
"errors"
"fmt"
"path/filepath"
"slices"
"strings"

"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)

// errInsecureWithPassword is returned when trying to use insecure mode with a password.
Expand Down Expand Up @@ -117,3 +120,47 @@ func WithInsecure() testcontainers.CustomizeRequestOption {
return nil
}
}

// configure sets the CockroachDBContainer options from the given request and updates the request
// wait strategies to match the options.
// This option must be called after all the options have been applied, in order to extract
// the credentials from the environment variables and the TLS strategy from the wait strategy.
func (c *CockroachDBContainer) configure() testcontainers.CustomizeRequestOption {
return func(req *testcontainers.GenericContainerRequest) error {
// refresh the credentials from the environment variables
c.user = req.Env[envUser]
c.password = req.Env[envPassword]
c.database = req.Env[envDatabase]

var insecure bool
if slices.Contains(req.Cmd, insecureFlag) {
insecure = true
}

// Walk the wait strategies to find the TLS strategy and either remove it or
// update the client certificate files to match the user and configure the
// container to use the TLS strategy.
if err := wait.Walk(&req.WaitingFor, func(strategy wait.Strategy) error {
if cert, ok := strategy.(*wait.TLSStrategy); ok {
if insecure {
// If insecure mode is enabled, the certificate strategy is removed.
return errors.Join(wait.ErrVisitRemove, wait.ErrVisitStop)
}

// Update the client certificate files to match the user which may have changed.
cert.WithCert(certsDir+"/client."+c.user+".crt", certsDir+"/client."+c.user+".key")

c.tlsStrategy = cert

// Stop the walk as the certificate strategy has been found.
return wait.ErrVisitStop
}

return nil
}); err != nil {
return fmt.Errorf("walk strategies: %w", err)
}

return nil
}
}
Loading