diff --git a/pyle38/commands/set.py b/pyle38/commands/set.py index 6848d4d..68a11e4 100644 --- a/pyle38/commands/set.py +++ b/pyle38/commands/set.py @@ -4,7 +4,7 @@ from typing import Literal, Optional, Sequence, Union from ..client import Client, Command, SubCommand -from ..responses import JSONResponse +from ..responses import Fields, JSONResponse from .executable import Compiled, Executable @@ -13,6 +13,7 @@ class Set(Executable): _id: str _ex: Optional[int] = None _nx_or_xx: Optional[Union[Literal["NX", "XX"]]] = None + _fields: Optional[Fields] = {} _input: Optional[ Sequence[ Union[ @@ -25,6 +26,7 @@ def __init__(self, client: Client, key: str, id: str) -> None: super().__init__(client) self.key(key).id(id) + self._fields = {} def key(self, value: str) -> Set: self._key = value @@ -36,6 +38,11 @@ def id(self, value: str) -> Set: return self + def fields(self, fields: Fields): + self._fields = fields + + return self + def ex(self, seconds: int) -> Set: if seconds: self._ex = seconds @@ -79,6 +86,13 @@ def string(self, value: str) -> Set: return self + def unpack_fields(self, fields: Fields): + command = [] + for k, v in fields.items(): + command.extend([SubCommand.FIELD.value, k, v]) + + return command + def compile(self) -> Compiled: return [ @@ -86,6 +100,7 @@ def compile(self) -> Compiled: [ self._key, self._id, + *(self.unpack_fields(self._fields) if self._fields else []), *([SubCommand.EX.value, self._ex] if self._ex else []), *([self._nx_or_xx] if self._nx_or_xx else []), *(self._input if self._input else []), diff --git a/tests/test_command_set.py b/tests/test_command_set.py index 171367d..f43b709 100644 --- a/tests/test_command_set.py +++ b/tests/test_command_set.py @@ -15,6 +15,8 @@ "properties": {}, } +fields = {"speed": 100, "state": 1} + @pytest.mark.parametrize( "expected, received", @@ -36,8 +38,15 @@ Set(client, key, id).hash("u33d").compile(), ), (["SET", [key, id, "STRING", id]], Set(client, key, id).string(id).compile()), + ( + [ + "SET", + [key, id, "FIELD", "speed", 100, "FIELD", "state", 1, "POINT", 1, 1], + ], + Set(client, key, id).fields(fields).point(1, 1).compile(), + ), ], - ids=["point", "bounds", "object", "hash", "string"], + ids=["point", "bounds", "object", "hash", "string", "with fields"], ) @pytest.mark.asyncio async def test_command_set_compile(expected, received): @@ -55,3 +64,13 @@ async def test_command_set_query(tile38): received = await tile38.get(key, id).asObject() assert expected["object"] == received.object + + +@pytest.mark.asyncio +async def test_command_set_with_fields(tile38): + response = await tile38.set(key, id).fields(fields).point(1, 1).exec() + assert response.ok + + response = await tile38.get(key, id).with_fields().asObject() + assert response.ok + assert response.fields == fields