Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow oauth configuration per site and backend #542

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions common/djangoapps/third_party_auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,13 +361,12 @@ class OAuth2ProviderConfig(ProviderConfig):

.. no_pii:
"""
# We are keying the provider config by backend_name here as suggested in the python social
# auth documentation. In order to reuse a backend for a second provider, a subclass can be
# created with seperate name.
# We are keying the provider config by backend_name and site_id to support configuration per site.
# In order to reuse a backend for a second provider, a subclass can be created with seperate name.
# example:
# class SecondOpenIDProvider(OpenIDAuth):
# name = "second-openId-provider"
KEY_FIELDS = ('backend_name',)
KEY_FIELDS = ('site_id', 'backend_name')
navinkarkera marked this conversation as resolved.
Show resolved Hide resolved
prefix = 'oa2'
backend_name = models.CharField(
max_length=50, blank=False, db_index=True,
Expand Down Expand Up @@ -396,6 +395,29 @@ class Meta:
verbose_name = "Provider Configuration (OAuth)"
verbose_name_plural = verbose_name

@classmethod
def current(cls, *args):
"""
Get the current config model for the provider according to the given backend and the current
site.
"""
site_id = Site.objects.get_current(get_current_request()).id
return super(OAuth2ProviderConfig, cls).current(site_id, *args)

@property
def provider_id(self):
"""
Unique string key identifying this provider. Must be URL and css class friendly.
Ignoring site_id as the config is filtered using current method which fetches the configuration for the current
site_id.
"""
assert self.prefix is not None
return "-".join((self.prefix, ) + tuple(
str(getattr(self, field))
for field in self.KEY_FIELDS
if field != 'site_id'
))

def clean(self):
""" Standardize and validate fields """
super().clean()
Expand Down
4 changes: 2 additions & 2 deletions common/djangoapps/third_party_auth/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ def user_details_force_sync(auth_entry, strategy, details, user=None, *args, **k
This step is controlled by the `sync_learner_profile_data` flag on the provider's configuration.
"""
current_provider = provider.Registry.get_from_pipeline({'backend': strategy.request.backend.name, 'kwargs': kwargs})
if user and current_provider.sync_learner_profile_data:
if user and current_provider and current_provider.sync_learner_profile_data:
# Keep track of which incoming values get applied.
changed = {}

Expand Down Expand Up @@ -931,7 +931,7 @@ def set_id_verification_status(auth_entry, strategy, details, user=None, *args,
Use the user's authentication with the provider, if configured, as evidence of their identity being verified.
"""
current_provider = provider.Registry.get_from_pipeline({'backend': strategy.request.backend.name, 'kwargs': kwargs})
if user and current_provider.enable_sso_id_verification:
if user and current_provider and current_provider.enable_sso_id_verification:
# Get previous valid, non expired verification attempts for this SSO Provider and user
verifications = SSOVerification.objects.filter(
user=user,
Expand Down
11 changes: 9 additions & 2 deletions common/djangoapps/third_party_auth/tests/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_providers_displayed_for_login(self):
assert no_log_in_provider.provider_id not in provider_ids
assert normal_provider.provider_id in provider_ids

def test_tpa_hint_provider_displayed_for_login(self):
def test_tpa_hint_hidden_provider_displayed_for_login(self):
kaustavb12 marked this conversation as resolved.
Show resolved Hide resolved
"""
Tests to ensure that an enabled-but-not-visible provider is presented
for use in the UI when the "tpa_hint" parameter is specified
Expand All @@ -128,6 +128,7 @@ def test_tpa_hint_provider_displayed_for_login(self):
]
assert hidden_provider.provider_id in provider_ids

def test_tpa_hint_exp_hidden_provider_displayed_for_login(self):
# New providers are hidden (ie, not flagged as 'visible') by default
# The tpa_hint parameter should work for these providers as well
implicitly_hidden_provider = self.configure_linkedin_provider(enabled=True)
Expand All @@ -137,6 +138,7 @@ def test_tpa_hint_provider_displayed_for_login(self):
]
assert implicitly_hidden_provider.provider_id in provider_ids

def test_tpa_hint_disabled_hidden_provider_displayed_for_login(self):
# Disabled providers should not be matched in tpa_hint scenarios
disabled_provider = self.configure_twitter_provider(visible=True, enabled=False)
provider_ids = [
Expand All @@ -145,6 +147,7 @@ def test_tpa_hint_provider_displayed_for_login(self):
]
assert disabled_provider.provider_id not in provider_ids

def test_tpa_hint_no_log_hidden_provider_displayed_for_login(self):
# Providers not utilized for learner authentication should not match tpa_hint
no_log_in_provider = self.configure_lti_provider()
provider_ids = [
Expand Down Expand Up @@ -201,14 +204,18 @@ def test_oauth2_enabled_only_for_supplied_backend(self):
def test_get_returns_none_if_provider_id_is_none(self):
assert provider.Registry.get(None) is None

def test_get_returns_none_if_provider_not_enabled(self):
def test_get_returns_none_if_provider_not_enabled_change(self):
linkedin_provider_id = "oa2-linkedin-oauth2"
# At this point there should be no configuration entries at all so no providers should be enabled
assert provider.Registry.enabled() == []
assert provider.Registry.get(linkedin_provider_id) is None
# Now explicitly disabled this provider:
self.configure_linkedin_provider(enabled=False)
assert provider.Registry.get(linkedin_provider_id) is None

def test_get_returns_provider_if_provider_enabled(self):
"""Test to ensure that Registry gets enabled providers."""
linkedin_provider_id = "oa2-linkedin-oauth2"
self.configure_linkedin_provider(enabled=True)
assert provider.Registry.get(linkedin_provider_id).provider_id == linkedin_provider_id

Expand Down
Loading