Skip to content
This repository has been archived by the owner on Jan 8, 2024. It is now read-only.

Commit

Permalink
Merge pull request #1872 from hashicorp/aws-ecs-validate-memory-cpu
Browse files Browse the repository at this point in the history
AWS ECS: validate memory + cpu pairs
  • Loading branch information
catsby authored Jul 20, 2021
2 parents 77dfeb2 + f645a8e commit 9c32100
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 54 deletions.
7 changes: 7 additions & 0 deletions .changelog/1872.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
```release-note:bug
serverinstall/ecs: validate memory and cpu values
```

```release-note:bug
plugin/aws/ecs: validate memory and cpu values
```
63 changes: 10 additions & 53 deletions builtin/aws/ecs/platform.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ func (p *Platform) ConfigSet(config interface{}) error {
validation.Empty.When(alb.CertificateId != "" || alb.ZoneId != "" || alb.FQDN != "").Error("listener_arn can not be used with other options"),
),
))

if err != nil {
return err
}
Expand Down Expand Up @@ -342,7 +341,6 @@ func defaultSubnets(ctx context.Context, sess *session.Session) ([]*string, erro
},
},
})

if err != nil {
return nil, err
}
Expand All @@ -367,7 +365,6 @@ func (p *Platform) SetupCluster(ctx context.Context, s LifecycleStatus, sess *se
desc, err := ecsSvc.DescribeClusters(&ecs.DescribeClustersInput{
Clusters: []*string{aws.String(cluster)},
})

if err != nil {
return "", err
}
Expand Down Expand Up @@ -529,7 +526,6 @@ func (p *Platform) SetupLogs(ctx context.Context, s LifecycleStatus, L hclog.Log
Limit: aws.Int64(1),
LogGroupNamePrefix: aws.String(logGroup),
})

if err != nil {
return "", err
}
Expand All @@ -549,7 +545,6 @@ func (p *Platform) SetupLogs(ctx context.Context, s LifecycleStatus, L hclog.Log
}

return logGroup, nil

}

func createSG(
Expand All @@ -570,7 +565,6 @@ func createSG(
},
},
})

if err != nil {
return nil, err
}
Expand Down Expand Up @@ -780,7 +774,6 @@ func createALB(
},
},
})

