Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,12 @@ func (m *Manager) preserveExistingAuthSecrets(svc, existingService *service.Serv
svc.Auth.PinAuth = existingService.Auth.PinAuth
}

if svc.Auth.MTLSAuth != nil && svc.Auth.MTLSAuth.Enabled &&
existingService.Auth.MTLSAuth != nil && existingService.Auth.MTLSAuth.Enabled &&
svc.Auth.MTLSAuth.CACertPEM == "" {
svc.Auth.MTLSAuth.CACertPEM = existingService.Auth.MTLSAuth.CACertPEM
}

preserveHeaderAuthHashes(svc.Auth.HeaderAuths, existingService.Auth.HeaderAuths)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,30 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
assert.Equal(t, "new-password", updated.Auth.PasswordAuth.Password)
assert.NotEqual(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
})

t.Run("preserve mtls ca pem when empty", func(t *testing.T) {
existing := &rpservice.Service{
Auth: rpservice.AuthConfig{
MTLSAuth: &rpservice.MTLSAuthConfig{
Enabled: true,
CACertPEM: "existing-ca",
},
},
}

updated := &rpservice.Service{
Auth: rpservice.AuthConfig{
MTLSAuth: &rpservice.MTLSAuthConfig{
Enabled: true,
CACertPEM: "",
},
},
}

mgr.preserveExistingAuthSecrets(updated, existing)

assert.Equal(t, "existing-ca", updated.Auth.MTLSAuth.CACertPEM)
})
}

func TestPreserveServiceMetadata(t *testing.T) {
Expand Down
102 changes: 101 additions & 1 deletion management/internals/modules/reverseproxy/service/service.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package service

import (
"bytes"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"math/big"
Expand Down Expand Up @@ -100,11 +103,17 @@ type HeaderAuthConfig struct {
Value string `json:"value"`
}

type MTLSAuthConfig struct {
Enabled bool `json:"enabled"`
CACertPEM string `json:"ca_cert_pem"`
}

type AuthConfig struct {
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
HeaderAuths []*HeaderAuthConfig `json:"header_auths,omitempty" gorm:"serializer:json"`
MTLSAuth *MTLSAuthConfig `json:"mtls_auth,omitempty" gorm:"serializer:json"`
}

// AccessRestrictions controls who can connect to the service based on IP or geography.
Expand Down Expand Up @@ -169,6 +178,9 @@ func (a *AuthConfig) ClearSecrets() {
h.Value = ""
}
}
if a.MTLSAuth != nil {
a.MTLSAuth.CACertPEM = ""
}
}

type Meta struct {
Expand Down Expand Up @@ -249,6 +261,12 @@ func (s *Service) ToAPIResponse() *api.Service {
authConfig.HeaderAuths = &apiHeaders
}

if s.Auth.MTLSAuth != nil {
authConfig.MtlsAuth = &api.MTLSAuthConfig{
Enabled: s.Auth.MTLSAuth.Enabled,
}
}

// Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
Expand Down Expand Up @@ -337,6 +355,12 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
}
}

if s.Auth.MTLSAuth != nil && s.Auth.MTLSAuth.Enabled {
auth.MtlsAuth = &proto.MTLSAuth{
CaCertPem: s.Auth.MTLSAuth.CACertPEM,
}
}

mapping := &proto.ProxyMapping{
Type: operationToProtoType(operation),
Id: s.ID,
Expand Down Expand Up @@ -634,6 +658,12 @@ func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig {
})
}
}
if reqAuth.MtlsAuth != nil {
auth.MTLSAuth = &MTLSAuthConfig{
Enabled: reqAuth.MtlsAuth.Enabled,
CACertPEM: reqAuth.MtlsAuth.CaCertPem,
}
}
return auth
}

