Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix substituting regex with empty string #1953

Merged
merged 16 commits into from
Jul 2, 2024
Merged
103 changes: 58 additions & 45 deletions src/databricks/labs/ucx/workspace_access/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(
self.include_group_names = include_group_names

@abstractmethod
def generate_migrated_groups(self):
def generate_migrated_groups(self) -> Iterable[MigratedGroup]:
raise NotImplementedError

def get_filtered_groups(self):
Expand All @@ -182,9 +182,9 @@ def get_filtered_groups(self):
}

@staticmethod
def _safe_match(group_name: str, match_re: str) -> str:
def _safe_match(group_name: str, pattern: re.Pattern) -> str:
try:
match = re.search(match_re, group_name)
match = pattern.search(group_name)
if not match:
return group_name
match_groups = match.groups()
Expand All @@ -195,11 +195,11 @@ def _safe_match(group_name: str, match_re: str) -> str:
return group_name

@staticmethod
def _safe_sub(group_name: str, match_re: str, replace: str) -> str:
def _safe_sub(group_name: str, pattern: re.Pattern, replace: str) -> str:
try:
return re.sub(match_re, replace, group_name)
return pattern.sub(replace, group_name)
except re.error:
logger.warning(f"Failed to apply Regex Expression {match_re} on Group Name {group_name}")
logger.warning(f"Failed to apply Regex Expression {pattern} on Group Name {group_name}")
return group_name


