Skip to content

Commit

Permalink
fix: Minor
Browse files Browse the repository at this point in the history
  • Loading branch information
TomBursch committed Jan 3, 2025
1 parent d10efa8 commit 1d90e65
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 37 deletions.
1 change: 1 addition & 0 deletions backend/app/controller/auth/auth_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def check_if_token_revoked(jwt_header, jwt_payload: dict) -> bool:
token.save()
return token is None


# Register a callback function that takes whatever object is passed in as the
# identity when creating JWTs and converts it to a JSON serializable format.
@jwt.user_identity_loader
Expand Down
23 changes: 11 additions & 12 deletions backend/app/models/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,14 @@ def has_created_refresh_token(self) -> bool:
> 0
)

def delete_created_access_tokens(self, exclude_token_id=None):
def delete_created_access_tokens(self, commit=True):
if self.type != "refresh":
return
query = db.session.query(Token).filter(
Token.query.filter(
Token.refresh_token_id == self.id, Token.type == "access"
)
if exclude_token_id is not None:
query = query.filter(Token.id != exclude_token_id)
query.delete()
db.session.commit()
).delete()
if commit:
db.session.commit()

@classmethod
def create_access_token(
Expand Down Expand Up @@ -141,10 +139,10 @@ def create_refresh_token(

# Check if this refresh token has already been used to create another refresh token
if oldRefreshToken and oldRefreshToken.has_created_refresh_token():
for newer_token in db.session.query(Token).filter(
for newer_token in Token.query.filter(
Token.refresh_token_id == oldRefreshToken.id,
Token.type == "refresh"
):
).all():
newer_access_used = db.session.query(Token).filter(
Token.refresh_token_id == newer_token.id,
Token.type == "access",
Expand All @@ -161,11 +159,11 @@ def create_refresh_token(
)
else:
# Only invalidate the unused parallel refresh token chain
for token in db.session.query(Token).filter(
Token.query.filter(
Token.refresh_token_id == newer_token.id
).all():
db.session.delete(token)
).delete()
newer_token.type = "invalidated_refresh"
db.session.add(newer_token)

refreshToken = create_refresh_token(identity=user)
model = cls()
Expand All @@ -174,6 +172,7 @@ def create_refresh_token(
model.name = device or oldRefreshToken.name
model.user = user
if oldRefreshToken:
oldRefreshToken.delete_created_access_tokens(commit=False)
model.refresh_token = oldRefreshToken
model.save()
return refreshToken, model
Expand Down
53 changes: 28 additions & 25 deletions backend/tests/api/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def test_shaky_network_token_refresh(user_client, username, password):
assert response.status_code == 200
# Intentionally ignore new tokens

# Use old access token, should still work since we didn't use the new one
# Use old access token, should not work since refresh invalidates them
response = user_client.get("/api/user", headers={"Authorization": f"Bearer {access_token}"})
assert response.status_code == 200
assert response.status_code == 401


# Original refresh token should still work since we didn't use the new one
Expand Down Expand Up @@ -83,9 +83,9 @@ def test_token_hijack_attempt(user_client, username, password):
leaked_refresh_token = data["refresh_token"]


# User continues normal use with original tokens
# User cannot continue normal use with original access token
response = user_client.get("/api/user", headers={"Authorization": f"Bearer {access_token}"})
assert response.status_code == 200
assert response.status_code == 401

# Create another refresh token (normal use)
response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {refresh_token}"})
Expand Down Expand Up @@ -225,44 +225,47 @@ def test_complex_token_chain(user_client, username, password):
at5 = data["access_token"]
rt5 = data["refresh_token"]

# Use AT2 to make it the active chain
# AT2 should be rejected (refresh invalidates AT but not RT)
response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at2}"})
assert response.status_code == 401

# RT5/AT5 chain should work (last created refresh token)
response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at5}"})
assert response.status_code == 200
response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt5}"})
assert response.status_code == 200
data = response.get_json()
at6 = data["access_token"]
rt6 = data["refresh_token"]

# Verify unused tokens from parallel chains are rejected
# Verify unused tokens from parallel chains are rejected triggering breach detection
response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt4}"})
assert response.status_code == 401

# RT3/AT3 chain should be rejected (unused parallel chain)
# Check that no token works anymore
response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at1}"})
assert response.status_code == 401
response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt1}"})
assert response.status_code == 401
response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at2}"})
assert response.status_code == 401
response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt2}"})
assert response.status_code == 401
response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at3}"})
assert response.status_code == 401
response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt3}"})
assert response.status_code == 401

# RT4/AT4 chain should be rejected (unused parallel chain)
response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at4}"})
assert response.status_code == 401
response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt4}"})
assert response.status_code == 401

# RT5/AT5 chain should be rejected (unused parallel chain)
response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at5}"})
assert response.status_code == 401
response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt5}"})
assert response.status_code == 401

# Original RT1 should be rejected
response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt1}"})
response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at6}"})
assert response.status_code == 401

# Try to use one of the parallel chain tokens (RT3), which should trigger breach detection
response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt3}"})
assert response.status_code == 401

# AT2 should now be rejected as the use of RT3 indicates a potential breach
response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at2}"})
assert response.status_code == 401

# RT2 should be rejected (part of the compromised chain)
response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt2}"})
response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt6}"})
assert response.status_code == 401

def test_complex_token_chain2(user_client, username, password):
Expand Down

0 comments on commit 1d90e65

Please sign in to comment.