diff --git a/go/apps/ctrl/config.go b/go/apps/ctrl/config.go index 6bb1e7241d..3b1b4278b7 100644 --- a/go/apps/ctrl/config.go +++ b/go/apps/ctrl/config.go @@ -32,12 +32,36 @@ type CloudflareConfig struct { ApiToken string } +type Route53Config struct { + // Enables DNS-01 challenges using AWS Route53 + Enabled bool + + // AccessKeyID is the AWS access key ID + AccessKeyID string + + // SecretAccessKey is the AWS secret access key + SecretAccessKey string + + // Region is the AWS region (e.g., "us-east-1") + Region string + + // HostedZoneID bypasses zone auto-discovery. Required when domains have CNAMEs + // that confuse the zone lookup (e.g., wildcard CNAMEs to load balancers). + HostedZoneID string +} + type AcmeConfig struct { // Enables ACME challenges for TLS certificates Enabled bool - // Enables DNS-01 challenges using Cloudflare + // EmailDomain is the domain used for ACME account emails (e.g., "unkey.com") + EmailDomain string + + // Cloudflare enables DNS-01 challenges using Cloudflare Cloudflare CloudflareConfig + + // Route53 enables DNS-01 challenges using AWS Route53 + Route53 Route53Config } type RestateConfig struct { @@ -206,6 +230,17 @@ func (c Config) Validate() error { } } + // Validate Route53 configuration if enabled + if c.Acme.Enabled && c.Acme.Route53.Enabled { + if err := assert.All( + assert.NotEmpty(c.Acme.Route53.AccessKeyID, "route53 access key ID is required when route53 is enabled"), + assert.NotEmpty(c.Acme.Route53.SecretAccessKey, "route53 secret access key is required when route53 is enabled"), + assert.NotEmpty(c.Acme.Route53.Region, "route53 region is required when route53 is enabled"), + ); err != nil { + return err + } + } + if err := assert.NotEmpty(c.ClickhouseURL, "ClickhouseURL is required"); err != nil { return err } diff --git a/go/apps/ctrl/internal/caches/caches.go b/go/apps/ctrl/internal/caches/caches.go new file mode 100644 index 0000000000..70b3e3fb90 --- /dev/null +++ b/go/apps/ctrl/internal/caches/caches.go @@ -0,0 +1,58 @@ +package caches + +import ( + "time" + + "github.com/unkeyed/unkey/go/pkg/cache" + "github.com/unkeyed/unkey/go/pkg/clock" + "github.com/unkeyed/unkey/go/pkg/db" + "github.com/unkeyed/unkey/go/pkg/otel/logging" +) + +// Caches holds all shared cache instances for the ctrl application. +type Caches struct { + Domains cache.Cache[string, db.CustomDomain] + Challenges cache.Cache[string, db.AcmeChallenge] +} + +type Config struct { + Logger logging.Logger + Clock clock.Clock +} + +func New(cfg Config) (*Caches, error) { + clk := cfg.Clock + if clk == nil { + clk = clock.New() + } + + domains, err := cache.New(cache.Config[string, db.CustomDomain]{ + Fresh: 5 * time.Minute, + Stale: 10 * time.Minute, + MaxSize: 10000, + Logger: cfg.Logger, + Resource: "domains", + Clock: clk, + }) + if err != nil { + return nil, err + } + + // Short TTL for challenges since they change during ACME flow + challenges, err := cache.New(cache.Config[string, db.AcmeChallenge]{ + Fresh: 10 * time.Second, + Stale: 30 * time.Second, + MaxSize: 1000, + Logger: cfg.Logger, + Resource: "acme_challenges", + Clock: clk, + }) + if err != nil { + return nil, err + } + + return &Caches{ + Domains: domains, + Challenges: challenges, + }, nil +} diff --git a/go/apps/ctrl/run.go b/go/apps/ctrl/run.go index ca04f6e366..3c1fe2757f 100644 --- a/go/apps/ctrl/run.go +++ b/go/apps/ctrl/run.go @@ -3,16 +3,22 @@ package ctrl import ( "bytes" "context" + "database/sql" "fmt" "log/slog" "net/http" + "os" "time" "connectrpc.com/connect" + "github.com/go-acme/lego/v4/challenge" + restate "github.com/restatedev/sdk-go" restateIngress "github.com/restatedev/sdk-go/ingress" restateServer "github.com/restatedev/sdk-go/server" + ctrlCaches "github.com/unkeyed/unkey/go/apps/ctrl/internal/caches" "github.com/unkeyed/unkey/go/apps/ctrl/middleware" "github.com/unkeyed/unkey/go/apps/ctrl/services/acme" + "github.com/unkeyed/unkey/go/apps/ctrl/services/acme/providers" "github.com/unkeyed/unkey/go/apps/ctrl/services/build/backend/depot" "github.com/unkeyed/unkey/go/apps/ctrl/services/build/backend/docker" buildStorage "github.com/unkeyed/unkey/go/apps/ctrl/services/build/storage" @@ -29,11 +35,13 @@ import ( hydrav1 "github.com/unkeyed/unkey/go/gen/proto/hydra/v1" "github.com/unkeyed/unkey/go/gen/proto/krane/v1/kranev1connect" "github.com/unkeyed/unkey/go/pkg/clickhouse" + "github.com/unkeyed/unkey/go/pkg/clock" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/otel" "github.com/unkeyed/unkey/go/pkg/otel/logging" "github.com/unkeyed/unkey/go/pkg/retry" "github.com/unkeyed/unkey/go/pkg/shutdown" + "github.com/unkeyed/unkey/go/pkg/uid" "github.com/unkeyed/unkey/go/pkg/vault" "github.com/unkeyed/unkey/go/pkg/vault/storage" pkgversion "github.com/unkeyed/unkey/go/pkg/version" @@ -47,6 +55,11 @@ func Run(ctx context.Context, cfg Config) error { return fmt.Errorf("bad config: %w", err) } + // Disable CNAME following in lego to prevent it from following wildcard CNAMEs + // (e.g., *.example.com -> loadbalancer.aws.com) and failing Route53 zone lookup. + // Must be set before creating any ACME DNS providers. + os.Setenv("LEGO_DISABLE_CNAME_SUPPORT", "true") + shutdowns := shutdown.New() if cfg.OtelEnabled { @@ -270,11 +283,73 @@ func Run(ctx context.Context, cfg Config) error { DefaultDomain: cfg.DefaultDomain, }))) - restateSrv.Bind(hydrav1.NewCertificateServiceServer(certificate.New(certificate.Config{ + // Initialize shared caches for ACME (needed for verification endpoint regardless of provider config) + caches, cacheErr := ctrlCaches.New(ctrlCaches.Config{ Logger: logger, - DB: database, - Vault: acmeVaultSvc, - }))) + Clock: clock.New(), + }) + if cacheErr != nil { + return fmt.Errorf("failed to create ACME caches: %w", cacheErr) + } + + // Setup ACME challenge providers + var dnsProvider challenge.Provider + var httpProvider challenge.Provider + if cfg.Acme.Enabled { + // HTTP-01 provider for regular (non-wildcard) domains + httpProv, httpErr := providers.NewHTTPProvider(providers.HTTPConfig{ + DB: database, + Logger: logger, + DomainCache: caches.Domains, + }) + if httpErr != nil { + return fmt.Errorf("failed to create HTTP-01 provider: %w", httpErr) + } + httpProvider = httpProv + logger.Info("ACME HTTP-01 provider enabled") + + // DNS-01 provider for wildcard domains (requires DNS provider config) + if cfg.Acme.Cloudflare.Enabled { + cfProvider, cfErr := providers.NewCloudflareProvider(providers.CloudflareConfig{ + DB: database, + Logger: logger, + APIToken: cfg.Acme.Cloudflare.ApiToken, + DomainCache: caches.Domains, + }) + if cfErr != nil { + return fmt.Errorf("failed to create Cloudflare DNS provider: %w", cfErr) + } + dnsProvider = cfProvider + logger.Info("ACME Cloudflare DNS-01 provider enabled for wildcard certs") + } else if cfg.Acme.Route53.Enabled { + r53Provider, r53Err := providers.NewRoute53Provider(providers.Route53Config{ + DB: database, + Logger: logger, + AccessKeyID: cfg.Acme.Route53.AccessKeyID, + SecretAccessKey: cfg.Acme.Route53.SecretAccessKey, + Region: cfg.Acme.Route53.Region, + HostedZoneID: cfg.Acme.Route53.HostedZoneID, + DomainCache: caches.Domains, + }) + if r53Err != nil { + return fmt.Errorf("failed to create Route53 DNS provider: %w", r53Err) + } + dnsProvider = r53Provider + logger.Info("ACME Route53 DNS-01 provider enabled for wildcard certs") + } + } + + // Certificate service needs a longer timeout for ACME DNS-01 challenges + // which can take 5-10 minutes for DNS propagation + restateSrv.Bind(hydrav1.NewCertificateServiceServer(certificate.New(certificate.Config{ + Logger: logger, + DB: database, + Vault: acmeVaultSvc, + EmailDomain: cfg.Acme.EmailDomain, + DefaultDomain: cfg.DefaultDomain, + DNSProvider: dnsProvider, + HTTPProvider: httpProvider, + }), restate.WithInactivityTimeout(15*time.Minute))) restateSrv.Bind(hydrav1.NewProjectServiceServer(projectWorkflow.New(projectWorkflow.Config{ Logger: logger, DB: database, @@ -332,6 +407,29 @@ func Run(ctx context.Context, cfg Config) error { logger.Error("failed to register with Restate after retries", "error", err.Error()) } else { logger.Info("Successfully registered with Restate") + + // Bootstrap wildcard certificate for default domain if ACME is enabled + if cfg.Acme.Enabled && dnsProvider != nil && cfg.DefaultDomain != "" { + bootstrapWildcardDomain(ctx, database, logger, cfg.DefaultDomain) + } + + // Start the certificate renewal cron job if ACME is enabled + // Use Send with idempotency key so multiple restarts don't create duplicate crons + if cfg.Acme.Enabled && dnsProvider != nil { + certClient := hydrav1.NewCertificateServiceIngressClient(restateClient, "global") + _, startErr := certClient.RenewExpiringCertificates().Send( + ctx, + &hydrav1.RenewExpiringCertificatesRequest{ + DaysBeforeExpiry: 30, + }, + restate.WithIdempotencyKey("cert-renewal-cron-startup"), + ) + if startErr != nil { + logger.Warn("failed to start certificate renewal cron", "error", startErr) + } else { + logger.Info("Certificate renewal cron job started") + } + } } }() } @@ -370,8 +468,10 @@ func Run(ctx context.Context, cfg Config) error { }), connectOptions...)) mux.Handle(ctrlv1connect.NewOpenApiServiceHandler(openapi.New(database, logger), connectOptions...)) mux.Handle(ctrlv1connect.NewAcmeServiceHandler(acme.New(acme.Config{ - DB: database, - Logger: logger, + DB: database, + Logger: logger, + DomainCache: caches.Domains, + ChallengeCache: caches.Challenges, }), connectOptions...)) // Configure server @@ -435,3 +535,59 @@ func Run(ctx context.Context, cfg Config) error { logger.Info("Ctrl server shut down successfully") return nil } + +// bootstrapWildcardDomain ensures a wildcard domain and ACME challenge exist for the default domain. +// This allows the renewal cron to automatically issue a wildcard certificate on startup. +func bootstrapWildcardDomain(ctx context.Context, database db.Database, logger logging.Logger, defaultDomain string) { + wildcardDomain := "*." + defaultDomain + + // Check if the wildcard domain already exists + _, err := db.Query.FindCustomDomainByDomain(ctx, database.RO(), wildcardDomain) + if err == nil { + logger.Info("Wildcard domain already exists", "domain", wildcardDomain) + return + } + if !db.IsNotFound(err) { + logger.Error("Failed to check for existing wildcard domain", "error", err, "domain", wildcardDomain) + return + } + + // Create the custom domain record + domainID := uid.New(uid.DomainPrefix) + now := time.Now().UnixMilli() + + // Use "unkey_internal" as the workspace for platform-managed resources + workspaceID := "unkey_internal" + + err = db.Query.UpsertCustomDomain(ctx, database.RW(), db.UpsertCustomDomainParams{ + ID: domainID, + WorkspaceID: workspaceID, + Domain: wildcardDomain, + ChallengeType: db.CustomDomainsChallengeTypeDNS01, + CreatedAt: now, + UpdatedAt: sql.NullInt64{Int64: now, Valid: true}, + }) + if err != nil { + logger.Error("Failed to create wildcard domain", "error", err, "domain", wildcardDomain) + return + } + + // Create the ACME challenge record with status 'waiting' so the renewal cron picks it up + err = db.Query.InsertAcmeChallenge(ctx, database.RW(), db.InsertAcmeChallengeParams{ + WorkspaceID: workspaceID, + DomainID: domainID, + Token: "", + Authorization: "", + Status: db.AcmeChallengesStatusWaiting, + ChallengeType: db.AcmeChallengesChallengeTypeDNS01, + CreatedAt: now, + UpdatedAt: sql.NullInt64{Int64: now, Valid: true}, + ExpiresAt: 0, // Will be set when certificate is issued + }) + if err != nil { + logger.Error("Failed to create ACME challenge for wildcard domain", "error", err, "domain", wildcardDomain) + return + } + + logger.Info("Bootstrapped wildcard domain for certificate issuance", "domain", wildcardDomain) +} diff --git a/go/apps/ctrl/services/acme/certificate_verification.go b/go/apps/ctrl/services/acme/certificate_verification.go index a3b2af4949..9d073fc950 100644 --- a/go/apps/ctrl/services/acme/certificate_verification.go +++ b/go/apps/ctrl/services/acme/certificate_verification.go @@ -6,6 +6,8 @@ import ( "connectrpc.com/connect" ctrlv1 "github.com/unkeyed/unkey/go/gen/proto/ctrl/v1" + "github.com/unkeyed/unkey/go/internal/services/caches" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/db" ) @@ -15,27 +17,42 @@ func (s *Service) VerifyCertificate( ) (*connect.Response[ctrlv1.VerifyCertificateResponse], error) { res := connect.NewResponse(&ctrlv1.VerifyCertificateResponse{Authorization: ""}) - domain, err := db.Query.FindCustomDomainByDomain(ctx, s.db.RO(), req.Msg.GetDomain()) - if err != nil { - if db.IsNotFound(err) { - return nil, connect.NewError(connect.CodeNotFound, err) - } - + domainName := req.Msg.GetDomain() + token := req.Msg.GetToken() + + // Look up domain with cache + domain, hit, err := s.domainCache.SWR(ctx, domainName, + func(ctx context.Context) (db.CustomDomain, error) { + return db.Query.FindCustomDomainByDomain(ctx, s.db.RO(), domainName) + }, + caches.DefaultFindFirstOp, + ) + if err != nil && !db.IsNotFound(err) { return nil, connect.NewError(connect.CodeInternal, err) } + if hit == cache.Null || db.IsNotFound(err) { + return nil, connect.NewError(connect.CodeNotFound, errors.New("domain not found")) + } - challenge, err := db.Query.FindAcmeChallengeByToken(ctx, s.db.RO(), db.FindAcmeChallengeByTokenParams{ - WorkspaceID: domain.WorkspaceID, - DomainID: domain.ID, - Token: req.Msg.GetToken(), - }) - if err != nil { - if db.IsNotFound(err) { - return nil, connect.NewError(connect.CodeNotFound, err) - } - + // Look up challenge with cache + // Key format: domainID|token + challengeKey := domain.ID + "|" + token + challenge, hit, err := s.challengeCache.SWR(ctx, challengeKey, + func(ctx context.Context) (db.AcmeChallenge, error) { + return db.Query.FindAcmeChallengeByToken(ctx, s.db.RO(), db.FindAcmeChallengeByTokenParams{ + WorkspaceID: domain.WorkspaceID, + DomainID: domain.ID, + Token: token, + }) + }, + caches.DefaultFindFirstOp, + ) + if err != nil && !db.IsNotFound(err) { return nil, connect.NewError(connect.CodeInternal, err) } + if hit == cache.Null || db.IsNotFound(err) { + return nil, connect.NewError(connect.CodeNotFound, errors.New("challenge not found")) + } if challenge.Authorization == "" { return nil, connect.NewError(connect.CodeNotFound, errors.New("challenge hasn't been issued yet")) diff --git a/go/apps/ctrl/services/acme/pem.go b/go/apps/ctrl/services/acme/pem.go index c659776584..f2707982b5 100644 --- a/go/apps/ctrl/services/acme/pem.go +++ b/go/apps/ctrl/services/acme/pem.go @@ -7,6 +7,21 @@ import ( "fmt" ) +// GetCertificateExpiry parses a PEM-encoded certificate and returns its expiration time as Unix milliseconds. +func GetCertificateExpiry(certPEM []byte) (int64, error) { + block, _ := pem.Decode(certPEM) + if block == nil { + return 0, fmt.Errorf("failed to decode certificate PEM") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return 0, fmt.Errorf("failed to parse certificate: %w", err) + } + + return cert.NotAfter.UnixMilli(), nil +} + func privateKeyToString(privateKey *ecdsa.PrivateKey) (string, error) { // Marshal the private key to DER format privKeyBytes, err := x509.MarshalECPrivateKey(privateKey) diff --git a/go/apps/ctrl/services/acme/providers/cloudflare_provider.go b/go/apps/ctrl/services/acme/providers/cloudflare_provider.go index 57d1c74996..5cfaf641b4 100644 --- a/go/apps/ctrl/services/acme/providers/cloudflare_provider.go +++ b/go/apps/ctrl/services/acme/providers/cloudflare_provider.go @@ -1,115 +1,38 @@ package providers import ( - "context" - "database/sql" "fmt" "time" - "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/providers/dns/cloudflare" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/otel/logging" ) -var _ challenge.Provider = (*CloudflareProvider)(nil) -var _ challenge.ProviderTimeout = (*CloudflareProvider)(nil) - -// CloudflareProvider implements the lego challenge.Provider interface for DNS-01 challenges -// It uses Cloudflare DNS to store challenges and tracks them in the database -type CloudflareProvider struct { - db db.Database - logger logging.Logger - provider *cloudflare.DNSProvider - defaultDomain string -} - -type CloudflareProviderConfig struct { - DB db.Database - Logger logging.Logger - APIToken string // Cloudflare API token with Zone:Read, DNS:Edit permissions - DefaultDomain string // Default domain for wildcard certificate handling +type CloudflareConfig struct { + DB db.Database + Logger logging.Logger + APIToken string // Cloudflare API token with Zone:Read, DNS:Edit permissions + DomainCache cache.Cache[string, db.CustomDomain] } -// NewCloudflareProvider creates a new DNS-01 challenge provider using Cloudflare -func NewCloudflareProvider(cfg CloudflareProviderConfig) (*CloudflareProvider, error) { +// NewCloudflareProvider creates a new DNS-01 challenge provider using Cloudflare. +func NewCloudflareProvider(cfg CloudflareConfig) (*Provider, error) { config := cloudflare.NewDefaultConfig() - config.PropagationTimeout = time.Minute * 5 // 5 minutes propagation timeout + config.PropagationTimeout = time.Minute * 5 config.AuthToken = cfg.APIToken - config.TTL = 60 * 10 // 10 minutes TTL for challenge records - provider, err := cloudflare.NewDNSProviderConfig(config) - if err != nil { - return nil, fmt.Errorf("failed to create Cloudflare DNS provider: %w", err) - } - - return &CloudflareProvider{ - db: cfg.DB, - logger: cfg.Logger, - provider: provider, - defaultDomain: cfg.DefaultDomain, - }, nil -} - -// Present creates a DNS TXT record for the ACME challenge -func (p *CloudflareProvider) Present(domain, token, keyAuth string) error { - ctx := context.Background() - - // Find domain in database to track the challenge - // For DNS-01 challenges on the default domain, Let's Encrypt passes the base domain - // but we store the wildcard domain in the database - searchDomain := domain - if domain == p.defaultDomain { - // This is our default domain - look for the wildcard version - searchDomain = "*." + domain - } - - dom, err := db.Query.FindCustomDomainByDomain(ctx, p.db.RO(), searchDomain) - if err != nil { - return fmt.Errorf("failed to find domain %s: %w", searchDomain, err) - } - - p.logger.Info("presenting dns challenge", "domain", domain, "token", "[REDACTED]") + config.TTL = 60 * 10 - // Create the DNS challenge record using Cloudflare - err = p.provider.Present(domain, token, keyAuth) + dns, err := cloudflare.NewDNSProviderConfig(config) if err != nil { - return fmt.Errorf("failed to present DNS challenge for domain %s: %w", domain, err) + return nil, fmt.Errorf("failed to create Cloudflare DNS provider: %w", err) } - // Update the database to track the challenge - err = db.Query.UpdateAcmeChallengePending(ctx, p.db.RW(), db.UpdateAcmeChallengePendingParams{ - DomainID: dom.ID, - Status: db.AcmeChallengesStatusPending, - Token: token, - Authorization: keyAuth, - UpdatedAt: sql.NullInt64{Int64: time.Now().UnixMilli(), Valid: true}, + return NewProvider(ProviderConfig{ + DB: cfg.DB, + Logger: cfg.Logger, + DNS: dns, + DomainCache: cfg.DomainCache, }) - - if err != nil { - // Don't cleanup DNS record - Let's Encrypt still needs it for validation - // The DNS record will be cleaned up later in CleanUp() regardless of success/failure - return fmt.Errorf("failed to store challenge for domain %s: %w", domain, err) - } - - p.logger.Info("dns challenge presented successfully", "domain", domain) - - return nil -} - -// CleanUp removes the DNS TXT record and updates the database -func (p *CloudflareProvider) CleanUp(domain, token, keyAuth string) error { - p.logger.Info("cleaning up dns challenge", "domain", domain) - - // Clean up the DNS record first - err := p.provider.CleanUp(domain, token, keyAuth) - if err != nil { - p.logger.Warn("failed to clean up dns challenge record", "error", err, "domain", domain) - } - - return nil -} - -// Timeout returns the timeout and polling interval for the DNS challenge -func (p *CloudflareProvider) Timeout() (timeout, interval time.Duration) { - return p.provider.Timeout() } diff --git a/go/apps/ctrl/services/acme/providers/http_provider.go b/go/apps/ctrl/services/acme/providers/http_provider.go index dfdfac225a..192a4e9a8d 100644 --- a/go/apps/ctrl/services/acme/providers/http_provider.go +++ b/go/apps/ctrl/services/acme/providers/http_provider.go @@ -1,12 +1,10 @@ package providers import ( - "context" - "database/sql" - "fmt" "time" "github.com/go-acme/lego/v4/challenge" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/otel/logging" ) @@ -14,81 +12,53 @@ import ( var _ challenge.Provider = (*HTTPProvider)(nil) var _ challenge.ProviderTimeout = (*HTTPProvider)(nil) -// HTTPProvider implements the lego challenge.Provider interface for HTTP-01 challenges -// It stores challenges in the database where the gateway can retrieve them -type HTTPProvider struct { +// httpDNS implements the DNSProvider interface for HTTP-01 challenges. +// It stores challenges in the database where the gateway can retrieve them. +type httpDNS struct { db db.Database logger logging.Logger } -type HTTPProviderConfig struct { - DB db.Database - Logger logging.Logger -} - -// NewHTTPProvider creates a new HTTP-01 challenge provider -func NewHTTPProvider(cfg HTTPProviderConfig) *HTTPProvider { - return &HTTPProvider{ - db: cfg.DB, - logger: cfg.Logger, - } -} - -// Present stores the challenge token in the database for the gateway to serve +// Present stores the challenge token in the database for the gateway to serve. // The gateway will intercept requests to /.well-known/acme-challenge/{token} -// and respond with the keyAuth value -func (p *HTTPProvider) Present(domain, token, keyAuth string) error { - ctx := context.Background() - dom, err := db.Query.FindCustomDomainByDomain(ctx, p.db.RO(), domain) - if err != nil { - return fmt.Errorf("failed to find domain %s: %w", domain, err) - } - - // Update the existing challenge record with the token and authorization - err = db.Query.UpdateAcmeChallengePending(ctx, p.db.RW(), db.UpdateAcmeChallengePendingParams{ - DomainID: dom.ID, - Status: db.AcmeChallengesStatusPending, - Token: token, - Authorization: keyAuth, - UpdatedAt: sql.NullInt64{Int64: time.Now().UnixMilli(), Valid: true}, - }) - - if err != nil { - return fmt.Errorf("failed to store challenge for domain %s: %w", domain, err) - } - +// and respond with the keyAuth value. +func (h *httpDNS) Present(domain, token, keyAuth string) error { + h.logger.Info("presenting http-01 challenge", "domain", domain) + // The actual DB update is handled by the generic Provider wrapper return nil } -// CleanUp removes the challenge token from the database after validation -func (p *HTTPProvider) CleanUp(domain, token, keyAuth string) error { - ctx := context.Background() - - dom, err := db.Query.FindCustomDomainByDomain(ctx, p.db.RO(), domain) - if err != nil { - return fmt.Errorf("failed to find domain %s during cleanup: %w", domain, err) - } +// CleanUp is a no-op for HTTP-01 - the token remains in DB until overwritten +func (h *httpDNS) CleanUp(domain, token, keyAuth string) error { + h.logger.Info("cleaning up http-01 challenge", "domain", domain) + return nil +} - // Clear the token and authorization so the gateway stops serving the challenge - // Don't change the status - it should remain as set by the certificate workflow - err = db.Query.ClearAcmeChallengeTokens(ctx, p.db.RW(), db.ClearAcmeChallengeTokensParams{ - Token: "", // Clear token - Authorization: "", // Clear authorization - UpdatedAt: sql.NullInt64{Int64: time.Now().UnixMilli(), Valid: true}, - DomainID: dom.ID, - }) +// Timeout returns custom timeout and check interval for HTTP-01 challenges. +// HTTP challenges typically resolve faster than DNS. +func (h *httpDNS) Timeout() (time.Duration, time.Duration) { + return 90 * time.Second, 3 * time.Second +} - if err != nil { - p.logger.Warn("failed to clean up challenge token", "error", err, "domain", domain) - } +// HTTPProvider wraps httpDNS with the generic Provider for DB tracking and caching. +// This is a type alias to make it clear this is an HTTP-01 provider. +type HTTPProvider = Provider - return nil +type HTTPConfig struct { + DB db.Database + Logger logging.Logger + DomainCache cache.Cache[string, db.CustomDomain] } -// Timeout returns custom timeout and check interval for HTTP-01 challenges -// Returns (timeout, interval) - how long to wait and time between checks -func (p *HTTPProvider) Timeout() (time.Duration, time.Duration) { - // HTTP challenges typically resolve faster than DNS, but give some buffer - // 90 seconds timeout, 3 second check interval - return 90 * time.Second, 3 * time.Second +// NewHTTPProvider creates a new HTTP-01 challenge provider. +func NewHTTPProvider(cfg HTTPConfig) (*HTTPProvider, error) { + return NewProvider(ProviderConfig{ + DB: cfg.DB, + Logger: cfg.Logger, + DNS: &httpDNS{ + db: cfg.DB, + logger: cfg.Logger, + }, + DomainCache: cfg.DomainCache, + }) } diff --git a/go/apps/ctrl/services/acme/providers/provider.go b/go/apps/ctrl/services/acme/providers/provider.go new file mode 100644 index 0000000000..af775d262d --- /dev/null +++ b/go/apps/ctrl/services/acme/providers/provider.go @@ -0,0 +1,141 @@ +package providers + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/go-acme/lego/v4/challenge" + "github.com/unkeyed/unkey/go/internal/services/caches" + "github.com/unkeyed/unkey/go/pkg/assert" + "github.com/unkeyed/unkey/go/pkg/cache" + "github.com/unkeyed/unkey/go/pkg/db" + "github.com/unkeyed/unkey/go/pkg/otel/logging" +) + +// ErrDomainNotFound is returned when a domain is not found in the database. +var ErrDomainNotFound = errors.New("domain not found") + +var _ challenge.Provider = (*Provider)(nil) +var _ challenge.ProviderTimeout = (*Provider)(nil) + +// DNSProvider is the interface for underlying DNS operations. +// Both lego's cloudflare.DNSProvider and route53.DNSProvider implement this. +type DNSProvider interface { + Present(domain, token, keyAuth string) error + CleanUp(domain, token, keyAuth string) error + Timeout() (timeout, interval time.Duration) +} + +// Provider wraps a DNS provider with database tracking and caching. +// It implements the lego challenge.Provider interface. +type Provider struct { + db db.Database + logger logging.Logger + dns DNSProvider + cache cache.Cache[string, db.CustomDomain] +} + +type ProviderConfig struct { + DB db.Database + Logger logging.Logger + DNS DNSProvider + DomainCache cache.Cache[string, db.CustomDomain] +} + +// NewProvider creates a new Provider that wraps a DNS provider with database tracking. +func NewProvider(cfg ProviderConfig) (*Provider, error) { + err := assert.All( + assert.NotNilAndNotZero(cfg.DB, "db is required"), + assert.NotNilAndNotZero(cfg.Logger, "logger is required"), + assert.NotNilAndNotZero(cfg.DNS, "dns provider is required"), + assert.NotNilAndNotZero(cfg.DomainCache, "domain cache is required"), + ) + if err != nil { + return nil, err + } + + return &Provider{ + db: cfg.DB, + logger: cfg.Logger, + dns: cfg.DNS, + cache: cfg.DomainCache, + }, nil +} + +// resolveDomain finds the best matching custom domain for a given domain. +// It queries for both the exact domain and wildcard (*.domain) in a single query, +// preferring exact matches. +func (p *Provider) resolveDomain(ctx context.Context, domain string) (db.CustomDomain, error) { + wildcardDomain := "*." + domain + cacheKey := domain + "|" + wildcardDomain + + dom, hit, err := p.cache.SWR(ctx, cacheKey, + func(ctx context.Context) (db.CustomDomain, error) { + return db.Query.FindCustomDomainByDomainOrWildcard(ctx, p.db.RO(), db.FindCustomDomainByDomainOrWildcardParams{ + Domain: domain, + Domain_2: wildcardDomain, + Domain_3: domain, + }) + }, + caches.DefaultFindFirstOp, + ) + if err != nil { + return db.CustomDomain{}, err + } + if hit == cache.Null { + return db.CustomDomain{}, ErrDomainNotFound + } + return dom, nil +} + +// Present creates a DNS TXT record for the ACME challenge and tracks it in the database. +func (p *Provider) Present(domain, token, keyAuth string) error { + ctx := context.Background() + + // Find domain - tries exact match first, then wildcard (*.domain) + dom, err := p.resolveDomain(ctx, domain) + if err != nil { + return fmt.Errorf("failed to find domain %s: %w", domain, err) + } + + p.logger.Info("presenting dns challenge", "domain", domain, "matched", dom.Domain) + + err = p.dns.Present(domain, token, keyAuth) + if err != nil { + return fmt.Errorf("failed to present DNS challenge for domain %s: %w", domain, err) + } + + err = db.Query.UpdateAcmeChallengePending(ctx, p.db.RW(), db.UpdateAcmeChallengePendingParams{ + DomainID: dom.ID, + Status: db.AcmeChallengesStatusPending, + Token: token, + Authorization: keyAuth, + UpdatedAt: sql.NullInt64{Int64: time.Now().UnixMilli(), Valid: true}, + }) + if err != nil { + return fmt.Errorf("failed to store challenge for domain %s: %w", domain, err) + } + + p.logger.Info("dns challenge presented successfully", "domain", domain) + return nil +} + +// CleanUp removes the DNS TXT record. +func (p *Provider) CleanUp(domain, token, keyAuth string) error { + p.logger.Info("cleaning up dns challenge", "domain", domain) + + err := p.dns.CleanUp(domain, token, keyAuth) + if err != nil { + p.logger.Warn("failed to clean up dns challenge record", "error", err, "domain", domain) + } + + return nil +} + +// Timeout returns the timeout and polling interval for the DNS challenge. +func (p *Provider) Timeout() (timeout, interval time.Duration) { + return p.dns.Timeout() +} diff --git a/go/apps/ctrl/services/acme/providers/route53_provider.go b/go/apps/ctrl/services/acme/providers/route53_provider.go new file mode 100644 index 0000000000..cbab615502 --- /dev/null +++ b/go/apps/ctrl/services/acme/providers/route53_provider.go @@ -0,0 +1,55 @@ +package providers + +import ( + "fmt" + "time" + + "github.com/go-acme/lego/v4/providers/dns/route53" + "github.com/unkeyed/unkey/go/pkg/cache" + "github.com/unkeyed/unkey/go/pkg/db" + "github.com/unkeyed/unkey/go/pkg/otel/logging" +) + +type Route53Config struct { + DB db.Database + Logger logging.Logger + AccessKeyID string + SecretAccessKey string + Region string + DomainCache cache.Cache[string, db.CustomDomain] + // HostedZoneID bypasses zone auto-discovery. Required when domains have CNAMEs + // that would confuse the zone lookup (e.g., wildcard CNAMEs to load balancers). + HostedZoneID string +} + +// NewRoute53Provider creates a new DNS-01 challenge provider using AWS Route53. +// +// Important: LEGO_DISABLE_CNAME_SUPPORT must be set to "true" before calling this +// function to prevent lego from following wildcard CNAMEs and failing zone lookup. +// This should be done once at application startup (see run.go). +// +// HostedZoneID should be provided to explicitly specify which Route53 zone to use, +// bypassing zone auto-discovery. +func NewRoute53Provider(cfg Route53Config) (*Provider, error) { + + config := route53.NewDefaultConfig() + config.PropagationTimeout = time.Minute * 5 + config.TTL = 60 * 10 // 10 minutes + config.AccessKeyID = cfg.AccessKeyID + config.SecretAccessKey = cfg.SecretAccessKey + config.Region = cfg.Region + config.HostedZoneID = cfg.HostedZoneID + config.WaitForRecordSetsChanged = true + + dns, err := route53.NewDNSProviderConfig(config) + if err != nil { + return nil, fmt.Errorf("failed to create Route53 DNS provider: %w", err) + } + + return NewProvider(ProviderConfig{ + DB: cfg.DB, + Logger: cfg.Logger, + DNS: dns, + DomainCache: cfg.DomainCache, + }) +} diff --git a/go/apps/ctrl/services/acme/service.go b/go/apps/ctrl/services/acme/service.go index 851dbc0a6e..9c0fc4649f 100644 --- a/go/apps/ctrl/services/acme/service.go +++ b/go/apps/ctrl/services/acme/service.go @@ -2,19 +2,24 @@ package acme import ( "github.com/unkeyed/unkey/go/gen/proto/ctrl/v1/ctrlv1connect" + "github.com/unkeyed/unkey/go/pkg/cache" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/otel/logging" ) type Service struct { ctrlv1connect.UnimplementedAcmeServiceHandler - db db.Database - logger logging.Logger + db db.Database + logger logging.Logger + domainCache cache.Cache[string, db.CustomDomain] + challengeCache cache.Cache[string, db.AcmeChallenge] } type Config struct { - DB db.Database - Logger logging.Logger + DB db.Database + Logger logging.Logger + DomainCache cache.Cache[string, db.CustomDomain] + ChallengeCache cache.Cache[string, db.AcmeChallenge] } func New(cfg Config) *Service { @@ -22,5 +27,7 @@ func New(cfg Config) *Service { UnimplementedAcmeServiceHandler: ctrlv1connect.UnimplementedAcmeServiceHandler{}, db: cfg.DB, logger: cfg.Logger, + domainCache: cfg.DomainCache, + challengeCache: cfg.ChallengeCache, } } diff --git a/go/apps/ctrl/services/acme/user.go b/go/apps/ctrl/services/acme/user.go index 73ba79ffc6..62e35de5a6 100644 --- a/go/apps/ctrl/services/acme/user.go +++ b/go/apps/ctrl/services/acme/user.go @@ -21,12 +21,13 @@ import ( type AcmeUser struct { WorkspaceID string + EmailDomain string Registration *registration.Resource key crypto.PrivateKey } func (u *AcmeUser) GetEmail() string { - return fmt.Sprintf("%s@%s", u.WorkspaceID, "unkey.fun") + return fmt.Sprintf("%s@%s", u.WorkspaceID, u.EmailDomain) } func (u AcmeUser) GetRegistration() *registration.Resource { @@ -42,6 +43,7 @@ type UserConfig struct { Logger logging.Logger Vault *vault.Service WorkspaceID string + EmailDomain string // Domain for ACME registration emails (e.g., "unkey.com") } func GetOrCreateUser(ctx context.Context, cfg UserConfig) (*lego.Client, error) { @@ -50,6 +52,7 @@ func GetOrCreateUser(ctx context.Context, cfg UserConfig) (*lego.Client, error) if db.IsNotFound(err) { return register(ctx, cfg) } + return nil, fmt.Errorf("failed to find acme user: %w", err) } resp, err := cfg.Vault.Decrypt(ctx, &vaultv1.DecryptRequest{ @@ -65,19 +68,48 @@ func GetOrCreateUser(ctx context.Context, cfg UserConfig) (*lego.Client, error) return nil, fmt.Errorf("failed to convert private key: %w", err) } - config := lego.NewConfig(&AcmeUser{ - //nolint: exhaustruct - Registration: ®istration.Resource{ + user := &AcmeUser{ + key: key, + WorkspaceID: cfg.WorkspaceID, + EmailDomain: cfg.EmailDomain, + Registration: nil, + } + + // If we have a valid registration URI, use it + if foundUser.RegistrationUri.Valid && foundUser.RegistrationUri.String != "" { + //nolint:exhaustruct // external library type + user.Registration = ®istration.Resource{ URI: foundUser.RegistrationUri.String, - }, - key: key, - WorkspaceID: cfg.WorkspaceID, - }) + } + } + + config := lego.NewConfig(user) client, err := lego.NewClient(config) if err != nil { return nil, fmt.Errorf("failed to create ACME client: %w", err) } + // If user exists but doesn't have a registration URI, complete the registration + if !foundUser.RegistrationUri.Valid || foundUser.RegistrationUri.String == "" { + cfg.Logger.Info("acme user missing registration, completing registration", + "workspace_id", cfg.WorkspaceID, + ) + + reg, regErr := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) + if regErr != nil { + return nil, fmt.Errorf("failed to complete acme registration: %w", regErr) + } + + user.Registration = reg + + if updateErr := db.Query.UpdateAcmeUserRegistrationURI(ctx, cfg.DB.RW(), db.UpdateAcmeUserRegistrationURIParams{ + ID: foundUser.ID, + RegistrationUri: sql.NullString{Valid: true, String: reg.URI}, + }); updateErr != nil { + cfg.Logger.Warn("failed to persist registration URI", "error", updateErr) + } + } + return client, nil } @@ -91,6 +123,7 @@ func register(ctx context.Context, cfg UserConfig) (*lego.Client, error) { Registration: nil, key: privateKey, WorkspaceID: cfg.WorkspaceID, + EmailDomain: cfg.EmailDomain, } privKeyString, err := privateKeyToString(privateKey) diff --git a/go/apps/ctrl/workflows/certificate/process_challenge_handler.go b/go/apps/ctrl/workflows/certificate/process_challenge_handler.go index d74b8f4d26..0966f77fb0 100644 --- a/go/apps/ctrl/workflows/certificate/process_challenge_handler.go +++ b/go/apps/ctrl/workflows/certificate/process_challenge_handler.go @@ -1,26 +1,28 @@ package certificate import ( + "context" "database/sql" + "fmt" "time" "github.com/go-acme/lego/v4/certificate" + "github.com/go-acme/lego/v4/lego" restate "github.com/restatedev/sdk-go" + "github.com/unkeyed/unkey/go/apps/ctrl/services/acme" hydrav1 "github.com/unkeyed/unkey/go/gen/proto/hydra/v1" + vaultv1 "github.com/unkeyed/unkey/go/gen/proto/vault/v1" "github.com/unkeyed/unkey/go/pkg/db" + "github.com/unkeyed/unkey/go/pkg/retry" "github.com/unkeyed/unkey/go/pkg/uid" ) // EncryptedCertificate holds a certificate and its encrypted private key. type EncryptedCertificate struct { - // Certificate is the PEM-encoded certificate. - Certificate string - - // EncryptedPrivateKey is the encrypted PEM-encoded private key. + CertificateID string + Certificate string EncryptedPrivateKey string - - // ExpiresAt is the certificate expiration time as Unix milliseconds. - ExpiresAt int64 + ExpiresAt int64 } // ProcessChallenge handles the complete ACME certificate challenge flow. @@ -29,20 +31,12 @@ type EncryptedCertificate struct { // an SSL/TLS certificate for a domain. Each step is wrapped in restate.Run for durability, // allowing the workflow to resume from the last completed step if interrupted. // -// The workflow performs these steps: -// 1. Resolve domain - Verify domain exists and belongs to workspace -// 2. Claim challenge - Acquire exclusive lock on the domain challenge -// 3. Setup ACME client - Get or create ACME account (TODO: not yet implemented) -// 4. Obtain certificate - Request certificate from CA (TODO: not yet implemented) -// 5. Persist certificate - Store in DB for gateway access -// 6. Mark verified - Update challenge status with expiry time -// -// Returns status "success" if certificate was issued, "failed" if the ACME challenge -// failed or ACME client setup is not yet implemented. +// Uses the saga pattern: if any step fails after claiming the challenge, the deferred +// compensation marks the challenge as failed. func (s *Service) ProcessChallenge( ctx restate.ObjectContext, req *hydrav1.ProcessChallengeRequest, -) (*hydrav1.ProcessChallengeResponse, error) { +) (resp *hydrav1.ProcessChallengeResponse, err error) { s.logger.Info("starting certificate challenge", "workspace_id", req.GetWorkspaceId(), "domain", req.GetDomain(), @@ -51,7 +45,7 @@ func (s *Service) ProcessChallenge( // Step 1: Resolve domain dom, err := restate.Run(ctx, func(stepCtx restate.RunContext) (db.CustomDomain, error) { return db.Query.FindCustomDomainByDomain(stepCtx, s.db.RO(), req.GetDomain()) - }, restate.WithName("resolving domain")) + }, restate.WithName("resolve domain")) if err != nil { return nil, err } @@ -63,57 +57,24 @@ func (s *Service) ProcessChallenge( Status: db.AcmeChallengesStatusPending, UpdatedAt: sql.NullInt64{Int64: time.Now().UnixMilli(), Valid: true}, }) - }, restate.WithName("acquiring challenge")) + }, restate.WithName("claim challenge")) if err != nil { return nil, err } - // Step 3: Get or create ACME client for this workspace - _, err = restate.Run(ctx, func(stepCtx restate.RunContext) (*certificate.Resource, error) { - // nolint: godox - // TODO: Get ACME client for workspace - // This requires implementing GetOrCreateUser from acme/user.go - // and setting up challenge providers (HTTP-01, DNS-01) - - // For now, return error indicating this needs ACME client setup - return nil, restate.TerminalError( - err, - 500, - ) - }, restate.WithName("setup acme client")) - if err != nil { - _, _ = restate.Run(ctx, func(stepCtx restate.RunContext) (restate.Void, error) { - if updateErr := db.Query.UpdateAcmeChallengeStatus(stepCtx, s.db.RW(), db.UpdateAcmeChallengeStatusParams{ - DomainID: dom.ID, - Status: db.AcmeChallengesStatusFailed, - UpdatedAt: sql.NullInt64{Valid: true, Int64: time.Now().UnixMilli()}, - }); updateErr != nil { - s.logger.Error("failed to update challenge status", "error", updateErr, "domain_id", dom.ID) - } - return restate.Void{}, nil - }, restate.WithName("mark challenge failed")) - return &hydrav1.ProcessChallengeResponse{ - CertificateId: "", - Status: "failed", - }, nil - } - - // Step 4: Obtain or renew certificate - cert, err := restate.Run(ctx, func(stepCtx restate.RunContext) (EncryptedCertificate, error) { - _, err = db.Query.FindCertificateByHostname(stepCtx, s.db.RO(), req.GetDomain()) - if err != nil && !db.IsNotFound(err) { - return EncryptedCertificate{}, err + // Compensation: if anything fails after claiming, mark challenge as failed + defer func() { + if err != nil || (resp != nil && resp.GetStatus() == "failed") { + s.markChallengeFailed(ctx, dom.ID) } + }() - // nolint: godox - // TODO: Implement certificate obtain/renew logic - // This requires the ACME client from step 3 - - return EncryptedCertificate{}, restate.TerminalError( - err, - 500, - ) - }, restate.WithName("obtaining certificate")) + // Step 3: Obtain certificate via DNS-01 challenge + // Note: ACME client creation happens inside obtainCertificate because lego.Client + // cannot be serialized through Restate (has internal pointers) + cert, err := restate.Run(ctx, func(stepCtx restate.RunContext) (EncryptedCertificate, error) { + return s.obtainCertificate(stepCtx, req.GetWorkspaceId(), dom, req.GetDomain()) + }, restate.WithName("obtain certificate")) if err != nil { return &hydrav1.ProcessChallengeResponse{ CertificateId: "", @@ -122,18 +83,9 @@ func (s *Service) ProcessChallenge( } // Step 5: Persist certificate to DB - _, err = restate.Run(ctx, func(stepCtx restate.RunContext) (restate.Void, error) { - now := time.Now().UnixMilli() - return restate.Void{}, db.Query.InsertCertificate(stepCtx, s.db.RW(), db.InsertCertificateParams{ - ID: uid.New(uid.CertificatePrefix), - WorkspaceID: dom.WorkspaceID, - Hostname: req.GetDomain(), - Certificate: cert.Certificate, - EncryptedPrivateKey: cert.EncryptedPrivateKey, - CreatedAt: now, - UpdatedAt: sql.NullInt64{Valid: false, Int64: 0}, - }) - }, restate.WithName("persisting certificate")) + certID, err := restate.Run(ctx, func(stepCtx restate.RunContext) (string, error) { + return s.persistCertificate(stepCtx, dom, req.GetDomain(), cert) + }, restate.WithName("persist certificate")) if err != nil { return nil, err } @@ -146,18 +98,165 @@ func (s *Service) ProcessChallenge( UpdatedAt: sql.NullInt64{Valid: true, Int64: time.Now().UnixMilli()}, DomainID: dom.ID, }) - }, restate.WithName("completing challenge")) + }, restate.WithName("mark verified")) if err != nil { return nil, err } s.logger.Info("certificate challenge completed successfully", "domain", req.GetDomain(), + "certificate_id", certID, "expires_at", cert.ExpiresAt, ) return &hydrav1.ProcessChallengeResponse{ - CertificateId: "", + CertificateId: certID, Status: "success", }, nil } + +// globalAcmeUserID is the fixed user ID for the single global ACME account +const globalAcmeUserID = "acme" + +// isWildcard returns true if the domain starts with "*." +func isWildcard(domain string) bool { + return len(domain) > 2 && domain[0] == '*' && domain[1] == '.' +} + +func (s *Service) getOrCreateAcmeClient(ctx context.Context, domain string) (*lego.Client, error) { + // Use a single global ACME user for all certificates + client, err := acme.GetOrCreateUser(ctx, acme.UserConfig{ + DB: s.db, + Logger: s.logger, + Vault: s.vault, + WorkspaceID: globalAcmeUserID, + EmailDomain: s.emailDomain, + }) + if err != nil { + return nil, fmt.Errorf("failed to get/create ACME user: %w", err) + } + + // Wildcard certificates require DNS-01 challenge + // Regular domains use HTTP-01 (faster, no DNS propagation wait) + if isWildcard(domain) { + if s.dnsProvider == nil { + return nil, fmt.Errorf("DNS provider required for wildcard certificate: %s", domain) + } + if err := client.Challenge.SetDNS01Provider(s.dnsProvider); err != nil { + return nil, fmt.Errorf("failed to set DNS-01 provider: %w", err) + } + s.logger.Info("using DNS-01 challenge for wildcard domain", "domain", domain) + } else { + if s.httpProvider == nil { + return nil, fmt.Errorf("HTTP provider required for certificate: %s", domain) + } + if err := client.Challenge.SetHTTP01Provider(s.httpProvider); err != nil { + return nil, fmt.Errorf("failed to set HTTP-01 provider: %w", err) + } + s.logger.Info("using HTTP-01 challenge for domain", "domain", domain) + } + + return client, nil +} + +func (s *Service) obtainCertificate(ctx context.Context, _ string, dom db.CustomDomain, domain string) (EncryptedCertificate, error) { + s.logger.Info("creating ACME client", "domain", domain) + client, err := s.getOrCreateAcmeClient(ctx, domain) + if err != nil { + return EncryptedCertificate{}, fmt.Errorf("failed to create ACME client: %w", err) + } + s.logger.Info("ACME client created, requesting certificate", "domain", domain) + + // Request certificate from Let's Encrypt with retry and exponential backoff + //nolint:exhaustruct // external library type + request := certificate.ObtainRequest{ + Domains: []string{domain}, + Bundle: true, + } + + var certificates *certificate.Resource + retrier := retry.New( + retry.Attempts(3), + retry.Backoff(func(attempt int) time.Duration { + // Exponential backoff: 30s, 60s, 120s (capped at 5min) + return min(time.Duration(30<<(attempt-1))*time.Second, 5*time.Minute) + }), + ) + + err = retrier.Do(func() error { + var obtainErr error + certificates, obtainErr = client.Certificate.Obtain(request) + return obtainErr + }) + if err != nil { + return EncryptedCertificate{}, fmt.Errorf("failed to obtain certificate after retries: %w", err) + } + + // Parse certificate to get expiration + expiresAt, err := acme.GetCertificateExpiry(certificates.Certificate) + if err != nil { + s.logger.Warn("failed to parse certificate expiry, using default", "error", err) + expiresAt = time.Now().Add(90 * 24 * time.Hour).UnixMilli() + } + + // Encrypt the private key before storage + encryptResp, err := s.vault.Encrypt(ctx, &vaultv1.EncryptRequest{ + Keyring: dom.WorkspaceID, + Data: string(certificates.PrivateKey), + }) + if err != nil { + return EncryptedCertificate{}, fmt.Errorf("failed to encrypt private key: %w", err) + } + + return EncryptedCertificate{ + CertificateID: uid.New(uid.CertificatePrefix), + Certificate: string(certificates.Certificate), + EncryptedPrivateKey: encryptResp.GetEncrypted(), + ExpiresAt: expiresAt, + }, nil +} + +func (s *Service) persistCertificate(ctx context.Context, dom db.CustomDomain, domain string, cert EncryptedCertificate) (string, error) { + now := time.Now().UnixMilli() + + // Check if certificate already exists for this hostname (renewal case) + // If it does, we keep the existing ID; otherwise use the new ID + certID := cert.CertificateID + existingCert, err := db.Query.FindCertificateByHostname(ctx, s.db.RO(), domain) + if err != nil && !db.IsNotFound(err) { + return "", fmt.Errorf("failed to check for existing certificate: %w", err) + } + if err == nil { + // Renewal: keep the existing certificate ID + certID = existingCert.ID + } + + // InsertCertificate uses ON DUPLICATE KEY UPDATE, so this handles both insert and renewal + err = db.Query.InsertCertificate(ctx, s.db.RW(), db.InsertCertificateParams{ + ID: certID, + WorkspaceID: dom.WorkspaceID, + Hostname: domain, + Certificate: cert.Certificate, + EncryptedPrivateKey: cert.EncryptedPrivateKey, + CreatedAt: now, + UpdatedAt: sql.NullInt64{Valid: true, Int64: now}, + }) + if err != nil { + return "", fmt.Errorf("failed to persist certificate: %w", err) + } + + return certID, nil +} + +func (s *Service) markChallengeFailed(ctx restate.ObjectContext, domainID string) { + _, _ = restate.Run(ctx, func(stepCtx restate.RunContext) (restate.Void, error) { + if updateErr := db.Query.UpdateAcmeChallengeStatus(stepCtx, s.db.RW(), db.UpdateAcmeChallengeStatusParams{ + DomainID: domainID, + Status: db.AcmeChallengesStatusFailed, + UpdatedAt: sql.NullInt64{Valid: true, Int64: time.Now().UnixMilli()}, + }); updateErr != nil { + s.logger.Error("failed to update challenge status", "error", updateErr, "domain_id", domainID) + } + return restate.Void{}, nil + }, restate.WithName("mark failed")) +} diff --git a/go/apps/ctrl/workflows/certificate/renew_handler.go b/go/apps/ctrl/workflows/certificate/renew_handler.go new file mode 100644 index 0000000000..326276bec1 --- /dev/null +++ b/go/apps/ctrl/workflows/certificate/renew_handler.go @@ -0,0 +1,104 @@ +package certificate + +import ( + "time" + + restate "github.com/restatedev/sdk-go" + hydrav1 "github.com/unkeyed/unkey/go/gen/proto/hydra/v1" + "github.com/unkeyed/unkey/go/pkg/db" +) + +const ( + // renewalInterval is how often the certificate renewal check runs + renewalInterval = 24 * time.Hour + + // renewalKey is the virtual object key for the singleton renewal job + renewalKey = "global" +) + +// RenewExpiringCertificates scans for certificates expiring soon and triggers renewal. +// This is a self-scheduling Restate cron job - after completing, it schedules itself +// to run again after renewalInterval (24 hours). +// +// To start the cron job, call this handler once with key "global". It will then +// automatically reschedule itself forever. +func (s *Service) RenewExpiringCertificates( + ctx restate.ObjectContext, + req *hydrav1.RenewExpiringCertificatesRequest, +) (*hydrav1.RenewExpiringCertificatesResponse, error) { + s.logger.Info("starting certificate renewal check") + + challengeTypes := []db.AcmeChallengesChallengeType{ + db.AcmeChallengesChallengeTypeDNS01, + db.AcmeChallengesChallengeTypeHTTP01, + } + + // Find all challenges that need processing (waiting or expiring soon) + challenges, err := restate.Run(ctx, func(stepCtx restate.RunContext) ([]db.ListExecutableChallengesRow, error) { + return db.Query.ListExecutableChallenges(stepCtx, s.db.RO(), challengeTypes) + }, restate.WithName("list expiring certificates")) + if err != nil { + return nil, err + } + + s.logger.Info("found certificates to process", "count", len(challenges)) + + var failedDomains []string + renewalsTriggered := int32(0) + + for _, challenge := range challenges { + s.logger.Info("triggering certificate renewal", + "domain", challenge.Domain, + "workspace_id", challenge.WorkspaceID, + ) + + // Trigger the ProcessChallenge workflow for this domain (fire-and-forget) + client := hydrav1.NewCertificateServiceClient(ctx, challenge.Domain) + sendErr := client.ProcessChallenge().Send(&hydrav1.ProcessChallengeRequest{ + WorkspaceId: challenge.WorkspaceID, + Domain: challenge.Domain, + }) + + if sendErr != nil { + s.logger.Warn("failed to trigger renewal", + "domain", challenge.Domain, + "error", sendErr, + ) + failedDomains = append(failedDomains, challenge.Domain) + } else { + renewalsTriggered++ + } + + // Small delay between requests to avoid overwhelming the system + if err := restate.Sleep(ctx, 100*time.Millisecond); err != nil { + return nil, err + } + } + + s.logger.Info("certificate renewal check completed", + "checked", len(challenges), + "triggered", renewalsTriggered, + "failed", len(failedDomains), + ) + + // Schedule next run - this creates the Restate cron pattern + // The job will run again after renewalInterval + // Use idempotency key based on the next run date to prevent duplicate schedules + nextRunDate := time.Now().Add(renewalInterval).Format("2006-01-02") + selfClient := hydrav1.NewCertificateServiceClient(ctx, renewalKey) + selfClient.RenewExpiringCertificates().Send( + &hydrav1.RenewExpiringCertificatesRequest{ + DaysBeforeExpiry: 30, + }, + restate.WithDelay(renewalInterval), + restate.WithIdempotencyKey("cert-renewal-"+nextRunDate), + ) + + s.logger.Info("scheduled next certificate renewal check", "delay", renewalInterval) + + return &hydrav1.RenewExpiringCertificatesResponse{ + CertificatesChecked: int32(len(challenges)), + RenewalsTriggered: renewalsTriggered, + FailedDomains: failedDomains, + }, nil +} diff --git a/go/apps/ctrl/workflows/certificate/service.go b/go/apps/ctrl/workflows/certificate/service.go index 6b0eb7969d..ad1ebec1c5 100644 --- a/go/apps/ctrl/workflows/certificate/service.go +++ b/go/apps/ctrl/workflows/certificate/service.go @@ -1,6 +1,7 @@ package certificate import ( + "github.com/go-acme/lego/v4/challenge" hydrav1 "github.com/unkeyed/unkey/go/gen/proto/hydra/v1" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/otel/logging" @@ -18,9 +19,13 @@ import ( // and rate limit violations. type Service struct { hydrav1.UnimplementedCertificateServiceServer - db db.Database - vault *vault.Service - logger logging.Logger + db db.Database + vault *vault.Service + logger logging.Logger + emailDomain string + defaultDomain string + dnsProvider challenge.Provider // For DNS-01 challenges (wildcard certs) + httpProvider challenge.Provider // For HTTP-01 challenges (regular certs) } var _ hydrav1.CertificateServiceServer = (*Service)(nil) @@ -35,6 +40,18 @@ type Config struct { // Logger for structured logging. Logger logging.Logger + + // EmailDomain is the domain used for ACME account emails (workspace_id@domain) + EmailDomain string + + // DefaultDomain is the base domain for wildcard certificates + DefaultDomain string + + // DNSProvider is the challenge provider for DNS-01 challenges (wildcard certs) + DNSProvider challenge.Provider + + // HTTPProvider is the challenge provider for HTTP-01 challenges (regular certs) + HTTPProvider challenge.Provider } // New creates a new certificate service instance. @@ -44,5 +61,9 @@ func New(cfg Config) *Service { db: cfg.DB, vault: cfg.Vault, logger: cfg.Logger, + emailDomain: cfg.EmailDomain, + defaultDomain: cfg.DefaultDomain, + dnsProvider: cfg.DNSProvider, + httpProvider: cfg.HTTPProvider, } } diff --git a/go/cmd/ctrl/main.go b/go/cmd/ctrl/main.go index 8b47d7cb74..5a1e1d5806 100644 --- a/go/cmd/ctrl/main.go +++ b/go/cmd/ctrl/main.go @@ -113,9 +113,19 @@ var Cmd = &cli.Command{ cli.EnvVar("UNKEY_DEPOT_PROJECT_REGION"), cli.Default("us-east-1")), cli.Bool("acme-enabled", "Enable Let's Encrypt for acme challenges", cli.EnvVar("UNKEY_ACME_ENABLED")), + cli.String("acme-email-domain", "Domain for ACME registration emails (workspace_id@domain)", cli.Default("unkey.com"), cli.EnvVar("UNKEY_ACME_EMAIL_DOMAIN")), + + // Cloudflare DNS provider cli.Bool("acme-cloudflare-enabled", "Enable Cloudflare for wildcard certificates", cli.EnvVar("UNKEY_ACME_CLOUDFLARE_ENABLED")), cli.String("acme-cloudflare-api-token", "Cloudflare API token for Let's Encrypt", cli.EnvVar("UNKEY_ACME_CLOUDFLARE_API_TOKEN")), + // Route53 DNS provider + cli.Bool("acme-route53-enabled", "Enable Route53 for DNS-01 challenges", cli.EnvVar("UNKEY_ACME_ROUTE53_ENABLED")), + cli.String("acme-route53-access-key-id", "AWS access key ID for Route53", cli.EnvVar("UNKEY_ACME_ROUTE53_ACCESS_KEY_ID")), + cli.String("acme-route53-secret-access-key", "AWS secret access key for Route53", cli.EnvVar("UNKEY_ACME_ROUTE53_SECRET_ACCESS_KEY")), + cli.String("acme-route53-region", "AWS region for Route53", cli.Default("us-east-1"), cli.EnvVar("UNKEY_ACME_ROUTE53_REGION")), + cli.String("acme-route53-hosted-zone-id", "Route53 hosted zone ID (bypasses auto-discovery, required when wildcard CNAMEs exist)", cli.EnvVar("UNKEY_ACME_ROUTE53_HOSTED_ZONE_ID")), + cli.String("default-domain", "Default domain for auto-generated hostnames", cli.Default("unkey.app"), cli.EnvVar("UNKEY_DEFAULT_DOMAIN")), // Restate Configuration @@ -216,11 +226,19 @@ func action(ctx context.Context, cmd *cli.Command) error { // Acme configuration Acme: ctrl.AcmeConfig{ - Enabled: cmd.Bool("acme-enabled"), + Enabled: cmd.Bool("acme-enabled"), + EmailDomain: cmd.String("acme-email-domain"), Cloudflare: ctrl.CloudflareConfig{ Enabled: cmd.Bool("acme-cloudflare-enabled"), ApiToken: cmd.String("acme-cloudflare-api-token"), }, + Route53: ctrl.Route53Config{ + Enabled: cmd.Bool("acme-route53-enabled"), + AccessKeyID: cmd.String("acme-route53-access-key-id"), + SecretAccessKey: cmd.String("acme-route53-secret-access-key"), + Region: cmd.String("acme-route53-region"), + HostedZoneID: cmd.String("acme-route53-hosted-zone-id"), + }, }, DefaultDomain: cmd.String("default-domain"), diff --git a/go/gen/proto/hydra/v1/certificate.pb.go b/go/gen/proto/hydra/v1/certificate.pb.go index 49c0a3a2d3..38ad891ea4 100644 --- a/go/gen/proto/hydra/v1/certificate.pb.go +++ b/go/gen/proto/hydra/v1/certificate.pb.go @@ -126,6 +126,111 @@ func (x *ProcessChallengeResponse) GetStatus() string { return "" } +type RenewExpiringCertificatesRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Number of days before expiry to start renewal (default: 30) + DaysBeforeExpiry int32 `protobuf:"varint,1,opt,name=days_before_expiry,json=daysBeforeExpiry,proto3" json:"days_before_expiry,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RenewExpiringCertificatesRequest) Reset() { + *x = RenewExpiringCertificatesRequest{} + mi := &file_hydra_v1_certificate_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RenewExpiringCertificatesRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RenewExpiringCertificatesRequest) ProtoMessage() {} + +func (x *RenewExpiringCertificatesRequest) ProtoReflect() protoreflect.Message { + mi := &file_hydra_v1_certificate_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RenewExpiringCertificatesRequest.ProtoReflect.Descriptor instead. +func (*RenewExpiringCertificatesRequest) Descriptor() ([]byte, []int) { + return file_hydra_v1_certificate_proto_rawDescGZIP(), []int{2} +} + +func (x *RenewExpiringCertificatesRequest) GetDaysBeforeExpiry() int32 { + if x != nil { + return x.DaysBeforeExpiry + } + return 0 +} + +type RenewExpiringCertificatesResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + CertificatesChecked int32 `protobuf:"varint,1,opt,name=certificates_checked,json=certificatesChecked,proto3" json:"certificates_checked,omitempty"` + RenewalsTriggered int32 `protobuf:"varint,2,opt,name=renewals_triggered,json=renewalsTriggered,proto3" json:"renewals_triggered,omitempty"` + FailedDomains []string `protobuf:"bytes,3,rep,name=failed_domains,json=failedDomains,proto3" json:"failed_domains,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RenewExpiringCertificatesResponse) Reset() { + *x = RenewExpiringCertificatesResponse{} + mi := &file_hydra_v1_certificate_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RenewExpiringCertificatesResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RenewExpiringCertificatesResponse) ProtoMessage() {} + +func (x *RenewExpiringCertificatesResponse) ProtoReflect() protoreflect.Message { + mi := &file_hydra_v1_certificate_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RenewExpiringCertificatesResponse.ProtoReflect.Descriptor instead. +func (*RenewExpiringCertificatesResponse) Descriptor() ([]byte, []int) { + return file_hydra_v1_certificate_proto_rawDescGZIP(), []int{3} +} + +func (x *RenewExpiringCertificatesResponse) GetCertificatesChecked() int32 { + if x != nil { + return x.CertificatesChecked + } + return 0 +} + +func (x *RenewExpiringCertificatesResponse) GetRenewalsTriggered() int32 { + if x != nil { + return x.RenewalsTriggered + } + return 0 +} + +func (x *RenewExpiringCertificatesResponse) GetFailedDomains() []string { + if x != nil { + return x.FailedDomains + } + return nil +} + var File_hydra_v1_certificate_proto protoreflect.FileDescriptor const file_hydra_v1_certificate_proto_rawDesc = "" + @@ -136,9 +241,16 @@ const file_hydra_v1_certificate_proto_rawDesc = "" + "\x06domain\x18\x02 \x01(\tR\x06domain\"Y\n" + "\x18ProcessChallengeResponse\x12%\n" + "\x0ecertificate_id\x18\x01 \x01(\tR\rcertificateId\x12\x16\n" + - "\x06status\x18\x02 \x01(\tR\x06status2w\n" + + "\x06status\x18\x02 \x01(\tR\x06status\"P\n" + + " RenewExpiringCertificatesRequest\x12,\n" + + "\x12days_before_expiry\x18\x01 \x01(\x05R\x10daysBeforeExpiry\"\xac\x01\n" + + "!RenewExpiringCertificatesResponse\x121\n" + + "\x14certificates_checked\x18\x01 \x01(\x05R\x13certificatesChecked\x12-\n" + + "\x12renewals_triggered\x18\x02 \x01(\x05R\x11renewalsTriggered\x12%\n" + + "\x0efailed_domains\x18\x03 \x03(\tR\rfailedDomains2\xef\x01\n" + "\x12CertificateService\x12[\n" + - "\x10ProcessChallenge\x12!.hydra.v1.ProcessChallengeRequest\x1a\".hydra.v1.ProcessChallengeResponse\"\x00\x1a\x04\x98\x80\x01\x01B\x99\x01\n" + + "\x10ProcessChallenge\x12!.hydra.v1.ProcessChallengeRequest\x1a\".hydra.v1.ProcessChallengeResponse\"\x00\x12v\n" + + "\x19RenewExpiringCertificates\x12*.hydra.v1.RenewExpiringCertificatesRequest\x1a+.hydra.v1.RenewExpiringCertificatesResponse\"\x00\x1a\x04\x98\x80\x01\x01B\x99\x01\n" + "\fcom.hydra.v1B\x10CertificateProtoP\x01Z6github.com/unkeyed/unkey/go/gen/proto/hydra/v1;hydrav1\xa2\x02\x03HXX\xaa\x02\bHydra.V1\xca\x02\bHydra\\V1\xe2\x02\x14Hydra\\V1\\GPBMetadata\xea\x02\tHydra::V1b\x06proto3" var ( @@ -153,16 +265,20 @@ func file_hydra_v1_certificate_proto_rawDescGZIP() []byte { return file_hydra_v1_certificate_proto_rawDescData } -var file_hydra_v1_certificate_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_hydra_v1_certificate_proto_msgTypes = make([]protoimpl.MessageInfo, 4) var file_hydra_v1_certificate_proto_goTypes = []any{ - (*ProcessChallengeRequest)(nil), // 0: hydra.v1.ProcessChallengeRequest - (*ProcessChallengeResponse)(nil), // 1: hydra.v1.ProcessChallengeResponse + (*ProcessChallengeRequest)(nil), // 0: hydra.v1.ProcessChallengeRequest + (*ProcessChallengeResponse)(nil), // 1: hydra.v1.ProcessChallengeResponse + (*RenewExpiringCertificatesRequest)(nil), // 2: hydra.v1.RenewExpiringCertificatesRequest + (*RenewExpiringCertificatesResponse)(nil), // 3: hydra.v1.RenewExpiringCertificatesResponse } var file_hydra_v1_certificate_proto_depIdxs = []int32{ 0, // 0: hydra.v1.CertificateService.ProcessChallenge:input_type -> hydra.v1.ProcessChallengeRequest - 1, // 1: hydra.v1.CertificateService.ProcessChallenge:output_type -> hydra.v1.ProcessChallengeResponse - 1, // [1:2] is the sub-list for method output_type - 0, // [0:1] is the sub-list for method input_type + 2, // 1: hydra.v1.CertificateService.RenewExpiringCertificates:input_type -> hydra.v1.RenewExpiringCertificatesRequest + 1, // 2: hydra.v1.CertificateService.ProcessChallenge:output_type -> hydra.v1.ProcessChallengeResponse + 3, // 3: hydra.v1.CertificateService.RenewExpiringCertificates:output_type -> hydra.v1.RenewExpiringCertificatesResponse + 2, // [2:4] is the sub-list for method output_type + 0, // [0:2] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -179,7 +295,7 @@ func file_hydra_v1_certificate_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_hydra_v1_certificate_proto_rawDesc), len(file_hydra_v1_certificate_proto_rawDesc)), NumEnums: 0, - NumMessages: 2, + NumMessages: 4, NumExtensions: 0, NumServices: 1, }, diff --git a/go/gen/proto/hydra/v1/certificate_restate.pb.go b/go/gen/proto/hydra/v1/certificate_restate.pb.go index 50586574c7..940e5bf5b7 100644 --- a/go/gen/proto/hydra/v1/certificate_restate.pb.go +++ b/go/gen/proto/hydra/v1/certificate_restate.pb.go @@ -20,6 +20,10 @@ type CertificateServiceClient interface { // ProcessChallenge handles the complete ACME certificate challenge flow // Key: domain name (ensures only one challenge per domain at a time) ProcessChallenge(opts ...sdk_go.ClientOption) sdk_go.Client[*ProcessChallengeRequest, *ProcessChallengeResponse] + // RenewExpiringCertificates checks for certificates expiring soon and renews them. + // This should be called periodically (e.g., daily via cron). + // Key: "global" (single instance ensures no duplicate renewal runs) + RenewExpiringCertificates(opts ...sdk_go.ClientOption) sdk_go.Client[*RenewExpiringCertificatesRequest, *RenewExpiringCertificatesResponse] } type certificateServiceClient struct { @@ -44,6 +48,14 @@ func (c *certificateServiceClient) ProcessChallenge(opts ...sdk_go.ClientOption) return sdk_go.WithRequestType[*ProcessChallengeRequest](sdk_go.Object[*ProcessChallengeResponse](c.ctx, "hydra.v1.CertificateService", c.key, "ProcessChallenge", cOpts...)) } +func (c *certificateServiceClient) RenewExpiringCertificates(opts ...sdk_go.ClientOption) sdk_go.Client[*RenewExpiringCertificatesRequest, *RenewExpiringCertificatesResponse] { + cOpts := c.options + if len(opts) > 0 { + cOpts = append(append([]sdk_go.ClientOption{}, cOpts...), opts...) + } + return sdk_go.WithRequestType[*RenewExpiringCertificatesRequest](sdk_go.Object[*RenewExpiringCertificatesResponse](c.ctx, "hydra.v1.CertificateService", c.key, "RenewExpiringCertificates", cOpts...)) +} + // CertificateServiceIngressClient is the ingress client API for hydra.v1.CertificateService service. // // This client is used to call the service from outside of a Restate context. @@ -51,6 +63,10 @@ type CertificateServiceIngressClient interface { // ProcessChallenge handles the complete ACME certificate challenge flow // Key: domain name (ensures only one challenge per domain at a time) ProcessChallenge() ingress.Requester[*ProcessChallengeRequest, *ProcessChallengeResponse] + // RenewExpiringCertificates checks for certificates expiring soon and renews them. + // This should be called periodically (e.g., daily via cron). + // Key: "global" (single instance ensures no duplicate renewal runs) + RenewExpiringCertificates() ingress.Requester[*RenewExpiringCertificatesRequest, *RenewExpiringCertificatesResponse] } type certificateServiceIngressClient struct { @@ -72,6 +88,11 @@ func (c *certificateServiceIngressClient) ProcessChallenge() ingress.Requester[* return ingress.NewRequester[*ProcessChallengeRequest, *ProcessChallengeResponse](c.client, c.serviceName, "ProcessChallenge", &c.key, &codec) } +func (c *certificateServiceIngressClient) RenewExpiringCertificates() ingress.Requester[*RenewExpiringCertificatesRequest, *RenewExpiringCertificatesResponse] { + codec := encoding.ProtoJSONCodec + return ingress.NewRequester[*RenewExpiringCertificatesRequest, *RenewExpiringCertificatesResponse](c.client, c.serviceName, "RenewExpiringCertificates", &c.key, &codec) +} + // CertificateServiceServer is the server API for hydra.v1.CertificateService service. // All implementations should embed UnimplementedCertificateServiceServer // for forward compatibility. @@ -81,6 +102,10 @@ type CertificateServiceServer interface { // ProcessChallenge handles the complete ACME certificate challenge flow // Key: domain name (ensures only one challenge per domain at a time) ProcessChallenge(ctx sdk_go.ObjectContext, req *ProcessChallengeRequest) (*ProcessChallengeResponse, error) + // RenewExpiringCertificates checks for certificates expiring soon and renews them. + // This should be called periodically (e.g., daily via cron). + // Key: "global" (single instance ensures no duplicate renewal runs) + RenewExpiringCertificates(ctx sdk_go.ObjectContext, req *RenewExpiringCertificatesRequest) (*RenewExpiringCertificatesResponse, error) } // UnimplementedCertificateServiceServer should be embedded to have @@ -93,6 +118,9 @@ type UnimplementedCertificateServiceServer struct{} func (UnimplementedCertificateServiceServer) ProcessChallenge(ctx sdk_go.ObjectContext, req *ProcessChallengeRequest) (*ProcessChallengeResponse, error) { return nil, sdk_go.TerminalError(fmt.Errorf("method ProcessChallenge not implemented"), 501) } +func (UnimplementedCertificateServiceServer) RenewExpiringCertificates(ctx sdk_go.ObjectContext, req *RenewExpiringCertificatesRequest) (*RenewExpiringCertificatesResponse, error) { + return nil, sdk_go.TerminalError(fmt.Errorf("method RenewExpiringCertificates not implemented"), 501) +} func (UnimplementedCertificateServiceServer) testEmbeddedByValue() {} // UnsafeCertificateServiceServer may be embedded to opt out of forward compatibility for this service. @@ -113,5 +141,6 @@ func NewCertificateServiceServer(srv CertificateServiceServer, opts ...sdk_go.Se sOpts := append([]sdk_go.ServiceDefinitionOption{sdk_go.WithProtoJSON}, opts...) router := sdk_go.NewObject("hydra.v1.CertificateService", sOpts...) router = router.Handler("ProcessChallenge", sdk_go.NewObjectHandler(srv.ProcessChallenge)) + router = router.Handler("RenewExpiringCertificates", sdk_go.NewObjectHandler(srv.RenewExpiringCertificates)) return router } diff --git a/go/go.mod b/go/go.mod index 48af0f9bb3..8c8add94fe 100644 --- a/go/go.mod +++ b/go/go.mod @@ -90,6 +90,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.7.5 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.18 // indirect + github.com/aws/aws-sdk-go-v2/service/route53 v1.53.1 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.30.1 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.39.1 // indirect diff --git a/go/go.sum b/go/go.sum index 1c0cf1e850..abe31ff328 100644 --- a/go/go.sum +++ b/go/go.sum @@ -90,6 +90,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 h1:kDqdFvMY github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13/go.mod h1:lmKuogqSU3HzQCwZ9ZtcqOc5XGMqtDK7OIc2+DxiUEg= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.18 h1:OS2e0SKqsU2LiJPqL8u9x41tKc6MMEHrWjLVLn3oysg= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.18/go.mod h1:+Yrk+MDGzlNGxCXieljNeWpoZTCQUQVL+Jk9hGGJ8qM= +github.com/aws/aws-sdk-go-v2/service/route53 v1.53.1 h1:R3nSX1hguRy6MnknHiepSvqnnL8ansFwK2hidPesAYU= +github.com/aws/aws-sdk-go-v2/service/route53 v1.53.1/go.mod h1:fmSiB4OAghn85lQgk7XN9l9bpFg5Bm1v3HuaXKytPEw= github.com/aws/aws-sdk-go-v2/service/s3 v1.84.1 h1:RkHXU9jP0DptGy7qKI8CBGsUJruWz0v5IgwBa2DwWcU= github.com/aws/aws-sdk-go-v2/service/s3 v1.84.1/go.mod h1:3xAOf7tdKF+qbb+XpU+EPhNXAdun3Lu1RcDrj8KC24I= github.com/aws/aws-sdk-go-v2/service/sso v1.30.1 h1:0JPwLz1J+5lEOfy/g0SURC9cxhbQ1lIMHMa+AHZSzz0= diff --git a/go/pkg/db/bulk_custom_domain_upsert.sql_generated.go b/go/pkg/db/bulk_custom_domain_upsert.sql_generated.go new file mode 100644 index 0000000000..cd10a9b855 --- /dev/null +++ b/go/pkg/db/bulk_custom_domain_upsert.sql_generated.go @@ -0,0 +1,50 @@ +// Code generated by sqlc bulk insert plugin. DO NOT EDIT. + +package db + +import ( + "context" + "fmt" + "strings" +) + +// bulkUpsertCustomDomain is the base query for bulk insert +const bulkUpsertCustomDomain = `INSERT INTO custom_domains (id, workspace_id, domain, challenge_type, created_at) VALUES %s ON DUPLICATE KEY UPDATE + workspace_id = VALUES(workspace_id), + challenge_type = VALUES(challenge_type), + updated_at = ?` + +// UpsertCustomDomain performs bulk insert in a single query +func (q *BulkQueries) UpsertCustomDomain(ctx context.Context, db DBTX, args []UpsertCustomDomainParams) error { + + if len(args) == 0 { + return nil + } + + // Build the bulk insert query + valueClauses := make([]string, len(args)) + for i := range args { + valueClauses[i] = "(?, ?, ?, ?, ?)" + } + + bulkQuery := fmt.Sprintf(bulkUpsertCustomDomain, strings.Join(valueClauses, ", ")) + + // Collect all arguments + var allArgs []any + for _, arg := range args { + allArgs = append(allArgs, arg.ID) + allArgs = append(allArgs, arg.WorkspaceID) + allArgs = append(allArgs, arg.Domain) + allArgs = append(allArgs, arg.ChallengeType) + allArgs = append(allArgs, arg.CreatedAt) + } + + // Add ON DUPLICATE KEY UPDATE parameters (only once, not per row) + if len(args) > 0 { + allArgs = append(allArgs, args[0].UpdatedAt) + } + + // Execute the bulk insert + _, err := db.ExecContext(ctx, bulkQuery, allArgs...) + return err +} diff --git a/go/pkg/db/custom_domain_find_by_domain_or_wildcard.sql_generated.go b/go/pkg/db/custom_domain_find_by_domain_or_wildcard.sql_generated.go new file mode 100644 index 0000000000..9c2dae3160 --- /dev/null +++ b/go/pkg/db/custom_domain_find_by_domain_or_wildcard.sql_generated.go @@ -0,0 +1,45 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: custom_domain_find_by_domain_or_wildcard.sql + +package db + +import ( + "context" +) + +const findCustomDomainByDomainOrWildcard = `-- name: FindCustomDomainByDomainOrWildcard :one +SELECT id, workspace_id, domain, challenge_type, created_at, updated_at FROM custom_domains +WHERE domain IN (?, ?) +ORDER BY + CASE WHEN domain = ? THEN 0 ELSE 1 END +LIMIT 1 +` + +type FindCustomDomainByDomainOrWildcardParams struct { + Domain string `db:"domain"` + Domain_2 string `db:"domain_2"` + Domain_3 string `db:"domain_3"` +} + +// FindCustomDomainByDomainOrWildcard +// +// SELECT id, workspace_id, domain, challenge_type, created_at, updated_at FROM custom_domains +// WHERE domain IN (?, ?) +// ORDER BY +// CASE WHEN domain = ? THEN 0 ELSE 1 END +// LIMIT 1 +func (q *Queries) FindCustomDomainByDomainOrWildcard(ctx context.Context, db DBTX, arg FindCustomDomainByDomainOrWildcardParams) (CustomDomain, error) { + row := db.QueryRowContext(ctx, findCustomDomainByDomainOrWildcard, arg.Domain, arg.Domain_2, arg.Domain_3) + var i CustomDomain + err := row.Scan( + &i.ID, + &i.WorkspaceID, + &i.Domain, + &i.ChallengeType, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/go/pkg/db/custom_domain_upsert.sql_generated.go b/go/pkg/db/custom_domain_upsert.sql_generated.go new file mode 100644 index 0000000000..bb20b93b6a --- /dev/null +++ b/go/pkg/db/custom_domain_upsert.sql_generated.go @@ -0,0 +1,49 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: custom_domain_upsert.sql + +package db + +import ( + "context" + "database/sql" +) + +const upsertCustomDomain = `-- name: UpsertCustomDomain :exec +INSERT INTO custom_domains (id, workspace_id, domain, challenge_type, created_at) +VALUES (?, ?, ?, ?, ?) +ON DUPLICATE KEY UPDATE + workspace_id = VALUES(workspace_id), + challenge_type = VALUES(challenge_type), + updated_at = ? +` + +type UpsertCustomDomainParams struct { + ID string `db:"id"` + WorkspaceID string `db:"workspace_id"` + Domain string `db:"domain"` + ChallengeType CustomDomainsChallengeType `db:"challenge_type"` + CreatedAt int64 `db:"created_at"` + UpdatedAt sql.NullInt64 `db:"updated_at"` +} + +// UpsertCustomDomain +// +// INSERT INTO custom_domains (id, workspace_id, domain, challenge_type, created_at) +// VALUES (?, ?, ?, ?, ?) +// ON DUPLICATE KEY UPDATE +// workspace_id = VALUES(workspace_id), +// challenge_type = VALUES(challenge_type), +// updated_at = ? +func (q *Queries) UpsertCustomDomain(ctx context.Context, db DBTX, arg UpsertCustomDomainParams) error { + _, err := db.ExecContext(ctx, upsertCustomDomain, + arg.ID, + arg.WorkspaceID, + arg.Domain, + arg.ChallengeType, + arg.CreatedAt, + arg.UpdatedAt, + ) + return err +} diff --git a/go/pkg/db/handle_err_no_rows.go b/go/pkg/db/handle_err_no_rows.go index aae1d973e1..03b00285ba 100644 --- a/go/pkg/db/handle_err_no_rows.go +++ b/go/pkg/db/handle_err_no_rows.go @@ -5,6 +5,8 @@ import ( "errors" ) +// IsNotFound returns true if the error is sql.ErrNoRows. +// Use this for consistent not-found handling across the codebase. func IsNotFound(err error) bool { return errors.Is(err, sql.ErrNoRows) } diff --git a/go/pkg/db/querier_bulk_generated.go b/go/pkg/db/querier_bulk_generated.go index 3383c61f9a..1d92ccbc41 100644 --- a/go/pkg/db/querier_bulk_generated.go +++ b/go/pkg/db/querier_bulk_generated.go @@ -13,6 +13,7 @@ type BulkQuerier interface { InsertAuditLogTargets(ctx context.Context, db DBTX, args []InsertAuditLogTargetParams) error InsertCertificates(ctx context.Context, db DBTX, args []InsertCertificateParams) error InsertClickhouseWorkspaceSettingses(ctx context.Context, db DBTX, args []InsertClickhouseWorkspaceSettingsParams) error + UpsertCustomDomain(ctx context.Context, db DBTX, args []UpsertCustomDomainParams) error InsertDeployments(ctx context.Context, db DBTX, args []InsertDeploymentParams) error InsertDeploymentSteps(ctx context.Context, db DBTX, args []InsertDeploymentStepParams) error InsertEnvironments(ctx context.Context, db DBTX, args []InsertEnvironmentParams) error diff --git a/go/pkg/db/querier_generated.go b/go/pkg/db/querier_generated.go index 120e82978f..210d45d9d1 100644 --- a/go/pkg/db/querier_generated.go +++ b/go/pkg/db/querier_generated.go @@ -177,6 +177,14 @@ type Querier interface { // FROM custom_domains // WHERE domain = ? FindCustomDomainByDomain(ctx context.Context, db DBTX, domain string) (CustomDomain, error) + //FindCustomDomainByDomainOrWildcard + // + // SELECT id, workspace_id, domain, challenge_type, created_at, updated_at FROM custom_domains + // WHERE domain IN (?, ?) + // ORDER BY + // CASE WHEN domain = ? THEN 0 ELSE 1 END + // LIMIT 1 + FindCustomDomainByDomainOrWildcard(ctx context.Context, db DBTX, arg FindCustomDomainByDomainOrWildcardParams) (CustomDomain, error) //FindCustomDomainById // // SELECT @@ -2131,6 +2139,15 @@ type Querier interface { // SET plan = ? // WHERE id = ? UpdateWorkspacePlan(ctx context.Context, db DBTX, arg UpdateWorkspacePlanParams) (sql.Result, error) + //UpsertCustomDomain + // + // INSERT INTO custom_domains (id, workspace_id, domain, challenge_type, created_at) + // VALUES (?, ?, ?, ?, ?) + // ON DUPLICATE KEY UPDATE + // workspace_id = VALUES(workspace_id), + // challenge_type = VALUES(challenge_type), + // updated_at = ? + UpsertCustomDomain(ctx context.Context, db DBTX, arg UpsertCustomDomainParams) error //UpsertEnvironment // // INSERT INTO environments ( diff --git a/go/pkg/db/queries/custom_domain_find_by_domain_or_wildcard.sql b/go/pkg/db/queries/custom_domain_find_by_domain_or_wildcard.sql new file mode 100644 index 0000000000..641d247d71 --- /dev/null +++ b/go/pkg/db/queries/custom_domain_find_by_domain_or_wildcard.sql @@ -0,0 +1,6 @@ +-- name: FindCustomDomainByDomainOrWildcard :one +SELECT * FROM custom_domains +WHERE domain IN (?, ?) +ORDER BY + CASE WHEN domain = ? THEN 0 ELSE 1 END +LIMIT 1; diff --git a/go/pkg/db/queries/custom_domain_upsert.sql b/go/pkg/db/queries/custom_domain_upsert.sql new file mode 100644 index 0000000000..690e90c6ff --- /dev/null +++ b/go/pkg/db/queries/custom_domain_upsert.sql @@ -0,0 +1,7 @@ +-- name: UpsertCustomDomain :exec +INSERT INTO custom_domains (id, workspace_id, domain, challenge_type, created_at) +VALUES (?, ?, ?, ?, ?) +ON DUPLICATE KEY UPDATE + workspace_id = VALUES(workspace_id), + challenge_type = VALUES(challenge_type), + updated_at = ?; diff --git a/go/proto/hydra/v1/certificate.proto b/go/proto/hydra/v1/certificate.proto index 727ec20f03..ef7049e604 100644 --- a/go/proto/hydra/v1/certificate.proto +++ b/go/proto/hydra/v1/certificate.proto @@ -13,6 +13,11 @@ service CertificateService { // ProcessChallenge handles the complete ACME certificate challenge flow // Key: domain name (ensures only one challenge per domain at a time) rpc ProcessChallenge(ProcessChallengeRequest) returns (ProcessChallengeResponse) {} + + // RenewExpiringCertificates checks for certificates expiring soon and renews them. + // This should be called periodically (e.g., daily via cron). + // Key: "global" (single instance ensures no duplicate renewal runs) + rpc RenewExpiringCertificates(RenewExpiringCertificatesRequest) returns (RenewExpiringCertificatesResponse) {} } message ProcessChallengeRequest { @@ -24,3 +29,14 @@ message ProcessChallengeResponse { string certificate_id = 1; string status = 2; // "success", "failed", "pending" } + +message RenewExpiringCertificatesRequest { + // Number of days before expiry to start renewal (default: 30) + int32 days_before_expiry = 1; +} + +message RenewExpiringCertificatesResponse { + int32 certificates_checked = 1; + int32 renewals_triggered = 2; + repeated string failed_domains = 3; +}