Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions hathor/nanocontracts/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from collections import defaultdict
from itertools import chain
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Sequence, final
from typing import TYPE_CHECKING, Any, Sequence, assert_never, final

from hathor.crypto.util import get_address_b58_from_bytes
from hathor.nanocontracts.exception import NCFail, NCInvalidContext
from hathor.nanocontracts.types import Address, ContractId, NCAction, TokenUid
from hathor.nanocontracts.types import Address, CallerId, ContractId, NCAction, TokenUid
from hathor.nanocontracts.vertex_data import VertexData
from hathor.transaction.exceptions import TxValidationError

Expand All @@ -39,17 +39,17 @@ class Context:
Deposits and withdrawals are grouped by token. Note that it is impossible
to have both a deposit and a withdrawal for the same token.
"""
__slots__ = ('__actions', '__address', '__vertex', '__timestamp', '__all_actions__')
__slots__ = ('__actions', '__caller_id', '__vertex', '__timestamp', '__all_actions__')
__actions: MappingProxyType[TokenUid, tuple[NCAction, ...]]
__address: Address | ContractId
__caller_id: CallerId
__vertex: VertexData
__timestamp: int

def __init__(
self,
actions: Sequence[NCAction],
vertex: BaseTransaction | VertexData,
address: Address | ContractId,
caller_id: CallerId,
timestamp: int,
) -> None:
# Dict of action where the key is the token_uid.
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
self.__vertex = VertexData.create_from_vertex(vertex)

# Address calling the method.
self.__address = address
self.__caller_id = caller_id

# Timestamp of the first block confirming tx.
self.__timestamp = timestamp
Expand All @@ -87,8 +87,29 @@ def vertex(self) -> VertexData:
return self.__vertex

@property
def address(self) -> Address | ContractId:
return self.__address
def caller_id(self) -> CallerId:
"""Get the caller ID which can be either an Address or a ContractId."""
return self.__caller_id

def get_caller_address(self) -> Address | None:
"""Get the caller address if the caller is an address, None if it's a contract."""
match self.caller_id:
case Address():
return self.caller_id
case ContractId():
return None
case _:
assert_never(self.caller_id)

def get_caller_contract_id(self) -> ContractId | None:
"""Get the caller contract ID if the caller is a contract, None if it's an address."""
match self.caller_id:
case Address():
return None
case ContractId():
return self.caller_id
case _:
assert_never(self.caller_id)

@property
def timestamp(self) -> int:
Expand Down Expand Up @@ -116,14 +137,14 @@ def copy(self) -> Context:
return Context(
actions=list(self.__all_actions__),
vertex=self.vertex,
address=self.address,
caller_id=self.caller_id,
timestamp=self.timestamp,
)

def to_json(self) -> dict[str, Any]:
"""Return a JSON representation of the context."""
return {
'actions': [action.to_json() for action in self.__all_actions__],
'address': get_address_b58_from_bytes(self.address),
'caller_id': get_address_b58_from_bytes(self.caller_id),
'timestamp': self.timestamp,
}
3 changes: 3 additions & 0 deletions hathor/nanocontracts/nc_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from hathor.nanocontracts.nc_types.address_nc_type import AddressNCType
from hathor.nanocontracts.nc_types.bool_nc_type import BoolNCType
from hathor.nanocontracts.nc_types.bytes_nc_type import BytesLikeNCType, BytesNCType
from hathor.nanocontracts.nc_types.caller_id_nc_type import CallerIdNCType
from hathor.nanocontracts.nc_types.collection_nc_type import DequeNCType, FrozenSetNCType, ListNCType, SetNCType
from hathor.nanocontracts.nc_types.dataclass_nc_type import DataclassNCType
from hathor.nanocontracts.nc_types.fixed_size_bytes_nc_type import Bytes32NCType
Expand Down Expand Up @@ -56,6 +57,7 @@
'BoolNCType',
'BytesLikeNCType',
'BytesNCType',
'CallerIdNCType',
'DataclassNCType',
'DequeNCType',
'DictNCType',
Expand Down Expand Up @@ -124,6 +126,7 @@
TxOutputScript: BytesLikeNCType[TxOutputScript],
VertexId: Bytes32NCType,
SignedData: SignedDataNCType,
(Address, ContractId): CallerIdNCType,
}

# This mapping includes all supported NCType classes, should only be used for parsing function calls
Expand Down
115 changes: 115 additions & 0 deletions hathor/nanocontracts/nc_types/caller_id_nc_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2025 Hathor Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from types import UnionType
from typing import _UnionGenericAlias as UnionGenericAlias, assert_never, get_args # type: ignore[attr-defined]

from typing_extensions import Self, override

