Skip to content

Commit

Permalink
Refactored credentials.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Casvt committed Dec 10, 2024
1 parent 9fe8450 commit dc6dae7
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 280 deletions.
28 changes: 0 additions & 28 deletions backend/base/custom_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,34 +290,6 @@ class CredentialNotFound(CustomException):
api_response = {'error': 'CredentialNotFound', 'result': {}, 'code': 404}


class CredentialSourceNotFound(Exception):
"""The credential source with the given string was not found"""

def __init__(self, string: str) -> None:
self.string = string
LOGGER.warning(
f'Credential source with given string not found: {string}'
)
return

@property
def api_response(self):
return {
'error': 'CredentialSourceNotFound',
'result': {'string': self.string},
'code': 404
}


class CredentialAlreadyAdded(CustomException):
"""A credential for the given source is already added"""
api_response = {
'error': 'CredentialAlreadyAdded',
'result': {},
'code': 400
}


class CredentialInvalid(Exception):
"""A credential is incorrect (can't login with it)"""
api_response = {'error': 'CredentialInvalid', 'result': {}, 'code': 400}
Expand Down
19 changes: 19 additions & 0 deletions backend/base/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ class MonitorScheme(BaseEnum):
NONE = "none"


class CredentialSource(BaseEnum):
MEGA = "mega"


# region TypedDicts
class FilenameData(TypedDict):
series: str
Expand Down Expand Up @@ -485,6 +489,21 @@ class VolumeData:
last_cv_fetch: int


@dataclass
class CredentialData:
id: int
source: CredentialSource
username: Union[str, None]
email: Union[str, None]
password: Union[str, None]
api_key: Union[str, None]

def as_dict(self) -> Dict[str, Any]:
result = asdict(self)
result['source'] = self.source.value
return result


# region Abstract Classes
class DBMigrator(ABC):
start_version: int
Expand Down
258 changes: 108 additions & 150 deletions backend/implementations/credentials.py
Original file line number Diff line number Diff line change
@@ -1,196 +1,154 @@
# -*- coding: utf-8 -*-

from sqlite3 import IntegrityError
from typing import List, Union

from backend.base.custom_exceptions import (CredentialAlreadyAdded,
CredentialInvalid,
CredentialNotFound,
CredentialSourceNotFound)
from backend.base.helpers import first_of_column
from typing import Any, Dict, List, Tuple

from typing_extensions import assert_never

from backend.base.custom_exceptions import (CredentialInvalid,
CredentialNotFound)
from backend.base.definitions import CredentialData, CredentialSource
from backend.base.logging import LOGGER
from backend.internals.db import get_db
from backend.lib.mega import Mega, RequestError


class Credentials:
"""For interracting with the service credentials
auth_tokens: Dict[CredentialSource, Dict[str, Tuple[Any, int]]] = {}
"""
Store auth tokens as to avoid logging in while already having a cleared
token. Maps from credential source to user identifier (something like user
ID, email or username) to a tuple of the token and it's expiration time.
"""
cache = {}
__load_first = True

def __init__(self, sids: dict) -> None:
"""Set up the credential class
def get_all(self) -> List[CredentialData]:
"""Get all credentials.
Args:
sids (dict): The sids variable at backend.lib.mega.sids
Returns:
List[CredentialData]: The list of credentials.
"""
self.sids = sids
return

def get_all(self, use_cache: bool = True) -> List[dict]:
"""Get all credentials
return [
CredentialData(**{
**dict(c),
'source': CredentialSource[c["source"].upper()]
})
for c in get_db().execute("""
SELECT
id, source,
username, email,
password, api_key
FROM credentials;
""").fetchall()
]

def get_one(self, id: int) -> CredentialData:
"""Get a credential based on it's id.
Args:
use_cache (bool, optional): Wether or not to pull data from cache
instead of going to the database.
Defaults to True.
id (int): The ID of the credential to get.
Raises:
CredentialNotFound: The ID doesn't map to any credential.
Returns:
List[dict]: The list of credentials
CredentialData: The credential info
"""
if not use_cache or not self.cache or self.__load_first:
cred = dict(
(c['id'], dict(c))
for c in get_db().execute("""
SELECT
c.id, cs.source,
c.email, c.password
FROM credentials c
INNER JOIN credentials_sources cs
ON c.source = cs.id;
"""
)
)
self.cache = cred
self.__load_first = False
result = get_db().execute("""
SELECT
id, source,
username, email,
password, api_key
FROM credentials
WHERE id = ?
LIMIT 1;
""",
(id,)
).fetchone()

if result is None:
raise CredentialNotFound

return list(self.cache.values())
return CredentialData(**{
**dict(result),
'source': CredentialSource(result["source"])
})

def get_one(self, id: int, use_cache: bool = True) -> dict:
"""Get a credential based on it's id.
def get_from_source(
self,
source: CredentialSource
) -> List[CredentialData]:
"""Get credentials for the given source.
Args:
id (int): The id of the credential to get.
use_cache (bool, optional): Wether or not to pull data from cache
instead of going to the database.
Defaults to True.
Raises:
CredentialNotFound: The id doesn't map to any credential.
Could also be because of cache being behind database.
source (CredentialSource): The source of the credentials.
Returns:
dict: The credential info
List[CredentialData]: The credentials for the given source.
"""
if not use_cache or self.__load_first:
self.get_all(use_cache=False)
cred = self.cache.get(id)
if not cred:
raise CredentialNotFound
return cred
return [
c
for c in self.get_all()
if c.source == source
]

