diff --git a/backend/src/database/crud/user.py b/backend/src/database/crud/user.py index 9ff6e41..05c5ab6 100644 --- a/backend/src/database/crud/user.py +++ b/backend/src/database/crud/user.py @@ -1,9 +1,9 @@ -from src.database.models import DbUser +from src.database.models import DbAccessToken, DbUser from src.dataclasses.user import User def get_user_by_id(id: str): - return DbUser.get( + return DbUser.get_or_none( DbUser.id == id, ) @@ -26,3 +26,15 @@ def get_or_create_user(user: User): "uri": user.uri, }, ) + + +def upsert_user_tokens(user_id: str, access_token: str, refresh_token: str): + DbAccessToken.insert( + user=user_id, access_token=access_token, refresh_token=refresh_token + ).on_conflict( + conflict_target=[DbAccessToken.user], + update={ + DbAccessToken.access_token: access_token, + DbAccessToken.refresh_token: refresh_token, + }, + ).execute() diff --git a/backend/src/database/migrations/240909-add-tracks-table.py b/backend/src/database/migrations/m_240909_add_tracks_table.py similarity index 66% rename from backend/src/database/migrations/240909-add-tracks-table.py rename to backend/src/database/migrations/m_240909_add_tracks_table.py index 1b642bb..9a3d100 100644 --- a/backend/src/database/migrations/240909-add-tracks-table.py +++ b/backend/src/database/migrations/m_240909_add_tracks_table.py @@ -1,13 +1,13 @@ from src.database.models import ( DbTrack, TrackArtistRelationship, - database, + db_wrapper, ) def up(): - with database: - database.create_tables( + with db_wrapper.database: + db_wrapper.database.create_tables( [ DbTrack, TrackArtistRelationship, @@ -16,8 +16,8 @@ def up(): def down(): - with database: - database.drop_tables( + with db_wrapper.database: + db_wrapper.database.drop_tables( [ DbTrack, TrackArtistRelationship, diff --git a/backend/src/database/migrations/m_241030_add_access_tokens_table.py b/backend/src/database/migrations/m_241030_add_access_tokens_table.py new file mode 100644 index 0000000..0a6b2b3 --- /dev/null +++ b/backend/src/database/migrations/m_241030_add_access_tokens_table.py @@ -0,0 +1,19 @@ +from src.database.models import DbAccessToken, db_wrapper + + +def up(): + with db_wrapper.database: + db_wrapper.database.create_tables( + [ + DbAccessToken, + ] + ) + + +def down(): + with db_wrapper.database: + db_wrapper.database.drop_tables( + [ + DbAccessToken, + ] + ) diff --git a/backend/src/database/migrations/init.py b/backend/src/database/migrations/m_init.py similarity index 83% rename from backend/src/database/migrations/init.py rename to backend/src/database/migrations/m_init.py index 9d8437d..3c28e11 100644 --- a/backend/src/database/migrations/init.py +++ b/backend/src/database/migrations/m_init.py @@ -7,13 +7,13 @@ DbPlaylist, PlaylistAlbumRelationship, DbUser, - database, + db_wrapper, ) def up(): - with database: - database.create_tables( + with db_wrapper.database: + db_wrapper.database.create_tables( [ DbUser, DbPlaylist, @@ -28,8 +28,8 @@ def up(): def down(): - with database: - database.drop_tables( + with db_wrapper.database: + db_wrapper.database.drop_tables( [ DbUser, DbPlaylist, diff --git a/backend/src/database/models.py b/backend/src/database/models.py index 8f94a8a..4e7e2d4 100644 --- a/backend/src/database/models.py +++ b/backend/src/database/models.py @@ -1,12 +1,9 @@ from peewee import ( - PostgresqlDatabase, - Model, CharField, IntegerField, DateField, ForeignKeyField, ) -from src.flask_config import Config from playhouse.flask_utils import FlaskDB db_wrapper = FlaskDB() @@ -116,3 +113,14 @@ class TrackArtistRelationship(db_wrapper.Model): class Meta: indexes = ((("track", "artist"), True),) + + +class DbAccessToken(db_wrapper.Model): + user = ForeignKeyField( + DbUser, backref="owner", to_field="id", on_delete="CASCADE", unique=True + ) + access_token = CharField(max_length=400) + refresh_token = CharField(max_length=200) + + class Meta: + db_table = "access_token" diff --git a/backend/src/spotify.py b/backend/src/spotify.py index 4c418d4..4ec1928 100644 --- a/backend/src/spotify.py +++ b/backend/src/spotify.py @@ -6,6 +6,7 @@ from typing import List, Optional from flask import Response, make_response, redirect from src.database.crud.album import get_album_genres +from src.database.crud.user import upsert_user_tokens from src.dataclasses.album import Album from src.dataclasses.playback_info import PlaybackInfo, PlaylistProgression from src.dataclasses.playback_request import ( @@ -102,6 +103,11 @@ def refresh_access_token(self, refresh_token): token_response = TokenResponse.model_validate(api_response) access_token = token_response.access_token user_info = self.get_current_user(access_token) + upsert_user_tokens( + user_info.id, + access_token=token_response.access_token, + refresh_token=token_response.refresh_token, + ) resp = add_cookies_to_response( make_response(), {"spotify_access_token": access_token, "user_id": user_info.id}, @@ -123,12 +129,16 @@ def request_access_token(self, code): ) api_response = self.response_handler(response) token_response = TokenResponse.model_validate(api_response) - access_token = token_response.access_token - user_info = self.get_current_user(access_token) + user_info = self.get_current_user(token_response.access_token) + upsert_user_tokens( + user_info.id, + access_token=token_response.access_token, + refresh_token=token_response.refresh_token, + ) resp = add_cookies_to_response( make_response(redirect(f"{Config().FRONTEND_URL}/")), { - "spotify_access_token": access_token, + "spotify_access_token": token_response.access_token, "spotify_refresh_token": token_response.refresh_token, "user_id": user_info.id, },