Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KeyError when using create_observation() in dry-run mode #538

Merged
merged 1 commit into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased
* ⚠️ Drop support for python 3.7
* Fix `KeyError` when using `create_observation()` in dry-run mode

## 0.19.0 (2023-12-12)

Expand Down
41 changes: 29 additions & 12 deletions pyinaturalist/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Session class and related functions for preparing and sending API requests"""
import json
import threading
from collections import defaultdict
from json import JSONDecodeError
from logging import getLogger
from os import getenv
Expand Down Expand Up @@ -298,7 +299,7 @@ def send( # type: ignore # Adds kwargs not present in Session.send()

# Make a mock request, if specified
if dry_run or is_dry_run_enabled(request.method):
return get_mock_response(request)
return MockResponse(request)

# Otherwise, send the request
read_timeout = timeout or self.timeout
Expand Down Expand Up @@ -432,17 +433,33 @@ def get_refresh_params(endpoint) -> Dict:
return {'refresh': True, 'v': v} if v > 0 else {'refresh': True}


def get_mock_response(request: PreparedRequest) -> CachedResponse:
"""Get mock response content to return in dry-run mode"""
json_content = {'results': [], 'total_results': 0, 'access_token': ''}
mock_response = CachedResponse(
headers={'Cache-Control': 'no-store'},
request=request,
status_code=200,
reason='DRY_RUN',
content=json.dumps(json_content).encode(),
)
return mock_response
class MockResponse(CachedResponse):
"""A mock response to return in dry-run mode.
This behaves the same as a cached response, but with the following additions:

* Adds default response values
* Returns a ``defaultdict`` when calling ``json()``, to accommodate checking for arbitrary keys
"""

def __init__(self, request: Optional[PreparedRequest] = None, **kwargs):
json_content = {
'results': [],
'total_results': 0,
'access_token': '',
'id': 'placeholder-id',
}
default_kwargs = {
'headers': {'Cache-Control': 'no-store'},
'request': request or PreparedRequest(),
'status_code': 200,
'reason': 'DRY_RUN',
'content': json.dumps(json_content).encode(),
}
default_kwargs.update(kwargs)
super().__init__(**default_kwargs)

def json(self, **kwargs):
return defaultdict(str, super().json(**kwargs))


def is_dry_run_enabled(method: str) -> bool:
Expand Down
8 changes: 6 additions & 2 deletions test/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from pyinaturalist.session import (
CACHE_FILE,
ClientSession,
MockResponse,
clear_cache,
delete,
get,
get_local_session,
get_mock_response,
get_refresh_params,
post,
put,
Expand Down Expand Up @@ -106,6 +106,10 @@ def test_request_dry_run(
assert response.reason == 'DRY_RUN'
assert mock_send.call_count == 0

json_content = response.json()
assert len(json_content['results']) == json_content['total_results'] == 0
assert response.json()['arbitrary_key'] == ''


@patch.object(Session, 'send')
def test_request_dry_run_kwarg(mock_request):
Expand Down Expand Up @@ -207,7 +211,7 @@ def test_session__send__cache_settings(mock_send):
session = ClientSession()
with patch.object(session, 'send') as mock_cache_send:
request = Request(method='GET', url='http://test.com').prepare()
mock_send.return_value = get_mock_response(request)
mock_send.return_value = MockResponse(request)

session.send(request)
mock_cache_send.assert_called_with(request)
Expand Down
9 changes: 3 additions & 6 deletions test/v1/test_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

import pytest
from dateutil.tz import tzoffset, tzutc
from requests import PreparedRequest

from pyinaturalist.constants import API_V1
from pyinaturalist.exceptions import ObservationNotFound
from pyinaturalist.session import ClientSession, get_mock_response
from pyinaturalist.session import ClientSession, MockResponse
from pyinaturalist.v1 import (
create_observation,
delete_observation,
Expand All @@ -33,8 +32,6 @@
j_taxon_summary_2_listed,
)

MOCK_RESPONSE = get_mock_response(PreparedRequest())


def test_get_observation(requests_mock):
requests_mock.get(
Expand Down Expand Up @@ -106,14 +103,14 @@ def test_get_observations__all_pages(requests_mock):
assert len(observations['results']) == 2


@patch.object(ClientSession, 'send', return_value=MOCK_RESPONSE)
@patch.object(ClientSession, 'send', return_value=MockResponse())
def test_get_observations__by_obs_field(mock_send):
get_observations(taxon_id=3, observation_fields=['Species count'])
request = mock_send.call_args[0][0]
assert request.params == {'taxon_id': '3', 'field:Species count': ''}


@patch.object(ClientSession, 'send', return_value=MOCK_RESPONSE)
@patch.object(ClientSession, 'send', return_value=MockResponse())
def test_get_observations__by_obs_field_values(mock_send):
get_observations(taxon_id=3, observation_fields={'Species count': 2})
request = mock_send.call_args[0][0]
Expand Down
9 changes: 3 additions & 6 deletions test/v2/test_observations_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
from unittest.mock import patch

import pytest
from requests import PreparedRequest

from pyinaturalist.constants import API_V2
from pyinaturalist.session import ClientSession, get_mock_response
from pyinaturalist.session import ClientSession, MockResponse
from pyinaturalist.v2 import get_observations
from test.sample_data import SAMPLE_DATA

MOCK_RESPONSE = get_mock_response(PreparedRequest())


def test_get_observations__minimal(requests_mock):
requests_mock.get(
Expand Down Expand Up @@ -91,14 +88,14 @@ def test_get_observations__all_pages__post(requests_mock):
assert len(observations['results']) == 2


@patch.object(ClientSession, 'send', return_value=MOCK_RESPONSE)
@patch.object(ClientSession, 'send', return_value=MockResponse())
def test_get_observations__by_obs_field(mock_send):
get_observations(taxon_id=3, observation_fields=['Species count'])
request = mock_send.call_args[0][0]
assert request.params == {'taxon_id': '3', 'field:Species count': ''}


@patch.object(ClientSession, 'send', return_value=MOCK_RESPONSE)
@patch.object(ClientSession, 'send', return_value=MockResponse())
def test_get_observations__by_obs_field_values(mock_send):
get_observations(taxon_id=3, observation_fields={'Species count': 2})
request = mock_send.call_args[0][0]
Expand Down