Skip to content

Commit

Permalink
feat: message buffering and TCP
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Bluhm <[email protected]>
  • Loading branch information
dbluhm committed Nov 9, 2023
1 parent a73e251 commit 656bf49
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 60 deletions.
18 changes: 12 additions & 6 deletions oid4vci/int/afj-wrapper/afj_wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import asyncio
from pathlib import Path
from os import getenv

from .jsonrpc import JsonRpcClient, UnixSocketTransport
from .jsonrpc import JsonRpcClient, TCPSocketTransport

AFJ_HOST = getenv("AFJ_HOST", "localhost")
AFJ_PORT = int(getenv("AFJ_PORT", "3000"))


async def main():
"""Connect to AFJ."""
transport = UnixSocketTransport(
str(Path(__file__).parent.parent / "afj/agent.sock")
)
transport = TCPSocketTransport(AFJ_HOST, AFJ_PORT)
client = JsonRpcClient(transport)
async with transport, client:
result = await client.send_request("initialize")
result = await client.request(
"initialize",
endpoint="http://localhost:3001",
host="0.0.0.0",
port=3001,
)
print(result)


Expand Down
96 changes: 82 additions & 14 deletions oid4vci/int/afj-wrapper/afj_wrapper/jsonrpc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""JSON-RPC client implementation."""
from abc import ABC, abstractmethod
import asyncio
import json
from typing import Any, Dict, Optional, Protocol
Expand All @@ -8,20 +10,26 @@ class TransportProtocol(Protocol):
"""Transport protocol that transport class should implement."""

async def send(self, message: str) -> None:
"""Send a message to connected clients."""
...

async def receive(self) -> str:
"""Receive a message from connected clients."""
...


class JsonRpcError(Exception):
"""Exception raised when an error is returned by the server."""

def __init__(self, error: dict):
"""Initialize the exception."""
self.code = error.get("code")
self.message = error.get("message")
self.data = error.get("data")
super().__init__(self.__str__())

def __str__(self) -> str:
"""Convert the exception to a string."""
error_str = f"JSON-RPC Error {self.code}: {self.message}"
if isinstance(self.data, str):
# Represent newlines in data correctly when converting to string
Expand All @@ -34,6 +42,7 @@ class JsonRpcClient:
"""JSON-RPC client implementation."""

def __init__(self, transport: TransportProtocol) -> None:
"""Initialize the client."""
self.transport = transport
self.id_counter = 0
self.pending_calls: Dict[int, asyncio.Future] = {}
Expand Down Expand Up @@ -77,9 +86,14 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Async context manager: close the client."""
await self.stop()

async def request(self, method: str, **params: Any) -> Any:
"""Send a request to the server."""
return await self.send_request(method, params)

async def send_request(
self, method: str, params: Optional[Dict[str, Any]] = None
) -> Any:
"""Send a request to the server."""
self.id_counter += 1
message_id = self.id_counter
request = {
Expand All @@ -96,6 +110,7 @@ async def send_request(
return await future

async def receive_response(self) -> None:
"""Receive responses from the server."""
while True:
response_str = await self.transport.receive()
response = json.loads(response_str)
Expand All @@ -113,43 +128,93 @@ async def receive_response(self) -> None:
future.set_exception(Exception("Invalid JSON-RPC response"))


class UnixSocketTransport:
"""Transport implementation that uses a Unix socket."""
class BaseSocketTransport(ABC):
"""Base transport implementation using asyncio sockets."""

def __init__(self, path: str) -> None:
self.path = path
def __init__(self) -> None:
"""Initialize the transport."""
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None

async def connect(self) -> None:
self.reader, self.writer = await asyncio.open_unix_connection(self.path)

async def send(self, message: str) -> None:
"""Send a message to connected clients."""
if self.writer is None:
raise Exception("Transport is not connected")
self.writer.write(message.encode())

message_bytes = message.encode()
header = f"length: {len(message_bytes)}\n".encode()
full_message = header + message_bytes

self.writer.write(full_message)
await self.writer.drain()

async def receive(self) -> str:
"""Receive a message from connected clients."""
if self.reader is None:
raise Exception("Transport is not connected")
data = await self.reader.read(4096) # Adjust buffer size as needed
return data.decode()

# Read the header first
header = await self.reader.readuntil(b"\n")
# Decode the header and extract the length
length_str = header.decode().strip()
if not length_str.startswith("length: "):
raise ValueError("Invalid message header received")

# Extract the length and convert to an integer
length = int(length_str.split("length: ")[1].strip())

# Now read the body of the message
body_bytes = await self.reader.readexactly(length)

# Decode and return the message body
return body_bytes.decode()

async def close(self) -> None:
"""Close the transport."""
if self.writer is not None:
self.writer.close()
await self.writer.wait_closed()

async def __aenter__(self) -> "UnixSocketTransport":
"""Async context manager: connect to the socket."""
async def __aenter__(self) -> "BaseSocketTransport":
"""Async context manager: connect to the server."""
await self.connect()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Async context manager: close the socket."""
"""Async context manager: close the transport."""
await self.close()

@abstractmethod
async def connect(self):
"""Connect to the server."""


class UnixSocketTransport(BaseSocketTransport):
"""Transport implementation that uses a Unix socket."""

def __init__(self, path: str) -> None:
"""Initialize the transport."""
super().__init__()
self.path = path

async def connect(self) -> None:
"""Connect to the server."""
self.reader, self.writer = await asyncio.open_unix_connection(self.path)


class TCPSocketTransport(BaseSocketTransport):
"""Transport implementation that uses a TCP socket."""

def __init__(self, host: str, port: int) -> None:
"""Initialize the transport."""
super().__init__()
self.host = host
self.port = port

async def connect(self) -> None:
"""Connect to the server."""
self.reader, self.writer = await asyncio.open_connection(self.host, self.port)


async def main():
"""Usage example."""
Expand All @@ -158,10 +223,13 @@ class DummyTransport:
"""Dummy transport implementation that prints messages to the console."""

async def send(self, message: str) -> None:
"""Send a message to connected clients."""
print(f"Sending message: {message}")

async def receive(self) -> str:
# Simulate a response (In a real implementation, you would receive messages from a server)
"""Receive a message from connected clients."""
# Simulate a response (In a real implementation, you would receive
# messages from a server)
await asyncio.sleep(1)
return json.dumps({"jsonrpc": "2.0", "result": "pong", "id": 1})

Expand Down
7 changes: 4 additions & 3 deletions oid4vci/int/afj/src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ import {
import { agentDependencies, HttpInboundTransport } from '@aries-framework/node';
import { AskarModule } from '@aries-framework/askar';
import { ariesAskar } from '@hyperledger/aries-askar-nodejs';
import { SocketServer } from './server';
import { TCPSocketServer } from './server';

let agent: Agent | null;
const server = new SocketServer({
socketPath: process.env.SOCKET_PATH || '/tmp/agent.sock',
const server = new TCPSocketServer({
host: process.env.AFJ_HOST || '0.0.0.0',
port: parseInt(process.env.AFJ_PORT || '3000'),
})

const rpc = new JSONRPCServerAndClient(
Expand Down
Loading

0 comments on commit 656bf49

Please sign in to comment.