Expand Down Expand Up @@ -723,6 +753,9 @@ func (s *Service) Validate() error {
if err := validateHeaderAuths(s.Auth.HeaderAuths); err != nil {
return err
}
if err := validateMTLSAuth(s.Auth.MTLSAuth); err != nil {
return err
}
if err := validateAccessRestrictions(&s.Restrictions); err != nil {
return err
}
Expand Down Expand Up @@ -1002,6 +1035,68 @@ func validateHeaderAuths(headers []*HeaderAuthConfig) error {
return nil
}

func validateMTLSAuth(config *MTLSAuthConfig) error {
if config == nil || !config.Enabled {
return nil
}
if strings.TrimSpace(config.CACertPEM) == "" {
return errors.New("mtls_auth: ca_cert_pem is required when enabled")
}
if _, err := parseClientCAPEM(config.CACertPEM); err != nil {
return fmt.Errorf("mtls_auth: %w", err)
}
return nil
}

func parseClientCAPEM(caCertPEM string) (*x509.CertPool, error) {
pool := x509.NewCertPool()
remaining := []byte(caCertPEM)
foundCertificate := false

for len(remaining) > 0 {
remaining = trimPEMCommentsAndWhitespace(remaining)
if len(remaining) == 0 {
break
}

var block *pem.Block
block, remaining = pem.Decode(remaining)
if block == nil {
return nil, errors.New("ca_cert_pem contains invalid PEM data")
}
if block.Type != "CERTIFICATE" {
continue
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse certificate: %w", err)
}
pool.AddCert(cert)
foundCertificate = true
}

if !foundCertificate {
return nil, errors.New("ca_cert_pem must contain at least one certificate")
}

return pool, nil
}

func trimPEMCommentsAndWhitespace(data []byte) []byte {
for len(data) > 0 {
data = bytes.TrimLeft(data, " \t\r\n")
if len(data) == 0 || data[0] != '#' {
return data
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
data = data[i+1:]
continue
}
return nil
}
return data
}

const (
maxCIDREntries = 200
maxCountryEntries = 50
Expand Down Expand Up @@ -1104,7 +1199,8 @@ func (s *Service) EventMeta() map[string]any {
func (s *Service) isAuthEnabled() bool {
if (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) ||
(s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) ||
(s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) {
(s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled) ||
(s.Auth.MTLSAuth != nil && s.Auth.MTLSAuth.Enabled) {
return true
}
for _, h := range s.Auth.HeaderAuths {
Expand Down Expand Up @@ -1159,6 +1255,10 @@ func (s *Service) Copy() *Service {
authCopy.HeaderAuths[i] = &hCopy
}
}
if s.Auth.MTLSAuth != nil {
mtls := *s.Auth.MTLSAuth
authCopy.MTLSAuth = &mtls
}

return &Service{
ID: s.ID,
Expand Down
144 changes: 144 additions & 0 deletions management/internals/modules/reverseproxy/service/service_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package service

import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"math/big"
"strings"
"testing"
"time"
Expand All @@ -25,6 +31,28 @@ func validProxy() *Service {
}
}

func testCertificatePEM(t *testing.T) string {
t.Helper()

priv, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "Test CA"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
}

der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
require.NoError(t, err)

return string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}))
}

func TestValidate_Valid(t *testing.T) {
require.NoError(t, validProxy().Validate())
}
Expand Down Expand Up @@ -576,6 +604,13 @@ func TestAuthConfig_ClearSecrets(t *testing.T) {
Enabled: true,
Pin: "hashedPin",
},
HeaderAuths: []*HeaderAuthConfig{
{Enabled: true, Header: "X-Test", Value: "hashedHeader"},
},
MTLSAuth: &MTLSAuthConfig{
Enabled: true,
CACertPEM: "ca-pem",
},
}

config.ClearSecrets()
Expand All @@ -586,6 +621,115 @@ func TestAuthConfig_ClearSecrets(t *testing.T) {
if config.PinAuth.Pin != "" {
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
}
require.Len(t, config.HeaderAuths, 1)
if config.HeaderAuths[0].Value != "" {
t.Errorf("Header auth value not cleared, got: %s", config.HeaderAuths[0].Value)
}
require.NotNil(t, config.MTLSAuth)
if config.MTLSAuth.CACertPEM != "" {
t.Errorf("mTLS CA PEM not cleared, got: %s", config.MTLSAuth.CACertPEM)
}
}

