diff --git a/lib/services/github.go b/lib/services/github.go index 1294b7db5cc7a..0444cc5f0b2ac 100644 --- a/lib/services/github.go +++ b/lib/services/github.go @@ -176,12 +176,20 @@ func MarshalOSSGithubConnector(githubConnector types.GithubConnector, opts ...Ma return nil, trace.Wrap(err) } - // Only return an error if the endpoint url is set and the build is OSS - // so that the enterprise marshaler can call this marshaler to produce - // the final output without receiving an error. - if modules.GetModules().IsOSSBuild() && - githubConnector.Spec.EndpointURL != "" { - return nil, fmt.Errorf("GitHub endpoint URL is set: %w", ErrRequiresEnterprise) + // Return an error for OSS build if the endpoint url is set, but it is + // not the public GitHub endpoint. Empty endpoint url is also allowed. + // + // Note that the enterprise marshaler also calls this marshaler to + // produce the final output. + if modules.GetModules().IsOSSBuild() { + if githubConnector.Spec.EndpointURL != "" && + githubConnector.Spec.EndpointURL != types.GithubURL { + return nil, fmt.Errorf("GitHub endpoint URL is set: %w", ErrRequiresEnterprise) + } + if githubConnector.Spec.APIEndpointURL != "" && + githubConnector.Spec.APIEndpointURL != types.GithubAPIURL { + return nil, fmt.Errorf("GitHub API endpoint URL is set: %w", ErrRequiresEnterprise) + } } return utils.FastMarshal(maybeResetProtoRevision(cfg.PreserveRevision, githubConnector)) default: diff --git a/lib/services/github_test.go b/lib/services/github_test.go index c14ea618e69f3..9c4169126caab 100644 --- a/lib/services/github_test.go +++ b/lib/services/github_test.go @@ -68,12 +68,13 @@ func TestUnmarshal(t *testing.T) { } func TestMarshal(t *testing.T) { - connector, err := types.NewGithubConnector("github", types.GithubConnectorSpecV3{ - ClientID: "aaa", - ClientSecret: "bbb", - RedirectURL: "https://localhost:3080/v1/webapi/github/callback", - Display: "GitHub", - EndpointURL: "https://github.com", + connectorWithPublicEndpoint, err := types.NewGithubConnector("github", types.GithubConnectorSpecV3{ + ClientID: "aaa", + ClientSecret: "bbb", + RedirectURL: "https://localhost:3080/v1/webapi/github/callback", + Display: "GitHub", + EndpointURL: "https://github.com", + APIEndpointURL: "https://api.github.com", TeamsToRoles: []types.TeamRolesMapping{ { Organization: "gravitational", @@ -84,20 +85,46 @@ func TestMarshal(t *testing.T) { }) require.NoError(t, err) - t.Run("oss", func(t *testing.T) { - _, err = MarshalGithubConnector(connector) + connectorWithPrivateEndpoint, err := types.NewGithubConnector("github", types.GithubConnectorSpecV3{ + ClientID: "aaa", + ClientSecret: "bbb", + RedirectURL: "https://localhost:3080/v1/webapi/github/callback", + Display: "GitHub", + EndpointURL: "https://my-private-github.com", + APIEndpointURL: "https://api.my-private-github.com", + TeamsToRoles: []types.TeamRolesMapping{ + { + Organization: "gravitational", + Team: "admins", + Roles: []string{teleport.PresetAccessRoleName}, + }, + }, + }) + require.NoError(t, err) + + t.Run("oss with public endpoint", func(t *testing.T) { + marshaled, err := MarshalGithubConnector(connectorWithPublicEndpoint) + require.NoError(t, err) + + unmarshaled, err := UnmarshalGithubConnector(marshaled) + require.NoError(t, err) + require.Empty(t, cmp.Diff(connectorWithPublicEndpoint, unmarshaled)) + }) + + t.Run("oss with private endpoint", func(t *testing.T) { + _, err := MarshalGithubConnector(connectorWithPrivateEndpoint) require.ErrorIs(t, err, ErrRequiresEnterprise, "expected ErrRequiresEnterprise, got %T", err) }) t.Run("enterprise", func(t *testing.T) { modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise}) - marshaled, err := MarshalGithubConnector(connector) + marshaled, err := MarshalGithubConnector(connectorWithPrivateEndpoint) require.NoError(t, err) unmarshaled, err := UnmarshalGithubConnector(marshaled) require.NoError(t, err) - require.Empty(t, cmp.Diff(connector, unmarshaled)) + require.Empty(t, cmp.Diff(connectorWithPrivateEndpoint, unmarshaled)) }) }