Skip to content

Commit

Permalink
apns: async base_token property
Browse files Browse the repository at this point in the history
makes awaiting _connected an implementation detail
  • Loading branch information
JJTech0130 committed May 19, 2024
1 parent 4d40ed0 commit 8d59b7e
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions pypush/apns/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ async def create_apns_connection(
conn = Connection(
tg, certificate, private_key, token, sandbox, courier
) # Await connected for first time here, so that base token is set
await conn._connected.wait()
yield conn
tg.cancel_scope.cancel() # Cancel the task group when the context manager exits
await (
Expand All @@ -49,11 +48,11 @@ def __init__(
):
self.certificate = certificate
self.private_key = private_key
self.base_token = token
self._base_token = token

self._filters: dict[str, int] = {} # topic -> use count

self._connected = anyio.Event() # Set when the connection is first established
self._connected = anyio.Event() # Only use for base_token property

self._conn = None
self._tg = task_group
Expand All @@ -75,6 +74,13 @@ def __init__(
self._tg.start_soon(self.reconnect)
self._tg.start_soon(self._ping_task)

@property
async def base_token(self) -> bytes:
if self._base_token is None:
await self._connected.wait()
assert self._base_token is not None
return self._base_token

async def _receive_task(self):
assert self._conn is not None
async for command in self._conn:
Expand Down Expand Up @@ -114,7 +120,7 @@ async def reconnect(self):
)
await conn.send(
protocol.ConnectCommand(
push_token=self.base_token,
push_token=self._base_token,
state=1,
flags=65, # 69
certificate=cert,
Expand All @@ -133,8 +139,8 @@ async def reconnect(self):
lambda c: (
c
if (
c.token == self.base_token
if self.base_token is not None
c.token == self._base_token
if self._base_token is not None
else True
)
else None
Expand All @@ -143,10 +149,10 @@ async def reconnect(self):
)
logging.debug(f"Connected with ack: {ack}")
assert ack.status == 0
if self.base_token is None:
self.base_token = ack.token
if self._base_token is None:
self._base_token = ack.token
else:
assert ack.token == self.base_token
assert ack.token == self._base_token
if not self._connected.is_set():
self._connected.set()

Expand Down Expand Up @@ -187,10 +193,9 @@ async def _send(self, command: protocol.Command):
await self._send(command)

async def _update_filter(self):
assert self.base_token is not None
await self._send(
protocol.FilterCommand(
token=self.base_token,
token=await self.base_token,
enabled_topic_hashes=[
sha1(topic.encode()).digest() for topic in self._filters
],
Expand All @@ -199,7 +204,6 @@ async def _update_filter(self):

@asynccontextmanager
async def _filter(self, topics: list[str]):
assert self.base_token is not None
for topic in topics:
self._filters[topic] = self._filters.get(topic, 0) + 1
await self._update_filter()
Expand All @@ -212,9 +216,8 @@ async def _filter(self, topics: list[str]):

async def mint_scoped_token(self, topic: str) -> bytes:
topic_hash = sha1(topic.encode()).digest()
assert self.base_token is not None
await self._send(
protocol.ScopedTokenCommand(token=self.base_token, topic=topic_hash)
protocol.ScopedTokenCommand(token=await self.base_token, topic=topic_hash)
)
ack = await self._receive(filters.cmd(protocol.ScopedTokenAck))
assert ack.status == 0
Expand All @@ -230,7 +233,7 @@ async def notification_stream(
] = filters.ALL,
):
if token is None:
token = self.base_token
token = await self.base_token
async with self._filter([topic]), self._receive_stream(
filters.chain(
filters.chain(
Expand Down

0 comments on commit 8d59b7e

Please sign in to comment.