diff --git a/pkg/config/config.go b/pkg/config/config.go index a0d927c06..7a67c9c45 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -30,8 +30,9 @@ type ( PubSub map[string]*PubSub `json:"pubsub,omitempty" yaml:"pubsub,omitempty" toml:"pubsub,omitempty"` } Expose struct { - Type string `json:"type" yaml:"type" toml:"type"` - InfraParams InfraParams `json:"pulumi_params,omitempty" yaml:"pulumi_params,omitempty" toml:"pulumi_params,omitempty"` + Type string `json:"type" yaml:"type" toml:"type"` + ContentDeliveryNetwork ContentDeliveryNetwork `json:"content_delivery_network,omitempty" yaml:"content_delivery_network,omitempty" toml:"content_delivery_network,omitempty"` + InfraParams InfraParams `json:"pulumi_params,omitempty" yaml:"pulumi_params,omitempty" toml:"pulumi_params,omitempty"` } Persist struct { @@ -58,8 +59,9 @@ type ( } StaticUnit struct { - Type string `json:"type" yaml:"type" toml:"type"` - InfraParams InfraParams `json:"pulumi_params,omitempty" yaml:"pulumi_params,omitempty" toml:"pulumi_params,omitempty"` + Type string `json:"type" yaml:"type" toml:"type"` + InfraParams InfraParams `json:"pulumi_params,omitempty" yaml:"pulumi_params,omitempty" toml:"pulumi_params,omitempty"` + ContentDeliveryNetwork ContentDeliveryNetwork `json:"content_delivery_network,omitempty" yaml:"content_delivery_network,omitempty" toml:"content_delivery_network,omitempty"` } Defaults struct { @@ -84,6 +86,10 @@ type ( RedisCluster KindDefaults `json:"redis_cluster" yaml:"redis_cluster" toml:"redis_cluster"` } + ContentDeliveryNetwork struct { + Id string `json:"id,omitempty" yaml:"id,omitempty" toml:"id,omitempty"` + } + // InfraParams are passed as-is to the generated IaC InfraParams map[string]interface{} ) @@ -150,6 +156,7 @@ func (cfg *Expose) Merge(other Expose) { if other.Type != "" { cfg.Type = other.Type } + cfg.ContentDeliveryNetwork = other.ContentDeliveryNetwork cfg.InfraParams.Merge(other.InfraParams) } @@ -171,6 +178,7 @@ func (cfg *StaticUnit) Merge(other StaticUnit) { if other.Type != "" { cfg.Type = other.Type } + cfg.ContentDeliveryNetwork = other.ContentDeliveryNetwork cfg.InfraParams.Merge(other.InfraParams) } diff --git a/pkg/infra/pulumi_aws/deploylib.ts b/pkg/infra/pulumi_aws/deploylib.ts index 7dce2d2d3..8e2e25a8a 100755 --- a/pkg/infra/pulumi_aws/deploylib.ts +++ b/pkg/infra/pulumi_aws/deploylib.ts @@ -9,22 +9,15 @@ import * as fs from 'fs' import * as requestRetry from 'requestretry' import * as crypto from 'crypto' import { setupElasticacheCluster } from './iac/elasticache' - import * as analytics from './iac/analytics' import { LoadBalancerPlugin } from './iac/load_balancing' -import { - DefaultEksClusterOptions, - Eks, - EksExecUnit, - EksExecUnitArgs, - HelmChart, - plugins as EksPlugins, -} from './iac/eks' +import { DefaultEksClusterOptions, Eks, EksExecUnit, HelmChart } from './iac/eks' import { setupMemoryDbCluster } from './iac/memorydb' export enum Resource { exec_unit = 'exec_unit', + static_unit = 'static_unit', gateway = 'gateway', kv = 'persist_kv', fs = 'persist_fs', @@ -35,6 +28,11 @@ export enum Resource { pubsub = 'pubsub', } +export interface ResourceKey { + Kind: string + Name: string +} + interface ResourceInfo { id: string urn: string @@ -73,6 +71,9 @@ export class CloudCCLib { execUnitToPolicyStatements = new Map() execUnitToImage = new Map>() + gatewayToUrl = new Map>() + siteBuckets = new Map() + topologySpecOutputs: pulumi.Output[] = [] connectionString = new Map>() @@ -1069,6 +1070,8 @@ export class CloudCCLib { })) ) + this.gatewayToUrl.set(providedName, stage.invokeUrl) + return stage.invokeUrl } diff --git a/pkg/infra/pulumi_aws/iac/cloudfront.ts b/pkg/infra/pulumi_aws/iac/cloudfront.ts new file mode 100644 index 000000000..623e5e96a --- /dev/null +++ b/pkg/infra/pulumi_aws/iac/cloudfront.ts @@ -0,0 +1,151 @@ +import * as aws from '@pulumi/aws' +import * as pulumi from '@pulumi/pulumi' +import * as mime from 'mime' +import * as fs from 'fs' +import * as path from 'path' +import { CloudCCLib, ResourceKey, Resource } from '../deploylib' + +export interface CloudfrontDistribution { + Id: string + Origins: ResourceKey[] + DefaultRootObject: string +} + +interface TargetOrigin { + type?: Resource.static_unit | Resource.gateway + id?: string +} + +export class Cloudfront { + constructor(lib: CloudCCLib, cloudfrontDistributions: CloudfrontDistribution[]) { + for (const dist of cloudfrontDistributions) { + const origins: aws.types.input.cloudfront.DistributionOrigin[] = [] + let targetOrigin: TargetOrigin = {} + const indexDocument = dist.DefaultRootObject == '' ? undefined : dist.DefaultRootObject + for (const origin of dist.Origins) { + if (origin.Kind == Resource.gateway) { + origins.push( + this.createCustomOrigin(origin.Name, lib.gatewayToUrl.get(origin.Name)!) + ) + if (!targetOrigin.id) { + targetOrigin = { + type: Resource.gateway, + id: origin.Name, + } + } + } else if (origin.Kind == Resource.static_unit) { + const bucket = lib.siteBuckets.get(origin.Name)! + origins.push(this.createS3Origin(origin.Name, bucket)) + targetOrigin = { + type: Resource.static_unit, + id: origin.Name, + } + } + } + this.createDistribution(dist.Id, origins, targetOrigin, indexDocument) + } + } + + createCustomOrigin( + name: string, + domainName: pulumi.Output + ): aws.types.input.cloudfront.DistributionOrigin { + const origin: aws.types.input.cloudfront.DistributionOrigin = { + originId: name, + customOriginConfig: { + httpPort: 80, + httpsPort: 443, + originProtocolPolicy: 'https-only', + originSslProtocols: ['SSLv3', 'TLSv1', 'TLSv1.1', 'TLSv1.2'], + }, + domainName: domainName.apply((d) => d.split('//')[1].split('/')[0]), + originPath: domainName.apply((d) => '/' + d.split('//')[1].split('/')[1]), + } + return origin + } + + createS3Origin( + name: string, + siteBucket: aws.s3.Bucket + ): aws.types.input.cloudfront.DistributionOrigin { + const originAccessIdentity = new aws.cloudfront.OriginAccessIdentity( + `${siteBucket}-originAccessIdentity`, + { + comment: 'this is needed to setup s3 polices and make s3 not public.', + } + ) + + new aws.s3.BucketPolicy('bucketPolicy', { + bucket: siteBucket.id, // refer to the bucket created earlier + policy: pulumi + .all([originAccessIdentity.iamArn, siteBucket.arn]) + .apply(([oaiArn, bucketArn]) => + JSON.stringify({ + Version: '2012-10-17', + Statement: [ + { + Effect: 'Allow', + Principal: { + AWS: oaiArn, + }, // Only allow Cloudfront read access. + Action: ['s3:GetObject'], + Resource: [`${bucketArn}/*`], // Give Cloudfront access to the entire bucket. + }, + ], + }) + ), + }) + + const origin = { + domainName: siteBucket.bucketRegionalDomainName, + originId: name, + s3OriginConfig: { + originAccessIdentity: originAccessIdentity.cloudfrontAccessIdentityPath, + }, + } + + return origin + } + + createDistribution( + name, + origins, + targetOrigin: TargetOrigin, + indexDocument? + ): aws.cloudfront.Distribution { + let defaultTtl = 3600 + if (targetOrigin.type == Resource.gateway) { + defaultTtl = 0 + } + + const distribution = new aws.cloudfront.Distribution(name, { + origins, + enabled: true, + viewerCertificate: { + cloudfrontDefaultCertificate: true, + }, + defaultCacheBehavior: { + allowedMethods: ['DELETE', 'GET', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT'], + cachedMethods: ['HEAD', 'GET'], + targetOriginId: targetOrigin.id!, + forwardedValues: { + queryString: true, + cookies: { + forward: 'none', + }, + }, + viewerProtocolPolicy: 'allow-all', + minTtl: 0, + defaultTtl, + maxTtl: 86400, + }, + restrictions: { + geoRestriction: { + restrictionType: 'none', + }, + }, + defaultRootObject: indexDocument, + }) + return distribution + } +} diff --git a/pkg/infra/pulumi_aws/iac/static_s3_website.ts b/pkg/infra/pulumi_aws/iac/static_s3_website.ts index ddcb7fac8..5f1ecf422 100644 --- a/pkg/infra/pulumi_aws/iac/static_s3_website.ts +++ b/pkg/infra/pulumi_aws/iac/static_s3_website.ts @@ -3,121 +3,27 @@ import * as pulumi from '@pulumi/pulumi' import * as mime from 'mime' import * as fs from 'fs' import * as path from 'path' +import { CloudCCLib } from '../deploylib' export const createStaticS3Website = ( staticUnit: string, indexDocument: string, - params -): pulumi.Output => { + contentDeliveryNetworkId: string, + lib: CloudCCLib +) => { // Create an S3 bucket const bucketArgs: aws.s3.BucketArgs = {} - if (indexDocument != '' && !params.cloudFrontEnabled) { + if (indexDocument != '' && contentDeliveryNetworkId == '') { bucketArgs['website'] = { indexDocument: indexDocument, } } let siteBucket = new aws.s3.Bucket(`static-website-${staticUnit}`, bucketArgs) + lib.siteBuckets.set(staticUnit, siteBucket) createAllObjects(staticUnit, siteBucket) - - if (params.cloudFrontEnabled) { - // Generate Origin Access Identity to access the private s3 bucket. - const originAccessIdentity = new aws.cloudfront.OriginAccessIdentity( - 'originAccessIdentity', - { - comment: 'this is needed to setup s3 polices and make s3 not public.', - } - ) - - const bucketPolicy = new aws.s3.BucketPolicy('bucketPolicy', { - bucket: siteBucket.id, // refer to the bucket created earlier - policy: pulumi - .all([originAccessIdentity.iamArn, siteBucket.arn]) - .apply(([oaiArn, bucketArn]) => - JSON.stringify({ - Version: '2012-10-17', - Statement: [ - { - Effect: 'Allow', - Principal: { - AWS: oaiArn, - }, // Only allow Cloudfront read access. - Action: ['s3:GetObject'], - Resource: [`${bucketArn}/*`], // Give Cloudfront access to the entire bucket. - }, - ], - }) - ), - }) - - const distribution = new aws.cloudfront.Distribution(`cdn-static-${staticUnit}`, { - origins: [ - { - domainName: siteBucket.bucketRegionalDomainName, - originId: siteBucket.arn, - s3OriginConfig: { - originAccessIdentity: originAccessIdentity.cloudfrontAccessIdentityPath, - }, - }, - ], - enabled: true, - viewerCertificate: { - cloudfrontDefaultCertificate: true, - }, - defaultCacheBehavior: { - allowedMethods: ['DELETE', 'GET', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT'], - cachedMethods: ['GET', 'HEAD'], - targetOriginId: siteBucket.arn, - forwardedValues: { - queryString: false, - cookies: { - forward: 'none', - }, - }, - viewerProtocolPolicy: 'allow-all', - minTtl: 0, - defaultTtl: 3600, - maxTtl: 86400, - }, - restrictions: { - geoRestriction: { - restrictionType: 'none', - }, - }, - defaultRootObject: indexDocument, - }) - - return distribution.domainName - } else { - // Create an S3 Bucket Policy to allow public read of all objects in bucket - // This reusable function can be pulled out into its own module - function publicReadPolicyForBucket(bucketName) { - return JSON.stringify({ - Version: '2012-10-17', - Statement: [ - { - Effect: 'Allow', - Principal: '*', - Action: ['s3:GetObject'], - Resource: [ - `arn:aws:s3:::${bucketName}/*`, // policy refers to bucket name explicitly - ], - }, - ], - }) - } - - // Set the access policy for the bucket so all objects are readable - let bucketPolicy = new aws.s3.BucketPolicy('bucketPolicy', { - bucket: siteBucket.bucket, // depends on siteBucket -- see explanation below - policy: siteBucket.bucket.apply(publicReadPolicyForBucket), - // transform the siteBucket.bucket output property -- see explanation below - }) - - return siteBucket.websiteEndpoint - } } const createAllObjects = (staticUnit, siteBucket, prefixPath = '') => { diff --git a/pkg/infra/pulumi_aws/index.ts.tmpl b/pkg/infra/pulumi_aws/index.ts.tmpl index aad3bb4df..9f611d488 100755 --- a/pkg/infra/pulumi_aws/index.ts.tmpl +++ b/pkg/infra/pulumi_aws/index.ts.tmpl @@ -13,10 +13,17 @@ import * as awsx from '@pulumi/awsx' import * as k8s from '@pulumi/kubernetes' import { v4 as uuidv4 } from "uuid"; import {CloudCCLib, kloConfig} from './deploylib' -import {createStaticS3Website} from './iac/static_s3_website' import {EksExecUnitArgs, EksExecUnit} from './iac/eks' import {CockroachDB} from './iac/cockroachdb' +{{- if .StaticUnits}} +import {createStaticS3Website} from './iac/static_s3_website' +{{end}} + +{{- if .CloudfrontDistributions}} +import { Cloudfront } from './iac/cloudfront' +{{end}} + export = async () => { const minimumNodeVersion = 16 const nodeVersionMatch = process.version.match(/^v(\d+)/) @@ -136,13 +143,14 @@ export = async () => { {{range $unit := .StaticUnits}} staticUnitUrls.push(createStaticS3Website( - "{{$unit.Name}}", "{{$unit.IndexDocument}}", {{jsonPretty $unit.Params | indent 4}}) - ) + "{{$unit.Name}}", "{{$unit.IndexDocument}}", "{{$unit.ContentDeliveryNetwork.Id}}", cloudLib + )) {{end}} - const apprunnerUrls = arUrls.filter(url => { return url != null});; const frontendUrls = staticUnitUrls; + const apprunnerUrls = arUrls.filter(url => { return url != null});; + const gatewayUrls: any[] = []; {{- range .Gateways}} @@ -158,6 +166,9 @@ export = async () => { const apiUrls = gatewayUrls; + {{- if .CloudfrontDistributions}} + new Cloudfront(cloudLib, {{jsonPretty .CloudfrontDistributions | indent 4}}) + {{end}} {{range $event := .PubSubs}} cloudLib.createTopic( diff --git a/pkg/infra/pulumi_aws/plugin_iac.go b/pkg/infra/pulumi_aws/plugin_iac.go index 2c20434e9..99511ace7 100644 --- a/pkg/infra/pulumi_aws/plugin_iac.go +++ b/pkg/infra/pulumi_aws/plugin_iac.go @@ -119,6 +119,14 @@ func (p Plugin) Transform(result *core.CompilationResult, deps *core.Dependencie } } + if len(data.CloudfrontDistributions) > 0 { + addFile("iac/cloudfront.ts") + } + + if len(data.StaticUnits) > 0 { + addFile("iac/static_s3_website.ts") + } + addFile("deploylib.ts") addFile("package.json") addFile("tsconfig.json") @@ -126,7 +134,6 @@ func (p Plugin) Transform(result *core.CompilationResult, deps *core.Dependencie addFile("iac/memorydb.ts") addFile("iac/eks.ts") addFile("iac/kubernetes.ts") - addFile("iac/static_s3_website.ts") addFile("iac/cockroachdb.ts") addFile("iac/analytics.ts") addFile("iac/load_balancing.ts") diff --git a/pkg/provider/aws/config.go b/pkg/provider/aws/config.go index cd4fb4b38..95b7b981b 100644 --- a/pkg/provider/aws/config.go +++ b/pkg/provider/aws/config.go @@ -4,6 +4,7 @@ import ( "github.com/klothoplatform/klotho/pkg/config" "github.com/klothoplatform/klotho/pkg/core" "github.com/klothoplatform/klotho/pkg/provider" + "github.com/klothoplatform/klotho/pkg/provider/aws/resources" ) type ( @@ -15,7 +16,8 @@ type ( TemplateData struct { provider.TemplateData TemplateConfig - UseVPC bool + UseVPC bool + CloudfrontDistributions []*resources.CloudfrontDistribution } ) @@ -30,6 +32,17 @@ func (t *TemplateData) Key() core.ResourceKey { } } +func NewTemplateData(config *config.Application) *TemplateData { + return &TemplateData{ + TemplateConfig: TemplateConfig{ + TemplateConfig: provider.TemplateConfig{ + AppName: config.AppName, + }, + PayloadsBucketName: SanitizeS3BucketName(config.AppName), + }, + } +} + func (c *AWS) Name() string { return "aws" } // Enums for the types we allow in the aws provider so that we can reuse the same string within the provider diff --git a/pkg/provider/aws/infra_template.go b/pkg/provider/aws/infra_template.go index 7bca5ca57..58ef5b8ee 100644 --- a/pkg/provider/aws/infra_template.go +++ b/pkg/provider/aws/infra_template.go @@ -6,24 +6,20 @@ import ( "github.com/klothoplatform/klotho/pkg/infra/kubernetes" "github.com/klothoplatform/klotho/pkg/multierr" "github.com/klothoplatform/klotho/pkg/provider" + "github.com/klothoplatform/klotho/pkg/provider/aws/resources" "github.com/pkg/errors" ) func (a *AWS) Transform(result *core.CompilationResult, deps *core.Dependencies) error { var errs multierr.Error - data := &TemplateData{ - TemplateConfig: TemplateConfig{ - TemplateConfig: provider.TemplateConfig{ - AppName: a.Config.AppName, - }, - PayloadsBucketName: SanitizeS3BucketName(a.Config.AppName), - }, - } + data := NewTemplateData(a.Config) a.Config.UpdateForResources(result.Resources()) data.Results = result + a.GenerateCloudfrontDistributions(data, result) + helmCharts := result.GetResourcesOfType(kubernetes.HelmChartKind) for _, res := range result.Resources() { @@ -93,10 +89,10 @@ func (a *AWS) Transform(result *core.CompilationResult, deps *core.Dependencies) case *core.StaticUnit: cfg := a.Config.GetStaticUnit(key.Name) unit := provider.StaticUnit{ - Name: res.Name, - Type: res.Type(), - IndexDocument: res.IndexDocument, - Params: cfg.InfraParams, + Name: res.Name, + Type: res.Type(), + IndexDocument: res.IndexDocument, + ContentDeliveryNetwork: cfg.ContentDeliveryNetwork, } data.StaticUnits = append(data.StaticUnits, unit) @@ -172,3 +168,40 @@ func (a *AWS) Transform(result *core.CompilationResult, deps *core.Dependencies) result.Add(data) return errs.ErrOrNil() } + +func (a *AWS) GenerateCloudfrontDistributions(data *TemplateData, result *core.CompilationResult) { + cloudfrontMap := make(map[string][]core.CloudResource) + for _, res := range result.Resources() { + key := res.Key() + switch res.(type) { + case *core.Gateway: + cfg := a.Config.GetExposed(key.Name) + cfId := cfg.ContentDeliveryNetwork.Id + if cfId != "" { + cf, ok := cloudfrontMap[cfId] + if ok { + cloudfrontMap[cfId] = append(cf, res) + } else { + cloudfrontMap[cfId] = []core.CloudResource{res} + } + } + case *core.StaticUnit: + cfg := a.Config.GetStaticUnit(key.Name) + cfId := cfg.ContentDeliveryNetwork.Id + if cfId != "" { + cf, ok := cloudfrontMap[cfId] + if ok { + cloudfrontMap[cfId] = append(cf, res) + } else { + cloudfrontMap[cfId] = []core.CloudResource{res} + } + } + } + } + + for name, keys := range cloudfrontMap { + cf := resources.CreateCloudfrontDistribution(keys) + cf.Id = a.Config.AppName + "-" + name + data.CloudfrontDistributions = append(data.CloudfrontDistributions, cf) + } +} diff --git a/pkg/provider/aws/infra_template_test.go b/pkg/provider/aws/infra_template_test.go index edffa51ff..50228279b 100644 --- a/pkg/provider/aws/infra_template_test.go +++ b/pkg/provider/aws/infra_template_test.go @@ -7,6 +7,7 @@ import ( "github.com/klothoplatform/klotho/pkg/core" "github.com/klothoplatform/klotho/pkg/infra/kubernetes" "github.com/klothoplatform/klotho/pkg/provider" + "github.com/klothoplatform/klotho/pkg/provider/aws/resources" "github.com/stretchr/testify/assert" ) @@ -126,3 +127,197 @@ func TestInfraTemplateModification(t *testing.T) { }) } } + +func Test_GenerateCloudfrontDistributions(t *testing.T) { + cases := []struct { + name string + results []core.CloudResource + cfg config.Application + data TemplateData + want []*resources.CloudfrontDistribution + }{ + { + name: "simple gateway test", + results: []core.CloudResource{ + &core.Gateway{ + Name: "gw", + GWType: core.GatewayKind, + Routes: []core.Route{{Path: "/"}}, + }, + }, + cfg: config.Application{ + Provider: "aws", + AppName: "app", + Exposed: map[string]*config.Expose{ + "gw": { + Type: "apigateway", + ContentDeliveryNetwork: config.ContentDeliveryNetwork{Id: "distro"}, + }, + }, + }, + data: TemplateData{ + TemplateData: provider.TemplateData{ + Gateways: []provider.Gateway{ + {Name: "gw"}, + }, + }, + }, + want: []*resources.CloudfrontDistribution{ + { + Id: "app-distro", + Origins: []core.ResourceKey{ + {Kind: core.GatewayKind, Name: "gw"}, + }, + }, + }, + }, + { + name: "simple static unit test", + results: []core.CloudResource{ + &core.StaticUnit{ + Name: "su", + }, + }, + cfg: config.Application{ + Provider: "aws", + AppName: "app", + StaticUnit: map[string]*config.StaticUnit{ + "su": { + Type: "apigateway", + ContentDeliveryNetwork: config.ContentDeliveryNetwork{Id: "distro"}, + }, + }, + }, + data: TemplateData{ + TemplateData: provider.TemplateData{ + StaticUnits: []provider.StaticUnit{ + {Name: "su"}, + }, + }, + }, + want: []*resources.CloudfrontDistribution{ + { + Id: "app-distro", + Origins: []core.ResourceKey{ + {Kind: core.StaticUnitKind, Name: "su"}, + }, + }, + }, + }, + { + name: "simple static unit with index document test", + results: []core.CloudResource{ + &core.StaticUnit{ + Name: "su", + IndexDocument: "index.html", + }, + }, + cfg: config.Application{ + Provider: "aws", + AppName: "app", + StaticUnit: map[string]*config.StaticUnit{ + "su": { + Type: "apigateway", + ContentDeliveryNetwork: config.ContentDeliveryNetwork{Id: "distro"}, + }, + }, + }, + data: TemplateData{ + TemplateData: provider.TemplateData{ + StaticUnits: []provider.StaticUnit{ + {Name: "su"}, + }, + }, + }, + want: []*resources.CloudfrontDistribution{ + { + Id: "app-distro", + Origins: []core.ResourceKey{ + {Kind: core.StaticUnitKind, Name: "su"}, + }, + DefaultRootObject: "index.html", + }, + }, + }, + { + name: "static unit and gw test", + results: []core.CloudResource{ + &core.StaticUnit{ + Name: "su", + }, + &core.Gateway{ + Name: "gw", + GWType: core.GatewayKind, + Routes: []core.Route{{Path: "/"}}, + }, + }, + cfg: config.Application{ + Provider: "aws", + AppName: "app", + StaticUnit: map[string]*config.StaticUnit{ + "su": { + Type: "apigateway", + ContentDeliveryNetwork: config.ContentDeliveryNetwork{Id: "distro"}, + }, + }, + Exposed: map[string]*config.Expose{ + "gw": { + Type: "apigateway", + ContentDeliveryNetwork: config.ContentDeliveryNetwork{Id: "distro"}, + }, + }, + }, + data: TemplateData{ + TemplateData: provider.TemplateData{ + StaticUnits: []provider.StaticUnit{ + {Name: "su"}, + }, + Gateways: []provider.Gateway{ + {Name: "gw"}, + }, + }, + }, + want: []*resources.CloudfrontDistribution{ + { + Id: "app-distro", + Origins: []core.ResourceKey{ + {Kind: core.GatewayKind, Name: "gw"}, + {Kind: core.StaticUnitKind, Name: "su"}, + }, + }, + }, + }, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + result := core.CompilationResult{} + + result.AddAll(tt.results) + + aws := AWS{ + Config: &tt.cfg, + } + aws.GenerateCloudfrontDistributions(&tt.data, &result) + for _, cf := range tt.want { + found := false + for _, gotCf := range tt.data.CloudfrontDistributions { + if gotCf.Id == cf.Id { + found = true + assert.Equal(cf.DefaultRootObject, gotCf.DefaultRootObject) + for _, cfOrigin := range cf.Origins { + originFound := false + for _, gotCfOrigin := range gotCf.Origins { + if cfOrigin.String() == gotCfOrigin.String() { + originFound = true + } + } + assert.True(originFound) + } + } + } + assert.True(found) + } + }) + } +} diff --git a/pkg/provider/aws/resources/cloudfront.go b/pkg/provider/aws/resources/cloudfront.go new file mode 100644 index 000000000..d405cba8f --- /dev/null +++ b/pkg/provider/aws/resources/cloudfront.go @@ -0,0 +1,31 @@ +package resources + +import ( + "github.com/klothoplatform/klotho/pkg/core" + "go.uber.org/zap" +) + +type CloudfrontDistribution struct { + Id string + Origins []core.ResourceKey + DefaultRootObject string +} + +func CreateCloudfrontDistribution(resources []core.CloudResource) *CloudfrontDistribution { + distribution := &CloudfrontDistribution{} + + for _, res := range resources { + switch res.Key().Kind { + case core.GatewayKind: + distribution.Origins = append(distribution.Origins, res.Key()) + case core.StaticUnitKind: + sunit := res.(*core.StaticUnit) + distribution.Origins = append(distribution.Origins, res.Key()) + if distribution.DefaultRootObject != "" { + zap.S().Warn("Cannot have a cdn with multiple root objects") + } + distribution.DefaultRootObject = sunit.IndexDocument + } + } + return distribution +} diff --git a/pkg/provider/infra_template.go b/pkg/provider/infra_template.go index feba9a14a..1614953d7 100644 --- a/pkg/provider/infra_template.go +++ b/pkg/provider/infra_template.go @@ -68,10 +68,10 @@ type ( } StaticUnit struct { - Name string - Type string - IndexDocument string - Params config.InfraParams + Name string + Type string + IndexDocument string + ContentDeliveryNetwork config.ContentDeliveryNetwork } Gateway struct {