diff --git a/cmd/dry.go b/cmd/dry.go index e370d9dca..eb97a28dc 100644 --- a/cmd/dry.go +++ b/cmd/dry.go @@ -35,7 +35,7 @@ func dryCmd() *cobra.Command { os.Exit(1) } - workflow, err := dag.Load(cfg.BaseConfig, args[0], params) + workflow, err := dag.Load(cfg.BaseConfig, args[0], removeQuotes(params)) if err != nil { initLogger.Error("Workflow load failed", "error", err, "file", args[0]) os.Exit(1) diff --git a/cmd/start.go b/cmd/start.go index 6754c71f0..3aab8a4dd 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -42,7 +42,7 @@ func startCmd() *cobra.Command { os.Exit(1) } - workflow, err := dag.Load(cfg.BaseConfig, args[0], params) + workflow, err := dag.Load(cfg.BaseConfig, args[0], removeQuotes(params)) if err != nil { initLogger.Error("Workflow load failed", "error", err, "file", args[0]) os.Exit(1) @@ -110,3 +110,11 @@ func startCmd() *cobra.Command { cmd.Flags().BoolP("quiet", "q", false, "suppress output") return cmd } + +// removeQuotes removes the surrounding quotes from the string. +func removeQuotes(s string) string { + if len(s) > 1 && s[0] == '"' && s[len(s)-1] == '"' { + return s[1 : len(s)-1] + } + return s +} diff --git a/internal/dag/builder_test.go b/internal/dag/builder_test.go index fe9b80ea4..ec504254e 100644 --- a/internal/dag/builder_test.go +++ b/internal/dag/builder_test.go @@ -149,6 +149,14 @@ func TestBuilder_BuildParams(t *testing.T) { "2": "x", }, }, + { + name: "QuotedParams", + params: `x="1" y="2"`, + expected: map[string]string{ + "x": "1", + "y": "2", + }, + }, { name: "ComplexParams", params: "first P1=foo P2=${FOO} P3=`/bin/echo BAR` X=bar Y=${P1} Z=\"A B C\"", @@ -173,11 +181,17 @@ func TestBuilder_BuildParams(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - dg, err := unmarshalData([]byte(fmt.Sprintf(` -env: - - %s + var data string + if tt.env != "" { + data = fmt.Sprintf(`env: +- %s params: %s - `, tt.env, tt.params))) +`, tt.env, tt.params) + } else { + data = fmt.Sprintf(`params: %s +`, tt.params) + } + dg, err := unmarshalData([]byte(data)) require.NoError(t, err) def, err := decode(dg) diff --git a/internal/test/setup.go b/internal/test/setup.go index 86e7b4c1e..3033bcf93 100644 --- a/internal/test/setup.go +++ b/internal/test/setup.go @@ -56,7 +56,8 @@ func SetupTest(t *testing.T) Setup { err := os.Setenv("HOME", tmpDir) require.NoError(t, err) - viper.AddConfigPath(config.ConfigDir) + configDir := filepath.Join(tmpDir, "config") + viper.AddConfigPath(configDir) viper.SetConfigType("yaml") viper.SetConfigName("admin")