Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix substituting regex with empty string #1953

Merged
merged 16 commits into from
Jul 2, 2024
Merged
66 changes: 33 additions & 33 deletions src/databricks/labs/ucx/workspace_access/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(
self.include_group_names = include_group_names

@abstractmethod
def generate_migrated_groups(self):
def generate_migrated_groups(self) -> Iterable[MigratedGroup]:
raise NotImplementedError

def get_filtered_groups(self):
Expand Down Expand Up @@ -208,9 +208,9 @@ def __init__(
self,
workspace_groups_in_workspace,
account_groups_in_account,
/,
renamed_groups_prefix,
include_group_names=None,
*,
renamed_groups_prefix: str,
include_group_names: list[str] | None,
):
super().__init__(
workspace_groups_in_workspace,
Expand All @@ -219,16 +219,16 @@ def __init__(
renamed_groups_prefix=renamed_groups_prefix,
)

def generate_migrated_groups(self):
def generate_migrated_groups(self) -> Iterable[MigratedGroup]:
workspace_groups = self.get_filtered_groups()
for group in workspace_groups.values():
temporary_name = f"{self.renamed_groups_prefix}{group.display_name}"
account_group = self.account_groups_in_account.get(group.display_name)
if not account_group:
logger.info(
f"Couldn't find a matching account group for {group.display_name} group using name matching"
)
continue
temporary_name = f"{self.renamed_groups_prefix}{group.display_name}"
yield MigratedGroup(
id_in_workspace=group.id,
name_in_workspace=group.display_name,
Expand All @@ -246,9 +246,9 @@ def __init__(
self,
workspace_groups_in_workspace,
account_groups_in_account,
/,
renamed_groups_prefix,
include_group_names=None,
*,
renamed_groups_prefix: str,
include_group_names: list[str] | None,
):
super().__init__(
workspace_groups_in_workspace,
Expand All @@ -257,15 +257,15 @@ def __init__(
renamed_groups_prefix=renamed_groups_prefix,
)

def generate_migrated_groups(self):
def generate_migrated_groups(self) -> Iterable[MigratedGroup]:
workspace_groups = self.get_filtered_groups()
account_groups_by_id = {group.external_id: group for group in self.account_groups_in_account.values()}
for group in workspace_groups.values():
temporary_name = f"{self.renamed_groups_prefix}{group.display_name}"
account_group = account_groups_by_id.get(group.external_id)
if not account_group:
logger.info(f"Couldn't find a matching account group for {group.display_name} group with external_id")
continue
temporary_name = f"{self.renamed_groups_prefix}{group.display_name}"
yield MigratedGroup(
id_in_workspace=group.id,
name_in_workspace=group.display_name,
Expand All @@ -281,27 +281,26 @@ def generate_migrated_groups(self):
class RegexSubStrategy(GroupMigrationStrategy):
def __init__(
self,
workspace_groups_in_workspace,
account_groups_in_account,
/,
renamed_groups_prefix,
include_group_names=None,
workspace_group_regex: str | None = None,
workspace_group_replace: str | None = None,
workspace_groups_in_workspace: dict[str, Group],
account_groups_in_account: dict[str, Group],
*,
renamed_groups_prefix: str,
include_group_names: list[str] | None,
workspace_group_regex: str,
workspace_group_replace: str,
):
super().__init__(
workspace_groups_in_workspace,
account_groups_in_account,
include_group_names=include_group_names,
renamed_groups_prefix=renamed_groups_prefix,
)
self.workspace_group_replace = workspace_group_replace
self.workspace_group_regex = workspace_group_regex
self.workspace_group_replace = workspace_group_replace

def generate_migrated_groups(self):
def generate_migrated_groups(self) -> Iterable[MigratedGroup]:
workspace_groups = self.get_filtered_groups()
for group in workspace_groups.values():
temporary_name = f"{self.renamed_groups_prefix}{group.display_name}"
name_in_account = self._safe_sub(
group.display_name, self.workspace_group_regex, self.workspace_group_replace
)
Expand All @@ -311,6 +310,7 @@ def generate_migrated_groups(self):
f"Couldn't find a matching account group for {group.display_name} group with regex substitution"
)
continue
temporary_name = f"{self.renamed_groups_prefix}{group.display_name}"
yield MigratedGroup(
id_in_workspace=group.id,
name_in_workspace=group.display_name,
Expand All @@ -328,11 +328,11 @@ def __init__(
self,
workspace_groups_in_workspace,
account_groups_in_account,
/,
renamed_groups_prefix,
include_group_names=None,
workspace_group_regex: str | None = None,
account_group_regex: str | None = None,
*,
renamed_groups_prefix: str,
include_group_names: list[str] | None,
workspace_group_regex: str,
account_group_regex: str,
):
super().__init__(
workspace_groups_in_workspace,
Expand All @@ -343,7 +343,7 @@ def __init__(
self.account_group_regex = account_group_regex
self.workspace_group_regex = workspace_group_regex
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the caller is allowed to pass these in as strings (instead of compiled patterns) I'd really like to compile them here during isn't initialisation. This isn't (just) a performance thing: a) it catches problems with regex syntax (a common problem) during initialisation rather than later when we use it for the first time; and b) any time a string is converted to its domain type lots of things just become simpler (IDE completions, type checking, etc.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, note that I kept the public API the same


def generate_migrated_groups(self):
def generate_migrated_groups(self) -> Iterable[MigratedGroup]:
workspace_groups_by_match = {
self._safe_match(group_name, self.workspace_group_regex): group
for group_name, group in self.get_filtered_groups().items()
Expand All @@ -353,13 +353,13 @@ def generate_migrated_groups(self):
for group_name, group in self.account_groups_in_account.items()
}
for group_match, ws_group in workspace_groups_by_match.items():
temporary_name = f"{self.renamed_groups_prefix}{ws_group.display_name}"
account_group = account_groups_by_match.get(group_match)
if not account_group:
logger.info(
f"Couldn't find a matching account group for {ws_group.display_name} group with regex matching"
)
continue
temporary_name = f"{self.renamed_groups_prefix}{ws_group.display_name}"
yield MigratedGroup(
id_in_workspace=ws_group.id,
name_in_workspace=ws_group.display_name,
Expand Down Expand Up @@ -672,7 +672,7 @@ def _reflect_account_group_to_workspace(self, account_group_id: str):
def _get_strategy(
self, workspace_groups_in_workspace: dict[str, Group], account_groups_in_account: dict[str, Group]
) -> GroupMigrationStrategy:
if self._workspace_group_regex and self._workspace_group_replace:
if self._workspace_group_regex is not None and self._workspace_group_replace is not None:
return RegexSubStrategy(
workspace_groups_in_workspace,
account_groups_in_account,
Expand All @@ -681,7 +681,7 @@ def _get_strategy(
workspace_group_regex=self._workspace_group_regex,
workspace_group_replace=self._workspace_group_replace,
)
if self._workspace_group_regex and self._account_group_regex:
if self._workspace_group_regex is not None and self._account_group_regex is not None:
return RegexMatchStrategy(
workspace_groups_in_workspace,
account_groups_in_account,
Expand Down Expand Up @@ -756,7 +756,7 @@ def _configure_substitution(self):
if not match_value:
return False
sub_value = self._ask_for_group("Enter the substitution value")
if not sub_value:
if sub_value is None:
return False
self.workspace_group_regex = match_value
self.workspace_group_replace = sub_value
Expand Down Expand Up @@ -788,8 +788,8 @@ def _configure_external(self):
return True

@staticmethod
def _is_valid_group_str(group_str: str):
return group_str and not re.search(r"[\s#,+ \\<>;]", group_str)
def _is_valid_group_str(group_str: str | None) -> bool:
return group_str is not None and not re.search(r"[\s#,+ \\<>;]", group_str)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please pre-compile as a class-level property.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


@staticmethod
def _validate_regex(regex_input: str) -> bool:
Expand Down
54 changes: 50 additions & 4 deletions tests/unit/workspace_access/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
GroupManager,
MigratedGroup,
MigrationState,
RegexSubStrategy,
)


Expand Down Expand Up @@ -144,7 +145,6 @@ def test_snapshot_should_consider_groups_defined_in_conf():
wsclient = create_autospec(WorkspaceClient)
group1 = Group(id="1", display_name="de", meta=ResourceMeta(resource_type="WorkspaceGroup"))
group2 = Group(id="2", display_name="ds", meta=ResourceMeta(resource_type="WorkspaceGroup"))
wsclient.groups.list.return_value = [group1, group2]
acc_group_1 = Group(id="11", display_name="de", external_id="1234")
acc_group_2 = Group(id="12", display_name="ds", external_id="1235")
wsclient.api_client.do.return_value = {
Expand Down Expand Up @@ -771,6 +771,31 @@ def test_snapshot_with_group_matched_by_external_id_not_found(caplog):
assert "Couldn't find a matching account group for de_(1234) group with external_id" in caplog.text


def test_snapshot_migrated_groups_when_substitute_with_empty_string():
backend = MockBackend()

workspace_group = Group(display_name="group_old", id="1")
ws = create_autospec(WorkspaceClient)
ws.groups.list.return_value = [workspace_group]
ws.groups.get.return_value = workspace_group

account_group = Group(display_name="group")
ws.api_client.do.return_value = {"Resources": [account_group.as_dict()]}

group_manager = GroupManager(
backend,
ws,
inventory_database="inv",
workspace_group_regex="_old",
workspace_group_replace="",
)
migrated_groups = group_manager.snapshot()

assert len(migrated_groups) == 1
assert migrated_groups[0].name_in_workspace == "group_old"
assert migrated_groups[0].name_in_account == "group"


def test_configure_include_groups():
configure_groups = ConfigureGroups(
MockPrompts(
Expand Down Expand Up @@ -831,21 +856,22 @@ def test_configure_external_id():
assert configure_groups.group_match_by_external_id


def test_configure_substitute():
@pytest.mark.parametrize("substitution_value", ["business", ""])
def test_configure_substitute(substitution_value):
configure_groups = ConfigureGroups(
MockPrompts(
{
"Backup prefix": "",
r"Choose how to map the workspace groups.*": "4", # substitute
r".*for substitution": "biz",
r".*substitution value": "business",
r".*substitution value": substitution_value,
".*": "",
}
)
)
configure_groups.run()
assert configure_groups.workspace_group_regex == "biz"
assert configure_groups.workspace_group_replace == "business"
assert configure_groups.workspace_group_replace == substitution_value


def test_configure_match():
Expand Down Expand Up @@ -1108,3 +1134,23 @@ def test_migration_state_with_filtered_group():
roles='',
)
]


def test_regex_sub_strategy_replaces_with_empty_replace():
workspace_groups = {"group_old": Group("group_old")}
account_groups = {"group": Group("group")}
strategy = RegexSubStrategy(
workspace_groups,
account_groups,
renamed_groups_prefix="ucx-renamed-",
include_group_names=["group_old"],
workspace_group_regex="_old",
workspace_group_replace="",
)

migrated_group = next(strategy.generate_migrated_groups(), None)

assert migrated_group is not None
assert migrated_group.name_in_workspace == "group_old"
assert migrated_group.name_in_account == "group"
assert migrated_group.temporary_name == "ucx-renamed-group_old"
Loading