diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py index e094bc1..23d3f94 100644 --- a/pypush/apns/lifecycle.py +++ b/pypush/apns/lifecycle.py @@ -29,7 +29,6 @@ async def create_apns_connection( conn = Connection( tg, certificate, private_key, token, sandbox, courier ) # Await connected for first time here, so that base token is set - await conn._connected.wait() yield conn tg.cancel_scope.cancel() # Cancel the task group when the context manager exits await ( @@ -49,11 +48,11 @@ def __init__( ): self.certificate = certificate self.private_key = private_key - self.base_token = token + self._base_token = token self._filters: dict[str, int] = {} # topic -> use count - self._connected = anyio.Event() # Set when the connection is first established + self._connected = anyio.Event() # Only use for base_token property self._conn = None self._tg = task_group @@ -75,6 +74,13 @@ def __init__( self._tg.start_soon(self.reconnect) self._tg.start_soon(self._ping_task) + @property + async def base_token(self) -> bytes: + if self._base_token is None: + await self._connected.wait() + assert self._base_token is not None + return self._base_token + async def _receive_task(self): assert self._conn is not None async for command in self._conn: @@ -114,7 +120,7 @@ async def reconnect(self): ) await conn.send( protocol.ConnectCommand( - push_token=self.base_token, + push_token=self._base_token, state=1, flags=65, # 69 certificate=cert, @@ -133,8 +139,8 @@ async def reconnect(self): lambda c: ( c if ( - c.token == self.base_token - if self.base_token is not None + c.token == self._base_token + if self._base_token is not None else True ) else None @@ -143,10 +149,10 @@ async def reconnect(self): ) logging.debug(f"Connected with ack: {ack}") assert ack.status == 0 - if self.base_token is None: - self.base_token = ack.token + if self._base_token is None: + self._base_token = ack.token else: - assert ack.token == self.base_token + assert ack.token == self._base_token if not self._connected.is_set(): self._connected.set() @@ -187,10 +193,9 @@ async def _send(self, command: protocol.Command): await self._send(command) async def _update_filter(self): - assert self.base_token is not None await self._send( protocol.FilterCommand( - token=self.base_token, + token=await self.base_token, enabled_topic_hashes=[ sha1(topic.encode()).digest() for topic in self._filters ], @@ -199,7 +204,6 @@ async def _update_filter(self): @asynccontextmanager async def _filter(self, topics: list[str]): - assert self.base_token is not None for topic in topics: self._filters[topic] = self._filters.get(topic, 0) + 1 await self._update_filter() @@ -212,9 +216,8 @@ async def _filter(self, topics: list[str]): async def mint_scoped_token(self, topic: str) -> bytes: topic_hash = sha1(topic.encode()).digest() - assert self.base_token is not None await self._send( - protocol.ScopedTokenCommand(token=self.base_token, topic=topic_hash) + protocol.ScopedTokenCommand(token=await self.base_token, topic=topic_hash) ) ack = await self._receive(filters.cmd(protocol.ScopedTokenAck)) assert ack.status == 0 @@ -230,7 +233,7 @@ async def notification_stream( ] = filters.ALL, ): if token is None: - token = self.base_token + token = await self.base_token async with self._filter([topic]), self._receive_stream( filters.chain( filters.chain(