Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor to allow for better testing #27

Merged
merged 3 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions advisor/pkg/k8s/labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

// DetectLabels detects the labels of a pod.
func DetectSelectorLabels(clientset *kubernetes.Clientset, origin interface{}) (map[string]string, error) {
func detectSelectorLabels(clientset *kubernetes.Clientset, origin interface{}) (map[string]string, error) {
// Use type assertion to check the specific type
switch o := origin.(type) {
case *v1.Pod:
Expand All @@ -23,7 +23,7 @@ func DetectSelectorLabels(clientset *kubernetes.Clientset, origin interface{}) (
svc = o.Service
return svc.Spec.Selector, nil
default:
return nil, fmt.Errorf("unknown type")
return nil, fmt.Errorf("detectSelectorLabels: unknown type")
}
}

Expand Down
8 changes: 4 additions & 4 deletions advisor/pkg/k8s/labels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@ func TestDetectSelectorLabels(t *testing.T) {
},
}

labels1, err1 := DetectSelectorLabels(clientset, pod)
labels1, err1 := detectSelectorLabels(clientset, pod)
assert.NoError(t, err1)
assert.Equal(t, map[string]string{"app": "test-app"}, labels1)

labels2, err2 := DetectSelectorLabels(clientset, podDetail)
labels2, err2 := detectSelectorLabels(clientset, podDetail)
assert.NoError(t, err2)
assert.Equal(t, map[string]string{"app": "test-app"}, labels2)

labels3, err3 := DetectSelectorLabels(clientset, serviceDetail)
labels3, err3 := detectSelectorLabels(clientset, serviceDetail)
assert.NoError(t, err3)
assert.Equal(t, map[string]string{"app": "test-app"}, labels3)

_, err4 := DetectSelectorLabels(clientset, "unknown type")
_, err4 := detectSelectorLabels(clientset, "unknown type")
assert.Error(t, err4)
}
258 changes: 154 additions & 104 deletions advisor/pkg/k8s/networkpolicies.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package k8s

