-
Notifications
You must be signed in to change notification settings - Fork 61
/
api.py
708 lines (574 loc) · 23.7 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
import base64
import json
import sys
import gnupg
import requests
"""
When using alternative routing, we want to verify as little data as possible. Thus we'll
end up relying mostly on tls key pinning. If we don't disable warnings, a warning will be
constantly popping on the terminal informing the user about it.
https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
"""
import urllib3
urllib3.disable_warnings()
from concurrent.futures import ThreadPoolExecutor
from .cert_pinning import TLSPinningAdapter
from .constants import (ALT_HASH_DICT, DEFAULT_TIMEOUT, DNS_HOSTS,
ENCODED_URLS, SRP_MODULUS_KEY,
SRP_MODULUS_KEY_FINGERPRINT)
from .exceptions import (ConnectionTimeOutError, NetworkError,
NewConnectionError, ProtonAPIError, TLSPinningError,
UnknownConnectionError, MissingDepedencyError)
from .logger import CustomLogger
from .metadata import MetadataBackend
from .srp import User as PmsrpUser
class Session:
"""A Proton Session.
Provides public key pinning, fetch alternative routes and connect to
Proton API in a authenticated manner, dump and load sessions.
All this is possible since it serves as a wrapper for `<Requests>`
Basic Usage:
>>> import proton
>>> s = proton.Session("https://url-to-api.ch")
>>> s.enable_alternative_routing = True
>>> s.api_request("/api/endpoint")
<Response [200]>
"""
_base_headers = {
"x-pm-apiversion": "3",
"Accept": "application/vnd.protonmail.v1+json"
}
__force_skip_alternative_routing = False
@staticmethod
def load(
dump, log_dir_path, cache_dir_path,
tls_pinning=True, timeout=DEFAULT_TIMEOUT,
proxies=None
):
"""Load session from file/keyring.
This should load the output generated by dump().
Args:
log_dir_path (str): path to desired logging directory
cache_dir_path (str): path to desired cache directory
tls_pinning (bool): tls pinning
timeout (tuple|int|float): How long to wait for the server to send
data before giving up.
proxies (dict): desired proxies
Returns:
proton.Session
"""
api_url = dump["api_url"]
appversion = dump["appversion"]
user_agent = dump["User-Agent"]
cookies = dump.get("cookies", {})
s = Session(
api_url=api_url,
log_dir_path=log_dir_path,
cache_dir_path=cache_dir_path,
appversion=appversion,
user_agent=user_agent,
tls_pinning=tls_pinning,
timeout=timeout,
proxies=proxies
)
requests.utils.add_dict_to_cookiejar(s.s.cookies, cookies)
s._session_data = dump["session_data"]
if s.UID is not None:
s.s.headers["x-pm-uid"] = s.UID
s.s.headers["Authorization"] = "Bearer " + s.AccessToken
return s
def dump(self):
"""Dump session.
If you want to reuse the session, then dump it and store the values
somewhere safe.
Returns:
dict
"""
return {
"api_url": self.__api_url,
"appversion": self.__appversion,
"User-Agent": self.__user_agent,
"cookies": self.s.cookies.get_dict(),
"session_data": self._session_data
}
def __init__(
self, api_url, log_dir_path, cache_dir_path,
appversion="Other", user_agent="None",
tls_pinning=True, ClientSecret=None, timeout=DEFAULT_TIMEOUT,
proxies=None
):
"""Constructs a new Session object.
Args:
api_url (string): URL for the new Session object
appversion (string): version for the new Session object
user_agent (string): user agent for the new Session` object
should be in the following syntax:
- Linux based -> ClientName/client.version (Linux; Distro/distro_version)
- Non-linux based -> ClientName/client.version (OS)
tls_pinning (bool): wether tls pinning should be enabled for the new
Session object.
ClientSecret (string): secret token for the new Session object that
is added to the payload with key `ClientSecret`. [OPTIONAL]
timeout (int|float|tuple): How long to wait for the server to send
data before giving up. [OPTIONAL]
proxies (dict): proxies to be used by the new Session object.
This is mutually exclusive with `tls_pinning`. [OPTIONAL]
"""
self.__api_url = api_url
self.__appversion = appversion
self.__user_agent = user_agent
self.__clientsecret = ClientSecret
self.__timeout = timeout
self.__tls_pinning_enabled = tls_pinning
self._logger = CustomLogger()
self._logger.set_log_path(log_dir_path)
self._logger = self._logger.logger
self.__metadata = MetadataBackend.get_backend()
self.__metadata.cache_dir_path = cache_dir_path
self.__metadata.logger = self._logger
self.__allow_alternative_routing = None
# Verify modulus
self.__gnupg = gnupg.GPG()
self.__gnupg.import_keys(SRP_MODULUS_KEY)
self._session_data = {}
self.s = requests.Session()
if proxies and self.__tls_pinning_enabled:
raise RuntimeError("Not allowed to add proxies while TLS Pinning is enabled")
self.s.proxies = proxies
if self.__tls_pinning_enabled:
self.s.mount(self.__api_url, TLSPinningAdapter())
self.s.headers["x-pm-appversion"] = appversion
self.s.headers["User-Agent"] = user_agent
def api_request(
self, endpoint,
jsondata=None, additional_headers=None,
method=None, params=None, _skip_alt_routing_for_api_check=False
):
"""Make API request.
Args:
endpoint (string): API endpoint.
jsondata (json): json to send in the body.
additional_headers (dict): additional (dictionary of) headers to send.
method (string): get|post|put|delete|patch.
params (dict|tuple): URL parameters to append to the URL. If a dictionary or
list of tuples ``[(key, value)]`` is provided, form-encoding will
take place.
_skip_alt_routing_for_api_check (bool): used to temporarly skip alt routing.
Returns:
requests.Response
"""
if self.__allow_alternative_routing is None:
msg = "Alternative routing has not been configured before making API requests. " \
"Please either enable or disable it before making any requests."
self._logger.info(msg)
raise RuntimeError(msg)
fct = self.s.post
if method is None:
if jsondata is None:
fct = self.s.get
else:
fct = self.s.post
else:
fct = {
"get": self.s.get,
"post": self.s.post,
"put": self.s.put,
"delete": self.s.delete,
"patch": self.s.patch
}.get(method.lower())
if fct is None:
raise ValueError("Unknown method: {}".format(method))
_url = self.__api_url
_verify = True
if not self.__metadata.try_original_url(
self.__allow_alternative_routing,
self.__force_skip_alternative_routing
):
_url = self.__metadata.get_alternative_url()
_verify = False
request_params = {
"url": _url,
"endpoint": endpoint,
"headers": additional_headers,
"json": jsondata,
"timeout": self.__timeout,
"verify": _verify,
"params": params
}
exception_class = None
exception_msg = None
try:
response = self.__make_request(fct, **request_params)
except (
NewConnectionError,
ConnectionTimeOutError,
TLSPinningError,
) as e:
self._logger.exception(e)
exc_type, *_ = sys.exc_info()
exception_class = exc_type
exception_msg = e
except (Exception, requests.exceptions.BaseHTTPError) as e:
self._logger.exception(e)
raise UnknownConnectionError(e)
if exception_class and (not self.__allow_alternative_routing or _skip_alt_routing_for_api_check or self.__force_skip_alternative_routing): # noqa
self._logger.info("{}: {}".format(exception_class, exception_msg))
raise exception_class(exception_msg)
elif (
exception_class in [NewConnectionError, ConnectionTimeOutError, TLSPinningError]
and not self._is_api_reacheable()
):
response = self.__try_with_alt_routing(fct, **request_params)
try:
status_code = response.status_code
except: # noqa
status_code = False
try:
json_error = False
response = response.json()
except json.decoder.JSONDecodeError as e:
json_error = e
if json_error and status_code != 200:
self._logger.exception(json_error)
raise ProtonAPIError(
{
"Code": response.status_code,
"Error": response.reason,
"Headers": response.headers
}
)
# This check is needed for routers or any other clients that will ask for other
# data that is not provided in json format, such as when asking /vpn/config for
# a .ovpn template
try:
if response["Code"] not in [1000, 1001]:
if response["Code"] == 9001:
self.__captcha_token = response["Details"]["HumanVerificationToken"]
elif response["Code"] == 12087:
del self.human_verification_token
raise ProtonAPIError(response)
except TypeError as e:
if status_code != 200:
raise TypeError(e)
return response
def __try_with_alt_routing(self, fct, **request_params):
alternative_routes = self.get_alternative_routes_from_dns()
request_params["verify"] = False
response = None
for route in alternative_routes:
_alt_url = "https://{}".format(route)
request_params["url"] = _alt_url
if self.__tls_pinning_enabled:
self.s.mount(_alt_url, TLSPinningAdapter(ALT_HASH_DICT))
self._logger.info("Trying {}".format(_alt_url))
try:
response = self.__make_request(fct, **request_params)
except Exception as e: # noqa
self._logger.exception(e)
continue
else:
self._logger.info("Storing alternative route: {}".format(_alt_url))
self.__metadata.store_alternative_route(_alt_url)
break
if not response:
self._logger.info("Possible network error, unable to reach API")
raise NetworkError("Network error")
return response
def __make_request(self, fct, **kwargs):
_endpoint = kwargs.pop("endpoint")
_url = kwargs["url"]
kwargs["url"] = _url + _endpoint
try:
ret = fct(**kwargs)
except requests.exceptions.ConnectionError as e:
raise NewConnectionError(e)
except requests.exceptions.Timeout as e:
raise ConnectionTimeOutError(e)
except TLSPinningError as e:
raise TLSPinningError(e)
except (Exception, requests.exceptions.BaseHTTPError) as e:
raise UnknownConnectionError(e)
return ret
def _is_api_reacheable(self):
try:
self.api_request("/tests/ping", _skip_alt_routing_for_api_check=True)
except (NewConnectionError, ConnectionTimeOutError, TLSPinningError) as e:
self._logger.exception(e)
return False
return True
def verify_modulus(self, armored_modulus):
# gpg.decrypt verifies the signature too, and returns the parsed data.
# By using gpg.verify the data is not returned
verified = self.__gnupg.decrypt(armored_modulus)
if not (verified.valid and verified.fingerprint.lower() == SRP_MODULUS_KEY_FINGERPRINT):
raise ValueError("Invalid modulus")
return base64.b64decode(verified.data.strip())
def authenticate(self, username, password):
"""Authenticate user against API.
Args:
username (string): proton account username
password (string): proton account password
Returns:
dict
The returning dict contains the Scope of the account. This allows
to identify if the account is locked, has unpaid invoices, etc.
"""
self.logout()
payload = {"Username": username}
if self.__clientsecret:
payload["ClientSecret"] = self.__clientsecret
info_response = self.api_request("/auth/info", payload)
modulus = self.verify_modulus(info_response["Modulus"])
server_challenge = base64.b64decode(info_response["ServerEphemeral"])
salt = base64.b64decode(info_response["Salt"])
version = info_response["Version"]
usr = PmsrpUser(password, modulus)
client_challenge = usr.get_challenge()
client_proof = usr.process_challenge(salt, server_challenge, version)
if client_proof is None:
raise ValueError("Invalid challenge")
# Send response
payload = {
"Username": username,
"ClientEphemeral": base64.b64encode(client_challenge).decode(
"utf8"
),
"ClientProof": base64.b64encode(client_proof).decode("utf8"),
"SRPSession": info_response["SRPSession"],
}
if self.__clientsecret:
payload["ClientSecret"] = self.__clientsecret
auth_response = self.api_request("/auth", payload)
if "ServerProof" not in auth_response:
raise ValueError("Invalid password")
usr.verify_session(base64.b64decode(auth_response["ServerProof"]))
if not usr.authenticated():
raise ValueError("Invalid server proof")
self._session_data = {
"UID": auth_response["UID"],
"AccessToken": auth_response["AccessToken"],
"RefreshToken": auth_response["RefreshToken"],
"PasswordMode": auth_response["PasswordMode"],
"Scope": auth_response["Scope"].split(),
}
if self.UID is not None:
self.s.headers["x-pm-uid"] = self.UID
self.s.headers["Authorization"] = "Bearer " + self.AccessToken
return self.Scope
def provide_2fa(self, code):
"""Provide Two Factor Authentication Code to the API.
Args:
code (string): string of ints
Returns:
dict
The returning dict contains the Scope of the account. This allows
to identify if the account is locked, has unpaid invoices, etc.
"""
ret = self.api_request("/auth/2fa", {"TwoFactorCode": code})
self._session_data["Scope"] = ret["Scope"]
return self.Scope
def logout(self):
"""Logout from API."""
if self._session_data:
self.api_request("/auth", method="DELETE")
del self.s.headers["Authorization"]
del self.s.headers["x-pm-uid"]
self._session_data = {}
def refresh(self):
"""Refresh tokens.
Refresh AccessToken with a valid RefreshToken.
If the RefreshToken is invalid then the user will have to
re-authenticate.
"""
refresh_response = self.api_request(
"/auth/refresh",
{
"ResponseType": "token",
"GrantType": "refresh_token",
"RefreshToken": self.RefreshToken,
"RedirectURI": "http://protonmail.ch"
}
)
self._session_data["AccessToken"] = refresh_response["AccessToken"]
self._session_data["RefreshToken"] = refresh_response["RefreshToken"]
self.s.headers["Authorization"] = "Bearer " + self.AccessToken
def get_alternative_routes_from_dns(self, callback=None):
"""Get alternative routes to circumvent firewalls and API restrictions.
Args:
callback (func): a callback method to be called.
Might be usefull for multi-threading. [OPTIONAL]
This method leverages the power of ThreadPoolExecutor to async
check if the provided dns hosts can be reached, and if so, collect the
alternatives routes provided by them.
The encoded url are done sync because most often one of the two should work,
as it should provide the data as quick as possible.
If callback is passed then the method does not return any value, otherwise it
returns a set().
"""
try:
from dns import message
from dns.rdatatype import TXT
except ImportError as e:
self._logger.exception(e)
raise MissingDepedencyError(
"Could not find dnspython package. "
"Please either install the missing package or disable "
"alternative routing."
)
routes = set()
for encoded_url in ENCODED_URLS:
dns_query, dns_encoded_data = self.__generate_dns_message(encoded_url)
dns_hosts_response = []
host_and_dns = [(host, dns_encoded_data) for host in DNS_HOSTS]
with ThreadPoolExecutor(max_workers=len(DNS_HOSTS)) as executor:
dns_hosts_response = list(
executor.map(self.__query_for_dns_data, host_and_dns, timeout=20)
)
dns_hosts_response = [dns_url for dns_url in dns_hosts_response if dns_url]
if len(dns_hosts_response) == 0:
continue
for response in dns_hosts_response:
routes = self.__extract_dns_answer(response, dns_query)
if len(routes) > 0:
break
if not callback:
return routes
callback(routes)
def __generate_dns_message(self, encoded_url):
"""Generate DNS message object.
Args:
encoded_url (string): encoded url as per documentation
Returns:
tuple():
dns_query (dns.message.Message): output of dns.message.make_query
base64_dns_message (base64): encode bytes
"""
from dns import message
from dns.rdatatype import TXT
dns_query = message.make_query(encoded_url, TXT)
dns_wire = dns_query.to_wire()
base64_dns_message = base64.urlsafe_b64encode(dns_wire).rstrip(b"=")
return dns_query, base64_dns_message
def __query_for_dns_data(self, dns_settings):
"""Query DNS host for data.
Args:
dns_settings (tuple):
host_url (str): http/https url
dns_encoded_data (str): base64 output
generate by __generate_dns_message()
This method uses requests.get to query the url
for dns data.
Returns:
bytes: content of the response
"""
dns_host, dns_encoded_data = dns_settings[0], dns_settings[1]
try:
response = requests.get(
dns_host,
headers={"accept": "application/dns-message"},
timeout=(3.05, 16.95),
params={"dns": dns_encoded_data}
)
if response.status_code == 404:
return
except Exception as e: # noqa
return
return response.content
def __extract_dns_answer(self, query_content, dns_query):
"""Extract alternative URL from dns message.
Args:
query_content (bytes): content of the response
dns_query (dns.message.Message): output of dns.message.make_query
Returns:
set(): alternative routes for API
"""
from dns import message
r = message.from_wire(
query_content,
keyring=dns_query.keyring,
request_mac=dns_query.request_mac,
one_rr_per_rrset=False,
ignore_trailing=False
)
routes = set()
for route in r.answer:
routes = set([str(url).strip("\"") for url in route])
return routes
@property
def captcha_url(self):
return "{}/core/v4/captcha?Token={}".format(
self.__api_url, self.__captcha_token
)
@property
def enable_alternative_routing(self):
"""Alternative routing getter."""
return self.__allow_alternative_routing
@enable_alternative_routing.setter
def enable_alternative_routing(self, newvalue):
"""Alternative routing setter.
If you would like to enable/disable alternative routing
before making any requests, this should be set to the desired
value.
Args:
newvalue (bool)
"""
if self.__allow_alternative_routing != bool(newvalue):
self.__allow_alternative_routing = bool(newvalue)
@property
def force_skip_alternative_routing(self):
"""Force skip alternative routing getter."""
return self.__force_skip_alternative_routing
@force_skip_alternative_routing.setter
def force_skip_alternative_routing(self, newvalue):
"""Force skip alternative routing setter.
Alternative routing is normally used when the usual API is not
reacheable. In certain cases, such as when connected to the VPN,
the usual API should be reacheable as the connection is tunneled,
thus there is not need to reach for the alternative routes and the
usual API is preffered to be used, for security and reliability.
Args:
newvalue (bool)
"""
self.__force_skip_alternative_routing = bool(newvalue)
@property
def human_verification_token(self):
return (
self.s.headers.get("X-PM-Human-Verification-Token-Type", None),
self.s.headers.get("X-PM-Human-Verification-Token", None)
)
@human_verification_token.setter
def human_verification_token(self, newtuplevalue):
"""Set human verification token:
Args:
newtuplevalue (tuple): (token_type, token_value)
"""
self.s.headers["X-PM-Human-Verification-Token-Type"] = newtuplevalue[0]
self.s.headers["X-PM-Human-Verification-Token"] = newtuplevalue[1]
@human_verification_token.deleter
def human_verification_token(self):
# Safest to use .pop() as it will onyl attempt to remove the key by name
# while del can also remove the whole dict (in case of code/programming error)
# Thus to prevent this, pop() is used.
try:
self.s.headers.pop("X-PM-Human-Verification-Token-Type")
except (KeyError, IndexError):
pass
try:
self.s.headers.pop("X-PM-Human-Verification-Token")
except (KeyError, IndexError):
pass
@property
def UID(self):
return self._session_data.get("UID", None)
@property
def AccessToken(self):
return self._session_data.get("AccessToken", None)
@property
def RefreshToken(self):
return self._session_data.get("RefreshToken", None)
@property
def PasswordMode(self):
return self._session_data.get("PasswordMode", None)
@property
def Scope(self):
return self._session_data.get("Scope", [])