-
-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
259 additions
and
280 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,196 +1,154 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from sqlite3 import IntegrityError | ||
from typing import List, Union | ||
|
||
from backend.base.custom_exceptions import (CredentialAlreadyAdded, | ||
CredentialInvalid, | ||
CredentialNotFound, | ||
CredentialSourceNotFound) | ||
from backend.base.helpers import first_of_column | ||
from typing import Any, Dict, List, Tuple | ||
|
||
from typing_extensions import assert_never | ||
|
||
from backend.base.custom_exceptions import (CredentialInvalid, | ||
CredentialNotFound) | ||
from backend.base.definitions import CredentialData, CredentialSource | ||
from backend.base.logging import LOGGER | ||
from backend.internals.db import get_db | ||
from backend.lib.mega import Mega, RequestError | ||
|
||
|
||
class Credentials: | ||
"""For interracting with the service credentials | ||
auth_tokens: Dict[CredentialSource, Dict[str, Tuple[Any, int]]] = {} | ||
""" | ||
Store auth tokens as to avoid logging in while already having a cleared | ||
token. Maps from credential source to user identifier (something like user | ||
ID, email or username) to a tuple of the token and it's expiration time. | ||
""" | ||
cache = {} | ||
__load_first = True | ||
|
||
def __init__(self, sids: dict) -> None: | ||
"""Set up the credential class | ||
def get_all(self) -> List[CredentialData]: | ||
"""Get all credentials. | ||
Args: | ||
sids (dict): The sids variable at backend.lib.mega.sids | ||
Returns: | ||
List[CredentialData]: The list of credentials. | ||
""" | ||
self.sids = sids | ||
return | ||
|
||
def get_all(self, use_cache: bool = True) -> List[dict]: | ||
"""Get all credentials | ||
return [ | ||
CredentialData(**{ | ||
**dict(c), | ||
'source': CredentialSource[c["source"].upper()] | ||
}) | ||
for c in get_db().execute(""" | ||
SELECT | ||
id, source, | ||
username, email, | ||
password, api_key | ||
FROM credentials; | ||
""").fetchall() | ||
] | ||
|
||
def get_one(self, id: int) -> CredentialData: | ||
"""Get a credential based on it's id. | ||
Args: | ||
use_cache (bool, optional): Wether or not to pull data from cache | ||
instead of going to the database. | ||
Defaults to True. | ||
id (int): The ID of the credential to get. | ||
Raises: | ||
CredentialNotFound: The ID doesn't map to any credential. | ||
Returns: | ||
List[dict]: The list of credentials | ||
CredentialData: The credential info | ||
""" | ||
if not use_cache or not self.cache or self.__load_first: | ||
cred = dict( | ||
(c['id'], dict(c)) | ||
for c in get_db().execute(""" | ||
SELECT | ||
c.id, cs.source, | ||
c.email, c.password | ||
FROM credentials c | ||
INNER JOIN credentials_sources cs | ||
ON c.source = cs.id; | ||
""" | ||
) | ||
) | ||
self.cache = cred | ||
self.__load_first = False | ||
result = get_db().execute(""" | ||
SELECT | ||
id, source, | ||
username, email, | ||
password, api_key | ||
FROM credentials | ||
WHERE id = ? | ||
LIMIT 1; | ||
""", | ||
(id,) | ||
).fetchone() | ||
|
||
if result is None: | ||
raise CredentialNotFound | ||
|
||
return list(self.cache.values()) | ||
return CredentialData(**{ | ||
**dict(result), | ||
'source': CredentialSource(result["source"]) | ||
}) | ||
|
||
def get_one(self, id: int, use_cache: bool = True) -> dict: | ||
"""Get a credential based on it's id. | ||
def get_from_source( | ||
self, | ||
source: CredentialSource | ||
) -> List[CredentialData]: | ||
"""Get credentials for the given source. | ||
Args: | ||
id (int): The id of the credential to get. | ||
use_cache (bool, optional): Wether or not to pull data from cache | ||
instead of going to the database. | ||
Defaults to True. | ||
Raises: | ||
CredentialNotFound: The id doesn't map to any credential. | ||
Could also be because of cache being behind database. | ||
source (CredentialSource): The source of the credentials. | ||
Returns: | ||
dict: The credential info | ||
List[CredentialData]: The credentials for the given source. | ||
""" | ||
if not use_cache or self.__load_first: | ||
self.get_all(use_cache=False) | ||
cred = self.cache.get(id) | ||
if not cred: | ||
raise CredentialNotFound | ||
return cred | ||
return [ | ||
c | ||
for c in self.get_all() | ||
if c.source == source | ||
] | ||
|
||
def get_one_from_source(self, source: str, use_cache: bool = True) -> dict: | ||
"""Get a credential based on it's source string. | ||
def add(self, credential_data: CredentialData) -> CredentialData: | ||
"""Add a credential. | ||
Args: | ||
source (str): The source of which to get the credential. | ||
use_cache (bool, optional): Wether or not to pull data from cache | ||
instead of going to the database. | ||
Defaults to True. | ||
credential_data (CredentialData): The data of the credential to | ||
store. | ||
Returns: | ||
dict: The credential info or a 'ghost' version of the response | ||
dict: The credential info | ||
""" | ||
if not use_cache or self.__load_first: | ||
self.get_all(use_cache=False) | ||
for cred in self.cache.values(): | ||
if cred['source'] == source: | ||
return cred | ||
|
||
# If no cred is set for the source, | ||
# return a 'ghost' cred because other code can then | ||
# simply grab value of 'email' and 'password' and it'll be None | ||
return { | ||
'id': -1, | ||
'source': source, | ||
'email': None, | ||
'password': None | ||
} | ||
|
||
def add(self, source: str, email: str, password: str) -> dict: | ||
"""Add a credential | ||
LOGGER.info(f'Adding credential for {credential_data.source.value}') | ||
|
||
# Check if it works | ||
if credential_data.source == CredentialSource.MEGA: | ||
from ..lib.mega import Mega, RequestError | ||
try: | ||
Mega( | ||
'', | ||
_only_login=True | ||
)._login_user( | ||
credential_data.email or '', | ||
credential_data.password or '' | ||
) | ||
|
||
Args: | ||
source (str): The service for which the credential is. | ||
Must be a value of `settings.credential_sources`. | ||
except RequestError: | ||
raise CredentialInvalid | ||
|
||
email (str): The email of the credential. | ||
credential_data.api_key = None | ||
credential_data.username = None | ||
|
||
password (str): The password of the credential | ||
else: | ||
assert_never(credential_data.source) | ||
|
||
Raises: | ||
CredentialSourceNotFound: The source string doesn't map to any service. | ||
CredentialAlreadyAdded: The service already has a credential for it. | ||
id = get_db().execute(""" | ||
INSERT INTO credentials(source, username, email, password, api_key) | ||
VALUES (:source, :username, :email, :password, :api_key); | ||
""", | ||
credential_data.as_dict() | ||
).lastrowid | ||
|
||
Returns: | ||
dict: The credential info | ||
""" | ||
cursor = get_db() | ||
source_id: Union[int, None] = cursor.execute( | ||
"SELECT id FROM credentials_sources WHERE source = ? LIMIT 1;", | ||
(source,) | ||
).exists() | ||
if not source_id: | ||
raise CredentialSourceNotFound(source) | ||
|
||
LOGGER.info(f'Adding credential for {source}') | ||
try: | ||
if source == 'mega': | ||
Mega('', email, password, only_check_login=True) | ||
|
||
id = cursor.execute(""" | ||
INSERT INTO credentials(source, email, password) | ||
VALUES (?,?,?); | ||
""", | ||
(source_id, email, password) | ||
).lastrowid | ||
|
||
except RequestError: | ||
raise CredentialInvalid | ||
|
||
except IntegrityError: | ||
raise CredentialAlreadyAdded | ||
|
||
return self.get_one(id, use_cache=False) | ||
return self.get_one(id) | ||
|
||
def delete(self, cred_id: int) -> None: | ||
"""Delete a credential | ||
"""Delete a credential. | ||
Args: | ||
cred_id (int): The id of the credential to delete | ||
cred_id (int): The ID of the credential to delete. | ||
Raises: | ||
CredentialNotFound: The id doesn't map to any credential | ||
CredentialNotFound: The ID doesn't map to any credential. | ||
""" | ||
LOGGER.info(f'Deleting credential: {cred_id}') | ||
|
||
if not get_db().execute( | ||
"DELETE FROM credentials WHERE id = ?", (cred_id,) | ||
).rowcount: | ||
raise CredentialNotFound | ||
|
||
self.get_all(use_cache=False) | ||
self.sids.clear() | ||
return | ||
source = self.get_one(cred_id).source | ||
|
||
def get_open(self) -> List[str]: | ||
"""Get a list of all services that | ||
don't have a credential registered for it | ||
Returns: | ||
List[str]: The list of service strings | ||
""" | ||
result = first_of_column( | ||
get_db().execute(""" | ||
SELECT cs.source | ||
FROM credentials_sources cs | ||
LEFT JOIN credentials c | ||
ON cs.id = c.source | ||
WHERE c.id IS NULL; | ||
""") | ||
get_db().execute( | ||
"DELETE FROM credentials WHERE id = ?", (cred_id,) | ||
) | ||
|
||
return result | ||
if source in self.auth_tokens: | ||
del self.auth_tokens[source] | ||
|
||
return |
Oops, something went wrong.