Skip to content

Commit

Permalink
simplified perms for get
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacobjeevan committed Dec 2, 2024
1 parent 4d656be commit c38cfc3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 38 deletions.
28 changes: 5 additions & 23 deletions care/users/api/viewsets/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def last_active_after(self, queryset, name, value):

class UserViewSetPermission(DRYPermissions):
def has_permission(self, request, view):
if request.method == "GET" and "username" in view.kwargs:
if request.method == "GET" and view.action == "retrieve":
return True
return super().has_permission(request, view)

def has_object_permission(self, request, view, obj):
if request.method == "GET" and "username" in view.kwargs:
if request.method == "GET" and view.action == "retrieve":
return True
return super().has_object_permission(request, view, obj)

Expand Down Expand Up @@ -167,16 +167,9 @@ def get_queryset(self):

def get_object(self) -> User:
try:
if self.request.method == "GET" and not self.kwargs.get("username"):
username = self.request.query_params.get("username")
if not username:
raise ValidationError({"message": "Username is required"})
user = get_object_or_404(self.get_queryset(), username=username)
if not self.has_permission(user):
raise ValidationError(
{"message": "You do not have permission to access this user"}
)
return user
if self.request.method == "GET" and self.action == "retrieve":
username = self.kwargs.get("username")
return get_object_or_404(User, username=username)
return super().get_object()
except Http404 as e:
error = "User not found"
Expand Down Expand Up @@ -225,17 +218,6 @@ def destroy(self, request, *args, **kwargs):
user.save(update_fields=["is_active"])
return Response(status=status.HTTP_204_NO_CONTENT)

def has_permission(self, user):
requesting_user = self.request.user
return (
requesting_user == user
or requesting_user.is_superuser
or (
requesting_user.state == user.state
or requesting_user.district == user.district
)
)

@extend_schema(tags=["users"])
@action(detail=False, methods=["POST"])
def add_user(self, request, *args, **kwargs):
Expand Down
19 changes: 7 additions & 12 deletions care/users/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ def test_superuser_can_view(self):
data = self.user_data.copy()
data["date_of_birth"] = str(data["date_of_birth"])
data.pop("password")
user_data = self.get_detail_representation(self.user)
user_data.pop("created_by")
self.assertDictEqual(
res_data_json,
self.get_detail_representation(self.user),
user_data,
)

def test_superuser_can_modify(self):
Expand Down Expand Up @@ -264,19 +266,12 @@ def test_user_can_modify_themselves(self):
User.objects.get(username=username).date_of_birth, date(2005, 4, 1)
)

def test_user_cannot_read_others(self):
"""Test 1 user can read the attributes of the other user not in the same ditrict/state"""
username = self.data_2["username"]
response = self.client.get(f"/api/v1/users/{username}/")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.json()["detail"], "User not found")

def test_user_can_read_others_in_same_district_or_state(self):
"""Test 1 user can read the attributes of the other user in the same district or state"""
username = self.user_3.username
def test_user_can_read_others(self):
"""Test 1 user can read the attributes of any other user"""
username = self.user_2.username
response = self.client.get(f"/api/v1/users/{username}/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()["first_name"], self.user_3.first_name)
self.assertEqual(response.json()["first_name"], self.user_2.first_name)

def test_user_cannot_modify_others(self):
"""Test a user can't modify others"""
Expand Down
6 changes: 3 additions & 3 deletions care/utils/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def get_local_body_district_state_representation(self, obj):
response.update(self.get_state_representation(getattr(obj, "state", None)))
return response

def get_local_body_representation(self, local_body: LocalBody):
def get_local_body_representation(self, local_body: LocalBody | None):
if local_body is None:
return {"local_body": None, "local_body_object": None}
return {
Expand All @@ -778,7 +778,7 @@ def get_local_body_representation(self, local_body: LocalBody):
},
}

def get_district_representation(self, district: District):
def get_district_representation(self, district: District | None):
if district is None:
return {"district": None, "district_object": None}
return {
Expand All @@ -790,7 +790,7 @@ def get_district_representation(self, district: District):
},
}

def get_state_representation(self, state: State):
def get_state_representation(self, state: State | None):
if state is None:
return {"state": None, "state_object": None}
return {"state": state.id, "state_object": {"id": state.id, "name": state.name}}
Expand Down

0 comments on commit c38cfc3

Please sign in to comment.