from hathor.crypto.util import decode_address, get_address_b58_from_bytes
from hathor.nanocontracts.nc_types.nc_type import NCType
from hathor.nanocontracts.types import Address, CallerId, ContractId
from hathor.serialization import Deserializer, Serializer
from hathor.serialization.compound_encoding.caller_id import decode_caller_id, encode_caller_id
from hathor.transaction.base_transaction import TX_HASH_SIZE
from hathor.transaction.headers.nano_header import ADDRESS_LEN_BYTES


class CallerIdNCType(NCType[CallerId]):
"""Represents `CallerID` values, which can be `Address` or `ContractId`."""
__slots__ = ()
_is_hashable = True

@override
@classmethod
def _from_type(cls, type_: type[Address] | type[ContractId], /, *, type_map: NCType.TypeMap) -> Self:
if not isinstance(type_, (UnionType, UnionGenericAlias)):
raise TypeError('expected type union')
args = get_args(type_)
assert args, 'union always has args'
if len(args) != 2 or Address not in args or ContractId not in args:
raise TypeError('type must be either `Address | ContractId` or `ContractId | Address`')
return cls()

@override
def _check_value(self, value: CallerId, /, *, deep: bool) -> None:
match value:
case Address():
if len(value) != ADDRESS_LEN_BYTES:
raise ValueError(f'an address must always have {ADDRESS_LEN_BYTES} bytes')
case ContractId():
if len(value) != TX_HASH_SIZE:
raise ValueError(f'an contract id must always have {TX_HASH_SIZE} bytes')
case _:
assert_never(value)

@override
def _serialize(self, serializer: Serializer, value: CallerId, /) -> None:
encode_caller_id(serializer, value)

@override
def _deserialize(self, deserializer: Deserializer, /) -> CallerId:
return decode_caller_id(deserializer)

@override
def _json_to_value(self, json_value: NCType.Json, /) -> CallerId:
"""
>>> nc_type = CallerIdNCType()
>>> value = nc_type.json_to_value('HH5As5aLtzFkcbmbXZmE65wSd22GqPWq2T')
>>> isinstance(value, Address)
True
>>> value == Address(bytes.fromhex('2873c0a326af979a12be89ee8a00e8871c8e2765022e9b803c'))
True
>>> contract_id = ContractId(b'\x11' * 32)
>>> value = nc_type.json_to_value(contract_id.hex())
>>> isinstance(value, ContractId)
True
>>> value == contract_id
True
>>> nc_type.json_to_value('foo')
Traceback (most recent call last):
...
ValueError: cannot decode "foo" as CallerId
"""
if not isinstance(json_value, str):
raise ValueError('expected str')

if len(json_value) == 34:
return Address(decode_address(json_value))

if len(json_value) == TX_HASH_SIZE * 2:
return ContractId(bytes.fromhex(json_value))

raise ValueError(f'cannot decode "{json_value}" as CallerId')

@override
def _value_to_json(self, value: CallerId, /) -> NCType.Json:
"""
>>> nc_type = CallerIdNCType()
>>> address = Address(bytes.fromhex('2873c0a326af979a12be89ee8a00e8871c8e2765022e9b803c'))
>>> nc_type.value_to_json(address)
'HH5As5aLtzFkcbmbXZmE65wSd22GqPWq2T'
>>> contract_id = ContractId(b'\x11' * 32)
>>> nc_type.value_to_json(contract_id)
'1111111111111111111111111111111111111111111111111111111111111111'
"""
match value:
case Address():
return get_address_b58_from_bytes(value)
case ContractId():
return value.hex()
case _:
assert_never(value)
15 changes: 12 additions & 3 deletions hathor/nanocontracts/nc_types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

T = TypeVar('T')
TypeAliasMap: TypeAlias = Mapping[type | UnionType, type]
TypeToNCTypeMap: TypeAlias = Mapping[type | UnionType, type['NCType']]
TypeToNCTypeMap: TypeAlias = Mapping[type | UnionType | tuple[type, ...], type['NCType']]


