Skip to content

Commit

Permalink
Persist async client session (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-stytch authored Apr 9, 2024
1 parent 7b2d375 commit e95b160
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 30 deletions.
34 changes: 34 additions & 0 deletions bin/async_session_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python3

import asyncio
import os

import aiohttp
import stytch


async def main() -> None:
print("=== Test 1: custom session ===")
session = aiohttp.ClientSession(headers={"async-test": "true"})
client = stytch.Client(
project_id=os.environ["STYTCH_PROJECT_ID"],
secret=os.environ["STYTCH_SECRET"],
async_session=session,
)
resp = await client.users.search_async()
print(f"First user: {resp.results[0].user_id}")
await session.close()

print("\n\n=== Test 2: default session ===")
client = stytch.Client(
project_id=os.environ["STYTCH_PROJECT_ID"],
secret=os.environ["STYTCH_SECRET"],
)
resp = await client.users.search_async()
print(f"First user: {resp.results[0].user_id}")

print("\n\n=== Testing done ===")


if __name__ == "__main__":
asyncio.run(main())
7 changes: 0 additions & 7 deletions bin/generate-api.sh

This file was deleted.

10 changes: 9 additions & 1 deletion stytch/b2b/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from typing import Optional

import aiohttp
import jwt

from stytch.b2b.api.discovery import Discovery
Expand Down Expand Up @@ -40,8 +41,15 @@ def __init__(
secret: str,
environment: Optional[str] = None,
suppress_warnings: bool = False,
async_session: Optional[aiohttp.ClientSession] = None,
):
super().__init__(project_id, secret, environment, suppress_warnings)
super().__init__(
project_id=project_id,
secret=secret,
environment=environment,
suppress_warnings=suppress_warnings,
async_session=async_session,
)

policy_cache = PolicyCache(
RBAC(
Expand Down
10 changes: 9 additions & 1 deletion stytch/consumer/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from typing import Optional

import aiohttp
import jwt

from stytch.consumer.api.crypto_wallets import CryptoWallets
Expand Down Expand Up @@ -36,8 +37,15 @@ def __init__(
secret: str,
environment: Optional[str] = None,
suppress_warnings: bool = False,
async_session: Optional[aiohttp.ClientSession] = None,
):
super().__init__(project_id, secret, environment, suppress_warnings)
super().__init__(
project_id=project_id,
secret=secret,
environment=environment,
suppress_warnings=suppress_warnings,
async_session=async_session,
)

self.crypto_wallets = CryptoWallets(
api_base=self.api_base,
Expand Down
4 changes: 3 additions & 1 deletion stytch/core/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from typing import Optional

import aiohttp
import jwt

from stytch.core.api_base import ApiBase
Expand All @@ -15,11 +16,12 @@ def __init__(
secret: str,
environment: Optional[str] = None,
suppress_warnings: bool = False,
async_session: Optional[aiohttp.ClientSession] = None,
):
base_url = self._env_url(project_id, environment, suppress_warnings)
self.api_base = ApiBase(base_url)
self.sync_client = SyncClient(project_id, secret)
self.async_client = AsyncClient(project_id, secret)
self.async_client = AsyncClient(project_id, secret, session=async_session)
self.jwks_client = self.get_jwks_client(project_id)

@abc.abstractmethod
Expand Down
56 changes: 37 additions & 19 deletions stytch/core/http/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import asyncio
from dataclasses import dataclass
from typing import Any, Dict, Generic, Optional, TypeVar

Expand Down Expand Up @@ -86,9 +87,30 @@ def delete(


class AsyncClient(ClientBase):
def __init__(self, project_id: str, secret: str) -> None:
def __init__(
self,
project_id: str,
secret: str,
session: Optional[aiohttp.ClientSession] = None,
) -> None:
super().__init__(project_id, secret)
self.auth = aiohttp.BasicAuth(project_id, secret)
self._external_session = session is not None
self._session = session or aiohttp.ClientSession()

def __del__(self) -> None:
if self._external_session:
return

# If we're responsible for the session, close it now
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(self._session.close())
else:
loop.run_until_complete(self._session.close())
except Exception:
pass

@classmethod
async def _response_from_request(
Expand All @@ -108,11 +130,10 @@ async def get(
) -> ResponseWithJson:
final_headers = self.headers.copy()
final_headers.update(headers or {})
async with aiohttp.ClientSession() as session:
resp = await session.get(
url, params=params, headers=final_headers, auth=self.auth
)
return await self._response_from_request(resp)
resp = await self._session.get(
url, params=params, headers=final_headers, auth=self.auth
)
return await self._response_from_request(resp)

async def post(
self,
Expand All @@ -122,11 +143,10 @@ async def post(
) -> ResponseWithJson:
final_headers = self.headers.copy()
final_headers.update(headers or {})
async with aiohttp.ClientSession() as session:
resp = await session.post(
url, json=json, headers=final_headers, auth=self.auth
)
return await self._response_from_request(resp)
resp = await self._session.post(
url, json=json, headers=final_headers, auth=self.auth
)
return await self._response_from_request(resp)

async def put(
self,
Expand All @@ -136,17 +156,15 @@ async def put(
) -> ResponseWithJson:
final_headers = self.headers.copy()
final_headers.update(headers or {})
async with aiohttp.ClientSession() as session:
resp = await session.put(
url, json=json, headers=final_headers, auth=self.auth
)
return await self._response_from_request(resp)
resp = await self._session.put(
url, json=json, headers=final_headers, auth=self.auth
)
return await self._response_from_request(resp)

async def delete(
self, url: str, headers: Optional[Dict[str, str]] = None
) -> ResponseWithJson:
final_headers = self.headers.copy()
final_headers.update(headers or {})
async with aiohttp.ClientSession() as session:
resp = await session.delete(url, headers=final_headers, auth=self.auth)
return await self._response_from_request(resp)
resp = await self._session.delete(url, headers=final_headers, auth=self.auth)
return await self._response_from_request(resp)
2 changes: 1 addition & 1 deletion stytch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "8.7.0"
__version__ = "9.0.0"

0 comments on commit e95b160

Please sign in to comment.