Skip to content

Commit

Permalink
Add estimate_fee to Account (#1279)
Browse files Browse the repository at this point in the history
* Add estimate_fee to Account

* fix

* add support list od transaction and allows use skip_validate

* update

* remove print

* update

* Update starknet_py/net/account/base_account.py

Co-authored-by: ddoktorski <[email protected]>

* Update starknet_py/net/account/account.py

Co-authored-by: ddoktorski <[email protected]>

* fmt

* add test

* fix test

* update test

* use two distinct transaction types in test

* feedback

---------

Co-authored-by: ddoktorski <[email protected]>
  • Loading branch information
tkumor3 and ddoktorski authored Feb 15, 2024
1 parent 0dcf387 commit f992bc8
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 17 deletions.
35 changes: 18 additions & 17 deletions starknet_py/net/account/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ async def _get_max_fee(
)

if auto_estimate:
estimated_fee = await self._estimate_fee(transaction)
estimated_fee = await self.estimate_fee(transaction)
assert isinstance(estimated_fee, EstimatedFee)

max_fee = int(estimated_fee.overall_fee * Account.ESTIMATED_FEE_MULTIPLIER)

if max_fee is None:
Expand All @@ -158,7 +160,9 @@ async def _get_resource_bounds(
)

if auto_estimate:
estimated_fee = await self._estimate_fee(transaction)
estimated_fee = await self.estimate_fee(transaction)
assert isinstance(estimated_fee, EstimatedFee)

l1_resource_bounds = ResourceBounds(
max_amount=int(
estimated_fee.gas_consumed * Account.ESTIMATED_AMOUNT_MULTIPLIER
Expand Down Expand Up @@ -246,28 +250,25 @@ async def _prepare_invoke_v3(
)
return _add_resource_bounds_to_transaction(transaction, resource_bounds)

async def _estimate_fee(
async def estimate_fee(
self,
tx: AccountTransaction,
tx: Union[AccountTransaction, List[AccountTransaction]],
skip_validate: bool = False,
block_hash: Optional[Union[Hash, Tag]] = None,
block_number: Optional[Union[int, Tag]] = None,
) -> EstimatedFee:
"""
:param tx: Transaction which fee we want to calculate.
:param block_hash: a block hash.
:param block_number: a block number.
:return: Estimated fee.
"""
tx = await self.sign_for_fee_estimate(tx)
) -> Union[EstimatedFee, List[EstimatedFee]]:
transactions = (
await self.sign_for_fee_estimate(tx)
if isinstance(tx, AccountTransaction)
else [await self.sign_for_fee_estimate(t) for t in tx]
)

estimated_fee = await self._client.estimate_fee(
tx=tx,
return await self._client.estimate_fee(
tx=transactions,
skip_validate=skip_validate,
block_hash=block_hash,
block_number=block_number,
)
assert isinstance(estimated_fee, EstimatedFee)

return estimated_fee

async def get_nonce(
self,
Expand Down
23 changes: 23 additions & 0 deletions starknet_py/net/account/base_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from starknet_py.net.client import Client
from starknet_py.net.client_models import (
Calls,
EstimatedFee,
Hash,
ResourceBounds,
SentTransactionResponse,
Tag,
)
from starknet_py.net.models import AddressRepresentation, StarknetChainId
from starknet_py.net.models.transaction import (
AccountTransaction,
DeclareV1,
DeclareV2,
DeclareV3,
Expand Down Expand Up @@ -51,6 +53,27 @@ def client(self) -> Client:
Get the Client used by the Account.
"""

@abstractmethod
async def estimate_fee(
self,
tx: Union[AccountTransaction, List[AccountTransaction]],
skip_validate: bool = False,
block_hash: Optional[Union[Hash, Tag]] = None,
block_number: Optional[Union[int, Tag]] = None,
) -> Union[EstimatedFee, List[EstimatedFee]]:
"""
Estimates the resources required by a given sequence of transactions when applied on a given state.
If one of the transactions reverts or fails due to any reason (e.g. validation failure or an internal error),
a TRANSACTION_EXECUTION_ERROR is returned.
For v0-2 transactions the estimate is given in Wei, and for v3 transactions it is given in Fri.
:param tx: Transaction or list of transactions to estimate
:param skip_validate: Flag checking whether the validation part of the transaction should be executed
:param block_hash: Block hash or literals `"pending"` or `"latest"`
:param block_number: Block number or literals `"pending"` or `"latest"`
:return: Estimated fee or list of estimated fees for each transaction
"""

@abstractmethod
async def get_nonce(
self,
Expand Down
48 changes: 48 additions & 0 deletions starknet_py/tests/e2e/account/account_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,54 @@ async def test_estimate_fee_for_declare_transaction(account, map_compiled_contra
)


@pytest.mark.asyncio
async def test_account_estimate_fee_for_declare_transaction(
account, map_compiled_contract
):
declare_tx = await account.sign_declare_v1(
compiled_contract=map_compiled_contract, max_fee=MAX_FEE
)

estimated_fee = await account.estimate_fee(tx=declare_tx)

assert estimated_fee.unit == PriceUnit.WEI
assert isinstance(estimated_fee.overall_fee, int)
assert estimated_fee.overall_fee > 0
assert (
estimated_fee.gas_consumed * estimated_fee.gas_price
== estimated_fee.overall_fee
)


@pytest.mark.asyncio
async def test_account_estimate_fee_for_transactions(
account, map_compiled_contract, map_contract
):
declare_tx = await account.sign_declare_v1(
compiled_contract=map_compiled_contract, max_fee=MAX_FEE
)

invoke_tx = await account.sign_invoke_v3(
calls=Call(map_contract.address, get_selector_from_name("put"), [3, 4]),
l1_resource_bounds=MAX_RESOURCE_BOUNDS_L1,
nonce=(declare_tx.nonce + 1),
)

estimated_fee = await account.estimate_fee(tx=[declare_tx, invoke_tx])

assert len(estimated_fee) == 2
assert isinstance(estimated_fee[0], EstimatedFee)
assert isinstance(estimated_fee[1], EstimatedFee)
assert estimated_fee[0].unit == PriceUnit.WEI
assert estimated_fee[1].unit == PriceUnit.FRI
assert isinstance(estimated_fee[0].overall_fee, int)
assert estimated_fee[0].overall_fee > 0
assert (
estimated_fee[0].gas_consumed * estimated_fee[0].gas_price
== estimated_fee[0].overall_fee
)


@pytest.mark.asyncio
@pytest.mark.parametrize("key, val", [(20, 20), (30, 30)])
async def test_sending_multicall(account, map_contract, key, val):
Expand Down

0 comments on commit f992bc8

Please sign in to comment.