Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add estimate_fee to Account #1279

Merged
merged 18 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions starknet_py/net/account/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ async def _get_max_fee(
)

if auto_estimate:
estimated_fee = await self._estimate_fee(transaction)
estimated_fee = await self.estimate_fee(transaction)
assert isinstance(estimated_fee, EstimatedFee)

max_fee = int(estimated_fee.overall_fee * Account.ESTIMATED_FEE_MULTIPLIER)

if max_fee is None:
Expand All @@ -158,7 +160,9 @@ async def _get_resource_bounds(
)

if auto_estimate:
estimated_fee = await self._estimate_fee(transaction)
estimated_fee = await self.estimate_fee(transaction)
assert isinstance(estimated_fee, EstimatedFee)

l1_resource_bounds = ResourceBounds(
max_amount=int(
estimated_fee.gas_consumed * Account.ESTIMATED_AMOUNT_MULTIPLIER
Expand Down Expand Up @@ -246,28 +250,25 @@ async def _prepare_invoke_v3(
)
return _add_resource_bounds_to_transaction(transaction, resource_bounds)

async def _estimate_fee(
async def estimate_fee(
self,
tx: AccountTransaction,
tx: Union[AccountTransaction, List[AccountTransaction]],
skip_validate: bool = False,
block_hash: Optional[Union[Hash, Tag]] = None,
block_number: Optional[Union[int, Tag]] = None,
) -> EstimatedFee:
"""
:param tx: Transaction which fee we want to calculate.
:param block_hash: a block hash.
:param block_number: a block number.
:return: Estimated fee.
"""
tx = await self.sign_for_fee_estimate(tx)
) -> Union[EstimatedFee, List[EstimatedFee]]:
transactions = (
await self.sign_for_fee_estimate(tx)
if isinstance(tx, AccountTransaction)
else [await self.sign_for_fee_estimate(t) for t in tx]
)

estimated_fee = await self._client.estimate_fee(
tx=tx,
return await self._client.estimate_fee(
tx=transactions,
skip_validate=skip_validate,
block_hash=block_hash,
block_number=block_number,
)
assert isinstance(estimated_fee, EstimatedFee)

return estimated_fee

async def get_nonce(
self,
Expand Down
23 changes: 23 additions & 0 deletions starknet_py/net/account/base_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from starknet_py.net.client import Client
from starknet_py.net.client_models import (
Calls,
EstimatedFee,
Hash,
ResourceBounds,
SentTransactionResponse,
Tag,
)
from starknet_py.net.models import AddressRepresentation, StarknetChainId
from starknet_py.net.models.transaction import (
AccountTransaction,
DeclareV1,
DeclareV2,
DeclareV3,
Expand Down Expand Up @@ -51,6 +53,27 @@ def client(self) -> Client:
Get the Client used by the Account.
"""

@abstractmethod
async def estimate_fee(
self,
tx: Union[AccountTransaction, List[AccountTransaction]],
skip_validate: bool = False,
block_hash: Optional[Union[Hash, Tag]] = None,
block_number: Optional[Union[int, Tag]] = None,
) -> Union[EstimatedFee, List[EstimatedFee]]:
"""
Estimates the resources required by a given sequence of transactions when applied on a given state.
If one of the transactions reverts or fails due to any reason (e.g. validation failure or an internal error),
a TRANSACTION_EXECUTION_ERROR is returned.
For v0-2 transactions the estimate is given in Wei, and for v3 transactions it is given in Fri.

:param tx: Transaction or list of transactions to estimate
:param skip_validate: Flag checking whether the validation part of the transaction should be executed
:param block_hash: Block hash or literals `"pending"` or `"latest"`
:param block_number: Block number or literals `"pending"` or `"latest"`
:return: Estimated fee or list of estimated fees for each transaction
"""

@abstractmethod
async def get_nonce(
self,
Expand Down
54 changes: 54 additions & 0 deletions starknet_py/tests/e2e/account/account_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DeployAccountTransactionV3,
EstimatedFee,
InvokeTransactionV3,
PriceUnit,
ResourceBounds,
ResourceBoundsMapping,
SierraContractClass,
Expand Down Expand Up @@ -98,6 +99,59 @@ async def test_estimate_fee_for_declare_transaction(account, map_compiled_contra
)


@pytest.mark.asyncio
async def test_account_estimate_fee_for_declare_transaction(
tkumor3 marked this conversation as resolved.
Show resolved Hide resolved
account, map_compiled_contract
):
declare_tx = await account.sign_declare_v1(
compiled_contract=map_compiled_contract, max_fee=MAX_FEE
)

estimated_fee = await account.estimate_fee(tx=declare_tx)

assert estimated_fee.unit == PriceUnit.WEI
assert isinstance(estimated_fee.overall_fee, int)
tkumor3 marked this conversation as resolved.
Show resolved Hide resolved
assert estimated_fee.overall_fee > 0
assert (
estimated_fee.gas_consumed * estimated_fee.gas_price
== estimated_fee.overall_fee
)


@pytest.mark.asyncio
async def test_account_estimate_fee_for_transactions(
account, map_compiled_contract, abi_types_compiled_contract_and_class_hash
):
nonce = await account.get_nonce(block_hash="pending")
print(nonce)
declare_tx = await account.sign_declare_v1(
compiled_contract=map_compiled_contract, max_fee=MAX_FEE
)

(
compiled_contract,
compiled_class_hash,
) = abi_types_compiled_contract_and_class_hash

declare_tx2 = await account.sign_declare_v3(
compiled_contract=compiled_contract,
compiled_class_hash=compiled_class_hash,
l1_resource_bounds=MAX_RESOURCE_BOUNDS_L1,
nonce=(declare_tx.nonce + 1),
)

estimated_fee = await account.estimate_fee(tx=[declare_tx, declare_tx2])
tkumor3 marked this conversation as resolved.
Show resolved Hide resolved

tkumor3 marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(estimated_fee[0], EstimatedFee)
assert isinstance(estimated_fee[1], EstimatedFee)
assert isinstance(estimated_fee[0].overall_fee, int)
assert estimated_fee[0].overall_fee > 0
assert (
estimated_fee[0].gas_consumed * estimated_fee[0].gas_price
== estimated_fee[0].overall_fee
)


@pytest.mark.asyncio
@pytest.mark.parametrize("key, val", [(20, 20), (30, 30)])
async def test_sending_multicall(account, map_contract, key, val):
Expand Down
Loading