import (
"encoding/json"
"strings"

log "github.com/rs/zerolog/log"
api "github.com/xentra-ai/advisor/pkg/api"
networkingv1 "k8s.io/api/networking/v1"
Expand All @@ -27,137 +30,52 @@ type RuleSets struct {
}

func GenerateNetworkPolicy(podName string, config *Config) {

podTraffic, err := api.GetPodTraffic(podName)
if err != nil {
log.Fatal().Err(err).Msg("Error retrieving pod traffic")
return
}

if podTraffic == nil {
log.Fatal().Msgf("No pod traffic found for pod %s\n", podName)
return
}

podDetail, err := api.GetPodSpec(podTraffic[0].SrcIP)
if err != nil {
log.Fatal().Err(err).Msg("Error retrieving pod spec")
return
}

if podDetail == nil {
log.Fatal().Msgf("No pod spec found for pod %s\n", podTraffic[0].SrcIP)
return
}

policy := TransformToNetworkPolicy(&podTraffic, podDetail, config)
policy, err := transformToNetworkPolicy(podTraffic, podDetail, config)
if err != nil {
log.Error().Err(err).Msg("Error transforming policy")
}

policyYAML, err := yaml.Marshal(policy)
if err != nil {
log.Error().Err(err).Msg("Error converting policy to YAML")
return
}
log.Info().Msgf("Generated policy for pod %s:\n%s", podName, string(policyYAML))
}

func TransformToNetworkPolicy(podTraffic *[]api.PodTraffic, podDetail *api.PodDetail, config *Config) *networkingv1.NetworkPolicy {
var ingressRules []networkingv1.NetworkPolicyIngressRule
var egressRules []networkingv1.NetworkPolicyEgressRule

podSelectorLabels, err := DetectSelectorLabels(config.Clientset, &podDetail.Pod)
func transformToNetworkPolicy(podTraffic []api.PodTraffic, podDetail *api.PodDetail, config *Config) (*networkingv1.NetworkPolicy, error) {
ingressRulesRaw, err := processIngressRules(podTraffic, config)
if err != nil {
// This would mean a controller was detected but may no longer exist due to the pod being deleted but still present in the database
// TODO: Handle this case
log.Error().Err(err).Msg("Detect Pod Labels")
return nil
return nil, err
}
egressRulesRaw, err := processEgressRules(podTraffic, config)
if err != nil {
return nil, err
}

for _, traffic := range *podTraffic {
var origin interface{}

// Get pod spec for the pod that is sending traffic
podOrigin, err := api.GetPodSpec(traffic.DstIP)
if err != nil {
log.Error().Err(err).Msg("Get Pod Spec of origin")
}
if podOrigin != nil {
origin = podOrigin
}

// If we couldn't get the Pod details, try getting the Service details
if origin == nil {
svcOrigin, err := api.GetSvcSpec(traffic.DstIP)
if err != nil {
log.Error().Err(err).Msg("Get Svc Spec of origin")
continue
} else if svcOrigin != nil {
origin = svcOrigin
}
}

if origin == nil {
log.Debug().Msgf("Could not find details for origin assuming IP is external %s", traffic.DstIP)
}

var metadata metav1.ObjectMeta
var peerSelectorLabels map[string]string
var peer *networkingv1.NetworkPolicyPeer
// If the traffic originated from in-cluster as either a pod or service
if origin != nil {
peerSelectorLabels, err = DetectSelectorLabels(config.Clientset, origin)
if err != nil {
log.Error().Err(err).Msg("Detect Peer Labels")
continue
}
switch o := origin.(type) {
case *api.PodDetail:
metadata = o.Pod.ObjectMeta
case *api.SvcDetail:
metadata = o.Service.ObjectMeta
default:
log.Error().Msg("Unknown type for origin")
continue
}
peer = &networkingv1.NetworkPolicyPeer{
PodSelector: &metav1.LabelSelector{
MatchLabels: peerSelectorLabels,
},
NamespaceSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{"kubernetes.io/metadata.name": metadata.Namespace},
},
}
} else {
peer = &networkingv1.NetworkPolicyPeer{
IPBlock: &networkingv1.IPBlock{
CIDR: traffic.DstIP + "/32",
},
}
}
ingressRules := deduplicateIngressRules(ingressRulesRaw)
egressRules := deduplicateEgressRules(egressRulesRaw)

protocol := traffic.Protocol
if traffic.TrafficType == "INGRESS" {
port := intstr.Parse(traffic.SrcPodPort)
ingressRules = append(ingressRules, networkingv1.NetworkPolicyIngressRule{
Ports: []networkingv1.NetworkPolicyPort{
{
Protocol: &protocol,
Port: &port,
},
},
From: []networkingv1.NetworkPolicyPeer{*peer},
})
} else if traffic.TrafficType == "EGRESS" {
port := intstr.Parse(traffic.DstPort)

egressRules = append(egressRules, networkingv1.NetworkPolicyEgressRule{
Ports: []networkingv1.NetworkPolicyPort{
{
Protocol: &protocol,
Port: &port,
},
},
To: []networkingv1.NetworkPolicyPeer{*peer},
})
}
podSelectorLabels, err := detectSelectorLabels(config.Clientset, &podDetail.Pod)
if err != nil {
return nil, err
}

networkPolicy := &networkingv1.NetworkPolicy{
Expand All @@ -168,7 +86,6 @@ func TransformToNetworkPolicy(podTraffic *[]api.PodTraffic, podDetail *api.PodDe
ObjectMeta: metav1.ObjectMeta{
Name: podDetail.Name,
Namespace: podDetail.Namespace,
// TODO: What labels should we use?
Labels: map[string]string{
"advisor.xentra.ai/managed-by": "xentra",
"advisor.xentra.ai/version": "0.0.1",
Expand All @@ -187,5 +104,138 @@ func TransformToNetworkPolicy(podTraffic *[]api.PodTraffic, podDetail *api.PodDe
},
}

return networkPolicy
return networkPolicy, nil
}

func processIngressRules(podTraffic []api.PodTraffic, config *Config) ([]networkingv1.NetworkPolicyIngressRule, error) {
var ingressRules []networkingv1.NetworkPolicyIngressRule
for _, traffic := range podTraffic {
if strings.ToUpper(traffic.TrafficType) != "INGRESS" {
continue
}
peer, err := determinePeerForTraffic(traffic, config)
if err != nil {
return nil, err
}
protocol := traffic.Protocol
portIntOrString := intstr.Parse(traffic.SrcPodPort)
portPtr := &portIntOrString
ingressRules = append(ingressRules, networkingv1.NetworkPolicyIngressRule{
Ports: []networkingv1.NetworkPolicyPort{
{
Protocol: &protocol,
Port: portPtr,
},
},
From: []networkingv1.NetworkPolicyPeer{*peer},
})
}
return ingressRules, nil
}

func processEgressRules(podTraffic []api.PodTraffic, config *Config) ([]networkingv1.NetworkPolicyEgressRule, error) {
var egressRules []networkingv1.NetworkPolicyEgressRule
for _, traffic := range podTraffic {
if strings.ToUpper(traffic.TrafficType) != "EGRESS" {
continue
}
peer, err := determinePeerForTraffic(traffic, config)
if err != nil {
return nil, err
}
protocol := traffic.Protocol
portIntOrString := intstr.Parse(traffic.DstPort)
portPtr := &portIntOrString
egressRules = append(egressRules, networkingv1.NetworkPolicyEgressRule{
Ports: []networkingv1.NetworkPolicyPort{
{
Protocol: &protocol,
Port: portPtr,
},
},
To: []networkingv1.NetworkPolicyPeer{*peer},
})
}
return egressRules, nil
}

func determinePeerForTraffic(traffic api.PodTraffic, config *Config) (*networkingv1.NetworkPolicyPeer, error) {
var origin interface{} = nil

podOrigin, err := api.GetPodSpec(traffic.DstIP)
if err != nil {
return nil, err
}
if podOrigin != nil {
origin = podOrigin
}

if origin == nil {
svcOrigin, err := api.GetSvcSpec(traffic.DstIP)
if err != nil {
return nil, err
}
if svcOrigin != nil {
origin = svcOrigin
}
}

if origin == nil {
log.Debug().Msgf("Could not find details for origin assuming IP is external %s", traffic.DstIP)
return &networkingv1.NetworkPolicyPeer{
IPBlock: &networkingv1.IPBlock{
CIDR: traffic.DstIP + "/32",
},
}, nil
}

peerSelectorLabels, err := detectSelectorLabels(config.Clientset, origin)
if err != nil {
return nil, err
}

var metadata metav1.ObjectMeta
switch o := origin.(type) {
case *api.PodDetail:
metadata = o.Pod.ObjectMeta
case *api.SvcDetail:
metadata = o.Service.ObjectMeta
}

return &networkingv1.NetworkPolicyPeer{
PodSelector: &metav1.LabelSelector{
MatchLabels: peerSelectorLabels,
},
NamespaceSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{"kubernetes.io/metadata.name": metadata.Namespace},
},
}, nil
}

func deduplicateIngressRules(rules []networkingv1.NetworkPolicyIngressRule) []networkingv1.NetworkPolicyIngressRule {
seen := make(map[string]bool)
var deduplicated []networkingv1.NetworkPolicyIngressRule

for _, rule := range rules {
ruleStr, _ := json.Marshal(rule)
if !seen[string(ruleStr)] {
seen[string(ruleStr)] = true
deduplicated = append(deduplicated, rule)
}
}
return deduplicated
}

func deduplicateEgressRules(rules []networkingv1.NetworkPolicyEgressRule) []networkingv1.NetworkPolicyEgressRule {
seen := make(map[string]bool)
var deduplicated []networkingv1.NetworkPolicyEgressRule

for _, rule := range rules {
ruleStr, _ := json.Marshal(rule)
if !seen[string(ruleStr)] {
seen[string(ruleStr)] = true
deduplicated = append(deduplicated, rule)
}
}
return deduplicated
}