diff --git a/storage/google/cloud/storage/_helpers.py b/storage/google/cloud/storage/_helpers.py index 88f9b8dc0ca7..9e47c10269fc 100644 --- a/storage/google/cloud/storage/_helpers.py +++ b/storage/google/cloud/storage/_helpers.py @@ -67,6 +67,11 @@ def client(self): """Abstract getter for the object client.""" raise NotImplementedError + @property + def user_project(self): + """Abstract getter for the object user_project.""" + raise NotImplementedError + def _require_client(self, client): """Check client or verify over-ride. @@ -94,6 +99,8 @@ def reload(self, client=None): # Pass only '?projection=noAcl' here because 'acl' and related # are handled via custom endpoints. query_params = {'projection': 'noAcl'} + if self.user_project is not None: + query_params['userProject'] = self.user_project api_response = client._connection.api_request( method='GET', path=self.path, query_params=query_params, _target_object=self) @@ -140,11 +147,14 @@ def patch(self, client=None): client = self._require_client(client) # Pass '?projection=full' here because 'PATCH' documented not # to work properly w/ 'noAcl'. + query_params = {'projection': 'full'} + if self.user_project is not None: + query_params['userProject'] = self.user_project update_properties = {key: self._properties[key] for key in self._changes} api_response = client._connection.api_request( method='PATCH', path=self.path, data=update_properties, - query_params={'projection': 'full'}, _target_object=self) + query_params=query_params, _target_object=self) self._set_properties(api_response) diff --git a/storage/google/cloud/storage/bucket.py b/storage/google/cloud/storage/bucket.py index 865a23840af4..64e38921d887 100644 --- a/storage/google/cloud/storage/bucket.py +++ b/storage/google/cloud/storage/bucket.py @@ -798,10 +798,39 @@ def versioning_enabled(self, value): details. :type value: convertible to boolean - :param value: should versioning be anabled for the bucket? + :param value: should versioning be enabled for the bucket? """ self._patch_property('versioning', {'enabled': bool(value)}) + @property + def requester_pays(self): + """Does the requester pay for API requests for this bucket? + + See https://cloud.google.com/storage/docs/ for + details. + + :setter: Update whether requester pays for this bucket. + :getter: Query whether requester pays for this bucket. + + :rtype: bool + :returns: True if requester pays for API requests for the bucket, + else False. + """ + versioning = self._properties.get('billing', {}) + return versioning.get('requesterPays', False) + + @requester_pays.setter + def requester_pays(self, value): + """Update whether requester pays for API requests for this bucket. + + See https://cloud.google.com/storage/docs/ for + details. + + :type value: convertible to boolean + :param value: should requester pay for API requests for the bucket? + """ + self._patch_property('billing', {'requesterPays': bool(value)}) + def configure_website(self, main_page_suffix=None, not_found_page=None): """Configure website-related properties. diff --git a/storage/google/cloud/storage/client.py b/storage/google/cloud/storage/client.py index 93785e05269f..f6f58adae92c 100644 --- a/storage/google/cloud/storage/client.py +++ b/storage/google/cloud/storage/client.py @@ -194,7 +194,7 @@ def lookup_bucket(self, bucket_name): except NotFound: return None - def create_bucket(self, bucket_name): + def create_bucket(self, bucket_name, requester_pays=None): """Create a new bucket. For example: @@ -211,10 +211,16 @@ def create_bucket(self, bucket_name): :type bucket_name: str :param bucket_name: The bucket name to create. + :type requester_pays: bool + :param requester_pays: (Optional) Whether requester pays for + API requests for this bucket and its blobs. + :rtype: :class:`google.cloud.storage.bucket.Bucket` :returns: The newly created bucket. """ bucket = Bucket(self, name=bucket_name) + if requester_pays is not None: + bucket.requester_pays = requester_pays bucket.create(client=self) return bucket diff --git a/storage/tests/system.py b/storage/tests/system.py index afab659882bf..06f50b26128b 100644 --- a/storage/tests/system.py +++ b/storage/tests/system.py @@ -30,6 +30,8 @@ HTTP = httplib2.Http() +REQUESTER_PAYS_ENABLED = False # query from environment? + def _bad_copy(bad_request): """Predicate: pass only exceptions for a failed copyTo.""" @@ -99,6 +101,15 @@ def test_create_bucket(self): self.case_buckets_to_delete.append(new_bucket_name) self.assertEqual(created.name, new_bucket_name) + @unittest.skipUnless(REQUESTER_PAYS_ENABLED, "requesterPays not enabled") + def test_create_bucket_with_requester_pays(self): + new_bucket_name = 'w-requester-pays' + unique_resource_id('-') + created = Config.CLIENT.create_bucket( + new_bucket_name, requester_pays=True) + self.case_buckets_to_delete.append(new_bucket_name) + self.assertEqual(created.name, new_bucket_name) + self.assertTrue(created.requester_pays) + def test_list_buckets(self): buckets_to_create = [ 'new' + unique_resource_id(), diff --git a/storage/tests/unit/test__helpers.py b/storage/tests/unit/test__helpers.py index 89967f3a0db0..1d87b42c9fea 100644 --- a/storage/tests/unit/test__helpers.py +++ b/storage/tests/unit/test__helpers.py @@ -26,7 +26,7 @@ def _get_target_class(): def _make_one(self, *args, **kw): return self._get_target_class()(*args, **kw) - def _derivedClass(self, path=None): + def _derivedClass(self, path=None, user_project=None): class Derived(self._get_target_class()): @@ -36,15 +36,26 @@ class Derived(self._get_target_class()): def path(self): return path + @property + def user_project(self): + return user_project + return Derived def test_path_is_abstract(self): mixin = self._make_one() - self.assertRaises(NotImplementedError, lambda: mixin.path) + with self.assertRaises(NotImplementedError): + mixin.path def test_client_is_abstract(self): mixin = self._make_one() - self.assertRaises(NotImplementedError, lambda: mixin.client) + with self.assertRaises(NotImplementedError): + mixin.client + + def test_user_project_is_abstract(self): + mixin = self._make_one() + with self.assertRaises(NotImplementedError): + mixin.user_project def test_reload(self): connection = _Connection({'foo': 'Foo'}) @@ -62,6 +73,25 @@ def test_reload(self): # Make sure changes get reset by reload. self.assertEqual(derived._changes, set()) + def test_reload_w_user_project(self): + user_project = 'user-project-123' + connection = _Connection({'foo': 'Foo'}) + client = _Client(connection) + derived = self._derivedClass('/path', user_project)() + # Make sure changes is not a set, so we can observe a change. + derived._changes = object() + derived.reload(client=client) + self.assertEqual(derived._properties, {'foo': 'Foo'}) + kw = connection._requested + self.assertEqual(len(kw), 1) + self.assertEqual(kw[0]['method'], 'GET') + self.assertEqual(kw[0]['path'], '/path') + self.assertEqual( + kw[0]['query_params'], + {'projection': 'noAcl', 'userProject': user_project}) + # Make sure changes get reset by reload. + self.assertEqual(derived._changes, set()) + def test__set_properties(self): mixin = self._make_one() self.assertEqual(mixin._properties, {}) @@ -95,6 +125,30 @@ def test_patch(self): # Make sure changes get reset by patch(). self.assertEqual(derived._changes, set()) + def test_patch_w_user_project(self): + user_project = 'user-project-123' + connection = _Connection({'foo': 'Foo'}) + client = _Client(connection) + derived = self._derivedClass('/path', user_project)() + # Make sure changes is non-empty, so we can observe a change. + BAR = object() + BAZ = object() + derived._properties = {'bar': BAR, 'baz': BAZ} + derived._changes = set(['bar']) # Ignore baz. + derived.patch(client=client) + self.assertEqual(derived._properties, {'foo': 'Foo'}) + kw = connection._requested + self.assertEqual(len(kw), 1) + self.assertEqual(kw[0]['method'], 'PATCH') + self.assertEqual(kw[0]['path'], '/path') + self.assertEqual( + kw[0]['query_params'], + {'projection': 'full', 'userProject': user_project}) + # Since changes does not include `baz`, we don't see it sent. + self.assertEqual(kw[0]['data'], {'bar': BAR}) + # Make sure changes get reset by patch(). + self.assertEqual(derived._changes, set()) + class Test__scalar_property(unittest.TestCase): diff --git a/storage/tests/unit/test_bucket.py b/storage/tests/unit/test_bucket.py index 5e4a91575197..34835110bd67 100644 --- a/storage/tests/unit/test_bucket.py +++ b/storage/tests/unit/test_bucket.py @@ -176,6 +176,7 @@ def test_create_w_extra_properties(self): 'location': LOCATION, 'storageClass': STORAGE_CLASS, 'versioning': {'enabled': True}, + 'billing': {'requesterPays': True}, 'labels': LABELS, } connection = _Connection(DATA) @@ -186,6 +187,7 @@ def test_create_w_extra_properties(self): bucket.location = LOCATION bucket.storage_class = STORAGE_CLASS bucket.versioning_enabled = True + bucket.requester_pays = True bucket.labels = LABELS bucket.create() @@ -866,6 +868,24 @@ def test_versioning_enabled_setter(self): bucket.versioning_enabled = True self.assertTrue(bucket.versioning_enabled) + def test_requester_pays_getter_missing(self): + NAME = 'name' + bucket = self._make_one(name=NAME) + self.assertEqual(bucket.requester_pays, False) + + def test_requester_pays_getter(self): + NAME = 'name' + before = {'billing': {'requesterPays': True}} + bucket = self._make_one(name=NAME, properties=before) + self.assertEqual(bucket.requester_pays, True) + + def test_requester_pays_setter(self): + NAME = 'name' + bucket = self._make_one(name=NAME) + self.assertFalse(bucket.requester_pays) + bucket.requester_pays = True + self.assertTrue(bucket.requester_pays) + def test_configure_website_defaults(self): NAME = 'name' UNSET = {'website': {'mainPageSuffix': None, diff --git a/storage/tests/unit/test_client.py b/storage/tests/unit/test_client.py index 9696d4e5fa51..29545415a220 100644 --- a/storage/tests/unit/test_client.py +++ b/storage/tests/unit/test_client.py @@ -155,22 +155,22 @@ def test_get_bucket_hit(self): CREDENTIALS = _make_credentials() client = self._make_one(project=PROJECT, credentials=CREDENTIALS) - BLOB_NAME = 'blob-name' + BUCKET_NAME = 'bucket-name' URI = '/'.join([ client._connection.API_BASE_URL, 'storage', client._connection.API_VERSION, 'b', - '%s?projection=noAcl' % (BLOB_NAME,), + '%s?projection=noAcl' % (BUCKET_NAME,), ]) http = client._http_internal = _Http( {'status': '200', 'content-type': 'application/json'}, - '{{"name": "{0}"}}'.format(BLOB_NAME).encode('utf-8'), + '{{"name": "{0}"}}'.format(BUCKET_NAME).encode('utf-8'), ) - bucket = client.get_bucket(BLOB_NAME) + bucket = client.get_bucket(BUCKET_NAME) self.assertIsInstance(bucket, Bucket) - self.assertEqual(bucket.name, BLOB_NAME) + self.assertEqual(bucket.name, BUCKET_NAME) self.assertEqual(http._called_with['method'], 'GET') self.assertEqual(http._called_with['uri'], URI) @@ -203,33 +203,34 @@ def test_lookup_bucket_hit(self): CREDENTIALS = _make_credentials() client = self._make_one(project=PROJECT, credentials=CREDENTIALS) - BLOB_NAME = 'blob-name' + BUCKET_NAME = 'bucket-name' URI = '/'.join([ client._connection.API_BASE_URL, 'storage', client._connection.API_VERSION, 'b', - '%s?projection=noAcl' % (BLOB_NAME,), + '%s?projection=noAcl' % (BUCKET_NAME,), ]) http = client._http_internal = _Http( {'status': '200', 'content-type': 'application/json'}, - '{{"name": "{0}"}}'.format(BLOB_NAME).encode('utf-8'), + '{{"name": "{0}"}}'.format(BUCKET_NAME).encode('utf-8'), ) - bucket = client.lookup_bucket(BLOB_NAME) + bucket = client.lookup_bucket(BUCKET_NAME) self.assertIsInstance(bucket, Bucket) - self.assertEqual(bucket.name, BLOB_NAME) + self.assertEqual(bucket.name, BUCKET_NAME) self.assertEqual(http._called_with['method'], 'GET') self.assertEqual(http._called_with['uri'], URI) def test_create_bucket_conflict(self): + import json from google.cloud.exceptions import Conflict PROJECT = 'PROJECT' CREDENTIALS = _make_credentials() client = self._make_one(project=PROJECT, credentials=CREDENTIALS) - BLOB_NAME = 'blob-name' + BUCKET_NAME = 'bucket-name' URI = '/'.join([ client._connection.API_BASE_URL, 'storage', @@ -241,18 +242,21 @@ def test_create_bucket_conflict(self): '{"error": {"message": "Conflict"}}', ) - self.assertRaises(Conflict, client.create_bucket, BLOB_NAME) + self.assertRaises(Conflict, client.create_bucket, BUCKET_NAME) self.assertEqual(http._called_with['method'], 'POST') self.assertEqual(http._called_with['uri'], URI) + body = json.loads(http._called_with['body']) + self.assertEqual(body, {'name': BUCKET_NAME}) def test_create_bucket_success(self): + import json from google.cloud.storage.bucket import Bucket PROJECT = 'PROJECT' CREDENTIALS = _make_credentials() client = self._make_one(project=PROJECT, credentials=CREDENTIALS) - BLOB_NAME = 'blob-name' + BUCKET_NAME = 'bucket-name' URI = '/'.join([ client._connection.API_BASE_URL, 'storage', @@ -261,14 +265,17 @@ def test_create_bucket_success(self): ]) http = client._http_internal = _Http( {'status': '200', 'content-type': 'application/json'}, - '{{"name": "{0}"}}'.format(BLOB_NAME).encode('utf-8'), + '{{"name": "{0}"}}'.format(BUCKET_NAME).encode('utf-8'), ) - bucket = client.create_bucket(BLOB_NAME) + bucket = client.create_bucket(BUCKET_NAME, requester_pays=True) self.assertIsInstance(bucket, Bucket) - self.assertEqual(bucket.name, BLOB_NAME) + self.assertEqual(bucket.name, BUCKET_NAME) self.assertEqual(http._called_with['method'], 'POST') self.assertEqual(http._called_with['uri'], URI) + body = json.loads(http._called_with['body']) + self.assertEqual( + body, {'name': BUCKET_NAME, 'billing': {'requesterPays': True}}) def test_list_buckets_empty(self): from six.moves.urllib.parse import parse_qs @@ -400,7 +407,7 @@ def test_page_non_empty_response(self): credentials = _make_credentials() client = self._make_one(project=project, credentials=credentials) - blob_name = 'blob-name' + blob_name = 'bucket-name' response = {'items': [{'name': blob_name}]} def dummy_response():