diff --git a/cmd/client/grpc_client.go b/cmd/client/grpc_client.go index 2a22f743c..b7d2f3b18 100644 --- a/cmd/client/grpc_client.go +++ b/cmd/client/grpc_client.go @@ -32,6 +32,7 @@ const ( FlagInsecureNoTransportSecurity = "insecure-disable-transport-security" FlagInsecureSkipHostVerification = "insecure-skip-hostname-verification" FlagAuthority = "authority" + FlagBlock = "block" EnvReadRemote = "KETO_READ_REMOTE" EnvWriteRemote = "KETO_WRITE_REMOTE" @@ -45,6 +46,7 @@ type connectionDetails struct { token, authority string skipHostVerification bool noTransportSecurity bool + block bool } func (d *connectionDetails) dialOptions() (opts []grpc.DialOption) { @@ -71,6 +73,11 @@ func (d *connectionDetails) dialOptions() (opts []grpc.DialOption) { // Defaults to the default host root CA bundle opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(nil))) } + + if d.block { + opts = append(opts, grpc.WithBlock()) + } + return opts } @@ -106,6 +113,7 @@ func getConnectionDetails(cmd *cobra.Command) connectionDetails { authority: getAuthority(cmd), skipHostVerification: flagx.MustGetBool(cmd, FlagInsecureSkipHostVerification), noTransportSecurity: flagx.MustGetBool(cmd, FlagInsecureNoTransportSecurity), + block: flagx.MustGetBool(cmd, "block"), } } @@ -124,21 +132,16 @@ func GetWriteConn(cmd *cobra.Command) (*grpc.ClientConn, error) { } func Conn(ctx context.Context, remote string, details connectionDetails) (*grpc.ClientConn, error) { - timeout := 3 * time.Second if d, ok := ctx.Value(ContextKeyTimeout).(time.Duration); ok { - timeout = d + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, d) + defer cancel() } - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - return grpc.DialContext( ctx, remote, - append([]grpc.DialOption{ - grpc.WithBlock(), - grpc.WithDisableHealthCheck(), - }, details.dialOptions()...)..., + details.dialOptions()..., ) } @@ -148,4 +151,5 @@ func RegisterRemoteURLFlags(flags *pflag.FlagSet) { flags.String(FlagAuthority, "", "Set the authority header for the remote gRPC server.") flags.Bool(FlagInsecureNoTransportSecurity, false, "Disables transport security. Do not use this in production.") flags.Bool(FlagInsecureSkipHostVerification, false, "Disables hostname verification. Do not use this in production.") + flags.Bool(FlagBlock, false, "Block until all migrations have been applied") } diff --git a/cmd/migrate/status.go b/cmd/migrate/status.go index 51a4cd80d..1b7b3aeaa 100644 --- a/cmd/migrate/status.go +++ b/cmd/migrate/status.go @@ -12,12 +12,12 @@ import ( "github.com/ory/x/cmdx" "github.com/spf13/cobra" + "github.com/ory/keto/cmd/client" "github.com/ory/keto/internal/driver" "github.com/ory/keto/ketoctx" ) func newStatusCmd(opts []ketoctx.Option) *cobra.Command { - block := false cmd := &cobra.Command{ Use: "status", Short: "Get the current migration status", @@ -26,6 +26,11 @@ func newStatusCmd(opts []ketoctx.Option) *cobra.Command { RunE: func(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() + block, err := cmd.Flags().GetBool(client.FlagBlock) + if err != nil { + return err + } + reg, err := driver.NewDefaultRegistry(ctx, cmd.Flags(), true, opts) if err != nil { return err @@ -63,7 +68,6 @@ func newStatusCmd(opts []ketoctx.Option) *cobra.Command { } cmdx.RegisterFormatFlags(cmd.Flags()) - cmd.Flags().BoolVar(&block, "block", false, "Block until all migrations have been applied") return cmd } diff --git a/cmd/status/root.go b/cmd/status/root.go index fe2ec9c4c..72c34373f 100644 --- a/cmd/status/root.go +++ b/cmd/status/root.go @@ -18,15 +18,11 @@ import ( ) const ( - FlagBlock = "block" FlagEndpoint = "endpoint" ) func newStatusCmd() *cobra.Command { - var ( - block bool - endpoint string - ) + var endpoint string cmd := &cobra.Command{ Use: "status", @@ -34,6 +30,11 @@ func newStatusCmd() *cobra.Command { Long: "Get a status report about the upstream Keto instance. Can also block until the service is healthy.", Args: cobra.ExactArgs(0), RunE: func(cmd *cobra.Command, _ []string) error { + block, err := cmd.Flags().GetBool(cliclient.FlagBlock) + if err != nil { + return err + } + var connect func(*cobra.Command) (*grpc.ClientConn, error) switch endpoints := stringsx.SwitchExact(endpoint); { @@ -56,11 +57,9 @@ func newStatusCmd() *cobra.Command { conn, err = connect(cmd) } - if errors.Is(err, context.DeadlineExceeded) { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), grpcHealthV1.HealthCheckResponse_NOT_SERVING.String()) - return nil - } else if err != nil { - return err + if err != nil { + _, _ = fmt.Fprint(cmd.ErrOrStderr(), err.Error()) + return cmdx.FailSilently(cmd) } c := grpcHealthV1.NewHealthClient(conn) @@ -114,7 +113,6 @@ func newStatusCmd() *cobra.Command { cliclient.RegisterRemoteURLFlags(cmd.Flags()) cmdx.RegisterNoiseFlags(cmd.Flags()) - cmd.Flags().BoolVarP(&block, FlagBlock, "b", false, "block until the service is healthy") cmd.Flags().StringVar(&endpoint, FlagEndpoint, "read", "which endpoint to use; one of {read, write}") return cmd diff --git a/cmd/status/root_test.go b/cmd/status/root_test.go index 9b6c56bdb..df879b2dd 100644 --- a/cmd/status/root_test.go +++ b/cmd/status/root_test.go @@ -33,11 +33,11 @@ func TestStatusCmd(t *testing.T) { ts.Cmd.PersistentArgs = append(ts.Cmd.PersistentArgs, "--"+cmdx.FlagQuiet, "--"+FlagEndpoint, string(serverType)) t.Run("case=timeout,noblock", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) defer cancel() - stdOut := cmdx.ExecNoErrCtx(ctx, t, newStatusCmd(), "--"+FlagEndpoint, string(serverType), "--"+ts.FlagRemote, ts.Addr+"0") - assert.Equal(t, grpcHealthV1.HealthCheckResponse_NOT_SERVING.String()+"\n", stdOut) + stdErr := cmdx.ExecExpectedErrCtx(ctx, t, newStatusCmd(), "--"+FlagEndpoint, string(serverType), "--"+ts.FlagRemote, ts.Addr[:len(ts.Addr)-1]) + assert.Equal(t, "context deadline exceeded", stdErr) }) t.Run("case=noblock", func(t *testing.T) { @@ -82,7 +82,7 @@ func TestStatusCmd(t *testing.T) { "--"+FlagEndpoint, string(serverType), "--"+ts.FlagRemote, l.Addr().String(), "--insecure-skip-hostname-verification=true", - "--"+FlagBlock, + "--"+client.FlagBlock, ).Wait(), ) diff --git a/internal/e2e/cli_client_test.go b/internal/e2e/cli_client_test.go index 5ea581fd8..9cf0301ef 100644 --- a/internal/e2e/cli_client_test.go +++ b/internal/e2e/cli_client_test.go @@ -20,13 +20,12 @@ import ( grpcHealthV1 "google.golang.org/grpc/health/grpc_health_v1" - "github.com/ory/keto/cmd/status" - "github.com/ory/keto/internal/x" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + gprclient "github.com/ory/keto/cmd/client" cliexpand "github.com/ory/keto/cmd/expand" clirelationtuple "github.com/ory/keto/cmd/relationtuple" @@ -138,7 +137,7 @@ func (g *cliClient) waitUntilLive(t require.TestingT) { ctx, cancel := context.WithTimeout(g.c.Ctx, time.Minute) defer cancel() - out := cmdx.ExecNoErrCtx(ctx, t, g.c.New(), append(flags, "status", "--"+status.FlagBlock)...) + out := cmdx.ExecNoErrCtx(ctx, t, g.c.New(), append(flags, "status", "--"+gprclient.FlagBlock)...) require.Equal(t, grpcHealthV1.HealthCheckResponse_SERVING.String()+"\n", out) }