diff --git a/integration/test_rbac.py b/integration/test_rbac.py index b7349e73e..aae9dc4f3 100644 --- a/integration/test_rbac.py +++ b/integration/test_rbac.py @@ -261,6 +261,14 @@ def test_downsert_permissions(client_factory: ClientFactory) -> None: client.roles.delete(role_name) +def test_own_roles(client_factory: ClientFactory) -> None: + with client_factory(ports=RBAC_PORTS, auth_credentials=RBAC_AUTH_CREDS) as client: + if client._connection._weaviate_version.is_lower_than(1, 28, 0): + pytest.skip("This test requires Weaviate 1.28.0 or higher") + roles = client.roles.get_current_roles() + assert len(roles) > 0 + + def test_multiple_permissions(client_factory: ClientFactory) -> None: with client_factory(ports=RBAC_PORTS, auth_credentials=RBAC_AUTH_CREDS) as client: if client._connection._weaviate_version.is_lower_than(1, 28, 0): diff --git a/weaviate/rbac/roles.py b/weaviate/rbac/roles.py index 7bd6498ac..8780d9629 100644 --- a/weaviate/rbac/roles.py +++ b/weaviate/rbac/roles.py @@ -27,6 +27,16 @@ async def _get_roles(self) -> List[WeaviateRole]: ) return cast(List[WeaviateRole], res.json()) + async def _get_current_roles(self) -> List[WeaviateRole]: + path = "/authz/users/own-roles" + + res = await self._connection.get( + path, + error_msg="Could not get roles", + status_codes=_ExpectedStatusCodes(ok_in=[200], error="Get own roles"), + ) + return cast(List[WeaviateRole], res.json()) + async def _get_role(self, name: str) -> Optional[WeaviateRole]: path = f"/authz/roles/{name}" @@ -124,13 +134,23 @@ class _RolesAsync(_RolesBase): def __user_from_weaviate_user(self, user: str) -> User: return User(name=user) - async def list_all(self) -> List[Role]: + async def list_all(self) -> Dict[str, Role]: """Get all roles. Returns: - All roles. + A dictionary with user names as keys and the `Role` objects as values. """ - return [Role._from_weaviate_role(role) for role in await self._get_roles()] + return {role["name"]: Role._from_weaviate_role(role) for role in await self._get_roles()} + + async def get_current_roles(self) -> Dict[str, Role]: + """Get all roles for current user. + + Returns: + A dictionary with user names as keys and the `Role` objects as values. + """ + return { + role["name"]: Role._from_weaviate_role(role) for role in await self._get_current_roles() + } async def exists(self, role: str) -> bool: """Check if a role exists. diff --git a/weaviate/rbac/sync.pyi b/weaviate/rbac/sync.pyi index a7d8f3d4a..418bb1f52 100644 --- a/weaviate/rbac/sync.pyi +++ b/weaviate/rbac/sync.pyi @@ -1,9 +1,11 @@ from typing import Dict, List, Optional, Union -from weaviate.rbac.roles import _RolesBase + from weaviate.rbac.models import PermissionsType, Role, User +from weaviate.rbac.roles import _RolesBase class _Roles(_RolesBase): - def list_all(self) -> List[Role]: ... + def list_all(self) -> Dict[str, Role]: ... + def get_current_roles(self) -> Dict[str, Role]: ... def by_name(self, role: str) -> Optional[Role]: ... def by_user(self, user: str) -> Dict[str, Role]: ... def users(self, role: str) -> Dict[str, User]: ...