Skip to content

Commit

Permalink
feat: allow oauth configuration per site and backend (#579)
Browse files Browse the repository at this point in the history
* feat: allow oauth configuration per site and backend

Allows admins to configure same oauth backend for multiple sites

(cherry picked from commit efd2ef9)

* fix: skip pipeline if oauth provider for site is not setup

(cherry picked from commit 29f8494)

* refactor: remove site_id from provider_id

(cherry picked from commit 236f7a5)

---------

Cherry pick of #542 

Upstream PR openedx#32656

[BB-7589](https://tasks.opencraft.com/browse/BB-7589)

---------

Co-authored-by: Navin Karkera <[email protected]>
  • Loading branch information
kaustavb12 and navinkarkera committed Aug 15, 2023
1 parent 13bda07 commit 86041f9
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
30 changes: 26 additions & 4 deletions common/djangoapps/third_party_auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,12 @@ class OAuth2ProviderConfig(ProviderConfig):
.. no_pii:
"""
# We are keying the provider config by backend_name here as suggested in the python social
# auth documentation. In order to reuse a backend for a second provider, a subclass can be
# created with seperate name.
# We are keying the provider config by backend_name and site_id to support configuration per site.
# In order to reuse a backend for a second provider, a subclass can be created with seperate name.
# example:
# class SecondOpenIDProvider(OpenIDAuth):
# name = "second-openId-provider"
KEY_FIELDS = ('backend_name',)
KEY_FIELDS = ('site_id', 'backend_name')
prefix = 'oa2'
backend_name = models.CharField(
max_length=50, blank=False, db_index=True,
Expand Down Expand Up @@ -401,6 +400,29 @@ class Meta:
verbose_name = "Provider Configuration (OAuth)"
verbose_name_plural = verbose_name

@classmethod
def current(cls, *args):
"""
Get the current config model for the provider according to the given backend and the current
site.
"""
site_id = Site.objects.get_current(get_current_request()).id
return super(OAuth2ProviderConfig, cls).current(site_id, *args)

@property
def provider_id(self):
"""
Unique string key identifying this provider. Must be URL and css class friendly.
Ignoring site_id as the config is filtered using current method which fetches the configuration for the current
site_id.
"""
assert self.prefix is not None
return "-".join((self.prefix, ) + tuple(
str(getattr(self, field))
for field in self.KEY_FIELDS
if field != 'site_id'
))

def clean(self):
""" Standardize and validate fields """
super().clean()
Expand Down
4 changes: 2 additions & 2 deletions common/djangoapps/third_party_auth/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ def user_details_force_sync(auth_entry, strategy, details, user=None, *args, **k
This step is controlled by the `sync_learner_profile_data` flag on the provider's configuration.
"""
current_provider = provider.Registry.get_from_pipeline({'backend': strategy.request.backend.name, 'kwargs': kwargs})
if user and current_provider.sync_learner_profile_data:
if user and current_provider and current_provider.sync_learner_profile_data:
# Keep track of which incoming values get applied.
changed = {}

Expand Down Expand Up @@ -930,7 +930,7 @@ def set_id_verification_status(auth_entry, strategy, details, user=None, *args,
Use the user's authentication with the provider, if configured, as evidence of their identity being verified.
"""
current_provider = provider.Registry.get_from_pipeline({'backend': strategy.request.backend.name, 'kwargs': kwargs})
if user and current_provider.enable_sso_id_verification:
if user and current_provider and current_provider.enable_sso_id_verification:
# Get previous valid, non expired verification attempts for this SSO Provider and user
verifications = SSOVerification.objects.filter(
user=user,
Expand Down
11 changes: 9 additions & 2 deletions common/djangoapps/third_party_auth/tests/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_providers_displayed_for_login(self):
assert no_log_in_provider.provider_id not in provider_ids
assert normal_provider.provider_id in provider_ids

def test_tpa_hint_provider_displayed_for_login(self):
def test_tpa_hint_hidden_provider_displayed_for_login(self):
"""
Tests to ensure that an enabled-but-not-visible provider is presented
for use in the UI when the "tpa_hint" parameter is specified
Expand All @@ -128,6 +128,7 @@ def test_tpa_hint_provider_displayed_for_login(self):
]
assert hidden_provider.provider_id in provider_ids

def test_tpa_hint_exp_hidden_provider_displayed_for_login(self):
# New providers are hidden (ie, not flagged as 'visible') by default
# The tpa_hint parameter should work for these providers as well
implicitly_hidden_provider = self.configure_linkedin_provider(enabled=True)
Expand All @@ -137,6 +138,7 @@ def test_tpa_hint_provider_displayed_for_login(self):
]
assert implicitly_hidden_provider.provider_id in provider_ids

def test_tpa_hint_disabled_hidden_provider_displayed_for_login(self):
# Disabled providers should not be matched in tpa_hint scenarios
disabled_provider = self.configure_twitter_provider(visible=True, enabled=False)
provider_ids = [
Expand All @@ -145,6 +147,7 @@ def test_tpa_hint_provider_displayed_for_login(self):
]
assert disabled_provider.provider_id not in provider_ids

def test_tpa_hint_no_log_hidden_provider_displayed_for_login(self):
# Providers not utilized for learner authentication should not match tpa_hint
no_log_in_provider = self.configure_lti_provider()
provider_ids = [
Expand Down Expand Up @@ -201,14 +204,18 @@ def test_oauth2_enabled_only_for_supplied_backend(self):
def test_get_returns_none_if_provider_id_is_none(self):
assert provider.Registry.get(None) is None

def test_get_returns_none_if_provider_not_enabled(self):
def test_get_returns_none_if_provider_not_enabled_change(self):
linkedin_provider_id = "oa2-linkedin-oauth2"
# At this point there should be no configuration entries at all so no providers should be enabled
assert provider.Registry.enabled() == []
assert provider.Registry.get(linkedin_provider_id) is None
# Now explicitly disabled this provider:
self.configure_linkedin_provider(enabled=False)
assert provider.Registry.get(linkedin_provider_id) is None

def test_get_returns_provider_if_provider_enabled(self):
"""Test to ensure that Registry gets enabled providers."""
linkedin_provider_id = "oa2-linkedin-oauth2"
self.configure_linkedin_provider(enabled=True)
assert provider.Registry.get(linkedin_provider_id).provider_id == linkedin_provider_id

Expand Down

0 comments on commit 86041f9

Please sign in to comment.