Skip to content

Commit

Permalink
Use kaggle-web-client in kaggle_secrets.
Browse files Browse the repository at this point in the history
Use kaggle-web-client in kaggle_secrets.
  • Loading branch information
ifigotin authored Jul 23, 2020
2 parents 953ce15 + 2259a07 commit 5770a85
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 77 deletions.
3 changes: 2 additions & 1 deletion patches/kaggle_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

class KaggleDatasets:
GET_GCS_PATH_ENDPOINT = '/requests/CopyDatasetVersionToKnownGcsBucketRequest'
TIMEOUT_SECS = 600

# Integration types for GCS
AUTO_ML = 1
Expand All @@ -20,5 +21,5 @@ def get_gcs_path(self, dataset_dir: str = None) -> str:
'MountSlug': dataset_dir,
'IntegrationType': integration_type,
}
result = self.web_client.make_post_request(data, self.GET_GCS_PATH_ENDPOINT)
result = self.web_client.make_post_request(data, self.GET_GCS_PATH_ENDPOINT, self.TIMEOUT_SECS)
return result['destinationBucket']
65 changes: 5 additions & 60 deletions patches/kaggle_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,12 @@
(ie. BigQuery).
"""

import json
import os
import socket
import urllib.request
from datetime import datetime, timedelta
from enum import Enum, unique
from typing import Optional, Tuple
from urllib.error import HTTPError, URLError

_KAGGLE_DEFAULT_URL_BASE = "https://www.kaggle.com"
_KAGGLE_URL_BASE_ENV_VAR_NAME = "KAGGLE_URL_BASE"
_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME = "KAGGLE_USER_SECRETS_TOKEN"
TIMEOUT_SECS = 40


class CredentialError(Exception):
pass


class BackendError(Exception):
pass

from kaggle_web_client import KaggleWebClient
from kaggle_web_client import (CredentialError, BackendError)

class ValidationError(Exception):
pass
Expand Down Expand Up @@ -56,48 +40,9 @@ def service(self):
class UserSecretsClient():
GET_USER_SECRET_ENDPOINT = '/requests/GetUserSecretRequest'
GET_USER_SECRET_BY_LABEL_ENDPOINT = '/requests/GetUserSecretByLabelRequest'
BIGQUERY_TARGET_VALUE = 1

def __init__(self):
url_base_override = os.getenv(_KAGGLE_URL_BASE_ENV_VAR_NAME)
self.url_base = url_base_override or _KAGGLE_DEFAULT_URL_BASE
# Follow the OAuth 2.0 Authorization standard (https://tools.ietf.org/html/rfc6750)
self.jwt_token = os.getenv(_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME)
if self.jwt_token is None:
raise CredentialError(
'A JWT Token is required to use the UserSecretsClient, '
f'but none found in environment variable {_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME}')
self.headers = {'Content-type': 'application/json'}

def _make_post_request(self, data: dict, endpoint: str = GET_USER_SECRET_ENDPOINT) -> dict:
# TODO(b/148309982) This code and the code in the constructor should be
# removed and this class should use the new KaggleWebClient class instead.
url = f'{self.url_base}{endpoint}'
request_body = dict(data)
request_body['JWE'] = self.jwt_token
req = urllib.request.Request(url, headers=self.headers, data=bytes(
json.dumps(request_body), encoding="utf-8"))
try:
with urllib.request.urlopen(req, timeout=TIMEOUT_SECS) as response:
response_json = json.loads(response.read())
if not response_json.get('wasSuccessful') or 'result' not in response_json:
raise BackendError(
f'Unexpected response from the service. Response: {response_json}.')
return response_json['result']
except (URLError, socket.timeout) as e:
if isinstance(
e, socket.timeout) or isinstance(
e.reason, socket.timeout):
raise ConnectionError(
'Timeout error trying to communicate with service. Please ensure internet is on.') from e
raise ConnectionError(
'Connection error trying to communicate with service.') from e
except HTTPError as e:
if e.code == 401 or e.code == 403:
raise CredentialError(
f'Service responded with error code {e.code}.'
' Please ensure you have access to the resource.') from e
raise BackendError('Unexpected response from the service.') from e
self.web_client = KaggleWebClient()

def get_secret(self, label) -> str:
"""Retrieves a user secret value by its label.
Expand All @@ -113,7 +58,7 @@ def get_secret(self, label) -> str:
request_body = {
'Label': label,
}
response_json = self._make_post_request(request_body, self.GET_USER_SECRET_BY_LABEL_ENDPOINT)
response_json = self.web_client.make_post_request(request_body, self.GET_USER_SECRET_BY_LABEL_ENDPOINT)
if 'secret' not in response_json:
raise BackendError(
f'Unexpected response from the service. Response: {response_json}')
Expand Down Expand Up @@ -174,7 +119,7 @@ def _get_access_token(self, target: GcpTarget) -> Tuple[str, Optional[datetime]]
request_body = {
'Target': target.target
}
response_json = self._make_post_request(request_body)
response_json = self.web_client.make_post_request(request_body, self.GET_USER_SECRET_ENDPOINT)
if 'secret' not in response_json:
raise BackendError(
f'Unexpected response from the service. Response: {response_json}')
Expand Down
25 changes: 15 additions & 10 deletions patches/kaggle_web_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@
import os
import socket
import urllib.request
from datetime import datetime, timedelta
from enum import Enum, unique
from typing import Optional, Tuple
from urllib.error import HTTPError, URLError
from kaggle_secrets import (_KAGGLE_DEFAULT_URL_BASE,
_KAGGLE_URL_BASE_ENV_VAR_NAME,
_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME,
CredentialError, BackendError, ValidationError)

