Skip to content

RDS: implement CopyTags and Tags parameters for copy_db_snapshot() #8320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions moto/rds/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,15 +1838,19 @@ def copy_db_snapshot(
source_snapshot_identifier: str,
target_snapshot_identifier: str,
tags: Optional[List[Dict[str, str]]] = None,
copy_tags: bool = False,
) -> DBSnapshot:
if source_snapshot_identifier not in self.database_snapshots:
raise DBSnapshotNotFoundError(source_snapshot_identifier)

source_snapshot = self.database_snapshots[source_snapshot_identifier]
if tags is None:
tags = source_snapshot.tags
else:
tags = self._merge_tags(source_snapshot.tags, tags)

# When tags are passed, AWS does NOT copy/merge tags of the
# source snapshot, even when copy_tags=True is given.
# But when tags=[], AWS does honor copy_tags=True.
if not tags:
tags = source_snapshot.tags if copy_tags else []

return self.create_db_snapshot(
db_instance=source_snapshot.database,
db_snapshot_identifier=target_snapshot_identifier,
Expand Down
3 changes: 2 additions & 1 deletion moto/rds/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,9 @@ def copy_db_snapshot(self) -> str:
source_snapshot_identifier = self._get_param("SourceDBSnapshotIdentifier")
target_snapshot_identifier = self._get_param("TargetDBSnapshotIdentifier")
tags = self.unpack_list_params("Tags", "Tag")
copy_tags = self._get_param("CopyTags")
snapshot = self.backend.copy_db_snapshot(
source_snapshot_identifier, target_snapshot_identifier, tags
source_snapshot_identifier, target_snapshot_identifier, tags, copy_tags
)
template = self.response_template(COPY_SNAPSHOT_TEMPLATE)
return template.render(snapshot=snapshot)
Expand Down
49 changes: 49 additions & 0 deletions tests/test_rds/test_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,55 @@ def test_copy_db_snapshots(delete_db_instance: bool):
assert result["TagList"] == []


original_snapshot_tags = [{"Key": "original", "Value": "snapshot tags"}]
new_snapshot_tags = [{"Key": "new", "Value": "tag"}]


@pytest.mark.parametrize(
"kwargs,expected_tags",
[
# No Tags parameter, CopyTags defaults to False -> no tags
({}, []),
# No Tags parameter, CopyTags set to True -> use tags of original snapshot
({"CopyTags": True}, original_snapshot_tags),
# When "Tags" are given, they become the only tags of the snapshot.
({"Tags": new_snapshot_tags}, new_snapshot_tags),
# When "Tags" are given, they become the only tags of the snapshot. Even if CopyTags is True!
({"Tags": new_snapshot_tags, "CopyTags": True}, new_snapshot_tags),
# When "Tags" are given but empty, CopyTags=True takes effect again!
({"Tags": [], "CopyTags": True}, original_snapshot_tags),
],
ids=(
"no_parameters",
"copytags_true",
"only_tags",
"copytags_true_and_tags",
"copytags_true_and_empty_tags",
),
)
@mock_aws
def test_copy_db_snapshots_copytags_and_tags(kwargs, expected_tags):
conn = boto3.client("rds", region_name=DEFAULT_REGION)
conn.create_db_instance(
DBInstanceIdentifier="db-primary-1",
Engine="postgres",
DBInstanceClass="db.m1.small",
)
conn.create_db_snapshot(
DBInstanceIdentifier="db-primary-1",
DBSnapshotIdentifier="snapshot",
Tags=original_snapshot_tags,
)

target_snapshot = conn.copy_db_snapshot(
SourceDBSnapshotIdentifier="snapshot",
TargetDBSnapshotIdentifier="snapshot-copy",
**kwargs,
).get("DBSnapshot")
result = conn.list_tags_for_resource(ResourceName=target_snapshot["DBSnapshotArn"])
assert result["TagList"] == expected_tags


@mock_aws
def test_describe_db_snapshots():
conn = boto3.client("rds", region_name=DEFAULT_REGION)
Expand Down
Loading