Skip to content

Commit

Permalink
Merge pull request #15887 from DrFaust92/r/sagemaker_endpoint_configu…
Browse files Browse the repository at this point in the history
…ration_data_capture

r/sagemaker_endpoint_configuration - support `data_capture_config` + validations and test refactor
  • Loading branch information
YakDriver committed Nov 5, 2020
2 parents 53f4dac + 82105d6 commit 290d97d
Show file tree
Hide file tree
Showing 3 changed files with 398 additions and 51 deletions.
233 changes: 226 additions & 7 deletions aws/resource_aws_sagemaker_endpoint_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aws
import (
"fmt"
"log"
"regexp"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sagemaker"
Expand Down Expand Up @@ -62,9 +63,10 @@ func resourceAwsSagemakerEndpointConfiguration() *schema.Resource {
},

"instance_type": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ProductionVariantInstanceType_Values(), false),
},

"initial_variant_weight": {
Expand All @@ -76,9 +78,10 @@ func resourceAwsSagemakerEndpointConfiguration() *schema.Resource {
},

"accelerator_type": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ProductionVariantAcceleratorType_Values(), false),
},
},
},
Expand All @@ -92,6 +95,102 @@ func resourceAwsSagemakerEndpointConfiguration() *schema.Resource {
},

"tags": tagsSchema(),

"data_capture_config": {
Type: schema.TypeList,
MaxItems: 1,
Optional: true,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"enable_capture": {
Type: schema.TypeBool,
Optional: true,
ForceNew: true,
},

"initial_sampling_percentage": {
Type: schema.TypeInt,
Required: true,
ForceNew: true,
ValidateFunc: validation.IntBetween(0, 100),
},

"destination_s3_uri": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.All(
validation.StringMatch(regexp.MustCompile(`^(https|s3)://([^/])/?(.*)$`), ""),
validation.StringLenBetween(1, 512),
)},

"kms_key_id": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validateArn,
},

"capture_options": {
Type: schema.TypeList,
Required: true,
MaxItems: 2,
MinItems: 1,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"capture_mode": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.CaptureMode_Values(), false),
},
},
},
},

"capture_content_type_header": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"csv_content_types": {
Type: schema.TypeSet,
MinItems: 1,
MaxItems: 10,
Elem: &schema.Schema{
Type: schema.TypeString,
ValidateFunc: validation.All(
validation.StringMatch(regexp.MustCompile(`^[a-zA-Z0-9](-*[a-zA-Z0-9])*\/[a-zA-Z0-9](-*[a-zA-Z0-9.])*`), ""),
validation.StringLenBetween(1, 256),
),
},
Optional: true,
ForceNew: true,
},
"json_content_types": {
Type: schema.TypeSet,
MinItems: 1,
MaxItems: 10,
Elem: &schema.Schema{
Type: schema.TypeString,
ValidateFunc: validation.All(
validation.StringMatch(regexp.MustCompile(`^[a-zA-Z0-9](-*[a-zA-Z0-9])*\/[a-zA-Z0-9](-*[a-zA-Z0-9.])*`), ""),
validation.StringLenBetween(1, 256),
),
},
Optional: true,
ForceNew: true,
},
},
},
},
},
},
},
},
}
}
Expand Down Expand Up @@ -119,6 +218,10 @@ func resourceAwsSagemakerEndpointConfigurationCreate(d *schema.ResourceData, met
createOpts.Tags = keyvaluetags.New(v.(map[string]interface{})).IgnoreAws().SagemakerTags()
}

if v, ok := d.GetOk("data_capture_config"); ok {
createOpts.DataCaptureConfig = expandSagemakerDataCaptureConfig(v.([]interface{}))
}

