Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions internal/pkg/object/command/ecs/container_options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package ecs

import (
"fmt"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
)

// ContainerOption is a function that modifies container overrides
type ContainerOption func(*executionContext) error

// ApplyContainerOptions applies a list of container options to container overrides
func ApplyContainerOptions(execCtx *executionContext, options ...ContainerOption) error {
for _, option := range options {
if err := option(execCtx); err != nil {
return err
}
}
return nil
}

// WithFileUploadScript wraps container commands with a script that downloads files from S3 to the container
func WithFileUploadScript(fileUploads []FileUpload, localDir string) ContainerOption {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with file download script? Name doesn't match description

return func(execCtx *executionContext) error {
if len(fileUploads) == 0 {
return nil
}

for i := range execCtx.ContainerOverrides {
override := &execCtx.ContainerOverrides[i]

// Get the original command from override, or from task definition if not overridden
var originalCommand []string
if len(override.Command) > 0 {
originalCommand = override.Command
} else {
// Get command from task definition
for _, container := range execCtx.TaskDefinitionWrapper.TaskDefinition.ContainerDefinitions {
if aws.ToString(container.Name) == aws.ToString(override.Name) {
originalCommand = container.Command
break
}
}
}

if len(originalCommand) == 0 {
originalCommand = []string{}
}

// Generate the download wrapper script
wrapperScript := generateDownloadWrapperScript(fileUploads, localDir, originalCommand)

// Replace container command with wrapper script
override.Command = []string{"sh", "-c", wrapperScript}
}

return nil
}
}

const downloadScriptTemplate = `
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the general rule is to put constants on the top of the file. You can also consider the option to move it to the separate file

set -e
apk update && apk add aws-cli
mkdir -p {{LOCAL_DIR}}
for s3_path in {{S3_PATHS}};do aws s3 cp "$s3_path" "{{LOCAL_DIR}}/$(basename "$s3_path")" 2>&1;done
exec {{CMD}}`

// generateDownloadWrapperScript generates a minimal bash script that downloads files from S3 and executes the original command
func generateDownloadWrapperScript(fileUploads []FileUpload, localDir string, originalCommand []string) string {
// Build S3 paths list for the for loop
var s3Paths []string
for _, upload := range fileUploads {
s3Paths = append(s3Paths, fmt.Sprintf(`"%s"`, upload.S3Destination))
}
s3PathsList := strings.Join(s3Paths, " ")

// Build command string
cmdStr := "wait"
if len(originalCommand) > 0 {
escapedCmd := make([]string, len(originalCommand))
for i, arg := range originalCommand {
escapedCmd[i] = fmt.Sprintf("'%s'", strings.ReplaceAll(arg, "'", "'\\''"))
}
cmdStr = strings.Join(escapedCmd, " ")
}

// Fill template
script := strings.ReplaceAll(downloadScriptTemplate, "{{LOCAL_DIR}}", localDir)
script = strings.ReplaceAll(script, "{{S3_PATHS}}", s3PathsList)
script = strings.ReplaceAll(script, "{{CMD}}", cmdStr)
return script
}

// Future container options can be added here as needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add it as a comment in PR, but when time comes people will figure out how to do that

// Example:
// func WithEnvironmentVariables(envVars map[string]string) ContainerOption { ... }
// func WithHealthCheck(config HealthCheckConfig) ContainerOption { ... }
// func WithResourceLimits(limits ResourceLimits) ContainerOption { ... }
115 changes: 114 additions & 1 deletion internal/pkg/object/command/ecs/ecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ecs"
"github.com/aws/aws-sdk-go-v2/service/ecs/types"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/patterninc/heimdall/pkg/context"
"github.com/patterninc/heimdall/pkg/duration"
"github.com/patterninc/heimdall/pkg/object/cluster"
Expand All @@ -21,6 +22,12 @@ import (
"github.com/patterninc/heimdall/pkg/result/column"
)

