Skip to content

Commit 254ad17

Browse files
fixes try_add_state
Signed-off-by: Elena Kolevska <[email protected]>
1 parent 278ba8d commit 254ad17

File tree

4 files changed

+116
-3
lines changed

4 files changed

+116
-3
lines changed

dapr/actor/runtime/mock_state_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async def try_add_state(self, state_name: str, value: T) -> bool:
5151
return True
5252
return False
5353
existed = state_name in self._mock_state
54-
if not existed:
54+
if existed:
5555
return False
5656
self._default_state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.add)
5757
self._mock_state[state_name] = value

dapr/actor/runtime/state_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ async def try_add_state(self, state_name: str, value: T) -> bool:
9090
existed = await self._actor.runtime_ctx.state_provider.contains_state(
9191
self._type_name, self._actor.id.id, state_name
9292
)
93-
if not existed:
93+
if existed:
9494
return False
9595

9696
state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.add)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import unittest
2+
from dapr.actor.runtime.state_manager import StateChangeKind
3+
from dapr.actor.runtime.mock_state_manager import MockStateManager
4+
from dapr.actor.runtime.mock_actor import MockActor
5+
6+
7+
class TestMockStateManager(unittest.IsolatedAsyncioTestCase):
8+
def setUp(self):
9+
"""Set up a mock actor and state manager."""
10+
11+
class TestActor(MockActor):
12+
pass
13+
14+
self.mock_actor = TestActor(actor_id='test_actor', initstate=None)
15+
self.state_manager = MockStateManager(
16+
actor=self.mock_actor, initstate={'initial_key': 'initial_value'}
17+
)
18+
19+
async def test_add_state(self):
20+
"""Test adding a new state."""
21+
await self.state_manager.add_state('new_key', 'new_value')
22+
state = await self.state_manager.get_state('new_key')
23+
self.assertEqual(state, 'new_value')
24+
25+
# Ensure it is tracked as an added state
26+
tracker = self.state_manager._default_state_change_tracker
27+
self.assertEqual(tracker['new_key'].change_kind, StateChangeKind.add)
28+
self.assertEqual(tracker['new_key'].value, 'new_value')
29+
30+
async def test_get_existing_state(self):
31+
"""Test retrieving an existing state."""
32+
state = await self.state_manager.get_state('initial_key')
33+
self.assertEqual(state, 'initial_value')
34+
35+
async def test_get_nonexistent_state(self):
36+
"""Test retrieving a state that does not exist."""
37+
with self.assertRaises(KeyError):
38+
await self.state_manager.get_state('nonexistent_key')
39+
40+
async def test_update_state(self):
41+
"""Test updating an existing state."""
42+
await self.state_manager.set_state('initial_key', 'updated_value')
43+
state = await self.state_manager.get_state('initial_key')
44+
self.assertEqual(state, 'updated_value')
45+
46+
# Ensure it is tracked as an updated state
47+
tracker = self.state_manager._default_state_change_tracker
48+
self.assertEqual(tracker['initial_key'].change_kind, StateChangeKind.update)
49+
self.assertEqual(tracker['initial_key'].value, 'updated_value')
50+
51+
async def test_remove_state(self):
52+
"""Test removing an existing state."""
53+
await self.state_manager.remove_state('initial_key')
54+
with self.assertRaises(KeyError):
55+
await self.state_manager.get_state('initial_key')
56+
57+
# Ensure it is tracked as a removed state
58+
tracker = self.state_manager._default_state_change_tracker
59+
self.assertEqual(tracker['initial_key'].change_kind, StateChangeKind.remove)
60+
61+
async def test_save_state(self):
62+
"""Test saving state changes."""
63+
await self.state_manager.add_state('key1', 'value1')
64+
await self.state_manager.set_state('initial_key', 'value2')
65+
await self.state_manager.remove_state('initial_key')
66+
67+
await self.state_manager.save_state()
68+
69+
# After saving, state tracker should be cleared
70+
tracker = self.state_manager._default_state_change_tracker
71+
self.assertEqual(len(tracker), 1)
72+
73+
# State changes should be reflected in _mock_state
74+
self.assertIn('key1', self.state_manager._mock_state)
75+
self.assertEqual(self.state_manager._mock_state['key1'], 'value1')
76+
self.assertNotIn('initial_key', self.state_manager._mock_state)
77+
78+
async def test_contains_state(self):
79+
"""Test checking if a state exists."""
80+
self.assertTrue(await self.state_manager.contains_state('initial_key'))
81+
self.assertFalse(await self.state_manager.contains_state('nonexistent_key'))
82+
83+
async def test_clear_cache(self):
84+
"""Test clearing the cache."""
85+
await self.state_manager.add_state('key1', 'value1')
86+
await self.state_manager.clear_cache()
87+
88+
# Tracker should be empty
89+
self.assertEqual(len(self.state_manager._default_state_change_tracker), 0)
90+
91+
async def test_state_ttl(self):
92+
"""Test setting state with TTL."""
93+
await self.state_manager.set_state_ttl('key_with_ttl', 'value', ttl_in_seconds=10)
94+
tracker = self.state_manager._default_state_change_tracker
95+
self.assertEqual(tracker['key_with_ttl'].ttl_in_seconds, 10)
96+
97+
98+
if __name__ == '__main__':
99+
unittest.main()

tests/actor/test_state_manager.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def setUp(self):
4646

4747
@mock.patch(
4848
'tests.actor.fake_client.FakeDaprActorClient.get_state',
49-
new=_async_mock(return_value=base64.b64encode(b'"value1"')),
49+
new=_async_mock(),
5050
)
5151
@mock.patch(
5252
'tests.actor.fake_client.FakeDaprActorClient.save_state_transactionally', new=_async_mock()
@@ -67,6 +67,20 @@ def test_add_state(self):
6767
added = _run(state_manager.try_add_state('state1', 'value1'))
6868
self.assertFalse(added)
6969

70+
@mock.patch(
71+
'tests.actor.fake_client.FakeDaprActorClient.get_state',
72+
new=_async_mock(return_value=base64.b64encode(b'"value1"')),
73+
)
74+
@mock.patch(
75+
'tests.actor.fake_client.FakeDaprActorClient.save_state_transactionally', new=_async_mock()
76+
)
77+
def test_add_state_with_existing_state(self):
78+
state_manager = ActorStateManager(self._fake_actor)
79+
80+
# Add first 'state1'
81+
added = _run(state_manager.try_add_state('state1', 'value1'))
82+
self.assertFalse(added)
83+
7084
@mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock())
7185
def test_get_state_for_no_state(self):
7286
state_manager = ActorStateManager(self._fake_actor)

0 commit comments

Comments
 (0)