From 823fa1505d3e832dece55b1f7a0e016405ce6971 Mon Sep 17 00:00:00 2001 From: Vincent Roseberry Date: Thu, 30 Nov 2023 17:33:12 +0000 Subject: [PATCH] Fix UserSecretsClient#set_tensorflow_credentials http://b/313994895 --- patches/kaggle_secrets.py | 12 ++---------- tests/test_user_secrets.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/patches/kaggle_secrets.py b/patches/kaggle_secrets.py index 0bb97a03..a177c171 100644 --- a/patches/kaggle_secrets.py +++ b/patches/kaggle_secrets.py @@ -106,19 +106,11 @@ def set_gcloud_credentials(self, project=None, account=None): subprocess.run(['gcloud', 'config', 'set', 'account', account]) def set_tensorflow_credential(self, credential): - """Sets the credential for use by Tensorflow both in the local notebook - and to pass to the TPU. - """ - # b/159906185: Import tensorflow_gcs_config only when this method is called to prevent preloading TensorFlow. - import tensorflow_gcs_config + """Sets the credential for use by Tensorflow""" - # Write to a local JSON credentials file and set - # GOOGLE_APPLICATION_CREDENTIALS for tensorflow running in the notebook. + # Write to a local JSON credentials file self._write_credentials_file(credential) - # set the credential for the TPU - tensorflow_gcs_config.configure_gcs(credentials=credential) - def get_bigquery_access_token(self) -> Tuple[str, Optional[datetime]]: """Retrieves BigQuery access token information from the UserSecrets service. diff --git a/tests/test_user_secrets.py b/tests/test_user_secrets.py index 6dd26354..9cb9c7c5 100644 --- a/tests/test_user_secrets.py +++ b/tests/test_user_secrets.py @@ -166,6 +166,22 @@ def test_fn(): self._test_client(test_fn, '/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"}, secret=secret) + def test_set_tensorflow_credential(self): + secret = '{"client_id":"gcloud","type":"authorized_user","refresh_token":"refresh_token"}' + + def test_fn(): + client = UserSecretsClient() + creds = client.get_gcloud_credential() + client.set_tensorflow_credential(creds) + + expected_creds_file = '/tmp/gcloud_credential.json' + self.assertEqual(expected_creds_file, os.environ['GOOGLE_APPLICATION_CREDENTIALS']) + + with open(expected_creds_file, 'r') as f: + self.assertEqual(secret, '\n'.join(f.readlines())) + + self._test_client(test_fn, '/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"}, secret=secret) + @mock.patch('kaggle_secrets.datetime') def test_get_access_token_succeeds(self, mock_dt): secret = '12345'