Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion client/internal/engine_ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,16 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
}

if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}

log.Debugf("starting SSH server with JWT authentication: audiences=%v", audiences)

jwtConfig := &sshserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audience: protoJWT.GetAudience(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
}
Expand Down
2 changes: 1 addition & 1 deletion client/ssh/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func TestSSHProxy_Connect(t *testing.T) {
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
Expand Down
113 changes: 109 additions & 4 deletions client/ssh/server/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestJWTEnforcement(t *testing.T) {
t.Run("blocks_without_jwt", func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
Audiences: []string{"test-audience"},
KeysLocation: "test-keys",
}
serverConfig := &Config{
Expand Down Expand Up @@ -202,7 +202,7 @@ func TestJWTDetection(t *testing.T) {

jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: []string{audience},
KeysLocation: jwksURL,
}
serverConfig := &Config{
Expand Down Expand Up @@ -329,7 +329,7 @@ func TestJWTFailClose(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: []string{audience},
KeysLocation: jwksURL,
MaxTokenAge: 3600,
}
Expand Down Expand Up @@ -567,7 +567,7 @@ func TestJWTAuthentication(t *testing.T) {

jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: []string{audience},
KeysLocation: jwksURL,
}
serverConfig := &Config{
Expand Down Expand Up @@ -646,3 +646,108 @@ func TestJWTAuthentication(t *testing.T) {
})
}
}

// TestJWTMultipleAudiences tests JWT validation with multiple audiences (dashboard and CLI).
func TestJWTMultipleAudiences(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT multiple audiences tests in short mode")
}

jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()

const (
issuer = "https://test-issuer.example.com"
dashboardAudience = "dashboard-audience"
cliAudience = "cli-audience"
)

hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)

testCases := []struct {
name string
audience string
wantAuthOK bool
}{
{
name: "accepts_dashboard_audience",
audience: dashboardAudience,
wantAuthOK: true,
},
{
name: "accepts_cli_audience",
audience: cliAudience,
wantAuthOK: true,
},
{
name: "rejects_unknown_audience",
audience: "unknown-audience",
wantAuthOK: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audiences: []string{dashboardAudience, cliAudience},
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)

testUserHash, err := sshuserhash.HashUserID("test-user")
require.NoError(t, err)

currentUser := testutil.GetTestUsername(t)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
currentUser: {0},
},
}
server.UpdateSSHAuth(authConfig)

serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())

host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)

token := generateValidJWT(t, privateKey, issuer, tc.audience)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{
cryptossh.Password(token),
},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}

conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if tc.wantAuthOK {
require.NoError(t, err, "JWT authentication should succeed for audience %s", tc.audience)
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()

session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()

err = session.Shell()
require.NoError(t, err, "Shell should work with valid audience")
} else {
assert.Error(t, err, "JWT authentication should fail for unknown audience")
}
})
}
}
15 changes: 9 additions & 6 deletions client/ssh/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ type Server struct {

type JWTConfig struct {
Issuer string
Audience string
KeysLocation string
MaxTokenAge int64
Audiences []string
}

// Config contains all SSH server configuration options
Expand Down Expand Up @@ -427,18 +427,21 @@ func (s *Server) ensureJWTValidator() error {
return fmt.Errorf("JWT config not set")
}

log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
if len(config.Audiences) == 0 {
return fmt.Errorf("JWT config has no audiences configured")
}

log.Debugf("Initializing JWT validator (issuer: %s, audiences: %v)", config.Issuer, config.Audiences)
validator := jwt.NewValidator(
config.Issuer,
[]string{config.Audience},
config.Audiences,
config.KeysLocation,
true,
)

// Use custom userIDClaim from authorizer if available
extractorOptions := []jwt.ClaimsExtractorOption{
jwt.WithAudience(config.Audience),
jwt.WithAudience(config.Audiences[0]),
}
if authorizer.GetUserIDClaim() != "" {
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
Expand Down Expand Up @@ -475,8 +478,8 @@ func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
if err != nil {
if jwtConfig != nil {
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
return nil, fmt.Errorf("validate token (expected issuer=%s, audiences=%v, actual issuer=%v, audience=%v): %w",
jwtConfig.Issuer, jwtConfig.Audiences, claims["iss"], claims["aud"], err)
}
}
return nil, fmt.Errorf("validate token: %w", err)
Expand Down
7 changes: 7 additions & 0 deletions management/internals/shared/grpc/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,16 @@ func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfi
if config.CLIAuthAudience != "" {
audience = config.CLIAuthAudience
}

audiences := []string{config.AuthAudience}
if config.CLIAuthAudience != "" && config.CLIAuthAudience != config.AuthAudience {
audiences = append(audiences, config.CLIAuthAudience)
}

return &proto.JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: audiences,
KeysLocation: keysLocation,
}
}
Expand Down
51 changes: 51 additions & 0 deletions management/internals/shared/grpc/conversion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ import (
"reflect"
"testing"

"github.com/stretchr/testify/assert"

nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)

func TestToProtocolDNSConfigWithCache(t *testing.T) {
Expand Down Expand Up @@ -148,3 +151,51 @@ func generateTestData(size int) nbdns.Config {

return config
}

func TestBuildJWTConfig_Audiences(t *testing.T) {
tests := []struct {
name string
authAudience string
cliAuthAudience string
expectedAudiences []string
expectedAudience string
}{
{
name: "only_auth_audience",
authAudience: "dashboard-aud",
cliAuthAudience: "",
expectedAudiences: []string{"dashboard-aud"},
expectedAudience: "dashboard-aud",
},
{
name: "both_audiences_different",
authAudience: "dashboard-aud",
cliAuthAudience: "cli-aud",
expectedAudiences: []string{"dashboard-aud", "cli-aud"},
expectedAudience: "cli-aud",
},
{
name: "both_audiences_same",
authAudience: "same-aud",
cliAuthAudience: "same-aud",
expectedAudiences: []string{"same-aud"},
expectedAudience: "same-aud",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
config := &nbconfig.HttpServerConfig{
AuthIssuer: "https://issuer.example.com",
AuthAudience: tc.authAudience,
CLIAuthAudience: tc.cliAuthAudience,
}

result := buildJWTConfig(config, nil)

assert.NotNil(t, result)
assert.Equal(t, tc.expectedAudiences, result.Audiences, "audiences should match expected")
assert.Equal(t, tc.expectedAudience, result.Audience, "audience should match expected")
})
}
}
Loading