Skip to content

Commit

Permalink
Updated NPM modules, added token validation to Engine, and fixed mino…
Browse files Browse the repository at this point in the history
…r logic issue in Planner utility
  • Loading branch information
DCMattyG committed Dec 22, 2023
1 parent d14e279 commit 1eca37e
Show file tree
Hide file tree
Showing 10 changed files with 1,214 additions and 952 deletions.
106 changes: 93 additions & 13 deletions engine/app/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,109 @@
from fastapi import Request, HTTPException

from requests import Session, adapters
from urllib3.util.retry import Retry
from cryptography.hazmat.primitives import serialization

import jwt
import time
import copy
import json

from app.routers.common.helper import (
cosmos_query
)

async def check_token_expired(request: Request):
now = int(time.time()) + 10
auth = request.headers.get('authorization')
from app.globals import globals

_session = None

async def fetch_jwks_keys():
global _session

if _session is None:
_session = Session()

retries = Retry(
total=5,
backoff_factor=0.1,
status_forcelist=[ 500, 502, 503, 504 ]
)

_session.mount('https://', adapters.HTTPAdapter(max_retries=retries))
_session.mount('http://', adapters.HTTPAdapter(max_retries=retries))

key_url = "https://" + globals.AUTHORITY_HOST + "/" + globals.TENANT_ID + "/discovery/v2.0/keys"

jwks = _session.get(key_url).json()

return jwks

async def get_token_auth_header(request: Request):
auth = request.headers.get("Authorization", None)

if not auth:
raise HTTPException(status_code=401, detail="Authorization header missing.")
raise HTTPException(status_code=401, detail="Authorization header is missing.")

parts = auth.split()

if parts[0].lower() != "bearer":
raise HTTPException(status_code=401, detail="Authorization header must start with 'Bearer'.")
elif len(parts) == 1:
raise HTTPException(status_code=401, detail="Token not found.")
elif len(parts) > 2:
raise HTTPException(status_code=401, detail="Authorization header must be of type Bearer token.")

token = parts[1]

user_assertion=auth.split(' ')[1]
return token

async def validate_token(request: Request):
try:
decoded = jwt.decode(user_assertion, options={"verify_signature": False})
except:
raise HTTPException(status_code=401, detail="Authorization token missing or invalid in header.")
token = await get_token_auth_header(request)
jwks = await fetch_jwks_keys()
unverified_header = jwt.get_unverified_header(token)

if(now >= int(decoded['exp'])):
raise HTTPException(status_code=401, detail="Token has expired.")
rsa_key = {}

request.state.tenant_id = decoded['tid']
for key in jwks["keys"]:
if key["kid"] == unverified_header["kid"]:
rsa_key = {
"kty": key["kty"],
"kid": key["kid"],
"use": key["use"],
"n": key["n"],
"e": key["e"]
}
except Exception:
raise HTTPException(status_code=401, detail="Unable to parse authorization token.")

await check_admin(request, decoded['oid'], decoded['tid'])
if rsa_key:
rsa_pem_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(rsa_key))
rsa_pem_key_bytes = rsa_pem_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)

try:
payload = jwt.decode(
token,
key=rsa_pem_key_bytes,
verify=True,
algorithms=["RS256"],
audience=globals.CLIENT_ID,
issuer="https://" + globals.AUTHORITY_HOST + "/" + globals.TENANT_ID + "/v2.0"
)
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token has expired.")
except jwt.MissingRequiredClaimError:
raise HTTPException(status_code=401, detail="Incorrect token claims, please check the audience and issuer.")
except Exception:
raise HTTPException(status_code=401, detail="Unable to parse authorization token.")
else:
raise HTTPException(status_code=401, detail="Unable to find appropriate signing key.")

request.state.tenant_id = payload['tid']

return payload

async def check_admin(request: Request, user_oid: str, user_tid: str):
admin_query = await cosmos_query("SELECT * FROM c WHERE c.type = 'admin'", user_tid)
Expand All @@ -44,6 +120,10 @@ async def check_admin(request: Request, user_oid: str, user_tid: str):

request.state.admin = True if is_admin else False

async def api_auth_checks(request: Request):
token_payload = await validate_token(request)
await check_admin(request, token_payload['oid'], token_payload['tid'])

async def get_admin(request: Request):
return request.state.admin

Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import uuid

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand All @@ -36,7 +36,7 @@
router = APIRouter(
prefix="/admin",
tags=["admin"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

async def new_admin_db(admin_list, exclusion_list, tenant_id):
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from netaddr import IPSet, IPNetwork

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand All @@ -45,7 +45,7 @@
router = APIRouter(
prefix="/azure",
tags=["azure"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

def str_to_list(input):
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from netaddr import IPNetwork

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand All @@ -35,7 +35,7 @@
router = APIRouter(
prefix="/internal",
tags=["internal"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

async def multi_helper(func, list, *args):
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from netaddr import IPSet, IPNetwork

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand Down Expand Up @@ -54,7 +54,7 @@
router = APIRouter(
prefix="/spaces",
tags=["spaces"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

async def valid_space_name_update(name, space_name, tenant_id):
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from netaddr import IPSet, IPNetwork

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_tenant_id
)

Expand All @@ -33,7 +33,7 @@
router = APIRouter(
prefix="/tools",
tags=["tools"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

@router.post(
Expand Down
4 changes: 2 additions & 2 deletions engine/app/routers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Union, List

from app.dependencies import (
check_token_expired,
api_auth_checks,
get_admin,
get_tenant_id
)
Expand All @@ -39,7 +39,7 @@
router = APIRouter(
prefix="/users",
tags=["users"],
dependencies=[Depends(check_token_expired)]
dependencies=[Depends(api_auth_checks)]
)

async def new_user(user_id, tenant_id):
Expand Down
Loading

0 comments on commit 1eca37e

Please sign in to comment.