def get_one_from_source(self, source: str, use_cache: bool = True) -> dict:
"""Get a credential based on it's source string.
def add(self, credential_data: CredentialData) -> CredentialData:
"""Add a credential.
Args:
source (str): The source of which to get the credential.
use_cache (bool, optional): Wether or not to pull data from cache
instead of going to the database.
Defaults to True.
credential_data (CredentialData): The data of the credential to
store.
Returns:
dict: The credential info or a 'ghost' version of the response
dict: The credential info
"""
if not use_cache or self.__load_first:
self.get_all(use_cache=False)
for cred in self.cache.values():
if cred['source'] == source:
return cred

# If no cred is set for the source,
# return a 'ghost' cred because other code can then
# simply grab value of 'email' and 'password' and it'll be None
return {
'id': -1,
'source': source,
'email': None,
'password': None
}

def add(self, source: str, email: str, password: str) -> dict:
"""Add a credential
LOGGER.info(f'Adding credential for {credential_data.source.value}')

# Check if it works
if credential_data.source == CredentialSource.MEGA:
from ..lib.mega import Mega, RequestError
try:
Mega(
'',
_only_login=True
)._login_user(
credential_data.email or '',
credential_data.password or ''
)

Args:
source (str): The service for which the credential is.
Must be a value of `settings.credential_sources`.
except RequestError:
raise CredentialInvalid

email (str): The email of the credential.
credential_data.api_key = None
credential_data.username = None

password (str): The password of the credential
else:
assert_never(credential_data.source)

Raises:
CredentialSourceNotFound: The source string doesn't map to any service.
CredentialAlreadyAdded: The service already has a credential for it.
id = get_db().execute("""
INSERT INTO credentials(source, username, email, password, api_key)
VALUES (:source, :username, :email, :password, :api_key);
""",
credential_data.as_dict()
).lastrowid

Returns:
dict: The credential info
"""
cursor = get_db()
source_id: Union[int, None] = cursor.execute(
"SELECT id FROM credentials_sources WHERE source = ? LIMIT 1;",
(source,)
).exists()
if not source_id:
raise CredentialSourceNotFound(source)

LOGGER.info(f'Adding credential for {source}')
try:
if source == 'mega':
Mega('', email, password, only_check_login=True)

id = cursor.execute("""
INSERT INTO credentials(source, email, password)
VALUES (?,?,?);
""",
(source_id, email, password)
).lastrowid

except RequestError:
raise CredentialInvalid

except IntegrityError:
raise CredentialAlreadyAdded

return self.get_one(id, use_cache=False)
return self.get_one(id)

def delete(self, cred_id: int) -> None:
"""Delete a credential
"""Delete a credential.
Args:
cred_id (int): The id of the credential to delete
cred_id (int): The ID of the credential to delete.
Raises:
CredentialNotFound: The id doesn't map to any credential
CredentialNotFound: The ID doesn't map to any credential.
"""
LOGGER.info(f'Deleting credential: {cred_id}')

if not get_db().execute(
"DELETE FROM credentials WHERE id = ?", (cred_id,)
).rowcount:
raise CredentialNotFound

self.get_all(use_cache=False)
self.sids.clear()
return
source = self.get_one(cred_id).source

def get_open(self) -> List[str]:
"""Get a list of all services that
don't have a credential registered for it
Returns:
List[str]: The list of service strings
"""
result = first_of_column(
get_db().execute("""
SELECT cs.source
FROM credentials_sources cs
LEFT JOIN credentials c
ON cs.id = c.source
WHERE c.id IS NULL;
""")
get_db().execute(
"DELETE FROM credentials WHERE id = ?", (cred_id,)
)

return result
if source in self.auth_tokens:
del self.auth_tokens[source]

return
Loading

0 comments on commit dc6dae7

Please sign in to comment.