Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 64 additions & 15 deletions api/types/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,15 @@ const (
type Integration interface {
ResourceWithLabels

// CanChangeStateTo checks if the current Integration can be updated for the provided integration.
CanChangeStateTo(Integration) error

// GetAWSOIDCIntegrationSpec returns the `aws-oidc` spec fields.
GetAWSOIDCIntegrationSpec() *AWSOIDCIntegrationSpecV1
// SetAWSOIDCIntegrationSpec sets the `aws-oidc` spec fields.
SetAWSOIDCIntegrationSpec(*AWSOIDCIntegrationSpecV1)
// SetAWSOIDCRoleARN sets the RoleARN of the AWS OIDC Spec.
SetAWSOIDCRoleARN(string)
}

var _ ResourceWithLabels = (*IntegrationV1)(nil)
Expand Down Expand Up @@ -92,6 +97,19 @@ func (ig *IntegrationV1) CheckAndSetDefaults() error {
return trace.Wrap(ig.Spec.CheckAndSetDefaults())
}

// CanChangeStateTo checks if the current Integration can be updated for the provided integration.
func (ig *IntegrationV1) CanChangeStateTo(newState Integration) error {
if ig.SubKind != newState.GetSubKind() {
return trace.BadParameter("cannot update %q fields for a %q integration", newState.GetSubKind(), ig.SubKind)
}

if err := newState.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}

return nil
}

// CheckAndSetDefaults validates and sets default values for a integration.
func (s *IntegrationSpecV1) CheckAndSetDefaults() error {
if s.SubKindSpec == nil {
Expand Down Expand Up @@ -136,6 +154,19 @@ func (ig *IntegrationV1) SetAWSOIDCIntegrationSpec(awsOIDCSpec *AWSOIDCIntegrati
}
}

// SetAWSOIDCRoleARN sets the RoleARN of the AWS OIDC Spec.
func (ig *IntegrationV1) SetAWSOIDCRoleARN(roleARN string) {
currentSubSpec := ig.Spec.GetAWSOIDC()
if currentSubSpec == nil {
currentSubSpec = &AWSOIDCIntegrationSpecV1{}
}

currentSubSpec.RoleARN = roleARN
ig.Spec.SubKindSpec = &IntegrationSpecV1_AWSOIDC{
AWSOIDC: currentSubSpec,
}
}

// Integrations is a list of Integration resources.
type Integrations []Integration

Expand All @@ -161,7 +192,7 @@ func (igs Integrations) Swap(i, j int) { igs[i], igs[j] = igs[j], igs[i] }
// It is required because the Spec.SubKindSpec proto field is a oneof.
// This translates into two issues when generating golang code:
// - the Spec.SubKindSpec field in Go is an interface
// - there's no way to provide json tags for oneof fields, so instead of snake_case, we get CamelCase for the Spec.SubKindSpec field
// - it creates an extra field to store the oneof values
//
// Spec.SubKindSpec is an interface because it can have one of multiple values,
// even though there's only one type for now: aws_oidc.
Expand All @@ -170,16 +201,22 @@ func (igs Integrations) Swap(i, j int) { igs[i], igs[j] = igs[j], igs[i] }
// and then use its SubKind to provide a concrete type for the Spec.SubKindSpec field.
// Unmarshalling the remaining fields uses the standard json.Unmarshal over the Spec field.
//
// Spec.SubKindSpec is expecting the `SubKindSpec` json tag, however we are using snake_case everywhere.
// So, we create a local type that has the expected json tag (`sub_kind_spec`) and use it to unmarshal and then copy
// to the proper type.
// Spec.SubKindSpec is an extra field which only adds clutter
// This method pulls those fields into a higher level.
// So, instead of:
//
// spec.subkind_spec.aws_oidc.role_arn: xyz
//
// It will be:
//
// spec.aws_oidc.role_arn: xyz
func (ig *IntegrationV1) UnmarshalJSON(data []byte) error {
var integration IntegrationV1

d := struct {
ResourceHeader `json:""`
Spec struct {
RawSubKindSpec json.RawMessage `json:"subkind_spec"`
AWSOIDC json.RawMessage `json:"aws_oidc"`
} `json:"spec"`
}{}

Expand All @@ -190,20 +227,22 @@ func (ig *IntegrationV1) UnmarshalJSON(data []byte) error {

integration.ResourceHeader = d.ResourceHeader

var subkindSpec isIntegrationSpecV1_SubKindSpec
switch integration.SubKind {
case IntegrationSubKindAWSOIDC:
subkindSpec = &IntegrationSpecV1_AWSOIDC{}
subkindSpec := &IntegrationSpecV1_AWSOIDC{
AWSOIDC: &AWSOIDCIntegrationSpecV1{},
}

if err := json.Unmarshal(d.Spec.AWSOIDC, subkindSpec.AWSOIDC); err != nil {
return trace.Wrap(err)
}

integration.Spec.SubKindSpec = subkindSpec

default:
return trace.BadParameter("invalid subkind %q", integration.ResourceHeader.SubKind)
}

if err := json.Unmarshal(d.Spec.RawSubKindSpec, subkindSpec); err != nil {
return trace.Wrap(err)
}

integration.Spec.SubKindSpec = subkindSpec

if err := integration.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
Expand All @@ -220,12 +259,22 @@ func (ig *IntegrationV1) MarshalJSON() ([]byte, error) {
d := struct {
ResourceHeader `json:""`
Spec struct {
SubKindSpec isIntegrationSpecV1_SubKindSpec `json:"subkind_spec"`
AWSOIDC AWSOIDCIntegrationSpecV1 `json:"aws_oidc"`
} `json:"spec"`
}{}

d.ResourceHeader = ig.ResourceHeader
d.Spec.SubKindSpec = ig.Spec.SubKindSpec

switch ig.SubKind {
case IntegrationSubKindAWSOIDC:
if ig.GetAWSOIDCIntegrationSpec() == nil {
return nil, trace.BadParameter("missing subkind data for %q subkind", ig.SubKind)
}

d.Spec.AWSOIDC = *ig.GetAWSOIDCIntegrationSpec()
Comment thread
marcoandredinis marked this conversation as resolved.
Outdated
default:
return nil, trace.BadParameter("invalid subkind %q", ig.SubKind)
}

out, err := json.Marshal(d)
return out, trace.Wrap(err)
Expand Down
20 changes: 4 additions & 16 deletions integration/conntest/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"encoding/base32"
"encoding/json"
"io"
"net"
"net/http"
"strings"
Expand Down Expand Up @@ -216,14 +215,8 @@ func TestDiagnoseConnectionForPostgresDatabases(t *testing.T) {
DialTimeout: time.Second,
InsecureSkipVerify: true,
}
resp, err := webPack.DoRequest(http.MethodPost, diagnoseConnectionEndpoint, diagnoseReq)
require.NoError(t, err)

respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)

defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode, string(respBody))
respStatusCode, respBody := webPack.DoRequest(t, http.MethodPost, diagnoseConnectionEndpoint, diagnoseReq)
require.Equal(t, http.StatusOK, respStatusCode, string(respBody))

var connectionDiagnostic ui.ConnectionDiagnostic
require.NoError(t, json.Unmarshal(respBody, &connectionDiagnostic))
Expand Down Expand Up @@ -307,13 +300,8 @@ func TestDiagnoseConnectionForPostgresDatabases(t *testing.T) {
TOTPCode: validToken,
},
}
resp, err := webPack.DoRequest(http.MethodPost, diagnoseConnectionEndpoint, diagnoseReq)
require.NoError(t, err)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)

defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode, string(respBody))
respStatusCode, respBody := webPack.DoRequest(t, http.MethodPost, diagnoseConnectionEndpoint, diagnoseReq)
require.Equal(t, http.StatusOK, respStatusCode, string(respBody))

var connectionDiagnostic ui.ConnectionDiagnostic
require.NoError(t, json.Unmarshal(respBody, &connectionDiagnostic))
Expand Down
11 changes: 2 additions & 9 deletions integration/db/database_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package db
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
Expand Down Expand Up @@ -85,14 +84,8 @@ func TestDatabaseServiceHeartbeat(t *testing.T) {

// List Database Services
listDBServicesEndpoint := strings.Join([]string{"sites", "$site", "databaseservices"}, "/")
resp, err := webPack.DoRequest(http.MethodGet, listDBServicesEndpoint, nil)
require.NoError(t, err)

respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)

defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode, string(respBody))
respStatusCode, respBody := webPack.DoRequest(t, http.MethodGet, listDBServicesEndpoint, nil)
require.Equal(t, http.StatusOK, respStatusCode, string(respBody))

var listResp listDatabaseServicesResp
require.NoError(t, json.Unmarshal(respBody, &listResp))
Expand Down
38 changes: 19 additions & 19 deletions integration/helpers/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/httplib/csrf"
Expand Down Expand Up @@ -100,12 +100,11 @@ func LoginWebClient(t *testing.T, host, username, password string) *WebClientPac
bearerToken: csResp.Token,
}

resp, err = webClient.DoRequest(http.MethodGet, "sites", nil)
require.NoError(t, err)
defer resp.Body.Close()
respStatusCode, bs := webClient.DoRequest(t, http.MethodGet, "sites", nil)
require.Equal(t, http.StatusOK, respStatusCode, string(bs))

var clusters []ui.Cluster
require.NoError(t, json.NewDecoder(resp.Body).Decode(&clusters))
require.NoError(t, json.Unmarshal(bs, &clusters), string(bs))
require.NotEmpty(t, clusters)

webClient.clusterName = clusters[0].Name
Expand All @@ -114,24 +113,18 @@ func LoginWebClient(t *testing.T, host, username, password string) *WebClientPac

// DoRequest receives a method, endpoint and payload and sends an HTTP Request to the Teleport API.
// The endpoint must not contain the host neither the base path ('/v1/webapi/').
// Returns the http.Response.
func (w *WebClientPack) DoRequest(method, endpoint string, payload any) (*http.Response, error) {
// Status Code and Body are returned.
func (w *WebClientPack) DoRequest(t *testing.T, method, endpoint string, payload any) (int, []byte) {
endpoint = fmt.Sprintf("https://%s/v1/webapi/%s", w.host, endpoint)
endpoint = strings.ReplaceAll(endpoint, "$site", w.clusterName)
u := url.URL{
Scheme: "https",
Host: w.host,
Path: fmt.Sprintf("/v1/webapi/%s", endpoint),
}
u, err := url.Parse(endpoint)
require.NoError(t, err)

bs, err := json.Marshal(payload)
if err != nil {
return nil, trace.Wrap(err)
}
require.NoError(t, err)

req, err := http.NewRequest(method, u.String(), bytes.NewBuffer(bs))
if err != nil {
return nil, trace.Wrap(err)
}
require.NoError(t, err)

req.AddCookie(&http.Cookie{
Name: web.CookieName,
Expand All @@ -141,5 +134,12 @@ func (w *WebClientPack) DoRequest(method, endpoint string, payload any) (*http.R
req.Header.Add("Content-Type", "application/json")

resp, err := w.clt.Do(req)
return resp, trace.Wrap(err)
require.NoError(t, err)

defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)

return resp.StatusCode, body
}
Loading