Skip to content

Commit

Permalink
Merge pull request #5 from DevXT-LLC/Add-google-auth
Browse files Browse the repository at this point in the history
Add SSO (Single Sign On) Providers
  • Loading branch information
Josh-XT authored Jun 14, 2024
2 parents cce0535 + e61d5d2 commit 82298c5
Show file tree
Hide file tree
Showing 181 changed files with 10,703 additions and 26 deletions.
19 changes: 19 additions & 0 deletions DB.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,25 @@ class User(Base):
is_active = Column(Boolean, default=True)


class UserOAuth(Base):
__tablename__ = "user_oauth"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id"))
user = relationship("User")
provider_id = Column(UUID(as_uuid=True), ForeignKey("oauth_provider.id"))
provider = relationship("OAuthProvider")
access_token = Column(String, default="", nullable=False)
refresh_token = Column(String, default="", nullable=False)
created_at = Column(DateTime, server_default=func.now())
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())


class OAuthProvider(Base):
__tablename__ = "oauth_provider"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String, default="", nullable=False)


class FailedLogins(Base):
__tablename__ = "failed_logins"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
Expand Down
9 changes: 4 additions & 5 deletions Globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ def getenv(var_name: str):
"MODE": "production",
"DATABASE_TYPE": "postgres",
"ALLOWED_DOMAINS": "*",
"ENCRYPTION_SECRET": "n0ne",
"ENCRYPTION_SECRET": "it-is-a-secret-to-everybody",
"APP_NAME": "Magical Auth",
"AUTH_PROVIDER": "magicalauth",
"MAGIC_LINK_URL": "https://localhost:8507/",
"MAGIC_LINK_URL": "https://localhost:8519",
"LOG_LEVEL": "INFO",
"LOG_FORMAT": "%(asctime)s | %(levelname)s | %(message)s",
"UVICORN_WORKERS": 1,
Expand All @@ -23,9 +23,8 @@ def getenv(var_name: str):
),
"DATABASE_USER": "postgres",
"DATABASE_PASSWORD": "postgres",
"DATABASE_HOST": "notesdb",
"DATABASE_HOST": "magicalauthdb",
"DATABASE_PORT": "5432",
"SELECTED_EHR": "None",
}
default_value = default_values[var_name] if var_name in default_values else None
default_value = default_values[var_name] if var_name in default_values else ""
return os.getenv(var_name, default_value)
89 changes: 87 additions & 2 deletions MagicalAuth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from DB import User, FailedLogins, get_session
from DB import User, FailedLogins, UserOAuth, OAuthProvider, get_session
from Models import UserInfo, Register, Login
from fastapi import Header, HTTPException
from Globals import getenv
from OAuth2Providers import get_sso_provider
from datetime import datetime, timedelta
from fastapi import HTTPException
from sendgrid import SendGridAPIClient
Expand Down Expand Up @@ -184,7 +185,13 @@ def count_failed_logins(self):
session.close()
return failed_logins

def send_magic_link(self, ip_address, login: Login, referrer=None):
def send_magic_link(
self,
ip_address,
login: Login,
referrer=None,
send_link: bool = True,
):
self.email = login.email.lower()
session = get_session()
user = session.query(User).filter(User.email == self.email).first()
Expand Down Expand Up @@ -243,6 +250,7 @@ def send_magic_link(self, ip_address, login: Login, referrer=None):
and str(getenv("SENDGRID_API_KEY")).lower() != "none"
and getenv("SENDGRID_FROM_EMAIL") != ""
and str(getenv("SENDGRID_FROM_EMAIL")).lower() != "none"
and send_link
):
send_email(
email=self.email,
Expand Down Expand Up @@ -364,3 +372,80 @@ def delete_user(self):
session.commit()
session.close()
return "User deleted successfully"

def sso(
self,
provider,
code,
ip_address,
referrer=None,
):
if not referrer:
referrer = getenv("MAGIC_LINK_URL")
provider = str(provider).lower()
sso_data = None
sso_data = get_sso_provider(provider=provider, code=code, redirect_uri=referrer)
if not sso_data:
raise HTTPException(
status_code=400,
detail=f"Failed to get user data from {provider.capitalize()}.",
)
if not sso_data.access_token:
raise HTTPException(
status_code=400,
detail=f"Failed to get access token from {provider.capitalize()}.",
)
user_data = sso_data.user_info
access_token = sso_data.access_token
refresh_token = sso_data.refresh_token
self.email = str(user_data["email"]).lower()
if not user_data:
logging.warning(f"Error on {provider.capitalize()}: {user_data}")
raise HTTPException(
status_code=400,
detail=f"Failed to get user data from {provider.capitalize()}.",
)
session = get_session()
user = session.query(User).filter(User.email == self.email).first()
if not user:
register = Register(
email=self.email,
first_name=user_data["first_name"] if "first_name" in user_data else "",
last_name=user_data["last_name"] if "last_name" in user_data else "",
)
mfa_token = self.register(new_user=register)
# Create the UserOAuth record
user = session.query(User).filter(User.email == self.email).first()
provider = (
session.query(OAuthProvider)
.filter(OAuthProvider.name == provider)
.first()
)
if not provider:
provider = OAuthProvider(name=provider)
session.add(provider)
user_oauth = UserOAuth(
user_id=user.id,
provider_id=provider.id,
access_token=access_token,
refresh_token=refresh_token,
)
session.add(user_oauth)
else:
mfa_token = user.mfa_token
user_oauth = (
session.query(UserOAuth).filter(UserOAuth.user_id == user.id).first()
)
if user_oauth:
user_oauth.access_token = access_token
user_oauth.refresh_token = refresh_token
session.commit()
session.close()
totp = pyotp.TOTP(mfa_token)
login = Login(email=self.email, token=totp.now())
return self.send_magic_link(
ip_address=ip_address,
login=login,
referrer=referrer,
send_link=False,
)
9 changes: 5 additions & 4 deletions Models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic import BaseModel
from typing import Optional


# Auth user models
Expand All @@ -9,10 +10,10 @@ class Login(BaseModel):

class Register(BaseModel):
email: str
first_name: str
last_name: str
company_name: str
job_title: str
first_name: Optional[str] = ""
last_name: Optional[str] = ""
company_name: Optional[str] = ""
job_title: Optional[str] = ""


class UserInfo(BaseModel):
Expand Down
Loading

0 comments on commit 82298c5

Please sign in to comment.