diff --git a/README.md b/README.md index b8a464a7f..20991d14b 100644 --- a/README.md +++ b/README.md @@ -425,6 +425,11 @@ In case multi-cluster support is enabled (default) and you have access to multip - `storage` (`string`) - Optional storage size for the VM's root disk when using DataSources (e.g., '30Gi', '50Gi', '100Gi'). Defaults to 30Gi. Ignored when using container disks. - `workload` (`string`) - The workload for the VM. Accepts OS names (e.g., 'fedora' (default), 'ubuntu', 'centos', 'centos-stream', 'debian', 'rhel', 'opensuse', 'opensuse-tumbleweed', 'opensuse-leap') or full container disk image URLs +- **vm_lifecycle** - Manage VirtualMachine lifecycle: start, stop, or restart a VM + - `action` (`string`) **(required)** - The lifecycle action to perform: 'start' (changes runStrategy to Always), 'stop' (changes runStrategy to Halted), or 'restart' (stops then starts the VM) + - `name` (`string`) **(required)** - The name of the virtual machine + - `namespace` (`string`) **(required)** - The namespace of the virtual machine + diff --git a/pkg/kubevirt/vm.go b/pkg/kubevirt/vm.go new file mode 100644 index 000000000..0a381c8a9 --- /dev/null +++ b/pkg/kubevirt/vm.go @@ -0,0 +1,160 @@ +package kubevirt + +import ( + "context" + "fmt" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/dynamic" +) + +// RunStrategy represents the run strategy for a VirtualMachine +type RunStrategy string + +const ( + RunStrategyAlways RunStrategy = "Always" + RunStrategyHalted RunStrategy = "Halted" +) + +var ( + // VirtualMachineGVK is the GroupVersionKind for VirtualMachine resources + VirtualMachineGVK = schema.GroupVersionKind{ + Group: "kubevirt.io", + Version: "v1", + Kind: "VirtualMachine", + } + + // VirtualMachineGVR is the GroupVersionResource for VirtualMachine resources + VirtualMachineGVR = schema.GroupVersionResource{ + Group: "kubevirt.io", + Version: "v1", + Resource: "virtualmachines", + } +) + +// GetVirtualMachine retrieves a VirtualMachine by namespace and name +func GetVirtualMachine(ctx context.Context, client dynamic.Interface, namespace, name string) (*unstructured.Unstructured, error) { + return client.Resource(VirtualMachineGVR).Namespace(namespace).Get(ctx, name, metav1.GetOptions{}) +} + +// GetVMRunStrategy retrieves the current runStrategy from a VirtualMachine +// Returns the strategy, whether it was found, and any error +func GetVMRunStrategy(vm *unstructured.Unstructured) (RunStrategy, bool, error) { + strategy, found, err := unstructured.NestedString(vm.Object, "spec", "runStrategy") + if err != nil { + return "", false, fmt.Errorf("failed to read runStrategy: %w", err) + } + + return RunStrategy(strategy), found, nil +} + +// SetVMRunStrategy sets the runStrategy on a VirtualMachine +func SetVMRunStrategy(vm *unstructured.Unstructured, strategy RunStrategy) error { + return unstructured.SetNestedField(vm.Object, string(strategy), "spec", "runStrategy") +} + +// UpdateVirtualMachine updates a VirtualMachine in the cluster +func UpdateVirtualMachine(ctx context.Context, client dynamic.Interface, vm *unstructured.Unstructured) (*unstructured.Unstructured, error) { + return client.Resource(VirtualMachineGVR). + Namespace(vm.GetNamespace()). + Update(ctx, vm, metav1.UpdateOptions{}) +} + +// StartVM starts a VirtualMachine by updating its runStrategy to Always +// Returns the updated VM and true if the VM was started, false if it was already running +func StartVM(ctx context.Context, dynamicClient dynamic.Interface, namespace, name string) (*unstructured.Unstructured, bool, error) { + // Get the current VirtualMachine + vm, err := GetVirtualMachine(ctx, dynamicClient, namespace, name) + if err != nil { + return nil, false, fmt.Errorf("failed to get VirtualMachine: %w", err) + } + + currentStrategy, found, err := GetVMRunStrategy(vm) + if err != nil { + return nil, false, err + } + + // Check if already running + if found && currentStrategy == RunStrategyAlways { + return vm, false, nil + } + + // Update runStrategy to Always + if err := SetVMRunStrategy(vm, RunStrategyAlways); err != nil { + return nil, false, fmt.Errorf("failed to set runStrategy: %w", err) + } + + // Update the VM in the cluster + updatedVM, err := UpdateVirtualMachine(ctx, dynamicClient, vm) + if err != nil { + return nil, false, fmt.Errorf("failed to start VirtualMachine: %w", err) + } + + return updatedVM, true, nil +} + +// StopVM stops a VirtualMachine by updating its runStrategy to Halted +// Returns the updated VM and true if the VM was stopped, false if it was already stopped +func StopVM(ctx context.Context, dynamicClient dynamic.Interface, namespace, name string) (*unstructured.Unstructured, bool, error) { + // Get the current VirtualMachine + vm, err := GetVirtualMachine(ctx, dynamicClient, namespace, name) + if err != nil { + return nil, false, fmt.Errorf("failed to get VirtualMachine: %w", err) + } + + currentStrategy, found, err := GetVMRunStrategy(vm) + if err != nil { + return nil, false, err + } + + // Check if already stopped + if found && currentStrategy == RunStrategyHalted { + return vm, false, nil + } + + // Update runStrategy to Halted + if err := SetVMRunStrategy(vm, RunStrategyHalted); err != nil { + return nil, false, fmt.Errorf("failed to set runStrategy: %w", err) + } + + // Update the VM in the cluster + updatedVM, err := UpdateVirtualMachine(ctx, dynamicClient, vm) + if err != nil { + return nil, false, fmt.Errorf("failed to stop VirtualMachine: %w", err) + } + + return updatedVM, true, nil +} + +// RestartVM restarts a VirtualMachine by temporarily setting runStrategy to Halted then back to Always +func RestartVM(ctx context.Context, dynamicClient dynamic.Interface, namespace, name string) (*unstructured.Unstructured, error) { + // Get the current VirtualMachine + vm, err := GetVirtualMachine(ctx, dynamicClient, namespace, name) + if err != nil { + return nil, fmt.Errorf("failed to get VirtualMachine: %w", err) + } + + // Stop the VM first + if err := SetVMRunStrategy(vm, RunStrategyHalted); err != nil { + return nil, fmt.Errorf("failed to set runStrategy to Halted: %w", err) + } + + vm, err = UpdateVirtualMachine(ctx, dynamicClient, vm) + if err != nil { + return nil, fmt.Errorf("failed to stop VirtualMachine: %w", err) + } + + // Start the VM again + if err := SetVMRunStrategy(vm, RunStrategyAlways); err != nil { + return nil, fmt.Errorf("failed to set runStrategy to Always: %w", err) + } + + updatedVM, err := UpdateVirtualMachine(ctx, dynamicClient, vm) + if err != nil { + return nil, fmt.Errorf("failed to start VirtualMachine: %w", err) + } + + return updatedVM, nil +} diff --git a/pkg/kubevirt/vm_test.go b/pkg/kubevirt/vm_test.go new file mode 100644 index 000000000..6afd4dbe2 --- /dev/null +++ b/pkg/kubevirt/vm_test.go @@ -0,0 +1,329 @@ +package kubevirt + +import ( + "context" + "strings" + "testing" + + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/dynamic/fake" +) + +// createTestVM creates a test VirtualMachine with the given name, namespace, and runStrategy +func createTestVM(name, namespace string, runStrategy RunStrategy) *unstructured.Unstructured { + vm := &unstructured.Unstructured{} + vm.SetUnstructuredContent(map[string]interface{}{ + "apiVersion": "kubevirt.io/v1", + "kind": "VirtualMachine", + "metadata": map[string]interface{}{ + "name": name, + "namespace": namespace, + }, + "spec": map[string]interface{}{ + "runStrategy": string(runStrategy), + }, + }) + return vm +} + +func TestStartVM(t *testing.T) { + tests := []struct { + name string + initialVM *unstructured.Unstructured + wantStarted bool + wantError bool + errorContains string + }{ + { + name: "Start VM that is Halted", + initialVM: createTestVM("test-vm", "default", RunStrategyHalted), + wantStarted: true, + wantError: false, + }, + { + name: "Start VM that is already running (Always)", + initialVM: createTestVM("test-vm", "default", RunStrategyAlways), + wantStarted: false, + wantError: false, + }, + { + name: "Start VM without runStrategy", + initialVM: &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "kubevirt.io/v1", + "kind": "VirtualMachine", + "metadata": map[string]interface{}{ + "name": "test-vm", + "namespace": "default", + }, + "spec": map[string]interface{}{}, + }, + }, + wantStarted: true, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scheme := runtime.NewScheme() + client := fake.NewSimpleDynamicClient(scheme, tt.initialVM) + ctx := context.Background() + + vm, wasStarted, err := StartVM(ctx, client, tt.initialVM.GetNamespace(), tt.initialVM.GetName()) + + if tt.wantError { + if err == nil { + t.Errorf("Expected error, got nil") + return + } + if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("Error = %v, want to contain %q", err, tt.errorContains) + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if vm == nil { + t.Errorf("Expected non-nil VM, got nil") + return + } + + if wasStarted != tt.wantStarted { + t.Errorf("wasStarted = %v, want %v", wasStarted, tt.wantStarted) + } + + // Verify the VM's runStrategy is Always + strategy, found, err := GetVMRunStrategy(vm) + if err != nil { + t.Errorf("Failed to get runStrategy: %v", err) + return + } + if !found { + t.Errorf("runStrategy not found") + return + } + if strategy != RunStrategyAlways { + t.Errorf("Strategy = %q, want %q", strategy, RunStrategyAlways) + } + }) + } +} + +func TestStartVMNotFound(t *testing.T) { + scheme := runtime.NewScheme() + client := fake.NewSimpleDynamicClient(scheme) + ctx := context.Background() + + _, _, err := StartVM(ctx, client, "default", "non-existent-vm") + if err == nil { + t.Errorf("Expected error for non-existent VM, got nil") + return + } + if !strings.Contains(err.Error(), "failed to get VirtualMachine") { + t.Errorf("Error = %v, want to contain 'failed to get VirtualMachine'", err) + } +} + +func TestStopVM(t *testing.T) { + tests := []struct { + name string + initialVM *unstructured.Unstructured + wantStopped bool + wantError bool + errorContains string + }{ + { + name: "Stop VM that is running (Always)", + initialVM: createTestVM("test-vm", "default", RunStrategyAlways), + wantStopped: true, + wantError: false, + }, + { + name: "Stop VM that is already stopped (Halted)", + initialVM: createTestVM("test-vm", "default", RunStrategyHalted), + wantStopped: false, + wantError: false, + }, + { + name: "Stop VM without runStrategy", + initialVM: &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "kubevirt.io/v1", + "kind": "VirtualMachine", + "metadata": map[string]interface{}{ + "name": "test-vm", + "namespace": "default", + }, + "spec": map[string]interface{}{}, + }, + }, + wantStopped: true, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scheme := runtime.NewScheme() + client := fake.NewSimpleDynamicClient(scheme, tt.initialVM) + ctx := context.Background() + + vm, wasStopped, err := StopVM(ctx, client, tt.initialVM.GetNamespace(), tt.initialVM.GetName()) + + if tt.wantError { + if err == nil { + t.Errorf("Expected error, got nil") + return + } + if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("Error = %v, want to contain %q", err, tt.errorContains) + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if vm == nil { + t.Errorf("Expected non-nil VM, got nil") + return + } + + if wasStopped != tt.wantStopped { + t.Errorf("wasStopped = %v, want %v", wasStopped, tt.wantStopped) + } + + // Verify the VM's runStrategy is Halted + strategy, found, err := GetVMRunStrategy(vm) + if err != nil { + t.Errorf("Failed to get runStrategy: %v", err) + return + } + if !found { + t.Errorf("runStrategy not found") + return + } + if strategy != RunStrategyHalted { + t.Errorf("Strategy = %q, want %q", strategy, RunStrategyHalted) + } + }) + } +} + +func TestStopVMNotFound(t *testing.T) { + scheme := runtime.NewScheme() + client := fake.NewSimpleDynamicClient(scheme) + ctx := context.Background() + + _, _, err := StopVM(ctx, client, "default", "non-existent-vm") + if err == nil { + t.Errorf("Expected error for non-existent VM, got nil") + return + } + if !strings.Contains(err.Error(), "failed to get VirtualMachine") { + t.Errorf("Error = %v, want to contain 'failed to get VirtualMachine'", err) + } +} + +func TestRestartVM(t *testing.T) { + tests := []struct { + name string + initialVM *unstructured.Unstructured + wantError bool + errorContains string + }{ + { + name: "Restart VM that is running (Always)", + initialVM: createTestVM("test-vm", "default", RunStrategyAlways), + wantError: false, + }, + { + name: "Restart VM that is stopped (Halted)", + initialVM: createTestVM("test-vm", "default", RunStrategyHalted), + wantError: false, + }, + { + name: "Restart VM without runStrategy", + initialVM: &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "kubevirt.io/v1", + "kind": "VirtualMachine", + "metadata": map[string]interface{}{ + "name": "test-vm", + "namespace": "default", + }, + "spec": map[string]interface{}{}, + }, + }, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scheme := runtime.NewScheme() + client := fake.NewSimpleDynamicClient(scheme, tt.initialVM) + ctx := context.Background() + + vm, err := RestartVM(ctx, client, tt.initialVM.GetNamespace(), tt.initialVM.GetName()) + + if tt.wantError { + if err == nil { + t.Errorf("Expected error, got nil") + return + } + if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("Error = %v, want to contain %q", err, tt.errorContains) + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if vm == nil { + t.Errorf("Expected non-nil VM, got nil") + return + } + + // Verify the VM's runStrategy is Always (after restart) + strategy, found, err := GetVMRunStrategy(vm) + if err != nil { + t.Errorf("Failed to get runStrategy: %v", err) + return + } + if !found { + t.Errorf("runStrategy not found") + return + } + if strategy != RunStrategyAlways { + t.Errorf("Strategy = %q, want %q after restart", strategy, RunStrategyAlways) + } + }) + } +} + +func TestRestartVMNotFound(t *testing.T) { + scheme := runtime.NewScheme() + client := fake.NewSimpleDynamicClient(scheme) + ctx := context.Background() + + _, err := RestartVM(ctx, client, "default", "non-existent-vm") + if err == nil { + t.Errorf("Expected error for non-existent VM, got nil") + return + } + if !strings.Contains(err.Error(), "failed to get VirtualMachine") { + t.Errorf("Error = %v, want to contain 'failed to get VirtualMachine'", err) + } +} diff --git a/pkg/mcp/kubevirt_test.go b/pkg/mcp/kubevirt_test.go index 83e6e9ed5..aa2514224 100644 --- a/pkg/mcp/kubevirt_test.go +++ b/pkg/mcp/kubevirt_test.go @@ -453,6 +453,221 @@ func (s *KubevirtSuite) TestCreate() { }) } +func (s *KubevirtSuite) TestVMLifecycle() { + // Create a test VM in Halted state for start tests + dynamicClient := dynamic.NewForConfigOrDie(envTestRestConfig) + vm := &unstructured.Unstructured{} + vm.SetUnstructuredContent(map[string]interface{}{ + "apiVersion": "kubevirt.io/v1", + "kind": "VirtualMachine", + "metadata": map[string]interface{}{ + "name": "test-vm-lifecycle", + "namespace": "default", + }, + "spec": map[string]interface{}{ + "runStrategy": "Halted", + }, + }) + _, err := dynamicClient.Resource(schema.GroupVersionResource{ + Group: "kubevirt.io", + Version: "v1", + Resource: "virtualmachines", + }).Namespace("default").Create(s.T().Context(), vm, metav1.CreateOptions{}) + s.Require().NoError(err, "failed to create test VM") + + s.Run("vm_lifecycle missing required params", func() { + testCases := []string{"name", "namespace", "action"} + for _, param := range testCases { + s.Run("missing "+param, func() { + params := map[string]interface{}{ + "name": "test-vm-lifecycle", + "namespace": "default", + "action": "start", + } + delete(params, param) + toolResult, err := s.CallTool("vm_lifecycle", params) + s.Require().Nilf(err, "call tool failed %v", err) + s.Truef(toolResult.IsError, "expected call tool to fail due to missing %s", param) + s.Equal(toolResult.Content[0].(mcp.TextContent).Text, param+" parameter required") + }) + } + }) + + s.Run("vm_lifecycle invalid action", func() { + toolResult, err := s.CallTool("vm_lifecycle", map[string]interface{}{ + "name": "test-vm-lifecycle", + "namespace": "default", + "action": "invalid", + }) + s.Require().Nilf(err, "call tool failed %v", err) + s.Truef(toolResult.IsError, "expected call tool to fail due to invalid action") + s.Truef(strings.Contains(toolResult.Content[0].(mcp.TextContent).Text, "invalid action"), + "Expected invalid action message, got %v", toolResult.Content[0].(mcp.TextContent).Text) + }) + + s.Run("vm_lifecycle action=start on halted VM", func() { + toolResult, err := s.CallTool("vm_lifecycle", map[string]interface{}{ + "name": "test-vm-lifecycle", + "namespace": "default", + "action": "start", + }) + s.Run("no error", func() { + s.Nilf(err, "call tool failed %v", err) + s.Falsef(toolResult.IsError, "call tool failed") + }) + var decodedResult []unstructured.Unstructured + err = yaml.Unmarshal([]byte(toolResult.Content[0].(mcp.TextContent).Text), &decodedResult) + s.Run("returns yaml content", func() { + s.Nilf(err, "invalid tool result content %v", err) + s.Truef(strings.HasPrefix(toolResult.Content[0].(mcp.TextContent).Text, "# VirtualMachine started successfully"), + "Expected success message, got %v", toolResult.Content[0].(mcp.TextContent).Text) + s.Require().Lenf(decodedResult, 1, "invalid resource count, expected 1, got %v", len(decodedResult)) + s.Equal("test-vm-lifecycle", decodedResult[0].GetName(), "invalid resource name") + s.Equal("default", decodedResult[0].GetNamespace(), "invalid resource namespace") + s.Equal("Always", + decodedResult[0].Object["spec"].(map[string]interface{})["runStrategy"].(string), + "expected runStrategy to be Always after start") + }) + }) + + s.Run("vm_lifecycle action=start on already running VM (idempotent)", func() { + toolResult, err := s.CallTool("vm_lifecycle", map[string]interface{}{ + "name": "test-vm-lifecycle", + "namespace": "default", + "action": "start", + }) + s.Run("no error", func() { + s.Nilf(err, "call tool failed %v", err) + s.Falsef(toolResult.IsError, "call tool failed") + }) + var decodedResult []unstructured.Unstructured + err = yaml.Unmarshal([]byte(toolResult.Content[0].(mcp.TextContent).Text), &decodedResult) + s.Run("returns yaml content showing VM was already running", func() { + s.Nilf(err, "invalid tool result content %v", err) + expectedPrefix := fmt.Sprintf("# VirtualMachine '%s' in namespace '%s' is already running", "test-vm-lifecycle", "default") + s.Truef(strings.HasPrefix(toolResult.Content[0].(mcp.TextContent).Text, expectedPrefix), + "Expected already running message, got %v", toolResult.Content[0].(mcp.TextContent).Text) + s.Require().Lenf(decodedResult, 1, "invalid resource count, expected 1, got %v", len(decodedResult)) + s.Equal("Always", + decodedResult[0].Object["spec"].(map[string]interface{})["runStrategy"].(string), + "expected runStrategy to remain Always") + }) + }) + + s.Run("vm_lifecycle action=stop on running VM", func() { + toolResult, err := s.CallTool("vm_lifecycle", map[string]interface{}{ + "name": "test-vm-lifecycle", + "namespace": "default", + "action": "stop", + }) + s.Run("no error", func() { + s.Nilf(err, "call tool failed %v", err) + s.Falsef(toolResult.IsError, "call tool failed") + }) + var decodedResult []unstructured.Unstructured + err = yaml.Unmarshal([]byte(toolResult.Content[0].(mcp.TextContent).Text), &decodedResult) + s.Run("returns yaml content", func() { + s.Nilf(err, "invalid tool result content %v", err) + s.Truef(strings.HasPrefix(toolResult.Content[0].(mcp.TextContent).Text, "# VirtualMachine stopped successfully"), + "Expected success message, got %v", toolResult.Content[0].(mcp.TextContent).Text) + s.Require().Lenf(decodedResult, 1, "invalid resource count, expected 1, got %v", len(decodedResult)) + s.Equal("test-vm-lifecycle", decodedResult[0].GetName(), "invalid resource name") + s.Equal("default", decodedResult[0].GetNamespace(), "invalid resource namespace") + s.Equal("Halted", + decodedResult[0].Object["spec"].(map[string]interface{})["runStrategy"].(string), + "expected runStrategy to be Halted after stop") + }) + }) + + s.Run("vm_lifecycle action=stop on already stopped VM (idempotent)", func() { + toolResult, err := s.CallTool("vm_lifecycle", map[string]interface{}{ + "name": "test-vm-lifecycle", + "namespace": "default", + "action": "stop", + }) + s.Run("no error", func() { + s.Nilf(err, "call tool failed %v", err) + s.Falsef(toolResult.IsError, "call tool failed") + }) + var decodedResult []unstructured.Unstructured + err = yaml.Unmarshal([]byte(toolResult.Content[0].(mcp.TextContent).Text), &decodedResult) + s.Run("returns yaml content showing VM was already stopped", func() { + s.Nilf(err, "invalid tool result content %v", err) + expectedPrefix := fmt.Sprintf("# VirtualMachine '%s' in namespace '%s' is already stopped", "test-vm-lifecycle", "default") + s.Truef(strings.HasPrefix(toolResult.Content[0].(mcp.TextContent).Text, expectedPrefix), + "Expected already stopped message, got %v", toolResult.Content[0].(mcp.TextContent).Text) + s.Require().Lenf(decodedResult, 1, "invalid resource count, expected 1, got %v", len(decodedResult)) + s.Equal("Halted", + decodedResult[0].Object["spec"].(map[string]interface{})["runStrategy"].(string), + "expected runStrategy to remain Halted") + }) + }) + + s.Run("vm_lifecycle action=restart on stopped VM", func() { + toolResult, err := s.CallTool("vm_lifecycle", map[string]interface{}{ + "name": "test-vm-lifecycle", + "namespace": "default", + "action": "restart", + }) + s.Run("no error", func() { + s.Nilf(err, "call tool failed %v", err) + s.Falsef(toolResult.IsError, "call tool failed") + }) + var decodedResult []unstructured.Unstructured + err = yaml.Unmarshal([]byte(toolResult.Content[0].(mcp.TextContent).Text), &decodedResult) + s.Run("returns yaml content showing VM restarted from stopped state", func() { + s.Nilf(err, "invalid tool result content %v", err) + s.Truef(strings.HasPrefix(toolResult.Content[0].(mcp.TextContent).Text, "# VirtualMachine restarted successfully"), + "Expected success message, got %v", toolResult.Content[0].(mcp.TextContent).Text) + s.Require().Lenf(decodedResult, 1, "invalid resource count, expected 1, got %v", len(decodedResult)) + s.Equal("Always", + decodedResult[0].Object["spec"].(map[string]interface{})["runStrategy"].(string), + "expected runStrategy to be Always after restart from Halted") + }) + }) + + s.Run("vm_lifecycle action=restart on running VM", func() { + toolResult, err := s.CallTool("vm_lifecycle", map[string]interface{}{ + "name": "test-vm-lifecycle", + "namespace": "default", + "action": "restart", + }) + s.Run("no error", func() { + s.Nilf(err, "call tool failed %v", err) + s.Falsef(toolResult.IsError, "call tool failed") + }) + var decodedResult []unstructured.Unstructured + err = yaml.Unmarshal([]byte(toolResult.Content[0].(mcp.TextContent).Text), &decodedResult) + s.Run("returns yaml content", func() { + s.Nilf(err, "invalid tool result content %v", err) + s.Truef(strings.HasPrefix(toolResult.Content[0].(mcp.TextContent).Text, "# VirtualMachine restarted successfully"), + "Expected success message, got %v", toolResult.Content[0].(mcp.TextContent).Text) + s.Require().Lenf(decodedResult, 1, "invalid resource count, expected 1, got %v", len(decodedResult)) + s.Equal("test-vm-lifecycle", decodedResult[0].GetName(), "invalid resource name") + s.Equal("default", decodedResult[0].GetNamespace(), "invalid resource namespace") + s.Equal("Always", + decodedResult[0].Object["spec"].(map[string]interface{})["runStrategy"].(string), + "expected runStrategy to be Always after restart") + }) + }) + + s.Run("vm_lifecycle on non-existent VM", func() { + for _, action := range []string{"start", "stop", "restart"} { + s.Run("action="+action, func() { + toolResult, err := s.CallTool("vm_lifecycle", map[string]interface{}{ + "name": "non-existent-vm", + "namespace": "default", + "action": action, + }) + s.Nilf(err, "call tool failed %v", err) + s.Truef(toolResult.IsError, "expected call tool to fail for non-existent VM") + s.Truef(strings.Contains(toolResult.Content[0].(mcp.TextContent).Text, "failed to get VirtualMachine"), + "Expected error message about VM not found, got %v", toolResult.Content[0].(mcp.TextContent).Text) + }) + } + }) +} + func TestKubevirt(t *testing.T) { suite.Run(t, new(KubevirtSuite)) } diff --git a/pkg/mcp/testdata/toolsets-kubevirt-tools.json b/pkg/mcp/testdata/toolsets-kubevirt-tools.json index 648f45f73..37167b77b 100644 --- a/pkg/mcp/testdata/toolsets-kubevirt-tools.json +++ b/pkg/mcp/testdata/toolsets-kubevirt-tools.json @@ -77,5 +77,41 @@ ] }, "name": "vm_create" + }, + { + "annotations": { + "title": "Virtual Machine: Lifecycle", + "destructiveHint": true, + "openWorldHint": false + }, + "description": "Manage VirtualMachine lifecycle: start, stop, or restart a VM", + "inputSchema": { + "type": "object", + "properties": { + "action": { + "description": "The lifecycle action to perform: 'start' (changes runStrategy to Always), 'stop' (changes runStrategy to Halted), or 'restart' (stops then starts the VM)", + "enum": [ + "start", + "stop", + "restart" + ], + "type": "string" + }, + "name": { + "description": "The name of the virtual machine", + "type": "string" + }, + "namespace": { + "description": "The namespace of the virtual machine", + "type": "string" + } + }, + "required": [ + "namespace", + "name", + "action" + ] + }, + "name": "vm_lifecycle" } ] diff --git a/pkg/toolsets/kubevirt/toolset.go b/pkg/toolsets/kubevirt/toolset.go index bec5fd208..33a860a4d 100644 --- a/pkg/toolsets/kubevirt/toolset.go +++ b/pkg/toolsets/kubevirt/toolset.go @@ -7,6 +7,7 @@ import ( internalk8s "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" "github.com/containers/kubernetes-mcp-server/pkg/toolsets" vm_create "github.com/containers/kubernetes-mcp-server/pkg/toolsets/kubevirt/vm/create" + vm_lifecycle "github.com/containers/kubernetes-mcp-server/pkg/toolsets/kubevirt/vm/lifecycle" ) type Toolset struct{} @@ -24,6 +25,7 @@ func (t *Toolset) GetDescription() string { func (t *Toolset) GetTools(o internalk8s.Openshift) []api.ServerTool { return slices.Concat( vm_create.Tools(), + vm_lifecycle.Tools(), ) } diff --git a/pkg/toolsets/kubevirt/vm/lifecycle/tool.go b/pkg/toolsets/kubevirt/vm/lifecycle/tool.go new file mode 100644 index 000000000..aa81b7def --- /dev/null +++ b/pkg/toolsets/kubevirt/vm/lifecycle/tool.go @@ -0,0 +1,139 @@ +package lifecycle + +import ( + "fmt" + + "github.com/containers/kubernetes-mcp-server/pkg/api" + "github.com/containers/kubernetes-mcp-server/pkg/kubevirt" + "github.com/containers/kubernetes-mcp-server/pkg/output" + "github.com/google/jsonschema-go/jsonschema" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/utils/ptr" +) + +// Action represents the lifecycle action to perform on a VM +type Action string + +const ( + ActionStart Action = "start" + ActionStop Action = "stop" + ActionRestart Action = "restart" +) + +func Tools() []api.ServerTool { + return []api.ServerTool{ + { + Tool: api.Tool{ + Name: "vm_lifecycle", + Description: "Manage VirtualMachine lifecycle: start, stop, or restart a VM", + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "namespace": { + Type: "string", + Description: "The namespace of the virtual machine", + }, + "name": { + Type: "string", + Description: "The name of the virtual machine", + }, + "action": { + Type: "string", + Enum: []any{"start", "stop", "restart"}, + Description: "The lifecycle action to perform: 'start' (changes runStrategy to Always), 'stop' (changes runStrategy to Halted), or 'restart' (stops then starts the VM)", + }, + }, + Required: []string{"namespace", "name", "action"}, + }, + Annotations: api.ToolAnnotations{ + Title: "Virtual Machine: Lifecycle", + ReadOnlyHint: ptr.To(false), + DestructiveHint: ptr.To(true), + IdempotentHint: ptr.To(false), + OpenWorldHint: ptr.To(false), + }, + }, + Handler: lifecycle, + }, + } +} + +func lifecycle(params api.ToolHandlerParams) (*api.ToolCallResult, error) { + // Parse input parameters + namespace, err := getRequiredString(params, "namespace") + if err != nil { + return api.NewToolCallResult("", err), nil + } + + name, err := getRequiredString(params, "name") + if err != nil { + return api.NewToolCallResult("", err), nil + } + + action, err := getRequiredString(params, "action") + if err != nil { + return api.NewToolCallResult("", err), nil + } + + dynamicClient := params.AccessControlClientset().DynamicClient() + + var vm *unstructured.Unstructured + var message string + + switch Action(action) { + case ActionStart: + var wasStarted bool + vm, wasStarted, err = kubevirt.StartVM(params.Context, dynamicClient, namespace, name) + if err != nil { + return api.NewToolCallResult("", err), nil + } + if wasStarted { + message = "# VirtualMachine started successfully\n" + } else { + message = fmt.Sprintf("# VirtualMachine '%s' in namespace '%s' is already running\n", name, namespace) + } + + case ActionStop: + var wasRunning bool + vm, wasRunning, err = kubevirt.StopVM(params.Context, dynamicClient, namespace, name) + if err != nil { + return api.NewToolCallResult("", err), nil + } + if wasRunning { + message = "# VirtualMachine stopped successfully\n" + } else { + message = fmt.Sprintf("# VirtualMachine '%s' in namespace '%s' is already stopped\n", name, namespace) + } + + case ActionRestart: + vm, err = kubevirt.RestartVM(params.Context, dynamicClient, namespace, name) + if err != nil { + return api.NewToolCallResult("", err), nil + } + message = "# VirtualMachine restarted successfully\n" + + default: + return api.NewToolCallResult("", fmt.Errorf("invalid action '%s': must be one of 'start', 'stop', 'restart'", action)), nil + } + + // Format the output + marshalledYaml, err := output.MarshalYaml([]*unstructured.Unstructured{vm}) + if err != nil { + return api.NewToolCallResult("", fmt.Errorf("failed to marshal VirtualMachine: %w", err)), nil + } + + return api.NewToolCallResult(message+marshalledYaml, nil), nil +} + +func getRequiredString(params api.ToolHandlerParams, key string) (string, error) { + args := params.GetArguments() + val, ok := args[key] + if !ok { + return "", fmt.Errorf("%s parameter required", key) + } + str, ok := val.(string) + if !ok { + return "", fmt.Errorf("%s parameter must be a string", key) + } + return str, nil +}