From c3fa9b5ea9f0b93c2ab462435423e64e32b8a0e4 Mon Sep 17 00:00:00 2001 From: chronark Date: Thu, 17 Jul 2025 12:08:39 +0200 Subject: [PATCH 1/6] fix: port allocation in tests we now try to listen on a random port assigned by the OS and never stop listening until the tests are done. This prevents the race conditions between assigning a port and using it --- go/apps/api/cancel_test.go | 15 ++-- go/apps/api/config.go | 8 ++ go/apps/api/integration/harness.go | 21 ++--- go/apps/api/run.go | 25 ++++-- go/cmd/api/main.go | 5 +- go/pkg/port/doc.go | 37 --------- go/pkg/port/free.go | 128 ----------------------------- go/pkg/zen/README.md | 54 +++++++++++- go/pkg/zen/doc.go | 8 +- go/pkg/zen/server.go | 14 ++-- go/pkg/zen/server_tls_test.go | 42 ++++------ 11 files changed, 125 insertions(+), 232 deletions(-) delete mode 100644 go/pkg/port/doc.go delete mode 100644 go/pkg/port/free.go diff --git a/go/apps/api/cancel_test.go b/go/apps/api/cancel_test.go index f64a9d020d..3e5cef0d6e 100644 --- a/go/apps/api/cancel_test.go +++ b/go/apps/api/cancel_test.go @@ -3,13 +3,13 @@ package api_test import ( "context" "fmt" + "net" "net/http" "testing" "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api" - "github.com/unkeyed/unkey/go/pkg/port" "github.com/unkeyed/unkey/go/pkg/testutil/containers" "github.com/unkeyed/unkey/go/pkg/uid" "github.com/unkeyed/unkey/go/pkg/vault/keys" @@ -23,9 +23,10 @@ func TestContextCancellation(t *testing.T) { mysqlCfg.DBName = "unkey" dbDsn := mysqlCfg.FormatDSN() redisUrl := containers.Redis(t) - // Get free ports for the node - portAllocator := port.New() - httpPort := portAllocator.Get() + + // Create ephemeral listener + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err, "Failed to create ephemeral listener") // Create a cancellable context ctx, cancel := context.WithCancel(context.Background()) @@ -37,7 +38,7 @@ func TestContextCancellation(t *testing.T) { config := api.Config{ Platform: "test", Image: "test", - HttpPort: httpPort, + Listener: ln, Region: "test-region", Clock: nil, // Will use real clock InstanceID: uid.New(uid.InstancePrefix), @@ -65,7 +66,7 @@ func TestContextCancellation(t *testing.T) { // Wait for the server to start up require.Eventually(t, func() bool { - res, livenessErr := http.Get(fmt.Sprintf("http://localhost:%d/v2/liveness", httpPort)) + res, livenessErr := http.Get(fmt.Sprintf("http://%s/v2/liveness", ln.Addr())) if livenessErr != nil { return false } @@ -90,6 +91,6 @@ func TestContextCancellation(t *testing.T) { } // Verify the server is no longer responding - _, err = http.Get(fmt.Sprintf("http://localhost:%d/v2/liveness", httpPort)) + _, err = http.Get(fmt.Sprintf("http://%s/v2/liveness", ln.Addr())) require.Error(t, err, "Server should no longer be responding after shutdown") } diff --git a/go/apps/api/config.go b/go/apps/api/config.go index fbeb1f2e6d..07995fbd46 100644 --- a/go/apps/api/config.go +++ b/go/apps/api/config.go @@ -1,6 +1,8 @@ package api import ( + "net" + "github.com/unkeyed/unkey/go/pkg/clock" "github.com/unkeyed/unkey/go/pkg/tls" ) @@ -23,8 +25,14 @@ type Config struct { Image string // HttpPort defines the HTTP port for the API server to listen on (default: 7070) + // Used in production deployments. Ignored if Listener is provided. HttpPort int + // Listener defines a pre-created network listener for the HTTP server + // If provided, the server will use this listener instead of creating one from HttpPort + // This is intended for testing scenarios where ephemeral ports are needed to avoid conflicts + Listener net.Listener + // Region identifies the geographic region where this node is deployed Region string diff --git a/go/apps/api/integration/harness.go b/go/apps/api/integration/harness.go index 230e2b6b4b..c8f8814561 100644 --- a/go/apps/api/integration/harness.go +++ b/go/apps/api/integration/harness.go @@ -3,6 +3,7 @@ package integration import ( "context" "fmt" + "net" "net/http" "testing" "time" @@ -13,7 +14,6 @@ import ( "github.com/unkeyed/unkey/go/pkg/clock" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/otel/logging" - "github.com/unkeyed/unkey/go/pkg/port" "github.com/unkeyed/unkey/go/pkg/testutil/containers" "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) @@ -36,7 +36,6 @@ type Harness struct { ctx context.Context cancel context.CancelFunc instanceAddrs []string - ports *port.FreePort Seed *seed.Seeder dbDSN string DB db.Database @@ -87,7 +86,6 @@ func New(t *testing.T, config Config) *Harness { t: t, ctx: ctx, cancel: cancel, - ports: port.New(), instanceAddrs: []string{}, Seed: seed.New(t, db), dbDSN: mysqlHostDSN, @@ -124,11 +122,11 @@ func (h *Harness) RunAPI(config ApiConfig) *ApiCluster { // Start each API node as a goroutine for i := 0; i < config.Nodes; i++ { - // Find an available port - portFinder := port.New() - nodePort := portFinder.Get() + // Create ephemeral listener + ln, err := net.Listen("tcp", ":0") + require.NoError(h.t, err, "Failed to create ephemeral listener") - cluster.Addrs[i] = fmt.Sprintf("http://localhost:%d", nodePort) + cluster.Addrs[i] = fmt.Sprintf("http://%s", ln.Addr().String()) // Create API config for this node using host connections mysqlHostCfg := containers.MySQL(h.t) @@ -139,7 +137,7 @@ func (h *Harness) RunAPI(config ApiConfig) *ApiCluster { apiConfig := api.Config{ Platform: "test", Image: "test", - HttpPort: nodePort, + Listener: ln, DatabasePrimary: mysqlHostCfg.FormatDSN(), DatabaseReadonlyReplica: "", ClickhouseURL: clickhouseHostDSN, @@ -198,12 +196,13 @@ func (h *Harness) RunAPI(config ApiConfig) *ApiCluster { // Wait for server to start maxAttempts := 30 + healthURL := fmt.Sprintf("http://%s/v2/liveness", ln.Addr().String()) for attempt := 0; attempt < maxAttempts; attempt++ { - resp, err := http.Get(fmt.Sprintf("http://localhost:%d/v2/liveness", nodePort)) + resp, err := http.Get(healthURL) if err == nil { resp.Body.Close() if resp.StatusCode == http.StatusOK { - h.t.Logf("API server %d started on port %d", i, nodePort) + h.t.Logf("API server %d started on %s", i, ln.Addr().String()) break } } @@ -216,6 +215,8 @@ func (h *Harness) RunAPI(config ApiConfig) *ApiCluster { // Register cleanup h.t.Cleanup(func() { cancel() + // Note: Don't call ln.Close() here as the zen server + // will properly close the listener during graceful shutdown }) } diff --git a/go/apps/api/run.go b/go/apps/api/run.go index 5169007fdf..185ffc875a 100644 --- a/go/apps/api/run.go +++ b/go/apps/api/run.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "net" "runtime/debug" "time" @@ -110,7 +111,11 @@ func Run(ctx context.Context, cfg Config) error { return fmt.Errorf("unable to start prometheus: %w", promErr) } go func() { - promListenErr := prom.Listen(ctx, fmt.Sprintf(":%d", cfg.PrometheusPort)) + promListener, err := net.Listen("tcp", fmt.Sprintf(":%d", cfg.PrometheusPort)) + if err != nil { + panic(err) + } + promListenErr := prom.Serve(ctx, promListener) if promListenErr != nil { panic(promListenErr) } @@ -222,16 +227,22 @@ func Run(ctx context.Context, cfg Config) error { Caches: caches, Vault: vaultSvc, }) + if cfg.Listener == nil { + // Create listener from HttpPort (production) + cfg.Listener, err = net.Listen("tcp", fmt.Sprintf(":%d", cfg.HttpPort)) + if err != nil { + return fmt.Errorf("Unable to listen on port %d: %w", cfg.HttpPort, err) + } + } go func() { - listenErr := srv.Listen(ctx, fmt.Sprintf(":%d", cfg.HttpPort)) - if listenErr != nil { - panic(listenErr) + serveErr := srv.Serve(ctx, cfg.Listener) + if serveErr != nil { + panic(serveErr) } - }() - // Wait for signals and handle shutdown - logger.Info("API server started successfully") + logger.Info("API server started successfully") + }() // Wait for either OS signals or context cancellation, then shutdown if err := shutdowns.WaitForSignal(ctx, time.Minute); err != nil { diff --git a/go/cmd/api/main.go b/go/cmd/api/main.go index ed28550263..e9a3ce90fb 100644 --- a/go/cmd/api/main.go +++ b/go/cmd/api/main.go @@ -208,7 +208,6 @@ func action(ctx context.Context, cmd *cli.Command) error { // Basic configuration Platform: cmd.String("platform"), Image: cmd.String("image"), - HttpPort: cmd.Int("http-port"), Region: cmd.String("region"), // Database configuration @@ -231,6 +230,10 @@ func action(ctx context.Context, cmd *cli.Command) error { Clock: clock.New(), TestMode: cmd.Bool("test-mode"), + // HTTP configuration + HttpPort: cmd.Int("http-port"), + Listener: nil, // Production uses HttpPort + // Vault configuration VaultMasterKeys: cmd.StringSlice("vault-master-keys"), VaultS3: vaultS3Config, diff --git a/go/pkg/port/doc.go b/go/pkg/port/doc.go deleted file mode 100644 index 0483e3b769..0000000000 --- a/go/pkg/port/doc.go +++ /dev/null @@ -1,37 +0,0 @@ -// Package port provides utilities for finding and managing available network ports. -// -// This package is particularly useful for testing scenarios where multiple -// services need to run on unique ports without conflicting with each other -// or with existing services. It safely locates available ports through actual -// network binding and offers mechanisms to track allocated ports to prevent -// reuse within the same process. -// -// The implementation uses a combination of random port selection and actual -// TCP socket binding to verify availability. This approach is more reliable -// than just checking if a port is currently in use, as it accounts for -// ports that may be temporarily unavailable or restricted by the operating system. -// -// Basic usage: -// -// // Create a port finder -// finder := port.New() -// -// // Get an available port -// port := finder.Get() -// -// // Use the port for your service -// listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) -// -// // Or for testing multiple services: -// port1 := finder.Get() -// port2 := finder.Get() -// port3 := finder.Get() -// -// The package tracks ports it has assigned within the current process -// to ensure the same port isn't returned twice, even if the port hasn't -// been bound yet. -// -// Note that port availability is only guaranteed at the moment Get() is called. -// If there is a delay between getting the port and binding to it, another -// process could potentially bind to that port in the meantime. -package port diff --git a/go/pkg/port/free.go b/go/pkg/port/free.go deleted file mode 100644 index 86e5be048b..0000000000 --- a/go/pkg/port/free.go +++ /dev/null @@ -1,128 +0,0 @@ -package port - -import ( - "fmt" - "math/rand/v2" - "net" - "sync" -) - -// FreePort provides utilities for finding available network ports. -// It manages a pool of assigned ports to prevent the same port from -// being returned multiple times within the same process. -type FreePort struct { - mu sync.RWMutex - min int - max int - attempts int - - // The caller may request multiple ports without binding them immediately - // so we need to keep track of which ports are assigned. - assigned map[int]bool -} - -// New creates a new FreePort instance for finding available ports. -// The returned instance keeps track of ports it has assigned to prevent -// returning the same port twice, even if the actual binding hasn't occurred. -// -// By default, ports are selected from the range 10000-65535, which falls -// within the standard range for ephemeral/private ports. -// -// Example: -// -// // Create a new port finder -// portFinder := port.New() -// -// // Get multiple available ports -// httpPort := portFinder.Get() -// grpcPort := portFinder.Get() -// metricsPort := portFinder.Get() -// -// fmt.Printf("Running HTTP on port %d, gRPC on port %d, metrics on port %d\n", -// httpPort, grpcPort, metricsPort) -func New() *FreePort { - return &FreePort{ - min: 10000, - max: 65535, - attempts: 10, - assigned: map[int]bool{}, - mu: sync.RWMutex{}, - } -} - -// Get returns an available TCP port number. -// The port is guaranteed to be available at the time of the call, -// and will not be returned again by the same FreePort instance. -// -// This method will attempt to find an available port by: -// 1. Selecting a random port in the range 10000-65535 -// 2. Checking that the port hasn't already been assigned by this instance -// 3. Verifying availability by attempting to bind to 127.0.0.1 on that port -// 4. Marking the port as assigned to prevent future reuse -// -// If no available port can be found after multiple attempts, Get will panic. -// For cases where error handling is preferred over panicking, use GetWithError. -// -// Example: -// -// finder := port.New() -// serverPort := finder.Get() -// -// // Start your server on this port -// server := &http.Server{ -// Addr: fmt.Sprintf(":%d", serverPort), -// Handler: mux, -// } -// server.ListenAndServe() -func (f *FreePort) Get() int { - port, err := f.GetWithError() - if err != nil { - panic(err) - } - - return port -} - -// GetWithError returns an available TCP port number or an error if no port -// could be found after multiple attempts. -// -// This method works the same as Get() but returns an error instead of -// panicking when no available ports can be found. This is preferred in -// production code where error handling is more appropriate than panicking. -// -// Example: -// -// finder := port.New() -// port, err := finder.GetWithError() -// if err != nil { -// log.Fatalf("Failed to find available port: %v", err) -// } -// -// // Use the port -// listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) -func (f *FreePort) GetWithError() (int, error) { - f.mu.Lock() - defer f.mu.Unlock() - - for i := 0; i < f.attempts; i++ { - - // nolint:gosec - // This isn't cryptography - port := rand.IntN(f.max-f.min) + f.min - if f.assigned[port] { - continue - } - - ln, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: port, Zone: ""}) - if err != nil { - continue - } - err = ln.Close() - if err != nil { - return -1, err - } - f.assigned[port] = true - return port, nil - } - return -1, fmt.Errorf("could not find a free port, maybe increase attempts?") -} diff --git a/go/pkg/zen/README.md b/go/pkg/zen/README.md index bfda47d98b..4811407c1c 100644 --- a/go/pkg/zen/README.md +++ b/go/pkg/zen/README.md @@ -38,6 +38,7 @@ import ( "context" "log" "log/slog" + "net" "net/http" "github.com/unkeyed/unkey/go/pkg/zen" @@ -141,7 +142,14 @@ func main() { logger.Info("starting server", "address", ":8080", ) - err = server.Listen(context.Background(), ":8080") + + // Create a listener + listener, err := net.Listen("tcp", ":8080") + if err != nil { + log.Fatalf("failed to create listener: %v", err) + } + + err = server.Serve(context.Background(), listener) if err != nil { logger.Error("server error", slog.String("error", err.Error())) } @@ -158,6 +166,7 @@ package main import ( "context" "log" + "net" "github.com/unkeyed/unkey/go/pkg/tls" "github.com/unkeyed/unkey/go/pkg/zen" @@ -184,9 +193,15 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // Create a listener for HTTPS + listener, err := net.Listen("tcp", ":443") + if err != nil { + log.Fatalf("failed to create listener: %v", err) + } + // Start in a goroutine so you can handle shutdown signals go func() { - if err := server.Listen(ctx, ":443"); err != nil { + if err := server.Serve(ctx, listener); err != nil { log.Fatalf("server error: %v", err) } }() @@ -199,6 +214,32 @@ func main() { } ``` +## Testing with Ephemeral Ports + +For testing, you can use ephemeral ports to let the OS assign an available port automatically. This prevents port conflicts in testing environments: + +```go +import "github.com/unkeyed/unkey/go/pkg/listener" + +// Get an available port and listener +listenerImpl, err := listener.Ephemeral() +if err != nil { + t.Fatalf("failed to create ephemeral listener: %v", err) +} +netListener, err := listenerImpl.Listen() +if err != nil { + t.Fatalf("failed to get listener: %v", err) +} + +// Start the server +go server.Serve(ctx, netListener) + +// Make requests to the server +resp, err := http.Get(fmt.Sprintf("http://%s/test", listenerImpl.Addr())) +``` + +This approach is especially useful for concurrent tests where multiple servers need to run simultaneously without conflicting ports. + ## Working with OpenAPI Validation Zen works well with a schema-first approach to API design. Define your OpenAPI specification first, then use it for validation: @@ -228,8 +269,13 @@ Zen provides built-in support for graceful shutdown through context cancellation // Create a context that can be cancelled ctx, cancel := context.WithCancel(context.Background()) -// Start the server with this context -go server.Listen(ctx, ":8080") +// Create a listener and start the server with this context +listener, err := net.Listen("tcp", ":8080") +if err != nil { + log.Fatalf("failed to create listener: %v", err) +} + +go server.Serve(ctx, listener) // When you need to shut down (e.g., on SIGTERM): cancel() diff --git a/go/pkg/zen/doc.go b/go/pkg/zen/doc.go index 88faf29ddc..fb12bef9c5 100644 --- a/go/pkg/zen/doc.go +++ b/go/pkg/zen/doc.go @@ -49,8 +49,12 @@ // route, // ) // -// // Start the server -// err = server.Listen(ctx, ":8080") +// // Create a listener and start the server +// listener, err := net.Listen("tcp", ":8080") +// if err != nil { +// log.Fatalf("failed to create listener: %v", err) +// } +// err = server.Serve(ctx, listener) // // Zen is optimized for building maintainable, observable web services with minimal // external dependencies and strong integration with standard Go libraries. diff --git a/go/pkg/zen/server.go b/go/pkg/zen/server.go index cbf0fb34cd..9e47d1b8da 100644 --- a/go/pkg/zen/server.go +++ b/go/pkg/zen/server.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "net/http" "sync" "time" @@ -162,7 +163,7 @@ func (s *Server) Flags() Flags { // log.Printf("server stopped: %v", err) // } // }() -func (s *Server) Listen(ctx context.Context, addr string) error { +func (s *Server) Serve(ctx context.Context, ln net.Listener) error { s.mu.Lock() if s.isListening { s.logger.Warn("already listening") @@ -172,8 +173,6 @@ func (s *Server) Listen(ctx context.Context, addr string) error { s.isListening = true s.mu.Unlock() - s.srv.Addr = addr - // Set up context handling for graceful shutdown serverCtx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -194,20 +193,19 @@ func (s *Server) Listen(ctx context.Context, addr string) error { // Server stopped on its own } }() - var err error // Check if TLS should be used if s.tlsConfig != nil { - s.logger.Info("listening", "srv", "https", "addr", addr) + s.logger.Info("listening", "srv", "https", "addr", ln.Addr().String()) s.srv.TLSConfig = s.tlsConfig // ListenAndServeTLS with empty strings will use the certificates from TLSConfig - err = s.srv.ListenAndServeTLS("", "") + err = s.srv.ServeTLS(ln, "", "") } else { - s.logger.Info("listening", "srv", "http", "addr", addr) - err = s.srv.ListenAndServe() + s.logger.Info("listening", "srv", "http", "addr", ln.Addr().String()) + err = s.srv.Serve(ln) } // Cancel the server context since the server has stopped diff --git a/go/pkg/zen/server_tls_test.go b/go/pkg/zen/server_tls_test.go index 1a0f4fc96a..7abfedf180 100644 --- a/go/pkg/zen/server_tls_test.go +++ b/go/pkg/zen/server_tls_test.go @@ -98,16 +98,12 @@ func TestServerWithTLS(t *testing.T) { }) server.RegisterRoute([]Middleware{}, testRoute) - // Create a net.Listener to determine the port - ln, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err, "Failed to create listener") + // Create ephemeral listener + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err, "Failed to create ephemeral listener") - // Get the assigned port - _, portStr, err := net.SplitHostPort(ln.Addr().String()) - require.NoError(t, err, "Failed to get port") - - // Modify server to use our listener's port - addr := "localhost:" + portStr + // Get the address for the test client + addr := ln.Addr().String() // Start the server in a goroutine serverCtx, serverCancel := context.WithCancel(context.Background()) @@ -117,15 +113,12 @@ func TestServerWithTLS(t *testing.T) { serverReady := make(chan struct{}) go func() { - // Close our listener as server.Listen will create its own - ln.Close() - // Signal that we're about to start the server close(serverReady) - listenErr := server.Listen(serverCtx, addr) + listenErr := server.Serve(serverCtx, ln) if listenErr != nil && listenErr.Error() != "http: Server closed" { - t.Errorf("server.Listen returned: %v", listenErr) + t.Errorf("server.Serve returned: %v", listenErr) } }() defer server.Shutdown(context.Background()) @@ -200,16 +193,12 @@ func TestServerWithTLSContextCancellation(t *testing.T) { }) server.RegisterRoute([]Middleware{}, testRoute) - // Create a net.Listener to determine the port - ln, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err, "Failed to create listener") + // Create ephemeral listener + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err, "Failed to create ephemeral listener") - // Get the assigned port - _, portStr, err := net.SplitHostPort(ln.Addr().String()) - require.NoError(t, err, "Failed to get port") - - // Modify server to use our listener's port - addr := "localhost:" + portStr + // Get the address for the test client + addr := ln.Addr().String() // Create a context that can be canceled serverCtx, serverCancel := context.WithCancel(context.Background()) @@ -222,15 +211,12 @@ func TestServerWithTLSContextCancellation(t *testing.T) { // Start the server in a goroutine go func() { - // Close our listener as server.Listen will create its own - ln.Close() - // Signal that we're about to start the server close(serverReady) - listenErr := server.Listen(serverCtx, addr) + listenErr := server.Serve(serverCtx, ln) if listenErr != nil && listenErr.Error() != "http: Server closed" { - t.Errorf("server.Listen returned: %v", listenErr) + t.Errorf("server.Serve returned: %v", listenErr) } // Signal that the server has exited From e071c2c54e96e03516b58aba0779442442d4d833 Mon Sep 17 00:00:00 2001 From: James P Date: Thu, 17 Jul 2025 07:46:03 -0400 Subject: [PATCH 2/6] feat: Allow filtering by tags on API, Keys and Key requests (#3614) * begining of things * Stop crashing logs * fix the ability to retrieve from CH * Fix all bugs * formatting * Better logs, fix some errors, typesafe * update the v1_keys_getVerifications * fmt * delete all tests --------- Co-authored-by: MichaelUnkey --- .../src/routes/v1_keys_getVerifications.ts | 7 +- .../bar-chart/hooks/use-fetch-timeseries.ts | 11 +++ .../bar-chart/query-timeseries.schema.ts | 6 ++ .../line-chart/hooks/use-fetch-timeseries.ts | 11 +++ .../components/logs-filters/index.tsx | 31 ++++++ .../table/components/log-details/index.tsx | 4 + .../components/table/hooks/use-logs-query.ts | 11 +++ .../components/table/query-logs.schema.ts | 8 ++ .../apis/[apiId]/_overview/filters.schema.ts | 7 ++ .../[apiId]/_overview/hooks/use-filters.ts | 18 +++- .../bar-chart/hooks/use-fetch-timeseries.ts | 18 ++++ .../bar-chart/query-timeseries.schema.ts | 7 ++ .../components/logs-filters/index.tsx | 36 ++++++- .../components/table/hooks/use-logs-query.ts | 10 ++ .../components/table/query-logs.schema.ts | 9 ++ .../[keyAuthId]/[keyId]/filters.schema.ts | 24 ++++- .../[keyAuthId]/[keyId]/hooks/use-filters.ts | 45 ++++++++- .../components/logs-filters/index.tsx | 30 ++++++ .../table/hooks/use-keys-list-query.ts | 1 - .../components/table/query-logs.schema.ts | 1 + .../[keyAuthId]/_components/filters.schema.ts | 5 + .../_components/hooks/use-filters.ts | 1 - .../api/keys/query-api-keys/get-all-keys.ts | 98 ++++++++++++++++++- .../routers/api/keys/query-api-keys/index.ts | 1 + .../keys/query-key-usage-timeseries/index.ts | 1 + .../api/keys/query-overview-logs/index.ts | 3 +- .../api/keys/query-overview-logs/utils.ts | 7 ++ .../trpc/routers/api/keys/timeseries.utils.ts | 8 ++ .../api/overview/query-timeseries/utils.ts | 1 + .../lib/trpc/routers/key/query-logs/utils.ts | 7 ++ internal/clickhouse/src/keys/keys.ts | 41 ++++++++ .../clickhouse/src/verification_tags.test.ts | 1 + internal/clickhouse/src/verifications.ts | 84 +++++++++++++++- 33 files changed, 533 insertions(+), 20 deletions(-) diff --git a/apps/api/src/routes/v1_keys_getVerifications.ts b/apps/api/src/routes/v1_keys_getVerifications.ts index 99f3a4f06e..fd11a7a089 100644 --- a/apps/api/src/routes/v1_keys_getVerifications.ts +++ b/apps/api/src/routes/v1_keys_getVerifications.ts @@ -173,7 +173,11 @@ export const registerV1KeysGetVerifications = (app: App) => if (!dbRes) { return []; } - return dbRes.map((key) => ({ key, api: key.keyAuth.api, ratelimits: key.ratelimits })); + return dbRes.map((key) => ({ + key, + api: key.keyAuth.api, + ratelimits: key.ratelimits, + })); }); if (keys.err) { throw new UnkeyApiError({ @@ -238,6 +242,7 @@ export const registerV1KeysGetVerifications = (app: App) => keyIds: null, outcomes: null, names: null, + tags: null, }) .catch((err) => { throw new Error(err.message); diff --git a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/charts/bar-chart/hooks/use-fetch-timeseries.ts b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/charts/bar-chart/hooks/use-fetch-timeseries.ts index 03979a2a2b..2a6874ae6f 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/charts/bar-chart/hooks/use-fetch-timeseries.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/charts/bar-chart/hooks/use-fetch-timeseries.ts @@ -20,6 +20,7 @@ export const useFetchVerificationTimeseries = (apiId: string | null) => { outcomes: { filters: [] }, names: { filters: [] }, identities: { filters: [] }, + tags: null, apiId: apiId ?? "", since: "", }; @@ -99,6 +100,16 @@ export const useFetchVerificationTimeseries = (apiId: string | null) => { } break; } + + case "tags": { + if (typeof filter.value === "string" && filter.value.trim()) { + params.tags = { + operator, + value: filter.value, + }; + } + break; + } } }); diff --git a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/charts/bar-chart/query-timeseries.schema.ts b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/charts/bar-chart/query-timeseries.schema.ts index c11546c629..1fb98eeb72 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/charts/bar-chart/query-timeseries.schema.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/charts/bar-chart/query-timeseries.schema.ts @@ -50,6 +50,12 @@ export const keysOverviewQueryTimeseriesPayload = z.object({ ), }) .nullable(), + tags: z + .object({ + operator: keysOverviewFilterOperatorEnum, + value: z.string(), + }) + .nullable(), }); export type KeysOverviewQueryTimeseriesPayload = z.infer; diff --git a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/charts/line-chart/hooks/use-fetch-timeseries.ts b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/charts/line-chart/hooks/use-fetch-timeseries.ts index 193d64b7df..9e31041b1b 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/charts/line-chart/hooks/use-fetch-timeseries.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/charts/line-chart/hooks/use-fetch-timeseries.ts @@ -20,6 +20,7 @@ export const useFetchActiveKeysTimeseries = (apiId: string | null) => { outcomes: { filters: [] }, names: { filters: [] }, identities: { filters: [] }, + tags: null, apiId: apiId ?? "", since: "", }; @@ -99,6 +100,16 @@ export const useFetchActiveKeysTimeseries = (apiId: string | null) => { } break; } + + case "tags": { + if (typeof filter.value === "string" && filter.value.trim()) { + params.tags = { + operator, + value: filter.value, + }; + } + break; + } } }); diff --git a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/controls/components/logs-filters/index.tsx b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/controls/components/logs-filters/index.tsx index 1cb3bc6983..a0c379d6bc 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/controls/components/logs-filters/index.tsx +++ b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/controls/components/logs-filters/index.tsx @@ -20,10 +20,15 @@ export const LogsFilters = () => { const activeNameFilter = filters.find((f) => f.field === "names"); const activeIdentityFilter = filters.find((f) => f.field === "identities"); const activeKeyIdsFilter = filters.find((f) => f.field === "keyIds"); + const activeTagsFilter = filters.find((f) => f.field === "tags"); const keyIdOptions = keysOverviewFilterFieldConfig.keyIds.operators.map((op) => ({ id: op, label: op, })); + const tagsOptions = keysOverviewFilterFieldConfig.tags.operators.map((op) => ({ + id: op, + label: op, + })); return ( { /> ), }, + { + id: "tags", + label: "Tags", + shortcut: "t", + component: ( + { + const activeFiltersWithoutTags = filters.filter((f) => f.field !== "tags"); + updateFilters([ + ...activeFiltersWithoutTags, + { + field: "tags", + id: crypto.randomUUID(), + operator: id, + value: text, + }, + ]); + setOpen(false); + }} + /> + ), + }, { id: "outcomes", label: "Outcomes", diff --git a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/table/components/log-details/index.tsx b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/table/components/log-details/index.tsx index bfa55f41ed..08d663cc6c 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/table/components/log-details/index.tsx +++ b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/table/components/log-details/index.tsx @@ -93,6 +93,9 @@ export const KeysOverviewLogDetails = ({ : "Unlimited", }; + const tags = + log.tags && log.tags.length > 0 ? { Tags: log.tags.join(", ") } : { "No tags": null }; + const identity = log.key_details.identity ? { "External ID": log.key_details.identity.external_id || "N/A" } : { "No identity connected": null }; @@ -111,6 +114,7 @@ export const KeysOverviewLogDetails = ({ + diff --git a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/table/hooks/use-logs-query.ts b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/table/hooks/use-logs-query.ts index 1a44dc986a..c5d6ac96be 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/table/hooks/use-logs-query.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/table/hooks/use-logs-query.ts @@ -34,6 +34,7 @@ export function useKeysOverviewLogsQuery({ apiId, limit = 50 }: UseLogsQueryPara outcomes: [], identities: [], names: [], + tags: [], apiId, since: "", sorts: sorts.length > 0 ? sorts : null, @@ -84,6 +85,16 @@ export function useKeysOverviewLogsQuery({ apiId, limit = 50 }: UseLogsQueryPara break; } + case "tags": { + if (typeof filter.value === "string" && filter.value.trim()) { + params.tags?.push({ + operator, + value: filter.value, + }); + } + break; + } + case "startTime": case "endTime": { const numValue = diff --git a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/table/query-logs.schema.ts b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/table/query-logs.schema.ts index fb725602f3..0e82ca07e4 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/table/query-logs.schema.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/_overview/components/table/query-logs.schema.ts @@ -58,6 +58,14 @@ export const keysQueryOverviewLogsPayload = z.object({ ) .optional() .nullable(), + tags: z + .array( + z.object({ + operator: z.enum(["is", "contains", "startsWith", "endsWith"]), + value: z.string(), + }), + ) + .nullable(), }); export type KeysQueryOverviewLogsPayload = z.infer; diff --git a/apps/dashboard/app/(app)/apis/[apiId]/_overview/filters.schema.ts b/apps/dashboard/app/(app)/apis/[apiId]/_overview/filters.schema.ts index ddec9f3eb6..e570a04b66 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/_overview/filters.schema.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/_overview/filters.schema.ts @@ -33,6 +33,10 @@ export const keysOverviewFilterFieldConfig: FilterFieldConfigs = { type: "string", operators: ["is", "contains", "startsWith", "endsWith"], }, + tags: { + type: "string", + operators: ["is", "contains", "startsWith", "endsWith"], + }, outcomes: { type: "string", operators: ["is"], @@ -52,6 +56,7 @@ export const keysOverviewFilterFieldEnum = z.enum([ "names", "outcomes", "identities", + "tags", ]); export const filterOutputSchema = createFilterOutputSchema( @@ -73,6 +78,7 @@ export type FilterFieldConfigs = { names: StringConfig; outcomes: StringConfig; identities: StringConfig; + tags: StringConfig; }; export type IsOnlyUrlValue = { @@ -108,4 +114,5 @@ export type KeysQuerySearchParams = { names: AllOperatorsUrlValue[] | null; outcomes: IsOnlyUrlValue[] | null; identities: AllOperatorsUrlValue[] | null; + tags: AllOperatorsUrlValue[] | null; }; diff --git a/apps/dashboard/app/(app)/apis/[apiId]/_overview/hooks/use-filters.ts b/apps/dashboard/app/(app)/apis/[apiId]/_overview/hooks/use-filters.ts index 0b29536509..1faf756125 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/_overview/hooks/use-filters.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/_overview/hooks/use-filters.ts @@ -35,6 +35,7 @@ export const queryParamsPayload = { since: parseAsRelativeTime, outcomes: parseAsIsOnlyFilterArray, identities: parseAsAllOperatorsFilterArray, + tags: parseAsAllOperatorsFilterArray, } as const; export const useFilters = () => { @@ -46,7 +47,10 @@ export const useFilters = () => { const activeFilters: KeysOverviewFilterValue[] = []; for (const [field, value] of Object.entries(searchParams)) { - if (!Array.isArray(value) || !["keyIds", "names", "identities", "outcomes"].includes(field)) { + if ( + !Array.isArray(value) || + !["keyIds", "names", "identities", "outcomes", "tags"].includes(field) + ) { continue; } @@ -98,12 +102,14 @@ export const useFilters = () => { names: null, identities: null, outcomes: null, + tags: null, }; const keyIdFilters: IsContainsUrlValue[] = []; const nameFilters: AllOperatorsUrlValue[] = []; const identitiesFilters: AllOperatorsUrlValue[] = []; const outcomeFilters: IsOnlyUrlValue[] = []; + const tagFilters: AllOperatorsUrlValue[] = []; newFilters.forEach((filter) => { const fieldConfig = keysOverviewFilterFieldConfig[filter.field]; @@ -146,6 +152,15 @@ export const useFilters = () => { } break; + case "tags": + if (typeof filter.value === "string") { + tagFilters.push({ + value: filter.value, + operator: operator as "is" | "contains" | "startsWith" | "endsWith", + }); + } + break; + case "startTime": case "endTime": { const numValue = @@ -173,6 +188,7 @@ export const useFilters = () => { newParams.names = nameFilters.length > 0 ? nameFilters : null; newParams.identities = identitiesFilters.length > 0 ? identitiesFilters : null; newParams.outcomes = outcomeFilters.length > 0 ? outcomeFilters : null; + newParams.tags = tagFilters.length > 0 ? tagFilters : null; setSearchParams(newParams); }, diff --git a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/charts/bar-chart/hooks/use-fetch-timeseries.ts b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/charts/bar-chart/hooks/use-fetch-timeseries.ts index 69f669bb7a..5ce4c4fe82 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/charts/bar-chart/hooks/use-fetch-timeseries.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/charts/bar-chart/hooks/use-fetch-timeseries.ts @@ -17,6 +17,7 @@ export const useFetchVerificationTimeseries = (keyId: string, keyspaceId: string startTime: timestamp - HISTORICAL_DATA_WINDOW, endTime: timestamp, outcomes: { filters: [] }, + tags: null, since: "", keyId, keyspaceId, @@ -28,6 +29,23 @@ export const useFetchVerificationTimeseries = (keyId: string, keyspaceId: string } switch (filter.field) { + case "tags": { + if (typeof filter.value === "string" && filter.value.trim()) { + const fieldConfig = keyDetailsFilterFieldConfig[filter.field]; + const validOperators = fieldConfig.operators; + + const operator = validOperators.includes(filter.operator) + ? filter.operator + : validOperators[0]; + + params.tags = { + operator, + value: filter.value, + }; + } + break; + } + case "startTime": case "endTime": { const numValue = diff --git a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/charts/bar-chart/query-timeseries.schema.ts b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/charts/bar-chart/query-timeseries.schema.ts index 5693528947..3fde367017 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/charts/bar-chart/query-timeseries.schema.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/charts/bar-chart/query-timeseries.schema.ts @@ -1,5 +1,6 @@ import { KEY_VERIFICATION_OUTCOMES } from "@unkey/clickhouse/src/keys/keys"; import { z } from "zod"; +import { TAG_OPERATORS } from "../../../filters.schema"; export const MAX_KEYID_COUNT = 1; export const keyDetailsQueryTimeseriesPayload = z.object({ @@ -8,6 +9,12 @@ export const keyDetailsQueryTimeseriesPayload = z.object({ since: z.string(), keyspaceId: z.string(), keyId: z.string(), + tags: z + .object({ + operator: z.enum(TAG_OPERATORS), + value: z.string(), + }) + .nullable(), outcomes: z .object({ filters: z.array( diff --git a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/controls/components/logs-filters/index.tsx b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/controls/components/logs-filters/index.tsx index 8e2fc04b75..a1453e1469 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/controls/components/logs-filters/index.tsx +++ b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/controls/components/logs-filters/index.tsx @@ -1,21 +1,55 @@ import { FiltersPopover } from "@/components/logs/checkbox/filters-popover"; +import { FilterOperatorInput } from "@/components/logs/filter-operator-input"; import { BarsFilter } from "@unkey/icons"; import { Button } from "@unkey/ui"; import { cn } from "@unkey/ui/src/lib/utils"; import { useState } from "react"; +import { keyDetailsFilterFieldConfig } from "../../../../filters.schema"; import { useFilters } from "../../../../hooks/use-filters"; import { OutcomesFilter } from "./outcome-filter"; export const LogsFilters = () => { - const { filters } = useFilters(); + const { filters, updateFilters } = useFilters(); const [open, setOpen] = useState(false); + const activeTagsFilter = filters.find((f) => f.field === "tags"); + const tagsOptions = keyDetailsFilterFieldConfig.tags.operators.map((op) => ({ + id: op, + label: op, + })); + return ( { + const activeFiltersWithoutTags = filters.filter((f) => f.field !== "tags"); + updateFilters([ + ...activeFiltersWithoutTags, + { + field: "tags", + id: crypto.randomUUID(), + operator: id, + value: text, + }, + ]); + setOpen(false); + }} + /> + ), + }, { id: "outcomes", label: "Outcomes", diff --git a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/table/hooks/use-logs-query.ts b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/table/hooks/use-logs-query.ts index 8199b24d0d..72a6a629c0 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/table/hooks/use-logs-query.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/table/hooks/use-logs-query.ts @@ -65,6 +65,7 @@ export function useKeyDetailsLogsQuery({ startTime: timestamp - HISTORICAL_DATA_WINDOW, endTime: timestamp, outcomes: [], + tags: [], since: "", }; @@ -76,6 +77,15 @@ export function useKeyDetailsLogsQuery({ } switch (filter.field) { + case "tags": { + if (typeof filter.value === "string") { + params.tags?.push({ + value: filter.value, + operator: filter.operator as "is" | "contains" | "startsWith" | "endsWith", + }); + } + break; + } case "outcomes": { type ValidOutcome = (typeof KEY_VERIFICATION_OUTCOMES)[number]; if ( diff --git a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/table/query-logs.schema.ts b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/table/query-logs.schema.ts index 1d84ba3a40..94cd07cf3d 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/table/query-logs.schema.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/components/table/query-logs.schema.ts @@ -9,6 +9,15 @@ export const keyDetailsLogsPayload = z.object({ keyId: z.string(), since: z.string(), cursor: z.number().nullable().optional().nullable(), + tags: z + .array( + z.object({ + value: z.string(), + operator: z.enum(["is", "contains", "startsWith", "endsWith"]), + }), + ) + .optional() + .nullable(), outcomes: z .array( z.object({ diff --git a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/filters.schema.ts b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/filters.schema.ts index fee38c74a3..f5d74ce923 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/filters.schema.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/filters.schema.ts @@ -9,6 +9,7 @@ import { z } from "zod"; import { getOutcomeColor } from "../../../_overview/utils"; export const ALLOWED_OPERATOR = ["is"] as const; +export const TAG_OPERATORS = ["is", "contains", "startsWith", "endsWith"] as const; export type KeyDetailsFilterOperator = z.infer; export const keyDetailsFilterFieldConfig: FilterFieldConfigs = { @@ -24,6 +25,10 @@ export const keyDetailsFilterFieldConfig: FilterFieldConfigs = { type: "string", operators: ALLOWED_OPERATOR, }, + tags: { + type: "string", + operators: TAG_OPERATORS, + }, outcomes: { type: "string", operators: ALLOWED_OPERATOR, @@ -32,8 +37,14 @@ export const keyDetailsFilterFieldConfig: FilterFieldConfigs = { } as const, }; -export const keyDetailsFilterOperatorEnum = z.enum(ALLOWED_OPERATOR); -export const keyDetailsFilterFieldEnum = z.enum(["startTime", "endTime", "since", "outcomes"]); +export const keyDetailsFilterOperatorEnum = z.enum([...ALLOWED_OPERATOR, ...TAG_OPERATORS]); +export const keyDetailsFilterFieldEnum = z.enum([ + "startTime", + "endTime", + "since", + "tags", + "outcomes", +]); export const filterOutputSchema = createFilterOutputSchema( keyDetailsFilterFieldEnum, @@ -47,12 +58,13 @@ export type FilterFieldConfigs = { startTime: NumberConfig; endTime: NumberConfig; since: StringConfig; + tags: StringConfig; outcomes: StringConfig; }; export type IsOnlyUrlValue = { value: string | number; - operator: KeyDetailsFilterOperator; + operator: "is"; }; export type KeyDetailsFilterUrlValue = Pick< @@ -62,9 +74,15 @@ export type KeyDetailsFilterUrlValue = Pick< export type KeyDetailsFilterValue = FilterValue; +export type AllOperatorsUrlValue = { + value: string; + operator: "is" | "contains" | "startsWith" | "endsWith"; +}; + export type KeysQuerySearchParams = { startTime?: number | null; endTime?: number | null; since?: string | null; + tags: AllOperatorsUrlValue[] | null; outcomes: IsOnlyUrlValue[] | null; }; diff --git a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/hooks/use-filters.ts b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/hooks/use-filters.ts index bd9b241b61..988cc9466c 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/hooks/use-filters.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/[keyId]/hooks/use-filters.ts @@ -5,6 +5,7 @@ import { import { parseAsInteger, useQueryStates } from "nuqs"; import { useCallback, useMemo } from "react"; import { + type AllOperatorsUrlValue, type IsOnlyUrlValue, type KeyDetailsFilterField, type KeyDetailsFilterValue, @@ -13,11 +14,15 @@ import { } from "../filters.schema"; const parseAsIsOnlyFilterArray = parseAsFilterValueArray<"is">(["is"]); +const parseAsAllOperatorsFilterArray = parseAsFilterValueArray< + "is" | "contains" | "startsWith" | "endsWith" +>(["is", "contains", "startsWith", "endsWith"]); export const queryParamsPayload = { startTime: parseAsInteger, endTime: parseAsInteger, since: parseAsRelativeTime, + tags: parseAsAllOperatorsFilterArray, outcomes: parseAsIsOnlyFilterArray, } as const; @@ -30,7 +35,7 @@ export const useFilters = () => { const activeFilters: KeyDetailsFilterValue[] = []; for (const [field, value] of Object.entries(searchParams)) { - if (!Array.isArray(value) || field !== "outcomes") { + if (!Array.isArray(value) || (field !== "outcomes" && field !== "tags")) { continue; } @@ -78,10 +83,12 @@ export const useFilters = () => { startTime: null, endTime: null, since: null, + tags: null, outcomes: null, }; const outcomeFilters: IsOnlyUrlValue[] = []; + const tagFilters: AllOperatorsUrlValue[] = []; newFilters.forEach((filter) => { const fieldConfig = keyDetailsFilterFieldConfig[filter.field]; @@ -90,12 +97,29 @@ export const useFilters = () => { const operator = validOperators.includes(filter.operator) ? filter.operator : validOperators[0]; - if (operator !== "is") { - throw new Error("Invalid filter operator. Only 'is' operator is allowed."); - } - switch (filter.field) { + case "tags": + if (!validOperators.includes(filter.operator)) { + throw new Error( + `Invalid filter operator for tags. Allowed operators are: ${validOperators.join( + ", ", + )}`, + ); + } + if (typeof filter.value === "string") { + tagFilters.push({ + value: filter.value, + operator: filter.operator as "is" | "contains" | "startsWith" | "endsWith", + }); + } + break; + case "outcomes": + if (operator !== "is") { + throw new Error( + "Invalid filter operator for outcomes. Only 'is' operator is allowed.", + ); + } if (typeof filter.value === "string") { outcomeFilters.push({ value: filter.value, @@ -106,6 +130,11 @@ export const useFilters = () => { case "startTime": case "endTime": { + if (operator !== "is") { + throw new Error( + "Invalid filter operator for time fields. Only 'is' operator is allowed.", + ); + } const numValue = typeof filter.value === "number" ? filter.value @@ -120,6 +149,11 @@ export const useFilters = () => { } case "since": + if (operator !== "is") { + throw new Error( + "Invalid filter operator for since field. Only 'is' operator is allowed.", + ); + } if (typeof filter.value === "string") { newParams.since = filter.value; } @@ -127,6 +161,7 @@ export const useFilters = () => { } }); + newParams.tags = tagFilters.length > 0 ? tagFilters : null; newParams.outcomes = outcomeFilters.length > 0 ? outcomeFilters : null; setSearchParams(newParams); diff --git a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/components/controls/components/logs-filters/index.tsx b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/components/controls/components/logs-filters/index.tsx index f495aa0aa4..5db05bce38 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/components/controls/components/logs-filters/index.tsx +++ b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/components/controls/components/logs-filters/index.tsx @@ -17,10 +17,15 @@ export const LogsFilters = () => { const activeNameFilter = filters.find((f) => f.field === "names"); const activeIdentityFilter = filters.find((f) => f.field === "identities"); const activeKeyIdsFilter = filters.find((f) => f.field === "keyIds"); + const activeTagsFilter = filters.find((f) => f.field === "tags"); const keyIdOptions = keysListFilterFieldConfig.names.operators.map((op) => ({ id: op, label: op, })); + const tagsOptions = keysListFilterFieldConfig.tags.operators.map((op) => ({ + id: op, + label: op, + })); return ( { /> ), }, + { + id: "tags", + label: "Tags", + shortcut: "t", + component: ( + { + const activeFiltersWithoutTags = filters.filter((f) => f.field !== "tags"); + updateFilters([ + ...activeFiltersWithoutTags, + { + field: "tags", + id: crypto.randomUUID(), + operator: id, + value: text, + }, + ]); + }} + /> + ), + }, ]} activeFilters={filters} > diff --git a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/components/table/hooks/use-keys-list-query.ts b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/components/table/hooks/use-keys-list-query.ts index 19835eeefb..148d17be4a 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/components/table/hooks/use-keys-list-query.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/components/table/hooks/use-keys-list-query.ts @@ -43,7 +43,6 @@ export function useKeysListQuery({ keyAuthId }: UseKeysListQueryParams) { }); } }); - return params; }, [filters, keyAuthId]); diff --git a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/components/table/query-logs.schema.ts b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/components/table/query-logs.schema.ts index 68f8f7f185..a83252263f 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/components/table/query-logs.schema.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/components/table/query-logs.schema.ts @@ -13,6 +13,7 @@ const baseKeysSchema = z.object({ names: baseFilterArraySchema, identities: baseFilterArraySchema, keyIds: baseFilterArraySchema, + tags: baseFilterArraySchema, }); export const keysQueryListPayload = baseKeysSchema.extend({ diff --git a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/filters.schema.ts b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/filters.schema.ts index dfa1a5eeac..cb69bce601 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/filters.schema.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/filters.schema.ts @@ -14,6 +14,7 @@ export type FilterFieldConfigs = { keyIds: StringConfig; names: StringConfig; identities: StringConfig; + tags: StringConfig; }; export const keysListFilterFieldConfig: FilterFieldConfigs = { @@ -29,6 +30,10 @@ export const keysListFilterFieldConfig: FilterFieldConfigs = { type: "string", operators: [...commonStringOperators], }, + tags: { + type: "string", + operators: [...commonStringOperators], + }, }; const allFilterFieldNames = Object.keys(keysListFilterFieldConfig) as (keyof FilterFieldConfigs)[]; diff --git a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/hooks/use-filters.ts b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/hooks/use-filters.ts index 8e0b2b4207..af079b1866 100644 --- a/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/hooks/use-filters.ts +++ b/apps/dashboard/app/(app)/apis/[apiId]/keys/[keyAuthId]/_components/hooks/use-filters.ts @@ -82,7 +82,6 @@ export const useFilters = () => { newParams[field] = fieldFilters; } }); - setSearchParams(newParams); }, [setSearchParams], diff --git a/apps/dashboard/lib/trpc/routers/api/keys/query-api-keys/get-all-keys.ts b/apps/dashboard/lib/trpc/routers/api/keys/query-api-keys/get-all-keys.ts index 30dd3cde46..a21ed5d3a3 100644 --- a/apps/dashboard/lib/trpc/routers/api/keys/query-api-keys/get-all-keys.ts +++ b/apps/dashboard/lib/trpc/routers/api/keys/query-api-keys/get-all-keys.ts @@ -1,7 +1,9 @@ import type { AllOperatorsUrlValue } from "@/app/(app)/apis/[apiId]/_overview/filters.schema"; +import { clickhouse } from "@/lib/clickhouse"; import { type SQL, db, like, or } from "@/lib/db"; import { TRPCError } from "@trpc/server"; import { identities } from "@unkey/db/src/schema"; +import { z } from "zod"; import type { KeyDetails } from "./schema"; interface GetAllKeysInput { @@ -11,6 +13,7 @@ interface GetAllKeysInput { keyIds?: AllOperatorsUrlValue[] | null; names?: AllOperatorsUrlValue[] | null; identities?: AllOperatorsUrlValue[] | null; + tags?: AllOperatorsUrlValue[] | null; }; limit?: number; cursorKeyId?: string | null; @@ -35,7 +38,8 @@ export async function getAllKeys({ limit = 50, cursorKeyId = null, }: GetAllKeysInput): Promise { - const { keyIds, names, identities: identityFilters } = filters; + const { keyIds, names, identities: identityFilters, tags } = filters; + try { // Security verification - ensure the keyspaceId belongs to the workspaceId const keyAuth = await db.query.keyAuth.findFirst({ @@ -56,11 +60,101 @@ export async function getAllKeys({ }); } + // Get keys that match tag filters if provided + let tagFilteredKeyIds: string[] | null = null; + + if (tags && tags.length > 0) { + try { + // Build tag filter conditions with proper parameterization + const tagQueries = tags.map((tag, index) => { + const paramKey = `tagValue${index}`; + let condition: string; + + switch (tag.operator) { + case "is": + condition = `has(tags, {${paramKey}: String})`; + break; + case "contains": + condition = `arrayExists(x -> position(x, {${paramKey}: String}) > 0, tags)`; + break; + case "startsWith": + condition = `arrayExists(x -> startsWith(x, {${paramKey}: String}), tags)`; + break; + case "endsWith": + condition = `arrayExists(x -> endsWith(x, {${paramKey}: String}), tags)`; + break; + default: + condition = `has(tags, {${paramKey}: String})`; + } + + return { condition, paramKey, value: tag.value }; + }); + + // Build the params schema dynamically + const paramsObj: Record = { + workspaceId: z.string(), + keyspaceId: z.string(), + }; + tagQueries.forEach(({ paramKey }) => { + paramsObj[paramKey] = z.string(); + }); + + // Build the query parameters + const queryParams: Record = { + workspaceId, + keyspaceId, + }; + tagQueries.forEach(({ paramKey, value }) => { + queryParams[paramKey] = value; + }); + + const tagQuery = clickhouse.querier.query({ + query: ` + SELECT DISTINCT key_id + FROM verifications.raw_key_verifications_v1 + WHERE workspace_id = {workspaceId: String} + AND key_space_id = {keyspaceId: String} + AND (${tagQueries.map(({ condition }) => condition).join(" OR ")}) + `, + params: z.object(paramsObj), + schema: z.object({ + key_id: z.string(), + }), + }); + + const result = await tagQuery(queryParams); + + if (result.err) { + console.error("ClickHouse query error:", result.err); + tagFilteredKeyIds = []; + } else { + tagFilteredKeyIds = result.val.map((row) => row.key_id); + } + } catch (error) { + console.error("Error querying tags from ClickHouse:", error); + tagFilteredKeyIds = []; + } + } + // Helper function to build the filter conditions (without cursor) // biome-ignore lint/suspicious/noExplicitAny: Leave it as is for now const buildFilterConditions = (key: any, { and, isNull, eq, sql }: any) => { const conditions = [eq(key.keyAuthId, keyspaceId), isNull(key.deletedAtM)]; + // Apply tag-based key filtering if we have filtered key IDs + if (tagFilteredKeyIds !== null) { + if (tagFilteredKeyIds.length === 0) { + conditions.push(sql`1 = 0`); + } else { + conditions.push( + sql`${key.id} IN (${sql.join( + tagFilteredKeyIds.map((id) => sql`${id}`), + sql`, `, + )})`, + ); + } + } + // Apply name filters if (names && names.length > 0) { const nameConditions = []; @@ -188,8 +282,6 @@ export async function getAllKeys({ }); const totalCount = countQuery.length; - - // Get the paginated keys with filters and cursor const keysQuery = await db.query.keys.findMany({ where: (key, helpers) => { const { and, lt } = helpers; diff --git a/apps/dashboard/lib/trpc/routers/api/keys/query-api-keys/index.ts b/apps/dashboard/lib/trpc/routers/api/keys/query-api-keys/index.ts index 2686da9104..799d39b10e 100644 --- a/apps/dashboard/lib/trpc/routers/api/keys/query-api-keys/index.ts +++ b/apps/dashboard/lib/trpc/routers/api/keys/query-api-keys/index.ts @@ -27,6 +27,7 @@ export const queryKeysList = t.procedure keyIds: input.keyIds, names: input.names, identities: input.identities, + tags: input.tags, }, limit: input.limit, cursorKeyId: input.cursor ?? null, diff --git a/apps/dashboard/lib/trpc/routers/api/keys/query-key-usage-timeseries/index.ts b/apps/dashboard/lib/trpc/routers/api/keys/query-key-usage-timeseries/index.ts index cb713821ea..c2a47aac9b 100644 --- a/apps/dashboard/lib/trpc/routers/api/keys/query-key-usage-timeseries/index.ts +++ b/apps/dashboard/lib/trpc/routers/api/keys/query-key-usage-timeseries/index.ts @@ -19,6 +19,7 @@ export const keyUsageTimeseries = t.procedure identities: null, keyIds: null, names: null, + tags: null, outcomes: null, }); diff --git a/apps/dashboard/lib/trpc/routers/api/keys/query-overview-logs/index.ts b/apps/dashboard/lib/trpc/routers/api/keys/query-overview-logs/index.ts index 44269abfe7..282f7a7928 100644 --- a/apps/dashboard/lib/trpc/routers/api/keys/query-overview-logs/index.ts +++ b/apps/dashboard/lib/trpc/routers/api/keys/query-overview-logs/index.ts @@ -46,6 +46,8 @@ export const queryKeysOverviewLogs = t.procedure keyspaceId: keyspaceId, // Only include keyIds filters if explicitly provided in the input keyIds: input.keyIds ? transformedInputs.keyIds : null, + // Pass tags to ClickHouse for filtering + tags: transformedInputs.tags, // Nullify these as we'll filter in the database names: null, identities: null, @@ -89,7 +91,6 @@ export const queryKeysOverviewLogs = t.procedure ...log, key_details: keyDetailsMap.get(log.key_id) || null, })); - const response: KeysOverviewLogsResponse = { keysOverviewLogs, hasMore: logs.length === input.limit && keysOverviewLogs.length > 0, diff --git a/apps/dashboard/lib/trpc/routers/api/keys/query-overview-logs/utils.ts b/apps/dashboard/lib/trpc/routers/api/keys/query-overview-logs/utils.ts index c0122f951d..0e30e0b5a1 100644 --- a/apps/dashboard/lib/trpc/routers/api/keys/query-overview-logs/utils.ts +++ b/apps/dashboard/lib/trpc/routers/api/keys/query-overview-logs/utils.ts @@ -31,6 +31,12 @@ export function transformKeysFilters( value: k.value, })) ?? null; + const tags = + params.tags?.map((k) => ({ + operator: k.operator, + value: k.value, + })) ?? null; + const outcomes = params.outcomes?.map((o) => ({ operator: "is" as const, @@ -50,6 +56,7 @@ export function transformKeysFilters( keyIds, names, identities, + tags, outcomes, cursorTime: params.cursor ?? null, sorts, diff --git a/apps/dashboard/lib/trpc/routers/api/keys/timeseries.utils.ts b/apps/dashboard/lib/trpc/routers/api/keys/timeseries.utils.ts index fd85932d41..a7447ac4a0 100644 --- a/apps/dashboard/lib/trpc/routers/api/keys/timeseries.utils.ts +++ b/apps/dashboard/lib/trpc/routers/api/keys/timeseries.utils.ts @@ -45,6 +45,14 @@ export function transformVerificationFilters(params: KeysOverviewQueryTimeseries operator: f.operator, value: f.value, })) || null, + tags: params.tags + ? [ + { + operator: params.tags.operator, + value: params.tags.value, + }, + ] + : null, }, granularity: timeConfig.granularity, }; diff --git a/apps/dashboard/lib/trpc/routers/api/overview/query-timeseries/utils.ts b/apps/dashboard/lib/trpc/routers/api/overview/query-timeseries/utils.ts index d1b2dc8767..4c294e6752 100644 --- a/apps/dashboard/lib/trpc/routers/api/overview/query-timeseries/utils.ts +++ b/apps/dashboard/lib/trpc/routers/api/overview/query-timeseries/utils.ts @@ -28,6 +28,7 @@ export function transformVerificationFilters(params: VerificationQueryTimeseries startTime: timeConfig.startTime, keyIds: [], names: [], + tags: null, outcomes: [], endTime: timeConfig.endTime, }, diff --git a/apps/dashboard/lib/trpc/routers/key/query-logs/utils.ts b/apps/dashboard/lib/trpc/routers/key/query-logs/utils.ts index 201d458e68..e805e72432 100644 --- a/apps/dashboard/lib/trpc/routers/key/query-logs/utils.ts +++ b/apps/dashboard/lib/trpc/routers/key/query-logs/utils.ts @@ -20,6 +20,12 @@ export function transformKeyDetailsFilters( value: o.value, })) ?? null; + const tags = + params.tags?.map((t) => ({ + operator: t.operator, + value: t.value, + })) ?? null; + return { workspaceId, keyId: params.keyId, @@ -29,5 +35,6 @@ export function transformKeyDetailsFilters( endTime, cursorTime: params.cursor ?? null, outcomes, + tags, }; } diff --git a/internal/clickhouse/src/keys/keys.ts b/internal/clickhouse/src/keys/keys.ts index 3d902686f5..9a2a5d4604 100644 --- a/internal/clickhouse/src/keys/keys.ts +++ b/internal/clickhouse/src/keys/keys.ts @@ -49,6 +49,14 @@ export const keysOverviewLogsParams = z.object({ }), ) .nullable(), + tags: z + .array( + z.object({ + operator: z.enum(["is", "contains", "startsWith", "endsWith"]), + value: z.string(), + }), + ) + .nullable(), cursorTime: z.number().int().nullable(), sorts: z .array( @@ -104,6 +112,7 @@ export const rawKeysOverviewLogs = z.object({ valid_count: z.number().int(), error_count: z.number().int(), outcome_counts: z.record(z.string(), z.number().int()), + tags: z.array(z.string()).optional(), }); export const keysOverviewLogs = rawKeysOverviewLogs.extend({ @@ -125,6 +134,7 @@ export function getKeysOverviewLogs(ch: Querier) { const hasKeyIdFilters = args.keyIds && args.keyIds.length > 0; const hasOutcomeFilters = args.outcomes && args.outcomes.length > 0; + const hasTagFilters = args.tags && args.tags.length > 0; const hasSortingRules = args.sorts && args.sorts.length > 0; const outcomeCondition = hasOutcomeFilters @@ -161,6 +171,29 @@ export function getKeysOverviewLogs(ch: Querier) { .join(" OR ") || "TRUE" : "TRUE"; + const tagConditions = hasTagFilters + ? args.tags + ?.map((filter, index) => { + const paramName = `tagValue_${index}`; + paramSchemaExtension[paramName] = z.string(); + parameters[paramName] = filter.value; + switch (filter.operator) { + case "is": + return `has(tags, {${paramName}: String})`; + case "contains": + return `arrayExists(x -> like(x, CONCAT('%', {${paramName}: String}, '%')), tags)`; + case "startsWith": + return `arrayExists(x -> like(x, CONCAT({${paramName}: String}, '%')), tags)`; + case "endsWith": + return `arrayExists(x -> like(x, CONCAT('%', {${paramName}: String})), tags)`; + default: + return null; + } + }) + .filter(Boolean) + .join(" OR ") || "TRUE" + : "TRUE"; + const allowedColumns = new Map([ ["time", "time"], ["valid", "valid_count"], @@ -238,6 +271,7 @@ WITH request_id, time, key_id, + tags, outcome FROM verifications.raw_key_verifications_v1 WHERE workspace_id = {workspaceId: String} @@ -247,6 +281,8 @@ WITH AND (${keyIdConditions}) -- Apply dynamic outcome filtering AND (${outcomeCondition}) + -- Apply dynamic tag filtering + AND (${tagConditions}) -- Handle pagination using only time as cursor ${cursorCondition} ), @@ -259,6 +295,8 @@ WITH max(time) as last_request_time, -- Get the request_id of the latest verification (based on time) argMax(request_id, time) as last_request_id, + -- Get the tags from the latest verification (based on time) + argMax(tags, time) as tags, -- Count valid verifications countIf(outcome = 'VALID') as valid_count, -- Count all non-valid verifications @@ -282,6 +320,7 @@ WITH a.key_id, a.last_request_time as time, a.last_request_id as request_id, + a.tags, a.valid_count, a.error_count, -- Create an array of tuples containing all outcomes and their counts @@ -294,6 +333,7 @@ WITH a.key_id, a.last_request_time, a.last_request_id, + a.tags, a.valid_count, a.error_count -- Sort results with most recent verification first @@ -329,6 +369,7 @@ WITH key_id: result.key_id, time: result.time, request_id: result.request_id, + tags: result.tags, valid_count: result.valid_count, error_count: result.error_count, outcome_counts: outcomeCountsObj, diff --git a/internal/clickhouse/src/verification_tags.test.ts b/internal/clickhouse/src/verification_tags.test.ts index 70dfaa08ec..3fa77fdf30 100644 --- a/internal/clickhouse/src/verification_tags.test.ts +++ b/internal/clickhouse/src/verification_tags.test.ts @@ -128,6 +128,7 @@ describe("materialized views", () => { identities: null, // Required parameter keyIds: null, // Required parameter outcomes: null, // Required parameter + tags: null, // Required parameter }); // Calculate total verification count from all data points diff --git a/internal/clickhouse/src/verifications.ts b/internal/clickhouse/src/verifications.ts index a4f90b6974..2759dc4d34 100644 --- a/internal/clickhouse/src/verifications.ts +++ b/internal/clickhouse/src/verifications.ts @@ -36,6 +36,14 @@ export const keyDetailsLogsParams = z.object({ limit: z.number().int(), startTime: z.number().int(), endTime: z.number().int(), + tags: z + .array( + z.object({ + value: z.string(), + operator: z.enum(["is", "contains", "startsWith", "endsWith"]), + }), + ) + .nullable(), outcomes: z .array( z.object({ @@ -67,8 +75,33 @@ export function getKeyDetailsLogs(ch: Querier) { const paramSchemaExtension: Record = {}; const parameters: ExtendedParamsKeyDetails = { ...args }; + const hasTagFilters = args.tags && args.tags.length > 0; const hasOutcomeFilters = args.outcomes && args.outcomes.length > 0; + const tagCondition = hasTagFilters + ? args.tags + ?.map((filter, index) => { + const paramName = `tagValue_${index}`; + paramSchemaExtension[paramName] = z.string(); + parameters[paramName] = filter.value; + + switch (filter.operator) { + case "is": + return `has(tags, {${paramName}: String})`; + case "contains": + return `arrayExists(tag -> position(tag, {${paramName}: String}) > 0, tags)`; + case "startsWith": + return `arrayExists(tag -> startsWith(tag, {${paramName}: String}), tags)`; + case "endsWith": + return `arrayExists(tag -> endsWith(tag, {${paramName}: String}), tags)`; + default: + return null; + } + }) + .filter(Boolean) + .join(" AND ") || "TRUE" + : "TRUE"; + const outcomeCondition = hasOutcomeFilters ? args.outcomes ?.map((filter, index) => { @@ -104,6 +137,7 @@ export function getKeyDetailsLogs(ch: Querier) { AND key_space_id = {keyspaceId: String} AND key_id = {keyId: String} AND time BETWEEN {startTime: UInt64} AND {endTime: UInt64} + AND (${tagCondition}) AND (${outcomeCondition}) `; @@ -182,6 +216,14 @@ export const verificationTimeseriesParams = z.object({ }), ) .nullable(), + tags: z + .array( + z.object({ + operator: z.enum(["is", "contains", "startsWith", "endsWith"]), + value: z.string(), + }), + ) + .nullable(), outcomes: z .array( z.object({ @@ -328,7 +370,7 @@ function createVerificationTimeseriesQuery(interval: TimeInterval, whereClause: 'forbidden_count', SUM(IF(outcome = 'FORBIDDEN', count, 0)), 'disabled_count', SUM(IF(outcome = 'DISABLED', count, 0)) , 'expired_count',SUM(IF(outcome = 'EXPIRED', count, 0)) , - 'usage_exceeded_count', SUM(IF(outcome = 'USAGE_EXCEEDED', count, 0)) + 'usage_exceeded_count', SUM(IF(outcome = 'USAGE_EXCEEDED', count, 0)) ) as y FROM ${interval.table} ${whereClause} @@ -400,6 +442,33 @@ function getVerificationTimeseriesWhereClause( } } + // Handle tags filter + if (params.tags && params.tags.length > 0) { + const tagConditions = params.tags + .map((filter, index) => { + const paramName = `tagValue_${index}`; + paramSchemaExtension[paramName] = z.string(); + + switch (filter.operator) { + case "is": + return `has(tags, {${paramName}: String})`; + case "contains": + return `arrayExists(tag -> position(tag, {${paramName}: String}) > 0, tags)`; + case "startsWith": + return `arrayExists(tag -> startsWith(tag, {${paramName}: String}), tags)`; + case "endsWith": + return `arrayExists(tag -> endsWith(tag, {${paramName}: String}), tags)`; + default: + return null; + } + }) + .filter(Boolean); + + if (tagConditions.length > 0) { + conditions.push(`(${tagConditions.join(" AND ")})`); + } + } + return { whereClause: conditions.length > 0 ? `WHERE ${conditions.join(" AND ")}` : "", paramSchema: verificationTimeseriesParams.extend(paramSchemaExtension), @@ -414,19 +483,28 @@ function createVerificationTimeseriesQuerier(interval: TimeInterval) { ]); // Create parameters object with filter values + const parameters = { ...args, ...(args.keyIds?.reduce( (acc, filter, index) => ({ - // biome-ignore lint/performance/noAccumulatingSpread: + // biome-ignore lint/performance/noAccumulatingSpread: We don't care about the spread syntax warning here ...acc, [`keyIdValue_${index}`]: filter.value, }), {}, ) ?? {}), + ...(args.tags?.reduce( + (acc, filter, index) => ({ + // biome-ignore lint/performance/noAccumulatingSpread: We don't care about the spread syntax warning here + ...acc, + [`tagValue_${index}`]: filter.value, + }), + {}, + ) ?? {}), ...(args.outcomes?.reduce( (acc, filter, index) => ({ - // biome-ignore lint/performance/noAccumulatingSpread: + // biome-ignore lint/performance/noAccumulatingSpread: We don't care about the spread syntax warning here ...acc, [`outcomeValue_${index}`]: filter.value, }), From c146fce9286c21eab519b326634765b8330296c2 Mon Sep 17 00:00:00 2001 From: chronark Date: Thu, 17 Jul 2025 13:46:36 +0200 Subject: [PATCH 3/6] ci: revert to when it actually worked --- .github/workflows/job_test_api_local.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/job_test_api_local.yaml b/.github/workflows/job_test_api_local.yaml index ea7192951a..3b59dba63a 100644 --- a/.github/workflows/job_test_api_local.yaml +++ b/.github/workflows/job_test_api_local.yaml @@ -3,7 +3,6 @@ on: workflow_call: permissions: contents: read - jobs: test: name: API Test Local @@ -16,7 +15,7 @@ jobs: run: rm -rf /opt/hostedtoolcache - name: Run containers - run: docker compose -f ./deployment/docker-compose.yaml up mysql planetscale agent s3 chproxy api -d + run: docker compose -f ./deployment/docker-compose.yaml up -d - name: Install uses: ./.github/actions/install From 4bc9cdfe59584629c1449ea7aa6189d8a1e377c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?O=C4=9Fuzhan=20Olguncu?= <21091016+ogzhanolguncu@users.noreply.github.com> Date: Thu, 17 Jul 2025 14:48:06 +0300 Subject: [PATCH 4/6] feat: Unkey Deploy CLI (#3564) * feat: add commands * feat: allow configuring name,desc and version * feat: pass env to cli * feat: match the initial impl * feat: add new progress aniamtion * feat: add tracker step for each phase * refactor: improve animations and errors * feat: use proper orchestrafor managing steps and trackers * refactor: rename build to run * refactor: remove UI logic from api * chore: remove redundant commands * refactor: remove ui bloat * feat: add colors for make it distinguishable * fix: steps * fix: code rabbit issues * feat: add proper flag parsing logic * refactor: show help if required args are missing * feat: add missing commands * fix: code rabbit comments * refactor: fix redundancy * refactor: improve sub spinner * refactor: move duplicated spinner loop * refactor: remove some commands for later --- go/cmd/cli/cli/command.go | 107 +++++++ go/cmd/cli/cli/flag.go | 268 +++++++++++++++++ go/cmd/cli/cli/help.go | 155 ++++++++++ go/cmd/cli/cli/parser.go | 167 +++++++++++ go/cmd/cli/commands/deploy/build_docker.go | 84 ++++++ go/cmd/cli/commands/deploy/control_plane.go | 237 +++++++++++++++ go/cmd/cli/commands/deploy/deploy.go | 302 ++++++++++++++++++++ go/cmd/cli/commands/deploy/ui.go | 159 +++++++++++ go/cmd/cli/commands/init/init.go | 69 +++++ go/cmd/cli/commands/versions/versions.go | 164 +++++++++++ go/cmd/cli/main.go | 31 ++ 11 files changed, 1743 insertions(+) create mode 100644 go/cmd/cli/cli/command.go create mode 100644 go/cmd/cli/cli/flag.go create mode 100644 go/cmd/cli/cli/help.go create mode 100644 go/cmd/cli/cli/parser.go create mode 100644 go/cmd/cli/commands/deploy/build_docker.go create mode 100644 go/cmd/cli/commands/deploy/control_plane.go create mode 100644 go/cmd/cli/commands/deploy/deploy.go create mode 100644 go/cmd/cli/commands/deploy/ui.go create mode 100644 go/cmd/cli/commands/init/init.go create mode 100644 go/cmd/cli/commands/versions/versions.go create mode 100644 go/cmd/cli/main.go diff --git a/go/cmd/cli/cli/command.go b/go/cmd/cli/cli/command.go new file mode 100644 index 0000000000..cd9fdd4811 --- /dev/null +++ b/go/cmd/cli/cli/command.go @@ -0,0 +1,107 @@ +package cli + +import ( + "context" + "fmt" + "os" +) + +// Action represents a command handler function that receives context and the parsed command +type Action func(context.Context, *Command) error + +// Command represents a CLI command with its configuration and runtime state +type Command struct { + // Configuration + Name string // Command name (e.g., "deploy", "version") + Usage string // Short description shown in help + Description string // Longer description for detailed help + Version string // Version string (only used for root command) + Commands []*Command // Subcommands + Flags []Flag // Available flags for this command + Action Action // Function to execute when command is run + Aliases []string // Alternative names for this command + + // Runtime state (populated during parsing) + args []string // Non-flag arguments passed to command + flagMap map[string]Flag // Map for O(1) flag lookup + parent *Command // Parent command (for building usage paths) +} + +// Args returns the non-flag arguments passed to the command +// Example: "mycli deploy myapp" -> Args() returns ["myapp"] +func (c *Command) Args() []string { + return c.args +} + +// String returns the value of a string flag by name +// Returns empty string if flag doesn't exist or isn't a StringFlag +func (c *Command) String(name string) string { + if flag, ok := c.flagMap[name]; ok { + if sf, ok := flag.(*StringFlag); ok { + return sf.Value() + } + } + return "" +} + +// Bool returns the value of a boolean flag by name +// Returns false if flag doesn't exist or isn't a BoolFlag +func (c *Command) Bool(name string) bool { + if flag, ok := c.flagMap[name]; ok { + if bf, ok := flag.(*BoolFlag); ok { + return bf.Value() + } + } + return false +} + +// Int returns the value of an integer flag by name +// Returns 0 if flag doesn't exist or isn't an IntFlag +func (c *Command) Int(name string) int { + if flag, ok := c.flagMap[name]; ok { + if inf, ok := flag.(*IntFlag); ok { + return inf.Value() + } + } + return 0 +} + +// Float returns the value of a float flag by name +// Returns 0.0 if flag doesn't exist or isn't a FloatFlag +func (c *Command) Float(name string) float64 { + if flag, ok := c.flagMap[name]; ok { + if ff, ok := flag.(*FloatFlag); ok { + return ff.Value() + } + } + return 0.0 +} + +// StringSlice returns the value of a string slice flag by name +// Returns empty slice if flag doesn't exist or isn't a StringSliceFlag +func (c *Command) StringSlice(name string) []string { + if flag, ok := c.flagMap[name]; ok { + if ssf, ok := flag.(*StringSliceFlag); ok { + return ssf.Value() + } + } + return []string{} +} + +// Run executes the command with the given arguments (typically os.Args) +// This is the main entry point for CLI execution +func (c *Command) Run(ctx context.Context, args []string) error { + if len(args) == 0 { + return fmt.Errorf("no arguments provided") + } + // Parse arguments starting from index 1 (skip program name) + return c.parse(ctx, args[1:]) +} + +// Exit provides a clean way to exit with an error message and code +// This is a convenience function that prints the message and calls os.Exit +func Exit(message string, code int) error { + fmt.Println(message) + os.Exit(code) + return nil // unreachable but satisfies error interface +} diff --git a/go/cmd/cli/cli/flag.go b/go/cmd/cli/cli/flag.go new file mode 100644 index 0000000000..da1c6d0ef4 --- /dev/null +++ b/go/cmd/cli/cli/flag.go @@ -0,0 +1,268 @@ +package cli + +import ( + "fmt" + "os" + "strconv" + "strings" +) + +// Flag represents a command line flag interface +// All flag types must implement these methods +type Flag interface { + Name() string // The flag name (without dashes) + Usage() string // Help text describing the flag + Required() bool // Whether this flag is mandatory + Parse(value string) error // Parse string value into the flag's type + IsSet() bool // Whether the flag was explicitly set by user +} + +// baseFlag contains common fields and methods shared by all flag types +type baseFlag struct { + name string // Flag name + usage string // Help description + envVar string // Environment variable to check for default + required bool // Whether flag is mandatory + set bool // Whether user explicitly provided this flag +} + +// Name returns the flag name +func (b *baseFlag) Name() string { return b.name } + +// Usage returns the flag's help text +func (b *baseFlag) Usage() string { return b.usage } + +// Required returns whether this flag is mandatory +func (b *baseFlag) Required() bool { return b.required } + +// IsSet returns whether the user explicitly provided this flag +func (b *baseFlag) IsSet() bool { return b.set } + +// EnvVar returns the environment variable name for this flag +func (b *baseFlag) EnvVar() string { return b.envVar } + +// StringFlag represents a string command line flag +type StringFlag struct { + baseFlag + value string // Current value +} + +// Parse sets the flag value from a string +func (f *StringFlag) Parse(value string) error { + f.value = value + f.set = true + return nil +} + +// Value returns the current string value +func (f *StringFlag) Value() string { return f.value } + +// BoolFlag represents a boolean command line flag +type BoolFlag struct { + baseFlag + value bool // Current value +} + +// Parse sets the flag value from a string +// Empty string means the flag was provided without a value (--flag), which sets it to true +// Otherwise parses as boolean: "true", "false", "1", "0", etc. +func (f *BoolFlag) Parse(value string) error { + if value == "" { + f.value = true + f.set = true + return nil + } + parsed, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid boolean value: %s", value) + } + f.value = parsed + f.set = true + return nil +} + +// Value returns the current boolean value +func (f *BoolFlag) Value() bool { return f.value } + +// IntFlag represents an integer command line flag +type IntFlag struct { + baseFlag + value int // Current value +} + +// Parse sets the flag value from a string +func (f *IntFlag) Parse(value string) error { + parsed, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("invalid integer value: %s", value) + } + f.value = parsed + f.set = true + return nil +} + +// Value returns the current integer value +func (f *IntFlag) Value() int { return f.value } + +// FloatFlag represents a float64 command line flag +type FloatFlag struct { + baseFlag + value float64 // Current value +} + +// Parse sets the flag value from a string +func (f *FloatFlag) Parse(value string) error { + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("invalid float value: %s", value) + } + f.value = parsed + f.set = true + return nil +} + +// Value returns the current float64 value +func (f *FloatFlag) Value() float64 { return f.value } + +// StringSliceFlag represents a string slice command line flag +type StringSliceFlag struct { + baseFlag + value []string // Current value +} + +// parseCommaSeparated splits a comma-separated string into a slice of trimmed non-empty strings +func (f *StringSliceFlag) parseCommaSeparated(value string) []string { + if value == "" { + return []string{} + } + parts := strings.Split(value, ",") + result := make([]string, 0, len(parts)) + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result +} + +// Parse sets the flag value from a string (comma-separated values) +func (f *StringSliceFlag) Parse(value string) error { + f.value = f.parseCommaSeparated(value) + f.set = true + return nil +} + +// Value returns the current string slice value +func (f *StringSliceFlag) Value() []string { return f.value } + +// String creates a new string flag with environment variable support +// If envVar is provided and set, it will be used as the default value +func String(name, usage, defaultValue, envVar string, required bool) *StringFlag { + flag := &StringFlag{ + baseFlag: baseFlag{ + name: name, + usage: usage, + envVar: envVar, + required: required, + }, + value: defaultValue, + } + // Check environment variable for default value + if envVar != "" { + if envValue := os.Getenv(envVar); envValue != "" { + flag.value = envValue + flag.set = true // Mark as set since env var was found + } + } + return flag +} + +// Bool creates a new boolean flag with environment variable support +func Bool(name, usage, envVar string, required bool) *BoolFlag { + flag := &BoolFlag{ + baseFlag: baseFlag{ + name: name, + usage: usage, + envVar: envVar, + required: required, + }, + } + // Check environment variable for default value + if envVar != "" { + if envValue := os.Getenv(envVar); envValue != "" { + if parsed, err := strconv.ParseBool(envValue); err == nil { + flag.value = parsed + flag.set = true // Mark as set since env var was found + } + } + } + return flag +} + +// Int creates a new integer flag with environment variable support +func Int(name, usage string, defaultValue int, envVar string, required bool) *IntFlag { + flag := &IntFlag{ + baseFlag: baseFlag{ + name: name, + usage: usage, + envVar: envVar, + required: required, + }, + value: defaultValue, + } + // Check environment variable for default value + if envVar != "" { + if envValue := os.Getenv(envVar); envValue != "" { + if parsed, err := strconv.Atoi(envValue); err == nil { + flag.value = parsed + flag.set = true // Mark as set since env var was found + } + } + } + return flag +} + +// Float creates a new float flag with environment variable support +func Float(name, usage string, defaultValue float64, envVar string, required bool) *FloatFlag { + flag := &FloatFlag{ + baseFlag: baseFlag{ + name: name, + usage: usage, + envVar: envVar, + required: required, + }, + value: defaultValue, + } + // Check environment variable for default value + if envVar != "" { + if envValue := os.Getenv(envVar); envValue != "" { + if parsed, err := strconv.ParseFloat(envValue, 64); err == nil { + flag.value = parsed + flag.set = true // Mark as set since env var was found + } + } + } + return flag +} + +// StringSlice creates a new string slice flag with environment variable support +func StringSlice(name, usage string, defaultValue []string, envVar string, required bool) *StringSliceFlag { + flag := &StringSliceFlag{ + baseFlag: baseFlag{ + name: name, + usage: usage, + envVar: envVar, + required: required, + }, + value: defaultValue, + } + // Check environment variable for default value + if envVar != "" { + if envValue := os.Getenv(envVar); envValue != "" { + flag.value = flag.parseCommaSeparated(envValue) + flag.set = true // Mark as set since env var was found + } + } + return flag +} diff --git a/go/cmd/cli/cli/help.go b/go/cmd/cli/cli/help.go new file mode 100644 index 0000000000..34483e69d0 --- /dev/null +++ b/go/cmd/cli/cli/help.go @@ -0,0 +1,155 @@ +package cli + +import ( + "fmt" + "strings" +) + +// showHelp displays comprehensive help information for the command +// This includes name, description, usage, subcommands, and flags +func (c *Command) showHelp() { + // Command name and usage description + fmt.Printf("NAME:\n %s", c.Name) + if c.Usage != "" { + fmt.Printf(" - %s", c.Usage) + } + fmt.Printf("\n\n") + + // Extended description if available + if c.Description != "" { + fmt.Printf("DESCRIPTION:\n %s\n\n", c.Description) + } + + // Build and show usage line + c.showUsageLine() + + // Show version for root command + if c.Version != "" { + fmt.Printf("VERSION:\n %s\n\n", c.Version) + } + + // Show available subcommands + if len(c.Commands) > 0 { + c.showCommands() + } + + // Show command-specific flags if any exist + if len(c.Flags) > 0 { + fmt.Printf("OPTIONS:\n") + for _, flag := range c.Flags { + c.showFlag(flag) + } + fmt.Printf("\n") + } + + // Always show global options + fmt.Printf("GLOBAL OPTIONS:\n") + fmt.Printf(" %-25s %s\n", "--help, -h", "show help") + + // Add version flag only for root command (commands with Version set) + if c.Version != "" { + fmt.Printf(" %-25s %s\n", "--version, -v", "print the version") + } + fmt.Printf("\n") +} + +// showUsageLine displays the command usage syntax +func (c *Command) showUsageLine() { + fmt.Printf("USAGE:\n ") + + // Build full command path (parent commands + this command) + path := c.buildCommandPath() + fmt.Printf("%s", strings.Join(path, " ")) + + // Add syntax indicators + if len(c.Flags) > 0 { + fmt.Printf(" [options]") + } + if len(c.Commands) > 0 { + fmt.Printf(" [command]") + } + fmt.Printf("\n\n") +} + +// buildCommandPath constructs the full command path from root to current command +func (c *Command) buildCommandPath() []string { + var path []string + + // Walk up the parent chain to build full path + cmd := c + for cmd != nil { + path = append([]string{cmd.Name}, path...) + cmd = cmd.parent + } + return path +} + +// showCommands displays all available subcommands in a formatted table +func (c *Command) showCommands() { + fmt.Printf("COMMANDS:\n") + + // Find the longest command name for alignment + maxLen := 0 + for _, cmd := range c.Commands { + if len(cmd.Name) > maxLen { + maxLen = len(cmd.Name) + } + } + + // Display each command with aliases + for _, cmd := range c.Commands { + name := cmd.Name + if len(cmd.Aliases) > 0 { + name += fmt.Sprintf(", %s", strings.Join(cmd.Aliases, ", ")) + } + fmt.Printf(" %-*s %s\n", maxLen+10, name, cmd.Usage) + } + + // Add built-in help command + fmt.Printf(" %-*s %s\n", maxLen+10, "help, h", "Shows help for commands") + fmt.Printf("\n") +} + +// showFlag displays a single flag with proper formatting +func (c *Command) showFlag(flag Flag) { + // Build flag name(s) - support both short and long forms + name := fmt.Sprintf("--%s", flag.Name()) + if len(flag.Name()) == 1 { + name = fmt.Sprintf("-%s, --%s", flag.Name(), flag.Name()) + } + + // Build usage description + usage := flag.Usage() + + // Add required indicator + if flag.Required() { + usage += " (required)" + } + + // Add environment variable info if available + envVar := c.getEnvVar(flag) + if envVar != "" { + usage += fmt.Sprintf(" [$%s]", envVar) + } + + // Display with consistent formatting + fmt.Printf(" %-25s %s\n", name, usage) +} + +// getEnvVar extracts environment variable name from flag if it supports it +func (c *Command) getEnvVar(flag Flag) string { + switch f := flag.(type) { + case *StringFlag: + return f.EnvVar() + case *BoolFlag: + return f.EnvVar() + case *IntFlag: + return f.EnvVar() + case *FloatFlag: + return f.EnvVar() + case *StringSliceFlag: + return f.EnvVar() + default: + return "" + } +} diff --git a/go/cmd/cli/cli/parser.go b/go/cmd/cli/cli/parser.go new file mode 100644 index 0000000000..845fb45a2a --- /dev/null +++ b/go/cmd/cli/cli/parser.go @@ -0,0 +1,167 @@ +package cli + +import ( + "context" + "fmt" + "slices" + "strings" +) + +// parse processes command line arguments and executes the appropriate action +// This handles flag parsing, subcommand routing, and help display +func (c *Command) parse(ctx context.Context, args []string) error { + // Initialize flagMap if not already done + if c.flagMap == nil { + c.flagMap = make(map[string]Flag) + for _, flag := range c.Flags { + c.flagMap[flag.Name()] = flag + } + } + + var commandArgs []string + for i := 0; i < len(args); i++ { + arg := args[i] + + // Handle help flags first - these short-circuit normal processing + if arg == "-h" || arg == "--help" || arg == "help" { + c.showHelp() + return nil + } + + // Handle version flags - print version and exit + if (arg == "-v" || arg == "--version") && c.Version != "" { + fmt.Println(c.Version) + return nil + } + + // Handle "help " pattern - show help for specific subcommand + if arg == "help" && i+1 < len(args) { + cmdName := args[i+1] + for _, subcmd := range c.Commands { + if subcmd.Name == cmdName { + subcmd.parent = c + subcmd.showHelp() + return nil + } + // Check aliases + if slices.Contains(subcmd.Aliases, cmdName) { + subcmd.parent = c + subcmd.showHelp() + return nil + } + } + return fmt.Errorf("unknown command: %s", cmdName) + } + + // Check for subcommands (non-flag arguments) + if !strings.HasPrefix(arg, "-") { + // Look for matching subcommand + for _, subcmd := range c.Commands { + if subcmd.Name == arg { + subcmd.parent = c + return subcmd.parse(ctx, args[i+1:]) + } + // Check aliases + if slices.Contains(subcmd.Aliases, arg) { + subcmd.parent = c + return subcmd.parse(ctx, args[i+1:]) + } + } + // Not a subcommand, treat as regular argument + commandArgs = append(commandArgs, arg) + continue + } + + // Parse flags (arguments starting with -) + if err := c.parseFlag(args, &i); err != nil { + return err + } + } + + // Store parsed arguments + c.args = commandArgs + + // Validate all required flags are present + if err := c.validateRequiredFlags(); err != nil { + fmt.Printf("Error: %v\n\n", err) + c.showHelp() + return err + } + + // Execute action if present + if c.Action != nil { + return c.Action(ctx, c) + } + + // No action defined - show help if we have subcommands + if len(c.Commands) > 0 { + c.showHelp() + } + + return nil +} + +// parseFlag handles parsing of a single flag and its value +// It modifies the index i to skip consumed arguments +func (c *Command) parseFlag(args []string, i *int) error { + arg := args[*i] + + // Remove leading dashes properly + var flagName string + if strings.HasPrefix(arg, "--") { + flagName = arg[2:] // Remove exactly "--" + } else if strings.HasPrefix(arg, "-") { + flagName = arg[1:] // Remove exactly "-" + } else { + return fmt.Errorf("invalid flag format: %s", arg) + } + + // Handle --flag=value format + var flagValue string + var hasValue bool + if eqIndex := strings.Index(flagName, "="); eqIndex != -1 { + flagValue = flagName[eqIndex+1:] + flagName = flagName[:eqIndex] + hasValue = true + } + + // Look up the flag + flag, exists := c.flagMap[flagName] + if !exists { + return fmt.Errorf("unknown flag: %s", flagName) + } + + // Handle boolean flags specially - they don't require values + if bf, ok := flag.(*BoolFlag); ok { + if hasValue { + // --bool-flag=true/false format + return bf.Parse(flagValue) + } else { + // --bool-flag format (implies true) + return bf.Parse("") + } + } + + // For non-boolean flags, we need a value + if !hasValue { + // Value should be in next argument + if *i+1 >= len(args) { + return fmt.Errorf("flag %s requires a value", flagName) + } + *i++ // Move to next argument + flagValue = args[*i] + } + + // Parse the flag value + return flag.Parse(flagValue) +} + +// validateRequiredFlags checks that all required flags have been set +func (c *Command) validateRequiredFlags() error { + for _, flag := range c.Flags { + if flag.Required() && !flag.IsSet() { + return fmt.Errorf("required flag missing: %s", flag.Name()) + } + } + return nil +} diff --git a/go/cmd/cli/commands/deploy/build_docker.go b/go/cmd/cli/commands/deploy/build_docker.go new file mode 100644 index 0000000000..ba664c5707 --- /dev/null +++ b/go/cmd/cli/commands/deploy/build_docker.go @@ -0,0 +1,84 @@ +package deploy + +import ( + "context" + "fmt" + "os/exec" + "strings" + "time" + + "github.com/unkeyed/unkey/go/pkg/git" +) + +func generateImageTag(opts *DeployOptions, gitInfo git.Info) string { + if gitInfo.ShortSHA != "" { + return fmt.Sprintf("%s-%s", opts.Branch, gitInfo.ShortSHA) + } + return fmt.Sprintf("%s-%d", opts.Branch, time.Now().Unix()) +} + +func buildImage(ctx context.Context, opts *DeployOptions, dockerImage string) error { + buildArgs := []string{"build"} + if opts.Dockerfile != "Dockerfile" { + buildArgs = append(buildArgs, "-f", opts.Dockerfile) + } + buildArgs = append(buildArgs, + "-t", dockerImage, + "--build-arg", fmt.Sprintf("VERSION=%s", opts.Commit), + opts.Context, + ) + + cmd := exec.CommandContext(ctx, "docker", buildArgs...) + + // Stream output directly instead of complex pipe handling + output, err := cmd.CombinedOutput() + if err != nil { + fmt.Printf("Docker build failed:\n%s\n", string(output)) + return ErrDockerBuildFailed + } + + return nil +} + +func pushImage(ctx context.Context, dockerImage, registry string) error { + cmd := exec.CommandContext(ctx, "docker", "push", dockerImage) + output, err := cmd.CombinedOutput() + if err != nil { + detailedMsg := classifyPushError(string(output), registry) + return fmt.Errorf("%s: %w", detailedMsg, err) + } + fmt.Printf("%s\n", string(output)) + return nil +} + +func classifyPushError(output, registry string) string { + output = strings.TrimSpace(output) + registryHost := getRegistryHost(registry) + + switch { + case strings.Contains(output, "denied"): + return fmt.Sprintf("registry access denied. try: docker login %s", registryHost) + + case strings.Contains(output, "not found") || strings.Contains(output, "404"): + return "registry not found. create repository or use --registry=your-registry/your-app" + + case strings.Contains(output, "unauthorized"): + return fmt.Sprintf("authentication required. run: docker login %s", registryHost) + + default: + return output + } +} + +func getRegistryHost(registry string) string { + parts := strings.Split(registry, "/") + if len(parts) > 0 { + return parts[0] + } + return "docker.io" +} + +func isDockerAvailable() bool { + cmd := exec.Command("docker", "--version") + return cmd.Run() == nil +} diff --git a/go/cmd/cli/commands/deploy/control_plane.go b/go/cmd/cli/commands/deploy/control_plane.go new file mode 100644 index 0000000000..1418c6979f --- /dev/null +++ b/go/cmd/cli/commands/deploy/control_plane.go @@ -0,0 +1,237 @@ +package deploy + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "connectrpc.com/connect" + ctrlv1 "github.com/unkeyed/unkey/go/gen/proto/ctrl/v1" + "github.com/unkeyed/unkey/go/gen/proto/ctrl/v1/ctrlv1connect" + "github.com/unkeyed/unkey/go/pkg/codes" + "github.com/unkeyed/unkey/go/pkg/fault" + "github.com/unkeyed/unkey/go/pkg/otel/logging" +) + +// VersionStatusEvent represents a status change event +type VersionStatusEvent struct { + VersionID string + PreviousStatus ctrlv1.VersionStatus + CurrentStatus ctrlv1.VersionStatus + Version *ctrlv1.Version +} + +// VersionStepEvent represents a step update event +type VersionStepEvent struct { + VersionID string + Step *ctrlv1.VersionStep + Status ctrlv1.VersionStatus +} + +// ControlPlaneClient handles API operations with the control plane +type ControlPlaneClient struct { + client ctrlv1connect.VersionServiceClient + opts *DeployOptions +} + +// NewControlPlaneClient creates a new control plane client +func NewControlPlaneClient(opts *DeployOptions) *ControlPlaneClient { + httpClient := &http.Client{} + client := ctrlv1connect.NewVersionServiceClient(httpClient, opts.ControlPlaneURL) + + return &ControlPlaneClient{ + client: client, + opts: opts, + } +} + +// CreateVersion creates a new version in the control plane +func (c *ControlPlaneClient) CreateVersion(ctx context.Context, dockerImage string) (string, error) { + createReq := connect.NewRequest(&ctrlv1.CreateVersionRequest{ + WorkspaceId: c.opts.WorkspaceID, + ProjectId: c.opts.ProjectID, + Branch: c.opts.Branch, + SourceType: ctrlv1.SourceType_SOURCE_TYPE_CLI_UPLOAD, + GitCommitSha: c.opts.Commit, + EnvironmentId: "env_prod", // TODO: Make this configurable + DockerImageTag: dockerImage, + }) + + createReq.Header().Set("Authorization", "Bearer "+c.opts.AuthToken) + + createResp, err := c.client.CreateVersion(ctx, createReq) + if err != nil { + return "", c.handleCreateVersionError(err) + } + + versionId := createResp.Msg.GetVersionId() + if versionId == "" { + return "", fmt.Errorf("empty version ID returned from control plane") + } + + return versionId, nil +} + +// GetVersion retrieves version information from the control plane +func (c *ControlPlaneClient) GetVersion(ctx context.Context, versionId string) (*ctrlv1.Version, error) { + getReq := connect.NewRequest(&ctrlv1.GetVersionRequest{ + VersionId: versionId, + }) + getReq.Header().Set("Authorization", "Bearer "+c.opts.AuthToken) + + getResp, err := c.client.GetVersion(ctx, getReq) + if err != nil { + return nil, err + } + + return getResp.Msg.GetVersion(), nil +} + +// PollVersionStatus polls for version changes and calls event handlers +func (c *ControlPlaneClient) PollVersionStatus( + ctx context.Context, + logger logging.Logger, + versionId string, + onStatusChange func(VersionStatusEvent) error, + onStepUpdate func(VersionStepEvent) error, +) error { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + timeout := time.NewTimer(300 * time.Second) + defer timeout.Stop() + + // Track processed steps by creation time to avoid duplicates + processedSteps := make(map[int64]bool) + lastStatus := ctrlv1.VersionStatus_VERSION_STATUS_UNSPECIFIED + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-timeout.C: + return fmt.Errorf("deployment timeout after 5 minutes") + case <-ticker.C: + version, err := c.GetVersion(ctx, versionId) + if err != nil { + logger.Debug("Failed to get version status", "error", err, "version_id", versionId) + continue + } + + currentStatus := version.GetStatus() + + // Handle version status changes + if currentStatus != lastStatus { + event := VersionStatusEvent{ + VersionID: versionId, + PreviousStatus: lastStatus, + CurrentStatus: currentStatus, + Version: version, + } + + if err := onStatusChange(event); err != nil { + return err + } + lastStatus = currentStatus + } + + // Process new step updates + if err := c.processNewSteps(versionId, version.GetSteps(), processedSteps, currentStatus, onStepUpdate); err != nil { + return err + } + + // Check for completion + if currentStatus == ctrlv1.VersionStatus_VERSION_STATUS_ACTIVE { + return nil + } + } + } +} + +// processNewSteps processes new deployment steps and calls the event handler +func (c *ControlPlaneClient) processNewSteps( + versionId string, + steps []*ctrlv1.VersionStep, + processedSteps map[int64]bool, + currentStatus ctrlv1.VersionStatus, + onStepUpdate func(VersionStepEvent) error, +) error { + for _, step := range steps { + // Creation timestamp as unique identifier + stepTimestamp := step.GetCreatedAt() + + if processedSteps[stepTimestamp] { + continue // Already processed this step + } + + // Handle step errors first + if step.GetErrorMessage() != "" { + return fmt.Errorf("deployment failed: %s", step.GetErrorMessage()) + } + + // Call step update handler + if step.GetMessage() != "" { + event := VersionStepEvent{ + VersionID: versionId, + Step: step, + Status: currentStatus, + } + + if err := onStepUpdate(event); err != nil { + return err + } + } + + // Mark this step as processed + processedSteps[stepTimestamp] = true + } + return nil +} + +// getFailureMessage extracts failure message from version +func (c *ControlPlaneClient) getFailureMessage(version *ctrlv1.Version) string { + if version.GetErrorMessage() != "" { + return version.GetErrorMessage() + } + + // Check for error in steps + for _, step := range version.GetSteps() { + if step.GetErrorMessage() != "" { + return step.GetErrorMessage() + } + } + + return "Unknown deployment error" +} + +// handleCreateVersionError provides specific error handling for version creation +func (c *ControlPlaneClient) handleCreateVersionError(err error) error { + // Check if it's a connection error + if strings.Contains(err.Error(), "connection refused") { + return fault.Wrap(err, + fault.Code(codes.UnkeyAppErrorsInternalServiceUnavailable), + fault.Internal(fmt.Sprintf("Failed to connect to control plane at %s", c.opts.ControlPlaneURL)), + fault.Public("Unable to connect to control plane. Is it running?"), + ) + } + + // Check if it's an auth error + if connectErr := new(connect.Error); errors.As(err, &connectErr) { + if connectErr.Code() == connect.CodeUnauthenticated { + return fault.Wrap(err, + fault.Code(codes.UnkeyAuthErrorsAuthenticationMalformed), + fault.Internal(fmt.Sprintf("Authentication failed with token: %s", c.opts.AuthToken)), + fault.Public("Authentication failed. Check your auth token."), + ) + } + } + + // Generic API error + return fault.Wrap(err, + fault.Code(codes.UnkeyAppErrorsInternalUnexpectedError), + fault.Internal(fmt.Sprintf("CreateVersion API call failed: %v", err)), + fault.Public("Failed to create version. Please try again."), + ) +} diff --git a/go/cmd/cli/commands/deploy/deploy.go b/go/cmd/cli/commands/deploy/deploy.go new file mode 100644 index 0000000000..a286f157fe --- /dev/null +++ b/go/cmd/cli/commands/deploy/deploy.go @@ -0,0 +1,302 @@ +package deploy + +import ( + "context" + "errors" + "fmt" + + "github.com/unkeyed/unkey/go/cmd/cli/cli" + ctrlv1 "github.com/unkeyed/unkey/go/gen/proto/ctrl/v1" + "github.com/unkeyed/unkey/go/pkg/git" + "github.com/unkeyed/unkey/go/pkg/otel/logging" +) + +// Step predictor - maps current step message patterns to next expected steps +// Based on the actual workflow messages from version.go +// TODO: In the future directly get those from hydra +var stepSequence = map[string]string{ + "Version queued and ready to start": "Downloading Docker image:", + "Downloading Docker image:": "Building rootfs from Docker image:", + "Building rootfs from Docker image:": "Uploading rootfs image to storage", + "Uploading rootfs image to storage": "Creating VM for version:", + "Creating VM for version:": "VM booted successfully:", + "VM booted successfully:": "Assigned hostname:", + "Assigned hostname:": "Version deployment completed successfully", +} + +var ( + ErrDockerNotFound = errors.New("docker command not found - please install Docker") + ErrDockerBuildFailed = errors.New("docker build failed") +) + +// DeployOptions contains all configuration for deployment +type DeployOptions struct { + WorkspaceID string + ProjectID string + Context string + Branch string + DockerImage string + Dockerfile string + Commit string + Registry string + SkipPush bool + ControlPlaneURL string + AuthToken string +} + +var DeployFlags = []cli.Flag{ + // Required flags + cli.String("workspace-id", "Workspace ID", "", "UNKEY_WORKSPACE_ID", true), + cli.String("project-id", "Project ID", "", "UNKEY_PROJECT_ID", true), + + // Optional flags with defaults + cli.String("context", "Docker context path", ".", "", false), + cli.String("branch", "Git branch", "main", "", false), + cli.String("docker-image", "Pre-built docker image", "", "", false), + cli.String("dockerfile", "Path to Dockerfile", "Dockerfile", "", false), + cli.String("commit", "Git commit SHA", "", "", false), + cli.String("registry", "Docker registry", "ghcr.io/unkeyed/deploy", "UNKEY_DOCKER_REGISTRY", false), + cli.Bool("skip-push", "Skip pushing to registry (for local testing)", "", false), + + // Control plane flags (internal) + cli.String("control-plane-url", "Control plane URL", "http://localhost:7091", "", false), + cli.String("auth-token", "Control plane auth token", "ctrl-secret-token", "", false), +} + +// Command defines the deploy CLI command +var Command = &cli.Command{ + Name: "deploy", + Usage: "Deploy a new version", + Description: `Build and deploy a new version of your application. +Builds a Docker image from the specified context and +deploys it to the Unkey platform. + +EXAMPLES: + # Basic deployment + unkey deploy \ + --workspace-id=ws_4QgQsKsKfdm3nGeC \ + --project-id=proj_9aiaks2dzl6mcywnxjf \ + --context=./demo_api + + # Deploy with your own registry + unkey deploy \ + --workspace-id=ws_4QgQsKsKfdm3nGeC \ + --project-id=proj_9aiaks2dzl6mcywnxjf \ + --registry=docker.io/mycompany/myapp + + # Local development (skip push) + unkey deploy \ + --workspace-id=ws_4QgQsKsKfdm3nGeC \ + --project-id=proj_9aiaks2dzl6mcywnxjf \ + --skip-push + + # Deploy pre-built image + unkey deploy \ + --workspace-id=ws_4QgQsKsKfdm3nGeC \ + --project-id=proj_9aiaks2dzl6mcywnxjf \ + --docker-image=ghcr.io/user/app:v1.0.0`, + Flags: DeployFlags, + Action: DeployAction, +} + +func DeployAction(ctx context.Context, cmd *cli.Command) error { + opts := &DeployOptions{ + WorkspaceID: cmd.String("workspace-id"), + ProjectID: cmd.String("project-id"), + Context: cmd.String("context"), + Branch: cmd.String("branch"), + DockerImage: cmd.String("docker-image"), + Dockerfile: cmd.String("dockerfile"), + Commit: cmd.String("commit"), + Registry: cmd.String("registry"), + SkipPush: cmd.Bool("skip-push"), + ControlPlaneURL: cmd.String("control-plane-url"), + AuthToken: cmd.String("auth-token"), + } + + return executeDeploy(ctx, opts) +} + +// Updated executeDeploy function - remove global spinner for deployment steps +func executeDeploy(ctx context.Context, opts *DeployOptions) error { + ui := NewUI() + logger := logging.New() + gitInfo := git.GetInfo() + + if opts.Branch == "main" && gitInfo.IsRepo && gitInfo.Branch != "" { + opts.Branch = gitInfo.Branch + } + if opts.Commit == "" && gitInfo.CommitSHA != "" { + opts.Commit = gitInfo.CommitSHA + } + + fmt.Printf("Unkey Deploy Progress\n") + fmt.Printf("──────────────────────────────────────────────────\n") + printSourceInfo(opts, gitInfo) + + ui.Print("Preparing deployment") + + var dockerImage string + + if opts.DockerImage == "" { + if !isDockerAvailable() { + ui.PrintError("Docker not found - please install Docker") + ui.PrintErrorDetails(ErrDockerNotFound.Error()) + return nil + } + imageTag := generateImageTag(opts, gitInfo) + dockerImage = fmt.Sprintf("%s:%s", opts.Registry, imageTag) + + ui.Print(fmt.Sprintf("Building image: %s", dockerImage)) + if err := buildImage(ctx, opts, dockerImage); err != nil { + ui.PrintError("Docker build failed") + ui.PrintErrorDetails(err.Error()) + return nil + } + ui.PrintSuccess("Image built successfully") + } else { + dockerImage = opts.DockerImage + ui.Print("Using pre-built Docker image") + } + + if !opts.SkipPush && opts.DockerImage == "" { + ui.Print("Pushing to registry") + if err := pushImage(ctx, dockerImage, opts.Registry); err != nil { + ui.PrintError("Push failed but continuing deployment") + ui.PrintErrorDetails(err.Error()) + // INFO: For now we are ignoring registry push because everyone one is working locally, + // omit this when comments and put the return nil back when going to prod + // return nil + } else { + ui.PrintSuccess("Image pushed successfully") + } + } else if opts.SkipPush { + ui.Print("Skipping registry push") + } + + ui.Print("Creating deployment") + + controlPlane := NewControlPlaneClient(opts) + versionId, err := controlPlane.CreateVersion(ctx, dockerImage) + if err != nil { + ui.PrintError("Failed to create version") + ui.PrintErrorDetails(err.Error()) + return nil + } + + ui.PrintSuccess(fmt.Sprintf("Version created: %s", versionId)) + + onStatusChange := func(event VersionStatusEvent) error { + if event.CurrentStatus == ctrlv1.VersionStatus_VERSION_STATUS_FAILED { + return handleVersionFailure(controlPlane, event.Version, ui) + } + return nil + } + + onStepUpdate := func(event VersionStepEvent) error { + return handleStepUpdate(event, ui) + } + + err = controlPlane.PollVersionStatus(ctx, logger, versionId, onStatusChange, onStepUpdate) + if err != nil { + // Complete any running step spinner on error + ui.CompleteCurrentStep("Deployment failed", false) + return err + } + + // Complete final step if still spinning + ui.CompleteCurrentStep("Version deployment completed successfully", true) + ui.PrintSuccess("Deployment completed successfully") + + fmt.Printf("\n") + printCompletionInfo(opts, gitInfo, versionId) + fmt.Printf("\n") + + return nil +} + +func getNextStepMessage(currentMessage string) string { + // Check if current message starts with any known step pattern + for key, next := range stepSequence { + if len(currentMessage) >= len(key) && currentMessage[:len(key)] == key { + return next + } + } + return "" +} + +func handleStepUpdate(event VersionStepEvent, ui *UI) error { + step := event.Step + + if step.GetErrorMessage() != "" { + ui.CompleteCurrentStep(step.GetMessage(), false) + ui.PrintErrorDetails(step.GetErrorMessage()) + return fmt.Errorf("deployment failed: %s", step.GetErrorMessage()) + } + + if step.GetMessage() != "" { + message := step.GetMessage() + nextStep := getNextStepMessage(message) + + if !ui.stepSpinning { + // First step - start spinner, then complete and start next + ui.StartStepSpinner(message) + ui.CompleteStepAndStartNext(message, nextStep) + } else { + // Complete current step and start next + ui.CompleteStepAndStartNext(message, nextStep) + } + } + + return nil +} + +func handleVersionFailure(controlPlane *ControlPlaneClient, version *ctrlv1.Version, ui *UI) error { + errorMsg := controlPlane.getFailureMessage(version) + ui.CompleteCurrentStep("Deployment failed", false) + ui.PrintError("Deployment failed") + ui.PrintErrorDetails(errorMsg) + return fmt.Errorf("deployment failed: %s", errorMsg) +} + +func printSourceInfo(opts *DeployOptions, gitInfo git.Info) { + fmt.Printf("Source Information:\n") + fmt.Printf(" Branch: %s\n", opts.Branch) + + if gitInfo.IsRepo && gitInfo.CommitSHA != "" { + + commitInfo := gitInfo.ShortSHA + if gitInfo.IsDirty { + commitInfo += " (dirty)" + } + fmt.Printf(" Commit: %s\n", commitInfo) + } + + fmt.Printf(" Context: %s\n", opts.Context) + + if opts.DockerImage != "" { + fmt.Printf(" Image: %s\n", opts.DockerImage) + } + + fmt.Printf("\n") +} + +func printCompletionInfo(opts *DeployOptions, gitInfo git.Info, versionId string) { + if versionId == "" || opts.WorkspaceID == "" || opts.Branch == "" { + fmt.Printf("✓ Deployment completed\n") + return + } + + fmt.Printf("Deployment Summary:\n") + fmt.Printf(" Version: %s\n", versionId) + fmt.Printf(" Status: Ready\n") + fmt.Printf(" Environment: Production\n") + + identifier := versionId + if gitInfo.ShortSHA != "" { + identifier = gitInfo.ShortSHA + } + + domain := fmt.Sprintf("https://%s-%s-%s.unkey.app", opts.Branch, identifier, opts.WorkspaceID) + fmt.Printf(" URL: %s\n", domain) +} diff --git a/go/cmd/cli/commands/deploy/ui.go b/go/cmd/cli/commands/deploy/ui.go new file mode 100644 index 0000000000..12601dd9c9 --- /dev/null +++ b/go/cmd/cli/commands/deploy/ui.go @@ -0,0 +1,159 @@ +package deploy + +import ( + "fmt" + "sync" + "time" +) + +// Color constants +const ( + ColorReset = "\033[0m" + ColorRed = "\033[31m" + ColorGreen = "\033[32m" + ColorYellow = "\033[33m" +) + +var spinnerChars = []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"} + +type UI struct { + mu sync.Mutex + spinning bool + currentStep string + stepSpinning bool +} + +func NewUI() *UI { + return &UI{} +} + +func (ui *UI) Print(message string) { + ui.mu.Lock() + defer ui.mu.Unlock() + fmt.Printf("%s•%s %s\n", ColorYellow, ColorReset, message) +} + +func (ui *UI) PrintSuccess(message string) { + ui.mu.Lock() + defer ui.mu.Unlock() + fmt.Printf("%s✓%s %s\n", ColorGreen, ColorReset, message) +} + +func (ui *UI) PrintError(message string) { + ui.mu.Lock() + defer ui.mu.Unlock() + fmt.Printf("%s✗%s %s\n", ColorRed, ColorReset, message) +} + +func (ui *UI) PrintErrorDetails(message string) { + ui.mu.Lock() + defer ui.mu.Unlock() + fmt.Printf(" %s->%s %s\n", ColorRed, ColorReset, message) +} + +func (ui *UI) PrintStepSuccess(message string) { + ui.mu.Lock() + defer ui.mu.Unlock() + fmt.Printf(" %s✓%s %s\n", ColorGreen, ColorReset, message) +} + +func (ui *UI) PrintStepError(message string) { + ui.mu.Lock() + defer ui.mu.Unlock() + fmt.Printf(" %s✗%s %s\n", ColorRed, ColorReset, message) +} + +func (ui *UI) spinnerLoop(prefix string, messageGetter func() string, isActive func() bool) { + go func() { + frame := 0 + for { + ui.mu.Lock() + if !isActive() { + ui.mu.Unlock() + return + } + message := messageGetter() + fmt.Printf("\r%s%s %s", prefix, spinnerChars[frame%len(spinnerChars)], message) + ui.mu.Unlock() + frame++ + time.Sleep(100 * time.Millisecond) + } + }() +} + +func (ui *UI) StartSpinner(message string) { + ui.mu.Lock() + if ui.spinning { + ui.mu.Unlock() + return + } + ui.spinning = true + spinnerMessage := message + ui.mu.Unlock() + + ui.spinnerLoop("", func() string { return spinnerMessage }, func() bool { return ui.spinning }) +} + +func (ui *UI) StopSpinner(finalMessage string, success bool) { + ui.mu.Lock() + defer ui.mu.Unlock() + if !ui.spinning { + return + } + ui.spinning = false + fmt.Print("\r\033[K") + if success { + fmt.Printf("%s✓%s %s\n", ColorGreen, ColorReset, finalMessage) + } else { + fmt.Printf("%s✗%s %s\n", ColorRed, ColorReset, finalMessage) + } +} + +// Step spinner methods - indented with 2 spaces to show as sub-steps +func (ui *UI) StartStepSpinner(message string) { + ui.mu.Lock() + if ui.stepSpinning { + fmt.Print("\r\033[K") + } + ui.currentStep = message + ui.stepSpinning = true + ui.mu.Unlock() + + ui.spinnerLoop(" ", func() string { return ui.currentStep }, func() bool { return ui.stepSpinning }) +} + +func (ui *UI) CompleteStepAndStartNext(completedMessage, nextMessage string) { + ui.mu.Lock() + // Stop current spinner and show completion + if ui.stepSpinning { + ui.stepSpinning = false + fmt.Print("\r\033[K") + fmt.Printf(" %s✓%s %s\n", ColorGreen, ColorReset, completedMessage) + } + + // Start next step if provided + if nextMessage != "" { + ui.currentStep = nextMessage + ui.stepSpinning = true + ui.mu.Unlock() + + ui.spinnerLoop(" ", func() string { return ui.currentStep }, func() bool { return ui.stepSpinning }) + } else { + ui.mu.Unlock() + } +} + +func (ui *UI) CompleteCurrentStep(message string, success bool) { + ui.mu.Lock() + defer ui.mu.Unlock() + if !ui.stepSpinning { + return + } + ui.stepSpinning = false + fmt.Print("\r\033[K") + if success { + fmt.Printf(" %s✓%s %s\n", ColorGreen, ColorReset, message) + } else { + fmt.Printf(" %s✗%s %s\n", ColorRed, ColorReset, message) + } +} diff --git a/go/cmd/cli/commands/init/init.go b/go/cmd/cli/commands/init/init.go new file mode 100644 index 0000000000..3b5b91f0ed --- /dev/null +++ b/go/cmd/cli/commands/init/init.go @@ -0,0 +1,69 @@ +package init + +import ( + "context" + "fmt" + + "github.com/unkeyed/unkey/go/cmd/cli/cli" +) + +var Command = &cli.Command{ + Name: "init", + Usage: "Initialize configuration file for Unkey CLI", + Description: `Initialize a configuration file to store default values for workspace ID, project ID, and context path. +This will create a configuration file that can be used to avoid specifying common flags repeatedly. + +EXAMPLES: + # Create default config file (./unkey.json) + unkey init + + # Create config file at custom location + unkey init --config=./my-project.json + + # Initialize with specific values + unkey init --workspace-id=ws_123 --project-id=proj_456`, + Flags: []cli.Flag{ + cli.String("config", "Configuration file path", "./unkey.json", "", false), + cli.String("workspace-id", "Default workspace ID to save in config", "", "", false), + cli.String("project-id", "Default project ID to save in config", "", "", false), + cli.String("context", "Default Docker context path to save in config", "", "", false), + }, + Action: run, +} + +func run(ctx context.Context, cmd *cli.Command) error { + configPath := cmd.String("config") + workspaceID := cmd.String("workspace-id") + projectID := cmd.String("project-id") + contextPath := cmd.String("context") + + fmt.Println("🚀 Unkey CLI Configuration Setup") + fmt.Println("") + + // For now, just show what would be saved + fmt.Println("Configuration file support coming soon!") + fmt.Println("") + fmt.Printf("Config file location: %s\n", configPath) + + if workspaceID != "" { + fmt.Printf("Workspace ID: %s\n", workspaceID) + } + if projectID != "" { + fmt.Printf("Project ID: %s\n", projectID) + } + if contextPath != "" { + fmt.Printf("Context path: %s\n", contextPath) + } + + fmt.Println("") + fmt.Println("For now, use flags directly:") + fmt.Println("") + fmt.Println("Example:") + fmt.Println(" unkey deploy \\") + fmt.Println(" --workspace-id=ws_4QgQsKsKfdm3nGeC \\") + fmt.Println(" --project-id=proj_9aiaks2dzl6mcywnxjf \\") + fmt.Println(" --context=./demo_api") + fmt.Println("") + + return nil +} diff --git a/go/cmd/cli/commands/versions/versions.go b/go/cmd/cli/commands/versions/versions.go new file mode 100644 index 0000000000..643d1c3524 --- /dev/null +++ b/go/cmd/cli/commands/versions/versions.go @@ -0,0 +1,164 @@ +package versions + +import ( + "context" + "fmt" + + "github.com/unkeyed/unkey/go/cmd/cli/cli" + "github.com/unkeyed/unkey/go/cmd/cli/commands/deploy" +) + +// VersionListOptions holds options for version list command +type VersionListOptions struct { + Branch string + Status string + Limit int +} + +// Command defines the version CLI command with subcommands +var Command = &cli.Command{ + Name: "version", + Usage: "Manage API versions", + Description: `Create, list, and manage versions of your API. + +Versions are immutable snapshots of your code, configuration, and infrastructure settings. + +EXAMPLES: + # Create new version + unkey version create --workspace-id=ws_123 --project-id=proj_456 + + # List versions + unkey version list + unkey version list --branch=main --limit=20 + + # Get specific version + unkey version get v_abc123def456`, + Commands: []*cli.Command{ + createCmd, + listCmd, + getCmd, + }, +} + +// createCmd handles version create (alias for deploy) +var createCmd = &cli.Command{ + Name: "create", + Aliases: []string{"deploy"}, + Usage: "Create a new version (same as deploy)", + Description: "Same as 'unkey deploy'. See 'unkey help deploy' for details.", + Flags: deploy.DeployFlags, + Action: deploy.DeployAction, +} + +// listCmd handles version listing +var listCmd = &cli.Command{ + Name: "list", + Usage: "List versions", + Description: `List all versions with optional filtering. + +EXAMPLES: + # List all versions + unkey version list + + # Filter by branch + unkey version list --branch=main + + # Filter by status and limit results + unkey version list --status=active --limit=5`, + Flags: []cli.Flag{ + cli.String("branch", "Filter by branch", "", "", false), + cli.String("status", "Filter by status (pending, building, active, failed)", "", "", false), + cli.Int("limit", "Number of versions to show", 10, "", false), + }, + Action: listAction, +} + +// getCmd handles getting version details +var getCmd = &cli.Command{ + Name: "get", + Usage: "Get version details", + Description: `Get detailed information about a specific version. + +USAGE: + unkey version get + +EXAMPLES: + unkey version get v_abc123def456`, + Action: getAction, +} + +// listAction handles the version list command execution +func listAction(ctx context.Context, cmd *cli.Command) error { + opts := &VersionListOptions{ + Branch: cmd.String("branch"), + Status: cmd.String("status"), + Limit: cmd.Int("limit"), + } + + // Display filter info if provided + filters := []string{} + if opts.Branch != "" { + filters = append(filters, fmt.Sprintf("branch=%s", opts.Branch)) + } + if opts.Status != "" { + filters = append(filters, fmt.Sprintf("status=%s", opts.Status)) + } + filters = append(filters, fmt.Sprintf("limit=%d", opts.Limit)) + + if len(filters) > 1 { + fmt.Printf("Listing versions (%s)\n", fmt.Sprintf("%v", filters)) + } else { + fmt.Printf("Listing versions (limit=%d)\n", opts.Limit) + } + fmt.Println() + + // TODO: Add actual version listing logic here + // This would typically: + // 1. Call control plane API with filters + // 2. Parse response + // 3. Format and display results + + // Mock data for demonstration + fmt.Println("ID STATUS BRANCH CREATED") + fmt.Println("v_abc123def456 ACTIVE main 2024-01-01 12:00:00") + if opts.Branch == "" || opts.Branch == "feature" { + fmt.Println("v_def456ghi789 ACTIVE feature 2024-01-01 11:00:00") + } + if opts.Status == "" || opts.Status == "failed" { + fmt.Println("v_ghi789jkl012 FAILED main 2024-01-01 10:00:00") + } + + return nil +} + +// getAction handles the version get command execution +func getAction(ctx context.Context, cmd *cli.Command) error { + args := cmd.Args() + if len(args) == 0 { + return cli.Exit("version get requires a version ID", 1) + } + + versionID := args[0] + fmt.Printf("Getting version: %s\n", versionID) + fmt.Println() + + // TODO: Add actual version get logic here + // This would typically: + // 1. Call control plane API with version ID + // 2. Parse response + // 3. Display detailed information + + // Mock data for demonstration + fmt.Printf("Version: %s\n", versionID) + fmt.Printf("Status: ACTIVE\n") + fmt.Printf("Branch: main\n") + fmt.Printf("Created: 2024-01-01 12:00:00\n") + fmt.Printf("Docker Image: ghcr.io/unkeyed/deploy:main-abc123\n") + fmt.Printf("Commit: abc123def456789\n") + fmt.Println() + fmt.Printf("Hostnames:\n") + fmt.Printf(" - https://main-abc123-workspace.unkey.app\n") + fmt.Printf(" - https://api.acme.com\n") + + return nil +} diff --git a/go/cmd/cli/main.go b/go/cmd/cli/main.go new file mode 100644 index 0000000000..d96409f649 --- /dev/null +++ b/go/cmd/cli/main.go @@ -0,0 +1,31 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/unkeyed/unkey/go/cmd/cli/cli" + "github.com/unkeyed/unkey/go/cmd/cli/commands/deploy" + initcmd "github.com/unkeyed/unkey/go/cmd/cli/commands/init" + "github.com/unkeyed/unkey/go/cmd/cli/commands/versions" + "github.com/unkeyed/unkey/go/pkg/version" +) + +func main() { + app := &cli.Command{ + Name: "unkey", + Usage: "Deploy and manage your API versions", + Version: version.Version, + Commands: []*cli.Command{ + deploy.Command, + versions.Command, + initcmd.Command, + }, + } + + if err := app.Run(context.Background(), os.Args); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} From a0fc26e2ebcf513357077a34435ab5e733f1ccd0 Mon Sep 17 00:00:00 2001 From: chronark Date: Thu, 17 Jul 2025 12:08:39 +0200 Subject: [PATCH 5/6] fix: port allocation in tests we now try to listen on a random port assigned by the OS and never stop listening until the tests are done. This prevents the race conditions between assigning a port and using it --- .github/workflows/job_test_api_local.yaml | 3 +- apps/agent/cmd/agent/agent.go | 93 +------------- apps/agent/config.apprunner.production.json | 10 -- apps/agent/config.apprunner.staging.json | 10 -- apps/agent/config.docker.json | 10 -- apps/agent/config.production.json | 10 -- apps/agent/config.staging.json | 10 -- deployment/docker-compose.yaml | 1 - go/apps/api/cancel_test.go | 15 +-- go/apps/api/config.go | 8 ++ go/apps/api/integration/harness.go | 21 ++-- go/apps/api/run.go | 25 ++-- go/cmd/api/main.go | 5 +- go/pkg/port/doc.go | 37 ------ go/pkg/port/free.go | 128 -------------------- go/pkg/zen/README.md | 54 ++++++++- go/pkg/zen/doc.go | 8 +- go/pkg/zen/server.go | 14 +-- go/pkg/zen/server_tls_test.go | 42 +++---- 19 files changed, 128 insertions(+), 376 deletions(-) delete mode 100644 go/pkg/port/doc.go delete mode 100644 go/pkg/port/free.go diff --git a/.github/workflows/job_test_api_local.yaml b/.github/workflows/job_test_api_local.yaml index ea7192951a..3b59dba63a 100644 --- a/.github/workflows/job_test_api_local.yaml +++ b/.github/workflows/job_test_api_local.yaml @@ -3,7 +3,6 @@ on: workflow_call: permissions: contents: read - jobs: test: name: API Test Local @@ -16,7 +15,7 @@ jobs: run: rm -rf /opt/hostedtoolcache - name: Run containers - run: docker compose -f ./deployment/docker-compose.yaml up mysql planetscale agent s3 chproxy api -d + run: docker compose -f ./deployment/docker-compose.yaml up -d - name: Install uses: ./.github/actions/install diff --git a/apps/agent/cmd/agent/agent.go b/apps/agent/cmd/agent/agent.go index b7cb435cda..8f2f3fddd7 100644 --- a/apps/agent/cmd/agent/agent.go +++ b/apps/agent/cmd/agent/agent.go @@ -3,28 +3,22 @@ package agent import ( "context" "fmt" - "net" "os" "os/signal" "runtime/debug" "strings" "syscall" - "github.com/Southclaws/fault" - "github.com/Southclaws/fault/fmsg" "github.com/unkeyed/unkey/apps/agent/pkg/api" "github.com/unkeyed/unkey/apps/agent/pkg/clickhouse" - "github.com/unkeyed/unkey/apps/agent/pkg/cluster" "github.com/unkeyed/unkey/apps/agent/pkg/config" "github.com/unkeyed/unkey/apps/agent/pkg/connect" - "github.com/unkeyed/unkey/apps/agent/pkg/membership" "github.com/unkeyed/unkey/apps/agent/pkg/metrics" "github.com/unkeyed/unkey/apps/agent/pkg/profiling" "github.com/unkeyed/unkey/apps/agent/pkg/prometheus" "github.com/unkeyed/unkey/apps/agent/pkg/tracing" "github.com/unkeyed/unkey/apps/agent/pkg/uid" "github.com/unkeyed/unkey/apps/agent/pkg/version" - "github.com/unkeyed/unkey/apps/agent/services/ratelimit" "github.com/unkeyed/unkey/apps/agent/services/vault" "github.com/unkeyed/unkey/apps/agent/services/vault/storage" storageMiddleware "github.com/unkeyed/unkey/apps/agent/services/vault/storage/middleware" @@ -160,81 +154,13 @@ func run(c *cli.Context) error { return fmt.Errorf("failed to create vault service: %w", err) } - var clus cluster.Cluster - - if cfg.Cluster != nil { - - memb, membershipErr := membership.New(membership.Config{ - NodeId: cfg.NodeId, - RpcAddr: cfg.Cluster.RpcAddr, - SerfAddr: cfg.Cluster.SerfAddr, - Logger: logger, - }) - if membershipErr != nil { - return fmt.Errorf("failed to create membership: %w", membershipErr) - } - - var join []string - if cfg.Cluster.Join.Dns != nil { - addrs, lookupErr := net.LookupHost(cfg.Cluster.Join.Dns.AAAA) - if lookupErr != nil { - return fmt.Errorf("failed to lookup dns: %w", lookupErr) - } - logger.Info().Strs("addrs", addrs).Msg("found dns records") - join = addrs - } else if cfg.Cluster.Join.Env != nil { - join = cfg.Cluster.Join.Env.Addrs - } - - _, err = memb.Join(join...) - if err != nil { - return fault.Wrap(err, fmsg.With("failed to join cluster")) - } - defer func() { - logger.Info().Msg("leaving membership") - err = memb.Leave() - if err != nil { - logger.Error().Err(err).Msg("failed to leave cluster") - } - }() - - clus, err = cluster.New(cluster.Config{ - NodeId: cfg.NodeId, - RpcAddr: cfg.Cluster.RpcAddr, - Membership: memb, - Logger: logger, - Metrics: m, - Debug: true, - AuthToken: cfg.Cluster.AuthToken, - }) - if err != nil { - return fmt.Errorf("failed to create cluster: %w", err) - } - defer func() { - shutdownErr := clus.Shutdown() - if shutdownErr != nil { - logger.Error().Err(shutdownErr).Msg("failed to shutdown cluster") - } - }() - - } - - rl, err := ratelimit.New(ratelimit.Config{ - Logger: logger, - Metrics: m, - Cluster: clus, - }) - if err != nil { - logger.Fatal().Err(err).Msg("failed to create service") - } - srv, err := api.New(api.Config{ NodeId: cfg.NodeId, Logger: logger, - Ratelimit: rl, + Ratelimit: nil, Metrics: m, Clickhouse: ch, - AuthToken: cfg.Cluster.AuthToken, + AuthToken: cfg.AuthToken, Vault: v, }) if err != nil { @@ -246,17 +172,6 @@ func run(c *cli.Context) error { return err } - err = connectSrv.AddService(connect.NewClusterServer(clus, logger)) - if err != nil { - return fmt.Errorf("failed to add cluster service: %w", err) - - } - err = connectSrv.AddService(connect.NewRatelimitServer(rl, logger, cfg.AuthToken)) - if err != nil { - return fmt.Errorf("failed to add ratelimit service: %w", err) - } - logger.Info().Msg("started ratelimit service") - go func() { err = connectSrv.Listen(fmt.Sprintf(":%s", cfg.RpcPort)) if err != nil { @@ -295,10 +210,6 @@ func run(c *cli.Context) error { if err != nil { return fmt.Errorf("failed to shutdown service: %w", err) } - err = clus.Shutdown() - if err != nil { - return fmt.Errorf("failed to shutdown cluster: %w", err) - } return nil } diff --git a/apps/agent/config.apprunner.production.json b/apps/agent/config.apprunner.production.json index 38707af264..d4fee7ef9c 100644 --- a/apps/agent/config.apprunner.production.json +++ b/apps/agent/config.apprunner.production.json @@ -35,16 +35,6 @@ "masterKeys": "${VAULT_MASTER_KEYS}" } }, - "cluster": { - "authToken": "${AUTH_TOKEN}", - "serfAddr": "[${FLY_PRIVATE_IP}]:${SERF_PORT}", - "rpcAddr": "http://${FLY_PRIVATE_IP}:${RPC_PORT}", - "join": { - "env": { - "addrs": [] - } - } - }, "heartbeat": { "interval": 60, "url": "${HEARTBEAT_URL}" diff --git a/apps/agent/config.apprunner.staging.json b/apps/agent/config.apprunner.staging.json index 232db2a3d8..c5a4329864 100644 --- a/apps/agent/config.apprunner.staging.json +++ b/apps/agent/config.apprunner.staging.json @@ -18,15 +18,5 @@ "s3AccessKeySecret": "${VAULT_S3_ACCESS_KEY_SECRET}", "masterKeys": "${VAULT_MASTER_KEYS}" } - }, - "cluster": { - "authToken": "${AUTH_TOKEN}", - "serfAddr": "[${FLY_PRIVATE_IP}]:${SERF_PORT}", - "rpcAddr": "http://${FLY_PRIVATE_IP}:${RPC_PORT}", - "join": { - "env": { - "addrs": [] - } - } } } diff --git a/apps/agent/config.docker.json b/apps/agent/config.docker.json index 7972eb6b9e..7a64d26d48 100644 --- a/apps/agent/config.docker.json +++ b/apps/agent/config.docker.json @@ -9,16 +9,6 @@ "authToken": "${AUTH_TOKEN}", "nodeId": "${NODE_ID}", "logging": {}, - "cluster": { - "authToken": "${AUTH_TOKEN}", - "serfAddr": "${HOSTNAME}:${SERF_PORT}", - "rpcAddr": "${HOSTNAME}:${RPC_PORT}", - "join": { - "env": { - "addrs": ["unkey-agent-1:${SERF_PORT}"] - } - } - }, "services": { "vault": { "s3Url": "${VAULT_S3_URL}", diff --git a/apps/agent/config.production.json b/apps/agent/config.production.json index ee8d7083d5..e183459186 100644 --- a/apps/agent/config.production.json +++ b/apps/agent/config.production.json @@ -38,16 +38,6 @@ "masterKeys": "${VAULT_MASTER_KEYS}" } }, - "cluster": { - "authToken": "${AUTH_TOKEN}", - "serfAddr": "[${FLY_PRIVATE_IP}]:${SERF_PORT}", - "rpcAddr": "http://${FLY_PRIVATE_IP}:${RPC_PORT}", - "join": { - "dns": { - "aaaa": "${FLY_APP_NAME}.internal" - } - } - }, "heartbeat": { "interval": 60, "url": "${HEARTBEAT_URL}" diff --git a/apps/agent/config.staging.json b/apps/agent/config.staging.json index 7fba639265..251cb7aee4 100644 --- a/apps/agent/config.staging.json +++ b/apps/agent/config.staging.json @@ -20,16 +20,6 @@ "masterKeys": "${VAULT_MASTER_KEYS}" } }, - "cluster": { - "authToken": "${AUTH_TOKEN}", - "serfAddr": "[${FLY_PRIVATE_IP}]:${SERF_PORT}", - "rpcAddr": "http://${FLY_PRIVATE_IP}:${RPC_PORT}", - "join": { - "dns": { - "aaaa": "${FLY_APP_NAME}.internal" - } - } - }, "prometheus": { "path": "/metrics", "port": 2112 diff --git a/deployment/docker-compose.yaml b/deployment/docker-compose.yaml index dd77f7c669..4ff5290c1e 100644 --- a/deployment/docker-compose.yaml +++ b/deployment/docker-compose.yaml @@ -88,7 +88,6 @@ services: - clickhouse environment: PORT: 8080 - SERF_PORT: 9999 RPC_PORT: 9095 AUTH_TOKEN: "agent-auth-secret" VAULT_S3_URL: "http://s3:3902" diff --git a/go/apps/api/cancel_test.go b/go/apps/api/cancel_test.go index f64a9d020d..3e5cef0d6e 100644 --- a/go/apps/api/cancel_test.go +++ b/go/apps/api/cancel_test.go @@ -3,13 +3,13 @@ package api_test import ( "context" "fmt" + "net" "net/http" "testing" "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api" - "github.com/unkeyed/unkey/go/pkg/port" "github.com/unkeyed/unkey/go/pkg/testutil/containers" "github.com/unkeyed/unkey/go/pkg/uid" "github.com/unkeyed/unkey/go/pkg/vault/keys" @@ -23,9 +23,10 @@ func TestContextCancellation(t *testing.T) { mysqlCfg.DBName = "unkey" dbDsn := mysqlCfg.FormatDSN() redisUrl := containers.Redis(t) - // Get free ports for the node - portAllocator := port.New() - httpPort := portAllocator.Get() + + // Create ephemeral listener + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err, "Failed to create ephemeral listener") // Create a cancellable context ctx, cancel := context.WithCancel(context.Background()) @@ -37,7 +38,7 @@ func TestContextCancellation(t *testing.T) { config := api.Config{ Platform: "test", Image: "test", - HttpPort: httpPort, + Listener: ln, Region: "test-region", Clock: nil, // Will use real clock InstanceID: uid.New(uid.InstancePrefix), @@ -65,7 +66,7 @@ func TestContextCancellation(t *testing.T) { // Wait for the server to start up require.Eventually(t, func() bool { - res, livenessErr := http.Get(fmt.Sprintf("http://localhost:%d/v2/liveness", httpPort)) + res, livenessErr := http.Get(fmt.Sprintf("http://%s/v2/liveness", ln.Addr())) if livenessErr != nil { return false } @@ -90,6 +91,6 @@ func TestContextCancellation(t *testing.T) { } // Verify the server is no longer responding - _, err = http.Get(fmt.Sprintf("http://localhost:%d/v2/liveness", httpPort)) + _, err = http.Get(fmt.Sprintf("http://%s/v2/liveness", ln.Addr())) require.Error(t, err, "Server should no longer be responding after shutdown") } diff --git a/go/apps/api/config.go b/go/apps/api/config.go index fbeb1f2e6d..07995fbd46 100644 --- a/go/apps/api/config.go +++ b/go/apps/api/config.go @@ -1,6 +1,8 @@ package api import ( + "net" + "github.com/unkeyed/unkey/go/pkg/clock" "github.com/unkeyed/unkey/go/pkg/tls" ) @@ -23,8 +25,14 @@ type Config struct { Image string // HttpPort defines the HTTP port for the API server to listen on (default: 7070) + // Used in production deployments. Ignored if Listener is provided. HttpPort int + // Listener defines a pre-created network listener for the HTTP server + // If provided, the server will use this listener instead of creating one from HttpPort + // This is intended for testing scenarios where ephemeral ports are needed to avoid conflicts + Listener net.Listener + // Region identifies the geographic region where this node is deployed Region string diff --git a/go/apps/api/integration/harness.go b/go/apps/api/integration/harness.go index 230e2b6b4b..c8f8814561 100644 --- a/go/apps/api/integration/harness.go +++ b/go/apps/api/integration/harness.go @@ -3,6 +3,7 @@ package integration import ( "context" "fmt" + "net" "net/http" "testing" "time" @@ -13,7 +14,6 @@ import ( "github.com/unkeyed/unkey/go/pkg/clock" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/otel/logging" - "github.com/unkeyed/unkey/go/pkg/port" "github.com/unkeyed/unkey/go/pkg/testutil/containers" "github.com/unkeyed/unkey/go/pkg/testutil/seed" ) @@ -36,7 +36,6 @@ type Harness struct { ctx context.Context cancel context.CancelFunc instanceAddrs []string - ports *port.FreePort Seed *seed.Seeder dbDSN string DB db.Database @@ -87,7 +86,6 @@ func New(t *testing.T, config Config) *Harness { t: t, ctx: ctx, cancel: cancel, - ports: port.New(), instanceAddrs: []string{}, Seed: seed.New(t, db), dbDSN: mysqlHostDSN, @@ -124,11 +122,11 @@ func (h *Harness) RunAPI(config ApiConfig) *ApiCluster { // Start each API node as a goroutine for i := 0; i < config.Nodes; i++ { - // Find an available port - portFinder := port.New() - nodePort := portFinder.Get() + // Create ephemeral listener + ln, err := net.Listen("tcp", ":0") + require.NoError(h.t, err, "Failed to create ephemeral listener") - cluster.Addrs[i] = fmt.Sprintf("http://localhost:%d", nodePort) + cluster.Addrs[i] = fmt.Sprintf("http://%s", ln.Addr().String()) // Create API config for this node using host connections mysqlHostCfg := containers.MySQL(h.t) @@ -139,7 +137,7 @@ func (h *Harness) RunAPI(config ApiConfig) *ApiCluster { apiConfig := api.Config{ Platform: "test", Image: "test", - HttpPort: nodePort, + Listener: ln, DatabasePrimary: mysqlHostCfg.FormatDSN(), DatabaseReadonlyReplica: "", ClickhouseURL: clickhouseHostDSN, @@ -198,12 +196,13 @@ func (h *Harness) RunAPI(config ApiConfig) *ApiCluster { // Wait for server to start maxAttempts := 30 + healthURL := fmt.Sprintf("http://%s/v2/liveness", ln.Addr().String()) for attempt := 0; attempt < maxAttempts; attempt++ { - resp, err := http.Get(fmt.Sprintf("http://localhost:%d/v2/liveness", nodePort)) + resp, err := http.Get(healthURL) if err == nil { resp.Body.Close() if resp.StatusCode == http.StatusOK { - h.t.Logf("API server %d started on port %d", i, nodePort) + h.t.Logf("API server %d started on %s", i, ln.Addr().String()) break } } @@ -216,6 +215,8 @@ func (h *Harness) RunAPI(config ApiConfig) *ApiCluster { // Register cleanup h.t.Cleanup(func() { cancel() + // Note: Don't call ln.Close() here as the zen server + // will properly close the listener during graceful shutdown }) } diff --git a/go/apps/api/run.go b/go/apps/api/run.go index 5169007fdf..185ffc875a 100644 --- a/go/apps/api/run.go +++ b/go/apps/api/run.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "net" "runtime/debug" "time" @@ -110,7 +111,11 @@ func Run(ctx context.Context, cfg Config) error { return fmt.Errorf("unable to start prometheus: %w", promErr) } go func() { - promListenErr := prom.Listen(ctx, fmt.Sprintf(":%d", cfg.PrometheusPort)) + promListener, err := net.Listen("tcp", fmt.Sprintf(":%d", cfg.PrometheusPort)) + if err != nil { + panic(err) + } + promListenErr := prom.Serve(ctx, promListener) if promListenErr != nil { panic(promListenErr) } @@ -222,16 +227,22 @@ func Run(ctx context.Context, cfg Config) error { Caches: caches, Vault: vaultSvc, }) + if cfg.Listener == nil { + // Create listener from HttpPort (production) + cfg.Listener, err = net.Listen("tcp", fmt.Sprintf(":%d", cfg.HttpPort)) + if err != nil { + return fmt.Errorf("Unable to listen on port %d: %w", cfg.HttpPort, err) + } + } go func() { - listenErr := srv.Listen(ctx, fmt.Sprintf(":%d", cfg.HttpPort)) - if listenErr != nil { - panic(listenErr) + serveErr := srv.Serve(ctx, cfg.Listener) + if serveErr != nil { + panic(serveErr) } - }() - // Wait for signals and handle shutdown - logger.Info("API server started successfully") + logger.Info("API server started successfully") + }() // Wait for either OS signals or context cancellation, then shutdown if err := shutdowns.WaitForSignal(ctx, time.Minute); err != nil { diff --git a/go/cmd/api/main.go b/go/cmd/api/main.go index ed28550263..e9a3ce90fb 100644 --- a/go/cmd/api/main.go +++ b/go/cmd/api/main.go @@ -208,7 +208,6 @@ func action(ctx context.Context, cmd *cli.Command) error { // Basic configuration Platform: cmd.String("platform"), Image: cmd.String("image"), - HttpPort: cmd.Int("http-port"), Region: cmd.String("region"), // Database configuration @@ -231,6 +230,10 @@ func action(ctx context.Context, cmd *cli.Command) error { Clock: clock.New(), TestMode: cmd.Bool("test-mode"), + // HTTP configuration + HttpPort: cmd.Int("http-port"), + Listener: nil, // Production uses HttpPort + // Vault configuration VaultMasterKeys: cmd.StringSlice("vault-master-keys"), VaultS3: vaultS3Config, diff --git a/go/pkg/port/doc.go b/go/pkg/port/doc.go deleted file mode 100644 index 0483e3b769..0000000000 --- a/go/pkg/port/doc.go +++ /dev/null @@ -1,37 +0,0 @@ -// Package port provides utilities for finding and managing available network ports. -// -// This package is particularly useful for testing scenarios where multiple -// services need to run on unique ports without conflicting with each other -// or with existing services. It safely locates available ports through actual -// network binding and offers mechanisms to track allocated ports to prevent -// reuse within the same process. -// -// The implementation uses a combination of random port selection and actual -// TCP socket binding to verify availability. This approach is more reliable -// than just checking if a port is currently in use, as it accounts for -// ports that may be temporarily unavailable or restricted by the operating system. -// -// Basic usage: -// -// // Create a port finder -// finder := port.New() -// -// // Get an available port -// port := finder.Get() -// -// // Use the port for your service -// listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) -// -// // Or for testing multiple services: -// port1 := finder.Get() -// port2 := finder.Get() -// port3 := finder.Get() -// -// The package tracks ports it has assigned within the current process -// to ensure the same port isn't returned twice, even if the port hasn't -// been bound yet. -// -// Note that port availability is only guaranteed at the moment Get() is called. -// If there is a delay between getting the port and binding to it, another -// process could potentially bind to that port in the meantime. -package port diff --git a/go/pkg/port/free.go b/go/pkg/port/free.go deleted file mode 100644 index 86e5be048b..0000000000 --- a/go/pkg/port/free.go +++ /dev/null @@ -1,128 +0,0 @@ -package port - -import ( - "fmt" - "math/rand/v2" - "net" - "sync" -) - -// FreePort provides utilities for finding available network ports. -// It manages a pool of assigned ports to prevent the same port from -// being returned multiple times within the same process. -type FreePort struct { - mu sync.RWMutex - min int - max int - attempts int - - // The caller may request multiple ports without binding them immediately - // so we need to keep track of which ports are assigned. - assigned map[int]bool -} - -// New creates a new FreePort instance for finding available ports. -// The returned instance keeps track of ports it has assigned to prevent -// returning the same port twice, even if the actual binding hasn't occurred. -// -// By default, ports are selected from the range 10000-65535, which falls -// within the standard range for ephemeral/private ports. -// -// Example: -// -// // Create a new port finder -// portFinder := port.New() -// -// // Get multiple available ports -// httpPort := portFinder.Get() -// grpcPort := portFinder.Get() -// metricsPort := portFinder.Get() -// -// fmt.Printf("Running HTTP on port %d, gRPC on port %d, metrics on port %d\n", -// httpPort, grpcPort, metricsPort) -func New() *FreePort { - return &FreePort{ - min: 10000, - max: 65535, - attempts: 10, - assigned: map[int]bool{}, - mu: sync.RWMutex{}, - } -} - -// Get returns an available TCP port number. -// The port is guaranteed to be available at the time of the call, -// and will not be returned again by the same FreePort instance. -// -// This method will attempt to find an available port by: -// 1. Selecting a random port in the range 10000-65535 -// 2. Checking that the port hasn't already been assigned by this instance -// 3. Verifying availability by attempting to bind to 127.0.0.1 on that port -// 4. Marking the port as assigned to prevent future reuse -// -// If no available port can be found after multiple attempts, Get will panic. -// For cases where error handling is preferred over panicking, use GetWithError. -// -// Example: -// -// finder := port.New() -// serverPort := finder.Get() -// -// // Start your server on this port -// server := &http.Server{ -// Addr: fmt.Sprintf(":%d", serverPort), -// Handler: mux, -// } -// server.ListenAndServe() -func (f *FreePort) Get() int { - port, err := f.GetWithError() - if err != nil { - panic(err) - } - - return port -} - -// GetWithError returns an available TCP port number or an error if no port -// could be found after multiple attempts. -// -// This method works the same as Get() but returns an error instead of -// panicking when no available ports can be found. This is preferred in -// production code where error handling is more appropriate than panicking. -// -// Example: -// -// finder := port.New() -// port, err := finder.GetWithError() -// if err != nil { -// log.Fatalf("Failed to find available port: %v", err) -// } -// -// // Use the port -// listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) -func (f *FreePort) GetWithError() (int, error) { - f.mu.Lock() - defer f.mu.Unlock() - - for i := 0; i < f.attempts; i++ { - - // nolint:gosec - // This isn't cryptography - port := rand.IntN(f.max-f.min) + f.min - if f.assigned[port] { - continue - } - - ln, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: port, Zone: ""}) - if err != nil { - continue - } - err = ln.Close() - if err != nil { - return -1, err - } - f.assigned[port] = true - return port, nil - } - return -1, fmt.Errorf("could not find a free port, maybe increase attempts?") -} diff --git a/go/pkg/zen/README.md b/go/pkg/zen/README.md index bfda47d98b..4811407c1c 100644 --- a/go/pkg/zen/README.md +++ b/go/pkg/zen/README.md @@ -38,6 +38,7 @@ import ( "context" "log" "log/slog" + "net" "net/http" "github.com/unkeyed/unkey/go/pkg/zen" @@ -141,7 +142,14 @@ func main() { logger.Info("starting server", "address", ":8080", ) - err = server.Listen(context.Background(), ":8080") + + // Create a listener + listener, err := net.Listen("tcp", ":8080") + if err != nil { + log.Fatalf("failed to create listener: %v", err) + } + + err = server.Serve(context.Background(), listener) if err != nil { logger.Error("server error", slog.String("error", err.Error())) } @@ -158,6 +166,7 @@ package main import ( "context" "log" + "net" "github.com/unkeyed/unkey/go/pkg/tls" "github.com/unkeyed/unkey/go/pkg/zen" @@ -184,9 +193,15 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // Create a listener for HTTPS + listener, err := net.Listen("tcp", ":443") + if err != nil { + log.Fatalf("failed to create listener: %v", err) + } + // Start in a goroutine so you can handle shutdown signals go func() { - if err := server.Listen(ctx, ":443"); err != nil { + if err := server.Serve(ctx, listener); err != nil { log.Fatalf("server error: %v", err) } }() @@ -199,6 +214,32 @@ func main() { } ``` +## Testing with Ephemeral Ports + +For testing, you can use ephemeral ports to let the OS assign an available port automatically. This prevents port conflicts in testing environments: + +```go +import "github.com/unkeyed/unkey/go/pkg/listener" + +// Get an available port and listener +listenerImpl, err := listener.Ephemeral() +if err != nil { + t.Fatalf("failed to create ephemeral listener: %v", err) +} +netListener, err := listenerImpl.Listen() +if err != nil { + t.Fatalf("failed to get listener: %v", err) +} + +// Start the server +go server.Serve(ctx, netListener) + +// Make requests to the server +resp, err := http.Get(fmt.Sprintf("http://%s/test", listenerImpl.Addr())) +``` + +This approach is especially useful for concurrent tests where multiple servers need to run simultaneously without conflicting ports. + ## Working with OpenAPI Validation Zen works well with a schema-first approach to API design. Define your OpenAPI specification first, then use it for validation: @@ -228,8 +269,13 @@ Zen provides built-in support for graceful shutdown through context cancellation // Create a context that can be cancelled ctx, cancel := context.WithCancel(context.Background()) -// Start the server with this context -go server.Listen(ctx, ":8080") +// Create a listener and start the server with this context +listener, err := net.Listen("tcp", ":8080") +if err != nil { + log.Fatalf("failed to create listener: %v", err) +} + +go server.Serve(ctx, listener) // When you need to shut down (e.g., on SIGTERM): cancel() diff --git a/go/pkg/zen/doc.go b/go/pkg/zen/doc.go index 88faf29ddc..fb12bef9c5 100644 --- a/go/pkg/zen/doc.go +++ b/go/pkg/zen/doc.go @@ -49,8 +49,12 @@ // route, // ) // -// // Start the server -// err = server.Listen(ctx, ":8080") +// // Create a listener and start the server +// listener, err := net.Listen("tcp", ":8080") +// if err != nil { +// log.Fatalf("failed to create listener: %v", err) +// } +// err = server.Serve(ctx, listener) // // Zen is optimized for building maintainable, observable web services with minimal // external dependencies and strong integration with standard Go libraries. diff --git a/go/pkg/zen/server.go b/go/pkg/zen/server.go index cbf0fb34cd..9e47d1b8da 100644 --- a/go/pkg/zen/server.go +++ b/go/pkg/zen/server.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "net/http" "sync" "time" @@ -162,7 +163,7 @@ func (s *Server) Flags() Flags { // log.Printf("server stopped: %v", err) // } // }() -func (s *Server) Listen(ctx context.Context, addr string) error { +func (s *Server) Serve(ctx context.Context, ln net.Listener) error { s.mu.Lock() if s.isListening { s.logger.Warn("already listening") @@ -172,8 +173,6 @@ func (s *Server) Listen(ctx context.Context, addr string) error { s.isListening = true s.mu.Unlock() - s.srv.Addr = addr - // Set up context handling for graceful shutdown serverCtx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -194,20 +193,19 @@ func (s *Server) Listen(ctx context.Context, addr string) error { // Server stopped on its own } }() - var err error // Check if TLS should be used if s.tlsConfig != nil { - s.logger.Info("listening", "srv", "https", "addr", addr) + s.logger.Info("listening", "srv", "https", "addr", ln.Addr().String()) s.srv.TLSConfig = s.tlsConfig // ListenAndServeTLS with empty strings will use the certificates from TLSConfig - err = s.srv.ListenAndServeTLS("", "") + err = s.srv.ServeTLS(ln, "", "") } else { - s.logger.Info("listening", "srv", "http", "addr", addr) - err = s.srv.ListenAndServe() + s.logger.Info("listening", "srv", "http", "addr", ln.Addr().String()) + err = s.srv.Serve(ln) } // Cancel the server context since the server has stopped diff --git a/go/pkg/zen/server_tls_test.go b/go/pkg/zen/server_tls_test.go index 1a0f4fc96a..7abfedf180 100644 --- a/go/pkg/zen/server_tls_test.go +++ b/go/pkg/zen/server_tls_test.go @@ -98,16 +98,12 @@ func TestServerWithTLS(t *testing.T) { }) server.RegisterRoute([]Middleware{}, testRoute) - // Create a net.Listener to determine the port - ln, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err, "Failed to create listener") + // Create ephemeral listener + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err, "Failed to create ephemeral listener") - // Get the assigned port - _, portStr, err := net.SplitHostPort(ln.Addr().String()) - require.NoError(t, err, "Failed to get port") - - // Modify server to use our listener's port - addr := "localhost:" + portStr + // Get the address for the test client + addr := ln.Addr().String() // Start the server in a goroutine serverCtx, serverCancel := context.WithCancel(context.Background()) @@ -117,15 +113,12 @@ func TestServerWithTLS(t *testing.T) { serverReady := make(chan struct{}) go func() { - // Close our listener as server.Listen will create its own - ln.Close() - // Signal that we're about to start the server close(serverReady) - listenErr := server.Listen(serverCtx, addr) + listenErr := server.Serve(serverCtx, ln) if listenErr != nil && listenErr.Error() != "http: Server closed" { - t.Errorf("server.Listen returned: %v", listenErr) + t.Errorf("server.Serve returned: %v", listenErr) } }() defer server.Shutdown(context.Background()) @@ -200,16 +193,12 @@ func TestServerWithTLSContextCancellation(t *testing.T) { }) server.RegisterRoute([]Middleware{}, testRoute) - // Create a net.Listener to determine the port - ln, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err, "Failed to create listener") + // Create ephemeral listener + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err, "Failed to create ephemeral listener") - // Get the assigned port - _, portStr, err := net.SplitHostPort(ln.Addr().String()) - require.NoError(t, err, "Failed to get port") - - // Modify server to use our listener's port - addr := "localhost:" + portStr + // Get the address for the test client + addr := ln.Addr().String() // Create a context that can be canceled serverCtx, serverCancel := context.WithCancel(context.Background()) @@ -222,15 +211,12 @@ func TestServerWithTLSContextCancellation(t *testing.T) { // Start the server in a goroutine go func() { - // Close our listener as server.Listen will create its own - ln.Close() - // Signal that we're about to start the server close(serverReady) - listenErr := server.Listen(serverCtx, addr) + listenErr := server.Serve(serverCtx, ln) if listenErr != nil && listenErr.Error() != "http: Server closed" { - t.Errorf("server.Listen returned: %v", listenErr) + t.Errorf("server.Serve returned: %v", listenErr) } // Signal that the server has exited From dfbeb3c3e64a4118ba83d9615ebe71c2d169eca2 Mon Sep 17 00:00:00 2001 From: chronark Date: Thu, 17 Jul 2025 14:07:27 +0200 Subject: [PATCH 6/6] ci: do stuff --- .github/workflows/job_test_api_local.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/job_test_api_local.yaml b/.github/workflows/job_test_api_local.yaml index 3b59dba63a..ea7192951a 100644 --- a/.github/workflows/job_test_api_local.yaml +++ b/.github/workflows/job_test_api_local.yaml @@ -3,6 +3,7 @@ on: workflow_call: permissions: contents: read + jobs: test: name: API Test Local @@ -15,7 +16,7 @@ jobs: run: rm -rf /opt/hostedtoolcache - name: Run containers - run: docker compose -f ./deployment/docker-compose.yaml up -d + run: docker compose -f ./deployment/docker-compose.yaml up mysql planetscale agent s3 chproxy api -d - name: Install uses: ./.github/actions/install