diff --git a/.env.example b/.env.example index 6c819320c..9249e6a73 100644 --- a/.env.example +++ b/.env.example @@ -19,17 +19,7 @@ RPCHOST = 'jsonrpc' RPCPORT = '4000' # (the port debugpy is listening on) RPCDEBUGPORT = '4678' - -# GenVM server details -GENVMPROTOCOL = 'http' -GENVMHOST = 'genvm' -GENVMPORT = '6000' -# Location of file excuted inside the GenVM -GENVMCONLOC = '/tmp' -# TODO: Will be removed with the new logging -GENVMDEBUG = 1 -# (the port debugpy is listening on) -GENVMDEBUGPORT = '6678' +GENVM_BIN = "/genvm/bin" # (enables debuggin in VScode) VSCODEDEBUG = "false" # "true" or "false" @@ -44,6 +34,7 @@ OLAMAPORT = '11434' WEBREQUESTPROTOCOL = 'http' WEBREQUESTHOST = 'webrequest' WEBREQUESTPORT = '5000' +WEBREQUESTSELENIUMPORT = '5001' # If you want to use OpenAI add your key here OPENAIKEY = '' diff --git a/.github/workflows/unit-tests-pr.yml b/.github/workflows/unit-tests-pr.yml index c9fdb29b8..a599a0b82 100644 --- a/.github/workflows/unit-tests-pr.yml +++ b/.github/workflows/unit-tests-pr.yml @@ -28,7 +28,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: 3.11 + python-version: 3.12.4 cache: pip - run: pip install -r backend/protocol_rpc/requirements.txt - run: pip install pytest-cov diff --git a/backend/consensus/base.py b/backend/consensus/base.py index a6597474e..f2798e7f1 100644 --- a/backend/consensus/base.py +++ b/backend/consensus/base.py @@ -27,7 +27,7 @@ Validator, ) from backend.node.base import Node -from backend.node.genvm.types import ExecutionMode, Receipt, Vote +from backend.node.types import ExecutionMode, Receipt, Vote, ExecutionResultStatus from backend.protocol_rpc.message_handler.base import MessageHandler from backend.protocol_rpc.message_handler.types import ( LogEvent, @@ -110,18 +110,34 @@ async def _run_consensus(self): for queue in [q for q in self.queues.values() if not q.empty()]: # sessions cannot be shared between coroutines, we need to create a new session for each coroutine # https://docs.sqlalchemy.org/en/20/orm/session_basics.html#is-the-session-thread-safe-is-asyncsession-safe-to-share-in-concurrent-tasks - transaction = await queue.get() + transaction: Transaction = await queue.get() with self.get_session() as session: + def contract_snapshot_factory( + contract_address, + session=session, + transaction=transaction, + ): + if ( + transaction.type == TransactionType.DEPLOY_CONTRACT + and contract_address == transaction.to_address + ): + ret = ContractSnapshot(None, session) + ret.contract_address = transaction.to_address + ret.contract_code = transaction.data[ + "contract_code" + ] + ret.encoded_state = {} + return ret + return ContractSnapshot(contract_address, session) + async def exec_transaction_with_session_handling(): await self.exec_transaction( transaction, TransactionsProcessor(session), ChainSnapshot(session), AccountsManager(session), - lambda contract_address, session=session: ContractSnapshot( - contract_address, session - ), + contract_snapshot_factory, ) session.commit() @@ -211,13 +227,17 @@ async def exec_transaction( num_validators = len(remaining_validators) + 1 - contract_snapshot = contract_snapshot_factory(transaction.to_address) + contract_snapshot_supplier = lambda: contract_snapshot_factory( + transaction.to_address + ) + + leaders_contract_snapshot = contract_snapshot_supplier() # Create Leader leader_node = node_factory( leader, ExecutionMode.LEADER, - contract_snapshot, + leaders_contract_snapshot, None, msg_handler, contract_snapshot_factory, @@ -225,6 +245,7 @@ async def exec_transaction( # Leader executes transaction leader_receipt = await leader_node.exec_transaction(transaction) + votes = {leader["address"]: leader_receipt.vote.value} consensus_data.votes = votes consensus_data.leader_receipt = leader_receipt @@ -246,7 +267,7 @@ async def exec_transaction( node_factory( validator, ExecutionMode.VALIDATOR, - contract_snapshot, + contract_snapshot_supplier(), leader_receipt, msg_handler, contract_snapshot_factory, @@ -286,7 +307,7 @@ async def exec_transaction( if ( len([vote for vote in votes.values() if vote == Vote.AGREE.value]) - >= num_validators // 2 + >= (num_validators + 1) // 2 ): break # Consensus reached @@ -335,30 +356,33 @@ async def exec_transaction( ) ) - # Register contract if it is a new contract - if transaction.type == TransactionType.DEPLOY_CONTRACT: - new_contract = { - "id": transaction.data["contract_address"], - "data": { - "state": leader_receipt.contract_state, - "code": transaction.data["contract_code"], - }, - } - contract_snapshot.register_contract(new_contract) - - msg_handler.send_message( - LogEvent( - "deployed_contract", - EventType.SUCCESS, - EventScope.GENVM, - "Contract deployed", - new_contract, + if leader_receipt.execution_result == ExecutionResultStatus.SUCCESS: + # Register contract if it is a new contract + if transaction.type == TransactionType.DEPLOY_CONTRACT: + new_contract = { + "id": transaction.data["contract_address"], + "data": { + "state": leader_receipt.contract_state, + "code": transaction.data["contract_code"], + }, + } + leaders_contract_snapshot.register_contract(new_contract) + + msg_handler.send_message( + LogEvent( + "deployed_contract", + EventType.SUCCESS, + EventScope.GENVM, + "Contract deployed", + new_contract, + ) ) - ) - # Update contract state if it is an existing contract - else: - contract_snapshot.update_contract_state(leader_receipt.contract_state) + # Update contract state if it is an existing contract + else: + leaders_contract_snapshot.update_contract_state( + leader_receipt.contract_state + ) # Finalize transaction consensus_data.final = True diff --git a/backend/database_handler/contract_snapshot.py b/backend/database_handler/contract_snapshot.py index ef6c1b6eb..5570ad5ed 100644 --- a/backend/database_handler/contract_snapshot.py +++ b/backend/database_handler/contract_snapshot.py @@ -9,10 +9,14 @@ class ContractSnapshot: """ Warning: if you initialize this class with a contract_address: - The contract_address must exist in the database. - - `self.contract_data`, `self.contract_code` and `self.cencoded_state` will be loaded from the database **only once** at initialization. + - `self.contract_data`, `self.contract_code` and `self.encoded_state` will be loaded from the database **only once** at initialization. """ - def __init__(self, contract_address: str, session: Session): + contract_address: str + contract_code: str + encoded_state: dict[str, dict[str, str]] + + def __init__(self, contract_address: str | None, session: Session): self.session = session if contract_address is not None: @@ -46,9 +50,9 @@ def register_contract(self, contract: dict): current_contract.data = contract["data"] self.session.commit() - def update_contract_state(self, new_state: str): + def update_contract_state(self, new_state: dict[str, str]): """Update the state of the contract in the database.""" - new_contract_nada = { + new_contract_data = { "code": self.contract_data["code"], "state": new_state, } @@ -56,5 +60,5 @@ def update_contract_state(self, new_state: str): contract = ( self.session.query(CurrentState).filter_by(id=self.contract_address).one() ) - contract.data = new_contract_nada + contract.data = new_contract_data self.session.commit() diff --git a/backend/database_handler/llm_providers.py b/backend/database_handler/llm_providers.py index 15a9a054b..73d2b8791 100644 --- a/backend/database_handler/llm_providers.py +++ b/backend/database_handler/llm_providers.py @@ -2,7 +2,7 @@ from backend.node.create_nodes.providers import get_default_providers from .models import LLMProviderDBModel from sqlalchemy.orm import Session -from backend.node.genvm.llms import get_llm_plugin +from backend.llms import get_llm_plugin import pprint @@ -26,7 +26,7 @@ def get_all(self) -> list[LLMProvider]: for provider in self.session.query(LLMProviderDBModel).all() ] - def get_all_dict(self) -> list[dict]: + async def get_all_dict(self) -> list[dict]: providers = self.session.query(LLMProviderDBModel).all() result = [] @@ -34,12 +34,12 @@ def get_all_dict(self) -> list[dict]: domain_provider = _to_domain(provider) provider_dict = domain_provider.__dict__ - plugin = get_llm_plugin( + plugin = await get_llm_plugin( domain_provider.plugin, domain_provider.plugin_config ) - provider_dict["is_available"] = plugin.is_available() - provider_dict["is_model_available"] = plugin.is_model_available( + provider_dict["is_available"] = await plugin.is_available() + provider_dict["is_model_available"] = await plugin.is_model_available( domain_provider.model ) diff --git a/backend/database_handler/migration/requirements.txt b/backend/database_handler/migration/requirements.txt index 8a4a76f55..5724cef1d 100644 --- a/backend/database_handler/migration/requirements.txt +++ b/backend/database_handler/migration/requirements.txt @@ -3,4 +3,9 @@ alembic==1.13.2 psycopg2-binary==2.9.9 rlp eth-utils -eth-hash[pycryptodome] \ No newline at end of file +eth-hash[pycryptodome] +jsonschema==4.23.0 +aiohttp==3.10.5 +openai==1.47.0 +anthropic==0.34.2 +python-dotenv==1.0.1 \ No newline at end of file diff --git a/backend/database_handler/migration/versions/1daddff774b2_drop_from_constraint_on_rollup_tx.py b/backend/database_handler/migration/versions/1daddff774b2_drop_from_constraint_on_rollup_tx.py new file mode 100644 index 000000000..eaf7b12b6 --- /dev/null +++ b/backend/database_handler/migration/versions/1daddff774b2_drop_from_constraint_on_rollup_tx.py @@ -0,0 +1,37 @@ +"""drop_from_constraint_on_rollup_tx + +Revision ID: 1daddff774b2 +Revises: 579e86111b36 +Create Date: 2024-11-21 16:25:11.469033 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "1daddff774b2" +down_revision: Union[str, None] = "579e86111b36" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.alter_column( + "rollup_transactions", + "from_", + existing_type=sa.String(length=255), + nullable=True, + ) + + +def downgrade() -> None: + op.alter_column( + "rollup_transactions", + "from_", + existing_type=sa.String(length=255), + nullable=False, + ) diff --git a/backend/database_handler/types.py b/backend/database_handler/types.py index 6f318fd8e..82908c7e7 100644 --- a/backend/database_handler/types.py +++ b/backend/database_handler/types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from backend.node.genvm.types import Receipt +from backend.node.types import Receipt @dataclass diff --git a/backend/domain/types.py b/backend/domain/types.py index b93ec5b94..93767c580 100644 --- a/backend/domain/types.py +++ b/backend/domain/types.py @@ -4,7 +4,7 @@ from dataclasses import dataclass import decimal -from enum import Enum +from enum import Enum, IntEnum from backend.database_handler.models import TransactionStatus @@ -54,7 +54,7 @@ def to_dict(self): return result -class TransactionType(Enum): +class TransactionType(IntEnum): SEND = 0 DEPLOY_CONTRACT = 1 RUN_CONTRACT = 2 @@ -65,8 +65,8 @@ class Transaction: hash: str status: TransactionStatus type: TransactionType - from_address: str | None = None - to_address: str | None = None + from_address: str | None + to_address: str | None input_data: dict | None = None data: dict | None = None consensus_data: dict | None = None diff --git a/backend/llms.py b/backend/llms.py new file mode 100644 index 000000000..eaa5d3f9f --- /dev/null +++ b/backend/llms.py @@ -0,0 +1,70 @@ +import asyncio +import aiohttp +from typing import Protocol, Any +import os +import json + +_webrequest_url: str = ( + os.environ["WEBREQUESTPROTOCOL"] + + "://" + + os.environ["WEBREQUESTHOST"] + + ":" + + os.environ["WEBREQUESTPORT"] +) + + +class Plugin(Protocol): + def __init__(self, plugin_config: dict): ... + + async def call( + self, + node_config: dict, + prompt: str, + regex: str | None, + ) -> str: ... + + async def is_available(self) -> bool: ... + + async def is_model_available(self, model: str) -> bool: ... + + +async def _call_jsonrpc(function_name: str, *args) -> Any: + payload = { + "jsonrpc": "2.0", + "method": function_name, + "params": [*args], + "id": 1, + } + async with aiohttp.ClientSession() as session: + async with session.post( + _webrequest_url + "/api", + data=json.dumps(payload), + headers={"Content-Type": "application/json"}, + ) as response: + res = json.loads(await response.text()) + return res["result"]["response"] + + +class _RemotePlugin(Plugin): + def __init__(self, id: str): + self._id = id + + async def call( + self, + node_config: dict, + prompt: str, + regex: str | None, + ) -> str: + return await _call_jsonrpc( + "llm_plugin_call", self._id, node_config, prompt, regex + ) + + async def is_available(self) -> bool: + return await _call_jsonrpc("llm_plugin_is_available", self._id) + + async def is_model_available(self, model: str) -> bool: + return await _call_jsonrpc("llm_plugin_is_model_available", self._id, model) + + +async def get_llm_plugin(plugin: str, plugin_config: dict) -> Plugin: + return _RemotePlugin(await _call_jsonrpc("llm_plugin_get", plugin, plugin_config)) diff --git a/backend/node/base.py b/backend/node/base.py index e53ea0424..3485ec506 100644 --- a/backend/node/base.py +++ b/backend/node/base.py @@ -4,39 +4,103 @@ import json import base64 from typing import Callable, Optional +import typing +import collections.abc +import os from backend.domain.types import Validator, Transaction, TransactionType -from backend.node.genvm.base import GenVM +from backend.protocol_rpc.message_handler.types import LogEvent, EventType, EventScope +import backend.node.genvm.base as genvmbase +import backend.node.genvm.origin.host_fns as genvmconsts from backend.database_handler.contract_snapshot import ContractSnapshot -from backend.node.genvm.types import Receipt, ExecutionMode, Vote +from backend.node.types import Receipt, ExecutionMode, Vote, ExecutionResultStatus from backend.protocol_rpc.message_handler.base import MessageHandler +from .types import Address + + +class _SnapshotView(genvmbase.StateProxy): + def __init__( + self, + snapshot: ContractSnapshot, + snapshot_factory: typing.Callable[[str], ContractSnapshot], + readonly: bool, + ): + self.contract_address = Address(snapshot.contract_address) + self.snapshot = snapshot + self.snapshot_factory = snapshot_factory + self.cached = {} + self.readonly = readonly + + def _get_snapshot(self, addr: Address) -> ContractSnapshot: + if addr == self.contract_address: + return self.snapshot + res = self.cached.get(addr) + if res is not None: + return res + res = self.snapshot_factory(addr.as_hex) + self.cached[addr] = res + return res + + def get_code(self, addr: Address) -> bytes: + return self._get_snapshot(addr).contract_code.encode("utf-8") + + def storage_read( + self, account: Address, slot: bytes, index: int, le: int, / + ) -> tuple[bytes, int]: + snap = self._get_snapshot(account) + for_acc = snap.encoded_state.setdefault(account.as_b64, {}) + for_slot = for_acc.setdefault(base64.b64encode(slot).decode("ascii"), "") + data = bytearray(base64.b64decode(for_slot)) + data.extend(b"\x00" * (index + le - len(data))) + return data[index : index + le] + + def storage_write( + self, + account: Address, + slot: bytes, + index: int, + got: collections.abc.Buffer, + /, + ) -> None: + assert account == self.contract_address + assert not self.readonly + snap = self._get_snapshot(account) + for_acc = snap.encoded_state.setdefault(account.as_b64, {}) + slot_id = base64.b64encode(slot).decode("ascii") + for_slot = for_acc.setdefault(slot_id, "") + data = bytearray(base64.b64decode(for_slot)) + mem = memoryview(got) + data.extend(b"\x00" * (index + len(mem) - len(data))) + data[index : index + len(mem)] = mem + for_acc[slot_id] = base64.b64encode(data).decode("utf-8") + class Node: def __init__( self, - contract_snapshot: ContractSnapshot, + contract_snapshot: ContractSnapshot | None, validator_mode: ExecutionMode, validator: Validator, - contract_snapshot_factory: Callable[[str], ContractSnapshot], + contract_snapshot_factory: Callable[[str], ContractSnapshot] | None, leader_receipt: Optional[Receipt] = None, - msg_handler: MessageHandler = None, + msg_handler: MessageHandler | None = None, ): + self.contract_snapshot = contract_snapshot self.validator_mode = validator_mode + self.validator = validator self.address = validator.address self.leader_receipt = leader_receipt self.msg_handler = msg_handler self.contract_snapshot_factory = contract_snapshot_factory - self.genvm = GenVM( - contract_snapshot, - self.validator_mode, - validator.to_dict(), - contract_snapshot_factory, - msg_handler, - ) + + def _create_genvm(self) -> genvmbase.IGenVM: + return genvmbase.GenVMHost() async def exec_transaction(self, transaction: Transaction) -> Receipt: + assert transaction.data is not None transaction_data = transaction.data + assert transaction.from_address is not None if transaction.type == TransactionType.DEPLOY_CONTRACT: calldata = base64.b64decode(transaction_data["calldata"]) receipt = await self.deploy_contract( @@ -51,16 +115,40 @@ async def exec_transaction(self, transaction: Transaction) -> Receipt: calldata, ) else: - receipt = ... + raise Exception(f"unknown transaction type {transaction.type}") return receipt - def parse_transaction_execution_receipt(self, receipt: Receipt) -> Receipt: - if ( - self.validator_mode == ExecutionMode.LEADER - or self.leader_receipt.contract_state == receipt.contract_state + def _set_vote(self, receipt: Receipt) -> Receipt: + + def try_decode_error(e: Exception) -> tuple[str, ...]: + # FIXME(kp2pml30): #622 I am not sure that we should compare args, + # because if traceback gets there, + # it will have different ones for validator and nodes + if len(e.args) != 2: + return () + typ = e.args[0] + if typ == "rollback": + return ("rollback", e.args[1]) + if typ == "error": + return ("error",) + return () + + leader_receipt = self.leader_receipt + if self.validator_mode == ExecutionMode.LEADER: + receipt.vote = Vote.AGREE + elif ( + leader_receipt.execution_result == receipt.execution_result + and (leader_receipt.error is None) == (receipt.error is None) + and ( + leader_receipt.error is None + or try_decode_error(leader_receipt.error) + == try_decode_error(receipt.error) + ) + and leader_receipt.returned == receipt.returned + and leader_receipt.contract_state == receipt.contract_state + and leader_receipt.pending_transactions == receipt.pending_transactions ): receipt.vote = Vote.AGREE - else: receipt.vote = Vote.DISAGREE @@ -72,32 +160,128 @@ async def deploy_contract( code_to_deploy: str, calldata: bytes, ) -> Receipt: - receipt = await self.genvm.deploy_contract( - from_address, code_to_deploy, calldata, self.leader_receipt + assert self.contract_snapshot is not None + self.contract_snapshot.contract_code = code_to_deploy + return await self._run_genvm( + from_address, calldata, readonly=False, is_init=True ) - return self.parse_transaction_execution_receipt(receipt) async def run_contract(self, from_address: str, calldata: bytes) -> Receipt: - receipt = await self.genvm.run_contract( - from_address, calldata, self.leader_receipt + return await self._run_genvm( + from_address, calldata, readonly=False, is_init=False ) - return self.parse_transaction_execution_receipt(receipt) + async def get_contract_data( + self, + from_address: str, + calldata: bytes, + ) -> Receipt: + return await self._run_genvm( + from_address, calldata, readonly=True, is_init=False + ) + + async def get_contract_schema(self, code: str) -> str: + genvm = self._create_genvm() + return await genvm.get_contract_schema(code.encode("utf-8")) - def get_contract_data( + async def _run_genvm( self, - code: str, - state: str, + from_address: str, calldata: bytes, - ): - result = self.genvm.get_contract_data( - code, - state, - calldata, - self.contract_snapshot_factory, + *, + readonly: bool, + is_init: bool, + ) -> Receipt: + genvm = self._create_genvm() + leader_res: None | dict[int, bytes] + if self.leader_receipt is None: + leader_res = None + else: + leader_res = { + k: base64.b64decode(v) + for k, v in self.leader_receipt.eq_outputs.items() + } + assert self.contract_snapshot is not None + assert self.contract_snapshot_factory is not None + config = { + "modules": [ + { + "path": "${genvmRoot}/lib/genvm-modules/", + "id": "llm", + "config": { + "host": f"{os.environ['WEBREQUESTPROTOCOL']}://{os.environ['WEBREQUESTHOST']}:{os.environ['WEBREQUESTPORT']}", + "provider": "simulator", + "model": json.dumps(self.validator.llmprovider.__dict__), + }, + }, + { + "path": "${genvmRoot}/lib/genvm-modules/", + "id": "web", + "config": { + "host": f"{os.environ['WEBREQUESTPROTOCOL']}://{os.environ['WEBREQUESTHOST']}:{os.environ['WEBREQUESTSELENIUMPORT']}" + }, + }, + ] + } + snapshot_view = _SnapshotView( + self.contract_snapshot, self.contract_snapshot_factory, readonly ) + res = await genvm.run_contract( + snapshot_view, + contract_address=Address(self.contract_snapshot.contract_address), + from_address=Address(from_address), + calldata_raw=calldata, + is_init=is_init, + leader_results=leader_res, + config=json.dumps(config), + ) + base_receipt = { + "class_name": "", + "calldata": calldata, + "mode": self.validator_mode, + "node_config": self.validator.to_dict(), + } + if self.msg_handler is not None: + self.msg_handler.send_message( + LogEvent( + name="execution_finished", + type=EventType.INFO, + scope=EventScope.GENVM, + message="execution finished", + data={ + "result": f"{res.result!r}", + "stdout": res.stdout, + "stderr": res.stderr, + }, + ) + ) - return result + returned = None + error = None + exec_result_code = ExecutionResultStatus.ERROR - def get_contract_schema(self, code: str): - return GenVM.get_contract_schema(code) + if isinstance(res.result, genvmbase.ExecutionFail): + error = Exception("error", repr(res.result)) + elif isinstance(res.result, genvmbase.ExecutionRollback): + error = Exception("rollback", res.result.message) + else: + assert isinstance(res.result, genvmbase.ExecutionReturn) + returned = res.result.ret + exec_result_code = ExecutionResultStatus.SUCCESS + + return self._set_vote( + Receipt( + returned=returned, + error=error, + gas_used=0, + eq_outputs={ + k: base64.b64encode(v).decode("ascii") + for k, v in res.eq_outputs.items() + }, + pending_transactions=res.pending_transactions, + vote=None, + execution_result=exec_result_code, + contract_state=self.contract_snapshot.encoded_state, + **base_receipt, + ) + ) diff --git a/backend/node/create_nodes/create_nodes.py b/backend/node/create_nodes/create_nodes.py index d6b55c587..4ad62cf78 100644 --- a/backend/node/create_nodes/create_nodes.py +++ b/backend/node/create_nodes/create_nodes.py @@ -1,23 +1,23 @@ import secrets -from typing import Callable, List +from typing import Callable, Awaitable from numpy.random import default_rng from dotenv import load_dotenv from backend.domain.types import LLMProvider -from backend.node.genvm.llms import Plugin +from backend.llms import Plugin load_dotenv() rng = default_rng(secrets.randbits(128)) -def random_validator_config( - get_stored_providers: Callable[[], List[LLMProvider]], - get_llm_plugin: Callable[[str, dict], Plugin], +async def random_validator_config( + get_stored_providers: Callable[[], list[LLMProvider]], + get_llm_plugin: Callable[[str, dict], Awaitable[Plugin]], limit_providers: set[str] = None, limit_models: set[str] = None, amount: int = 1, -) -> List[LLMProvider]: +) -> list[LLMProvider]: providers_to_use = get_stored_providers() if limit_providers: @@ -37,17 +37,19 @@ def random_validator_config( f"Requested providers '{limit_providers}' do not match any stored providers. Please review your stored providers." ) - def filter_by_available(provider: LLMProvider) -> bool: - plugin = get_llm_plugin(provider.plugin, provider.plugin_config) - if not plugin.is_available(): + async def filter_by_available(provider: LLMProvider) -> bool: + plugin = await get_llm_plugin(provider.plugin, provider.plugin_config) + if not await plugin.is_available(): return False - if not plugin.is_model_available(provider.model): + if not await plugin.is_model_available(provider.model): return False return True - providers_to_use = list(filter(filter_by_available, providers_to_use)) + providers_to_use = [ + plug for plug in providers_to_use if await filter_by_available(plug) + ] if not providers_to_use: raise Exception("No providers available.") diff --git a/backend/node/genvm/base.py b/backend/node/genvm/base.py index 483f548da..fbfddb3b1 100644 --- a/backend/node/genvm/base.py +++ b/backend/node/genvm/base.py @@ -1,567 +1,333 @@ # backend/node/genvm/base.py -from functools import partial -import inspect -import re -import pickle +__all__ = ("IGenVM", "GenVMHost") + +import typing +import tempfile +from pathlib import Path +import shutil +import json import base64 -import sys +import asyncio +import socket +import backend.node.genvm.origin.base_host as genvmhost +import collections.abc +import functools import traceback -import io -from contextlib import contextmanager, redirect_stderr, redirect_stdout -from typing import Any, Callable - -from backend.database_handler.contract_snapshot import ContractSnapshot -from backend.node.genvm.equivalence_principle import EquivalencePrinciple -from backend.node.genvm.code_enforcement import code_enforcement_check -from backend.node.genvm.std.vector_store import VectorStore -from backend.node.genvm.types import ( + +from backend.node.types import ( PendingTransaction, - Receipt, - ExecutionResultStatus, - ExecutionMode, -) -from backend.protocol_rpc.message_handler.base import MessageHandler -from backend.protocol_rpc.message_handler.types import ( - LogEvent, - EventType, - EventScope, + Address, ) +import backend.node.genvm.origin.calldata as calldata +from dataclasses import dataclass -from .calldata import ( - decode as calldata_decode, - encode as calldata_encode, - to_str as calldata_repr, -) +from backend.node.genvm.config import get_genvm_path -@contextmanager -def safe_globals(override_globals: dict[str] = None): - old_globals = globals().copy() - globals().update( - { - "contract_runner": None, - "VectorStore": VectorStore, - } - ) - if override_globals: - globals().update(override_globals) - try: - yield - finally: - globals().clear() - globals().update(old_globals) +# Unexpected error +# - user contract crushed (for instance from integer division by zero) +# - user contract is not a contract (genvm can't run it) +# - os-level error occurred (i.e. socket to genvm got closed) +@dataclass +class ExecutionFail: + exc: list[Exception] + def __repr__(self) -> str: + lines: list[str] = ["ExecutionFail:\n"] + for exc_lines in (traceback.format_exception(e) for e in self.exc): + lines.extend(exc_lines) + return "".join(lines) -_FAKE_DECODED_DATA = object() +@dataclass +class ExecutionReturn: + ret: bytes -def _calldata_to_str(raw: bytes, decoded): - if decoded is _FAKE_DECODED_DATA: - return str(base64.b64encode(raw), encoding="ascii") - return calldata_repr(decoded) +# Expected contract error (i.e. handled invalid arguments) +@dataclass +class ExecutionRollback: + message: str -class ContractRunner: - def __init__( - self, - mode: ExecutionMode, - node_config: dict, - contract_snapshot_factory: Callable[[str], ContractSnapshot], - ): - self.mode = mode # if the node is acting as "validator" or "leader" - self.node_config = node_config # provider, model, config, stake - self.from_address = None # the address of the transaction sender - self.gas_used = 0 # the amount of gas used by the contract - self.eq_num = 0 # keeps track of the eq principle number being executed - self.eq_outputs = { - ExecutionMode.LEADER.value: {} - } # the eq principle outputs for the leader and validators - self.contract_snapshot_factory = contract_snapshot_factory - - -class GenVM: - eq_principle = EquivalencePrinciple - - def __init__( - self, - snapshot: ContractSnapshot, - validator_mode: str, - validator: dict, - contract_snapshot_factory: Callable[[str], ContractSnapshot], - msg_handler: MessageHandler = None, - ): - self.snapshot = snapshot - self.validator_mode = validator_mode - self.msg_handler = msg_handler - self.contract_runner = ContractRunner( - validator_mode, validator, contract_snapshot_factory - ) - self.pending_transactions: list[PendingTransaction] = [] - @staticmethod - def _get_contract_class_name(contract_code: str) -> str: - pattern = r"class (\w+)\(IContract\):" - matches = re.findall(pattern, contract_code) - if len(matches) == 0: - raise Exception("No class name found") - return matches[0] +@dataclass +class ExecutionResult: + result: ExecutionReturn | ExecutionRollback | ExecutionFail + eq_outputs: dict[int, bytes] + pending_transactions: list[PendingTransaction] + stdout: str + stderr: str - def _generate_receipt( - self, - class_name: str, - encoded_object: str, - calldata: bytes, - execution_result: ExecutionResultStatus, - error: Exception, - ) -> Receipt: - return Receipt( - class_name=class_name, - calldata=calldata, - gas_used=self.contract_runner.gas_used, - mode=self.contract_runner.mode, - contract_state=encoded_object, - node_config=self.contract_runner.node_config, - eq_outputs=self.contract_runner.eq_outputs, - execution_result=execution_result, - error=error, - pending_transactions=self.pending_transactions, - ) - async def deploy_contract( +# Interface for accessing the blockchain state, it is needed to not tangle current (awfully unoptimized) +# storage format with the genvm source code +class StateProxy(typing.Protocol): + def storage_read( + self, account: Address, slot: bytes, index: int, le: int, / + ) -> bytes: ... + def storage_write( self, - from_address: str, - code_to_deploy: str, - calldata_raw: bytes, - leader_receipt: Receipt | None, - ): - class_name = self._get_contract_class_name(code_to_deploy) - code_enforcement_check(code_to_deploy, class_name) - self.contract_runner.from_address = from_address - execution_result = ExecutionResultStatus.SUCCESS - error = None - - self.eq_principle.contract_runner = self.contract_runner - if self.contract_runner.mode == ExecutionMode.VALIDATOR: - self.contract_runner.eq_outputs[ExecutionMode.LEADER.value] = ( - leader_receipt.eq_outputs[ExecutionMode.LEADER.value] - ) - - # Buffers to capture stdout and stderr - stdout_buffer = io.StringIO() - - calldata = _FAKE_DECODED_DATA - - with redirect_stdout(stdout_buffer), safe_globals( - { - "contract_runner": self.contract_runner, - "Contract": partial( - ExternalContract, - self.contract_runner.contract_snapshot_factory, - lambda x: self.pending_transactions.append(x), - self, - ), - } - ): - local_namespace = {} - exec(code_to_deploy, globals(), local_namespace) - - contract_class = local_namespace[class_name] - - # Ensure the class and other necessary elements are in the global local_namespace if needed - for name, value in local_namespace.items(): - globals()[name] = value - - module = sys.modules[__name__] - setattr(module, class_name, contract_class) - - encoded_pickled_object = None # Default value in order to have something to return in case of error - try: - calldata = calldata_decode(calldata_raw) - ctor_args = calldata["args"] - if not isinstance(ctor_args, list): - raise Exception( - f"Invalid arguments, list expected, got {ctor_args}" - ) - # Manual instantiation of the class is done to handle async __init__ methods - current_contract = contract_class.__new__(contract_class, *ctor_args) - ctor_method = getattr(contract_class, "__init__") - if inspect.iscoroutinefunction(ctor_method): - await ctor_method(current_contract, *ctor_args) - else: - ctor_method(current_contract, *ctor_args) - pickled_object = pickle.dumps(current_contract) - encoded_pickled_object = base64.b64encode(pickled_object).decode( - "utf-8" - ) - - except Exception as e: - trace = traceback.format_exc() - error = e - print("Error deploying contract", error) - print(trace) - execution_result = ExecutionResultStatus.ERROR - self.msg_handler.send_message( - LogEvent( - "contract_deployment_failed", - EventType.ERROR, - EventScope.GENVM, - "Error deploying contract: " + str(error), - { - "error": str(error), - "traceback": f"\n{trace}", - }, - ) - ) + account: Address, + slot: bytes, + index: int, + got: collections.abc.Buffer, + /, + ) -> None: ... + def get_code(self, addr: Address) -> bytes: ... - ## Clean up - delattr(module, class_name) - - if self.contract_runner.mode == ExecutionMode.LEADER: - captured_stdout = stdout_buffer.getvalue() - - if captured_stdout: - print(captured_stdout) - self.send_stdout(captured_stdout, self.msg_handler) - - if execution_result == ExecutionResultStatus.SUCCESS: - self.msg_handler.send_message( - LogEvent( - "deploying_contract", - EventType.SUCCESS, - EventScope.GENVM, - "Deploying contract", - { - "calldata": _calldata_to_str(calldata_raw, calldata), - "output": captured_stdout, - }, - ) - ) - - return self._generate_receipt( - class_name, - encoded_pickled_object, - calldata_raw, - execution_result, - error, - ) +# GenVM protocol just in case it is needed for mocks or bringing back the old one +class IGenVM(typing.Protocol): async def run_contract( self, - from_address: str, + state: StateProxy, + *, + from_address: Address, + contract_address: Address, calldata_raw: bytes, - leader_receipt: Receipt | None, - ) -> Receipt: - self.contract_runner.from_address = from_address - contract_code = self.snapshot.contract_code - execution_result = ExecutionResultStatus.SUCCESS - error = None - - self.eq_principle.contract_runner = self.contract_runner - - if self.contract_runner.mode == ExecutionMode.VALIDATOR: - self.contract_runner.eq_outputs[ExecutionMode.LEADER.value] = ( - leader_receipt.eq_outputs[ExecutionMode.LEADER.value] - ) - - # Buffers to capture stdout and stderr - stdout_buffer = io.StringIO() - - calldata = _FAKE_DECODED_DATA - - with redirect_stdout(stdout_buffer), safe_globals( - { - "contract_runner": self.contract_runner, - "Contract": partial( - ExternalContract, - self.contract_runner.contract_snapshot_factory, - lambda x: self.pending_transactions.append(x), - self, - ), - } - ): - local_namespace = {} - # Execute the code to ensure all classes are defined in the local_namespace - exec(contract_code, globals(), local_namespace) - - # Ensure the class and other necessary elements are in the global local_namespace if needed - globals().update(local_namespace) - - contract_encoded_state = self.snapshot.encoded_state - decoded_pickled_object = base64.b64decode(contract_encoded_state) - current_contract = pickle.loads(decoded_pickled_object) - - method_name = "" - method_args = [] - try: - calldata = calldata_decode(calldata_raw) - method_name = calldata["method"] - method_args = calldata["args"] - if not isinstance(method_args, list): - raise Exception( - f"Invalid arguments, list expected, got {method_args}" - ) - function_to_run = getattr(current_contract, method_name) - if inspect.iscoroutinefunction(function_to_run): - await function_to_run(*method_args) - else: - function_to_run(*method_args) - except Exception as e: - trace = traceback.format_exc() - error = e - print("Error executing method", error) - print(trace) - execution_result = ExecutionResultStatus.ERROR - self.msg_handler.send_message( - LogEvent( - "write_contract_failed", - EventType.ERROR, - EventScope.GENVM, - "Error executing method " + method_name + ": " + str(error), - { - "calldata": _calldata_to_str(calldata_raw, calldata), - "error": str(error), - "traceback": f"\n{trace}", - }, - ) - ) - - pickled_object = pickle.dumps(current_contract) - encoded_pickled_object = base64.b64encode(pickled_object).decode("utf-8") - class_name = self._get_contract_class_name(contract_code) - - if self.contract_runner.mode == ExecutionMode.LEADER: - captured_stdout = stdout_buffer.getvalue() - - if captured_stdout: - print(captured_stdout) - self.send_stdout(captured_stdout, self.msg_handler) - - if execution_result == ExecutionResultStatus.SUCCESS: - self.msg_handler.send_message( - LogEvent( - "write_contract", - EventType.INFO, - EventScope.GENVM, - "Execute method: " + method_name, - { - "calldata": _calldata_to_str(calldata_raw, calldata), - "output": captured_stdout, - }, - ) - ) - - return self._generate_receipt( - class_name, - encoded_pickled_object, - calldata_raw, - execution_result, - error, - ) + is_init: bool = False, + leader_results: None | dict[int, bytes], + config: str, + ) -> ExecutionResult: ... - @staticmethod - def get_contract_schema(contract_code: str) -> dict: + async def get_contract_schema(self, contract_code: bytes) -> str: ... - namespace = {} - with safe_globals(): - exec(contract_code, globals(), namespace) - class_name = GenVM._get_contract_class_name(contract_code) - iclass = namespace[class_name] +# state proxy that always fails and can give code only for address from a constructor +# useful for get_schema +class _StateProxyNone(StateProxy): + def __init__(self, my_address: Address, code: bytes): + self.my_address = my_address + self.code = code - members = inspect.getmembers(iclass) + def storage_read( + self, account: Address, slot: bytes, index: int, le: int, / + ) -> bytes: + assert False - # Find all class methods - methods = {} - functions_and_methods = [ - m for m in members if inspect.isfunction(m[1]) or inspect.ismethod(m[1]) - ] - for name, member in functions_and_methods: - signature = inspect.signature(member) - - inputs = {} - for ( - method_variable_name, - method_variable, - ) in signature.parameters.items(): - if method_variable_name != "self": - annotation = str(method_variable.annotation)[8:-2] - inputs[method_variable_name] = str(annotation) - - return_annotation = str(signature.return_annotation)[8:-2] - - if return_annotation == "inspect._empty": - return_annotation = "None" - - result = {"inputs": inputs, "output": return_annotation} - - methods[name] = result - - abi = GenVM.generate_abi_from_schema_methods(methods) - - contract_schema = { - "class": class_name, - "abi": abi, - } - - return contract_schema - - @staticmethod - def get_abi_param_type(param_type: str) -> str: - # okay, this is unsolvable with current implementation... - if param_type == "int": - return "int" - if param_type == "str": - return "string" - if param_type == "bool": - return "bool" - if param_type == "dict": - return "any" - if param_type == "list": - return "any" - if param_type == "None": - return "None" - return param_type - - @staticmethod - def generate_abi_from_schema_methods(contract_schema_methods: dict) -> list: - abi = [] - - for method_name, method_info in contract_schema_methods.items(): - abi_entry = { - "name": method_name, - "type": "function", - "inputs": [], - "outputs": [], - } - - for input_name, input_type in method_info["inputs"].items(): - abi_entry["inputs"].append( - {"name": input_name, "type": GenVM.get_abi_param_type(input_type)} - ) - - if method_info["output"]: - abi_entry["outputs"].append( - { - "name": "", - "type": GenVM.get_abi_param_type(method_info["output"]), - } - ) - - if method_name == "__init__": - abi_entry["type"] = "constructor" - del abi_entry["name"] - del abi_entry["outputs"] + def storage_write( + self, + account: Address, + slot: bytes, + index: int, + got: collections.abc.Buffer, + /, + ) -> None: + assert False - abi.append(abi_entry) + def get_code(self, addr: Address) -> bytes: + assert addr == self.my_address + return self.code - return abi - @staticmethod - def send_stdout(stdout: str, msg_handler: MessageHandler) -> str: - msg_handler.send_message( - LogEvent( - "contract_stdout", - EventType.INFO, - EventScope.GENVM, - stdout, +# Actual genvm wrapper that will start process and handle all communication +class GenVMHost(IGenVM): + async def run_contract( + self, + state: StateProxy, + *, + from_address: Address, + contract_address: Address, + calldata_raw: bytes, + is_init: bool = False, + leader_results: None | dict[int, bytes], + config: str, + ) -> ExecutionResult: + message = { + "is_init": is_init, + "contract_account": contract_address.as_b64, + "sender_account": from_address.as_b64, + "value": None, + "gas": 2**64 - 1, + } + return await _run_genvm_host( + functools.partial( + _Host, + calldata_bytes=calldata_raw, + state_proxy=state, + leader_results=leader_results, ), - log_to_terminal=False, + ["--message", json.dumps(message)], + config, ) - def get_contract_data( - self, - code: str, - state: str, - calldata_raw: bytes, - contract_snapshot_factory: Callable[[str], ContractSnapshot], - ) -> Any: - result = None - decoded_pickled_object = base64.b64decode(state) - output_buffer = io.StringIO() - - with redirect_stdout(output_buffer), redirect_stderr( - output_buffer - ), safe_globals( - { - "Contract": partial( - ExternalContract, - contract_snapshot_factory, - None, # TODO: should read methods be allowed to add new transactions? - self, - ) - } - ): - local_namespace = {} - # Execute the code to ensure all classes are defined in the namespace - exec(code, globals(), local_namespace) - - # Ensure the class and other necessary elements are in the global namespace if needed - globals().update(local_namespace) - - calldata = calldata_decode(calldata_raw) - method_name = calldata["method"] - method_args = calldata["args"] - - contract_state = pickle.loads(decoded_pickled_object) - method_to_call = getattr(contract_state, method_name) - result = method_to_call(*method_args) - - captured_stdout = output_buffer.getvalue() - - if captured_stdout: - print(captured_stdout) - self.send_stdout(captured_stdout, self.msg_handler) - - if self.contract_runner.mode == ExecutionMode.LEADER: - self.msg_handler.send_message( - LogEvent( - "read_contract", - EventType.INFO, - EventScope.GENVM, - "Call method: " + method_name, - { - "calldata": calldata_repr(calldata), - "result": result, - "output": captured_stdout, - }, - ) - ) + async def get_contract_schema(self, contract_code: bytes) -> str: + NO_ADDR = str(base64.b64encode(b"\x00" * 20), encoding="ascii") + message = { + "is_init": False, + "contract_account": NO_ADDR, + "sender_account": NO_ADDR, + "value": None, + "gas": 2**64 - 1, + } + res = await _run_genvm_host( + functools.partial( + _Host, + calldata_bytes=calldata.encode({"method": "__get_schema__"}), + state_proxy=_StateProxyNone(Address(NO_ADDR), contract_code), + leader_results=None, + ), + ["--message", json.dumps(message)], + None, + ) + if not isinstance(res.result, ExecutionReturn): + raise Exception(f"execution failed {res}") + ret_calldata = res.result.ret + schema = calldata.decode(ret_calldata) + if not isinstance(schema, str): + raise Exception(f"abi violation, __get_schema__ returned {schema}") + return schema - return result +# Class that has logic for handling all genvm host methods and accumulating results +class _Host(genvmhost.IHost): + _result: ExecutionReturn | ExecutionRollback | ExecutionFail | None + _eq_outputs: dict[int, bytes] + _pending_transactions: list[PendingTransaction] -class ExternalContract: def __init__( self, - contract_snapshot_factory: Callable[[str], ContractSnapshot], - schedule_pending_transaction: Callable[[PendingTransaction], None], - genvm: GenVM, - address: str, + sock_listen: socket.socket, + *, + calldata_bytes: bytes, + state_proxy: StateProxy, + leader_results: None | dict[int, bytes], ): - self.address = address - self.genvm = genvm - self.contract_snapshot = contract_snapshot_factory(address) - self.contract_snapshot_factory = contract_snapshot_factory - self.schedule_pending_transaction = schedule_pending_transaction - - def __getattr__(self, name): - def method(*args): # kwargs are not supported yet - if re.match("get_", name): - return self.genvm.get_contract_data( - self.contract_snapshot.contract_code, - self.contract_snapshot.encoded_state, - calldata_encode({"method": name, "args": args}), - self.contract_snapshot_factory, - ) - else: - self.schedule_pending_transaction( - PendingTransaction( - address=self.address, - calldata=calldata_encode({"method": name, "args": args}), - ) - ) - + self._eq_outputs = {} + self._pending_transactions = [] + self._result = None + + self.sock_listen = sock_listen + self.sock = None + self._state_proxy = state_proxy + self.calldata_bytes = calldata_bytes + self._leader_results = leader_results + + def provide_result( + self, res: genvmhost.RunHostAndProgramRes, fail: ExecutionFail | None + ) -> ExecutionResult: + ret = functools.partial( + ExecutionResult, + eq_outputs=self._eq_outputs, + pending_transactions=self._pending_transactions, + stdout=res.stdout, + stderr=res.stderr, + ) + if fail is not None: + return ret(fail) + if self._result is not None: + return ret(self._result) + return ret(ExecutionFail(exc=res.exceptions)) + + async def loop_enter(self) -> socket.socket: + async_loop = asyncio.get_event_loop() + self.sock, _addr = await async_loop.sock_accept(self.sock_listen) + self.sock.setblocking(False) + self.sock_listen.close() + return self.sock + + async def get_calldata(self, /) -> bytes: + return self.calldata_bytes + + async def get_code(self, addr: bytes, /) -> bytes: + return self._state_proxy.get_code(Address(addr)) + + async def storage_read( + self, account: bytes, slot: bytes, index: int, le: int, / + ) -> bytes: + return self._state_proxy.storage_read(Address(account), slot, index, le) + + async def storage_write( + self, + account: bytes, + slot: bytes, + index: int, + got: collections.abc.Buffer, + /, + ) -> None: + return self._state_proxy.storage_write(Address(account), slot, index, got) + + async def consume_result( + self, type: genvmhost.ResultCode, data: collections.abc.Buffer, / + ) -> None: + if type == genvmhost.ResultCode.RETURN: + self._result = ExecutionReturn(ret=bytes(data)) + elif type == genvmhost.ResultCode.ROLLBACK: + self._result = ExecutionRollback(str(data, encoding="utf-8")) + + async def get_leader_nondet_result(self, call_no: int, /) -> bytes | str | None: + leader_results = self._leader_results + if leader_results is None: return None + leader_results_mem = memoryview(leader_results[call_no]) + if leader_results_mem[0] == genvmhost.ResultCode.ROLLBACK: + return str(leader_results_mem[1:], "utf-8") + if leader_results_mem[0] == genvmhost.ResultCode.RETURN: + return bytes(leader_results_mem[1:]) + assert False + + async def post_nondet_result( + self, call_no: int, type: genvmhost.ResultCode, data: collections.abc.Buffer, / + ) -> None: + barr = bytearray() + barr.append(type.value) + barr.extend(memoryview(data)) + self._eq_outputs[call_no] = bytes(barr) + + async def post_message( + self, gas: int, account: bytes, calldata: bytes, code: bytes, / + ) -> None: + self._pending_transactions.append( + PendingTransaction(Address(account).as_hex, calldata) + ) + + async def consume_gas(self, gas: int, /) -> None: + pass + - return method +async def _run_genvm_host( + host_supplier: typing.Callable[[socket.socket], _Host], + args: list[Path | str], + config: str | None, +) -> ExecutionResult: + tmpdir = Path(tempfile.mkdtemp()) + try: + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock_listener: + sock_listener.setblocking(False) + sock_path = tmpdir.joinpath("sock") + sock_listener.bind(str(sock_path)) + sock_listener.listen(1) + + new_args = [ + get_genvm_path(), + "run", + "--host", + f"unix://{sock_path}", + "--print=shrink", + ] + + if config is not None: + conf_path = tmpdir.joinpath("conf.json") + conf_path.write_text(config) + new_args.extend(["--config", conf_path]) + new_args.extend(args) + + host: _Host = host_supplier(sock_listener) # _Host(sock_listener) + try: + return host.provide_result( + await genvmhost.run_host_and_program(host, new_args), None + ) + finally: + if host.sock is not None: + host.sock.close() + except Exception as e: + return ExecutionResult( + result=ExecutionFail([e]), + eq_outputs={}, + pending_transactions=[], + stdout="", + stderr="", + ) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/backend/node/genvm/code_enforcement.py b/backend/node/genvm/code_enforcement.py deleted file mode 100644 index 97fb90714..000000000 --- a/backend/node/genvm/code_enforcement.py +++ /dev/null @@ -1,202 +0,0 @@ -import ast - - -def code_enforcement_check(code: str, class_name: str) -> str: - result = {"status": "error", "message": "", "data": []} - # Check is valid Python code - code_check_result = _check_code(code) - if code_check_result: - result["message"] = code_check_result - return result - # See if the class exists - if not _does_class_exist_in_code(code, class_name): - result["message"] = f"The class {class_name} does not exist in the code" - return result - # Make sure there are no raw instantiations of the EquivalencePrinciple class - if _code_has_bad_implementations_of_eq_principle(code): - result["message"] = ( - "You cannot directly instantiate the EquivalencePrinciple class" - ) - return result - # Make sure that no code modifies self inside an equivalence block - linenos_list = _linenos_of_where_code_modifys_self_in_eq_block(code, class_name) - if len(linenos_list): - result["message"] = "Self was modified inside an equivalence block" - result["data"] = linenos_list - return result - # Make sure no equivalence block variables are referenced outside the block - linenos_list = _code_references_eq_block_variables(code, class_name) - if len(linenos_list): - result["message"] = ( - "Variables declared in the equivalence block are referenced outside of the equivalence block" - ) - result["data"] = linenos_list - return result - result["status"] = "success" - return result - - -def _check_code(code: str): - try: - ast.parse(code) - return None - except Exception: - return "Your code is not valid Python code" - - -def _does_class_exist_in_code(code: str, class_name: str) -> bool: - visitor = ClassExistsVisitor(class_name) - visitor.visit(ast.parse(code)) - return visitor.class_exists - - -# Visits each class in the code and checks if the -# class name matches the class name poassed in -class ClassExistsVisitor(ast.NodeVisitor): - def __init__(self, class_name): - self.class_name = class_name - self.class_exists = False - - def visit_ClassDef(self, node): - if node.name == self.class_name: - self.class_exists = True - self.generic_visit(node) - - -def _code_has_bad_implementations_of_eq_principle(code: str) -> bool: - visitor = EquivalencePrincipleVisitor() - visitor.visit(ast.parse(code)) - if visitor.all_eq_line_nums != visitor.eq_line_nums_inside_async_waith_blocks: - return True - return False - - -class EquivalencePrincipleVisitor(ast.NodeVisitor): - - def __init__(self): - self.all_eq_line_nums = [] - self.eq_line_nums_inside_async_waith_blocks = [] - - # All the places where the EquivalencePrinciple class is called - def visit_Call(self, node): - if isinstance(node.func, ast.Name) and node.func.id == "EquivalencePrinciple": - self.eq_line_nums_inside_async_waith_blocks.append(node.lineno) - self.generic_visit(node) - - # All the places where the EquivalencePrinciple class is called within an async with block - def visit_AsyncWith(self, node): - async_ctx = node.items[0].context_expr - if isinstance(async_ctx, ast.Call) and isinstance(async_ctx.func, ast.Name): - if async_ctx.func.id == "EquivalencePrinciple": - self.all_eq_line_nums.append(node.lineno) - self.generic_visit(node) - - -def _linenos_of_where_code_modifys_self_in_eq_block(code: str, class_name) -> bool: - visitor = EquivalencePrincipleModifySelfVisitor(class_name) - visitor.visit(ast.parse(code)) - return visitor.modeifies_self_linenos - - -class EquivalencePrincipleModifySelfVisitor(ast.NodeVisitor): - - def __init__(self, class_name): - self.class_name = class_name - self.inside_call_method = False - self.inside_eq_block = False - self.modeifies_self_linenos = [] - - # Mark the fact that we are inside the code's class method - def visit_ClassDef(self, node): - if node.name == self.class_name: - for body_item in node.body: - if ( - isinstance(body_item, ast.AsyncFunctionDef) - and body_item.name != "__init__" - ): - self.inside_call_method = True - self.generic_visit(body_item) - self.inside_call_method = False - - # Mark thew fact that we are inside an equivalence block - def visit_AsyncWith(self, node): - async_ctx = node.items[0].context_expr - if isinstance(async_ctx, ast.Call) and isinstance(async_ctx.func, ast.Name): - if async_ctx.func.id == "EquivalencePrinciple": - self.inside_eq_block = True - self.generic_visit(node) - self.inside_eq_block = False - - # Record all assignments of class variables done inside an - # equivalence block - def visit_Assign(self, node): - for target in node.targets: - if self.inside_call_method and self.inside_eq_block: - if ( - isinstance(target, ast.Attribute) - and isinstance(target.value, ast.Name) - and target.value.id == "self" - ): - self.modeifies_self_linenos.append(target.lineno) - self.generic_visit(node) - - -def _code_references_eq_block_variables(code: str, class_name: str) -> list: - visitor = EquivalencePrincipleBlockVariables(class_name) - visitor.visit(ast.parse(code)) - return visitor.referenced_block_variables - - -class EquivalencePrincipleBlockVariables(ast.NodeVisitor): - - def __init__(self, class_name: str): - self.class_name = class_name - self.inside_call_method = False - self.inside_eq_block = False - self.has_visited_eq_block = False - self.eq_block_variables = [] - self.referenced_block_variables = [] - - # Mark the fact that we are inside the code's class method - def visit_ClassDef(self, node): - if node.name == self.class_name: - for body_item in node.body: - if ( - isinstance(body_item, ast.AsyncFunctionDef) - and body_item.name != "__init__" - ): - self.inside_call_method = True - self.generic_visit(body_item) - self.inside_call_method = False - - # Mark the fact that we are inside an equivalence block - # (and have visited at least one equivalence block) - def visit_AsyncWith(self, node): - async_ctx = node.items[0].context_expr - if isinstance(async_ctx, ast.Call) and isinstance(async_ctx.func, ast.Name): - if async_ctx.func.id == "EquivalencePrinciple": - self.inside_eq_block = True - self.generic_visit(node) - self.inside_eq_block = False - self.has_visited_eq_block = True - - # Records all the variables declared inside an equivalence block - def visit_Assign(self, node): - for target in node.targets: - if ( - isinstance(target, ast.Name) - and self.inside_call_method - and self.inside_eq_block - ): - self.eq_block_variables.append(target.id) - self.generic_visit(node) - - # Records all the variables declared in an equivalence block - # that are referenced outside of it - def visit_Name(self, node): - if ( - self.inside_call_method - and not self.inside_eq_block - and node.id in self.eq_block_variables - ): - self.referenced_block_variables.append(node.lineno) diff --git a/backend/node/genvm/config.py b/backend/node/genvm/config.py new file mode 100644 index 000000000..f7b502f21 --- /dev/null +++ b/backend/node/genvm/config.py @@ -0,0 +1,38 @@ +import os +from pathlib import Path + + +def _check_one(check: Path) -> bool: + try: + return check.exists() and check.is_file() + except: + return False + + +def _find_exe(name: str) -> Path: + checked = [] + for env_var in [f"{name.upper()}PATH", f"{name.upper()}_BIN"]: + var = os.getenv(env_var) + if var is None: + continue + for check in [Path(var), Path(var).joinpath(name)]: + checked.append(check) + if _check_one(check): + return check + for p in os.getenv("PATH", "").split(":"): + check = Path(p).joinpath(name) + checked.append(check) + if _check_one(check): + return check + raise Exception(f"Can't find executable {name}, searched at {checked}") + + +_found_at: Path | None = None + + +def get_genvm_path() -> Path: + global _found_at + if _found_at is None: + _found_at = _find_exe("genvm") + + return _found_at diff --git a/backend/node/genvm/context_wrapper.py b/backend/node/genvm/context_wrapper.py deleted file mode 100644 index e4c434bb7..000000000 --- a/backend/node/genvm/context_wrapper.py +++ /dev/null @@ -1,46 +0,0 @@ -from functools import wraps - - -def enforce_with_context(cls): - original_new = cls.__new__ - original_aenter = cls.__aenter__ - original_aexit = cls.__aexit__ - - @wraps(original_new) - def new_wrapper(cls, *args, **kwargs): - instance = original_new(cls) - instance._is_within_with_block = False - return instance - - @wraps(original_aenter) - def aenter_wrapper(self): - self._is_within_with_block = True - return original_aenter(self) - - @wraps(original_aexit) - def aexit_wrapper(self, exc_type, exc_value, traceback): - self._is_within_with_block = True - return original_aexit(self) - - def method_wrapper(method): - @wraps(method) - def wrapper(self, *args, **kwargs): - if not self._is_within_with_block: - raise RuntimeError( - f"Methods of {cls.__name__} must be called inside a 'with' block." - ) - return method(self, *args, **kwargs) - - return wrapper - - # Wrap all methods to enforce the check - for attr_name in dir(cls): - attr = getattr(cls, attr_name) - if callable(attr) and not attr_name.startswith("_"): - setattr(cls, attr_name, method_wrapper(attr)) - - cls.__new__ = new_wrapper - cls.__aenter__ = aenter_wrapper - cls.__aexit__ = aexit_wrapper - - return cls diff --git a/backend/node/genvm/equivalence_principle.py b/backend/node/genvm/equivalence_principle.py deleted file mode 100644 index fa9123a21..000000000 --- a/backend/node/genvm/equivalence_principle.py +++ /dev/null @@ -1,123 +0,0 @@ -# backend/node/genvm/equivalence_principle.py - -from typing import Any, Optional -from backend.node.genvm.context_wrapper import enforce_with_context -from backend.node.genvm import llms -from backend.node.genvm.webpage_utils import get_webpage_content -from backend.node.genvm.types import ExecutionMode - - -def clear_locals(scope): - inside_eq = False - local_vars = scope.copy() - for var in local_vars: - if inside_eq: - del scope[var] - if var == "eq": - inside_eq = True - - -# check that block does not modify self helper - - -@enforce_with_context -class EquivalencePrinciple: - contract_runner: Any # TODO: this should be of type ContractRunner but that raises a cyclic import error - - def __init__( - self, - result: dict, - principle: Optional[str], - comparative: bool = True, - ): - if result != {}: - raise Exception("result must be empty") - self.result = result - self.principle = principle - self.comparative = comparative - self.last_method = None - self.last_args = [] - - async def __aenter__(self): - return self - - async def __aexit__(self): - - # check eq principle - if self.principle == None: - return - - if ( - self.contract_runner.mode == ExecutionMode.VALIDATOR - and self.comparative == True - ): - llm_function = self.__get_llm_function() - eq_prompt = f"""Given the equivalence principle '{self.principle}', - decide whether the following two outputs can be considered equivalent. - - Leader's Output: {self.contract_runner.eq_outputs['leader'][str(self.contract_runner.eq_num - 1)]} - - Validator's Output: {self.result['validator_value']} - - Respond with: TRUE or FALSE""" - validation_response = await llm_function( - self.contract_runner.node_config, eq_prompt, None, None - ) - print("validation_response", validation_response) - # if TRUE => nothing, FALSE => fuera todo y un state de disagree - - async def get_webpage(self, url: str, format: str = "text"): - url_body = get_webpage_content(url, format) - final_response = url_body["response"] - return final_response - - async def call_llm(self, prompt: str): - llm_function = self.__get_llm_function() - final_response = await llm_function( - self.contract_runner.node_config, prompt, None, None - ) - return final_response - - def set(self, value): - if self.contract_runner.mode == ExecutionMode.LEADER: - self.result["output"] = value - self.contract_runner.eq_outputs[ExecutionMode.LEADER.value][ - str(self.contract_runner.eq_num) - ] = value - else: - self.result["validator_value"] = value - self.result["output"] = self.contract_runner.eq_outputs[ - ExecutionMode.LEADER.value - ][str(self.contract_runner.eq_num)] - self.contract_runner.eq_num += 1 - - def __get_llm_function(self): - return llms.get_llm_plugin( - self.contract_runner.node_config["plugin"], - self.contract_runner.node_config["plugin_config"], - ).call - - -async def call_llm_with_principle(prompt, eq_principle, comparative=True): - final_result = {} - async with EquivalencePrinciple( - result=final_result, - principle=eq_principle, - comparative=comparative, - ) as eq: - result = await eq.call_llm(prompt) - eq.set(result) - - return final_result["output"] - - -async def get_webpage_with_principle(url, eq_principle, format: str = "text"): - final_result = {} - async with EquivalencePrinciple( - result=final_result, - principle=eq_principle, - comparative=True, - ) as eq: - result = await eq.get_webpage(url, format) - eq.set(result) - return final_result diff --git a/backend/node/genvm/icontract.py b/backend/node/genvm/icontract.py deleted file mode 100644 index a29df31a7..000000000 --- a/backend/node/genvm/icontract.py +++ /dev/null @@ -1,13 +0,0 @@ -from abc import ABC, abstractmethod - - -class IContract(ABC): - @abstractmethod - def __init__(self): - """ - Constructor for the abstract class, which should be implemented by subclasses. - Raises an exception if an attempt is made to instantiate the abstract class directly. - """ - raise NotImplementedError( - "Constructor (__init__) must be implemented by subclass" - ) diff --git a/backend/node/genvm/origin/__init__.py b/backend/node/genvm/origin/__init__.py new file mode 100644 index 000000000..247ccd3e8 --- /dev/null +++ b/backend/node/genvm/origin/__init__.py @@ -0,0 +1 @@ +# This code is taken from genvm repo diff --git a/backend/node/genvm/origin/base_host.py b/backend/node/genvm/origin/base_host.py new file mode 100644 index 000000000..ff69c6bf0 --- /dev/null +++ b/backend/node/genvm/origin/base_host.py @@ -0,0 +1,280 @@ +import socket +import typing +import collections.abc +import asyncio +import os + +from dataclasses import dataclass + +from pathlib import Path + +if typing.TYPE_CHECKING: + from .host_fns import * +else: + from pathlib import Path + + exec(Path(__file__).parent.joinpath("host_fns.py").read_text()) + +ACCOUNT_ADDR_SIZE = 20 +GENERIC_ADDR_SIZE = 32 + + +class IHost(typing.Protocol): + async def loop_enter(self) -> socket.socket: ... + + async def get_calldata(self, /) -> bytes: ... + async def get_code(self, addr: bytes, /) -> bytes: ... + async def storage_read( + self, account: bytes, slot: bytes, index: int, le: int, / + ) -> bytes: ... + async def storage_write( + self, + account: bytes, + slot: bytes, + index: int, + got: collections.abc.Buffer, + /, + ) -> None: ... + async def consume_result( + self, type: ResultCode, data: collections.abc.Buffer, / + ) -> None: ... + async def get_leader_nondet_result(self, call_no: int, /) -> bytes | str | None: ... + async def post_nondet_result( + self, call_no: int, type: ResultCode, data: collections.abc.Buffer, / + ) -> None: ... + async def post_message( + self, gas: int, account: bytes, calldata: bytes, code: bytes, / + ) -> None: ... + async def consume_gas(self, gas: int, /) -> None: ... + + +async def host_loop(handler: IHost): + async_loop = asyncio.get_event_loop() + + sock = await handler.loop_enter() + + async def send_all(data: collections.abc.Buffer): + await async_loop.sock_sendall(sock, data) + + async def read_exact(le: int) -> bytes: + buf = bytearray([0] * le) + idx = 0 + while idx < le: + read = await async_loop.sock_recv_into(sock, memoryview(buf)[idx:le]) + if read == 0: + raise ConnectionResetError() + idx += read + return bytes(buf) + + async def recv_int(bytes: int = 4) -> int: + return int.from_bytes(await read_exact(bytes), byteorder="little", signed=False) + + async def send_int(i: int, bytes=4): + await send_all(int.to_bytes(i, bytes, byteorder="little", signed=False)) + + async def read_result() -> tuple[ResultCode, bytes]: + type = await recv_int(1) + le = await recv_int() + data = await read_exact(le) + return (ResultCode(type), data) + + while True: + meth_id = Methods(await recv_int(1)) + match meth_id: + case Methods.APPEND_CALLDATA: + cd = await handler.get_calldata() + await send_int(len(cd)) + await send_all(cd) + case Methods.GET_CODE: + addr = await read_exact(ACCOUNT_ADDR_SIZE) + code = await handler.get_code(addr) + await send_int(len(code)) + await send_all(code) + case Methods.STORAGE_READ: + account = await read_exact(ACCOUNT_ADDR_SIZE) + slot = await read_exact(GENERIC_ADDR_SIZE) + index = await recv_int() + le = await recv_int() + res = await handler.storage_read(account, slot, index, le) + assert len(res) == le + await send_all(res) + case Methods.STORAGE_WRITE: + account = await read_exact(ACCOUNT_ADDR_SIZE) + slot = await read_exact(GENERIC_ADDR_SIZE) + index = await recv_int() + le = await recv_int() + got = await read_exact(le) + await handler.storage_write(account, slot, index, got) + case Methods.CONSUME_RESULT: + await handler.consume_result(*await read_result()) + await send_all(b"\x00") + return + case Methods.GET_LEADER_NONDET_RESULT: + call_no = await recv_int() # call no + data = await handler.get_leader_nondet_result(call_no) + if data is None: + await send_all(bytes([ResultCode.NONE])) + elif isinstance(data, str): + await send_all(bytes([ResultCode.ROLLBACK])) + encoded = data.encode("utf-8") + await send_int(len(encoded)) + await send_all(encoded) + else: + await send_all(bytes([ResultCode.RETURN])) + await send_int(len(data)) + await send_all(data) + case Methods.POST_NONDET_RESULT: + call_no = await recv_int() + await handler.post_nondet_result(call_no, *await read_result()) + case Methods.POST_MESSAGE: + account = await read_exact(ACCOUNT_ADDR_SIZE) + gas = await recv_int(8) + calldata_len = await recv_int() + calldata = await read_exact(calldata_len) + code_len = await recv_int() + code = await read_exact(code_len) + await handler.post_message(gas, account, calldata, code) + case Methods.CONSUME_FUEL: + gas = await recv_int(8) + await handler.consume_gas(gas) + case x: + raise Exception(f"unknown method {x}") + + +@dataclass +class RunHostAndProgramRes: + stdout: str + stderr: str + exceptions: list[Exception] + + +from concurrent.futures import ProcessPoolExecutor + + +async def run_host_and_program( + handler: IHost, + program: list[Path | str], + *, + env=None, + cwd: Path | None = None, + exit_timeout=0.05, + deadline: float | None = None, +) -> RunHostAndProgramRes: + loop = asyncio.get_running_loop() + + async def connect_reader(fd): + reader = asyncio.StreamReader(loop=loop) + reader_proto = asyncio.StreamReaderProtocol(reader) + transport, _ = await loop.connect_read_pipe( + lambda: reader_proto, os.fdopen(fd, "rb") + ) + return reader, transport + + stdout_rfd, stdout_wfd = os.pipe() + stderr_rfd, stderr_wfd = os.pipe() + stdout_reader, stdout_transport = await connect_reader(stdout_rfd) + stderr_reader, stderr_transport = await connect_reader(stderr_rfd) + + process = await asyncio.create_subprocess_exec( + program[0], + *program[1:], + stdin=asyncio.subprocess.DEVNULL, + stdout=stdout_wfd, + stderr=stderr_wfd, + cwd=cwd, + env=env, + ) + os.close(stdout_wfd) + os.close(stderr_wfd) + if process.stdin is not None: + process.stdin.close() + + async def read_whole(reader, transport, put_to: list[bytes]): + try: + while True: + read = await reader.read(4096) + if read is None or len(read) == 0: + break + put_to.append(read) + finally: + try: + transport.close() + except OSError: + pass + await asyncio.sleep(0) + + async def wrap_host(): + await host_loop(handler) + + stdout, stderr = [], [] + + async def wrap_proc(): + await asyncio.gather( + read_whole(stdout_reader, stdout_transport, stdout), + read_whole(stderr_reader, stderr_transport, stderr), + process.wait(), + ) + + coro_loop = asyncio.ensure_future(wrap_host()) + coro_proc = asyncio.ensure_future(wrap_proc()) + + all_proc = [coro_loop, coro_proc] + if deadline is not None: + all_proc.append(asyncio.ensure_future(asyncio.sleep(deadline))) + + done, _pending = await asyncio.wait( + all_proc, + return_when=asyncio.FIRST_COMPLETED, + ) + + errors = [] + + for x in done: + try: + x.result() + except Exception as e: + errors.append(e) + + # coro_loop must finish first if everything succeeded + if not coro_loop.done() and deadline is None: + print("WARNING: genvm finished first") + coro_loop.cancel() + + exit_code_use = True + + if not coro_proc.done(): + # genvm is exiting, let it clean all the resources for a bit + await asyncio.wait( + [coro_proc, asyncio.ensure_future(asyncio.sleep(exit_timeout))], + return_when=asyncio.FIRST_COMPLETED, + ) + if not coro_proc.done(): + # genvm exit takes to long, maybe it hanged. Politely ask to quit and wait a bit + try: + process.terminate() + except: + pass + exit_code_use = False + await asyncio.wait( + [coro_proc, asyncio.ensure_future(asyncio.sleep(exit_timeout))], + return_when=asyncio.FIRST_COMPLETED, + ) + if not coro_proc.done(): + # genvm exit takes to long, forcefully quit it + try: + process.kill() + except: + pass + + await coro_proc + exit_code = await process.wait() + + if not coro_loop.done(): + coro_loop.cancel() + + if exit_code_use and exit_code != 0: + errors.append(Exception(f"exit code {exit_code} != 0")) + + return RunHostAndProgramRes( + b"".join(stdout).decode(), b"".join(stderr).decode(), errors + ) diff --git a/backend/node/genvm/calldata.py b/backend/node/genvm/origin/calldata.py similarity index 90% rename from backend/node/genvm/calldata.py rename to backend/node/genvm/origin/calldata.py index 8ed019838..3bc5d223f 100644 --- a/backend/node/genvm/calldata.py +++ b/backend/node/genvm/origin/calldata.py @@ -1,5 +1,7 @@ -from .types import Address +from ...types import Address from typing import Any +import collections.abc +import json BITS_IN_TYPE = 3 @@ -86,7 +88,7 @@ def impl(b): return bytes(mem) -def decode(mem0) -> Any: # type: ignore +def decode(mem0: collections.abc.Buffer) -> Any: # type: ignore mem: memoryview = memoryview(mem0) def read_uleb128() -> int: @@ -126,7 +128,7 @@ def impl() -> Any: elif typ == TYPE_BYTES: ret_bytes = mem[:code] mem = mem[code:] - return bytes(ret_bytes) + return ret_bytes elif typ == TYPE_STR: ret_str = mem[:code] mem = mem[code:] @@ -168,7 +170,7 @@ def impl(d: Any) -> None: elif d is False: buf.append("false") elif isinstance(d, str): - buf.append(f"{d!r}") + buf.append(json.dumps(d)) elif isinstance(d, bytes): buf.append("b#") buf.append(d.hex()) @@ -178,18 +180,26 @@ def impl(d: Any) -> None: buf.append("addr#") buf.append(d.as_bytes.hex()) elif isinstance(d, dict): + was_first = False buf.append("{") for k, v in d.items(): - buf.append(f"{k!r}") + if was_first: + buf.append(",") + else: + was_first = True + buf.append(json.dumps(k)) buf.append(":") impl(v) - buf.append(",") buf.append("}") elif isinstance(d, list): + was_first = False buf.append("[") for v in d: + if was_first: + buf.append(",") + else: + was_first = True impl(v) - buf.append(",") buf.append("]") else: raise Exception(f"can't encode {d} to calldata") diff --git a/backend/node/genvm/origin/host_fns.py b/backend/node/genvm/origin/host_fns.py new file mode 100644 index 000000000..8c74ef6f3 --- /dev/null +++ b/backend/node/genvm/origin/host_fns.py @@ -0,0 +1,20 @@ +from enum import IntEnum + + +class Methods(IntEnum): + APPEND_CALLDATA = 0 + GET_CODE = 1 + STORAGE_READ = 2 + STORAGE_WRITE = 3 + CONSUME_RESULT = 4 + GET_LEADER_NONDET_RESULT = 5 + POST_NONDET_RESULT = 6 + POST_MESSAGE = 7 + CONSUME_FUEL = 8 + + +class ResultCode(IntEnum): + RETURN = 0 + ROLLBACK = 1 + NONE = 2 + ERROR = 3 diff --git a/backend/node/genvm/std/__init__.py b/backend/node/genvm/std/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/node/genvm/std/models.py b/backend/node/genvm/std/models.py deleted file mode 100644 index d6e88ac99..000000000 --- a/backend/node/genvm/std/models.py +++ /dev/null @@ -1,10 +0,0 @@ -# backend/node/genvm/std/models.py - -from sentence_transformers import SentenceTransformer - -DEFAULT_MODEL_NAME = "paraphrase-MiniLM-L6-v2" - - -def get_model(model_name: str = None): - model = model_name if model_name is not None else DEFAULT_MODEL_NAME - return SentenceTransformer(model) diff --git a/backend/node/genvm/std/vector_store.py b/backend/node/genvm/std/vector_store.py deleted file mode 100644 index 4e941905b..000000000 --- a/backend/node/genvm/std/vector_store.py +++ /dev/null @@ -1,185 +0,0 @@ -# backend/node/genvm/std/vector_store.py - -from typing import Any, Optional -import numpy as np -from backend.node.genvm.std.models import get_model - - -class VectorStore: - def __init__(self, model_name: str = None): - """ - Initialize the VectorStore with a custom embedding model. - - Args: - model: A model with an encode method to generate embeddings. - """ - self.texts = {} # Dictionary to store texts - self.vector_data = {} # Dictionary to store vectors - self.metadata = {} # Dictionary to store metadata - self.model_name = model_name - - def add_text(self, text: str, metadata: Any, vector_id: Optional[int] = None): - """ - Add a new text to the store with its metadata. - - Args: - text (str): The text to be added. - metadata (Any): The metadata. - vector_id (int, optional): The ID for the vector. If not provided, a new ID will be generated. - """ - if vector_id is None: - vector_id = max(self.vector_data.keys(), default=0) + 1 - elif not isinstance(vector_id, int): - raise ValueError("vector_id must be an integer") - - if vector_id in self.vector_data: - raise ValueError(f"Vector ID {vector_id} already exists") - - model = get_model(self.model_name) - embedding = model.encode([text])[0] - - self.texts[vector_id] = text - self.vector_data[vector_id] = embedding - self.metadata[vector_id] = metadata - - return vector_id - - def get_closest_vector(self, text: str) -> tuple[float, int, str, Any, list[float]]: - """ - Get the closest vector to the given text along with the similarity percentage and metadata. - - Args: - text (str): The text for which to find the closest vector. - - Returns: - tuple: A tuple containing: - * the similarity percentage, - * the id, - * the text, - * the metadata, and - * the vector - """ - results = self.get_k_closest_vectors(text, k=1) - return results[0] if results else None - - def get_k_closest_vectors( - self, text: str, k: int = 5 - ) -> list[tuple[float, int, str, Any, list[float]]]: - """ - Get the closest k vectors to the given text along with the similarity percentages and metadata. - - Args: - text (str): The text for which to find the closest vectors. - k (int): The number of closest vectors to return. - - Returns: - list: A list of tuples, each containing: - * the similarity percentage, - * the id, - * the text, - * the metadata, and - * the vector - """ - if len(self.vector_data) == 0: - return [] - - model = get_model(self.model_name) - query_embedding = model.encode([text])[0] - - # Convert vector_data to a NumPy array for efficient calculations - all_embeddings = np.array(list(self.vector_data.values())) - all_ids = np.array(list(self.vector_data.keys())) - - # Compute cosine similarities - dot_products = np.dot(all_embeddings, query_embedding) - norms = np.linalg.norm(all_embeddings, axis=1) * np.linalg.norm(query_embedding) - similarities = dot_products / norms - - # Get the top k similarities - top_k_indices = similarities.argsort()[-k:][::-1] - results = [ - ( - float(similarities[i]), - int(all_ids[i]), - self.texts[all_ids[i]], - self.metadata[all_ids[i]], - self.vector_data[all_ids[i]].tolist(), - ) - for i in top_k_indices - ] - return results - - def update_text(self, vector_id: int, new_text: str, new_metadata: Any): - """ - Update the text and metadata of an existing vector. - - Args: - vector_id (int): The identifier of the vector to update. - new_text (str): The new text to update. - new_metadata (dict): The new metadata to update. - """ - if vector_id not in self.vector_data: - raise ValueError("Vector ID does not exist") - - model = get_model(self.model_name) - embedding = model.encode([new_text])[0] - self.texts[vector_id] = new_text - self.vector_data[vector_id] = embedding - self.metadata[vector_id] = new_metadata - - def delete_vector(self, vector_id: int): - """ - Delete a vector and its metadata from the store. - - Args: - vector_id (int): The identifier of the vector to delete. - """ - if vector_id in self.vector_data: - del self.texts[vector_id] - del self.vector_data[vector_id] - del self.metadata[vector_id] - else: - raise ValueError("Vector ID does not exist") - - def get_vector(self, vector_id: int) -> tuple[str, Any, list[float]]: - """ - Retrieve a vector and its metadata from the store. - - Args: - vector_id (int): The identifier of the vector to retrieve. - - Returns: - tuple: The text, the metadata, the vector. - """ - - vector_id = int(vector_id) - if vector_id in self.vector_data: - return ( - self.texts[vector_id], - self.metadata[vector_id], - self.vector_data[vector_id], - ) - else: - raise ValueError("Vector ID does not exist") - - def cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float: - """ - Calculate the cosine similarity between two vectors. - - Args: - a (numpy.ndarray): First vector. - b (numpy.ndarray): Second vector. - - Returns: - float: Cosine similarity between the vectors. - """ - dot_product = np.dot(a, b) - norm_a = np.linalg.norm(a) - norm_b = np.linalg.norm(b) - return dot_product / (norm_a * norm_b) - - def get_all_items(self) -> list[tuple[str, Any]]: - """ - Get all vectors and their metadata from the store. - """ - return [(self.texts[i], self.metadata[i]) for i in self.vector_data] diff --git a/backend/node/genvm/tests/code/bad_eq_implementation.py b/backend/node/genvm/tests/code/bad_eq_implementation.py deleted file mode 100644 index 44e5e2109..000000000 --- a/backend/node/genvm/tests/code/bad_eq_implementation.py +++ /dev/null @@ -1,14 +0,0 @@ -from genvm.base.equivalence_principle import EquivalencePrinciple - - -class A: - async def method1(self): - final_result = {} - async with EquivalencePrinciple( - result=final_result, - principle="The result['give_coin'] has to be exactly the same", - ) as eq: - eq.call_llm("something") - eq.set(result) - a = EquivalencePrinciple({"a": 1, "b": 2}) - return final_result["output"] diff --git a/backend/node/genvm/tests/code/bad_eq_modifys_self.py b/backend/node/genvm/tests/code/bad_eq_modifys_self.py deleted file mode 100644 index 256283984..000000000 --- a/backend/node/genvm/tests/code/bad_eq_modifys_self.py +++ /dev/null @@ -1,19 +0,0 @@ -from genvm.base.equivalence_principle import EquivalencePrinciple - - -class A: - - def __init__(self) -> None: - self.name = "dave" - - async def method1(self): - final_result = {} - async with EquivalencePrinciple( - result=final_result, - principle="The result['give_coin'] has to be exactly the same", - ) as eq: - result = await eq.call_llm("something") - self.name = "james" - eq.set(result) - age = 38 - return final_result["output"] diff --git a/backend/node/genvm/tests/code/bad_eq_variables_accessed_outside_of_block.py b/backend/node/genvm/tests/code/bad_eq_variables_accessed_outside_of_block.py deleted file mode 100644 index 11b32e3e4..000000000 --- a/backend/node/genvm/tests/code/bad_eq_variables_accessed_outside_of_block.py +++ /dev/null @@ -1,19 +0,0 @@ -from genvm.base.equivalence_principle import EquivalencePrinciple - - -class A: - - def __init__(self) -> None: - pass - - async def method1(self): - final_result = {} - async with EquivalencePrinciple( - result=final_result, - principle="The result['give_coin'] has to be exactly the same", - ) as eq: - result = await eq.call_llm("something") - name = "james" - eq.set(result) - name = "dave" - return final_result["output"] diff --git a/backend/node/genvm/tests/code/bad_eq_variables_accessed_outside_of_block_complex.py b/backend/node/genvm/tests/code/bad_eq_variables_accessed_outside_of_block_complex.py deleted file mode 100644 index 73c740140..000000000 --- a/backend/node/genvm/tests/code/bad_eq_variables_accessed_outside_of_block_complex.py +++ /dev/null @@ -1,31 +0,0 @@ -from genvm.base.equivalence_principle import EquivalencePrinciple - - -class A: - - def __init__(self) -> None: - pass - - async def method1(self): - final_result = {} - async with EquivalencePrinciple( - result=final_result, - principle="The result['give_coin'] has to be exactly the same", - ) as eq: - result = await eq.call_llm("something") - name = "james" - age = 38 - eq.set(result) - name = "dave" - another_final_result = {} - async with EquivalencePrinciple( - result=another_final_result, - principle="The result['give_coin'] has to be exactly the same", - ) as eq: - result = await eq.call_llm("something") - passing_through = final_result["output"] - location = "Spain" - eq.set(result) - age = 58 - location = "France" - return another_final_result["output"] diff --git a/backend/node/genvm/tests/code/working_code.py b/backend/node/genvm/tests/code/working_code.py deleted file mode 100644 index a8e080d08..000000000 --- a/backend/node/genvm/tests/code/working_code.py +++ /dev/null @@ -1,3 +0,0 @@ -class A: - def method1(self): - pass diff --git a/backend/node/genvm/tests/test_code_enforcement_check.py b/backend/node/genvm/tests/test_code_enforcement_check.py deleted file mode 100644 index b79a3fbe3..000000000 --- a/backend/node/genvm/tests/test_code_enforcement_check.py +++ /dev/null @@ -1,78 +0,0 @@ -# from backend.node.genvm.webpage_utils import code_enforcement_check - - -# def test_bad_code(): -# broken_code = """ -# class A: -# method1(self): -# pass -# """ -# result = code_enforcement_check(broken_code, "A") -# result["status"] == "error" -# result["message"] == "Your code is not valid Python code" -# result["data"] = [] - - -# def test_good_code(): -# with open("genvm/tests/code/working_code.py", "r") as f: -# working_code = f.read() -# result = code_enforcement_check(working_code, "A") -# result["status"] = "success" -# result["message"] = "" -# result["data"] = [] - - -# def test_good_code_class_does_not_exist(): -# with open("genvm/tests/code/working_code.py", "r") as f: -# class_name = "ThisClassDoesNotExist" -# working_code = f.read() -# result = code_enforcement_check(working_code, class_name) -# result["status"] = "error" -# result["message"] = f"The class {class_name} does not exist in the code" -# result["data"] = [] - - -# def test_catch_direct_instatntiation_of_eq_principle(): -# with open("genvm/tests/code/bad_eq_implementation.py", "r") as f: -# bad_eq_implementation = f.read() -# result = code_enforcement_check(bad_eq_implementation, "A") -# result["status"] == "error" -# result[ -# "message" -# ] == "You cannot directly instantiate the EquivalencePrinciple class" -# result["data"] = [] - - -# def test_inside_eq_with_block_modifys_self(): -# with open("genvm/tests/code/bad_eq_modifys_self.py", "r") as f: -# bad_eq_modifys_self = f.read() -# result = code_enforcement_check(bad_eq_modifys_self, "A") -# result["status"] == "error" -# result["message"] == "Self was modified inside an equivalence block" -# result["data"] = [15] - - -# def test_eq_block_variables_not_accessed_outsode_of_block(): -# with open( -# "genvm/tests/code/bad_eq_variables_accessed_outside_of_block.py", "r" -# ) as f: -# bad_eq_variables_accessed_outside_of_block = f.read() -# result = code_enforcement_check(bad_eq_variables_accessed_outside_of_block, "A") -# result["status"] == "error" -# result[ -# "message" -# ] == "Variables declared in the equivalence block are referenced outside of the equivalence block" -# result["data"] = [17] - - -# def test_eq_block_variables_not_accessed_outsode_of_block_complex(): -# with open( -# "genvm/tests/code/bad_eq_variables_accessed_outside_of_block_complex.py", "r" -# ) as f: -# bad_eq_variables_accessed_outside_of_block = f.read() -# result = code_enforcement_check(bad_eq_variables_accessed_outside_of_block, "B") -# result["status"] == "error" -# result[ -# "message" -# ] == "Variables declared in the equivalence block are referenced outside of the equivalence block" -# result["data"] = [18, 28, 29] diff --git a/backend/node/genvm/webpage_utils.py b/backend/node/genvm/webpage_utils.py deleted file mode 100644 index dacf36d28..000000000 --- a/backend/node/genvm/webpage_utils.py +++ /dev/null @@ -1,30 +0,0 @@ -import os -import requests -import re - - -def get_webpage_content(url: str, format: str = "text") -> str: - - payload = { - "jsonrpc": "2.0", - "method": "get_webpage", - "params": [url, format], - "id": 2, - } - - result = requests.post(webrequest_url() + "/api", json=payload).json() - - if result["result"]["status"] == "error": - raise Exception(result["result"]) - - return result["result"] - - -def webrequest_url(): - return ( - os.environ["WEBREQUESTPROTOCOL"] - + "://" - + os.environ["WEBREQUESTHOST"] - + ":" - + os.environ["WEBREQUESTPORT"] - ) diff --git a/backend/node/genvm/types.py b/backend/node/types.py similarity index 52% rename from backend/node/genvm/types.py rename to backend/node/types.py index d3e18f1f8..cfd4e0f64 100644 --- a/backend/node/genvm/types.py +++ b/backend/node/types.py @@ -3,24 +3,54 @@ from typing import Iterable, Optional import base64 +import collections.abc + +from eth_hash.auto import keccak + class Address: - SIZE = 32 - _as_bytes: bytes + SIZE = 20 + + __slots__ = ("_as_bytes", "_as_hex") - def __init__(self, val: str | bytes | memoryview): - if isinstance(val, memoryview): + _as_bytes: bytes + _as_hex: str | None + + def __init__(self, val: str | collections.abc.Buffer): + self._as_hex = None + if isinstance(val, str): + if len(val) == 2 + Address.SIZE * 2 and val.startswith("0x"): + val = bytes.fromhex(val[2:]) + elif len(val) > Address.SIZE: + val = base64.b64decode(val) + else: val = bytes(val) - if isinstance(val, str) or len(val) > Address.SIZE: - val = base64.b64decode(val) - if len(val) != Address.SIZE: - raise Exception("invalid address") + if not isinstance(val, bytes) or len(val) != Address.SIZE: + raise Exception(f"invalid address {val}") self._as_bytes = val @property def as_bytes(self) -> bytes: return self._as_bytes + @property + def as_hex(self) -> str: + if self._as_hex is None: + simple = self._as_bytes.hex() + low_up = keccak(simple.encode("ascii")).hex() + res = ["0", "x"] + for i in range(len(simple)): + if low_up[i] in ["0", "1", "2", "3", "4", "5", "6", "7"]: + res.append(simple[i]) + else: + res.append(simple[i].upper()) + self._as_hex = "".join(res) + return self._as_hex + + @property + def as_b64(self) -> str: + return str(base64.b64encode(self.as_bytes), encoding="ascii") + @property def as_int(self) -> int: return int.from_bytes(self._as_bytes, "little", signed=False) @@ -28,13 +58,29 @@ def as_int(self) -> int: def __hash__(self): return hash(self._as_bytes) + def __lt__(self, r): + assert isinstance(r, Address) + return self._as_bytes < r._as_bytes + + def __le__(self, r): + assert isinstance(r, Address) + return self._as_bytes <= r._as_bytes + def __eq__(self, r): if not isinstance(r, Address): return False return self._as_bytes == r._as_bytes + def __ge__(self, r): + assert isinstance(r, Address) + return self._as_bytes >= r._as_bytes + + def __gt__(self, r): + assert isinstance(r, Address) + return self._as_bytes > r._as_bytes + def __repr__(self) -> str: - return "addr:[" + "".join(["{:02x}".format(x) for x in self._as_bytes]) + "]" + return "addr#" + "".join(["{:02x}".format(x) for x in self._as_bytes]) class Vote(Enum): @@ -66,13 +112,14 @@ def to_dict(self): @dataclass class Receipt: + returned: bytes | None class_name: str calldata: bytes gas_used: int mode: ExecutionMode - contract_state: str + contract_state: dict[str, dict[str, str]] node_config: dict - eq_outputs: dict + eq_outputs: dict[int, str] execution_result: ExecutionResultStatus error: Optional[Exception] = None vote: Optional[Vote] = None @@ -82,6 +129,9 @@ def to_dict(self): return { "vote": self.vote.value, "execution_result": self.execution_result.value, + "returned": base64.b64encode( + self.returned if self.returned is not None else b"" + ).decode("ascii"), "class_name": self.class_name, "calldata": str(base64.b64encode(self.calldata), encoding="ascii"), "gas_used": self.gas_used, diff --git a/backend/protocol_rpc/endpoint_generator.py b/backend/protocol_rpc/endpoint_generator.py index 19ff3ba90..304988ba4 100644 --- a/backend/protocol_rpc/endpoint_generator.py +++ b/backend/protocol_rpc/endpoint_generator.py @@ -3,6 +3,7 @@ import inspect from typing import Callable from flask_jsonrpc import JSONRPC +import flask from flask_jsonrpc.exceptions import JSONRPCError from functools import partial, wraps @@ -38,9 +39,11 @@ def generate_rpc_endpoint( partial_function.__annotations__ = get_function_annotations(partial_function) @wraps(partial_function) - def endpoint(*endpoint_args, **endpoint_kwargs): + async def endpoint(*endpoint_args, **endpoint_kwargs): try: result = partial_function(*endpoint_args, **endpoint_kwargs) + if hasattr(result, "__await__"): + result = await result return _serialize(result) except Exception as e: diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index edf9e316c..bc7507b38 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -7,16 +7,17 @@ from sqlalchemy import Table from sqlalchemy.orm import Session +import backend.node.genvm.origin.calldata as genvm_calldata from backend.database_handler.contract_snapshot import ContractSnapshot from backend.database_handler.llm_providers import LLMProviderRegistry from backend.database_handler.models import Base -from backend.domain.types import LLMProvider, Validator +from backend.domain.types import LLMProvider, Validator, TransactionType from backend.node.create_nodes.providers import ( get_default_provider_for, validate_provider, ) -from backend.node.genvm.llms import get_llm_plugin +from backend.llms import get_llm_plugin from backend.protocol_rpc.message_handler.base import ( MessageHandler, get_client_session_id, @@ -42,7 +43,7 @@ TransactionsProcessor, ) from backend.node.base import Node -from backend.node.genvm.types import ExecutionMode +from backend.node.types import ExecutionMode, ExecutionResultStatus from flask import request @@ -81,8 +82,10 @@ def reset_defaults_llm_providers(llm_provider_registry: LLMProviderRegistry) -> llm_provider_registry.reset_defaults() -def get_providers_and_models(llm_provider_registry: LLMProviderRegistry) -> list[dict]: - return llm_provider_registry.get_all_dict() +async def get_providers_and_models( + llm_provider_registry: LLMProviderRegistry, +) -> list[dict]: + return await llm_provider_registry.get_all_dict() def add_provider(llm_provider_registry: LLMProviderRegistry, params: dict) -> int: @@ -153,23 +156,25 @@ def create_validator( ) -def create_random_validator( +async def create_random_validator( validators_registry: ValidatorsRegistry, accounts_manager: AccountsManager, llm_provider_registry: LLMProviderRegistry, stake: int, ) -> dict: - return create_random_validators( - validators_registry, - accounts_manager, - llm_provider_registry, - 1, - stake, - stake, + return ( + await create_random_validators( + validators_registry, + accounts_manager, + llm_provider_registry, + 1, + stake, + stake, + ) )[0] -def create_random_validators( +async def create_random_validators( validators_registry: ValidatorsRegistry, accounts_manager: AccountsManager, llm_provider_registry: LLMProviderRegistry, @@ -182,7 +187,7 @@ def create_random_validators( limit_providers = limit_providers or [] limit_models = limit_models or [] - details = random_validator_config( + details = await random_validator_config( llm_provider_registry.get_all, get_llm_plugin, limit_providers=set(limit_providers), @@ -278,7 +283,7 @@ def count_validators(validators_registry: ValidatorsRegistry) -> int: ####### GEN ENDPOINTS ####### -def get_contract_schema( +async def get_contract_schema( accounts_manager: AccountsManager, msg_handler: MessageHandler, contract_address: str, @@ -314,10 +319,11 @@ def get_contract_schema( msg_handler=msg_handler.with_client_session(get_client_session_id()), contract_snapshot_factory=None, ) - return node.get_contract_schema(contract_account["data"]["code"]) + schema = await node.get_contract_schema(contract_account["data"]["code"]) + return json.loads(schema) -def get_contract_schema_for_code( +async def get_contract_schema_for_code( msg_handler: MessageHandler, contract_code: str ) -> dict: node = Node( # Mock node just to get the data from the GenVM @@ -338,7 +344,7 @@ def get_contract_schema_for_code( msg_handler=msg_handler.with_client_session(get_client_session_id()), contract_snapshot_factory=None, ) - return node.get_contract_schema(contract_code) + return json.loads(await node.get_contract_schema(contract_code)) ####### ETH ENDPOINTS ####### @@ -361,11 +367,11 @@ def get_transaction_count( def get_transaction_by_hash( transactions_processor: TransactionsProcessor, transaction_hash: str -) -> dict: +) -> dict | None: return transactions_processor.get_transaction_by_hash(transaction_hash) -def call( +async def call( session: Session, accounts_manager: AccountsManager, msg_handler: MessageHandler, @@ -384,9 +390,9 @@ def call( decoded_data = decode_method_call_data(data) - contract_account = accounts_manager.get_account_or_fail(to_address) node = Node( # Mock node just to get the data from the GenVM - contract_snapshot=None, + contract_snapshot=ContractSnapshot(to_address, session), + contract_snapshot_factory=partial(ContractSnapshot, session=session), validator_mode=ExecutionMode.LEADER, validator=Validator( address="", @@ -401,14 +407,23 @@ def call( ), leader_receipt=None, msg_handler=msg_handler.with_client_session(get_client_session_id()), - contract_snapshot_factory=partial(ContractSnapshot, session=session), ) - return node.get_contract_data( - code=contract_account["data"]["code"], - state=contract_account["data"]["state"], + receipt = await node.get_contract_data( + from_address="0x" + "00" * 20, calldata=decoded_data.calldata, ) + # FIXME #621 + # this place is defective because + # - write methods can return as well and it is not supported at all in the UI + # - no calldata decoding should happen here, the frontend (caller) should be responsible for that + if receipt.execution_result != ExecutionResultStatus.SUCCESS: + return receipt.to_dict() + try: + return genvm_calldata.to_str(genvm_calldata.decode(receipt.returned)) + except: + pass + return receipt.to_dict() def send_raw_transaction( @@ -443,11 +458,11 @@ def send_raw_transaction( transaction_data = {} result = {} - transaction_type = None + transaction_type: TransactionType leader_only = False if not decoded_transaction.data: # Sending value transaction - transaction_type = 0 + transaction_type = TransactionType.SEND elif not to_address or to_address == "0x": # Contract deployment if value > 0: @@ -462,8 +477,8 @@ def send_raw_transaction( "calldata": decoded_data.calldata, } result["contract_address"] = new_contract_address - to_address = None - transaction_type = 1 + to_address = new_contract_address + transaction_type = TransactionType.DEPLOY_CONTRACT leader_only = decoded_data.leader_only else: # Contract Call @@ -473,7 +488,7 @@ def send_raw_transaction( ) decoded_data = decode_method_call_data(decoded_transaction.data) transaction_data = {"calldata": decoded_data.calldata} - transaction_type = 2 + transaction_type = TransactionType.RUN_CONTRACT leader_only = decoded_data.leader_only # Insert transaction into the database @@ -482,7 +497,7 @@ def send_raw_transaction( to_address, transaction_data, value, - transaction_type, + transaction_type.value, nonce, leader_only, ) diff --git a/backend/protocol_rpc/message_handler/base.py b/backend/protocol_rpc/message_handler/base.py index 0eb33c5d3..9eef6dd0a 100644 --- a/backend/protocol_rpc/message_handler/base.py +++ b/backend/protocol_rpc/message_handler/base.py @@ -7,6 +7,7 @@ from flask import request from loguru import logger import sys +import asyncio from backend.protocol_rpc.message_handler.types import LogEvent from flask_socketio import SocketIO @@ -95,7 +96,7 @@ def send_message(self, log_event: LogEvent, log_to_terminal: bool = True): def log_endpoint_info_wrapper(msg_handler: MessageHandler, config: GlobalConfiguration): def decorator(func): @wraps(func) - def wrapper(*args, **kwargs): + async def wrapper(*args, **kwargs): shouldPrintInfoLogs = ( func.__name__ not in config.get_disabled_info_logs_endpoints() ) @@ -112,6 +113,8 @@ def wrapper(*args, **kwargs): ) try: result = func(*args, **kwargs) + if hasattr(result, "__await__"): + result = await result if shouldPrintInfoLogs: msg_handler.send_message( LogEvent( diff --git a/backend/protocol_rpc/requirements.txt b/backend/protocol_rpc/requirements.txt index 64470883b..a1c6cfccf 100644 --- a/backend/protocol_rpc/requirements.txt +++ b/backend/protocol_rpc/requirements.txt @@ -13,14 +13,11 @@ pytest-asyncio==0.24.0 colorama==0.4.6 debugpy==1.8.5 aiohttp==3.10.5 -openai==1.47.0 -anthropic==0.34.2 SQLAlchemy[asyncio]==2.0.35 alembic==1.13.2 eth-account==0.13.3 eth-utils==5.0.0 -sentence-transformers==3.1.1 Flask-SQLAlchemy==3.1.1 jsf==0.11.2 jsonschema==4.23.0 -loguru==0.7.2 \ No newline at end of file +loguru==0.7.2 diff --git a/docker-compose.yml b/docker-compose.yml index a6e1d6890..e48796677 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -29,6 +29,9 @@ services: # TODO: remove this in production - PYTHONUNBUFFERED=1 - RPCDEBUGPORT=${RPCDEBUGPORT} + - WEBREQUESTPORT=${WEBREQUESTPORT} + - WEBREQUESTHOST=${WEBREQUESTHOST} + - WEBREQUESTPROTOCOL=${WEBREQUESTPROTOCOL} ports: - "${RPCPORT}:${RPCPORT}" - "${RPCDEBUGPORT}:${RPCDEBUGPORT}" @@ -36,10 +39,10 @@ services: - ./.env:/app/.env - ./backend:/app/backend depends_on: - ollama: - condition: service_started database-migration: condition: service_completed_successfully + webrequest: + condition: service_healthy expose: - "${RPCPORT}" @@ -49,13 +52,19 @@ services: dockerfile: ./docker/Dockerfile.webrequest shm_size: 2gb volumes: + - ./.env:/app/webrequest/.env - ./webrequest:/app/webrequest environment: - FLASK_SERVER_PORT=${WEBREQUESTPORT} + - WEBREQUESTSELENIUMPORT=${WEBREQUESTSELENIUMPORT} # TODO: remove this in production - PYTHONUNBUFFERED=1 expose: - - "${WEBREQUESTPORT}" + - "${WEBREQUESTPORT}:${WEBREQUESTPORT}" + - "${WEBREQUESTSELENIUMPORT}:${WEBREQUESTSELENIUMPORT}" + depends_on: + ollama: + condition: service_started ollama: image: ollama/ollama:0.3.11 @@ -91,6 +100,11 @@ services: dockerfile: docker/Dockerfile.database-migration environment: - DB_URL=postgresql://${DBUSER}:${DBUSER}@postgres/${DBNAME} + - WEBREQUESTPORT=${WEBREQUESTPORT} + - WEBREQUESTHOST=${WEBREQUESTHOST} + - WEBREQUESTPROTOCOL=${WEBREQUESTPROTOCOL} depends_on: postgres: condition: service_healthy + webrequest: + condition: service_healthy diff --git a/docker/Dockerfile.backend b/docker/Dockerfile.backend index 828c34b12..910c01759 100644 --- a/docker/Dockerfile.backend +++ b/docker/Dockerfile.backend @@ -1,23 +1,54 @@ FROM python:3.12.6-slim AS base +ARG TARGETPLATFORM + ARG path=/app WORKDIR $path ADD backend/protocol_rpc/requirements.txt backend/protocol_rpc/requirements.txt RUN --mount=type=cache,target=/root/.cache/pip \ - pip install --cache-dir=/root/.cache/pip torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu \ - && pip install --cache-dir=/root/.cache/pip -r backend/protocol_rpc/requirements.txt + pip install --cache-dir=/root/.cache/pip -r backend/protocol_rpc/requirements.txt RUN groupadd -r backend-group \ && useradd -r -g backend-group backend-user \ && mkdir -p /home/backend-user/.cache/huggingface \ && chown -R backend-user:backend-group /home/backend-user \ - && chown -R backend-user:backend-group $path + && chown -R backend-user:backend-group $path \ + && mkdir -p /genvm ENV PYTHONPATH "${PYTHONPATH}:/${path}" ENV FLASK_APP backend/protocol_rpc/server.py ENV HUGGINGFACE_HUB_CACHE /home/backend-user/.cache/huggingface +SHELL ["/bin/bash", "-c"] +RUN apt-get update -y && apt-get install -y --no-install-recommends curl unzip xz-utils && apt-get clean && rm -rf /var/lib/apt/lists/* +ENV RUST_BACKTRACE=1 + +# FIXME(kp2pml30): rewrite when genvm becomes public +RUN cd /genvm \ + && if [[ "$TARGETPLATFORM" == "linux/amd64" ]] ; \ + then \ + URL="https://storage.googleapis.com/gh-af/genvm_executor_be2586c6a5bb77362f2fb6bc51aa7050dfca8fda_20241120_124155/genvm_executor.zip" ; \ + elif [[ "$TARGETPLATFORM" == "linux/arm64" ]] ; \ + then \ + URL="https://storage.googleapis.com/gh-af/genvm_executor_be2586c6a5bb77362f2fb6bc51aa7050dfca8fda_20241120_124315/genvm_executor.zip" && \ + DIR="$(pwd)" && \ + mkdir -p /openssl && cd /openssl && \ + curl -L --fail-with-body -H 'Accept: application/octet-stream' -o "openssl.tar.xz" "http://mirror.archlinuxarm.org/aarch64/core/openssl-1.0-1.0.2.u-1-aarch64.pkg.tar.xz" && \ + tar -xf openssl.tar.xz && \ + cp usr/lib/*.so* /usr/lib/ && \ + cd "$DIR" || exit 1 ; \ + else echo "Sorry, $TARGETPLATFORM is not supported yet" ; exit 1 ; \ + fi \ + && curl -L --fail-with-body -H 'Accept: application/octet-stream' -o executor.zip "$URL" \ + && curl -L --fail-with-body -H 'Accept: application/octet-stream' -o runners.zip "https://storage.googleapis.com/gh-af/genvm_runners_be2586c6a5bb77362f2fb6bc51aa7050dfca8fda_20241120_125238/runners.zip" \ + && unzip runners.zip && unzip executor.zip && rm runners.zip executor.zip \ + && ls -R . \ + && cd "$path" \ + && true + +RUN su - backend-user -c '/genvm/bin/genvm precompile' + COPY ../.env . COPY backend $path/backend diff --git a/docker/Dockerfile.webrequest b/docker/Dockerfile.webrequest index 370bf7e9a..8019a2140 100644 --- a/docker/Dockerfile.webrequest +++ b/docker/Dockerfile.webrequest @@ -1,26 +1,38 @@ FROM --platform=linux/amd64 python:3.12.6-slim -RUN apt-get update && \ - apt-get install -y --no-install-recommends wget gnupg && \ - wget -q -O - https://dl-ssl.google.com/linux/linux_signing_key.pub | apt-key add - && \ - echo "deb [arch=amd64] http://dl.google.com/linux/chrome/deb/ stable main" >> /etc/apt/sources.list.d/google.list && \ +ENV base=/app +ENV path=webrequest +ENV selpath=selenium + +# These are all the requirements to run chrome and it's webdriver + selenium +# see https://source.chromium.org/chromium/chromium/src/+/main:chrome/installer/linux/debian/dist_package_versions.json +RUN mkdir -p "${base}/${path}" && \ + mkdir -p "${base}/${selpath}" && \ apt-get update && \ - apt-get install -y --no-install-recommends google-chrome-stable && \ + apt-get install -y --no-install-recommends wget gnupg ca-certificates unzip curl \ + openjdk-17-jre-headless fonts-liberation libasound2 libatk-bridge2.0-0 libatk1.0-0 libc6 libcairo2 libcups2 libdbus-1-3 libexpat1 libfontconfig1 libgbm1 libgcc1 libglib2.0-0 libgtk-3-0 libnspr4 libnss3 libpango-1.0-0 libpangocairo-1.0-0 libstdc++6 libx11-6 libx11-xcb1 libxcb1 libxcomposite1 libxcursor1 libxdamage1 libxext6 libxfixes3 libxi6 libxrandr2 libxrender1 libxss1 libxtst6 lsb-release xdg-utils && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* && \ - useradd -m appuser + useradd -m appuser && \ + mkdir -p /selenium && \ + cd "${base}/${selpath}" && \ + wget -q https://github.com/SeleniumHQ/selenium/releases/download/selenium-4.24.0/selenium-server-4.24.0.jar && \ + true -ENV base=/app -ENV path=webrequest ENV PYTHONPATH="${base}/${path}" -ENV PATH="${PATH}:${base}" -RUN mkdir -p $base && chown -R appuser:appuser $base +ENV PATH=":${PATH}" +RUN chown -R appuser:appuser $base USER appuser WORKDIR $base COPY --chown=appuser:appuser $path $base/$path RUN --mount=type=cache,target=/root/.cache/pip \ - pip install --cache-dir=/root/.cache/pip -r $path/requirements.txt -WORKDIR $path -COPY --chown=appuser:appuser ../.env . -CMD ["python3", "server.py"] +pip install --cache-dir=/root/.cache/pip -r $path/requirements.txt + +# TODO: download and cache webdriver here + +HEALTHCHECK --interval=10s --timeout=5s --retries=15 --start-period=5s CMD curl "http://0.0.0.0:$WEBREQUESTSELENIUMPORT/status" | grep -P "\"ready\"\s*:\s*true\b" + +COPY webrequest/docker-entry.sh "${base}/webrequest-docker-entry.sh" +WORKDIR $base/$path +CMD "${base}/webrequest-docker-entry.sh" diff --git a/examples/contracts/_hello_world.py b/examples/contracts/_hello_world.py new file mode 100644 index 000000000..57104b387 --- /dev/null +++ b/examples/contracts/_hello_world.py @@ -0,0 +1,42 @@ +# { "Depends": "py-genlayer:test" } + +# Always put above line as first in the contract file +# Upon release it will be changed from `:test` to `:` and library will be frozen forever + +# this imports all types into globals and `genlayer.std` as `gl` (will be imported lazily on first access) +from genlayer import * + + +# use @gl.contract annotation to mark class as a contract. There can be only one contract class +@gl.contract +class Storage: + # below you must declare all class fields that you are going to use + # this fields persist between contract calls + storage_str: str + storage_int: u256 # NOTE: `int`s are intentionally not supported! in future `bigint` int alias will be introduced + + # all public methods must have type annotations to be user friendly + + # constructor, must not be public + def __init__(self, initial_str_storage: str): + self.storage_str = initial_str_storage + + # methods that don't modify anything must be annotated with view + @gl.public.view + def get_storage(self) -> str: + return self.storage_str + + # keyword arguments are supported as well, however, they should not be mixed with positional, + # in python terms it means that function has following signature (note `/`) + # def debug(self, x: int, /, *, flag: bool) -> str: + @gl.public.view + def debug(self, x: int, *, flag: bool) -> None: + # you can use prints for debugging (even in write methods and non deterministic blocks) + # however, stdout doesn't go through consensus and is meant for debug use only + # it also may be absent in the actual node + print(f"debug: {self.storage_int}, {x}, {flag}") + + # methods that modify storage must be annotated with write + @gl.public.write + def update_storage(self, new_storage: str) -> None: + self.storage_str = new_storage diff --git a/examples/contracts/football_prediction_market.py b/examples/contracts/football_prediction_market.py index 61a494431..745efca16 100644 --- a/examples/contracts/football_prediction_market.py +++ b/examples/contracts/football_prediction_market.py @@ -1,9 +1,19 @@ +# { "Depends": "py-genlayer:test" } + +from genlayer import * + import json -from backend.node.genvm.icontract import IContract -from backend.node.genvm.equivalence_principle import EquivalencePrinciple +import typing + +@gl.contract +class PredictionMarket: + has_resolved: bool + game_date: str + team1: str + team2: str + resolution_url: str -class PredictionMarket(IContract): def __init__(self, game_date: str, team1: str, team2: str): """ Initializes a new instance of the prediction market with the specified game date and teams. @@ -28,18 +38,14 @@ def __init__(self, game_date: str, team1: str, team2: str): self.team1 = team1 self.team2 = team2 - async def resolve(self) -> None: + @gl.public.write + def resolve(self) -> typing.Any: if self.has_resolved: return "Already resolved" - final_result = {} - async with EquivalencePrinciple( - result=final_result, - principle="The score and the winner has to be exactly the same", - comparative=True, - ) as eq: - web_data = await eq.get_webpage(self.resolution_url, "text") + def nondet() -> str: + web_data = gl.get_webpage(self.resolution_url, mode="text") print(web_data) task = f"""In the following web page, find the winning team in a matchup between the following teams: @@ -61,13 +67,13 @@ async def resolve(self) -> None: It is mandatory that you respond only using the JSON format above, nothing else. Don't include any other words or characters, your output must be only JSON without any formatting prefix or suffix. - This result should be perfectly parseable by a JSON parser without errors. + This result should be perfectly parsable by a JSON parser without errors. """ - result = await eq.call_llm(task) + result = gl.exec_prompt(task).replace("```json", "").replace("```", "") print(result) - eq.set(result) + return json.dumps(json.loads(result), sort_keys=True) - result_json = json.loads(final_result["output"]) + result_json = json.loads(gl.eq_principle_strict_eq(nondet)) if result_json["winner"] > -1: self.has_resolved = True diff --git a/examples/contracts/llm_erc20.py b/examples/contracts/llm_erc20.py index a5696d68e..5ded2497c 100644 --- a/examples/contracts/llm_erc20.py +++ b/examples/contracts/llm_erc20.py @@ -1,21 +1,26 @@ +# { "Depends": "py-genlayer:test" } + import json -from backend.node.genvm.icontract import IContract -from backend.node.genvm.equivalence_principle import EquivalencePrinciple +from genlayer import * + + +@gl.contract +class LlmErc20: + balances: TreeMap[Address, u256] -class LlmErc20(IContract): def __init__(self, total_supply: int) -> None: - self.balances = {} - self.balances[contract_runner.from_address] = total_supply + self.balances[gl.message.sender_account] = u256(total_supply) - async def transfer(self, amount: int, to_address: str) -> None: + @gl.public.write + def transfer(self, amount: int, to_address: str) -> None: prompt = f""" You keep track of transactions between users and their balance in coins. The current balance for all users in JSON format is: -{json.dumps(self.balances)} +{json.dumps(self.get_balances())} The transaction to compute is: {{ -sender: "{contract_runner.from_address}", -recipient: "{to_address}", +sender: "{gl.message.sender_account.as_hex}", +recipient: "{Address(to_address).as_hex}", amount: {amount}, }} @@ -32,28 +37,30 @@ async def transfer(self, amount: int, to_address: str) -> None: It is mandatory that you respond only using the JSON format above, nothing else. Don't include any other words or characters, your output must be only JSON without any formatting prefix or suffix. -This result should be perfectly parseable by a JSON parser without errors.""" +This result should be perfectly parsable by a JSON parser without errors.""" print(prompt) - final_result = {} - async with EquivalencePrinciple( - result=final_result, - principle="""The new_balance of the sender should have decreased + + def run(): + res = gl.exec_prompt(prompt) + res = res.replace("```json", "").replace("```", "") + return res + + final_result = gl.eq_principle_prompt_comparative( + run, + """The new_balance of the sender should have decreased in the amount sent and the new_balance of the receiver should have increased by the amount sent. Also, the total sum of all balances should have remain the same before and after the transaction""", - comparative=True, - ) as eq: - result = await eq.call_llm(prompt) - result_clean = result.replace("True", "true").replace("False", "false") - eq.set(result_clean) - + ) print("final_result: ", final_result) - print("final_result[output]: ", final_result["output"]) - result_json = json.loads(final_result["output"]) - self.balances = result_json["updated_balances"] + result_json = json.loads(final_result) + for k, v in result_json["updated_balances"].items(): + self.balances[Address(k)] = v + @gl.public.view def get_balances(self) -> dict[str, int]: - return self.balances + return {k.as_hex: v for k, v in self.balances.items()} + @gl.public.view def get_balance_of(self, address: str) -> int: - return self.balances.get(address, 0) + return self.balances.get(Address(address), 0) diff --git a/examples/contracts/log_indexer.py b/examples/contracts/log_indexer.py index b2f3bc4a0..f4f40c984 100644 --- a/examples/contracts/log_indexer.py +++ b/examples/contracts/log_indexer.py @@ -1,37 +1,66 @@ -from backend.node.genvm.icontract import IContract -from backend.node.genvm.std.vector_store import VectorStore +# { +# "Seq": [ +# { "Depends": "py-lib-genlayermodelwrappers:test" }, +# { "Depends": "py-genlayer:test" } +# ] +# } + +from genlayer import * +import genlayermodelwrappers +import numpy as np +from dataclasses import dataclass + + +@dataclass +class StoreValue: + log_id: u256 + text: str # contract class -class LogIndexer(IContract): +@gl.contract +class LogIndexer: + vector_store: VecDB[np.float32, typing.Literal[384], StoreValue] - # constructor def __init__(self): - self.vector_store = VectorStore() + pass - # read methods must start with get_ - def get_closest_vector(self, text: str) -> dict: - result = self.vector_store.get_closest_vector(text) - if result is None: + def get_emb_generator(self): + return genlayermodelwrappers.SentenceTransformer("all-MiniLM-L6-v2") + + def get_emb( + self, txt: str + ) -> np.ndarray[tuple[typing.Literal[384]], np.dtypes.Float32DType]: + return self.get_emb_generator()(txt) + + @gl.public.view + def get_closest_vector(self, text: str) -> dict | None: + emb = self.get_emb(text) + result = list(self.vector_store.knn(emb, 1)) + if len(result) == 0: return None + result = result[0] return { - "similarity": result[0], - "id": result[1], - "text": result[2], - "metadata": result[3], - "vector": result[4], + "vector": list(str(x) for x in result.key), + "similarity": str(1 - result.distance), + "id": result.value.log_id, + "text": result.value.text, } - # write method + @gl.public.write def add_log(self, log: str, log_id: int) -> None: - self.vector_store.add_text(log, {"log_id": log_id}) + emb = self.get_emb(log) + self.vector_store.insert(emb, StoreValue(text=log, log_id=u256(log_id))) - def update_log(self, id: int, log: str, log_id: int) -> None: - self.vector_store.update_text(id, log, {"log_id": log_id}) + @gl.public.write + def update_log(self, log_id: int, log: str) -> None: + emb = self.get_emb(log) + for elem in self.vector_store.knn(emb, 2): + if elem.value.text == log: + elem.value.log_id = u256(log_id) + @gl.public.write def remove_log(self, id: int) -> None: - self.vector_store.delete_vector(id) - - def get_vector_metadata(self, id: int) -> None: - _, metadata, _ = self.vector_store.get_vector(id) - return metadata + for el in self.vector_store: + if el.value.log_id == id: + el.remove() diff --git a/examples/contracts/storage.py b/examples/contracts/storage.py index 55e7948b5..3fdeac8d7 100644 --- a/examples/contracts/storage.py +++ b/examples/contracts/storage.py @@ -1,17 +1,23 @@ -from backend.node.genvm.icontract import IContract +# { "Depends": "py-genlayer:test" } + +from genlayer import * # contract class -class Storage(IContract): +@gl.contract +class Storage: + storage: str # constructor def __init__(self, initial_storage: str): self.storage = initial_storage - # read methods must start with get_ + # read methods must be annotated with view + @gl.public.view def get_storage(self) -> str: return self.storage # write method + @gl.public.write def update_storage(self, new_storage: str) -> None: self.storage = new_storage diff --git a/examples/contracts/user_storage.py b/examples/contracts/user_storage.py index 90478a2d2..dcc3520c4 100644 --- a/examples/contracts/user_storage.py +++ b/examples/contracts/user_storage.py @@ -1,20 +1,25 @@ -from backend.node.genvm.icontract import IContract +# { "Depends": "py-genlayer:test" } +from genlayer import * -# contract class -class UserStorage(IContract): + +@gl.contract +class UserStorage: + storage: TreeMap[Address, str] # constructor def __init__(self): - self.storage = {} + pass - # read methods must start with get_ - def get_complete_storage(self) -> dict: - return self.storage + # read methods must be annotated + @gl.public.view + def get_complete_storage(self) -> dict[str, str]: + return {k.as_hex: v for k, v in self.storage.items()} + @gl.public.view def get_account_storage(self, account_address: str) -> str: - return self.storage[account_address] + return self.storage[Address(account_address)] - # write method + @gl.public.write def update_storage(self, new_storage: str) -> None: - self.storage[contract_runner.from_address] = new_storage + self.storage[gl.message.sender_account] = new_storage diff --git a/examples/contracts/wizard_of_coin.py b/examples/contracts/wizard_of_coin.py index f4a60593d..c74da1817 100644 --- a/examples/contracts/wizard_of_coin.py +++ b/examples/contracts/wizard_of_coin.py @@ -1,13 +1,20 @@ +# { "Depends": "py-genlayer:test" } +from genlayer import * + import json -from backend.node.genvm.icontract import IContract -from backend.node.genvm.equivalence_principle import call_llm_with_principle -class WizardOfCoin(IContract): +@gl.contract +class WizardOfCoin: + have_coin: bool + def __init__(self, have_coin: bool): self.have_coin = have_coin - async def ask_for_coin(self, request: str) -> None: + @gl.public.write + def ask_for_coin(self, request: str) -> None: + if not self.have_coin: + return prompt = f""" You are a wizard, and you hold a magical coin. Many adventurers will come and try to get you to give them the coin. @@ -30,17 +37,18 @@ async def ask_for_coin(self, request: str) -> None: your output must be only JSON without any formatting prefix or suffix. This result should be perfectly parseable by a JSON parser without errors. """ - if self.have_coin: - # that must be awaited - result = await call_llm_with_principle( - prompt, - eq_principle="The result['give_coin'] has to be exactly the same", - ) - result_clean = result.replace("True", "true").replace("False", "false") - output = json.loads(result_clean) - - if output["give_coin"] is True: - self.have_coin = False + def nondet(): + res = gl.exec_prompt(prompt) + res = res.replace("```json", "").replace("```", "") + print(res) + dat = json.loads(res) + return dat["give_coin"] + + result = gl.eq_principle_strict_eq(nondet) + assert isinstance(result, bool) + self.have_coin = result + + @gl.public.view def get_have_coin(self) -> bool: return self.have_coin diff --git a/frontend/src/components/Simulator/ConstructorParameters.vue b/frontend/src/components/Simulator/ConstructorParameters.vue index 17f88f13c..1f1842368 100644 --- a/frontend/src/components/Simulator/ConstructorParameters.vue +++ b/frontend/src/components/Simulator/ConstructorParameters.vue @@ -1,11 +1,11 @@