log.Printf("[DEBUG] SageMaker Endpoint Configuration create config: %#v", *createOpts)
_, err := conn.CreateEndpointConfig(createOpts)
if err != nil {
Expand All @@ -139,7 +242,7 @@ func resourceAwsSagemakerEndpointConfigurationRead(d *schema.ResourceData, meta

endpointConfig, err := conn.DescribeEndpointConfig(request)
if err != nil {
if isAWSErr(err, "ValidationException", "") {
if isAWSErr(err, "ValidationException", "Could not find endpoint configuration") {
log.Printf("[INFO] unable to find the SageMaker Endpoint Configuration resource and therefore it is removed from the state: %s", d.Id())
d.SetId("")
return nil
Expand All @@ -159,6 +262,9 @@ func resourceAwsSagemakerEndpointConfigurationRead(d *schema.ResourceData, meta
if err := d.Set("kms_key_arn", endpointConfig.KmsKeyId); err != nil {
return err
}
if err := d.Set("data_capture_config", flattenSagemakerDataCaptureConfig(endpointConfig.DataCaptureConfig)); err != nil {
return err
}

tags, err := keyvaluetags.SagemakerListTags(conn, aws.StringValue(endpointConfig.EndpointConfigArn))
if err != nil {
Expand Down Expand Up @@ -255,3 +361,116 @@ func flattenProductionVariants(list []*sagemaker.ProductionVariant) []map[string
}
return result
}

func expandSagemakerDataCaptureConfig(configured []interface{}) *sagemaker.DataCaptureConfig {
if len(configured) == 0 {
return nil
}

m := configured[0].(map[string]interface{})

c := &sagemaker.DataCaptureConfig{
InitialSamplingPercentage: aws.Int64(int64(m["initial_sampling_percentage"].(int))),
DestinationS3Uri: aws.String(m["destination_s3_uri"].(string)),
CaptureOptions: expandSagemakerCaptureOptions(m["capture_options"].([]interface{})),
}

if v, ok := m["enable_capture"]; ok {
c.EnableCapture = aws.Bool(v.(bool))
}

if v, ok := m["kms_key_id"]; ok && v.(string) != "" {
c.KmsKeyId = aws.String(v.(string))
}

if v, ok := m["capture_content_type_header"]; ok && (len(v.([]interface{})) > 0) {
c.CaptureContentTypeHeader = expandSagemakerCaptureContentTypeHeader(v.([]interface{})[0].(map[string]interface{}))
}

return c
}

func flattenSagemakerDataCaptureConfig(dataCaptureConfig *sagemaker.DataCaptureConfig) []map[string]interface{} {
if dataCaptureConfig == nil {
return []map[string]interface{}{}
}

cfg := map[string]interface{}{
"initial_sampling_percentage": aws.Int64Value(dataCaptureConfig.InitialSamplingPercentage),
"destination_s3_uri": aws.StringValue(dataCaptureConfig.DestinationS3Uri),
"capture_options": flattenSagemakerCaptureOptions(dataCaptureConfig.CaptureOptions),
}

if dataCaptureConfig.EnableCapture != nil {
cfg["enable_capture"] = aws.BoolValue(dataCaptureConfig.EnableCapture)
}

if dataCaptureConfig.KmsKeyId != nil {
cfg["kms_key_id"] = aws.StringValue(dataCaptureConfig.KmsKeyId)
}

if dataCaptureConfig.CaptureContentTypeHeader != nil {
cfg["capture_content_type_header"] = flattenSagemakerCaptureContentTypeHeader(dataCaptureConfig.CaptureContentTypeHeader)
}

return []map[string]interface{}{cfg}
}

func expandSagemakerCaptureOptions(configured []interface{}) []*sagemaker.CaptureOption {
containers := make([]*sagemaker.CaptureOption, 0, len(configured))

for _, lRaw := range configured {
data := lRaw.(map[string]interface{})

l := &sagemaker.CaptureOption{
CaptureMode: aws.String(data["capture_mode"].(string)),
}
containers = append(containers, l)
}

return containers
}

func flattenSagemakerCaptureOptions(list []*sagemaker.CaptureOption) []map[string]interface{} {
containers := make([]map[string]interface{}, 0, len(list))

for _, lRaw := range list {
captureOption := make(map[string]interface{})
captureOption["capture_mode"] = aws.StringValue(lRaw.CaptureMode)
containers = append(containers, captureOption)
}

return containers
}

func expandSagemakerCaptureContentTypeHeader(m map[string]interface{}) *sagemaker.CaptureContentTypeHeader {
c := &sagemaker.CaptureContentTypeHeader{}

if v, ok := m["csv_content_types"].(*schema.Set); ok && v.Len() > 0 {
c.CsvContentTypes = expandStringSet(v)
}

if v, ok := m["json_content_types"].(*schema.Set); ok && v.Len() > 0 {
c.JsonContentTypes = expandStringSet(v)
}

return c
}

func flattenSagemakerCaptureContentTypeHeader(contentTypeHeader *sagemaker.CaptureContentTypeHeader) []map[string]interface{} {
if contentTypeHeader == nil {
return []map[string]interface{}{}
}

l := make(map[string]interface{})

if contentTypeHeader.CsvContentTypes != nil {
l["csv_content_types"] = flattenStringSet(contentTypeHeader.CsvContentTypes)
}

if contentTypeHeader.JsonContentTypes != nil {
l["json_content_types"] = flattenStringSet(contentTypeHeader.JsonContentTypes)
}

return []map[string]interface{}{l}
}
Loading

0 comments on commit 290d97d

Please sign in to comment.