Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 1.1.0 (Unreleased)

### Features Added
- Registry list operation now accepts scope value to allow subscription-only based requests.
- Most configuration classes from the entity package now implement the standard mapping protocol.

### Breaking Changes
Expand Down
3 changes: 2 additions & 1 deletion sdk/ml/azure-ai-ml/azure/ai/ml/constants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__path__ = __import__("pkgutil").extend_path(__path__, __name__)

from ._common import AssetTypes, InputOutputModes, ModelType, TimeZone
from ._common import AssetTypes, InputOutputModes, ModelType, TimeZone, Scope
from ._component import ParallelTaskType
from ._deployment import BatchDeploymentOutputAction
from ._job import (
Expand Down Expand Up @@ -38,4 +38,5 @@
"AcrAccountSku",
"NlpModels",
"NlpLearningRateScheduler",
"Scope",
]
5 changes: 5 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,8 @@ class RollingRate:
DAY = "day"
HOUR = "hour"
MINUTE = "minute"


class Scope:
SUBSCRIPTION="subscription"
RESOURCE_GROUP="resource_group"
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from azure.ai.ml._utils._experimental import experimental
from .._utils._azureml_polling import AzureMLPolling
from ..constants._common import LROConfigurations
from ..constants._common import LROConfigurations, Scope

ops_logger = OpsLogger(__name__)
module_logger = ops_logger.module_logger
Expand Down Expand Up @@ -52,14 +52,19 @@ def __init__(
self._init_kwargs = kwargs

#@ monitor_with_activity(logger, "Registry.List", ActivityType.PUBLICAPI)
def list(self) -> Iterable[Registry]:
def list(self, *, scope: str = Scope.RESOURCE_GROUP) -> Iterable[Registry]:
"""List all registries that the user has access to in the current
resource group.
resource group or subscription.

:param scope: scope of the listing, "resource_group" or "subscription", defaults to "resource_group"
:type scope: str, optional
:return: An iterator like instance of Registry objects
:rtype: ~azure.core.paging.ItemPaged[Registry]
"""

if scope.lower() == Scope.SUBSCRIPTION:
return self._operation.list_by_subscription(
cls=lambda objs: [Registry._from_rest_object(obj) for obj in objs]
)
return self._operation.list(cls=lambda objs: [Registry._from_rest_object(obj) for obj in objs], \
resource_group_name=self._resource_group_name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from azure.ai.ml._utils.utils import camel_to_snake
from azure.ai.ml._version import VERSION
from azure.ai.ml.constants import ManagedServiceIdentityType
from azure.ai.ml.constants._common import ArmConstants, LROConfigurations, WorkspaceResourceConstants
from azure.ai.ml.constants._common import ArmConstants, LROConfigurations, WorkspaceResourceConstants, Scope
from azure.ai.ml.entities._credentials import IdentityConfiguration
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
from azure.core.credentials import TokenCredential
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
self.containerRegistry = "none"

# @monitor_with_activity(logger, "Workspace.List", ActivityType.PUBLICAPI)
def list(self, *, scope: str = "resource_group") -> Iterable[Workspace]:
def list(self, *, scope: str = Scope.RESOURCE_GROUP) -> Iterable[Workspace]:
"""List all workspaces that the user has access to in the current
resource group or subscription.

Expand All @@ -80,7 +80,7 @@ def list(self, *, scope: str = "resource_group") -> Iterable[Workspace]:
:rtype: ~azure.core.paging.ItemPaged[Workspace]
"""

if scope == "subscription":
if scope == Scope.SUBSCRIPTION:
return self._operation.list_by_subscription(
cls=lambda objs: [Workspace._from_rest_object(obj) for obj in objs]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,18 @@ def mock_registry_operation(
@pytest.mark.unittest
class TestRegistryOperations:
def test_list(self, mock_registry_operation: RegistryOperations) -> None:
# Test different input options for the scope value
mock_registry_operation.list()
mock_registry_operation._operation.list.assert_called_once()

mock_registry_operation.list(scope="invalid")
assert mock_registry_operation._operation.list.call_count == 2
mock_registry_operation._operation.list_by_subscription.assert_not_called()

mock_registry_operation.list(scope="subscription")
assert mock_registry_operation._operation.list.call_count == 2
mock_registry_operation._operation.list_by_subscription.assert_called_once()

def test_get(self, mock_registry_operation: RegistryOperations, randstr: Callable[[], str]) -> None:
mock_registry_operation.get(f"unittest_{randstr('reg_name')}")
mock_registry_operation._operation.get.assert_called_once()
Expand Down