diff --git a/api/v1alpha1/backendsecurity_policy.go b/api/v1alpha1/backendsecurity_policy.go index 296f0e3551..f21432a5d3 100644 --- a/api/v1alpha1/backendsecurity_policy.go +++ b/api/v1alpha1/backendsecurity_policy.go @@ -165,6 +165,16 @@ type GCPServiceAccountImpersonationConfig struct { // BackendSecurityPolicyGCPCredentials contains the supported authentication mechanisms to access GCP. type BackendSecurityPolicyGCPCredentials struct { + // ProjectName is the GCP project name. + // + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=1 + ProjectName string `json:"projectName"` + // Region is the GCP region associated with the policy. + // + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=1 + Region string `json:"region"` // WorkLoadIdentityFederationConfig is the configuration for the GCP Workload Identity Federation. // // +kubebuilder:validation:Required diff --git a/examples/basic/basic.yaml b/examples/basic/basic.yaml index 15de6949c7..e949598720 100644 --- a/examples/basic/basic.yaml +++ b/examples/basic/basic.yaml @@ -174,6 +174,8 @@ metadata: spec: type: GCPCredentials gcpCredentials: + projectName: GCP_PROJECT_NAME # Replace with your GCP project name + region: GCP_REGION # Replace with your GCP region workLoadIdentityFederationConfig: projectID: GCP_PROJECT_ID # Replace with your GCP project ID workloadIdentityPoolName: GCP_WORKLOAD_IDENTITY_POOL # Replace with your workload identity pool name @@ -187,6 +189,9 @@ spec: clientSecret: name: envoy-ai-gateway-basic-gcp-client-secret namespace: default + serviceAccountImpersonation: + serviceAccountName: SERVICE_ACCOUNT_NAME # Replace with the service account name to impersonate + serviceAccountProjectName: GCP_SERVICE_ACCOUNT_PROJECT_NAME # Replace with the project name of the service account --- apiVersion: gateway.envoyproxy.io/v1alpha1 kind: Backend diff --git a/filterapi/filterconfig.go b/filterapi/filterconfig.go index 00dec6e31f..a498d7d560 100644 --- a/filterapi/filterconfig.go +++ b/filterapi/filterconfig.go @@ -137,9 +137,18 @@ type VersionedAPISchema struct { type APISchemaName string const ( - APISchemaOpenAI APISchemaName = "OpenAI" - APISchemaAWSBedrock APISchemaName = "AWSBedrock" + // APISchemaOpenAI represents the standard OpenAI API schema. + APISchemaOpenAI APISchemaName = "OpenAI" + // APISchemaAWSBedrock represents the AWS Bedrock API schema. + APISchemaAWSBedrock APISchemaName = "AWSBedrock" + // APISchemaAzureOpenAI represents the Azure OpenAI API schema. APISchemaAzureOpenAI APISchemaName = "AzureOpenAI" + // APISchemaGCPVertexAI represents the Google Cloud Gemini API schema. + // Used for Gemini models hosted on Google Cloud Vertex AI. + APISchemaGCPVertexAI APISchemaName = "GCPVertexAI" + // APISchemaGCPAnthropic represents the Google Cloud Anthropic API schema. + // Used for Claude models hosted on Google Cloud Vertex AI. + APISchemaGCPAnthropic APISchemaName = "GCPAnthropic" ) // HeaderMatch is an alias for HTTPHeaderMatch of the Gateway API. @@ -226,13 +235,19 @@ type AzureAuth struct { AccessToken string `json:"accessToken"` } -// GCPAuth defines the file containing GCP credential that will be mounted to the external proc. +// GCPAuth defines the GCP authentication configuration used to access Google Cloud AI services. type GCPAuth struct { // AccessToken is the access token as a literal string. + // This token is obtained through GCP Workload Identity Federation and service account impersonation. + // The token is automatically rotated by the BackendSecurityPolicy controller before expiration. AccessToken string `json:"accessToken"` // Region is the GCP region to use for the request. + // This is used in URL path templates when making requests to GCP Vertex AI endpoints. + // Examples: "us-central1", "europe-west4" Region string `json:"region"` // ProjectName is the GCP project name to use for the request. + // This is used in URL path templates when making requests to GCP Vertex AI endpoints. + // This should be the project where Vertex AI APIs are enabled. ProjectName string `json:"projectName"` } diff --git a/go.mod b/go.mod index 91572ab8b1..bc5dbf78ac 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,8 @@ require ( go.uber.org/zap v1.27.0 golang.org/x/exp v0.0.0-20250530174510-65e920069ea6 golang.org/x/oauth2 v0.30.0 + google.golang.org/api v0.223.0 + google.golang.org/genai v1.13.0 google.golang.org/grpc v1.73.0 google.golang.org/protobuf v1.36.6 k8s.io/api v0.33.1 @@ -47,6 +49,10 @@ require ( 4d63.com/gochecknoglobals v0.2.2 // indirect al.essio.dev/pkg/shellescape v1.5.1 // indirect cel.dev/expr v0.23.1 // indirect + cloud.google.com/go v0.116.0 // indirect + cloud.google.com/go/auth v0.15.0 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.7 // indirect + cloud.google.com/go/compute/metadata v0.6.0 // indirect dario.cat/mergo v1.0.1 // indirect github.com/4meepo/tagalign v1.4.2 // indirect github.com/Abirdcfly/dupword v0.1.3 // indirect @@ -214,8 +220,11 @@ require ( github.com/google/go-github/v56 v56.0.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/licensecheck v0.3.1 // indirect + github.com/google/s2a-go v0.1.9 // indirect github.com/google/safetext v0.0.0-20220905092116-b49f7bc46da2 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect + github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/gordonklaus/ineffassign v0.1.0 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect diff --git a/go.sum b/go.sum index 10464416a3..3bc574ed3b 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,14 @@ cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= +cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= +cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= +cloud.google.com/go/auth v0.15.0 h1:Ly0u4aA5vG/fsSsxu98qCQBemXtAtJf+95z9HK+cxps= +cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8= +cloud.google.com/go/auth/oauth2adapt v0.2.7 h1:/Lc7xODdqcEw8IrZ9SvwnlLX6j9FHQM74z6cBk9Rw6M= +cloud.google.com/go/auth/oauth2adapt v0.2.7/go.mod h1:NTbTTzfvPl1Y3V1nPpOgl2w6d/FjO7NNUQaWSox6ZMc= +cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= +cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s= dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= @@ -612,6 +620,8 @@ github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXi github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8= github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/safetext v0.0.0-20220905092116-b49f7bc46da2 h1:SJ+NtwL6QaZ21U+IrK7d0gGgpjGGvd2kz+FzTHVzdqI= github.com/google/safetext v0.0.0-20220905092116-b49f7bc46da2/go.mod h1:Tv1PlzqC9t8wNnpPdctvtSUOPUUg4SHeE6vR1Ir2hmg= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= @@ -620,7 +630,11 @@ github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw= +github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= +github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q= +github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY= github.com/googleapis/gnostic v0.1.0/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY= github.com/googleapis/gnostic v0.3.1/go.mod h1:on+2t9HRStVgn95RSsFWFz+6Q0Snyqv1awfrALZdbtU= @@ -1527,9 +1541,13 @@ gomodules.xyz/jsonpatch/v2 v2.0.1/go.mod h1:IhYNNY4jnS53ZnfE4PAmpKtDpTCj1JFXc+3m gomodules.xyz/jsonpatch/v2 v2.5.0 h1:JELs8RLM12qJGXU4u/TO3V25KW8GreMKl9pdkk14RM0= gomodules.xyz/jsonpatch/v2 v2.5.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= +google.golang.org/api v0.223.0 h1:JUTaWEriXmEy5AhvdMgksGGPEFsYfUKaPEYXd4c3Wvc= +google.golang.org/api v0.223.0/go.mod h1:C+RS7Z+dDwds2b+zoAk5hN/eSfsiCn0UDrYof/M4d2M= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genai v1.13.0 h1:LRhwx5PU+bXhfnXyPEHu2kt9yc+MpvuYbajxSorOJjg= +google.golang.org/genai v1.13.0/go.mod h1:QPj5NGJw+3wEOHg+PrsWwJKvG6UC84ex5FR7qAYsN/M= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= diff --git a/internal/apischema/gcp/gcp.go b/internal/apischema/gcp/gcp.go new file mode 100644 index 0000000000..f5247796bd --- /dev/null +++ b/internal/apischema/gcp/gcp.go @@ -0,0 +1,34 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package gcp + +import "google.golang.org/genai" + +type GenerateContentRequest struct { + // Contains the multipart content of a message. + // + // https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L858 + Contents []genai.Content `json:"contents"` + // Tool details of a tool that the model may use to generate a response. + // + // https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L1406 + Tools []genai.Tool `json:"tools"` + // Optional. Tool config. + // This config is shared for all tools provided in the request. + // + // https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L1466 + ToolConfig *genai.ToolConfig `json:"tool_config,omitempty"` + // Optional. Generation config. + // You can find API default values and more details at https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#generationconfig + // and https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/content-generation-parameters. + GenerationConfig *genai.GenerationConfig `json:"generation_config,omitempty"` + // Optional. Instructions for the model to steer it toward better performance. + // For example, "Answer as concisely as possible" or "Don't use technical + // terms in your response". + // + // https://github.com/googleapis/go-genai/blob/6a8184fcaf8bf15f0c566616a7b356560309be9b/types.go#L858 + SystemInstruction *genai.Content `json:"system_instruction,omitempty"` +} diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index d56f1bc954..b78bcc807d 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -154,6 +154,26 @@ func (c *BackendSecurityPolicyController) rotateCredential(ctx context.Context, if err != nil { return ctrl.Result{}, err } + case aigv1a1.BackendSecurityPolicyTypeGCPCredentials: + if err = validateGCPCredentialsParams(bsp.Spec.GCPCredentials); err != nil { + return ctrl.Result{}, fmt.Errorf("invalid GCP credentials configuration: %w", err) + } + + // For GCP, OIDC is currently the only supported authentication method. + // If additional methods are added, validate that OIDC is used before calling getBackendSecurityPolicyAuthOIDC. + oidc := getBackendSecurityPolicyAuthOIDC(bsp.Spec) + + // Create the OIDC token provider that will be used to get tokens from the OIDC provider. + var oidcProvider tokenprovider.TokenProvider + oidcProvider, err = tokenprovider.NewOidcTokenProvider(ctx, c.client, oidc) + if err != nil { + return ctrl.Result{}, fmt.Errorf("failed to initialize OIDC provider: %w", err) + } + rotator, err = rotators.NewGCPOIDCTokenRotator(c.client, c.logger, *bsp, preRotationWindow, oidcProvider) + if err != nil { + return ctrl.Result{}, err + } + default: err = fmt.Errorf("backend security type %s does not support OIDC token exchange", bsp.Spec.Type) c.logger.Error(err, "unsupported backend security type", "namespace", bsp.Namespace, "name", bsp.Name) @@ -207,6 +227,10 @@ func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *e return &spec.AzureCredentials.OIDCExchangeToken.OIDC } return nil + case aigv1a1.BackendSecurityPolicyTypeGCPCredentials: + if spec.GCPCredentials != nil { + return &spec.GCPCredentials.WorkLoadIdentityFederationConfig.WorkloadIdentityProvider.OIDCProvider.OIDC + } } return nil } @@ -238,3 +262,28 @@ func (c *BackendSecurityPolicyController) updateBackendSecurityPolicyStatus(ctx c.logger.Error(err, "failed to update BackendSecurityPolicy status") } } + +func validateGCPCredentialsParams(gcpCreds *aigv1a1.BackendSecurityPolicyGCPCredentials) error { + if gcpCreds == nil { + return fmt.Errorf("invalid backend security policy, gcp credentials cannot be nil") + } + if gcpCreds.ProjectName == "" { + return fmt.Errorf("invalid GCP credentials configuration: projectName cannot be empty") + } + if gcpCreds.Region == "" { + return fmt.Errorf("invalid GCP credentials configuration: region cannot be empty") + } + + wifConfig := gcpCreds.WorkLoadIdentityFederationConfig + if wifConfig.ProjectID == "" { + return fmt.Errorf("invalid GCP Workload Identity Federation configuration: projectID cannot be empty") + } + if wifConfig.WorkloadIdentityPoolName == "" { + return fmt.Errorf("invalid GCP Workload Identity Federation configuration: workloadIdentityPoolName cannot be empty") + } + if wifConfig.WorkloadIdentityProvider.Name == "" { + return fmt.Errorf("invalid GCP Workload Identity Federation configuration: workloadIdentityProvider.name cannot be empty") + } + + return nil +} diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index d155f693e1..a778f34fac 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -11,6 +11,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -396,6 +397,30 @@ func TestBackendSecurityPolicyController_GetBackendSecurityPolicyAuthOIDC(t *tes }) require.NotNil(t, oidcAWS) require.Equal(t, "some-client-id", oidcAWS.ClientID) + + // GCP type with OIDC defined. + oidcGCP := getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeGCPCredentials, + GCPCredentials: &aigv1a1.BackendSecurityPolicyGCPCredentials{ + ProjectName: "fake-project-name", + Region: "fake-region", + WorkLoadIdentityFederationConfig: aigv1a1.GCPWorkLoadIdentityFederationConfig{ + ProjectID: "fake-project-id", + WorkloadIdentityProvider: aigv1a1.GCPWorkloadIdentityProvider{ + Name: "fake-workload-identity-provider-name", + OIDCProvider: aigv1a1.BackendSecurityPolicyOIDC{ + OIDC: egv1a1.OIDC{ + ClientID: "some-client-id", + }, + }, + }, + WorkloadIdentityPoolName: "fake-workload-identity-pool-name", + ServiceAccountImpersonation: nil, + }, + }, + }) + require.NotNil(t, oidcGCP) + require.Equal(t, "some-client-id", oidcGCP.ClientID) } func TestNewBackendSecurityPolicyController_ReconcileAzureMissingSecret(t *testing.T) { @@ -645,3 +670,162 @@ func TestBackendSecurityPolicyController_ExecutionRotation(t *testing.T) { require.NoError(t, err) require.Less(t, res.RequeueAfter, time.Hour) } + +func TestValidateGCPCredentialsParams(t *testing.T) { + tests := []struct { + name string + input *aigv1a1.BackendSecurityPolicyGCPCredentials + wantError string + }{ + { + name: "nil credentials", + input: nil, + wantError: "invalid backend security policy, gcp credentials cannot be nil", + }, + { + name: "empty project name", + input: &aigv1a1.BackendSecurityPolicyGCPCredentials{ + ProjectName: "", + Region: "us-central1", + WorkLoadIdentityFederationConfig: aigv1a1.GCPWorkLoadIdentityFederationConfig{ + ProjectID: "pid", + WorkloadIdentityPoolName: "pool", + WorkloadIdentityProvider: aigv1a1.GCPWorkloadIdentityProvider{Name: "provider"}, + }, + }, + wantError: "invalid GCP credentials configuration: projectName cannot be empty", + }, + { + name: "empty region", + input: &aigv1a1.BackendSecurityPolicyGCPCredentials{ + ProjectName: "proj", + Region: "", + WorkLoadIdentityFederationConfig: aigv1a1.GCPWorkLoadIdentityFederationConfig{ + ProjectID: "pid", + WorkloadIdentityPoolName: "pool", + WorkloadIdentityProvider: aigv1a1.GCPWorkloadIdentityProvider{Name: "provider"}, + }, + }, + wantError: "invalid GCP credentials configuration: region cannot be empty", + }, + { + name: "empty projectID", + input: &aigv1a1.BackendSecurityPolicyGCPCredentials{ + ProjectName: "proj", + Region: "us-central1", + WorkLoadIdentityFederationConfig: aigv1a1.GCPWorkLoadIdentityFederationConfig{ + ProjectID: "", + WorkloadIdentityPoolName: "pool", + WorkloadIdentityProvider: aigv1a1.GCPWorkloadIdentityProvider{Name: "provider"}, + }, + }, + wantError: "invalid GCP Workload Identity Federation configuration: projectID cannot be empty", + }, + { + name: "empty workloadIdentityPoolName", + input: &aigv1a1.BackendSecurityPolicyGCPCredentials{ + ProjectName: "proj", + Region: "us-central1", + WorkLoadIdentityFederationConfig: aigv1a1.GCPWorkLoadIdentityFederationConfig{ + ProjectID: "pid", + WorkloadIdentityPoolName: "", + WorkloadIdentityProvider: aigv1a1.GCPWorkloadIdentityProvider{Name: "provider"}, + }, + }, + wantError: "invalid GCP Workload Identity Federation configuration: workloadIdentityPoolName cannot be empty", + }, + { + name: "empty workloadIdentityProvider name", + input: &aigv1a1.BackendSecurityPolicyGCPCredentials{ + ProjectName: "proj", + Region: "us-central1", + WorkLoadIdentityFederationConfig: aigv1a1.GCPWorkLoadIdentityFederationConfig{ + ProjectID: "pid", + WorkloadIdentityPoolName: "pool", + WorkloadIdentityProvider: aigv1a1.GCPWorkloadIdentityProvider{Name: ""}, + }, + }, + wantError: "invalid GCP Workload Identity Federation configuration: workloadIdentityProvider.name cannot be empty", + }, + { + name: "valid credentials", + input: &aigv1a1.BackendSecurityPolicyGCPCredentials{ + ProjectName: "proj", + Region: "us-central1", + WorkLoadIdentityFederationConfig: aigv1a1.GCPWorkLoadIdentityFederationConfig{ + ProjectID: "pid", + WorkloadIdentityPoolName: "pool", + WorkloadIdentityProvider: aigv1a1.GCPWorkloadIdentityProvider{Name: "provider"}, + }, + }, + wantError: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateGCPCredentialsParams(tt.input) + if tt.wantError == "" { + require.NoError(t, err) + } else { + require.Error(t, err) + require.Equal(t, tt.wantError, err.Error()) + } + }) + } +} + +func TestBackendSecurityPolicyController_RotateCredential_GCPCredentials(t *testing.T) { + bspNamespace := "default" + bspName := "test-gcp-policy" + tests := []struct { + name string + bsp *aigv1a1.BackendSecurityPolicySpec + expectedErrMsg string + }{ + { + name: "nil gcp credentials", + bsp: &aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeGCPCredentials, + GCPCredentials: nil, + }, + expectedErrMsg: "invalid backend security policy, gcp credentials cannot be nil", + }, + { + name: "empty gcp credentials", + bsp: &aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeGCPCredentials, + GCPCredentials: &aigv1a1.BackendSecurityPolicyGCPCredentials{}, + }, + expectedErrMsg: "invalid GCP credentials configuration: projectName cannot be empty", + }, + } + + c := NewBackendSecurityPolicyController(fake.NewFakeClient(), fake2.NewClientset(), ctrl.Log, nil) + + for _, tt := range tests { + bsp := &aigv1a1.BackendSecurityPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: bspName, + Namespace: bspNamespace, + }, + Spec: *tt.bsp, + } + t.Run(tt.name, func(t *testing.T) { + // Initial rotation should create a new secret + res, err := c.rotateCredential(context.Background(), bsp) + + switch { + case tt.expectedErrMsg != "" && err == nil: + t.Errorf("expected error but got none, expected: %s", tt.expectedErrMsg) + case tt.expectedErrMsg == "" && err != nil: + t.Errorf("unexpected error: %v", err) + case tt.expectedErrMsg != "" && err != nil: + strings.Contains(err.Error(), tt.expectedErrMsg) + default: + require.NoError(t, err) + require.NotZero(t, res.RequeueAfter) + } + }) + } +} diff --git a/internal/controller/gateway.go b/internal/controller/gateway.go index 7fc387e814..40ea5ccd54 100644 --- a/internal/controller/gateway.go +++ b/internal/controller/gateway.go @@ -367,6 +367,23 @@ func (c *GatewayController) bspToFilterAPIBackendAuth(ctx context.Context, names return &filterapi.BackendAuth{ AzureAuth: &filterapi.AzureAuth{AccessToken: azureAccessToken}, }, nil + case aigv1a1.BackendSecurityPolicyTypeGCPCredentials: + gcpCreds := backendSecurityPolicy.Spec.GCPCredentials + if gcpCreds == nil { + return nil, fmt.Errorf("GCP credentials type selected but not defined %s", backendSecurityPolicy.Name) + } + secretName := rotators.GetBSPSecretName(backendSecurityPolicy.Name) + gcpAccessToken, err := c.getSecretData(ctx, namespace, secretName, rotators.GCPAccessTokenKey) + if err != nil { + return nil, fmt.Errorf("failed to get secret %s: %w", secretName, err) + } + return &filterapi.BackendAuth{ + GCPAuth: &filterapi.GCPAuth{ + AccessToken: gcpAccessToken, + Region: gcpCreds.Region, + ProjectName: gcpCreds.ProjectName, + }, + }, nil default: return nil, fmt.Errorf("invalid backend security type %s for policy %s", backendSecurityPolicy.Spec.Type, backendSecurityPolicy.Name) diff --git a/internal/controller/gateway_test.go b/internal/controller/gateway_test.go index 21de3461e3..c4f37c1e12 100644 --- a/internal/controller/gateway_test.go +++ b/internal/controller/gateway_test.go @@ -6,6 +6,7 @@ package controller import ( + "context" "fmt" "strconv" "testing" @@ -317,6 +318,115 @@ func TestGatewayController_bspToFilterAPIBackendAuth(t *testing.T) { } } +func TestGatewayController_bspToFilterAPIBackendAuth_ErrorCases(t *testing.T) { + fakeClient := requireNewFakeClientWithIndexes(t) + c := NewGatewayController(fakeClient, fake2.NewClientset(), ctrl.Log, + "envoy-gateway-system", "/foo/bar/uds.sock", "docker.io/envoyproxy/ai-gateway-extproc:latest") + + ctx := context.Background() + namespace := "test-namespace" + + tests := []struct { + name string + bspName string + setupBSP *aigv1a1.BackendSecurityPolicy + setupSecret *corev1.Secret + expectedError string + }{ + { + name: "missing backend security policy", + bspName: "missing-bsp", + expectedError: "failed to get BackendSecurityPolicy missing-bsp", + }, + { + name: "api key type with missing secret", + bspName: "api-key-bsp", + setupBSP: &aigv1a1.BackendSecurityPolicy{ + ObjectMeta: metav1.ObjectMeta{Name: "api-key-bsp", Namespace: namespace}, + Spec: aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAPIKey, + APIKey: &aigv1a1.BackendSecurityPolicyAPIKey{ + SecretRef: &gwapiv1.SecretObjectReference{ + Name: "missing-secret", + }, + }, + }, + }, + expectedError: "failed to get secret missing-secret", + }, + { + name: "aws credentials without credentials defined", + bspName: "aws-no-creds-bsp", + setupBSP: &aigv1a1.BackendSecurityPolicy{ + ObjectMeta: metav1.ObjectMeta{Name: "aws-no-creds-bsp", Namespace: namespace}, + Spec: aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, + AWSCredentials: nil, // This should trigger the error + }, + }, + expectedError: "AWSCredentials type selected but not defined", + }, + { + name: "aws credentials with credentials file missing secret", + bspName: "aws-creds-file-bsp", + setupBSP: &aigv1a1.BackendSecurityPolicy{ + ObjectMeta: metav1.ObjectMeta{Name: "aws-creds-file-bsp", Namespace: namespace}, + Spec: aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, + AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ + Region: "us-west-2", + CredentialsFile: &aigv1a1.AWSCredentialsFile{ + SecretRef: &gwapiv1.SecretObjectReference{ + Name: "missing-aws-secret", + }, + }, + }, + }, + }, + expectedError: "failed to get secret missing-aws-secret", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup BSP if provided + if tt.setupBSP != nil { + err := fakeClient.Create(ctx, tt.setupBSP) + require.NoError(t, err) + } + + // Setup secret if provided + if tt.setupSecret != nil { + err := fakeClient.Create(ctx, tt.setupSecret) + require.NoError(t, err) + } + + // Call the function + result, err := c.bspToFilterAPIBackendAuth(ctx, namespace, tt.bspName) + + // Verify expected error + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectedError) + require.Nil(t, result) + }) + } +} + +func TestGatewayController_GetSecretData_ErrorCases(t *testing.T) { + fakeClient := requireNewFakeClientWithIndexes(t) + c := NewGatewayController(fakeClient, fake2.NewClientset(), ctrl.Log, + "envoy-gateway-system", "/foo/bar/uds.sock", "docker.io/envoyproxy/ai-gateway-extproc:latest") + + ctx := context.Background() + namespace := "test-namespace" + + // Test missing secret + result, err := c.getSecretData(ctx, namespace, "missing-secret", "test-key") + require.Error(t, err) + require.Contains(t, err.Error(), "secrets \"missing-secret\" not found") + require.Empty(t, result) +} + func TestGatewayController_annotateGatewayPods(t *testing.T) { egNamespace := "envoy-gateway-system" gwName, gwNamepsace := "gw", "ns" diff --git a/internal/controller/rotators/gcp_oidc_token_rotator.go b/internal/controller/rotators/gcp_oidc_token_rotator.go new file mode 100644 index 0000000000..d88605e6e0 --- /dev/null +++ b/internal/controller/rotators/gcp_oidc_token_rotator.go @@ -0,0 +1,370 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package rotators + +import ( + "context" + "fmt" + "net/http" + "net/url" + "os" + "time" + + "github.com/go-logr/logr" + "golang.org/x/oauth2" + "google.golang.org/api/impersonate" + "google.golang.org/api/option" + "google.golang.org/api/sts/v1" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + + aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" + "github.com/envoyproxy/ai-gateway/filterapi" + "github.com/envoyproxy/ai-gateway/internal/controller/tokenprovider" +) + +const ( + // GCPAccessTokenKey is the key used to store GCP access token in Kubernetes secrets. + GCPAccessTokenKey = "gcpAccessToken" + GCPProjectNameKey = "projectName" + GCPRegionKey = "region" + // grantTypeTokenExchange is the OAuth 2.0 grant type for token exchange. + grantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange" //nolint:gosec + // gcpIAMScope is the OAuth scope for IAM operations in GCP. + gcpIAMScope = "https://www.googleapis.com/auth/iam" //nolint:gosec + // tokenTypeAccessToken indicates the requested token type is an access token. + tokenTypeAccessToken = "urn:ietf:params:oauth:token-type:access_token" //nolint:gosec + // tokenTypeJWT indicates the subject token type is a JWT. + tokenTypeJWT = "urn:ietf:params:oauth:token-type:jwt" //nolint:gosec + // stsTokenScope is the OAuth scope for GCP cloud platform operations. + stsTokenScope = "https://www.googleapis.com/auth/cloud-platform" //nolint:gosec +) + +// serviceAccountTokenGenerator defines a function type for generating a GCP service account access token +// using an STS token and impersonation configuration. +type serviceAccountTokenGenerator func( + ctx context.Context, + stsToken string, + saConfig aigv1a1.GCPServiceAccountImpersonationConfig, + opts ...option.ClientOption, +) (*tokenprovider.TokenExpiry, error) + +// stsTokenGenerator defines a function type for exchanging a JWT token for a GCP STS token +// using Workload Identity Federation configuration. +type stsTokenGenerator func( + ctx context.Context, + jwtToken string, + wifConfig aigv1a1.GCPWorkLoadIdentityFederationConfig, + opts ...option.ClientOption, +) (*tokenprovider.TokenExpiry, error) + +// gcpOIDCTokenRotator implements Rotator interface for GCP access token exchange. +// It handles the complete authentication flow for GCP Workload Identity Federation: +// 1. Obtaining an OIDC token from the configured provider +// 2. Exchanging the OIDC token for a GCP STS token +// 3. Using the STS token to impersonate a GCP service account +// 4. Storing the resulting access token in a Kubernetes secret +type gcpOIDCTokenRotator struct { + client client.Client // Kubernetes client for interacting with the cluster. + logger logr.Logger // Logger for recording rotator activities + // GCP Credentials configuration from BackendSecurityPolicy + gcpCredentials aigv1a1.BackendSecurityPolicyGCPCredentials + // backendSecurityPolicyName provides name of backend security policy. + backendSecurityPolicyName string + // backendSecurityPolicyNamespace provides namespace of backend security policy. + backendSecurityPolicyNamespace string + // preRotationWindow is the duration before token expiry when rotation should occur + preRotationWindow time.Duration + // oidcProvider provides the OIDC token needed for GCP Workload Identity Federation + oidcProvider tokenprovider.TokenProvider + + saTokenFunc serviceAccountTokenGenerator + stsTokenFunc stsTokenGenerator +} + +// NewGCPOIDCTokenRotator creates a new gcpOIDCTokenRotator with the given parameters. +func NewGCPOIDCTokenRotator( + client client.Client, + logger logr.Logger, + bsp aigv1a1.BackendSecurityPolicy, + preRotationWindow time.Duration, + tokenProvider tokenprovider.TokenProvider, +) (Rotator, error) { + logger = logger.WithName("gcp-token-rotator") + + if bsp.Spec.GCPCredentials == nil { + return nil, fmt.Errorf("GCP credentials are not configured in BackendSecurityPolicy %s/%s", bsp.Namespace, bsp.Name) + } + + return &gcpOIDCTokenRotator{ + client: client, + logger: logger, + gcpCredentials: *bsp.Spec.GCPCredentials, + backendSecurityPolicyName: bsp.Name, + backendSecurityPolicyNamespace: bsp.Namespace, + preRotationWindow: preRotationWindow, + oidcProvider: tokenProvider, + saTokenFunc: impersonateServiceAccount, + stsTokenFunc: exchangeJWTForSTSToken, + }, nil +} + +// IsExpired implements [Rotator.IsExpired]. +// IsExpired checks if the preRotation time is before the current time. +func (r *gcpOIDCTokenRotator) IsExpired(preRotationExpirationTime time.Time) bool { + // Use the common IsBufferedTimeExpired helper to determine if the token has expired + // A buffer of 0 means we check exactly at the pre-rotation time + return IsBufferedTimeExpired(0, preRotationExpirationTime) +} + +// GetPreRotationTime implements [Rotator.GetPreRotationTime]. +// GetPreRotationTime retrieves the pre-rotation time for GCP token. +func (r *gcpOIDCTokenRotator) GetPreRotationTime(ctx context.Context) (time.Time, error) { + // Look up the secret containing the current token + secret, err := LookupSecret(ctx, r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) + if err != nil { + if apierrors.IsNotFound(err) { + // If the secret doesn't exist, return zero time to indicate immediate rotation is needed + return time.Time{}, nil + } + return time.Time{}, fmt.Errorf("failed to lookup secret: %w", err) + } + // Extract the token expiration time from the secret's annotations + expirationTime, err := GetExpirationSecretAnnotation(secret) + if err != nil { + return time.Time{}, fmt.Errorf("failed to get expiration time from secret: %w", err) + } + + // Calculate the pre-rotation time by subtracting the pre-rotation window from the expiration time + // This ensures tokens are rotated before they expire + preRotationTime := expirationTime.Add(-r.preRotationWindow) + return preRotationTime, nil +} + +// Rotate implements [Rotator.Rotate]. +// Rotate fetches new GCP access token and updates the Kubernetes secret. +// The token rotation process follows these steps: +// 1. Obtain an OIDC token from the configured provider +// 2. Exchange the OIDC token for a GCP STS token +// 3. (If configured) Use the STS token to impersonate the specified GCP service account +// 4. Store the resulting access token in a Kubernetes secret +// Returns the expiration time of the new token and any error encountered during rotation. +func (r *gcpOIDCTokenRotator) Rotate(ctx context.Context) (time.Time, error) { + secretName := GetBSPSecretName(r.backendSecurityPolicyName) + + r.logger.Info("start rotating gcp access token", "namespace", r.backendSecurityPolicyNamespace, "name", r.backendSecurityPolicyName) + + // 1. Get OIDCProvider Token + // This is the initial token from the configured OIDC provider (e.g., Kubernetes service account token) + oidcTokenExpiry, err := r.oidcProvider.GetToken(ctx) + if err != nil { + r.logger.Error(err, "failed to get token from oidc provider", "oidcIssuer", r.gcpCredentials.WorkLoadIdentityFederationConfig.WorkloadIdentityProvider.Name) + return time.Time{}, fmt.Errorf("failed to obtain OIDC token: %w", err) + } + + // 2. Exchange the JWT for an STS token. + // The OIDC JWT token is exchanged for a Google Cloud STS token + stsToken, err := r.stsTokenFunc(ctx, oidcTokenExpiry.Token, r.gcpCredentials.WorkLoadIdentityFederationConfig) + if err != nil { + wifConfig := r.gcpCredentials.WorkLoadIdentityFederationConfig + r.logger.Error(err, "failed to exchange JWT for STS token", + "projectID", wifConfig.ProjectID, + "workloadIdentityPool", wifConfig.WorkloadIdentityPoolName, + "workloadIdentityProvider", wifConfig.WorkloadIdentityProvider.Name) + return time.Time{}, fmt.Errorf("failed to exchange JWT for STS token (project: %s, pool: %s): %w", + wifConfig.ProjectID, wifConfig.WorkloadIdentityPoolName, err) + } + + // 3. Exchange the STS token for a GCP service account access token. + // The STS token is used to impersonate a GCP service account + var gcpAccessToken *tokenprovider.TokenExpiry + if r.gcpCredentials.WorkLoadIdentityFederationConfig.ServiceAccountImpersonation != nil { + gcpAccessToken, err = r.saTokenFunc(ctx, stsToken.Token, *r.gcpCredentials.WorkLoadIdentityFederationConfig.ServiceAccountImpersonation) + if err != nil { + saImpersonation := r.gcpCredentials.WorkLoadIdentityFederationConfig.ServiceAccountImpersonation + saEmail := fmt.Sprintf("%s@%s.iam.gserviceaccount.com", + saImpersonation.ServiceAccountName, + saImpersonation.ServiceAccountProjectName) + r.logger.Error(err, "failed to impersonate GCP service account", + "serviceAccount", saEmail, + "serviceAccountProject", saImpersonation.ServiceAccountProjectName) + return time.Time{}, fmt.Errorf("failed to impersonate service account %s: %w", saEmail, err) + } + } else { + // If no service account impersonation is configured, use the STS token directly + gcpAccessToken = stsToken + } + + secret, err := LookupSecret(ctx, r.client, r.backendSecurityPolicyNamespace, secretName) + if err != nil { + if apierrors.IsNotFound(err) { + r.logger.Info("creating a new gcp access token into secret", "namespace", r.backendSecurityPolicyNamespace, "name", r.backendSecurityPolicyName) + secret = &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: secretName, + Namespace: r.backendSecurityPolicyNamespace, + }, + Type: corev1.SecretTypeOpaque, + Data: make(map[string][]byte), + } + populateInSecret(secret, filterapi.GCPAuth{ + AccessToken: gcpAccessToken.Token, + Region: r.gcpCredentials.Region, + ProjectName: r.gcpCredentials.ProjectName, + }, gcpAccessToken.ExpiresAt) + err = r.client.Create(ctx, secret) + if err != nil { + r.logger.Error(err, "failed to create gcp access token", "namespace", r.backendSecurityPolicyNamespace, "name", r.backendSecurityPolicyName) + return time.Time{}, err + } + return gcpAccessToken.ExpiresAt, nil + } + r.logger.Error(err, "failed to lookup gcp access token secret", "namespace", r.backendSecurityPolicyNamespace, "name", r.backendSecurityPolicyName) + return time.Time{}, err + } + r.logger.Info("updating gcp access token secret", "namespace", r.backendSecurityPolicyNamespace, "name", r.backendSecurityPolicyName) + + populateInSecret(secret, filterapi.GCPAuth{ + AccessToken: gcpAccessToken.Token, + Region: r.gcpCredentials.Region, + ProjectName: r.gcpCredentials.ProjectName, + }, gcpAccessToken.ExpiresAt) + err = r.client.Update(ctx, secret) + if err != nil { + r.logger.Error(err, "failed to update gcp access token", "namespace", r.backendSecurityPolicyNamespace, "name", r.backendSecurityPolicyName) + return time.Time{}, err + } + return gcpAccessToken.ExpiresAt, nil +} + +var _ stsTokenGenerator = exchangeJWTForSTSToken + +// exchangeJWTForSTSToken implements [stsTokenGenerator] +// exchangeJWTForSTSToken exchanges a JWT token for a GCP STS (Security Token Service) token. +func exchangeJWTForSTSToken(ctx context.Context, jwtToken string, wifConfig aigv1a1.GCPWorkLoadIdentityFederationConfig, opts ...option.ClientOption) (*tokenprovider.TokenExpiry, error) { + proxyOpt, err := getGCPProxyClientOption() + if err != nil { + return nil, fmt.Errorf("error getting GCP proxy client option: %w", err) + } + if proxyOpt != nil { + opts = append(opts, proxyOpt) + } + + opts = append(opts, option.WithoutAuthentication()) + + // Create an STS client. + stsService, err := sts.NewService(ctx, opts...) + if err != nil { + return nil, fmt.Errorf("error creating GCP STS service client: %w", err) + } + // Construct the STS request. + // Build the audience string in the format required by GCP Workload Identity Federation + stsAudience := fmt.Sprintf("//iam.googleapis.com/projects/%s/locations/global/workloadIdentityPools/%s/providers/%s", + wifConfig.ProjectID, + wifConfig.WorkloadIdentityPoolName, + wifConfig.WorkloadIdentityProvider.Name) + + // Create the token exchange request with the appropriate parameters + req := &sts.GoogleIdentityStsV1ExchangeTokenRequest{ + GrantType: grantTypeTokenExchange, + Audience: stsAudience, + Scope: gcpIAMScope, + RequestedTokenType: tokenTypeAccessToken, + SubjectToken: jwtToken, + SubjectTokenType: tokenTypeJWT, + } + + // Call the STS API. + resp, err := stsService.V1.Token(req).Do() + if err != nil { + return nil, fmt.Errorf("error calling GCP STS Token API with audience %s: %w", stsAudience, err) + } + + return &tokenprovider.TokenExpiry{ + Token: resp.AccessToken, + ExpiresAt: time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second), + }, nil +} + +var _ serviceAccountTokenGenerator = impersonateServiceAccount + +// impersonateServiceAccount returns a GCP service account access token or an error if impersonation fails. +// It takes an STS token and uses it to impersonate a GCP service account, +// generating a new access token with the permissions of that service account. +// +// The service account email is constructed from serviceAccountName and serviceAccountProjectName +// in the format: @.iam.gserviceaccount.com +// +// The resulting token will have the cloud-platform scope. +func impersonateServiceAccount(ctx context.Context, stsToken string, saConfig aigv1a1.GCPServiceAccountImpersonationConfig, opts ...option.ClientOption) (*tokenprovider.TokenExpiry, error) { + // Construct the service account email from the configured parameters + saEmail := fmt.Sprintf("%s@%s.iam.gserviceaccount.com", saConfig.ServiceAccountName, saConfig.ServiceAccountProjectName) + + // Configure the impersonation parameters. + // Define which service account to impersonate and what scopes the token should have + config := impersonate.CredentialsConfig{ + TargetPrincipal: saEmail, // The service account to impersonate. + Scopes: []string{stsTokenScope}, // The desired scopes for the access token. + } + + // Use the STS token as the source token for impersonation + opts = append(opts, option.WithTokenSource(oauth2.StaticTokenSource(&oauth2.Token{AccessToken: stsToken, TokenType: "Bearer"}))) + + // If a proxy URL is set, add it as a client option + proxyOpt, err := getGCPProxyClientOption() + if err != nil { + return nil, fmt.Errorf("error getting GCP proxy client option: %w", err) + } + + if proxyOpt != nil { + opts = append(opts, proxyOpt) + } + + // Create a token source that will provide tokens with the permissions of the impersonated service account + ts, err := impersonate.CredentialsTokenSource(ctx, config, opts...) + if err != nil { + return nil, fmt.Errorf("error creating impersonated credentials for service account %s: %w", saEmail, err) + } + + // Get the token. + token, err := ts.Token() + if err != nil { + return nil, fmt.Errorf("error getting access token for service account %s: %w", saEmail, err) + } + return &tokenprovider.TokenExpiry{ + Token: token.AccessToken, + ExpiresAt: token.Expiry, + }, nil +} + +// populateAzureAccessToken updates the secret with the Azure access token. +func populateInSecret(secret *corev1.Secret, gcpAuth filterapi.GCPAuth, expiryTime time.Time) { + updateExpirationSecretAnnotation(secret, expiryTime) + secret.Data = map[string][]byte{ + GCPAccessTokenKey: []byte(gcpAuth.AccessToken), + GCPProjectNameKey: []byte(gcpAuth.ProjectName), + GCPRegionKey: []byte(gcpAuth.Region), + } +} + +func getGCPProxyClientOption() (option.ClientOption, error) { + proxyURL := os.Getenv("AI_GATEWAY_GCP_AUTH_PROXY_URL") + if proxyURL == "" { + return nil, nil + } + + parsedURL, err := url.Parse(proxyURL) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL: %w", err) + } + transport := &http.Transport{ + Proxy: http.ProxyURL(parsedURL), + } + httpClient := &http.Client{Transport: transport} + return option.WithHTTPClient(httpClient), nil +} diff --git a/internal/controller/rotators/gcp_oidc_token_rotator_test.go b/internal/controller/rotators/gcp_oidc_token_rotator_test.go new file mode 100644 index 0000000000..ec47190c2a --- /dev/null +++ b/internal/controller/rotators/gcp_oidc_token_rotator_test.go @@ -0,0 +1,932 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package rotators + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" + "github.com/go-logr/logr" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/require" + "google.golang.org/api/option" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + + aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" + "github.com/envoyproxy/ai-gateway/internal/controller/tokenprovider" +) + +const ( + dummyProjectName = "dummy-project-name" // #nosec G101 + dummyProjectRegion = "dummy-project-region" // #nosec G101 + dummyJWTToken = "dummy-oidc-token" // #nosec G101 + dummySTSToken = "dummy-sts-token" // #nosec G101 + oldGCPAccessToken = "old-gcp-access-token" // #nosec G101 + newGCPAccessToken = "new-gcp-access-token" // #nosec G101 +) + +func TestGCPTokenRotator_Rotate(t *testing.T) { + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, &corev1.Secret{}) + + // now := time.Now() + now := time.Date(2021, 1, 1, 12, 0, 0, 0, time.UTC) // Fixed time for testing + oneHourBeforeNow := now.Add(-1 * time.Hour) + twoHourAfterNow := now.Add(2 * time.Hour) + + oldSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: GetBSPSecretName("test-policy"), + Namespace: "default", + Annotations: map[string]string{ + ExpirationTimeAnnotationKey: oneHourBeforeNow.Format(time.RFC3339), + }, + }, + Type: corev1.SecretTypeOpaque, + Data: map[string][]byte{ + GCPProjectNameKey: []byte(dummyProjectName), + GCPRegionKey: []byte(dummyProjectRegion), + GCPAccessTokenKey: []byte(oldGCPAccessToken), + }, + } + + renewedSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: GetBSPSecretName("test-policy"), + Namespace: "default", + Annotations: map[string]string{ + ExpirationTimeAnnotationKey: twoHourAfterNow.Format(time.RFC3339), + }, + }, + Type: corev1.SecretTypeOpaque, + Data: map[string][]byte{ + GCPProjectNameKey: []byte(dummyProjectName), + GCPRegionKey: []byte(dummyProjectRegion), + GCPAccessTokenKey: []byte(newGCPAccessToken), + }, + } + + renewedSecretWithoutSAImpersonation := renewedSecret.DeepCopy() + renewedSecretWithoutSAImpersonation.Data[GCPAccessTokenKey] = []byte(dummySTSToken) + + tests := []struct { + name string + kubeInitObjects []runtime.Object + saTokenFunc serviceAccountTokenGenerator + stsTokenFunc stsTokenGenerator + skipServiceAccountImpersonation bool + expectedSecret *corev1.Secret + expectErrorMsg string + clientCreateFn func(t *testing.T) client.Client + }{ + { + name: "failed to get sts token", + kubeInitObjects: []runtime.Object{oldSecret}, + stsTokenFunc: func(_ context.Context, _ string, _ aigv1a1.GCPWorkLoadIdentityFederationConfig, _ ...option.ClientOption) (*tokenprovider.TokenExpiry, error) { + return nil, fmt.Errorf("fake network failure") + }, + expectErrorMsg: "failed to exchange JWT for STS token (project: test-project-id, pool: test-pool-name): fake network failure", + }, + { + name: "failed to get OIDC token", + kubeInitObjects: []runtime.Object{oldSecret}, + expectErrorMsg: "failed to obtain OIDC token: oidc provider error", + }, + { + name: "failed to impersonate service account", + kubeInitObjects: []runtime.Object{oldSecret}, + saTokenFunc: func(_ context.Context, _ string, _ aigv1a1.GCPServiceAccountImpersonationConfig, _ ...option.ClientOption) (*tokenprovider.TokenExpiry, error) { + return nil, fmt.Errorf("fake network failure") + }, + expectErrorMsg: "failed to impersonate service account test-service-account@test-service-account-project-name.iam.gserviceaccount.com: fake network failure", + }, + { + name: "secret with old does not exist", + kubeInitObjects: nil, + }, + { + name: "secret with old token exists", + kubeInitObjects: []runtime.Object{oldSecret}, + expectErrorMsg: "", + }, + { + name: "without service account impersonation", + kubeInitObjects: []runtime.Object{oldSecret}, + skipServiceAccountImpersonation: true, + expectedSecret: renewedSecretWithoutSAImpersonation, + expectErrorMsg: "", + }, + { + name: "create error", + kubeInitObjects: nil, + clientCreateFn: func(t *testing.T) client.Client { + // Create a fake client that returns an error on Create + fc := fake.NewFakeClient() + // Wrap the fake client to return an error on Create + return &errorOnCreateClient{ + Client: fc, + t: t, + } + }, + expectErrorMsg: "create error", + }, + { + name: "update error", + kubeInitObjects: []runtime.Object{oldSecret}, + clientCreateFn: func(t *testing.T) client.Client { + // Create a fake client that returns an error on Update + fc := fake.NewFakeClient() + // Wrap the fake client to return an error on Update + return &errorOnUpdateClient{ + Client: fc, + t: t, + } + }, + expectErrorMsg: "update error", + }, + { + name: "secret lookup error (non-NotFound)", + kubeInitObjects: []runtime.Object{}, + clientCreateFn: func(t *testing.T) client.Client { + // Create a fake client that returns an error on Get operations + fc := fake.NewFakeClient() + // Wrap the fake client to return an error on Get + return &errorOnGetClient{ + Client: fc, + t: t, + } + }, + expectErrorMsg: "failed to get secret: lookup error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var fakeClient client.Client + + // Use custom client if provided, otherwise use default fake client + if tt.clientCreateFn != nil { + fakeClient = tt.clientCreateFn(t) + // Add initial objects to the custom client if needed + for _, obj := range tt.kubeInitObjects { + if clientObj, ok := obj.(client.Object); ok { + err := fakeClient.Create(context.Background(), clientObj) + require.NoError(t, err) + } + } + } else { + fakeClient = fake.NewFakeClient(tt.kubeInitObjects...) + } + + // If no saTokenFunc or stsTokenFunc is provided, use the default mock functions. + if tt.saTokenFunc == nil { + tt.saTokenFunc = func(_ context.Context, _ string, _ aigv1a1.GCPServiceAccountImpersonationConfig, _ ...option.ClientOption) (*tokenprovider.TokenExpiry, error) { + return &tokenprovider.TokenExpiry{Token: newGCPAccessToken, ExpiresAt: twoHourAfterNow}, nil + } + } + if tt.stsTokenFunc == nil { + tt.stsTokenFunc = func(_ context.Context, _ string, _ aigv1a1.GCPWorkLoadIdentityFederationConfig, _ ...option.ClientOption) (*tokenprovider.TokenExpiry, error) { + return &tokenprovider.TokenExpiry{Token: dummySTSToken, ExpiresAt: twoHourAfterNow}, nil + } + } + gcpCredentials := aigv1a1.BackendSecurityPolicyGCPCredentials{ + ProjectName: dummyProjectName, + Region: dummyProjectRegion, + WorkLoadIdentityFederationConfig: aigv1a1.GCPWorkLoadIdentityFederationConfig{ + ProjectID: "test-project-id", + WorkloadIdentityProvider: aigv1a1.GCPWorkloadIdentityProvider{}, + WorkloadIdentityPoolName: "test-pool-name", + }, + } + if !tt.skipServiceAccountImpersonation { + gcpCredentials.WorkLoadIdentityFederationConfig.ServiceAccountImpersonation = &aigv1a1.GCPServiceAccountImpersonationConfig{ + ServiceAccountName: "test-service-account", + ServiceAccountProjectName: "test-service-account-project-name", + } + } + + rotator := &gcpOIDCTokenRotator{ + client: fakeClient, + logger: logr.Logger{}, + gcpCredentials: gcpCredentials, + backendSecurityPolicyName: "test-policy", + backendSecurityPolicyNamespace: "default", + preRotationWindow: 5 * time.Minute, + saTokenFunc: tt.saTokenFunc, + stsTokenFunc: tt.stsTokenFunc, + } + + // Set up OIDC provider based on test case + if tt.name == "failed to get OIDC token" { + rotator.oidcProvider = tokenprovider.NewMockTokenProvider("", time.Time{}, fmt.Errorf("oidc provider error")) + } else { + rotator.oidcProvider = tokenprovider.NewMockTokenProvider(dummyJWTToken, twoHourAfterNow, nil) + } + + expiration, err := rotator.Rotate(context.Background()) + switch { + case tt.expectErrorMsg != "" && err == nil: + t.Errorf("expected error %q, got nil", tt.expectErrorMsg) + case tt.expectErrorMsg != "" && err != nil: + if d := cmp.Diff(tt.expectErrorMsg, err.Error()); d != "" { + t.Errorf("GCPTokenRotator.Rotate() returned unexpected error (-want +got):\n%s", d) + } + case tt.expectErrorMsg == "" && err != nil: + t.Errorf("unexpected error: %v", err) + default: + if d := cmp.Diff(twoHourAfterNow, expiration); d != "" { + t.Errorf("GCPTokenRotator.Rotate() returned unexpected expiration time (-want +got):\n%s", d) + } + + var actualSec corev1.Secret + if err = fakeClient.Get(context.Background(), client.ObjectKey{ + Namespace: renewedSecret.Namespace, + Name: renewedSecret.Name, + }, &actualSec); err != nil { + t.Errorf("Failed to get expected secret from client: %v", err) + } + + if tt.expectedSecret == nil { + tt.expectedSecret = renewedSecret + } + if d := cmp.Diff(tt.expectedSecret, &actualSec, cmpopts.IgnoreFields(corev1.Secret{}, "ResourceVersion")); d != "" { + t.Errorf("GCPTokenRotator.Rotate() returned unexpected secret (-want +got):\n%s", d) + } + } + }) + } +} + +func TestGCPTokenRotator_GetPreRotationTime(t *testing.T) { + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, &corev1.Secret{}) + + now := time.Now() + + tests := []struct { + name string + secret *corev1.Secret + expectedTime time.Time + expectedError bool + clientCreateFn func(t *testing.T) client.Client + }{ + { + name: "secret annotation missing", + secret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: GetBSPSecretName("test-policy"), + Namespace: "default", + }, + Data: map[string][]byte{ + GCPProjectNameKey: []byte(dummyProjectName), + GCPRegionKey: []byte(dummyProjectRegion), + GCPAccessTokenKey: []byte(oldGCPAccessToken), + }, + }, + expectedTime: time.Time{}, + expectedError: true, + }, + { + name: "rotation time before expiration time", + secret: &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: GetBSPSecretName("test-policy"), + Namespace: "default", + Annotations: map[string]string{ + ExpirationTimeAnnotationKey: now.Add(2 * time.Hour).Format(time.RFC3339), + }, + }, + Data: map[string][]byte{ + GCPProjectNameKey: []byte(dummyProjectName), + GCPRegionKey: []byte(dummyProjectRegion), + GCPAccessTokenKey: []byte(oldGCPAccessToken), + }, + }, + expectedTime: now.Add(2 * time.Hour), + expectedError: false, + }, + { + name: "lookup secret error (non-NotFound)", + expectedTime: time.Time{}, + expectedError: true, + clientCreateFn: func(t *testing.T) client.Client { + return &errorOnGetClient{ + Client: fake.NewFakeClient(), + t: t, + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var testClient client.Client + + // Use custom client if provided, otherwise use default fake client + if tt.clientCreateFn != nil { + testClient = tt.clientCreateFn(t) + } else { + testClient = fake.NewClientBuilder().WithScheme(scheme).Build() + err := testClient.Create(context.Background(), tt.secret) + require.NoError(t, err) + } + + // Create rotator with the test client + testRotator := &gcpOIDCTokenRotator{ + client: testClient, + preRotationWindow: 5 * time.Minute, + backendSecurityPolicyName: "test-policy", + backendSecurityPolicyNamespace: "default", + gcpCredentials: aigv1a1.BackendSecurityPolicyGCPCredentials{}, + } + + got, err := testRotator.GetPreRotationTime(context.Background()) + if (err != nil) != tt.expectedError { + t.Errorf("GCPTokenRotator.GetPreRotationTime() error = %v, expectedError %v", err, tt.expectedError) + return + } + if !tt.expectedTime.IsZero() && got.Compare(tt.expectedTime) >= 0 { + t.Errorf("GCPTokenRotator.GetPreRotationTime() = %v, expected %v", got, tt.expectedTime) + } + }) + } +} + +func TestGCPTokenRotator_IsExpired(t *testing.T) { + fakeKubeClient := fake.NewFakeClient() + rotator := &gcpOIDCTokenRotator{ + client: fakeKubeClient, + } + now := time.Now() + tests := []struct { + name string + expiration time.Time + expect bool + }{ + { + name: "not expired", + expiration: now.Add(1 * time.Hour), + expect: false, + }, + { + name: "expired", + expiration: now.Add(-1 * time.Hour), + expect: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := rotator.IsExpired(tt.expiration); got != tt.expect { + t.Errorf("GCPTokenRotator.IsExpired() = %v, expect %v", got, tt.expect) + } + }) + } +} + +// TestExchangeJWTForSTSToken tests the exchangeJWTForSTSToken function. +func TestExchangeJWTForSTSToken(t *testing.T) { + // Test cases + tests := []struct { + name string + jwtToken string + wifConfig aigv1a1.GCPWorkLoadIdentityFederationConfig + mockServer func() *httptest.Server + expectedError bool + expectedToken string + expectedExpires time.Duration + }{ + { + name: "successful token exchange", + jwtToken: "test-jwt-token", + wifConfig: aigv1a1.GCPWorkLoadIdentityFederationConfig{ + ProjectID: "test-project", + WorkloadIdentityPoolName: "test-pool", + WorkloadIdentityProvider: aigv1a1.GCPWorkloadIdentityProvider{ + Name: "test-provider", + }, + }, + mockServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request is to token endpoint + if r.URL.Path != "/v1/token" { + http.Error(w, "Not found", http.StatusNotFound) + return + } + + // Verify it's a POST request + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Return successful token response + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{ + "access_token": "test-sts-token", + "expires_in": 3600, + "token_type": "Bearer" + }`) + })) + }, + expectedError: false, + expectedToken: "test-sts-token", + expectedExpires: time.Hour, + }, + { + name: "token exchange error", + jwtToken: "invalid-jwt-token", + wifConfig: aigv1a1.GCPWorkLoadIdentityFederationConfig{ + ProjectID: "test-project", + WorkloadIdentityPoolName: "test-pool", + WorkloadIdentityProvider: aigv1a1.GCPWorkloadIdentityProvider{ + Name: "test-provider", + }, + }, + mockServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // Return error response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, err := fmt.Fprintf(w, `{ + "error": "invalid_token", + "error_description": "The provided JWT is invalid" + }`) + if err != nil { + return + } + })) + }, + expectedError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup mock server + server := tc.mockServer() + defer server.Close() + + // Create custom HTTP client option that points to our test server + ctx := context.Background() + // Call the function being tested + tokenExpiry, err := exchangeJWTForSTSToken(ctx, tc.jwtToken, tc.wifConfig, option.WithEndpoint(server.URL)) + + // Check error conditions + if tc.expectedError { + require.Error(t, err) + require.Nil(t, tokenExpiry) + return + } + + // Validate successful results + require.NoError(t, err) + require.NotNil(t, tokenExpiry) + require.Equal(t, tc.expectedToken, tokenExpiry.Token) + + // Check expiration time is in the expected range + // Since the function uses time.Now(), we can't assert the exact time + // but we can check that it's within an acceptable range + expectedExpiryTime := time.Now().Add(tc.expectedExpires) + timeDiff := tokenExpiry.ExpiresAt.Sub(expectedExpiryTime) + require.Less(t, timeDiff.Abs(), time.Second*5, "Expiry time should be close to expected value") + }) + } +} + +func TestExchangeJWTForSTSToken_WithoutAuthOption(t *testing.T) { + // Create a mock server that validates the request has no authentication + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check for absence of authentication headers to validate WithoutAuthentication is working + authHeader := r.Header.Get("Authorization") + if authHeader != "" { + http.Error(w, "Authorization header should not be present", http.StatusBadRequest) + return + } + + // Return a successful response + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{ + "access_token": "test-sts-token", + "expires_in": 3600, + "token_type": "Bearer" + }`) + })) + defer server.Close() + + // Define test configuration + jwtToken := "test-jwt-token" // #nosec G101 + wifConfig := aigv1a1.GCPWorkLoadIdentityFederationConfig{ + ProjectID: "test-project", + WorkloadIdentityPoolName: "test-pool", + WorkloadIdentityProvider: aigv1a1.GCPWorkloadIdentityProvider{ + Name: "test-provider", + }, + } + + // Call the function with the server URL as the endpoint + ctx := context.Background() + tokenExpiry, err := exchangeJWTForSTSToken(ctx, jwtToken, wifConfig, option.WithEndpoint(server.URL)) + + // Verify the results + require.NoError(t, err) + require.NotNil(t, tokenExpiry) + require.Equal(t, "test-sts-token", tokenExpiry.Token) + + // Verify the expiration time is about an hour from now + expectedExpiryTime := time.Now().Add(time.Hour) + timeDiff := tokenExpiry.ExpiresAt.Sub(expectedExpiryTime) + require.Less(t, timeDiff.Abs(), time.Second*5, "Expiry time should be close to expected value") +} + +// roundTripperFunc implements http.RoundTripper interface for custom response handling +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// TestImpersonateServiceAccount tests the impersonateServiceAccount function. +func TestImpersonateServiceAccount(t *testing.T) { + // Test cases + tests := []struct { + name string + stsToken string + saConfig aigv1a1.GCPServiceAccountImpersonationConfig + // impersonateServiceAccount is hardcoded to call google api endpoint and ignore mockEndpoints set via opts. + // thus we mock the underlying HTTPRoundTripper to simulate mock responses. + mockResponse func(req *http.Request) (*http.Response, error) + expectedError bool + expectedToken string + }{ + { + name: "successful service account impersonation", + stsToken: "test-sts-token", + saConfig: aigv1a1.GCPServiceAccountImpersonationConfig{ + ServiceAccountName: "test-service-account", + ServiceAccountProjectName: "test-project", + }, + mockResponse: func(req *http.Request) (*http.Response, error) { + // Verify it's a POST request + if req.Method != http.MethodPost { + return &http.Response{ + StatusCode: http.StatusMethodNotAllowed, + Body: http.NoBody, + }, nil + } + + // Verify request is to the IAM credentials API and is asking to generate access token + if !strings.Contains(req.URL.String(), "iamcredentials.googleapis.com") || + !strings.Contains(req.URL.Path, "generateAccessToken") { + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: http.NoBody, + }, nil + } + + // Return successful token response + expiryTime := time.Now().Add(time.Hour).Format(time.RFC3339) + respBody := fmt.Sprintf(`{ + "accessToken": "impersonated-sa-token", + "expireTime": "%s" + }`, expiryTime) + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(respBody)), + Header: map[string][]string{"Content-Type": {"application/json"}}, + }, nil + }, + expectedError: false, + expectedToken: "impersonated-sa-token", + }, + { + name: "impersonation error", + stsToken: "invalid-sts-token", + saConfig: aigv1a1.GCPServiceAccountImpersonationConfig{ + ServiceAccountName: "test-service-account", + ServiceAccountProjectName: "test-project", + }, + mockResponse: func(_ *http.Request) (*http.Response, error) { + // Return error response + respBody := `{ + "error": { + "code": 401, + "message": "Request had invalid authentication credentials", + "status": "UNAUTHENTICATED" + } + }` + + return &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: io.NopCloser(strings.NewReader(respBody)), + Header: map[string][]string{"Content-Type": {"application/json"}}, + }, nil + }, + expectedError: true, + }, + { + name: "credentials creation error", + stsToken: "test-sts-token", + saConfig: aigv1a1.GCPServiceAccountImpersonationConfig{ + ServiceAccountName: "test-service-account", + ServiceAccountProjectName: "test-project", + }, + mockResponse: func(_ *http.Request) (*http.Response, error) { + // Simulate network error during credential creation + return nil, fmt.Errorf("network error during credential creation") + }, + expectedError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + // Create a mock HTTP client that intercepts requests + mockTransport := roundTripperFunc(tc.mockResponse) + mockHTTPClient := &http.Client{Transport: mockTransport} + + // Call the function being tested with our mock HTTP client + tokenExpiry, err := impersonateServiceAccount(ctx, tc.stsToken, tc.saConfig, option.WithHTTPClient(mockHTTPClient)) + + // Check error conditions + if tc.expectedError { + require.Error(t, err) + require.Nil(t, tokenExpiry) + return + } + + // Validate successful results + require.NoError(t, err) + require.NotNil(t, tokenExpiry) + require.Equal(t, tc.expectedToken, tokenExpiry.Token) + + // Check that expiration time is reasonably set (should be around 1 hour from now) + expectedExpiryTimeApprox := time.Now().Add(time.Hour) + timeDiff := tokenExpiry.ExpiresAt.Sub(expectedExpiryTimeApprox) + require.Less(t, timeDiff.Abs(), time.Minute*5, "Expiry time should be close to expected value") + }) + } +} + +// TestNewGCPOIDCTokenRotator tests the NewGCPOIDCTokenRotator constructor function. +func TestNewGCPOIDCTokenRotator(t *testing.T) { + logger := logr.Logger{} + preRotationWindow := 30 * time.Minute + + // Mock token provider creation by directly creating test cases with valid/invalid params + // without monkey patching the NewOidcTokenProvider function + + // Define OIDC values based on the real Envoy Gateway API types + validOIDCConfig := aigv1a1.BackendSecurityPolicyOIDC{ + OIDC: egv1a1.OIDC{ + ClientID: "client-id", + Scopes: []string{"scope1", "scope2"}, + }, + } + + tests := []struct { + name string + bsp aigv1a1.BackendSecurityPolicy + expectedError string + }{ + { + name: "nil GCP credentials", + bsp: aigv1a1.BackendSecurityPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-policy", + Namespace: "default", + }, + Spec: aigv1a1.BackendSecurityPolicySpec{ + GCPCredentials: nil, + }, + }, + expectedError: "GCP credentials are not configured in BackendSecurityPolicy default/test-policy", + }, + { + name: "success", + bsp: aigv1a1.BackendSecurityPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-success-policy", + Namespace: "default", + }, + Spec: aigv1a1.BackendSecurityPolicySpec{ + GCPCredentials: &aigv1a1.BackendSecurityPolicyGCPCredentials{ + ProjectName: "test-project", + Region: "us-central1", + WorkLoadIdentityFederationConfig: aigv1a1.GCPWorkLoadIdentityFederationConfig{ + ProjectID: "test-project-id", + WorkloadIdentityPoolName: "test-pool-name", + WorkloadIdentityProvider: aigv1a1.GCPWorkloadIdentityProvider{ + Name: "test-provider", + OIDCProvider: validOIDCConfig, + }, + ServiceAccountImpersonation: &aigv1a1.GCPServiceAccountImpersonationConfig{ + ServiceAccountName: "test-service-account", + ServiceAccountProjectName: "test-project", + }, + }, + }, + }, + }, + expectedError: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a fake client for each test + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, &corev1.Secret{}) + fakeClient := fake.NewFakeClient() + + var rotator Rotator + var err error + + mockTokenProvider := tokenprovider.NewMockTokenProvider("mock-jwt-token", time.Now().Add(time.Hour), nil) + rotator, err = NewGCPOIDCTokenRotator(fakeClient, logger, tt.bsp, preRotationWindow, mockTokenProvider) + + if tt.expectedError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectedError) + require.Nil(t, rotator) + } else { + require.NotNil(t, rotator) + + // Verify rotator is properly initialized + gcpRotator, ok := rotator.(*gcpOIDCTokenRotator) + require.True(t, ok, "Expected a gcpOIDCTokenRotator instance") + + // Instead of comparing the entire struct with cmp.Diff, which has issues with unexported fields, + // verify individual fields that we care about + require.Equal(t, tt.bsp.Name, gcpRotator.backendSecurityPolicyName) + require.Equal(t, tt.bsp.Namespace, gcpRotator.backendSecurityPolicyNamespace) + require.Equal(t, preRotationWindow, gcpRotator.preRotationWindow) + require.NotNil(t, gcpRotator.oidcProvider) + require.NotNil(t, gcpRotator.client) + require.NotNil(t, gcpRotator.saTokenFunc) + require.NotNil(t, gcpRotator.stsTokenFunc) + + // Verify that the GCP credentials were properly copied + if tt.bsp.Spec.GCPCredentials != nil { + require.Equal(t, *tt.bsp.Spec.GCPCredentials, gcpRotator.gcpCredentials) + } + } + }) + } +} + +// errorOnCreateClient is a client that returns an error on Create +type errorOnCreateClient struct { + client.Client + t *testing.T +} + +func (c *errorOnCreateClient) Create(_ context.Context, _ client.Object, _ ...client.CreateOption) error { + return fmt.Errorf("create error") +} + +// errorOnUpdateClient is a client that returns an error on Update +type errorOnUpdateClient struct { + client.Client + t *testing.T +} + +func (c *errorOnUpdateClient) Create(_ context.Context, _ client.Object, _ ...client.CreateOption) error { + return nil // Allow create to succeed +} + +func (c *errorOnUpdateClient) Get(_ context.Context, key client.ObjectKey, obj client.Object, _ ...client.GetOption) error { + // Cast to Secret + if secret, ok := obj.(*corev1.Secret); ok { + secret.Name = key.Name + secret.Namespace = key.Namespace + secret.Data = map[string][]byte{ + GCPProjectNameKey: []byte(dummyProjectName), + GCPRegionKey: []byte(dummyProjectRegion), + GCPAccessTokenKey: []byte(oldGCPAccessToken), + } + secret.Annotations = map[string]string{ + ExpirationTimeAnnotationKey: time.Now().Format(time.RFC3339), + } + return nil + } + return nil +} + +func (c *errorOnUpdateClient) Update(_ context.Context, _ client.Object, _ ...client.UpdateOption) error { + return fmt.Errorf("update error") +} + +// errorOnGetClient is a client that returns an error on Get (for testing lookup failures) +type errorOnGetClient struct { + client.Client + t *testing.T +} + +func (c *errorOnGetClient) Get(_ context.Context, _ client.ObjectKey, _ client.Object, _ ...client.GetOption) error { + return fmt.Errorf("lookup error") +} + +func TestGetGCPProxyClientOption(t *testing.T) { + tests := []struct { + name string + proxyURL string + setEnvVar bool + wantErr bool + wantNilOption bool + validateOption func(t *testing.T, opt option.ClientOption) + }{ + { + name: "no proxy URL environment variable", + setEnvVar: false, + wantErr: false, + wantNilOption: true, + }, + { + name: "empty proxy URL environment variable", + proxyURL: "", + setEnvVar: true, + wantErr: false, + wantNilOption: true, + }, + { + name: "valid HTTPS proxy URL", + proxyURL: "https://secure-proxy.example.com:8443", + setEnvVar: true, + wantErr: false, + wantNilOption: false, + validateOption: func(t *testing.T, opt option.ClientOption) { + require.NotNil(t, opt) + }, + }, + { + name: "invalid proxy URL - missing protocol scheme", + proxyURL: "://invalid", + setEnvVar: true, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save original environment variable + originalProxyURL := os.Getenv("AI_GATEWAY_GCP_AUTH_PROXY_URL") + defer func() { + if originalProxyURL != "" { + os.Setenv("AI_GATEWAY_GCP_AUTH_PROXY_URL", originalProxyURL) + } else { + os.Unsetenv("AI_GATEWAY_GCP_AUTH_PROXY_URL") + } + }() + + // Set up test environment + if tt.setEnvVar { + os.Setenv("AI_GATEWAY_GCP_AUTH_PROXY_URL", tt.proxyURL) + } else { + os.Unsetenv("AI_GATEWAY_GCP_AUTH_PROXY_URL") + } + + // Call the function under test + got, err := getGCPProxyClientOption() + + // Validate error expectation + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "invalid proxy URL") + return + } + + require.NoError(t, err) + + // Validate nil option expectation + if tt.wantNilOption { + require.Nil(t, got) + return + } + + // Additional validation if provided + if tt.validateOption != nil { + tt.validateOption(t, got) + } + }) + } +} diff --git a/internal/controller/tokenprovider/oidc_token_provider.go b/internal/controller/tokenprovider/oidc_token_provider.go index 3fae8d3764..e104b606fa 100644 --- a/internal/controller/tokenprovider/oidc_token_provider.go +++ b/internal/controller/tokenprovider/oidc_token_provider.go @@ -27,6 +27,10 @@ type oidcTokenProvider struct { // NewOidcTokenProvider creates a new TokenProvider with the given OIDC configuration. func NewOidcTokenProvider(ctx context.Context, client client.Client, oidcConfig *egv1a1.OIDC) (TokenProvider, error) { + if oidcConfig == nil { + return nil, fmt.Errorf("provided oidc config is nil") + } + issuerURL := oidcConfig.Provider.Issuer oidcProvider, err := oidc.NewProvider(ctx, issuerURL) if err != nil { diff --git a/internal/extproc/backendauth/auth.go b/internal/extproc/backendauth/auth.go index 6df9f1b360..bef50cf9b2 100644 --- a/internal/extproc/backendauth/auth.go +++ b/internal/extproc/backendauth/auth.go @@ -31,6 +31,8 @@ func NewHandler(ctx context.Context, config *filterapi.BackendAuth) (Handler, er return newAPIKeyHandler(config.APIKey) case config.AzureAuth != nil: return newAzureHandler(config.AzureAuth) + case config.GCPAuth != nil: + return newGCPHandler(config.GCPAuth) default: return nil, errors.New("no backend auth handler found") } diff --git a/internal/extproc/backendauth/auth_test.go b/internal/extproc/backendauth/auth_test.go index 422bea9903..2cfea3bd93 100644 --- a/internal/extproc/backendauth/auth_test.go +++ b/internal/extproc/backendauth/auth_test.go @@ -40,6 +40,16 @@ aws_secret_access_key = test AzureAuth: &filterapi.AzureAuth{AccessToken: "some-access-token"}, }, }, + { + name: "GCPAuth", + config: &filterapi.BackendAuth{ + GCPAuth: &filterapi.GCPAuth{ + AccessToken: "some-access-token", + Region: "some-region", + ProjectName: "some-project", + }, + }, + }, } { t.Run(tt.name, func(t *testing.T) { _, err := NewHandler(t.Context(), tt.config) diff --git a/internal/extproc/backendauth/gcp.go b/internal/extproc/backendauth/gcp.go new file mode 100644 index 0000000000..3f247b6c77 --- /dev/null +++ b/internal/extproc/backendauth/gcp.go @@ -0,0 +1,89 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package backendauth + +import ( + "context" + "fmt" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + + "github.com/envoyproxy/ai-gateway/filterapi" +) + +type gcpHandler struct { + gcpAccessToken string // The GCP access token used for authentication. + region string // The GCP region to use for requests. + projectName string // The GCP project to use for requests. +} + +func newGCPHandler(gcpAuth *filterapi.GCPAuth) (Handler, error) { + if gcpAuth == nil { + return nil, fmt.Errorf("GCP auth configuration cannot be nil") + } + + if gcpAuth.AccessToken == "" { + return nil, fmt.Errorf("GCP access token cannot be empty") + } + + return &gcpHandler{ + gcpAccessToken: gcpAuth.AccessToken, + region: gcpAuth.Region, + projectName: gcpAuth.ProjectName, + }, nil +} + +// Do implements [Handler.Do]. +// +// This method updates the request headers to: +// 1. Prepend the GCP API prefix to the ":path" header, constructing the full endpoint URL. +// 2. Add an "Authorization" header with the GCP access token. +// +// The ":path" header is expected to contain the API-specific suffix, which is injected by translator.requestBody. +// The suffix is combined with the generated prefix to form the complete path for the GCP API call. +func (g *gcpHandler) Do(_ context.Context, _ map[string]string, headerMut *extprocv3.HeaderMutation, _ *extprocv3.BodyMutation) error { + var pathHeaderFound bool + + // Build the GCP URL prefix using the configured region and project name. + prefixPath := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s", g.region, g.projectName, g.region) + + // Find and update the ":path" header by prepending the prefix. + for _, hdr := range headerMut.SetHeaders { + if hdr.Header != nil && hdr.Header.Key == ":path" { + pathHeaderFound = true + // Update the string value if present. + if len(hdr.Header.Value) > 0 { + suffixPath := hdr.Header.Value + hdr.Header.Value = fmt.Sprintf("%s/%s", prefixPath, suffixPath) + } + // Update the raw byte value if present. + if len(hdr.Header.RawValue) > 0 { + suffixPath := string(hdr.Header.RawValue) + path := fmt.Sprintf("%s/%s", prefixPath, suffixPath) + hdr.Header.RawValue = []byte(path) + } + break + } + } + + if !pathHeaderFound { + return fmt.Errorf("missing ':path' header in the request") + } + + // Add the Authorization header with the GCP access token. + headerMut.SetHeaders = append( + headerMut.SetHeaders, + &corev3.HeaderValueOption{ + Header: &corev3.HeaderValue{ + Key: "Authorization", + RawValue: []byte(fmt.Sprintf("Bearer %s", g.gcpAccessToken)), + }, + }, + ) + + return nil +} diff --git a/internal/extproc/backendauth/gcp_test.go b/internal/extproc/backendauth/gcp_test.go new file mode 100644 index 0000000000..699aeef67b --- /dev/null +++ b/internal/extproc/backendauth/gcp_test.go @@ -0,0 +1,178 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package backendauth + +import ( + "bytes" + "context" + "fmt" + "testing" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/filterapi" +) + +func TestNewGCPHandler(t *testing.T) { + testCases := []struct { + name string + gcpAuth *filterapi.GCPAuth + wantHandler *gcpHandler + wantError bool + }{ + { + name: "valid config", + gcpAuth: &filterapi.GCPAuth{ + AccessToken: "test-token", + Region: "us-central1", + ProjectName: "test-project", + }, + wantHandler: &gcpHandler{ + gcpAccessToken: "test-token", + region: "us-central1", + projectName: "test-project", + }, + wantError: false, + }, + { + name: "nil config", + gcpAuth: nil, + wantHandler: nil, + wantError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + handler, err := newGCPHandler(tc.gcpAuth) + if tc.wantError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.NotNil(t, handler) + + if d := cmp.Diff(tc.wantHandler, handler, cmp.AllowUnexported(gcpHandler{})); d != "" { + t.Errorf("Handler mismatch (-want +got):\n%s", d) + } + } + }) + } +} + +func TestGCPHandler_Do(t *testing.T) { + handler := &gcpHandler{ + gcpAccessToken: "test-token", + region: "us-central1", + projectName: "test-project", + } + testCases := []struct { + name string + handler *gcpHandler + requestHeaders map[string]string + headerMut *extprocv3.HeaderMutation + bodyMut *extprocv3.BodyMutation + wantPathValue string + wantPathRawValue []byte + wantErrorMsg string + }{ + { + name: "basic headers update with string value", + handler: handler, + headerMut: &extprocv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{ + { + Header: &corev3.HeaderValue{ + Key: ":path", + Value: "publishers/google/models/gemini-pro:generateContent", + }, + }, + }, + }, + bodyMut: &extprocv3.BodyMutation{}, + wantPathValue: "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:generateContent", + }, + { + name: "basic headers update with raw value", + handler: handler, + headerMut: &extprocv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{ + { + Header: &corev3.HeaderValue{ + Key: ":path", + RawValue: []byte("publishers/google/models/gemini-pro:generateContent"), + }, + }, + }, + }, + bodyMut: &extprocv3.BodyMutation{}, + wantPathRawValue: []byte("https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/gemini-pro:generateContent"), + }, + { + name: "no path header", + handler: handler, + headerMut: &extprocv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{ + { + Header: &corev3.HeaderValue{ + Key: "Content-Type", + Value: "application/json", + }, + }, + }, + }, + bodyMut: &extprocv3.BodyMutation{}, + wantErrorMsg: "missing ':path' header in the request", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + err := tc.handler.Do(ctx, nil, tc.headerMut, tc.bodyMut) + + if tc.wantErrorMsg != "" { + require.ErrorContains(t, err, tc.wantErrorMsg, "Expected error message not found") + } else { + require.NoError(t, err) + + // Check Authorization header + authHeaderFound := false + expectedAuthHeader := fmt.Sprintf("Bearer %s", tc.handler.gcpAccessToken) + + // Check path header if expected + pathHeaderUpdated := false + + for _, header := range tc.headerMut.SetHeaders { + if header.Header.Key == "Authorization" { + authHeaderFound = true + require.Equal(t, []byte(expectedAuthHeader), header.Header.RawValue) + } + + if header.Header.Key == ":path" { + pathHeaderUpdated = true + if len(tc.wantPathValue) > 0 { + require.Equal(t, tc.wantPathValue, header.Header.Value) + } + if len(tc.wantPathRawValue) > 0 { + require.True(t, bytes.Equal(tc.wantPathRawValue, header.Header.RawValue)) + } + } + } + + // Authorization header should always be added + require.True(t, authHeaderFound, "Authorization header not found") + + // Only check path header if we had expectations for it + if len(tc.wantPathValue) > 0 || len(tc.wantPathRawValue) > 0 { + require.True(t, pathHeaderUpdated, "Path header not updated as expected") + } + } + }) + } +} diff --git a/internal/extproc/chatcompletion_processor.go b/internal/extproc/chatcompletion_processor.go index d9f5d3dfa6..a77fc6ad30 100644 --- a/internal/extproc/chatcompletion_processor.go +++ b/internal/extproc/chatcompletion_processor.go @@ -181,6 +181,10 @@ func (c *chatCompletionProcessorUpstreamFilter) selectTranslator(out filterapi.V c.translator = translator.NewChatCompletionOpenAIToAWSBedrockTranslator(c.modelNameOverride) case filterapi.APISchemaAzureOpenAI: c.translator = translator.NewChatCompletionOpenAIToAzureOpenAITranslator(out.Version, c.modelNameOverride) + case filterapi.APISchemaGCPVertexAI: + c.translator = translator.NewChatCompletionOpenAIToGCPVertexAITranslator() + case filterapi.APISchemaGCPAnthropic: + c.translator = translator.NewChatCompletionOpenAIToGCPAnthropicTranslator() default: return fmt.Errorf("unsupported API schema: backend=%s", out) } diff --git a/internal/extproc/translator/gemini_helper.go b/internal/extproc/translator/gemini_helper.go new file mode 100644 index 0000000000..33c555bd34 --- /dev/null +++ b/internal/extproc/translator/gemini_helper.go @@ -0,0 +1,63 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "fmt" + "strconv" + + "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" +) + +const ( + GCPModelPublisherGoogle = "google" + GCPModelPublisherAnthropic = "anthropic" + GCPMethodGenerateContent = "generateContent" + HTTPHeaderKeyContentLength = "Content-Length" +) + +func buildGCPModelPathSuffix(publisher, model, gcpMethod string) string { + pathSuffix := fmt.Sprintf("publishers/%s/models/%s:%s", publisher, model, gcpMethod) + return pathSuffix +} + +// buildGCPRequestMutations creates header and body mutations for GCP requests +// It sets the ":path" header, the "content-length" header and the request body. +func buildGCPRequestMutations(path string, reqBody []byte) (*ext_procv3.HeaderMutation, *ext_procv3.BodyMutation) { + var bodyMutation *ext_procv3.BodyMutation + + // Create header mutation + headerMutation := &ext_procv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{ + { + Header: &corev3.HeaderValue{ + Key: ":path", + RawValue: []byte(path), + }, + }, + }, + } + + // If the request body is not empty, we set the content-length header and create a body mutation + if len(reqBody) != 0 { + // Set the "content-length" header + headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{ + Header: &corev3.HeaderValue{ + Key: HTTPHeaderKeyContentLength, + RawValue: []byte(strconv.Itoa(len(reqBody))), + }, + }) + + // Create body mutation + bodyMutation = &ext_procv3.BodyMutation{ + Mutation: &ext_procv3.BodyMutation_Body{Body: reqBody}, + } + + } + + return headerMutation, bodyMutation +} diff --git a/internal/extproc/translator/openai_gcpanthropic.go b/internal/extproc/translator/openai_gcpanthropic.go new file mode 100644 index 0000000000..7b47aa09b9 --- /dev/null +++ b/internal/extproc/translator/openai_gcpanthropic.go @@ -0,0 +1,67 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 + +package translator + +import ( + "io" + + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" +) + +// NewChatCompletionOpenAIToGCPAnthropicTranslator implements [Factory] for OpenAI to GCP Anthropic translation. +// This translator converts OpenAI ChatCompletion API requests to GCP Anthropic API format. +func NewChatCompletionOpenAIToGCPAnthropicTranslator() OpenAIChatCompletionTranslator { + return &openAIToGCPAnthropicTranslatorV1ChatCompletion{} +} + +type openAIToGCPAnthropicTranslatorV1ChatCompletion struct{} + +// RequestBody implements [Translator.RequestBody] for GCP Anthropic. +// This method translates an OpenAI ChatCompletion request to a GCP Anthropic API request. +func (o *openAIToGCPAnthropicTranslatorV1ChatCompletion) RequestBody(_ []byte, openAIReq *openai.ChatCompletionRequest, onRetry bool) ( + headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, err error, +) { + _ = onRetry + model := openAIReq.Model + pathSuffix := buildGCPModelPathSuffix(GCPModelPublisherAnthropic, model, GCPMethodGenerateContent) + + // TODO: Implement actual translation from OpenAI to Anthropic request. + + headerMutation, bodyMutation = buildGCPRequestMutations(pathSuffix, nil) + return headerMutation, bodyMutation, nil +} + +// ResponseHeaders implements [Translator.ResponseHeaders]. +func (o *openAIToGCPAnthropicTranslatorV1ChatCompletion) ResponseHeaders(headers map[string]string) ( + headerMutation *extprocv3.HeaderMutation, err error, +) { + // TODO: Implement header transformations if needed + _ = headers + return nil, nil +} + +// ResponseError implements [Translator.ResponseError]. +// This method translates GCP Anthropic API errors to OpenAI-compatible error formats. +func (o *openAIToGCPAnthropicTranslatorV1ChatCompletion) ResponseError(respHeaders map[string]string, body interface{}) ( + headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, err error, +) { + // TODO: Implement error translation + _, _ = respHeaders, body + return nil, nil, nil +} + +func (o *openAIToGCPAnthropicTranslatorV1ChatCompletion) ResponseBody(respHeaders map[string]string, body io.Reader, endOfStream bool) ( + headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tokenUsage LLMTokenUsage, err error, +) { + // TODO: Implement response translation from Anthropic to OpenAI format + _, _, _ = respHeaders, body, endOfStream + return nil, nil, LLMTokenUsage{}, nil +} diff --git a/internal/extproc/translator/openai_gcpanthropic_test.go b/internal/extproc/translator/openai_gcpanthropic_test.go new file mode 100644 index 0000000000..36c4d92c95 --- /dev/null +++ b/internal/extproc/translator/openai_gcpanthropic_test.go @@ -0,0 +1,279 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "bytes" + "testing" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" +) + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_RequestBody(t *testing.T) { + defaultHeaderMut := &extprocv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{ + { + Header: &corev3.HeaderValue{ + Key: ":path", + RawValue: []byte("publishers/anthropic/models/claude-3:generateContent"), + }, + }, + }, + } + + tests := []struct { + name string + raw []byte + input *openai.ChatCompletionRequest + onRetry bool + wantError bool + wantHeaderMut *extprocv3.HeaderMutation + wantBodyMut *extprocv3.BodyMutation + }{ + { + name: "basic request", + input: &openai.ChatCompletionRequest{ + Stream: false, + Model: "claude-3", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + Value: openai.ChatCompletionSystemMessageParam{ + Content: openai.StringOrArray{ + Value: "You are a helpful assistant", + }, + }, + Type: openai.ChatMessageRoleSystem, + }, + { + Value: openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "Tell me about AI Gateways", + }, + }, + Type: openai.ChatMessageRoleUser, + }, + }, + }, + wantError: false, + wantHeaderMut: defaultHeaderMut, + wantBodyMut: nil, + }, + { + name: "streaming request", + input: &openai.ChatCompletionRequest{ + Stream: true, + Model: "claude-3", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + Value: openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "Explain streaming responses", + }, + }, + Type: openai.ChatMessageRoleUser, + }, + }, + }, + wantError: false, + wantHeaderMut: defaultHeaderMut, + wantBodyMut: nil, + }, + { + name: "retry request", + input: &openai.ChatCompletionRequest{Model: "claude-3"}, + onRetry: true, + wantError: false, + wantHeaderMut: defaultHeaderMut, + wantBodyMut: nil, + }, + } + + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator() + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + headerMut, bodyMut, err := translator.RequestBody(tc.raw, tc.input, tc.onRetry) + + if tc.wantError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + + if diff := cmp.Diff(tc.wantHeaderMut, headerMut, cmpopts.IgnoreUnexported(extprocv3.HeaderMutation{}, corev3.HeaderValueOption{}, corev3.HeaderValue{})); diff != "" { + t.Errorf("HeaderMutation mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tc.wantBodyMut, bodyMut); diff != "" { + t.Errorf("BodyMutation mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseHeaders(t *testing.T) { + tests := []struct { + name string + headers map[string]string + wantError bool + wantHeaderMut *extprocv3.HeaderMutation + }{ + { + name: "empty headers", + headers: map[string]string{}, + wantError: false, + wantHeaderMut: nil, + }, + { + name: "with content-type", + headers: map[string]string{ + "content-type": "application/json", + }, + wantError: false, + wantHeaderMut: nil, + }, + { + name: "with status", + headers: map[string]string{ + ":status": "200", + }, + wantError: false, + wantHeaderMut: nil, + }, + } + + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator() + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + headerMut, err := translator.ResponseHeaders(tc.headers) + + if tc.wantError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + + if diff := cmp.Diff(tc.wantHeaderMut, headerMut); diff != "" { + t.Errorf("HeaderMutation mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestOpenAIToGCPAnthropicTranslatorV1ChatCompletion_ResponseBody(t *testing.T) { + tests := []struct { + name string + respHeaders map[string]string + body string + endOfStream bool + wantError bool + wantHeaderMut *extprocv3.HeaderMutation + wantBodyMut *extprocv3.BodyMutation + wantTokenUsage LLMTokenUsage + }{ + { + name: "successful response", + respHeaders: map[string]string{ + "content-type": "application/json", + }, + body: `{ + "id": "resp-1234567890", + "model": "claude-3-opus-20240229", + "type": "message", + "role": "assistant", + "content": [ + { + "type": "text", + "text": "AI Gateways act as intermediaries between clients and LLM services." + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 15 + } + }`, + endOfStream: true, + wantError: false, + wantHeaderMut: nil, + wantBodyMut: nil, + wantTokenUsage: LLMTokenUsage{ + InputTokens: 0, + OutputTokens: 0, + TotalTokens: 0, + }, + }, + { + name: "streaming chunk", + respHeaders: map[string]string{ + "content-type": "application/json", + }, + body: `{ + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": "AI" + } + }`, + endOfStream: false, + wantError: false, + wantHeaderMut: nil, + wantBodyMut: nil, + wantTokenUsage: LLMTokenUsage{}, + }, + { + name: "empty response", + respHeaders: map[string]string{ + "content-type": "application/json", + }, + body: `{}`, + endOfStream: true, + wantError: false, + wantHeaderMut: nil, + wantBodyMut: nil, + wantTokenUsage: LLMTokenUsage{}, + }, + } + + translator := NewChatCompletionOpenAIToGCPAnthropicTranslator() + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + reader := bytes.NewReader([]byte(tc.body)) + + headerMut, bodyMut, tokenUsage, err := translator.ResponseBody(tc.respHeaders, reader, tc.endOfStream) + + if tc.wantError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + + if diff := cmp.Diff(tc.wantHeaderMut, headerMut); diff != "" { + t.Errorf("HeaderMutation mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tc.wantBodyMut, bodyMut); diff != "" { + t.Errorf("BodyMutation mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tc.wantTokenUsage, tokenUsage); diff != "" { + t.Errorf("TokenUsage mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/internal/extproc/translator/openai_gcpvertexai.go b/internal/extproc/translator/openai_gcpvertexai.go new file mode 100644 index 0000000000..c86c843031 --- /dev/null +++ b/internal/extproc/translator/openai_gcpvertexai.go @@ -0,0 +1,58 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 + +package translator + +import ( + "io" + + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" +) + +// NewChatCompletionOpenAIToGCPVertexAITranslator implements [Factory] for OpenAI to GCP Gemini translation. +// This translator converts OpenAI ChatCompletion API requests to GCP Gemini API format. +func NewChatCompletionOpenAIToGCPVertexAITranslator() OpenAIChatCompletionTranslator { + return &openAIToGCPVertexAITranslatorV1ChatCompletion{} +} + +type openAIToGCPVertexAITranslatorV1ChatCompletion struct{} + +// RequestBody implements [Translator.RequestBody] for GCP Gemini. +// This method translates an OpenAI ChatCompletion request to a GCP Gemini API request. +func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) RequestBody(_ []byte, openAIReq *openai.ChatCompletionRequest, onRetry bool) ( + headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, err error, +) { + _, _ = openAIReq, onRetry + pathSuffix := buildGCPModelPathSuffix(GCPModelPublisherGoogle, openAIReq.Model, GCPMethodGenerateContent) + + // TODO: Implement actual translation from OpenAI to Gemini request. + + headerMutation, bodyMutation = buildGCPRequestMutations(pathSuffix, nil) + return headerMutation, bodyMutation, nil +} + +// ResponseHeaders implements [Translator.ResponseHeaders]. +func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) ResponseHeaders(headers map[string]string) ( + headerMutation *extprocv3.HeaderMutation, err error, +) { + // TODO: Implement if needed. + _ = headers + return nil, nil +} + +// ResponseBody implements [Translator.ResponseBody] for GCP Gemini. +// This method translates a GCP Gemini API response to the OpenAI ChatCompletion format. +func (o *openAIToGCPVertexAITranslatorV1ChatCompletion) ResponseBody(respHeaders map[string]string, body io.Reader, endOfStream bool) ( + headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tokenUsage LLMTokenUsage, err error, +) { + // TODO: Implement response body translation from GCP Gemini to OpenAI format + _, _, _ = respHeaders, body, endOfStream + return nil, nil, LLMTokenUsage{}, nil +} diff --git a/internal/extproc/translator/openai_gcpvertexai_test.go b/internal/extproc/translator/openai_gcpvertexai_test.go new file mode 100644 index 0000000000..6e7e915ddc --- /dev/null +++ b/internal/extproc/translator/openai_gcpvertexai_test.go @@ -0,0 +1,239 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "bytes" + "testing" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" +) + +func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_RequestBody(t *testing.T) { + tests := []struct { + name string + input openai.ChatCompletionRequest + onRetry bool + wantError bool + wantHeaderMut *extprocv3.HeaderMutation + wantBodyMut *extprocv3.BodyMutation + }{ + { + name: "basic request", + input: openai.ChatCompletionRequest{ + Stream: false, + Model: "gemini-pro", + Messages: []openai.ChatCompletionMessageParamUnion{ + { + Value: openai.ChatCompletionSystemMessageParam{ + Content: openai.StringOrArray{ + Value: "You are a helpful assistant", + }, + }, + Type: openai.ChatMessageRoleSystem, + }, + { + Value: openai.ChatCompletionUserMessageParam{ + Content: openai.StringOrUserRoleContentUnion{ + Value: "Tell me about AI Gateways", + }, + }, + Type: openai.ChatMessageRoleUser, + }, + }, + }, + onRetry: false, + wantError: false, + // Since these are stub implementations, we expect nil mutations + wantHeaderMut: &extprocv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{ + { + Header: &corev3.HeaderValue{ + Key: ":path", + RawValue: []byte("publishers/google/models/gemini-pro:generateContent"), + }, + }, + }, + }, + wantBodyMut: nil, + }, + } + + translator := NewChatCompletionOpenAIToGCPVertexAITranslator() + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + headerMut, bodyMut, err := translator.RequestBody(nil, &tc.input, tc.onRetry) + if tc.wantError { + assert.Error(t, err) + return + } + require.NoError(t, err) + + if diff := cmp.Diff(tc.wantHeaderMut, headerMut, cmpopts.IgnoreUnexported(extprocv3.HeaderMutation{}, corev3.HeaderValueOption{}, corev3.HeaderValue{})); diff != "" { + t.Errorf("HeaderMutation mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tc.wantBodyMut, bodyMut); diff != "" { + t.Errorf("BodyMutation mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_ResponseHeaders(t *testing.T) { + tests := []struct { + name string + headers map[string]string + wantError bool + wantHeaderMut *extprocv3.HeaderMutation + }{ + { + name: "basic headers", + headers: map[string]string{ + "content-type": "application/json", + }, + wantError: false, + wantHeaderMut: nil, + }, + // TODO: Add more test cases when implementation is ready + } + + translator := NewChatCompletionOpenAIToGCPVertexAITranslator() + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + headerMut, err := translator.ResponseHeaders(tc.headers) + if tc.wantError { + assert.Error(t, err) + return + } + require.NoError(t, err) + + if diff := cmp.Diff(tc.wantHeaderMut, headerMut); diff != "" { + t.Errorf("HeaderMutation mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestOpenAIToGCPVertexAITranslatorV1ChatCompletion_ResponseBody(t *testing.T) { + tests := []struct { + name string + respHeaders map[string]string + body string + endOfStream bool + wantError bool + wantHeaderMut *extprocv3.HeaderMutation + wantBodyMut *extprocv3.BodyMutation + wantTokenUsage LLMTokenUsage + }{ + { + name: "successful response", + respHeaders: map[string]string{ + "content-type": "application/json", + }, + body: `{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "AI Gateways act as intermediaries between clients and LLM services." + } + ] + }, + "finishReason": "STOP", + "safetyRatings": [] + } + ], + "promptFeedback": { + "safetyRatings": [] + }, + "usage": { + "promptTokens": 10, + "candidatesTokens": 15, + "totalTokens": 25 + } + }`, + endOfStream: true, + wantError: false, + wantHeaderMut: nil, + wantBodyMut: nil, + wantTokenUsage: LLMTokenUsage{ + InputTokens: 0, + OutputTokens: 0, + TotalTokens: 0, + }, + }, + { + name: "streaming chunk", + respHeaders: map[string]string{ + "content-type": "application/json", + }, + body: `{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "AI" + } + ] + } + } + ] + }`, + endOfStream: false, + wantError: false, + wantHeaderMut: nil, + wantBodyMut: nil, + wantTokenUsage: LLMTokenUsage{}, + }, + { + name: "empty response", + respHeaders: map[string]string{ + "content-type": "application/json", + }, + body: `{}`, + endOfStream: true, + wantError: false, + wantHeaderMut: nil, + wantBodyMut: nil, + wantTokenUsage: LLMTokenUsage{}, + }, + } + + translator := NewChatCompletionOpenAIToGCPVertexAITranslator() + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + reader := bytes.NewReader([]byte(tc.body)) + headerMut, bodyMut, tokenUsage, err := translator.ResponseBody(tc.respHeaders, reader, tc.endOfStream) + if tc.wantError { + assert.Error(t, err) + return + } + require.NoError(t, err) + + if diff := cmp.Diff(tc.wantHeaderMut, headerMut); diff != "" { + t.Errorf("HeaderMutation mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tc.wantBodyMut, bodyMut); diff != "" { + t.Errorf("BodyMutation mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tc.wantTokenUsage, tokenUsage); diff != "" { + t.Errorf("TokenUsage mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/manifests/charts/ai-gateway-crds-helm/templates/aigateway.envoyproxy.io_backendsecuritypolicies.yaml b/manifests/charts/ai-gateway-crds-helm/templates/aigateway.envoyproxy.io_backendsecuritypolicies.yaml index 4140188e07..6b30cc87df 100644 --- a/manifests/charts/ai-gateway-crds-helm/templates/aigateway.envoyproxy.io_backendsecuritypolicies.yaml +++ b/manifests/charts/ai-gateway-crds-helm/templates/aigateway.envoyproxy.io_backendsecuritypolicies.yaml @@ -2469,6 +2469,14 @@ spec: description: GCPCredentials is a mechanism to access a backend(s). GCP specific logic will be applied. properties: + projectName: + description: ProjectName is the GCP project name. + minLength: 1 + type: string + region: + description: Region is the GCP region associated with the policy. + minLength: 1 + type: string workLoadIdentityFederationConfig: description: WorkLoadIdentityFederationConfig is the configuration for the GCP Workload Identity Federation. @@ -3668,6 +3676,8 @@ spec: - workloadIdentityProvider type: object required: + - projectName + - region - workLoadIdentityFederationConfig type: object type: diff --git a/manifests/charts/ai-gateway-helm/values.yaml b/manifests/charts/ai-gateway-helm/values.yaml index 6b42abddb6..e15ef3c0ec 100644 --- a/manifests/charts/ai-gateway-helm/values.yaml +++ b/manifests/charts/ai-gateway-helm/values.yaml @@ -54,6 +54,9 @@ controller: # Azure authentication request will be configured to use AI_GATEWAY_AZURE_PROXY_URL proxy if set. # - key: AI_GATEWAY_AZURE_PROXY_URL # value: some-proxy-placeholder + # GCP authentication request will be configured to use AI_GATEWAY_GCP_AUTH_PROXY_URL proxy if set. + # - key: AI_GATEWAY_GCP_AUTH_PROXY_URL + # value: some-proxy-placeholder podEnv: {} # Example of volumes # - mountPath: /placeholder/path diff --git a/site/docs/api/api.mdx b/site/docs/api/api.mdx index c879eabfb5..22ef60f4bc 100644 --- a/site/docs/api/api.mdx +++ b/site/docs/api/api.mdx @@ -853,6 +853,16 @@ BackendSecurityPolicyGCPCredentials contains the supported authentication mechan