From 63c26116d978e76470b32e0f748470f3551d6a53 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 17 Mar 2026 14:54:16 +0100 Subject: [PATCH 1/2] Accept domain target type for L4 services and validate resource type matches --- .../reverseproxy/service/manager/manager.go | 20 ++++++- .../service/manager/manager_test.go | 58 +++++++++++++++++++ .../modules/reverseproxy/service/service.go | 2 +- .../reverseproxy/service/service_test.go | 25 ++++++++ 4 files changed, 103 insertions(+), 2 deletions(-) diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 65177bf5da9..457ee7b6830 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -14,6 +14,8 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" + resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" @@ -643,12 +645,16 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco return fmt.Errorf("look up peer target %q: %w", target.TargetId, err) } case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain: - if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { + resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId) + if err != nil { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId) } return fmt.Errorf("look up resource target %q: %w", target.TargetId, err) } + if err := validateResourceTargetType(target, resource); err != nil { + return err + } default: return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId) } @@ -656,6 +662,18 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco return nil } +// validateResourceTargetType checks that target_type matches the actual network resource type. +func validateResourceTargetType(target *service.Target, resource *resourcetypes.NetworkResource) error { + expected := resourcetypes.NetworkResourceType(target.TargetType) + if resource.Type != expected { + return status.Errorf(status.InvalidArgument, + "target %q has target_type %q but resource is of type %q", + target.TargetId, target.TargetType, resource.Type, + ) + } + return nil +} + func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) if err != nil { diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index d23c91017c5..0c34f81a219 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/mock_server" + resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -1214,3 +1215,60 @@ func TestValidateProtocolChange(t *testing.T) { }) } } + +func TestValidateTargetReferences_ResourceTypeMismatch(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + mockStore := store.NewMockStore(ctrl) + accountID := "test-account" + + tests := []struct { + name string + targetType rpservice.TargetType + resourceType resourcetypes.NetworkResourceType + wantErr bool + }{ + {"host matches host", rpservice.TargetTypeHost, resourcetypes.Host, false}, + {"domain matches domain", rpservice.TargetTypeDomain, resourcetypes.Domain, false}, + {"subnet matches subnet", rpservice.TargetTypeSubnet, resourcetypes.Subnet, false}, + {"host but resource is domain", rpservice.TargetTypeHost, resourcetypes.Domain, true}, + {"domain but resource is host", rpservice.TargetTypeDomain, resourcetypes.Host, true}, + {"host but resource is subnet", rpservice.TargetTypeHost, resourcetypes.Subnet, true}, + {"subnet but resource is domain", rpservice.TargetTypeSubnet, resourcetypes.Domain, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockStore.EXPECT(). + GetNetworkResourceByID(gomock.Any(), store.LockingStrengthShare, accountID, "resource-1"). + Return(&resourcetypes.NetworkResource{Type: tt.resourceType}, nil) + + targets := []*rpservice.Target{ + {TargetId: "resource-1", TargetType: tt.targetType, Host: "10.0.0.1"}, + } + err := validateTargetReferences(ctx, mockStore, accountID, targets) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "target_type") + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateTargetReferences_PeerValid(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + mockStore := store.NewMockStore(ctrl) + accountID := "test-account" + + mockStore.EXPECT(). + GetPeerByID(gomock.Any(), store.LockingStrengthShare, accountID, "peer-1"). + Return(&nbpeer.Peer{}, nil) + + targets := []*rpservice.Target{ + {TargetId: "peer-1", TargetType: rpservice.TargetTypePeer}, + } + require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets)) +} diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index 6c7c8080689..c00d494214a 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -790,7 +790,7 @@ func (s *Service) validateL4Target(target *Target) error { return errors.New("target_id is required for L4 services") } switch target.TargetType { - case TargetTypePeer, TargetTypeHost: + case TargetTypePeer, TargetTypeHost, TargetTypeDomain: // OK case TargetTypeSubnet: if target.Host == "" { diff --git a/management/internals/modules/reverseproxy/service/service_test.go b/management/internals/modules/reverseproxy/service/service_test.go index 9daf729fe58..9b7d0493841 100644 --- a/management/internals/modules/reverseproxy/service/service_test.go +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -847,6 +847,31 @@ func TestValidate_TLSSubnetValid(t *testing.T) { require.NoError(t, rp.Validate()) } +func TestValidate_L4DomainTargetValid(t *testing.T) { + modes := []struct { + mode string + port uint16 + }{ + {"tcp", 5432}, + {"tls", 443}, + {"udp", 5432}, + } + for _, m := range modes { + t.Run(m.mode, func(t *testing.T) { + rp := &Service{ + Name: m.mode + "-domain", + Mode: m.mode, + Domain: "cluster.test", + ListenPort: m.port, + Targets: []*Target{ + {TargetId: "resource-1", TargetType: TargetTypeDomain, Protocol: "tcp", Port: m.port, Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) + }) + } +} + func TestValidate_HTTPProxyProtocolRejected(t *testing.T) { rp := validProxy() rp.Targets[0].ProxyProtocol = true From 8d158352fb8fc28c8c8162221d3711d51e83bbce Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 17 Mar 2026 15:14:50 +0100 Subject: [PATCH 2/2] Fix UDP test protocol and reduce validateTargetReferences complexity --- .../reverseproxy/service/manager/manager.go | 37 ++++++++++++------- .../reverseproxy/service/service_test.go | 13 ++++--- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 457ee7b6830..2251f508496 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -638,21 +638,11 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco for _, target := range targets { switch target.TargetType { case service.TargetTypePeer: - if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { - if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { - return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId) - } - return fmt.Errorf("look up peer target %q: %w", target.TargetId, err) + if err := validatePeerTarget(ctx, transaction, accountID, target); err != nil { + return err } case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain: - resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId) - if err != nil { - if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { - return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId) - } - return fmt.Errorf("look up resource target %q: %w", target.TargetId, err) - } - if err := validateResourceTargetType(target, resource); err != nil { + if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil { return err } default: @@ -662,6 +652,27 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco return nil } +func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error { + if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { + if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { + return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId) + } + return fmt.Errorf("look up peer target %q: %w", target.TargetId, err) + } + return nil +} + +func validateResourceTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error { + resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId) + if err != nil { + if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { + return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId) + } + return fmt.Errorf("look up resource target %q: %w", target.TargetId, err) + } + return validateResourceTargetType(target, resource) +} + // validateResourceTargetType checks that target_type matches the actual network resource type. func validateResourceTargetType(target *service.Target, resource *resourcetypes.NetworkResource) error { expected := resourcetypes.NetworkResourceType(target.TargetType) diff --git a/management/internals/modules/reverseproxy/service/service_test.go b/management/internals/modules/reverseproxy/service/service_test.go index 9b7d0493841..3fe07b1d0fa 100644 --- a/management/internals/modules/reverseproxy/service/service_test.go +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -849,12 +849,13 @@ func TestValidate_TLSSubnetValid(t *testing.T) { func TestValidate_L4DomainTargetValid(t *testing.T) { modes := []struct { - mode string - port uint16 + mode string + port uint16 + proto string }{ - {"tcp", 5432}, - {"tls", 443}, - {"udp", 5432}, + {"tcp", 5432, "tcp"}, + {"tls", 443, "tcp"}, + {"udp", 5432, "udp"}, } for _, m := range modes { t.Run(m.mode, func(t *testing.T) { @@ -864,7 +865,7 @@ func TestValidate_L4DomainTargetValid(t *testing.T) { Domain: "cluster.test", ListenPort: m.port, Targets: []*Target{ - {TargetId: "resource-1", TargetType: TargetTypeDomain, Protocol: "tcp", Port: m.port, Enabled: true}, + {TargetId: "resource-1", TargetType: TargetTypeDomain, Protocol: m.proto, Port: m.port, Enabled: true}, }, } require.NoError(t, rp.Validate())