func TestValidateMTLSAuth(t *testing.T) {
validPEM := testCertificatePEM(t)

t.Run("disabled allows empty pem", func(t *testing.T) {
require.NoError(t, validateMTLSAuth(&MTLSAuthConfig{Enabled: false}))
})

t.Run("enabled requires pem", func(t *testing.T) {
err := validateMTLSAuth(&MTLSAuthConfig{Enabled: true})
require.Error(t, err)
assert.ErrorContains(t, err, "ca_cert_pem is required")
})

t.Run("enabled accepts valid pem", func(t *testing.T) {
require.NoError(t, validateMTLSAuth(&MTLSAuthConfig{Enabled: true, CACertPEM: validPEM}))
})

t.Run("enabled rejects malformed pem", func(t *testing.T) {
err := validateMTLSAuth(&MTLSAuthConfig{Enabled: true, CACertPEM: "not a pem"})
require.Error(t, err)
assert.ErrorContains(t, err, "contains invalid PEM data")
})
}

func TestServiceValidate_MTLSAuth(t *testing.T) {
validPEM := testCertificatePEM(t)

t.Run("enabled mtls requires pem", func(t *testing.T) {
rp := validProxy()
rp.Auth.MTLSAuth = &MTLSAuthConfig{Enabled: true}
err := rp.Validate()
require.Error(t, err)
assert.ErrorContains(t, err, "mtls_auth: ca_cert_pem is required")
})

t.Run("enabled mtls accepts valid pem", func(t *testing.T) {
rp := validProxy()
rp.Auth.MTLSAuth = &MTLSAuthConfig{Enabled: true, CACertPEM: validPEM}
require.NoError(t, rp.Validate())
})

t.Run("enabled mtls accepts pem bundle with hash comments", func(t *testing.T) {
rp := validProxy()
rp.Auth.MTLSAuth = &MTLSAuthConfig{
Enabled: true,
CACertPEM: "# exported by test tool\n\n" + validPEM + "\n# intermediate follows\n" + testCertificatePEM(t),
}
require.NoError(t, rp.Validate())
})

t.Run("disabled mtls ignores pem", func(t *testing.T) {
rp := validProxy()
rp.Auth.MTLSAuth = &MTLSAuthConfig{Enabled: false}
require.NoError(t, rp.Validate())
})
}

func TestService_ToAPIResponse_RedactsMTLSPEM(t *testing.T) {
rp := validProxy()
rp.ID = "svc-1"
rp.Auth.MTLSAuth = &MTLSAuthConfig{
Enabled: true,
CACertPEM: testCertificatePEM(t),
}

resp := rp.ToAPIResponse()
require.NotNil(t, resp.Auth.MtlsAuth)
assert.True(t, resp.Auth.MtlsAuth.Enabled)
assert.Empty(t, resp.Auth.MtlsAuth.CaCertPem)
}

func TestService_ToProtoMapping_IncludesMTLSAuth(t *testing.T) {
caPEM := testCertificatePEM(t)
rp := validProxy()
rp.ID = "svc-1"
rp.AccountID = "acc-1"
rp.Auth.MTLSAuth = &MTLSAuthConfig{
Enabled: true,
CACertPEM: caPEM,
}

mapping := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.NotNil(t, mapping.Auth)
require.NotNil(t, mapping.Auth.MtlsAuth)
assert.Equal(t, caPEM, mapping.Auth.MtlsAuth.CaCertPem)
}

func TestService_IsAuthEnabled_MTLAware(t *testing.T) {
t.Run("mtls enabled counts as auth", func(t *testing.T) {
rp := validProxy()
rp.Auth.MTLSAuth = &MTLSAuthConfig{Enabled: true, CACertPEM: testCertificatePEM(t)}
assert.True(t, rp.isAuthEnabled())
})

t.Run("mtls disabled does not count as auth", func(t *testing.T) {
rp := validProxy()
rp.Auth.MTLSAuth = &MTLSAuthConfig{Enabled: false}
assert.False(t, rp.isAuthEnabled())
})
}

func TestGenerateExposeName(t *testing.T) {
Expand Down
Loading
Loading