Skip to content

Commit

Permalink
Add support for setting the GCS credential
Browse files Browse the repository at this point in the history
Set the credential for both the TPU and the local notebook.

https://b.corp.google.com/issues/158133824
  • Loading branch information
mcollins42 committed Jun 9, 2020
1 parent ad4d638 commit c703d33
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
16 changes: 16 additions & 0 deletions patches/kaggle_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import os
import socket
import tensorflow_gcs_config
import urllib.request
from datetime import datetime, timedelta
from enum import Enum, unique
Expand Down Expand Up @@ -135,6 +136,21 @@ def get_gcloud_credential(self) -> str:
else:
raise

def set_tensorflow_credential(self, credential):
"""Sets the credential for use by Tensorflow both in the local notebook
and to pass to the TPU.
"""
# Write to a local JSON credentials file and set
# GOOGLE_APPLICATION_CREDENTIALS for tensorflow running in the notebook.
adc_path = os.path.join(
os.environ.get('HOME', '/'), 'gcloud_credential.json')
with open(adc_path, 'w') as f:
f.write(credential)
os.environ['GOOGLE_APPLICATION_CREDENTIALS']=adc_path

# set the credential for the TPU
tensorflow_gcs_config.configure_gcs(credentials=credential)

def get_bigquery_access_token(self) -> Tuple[str, Optional[datetime]]:
"""Retrieves BigQuery access token information from the UserSecrets service.
Expand Down
37 changes: 37 additions & 0 deletions tests/test_tensorflow_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import unittest

import os
import tensorflow_gcs_config
from unittest.mock import patch
from test.support import EnvironmentVarGuard
from kaggle_secrets import UserSecretsClient

class TestTensorflowCredentials(unittest.TestCase):

@patch('tensorflow_gcs_config.configure_gcs')
def test_set_tensorflow_credential(self, mock_configure_gcs):
credential = '{"client_id":"fake_client_id",' \
'"client_secret":"fake_client_secret",' \
'"refresh_token":"not a refresh token",' \
'"type":"authorized_user"}';

env = EnvironmentVarGuard()
env.set('HOME', '/tmp')
env.set('GOOGLE_APPLICATION_CREDENTIALS', '')

# These need to be set to make UserSecretsClient happy, but aren't
# pertinent to this test.
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'AUTOML')

user_secrets = UserSecretsClient()
user_secrets.set_tensorflow_credential(credential)

credential_path = '/tmp/gcloud_credential.json'
self.assertEqual(
credential_path, os.environ['GOOGLE_APPLICATION_CREDENTIALS'])
with open(credential_path, 'r') as f:
saved_cred = f.read()
self.assertEqual(credential, saved_cred)

mock_configure_gcs.assert_called_with(credentials=credential)

0 comments on commit c703d33

Please sign in to comment.