if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -1063,58 +1056,26 @@ func (p *Platform) Launch(
L.Debug("registering task definition", "id", id)

var cpuShares int
family := "waypoint-" + app.App

s.Status("Registering Task definition: %s", family)

runtime := aws.String("FARGATE")
if p.config.EC2Cluster {
runtime = aws.String("EC2")
cpuShares = p.config.CPU
} else {
if p.config.Memory == 0 {
return nil, fmt.Errorf("Memory value required for fargate")
if err := utils.ValidateEcsMemCPUPair(p.config.Memory, p.config.CPU); err != nil {
return nil, err
}
cpuValues, ok := fargateResources[p.config.Memory]
if !ok {
var (
allValues []int
goodValues []string
)

for k := range fargateResources {
allValues = append(allValues, k)
}

sort.Ints(allValues)

for _, k := range allValues {
goodValues = append(goodValues, strconv.Itoa(k))
}
cpuValues := fargateResources[p.config.Memory]

return nil, fmt.Errorf("Invalid memory value: %d (valid values: %s)",
p.config.Memory, strings.Join(goodValues, ", "))
}

if p.config.CPU == 0 {
// at this point we know that config.CPU is either 0, or a valid value
// for the memory given
cpuShares = p.config.CPU
if cpuShares == 0 {
cpuShares = cpuValues[0]
} else {
var (
valid bool
goodValues []string
)

for _, c := range cpuValues {
goodValues = append(goodValues, strconv.Itoa(c))
if c == p.config.CPU {
valid = true
break
}
}

if !valid {
return nil, fmt.Errorf("Invalid cpu value: %d (valid values: %s)",
p.config.Memory, strings.Join(goodValues, ", "))
}

cpuShares = p.config.CPU
}
}

Expand All @@ -1125,10 +1086,6 @@ func (p *Platform) Launch(
}
mems := strconv.Itoa(p.config.Memory)

family := "waypoint-" + app.App

s.Status("Registering Task definition: %s", family)

containerDefinitions := append([]*ecs.ContainerDefinition{&def}, additionalContainers...)

registerTaskDefinitionInput := ecs.RegisterTaskDefinitionInput{
Expand Down
76 changes: 76 additions & 0 deletions builtin/aws/utils/validations.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package utils

import (
"fmt"
"sort"
"strconv"
"strings"

validation "github.com/go-ozzo/ozzo-validation/v4"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -41,6 +46,77 @@ func Error(err error) error {
return st.Err()
}

var fargateResources = map[int][]int{
512: {256},
1024: {256, 512},
2048: {256, 512, 1024},
3072: {512, 1024},
4096: {512, 1024},
5120: {1024},
6144: {1024},
7168: {1024},
8192: {1024},
}

func init() {
for i := 4096; i < 16384; i += 1024 {
fargateResources[i] = append(fargateResources[i], 2048)
}

for i := 8192; i <= 30720; i += 1024 {
fargateResources[i] = append(fargateResources[i], 4096)
}
}

func ValidateEcsMemCPUPair(mem, cpu int) error {
cpuValues, ok := fargateResources[mem]
if !ok {
var (
allValues []int
goodValues []string
)

for k := range fargateResources {
allValues = append(allValues, k)
}

sort.Ints(allValues)

for _, k := range allValues {
goodValues = append(goodValues, strconv.Itoa(k))
}

return fmt.Errorf("invalid memory value: %d (valid values: %s)", mem,
strings.Join(goodValues, ", "))
}

if cpu == 0 {
// if cpu is 0 a default will likely be chosen by which ever AWS service
// is being used, based on the memory value
return nil
}

var (
valid bool
goodValues []string
)

for _, c := range cpuValues {
goodValues = append(goodValues, strconv.Itoa(c))
if c == cpu {
valid = true
break
}
}

if !valid {
return fmt.Errorf("invalid cpu value: %d (valid values: %s)",
mem, strings.Join(goodValues, ", "))
}

return nil
}

// errorAppend accumulates field violations by recursively nesting into the
// validation errors. We have to recurse to get nested structs/maps/etc.
// With each recursion, we prefix the errors with the field path to that
Expand Down
68 changes: 68 additions & 0 deletions builtin/aws/utils/validations_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package utils

import (
"testing"
)

func TestValidateEcsMemCPUPair(t *testing.T) {
// test values based off of
// https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-cpu-memory-error.html
// circa July 15, 2021

cases := map[string]struct {
mem int
cpu int
shouldErr bool
}{
"zeros": {
shouldErr: true,
},
"512/0": {
mem: 512,
},
"512/256": {
mem: 512,
cpu: 256,
},
"4096": {
mem: 4096,
},
"4096/512": {
mem: 4096,
cpu: 512,
},
"4096/256": {
mem: 4096,
cpu: 256,
shouldErr: true,
},
"512/512": {
mem: 512,
cpu: 512,
shouldErr: true,
},
"nonsense": {
mem: 7,
shouldErr: true,
},
"bad_pair": {
mem: 512,
cpu: 512,
shouldErr: true,
},
"zero_mem": {
cpu: 7,
shouldErr: true,
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
if err := ValidateEcsMemCPUPair(c.mem, c.cpu); err != nil {
if !c.shouldErr {
t.Error(err)
}
}
})
}
}
27 changes: 26 additions & 1 deletion internal/serverinstall/ecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ func (i *ECSInstaller) Install(
err error
)

// validate we have a memory/cpu combination that ECS will accept. See
// https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-cpu-memory-error.html
// for more information on valid combinations
mem, err := strconv.Atoi(i.config.Memory)
if err != nil {
return nil, err
}
cpu, err := strconv.Atoi(i.config.CPU)
if err != nil {
return nil, err
}

if err := utils.ValidateEcsMemCPUPair(mem, cpu); err != nil {
return nil, err
}

lf := &Lifecycle{
Init: func(ui terminal.UI) error {
sess, err = utils.GetSession(&utils.SessionConfig{
Expand Down Expand Up @@ -283,6 +299,8 @@ func (i *ECSInstaller) Launch(
return nil, err
}

// registerTaskDefinition() above ensures taskDef here is non-nil, if the
// error returned is nil
taskDefArn := *taskDef.TaskDefinitionArn

// Create the service
Expand Down Expand Up @@ -1964,14 +1982,21 @@ func registerTaskDefinition(def *ecs.RegisterTaskDefinitionInput, ecsSvc *ecs.EC

// if we encounter an unrecoverable error, exit now.
if aerr, ok := err.(awserr.Error); ok {
if aerr.Code() == "ResourceConflictException" {
if aerr.Code() == "ResourceConflictException" || aerr.Code() == "ClientException" {
return nil, err
}
}

// otherwise sleep and try again
time.Sleep(2 * time.Second)
}

// the above loop could expire and never get a valid task definition, so
// guard against a nil taskOut here
if taskOut == nil {
return nil, fmt.Errorf("error registering task definition, last error: %w", err)
}

return taskOut.TaskDefinition, nil
}

Expand Down

0 comments on commit 9c32100

Please sign in to comment.