// FileUpload represents configuration for uploading files from container to S3
type FileUpload struct {
Data string `yaml:"data,omitempty" json:"data,omitempty"` // File content as string
S3Destination string `yaml:"s3_destination,omitempty" json:"s3_destination,omitempty"` // S3 path (e.g., s3://bucket/path/filename)
}

// ECS command context structure
type ecsCommandContext struct {
TaskDefinitionTemplate string `yaml:"task_definition_template,omitempty" json:"task_definition_template,omitempty"`
Expand All @@ -29,6 +36,9 @@ type ecsCommandContext struct {
PollingInterval duration.Duration `yaml:"polling_interval,omitempty" json:"polling_interval,omitempty"`
Timeout duration.Duration `yaml:"timeout,omitempty" json:"timeout,omitempty"`
MaxFailCount int `yaml:"max_fail_count,omitempty" json:"max_fail_count,omitempty"` // max failures before giving up

// File upload configuration
FileUploads []FileUpload `yaml:"file_uploads,omitempty" json:"file_uploads,omitempty"`
Copy link
Contributor

@hladush hladush Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it a list of pointers

}

// ECS cluster context structure
Expand Down Expand Up @@ -77,7 +87,11 @@ type executionContext struct {
Timeout duration.Duration `json:"timeout"`
MaxFailCount int `json:"max_fail_count"`

// File upload configuration
FileUploads []FileUpload `json:"file_uploads"`

ecsClient *ecs.Client
s3Client *s3.Client
taskDefARN *string
tasks map[string]*taskTracker
}
Expand All @@ -87,6 +101,8 @@ const (
defaultTaskTimeout = duration.Duration(1 * time.Hour)
defaultMaxFailCount = 1
defaultTaskCount = 1
defaultUploadTimeout = 30
defaultLocalDir = "/tmp/downloads"
startedByPrefix = "heimdall-job-"
errMaxFailCount = "task %s failed %d times (max: %d), giving up"
errPollingTimeout = "polling timed out for arns %v after %v"
Expand Down Expand Up @@ -130,6 +146,11 @@ func (e *ecsCommandContext) handler(r *plugin.Runtime, job *job.Job, cluster *cl
return err
}

// Upload files to S3 if configured
if err := execCtx.uploadFilesToS3(); err != nil {
return fmt.Errorf("failed to upload files to S3: %w", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metrics were added to the heimdall please start to use them

}

// Start tasks
if err := execCtx.startTasks(job.ID); err != nil {
return err
Expand All @@ -151,6 +172,9 @@ func (e *ecsCommandContext) handler(r *plugin.Runtime, job *job.Job, cluster *cl

// prepare and register task definition with ECS
func (execCtx *executionContext) registerTaskDefinition() error {
// Use the original container definitions from the template
containerDefinitions := execCtx.TaskDefinitionWrapper.TaskDefinition.ContainerDefinitions

registerInput := &ecs.RegisterTaskDefinitionInput{
Family: aws.String(aws.ToString(execCtx.TaskDefinitionWrapper.TaskDefinition.Family)),
RequiresCompatibilities: []types.Compatibility{types.CompatibilityFargate},
Expand All @@ -159,7 +183,7 @@ func (execCtx *executionContext) registerTaskDefinition() error {
Memory: aws.String(fmt.Sprintf("%d", execCtx.ClusterConfig.Memory)),
ExecutionRoleArn: aws.String(execCtx.ClusterConfig.ExecutionRoleARN),
TaskRoleArn: aws.String(execCtx.ClusterConfig.TaskRoleARN),
ContainerDefinitions: execCtx.TaskDefinitionWrapper.TaskDefinition.ContainerDefinitions,
ContainerDefinitions: containerDefinitions,
}

registerOutput, err := execCtx.ecsClient.RegisterTaskDefinition(ctx, registerInput)
Expand Down Expand Up @@ -350,6 +374,19 @@ func buildExecutionContext(commandCtx *ecsCommandContext, j *job.Job, c *cluster
return nil, err
}

// Apply container options to ContainerOverrides
var options []ContainerOption

// Add file upload script option if configured
if len(execCtx.FileUploads) > 0 {
options = append(options, WithFileUploadScript(execCtx.FileUploads, defaultLocalDir))
}

// Apply all options to container overrides
if err := ApplyContainerOptions(execCtx, options...); err != nil {
return nil, fmt.Errorf("failed to apply container options: %w", err)
}

// Validate the resolved configuration
if err := validateExecutionContext(execCtx); err != nil {
return nil, err
Expand All @@ -361,6 +398,7 @@ func buildExecutionContext(commandCtx *ecsCommandContext, j *job.Job, c *cluster
return nil, err
}
execCtx.ecsClient = ecs.NewFromConfig(cfg)
execCtx.s3Client = s3.NewFromConfig(cfg)

return execCtx, nil

Expand All @@ -373,6 +411,20 @@ func validateExecutionContext(ctx *executionContext) error {
return fmt.Errorf("task count (%d) needs to be greater than 0 and less than cluster max task count (%d)", ctx.TaskCount, ctx.ClusterConfig.MaxTaskCount)
}

// Validate file uploads configuration
for i, upload := range ctx.FileUploads {
if upload.Data == "" {
return fmt.Errorf("file upload %d: data is required", i)
}
if upload.S3Destination == "" {
return fmt.Errorf("file upload %d: s3_destination is required", i)
}
// Validate that destination is an S3 URI
if !strings.HasPrefix(upload.S3Destination, "s3://") {
return fmt.Errorf("file upload %d: s3_destination must be an S3 URI (s3://bucket/path/filename)", i)
}
}

return nil

}
Expand Down Expand Up @@ -537,6 +589,67 @@ func isTaskSuccessful(task types.Task, execCtx *executionContext) bool {

}

// uploadFilesToS3 uploads file data to S3 after task completion
func (execCtx *executionContext) uploadFilesToS3() error {
// Skip if no files to upload
if len(execCtx.FileUploads) == 0 {
return nil
}

for i, upload := range execCtx.FileUploads {
// Parse S3 URI (s3://bucket/key/filename)
bucket, key, err := parseS3URI(upload.S3Destination)
if err != nil {
return fmt.Errorf("failed to parse S3 URI for upload %d: %w", i, err)
}

// Upload file content to S3

input := &s3.PutObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
Body: strings.NewReader(upload.Data),
}

// Set timeout context with default timeout
uploadCtx, cancel := ct.WithTimeout(ctx, time.Duration(defaultUploadTimeout)*time.Second)
defer cancel()

_, err = execCtx.s3Client.PutObject(uploadCtx, input)
if err != nil {
return fmt.Errorf("failed to upload file %d to S3 (%s): %w", i, upload.S3Destination, err)
}

}

return nil
}

// parseS3URI parses an S3 URI into bucket and key components
func parseS3URI(s3URI string) (bucket, key string, err error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move this function to another package because this function can be reused in different places

if !strings.HasPrefix(s3URI, "s3://") {
return "", "", fmt.Errorf("invalid S3 URI: must start with s3://")
}

// Remove s3:// prefix
path := strings.TrimPrefix(s3URI, "s3://")

// Split into bucket and key
parts := strings.SplitN(path, "/", 2)
if len(parts) < 2 {
return "", "", fmt.Errorf("invalid S3 URI: must include bucket and key (s3://bucket/key)")
}

bucket = parts[0]
key = parts[1]

if bucket == "" || key == "" {
return "", "", fmt.Errorf("invalid S3 URI: bucket and key cannot be empty")
}

return bucket, key, nil
}

// storeResults builds and stores the final result for the job.
func storeResults(execCtx *executionContext, j *job.Job) error {

Expand Down
Loading