Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion rest_framework/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,17 @@ def get_default_valid_fields(self, queryset, view, context={}):
)
raise ImproperlyConfigured(msg % self.__class__.__name__)

model_field_names = [field.name for field in queryset.model._meta.fields]

return [
(field.source.replace('.', '__') or field_name, field.label)
for field_name, field in serializer_class(context=context).fields.items()
if not getattr(field, 'write_only', False) and not field.source == '*'
if (
not getattr(field, 'write_only', False) and
not field.source == '*' and (
field_name in model_field_names or field.source in model_field_names
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced this is necessarily correct. For example we allow sources such as author.name, which will reference the name field across the relationship author, but that'd be filtered out by this new check, right?

Copy link
Contributor Author

@omerfarukabaci omerfarukabaci Mar 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right, I have missed that. 😕 What about the following solution:

model_class = queryset.model
model_property_names = [
    # "pk" is a property added in Django's Model class, however it is valid for ordering.
    attr for attr in dir(model_class) if isinstance(getattr(model_class, attr), property) and attr != 'pk'
]
return [
    (field.source.replace('.', '__') or field_name, field.label)
    for field_name, field in serializer_class(context=context).fields.items()
    if not getattr(field, 'write_only', False) and not field.source == '*'
    if (
        not getattr(field, 'write_only', False) and
        not field.source == '*' and
        field_name not in model_property_names and
        field.source not in model_property_names
    )
]

or do you have any other suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks okay to me, I think? (Although probably just need field.source not in model_property_names, rather than field_name not in model_property_names?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I didn't know that field.source is field_name by default, you are totally right then! 🚀 I think that the comment line about pk should stay, what do you think? After we decide this I will make the related changes. 👍🏼

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the PK comment makes sense to me.

Copy link
Contributor Author

@omerfarukabaci omerfarukabaci Mar 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed related changes: c7f8b14. If you have any other concerns we may further discuss. Thank you for your time! 🙏🏼

)
)
]

def get_valid_fields(self, queryset, view, context={}):
Expand Down
51 changes: 51 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,10 @@ class OrderingFilterModel(models.Model):
title = models.CharField(max_length=20, verbose_name='verbose title')
text = models.CharField(max_length=100)

@property
def description(self):
return self.title + ": " + self.text


class OrderingFilterRelatedModel(models.Model):
related_object = models.ForeignKey(OrderingFilterModel, related_name="relateds", on_delete=models.CASCADE)
Expand All @@ -436,6 +440,17 @@ class Meta:
fields = '__all__'


class OrderingFilterSerializerWithModelProperty(serializers.ModelSerializer):
class Meta:
model = OrderingFilterModel
fields = (
"id",
"title",
"text",
"description"
)


class OrderingDottedRelatedSerializer(serializers.ModelSerializer):
related_text = serializers.CharField(source='related_object.text')
related_title = serializers.CharField(source='related_object.title')
Expand Down Expand Up @@ -551,6 +566,42 @@ class OrderingListView(generics.ListAPIView):
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]

def test_ordering_without_ordering_fields(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializerWithModelProperty
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)

view = OrderingListView.as_view()

# Model field ordering works fine.
request = factory.get('/', {'ordering': 'text'})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'zyx', 'text': 'abc', 'description': 'zyx: abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd', 'description': 'yxw: bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde', 'description': 'xwv: cde'},
]

# `incorrectfield` ordering works fine.
request = factory.get('/', {'ordering': 'foobar'})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde', 'description': 'xwv: cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd', 'description': 'yxw: bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc', 'description': 'zyx: abc'},
]

# `description` is a Model property, which should be ignored.
request = factory.get('/', {'ordering': 'description'})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde', 'description': 'xwv: cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd', 'description': 'yxw: bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc', 'description': 'zyx: abc'},
]

def test_default_ordering(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
Expand Down