Expand All @@ -208,9 +208,9 @@ def __init__(
self,
workspace_groups_in_workspace,
account_groups_in_account,
/,
renamed_groups_prefix,
include_group_names=None,
*,
renamed_groups_prefix: str,
include_group_names: list[str] | None,
):
super().__init__(
workspace_groups_in_workspace,
Expand All @@ -219,16 +219,16 @@ def __init__(
renamed_groups_prefix=renamed_groups_prefix,
)

def generate_migrated_groups(self):
def generate_migrated_groups(self) -> Iterable[MigratedGroup]:
workspace_groups = self.get_filtered_groups()
for group in workspace_groups.values():
temporary_name = f"{self.renamed_groups_prefix}{group.display_name}"
account_group = self.account_groups_in_account.get(group.display_name)
if not account_group:
logger.info(
f"Couldn't find a matching account group for {group.display_name} group using name matching"
)
continue
temporary_name = f"{self.renamed_groups_prefix}{group.display_name}"
yield MigratedGroup(
id_in_workspace=group.id,
name_in_workspace=group.display_name,
Expand All @@ -246,9 +246,9 @@ def __init__(
self,
workspace_groups_in_workspace,
account_groups_in_account,
/,
renamed_groups_prefix,
include_group_names=None,
*,
renamed_groups_prefix: str,
include_group_names: list[str] | None,
):
super().__init__(
workspace_groups_in_workspace,
Expand All @@ -257,15 +257,15 @@ def __init__(
renamed_groups_prefix=renamed_groups_prefix,
)

def generate_migrated_groups(self):
def generate_migrated_groups(self) -> Iterable[MigratedGroup]:
workspace_groups = self.get_filtered_groups()
account_groups_by_id = {group.external_id: group for group in self.account_groups_in_account.values()}
for group in workspace_groups.values():
temporary_name = f"{self.renamed_groups_prefix}{group.display_name}"
account_group = account_groups_by_id.get(group.external_id)
if not account_group:
logger.info(f"Couldn't find a matching account group for {group.display_name} group with external_id")
continue
temporary_name = f"{self.renamed_groups_prefix}{group.display_name}"
yield MigratedGroup(
id_in_workspace=group.id,
name_in_workspace=group.display_name,
Expand All @@ -281,36 +281,40 @@ def generate_migrated_groups(self):
class RegexSubStrategy(GroupMigrationStrategy):
def __init__(
self,
workspace_groups_in_workspace,
account_groups_in_account,
/,
renamed_groups_prefix,
include_group_names=None,
workspace_group_regex: str | None = None,
workspace_group_replace: str | None = None,
workspace_groups_in_workspace: dict[str, Group],
account_groups_in_account: dict[str, Group],
*,
renamed_groups_prefix: str,
include_group_names: list[str] | None,
workspace_group_regex: str,
workspace_group_replace: str,
):
super().__init__(
workspace_groups_in_workspace,
account_groups_in_account,
include_group_names=include_group_names,
renamed_groups_prefix=renamed_groups_prefix,
)
self.workspace_group_regex = workspace_group_regex # Keep to support legacy public API
self.workspace_group_replace = workspace_group_replace
self.workspace_group_regex = workspace_group_regex

def generate_migrated_groups(self):
self._workspace_group_pattern = re.compile(self.workspace_group_regex)

def generate_migrated_groups(self) -> Iterable[MigratedGroup]:
workspace_groups = self.get_filtered_groups()
for group in workspace_groups.values():
temporary_name = f"{self.renamed_groups_prefix}{group.display_name}"
name_in_account = self._safe_sub(
group.display_name, self.workspace_group_regex, self.workspace_group_replace
group.display_name,
self._workspace_group_pattern,
self.workspace_group_replace,
)
account_group = self.account_groups_in_account.get(name_in_account)
if not account_group:
logger.info(
f"Couldn't find a matching account group for {group.display_name} group with regex substitution"
)
continue
temporary_name = f"{self.renamed_groups_prefix}{group.display_name}"
yield MigratedGroup(
id_in_workspace=group.id,
name_in_workspace=group.display_name,
Expand All @@ -328,38 +332,42 @@ def __init__(
self,
workspace_groups_in_workspace,
account_groups_in_account,
/,
renamed_groups_prefix,
include_group_names=None,
workspace_group_regex: str | None = None,
account_group_regex: str | None = None,
*,
renamed_groups_prefix: str,
include_group_names: list[str] | None,
workspace_group_regex: str,
account_group_regex: str,
):
super().__init__(
workspace_groups_in_workspace,
account_groups_in_account,
include_group_names=include_group_names,
renamed_groups_prefix=renamed_groups_prefix,
)
self.account_group_regex = account_group_regex
# Keep to support legacy public API
self.workspace_group_regex = workspace_group_regex
asnare marked this conversation as resolved.
Show resolved Hide resolved
self.account_group_regex = account_group_regex

def generate_migrated_groups(self):
self._workspace_group_pattern = re.compile(self.workspace_group_regex)
self._account_group_pattern = re.compile(self.account_group_regex)

def generate_migrated_groups(self) -> Iterable[MigratedGroup]:
workspace_groups_by_match = {
self._safe_match(group_name, self.workspace_group_regex): group
self._safe_match(group_name, self._workspace_group_pattern): group
for group_name, group in self.get_filtered_groups().items()
}
account_groups_by_match = {
self._safe_match(group_name, self.account_group_regex): group
self._safe_match(group_name, self._account_group_pattern): group
for group_name, group in self.account_groups_in_account.items()
}
for group_match, ws_group in workspace_groups_by_match.items():
temporary_name = f"{self.renamed_groups_prefix}{ws_group.display_name}"
account_group = account_groups_by_match.get(group_match)
if not account_group:
logger.info(
f"Couldn't find a matching account group for {ws_group.display_name} group with regex matching"
)
continue
temporary_name = f"{self.renamed_groups_prefix}{ws_group.display_name}"
yield MigratedGroup(
id_in_workspace=ws_group.id,
name_in_workspace=ws_group.display_name,
Expand Down Expand Up @@ -672,7 +680,7 @@ def _reflect_account_group_to_workspace(self, account_group_id: str):
def _get_strategy(
self, workspace_groups_in_workspace: dict[str, Group], account_groups_in_account: dict[str, Group]
) -> GroupMigrationStrategy:
if self._workspace_group_regex and self._workspace_group_replace:
if self._workspace_group_regex is not None and self._workspace_group_replace is not None:
return RegexSubStrategy(
workspace_groups_in_workspace,
account_groups_in_account,
Expand All @@ -681,7 +689,7 @@ def _get_strategy(
workspace_group_regex=self._workspace_group_regex,
workspace_group_replace=self._workspace_group_replace,
)
if self._workspace_group_regex and self._account_group_regex:
if self._workspace_group_regex is not None and self._account_group_regex is not None:
return RegexMatchStrategy(
workspace_groups_in_workspace,
account_groups_in_account,
Expand Down Expand Up @@ -713,9 +721,12 @@ class ConfigureGroups:
group_match_by_external_id = None
include_group_names = None

_valid_substitute_pattern = re.compile(r"[\s#,+ \\<>;]")

def __init__(self, prompts: Prompts):
self._prompts = prompts
self._ask_for_group = functools.partial(self._prompts.question, validate=self._is_valid_group_str)
self._ask_for_substitute = functools.partial(self._prompts.question, validate=self._is_valid_substitute_str)
self._ask_for_regex = functools.partial(self._prompts.question, validate=self._validate_regex)

def run(self):
Expand Down Expand Up @@ -755,11 +766,11 @@ def _configure_substitution(self):
match_value = self._ask_for_regex("Enter a regular expression for substitution")
if not match_value:
return False
sub_value = self._ask_for_group("Enter the substitution value")
if not sub_value:
substitute = self._ask_for_substitute("Enter the substitution value")
if substitute is None:
return False
self.workspace_group_regex = match_value
self.workspace_group_replace = sub_value
self.workspace_group_replace = substitute
return True

def _configure_matching(self):
Expand Down Expand Up @@ -787,9 +798,11 @@ def _configure_external(self):
self.group_match_by_external_id = True
return True

@staticmethod
def _is_valid_group_str(group_str: str):
return group_str and not re.search(r"[\s#,+ \\<>;]", group_str)
def _is_valid_group_str(self, group_str: str) -> bool:
return len(group_str) > 0 and self._is_valid_substitute_str(group_str)

def _is_valid_substitute_str(self, substitute: str) -> bool:
return not self._valid_substitute_pattern.search(substitute)

@staticmethod
def _validate_regex(regex_input: str) -> bool:
Expand Down
54 changes: 50 additions & 4 deletions tests/unit/workspace_access/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
GroupManager,
MigratedGroup,
MigrationState,
RegexSubStrategy,
)


Expand Down Expand Up @@ -144,7 +145,6 @@ def test_snapshot_should_consider_groups_defined_in_conf():
wsclient = create_autospec(WorkspaceClient)
group1 = Group(id="1", display_name="de", meta=ResourceMeta(resource_type="WorkspaceGroup"))
group2 = Group(id="2", display_name="ds", meta=ResourceMeta(resource_type="WorkspaceGroup"))
wsclient.groups.list.return_value = [group1, group2]
acc_group_1 = Group(id="11", display_name="de", external_id="1234")
acc_group_2 = Group(id="12", display_name="ds", external_id="1235")
wsclient.api_client.do.return_value = {
Expand Down Expand Up @@ -771,6 +771,31 @@ def test_snapshot_with_group_matched_by_external_id_not_found(caplog):
assert "Couldn't find a matching account group for de_(1234) group with external_id" in caplog.text


def test_snapshot_migrated_groups_when_substitute_with_empty_string():
backend = MockBackend()

workspace_group = Group(display_name="group_old", id="1")
ws = create_autospec(WorkspaceClient)
ws.groups.list.return_value = [workspace_group]
ws.groups.get.return_value = workspace_group

account_group = Group(display_name="group")
ws.api_client.do.return_value = {"Resources": [account_group.as_dict()]}

group_manager = GroupManager(
backend,
ws,
inventory_database="inv",
workspace_group_regex="_old",
workspace_group_replace="",
)
migrated_groups = group_manager.snapshot()

assert len(migrated_groups) == 1
assert migrated_groups[0].name_in_workspace == "group_old"
assert migrated_groups[0].name_in_account == "group"


def test_configure_include_groups():
configure_groups = ConfigureGroups(
MockPrompts(
Expand Down Expand Up @@ -831,21 +856,22 @@ def test_configure_external_id():
assert configure_groups.group_match_by_external_id


def test_configure_substitute():
@pytest.mark.parametrize("substitution_value", ["business", ""])
def test_configure_substitute(substitution_value):
configure_groups = ConfigureGroups(
MockPrompts(
{
"Backup prefix": "",
r"Choose how to map the workspace groups.*": "4", # substitute
r".*for substitution": "biz",
r".*substitution value": "business",
r".*substitution value": substitution_value,
".*": "",
}
)
)
configure_groups.run()
assert configure_groups.workspace_group_regex == "biz"
assert configure_groups.workspace_group_replace == "business"
assert configure_groups.workspace_group_replace == substitution_value


def test_configure_match():
Expand Down Expand Up @@ -1108,3 +1134,23 @@ def test_migration_state_with_filtered_group():
roles='',
)
]


def test_regex_sub_strategy_replaces_with_empty_replace():
workspace_groups = {"group_old": Group("group_old")}
account_groups = {"group": Group("group")}
strategy = RegexSubStrategy(
workspace_groups,
account_groups,
renamed_groups_prefix="ucx-renamed-",
include_group_names=["group_old"],
workspace_group_regex="_old",
workspace_group_replace="",
)

migrated_group = next(strategy.generate_migrated_groups(), None)

assert migrated_group is not None
assert migrated_group.name_in_workspace == "group_old"
assert migrated_group.name_in_account == "group"
assert migrated_group.temporary_name == "ucx-renamed-group_old"
Loading