From f26025c68856d21e19e5b7cb2147089d691fa024 Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Tue, 29 Oct 2024 18:16:20 +0000 Subject: [PATCH 01/14] Adding initial code Signed-off-by: Amit Schendel --- .../process_manager_interface.go | 18 + pkg/processmanager/process_manager_mock.go | 28 ++ pkg/processmanager/v1/process_manager.go | 423 ++++++++++++++++++ pkg/processmanager/v1/process_manager_test.go | 357 +++++++++++++++ 4 files changed, 826 insertions(+) create mode 100644 pkg/processmanager/process_manager_interface.go create mode 100644 pkg/processmanager/process_manager_mock.go create mode 100644 pkg/processmanager/v1/process_manager.go create mode 100644 pkg/processmanager/v1/process_manager_test.go diff --git a/pkg/processmanager/process_manager_interface.go b/pkg/processmanager/process_manager_interface.go new file mode 100644 index 00000000..bed86cff --- /dev/null +++ b/pkg/processmanager/process_manager_interface.go @@ -0,0 +1,18 @@ +package processmanager + +import ( + apitypes "github.com/armosec/armoapi-go/armotypes" + containercollection "github.com/inspektor-gadget/inspektor-gadget/pkg/container-collection" + "github.com/kubescape/node-agent/pkg/utils" +) + +// ProcessManagerClient is the interface for the process manager client. +// It provides methods to get process tree for a container or a PID. +// The manager is responsible for maintaining the process tree for all containers. +type ProcessManagerClient interface { + GetProcessTreeForPID(containerID string, pid int) (apitypes.Process, error) + + // ReportEvent will be called to report new exec events to the process manager. + ReportEvent(eventType utils.EventType, event utils.K8sEvent) + ContainerCallback(notif containercollection.PubSubEvent) +} diff --git a/pkg/processmanager/process_manager_mock.go b/pkg/processmanager/process_manager_mock.go new file mode 100644 index 00000000..43213fba --- /dev/null +++ b/pkg/processmanager/process_manager_mock.go @@ -0,0 +1,28 @@ +package processmanager + +import ( + apitypes "github.com/armosec/armoapi-go/armotypes" + containercollection "github.com/inspektor-gadget/inspektor-gadget/pkg/container-collection" + "github.com/kubescape/node-agent/pkg/utils" +) + +type ProcessManagerMock struct { +} + +var _ ProcessManagerClient = (*ProcessManagerMock)(nil) + +func CreateProcessManagerMock() *ProcessManagerMock { + return &ProcessManagerMock{} +} + +func (p *ProcessManagerMock) GetProcessTreeForPID(containerID string, pid int) (apitypes.Process, error) { + return apitypes.Process{}, nil +} + +func (p *ProcessManagerMock) ReportEvent(eventType utils.EventType, event utils.K8sEvent) { + // no-op +} + +func (p *ProcessManagerMock) ContainerCallback(notif containercollection.PubSubEvent) { + // no-op +} diff --git a/pkg/processmanager/v1/process_manager.go b/pkg/processmanager/v1/process_manager.go new file mode 100644 index 00000000..d46dee03 --- /dev/null +++ b/pkg/processmanager/v1/process_manager.go @@ -0,0 +1,423 @@ +package processmanager + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/goradd/maps" + containercollection "github.com/inspektor-gadget/inspektor-gadget/pkg/container-collection" + "github.com/prometheus/procfs" + + apitypes "github.com/armosec/armoapi-go/armotypes" + tracerexectype "github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/trace/exec/types" + "github.com/kubescape/go-logger" + "github.com/kubescape/go-logger/helpers" + "github.com/kubescape/node-agent/pkg/utils" +) + +const ( + cleanupInterval = 1 * time.Minute + maxTreeDepth = 50 // Prevent infinite recursion +) + +// ProcessManager manages container processes and their relationships +type ProcessManager struct { + containerIdToShimPid maps.SafeMap[string, uint32] // containerID -> shim pid + processTree maps.SafeMap[uint32, apitypes.Process] // pid -> process +} + +// CreateProcessManager creates a new ProcessManager instance +func CreateProcessManager(ctx context.Context) *ProcessManager { + pm := &ProcessManager{} + + // // Do initial process scan during initialization + // if err := pm.initialProcScan(); err != nil { + // logger.L().Warning("Failed initial process scan", helpers.Error(err)) + // } + + // Start cleanup routine + go pm.startCleanupRoutine(ctx) + + return pm +} + +// InitialProcScan performs a one-time scan of /proc to build the initial process tree +// Only processes that are descendants of existing container shims will be added +func (p *ProcessManager) InitialProcScan() error { + // If we have no shims registered, nothing to do + if len(p.containerIdToShimPid.Keys()) == 0 { + return nil + } + + fs, err := procfs.NewFS("/proc") + if err != nil { + return fmt.Errorf("failed to open procfs: %v", err) + } + + procs, err := fs.AllProcs() + if err != nil { + return fmt.Errorf("failed to read all procs: %v", err) + } + + // First pass: collect all processes + tempProcesses := make(map[uint32]apitypes.Process) + for _, proc := range procs { + if process, err := p.getProcessFromProc(proc.PID); err == nil { + tempProcesses[process.PID] = process + } + } + + // Second pass: identify shim descendants and build relationships + shimDescendants := make(map[uint32]bool) + for pid, process := range tempProcesses { + // Check if this process is a descendant of any registered shim + if p.isDescendantOfShim(pid, make(map[uint32]bool)) { + shimDescendants[pid] = true + + // Also mark all ancestors up to the shim as descendants + currentPID := process.PPID + visited := make(map[uint32]bool) + for currentPID != 0 && !visited[currentPID] { + visited[currentPID] = true + if proc, exists := tempProcesses[currentPID]; exists { + shimDescendants[currentPID] = true + currentPID = proc.PPID + } else { + break + } + } + } + } + + // Final pass: add only shim-related processes and build relationships + for pid, process := range tempProcesses { + if !shimDescendants[pid] { + continue // Skip processes not related to shims + } + + // If parent exists and is also a shim descendant, update parent's children + if parent, exists := tempProcesses[process.PPID]; exists && shimDescendants[process.PPID] { + parent.Children = append(parent.Children, process) + tempProcesses[process.PPID] = parent + } + + // Add process to the tree + p.processTree.Set(pid, process) + } + + logger.L().Debug("Initial process scan completed", + helpers.Int("total_processes", len(tempProcesses)), + helpers.Int("shim_related_processes", len(shimDescendants))) + + return nil +} + +// addProcess adds a process to the tree and updates relationships +func (p *ProcessManager) addProcess(process apitypes.Process) { + // Add the process to the tree + p.processTree.Set(process.PID, process) + + // Update parent's children list if parent exists + if parent, exists := p.processTree.Load(process.PPID); exists { + // Create new children slice to avoid modifying existing one + newChildren := make([]apitypes.Process, 0, len(parent.Children)+1) + // Add existing children, excluding any old version of this process + for _, child := range parent.Children { + if child.PID != process.PID { + newChildren = append(newChildren, child) + } + } + // Add the new process + newChildren = append(newChildren, process) + parent.Children = newChildren + p.processTree.Set(parent.PID, parent) + } +} + +// removeProcess removes a process and updates relationships +func (p *ProcessManager) removeProcess(pid uint32) { + // Get the process before removing it + if process, exists := p.processTree.Load(pid); exists { + // Update parent's children list + if parent, exists := p.processTree.Load(process.PPID); exists { + newChildren := make([]apitypes.Process, 0, len(parent.Children)) + for _, child := range parent.Children { + if child.PID != pid { + newChildren = append(newChildren, child) + } + } + parent.Children = newChildren + p.processTree.Set(parent.PID, parent) + } + + // Reassign children to nearest living ancestor + for _, child := range process.Children { + if newProcess, exists := p.processTree.Load(child.PID); exists { + newProcess.PPID = process.PPID + p.addProcess(newProcess) // This will update the new parent's children list + } + } + + // Finally remove the process + p.processTree.Delete(pid) + } +} + +// isDescendantOfShim checks if a process is a descendant of any container shim +func (p *ProcessManager) isDescendantOfShim(pid uint32, visited map[uint32]bool) bool { + if pid == 0 || len(visited) > maxTreeDepth { + return false + } + + if visited[pid] { + return false // Avoid cycles + } + visited[pid] = true + + // Check if this pid is a shim + isShim := false + p.containerIdToShimPid.Range(func(_ string, shimPid uint32) bool { + if shimPid == pid { + isShim = true + return false // Stop ranging + } + return true + }) + + if isShim { + return true + } + + // Check parent if process exists + if process, exists := p.processTree.Load(pid); exists { + return p.isDescendantOfShim(process.PPID, visited) + } + + return false +} + +func (p *ProcessManager) startCleanupRoutine(ctx context.Context) { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + p.cleanup() + case <-ctx.Done(): + return + } + } +} + +func (p *ProcessManager) cleanup() { + // First pass: identify dead processes + deadPids := make(map[uint32]bool) + p.processTree.Range(func(pid uint32, _ apitypes.Process) bool { + if !isProcessAlive(int(pid)) { + deadPids[pid] = true + } + return true + }) + + // Second pass: remove dead processes and update relationships + for pid := range deadPids { + p.removeProcess(pid) + } +} + +func (p *ProcessManager) ContainerCallback(notif containercollection.PubSubEvent) { + switch notif.Type { + case containercollection.EventTypeAddContainer: + containerID := notif.Container.Runtime.ContainerID + shimPID := uint32(notif.Container.Pid) + + // Store the shim PID + p.containerIdToShimPid.Set(containerID, shimPID) + + // If the shim process isn't in our tree yet (might have started after initial scan), + // add it directly + if !p.processTree.Has(shimPID) { + if process, err := p.getProcessFromProc(int(shimPID)); err == nil { + p.addProcess(process) + } else { + logger.L().Debug("Failed to get shim process info", + helpers.String("containerID", containerID), + helpers.Error(err)) + } + } + + case containercollection.EventTypeRemoveContainer: + containerID := notif.Container.Runtime.ContainerID + + // Get shim PID before removing the mapping + if shimPID, exists := p.containerIdToShimPid.Load(containerID); exists { + // Remove all descendants of this shim + descendants := make(map[uint32]bool) + p.processTree.Range(func(pid uint32, process apitypes.Process) bool { + if p.isDescendantOfShim(pid, make(map[uint32]bool)) { + descendants[pid] = true + } + return true + }) + + // Remove descendants in reverse order (children before parents) + for pid := range descendants { + p.removeProcess(pid) + } + + // Finally remove the shim itself + p.removeProcess(shimPID) + } + + // Remove the container mapping + p.containerIdToShimPid.Delete(containerID) + } +} + +func (p *ProcessManager) GetProcessTreeForPID(containerID string, pid int) (apitypes.Process, error) { + if !p.containerIdToShimPid.Has(containerID) { + return apitypes.Process{}, fmt.Errorf("container ID %s not found", containerID) + } + + shimPID := p.containerIdToShimPid.Get(containerID) + targetPID := uint32(pid) + + // If process doesn't exist in our tree, try to fetch it + if !p.processTree.Has(targetPID) { + // Only fetch if it's a descendant of our shim + if process, err := p.getProcessFromProc(int(targetPID)); err == nil { + p.addProcess(process) + } else { + return apitypes.Process{}, fmt.Errorf("process %d not found: %v", pid, err) + } + } + + // Build process tree from target up to shim + processes := make([]apitypes.Process, 0) + currentPID := targetPID + seen := make(map[uint32]bool) + + // Collect all processes up to shim + for currentPID != 0 && currentPID != shimPID { + if seen[currentPID] { + break // Avoid cycles + } + seen[currentPID] = true + + if proc, exists := p.processTree.Load(currentPID); exists { + processes = append([]apitypes.Process{proc}, processes...) // Prepend + currentPID = proc.PPID + } else { + break + } + } + + // No processes found or invalid tree + if len(processes) == 0 { + return apitypes.Process{}, fmt.Errorf("could not build process tree for pid %d", pid) + } + + // Build the tree structure + result := processes[0] + current := &result + + // Link processes together + for i := 1; i < len(processes); i++ { + child := processes[i] + current.Children = []apitypes.Process{child} + current = ¤t.Children[0] + } + + // Add the target process as the final leaf if it's not already in the chain + if current.PID != targetPID { + if targetProc, exists := p.processTree.Load(targetPID); exists { + current.Children = []apitypes.Process{targetProc} + } + } + + return result, nil +} + +func (p *ProcessManager) ReportEvent(eventType utils.EventType, event utils.K8sEvent) { + if eventType != utils.ExecveEventType { + return + } + + execEvent, ok := event.(*tracerexectype.Event) + if !ok { + return + } + + // Create new process from event + process := apitypes.Process{ + PID: execEvent.Pid, + PPID: execEvent.Ppid, + Comm: execEvent.Comm, + Uid: &execEvent.Uid, + Gid: &execEvent.Gid, + Cmdline: strings.Join(execEvent.Args, " "), + } + + p.addProcess(process) +} + +func (p *ProcessManager) getProcessFromProc(pid int) (apitypes.Process, error) { + proc, err := procfs.NewProc(pid) + if err != nil { + return apitypes.Process{}, fmt.Errorf("failed to get process info: %v", err) + } + + stat, err := utils.GetProcessStat(pid) + if err != nil { + return apitypes.Process{}, fmt.Errorf("failed to get process stat: %v", err) + } + + // Get process details + var uid, gid uint32 + if status, err := proc.NewStatus(); err == nil { + if len(status.UIDs) > 1 { + uid = uint32(status.UIDs[1]) + } + if len(status.GIDs) > 1 { + gid = uint32(status.GIDs[1]) + } + } + + cmdline, err := proc.CmdLine() + if err != nil { + cmdline = []string{stat.Comm} // Fallback to comm if cmdline fails + } + + cwd, err := proc.Cwd() + if err != nil { + cwd = "" // Empty string if we can't get cwd + } + + path, err := proc.Executable() + if err != nil { + path = "" // Empty string if we can't get executable path + } + + return apitypes.Process{ + PID: uint32(pid), + PPID: uint32(stat.PPID), + Comm: stat.Comm, + Uid: &uid, + Gid: &gid, + Cmdline: strings.Join(cmdline, " "), + Cwd: cwd, + Path: path, + }, nil +} + +func isProcessAlive(pid int) bool { + proc, err := procfs.NewProc(pid) + if err != nil { + return false + } + _, err = proc.Stat() + return err == nil +} diff --git a/pkg/processmanager/v1/process_manager_test.go b/pkg/processmanager/v1/process_manager_test.go new file mode 100644 index 00000000..ab22968b --- /dev/null +++ b/pkg/processmanager/v1/process_manager_test.go @@ -0,0 +1,357 @@ +package processmanager + +import ( + "context" + "testing" + "time" + + apitypes "github.com/armosec/armoapi-go/armotypes" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + containercollection "github.com/inspektor-gadget/inspektor-gadget/pkg/container-collection" + tracerexectype "github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/trace/exec/types" + "github.com/inspektor-gadget/inspektor-gadget/pkg/types" + "github.com/kubescape/node-agent/pkg/utils" +) + +// mockProcessInfo helps simulate process information for testing +type mockProcessInfo struct { + pid uint32 + ppid uint32 + comm string + cmdline string +} + +func TestProcessManagerBasics(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pm := CreateProcessManager(ctx) + require.NotNil(t, pm) + + // Test container creation + containerID := "test-container-1" + shimPID := uint32(1000) + + // Simulate container creation + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: shimPID, + }, + }) + + // Verify shim PID was recorded + assert.True(t, pm.containerIdToShimPid.Has(containerID)) + assert.Equal(t, shimPID, pm.containerIdToShimPid.Get(containerID)) +} + +func TestProcessTracking(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pm := CreateProcessManager(ctx) + containerID := "test-container-1" + shimPID := uint32(1000) + + // Simulate container creation + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: shimPID, + }, + }) + + // Simulate process creation events + testCases := []struct { + name string + event tracerexectype.Event + verify func(t *testing.T, pm *ProcessManager) + }{ + { + name: "Direct child of shim", + event: tracerexectype.Event{ + Pid: 1001, + Ppid: shimPID, + Comm: "nginx", + Args: []string{"nginx", "-g", "daemon off;"}, + }, + verify: func(t *testing.T, pm *ProcessManager) { + proc, exists := pm.processTree.Load(1001) + require.True(t, exists) + assert.Equal(t, uint32(1001), proc.PID) + assert.Equal(t, shimPID, proc.PPID) + assert.Equal(t, "nginx", proc.Comm) + }, + }, + { + name: "Child of nginx", + event: tracerexectype.Event{ + Pid: 1002, + Ppid: 1001, + Comm: "nginx-worker", + Args: []string{"nginx", "worker process"}, + }, + verify: func(t *testing.T, pm *ProcessManager) { + proc, exists := pm.processTree.Load(1002) + require.True(t, exists) + assert.Equal(t, uint32(1002), proc.PID) + assert.Equal(t, uint32(1001), proc.PPID) + + // Verify parent's children list + parent, exists := pm.processTree.Load(1001) + require.True(t, exists) + found := false + for _, child := range parent.Children { + if child.PID == 1002 { + found = true + break + } + } + assert.True(t, found, "Child process should be in parent's children list") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + pm.ReportEvent(utils.ExecveEventType, &tc.event) + tc.verify(t, pm) + }) + } +} + +func TestProcessCleanup(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pm := CreateProcessManager(ctx) + containerID := "test-container-1" + shimPID := uint32(1000) + + // Simulate container creation + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: shimPID, + }, + }) + + // Add some test processes + processes := []mockProcessInfo{ + {pid: 1001, ppid: shimPID, comm: "parent", cmdline: "./parent"}, + {pid: 1002, ppid: 1001, comm: "child1", cmdline: "./child1"}, + {pid: 1003, ppid: 1001, comm: "child2", cmdline: "./child2"}, + } + + for _, proc := range processes { + event := tracerexectype.Event{ + Pid: proc.pid, + Ppid: proc.ppid, + Comm: proc.comm, + Args: []string{proc.cmdline}, + } + pm.ReportEvent(utils.ExecveEventType, &event) + } + + // Verify initial state + for _, proc := range processes { + assert.True(t, pm.processTree.Has(proc.pid)) + } + + // Simulate container removal + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeRemoveContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: shimPID, + }, + }) + + // Verify cleanup + for _, proc := range processes { + assert.False(t, pm.processTree.Has(proc.pid), + "Process %d should be removed after container cleanup", proc.pid) + } + assert.False(t, pm.containerIdToShimPid.Has(containerID)) +} + +func TestGetProcessTree(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pm := CreateProcessManager(ctx) + containerID := "test-container-1" + shimPID := uint32(1000) + + // Setup container + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: shimPID, + }, + }) + + // Create a process tree: + // shim (1000) + // └── parent (1001) + // ├── child1 (1002) + // └── child2 (1003) + // └── grandchild (1004) + + processes := []mockProcessInfo{ + {pid: 1001, ppid: shimPID, comm: "parent", cmdline: "./parent"}, + {pid: 1002, ppid: 1001, comm: "child1", cmdline: "./child1"}, + {pid: 1003, ppid: 1001, comm: "child2", cmdline: "./child2"}, + {pid: 1004, ppid: 1003, comm: "grandchild", cmdline: "./grandchild"}, + } + + // Add processes to tree + for _, proc := range processes { + event := &tracerexectype.Event{ + Pid: proc.pid, + Ppid: proc.ppid, + Comm: proc.comm, + Args: []string{proc.cmdline}, + } + pm.ReportEvent(utils.ExecveEventType, event) + } + + // Get and verify process tree for grandchild + tree, err := pm.GetProcessTreeForPID(containerID, 1004) + require.NoError(t, err) + + // Helper function to find a process in the tree + var findProcess func(apitypes.Process, uint32) *apitypes.Process + findProcess = func(node apitypes.Process, targetPID uint32) *apitypes.Process { + if node.PID == targetPID { + return &node + } + for _, child := range node.Children { + if found := findProcess(child, targetPID); found != nil { + return found + } + } + return nil + } + + // Verify tree structure + t.Run("verify tree structure", func(t *testing.T) { + // Check the chain from grandchild up to parent + current := &tree + expectedChain := []uint32{1001, 1003, 1004} // parent -> child2 -> grandchild + + for i, expectedPID := range expectedChain { + assert.Equal(t, expectedPID, current.PID, "Mismatch at chain position %d", i) + if i < len(expectedChain)-1 { + require.Len(t, current.Children, 1, "Expected exactly one child at position %d", i) + current = ¤t.Children[0] + } + } + }) + + // Verify process details + t.Run("verify process details", func(t *testing.T) { + for _, expected := range processes { + proc := findProcess(tree, expected.pid) + if proc != nil { + assert.Equal(t, expected.ppid, proc.PPID) + assert.Equal(t, expected.comm, proc.Comm) + assert.Equal(t, expected.cmdline, proc.Cmdline) + } + } + }) +} + +func TestContainerLifecycle(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + pm := CreateProcessManager(ctx) + + containers := []struct { + id string + shimPID uint32 + }{ + {"container-1", 1000}, + {"container-2", 2000}, + } + + // Add containers + for _, c := range containers { + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: c.id, + }, + }, + Pid: c.shimPID, + }, + }) + } + + // Verify containers were added + for _, c := range containers { + assert.True(t, pm.containerIdToShimPid.Has(c.id)) + assert.Equal(t, c.shimPID, pm.containerIdToShimPid.Get(c.id)) + } + + // Remove containers + for _, c := range containers { + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeRemoveContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: c.id, + }, + }, + Pid: c.shimPID, + }, + }) + + // Verify container was removed + assert.False(t, pm.containerIdToShimPid.Has(c.id)) + } +} + +func TestCleanupRoutine(t *testing.T) { + // This is a bit tricky to test since we can't easily simulate dead processes + // But we can test that the cleanup routine runs without errors + ctx, cancel := context.WithCancel(context.Background()) + CreateProcessManager(ctx) + + // Let it run for a short while + time.Sleep(2 * time.Second) + + // Cancel and ensure it shuts down cleanly + cancel() + time.Sleep(100 * time.Millisecond) +} From 45a4fc0c351c61ad09d620878e7424a1540381e5 Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Tue, 29 Oct 2024 19:36:56 +0000 Subject: [PATCH 02/14] Adding updated code Signed-off-by: Amit Schendel --- pkg/processmanager/v1/process_manager.go | 400 +++++-------- pkg/processmanager/v1/process_manager_test.go | 552 +++++++++++++----- 2 files changed, 550 insertions(+), 402 deletions(-) diff --git a/pkg/processmanager/v1/process_manager.go b/pkg/processmanager/v1/process_manager.go index d46dee03..5c286bb4 100644 --- a/pkg/processmanager/v1/process_manager.go +++ b/pkg/processmanager/v1/process_manager.go @@ -19,128 +19,167 @@ import ( const ( cleanupInterval = 1 * time.Minute - maxTreeDepth = 50 // Prevent infinite recursion + maxTreeDepth = 50 ) -// ProcessManager manages container processes and their relationships type ProcessManager struct { - containerIdToShimPid maps.SafeMap[string, uint32] // containerID -> shim pid - processTree maps.SafeMap[uint32, apitypes.Process] // pid -> process + containerIdToShimPid maps.SafeMap[string, uint32] + processTree maps.SafeMap[uint32, apitypes.Process] + // For testing purposes we allow to override the function that gets process info from /proc. + getProcessFromProc func(pid int) (apitypes.Process, error) } -// CreateProcessManager creates a new ProcessManager instance func CreateProcessManager(ctx context.Context) *ProcessManager { - pm := &ProcessManager{} - - // // Do initial process scan during initialization - // if err := pm.initialProcScan(); err != nil { - // logger.L().Warning("Failed initial process scan", helpers.Error(err)) - // } - - // Start cleanup routine + pm := &ProcessManager{ + getProcessFromProc: getProcessFromProc, + } go pm.startCleanupRoutine(ctx) - return pm } -// InitialProcScan performs a one-time scan of /proc to build the initial process tree -// Only processes that are descendants of existing container shims will be added func (p *ProcessManager) InitialProcScan() error { - // If we have no shims registered, nothing to do if len(p.containerIdToShimPid.Keys()) == 0 { return nil } fs, err := procfs.NewFS("/proc") if err != nil { - return fmt.Errorf("failed to open procfs: %v", err) + return fmt.Errorf("failed to open procfs: %w", err) } procs, err := fs.AllProcs() if err != nil { - return fmt.Errorf("failed to read all procs: %v", err) + return fmt.Errorf("failed to read all procs: %w", err) } - // First pass: collect all processes - tempProcesses := make(map[uint32]apitypes.Process) + tempProcesses := make(map[uint32]apitypes.Process, len(procs)) + shimPIDs := make(map[uint32]struct{}) + + p.containerIdToShimPid.Range(func(_ string, shimPID uint32) bool { + shimPIDs[shimPID] = struct{}{} + return true + }) + + // First collect all processes for _, proc := range procs { if process, err := p.getProcessFromProc(proc.PID); err == nil { tempProcesses[process.PID] = process } } - // Second pass: identify shim descendants and build relationships - shimDescendants := make(map[uint32]bool) + // Then build relationships and add to tree for pid, process := range tempProcesses { - // Check if this process is a descendant of any registered shim - if p.isDescendantOfShim(pid, make(map[uint32]bool)) { - shimDescendants[pid] = true - - // Also mark all ancestors up to the shim as descendants - currentPID := process.PPID - visited := make(map[uint32]bool) - for currentPID != 0 && !visited[currentPID] { - visited[currentPID] = true - if proc, exists := tempProcesses[currentPID]; exists { - shimDescendants[currentPID] = true - currentPID = proc.PPID - } else { - break - } + if p.isDescendantOfShim(pid, process.PPID, shimPIDs, tempProcesses) { + if parent, exists := tempProcesses[process.PPID]; exists { + parent.Children = append(parent.Children, process) + tempProcesses[process.PPID] = parent } + p.processTree.Set(pid, process) } } - // Final pass: add only shim-related processes and build relationships - for pid, process := range tempProcesses { - if !shimDescendants[pid] { - continue // Skip processes not related to shims + return nil +} + +func (p *ProcessManager) isDescendantOfShim(pid uint32, ppid uint32, shimPIDs map[uint32]struct{}, processes map[uint32]apitypes.Process) bool { + visited := make(map[uint32]bool) + currentPID := pid + for depth := 0; depth < maxTreeDepth; depth++ { + if currentPID == 0 || visited[currentPID] { + return false } + visited[currentPID] = true - // If parent exists and is also a shim descendant, update parent's children - if parent, exists := tempProcesses[process.PPID]; exists && shimDescendants[process.PPID] { - parent.Children = append(parent.Children, process) - tempProcesses[process.PPID] = parent + if _, isShim := shimPIDs[ppid]; isShim { + return true } - // Add process to the tree - p.processTree.Set(pid, process) + process, exists := processes[ppid] + if !exists { + return false + } + currentPID = ppid + ppid = process.PPID } + return false +} - logger.L().Debug("Initial process scan completed", - helpers.Int("total_processes", len(tempProcesses)), - helpers.Int("shim_related_processes", len(shimDescendants))) +func (p *ProcessManager) ContainerCallback(notif containercollection.PubSubEvent) { + containerID := notif.Container.Runtime.BasicRuntimeMetadata.ContainerID - return nil + switch notif.Type { + case containercollection.EventTypeAddContainer: + containerPID := uint32(notif.Container.Pid) + if process, err := p.getProcessFromProc(int(containerPID)); err == nil { + shimPID := process.PPID + p.containerIdToShimPid.Set(containerID, shimPID) + p.addProcess(process) + } else { + logger.L().Warning("Failed to get container process info", + helpers.String("containerID", containerID), + helpers.Error(err)) + } + + case containercollection.EventTypeRemoveContainer: + if shimPID, exists := p.containerIdToShimPid.Load(containerID); exists { + p.removeProcessesUnderShim(shimPID) + p.containerIdToShimPid.Delete(containerID) + } + } +} + +func (p *ProcessManager) removeProcessesUnderShim(shimPID uint32) { + var pidsToRemove []uint32 + + p.processTree.Range(func(pid uint32, process apitypes.Process) bool { + currentPID := pid + visited := make(map[uint32]bool) + + for currentPID != 0 && !visited[currentPID] { + visited[currentPID] = true + if proc, exists := p.processTree.Load(currentPID); exists { + if proc.PPID == shimPID { + pidsToRemove = append(pidsToRemove, pid) + break + } + currentPID = proc.PPID + } else { + break + } + } + return true + }) + + // Remove in reverse order to handle parent-child relationships + for i := len(pidsToRemove) - 1; i >= 0; i-- { + p.removeProcess(pidsToRemove[i]) + } } -// addProcess adds a process to the tree and updates relationships func (p *ProcessManager) addProcess(process apitypes.Process) { - // Add the process to the tree p.processTree.Set(process.PID, process) - // Update parent's children list if parent exists if parent, exists := p.processTree.Load(process.PPID); exists { - // Create new children slice to avoid modifying existing one newChildren := make([]apitypes.Process, 0, len(parent.Children)+1) - // Add existing children, excluding any old version of this process + hasProcess := false for _, child := range parent.Children { - if child.PID != process.PID { + if child.PID == process.PID { + hasProcess = true + newChildren = append(newChildren, process) + } else { newChildren = append(newChildren, child) } } - // Add the new process - newChildren = append(newChildren, process) + if !hasProcess { + newChildren = append(newChildren, process) + } parent.Children = newChildren p.processTree.Set(parent.PID, parent) } } -// removeProcess removes a process and updates relationships func (p *ProcessManager) removeProcess(pid uint32) { - // Get the process before removing it if process, exists := p.processTree.Load(pid); exists { - // Update parent's children list if parent, exists := p.processTree.Load(process.PPID); exists { newChildren := make([]apitypes.Process, 0, len(parent.Children)) for _, child := range parent.Children { @@ -152,192 +191,52 @@ func (p *ProcessManager) removeProcess(pid uint32) { p.processTree.Set(parent.PID, parent) } - // Reassign children to nearest living ancestor for _, child := range process.Children { - if newProcess, exists := p.processTree.Load(child.PID); exists { - newProcess.PPID = process.PPID - p.addProcess(newProcess) // This will update the new parent's children list + if childProcess, exists := p.processTree.Load(child.PID); exists { + childProcess.PPID = process.PPID + p.addProcess(childProcess) } } - // Finally remove the process p.processTree.Delete(pid) } } -// isDescendantOfShim checks if a process is a descendant of any container shim -func (p *ProcessManager) isDescendantOfShim(pid uint32, visited map[uint32]bool) bool { - if pid == 0 || len(visited) > maxTreeDepth { - return false - } - - if visited[pid] { - return false // Avoid cycles - } - visited[pid] = true - - // Check if this pid is a shim - isShim := false - p.containerIdToShimPid.Range(func(_ string, shimPid uint32) bool { - if shimPid == pid { - isShim = true - return false // Stop ranging - } - return true - }) - - if isShim { - return true - } - - // Check parent if process exists - if process, exists := p.processTree.Load(pid); exists { - return p.isDescendantOfShim(process.PPID, visited) - } - - return false -} - -func (p *ProcessManager) startCleanupRoutine(ctx context.Context) { - ticker := time.NewTicker(cleanupInterval) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - p.cleanup() - case <-ctx.Done(): - return - } - } -} - -func (p *ProcessManager) cleanup() { - // First pass: identify dead processes - deadPids := make(map[uint32]bool) - p.processTree.Range(func(pid uint32, _ apitypes.Process) bool { - if !isProcessAlive(int(pid)) { - deadPids[pid] = true - } - return true - }) - - // Second pass: remove dead processes and update relationships - for pid := range deadPids { - p.removeProcess(pid) - } -} - -func (p *ProcessManager) ContainerCallback(notif containercollection.PubSubEvent) { - switch notif.Type { - case containercollection.EventTypeAddContainer: - containerID := notif.Container.Runtime.ContainerID - shimPID := uint32(notif.Container.Pid) - - // Store the shim PID - p.containerIdToShimPid.Set(containerID, shimPID) - - // If the shim process isn't in our tree yet (might have started after initial scan), - // add it directly - if !p.processTree.Has(shimPID) { - if process, err := p.getProcessFromProc(int(shimPID)); err == nil { - p.addProcess(process) - } else { - logger.L().Debug("Failed to get shim process info", - helpers.String("containerID", containerID), - helpers.Error(err)) - } - } - - case containercollection.EventTypeRemoveContainer: - containerID := notif.Container.Runtime.ContainerID - - // Get shim PID before removing the mapping - if shimPID, exists := p.containerIdToShimPid.Load(containerID); exists { - // Remove all descendants of this shim - descendants := make(map[uint32]bool) - p.processTree.Range(func(pid uint32, process apitypes.Process) bool { - if p.isDescendantOfShim(pid, make(map[uint32]bool)) { - descendants[pid] = true - } - return true - }) - - // Remove descendants in reverse order (children before parents) - for pid := range descendants { - p.removeProcess(pid) - } - - // Finally remove the shim itself - p.removeProcess(shimPID) - } - - // Remove the container mapping - p.containerIdToShimPid.Delete(containerID) - } -} - func (p *ProcessManager) GetProcessTreeForPID(containerID string, pid int) (apitypes.Process, error) { if !p.containerIdToShimPid.Has(containerID) { return apitypes.Process{}, fmt.Errorf("container ID %s not found", containerID) } - shimPID := p.containerIdToShimPid.Get(containerID) targetPID := uint32(pid) - - // If process doesn't exist in our tree, try to fetch it if !p.processTree.Has(targetPID) { - // Only fetch if it's a descendant of our shim - if process, err := p.getProcessFromProc(int(targetPID)); err == nil { - p.addProcess(process) - } else { + process, err := p.getProcessFromProc(pid) + if err != nil { return apitypes.Process{}, fmt.Errorf("process %d not found: %v", pid, err) } + p.addProcess(process) } - // Build process tree from target up to shim - processes := make([]apitypes.Process, 0) - currentPID := targetPID + result := p.processTree.Get(targetPID) + currentPID := result.PPID seen := make(map[uint32]bool) - // Collect all processes up to shim - for currentPID != 0 && currentPID != shimPID { + for currentPID != p.containerIdToShimPid.Get(containerID) && currentPID != 0 { if seen[currentPID] { - break // Avoid cycles + break } seen[currentPID] = true - if proc, exists := p.processTree.Load(currentPID); exists { - processes = append([]apitypes.Process{proc}, processes...) // Prepend - currentPID = proc.PPID + if p.processTree.Has(currentPID) { + parent := p.processTree.Get(currentPID) + parentCopy := parent + parentCopy.Children = []apitypes.Process{result} + result = parentCopy + currentPID = parent.PPID } else { break } } - // No processes found or invalid tree - if len(processes) == 0 { - return apitypes.Process{}, fmt.Errorf("could not build process tree for pid %d", pid) - } - - // Build the tree structure - result := processes[0] - current := &result - - // Link processes together - for i := 1; i < len(processes); i++ { - child := processes[i] - current.Children = []apitypes.Process{child} - current = ¤t.Children[0] - } - - // Add the target process as the final leaf if it's not already in the chain - if current.PID != targetPID { - if targetProc, exists := p.processTree.Load(targetPID); exists { - current.Children = []apitypes.Process{targetProc} - } - } - return result, nil } @@ -351,20 +250,49 @@ func (p *ProcessManager) ReportEvent(eventType utils.EventType, event utils.K8sE return } - // Create new process from event process := apitypes.Process{ - PID: execEvent.Pid, - PPID: execEvent.Ppid, - Comm: execEvent.Comm, - Uid: &execEvent.Uid, - Gid: &execEvent.Gid, - Cmdline: strings.Join(execEvent.Args, " "), + PID: uint32(execEvent.Pid), + PPID: uint32(execEvent.Ppid), + Comm: execEvent.Comm, + Uid: &execEvent.Uid, + Gid: &execEvent.Gid, + Hardlink: execEvent.ExePath, + UpperLayer: &execEvent.UpperLayer, + Cmdline: strings.Join(execEvent.Args, " "), } p.addProcess(process) } -func (p *ProcessManager) getProcessFromProc(pid int) (apitypes.Process, error) { +func (p *ProcessManager) startCleanupRoutine(ctx context.Context) { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + p.cleanup() + case <-ctx.Done(): + return + } + } +} + +func (p *ProcessManager) cleanup() { + deadPids := make(map[uint32]bool) + p.processTree.Range(func(pid uint32, _ apitypes.Process) bool { + if !isProcessAlive(int(pid)) { + deadPids[pid] = true + } + return true + }) + + for pid := range deadPids { + p.removeProcess(pid) + } +} + +func getProcessFromProc(pid int) (apitypes.Process, error) { proc, err := procfs.NewProc(pid) if err != nil { return apitypes.Process{}, fmt.Errorf("failed to get process info: %v", err) @@ -375,7 +303,6 @@ func (p *ProcessManager) getProcessFromProc(pid int) (apitypes.Process, error) { return apitypes.Process{}, fmt.Errorf("failed to get process stat: %v", err) } - // Get process details var uid, gid uint32 if status, err := proc.NewStatus(); err == nil { if len(status.UIDs) > 1 { @@ -386,20 +313,13 @@ func (p *ProcessManager) getProcessFromProc(pid int) (apitypes.Process, error) { } } - cmdline, err := proc.CmdLine() - if err != nil { - cmdline = []string{stat.Comm} // Fallback to comm if cmdline fails + cmdline, _ := proc.CmdLine() + if len(cmdline) == 0 { + cmdline = []string{stat.Comm} } - cwd, err := proc.Cwd() - if err != nil { - cwd = "" // Empty string if we can't get cwd - } - - path, err := proc.Executable() - if err != nil { - path = "" // Empty string if we can't get executable path - } + cwd, _ := proc.Cwd() + path, _ := proc.Executable() return apitypes.Process{ PID: uint32(pid), diff --git a/pkg/processmanager/v1/process_manager_test.go b/pkg/processmanager/v1/process_manager_test.go index ab22968b..0d97bd88 100644 --- a/pkg/processmanager/v1/process_manager_test.go +++ b/pkg/processmanager/v1/process_manager_test.go @@ -2,8 +2,9 @@ package processmanager import ( "context" + "fmt" + "sync" "testing" - "time" apitypes "github.com/armosec/armoapi-go/armotypes" "github.com/stretchr/testify/assert" @@ -15,26 +16,60 @@ import ( "github.com/kubescape/node-agent/pkg/utils" ) -// mockProcessInfo helps simulate process information for testing -type mockProcessInfo struct { - pid uint32 - ppid uint32 - comm string - cmdline string -} +// Helper function type definition +type mockProcessAdder func(pid int, ppid uint32, comm string) -func TestProcessManagerBasics(t *testing.T) { +// Updated setup function with correct return types +func setupTestProcessManager(t *testing.T) (*ProcessManager, mockProcessAdder) { ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - pm := CreateProcessManager(ctx) - require.NotNil(t, pm) - // Test container creation + // Create process mock map + mockProcesses := make(map[int]apitypes.Process) + + // Store original function + originalGetProcessFromProc := pm.getProcessFromProc + + // Replace with mock version + pm.getProcessFromProc = func(pid int) (apitypes.Process, error) { + if proc, exists := mockProcesses[pid]; exists { + return proc, nil + } + return apitypes.Process{}, fmt.Errorf("mock process not found: %d", pid) + } + + // Set up cleanup + t.Cleanup(func() { + cancel() + pm.getProcessFromProc = originalGetProcessFromProc + }) + + // Return the process manager and the mock process adder function + return pm, func(pid int, ppid uint32, comm string) { + uid := uint32(1000) + gid := uint32(1000) + mockProcesses[pid] = apitypes.Process{ + PID: uint32(pid), + PPID: ppid, + Comm: comm, + Cmdline: comm, + Uid: &uid, + Gid: &gid, + } + } +} + +func TestProcessManagerBasics(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) + containerID := "test-container-1" - shimPID := uint32(1000) + shimPID := uint32(999) + containerPID := uint32(1000) + + // Add mock container process with shim as parent + addMockProcess(int(containerPID), shimPID, "container-main") - // Simulate container creation + // Register container pm.ContainerCallback(containercollection.PubSubEvent{ Type: containercollection.EventTypeAddContainer, Container: &containercollection.Container{ @@ -43,24 +78,29 @@ func TestProcessManagerBasics(t *testing.T) { ContainerID: containerID, }, }, - Pid: shimPID, + Pid: containerPID, }, }) - // Verify shim PID was recorded + // Verify shim was recorded assert.True(t, pm.containerIdToShimPid.Has(containerID)) assert.Equal(t, shimPID, pm.containerIdToShimPid.Get(containerID)) + + // Verify container process was added + containerProc, exists := pm.processTree.Load(containerPID) + assert.True(t, exists) + assert.Equal(t, shimPID, containerProc.PPID) } func TestProcessTracking(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + pm, addMockProcess := setupTestProcessManager(t) - pm := CreateProcessManager(ctx) containerID := "test-container-1" - shimPID := uint32(1000) + shimPID := uint32(999) + containerPID := uint32(1000) + + addMockProcess(int(containerPID), shimPID, "container-main") - // Simulate container creation pm.ContainerCallback(containercollection.PubSubEvent{ Type: containercollection.EventTypeAddContainer, Container: &containercollection.Container{ @@ -69,57 +109,68 @@ func TestProcessTracking(t *testing.T) { ContainerID: containerID, }, }, - Pid: shimPID, + Pid: containerPID, }, }) - // Simulate process creation events testCases := []struct { name string event tracerexectype.Event verify func(t *testing.T, pm *ProcessManager) }{ { - name: "Direct child of shim", + name: "Container child process", event: tracerexectype.Event{ Pid: 1001, - Ppid: shimPID, + Ppid: containerPID, Comm: "nginx", Args: []string{"nginx", "-g", "daemon off;"}, }, verify: func(t *testing.T, pm *ProcessManager) { proc, exists := pm.processTree.Load(1001) require.True(t, exists) - assert.Equal(t, uint32(1001), proc.PID) - assert.Equal(t, shimPID, proc.PPID) + assert.Equal(t, containerPID, proc.PPID) assert.Equal(t, "nginx", proc.Comm) }, }, { - name: "Child of nginx", + name: "Exec process (direct child of shim)", event: tracerexectype.Event{ Pid: 1002, + Ppid: shimPID, + Comm: "bash", + Args: []string{"bash"}, + }, + verify: func(t *testing.T, pm *ProcessManager) { + proc, exists := pm.processTree.Load(1002) + require.True(t, exists) + assert.Equal(t, shimPID, proc.PPID) + assert.Equal(t, "bash", proc.Comm) + }, + }, + { + name: "Nested process", + event: tracerexectype.Event{ + Pid: 1003, Ppid: 1001, Comm: "nginx-worker", Args: []string{"nginx", "worker process"}, }, verify: func(t *testing.T, pm *ProcessManager) { - proc, exists := pm.processTree.Load(1002) + proc, exists := pm.processTree.Load(1003) require.True(t, exists) - assert.Equal(t, uint32(1002), proc.PID) assert.Equal(t, uint32(1001), proc.PPID) - // Verify parent's children list parent, exists := pm.processTree.Load(1001) require.True(t, exists) - found := false + hasChild := false for _, child := range parent.Children { - if child.PID == 1002 { - found = true + if child.PID == 1003 { + hasChild = true break } } - assert.True(t, found, "Child process should be in parent's children list") + assert.True(t, hasChild) }, }, } @@ -132,15 +183,15 @@ func TestProcessTracking(t *testing.T) { } } -func TestProcessCleanup(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +func TestProcessRemoval(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) - pm := CreateProcessManager(ctx) containerID := "test-container-1" - shimPID := uint32(1000) + shimPID := uint32(999) + containerPID := uint32(1000) + + addMockProcess(int(containerPID), shimPID, "container-main") - // Simulate container creation pm.ContainerCallback(containercollection.PubSubEvent{ Type: containercollection.EventTypeAddContainer, Container: &containercollection.Container{ @@ -149,62 +200,72 @@ func TestProcessCleanup(t *testing.T) { ContainerID: containerID, }, }, - Pid: shimPID, + Pid: containerPID, }, }) - // Add some test processes - processes := []mockProcessInfo{ - {pid: 1001, ppid: shimPID, comm: "parent", cmdline: "./parent"}, - {pid: 1002, ppid: 1001, comm: "child1", cmdline: "./child1"}, - {pid: 1003, ppid: 1001, comm: "child2", cmdline: "./child2"}, + // Create a process tree + processes := []struct { + pid uint32 + ppid uint32 + comm string + }{ + {1001, containerPID, "parent"}, + {1002, 1001, "child1"}, + {1003, 1002, "grandchild1"}, + {1004, 1002, "grandchild2"}, } + // Add processes for _, proc := range processes { - event := tracerexectype.Event{ + event := &tracerexectype.Event{ Pid: proc.pid, Ppid: proc.ppid, Comm: proc.comm, - Args: []string{proc.cmdline}, } - pm.ReportEvent(utils.ExecveEventType, &event) + pm.ReportEvent(utils.ExecveEventType, event) } - // Verify initial state + // Verify initial structure for _, proc := range processes { assert.True(t, pm.processTree.Has(proc.pid)) } - // Simulate container removal - pm.ContainerCallback(containercollection.PubSubEvent{ - Type: containercollection.EventTypeRemoveContainer, - Container: &containercollection.Container{ - Runtime: containercollection.RuntimeMetadata{ - BasicRuntimeMetadata: types.BasicRuntimeMetadata{ - ContainerID: containerID, - }, - }, - Pid: shimPID, - }, - }) + // Remove middle process and verify tree reorganization + pm.removeProcess(1002) - // Verify cleanup - for _, proc := range processes { - assert.False(t, pm.processTree.Has(proc.pid), - "Process %d should be removed after container cleanup", proc.pid) + // Verify process was removed + assert.False(t, pm.processTree.Has(1002)) + + // Verify children were reassigned to parent + parent, exists := pm.processTree.Load(1001) + require.True(t, exists) + + // Should now have both grandchildren + childPIDs := make(map[uint32]bool) + for _, child := range parent.Children { + childPIDs[child.PID] = true + } + assert.True(t, childPIDs[1003]) + assert.True(t, childPIDs[1004]) + + // Verify grandchildren's PPID was updated + for _, pid := range []uint32{1003, 1004} { + proc, exists := pm.processTree.Load(pid) + require.True(t, exists) + assert.Equal(t, uint32(1001), proc.PPID) } - assert.False(t, pm.containerIdToShimPid.Has(containerID)) } -func TestGetProcessTree(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +func TestContainerRemoval(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) - pm := CreateProcessManager(ctx) containerID := "test-container-1" - shimPID := uint32(1000) + shimPID := uint32(999) + containerPID := uint32(1000) + + addMockProcess(int(containerPID), shimPID, "container-main") - // Setup container pm.ContainerCallback(containercollection.PubSubEvent{ Type: containercollection.EventTypeAddContainer, Container: &containercollection.Container{ @@ -213,97 +274,69 @@ func TestGetProcessTree(t *testing.T) { ContainerID: containerID, }, }, - Pid: shimPID, + Pid: containerPID, }, }) - // Create a process tree: - // shim (1000) - // └── parent (1001) - // ├── child1 (1002) - // └── child2 (1003) - // └── grandchild (1004) - - processes := []mockProcessInfo{ - {pid: 1001, ppid: shimPID, comm: "parent", cmdline: "./parent"}, - {pid: 1002, ppid: 1001, comm: "child1", cmdline: "./child1"}, - {pid: 1003, ppid: 1001, comm: "child2", cmdline: "./child2"}, - {pid: 1004, ppid: 1003, comm: "grandchild", cmdline: "./grandchild"}, + // Create various processes under the container + processes := []struct { + pid uint32 + ppid uint32 + comm string + }{ + {containerPID, shimPID, "container-main"}, + {1001, containerPID, "app"}, + {1002, 1001, "worker"}, + {1003, shimPID, "exec"}, // direct child of shim } - // Add processes to tree for _, proc := range processes { event := &tracerexectype.Event{ Pid: proc.pid, Ppid: proc.ppid, Comm: proc.comm, - Args: []string{proc.cmdline}, } pm.ReportEvent(utils.ExecveEventType, event) } - // Get and verify process tree for grandchild - tree, err := pm.GetProcessTreeForPID(containerID, 1004) - require.NoError(t, err) + // Remove container + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeRemoveContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: containerPID, + }, + }) - // Helper function to find a process in the tree - var findProcess func(apitypes.Process, uint32) *apitypes.Process - findProcess = func(node apitypes.Process, targetPID uint32) *apitypes.Process { - if node.PID == targetPID { - return &node - } - for _, child := range node.Children { - if found := findProcess(child, targetPID); found != nil { - return found - } - } - return nil + // Verify all processes were removed + for _, proc := range processes { + assert.False(t, pm.processTree.Has(proc.pid)) } - // Verify tree structure - t.Run("verify tree structure", func(t *testing.T) { - // Check the chain from grandchild up to parent - current := &tree - expectedChain := []uint32{1001, 1003, 1004} // parent -> child2 -> grandchild - - for i, expectedPID := range expectedChain { - assert.Equal(t, expectedPID, current.PID, "Mismatch at chain position %d", i) - if i < len(expectedChain)-1 { - require.Len(t, current.Children, 1, "Expected exactly one child at position %d", i) - current = ¤t.Children[0] - } - } - }) - - // Verify process details - t.Run("verify process details", func(t *testing.T) { - for _, expected := range processes { - proc := findProcess(tree, expected.pid) - if proc != nil { - assert.Equal(t, expected.ppid, proc.PPID) - assert.Equal(t, expected.comm, proc.Comm) - assert.Equal(t, expected.cmdline, proc.Cmdline) - } - } - }) + // Verify container was removed from mapping + assert.False(t, pm.containerIdToShimPid.Has(containerID)) } -func TestContainerLifecycle(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - pm := CreateProcessManager(ctx) +func TestMultipleContainers(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) containers := []struct { - id string - shimPID uint32 + id string + shimPID uint32 + containerPID uint32 }{ - {"container-1", 1000}, - {"container-2", 2000}, + {"container-1", 999, 1000}, + {"container-2", 1998, 2000}, } // Add containers for _, c := range containers { + addMockProcess(int(c.containerPID), c.shimPID, fmt.Sprintf("container-%s", c.id)) + pm.ContainerCallback(containercollection.PubSubEvent{ Type: containercollection.EventTypeAddContainer, Container: &containercollection.Container{ @@ -312,46 +345,241 @@ func TestContainerLifecycle(t *testing.T) { ContainerID: c.id, }, }, - Pid: c.shimPID, + Pid: c.containerPID, }, }) + + // Add some processes to each container + event1 := &tracerexectype.Event{ + Pid: c.containerPID + 1, + Ppid: c.containerPID, + Comm: "process-1", + } + event2 := &tracerexectype.Event{ + Pid: c.containerPID + 2, + Ppid: c.shimPID, + Comm: "exec-process", + } + + pm.ReportEvent(utils.ExecveEventType, event1) + pm.ReportEvent(utils.ExecveEventType, event2) } - // Verify containers were added + // Verify each container's processes for _, c := range containers { - assert.True(t, pm.containerIdToShimPid.Has(c.id)) - assert.Equal(t, c.shimPID, pm.containerIdToShimPid.Get(c.id)) + // Check container process + proc, exists := pm.processTree.Load(c.containerPID) + require.True(t, exists) + assert.Equal(t, c.shimPID, proc.PPID) + + // Check child process + childProc, exists := pm.processTree.Load(c.containerPID + 1) + require.True(t, exists) + assert.Equal(t, c.containerPID, childProc.PPID) + + // Check exec process + execProc, exists := pm.processTree.Load(c.containerPID + 2) + require.True(t, exists) + assert.Equal(t, c.shimPID, execProc.PPID) } - // Remove containers - for _, c := range containers { + // Remove first container + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeRemoveContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containers[0].id, + }, + }, + Pid: containers[0].containerPID, + }, + }) + + // Verify first container's processes are gone + assert.False(t, pm.processTree.Has(containers[0].containerPID)) + assert.False(t, pm.processTree.Has(containers[0].containerPID+1)) + assert.False(t, pm.processTree.Has(containers[0].containerPID+2)) + + // Verify second container's processes remain + assert.True(t, pm.processTree.Has(containers[1].containerPID)) + assert.True(t, pm.processTree.Has(containers[1].containerPID+1)) + assert.True(t, pm.processTree.Has(containers[1].containerPID+2)) +} + +func TestErrorCases(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) + + t.Run("get non-existent process tree", func(t *testing.T) { + _, err := pm.GetProcessTreeForPID("non-existent", 1000) + assert.Error(t, err) + }) + + t.Run("process with non-existent parent", func(t *testing.T) { + containerID := "test-container" + shimPID := uint32(999) + containerPID := uint32(1000) + + addMockProcess(int(containerPID), shimPID, "container-main") + pm.ContainerCallback(containercollection.PubSubEvent{ - Type: containercollection.EventTypeRemoveContainer, + Type: containercollection.EventTypeAddContainer, Container: &containercollection.Container{ Runtime: containercollection.RuntimeMetadata{ BasicRuntimeMetadata: types.BasicRuntimeMetadata{ - ContainerID: c.id, + ContainerID: containerID, }, }, - Pid: c.shimPID, + Pid: containerPID, }, }) - // Verify container was removed - assert.False(t, pm.containerIdToShimPid.Has(c.id)) - } + // Add process with non-existent parent + event := &tracerexectype.Event{ + Pid: 2000, + Ppid: 1500, // Non-existent PPID + Comm: "orphan", + } + pm.ReportEvent(utils.ExecveEventType, event) + + // Process should still be added + assert.True(t, pm.processTree.Has(2000)) + }) } -func TestCleanupRoutine(t *testing.T) { - // This is a bit tricky to test since we can't easily simulate dead processes - // But we can test that the cleanup routine runs without errors - ctx, cancel := context.WithCancel(context.Background()) - CreateProcessManager(ctx) +func TestRaceConditions(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) - // Let it run for a short while - time.Sleep(2 * time.Second) + containerID := "test-container" + shimPID := uint32(999) + containerPID := uint32(1000) - // Cancel and ensure it shuts down cleanly - cancel() - time.Sleep(100 * time.Millisecond) + // Setup container + addMockProcess(int(containerPID), shimPID, "container-main") + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: containerPID, + }, + }) + + processCount := 100 + var mu sync.Mutex + processStates := make(map[uint32]struct { + added bool + removed bool + }) + + // Pre-populate process states + for i := 0; i < processCount; i++ { + pid := uint32(2000 + i) + processStates[pid] = struct { + added bool + removed bool + }{false, false} + } + + // Channel to signal between goroutines + removeDone := make(chan bool) + addDone := make(chan bool) + + // Goroutine to remove processes (run first) + go func() { + for i := 0; i < processCount; i++ { + if i%2 == 0 { + pid := uint32(2000 + i) + mu.Lock() + if state, exists := processStates[pid]; exists { + state.removed = true + processStates[pid] = state + } + mu.Unlock() + pm.removeProcess(pid) + } + } + removeDone <- true + }() + + // Wait for removals to complete before starting additions + <-removeDone + + // Goroutine to add processes + go func() { + for i := 0; i < processCount; i++ { + pid := uint32(2000 + i) + // Only add if not marked for removal + mu.Lock() + state := processStates[pid] + if !state.removed { + event := &tracerexectype.Event{ + Pid: pid, + Ppid: shimPID, + Comm: fmt.Sprintf("process-%d", i), + } + state.added = true + processStates[pid] = state + mu.Unlock() + pm.ReportEvent(utils.ExecveEventType, event) + } else { + mu.Unlock() + } + } + addDone <- true + }() + + // Wait for additions to complete + <-addDone + + // Verify final state + remainingCount := 0 + pm.processTree.Range(func(pid uint32, process apitypes.Process) bool { + if pid >= 2000 && pid < 2000+uint32(processCount) { + mu.Lock() + state := processStates[pid] + mu.Unlock() + + if state.removed { + t.Errorf("Process %d exists but was marked for removal", pid) + } + if !state.added { + t.Errorf("Process %d exists but was not marked as added", pid) + } + remainingCount++ + } + return true + }) + + // Verify all processes marked as removed are actually gone + mu.Lock() + for pid, state := range processStates { + if state.removed { + if pm.processTree.Has(pid) { + t.Errorf("Process %d was marked for removal but still exists", pid) + } + } else if state.added { + if !pm.processTree.Has(pid) { + t.Errorf("Process %d was marked as added but doesn't exist", pid) + } + } + } + mu.Unlock() + + // We expect exactly half of the processes to remain (odd-numbered ones) + expectedCount := processCount / 2 + assert.Equal(t, expectedCount, remainingCount, + "Expected exactly %d processes, got %d", expectedCount, remainingCount) + + // Verify all remaining processes have correct parent + pm.processTree.Range(func(pid uint32, process apitypes.Process) bool { + if pid >= 2000 && pid < 2000+uint32(processCount) { + assert.Equal(t, shimPID, process.PPID, + "Process %d should have shim as parent", pid) + } + return true + }) } From 7232814cac397a43e45c1fe41d93fdd5e7199f19 Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Tue, 29 Oct 2024 20:06:59 +0000 Subject: [PATCH 03/14] Adding base code Signed-off-by: Amit Schendel --- main.go | 16 ++++++++++------ pkg/containerwatcher/v1/container_watcher.go | 12 ++++++++++-- .../v1/container_watcher_private.go | 1 + pkg/containerwatcher/v1/open_test.go | 2 +- pkg/processmanager/process_manager_interface.go | 2 ++ pkg/processmanager/process_manager_mock.go | 4 ++++ pkg/processmanager/v1/process_manager.go | 5 ++++- pkg/rulemanager/v1/rule_manager.go | 12 ++++++++++-- 8 files changed, 42 insertions(+), 12 deletions(-) diff --git a/main.go b/main.go index f34c487a..453fc5cb 100644 --- a/main.go +++ b/main.go @@ -32,6 +32,8 @@ import ( "github.com/kubescape/node-agent/pkg/objectcache/k8scache" "github.com/kubescape/node-agent/pkg/objectcache/networkneighborhoodcache" objectcachev1 "github.com/kubescape/node-agent/pkg/objectcache/v1" + "github.com/kubescape/node-agent/pkg/processmanager" + processmanagerv1 "github.com/kubescape/node-agent/pkg/processmanager/v1" "github.com/kubescape/node-agent/pkg/relevancymanager" relevancymanagerv1 "github.com/kubescape/node-agent/pkg/relevancymanager/v1" rulebinding "github.com/kubescape/node-agent/pkg/rulebindingmanager" @@ -193,26 +195,27 @@ func main() { var networkManagerClient networkmanager.NetworkManagerClient var dnsManagerClient dnsmanager.DNSManagerClient var dnsResolver dnsmanager.DNSResolver - if cfg.EnableNetworkTracing { + if cfg.EnableNetworkTracing || cfg.EnableRuntimeDetection { dnsManager := dnsmanager.CreateDNSManager() dnsManagerClient = dnsManager // NOTE: dnsResolver is set for threat detection. dnsResolver = dnsManager networkManagerClient = networkmanagerv2.CreateNetworkManager(ctx, cfg, clusterData.ClusterName, k8sClient, storageClient, dnsManager, preRunningContainersIDs, k8sObjectCache) } else { - if cfg.EnableRuntimeDetection { - logger.L().Ctx(ctx).Fatal("Network tracing is disabled, but runtime detection is enabled. Network tracing is required for runtime detection.") - } dnsManagerClient = dnsmanager.CreateDNSManagerMock() dnsResolver = dnsmanager.CreateDNSManagerMock() networkManagerClient = networkmanager.CreateNetworkManagerMock() } var ruleManager rulemanager.RuleManagerClient + var processManager processmanager.ProcessManagerClient var objCache objectcache.ObjectCache var ruleBindingNotify chan rulebinding.RuleBindingNotify if cfg.EnableRuntimeDetection { + // create the process manager + processManager = processmanagerv1.CreateProcessManager(ctx) + // create ruleBinding cache ruleBindingCache := rulebindingcachev1.NewCache(nodeName, k8sClient) dWatcher.AddAdaptor(ruleBindingCache) @@ -235,7 +238,7 @@ func main() { exporter := exporters.InitExporters(cfg.Exporters, clusterData.ClusterName, nodeName) // create runtimeDetection managers - ruleManager, err = rulemanagerv1.CreateRuleManager(ctx, cfg, k8sClient, ruleBindingCache, objCache, exporter, prometheusExporter, nodeName, clusterData.ClusterName) + ruleManager, err = rulemanagerv1.CreateRuleManager(ctx, cfg, k8sClient, ruleBindingCache, objCache, exporter, prometheusExporter, nodeName, clusterData.ClusterName, processManager) if err != nil { logger.L().Ctx(ctx).Fatal("error creating RuleManager", helpers.Error(err)) } @@ -244,6 +247,7 @@ func main() { ruleManager = rulemanager.CreateRuleManagerMock() objCache = objectcache.NewObjectCacheMock() ruleBindingNotify = make(chan rulebinding.RuleBindingNotify, 1) + processManager = processmanager.CreateProcessManagerMock() } // Create the node profile manager @@ -269,7 +273,7 @@ func main() { } // Create the container handler - mainHandler, err := containerwatcher.CreateIGContainerWatcher(cfg, applicationProfileManager, k8sClient, relevancyManager, networkManagerClient, dnsManagerClient, prometheusExporter, ruleManager, malwareManager, preRunningContainersIDs, &ruleBindingNotify, containerRuntime, nil) + mainHandler, err := containerwatcher.CreateIGContainerWatcher(cfg, applicationProfileManager, k8sClient, relevancyManager, networkManagerClient, dnsManagerClient, prometheusExporter, ruleManager, malwareManager, preRunningContainersIDs, &ruleBindingNotify, containerRuntime, nil, processManager) if err != nil { logger.L().Ctx(ctx).Fatal("error creating the container watcher", helpers.Error(err)) } diff --git a/pkg/containerwatcher/v1/container_watcher.go b/pkg/containerwatcher/v1/container_watcher.go index 86c6edef..eeb19b5e 100644 --- a/pkg/containerwatcher/v1/container_watcher.go +++ b/pkg/containerwatcher/v1/container_watcher.go @@ -42,6 +42,7 @@ import ( tracersshtype "github.com/kubescape/node-agent/pkg/ebpf/gadgets/ssh/types" tracersymlink "github.com/kubescape/node-agent/pkg/ebpf/gadgets/symlink/tracer" tracersymlinktype "github.com/kubescape/node-agent/pkg/ebpf/gadgets/symlink/types" + "github.com/kubescape/node-agent/pkg/processmanager" "github.com/kubescape/node-agent/pkg/malwaremanager" "github.com/kubescape/node-agent/pkg/metricsmanager" @@ -153,11 +154,13 @@ type IGContainerWatcher struct { ruleBindingPodNotify *chan rulebinding.RuleBindingNotify // container runtime runtime *containerutilsTypes.RuntimeConfig + // process manager + processManager processmanager.ProcessManagerClient } var _ containerwatcher.ContainerWatcher = (*IGContainerWatcher)(nil) -func CreateIGContainerWatcher(cfg config.Config, applicationProfileManager applicationprofilemanager.ApplicationProfileManagerClient, k8sClient *k8sinterface.KubernetesApi, relevancyManager relevancymanager.RelevancyManagerClient, networkManagerClient networkmanager.NetworkManagerClient, dnsManagerClient dnsmanager.DNSManagerClient, metrics metricsmanager.MetricsManager, ruleManager rulemanager.RuleManagerClient, malwareManager malwaremanager.MalwareManagerClient, preRunningContainers mapset.Set[string], ruleBindingPodNotify *chan rulebinding.RuleBindingNotify, runtime *containerutilsTypes.RuntimeConfig, thirdPartyEventReceivers *maps.SafeMap[utils.EventType, mapset.Set[containerwatcher.EventReceiver]]) (*IGContainerWatcher, error) { +func CreateIGContainerWatcher(cfg config.Config, applicationProfileManager applicationprofilemanager.ApplicationProfileManagerClient, k8sClient *k8sinterface.KubernetesApi, relevancyManager relevancymanager.RelevancyManagerClient, networkManagerClient networkmanager.NetworkManagerClient, dnsManagerClient dnsmanager.DNSManagerClient, metrics metricsmanager.MetricsManager, ruleManager rulemanager.RuleManagerClient, malwareManager malwaremanager.MalwareManagerClient, preRunningContainers mapset.Set[string], ruleBindingPodNotify *chan rulebinding.RuleBindingNotify, runtime *containerutilsTypes.RuntimeConfig, thirdPartyEventReceivers *maps.SafeMap[utils.EventType, mapset.Set[containerwatcher.EventReceiver]], processManager processmanager.ProcessManagerClient) (*IGContainerWatcher, error) { // Use container collection to get notified for new containers containerCollection := &containercollection.ContainerCollection{} // Create a tracer collection instance @@ -449,6 +452,7 @@ func CreateIGContainerWatcher(cfg config.Config, applicationProfileManager appli runtime: runtime, thirdPartyTracers: mapset.NewSet[containerwatcher.CustomTracer](), thirdPartyContainerReceivers: mapset.NewSet[containerwatcher.ContainerReceiver](), + processManager: processManager, }, nil } @@ -494,11 +498,15 @@ func (ch *IGContainerWatcher) UnregisterContainerReceiver(receiver containerwatc func (ch *IGContainerWatcher) Start(ctx context.Context) error { if !ch.running { - if err := ch.startContainerCollection(ctx); err != nil { return fmt.Errorf("setting up container collection: %w", err) } + if err := ch.processManager.PopulateInitialProcesses(); err != nil { + ch.stopContainerCollection() + return fmt.Errorf("populating initial processes: %w", err) + } + if err := ch.startTracers(); err != nil { ch.stopContainerCollection() return fmt.Errorf("starting app behavior tracing: %w", err) diff --git a/pkg/containerwatcher/v1/container_watcher_private.go b/pkg/containerwatcher/v1/container_watcher_private.go index 5e89bf4e..ff219f2a 100644 --- a/pkg/containerwatcher/v1/container_watcher_private.go +++ b/pkg/containerwatcher/v1/container_watcher_private.go @@ -87,6 +87,7 @@ func (ch *IGContainerWatcher) startContainerCollection(ctx context.Context) erro ch.networkManager.ContainerCallback, ch.malwareManager.ContainerCallback, ch.ruleManager.ContainerCallback, + ch.processManager.ContainerCallback, } for receiver := range ch.thirdPartyContainerReceivers.Iter() { diff --git a/pkg/containerwatcher/v1/open_test.go b/pkg/containerwatcher/v1/open_test.go index 7c91dd6f..7fb4e084 100644 --- a/pkg/containerwatcher/v1/open_test.go +++ b/pkg/containerwatcher/v1/open_test.go @@ -23,7 +23,7 @@ func BenchmarkIGContainerWatcher_openEventCallback(b *testing.B) { assert.NoError(b, err) mockExporter := metricsmanager.NewMetricsMock() - mainHandler, err := CreateIGContainerWatcher(cfg, nil, nil, relevancyManager, nil, nil, mockExporter, nil, nil, nil, nil, nil, nil) + mainHandler, err := CreateIGContainerWatcher(cfg, nil, nil, relevancyManager, nil, nil, mockExporter, nil, nil, nil, nil, nil, nil, nil) assert.NoError(b, err) event := &traceropentype.Event{ Event: types.Event{ diff --git a/pkg/processmanager/process_manager_interface.go b/pkg/processmanager/process_manager_interface.go index bed86cff..0f36f367 100644 --- a/pkg/processmanager/process_manager_interface.go +++ b/pkg/processmanager/process_manager_interface.go @@ -11,6 +11,8 @@ import ( // The manager is responsible for maintaining the process tree for all containers. type ProcessManagerClient interface { GetProcessTreeForPID(containerID string, pid int) (apitypes.Process, error) + // PopulateInitialProcesses is called to populate the initial process tree (parsed from /proc) for all containers. + PopulateInitialProcesses() error // ReportEvent will be called to report new exec events to the process manager. ReportEvent(eventType utils.EventType, event utils.K8sEvent) diff --git a/pkg/processmanager/process_manager_mock.go b/pkg/processmanager/process_manager_mock.go index 43213fba..68fcdd14 100644 --- a/pkg/processmanager/process_manager_mock.go +++ b/pkg/processmanager/process_manager_mock.go @@ -19,6 +19,10 @@ func (p *ProcessManagerMock) GetProcessTreeForPID(containerID string, pid int) ( return apitypes.Process{}, nil } +func (p *ProcessManagerMock) PopulateInitialProcesses() error { + return nil +} + func (p *ProcessManagerMock) ReportEvent(eventType utils.EventType, event utils.K8sEvent) { // no-op } diff --git a/pkg/processmanager/v1/process_manager.go b/pkg/processmanager/v1/process_manager.go index 5c286bb4..668e1a2b 100644 --- a/pkg/processmanager/v1/process_manager.go +++ b/pkg/processmanager/v1/process_manager.go @@ -37,7 +37,7 @@ func CreateProcessManager(ctx context.Context) *ProcessManager { return pm } -func (p *ProcessManager) InitialProcScan() error { +func (p *ProcessManager) PopulateInitialProcesses() error { if len(p.containerIdToShimPid.Keys()) == 0 { return nil } @@ -258,6 +258,9 @@ func (p *ProcessManager) ReportEvent(eventType utils.EventType, event utils.K8sE Gid: &execEvent.Gid, Hardlink: execEvent.ExePath, UpperLayer: &execEvent.UpperLayer, + Path: execEvent.ExePath, + Cwd: execEvent.Cwd, + Pcomm: execEvent.Pcomm, Cmdline: strings.Join(execEvent.Args, " "), } diff --git a/pkg/rulemanager/v1/rule_manager.go b/pkg/rulemanager/v1/rule_manager.go index 56ce6156..9bb8d2f7 100644 --- a/pkg/rulemanager/v1/rule_manager.go +++ b/pkg/rulemanager/v1/rule_manager.go @@ -10,6 +10,7 @@ import ( "github.com/kubescape/node-agent/pkg/config" "github.com/kubescape/node-agent/pkg/exporters" "github.com/kubescape/node-agent/pkg/k8sclient" + "github.com/kubescape/node-agent/pkg/processmanager" "github.com/kubescape/node-agent/pkg/ruleengine" "github.com/kubescape/node-agent/pkg/rulemanager" "github.com/kubescape/node-agent/pkg/utils" @@ -64,11 +65,12 @@ type RuleManager struct { clusterName string containerIdToShimPid maps.SafeMap[string, uint32] containerIdToPid maps.SafeMap[string, uint32] + processManager processmanager.ProcessManagerClient } var _ rulemanager.RuleManagerClient = (*RuleManager)(nil) -func CreateRuleManager(ctx context.Context, cfg config.Config, k8sClient k8sclient.K8sClientInterface, ruleBindingCache bindingcache.RuleBindingCache, objectCache objectcache.ObjectCache, exporter exporters.Exporter, metrics metricsmanager.MetricsManager, nodeName string, clusterName string) (*RuleManager, error) { +func CreateRuleManager(ctx context.Context, cfg config.Config, k8sClient k8sclient.K8sClientInterface, ruleBindingCache bindingcache.RuleBindingCache, objectCache objectcache.ObjectCache, exporter exporters.Exporter, metrics metricsmanager.MetricsManager, nodeName string, clusterName string, processManager processmanager.ProcessManagerClient) (*RuleManager, error) { return &RuleManager{ cfg: cfg, ctx: ctx, @@ -81,6 +83,7 @@ func CreateRuleManager(ctx context.Context, cfg config.Config, k8sClient k8sclie metrics: metrics, nodeName: nodeName, clusterName: clusterName, + processManager: processManager, }, nil } @@ -434,7 +437,12 @@ func (rm *RuleManager) enrichRuleFailure(ruleFailure ruleengine.RuleFailure) rul runtimeProcessDetails.ProcessTree.Path = path } - if rm.containerIdToShimPid.Has(ruleFailure.GetRuntimeProcessDetails().ContainerID) { + tree, err := rm.processManager.GetProcessTreeForPID(ruleFailure.GetRuntimeProcessDetails().ContainerID, int(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID)) + if err == nil { + runtimeProcessDetails.ProcessTree = tree + } else if rm.containerIdToShimPid.Has(ruleFailure.GetRuntimeProcessDetails().ContainerID) { + logger.L().Debug("RuleManager - failed to get process tree, trying to get process tree from shim", + helpers.String("container ID", ruleFailure.GetRuntimeProcessDetails().ContainerID)) shimPid := rm.containerIdToShimPid.Get(ruleFailure.GetRuntimeProcessDetails().ContainerID) tree, err := utils.CreateProcessTree(&runtimeProcessDetails.ProcessTree, shimPid) if err == nil { From 82fce660314fdfc7adbda4e34b5567149759b9e3 Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Tue, 29 Oct 2024 21:27:17 +0000 Subject: [PATCH 04/14] Adding error msg Signed-off-by: Amit Schendel --- pkg/rulemanager/v1/rule_manager.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/rulemanager/v1/rule_manager.go b/pkg/rulemanager/v1/rule_manager.go index 9bb8d2f7..6adcd1ca 100644 --- a/pkg/rulemanager/v1/rule_manager.go +++ b/pkg/rulemanager/v1/rule_manager.go @@ -437,11 +437,13 @@ func (rm *RuleManager) enrichRuleFailure(ruleFailure ruleengine.RuleFailure) rul runtimeProcessDetails.ProcessTree.Path = path } + // TODO: Avoid Race condition where the tree is not populated yet. tree, err := rm.processManager.GetProcessTreeForPID(ruleFailure.GetRuntimeProcessDetails().ContainerID, int(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID)) if err == nil { runtimeProcessDetails.ProcessTree = tree } else if rm.containerIdToShimPid.Has(ruleFailure.GetRuntimeProcessDetails().ContainerID) { logger.L().Debug("RuleManager - failed to get process tree, trying to get process tree from shim", + helpers.Error(err), helpers.String("container ID", ruleFailure.GetRuntimeProcessDetails().ContainerID)) shimPid := rm.containerIdToShimPid.Get(ruleFailure.GetRuntimeProcessDetails().ContainerID) tree, err := utils.CreateProcessTree(&runtimeProcessDetails.ProcessTree, shimPid) From 72e9248265ba1713d27ab02259f28897deb142f6 Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Wed, 30 Oct 2024 08:30:57 +0000 Subject: [PATCH 05/14] Improving rules and adding debug logs when removing process Signed-off-by: Amit Schendel --- pkg/processmanager/v1/process_manager.go | 1 + .../v1/r0004_unexpected_capability_used.go | 8 + .../r0010_unexpected_sensitive_file_access.go | 38 +++- ...0_unexpected_sensitive_file_access_test.go | 211 ++++++++++-------- 4 files changed, 150 insertions(+), 108 deletions(-) diff --git a/pkg/processmanager/v1/process_manager.go b/pkg/processmanager/v1/process_manager.go index 668e1a2b..836116c0 100644 --- a/pkg/processmanager/v1/process_manager.go +++ b/pkg/processmanager/v1/process_manager.go @@ -291,6 +291,7 @@ func (p *ProcessManager) cleanup() { }) for pid := range deadPids { + logger.L().Debug("Removing dead process", helpers.Int("pid", int(pid))) p.removeProcess(pid) } } diff --git a/pkg/ruleengine/v1/r0004_unexpected_capability_used.go b/pkg/ruleengine/v1/r0004_unexpected_capability_used.go index 4f3b2571..c4d8e906 100644 --- a/pkg/ruleengine/v1/r0004_unexpected_capability_used.go +++ b/pkg/ruleengine/v1/r0004_unexpected_capability_used.go @@ -3,6 +3,7 @@ package ruleengine import ( "fmt" + "github.com/goradd/maps" "github.com/kubescape/node-agent/pkg/objectcache" "github.com/kubescape/node-agent/pkg/ruleengine" "github.com/kubescape/node-agent/pkg/utils" @@ -34,6 +35,7 @@ var _ ruleengine.RuleEvaluator = (*R0004UnexpectedCapabilityUsed)(nil) type R0004UnexpectedCapabilityUsed struct { BaseRule + alertedCapabilities maps.SafeMap[string, bool] } func CreateRuleR0004UnexpectedCapabilityUsed() *R0004UnexpectedCapabilityUsed { @@ -76,6 +78,10 @@ func (rule *R0004UnexpectedCapabilityUsed) ProcessEvent(eventType utils.EventTyp return nil } + if rule.alertedCapabilities.Has(capEvent.CapName) { + return nil + } + for _, capability := range appProfileCapabilitiesList.Capabilities { if capEvent.CapName == capability { return nil @@ -112,6 +118,8 @@ func (rule *R0004UnexpectedCapabilityUsed) ProcessEvent(eventType utils.EventTyp RuleID: rule.ID(), } + rule.alertedCapabilities.Set(capEvent.CapName, true) + return &ruleFailure } diff --git a/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access.go b/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access.go index 71a3bd40..1f512cc9 100644 --- a/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access.go +++ b/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access.go @@ -2,6 +2,7 @@ package ruleengine import ( "fmt" + "path/filepath" "strings" "github.com/kubescape/node-agent/pkg/objectcache" @@ -97,15 +98,7 @@ func (rule *R0010UnexpectedSensitiveFileAccess) ProcessEvent(eventType utils.Eve return nil } - isSensitive := false - for _, path := range rule.additionalPaths { - if strings.HasPrefix(openEvent.FullPath, path) { - isSensitive = true - break - } - } - - if !isSensitive { + if !isSensitivePath(openEvent.FullPath, rule.additionalPaths) { return nil } @@ -154,3 +147,30 @@ func (rule *R0010UnexpectedSensitiveFileAccess) Requirements() ruleengine.RuleSp EventTypes: R0010UnexpectedSensitiveFileAccessRuleDescriptor.Requirements.RequiredEventTypes(), } } + +// isSensitivePath checks if a given path matches or is within any sensitive paths +func isSensitivePath(fullPath string, paths []string) bool { + // Clean the path to handle "..", "//", etc. + fullPath = filepath.Clean(fullPath) + + for _, sensitivePath := range paths { + sensitivePath = filepath.Clean(sensitivePath) + + // Check if the path exactly matches + if fullPath == sensitivePath { + return true + } + + // Check if the path is a directory that contains sensitive files + if strings.HasPrefix(sensitivePath, fullPath+"/") { + return true + } + + // Check if the path is within a sensitive directory + if strings.HasPrefix(fullPath, sensitivePath+"/") { + return true + } + } + + return false +} diff --git a/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access_test.go b/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access_test.go index e35abe05..e45d7f2b 100644 --- a/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access_test.go +++ b/pkg/ruleengine/v1/r0010_unexpected_sensitive_file_access_test.go @@ -3,25 +3,14 @@ package ruleengine import ( "testing" - "github.com/kubescape/node-agent/pkg/objectcache" - "github.com/kubescape/node-agent/pkg/utils" - - "github.com/kubescape/storage/pkg/apis/softwarecomposition/v1beta1" - traceropentype "github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/trace/open/types" eventtypes "github.com/inspektor-gadget/inspektor-gadget/pkg/types" + "github.com/kubescape/node-agent/pkg/utils" + "github.com/kubescape/storage/pkg/apis/softwarecomposition/v1beta1" ) -func TestR0010UnexpectedSensitiveFileAccess(t *testing.T) { - // Create a new rule - r := CreateRuleR0010UnexpectedSensitiveFileAccess() - // Assert r is not nil - if r == nil { - t.Errorf("Expected r to not be nil") - } - - // Create a file access event - e := &traceropentype.Event{ +func createTestEvent(path string, flags []string) *traceropentype.Event { + return &traceropentype.Event{ Event: eventtypes.Event{ CommonData: eventtypes.CommonData{ K8s: eventtypes.K8sMetadata{ @@ -31,103 +20,127 @@ func TestR0010UnexpectedSensitiveFileAccess(t *testing.T) { }, }, }, - Path: "/test", - FullPath: "/test", - Flags: []string{"O_RDONLY"}, - } - - // Test with nil appProfileAccess - ruleResult := r.ProcessEvent(utils.OpenEventType, e, &objectcache.ObjectCacheMock{}) - if ruleResult != nil { - t.Errorf("Expected ruleResult to not be nil since no appProfile") + Path: path, + FullPath: path, + Flags: flags, } +} - // Test with whitelisted file - objCache := RuleObjectCacheMock{} - profile := objCache.ApplicationProfileCache().GetApplicationProfile("test") - if profile == nil { - profile = &v1beta1.ApplicationProfile{ - Spec: v1beta1.ApplicationProfileSpec{ - Containers: []v1beta1.ApplicationProfileContainer{ - { - Name: "test", - Opens: []v1beta1.OpenCalls{ - { - Path: "/test", - Flags: []string{"O_RDONLY"}, - }, - }, - }, - }, - }, +func createTestProfile(containerName string, paths []string, flags []string) *v1beta1.ApplicationProfile { + opens := make([]v1beta1.OpenCalls, len(paths)) + for i, path := range paths { + opens[i] = v1beta1.OpenCalls{ + Path: path, + Flags: flags, } - objCache.SetApplicationProfile(profile) - } - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult != nil { - t.Errorf("Expected ruleResult to be nil since file is whitelisted and not sensitive") - } - - // Test with non whitelisted file, but not sensitive - e.FullPath = "/var/test1" - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult != nil { - t.Errorf("Expected ruleResult to be nil since file is not whitelisted and not sensitive") - } - - // Test with sensitive file that is whitelisted - e.FullPath = "/etc/shadow" - profile.Spec.Containers[0].Opens[0].Path = "/etc/shadow" - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult != nil { - t.Errorf("Expected ruleResult to be nil since file is whitelisted and sensitive") - } - - // Test with sensitive file, but not whitelisted - e.FullPath = "/etc/shadow" - profile.Spec.Containers[0].Opens[0].Path = "/test" - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult == nil { - t.Errorf("Expected ruleResult to not be nil since file is not whitelisted and sensitive") - } - - // Test with sensitive file that originates from additionalPaths parameter - e.FullPath = "/etc/blabla" - profile.Spec.Containers[0].Opens[0].Path = "/test" - additionalPaths := []interface{}{"/etc/blabla"} - r.SetParameters(map[string]interface{}{"additionalPaths": additionalPaths}) - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult == nil { - t.Errorf("Expected ruleResult to not be nil since file is not whitelisted and sensitive") } - e.FullPath = "/tmp/blabla" - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult != nil { - t.Errorf("Expected ruleResult to be nil since file is whitelisted and not sensitive") - } - - profile = &v1beta1.ApplicationProfile{ + return &v1beta1.ApplicationProfile{ Spec: v1beta1.ApplicationProfileSpec{ Containers: []v1beta1.ApplicationProfileContainer{ { - Name: "test", - Opens: []v1beta1.OpenCalls{ - { - Path: "/etc/\u22ef", - Flags: []string{"O_RDONLY"}, - }, - }, + Name: containerName, + Opens: opens, }, }, }, } - objCache.SetApplicationProfile(profile) +} - e.FullPath = "/etc/blabla" - ruleResult = r.ProcessEvent(utils.OpenEventType, e, &objCache) - if ruleResult != nil { - t.Errorf("Expected ruleResult to be nil since file is whitelisted and not sensitive") +func TestR0010UnexpectedSensitiveFileAccess(t *testing.T) { + tests := []struct { + name string + event *traceropentype.Event + profile *v1beta1.ApplicationProfile + additionalPaths []interface{} + expectAlert bool + description string + }{ + { + name: "No application profile", + event: createTestEvent("/test", []string{"O_RDONLY"}), + profile: nil, + expectAlert: false, + description: "Should not alert when no application profile is present", + }, + { + name: "Whitelisted non-sensitive file", + event: createTestEvent("/test", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/test"}, []string{"O_RDONLY"}), + expectAlert: false, + description: "Should not alert for whitelisted non-sensitive file", + }, + { + name: "Non-whitelisted non-sensitive file", + event: createTestEvent("/var/test1", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/test"}, []string{"O_RDONLY"}), + expectAlert: false, + description: "Should not alert for non-whitelisted non-sensitive file", + }, + { + name: "Whitelisted sensitive file", + event: createTestEvent("/etc/shadow", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/etc/shadow"}, []string{"O_RDONLY"}), + expectAlert: false, + description: "Should not alert for whitelisted sensitive file", + }, + { + name: "Non-whitelisted sensitive file", + event: createTestEvent("/etc/shadow", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/test"}, []string{"O_RDONLY"}), + expectAlert: true, + description: "Should alert for non-whitelisted sensitive file", + }, + { + name: "Additional sensitive path", + event: createTestEvent("/etc/custom-sensitive", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/test"}, []string{"O_RDONLY"}), + additionalPaths: []interface{}{"/etc/custom-sensitive"}, + expectAlert: true, + description: "Should alert for non-whitelisted file in additional sensitive paths", + }, + { + name: "Wildcard path match", + event: createTestEvent("/etc/blabla", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/etc/\u22ef"}, []string{"O_RDONLY"}), + expectAlert: false, + description: "Should not alert when path matches wildcard pattern", + }, + { + name: "Path traversal attempt", + event: createTestEvent("/etc/shadow/../passwd", []string{"O_RDONLY"}), + profile: createTestProfile("test", []string{"/test"}, []string{"O_RDONLY"}), + expectAlert: true, + description: "Should alert for path traversal attempts", + }, } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := CreateRuleR0010UnexpectedSensitiveFileAccess() + if rule == nil { + t.Fatal("Expected rule to not be nil") + } + + objCache := &RuleObjectCacheMock{} + if tt.profile != nil { + objCache.SetApplicationProfile(tt.profile) + } + + if tt.additionalPaths != nil { + rule.SetParameters(map[string]interface{}{ + "additionalPaths": tt.additionalPaths, + }) + } + + result := rule.ProcessEvent(utils.OpenEventType, tt.event, objCache) + + if tt.expectAlert && result == nil { + t.Errorf("%s: expected alert but got none", tt.description) + } + if !tt.expectAlert && result != nil { + t.Errorf("%s: expected no alert but got one", tt.description) + } + }) + } } From ad18f3ad5593bc73320be6c0f0d83a5f5e2b3cb4 Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Wed, 30 Oct 2024 09:34:23 +0000 Subject: [PATCH 06/14] Adding more tests and logic to addProcess Signed-off-by: Amit Schendel --- pkg/processmanager/v1/process_manager.go | 17 ++ pkg/processmanager/v1/process_manager_test.go | 248 ++++++++++++++++++ 2 files changed, 265 insertions(+) diff --git a/pkg/processmanager/v1/process_manager.go b/pkg/processmanager/v1/process_manager.go index 836116c0..77a3f4fe 100644 --- a/pkg/processmanager/v1/process_manager.go +++ b/pkg/processmanager/v1/process_manager.go @@ -157,8 +157,25 @@ func (p *ProcessManager) removeProcessesUnderShim(shimPID uint32) { } func (p *ProcessManager) addProcess(process apitypes.Process) { + // First, check if the process already exists and has a different parent + if existingProc, exists := p.processTree.Load(process.PID); exists && existingProc.PPID != process.PPID { + // Remove from old parent's children list + if oldParent, exists := p.processTree.Load(existingProc.PPID); exists { + newChildren := make([]apitypes.Process, 0, len(oldParent.Children)) + for _, child := range oldParent.Children { + if child.PID != process.PID { + newChildren = append(newChildren, child) + } + } + oldParent.Children = newChildren + p.processTree.Set(oldParent.PID, oldParent) + } + } + + // Update the process in the tree p.processTree.Set(process.PID, process) + // Update new parent's children list if parent, exists := p.processTree.Load(process.PPID); exists { newChildren := make([]apitypes.Process, 0, len(parent.Children)+1) hasProcess := false diff --git a/pkg/processmanager/v1/process_manager_test.go b/pkg/processmanager/v1/process_manager_test.go index 0d97bd88..8b3f37df 100644 --- a/pkg/processmanager/v1/process_manager_test.go +++ b/pkg/processmanager/v1/process_manager_test.go @@ -583,3 +583,251 @@ func TestRaceConditions(t *testing.T) { return true }) } + +func TestDuplicateProcessHandling(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) + + containerID := "test-container" + shimPID := uint32(999) + containerPID := uint32(1000) + + // Setup container + addMockProcess(int(containerPID), shimPID, "container-main") + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: containerPID, + }, + }) + + t.Run("update process with same parent", func(t *testing.T) { + // First add a parent process + parentEvent := &tracerexectype.Event{ + Pid: 1001, + Ppid: containerPID, + Comm: "parent-process", + Args: []string{"parent-process", "--initial"}, + } + pm.ReportEvent(utils.ExecveEventType, parentEvent) + + // Add child process + childEvent := &tracerexectype.Event{ + Pid: 1002, + Ppid: 1001, + Comm: "child-process", + Args: []string{"child-process", "--initial"}, + } + pm.ReportEvent(utils.ExecveEventType, childEvent) + + // Verify initial state + parent, exists := pm.processTree.Load(1001) + require.True(t, exists) + assert.Equal(t, "parent-process", parent.Comm) + assert.Equal(t, "parent-process --initial", parent.Cmdline) + assert.Len(t, parent.Children, 1) + assert.Equal(t, uint32(1002), parent.Children[0].PID) + + // Add same child process again with different arguments + updatedChildEvent := &tracerexectype.Event{ + Pid: 1002, + Ppid: 1001, + Comm: "child-process", + Args: []string{"child-process", "--updated"}, + } + pm.ReportEvent(utils.ExecveEventType, updatedChildEvent) + + // Verify the process was updated + updatedChild, exists := pm.processTree.Load(1002) + require.True(t, exists) + assert.Equal(t, "child-process --updated", updatedChild.Cmdline) + + // Verify parent's children list was updated + updatedParent, exists := pm.processTree.Load(1001) + require.True(t, exists) + assert.Len(t, updatedParent.Children, 1) + assert.Equal(t, "child-process --updated", updatedParent.Children[0].Cmdline) + }) + + t.Run("update process with different parent", func(t *testing.T) { + // Move process to different parent + differentParentEvent := &tracerexectype.Event{ + Pid: 1002, + Ppid: containerPID, + Comm: "child-process", + Args: []string{"child-process", "--new-parent"}, + } + pm.ReportEvent(utils.ExecveEventType, differentParentEvent) + + // Verify process was updated with new parent + movedChild, exists := pm.processTree.Load(1002) + require.True(t, exists) + assert.Equal(t, containerPID, movedChild.PPID) + assert.Equal(t, "child-process --new-parent", movedChild.Cmdline) + + // Verify old parent no longer has the child + oldParent, exists := pm.processTree.Load(1001) + require.True(t, exists) + assert.Empty(t, oldParent.Children, "Old parent should have no children") + + // Verify new parent has the child + containerProcess, exists := pm.processTree.Load(containerPID) + require.True(t, exists) + hasChild := false + for _, child := range containerProcess.Children { + if child.PID == 1002 { + hasChild = true + assert.Equal(t, "child-process --new-parent", child.Cmdline) + } + } + assert.True(t, hasChild, "New parent should have the child") + }) +} + +func TestProcessReparenting(t *testing.T) { + pm, addMockProcess := setupTestProcessManager(t) + + containerID := "test-container" + shimPID := uint32(999) + containerPID := uint32(1000) + + // Setup container + addMockProcess(int(containerPID), shimPID, "container-main") + pm.ContainerCallback(containercollection.PubSubEvent{ + Type: containercollection.EventTypeAddContainer, + Container: &containercollection.Container{ + Runtime: containercollection.RuntimeMetadata{ + BasicRuntimeMetadata: types.BasicRuntimeMetadata{ + ContainerID: containerID, + }, + }, + Pid: containerPID, + }, + }) + + t.Run("reparent to nearest living ancestor", func(t *testing.T) { + // Create a chain of processes: + // shim -> grandparent -> parent -> child + + // Create grandparent process + grandparentPID := uint32(2000) + grandparentEvent := &tracerexectype.Event{ + Pid: grandparentPID, + Ppid: shimPID, + Comm: "grandparent", + Args: []string{"grandparent"}, + } + pm.ReportEvent(utils.ExecveEventType, grandparentEvent) + + // Create parent process + parentPID := uint32(2001) + parentEvent := &tracerexectype.Event{ + Pid: parentPID, + Ppid: grandparentPID, + Comm: "parent", + Args: []string{"parent"}, + } + pm.ReportEvent(utils.ExecveEventType, parentEvent) + + // Create child process + childPID := uint32(2002) + childEvent := &tracerexectype.Event{ + Pid: childPID, + Ppid: parentPID, + Comm: "child", + Args: []string{"child"}, + } + pm.ReportEvent(utils.ExecveEventType, childEvent) + + // Verify initial hierarchy + child, exists := pm.processTree.Load(childPID) + require.True(t, exists) + assert.Equal(t, parentPID, child.PPID) + + parent, exists := pm.processTree.Load(parentPID) + require.True(t, exists) + assert.Equal(t, grandparentPID, parent.PPID) + + // When parent dies, child should be reparented to grandparent + pm.removeProcess(parentPID) + + // Verify child was reparented to grandparent + child, exists = pm.processTree.Load(childPID) + require.True(t, exists) + assert.Equal(t, grandparentPID, child.PPID, "Child should be reparented to grandparent") + + // Verify grandparent has the child in its children list + grandparent, exists := pm.processTree.Load(grandparentPID) + require.True(t, exists) + hasChild := false + for _, c := range grandparent.Children { + if c.PID == childPID { + hasChild = true + break + } + } + assert.True(t, hasChild, "Grandparent should have the reparented child") + + // Now if grandparent dies too, child should be reparented to shim + pm.removeProcess(grandparentPID) + + child, exists = pm.processTree.Load(childPID) + require.True(t, exists) + assert.Equal(t, shimPID, child.PPID, "Child should be reparented to shim when grandparent dies") + }) + + t.Run("reparent multiple children", func(t *testing.T) { + // Create a parent with multiple children + parentPID := uint32(3000) + parentEvent := &tracerexectype.Event{ + Pid: parentPID, + Ppid: shimPID, + Comm: "parent", + Args: []string{"parent"}, + } + pm.ReportEvent(utils.ExecveEventType, parentEvent) + + // Create several children + childPIDs := []uint32{3001, 3002, 3003} + for _, pid := range childPIDs { + childEvent := &tracerexectype.Event{ + Pid: pid, + Ppid: parentPID, + Comm: fmt.Sprintf("child-%d", pid), + Args: []string{"child"}, + } + pm.ReportEvent(utils.ExecveEventType, childEvent) + } + + // Create a subprocess under one of the children + grandchildPID := uint32(3004) + grandchildEvent := &tracerexectype.Event{ + Pid: grandchildPID, + Ppid: childPIDs[0], + Comm: "grandchild", + Args: []string{"grandchild"}, + } + pm.ReportEvent(utils.ExecveEventType, grandchildEvent) + + // When parent dies, all direct children should be reparented to shim + pm.removeProcess(parentPID) + + // Verify all children were reparented to shim + for _, childPID := range childPIDs { + child, exists := pm.processTree.Load(childPID) + require.True(t, exists) + assert.Equal(t, shimPID, child.PPID, "Child should be reparented to shim") + } + + // When first child dies, its grandchild should be reparented to shim too + pm.removeProcess(childPIDs[0]) + + grandchild, exists := pm.processTree.Load(grandchildPID) + require.True(t, exists) + assert.Equal(t, shimPID, grandchild.PPID, "Grandchild should be reparented to shim") + }) +} From 2ce1600f51c16b00c2f174b4bdc4326c0ef77d3a Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Wed, 30 Oct 2024 09:42:51 +0000 Subject: [PATCH 07/14] Adding docs for each function Signed-off-by: Amit Schendel --- pkg/processmanager/v1/process_manager.go | 35 ++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/pkg/processmanager/v1/process_manager.go b/pkg/processmanager/v1/process_manager.go index 77a3f4fe..488dafcc 100644 --- a/pkg/processmanager/v1/process_manager.go +++ b/pkg/processmanager/v1/process_manager.go @@ -37,6 +37,9 @@ func CreateProcessManager(ctx context.Context) *ProcessManager { return pm } +// PopulateInitialProcesses scans the /proc filesystem to build the initial process tree +// for all registered container shim processes. It establishes parent-child relationships +// between processes and adds them to the process tree if they are descendants of a shim. func (p *ProcessManager) PopulateInitialProcesses() error { if len(p.containerIdToShimPid.Keys()) == 0 { return nil @@ -81,6 +84,9 @@ func (p *ProcessManager) PopulateInitialProcesses() error { return nil } +// isDescendantOfShim checks if a process with the given PID is a descendant of any +// registered shim process. It traverses the process tree upwards until it either finds +// a shim process or reaches the maximum tree depth to prevent infinite loops. func (p *ProcessManager) isDescendantOfShim(pid uint32, ppid uint32, shimPIDs map[uint32]struct{}, processes map[uint32]apitypes.Process) bool { visited := make(map[uint32]bool) currentPID := pid @@ -104,6 +110,9 @@ func (p *ProcessManager) isDescendantOfShim(pid uint32, ppid uint32, shimPIDs ma return false } +// ContainerCallback handles container lifecycle events (creation and removal). +// For new containers, it identifies the container's shim process and adds it to the tracking system. +// For removed containers, it cleans up the associated processes from the process tree. func (p *ProcessManager) ContainerCallback(notif containercollection.PubSubEvent) { containerID := notif.Container.Runtime.BasicRuntimeMetadata.ContainerID @@ -128,6 +137,9 @@ func (p *ProcessManager) ContainerCallback(notif containercollection.PubSubEvent } } +// removeProcessesUnderShim removes all processes that are descendants of the specified +// shim process PID from the process tree. This is typically called when a container +// is being removed. func (p *ProcessManager) removeProcessesUnderShim(shimPID uint32) { var pidsToRemove []uint32 @@ -156,6 +168,9 @@ func (p *ProcessManager) removeProcessesUnderShim(shimPID uint32) { } } +// addProcess adds or updates a process in the process tree and maintains the +// parent-child relationships between processes. If the process already exists +// with a different parent, it updates the relationships accordingly. func (p *ProcessManager) addProcess(process apitypes.Process) { // First, check if the process already exists and has a different parent if existingProc, exists := p.processTree.Load(process.PID); exists && existingProc.PPID != process.PPID { @@ -195,6 +210,9 @@ func (p *ProcessManager) addProcess(process apitypes.Process) { } } +// removeProcess removes a process from the process tree and updates the parent-child +// relationships. Children of the removed process are reassigned to their grandparent +// to maintain the process hierarchy. func (p *ProcessManager) removeProcess(pid uint32) { if process, exists := p.processTree.Load(pid); exists { if parent, exists := p.processTree.Load(process.PPID); exists { @@ -219,6 +237,9 @@ func (p *ProcessManager) removeProcess(pid uint32) { } } +// GetProcessTreeForPID retrieves the process tree for a specific PID within a container. +// It returns the process and all its ancestors up to the container's shim process. +// If the process is not in the tree, it attempts to fetch it from /proc. func (p *ProcessManager) GetProcessTreeForPID(containerID string, pid int) (apitypes.Process, error) { if !p.containerIdToShimPid.Has(containerID) { return apitypes.Process{}, fmt.Errorf("container ID %s not found", containerID) @@ -257,6 +278,9 @@ func (p *ProcessManager) GetProcessTreeForPID(containerID string, pid int) (apit return result, nil } +// ReportEvent handles process execution events from the system. +// It specifically processes execve events to track new process creations +// and updates the process tree accordingly. func (p *ProcessManager) ReportEvent(eventType utils.EventType, event utils.K8sEvent) { if eventType != utils.ExecveEventType { return @@ -284,6 +308,10 @@ func (p *ProcessManager) ReportEvent(eventType utils.EventType, event utils.K8sE p.addProcess(process) } +// startCleanupRoutine starts a goroutine that periodically runs the cleanup +// function to remove dead processes from the process tree. It continues until +// the context is cancelled. +// TODO: Register eBPF tracer to get process exit events and remove dead processes immediately. func (p *ProcessManager) startCleanupRoutine(ctx context.Context) { ticker := time.NewTicker(cleanupInterval) defer ticker.Stop() @@ -298,6 +326,8 @@ func (p *ProcessManager) startCleanupRoutine(ctx context.Context) { } } +// cleanup removes dead processes from the process tree by checking if each +// process in the tree is still alive in the system. func (p *ProcessManager) cleanup() { deadPids := make(map[uint32]bool) p.processTree.Range(func(pid uint32, _ apitypes.Process) bool { @@ -313,6 +343,9 @@ func (p *ProcessManager) cleanup() { } } +// getProcessFromProc retrieves process information from the /proc filesystem +// for a given PID. It collects various process attributes such as command line, +// working directory, and user/group IDs. func getProcessFromProc(pid int) (apitypes.Process, error) { proc, err := procfs.NewProc(pid) if err != nil { @@ -354,6 +387,8 @@ func getProcessFromProc(pid int) (apitypes.Process, error) { }, nil } +// isProcessAlive checks if a process with the given PID is still running +// by attempting to read its information from the /proc filesystem. func isProcessAlive(pid int) bool { proc, err := procfs.NewProc(pid) if err != nil { From 2a7bbd24e81d221f0ef8cc532357936cca592399 Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Wed, 30 Oct 2024 09:52:35 +0000 Subject: [PATCH 08/14] Adding remove processes under shims Signed-off-by: Amit Schendel --- pkg/processmanager/v1/process_manager_test.go | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/pkg/processmanager/v1/process_manager_test.go b/pkg/processmanager/v1/process_manager_test.go index 8b3f37df..bc0e3c15 100644 --- a/pkg/processmanager/v1/process_manager_test.go +++ b/pkg/processmanager/v1/process_manager_test.go @@ -831,3 +831,84 @@ func TestProcessReparenting(t *testing.T) { assert.Equal(t, shimPID, grandchild.PPID, "Grandchild should be reparented to shim") }) } + +func TestRemoveProcessesUnderShim(t *testing.T) { + tests := []struct { + name string + initialTree map[uint32]apitypes.Process + shimPID uint32 + expectedTree map[uint32]apitypes.Process + description string + }{ + { + name: "simple_process_tree", + initialTree: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim", Children: []apitypes.Process{}}, // shim process + 200: {PID: 200, PPID: 100, Comm: "parent", Children: []apitypes.Process{}}, // direct child of shim + 201: {PID: 201, PPID: 200, Comm: "child1", Children: []apitypes.Process{}}, // child of parent + 202: {PID: 202, PPID: 200, Comm: "child2", Children: []apitypes.Process{}}, // another child of parent + }, + shimPID: 100, + expectedTree: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim", Children: []apitypes.Process{}}, // only shim remains + }, + description: "Should remove all processes under shim including children of children", + }, + { + name: "empty_tree", + initialTree: map[uint32]apitypes.Process{}, + shimPID: 100, + expectedTree: map[uint32]apitypes.Process{}, + description: "Should handle empty process tree gracefully", + }, + { + name: "orphaned_processes", + initialTree: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim", Children: []apitypes.Process{}}, // shim process + 200: {PID: 200, PPID: 100, Comm: "parent", Children: []apitypes.Process{}}, // direct child of shim + 201: {PID: 201, PPID: 999, Comm: "orphan", Children: []apitypes.Process{}}, // orphaned process (parent doesn't exist) + }, + shimPID: 100, + expectedTree: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim", Children: []apitypes.Process{}}, // shim remains + 201: {PID: 201, PPID: 999, Comm: "orphan", Children: []apitypes.Process{}}, // orphan unaffected + }, + description: "Should handle orphaned processes correctly", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create process manager with test data + pm := &ProcessManager{} + + // Populate initial process tree + for pid, process := range tc.initialTree { + pm.processTree.Set(pid, process) + } + + // Call the function under test + pm.removeProcessesUnderShim(tc.shimPID) + + // Verify results + assert.Equal(t, len(tc.expectedTree), len(pm.processTree.Keys()), + "Process tree size mismatch after removal") + + // Check each expected process + for pid, expectedProcess := range tc.expectedTree { + actualProcess, exists := pm.processTree.Load(pid) + assert.True(t, exists, "Expected process %d not found in tree", pid) + assert.Equal(t, expectedProcess, actualProcess, + "Process %d details don't match expected values", pid) + } + + // Verify no unexpected processes remain + pm.processTree.Range(func(pid uint32, process apitypes.Process) bool { + _, shouldExist := tc.expectedTree[pid] + assert.True(t, shouldExist, + "Unexpected process %d found in tree", pid) + return true + }) + }) + } +} From 287844f3b290cfa81e389ceda69f3defd4312b39 Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Wed, 30 Oct 2024 10:47:31 +0000 Subject: [PATCH 09/14] Adding test for isDescendantOfShim Signed-off-by: Amit Schendel --- pkg/processmanager/v1/process_manager_test.go | 132 ++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/pkg/processmanager/v1/process_manager_test.go b/pkg/processmanager/v1/process_manager_test.go index bc0e3c15..6405763f 100644 --- a/pkg/processmanager/v1/process_manager_test.go +++ b/pkg/processmanager/v1/process_manager_test.go @@ -912,3 +912,135 @@ func TestRemoveProcessesUnderShim(t *testing.T) { }) } } + +func TestIsDescendantOfShim(t *testing.T) { + tests := []struct { + name string + processes map[uint32]apitypes.Process + shimPIDs map[uint32]struct{} + pid uint32 + ppid uint32 + expected bool + description string + }{ + { + name: "direct_child_of_shim", + processes: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim"}, + 200: {PID: 200, PPID: 100, Comm: "child"}, + }, + shimPIDs: map[uint32]struct{}{ + 100: {}, + }, + pid: 200, + ppid: 100, + expected: true, + description: "Process is a direct child of shim", + }, + { + name: "indirect_descendant", + processes: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim"}, + 200: {PID: 200, PPID: 100, Comm: "parent"}, + 300: {PID: 300, PPID: 200, Comm: "child"}, + }, + shimPIDs: map[uint32]struct{}{ + 100: {}, + }, + pid: 300, + ppid: 200, + expected: true, + description: "Process is an indirect descendant of shim", + }, + { + name: "not_a_descendant", + processes: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim"}, + 200: {PID: 200, PPID: 2, Comm: "unrelated"}, + }, + shimPIDs: map[uint32]struct{}{ + 100: {}, + }, + pid: 200, + ppid: 2, + expected: false, + description: "Process is not a descendant of any shim", + }, + { + name: "circular_reference", + processes: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim"}, + 200: {PID: 200, PPID: 300, Comm: "circular1"}, + 300: {PID: 300, PPID: 200, Comm: "circular2"}, + }, + shimPIDs: map[uint32]struct{}{ + 100: {}, + }, + pid: 200, + ppid: 300, + expected: false, + description: "Process is part of a circular reference", + }, + { + name: "process_chain_exceeds_max_depth", + processes: func() map[uint32]apitypes.Process { + // Create a chain where the target process is maxTreeDepth + 1 steps away from any shim + procs := map[uint32]apitypes.Process{ + 1: {PID: 1, PPID: 0, Comm: "init"}, // init process + 2: {PID: 2, PPID: 1, Comm: "shim"}, // shim process + } + // Create a chain starting far from the shim + currentPPID := uint32(100) // Start with a different base to avoid conflicts + targetPID := uint32(100 + maxTreeDepth + 1) + + // Build the chain backwards from target to base + for pid := targetPID; pid > currentPPID; pid-- { + procs[pid] = apitypes.Process{ + PID: pid, + PPID: pid - 1, + Comm: fmt.Sprintf("process-%d", pid), + } + } + // Add the base process that's not connected to shim + procs[currentPPID] = apitypes.Process{ + PID: currentPPID, + PPID: currentPPID - 1, + Comm: fmt.Sprintf("process-%d", currentPPID), + } + return procs + }(), + shimPIDs: map[uint32]struct{}{ + 2: {}, // Shim PID + }, + pid: uint32(100 + maxTreeDepth + 1), // Target process at the end of chain + ppid: uint32(100 + maxTreeDepth), // Its immediate parent + expected: false, + description: "Process chain exceeds maximum allowed depth", + }, + { + name: "multiple_shims", + processes: map[uint32]apitypes.Process{ + 100: {PID: 100, PPID: 1, Comm: "shim1"}, + 101: {PID: 101, PPID: 1, Comm: "shim2"}, + 200: {PID: 200, PPID: 100, Comm: "child1"}, + 201: {PID: 201, PPID: 101, Comm: "child2"}, + }, + shimPIDs: map[uint32]struct{}{ + 100: {}, + 101: {}, + }, + pid: 200, + ppid: 100, + expected: true, + description: "Multiple shims in the system", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + pm := &ProcessManager{} + result := pm.isDescendantOfShim(tc.pid, tc.ppid, tc.shimPIDs, tc.processes) + assert.Equal(t, tc.expected, result, tc.description) + }) + } +} From 566d594f29ec988289418554f88195292c79a0b9 Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Wed, 30 Oct 2024 10:48:34 +0000 Subject: [PATCH 10/14] Adding comment Signed-off-by: Amit Schendel --- pkg/containerwatcher/v1/container_watcher.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/containerwatcher/v1/container_watcher.go b/pkg/containerwatcher/v1/container_watcher.go index eeb19b5e..ece16315 100644 --- a/pkg/containerwatcher/v1/container_watcher.go +++ b/pkg/containerwatcher/v1/container_watcher.go @@ -502,6 +502,7 @@ func (ch *IGContainerWatcher) Start(ctx context.Context) error { return fmt.Errorf("setting up container collection: %w", err) } + // We want to populate the initial processes before starting the tracers but after retrieving the shims. if err := ch.processManager.PopulateInitialProcesses(); err != nil { ch.stopContainerCollection() return fmt.Errorf("populating initial processes: %w", err) From c0798a9a408db5ce983651431d60b6e4e516793f Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Wed, 30 Oct 2024 14:12:22 +0000 Subject: [PATCH 11/14] Adding report exec to add new processes Signed-off-by: Amit Schendel --- pkg/containerwatcher/v1/container_watcher.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/containerwatcher/v1/container_watcher.go b/pkg/containerwatcher/v1/container_watcher.go index ece16315..4e1af499 100644 --- a/pkg/containerwatcher/v1/container_watcher.go +++ b/pkg/containerwatcher/v1/container_watcher.go @@ -206,6 +206,7 @@ func CreateIGContainerWatcher(cfg config.Config, applicationProfileManager appli path = event.Args[0] } metrics.ReportEvent(utils.ExecveEventType) + processManager.ReportEvent(utils.ExecveEventType, &event) applicationProfileManager.ReportFileExec(k8sContainerID, path, event.Args) relevancyManager.ReportFileExec(event.Runtime.ContainerID, k8sContainerID, path) ruleManager.ReportEvent(utils.ExecveEventType, &event) From 2016b8e9f14a1bf44b1ca5d48d345b7684cb392e Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Wed, 30 Oct 2024 15:52:41 +0000 Subject: [PATCH 12/14] Making sure children slice exist Signed-off-by: Amit Schendel --- pkg/processmanager/v1/process_manager.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pkg/processmanager/v1/process_manager.go b/pkg/processmanager/v1/process_manager.go index 488dafcc..1500eb8e 100644 --- a/pkg/processmanager/v1/process_manager.go +++ b/pkg/processmanager/v1/process_manager.go @@ -303,6 +303,7 @@ func (p *ProcessManager) ReportEvent(eventType utils.EventType, event utils.K8sE Cwd: execEvent.Cwd, Pcomm: execEvent.Pcomm, Cmdline: strings.Join(execEvent.Args, " "), + Children: []apitypes.Process{}, } p.addProcess(process) @@ -376,14 +377,15 @@ func getProcessFromProc(pid int) (apitypes.Process, error) { path, _ := proc.Executable() return apitypes.Process{ - PID: uint32(pid), - PPID: uint32(stat.PPID), - Comm: stat.Comm, - Uid: &uid, - Gid: &gid, - Cmdline: strings.Join(cmdline, " "), - Cwd: cwd, - Path: path, + PID: uint32(pid), + PPID: uint32(stat.PPID), + Comm: stat.Comm, + Uid: &uid, + Gid: &gid, + Cmdline: strings.Join(cmdline, " "), + Cwd: cwd, + Path: path, + Children: []apitypes.Process{}, }, nil } From 9fa19fd5b778e0cd2429f8a05b29cf2338683d34 Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Thu, 31 Oct 2024 08:41:32 +0000 Subject: [PATCH 13/14] Adding pcomm Signed-off-by: Amit Schendel --- pkg/processmanager/v1/process_manager.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pkg/processmanager/v1/process_manager.go b/pkg/processmanager/v1/process_manager.go index 1500eb8e..74a95e5c 100644 --- a/pkg/processmanager/v1/process_manager.go +++ b/pkg/processmanager/v1/process_manager.go @@ -375,11 +375,29 @@ func getProcessFromProc(pid int) (apitypes.Process, error) { cwd, _ := proc.Cwd() path, _ := proc.Executable() + pcomm := func() string { + if stat.PPID <= 0 { + return "" + } + + parentProc, err := procfs.NewProc(stat.PPID) + if err != nil { + return "" + } + + parentStat, err := parentProc.Stat() + if err != nil { + return "" + } + + return parentStat.Comm + }() return apitypes.Process{ PID: uint32(pid), PPID: uint32(stat.PPID), Comm: stat.Comm, + Pcomm: pcomm, Uid: &uid, Gid: &gid, Cmdline: strings.Join(cmdline, " "), From e62f2687c98c6511c9cc3429d41aab16c5a0492b Mon Sep 17 00:00:00 2001 From: Amit Schendel Date: Thu, 31 Oct 2024 09:22:07 +0000 Subject: [PATCH 14/14] Adding backoff to handle race conditions Signed-off-by: Amit Schendel --- pkg/rulemanager/v1/rule_manager.go | 72 ++++++++++-------------------- 1 file changed, 24 insertions(+), 48 deletions(-) diff --git a/pkg/rulemanager/v1/rule_manager.go b/pkg/rulemanager/v1/rule_manager.go index 6adcd1ca..88e0800a 100644 --- a/pkg/rulemanager/v1/rule_manager.go +++ b/pkg/rulemanager/v1/rule_manager.go @@ -356,8 +356,13 @@ func (rm *RuleManager) processEvent(eventType utils.EventType, event utils.K8sEv } } func (rm *RuleManager) enrichRuleFailure(ruleFailure ruleengine.RuleFailure) ruleengine.RuleFailure { - path, err := utils.GetPathFromPid(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID) - hostPath := "" + var err error + var path string + var hostPath string + if ruleFailure.GetRuntimeProcessDetails().ProcessTree.Path == "" { + path, err = utils.GetPathFromPid(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID) + } + if err != nil { if ruleFailure.GetRuntimeProcessDetails().ProcessTree.Path != "" { hostPath = filepath.Join("/proc", fmt.Sprintf("/%d/root/%s", rm.containerIdToPid.Get(ruleFailure.GetTriggerEvent().Runtime.ContainerID), ruleFailure.GetRuntimeProcessDetails().ProcessTree.Path)) @@ -395,59 +400,30 @@ func (rm *RuleManager) enrichRuleFailure(ruleFailure ruleengine.RuleFailure) rul ruleFailure.SetBaseRuntimeAlert(baseRuntimeAlert) runtimeProcessDetails := ruleFailure.GetRuntimeProcessDetails() - if runtimeProcessDetails.ProcessTree.Cmdline == "" { - commandLine, err := utils.GetCmdlineByPid(int(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID)) - if err != nil { - runtimeProcessDetails.ProcessTree.Cmdline = "" - } else { - runtimeProcessDetails.ProcessTree.Cmdline = *commandLine - } - } - - if runtimeProcessDetails.ProcessTree.PPID == 0 { - parent, err := utils.GetProcessStat(int(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID)) - if err != nil { - runtimeProcessDetails.ProcessTree.PPID = 0 - } else { - runtimeProcessDetails.ProcessTree.PPID = uint32(parent.PPID) - } - - if runtimeProcessDetails.ProcessTree.Pcomm == "" { - if err == nil { - runtimeProcessDetails.ProcessTree.Pcomm = parent.Comm - } else { - runtimeProcessDetails.ProcessTree.Pcomm = "" - } - } - } - if runtimeProcessDetails.ProcessTree.PID == 0 { - runtimeProcessDetails.ProcessTree.PID = ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID - } - - if runtimeProcessDetails.ProcessTree.Comm == "" { - comm, err := utils.GetCommFromPid(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID) + err = backoff.Retry(func() error { + tree, err := rm.processManager.GetProcessTreeForPID( + ruleFailure.GetRuntimeProcessDetails().ContainerID, + int(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID), + ) if err != nil { - comm = "" + return err } - runtimeProcessDetails.ProcessTree.Comm = comm - } - - if runtimeProcessDetails.ProcessTree.Path == "" && path != "" { - runtimeProcessDetails.ProcessTree.Path = path - } - - // TODO: Avoid Race condition where the tree is not populated yet. - tree, err := rm.processManager.GetProcessTreeForPID(ruleFailure.GetRuntimeProcessDetails().ContainerID, int(ruleFailure.GetRuntimeProcessDetails().ProcessTree.PID)) - if err == nil { runtimeProcessDetails.ProcessTree = tree - } else if rm.containerIdToShimPid.Has(ruleFailure.GetRuntimeProcessDetails().ContainerID) { + return nil + }, backoff.NewExponentialBackOff( + backoff.WithInitialInterval(50*time.Millisecond), + backoff.WithMaxInterval(200*time.Millisecond), + backoff.WithMaxElapsedTime(500*time.Millisecond), + )) + + if err != nil && rm.containerIdToShimPid.Has(ruleFailure.GetRuntimeProcessDetails().ContainerID) { logger.L().Debug("RuleManager - failed to get process tree, trying to get process tree from shim", helpers.Error(err), helpers.String("container ID", ruleFailure.GetRuntimeProcessDetails().ContainerID)) - shimPid := rm.containerIdToShimPid.Get(ruleFailure.GetRuntimeProcessDetails().ContainerID) - tree, err := utils.CreateProcessTree(&runtimeProcessDetails.ProcessTree, shimPid) - if err == nil { + + if tree, err := utils.CreateProcessTree(&runtimeProcessDetails.ProcessTree, + rm.containerIdToShimPid.Get(ruleFailure.GetRuntimeProcessDetails().ContainerID)); err == nil { runtimeProcessDetails.ProcessTree = *tree } }