diff --git a/lib/utils/envutils/environment.go b/lib/utils/envutils/environment.go index d7e01f46cd575..141ffbb390bc4 100644 --- a/lib/utils/envutils/environment.go +++ b/lib/utils/envutils/environment.go @@ -17,6 +17,7 @@ package envutils import ( "bufio" "fmt" + "io" "os" "strings" @@ -39,10 +40,14 @@ func ReadEnvironmentFile(filename string) ([]string, error) { } defer file.Close() + return readEnvironment(file) +} + +func readEnvironment(r io.Reader) ([]string, error) { var lineno int env := &SafeEnv{} - scanner := bufio.NewScanner(file) + scanner := bufio.NewScanner(r) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) @@ -50,7 +55,7 @@ func ReadEnvironmentFile(filename string) ([]string, error) { // https://github.com/openssh/openssh-portable/blob/master/session.c#L873-L874 lineno = lineno + 1 if lineno > teleport.MaxEnvironmentFileLines { - log.Warnf("Too many lines in environment file %v, returning first %v lines", filename, teleport.MaxEnvironmentFileLines) + log.Warnf("Too many lines in environment file, returning first %v lines", teleport.MaxEnvironmentFileLines) return *env, nil } @@ -62,7 +67,7 @@ func ReadEnvironmentFile(filename string) ([]string, error) { // split on first =, if not found, log it and continue idx := strings.Index(line, "=") if idx == -1 { - log.Debugf("Bad line %v while reading %v: no = separator found", lineno, filename) + log.Debugf("Bad line %v while reading environment file: no = separator found", lineno) continue } @@ -70,7 +75,7 @@ func ReadEnvironmentFile(filename string) ([]string, error) { key := line[:idx] value := line[idx+1:] if strings.TrimSpace(key) == "" { - log.Debugf("Bad line %v while reading %v: key without name", lineno, filename) + log.Debugf("Bad line %v while reading environment file: key without name", lineno) continue } @@ -78,9 +83,8 @@ func ReadEnvironmentFile(filename string) ([]string, error) { env.AddTrusted(key, value) } - err = scanner.Err() - if err != nil { - log.Warnf("Unable to read environment file %v: %v, skipping", filename, err) + if err := scanner.Err(); err != nil { + log.Warnf("Unable to read environment file: %v", err) return []string{}, nil } diff --git a/lib/utils/envutils/environment_test.go b/lib/utils/envutils/environment_test.go index 38b856fb2d014..12758f07dc113 100644 --- a/lib/utils/envutils/environment_test.go +++ b/lib/utils/envutils/environment_test.go @@ -16,14 +16,13 @@ limitations under the License. package envutils import ( - "os" + "bytes" "testing" - "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" ) -func TestReadEnvironmentFile(t *testing.T) { +func TestReadEnvironment(t *testing.T) { t.Parallel() // contents of environment file @@ -40,21 +39,11 @@ bar=foo LD_PRELOAD=attack `) - // create a temp file with an environment in it - f, err := os.CreateTemp(t.TempDir(), "teleport-environment-") - require.NoError(t, err) - defer os.Remove(f.Name()) - _, err = f.Write(rawenv) - require.NoError(t, err) - err = f.Close() - require.NoError(t, err) - - // read in the temp file - env, err := ReadEnvironmentFile(f.Name()) + env, err := readEnvironment(bytes.NewReader(rawenv)) require.NoError(t, err) // check we parsed it correctly - require.Empty(t, cmp.Diff(env, []string{"foo=bar", "foo=bar=baz", "foo=", "bar=foo"})) + require.Equal(t, []string{"foo=bar", "foo=bar=baz", "foo=", "bar=foo"}, env) } func TestSafeEnvAdd(t *testing.T) {