Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (

nbpeer "github.com/netbirdio/netbird/management/server/peer"

resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"

"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
Expand Down Expand Up @@ -636,18 +638,12 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
for _, target := range targets {
switch target.TargetType {
case service.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
if err := validatePeerTarget(ctx, transaction, accountID, target); err != nil {
return err
}
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
return err
}
default:
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
Expand All @@ -656,6 +652,39 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
return nil
}

func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
return nil
}

func validateResourceTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId)
if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
return validateResourceTargetType(target, resource)
}

// validateResourceTargetType checks that target_type matches the actual network resource type.
func validateResourceTargetType(target *service.Target, resource *resourcetypes.NetworkResource) error {
expected := resourcetypes.NetworkResourceType(target.TargetType)
if resource.Type != expected {
return status.Errorf(status.InvalidArgument,
"target %q has target_type %q but resource is of type %q",
target.TargetId, target.TargetType, resource.Type,
)
}
return nil
}

func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
Expand Down Expand Up @@ -1214,3 +1215,60 @@ func TestValidateProtocolChange(t *testing.T) {
})
}
}

func TestValidateTargetReferences_ResourceTypeMismatch(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"

tests := []struct {
name string
targetType rpservice.TargetType
resourceType resourcetypes.NetworkResourceType
wantErr bool
}{
{"host matches host", rpservice.TargetTypeHost, resourcetypes.Host, false},
{"domain matches domain", rpservice.TargetTypeDomain, resourcetypes.Domain, false},
{"subnet matches subnet", rpservice.TargetTypeSubnet, resourcetypes.Subnet, false},
{"host but resource is domain", rpservice.TargetTypeHost, resourcetypes.Domain, true},
{"domain but resource is host", rpservice.TargetTypeDomain, resourcetypes.Host, true},
{"host but resource is subnet", rpservice.TargetTypeHost, resourcetypes.Subnet, true},
{"subnet but resource is domain", rpservice.TargetTypeSubnet, resourcetypes.Domain, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore.EXPECT().
GetNetworkResourceByID(gomock.Any(), store.LockingStrengthShare, accountID, "resource-1").
Return(&resourcetypes.NetworkResource{Type: tt.resourceType}, nil)

targets := []*rpservice.Target{
{TargetId: "resource-1", TargetType: tt.targetType, Host: "10.0.0.1"},
}
err := validateTargetReferences(ctx, mockStore, accountID, targets)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "target_type")
} else {
require.NoError(t, err)
}
})
}
}

func TestValidateTargetReferences_PeerValid(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"

mockStore.EXPECT().
GetPeerByID(gomock.Any(), store.LockingStrengthShare, accountID, "peer-1").
Return(&nbpeer.Peer{}, nil)

targets := []*rpservice.Target{
{TargetId: "peer-1", TargetType: rpservice.TargetTypePeer},
}
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets))
}
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ func (s *Service) validateL4Target(target *Target) error {
return errors.New("target_id is required for L4 services")
}
switch target.TargetType {
case TargetTypePeer, TargetTypeHost:
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// OK
case TargetTypeSubnet:
if target.Host == "" {
Expand Down
26 changes: 26 additions & 0 deletions management/internals/modules/reverseproxy/service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,32 @@ func TestValidate_TLSSubnetValid(t *testing.T) {
require.NoError(t, rp.Validate())
}

func TestValidate_L4DomainTargetValid(t *testing.T) {
modes := []struct {
mode string
port uint16
proto string
}{
{"tcp", 5432, "tcp"},
{"tls", 443, "tcp"},
{"udp", 5432, "udp"},
}
for _, m := range modes {
t.Run(m.mode, func(t *testing.T) {
rp := &Service{
Name: m.mode + "-domain",
Mode: m.mode,
Domain: "cluster.test",
ListenPort: m.port,
Targets: []*Target{
{TargetId: "resource-1", TargetType: TargetTypeDomain, Protocol: m.proto, Port: m.port, Enabled: true},
},
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
require.NoError(t, rp.Validate())
})
}
}

func TestValidate_HTTPProxyProtocolRejected(t *testing.T) {
rp := validProxy()
rp.Targets[0].ProxyProtocol = true
Expand Down
Loading