_KAGGLE_DEFAULT_URL_BASE = "https://www.kaggle.com"
_KAGGLE_URL_BASE_ENV_VAR_NAME = "KAGGLE_URL_BASE"
_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME = "KAGGLE_USER_SECRETS_TOKEN"
TIMEOUT_SECS = 40

class CredentialError(Exception):
pass


class BackendError(Exception):
pass


class KaggleWebClient:
TIMEOUT_SECS = 600

def __init__(self):
url_base_override = os.getenv(_KAGGLE_URL_BASE_ENV_VAR_NAME)
Expand All @@ -29,14 +34,14 @@ def __init__(self):
'X-Kaggle-Authorization': f'Bearer {self.jwt_token}',
}

def make_post_request(self, data: dict, endpoint: str) -> dict:
def make_post_request(self, data: dict, endpoint: str, timeout: int = TIMEOUT_SECS) -> dict:
url = f'{self.url_base}{endpoint}'
request_body = dict(data)
request_body['JWE'] = self.jwt_token
req = urllib.request.Request(url, headers=self.headers, data=bytes(
json.dumps(request_body), encoding="utf-8"))
try:
with urllib.request.urlopen(req, timeout=self.TIMEOUT_SECS) as response:
with urllib.request.urlopen(req, timeout=timeout) as response:
response_json = json.loads(response.read())
if not response_json.get('wasSuccessful') or 'result' not in response_json:
raise BackendError(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from test.support import EnvironmentVarGuard
from urllib.parse import urlparse

from kaggle_secrets import (_KAGGLE_URL_BASE_ENV_VAR_NAME,
from kaggle_web_client import (KaggleWebClient,
_KAGGLE_URL_BASE_ENV_VAR_NAME,
_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME,
CredentialError, BackendError, ValidationError)
from kaggle_web_client import KaggleWebClient
CredentialError, BackendError)
from kaggle_datasets import KaggleDatasets, _KAGGLE_TPU_NAME_ENV_VAR_NAME

_TEST_JWT = 'test-secrets-key'
Expand Down
7 changes: 4 additions & 3 deletions tests/test_user_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

from google.auth.exceptions import DefaultCredentialsError
from google.cloud import bigquery
from kaggle_secrets import (_KAGGLE_URL_BASE_ENV_VAR_NAME,
from kaggle_secrets import (GcpTarget, UserSecretsClient,
NotFoundError, ValidationError)
from kaggle_web_client import (_KAGGLE_URL_BASE_ENV_VAR_NAME,
_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME,
CredentialError, GcpTarget, UserSecretsClient,
BackendError, NotFoundError, ValidationError)
CredentialError, BackendError)

_TEST_JWT = 'test-secrets-key'

Expand Down

0 comments on commit 5770a85

Please sign in to comment.