Skip to content

Commit

Permalink
Fix UserSecretsClient#set_tensorflow_credentials (#1333)
Browse files Browse the repository at this point in the history
http://b/313994895
  • Loading branch information
rosbo authored Nov 30, 2023
1 parent 8c70958 commit 823af9a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
12 changes: 2 additions & 10 deletions patches/kaggle_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,11 @@ def set_gcloud_credentials(self, project=None, account=None):
subprocess.run(['gcloud', 'config', 'set', 'account', account])

def set_tensorflow_credential(self, credential):
"""Sets the credential for use by Tensorflow both in the local notebook
and to pass to the TPU.
"""
# b/159906185: Import tensorflow_gcs_config only when this method is called to prevent preloading TensorFlow.
import tensorflow_gcs_config
"""Sets the credential for use by Tensorflow"""

# Write to a local JSON credentials file and set
# GOOGLE_APPLICATION_CREDENTIALS for tensorflow running in the notebook.
# Write to a local JSON credentials file
self._write_credentials_file(credential)

# set the credential for the TPU
tensorflow_gcs_config.configure_gcs(credentials=credential)

def get_bigquery_access_token(self) -> Tuple[str, Optional[datetime]]:
"""Retrieves BigQuery access token information from the UserSecrets service.
Expand Down
16 changes: 16 additions & 0 deletions tests/test_user_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,22 @@ def test_fn():

self._test_client(test_fn, '/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"}, secret=secret)

def test_set_tensorflow_credential(self):
secret = '{"client_id":"gcloud","type":"authorized_user","refresh_token":"refresh_token"}'

def test_fn():
client = UserSecretsClient()
creds = client.get_gcloud_credential()
client.set_tensorflow_credential(creds)

expected_creds_file = '/tmp/gcloud_credential.json'
self.assertEqual(expected_creds_file, os.environ['GOOGLE_APPLICATION_CREDENTIALS'])

with open(expected_creds_file, 'r') as f:
self.assertEqual(secret, '\n'.join(f.readlines()))

self._test_client(test_fn, '/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"}, secret=secret)

@mock.patch('kaggle_secrets.datetime')
def test_get_access_token_succeeds(self, mock_dt):
secret = '12345'
Expand Down

0 comments on commit 823af9a

Please sign in to comment.