def get_origin_classes(type_: type) -> Iterator[type]:
Expand Down Expand Up @@ -218,7 +218,7 @@ def get_usable_origin_type(
*,
type_map: 'NCType.TypeMap',
_verbose: bool = True,
) -> type:
) -> type | tuple[type, ...]:
""" The purpose of this function is to map a given type into a type that is usable in a NCType.TypeMap

It takes into account type-aliasing according to NCType.TypeMap.alias_map. If the given type cannot be used in the
Expand All @@ -243,7 +243,16 @@ def get_usable_origin_type(

# if we have a `dict[int, int]` we use `get_origin()` to get the `dict` part, since it's a different instance
aliased_type: type = get_aliased_type(type_, type_map.alias_map, _verbose=_verbose)
origin_aliased_type: type = get_origin(aliased_type) or aliased_type
origin_aliased_type: type | tuple[type, ...] = get_origin(aliased_type) or aliased_type

if origin_aliased_type is UnionType:
# When it's an union and None is not in it, it's not Optional,
# so we must index by args which is a tuple of types.
# This is done for support of specific union types such as CallerId (Address | ContractId)
args = get_args(aliased_type)
assert args is not None
if NoneType not in args:
origin_aliased_type = args

if origin_aliased_type in type_map.nc_types_map:
return origin_aliased_type
Expand Down
2 changes: 1 addition & 1 deletion hathor/nanocontracts/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def _unsafe_call_another_contract_public_method(
ctx = Context(
actions=actions,
vertex=first_ctx.vertex,
address=last_call_record.contract_id,
caller_id=last_call_record.contract_id,
timestamp=first_ctx.timestamp,
)
return self._execute_public_method_call(
Expand Down
17 changes: 14 additions & 3 deletions hathor/nanocontracts/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,26 @@
from hathor.transaction.util import bytes_to_int, int_to_bytes
from hathor.utils.typing import InnerTypeMixin


# Types to be used by blueprints.
Address = NewType('Address', bytes)
class Address(bytes):
__slots__ = ()


class VertexId(bytes):
__slots__ = ()


class ContractId(VertexId):
__slots__ = ()


Amount = NewType('Amount', int)
Timestamp = NewType('Timestamp', int)
TokenUid = NewType('TokenUid', bytes)
TxOutputScript = NewType('TxOutputScript', bytes)
VertexId = NewType('VertexId', bytes)
BlueprintId = NewType('BlueprintId', VertexId)
ContractId = NewType('ContractId', VertexId)
CallerId: TypeAlias = Address | ContractId

T = TypeVar('T')

Expand Down
80 changes: 80 additions & 0 deletions hathor/serialization/compound_encoding/caller_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2025 Hathor Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""
A caller ID union type is encoded with a single byte identifier followed by the encoded value according to the type.

Layout:

[0x00][address] when Address
[0x01][contract_id] when ContractId

>>> from hathor.nanocontracts.types import Address, ContractId
>>> se = Serializer.build_bytes_serializer()
>>> addr = Address(b'\x11' * 25)
>>> encode_caller_id(se, addr)
>>> bytes(se.finalize()).hex()
'0011111111111111111111111111111111111111111111111111'

>>> se = Serializer.build_bytes_serializer()
>>> contract_id = ContractId(b'\x22' * 32)
>>> encode_caller_id(se, contract_id)
>>> bytes(se.finalize()).hex()
'012222222222222222222222222222222222222222222222222222222222222222'

>>> de = Deserializer.build_bytes_deserializer(bytes.fromhex('0011111111111111111111111111111111111111111111111111'))
>>> result = decode_caller_id(de)
>>> isinstance(result, Address)
True
>>> de.finalize()

>>> value = bytes.fromhex('012222222222222222222222222222222222222222222222222222222222222222')
>>> de = Deserializer.build_bytes_deserializer(value)
>>> result = decode_caller_id(de)
>>> isinstance(result, ContractId)
True
>>> de.finalize()
"""

from typing import assert_never

from hathor.nanocontracts.types import Address, CallerId, ContractId
from hathor.serialization import Deserializer, Serializer
from hathor.serialization.encoding.bool import decode_bool, encode_bool

from ...transaction.base_transaction import TX_HASH_SIZE
from ...transaction.headers.nano_header import ADDRESS_LEN_BYTES


def encode_caller_id(serializer: Serializer, value: CallerId) -> None:
match value:
case Address():
assert len(value) == ADDRESS_LEN_BYTES
encode_bool(serializer, False)
case ContractId():
assert len(value) == TX_HASH_SIZE
encode_bool(serializer, True)
case _:
assert_never(value)
serializer.write_bytes(value)


def decode_caller_id(deserializer: Deserializer) -> CallerId:
is_contract = decode_bool(deserializer)
if is_contract:
data = bytes(deserializer.read_bytes(TX_HASH_SIZE))
return ContractId(data)
else:
data = bytes(deserializer.read_bytes(ADDRESS_LEN_BYTES))
return Address(data)
2 changes: 1 addition & 1 deletion hathor/transaction/headers/nano_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def get_context(self) -> Context:
context = Context(
actions=action_list,
vertex=self.tx,
address=Address(self.nc_address),
caller_id=Address(self.nc_address),
timestamp=timestamp,
)
return context
Loading
Loading