Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for user-defined functions and call field in YAML format #444

Merged
merged 6 commits into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 101 additions & 12 deletions internal/dag/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func buildAll(def *configDefinition, d *DAG, options BuildDAGOptions) error {
errList := &errors.ErrorList{}

errList.Add(buildLogDir(def, d))
errList.Add(assertFunctions(def.Functions))
errList.Add(buildSteps(def, d, options))
errList.Add(buildHandlers(def, d, options))
errList.Add(buildConfig(def, d))
Expand Down Expand Up @@ -179,28 +180,28 @@ func buildParams(def *configDefinition, d *DAG, options BuildDAGOptions) (err er
func buildHandlers(def *configDefinition, d *DAG, options BuildDAGOptions) (err error) {
if def.HandlerOn.Exit != nil {
def.HandlerOn.Exit.Name = constants.OnExit
if d.HandlerOn.Exit, err = buildStep(d.Env, def.HandlerOn.Exit, options); err != nil {
if d.HandlerOn.Exit, err = buildStep(d.Env, def.HandlerOn.Exit, def.Functions, options); err != nil {
return err
}
}

if def.HandlerOn.Success != nil {
def.HandlerOn.Success.Name = constants.OnSuccess
if d.HandlerOn.Success, err = buildStep(d.Env, def.HandlerOn.Success, options); err != nil {
if d.HandlerOn.Success, err = buildStep(d.Env, def.HandlerOn.Success, def.Functions, options); err != nil {
return
}
}

if def.HandlerOn.Failure != nil {
def.HandlerOn.Failure.Name = constants.OnFailure
if d.HandlerOn.Failure, err = buildStep(d.Env, def.HandlerOn.Failure, options); err != nil {
if d.HandlerOn.Failure, err = buildStep(d.Env, def.HandlerOn.Failure, def.Functions, options); err != nil {
return
}
}

if def.HandlerOn.Cancel != nil {
def.HandlerOn.Cancel.Name = constants.OnCancel
if d.HandlerOn.Cancel, err = buildStep(d.Env, def.HandlerOn.Cancel, options); err != nil {
if d.HandlerOn.Cancel, err = buildStep(d.Env, def.HandlerOn.Cancel, def.Functions, options); err != nil {
return
}
}
Expand Down Expand Up @@ -323,7 +324,7 @@ func loadVariables(strVariables interface{}, options BuildDAGOptions) (
func buildSteps(def *configDefinition, d *DAG, options BuildDAGOptions) error {
ret := []*Step{}
for _, stepDef := range def.Steps {
step, err := buildStep(d.Env, stepDef, options)
step, err := buildStep(d.Env, stepDef, def.Functions, options)
if err != nil {
return err
}
Expand All @@ -334,15 +335,47 @@ func buildSteps(def *configDefinition, d *DAG, options BuildDAGOptions) error {
return nil
}

func buildStep(variables []string, def *stepDef, options BuildDAGOptions) (*Step, error) {
if err := assertStepDef(def); err != nil {
func buildStep(variables []string, def *stepDef, funcs []*funcDef, options BuildDAGOptions) (*Step, error) {
if err := assertStepDef(def, funcs); err != nil {
return nil, err
}
step := &Step{}
step.Name = def.Name
step.Description = def.Description
step.CmdWithArgs = def.Command
step.Command, step.Args = utils.SplitCommand(step.CmdWithArgs, false)
if def.Call != nil {
step.Args = make([]string, 0, len(def.Call.Args))
passedArgs := map[string]string{}
for k, v := range def.Call.Args {
if strV, ok := v.(string); ok {
step.Args = append(step.Args, strV)
passedArgs[k] = strV
continue
}

if intV, ok := v.(int); ok {
strV := strconv.Itoa(intV)
step.Args = append(step.Args, strV)
passedArgs[k] = strV
continue
}

return nil, fmt.Errorf("args must be convertible to either int or string")
}

calledFuncDef := &funcDef{}
for _, funcDef := range funcs {
if funcDef.Name == def.Call.Function {
calledFuncDef = funcDef
break
}
}
step.Command = utils.RemoveParams(calledFuncDef.Command)
step.CmdWithArgs = utils.AssignValues(calledFuncDef.Command, passedArgs)
} else {
step.CmdWithArgs = def.Command
step.Command, step.Args = utils.SplitCommand(step.CmdWithArgs, false)
}

step.Script = def.Script
step.Stdout = expandEnv(def.Stdout, options)
step.Stderr = expandEnv(def.Stderr, options)
Expand Down Expand Up @@ -541,14 +574,70 @@ func parseSchedule(values []string) ([]*Schedule, error) {
return ret, nil
}

func assertStepDef(def *stepDef) error {
// only assert functions clause
func assertFunctions(funcs []*funcDef) error {
if funcs == nil {
return nil
}

nameMap := make(map[string]bool)
for _, funcDef := range funcs {
if _, exists := nameMap[funcDef.Name]; exists {
return fmt.Errorf("duplicate function")
}
nameMap[funcDef.Name] = true

definedParamNames := strings.Split(funcDef.Params, " ")
passedParamNames := utils.ExtractParamNames(funcDef.Command)
if len(definedParamNames) != len(passedParamNames) {
return fmt.Errorf("func params and args given to func command do not match")
}

for i := 0; i < len(definedParamNames); i++ {
if definedParamNames[i] != passedParamNames[i] {
return fmt.Errorf("func params and args given to func command do not match")
}
}
}

return nil
}

func assertStepDef(def *stepDef, funcs []*funcDef) error {
if def.Name == "" {
return fmt.Errorf("step name must be specified")
}
// TODO: Refactor the validation check for each executor.
if def.Executor == nil && def.Command == "" {
return fmt.Errorf("step command must be specified")
if def.Executor == nil && (def.Command == "" && def.Call == nil) {
return fmt.Errorf("either step command or step call must be specified if executor is nil")
}

if def.Call != nil {
calledFunc := def.Call.Function
calledFuncDef := &funcDef{}
for _, funcDef := range funcs {
if funcDef.Name == calledFunc {
calledFuncDef = funcDef
break
}
}
if calledFuncDef.Name == "" {
return fmt.Errorf("call must specify a functions that exists")
}

definedParamNames := strings.Split(calledFuncDef.Params, " ")
if len(def.Call.Args) != len(definedParamNames) {
return fmt.Errorf("the number of parameters defined in the function does not match the number of parameters given")
}

for _, paramName := range definedParamNames {
_, exists := def.Call.Args[paramName]
if !exists {
return fmt.Errorf("required parameter not found")
}
}
}

return nil
}

Expand Down
15 changes: 14 additions & 1 deletion internal/dag/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type configDefinition struct {
LogDir string
Env interface{}
HandlerOn handerOnDef
Functions []*funcDef
Steps []*stepDef
Smtp smtpConfigDef
MailOn *mailOnDef
Expand Down Expand Up @@ -52,7 +53,19 @@ type stepDef struct {
MailOnError bool
Preconditions []*conditionDef
SignalOnStop *string
Env string
Env string
Call *callFuncDef
}

type funcDef struct {
Name string
Params string
Command string
}

type callFuncDef struct {
Function string
Args map[string]interface{}
}

type continueOnDef struct {
Expand Down
33 changes: 33 additions & 0 deletions internal/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,39 @@ func SplitCommand(cmd string, parse bool) (program string, args []string) {
return vals[0], []string{}
}

// Assign values to command parameters
func AssignValues(command string, params map[string]string) string {
updatedCommand := command

for k, v := range params {
updatedCommand = strings.ReplaceAll(updatedCommand, fmt.Sprintf("$%v", k), v)
}

return updatedCommand
}

// Returns a command with parameters stripped from it.
func RemoveParams(command string) string {
paramRegex := regexp.MustCompile(`\$\w+`)

return paramRegex.ReplaceAllString(command, "")
}

// extracts a slice of parameter names by removing the '$' from the command string.
func ExtractParamNames(command string) []string {
words := strings.Fields(command)

var params []string
for _, word := range words {
if strings.HasPrefix(word, "$") {
paramName := strings.TrimPrefix(word, "$")
params = append(params, paramName)
}
}

return params
}

func UnescapeSpecialchars(str string) string {
repl := strings.NewReplacer(
`\\t`, `\t`,
Expand Down