diff --git a/pkg/ddc/alluxio/operations/base.go b/pkg/ddc/alluxio/operations/base.go index 111a7a8e078..9fa4107abb8 100644 --- a/pkg/ddc/alluxio/operations/base.go +++ b/pkg/ddc/alluxio/operations/base.go @@ -515,6 +515,11 @@ func (a AlluxioFileUtils) exec(command []string, verbose bool) (stdout string, s // execWithoutTimeout func (a AlluxioFileUtils) execWithoutTimeout(command []string, verbose bool) (stdout string, stderr string, err error) { + err = utils.ValidateCommandSlice(command) + if err != nil { + return + } + stdout, stderr, err = kubeclient.ExecCommandInContainer(a.podName, a.container, a.namespace, command) if err != nil { a.log.Info("Stdout", "Command", command, "Stdout", stdout) diff --git a/pkg/ddc/goosefs/hcfs_test.go b/pkg/ddc/goosefs/hcfs_test.go index 40a0ffb0ad4..66a44aab3a3 100644 --- a/pkg/ddc/goosefs/hcfs_test.go +++ b/pkg/ddc/goosefs/hcfs_test.go @@ -97,7 +97,10 @@ func TestGetHCFSStatus(t *testing.T) { t.Fatal(err.Error()) } engine := newGooseFSEngineHCFS(fakeClient, "hbase", "fluid") - out, _ := engine.GetHCFSStatus() + out, err := engine.GetHCFSStatus() + if err != nil { + t.Fatal(err.Error()) + } wrappedUnhook() status := &v1alpha1.HCFSStatus{ Endpoint: "goosefs://hbase-master-0.fluid:2333", diff --git a/pkg/ddc/goosefs/operations/base.go b/pkg/ddc/goosefs/operations/base.go index dfdb518366e..905fb01bc94 100644 --- a/pkg/ddc/goosefs/operations/base.go +++ b/pkg/ddc/goosefs/operations/base.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "github.com/fluid-cloudnative/fluid/pkg/utils" "github.com/fluid-cloudnative/fluid/pkg/utils/kubeclient" "github.com/go-logr/logr" ) @@ -501,6 +502,11 @@ func (a GooseFSFileUtils) exec(command []string, verbose bool) (stdout string, s // execWithoutTimeout func (a GooseFSFileUtils) execWithoutTimeout(command []string, verbose bool) (stdout string, stderr string, err error) { + err = utils.ValidateCommandSlice(command) + if err != nil { + return + } + stdout, stderr, err = kubeclient.ExecCommandInContainer(a.podName, a.container, a.namespace, command) if err != nil { a.log.Info("Stdout", "Command", command, "Stdout", stdout) diff --git a/pkg/ddc/jindo/operations/base.go b/pkg/ddc/jindo/operations/base.go index 70ee279f994..f4419f0c649 100644 --- a/pkg/ddc/jindo/operations/base.go +++ b/pkg/ddc/jindo/operations/base.go @@ -22,6 +22,7 @@ import ( "strings" "time" + "github.com/fluid-cloudnative/fluid/pkg/utils" "github.com/fluid-cloudnative/fluid/pkg/utils/kubeclient" "github.com/go-logr/logr" ) @@ -66,6 +67,11 @@ func (a JindoFileUtils) exec(command []string, verbose bool) (stdout string, std // execWithoutTimeout func (a JindoFileUtils) execWithoutTimeout(command []string, verbose bool) (stdout string, stderr string, err error) { + err = utils.ValidateCommandSlice(command) + if err != nil { + return + } + stdout, stderr, err = kubeclient.ExecCommandInContainer(a.podName, a.container, a.namespace, command) if err != nil { a.log.Info("Stdout", "Command", command, "Stdout", stdout) diff --git a/pkg/ddc/jindofsx/operations/base.go b/pkg/ddc/jindofsx/operations/base.go index 1d19bfe7291..28051041e5e 100644 --- a/pkg/ddc/jindofsx/operations/base.go +++ b/pkg/ddc/jindofsx/operations/base.go @@ -22,6 +22,7 @@ import ( "strings" "time" + "github.com/fluid-cloudnative/fluid/pkg/utils" "github.com/fluid-cloudnative/fluid/pkg/utils/kubeclient" "github.com/go-logr/logr" ) @@ -87,6 +88,11 @@ func (a JindoFileUtils) execWithTimeOut(command []string, verbose bool, second i // execWithoutTimeout func (a JindoFileUtils) execWithoutTimeout(command []string, verbose bool) (stdout string, stderr string, err error) { + err = utils.ValidateCommandSlice(command) + if err != nil { + return + } + stdout, stderr, err = kubeclient.ExecCommandInContainer(a.podName, a.container, a.namespace, command) if err != nil { a.log.Info("Stdout", "Command", command, "Stdout", stdout) diff --git a/pkg/ddc/juicefs/data_load_test.go b/pkg/ddc/juicefs/data_load_test.go index d6929fcfb7e..88f4a13dda1 100644 --- a/pkg/ddc/juicefs/data_load_test.go +++ b/pkg/ddc/juicefs/data_load_test.go @@ -18,13 +18,14 @@ package juicefs import ( "errors" - "github.com/fluid-cloudnative/fluid/pkg/common" "os" "path/filepath" "reflect" "strings" "testing" + "github.com/fluid-cloudnative/fluid/pkg/common" + "github.com/brahma-adshonor/gohook" appsv1 "k8s.io/api/apps/v1" v1 "k8s.io/api/core/v1" @@ -621,7 +622,7 @@ func TestJuiceFSEngine_CheckExistenceOfPath(t *testing.T) { } notExist, err = engine.CheckExistenceOfPath(targetDataload) if !(err == nil && notExist == false) { - t.Errorf("fail to exec the function") + t.Errorf("fail to exec the function due to %v", err) } wrappedUnhook() } diff --git a/pkg/ddc/juicefs/operations/base.go b/pkg/ddc/juicefs/operations/base.go index 5e50b48884e..d32a8850d35 100644 --- a/pkg/ddc/juicefs/operations/base.go +++ b/pkg/ddc/juicefs/operations/base.go @@ -26,6 +26,7 @@ import ( "github.com/go-logr/logr" + "github.com/fluid-cloudnative/fluid/pkg/utils" "github.com/fluid-cloudnative/fluid/pkg/utils/kubeclient" ) @@ -325,6 +326,12 @@ func (j JuiceFileUtils) exec(command []string) (stdout string, stderr string, er // execWithoutTimeout func (j JuiceFileUtils) execWithoutTimeout(command []string) (stdout string, stderr string, err error) { + // validate the pipe command with white list + err = utils.ValidateCommandSlice(command) + if err != nil { + return + } + stdout, stderr, err = kubeclient.ExecCommandInContainer(j.podName, j.container, j.namespace, command) if err != nil { j.log.Info("Stdout", "Command", command, "Stdout", stdout) diff --git a/pkg/ddc/thin/operations/base.go b/pkg/ddc/thin/operations/base.go index 9359afc5319..21b4c70fb72 100644 --- a/pkg/ddc/thin/operations/base.go +++ b/pkg/ddc/thin/operations/base.go @@ -23,6 +23,7 @@ import ( "strings" "time" + "github.com/fluid-cloudnative/fluid/pkg/utils" "github.com/fluid-cloudnative/fluid/pkg/utils/kubeclient" "github.com/go-logr/logr" ) @@ -155,6 +156,12 @@ func (t ThinFileUtils) exec(command []string, verbose bool) (stdout string, stde // execWithoutTimeout func (t ThinFileUtils) execWithoutTimeout(command []string, verbose bool) (stdout string, stderr string, err error) { + // validate the pipe command with white list + err = utils.ValidateCommandSlice(command) + if err != nil { + return + } + stdout, stderr, err = kubeclient.ExecCommandInContainer(t.podName, t.container, t.namespace, command) if err != nil { t.log.Info("Stdout", "Command", command, "Stdout", stdout) diff --git a/pkg/utils/helm/helm.go b/pkg/utils/helm/helm.go index 4f8d5acb77d..f07870eace7 100644 --- a/pkg/utils/helm/helm.go +++ b/pkg/utils/helm/helm.go @@ -132,7 +132,11 @@ func GetChartVersion(chart string) (version string, err error) { "|", "grep", "version:"} log.V(1).Info("Exec bash -c", "args", args) - cmd := exec.Command("bash", "-c", strings.Join(args, " ")) + // cmd := exec.Command("bash", "-c", strings.Join(args, " ")) + cmd, err := utils.PipeCommand("bash", "-c", strings.Join(args, " ")) + if err != nil { + return "", err + } out, err := cmd.Output() if err != nil { return "", err diff --git a/pkg/utils/helm/helm_test.go b/pkg/utils/helm/helm_test.go index 5d4068cf275..30d82630d3d 100644 --- a/pkg/utils/helm/helm_test.go +++ b/pkg/utils/helm/helm_test.go @@ -109,7 +109,7 @@ func TestGenerateHelmTemplate(t *testing.T) { func TestGetChartVersion(t *testing.T) { LookPathCommon := func(file string) (string, error) { - return "test-path", nil + return "helm", nil } LookPathErr := func(file string) (string, error) { return "", errors.New("fail to run the command") @@ -190,10 +190,10 @@ func TestGetChartVersion(t *testing.T) { } version, err := GetChartVersion("fluid:v0.6.0") if err != nil { - t.Errorf("fail to exec the function") + t.Errorf("fail to exec the function due to %v", err) } if version != "v0.6.0" { - t.Errorf("fail to get the version of the helm") + t.Errorf("fail to get the version of the helm due to %v", err) } wrappedUnhookOutput() wrappedUnhookStat() diff --git a/pkg/utils/shell_pipes.go b/pkg/utils/shell_pipes.go new file mode 100644 index 00000000000..f79f016aa37 --- /dev/null +++ b/pkg/utils/shell_pipes.go @@ -0,0 +1,156 @@ +/* +Copyright 2024 The Fluid Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "fmt" + "os/exec" + "strings" +) + +func PipeCommand(name string, arg ...string) (cmd *exec.Cmd, err error) { + // prepare the slice for ValidatePipeCommandSlice + var commands []string + commands = append(commands, name) + commands = append(commands, arg...) + + // validate commands + err = ValidatePipeCommandSlice(commands) + if err != nil { + return nil, err + } + + return exec.Command(name, arg...), nil +} + +// ValidateCommandSlice takes in a slice of shell commands and returns an error if any are invalid. +// The function looks specifically for pipe commands (i.e., commands that contain a '|'). +// If a pipe command is found in the slice, ValidatePipeCommandSlice is called for further validation. +func ValidateCommandSlice(shellCommandSlice []string) (err error) { + isPossiblePipeCommand := false + + for _, command := range shellCommandSlice { + if strings.Contains(command, "|") { + isPossiblePipeCommand = true + } + } + + if isPossiblePipeCommand { + err = ValidatePipeCommandSlice(shellCommandSlice) + } // else { + // // Todo: need handle no PossiblePipeCommand + // } + return +} + +func ValidatePipeCommandSlice(shellCommandSlice []string) (err error) { + // Make sure the shell command is allowed + var AllowedShellCommands = map[string]bool{ + "bash -c": true, + "sh -c": true, + } + + // check if shellCommandSlice has enough arguments + if len(shellCommandSlice) < 3 { + return fmt.Errorf("insufficient arguments. Expected at least 3, received %d", len(shellCommandSlice)) + } + // We assume -c always directly follows the shell command + shellCommand := strings.Join(strings.Fields(shellCommandSlice[0]+" "+shellCommandSlice[1]), " ") + if _, ok := AllowedShellCommands[shellCommand]; !ok { + return fmt.Errorf("unknown shell command: %s", shellCommand) + } + + for _, command := range shellCommandSlice[2:] { + if err := ValidateShellPipeString(command); err != nil { + return err + } + } + return +} + +// ValidateShellPipeString function checks whether the input command string is safe to execute. +// It checks whether all parts of a pipeline command start with any command prefixes defined in AllowedCommands +// It also checks for any illegal sequences that may lead to command injection attack. +func ValidateShellPipeString(command string) error { + // Define illegal sequences that may lead to command injection attack + illegalSequences := []string{"&", ";", "$", "'", "`", "(", ")", "||", ">>"} + // Separate parts of pipeline command + pipelineCommands := strings.Split(command, "|") + + // AllowedCommands is a global map that contains all allowed command prefixes. + var AllowedCommands = map[string]bool{ + "ls": false, + "df": false, + "mount": false, + "alluxio": false, + "goosefs": false, + "kubectl": false, + "helm": false, + } + + // AllowedPipeCommands is a map that contains all allowed pipe command prefixes. + var allowedPipeCommands = map[string]bool{ + "grep": false, // false means partial match + "wc -l": true, // true means full match (wc -l is exactly the allowed command) + // Add more commands as you see fit + } + + // Check each part of pipeline command + for i, cmd := range pipelineCommands { + cmd = strings.Join( + strings.Fields( + strings.TrimSpace(cmd)), " ") + + if i > 0 { + // Check whether command starts with any allowed command prefix + validCmd := isValidCommand(cmd, allowedPipeCommands) + + // If none of the allowed command prefix is found, throw error + if !validCmd { + return fmt.Errorf("full pipeline command not supported: part %d contains unsupported command '%s', the whole command %s", i+1, cmd, command) + } + } else { + validCmd := isValidCommand(cmd, AllowedCommands) + // If none of the allowed command prefix is found, throw error + if !validCmd { + return fmt.Errorf("full pipeline command not supported: part %d contains unsupported command '%s', the whole command %s", i+1, cmd, command) + } + } + + // Check for illegal sequences in command + for _, illegalSeq := range illegalSequences { + if strings.Contains(cmd, illegalSeq) { + return fmt.Errorf("unsafe pipeline command %s, illegal sequence detected: %s in part %d: '%s'", command, illegalSeq, i+1, cmd) + } + } + } + + // If no error found, return nil + return nil +} + +// Defining a function to check if the command is valid +func isValidCommand(cmd string, allowedCommands map[string]bool) bool { + for cmdPrefix, exactMatch := range allowedCommands { + if exactMatch && cmd == cmdPrefix { + return true + } else if !exactMatch && strings.HasPrefix(cmd, cmdPrefix) { + return true + } + } + return false +} diff --git a/pkg/utils/shell_pipes_test.go b/pkg/utils/shell_pipes_test.go new file mode 100644 index 00000000000..430f86c028d --- /dev/null +++ b/pkg/utils/shell_pipes_test.go @@ -0,0 +1,132 @@ +package utils + +import ( + "os/exec" + "reflect" + "testing" +) + +/* +Copyright 2024 The Fluid Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +func TestValidateShellPipeString(t *testing.T) { + type args struct { + command string + } + tests := []struct { + name string + args args + wantErr bool + }{ + {name: "valid command with grep", args: args{command: "echo hello world | grep hello"}, wantErr: true}, + {name: "valid command with wc -l", args: args{command: "kubectl hello world | wc -l"}, wantErr: false}, + {name: "invalid command with xyz", args: args{command: "echo hello world | xyz"}, wantErr: true}, + {name: "illegal sequence in command with &", args: args{command: "echo hello world & echo y"}, wantErr: true}, + {name: "illegal sequence in command with ;", args: args{command: "ls ; echo y"}, wantErr: true}, + {name: "command with $", args: args{command: "kubectl $HOME"}, wantErr: true}, + {name: "command with absolute path", args: args{command: "ls /etc"}, wantErr: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := ValidateShellPipeString(tt.args.command); (err != nil) != tt.wantErr { + t.Errorf("Testcase '%s' ValidateShellPipeString() error = %v, wantErr %v", tt.name, err, tt.wantErr) + } + }) + } +} + +func TestPipeCommand(t *testing.T) { + type args struct { + name string + arg []string + } + tests := []struct { + name string + args args + wantCmd *exec.Cmd + wantErr bool + }{ + {name: "valid simple command", args: args{name: "bash", arg: []string{"-c", "ls"}}, wantCmd: exec.Command("bash", "-c", "ls"), wantErr: false}, + {name: "insufficient arguments", args: args{name: "bash", arg: []string{"-c"}}, wantCmd: nil, wantErr: true}, + {name: "unknown shell command", args: args{name: "zsh", arg: []string{"-c", "ls"}}, wantCmd: nil, wantErr: true}, + {name: "valid piped command", args: args{name: "bash", arg: []string{"-c", "ls | grep something"}}, wantCmd: exec.Command("bash", "-c", "ls | grep something"), wantErr: false}, + {name: "invalid piped command", args: args{name: "bash", arg: []string{"-c", "ls | random-command"}}, wantCmd: nil, wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCmd, err := PipeCommand(tt.args.name, tt.args.arg...) + if (err != nil) != tt.wantErr { + t.Errorf("Testcase '%s': PipeCommand() error = %v, wantErr %v", tt.name, err, tt.wantErr) + return + } + if gotCmd != nil && !reflect.DeepEqual(gotCmd.Path, tt.wantCmd.Path) { + t.Errorf("Testcase '%s': PipeCommand() = %v, want %v", tt.name, gotCmd, tt.wantCmd) + } + if gotCmd != nil && !reflect.DeepEqual(gotCmd.Args, tt.wantCmd.Args) { + t.Errorf("Testcase '%s': PipeCommand() = %v, want %v", tt.name, gotCmd, tt.wantCmd) + } + }) + } +} + +func TestValidatePipeCommandSlice(t *testing.T) { + type args struct { + shellCommandSlice []string + } + tests := []struct { + name string + args args + wantErr bool + }{ + {name: "valid bash command", args: args{shellCommandSlice: []string{"bash", "-c", "ls"}}, wantErr: false}, + {name: "valid sh command", args: args{shellCommandSlice: []string{"sh", "-c", "ls"}}, wantErr: false}, + {name: "unknown shell command", args: args{shellCommandSlice: []string{"zsh", "-c", "ls"}}, wantErr: true}, + {name: "invalid bash command", args: args{shellCommandSlice: []string{"bash", "-c", "wrong_command"}}, wantErr: true}, + {name: "insufficient arguments", args: args{shellCommandSlice: []string{"bash", "-c"}}, wantErr: true}, + {name: "empty command slice", args: args{shellCommandSlice: []string{}}, wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := ValidatePipeCommandSlice(tt.args.shellCommandSlice); (err != nil) != tt.wantErr { + t.Errorf("Testcase '%s': ValidatePipeCommandSlice() error = %v, wantErr %v", tt.name, err, tt.wantErr) + } + }) + } +} + +func TestIsValidCommand(t *testing.T) { + type args struct { + cmd string + allowedCommands map[string]bool + } + tests := []struct { + name string + args args + want bool + }{ + {name: "valid bash command", args: args{cmd: "bash", allowedCommands: map[string]bool{"bash": true}}, want: true}, + {name: "valid sh command", args: args{cmd: "sh", allowedCommands: map[string]bool{"bash": true, "sh": true}}, want: true}, + {name: "invalid zsh command", args: args{cmd: "zsh", allowedCommands: map[string]bool{"bash": true}}, want: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isValidCommand(tt.args.cmd, tt.args.allowedCommands); got != tt.want { + t.Errorf("Testcase '%s': isValidCommand() = %v, want %v", tt.name, got, tt.want) + } + }) + } +} \ No newline at end of file