Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions internal/auth/internal/types/computed_traffic_permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedComputedTrafficPermissions = resource.DecodedResource[*pbauth.ComputedTrafficPermissions]

func RegisterComputedTrafficPermission(r resource.Registry) {
r.Register(resource.Registration{
Type: pbauth.ComputedTrafficPermissionsType,
Expand All @@ -26,16 +28,12 @@ func RegisterComputedTrafficPermission(r resource.Registry) {
})
}

func ValidateComputedTrafficPermissions(res *pbresource.Resource) error {
var ctp pbauth.ComputedTrafficPermissions

if err := res.Data.UnmarshalTo(&ctp); err != nil {
return resource.NewErrDataParse(&ctp, err)
}
var ValidateComputedTrafficPermissions = resource.DecodeAndValidate(validateComputedTrafficPermissions)

func validateComputedTrafficPermissions(res *DecodedComputedTrafficPermissions) error {
var merr error

for i, permission := range ctp.AllowPermissions {
for i, permission := range res.Data.AllowPermissions {
wrapErr := func(err error) error {
return resource.ErrInvalidListElement{
Name: "allow_permissions",
Expand All @@ -48,7 +46,7 @@ func ValidateComputedTrafficPermissions(res *pbresource.Resource) error {
}
}

for i, permission := range ctp.DenyPermissions {
for i, permission := range res.Data.DenyPermissions {
wrapErr := func(err error) error {
return resource.ErrInvalidListElement{
Name: "deny_permissions",
Expand Down
64 changes: 17 additions & 47 deletions internal/auth/internal/types/traffic_permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedTrafficPermissions = resource.DecodedResource[*pbauth.TrafficPermissions]

func RegisterTrafficPermissions(r resource.Registry) {
r.Register(resource.Registration{
Type: pbauth.TrafficPermissionsType,
Proto: &pbauth.TrafficPermissions{},
ACLs: &resource.ACLHooks{
Read: aclReadHookTrafficPermissions,
Write: aclWriteHookTrafficPermissions,
Read: resource.DecodeAndAuthorizeRead(aclReadHookTrafficPermissions),
Write: resource.DecodeAndAuthorizeWrite(aclWriteHookTrafficPermissions),
List: resource.NoOpACLListHook,
},
Validate: ValidateTrafficPermissions,
Expand All @@ -27,28 +29,20 @@ func RegisterTrafficPermissions(r resource.Registry) {
})
}

func MutateTrafficPermissions(res *pbresource.Resource) error {
var tp pbauth.TrafficPermissions

if err := res.Data.UnmarshalTo(&tp); err != nil {
return resource.NewErrDataParse(&tp, err)
}
var MutateTrafficPermissions = resource.DecodeAndMutate(mutateTrafficPermissions)

func mutateTrafficPermissions(res *DecodedTrafficPermissions) (bool, error) {
var changed bool

for _, p := range tp.Permissions {
for _, p := range res.Data.Permissions {
for _, s := range p.Sources {
if updated := normalizedTenancyForSource(s, res.Id.Tenancy); updated {
changed = true
}
}
}

if !changed {
return nil
}

return res.Data.MarshalFrom(&tp)
return changed, nil
}

func normalizedTenancyForSource(src *pbauth.Source, parentTenancy *pbresource.Tenancy) bool {
Expand Down Expand Up @@ -110,17 +104,13 @@ func firstNonEmptyString(a, b, c string) (string, bool) {
return c, true
}

func ValidateTrafficPermissions(res *pbresource.Resource) error {
var tp pbauth.TrafficPermissions

if err := res.Data.UnmarshalTo(&tp); err != nil {
return resource.NewErrDataParse(&tp, err)
}
var ValidateTrafficPermissions = resource.DecodeAndValidate(validateTrafficPermissions)

func validateTrafficPermissions(res *DecodedTrafficPermissions) error {
var merr error

// enumcover:pbauth.Action
switch tp.Action {
switch res.Data.Action {
case pbauth.Action_ACTION_ALLOW:
case pbauth.Action_ACTION_DENY:
case pbauth.Action_ACTION_UNSPECIFIED:
Expand All @@ -132,14 +122,14 @@ func ValidateTrafficPermissions(res *pbresource.Resource) error {
})
}

if tp.Destination == nil || (len(tp.Destination.IdentityName) == 0) {
if res.Data.Destination == nil || (len(res.Data.Destination.IdentityName) == 0) {
merr = multierror.Append(merr, resource.ErrInvalidField{
Name: "data.destination",
Wrapped: resource.ErrEmpty,
})
}
// Validate permissions
for i, permission := range tp.Permissions {
for i, permission := range res.Data.Permissions {
wrapErr := func(err error) error {
return resource.ErrInvalidListElement{
Name: "permissions",
Expand Down Expand Up @@ -271,30 +261,10 @@ func isLocalPeer(p string) bool {
return p == "local" || p == ""
}

func aclReadHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, _ *pbresource.ID, res *pbresource.Resource) error {
if res == nil {
return resource.ErrNeedResource
}
return authorizeDestination(res, func(dest string) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsReadAllowed(dest, authzContext)
})
func aclReadHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *DecodedTrafficPermissions) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsReadAllowed(res.Data.Destination.IdentityName, authzContext)
}

func aclWriteHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *pbresource.Resource) error {
return authorizeDestination(res, func(dest string) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsWriteAllowed(dest, authzContext)
})
}

func authorizeDestination(res *pbresource.Resource, intentionAllowed func(string) error) error {
tp, err := resource.Decode[*pbauth.TrafficPermissions](res)
if err != nil {
return err
}
// Check intention:x permissions for identity
err = intentionAllowed(tp.Data.Destination.IdentityName)
if err != nil {
return err
}
return nil
func aclWriteHookTrafficPermissions(authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *DecodedTrafficPermissions) error {
return authorizer.ToAllowAuthorizer().TrafficPermissionsWriteAllowed(res.Data.Destination.IdentityName, authzContext)
}
11 changes: 10 additions & 1 deletion internal/auth/internal/types/workload_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedWorkloadIdentity = resource.DecodedResource[*pbauth.WorkloadIdentity]

func RegisterWorkloadIdentity(r resource.Registry) {
r.Register(resource.Registration{
Type: pbauth.WorkloadIdentityType,
Expand All @@ -20,10 +22,17 @@ func RegisterWorkloadIdentity(r resource.Registry) {
Write: aclWriteHookWorkloadIdentity,
List: resource.NoOpACLListHook,
},
Validate: nil,
Validate: ValidateWorkloadIdentity,
})
}

var ValidateWorkloadIdentity = resource.DecodeAndValidate(validateWorkloadIdentity)

func validateWorkloadIdentity(res *DecodedWorkloadIdentity) error {
// currently the WorkloadIdentity type has no fields.
return nil
}

func aclReadHookWorkloadIdentity(
authorizer acl.Authorizer,
authzCtx *acl.AuthorizerContext,
Expand Down
10 changes: 10 additions & 0 deletions internal/auth/internal/types/workload_identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,13 @@ func TestWorkloadIdentityACLs(t *testing.T) {
})
}
}

func TestWorkloadIdentity_ParseError(t *testing.T) {
rsc := resourcetest.Resource(pbauth.WorkloadIdentityType, "example").
WithData(t, &pbauth.TrafficPermissions{}).
Build()

err := ValidateWorkloadIdentity(rsc)
var parseErr resource.ErrDataParse
require.ErrorAs(t, err, &parseErr)
}
19 changes: 5 additions & 14 deletions internal/catalog/internal/types/acl_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,22 @@ func aclReadHookResourceWithWorkloadSelector(authorizer acl.Authorizer, authzCon
return authorizer.ToAllowAuthorizer().ServiceReadAllowed(id.GetName(), authzContext)
}

func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, res *pbresource.Resource) error {
if res == nil {
return resource.ErrNeedResource
}

decodedService, err := resource.Decode[T](res)
if err != nil {
return resource.ErrNeedResource
}

func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer acl.Authorizer, authzContext *acl.AuthorizerContext, r *resource.DecodedResource[T]) error {
// First check service:write on the name.
err = authorizer.ToAllowAuthorizer().ServiceWriteAllowed(res.GetId().GetName(), authzContext)
err := authorizer.ToAllowAuthorizer().ServiceWriteAllowed(r.GetId().GetName(), authzContext)
if err != nil {
return err
}

// Then also check whether we're allowed to select a service.
for _, name := range decodedService.GetData().GetWorkloads().GetNames() {
for _, name := range r.Data.GetWorkloads().GetNames() {
err = authorizer.ToAllowAuthorizer().ServiceReadAllowed(name, authzContext)
if err != nil {
return err
}
}

for _, prefix := range decodedService.GetData().GetWorkloads().GetPrefixes() {
for _, prefix := range r.Data.GetWorkloads().GetPrefixes() {
err = authorizer.ToAllowAuthorizer().ServiceReadPrefixAllowed(prefix, authzContext)
if err != nil {
return err
Expand All @@ -50,7 +41,7 @@ func aclWriteHookResourceWithWorkloadSelector[T WorkloadSelecting](authorizer ac
func ACLHooksForWorkloadSelectingType[T WorkloadSelecting]() *resource.ACLHooks {
return &resource.ACLHooks{
Read: aclReadHookResourceWithWorkloadSelector,
Write: aclWriteHookResourceWithWorkloadSelector[T],
Write: resource.DecodeAndAuthorizeWrite(aclWriteHookResourceWithWorkloadSelector[T]),
List: resource.NoOpACLListHook,
}
}
15 changes: 6 additions & 9 deletions internal/catalog/internal/types/dns_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import (

"github.com/hashicorp/consul/internal/resource"
pbcatalog "github.com/hashicorp/consul/proto-public/pbcatalog/v2beta1"
"github.com/hashicorp/consul/proto-public/pbresource"
)

type DecodedDNSPolicy = resource.DecodedResource[*pbcatalog.DNSPolicy]

func RegisterDNSPolicy(r resource.Registry) {
r.Register(resource.Registration{
Type: pbcatalog.DNSPolicyType,
Expand All @@ -23,25 +24,21 @@ func RegisterDNSPolicy(r resource.Registry) {
})
}

func ValidateDNSPolicy(res *pbresource.Resource) error {
var policy pbcatalog.DNSPolicy

if err := res.Data.UnmarshalTo(&policy); err != nil {
return resource.NewErrDataParse(&policy, err)
}
var ValidateDNSPolicy = resource.DecodeAndValidate(validateDNSPolicy)

func validateDNSPolicy(res *DecodedDNSPolicy) error {
var err error
// Ensure that this resource isn't useless and is attempting to
// select at least one workload.
if selErr := ValidateSelector(policy.Workloads, false); selErr != nil {
if selErr := ValidateSelector(res.Data.Workloads, false); selErr != nil {
err = multierror.Append(err, resource.ErrInvalidField{
Name: "workloads",
Wrapped: selErr,
})
}

// Validate the weights
if weightErr := validateDNSPolicyWeights(policy.Weights); weightErr != nil {
if weightErr := validateDNSPolicyWeights(res.Data.Weights); weightErr != nil {
err = multierror.Append(err, resource.ErrInvalidField{
Name: "weights",
Wrapped: weightErr,
Expand Down
Loading