diff --git a/client/internal/tunnel/certenroll.go b/client/internal/tunnel/certenroll.go new file mode 100644 index 000000000..00eade70c --- /dev/null +++ b/client/internal/tunnel/certenroll.go @@ -0,0 +1,451 @@ +// Package tunnel provides machine tunnel functionality for Windows pre-login VPN. +package tunnel + +import ( + "context" + "crypto/sha256" + "crypto/x509" + "encoding/hex" + "encoding/pem" + "fmt" + "os" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +// CertEnrollmentConfig contains configuration for certificate enrollment. +type CertEnrollmentConfig struct { + // TemplateName is the AD CS certificate template name. + // Default: "NetBirdMachineTunnel" + TemplateName string + + // DomainName is the FQDN of the domain (e.g., "corp.local"). + DomainName string + + // Hostname is the machine hostname (without domain). + Hostname string + + // OutputCertPath is the path to write the enrolled certificate. + OutputCertPath string + + // OutputKeyPath is the path to write the private key. + OutputKeyPath string + + // ValidityCheck enables pre-enrollment validation. + ValidityCheck bool +} + +// CertEnrollmentResult contains the results of certificate enrollment. +type CertEnrollmentResult struct { + // Success indicates if enrollment succeeded. + Success bool + + // CertPath is the path to the enrolled certificate. + CertPath string + + // KeyPath is the path to the private key. + KeyPath string + + // Thumbprint is the SHA-256 thumbprint of the certificate. + Thumbprint string + + // Subject is the certificate subject. + Subject string + + // DNSNames are the SAN DNS names in the certificate. + DNSNames []string + + // NotBefore is the certificate validity start time. + NotBefore time.Time + + // NotAfter is the certificate validity end time. + NotAfter time.Time + + // Error contains any error that occurred. + Error error +} + +// DefaultCertTemplateName is the default AD CS template name. +const DefaultCertTemplateName = "NetBirdMachineTunnel" + +// CertRenewalThreshold is how long before expiry to trigger renewal (30 days). +const CertRenewalThreshold = 30 * 24 * time.Hour + +// MinCertValidity is the minimum acceptable certificate validity (7 days). +const MinCertValidity = 7 * 24 * time.Hour + +// ValidateMachineCertificate validates a machine certificate for use with mTLS. +// It checks: +// - Certificate exists and is readable +// - Certificate is not expired +// - Certificate has valid SAN DNSNames matching expected hostname.domain format +// - Certificate is signed by a trusted CA (optional, if caCert provided) +func ValidateMachineCertificate(certPath string, expectedHostname, expectedDomain string) (*CertEnrollmentResult, error) { + result := &CertEnrollmentResult{ + CertPath: certPath, + } + + // Read certificate file + certPEM, err := os.ReadFile(certPath) + if err != nil { + result.Error = fmt.Errorf("read certificate: %w", err) + return result, result.Error + } + + // Parse PEM block + block, _ := pem.Decode(certPEM) + if block == nil { + result.Error = fmt.Errorf("failed to decode PEM block") + return result, result.Error + } + + // Parse certificate + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + result.Error = fmt.Errorf("parse certificate: %w", err) + return result, result.Error + } + + // Fill in result fields + result.Subject = cert.Subject.String() + result.DNSNames = cert.DNSNames + result.NotBefore = cert.NotBefore + result.NotAfter = cert.NotAfter + result.Thumbprint = ComputeCertThumbprint(cert) + + // Check expiry + now := time.Now() + if now.Before(cert.NotBefore) { + result.Error = fmt.Errorf("certificate not yet valid (starts %s)", cert.NotBefore) + return result, result.Error + } + if now.After(cert.NotAfter) { + result.Error = fmt.Errorf("certificate expired (ended %s)", cert.NotAfter) + return result, result.Error + } + + // Check minimum validity remaining + remaining := cert.NotAfter.Sub(now) + if remaining < MinCertValidity { + log.Warnf("Certificate expires soon: %s remaining", remaining) + } + + // Check SAN DNSNames + if len(cert.DNSNames) == 0 { + result.Error = fmt.Errorf("certificate has no SAN DNSNames") + return result, result.Error + } + + // Validate expected hostname.domain format + expectedFQDN := strings.ToLower(fmt.Sprintf("%s.%s", expectedHostname, expectedDomain)) + foundMatch := false + for _, dnsName := range cert.DNSNames { + if strings.EqualFold(dnsName, expectedFQDN) { + foundMatch = true + break + } + } + + if !foundMatch { + result.Error = fmt.Errorf("certificate SAN DNSNames %v do not match expected %s", cert.DNSNames, expectedFQDN) + return result, result.Error + } + + result.Success = true + log.Infof("Certificate validation passed: %s (expires %s)", result.Subject, cert.NotAfter) + return result, nil +} + +// ComputeCertThumbprint computes the SHA-256 thumbprint of a certificate. +func ComputeCertThumbprint(cert *x509.Certificate) string { + hash := sha256.Sum256(cert.Raw) + return hex.EncodeToString(hash[:]) +} + +// NeedsRenewal checks if a certificate needs renewal. +func NeedsRenewal(certPath string) (bool, error) { + certPEM, err := os.ReadFile(certPath) + if err != nil { + return true, fmt.Errorf("read certificate: %w", err) + } + + block, _ := pem.Decode(certPEM) + if block == nil { + return true, fmt.Errorf("failed to decode PEM block") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return true, fmt.Errorf("parse certificate: %w", err) + } + + remaining := time.Until(cert.NotAfter) + if remaining < CertRenewalThreshold { + log.Infof("Certificate renewal needed: %s remaining (threshold: %s)", remaining, CertRenewalThreshold) + return true, nil + } + + return false, nil +} + +// GenerateCertEnrollmentScript generates a PowerShell script for AD CS enrollment. +// This script uses certreq.exe which is available on domain-joined Windows machines. +func GenerateCertEnrollmentScript(config *CertEnrollmentConfig) string { + templateName := config.TemplateName + if templateName == "" { + templateName = DefaultCertTemplateName + } + + fqdn := fmt.Sprintf("%s.%s", config.Hostname, config.DomainName) + + script := fmt.Sprintf(`# Certificate Enrollment Script (Generated by NetBird Machine Tunnel) +# Prerequisites: Domain-joined, AD CS available, template "%s" configured + +$ErrorActionPreference = 'Stop' +$hostname = '%s' +$domain = '%s' +$fqdn = '%s' +$templateName = '%s' + +# Paths +$infPath = "$env:TEMP\netbird-certreq.inf" +$reqPath = "$env:TEMP\netbird-certreq.req" +$cerPath = "$env:TEMP\netbird-certreq.cer" +$pfxPath = "$env:TEMP\netbird-certreq.pfx" + +Write-Host "Enrolling machine certificate for: $fqdn" +Write-Host "Using template: $templateName" + +# Step 1: Create INF file for certificate request +$infContent = @" +[NewRequest] +Subject = "CN=$fqdn" +KeySpec = 1 +KeyLength = 2048 +Exportable = TRUE +MachineKeySet = TRUE +SMIME = FALSE +PrivateKeyArchive = FALSE +UserProtected = FALSE +UseExistingKeySet = FALSE +ProviderName = "Microsoft RSA SChannel Cryptographic Provider" +ProviderType = 12 +RequestType = PKCS10 +KeyUsage = 0xa0 +HashAlgorithm = SHA256 + +[EnhancedKeyUsageExtension] +OID = 1.3.6.1.5.5.7.3.2 ; Client Authentication + +[Extensions] +2.5.29.17 = "{text}" +_continue_ = "dns=$fqdn&" + +[RequestAttributes] +CertificateTemplate = $templateName +"@ + +Set-Content -Path $infPath -Value $infContent -Encoding ASCII +Write-Host "Created certificate request INF: $infPath" + +# Step 2: Generate certificate request +Write-Host "Generating certificate request..." +$result = certreq -new -machine $infPath $reqPath 2>&1 +if ($LASTEXITCODE -ne 0) { + throw "certreq -new failed: $result" +} +Write-Host "Created certificate request: $reqPath" + +# Step 3: Submit request to CA +Write-Host "Submitting request to CA..." +$result = certreq -submit -machine -config - $reqPath $cerPath 2>&1 +if ($LASTEXITCODE -ne 0) { + # Try with explicit CA discovery + Write-Host "Trying with CA auto-discovery..." + $result = certreq -submit -machine $reqPath $cerPath 2>&1 + if ($LASTEXITCODE -ne 0) { + throw "certreq -submit failed: $result" + } +} +Write-Host "Received certificate: $cerPath" + +# Step 4: Accept certificate into store +Write-Host "Installing certificate..." +$result = certreq -accept -machine $cerPath 2>&1 +if ($LASTEXITCODE -ne 0) { + throw "certreq -accept failed: $result" +} +Write-Host "Certificate installed to LocalMachine\My" + +# Step 5: Find and export the certificate +$cert = Get-ChildItem Cert:\LocalMachine\My | + Where-Object { $_.Subject -match $fqdn } | + Sort-Object NotAfter -Descending | + Select-Object -First 1 + +if (-not $cert) { + throw "Could not find enrolled certificate in store" +} + +Write-Host "Certificate Details:" +Write-Host " Subject: $($cert.Subject)" +Write-Host " Thumbprint: $($cert.Thumbprint)" +Write-Host " Expires: $($cert.NotAfter)" +Write-Host " DNS Names: $($cert.DnsNameList -join ', ')" + +# Step 6: Export to PEM format (requires OpenSSL or manual conversion) +# For now, output the thumbprint for config update +$thumbprint = $cert.Thumbprint + +# Cleanup temp files +Remove-Item $infPath, $reqPath, $cerPath -ErrorAction SilentlyContinue + +# Return result +@{ + Success = $true + Thumbprint = $thumbprint + Subject = $cert.Subject + NotAfter = $cert.NotAfter + DnsNames = $cert.DnsNameList +} +`, templateName, config.Hostname, config.DomainName, fqdn, templateName) + + return script +} + +// CertificateInfo contains parsed certificate information. +type CertificateInfo struct { + // Thumbprint is the SHA-256 thumbprint. + Thumbprint string + + // Subject is the certificate subject DN. + Subject string + + // Issuer is the certificate issuer DN. + Issuer string + + // DNSNames are the SAN DNS names. + DNSNames []string + + // NotBefore is the validity start. + NotBefore time.Time + + // NotAfter is the validity end. + NotAfter time.Time + + // SerialNumber is the certificate serial number (hex encoded). + SerialNumber string + + // IsExpired indicates if the certificate is expired. + IsExpired bool + + // NeedsRenewal indicates if the certificate should be renewed. + NeedsRenewal bool + + // RemainingValidity is the time until expiry. + RemainingValidity time.Duration +} + +// ParseCertificateFile parses a PEM certificate file and returns info. +func ParseCertificateFile(certPath string) (*CertificateInfo, error) { + certPEM, err := os.ReadFile(certPath) + if err != nil { + return nil, fmt.Errorf("read certificate: %w", err) + } + + block, _ := pem.Decode(certPEM) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("parse certificate: %w", err) + } + + now := time.Now() + remaining := cert.NotAfter.Sub(now) + + info := &CertificateInfo{ + Thumbprint: ComputeCertThumbprint(cert), + Subject: cert.Subject.String(), + Issuer: cert.Issuer.String(), + DNSNames: cert.DNSNames, + NotBefore: cert.NotBefore, + NotAfter: cert.NotAfter, + SerialNumber: cert.SerialNumber.Text(16), + IsExpired: now.After(cert.NotAfter), + NeedsRenewal: remaining < CertRenewalThreshold, + RemainingValidity: remaining, + } + + return info, nil +} + +// WatchCertificateExpiry starts a goroutine that monitors certificate expiry +// and calls the callback when renewal is needed. +func WatchCertificateExpiry(ctx context.Context, certPath string, checkInterval time.Duration, onRenewalNeeded func()) { + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Debug("Certificate expiry watcher stopped") + return + case <-ticker.C: + needsRenewal, err := NeedsRenewal(certPath) + if err != nil { + log.Warnf("Certificate renewal check failed: %v", err) + continue + } + if needsRenewal { + log.Info("Certificate renewal needed, triggering callback") + onRenewalNeeded() + } + } + } +} + +// ExtractIssuerFingerprint extracts the issuer certificate fingerprint from a cert chain. +// This is used for mTLS issuer verification (not AuthorityKeyId!). +func ExtractIssuerFingerprint(certPath string, verifyChain bool) (string, error) { + certPEM, err := os.ReadFile(certPath) + if err != nil { + return "", fmt.Errorf("read certificate: %w", err) + } + + // Parse all certificates in the PEM file (may include chain) + var certs []*x509.Certificate + rest := certPEM + for { + block, remaining := pem.Decode(rest) + if block == nil { + break + } + if block.Type == "CERTIFICATE" { + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return "", fmt.Errorf("parse certificate: %w", err) + } + certs = append(certs, cert) + } + rest = remaining + } + + if len(certs) == 0 { + return "", fmt.Errorf("no certificates found in file") + } + + // If we have a chain, the issuer is the second certificate + if len(certs) > 1 { + issuerCert := certs[1] + return ComputeCertThumbprint(issuerCert), nil + } + + // Single certificate - issuer fingerprint would need to be looked up + // In production, this should verify against the system trust store + return "", fmt.Errorf("certificate chain required for issuer fingerprint extraction") +} diff --git a/client/internal/tunnel/certenroll_test.go b/client/internal/tunnel/certenroll_test.go new file mode 100644 index 000000000..a583f364a --- /dev/null +++ b/client/internal/tunnel/certenroll_test.go @@ -0,0 +1,540 @@ +package tunnel + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultCertTemplateName(t *testing.T) { + assert.Equal(t, "NetBirdMachineTunnel", DefaultCertTemplateName) +} + +func TestCertRenewalThreshold(t *testing.T) { + assert.Equal(t, 30*24*time.Hour, CertRenewalThreshold) +} + +func TestMinCertValidity(t *testing.T) { + assert.Equal(t, 7*24*time.Hour, MinCertValidity) +} + +func TestComputeCertThumbprint(t *testing.T) { + // Create a test certificate + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + err := generateTestCertWithDNSNames(certPath, keyPath, []string{"test.example.com"}, time.Hour*24) + require.NoError(t, err) + + // Read and parse certificate + certPEM, err := os.ReadFile(certPath) + require.NoError(t, err) + + block, _ := pem.Decode(certPEM) + require.NotNil(t, block) + + cert, err := x509.ParseCertificate(block.Bytes) + require.NoError(t, err) + + // Compute thumbprint + thumbprint := ComputeCertThumbprint(cert) + + // Verify thumbprint is 64 hex characters (SHA-256) + assert.Len(t, thumbprint, 64) + assert.Regexp(t, "^[0-9a-f]+$", thumbprint) +} + +func TestValidateMachineCertificate_ValidCert(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + err := generateTestCertWithDNSNames(certPath, keyPath, []string{"testhost.example.com"}, time.Hour*24*365) + require.NoError(t, err) + + result, err := ValidateMachineCertificate(certPath, "testhost", "example.com") + + assert.NoError(t, err) + assert.True(t, result.Success) + assert.Contains(t, result.DNSNames, "testhost.example.com") + assert.NotEmpty(t, result.Thumbprint) +} + +func TestValidateMachineCertificate_ExpiredCert(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + // Generate an expired certificate + err := generateTestCertWithTimes(certPath, keyPath, []string{"testhost.example.com"}, + time.Now().Add(-48*time.Hour), time.Now().Add(-24*time.Hour)) + require.NoError(t, err) + + result, err := ValidateMachineCertificate(certPath, "testhost", "example.com") + + assert.Error(t, err) + assert.False(t, result.Success) + assert.Contains(t, err.Error(), "expired") +} + +func TestValidateMachineCertificate_NotYetValid(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + // Generate a certificate that's not yet valid + err := generateTestCertWithTimes(certPath, keyPath, []string{"testhost.example.com"}, + time.Now().Add(24*time.Hour), time.Now().Add(48*time.Hour)) + require.NoError(t, err) + + result, err := ValidateMachineCertificate(certPath, "testhost", "example.com") + + assert.Error(t, err) + assert.False(t, result.Success) + assert.Contains(t, err.Error(), "not yet valid") +} + +func TestValidateMachineCertificate_NoDNSNames(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + // Generate a certificate without DNS names + err := generateTestCertWithDNSNames(certPath, keyPath, nil, time.Hour*24) + require.NoError(t, err) + + result, err := ValidateMachineCertificate(certPath, "testhost", "example.com") + + assert.Error(t, err) + assert.False(t, result.Success) + assert.Contains(t, err.Error(), "no SAN DNSNames") +} + +func TestValidateMachineCertificate_WrongHostname(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + err := generateTestCertWithDNSNames(certPath, keyPath, []string{"otherhost.example.com"}, time.Hour*24) + require.NoError(t, err) + + result, err := ValidateMachineCertificate(certPath, "testhost", "example.com") + + assert.Error(t, err) + assert.False(t, result.Success) + assert.Contains(t, err.Error(), "do not match expected") +} + +func TestValidateMachineCertificate_FileNotFound(t *testing.T) { + result, err := ValidateMachineCertificate("/nonexistent/cert.pem", "test", "example.com") + + assert.Error(t, err) + assert.False(t, result.Success) + assert.Contains(t, err.Error(), "read certificate") +} + +func TestValidateMachineCertificate_InvalidPEM(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "invalid.pem") + + err := os.WriteFile(certPath, []byte("not a valid PEM"), 0600) + require.NoError(t, err) + + result, err := ValidateMachineCertificate(certPath, "test", "example.com") + + assert.Error(t, err) + assert.False(t, result.Success) + assert.Contains(t, err.Error(), "decode PEM") +} + +func TestNeedsRenewal_ValidCert(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + // Certificate valid for 1 year - doesn't need renewal + err := generateTestCertWithDNSNames(certPath, keyPath, []string{"test.example.com"}, time.Hour*24*365) + require.NoError(t, err) + + needsRenewal, err := NeedsRenewal(certPath) + + assert.NoError(t, err) + assert.False(t, needsRenewal) +} + +func TestNeedsRenewal_ExpiringSoon(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + // Certificate expires in 15 days - needs renewal (threshold is 30 days) + err := generateTestCertWithTimes(certPath, keyPath, []string{"test.example.com"}, + time.Now().Add(-time.Hour), time.Now().Add(15*24*time.Hour)) + require.NoError(t, err) + + needsRenewal, err := NeedsRenewal(certPath) + + assert.NoError(t, err) + assert.True(t, needsRenewal) +} + +func TestNeedsRenewal_FileNotFound(t *testing.T) { + needsRenewal, err := NeedsRenewal("/nonexistent/cert.pem") + + assert.Error(t, err) + assert.True(t, needsRenewal) // Should return true if we can't read the cert +} + +func TestParseCertificateFile(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + dnsNames := []string{"host.domain.local", "alias.domain.local"} + err := generateTestCertWithDNSNames(certPath, keyPath, dnsNames, time.Hour*24*365) + require.NoError(t, err) + + info, err := ParseCertificateFile(certPath) + + assert.NoError(t, err) + assert.NotEmpty(t, info.Thumbprint) + assert.NotEmpty(t, info.Subject) + assert.NotEmpty(t, info.Issuer) + assert.Equal(t, dnsNames, info.DNSNames) + assert.False(t, info.IsExpired) + assert.False(t, info.NeedsRenewal) + assert.True(t, info.RemainingValidity > 364*24*time.Hour) +} + +func TestParseCertificateFile_Expired(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + err := generateTestCertWithTimes(certPath, keyPath, []string{"test.example.com"}, + time.Now().Add(-48*time.Hour), time.Now().Add(-24*time.Hour)) + require.NoError(t, err) + + info, err := ParseCertificateFile(certPath) + + assert.NoError(t, err) + assert.True(t, info.IsExpired) + assert.True(t, info.NeedsRenewal) + assert.True(t, info.RemainingValidity < 0) +} + +func TestGenerateCertEnrollmentScript_Basic(t *testing.T) { + config := &CertEnrollmentConfig{ + Hostname: "win10-pc", + DomainName: "corp.local", + } + + script := GenerateCertEnrollmentScript(config) + + assert.Contains(t, script, "win10-pc.corp.local") + assert.Contains(t, script, DefaultCertTemplateName) + assert.Contains(t, script, "certreq -new") + assert.Contains(t, script, "certreq -submit") + assert.Contains(t, script, "certreq -accept") + assert.Contains(t, script, "Cert:\\LocalMachine\\My") +} + +func TestGenerateCertEnrollmentScript_CustomTemplate(t *testing.T) { + config := &CertEnrollmentConfig{ + Hostname: "server01", + DomainName: "example.com", + TemplateName: "CustomMachineTemplate", + } + + script := GenerateCertEnrollmentScript(config) + + assert.Contains(t, script, "CustomMachineTemplate") + assert.Contains(t, script, "server01.example.com") +} + +func TestGenerateCertEnrollmentScript_ContainsRequiredSteps(t *testing.T) { + config := &CertEnrollmentConfig{ + Hostname: "test", + DomainName: "test.local", + } + + script := GenerateCertEnrollmentScript(config) + + // Check that all required steps are present + assert.Contains(t, script, "Step 1: Create INF") + assert.Contains(t, script, "Step 2: Generate certificate request") + assert.Contains(t, script, "Step 3: Submit request") + assert.Contains(t, script, "Step 4: Accept certificate") + assert.Contains(t, script, "Step 5: Find and export") + assert.Contains(t, script, "Step 6: Export to PEM") + + // Check crypto requirements + assert.Contains(t, script, "KeyLength = 2048") + assert.Contains(t, script, "SHA256") + assert.Contains(t, script, "MachineKeySet = TRUE") +} + +func TestWatchCertificateExpiry(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + // Create a certificate that expires in 20 days (within renewal threshold) + err := generateTestCertWithTimes(certPath, keyPath, []string{"test.example.com"}, + time.Now().Add(-time.Hour), time.Now().Add(20*24*time.Hour)) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + var callbackCalled atomic.Bool + + go WatchCertificateExpiry(ctx, certPath, 500*time.Millisecond, func() { + callbackCalled.Store(true) + }) + + // Wait for at least one check + time.Sleep(1 * time.Second) + + assert.True(t, callbackCalled.Load(), "Callback should have been called for expiring cert") +} + +func TestWatchCertificateExpiry_NoRenewalNeeded(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + // Create a certificate valid for 1 year + err := generateTestCertWithDNSNames(certPath, keyPath, []string{"test.example.com"}, time.Hour*24*365) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + var callbackCalled atomic.Bool + + go WatchCertificateExpiry(ctx, certPath, 500*time.Millisecond, func() { + callbackCalled.Store(true) + }) + + // Wait for at least one check + time.Sleep(1 * time.Second) + + assert.False(t, callbackCalled.Load(), "Callback should NOT have been called for valid cert") +} + +func TestExtractIssuerFingerprint_SingleCert(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + err := generateTestCertWithDNSNames(certPath, keyPath, []string{"test.example.com"}, time.Hour*24) + require.NoError(t, err) + + // Single cert should return error (no issuer in chain) + _, err = ExtractIssuerFingerprint(certPath, true) + assert.Error(t, err) + assert.Contains(t, err.Error(), "certificate chain required") +} + +func TestExtractIssuerFingerprint_CertChain(t *testing.T) { + tmpDir := t.TempDir() + chainPath := filepath.Join(tmpDir, "chain.pem") + + // Generate CA and end-entity cert + caCert, caKey, err := generateCACertificate() + require.NoError(t, err) + + eeCert, _, err := generateSignedCertificate(caCert, caKey, []string{"test.example.com"}) + require.NoError(t, err) + + // Write chain (EE cert first, then CA) + chainPEM := append(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: eeCert.Raw}), + pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCert.Raw})...) + err = os.WriteFile(chainPath, chainPEM, 0600) + require.NoError(t, err) + + fingerprint, err := ExtractIssuerFingerprint(chainPath, true) + + assert.NoError(t, err) + assert.Len(t, fingerprint, 64) // SHA-256 hex + assert.Equal(t, ComputeCertThumbprint(caCert), fingerprint) +} + +func TestCertificateInfo_Fields(t *testing.T) { + info := &CertificateInfo{ + Thumbprint: "abc123", + Subject: "CN=test", + Issuer: "CN=CA", + DNSNames: []string{"test.local"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + SerialNumber: "1234", + IsExpired: false, + NeedsRenewal: false, + RemainingValidity: 24 * time.Hour, + } + + assert.Equal(t, "abc123", info.Thumbprint) + assert.Equal(t, "CN=test", info.Subject) + assert.Equal(t, "CN=CA", info.Issuer) + assert.Len(t, info.DNSNames, 1) + assert.False(t, info.IsExpired) +} + +func TestValidateMachineCertificate_CaseInsensitiveHostname(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "test.pem") + keyPath := filepath.Join(tmpDir, "test.key") + + // Certificate has lowercase DNS name + err := generateTestCertWithDNSNames(certPath, keyPath, []string{"testhost.example.com"}, time.Hour*24*365) + require.NoError(t, err) + + // Validate with uppercase hostname - should still match + result, err := ValidateMachineCertificate(certPath, "TESTHOST", "EXAMPLE.COM") + + assert.NoError(t, err) + assert.True(t, result.Success) +} + +// Helper functions + +func generateTestCertWithDNSNames(certPath, keyPath string, dnsNames []string, validity time.Duration) error { + return generateTestCertWithTimes(certPath, keyPath, dnsNames, + time.Now().Add(-time.Hour), time.Now().Add(validity)) +} + +func generateTestCertWithTimes(certPath, keyPath string, dnsNames []string, notBefore, notAfter time.Time) error { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return err + } + + serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "Test Certificate", + Organization: []string{"Test Org"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + DNSNames: dnsNames, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return err + } + + certFile, err := os.Create(certPath) + if err != nil { + return err + } + defer certFile.Close() + if err := pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { + return err + } + + keyFile, err := os.Create(keyPath) + if err != nil { + return err + } + defer keyFile.Close() + keyDER, _ := x509.MarshalECPrivateKey(privateKey) + if err := pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil { + return err + } + + return nil +} + +func generateCACertificate() (*x509.Certificate, *ecdsa.PrivateKey, error) { + caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, err + } + + serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + + caTemplate := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "Test CA", + Organization: []string{"Test CA Org"}, + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour * 365), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 1, + } + + caDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey) + if err != nil { + return nil, nil, err + } + + caCert, err := x509.ParseCertificate(caDER) + if err != nil { + return nil, nil, err + } + + return caCert, caKey, nil +} + +func generateSignedCertificate(caCert *x509.Certificate, caKey *ecdsa.PrivateKey, dnsNames []string) (*x509.Certificate, *ecdsa.PrivateKey, error) { + eeKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, err + } + + serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + + eeTemplate := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: strings.Join(dnsNames, ", "), + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour * 30), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + DNSNames: dnsNames, + } + + eeDER, err := x509.CreateCertificate(rand.Reader, &eeTemplate, caCert, &eeKey.PublicKey, caKey) + if err != nil { + return nil, nil, err + } + + eeCert, err := x509.ParseCertificate(eeDER) + if err != nil { + return nil, nil, err + } + + return eeCert, eeKey, nil +}