diff --git a/Cargo.lock b/Cargo.lock index 21fddc9af4..9709c859a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1013,7 +1013,7 @@ dependencies = [ "dirs", "ed25519-dalek", "ed25519-dalek-bip32 0.2.0", - "enum_from", + "enum_derives", "ethabi", "ethcore-transaction", "ethereum-types", @@ -1501,7 +1501,7 @@ dependencies = [ "common", "derive_more", "enum-primitive-derive", - "enum_from", + "enum_derives", "futures 0.3.28", "hex 0.4.3", "http 0.2.7", @@ -2072,7 +2072,7 @@ dependencies = [ ] [[package]] -name = "enum_from" +name = "enum_derives" version = "0.1.0" dependencies = [ "itertools", @@ -4284,7 +4284,7 @@ dependencies = [ "async-trait", "common", "derive_more", - "enum_from", + "enum_derives", "futures 0.3.28", "hex 0.4.3", "itertools", @@ -4422,7 +4422,7 @@ dependencies = [ "dirs", "either", "enum-primitive-derive", - "enum_from", + "enum_derives", "ethereum-types", "futures 0.1.29", "futures 0.3.28", diff --git a/Cargo.toml b/Cargo.toml index ac0439b17b..5868b5b314 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ members = [ "mm2src/common/shared_ref_counter", "mm2src/crypto", "mm2src/db_common", - "mm2src/derives/enum_from", + "mm2src/derives/enum_derives", "mm2src/derives/ser_error_derive", "mm2src/derives/ser_error", "mm2src/hw_common", diff --git a/mm2src/coins/Cargo.toml b/mm2src/coins/Cargo.toml index b2ccc8c227..77826bdc70 100644 --- a/mm2src/coins/Cargo.toml +++ b/mm2src/coins/Cargo.toml @@ -42,7 +42,7 @@ crypto = { path = "../crypto" } db_common = { path = "../db_common" } derive_more = "0.99" ed25519-dalek = "1.0.1" -enum_from = { path = "../derives/enum_from" } +enum_derives = { path = "../derives/enum_derives" } ethabi = { version = "17.0.0" } ethcore-transaction = { git = "https://github.com/KomodoPlatform/mm2-parity-ethereum.git" } ethereum-types = { version = "0.13", default-features = false, features = ["std", "serialize"] } diff --git a/mm2src/coins/eth.rs b/mm2src/coins/eth.rs index 1c068a26f8..a98a742b2a 100644 --- a/mm2src/coins/eth.rs +++ b/mm2src/coins/eth.rs @@ -37,7 +37,7 @@ use common::{now_ms, wait_until_ms}; use crypto::privkey::key_pair_from_secret; use crypto::{CryptoCtx, CryptoCtxError, GlobalHDAccountArc, KeyPairPolicy, StandardHDCoinAddress}; use derive_more::Display; -use enum_from::EnumFromStringify; +use enum_derives::EnumFromStringify; use ethabi::{Contract, Function, Token}; pub use ethcore_transaction::SignedTransaction as SignedEthTx; use ethcore_transaction::{Action, Transaction as UnSignedEthTx, UnverifiedTransaction}; @@ -107,10 +107,11 @@ pub use rlp; mod web3_transport; #[path = "eth/v2_activation.rs"] pub mod v2_activation; -use crate::nft::{find_wallet_nft_amount, WithdrawNftResult}; +use crate::nft::WithdrawNftResult; use v2_activation::{build_address_and_priv_key_policy, EthActivationV2Error}; mod nonce; +use crate::nft::nft_errors::GetNftInfoError; use crate::{PrivKeyPolicy, TransactionResult, WithdrawFrom}; use nonce::ParityNonce; @@ -877,16 +878,12 @@ async fn withdraw_impl(coin: EthCoin, req: WithdrawRequest) -> WithdrawResult { pub async fn withdraw_erc1155(ctx: MmArc, withdraw_type: WithdrawErc1155) -> WithdrawNftResult { let coin = lp_coinfind_or_err(&ctx, withdraw_type.chain.to_ticker()).await?; let (to_addr, token_addr, eth_coin) = - get_valid_nft_add_to_withdraw(coin, &withdraw_type.to, &withdraw_type.token_address)?; - let my_address = eth_coin.my_address()?; - - let wallet_amount = find_wallet_nft_amount( - &ctx, - &withdraw_type.chain, - withdraw_type.token_address.to_lowercase(), - withdraw_type.token_id.clone(), - ) - .await?; + get_valid_nft_addr_to_withdraw(coin, &withdraw_type.to, &withdraw_type.token_address)?; + let my_address_str = eth_coin.my_address()?; + + let token_id_str = &withdraw_type.token_id.to_string(); + let wallet_amount = eth_coin.erc1155_balance(token_addr, token_id_str).await?; + let amount_dec = if withdraw_type.max { wallet_amount.clone() } else { @@ -905,12 +902,10 @@ pub async fn withdraw_erc1155(ctx: MmArc, withdraw_type: WithdrawErc1155) -> Wit let (eth_value, data, call_addr, fee_coin) = match eth_coin.coin_type { EthCoinType::Eth => { let function = ERC1155_CONTRACT.function("safeTransferFrom")?; - let token_id_u256 = U256::from_dec_str(&withdraw_type.token_id.to_string()) - .map_err(|e| format!("{:?}", e)) - .map_to_mm(NumConversError::new)?; - let amount_u256 = U256::from_dec_str(&amount_dec.to_string()) - .map_err(|e| format!("{:?}", e)) - .map_to_mm(NumConversError::new)?; + let token_id_u256 = + U256::from_dec_str(token_id_str).map_to_mm(|e| NumConversError::new(format!("{:?}", e)))?; + let amount_u256 = + U256::from_dec_str(&amount_dec.to_string()).map_to_mm(|e| NumConversError::new(format!("{:?}", e)))?; let data = function.encode_input(&[ Token::Address(eth_coin.my_address), Token::Address(to_addr), @@ -959,7 +954,7 @@ pub async fn withdraw_erc1155(ctx: MmArc, withdraw_type: WithdrawErc1155) -> Wit Ok(TransactionNftDetails { tx_hex: BytesJson::from(signed_bytes.to_vec()), tx_hash: format!("{:02x}", signed.tx_hash()), - from: vec![my_address], + from: vec![my_address_str], to: vec![withdraw_type.to], contract_type: ContractType::Erc1155, token_address: withdraw_type.token_address, @@ -979,17 +974,26 @@ pub async fn withdraw_erc1155(ctx: MmArc, withdraw_type: WithdrawErc1155) -> Wit pub async fn withdraw_erc721(ctx: MmArc, withdraw_type: WithdrawErc721) -> WithdrawNftResult { let coin = lp_coinfind_or_err(&ctx, withdraw_type.chain.to_ticker()).await?; let (to_addr, token_addr, eth_coin) = - get_valid_nft_add_to_withdraw(coin, &withdraw_type.to, &withdraw_type.token_address)?; - let my_address = eth_coin.my_address()?; + get_valid_nft_addr_to_withdraw(coin, &withdraw_type.to, &withdraw_type.token_address)?; + let my_address_str = eth_coin.my_address()?; + + let token_id_str = &withdraw_type.token_id.to_string(); + let token_owner = eth_coin.erc721_owner(token_addr, token_id_str).await?; + let my_address = eth_coin.my_address; + if token_owner != my_address { + return MmError::err(WithdrawError::MyAddressNotNftOwner { + my_address: eth_addr_to_hex(&my_address), + token_owner: eth_addr_to_hex(&token_owner), + }); + } let (eth_value, data, call_addr, fee_coin) = match eth_coin.coin_type { EthCoinType::Eth => { let function = ERC721_CONTRACT.function("safeTransferFrom")?; let token_id_u256 = U256::from_dec_str(&withdraw_type.token_id.to_string()) - .map_err(|e| format!("{:?}", e)) - .map_to_mm(NumConversError::new)?; + .map_to_mm(|e| NumConversError::new(format!("{:?}", e)))?; let data = function.encode_input(&[ - Token::Address(eth_coin.my_address), + Token::Address(my_address), Token::Address(to_addr), Token::Uint(token_id_u256), ])?; @@ -1011,7 +1015,7 @@ pub async fn withdraw_erc721(ctx: MmArc, withdraw_type: WithdrawErc721) -> Withd ) .await?; let _nonce_lock = eth_coin.nonce_lock.lock().await; - let (nonce, _) = get_addr_nonce(eth_coin.my_address, eth_coin.web3_instances.clone()) + let (nonce, _) = get_addr_nonce(my_address, eth_coin.web3_instances.clone()) .compat() .timeout_secs(30.) .await? @@ -1034,7 +1038,7 @@ pub async fn withdraw_erc721(ctx: MmArc, withdraw_type: WithdrawErc721) -> Withd Ok(TransactionNftDetails { tx_hex: BytesJson::from(signed_bytes.to_vec()), tx_hash: format!("{:02x}", signed.tx_hash()), - from: vec![my_address], + from: vec![my_address_str], to: vec![withdraw_type.to], contract_type: ContractType::Erc721, token_address: withdraw_type.token_address, @@ -4011,6 +4015,59 @@ impl EthCoin { } } + async fn erc1155_balance(&self, token_addr: Address, token_id: &str) -> MmResult { + let wallet_amount_uint = match self.coin_type { + EthCoinType::Eth => { + let function = ERC1155_CONTRACT.function("balanceOf")?; + let token_id_u256 = + U256::from_dec_str(token_id).map_to_mm(|e| NumConversError::new(format!("{:?}", e)))?; + let data = function.encode_input(&[Token::Address(self.my_address), Token::Uint(token_id_u256)])?; + let result = self.call_request(token_addr, None, Some(data.into())).await?; + let decoded = function.decode_output(&result.0)?; + match decoded[0] { + Token::Uint(number) => number, + _ => { + let error = format!("Expected U256 as balanceOf result but got {:?}", decoded); + return MmError::err(BalanceError::InvalidResponse(error)); + }, + } + }, + EthCoinType::Erc20 { .. } => { + return MmError::err(BalanceError::Internal( + "Erc20 coin type doesnt support Erc1155 standard".to_owned(), + )) + }, + }; + let wallet_amount = u256_to_big_decimal(wallet_amount_uint, self.decimals)?; + Ok(wallet_amount) + } + + async fn erc721_owner(&self, token_addr: Address, token_id: &str) -> MmResult { + let owner_address = match self.coin_type { + EthCoinType::Eth => { + let function = ERC721_CONTRACT.function("ownerOf")?; + let token_id_u256 = + U256::from_dec_str(token_id).map_to_mm(|e| NumConversError::new(format!("{:?}", e)))?; + let data = function.encode_input(&[Token::Uint(token_id_u256)])?; + let result = self.call_request(token_addr, None, Some(data.into())).await?; + let decoded = function.decode_output(&result.0)?; + match decoded[0] { + Token::Address(owner) => owner, + _ => { + let error = format!("Expected Address as ownerOf result but got {:?}", decoded); + return MmError::err(GetNftInfoError::InvalidResponse(error)); + }, + } + }, + EthCoinType::Erc20 { .. } => { + return MmError::err(GetNftInfoError::Internal( + "Erc20 coin type doesnt support Erc721 standard".to_owned(), + )) + }, + }; + Ok(owner_address) + } + fn estimate_gas(&self, req: CallRequest) -> Box + Send> { // always using None block number as old Geth version accept only single argument in this RPC Box::new(self.web3.eth().estimate_gas(req, None).compat()) @@ -5293,9 +5350,7 @@ pub fn wei_from_big_decimal(amount: &BigDecimal, decimals: u8) -> NumConversResu } else { amount.insert_str(amount.len(), &"0".repeat(decimals)); } - U256::from_dec_str(&amount) - .map_err(|e| format!("{:?}", e)) - .map_to_mm(NumConversError::new) + U256::from_dec_str(&amount).map_to_mm(|e| NumConversError::new(format!("{:?}", e))) } impl Transaction for SignedEthTx { @@ -5726,6 +5781,7 @@ fn increase_gas_price_by_stage(gas_price: U256, level: &FeeApproxStage) -> U256 } } +/// Represents errors that can occur while retrieving an Ethereum address. #[derive(Clone, Debug, Deserialize, Display, PartialEq, Serialize)] pub enum GetEthAddressError { PrivKeyPolicyNotAllowed(PrivKeyPolicyNotAllowed), @@ -5766,21 +5822,20 @@ pub async fn get_eth_address( }) } +/// Errors encountered while validating Ethereum addresses for NFT withdrawal. #[derive(Display)] pub enum GetValidEthWithdrawAddError { - #[display(fmt = "My address {} and from address {} mismatch", my_address, from)] - AddressMismatchError { - my_address: String, - from: String, - }, + /// The specified coin does not support NFT withdrawal. #[display(fmt = "{} coin doesn't support NFT withdrawing", coin)] - CoinDoesntSupportNftWithdraw { - coin: String, - }, + CoinDoesntSupportNftWithdraw { coin: String }, + /// The provided address is invalid. InvalidAddress(String), } -fn get_valid_nft_add_to_withdraw( +/// Validates Ethereum addresses for NFT withdrawal. +/// Returns a tuple of valid `to` address, `token` address, and `EthCoin` instance on success. +/// Errors if the coin doesn't support NFT withdrawal or if the addresses are invalid. +fn get_valid_nft_addr_to_withdraw( coin_enum: MmCoinEnum, to: &str, token_add: &str, diff --git a/mm2src/coins/eth/v2_activation.rs b/mm2src/coins/eth/v2_activation.rs index fddf8da03f..cd54e472f3 100644 --- a/mm2src/coins/eth/v2_activation.rs +++ b/mm2src/coins/eth/v2_activation.rs @@ -2,7 +2,7 @@ use super::*; #[cfg(target_arch = "wasm32")] use crate::EthMetamaskPolicy; use common::executor::AbortedError; use crypto::{CryptoCtxError, StandardHDCoinAddress}; -use enum_from::EnumFromTrait; +use enum_derives::EnumFromTrait; use mm2_err_handle::common_errors::WithInternal; #[cfg(target_arch = "wasm32")] use mm2_metamask::{from_metamask_error, MetamaskError, MetamaskRpcError, WithMetamaskRpcError}; diff --git a/mm2src/coins/hd_confirm_address.rs b/mm2src/coins/hd_confirm_address.rs index b028f9fd9e..deccbac75b 100644 --- a/mm2src/coins/hd_confirm_address.rs +++ b/mm2src/coins/hd_confirm_address.rs @@ -4,7 +4,7 @@ use crypto::hw_rpc_task::HwConnectStatuses; use crypto::trezor::trezor_rpc_task::{TrezorRequestStatuses, TrezorRpcTaskProcessor, TryIntoUserAction}; use crypto::trezor::{ProcessTrezorResponse, TrezorError, TrezorProcessingError}; use crypto::{CryptoCtx, CryptoCtxError, HardwareWalletArc, HwError, HwProcessingError}; -use enum_from::{EnumFromInner, EnumFromStringify}; +use enum_derives::{EnumFromInner, EnumFromStringify}; use mm2_core::mm_ctx::MmArc; use mm2_err_handle::prelude::*; use rpc_task::{RpcTask, RpcTaskError, RpcTaskHandleShared}; diff --git a/mm2src/coins/lp_coins.rs b/mm2src/coins/lp_coins.rs index 1ca02100a7..2d1dea7885 100644 --- a/mm2src/coins/lp_coins.rs +++ b/mm2src/coins/lp_coins.rs @@ -52,7 +52,7 @@ use common::{calc_total_pages, now_sec, ten, HttpStatusCode}; use crypto::{derive_secp256k1_secret, Bip32Error, CryptoCtx, CryptoCtxError, DerivationPath, GlobalHDAccountArc, HwRpcError, KeyPairPolicy, Secp256k1Secret, StandardHDCoinAddress, StandardHDPathToCoin, WithHwRpcError}; use derive_more::Display; -use enum_from::{EnumFromStringify, EnumFromTrait}; +use enum_derives::{EnumFromStringify, EnumFromTrait}; use ethereum_types::H256; use futures::compat::Future01CompatExt; use futures::lock::Mutex as AsyncMutex; @@ -2402,11 +2402,6 @@ pub enum WithdrawError { CoinDoesntSupportNftWithdraw { coin: String, }, - #[display(fmt = "My address {} and from address {} mismatch", my_address, from)] - AddressMismatchError { - my_address: String, - from: String, - }, #[display(fmt = "Contract type {} doesnt support 'withdraw_nft' yet", _0)] ContractTypeDoesntSupportNftWithdrawing(String), #[display(fmt = "Action not allowed for coin: {}", _0)] @@ -2427,6 +2422,11 @@ pub enum WithdrawError { }, #[display(fmt = "DB error {}", _0)] DbError(String), + #[display(fmt = "My address is {}, while current Nft owner is {}", my_address, token_owner)] + MyAddressNotNftOwner { + my_address: String, + token_owner: String, + }, } impl HttpStatusCode for WithdrawError { @@ -2449,10 +2449,10 @@ impl HttpStatusCode for WithdrawError { | WithdrawError::UnsupportedError(_) | WithdrawError::ActionNotAllowed(_) | WithdrawError::GetNftInfoError(_) - | WithdrawError::AddressMismatchError { .. } | WithdrawError::ContractTypeDoesntSupportNftWithdrawing(_) | WithdrawError::CoinDoesntSupportNftWithdraw { .. } - | WithdrawError::NotEnoughNftsAmount { .. } => StatusCode::BAD_REQUEST, + | WithdrawError::NotEnoughNftsAmount { .. } + | WithdrawError::MyAddressNotNftOwner { .. } => StatusCode::BAD_REQUEST, WithdrawError::HwError(_) => StatusCode::GONE, #[cfg(target_arch = "wasm32")] WithdrawError::BroadcastExpected(_) => StatusCode::BAD_REQUEST, @@ -2496,9 +2496,6 @@ impl From for WithdrawError { impl From for WithdrawError { fn from(e: GetValidEthWithdrawAddError) -> Self { match e { - GetValidEthWithdrawAddError::AddressMismatchError { my_address, from } => { - WithdrawError::AddressMismatchError { my_address, from } - }, GetValidEthWithdrawAddError::CoinDoesntSupportNftWithdraw { coin } => { WithdrawError::CoinDoesntSupportNftWithdraw { coin } }, diff --git a/mm2src/coins/nft.rs b/mm2src/coins/nft.rs index 40736a6414..613603ebd2 100644 --- a/mm2src/coins/nft.rs +++ b/mm2src/coins/nft.rs @@ -16,11 +16,11 @@ use nft_structs::{Chain, ContractType, ConvertChain, Nft, NftFromMoralis, NftLis use crate::eth::{eth_addr_to_hex, get_eth_address, withdraw_erc1155, withdraw_erc721, EthCoin, EthCoinType, EthTxFeeDetails}; -use crate::nft::nft_errors::{MetaFromUrlError, ProtectFromSpamError, TransferConfirmationsError, +use crate::nft::nft_errors::{ClearNftDbError, MetaFromUrlError, ProtectFromSpamError, TransferConfirmationsError, UpdateSpamPhishingError}; -use crate::nft::nft_structs::{build_nft_with_empty_meta, BuildNftFields, NftCommon, NftCtx, NftTransferCommon, - PhishingDomainReq, PhishingDomainRes, RefreshMetadataReq, SpamContractReq, - SpamContractRes, TransferMeta, TransferStatus, UriMeta}; +use crate::nft::nft_structs::{build_nft_with_empty_meta, BuildNftFields, ClearNftDbReq, NftCommon, NftCtx, + NftTransferCommon, PhishingDomainReq, PhishingDomainRes, RefreshMetadataReq, + SpamContractReq, SpamContractRes, TransferMeta, TransferStatus, UriMeta}; use crate::nft::storage::{NftListStorageOps, NftTransferHistoryStorageOps}; use common::parse_rfc3339_to_timestamp; use crypto::StandardHDCoinAddress; @@ -29,7 +29,7 @@ use futures::compat::Future01CompatExt; use futures::future::try_join_all; use mm2_err_handle::map_to_mm::MapToMmResult; use mm2_net::transport::send_post_request_to_uri; -use mm2_number::{BigDecimal, BigUint}; +use mm2_number::BigUint; use regex::Regex; use serde_json::Value as Json; use std::cmp::Ordering; @@ -1160,30 +1160,6 @@ async fn mark_as_spam_and_build_empty_meta MmResult { - let nft_ctx = NftCtx::from_ctx(ctx).map_to_mm(GetNftInfoError::Internal)?; - - let storage = nft_ctx.lock_db().await?; - if !NftListStorageOps::is_initialized(&storage, chain).await? { - NftListStorageOps::init(&storage, chain).await?; - } - let nft_meta = storage - .get_nft(chain, token_address.to_lowercase(), token_id.clone()) - .await? - .ok_or_else(|| GetNftInfoError::TokenNotFoundInWallet { - token_address, - token_id: token_id.to_string(), - })?; - Ok(nft_meta.common.amount) -} - async fn cache_nfts_from_moralis( ctx: &MmArc, storage: &T, @@ -1396,3 +1372,51 @@ pub(crate) fn get_domain_from_url(url: Option<&str>) -> Option { url.and_then(|uri| Url::parse(uri).ok()) .and_then(|url| url.domain().map(String::from)) } + +/// Clears NFT data from the database for specified chains. +pub async fn clear_nft_db(ctx: MmArc, req: ClearNftDbReq) -> MmResult<(), ClearNftDbError> { + if req.clear_all { + let nft_ctx = NftCtx::from_ctx(&ctx).map_to_mm(ClearNftDbError::Internal)?; + let storage = nft_ctx.lock_db().await?; + storage.clear_all_nft_data().await?; + storage.clear_all_history_data().await?; + return Ok(()); + } + + if req.chains.is_empty() { + return MmError::err(ClearNftDbError::InvalidRequest( + "Nothing to clear was specified".to_string(), + )); + } + + let nft_ctx = NftCtx::from_ctx(&ctx).map_to_mm(ClearNftDbError::Internal)?; + let storage = nft_ctx.lock_db().await?; + let mut errors = Vec::new(); + for chain in req.chains.iter() { + if let Err(e) = clear_data_for_chain(&storage, chain).await { + errors.push(e); + } + } + if !errors.is_empty() { + return MmError::err(ClearNftDbError::DbError(format!("{:?}", errors))); + } + + Ok(()) +} + +async fn clear_data_for_chain(storage: &T, chain: &Chain) -> MmResult<(), ClearNftDbError> +where + T: NftListStorageOps + NftTransferHistoryStorageOps, +{ + let (is_nft_list_init, is_history_init) = ( + NftListStorageOps::is_initialized(storage, chain).await?, + NftTransferHistoryStorageOps::is_initialized(storage, chain).await?, + ); + if is_nft_list_init { + storage.clear_nft_data(chain).await?; + } + if is_history_init { + storage.clear_history_data(chain).await?; + } + Ok(()) +} diff --git a/mm2src/coins/nft/nft_errors.rs b/mm2src/coins/nft/nft_errors.rs index f5dd5adaba..96e520e5cd 100644 --- a/mm2src/coins/nft/nft_errors.rs +++ b/mm2src/coins/nft/nft_errors.rs @@ -2,12 +2,12 @@ use crate::eth::GetEthAddressError; #[cfg(target_arch = "wasm32")] use crate::nft::storage::wasm::WasmNftCacheError; use crate::nft::storage::NftStorageError; -use crate::{CoinFindError, GetMyAddressError, WithdrawError}; +use crate::{CoinFindError, GetMyAddressError, NumConversError, WithdrawError}; use common::{HttpStatusCode, ParseRfc3339Err}; #[cfg(not(target_arch = "wasm32"))] use db_common::sqlite::rusqlite::Error as SqlError; use derive_more::Display; -use enum_from::EnumFromStringify; +use enum_derives::EnumFromStringify; use http::StatusCode; use mm2_net::transport::{GetInfoFromUriError, SlurpError}; use serde::{Deserialize, Serialize}; @@ -43,6 +43,7 @@ pub enum GetNftInfoError { ContractTypeIsNull, ProtectFromSpamError(ProtectFromSpamError), TransferConfirmationsError(TransferConfirmationsError), + NumConversError(String), } impl From for WithdrawError { @@ -109,10 +110,24 @@ impl From for GetNftInfoError { fn from(e: TransferConfirmationsError) -> Self { GetNftInfoError::TransferConfirmationsError(e) } } +impl From for GetNftInfoError { + fn from(e: ethabi::Error) -> Self { + // Currently, we use the `ethabi` crate to work with a smart contract ABI known at compile time. + // It's an internal error if there are any issues during working with a smart contract ABI. + GetNftInfoError::Internal(e.to_string()) + } +} + +impl From for GetNftInfoError { + fn from(e: NumConversError) -> Self { GetNftInfoError::NumConversError(e.to_string()) } +} + impl HttpStatusCode for GetNftInfoError { fn status_code(&self) -> StatusCode { match self { - GetNftInfoError::InvalidRequest(_) => StatusCode::BAD_REQUEST, + GetNftInfoError::InvalidRequest(_) | GetNftInfoError::TransferConfirmationsError(_) => { + StatusCode::BAD_REQUEST + }, GetNftInfoError::InvalidResponse(_) | GetNftInfoError::ParseRfc3339Err(_) => StatusCode::FAILED_DEPENDENCY, GetNftInfoError::ContractTypeIsNull => StatusCode::NOT_FOUND, GetNftInfoError::Transport(_) @@ -121,7 +136,7 @@ impl HttpStatusCode for GetNftInfoError { | GetNftInfoError::TokenNotFoundInWallet { .. } | GetNftInfoError::DbError(_) | GetNftInfoError::ProtectFromSpamError(_) - | GetNftInfoError::TransferConfirmationsError(_) => StatusCode::INTERNAL_SERVER_ERROR, + | GetNftInfoError::NumConversError(_) => StatusCode::INTERNAL_SERVER_ERROR, } } } @@ -260,13 +275,12 @@ impl HttpStatusCode for UpdateNftError { } /// Enumerates the errors that can occur during spam protection operations. -/// -/// This includes issues such as regex failures during text validation and -/// serialization/deserialization problems. #[derive(Clone, Debug, Deserialize, Display, EnumFromStringify, PartialEq, Serialize)] pub enum ProtectFromSpamError { + /// Error related to regular expression operations. #[from_stringify("regex::Error")] RegexError(String), + /// Error related to serialization or deserialization with serde_json. #[from_stringify("serde_json::Error")] SerdeError(String), } @@ -331,10 +345,13 @@ impl From for MetaFromUrlError { fn from(e: GetInfoFromUriError) -> Self { MetaFromUrlError::GetInfoFromUriError(e) } } +/// Represents errors that can occur while locking the NFT database. #[derive(Debug, Display)] pub enum LockDBError { + /// Errors specific to the WebAssembly (WASM) environment's NFT cache. #[cfg(target_arch = "wasm32")] WasmNftCacheError(WasmNftCacheError), + /// Errors related to SQL operations in non-WASM environments. #[cfg(not(target_arch = "wasm32"))] SqlError(SqlError), } @@ -349,12 +366,16 @@ impl From for LockDBError { fn from(e: WasmNftCacheError) -> Self { LockDBError::WasmNftCacheError(e) } } +/// Errors related to calculating transfer confirmations for NFTs. #[derive(Clone, Debug, Deserialize, Display, PartialEq, Serialize)] pub enum TransferConfirmationsError { + /// Occurs when the specified coin does not exist. #[display(fmt = "No such coin {}", coin)] NoSuchCoin { coin: String }, + /// Triggered when the specified coin does not support NFT operations. #[display(fmt = "{} coin doesn't support NFT", coin)] CoinDoesntSupportNft { coin: String }, + /// Represents errors encountered while retrieving the current block number. #[display(fmt = "Get current block error: {}", _0)] GetCurrentBlockErr(String), } @@ -366,3 +387,35 @@ impl From for TransferConfirmationsError { } } } + +/// Enumerates errors that can occur while clearing NFT data from the database. +#[derive(Clone, Debug, Deserialize, Display, PartialEq, Serialize, SerializeErrorType)] +#[serde(tag = "error_type", content = "error_data")] +pub enum ClearNftDbError { + /// Represents errors related to database operations. + #[display(fmt = "DB error {}", _0)] + DbError(String), + /// Indicates internal errors not directly associated with database operations. + #[display(fmt = "Internal: {}", _0)] + Internal(String), + /// Used for various types of invalid requests, such as missing or contradictory parameters. + #[display(fmt = "Invalid request: {}", _0)] + InvalidRequest(String), +} + +impl From for ClearNftDbError { + fn from(err: T) -> Self { ClearNftDbError::DbError(format!("{:?}", err)) } +} + +impl From for ClearNftDbError { + fn from(e: LockDBError) -> Self { ClearNftDbError::DbError(e.to_string()) } +} + +impl HttpStatusCode for ClearNftDbError { + fn status_code(&self) -> StatusCode { + match self { + ClearNftDbError::InvalidRequest(_) => StatusCode::BAD_REQUEST, + ClearNftDbError::DbError(_) | ClearNftDbError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} diff --git a/mm2src/coins/nft/nft_structs.rs b/mm2src/coins/nft/nft_structs.rs index 9173f3bd5b..4a4163f633 100644 --- a/mm2src/coins/nft/nft_structs.rs +++ b/mm2src/coins/nft/nft_structs.rs @@ -1,4 +1,5 @@ use common::ten; +use enum_derives::EnumVariantList; use ethereum_types::Address; use mm2_core::mm_ctx::{from_ctx, MmArc}; use mm2_err_handle::prelude::*; @@ -67,43 +68,41 @@ pub struct NftListFilters { } /// Contains parameters required to fetch metadata for a specified NFT. -/// # Fields -/// * `token_address`: The address of the NFT token. -/// * `token_id`: The ID of the NFT token. -/// * `chain`: The blockchain where the NFT exists. -/// * `protect_from_spam`: Indicates whether to check and redact potential spam. If set to true, -/// the internal function `protect_from_nft_spam` is utilized. #[derive(Debug, Deserialize)] pub struct NftMetadataReq { + /// The address of the NFT token. pub(crate) token_address: Address, + /// The ID of the NFT token. #[serde(deserialize_with = "deserialize_token_id")] pub(crate) token_id: BigUint, + /// The blockchain where the NFT exists. pub(crate) chain: Chain, + /// Indicates whether to check and redact potential spam. If set to true, + /// the internal function `protect_from_nft_spam` is utilized. #[serde(default)] pub(crate) protect_from_spam: bool, } /// Contains parameters required to refresh metadata for a specified NFT. -/// # Fields -/// * `token_address`: The address of the NFT token whose metadata needs to be refreshed. -/// * `token_id`: The ID of the NFT token. -/// * `chain`: The blockchain where the NFT exists. -/// * `url`: URL to fetch the metadata. -/// * `url_antispam`: URL used to validate if the fetched contract addresses are associated -/// with spam contracts or if domain fields in the fetched metadata match known phishing domains. #[derive(Debug, Deserialize)] pub struct RefreshMetadataReq { + /// The address of the NFT token whose metadata needs to be refreshed. pub(crate) token_address: Address, + /// The ID of the NFT token. #[serde(deserialize_with = "deserialize_token_id")] pub(crate) token_id: BigUint, + /// The blockchain where the NFT exists. pub(crate) chain: Chain, + /// URL to fetch the metadata. pub(crate) url: Url, + /// URL used to validate if the fetched contract addresses are associated + /// with spam contracts or if domain fields in the fetched metadata match known phishing domains. pub(crate) url_antispam: Url, } /// Represents blockchains which are supported by NFT feature. /// Currently there are only EVM based chains. -#[derive(Clone, Copy, Debug, PartialEq, Serialize)] +#[derive(Clone, Copy, Debug, EnumVariantList, PartialEq, Serialize)] #[serde(rename_all = "UPPERCASE")] pub enum Chain { Avalanche, @@ -385,37 +384,57 @@ pub struct NftList { pub(crate) total: usize, } +/// Parameters for withdrawing an ERC-1155 token. #[derive(Clone, Deserialize)] pub struct WithdrawErc1155 { + /// The blockchain network to perform the withdrawal on. pub(crate) chain: Chain, + /// The address to send the NFT to. pub(crate) to: String, + /// The address of the ERC-1155 token contract. pub(crate) token_address: String, + /// The unique identifier of the NFT to withdraw. #[serde(deserialize_with = "deserialize_token_id")] pub(crate) token_id: BigUint, + /// Optional amount of the token to withdraw. Defaults to 1 if not specified. pub(crate) amount: Option, + /// If set to `true`, withdraws the maximum amount available. Overrides the `amount` field. #[serde(default)] pub(crate) max: bool, + /// Optional details for the withdrawal fee. pub(crate) fee: Option, } +/// Parameters for withdrawing an ERC-721 token. #[derive(Clone, Deserialize)] pub struct WithdrawErc721 { + /// The blockchain network to perform the withdrawal on. pub(crate) chain: Chain, + /// The address to send the NFT to. pub(crate) to: String, + /// The address of the ERC-721 token contract. pub(crate) token_address: String, + /// The unique identifier of the NFT to withdraw. #[serde(deserialize_with = "deserialize_token_id")] pub(crate) token_id: BigUint, + /// Optional details for the withdrawal fee. pub(crate) fee: Option, } +/// Represents a request for withdrawing an NFT, supporting different ERC standards. #[derive(Clone, Deserialize)] #[serde(tag = "type", content = "withdraw_data")] #[serde(rename_all = "snake_case")] pub enum WithdrawNftReq { + /// Parameters for withdrawing an ERC-1155 token. WithdrawErc1155(WithdrawErc1155), + /// Parameters for withdrawing an ERC-721 token. WithdrawErc721(WithdrawErc721), } +/// Details of NFT transaction. +/// +/// Includes the raw transaction hex for broadcasting, along with additional information about the transaction. #[derive(Debug, Deserialize, Serialize)] pub struct TransactionNftDetails { /// Raw bytes of signed transaction, this should be sent as is to `send_raw_transaction` RPC to broadcast the transaction @@ -593,33 +612,44 @@ pub struct NftTransferHistoryFilters { } /// Contains parameters required to update NFT transfer history and NFT list. -/// # Fields -/// * `chains`: A list of blockchains for which the NFTs need to be updated. -/// * `url`: URL to fetch the NFT data. -/// * `url_antispam`: URL used to validate if the fetched contract addresses are associated -/// with spam contracts or if domain fields in the fetched metadata match known phishing domains. #[derive(Debug, Deserialize)] pub struct UpdateNftReq { + /// A list of blockchains for which the NFTs need to be updated. pub(crate) chains: Vec, + /// URL to fetch the NFT data. pub(crate) url: Url, + /// URL used to validate if the fetched contract addresses are associated + /// with spam contracts or if domain fields in the fetched metadata match known phishing domains. pub(crate) url_antispam: Url, } +/// Represents a unique identifier for an NFT, consisting of its token address and token ID. #[derive(Debug, Deserialize, Eq, Hash, PartialEq)] pub struct NftTokenAddrId { + /// The address of the NFT token contract. pub(crate) token_address: String, + /// The unique identifier of the NFT within its contract. pub(crate) token_id: BigUint, } +/// Holds metadata information for an NFT transfer. #[derive(Debug)] pub struct TransferMeta { + /// The address of the NFT token contract. pub(crate) token_address: String, + /// The unique identifier of the NFT. pub(crate) token_id: BigUint, + /// Optional URI for the NFT's metadata. pub(crate) token_uri: Option, + /// Optional domain associated with the NFT's metadata. pub(crate) token_domain: Option, + /// Optional name of the NFT's collection. pub(crate) collection_name: Option, + /// Optional URL for the NFT's image. pub(crate) image_url: Option, + /// Optional domain for the NFT's image. pub(crate) image_domain: Option, + /// Optional name of the NFT. pub(crate) token_name: Option, } @@ -732,3 +762,13 @@ where let s = String::deserialize(deserializer)?; BigUint::from_str(&s).map_err(serde::de::Error::custom) } + +/// Request parameters for clearing NFT data from the database. +#[derive(Debug, Deserialize)] +pub struct ClearNftDbReq { + /// Specifies the blockchain networks (e.g., Ethereum, BSC) to clear NFT data. + pub(crate) chains: Vec, + /// If `true`, clears NFT data for all chains, ignoring the `chains` field. Defaults to `false`. + #[serde(default)] + pub(crate) clear_all: bool, +} diff --git a/mm2src/coins/nft/nft_tests.rs b/mm2src/coins/nft/nft_tests.rs index 99ec3925de..151d3af02d 100644 --- a/mm2src/coins/nft/nft_tests.rs +++ b/mm2src/coins/nft/nft_tests.rs @@ -395,6 +395,48 @@ cross_test!(test_exclude_nft_phishing_spam, { assert_eq!(nfts.len(), 2); }); +cross_test!(test_clear_nft, { + let chain = Chain::Bsc; + let nft_ctx = get_nft_ctx(&chain).await; + let storage = nft_ctx.lock_db().await.unwrap(); + NftListStorageOps::init(&storage, &chain).await.unwrap(); + let nft = nft(); + storage.add_nfts_to_list(chain, vec![nft], 28056726).await.unwrap(); + + storage.clear_nft_data(&chain).await.unwrap(); + test_clear_nft_target(&storage, &chain).await; +}); + +cross_test!(test_clear_all_nft, { + let chain = Chain::Bsc; + let nft_ctx = get_nft_ctx(&chain).await; + let storage = nft_ctx.lock_db().await.unwrap(); + NftListStorageOps::init(&storage, &chain).await.unwrap(); + let nft = nft(); + storage.add_nfts_to_list(chain, vec![nft], 28056726).await.unwrap(); + + storage.clear_all_nft_data().await.unwrap(); + test_clear_nft_target(&storage, &chain).await; +}); + +#[cfg(not(target_arch = "wasm32"))] +async fn test_clear_nft_target(storage: &S, chain: &Chain) { + let is_initialized = NftListStorageOps::is_initialized(storage, chain).await.unwrap(); + assert!(!is_initialized); + + let is_err = storage.get_nft_list(vec![*chain], false, 10, None, None).await.is_err(); + assert!(is_err); + + let is_err = storage.get_last_scanned_block(chain).await.is_err(); + assert!(is_err); +} + +#[cfg(target_arch = "wasm32")] +async fn test_clear_nft_target(storage: &S, chain: &Chain) { + let nft_list = storage.get_nft_list(vec![*chain], true, 1, None, None).await.unwrap(); + assert!(nft_list.nfts.is_empty()); +} + cross_test!(test_add_get_transfers, { let chain = Chain::Bsc; let nft_ctx = get_nft_ctx(&chain).await; @@ -527,7 +569,7 @@ cross_test!(test_get_update_transfer_meta, { storage.add_transfers_to_history(chain, transfers).await.unwrap(); let vec_token_add_id = storage.get_transfers_with_empty_meta(chain).await.unwrap(); - assert_eq!(vec_token_add_id.len(), 3); + assert_eq!(vec_token_add_id.len(), 2); let token_add = "0x5c7d6712dfaf0cb079d48981781c8705e8417ca0".to_string(); let transfer_meta = TransferMeta { @@ -693,3 +735,44 @@ cross_test!(test_exclude_transfer_phishing_spam, { .transfer_history; assert_eq!(transfers.len(), 1); }); + +cross_test!(test_clear_history, { + let chain = Chain::Bsc; + let nft_ctx = get_nft_ctx(&chain).await; + let storage = nft_ctx.lock_db().await.unwrap(); + NftTransferHistoryStorageOps::init(&storage, &chain).await.unwrap(); + let transfers = nft_transfer_history(); + storage.add_transfers_to_history(chain, transfers).await.unwrap(); + + storage.clear_history_data(&chain).await.unwrap(); + test_clear_history_target(&storage, &chain).await; +}); + +cross_test!(test_clear_all_history, { + let chain = Chain::Bsc; + let nft_ctx = get_nft_ctx(&chain).await; + let storage = nft_ctx.lock_db().await.unwrap(); + NftTransferHistoryStorageOps::init(&storage, &chain).await.unwrap(); + let transfers = nft_transfer_history(); + storage.add_transfers_to_history(chain, transfers).await.unwrap(); + + storage.clear_all_history_data().await.unwrap(); + test_clear_history_target(&storage, &chain).await; +}); + +#[cfg(not(target_arch = "wasm32"))] +async fn test_clear_history_target(storage: &S, chain: &Chain) { + let is_init = NftTransferHistoryStorageOps::is_initialized(storage, chain) + .await + .unwrap(); + assert!(!is_init); +} + +#[cfg(target_arch = "wasm32")] +async fn test_clear_history_target(storage: &S, chain: &Chain) { + let transfer_list = storage + .get_transfer_history(vec![*chain], true, 1, None, None) + .await + .unwrap(); + assert!(transfer_list.transfer_history.is_empty()); +} diff --git a/mm2src/coins/nft/storage/mod.rs b/mm2src/coins/nft/storage/mod.rs index c28c33ea54..ad255100c3 100644 --- a/mm2src/coins/nft/storage/mod.rs +++ b/mm2src/coins/nft/storage/mod.rs @@ -1,7 +1,6 @@ use crate::eth::EthTxFeeDetails; use crate::nft::nft_structs::{Chain, Nft, NftList, NftListFilters, NftTokenAddrId, NftTransferHistory, NftTransferHistoryFilters, NftsTransferHistoryList, TransferMeta}; -use crate::WithdrawError; use async_trait::async_trait; use ethereum_types::Address; use mm2_err_handle::mm_error::MmResult; @@ -28,10 +27,6 @@ pub enum RemoveNftResult { /// Defines the standard errors that can occur in NFT storage operations pub trait NftStorageError: std::fmt::Debug + NotMmError + NotEqual + Send {} -impl From for WithdrawError { - fn from(err: T) -> Self { WithdrawError::DbError(format!("{:?}", err)) } -} - /// Provides asynchronous operations for handling and querying NFT listings. #[async_trait] pub trait NftListStorageOps { @@ -112,6 +107,11 @@ pub trait NftListStorageOps { domain: String, possible_phishing: bool, ) -> MmResult<(), Self::Error>; + + async fn clear_nft_data(&self, chain: &Chain) -> MmResult<(), Self::Error>; + + /// Clears all nft list tables related to each chain. + async fn clear_all_nft_data(&self) -> MmResult<(), Self::Error>; } /// Provides asynchronous operations related to the history of NFT transfers. @@ -201,6 +201,11 @@ pub trait NftTransferHistoryStorageOps { domain: String, possible_phishing: bool, ) -> MmResult<(), Self::Error>; + + async fn clear_history_data(&self, chain: &Chain) -> MmResult<(), Self::Error>; + + /// Clears all nft history tables related to each chain. + async fn clear_all_history_data(&self) -> MmResult<(), Self::Error>; } /// `get_offset_limit` function calculates offset and limit for final result if we use pagination. diff --git a/mm2src/coins/nft/storage/sql_storage.rs b/mm2src/coins/nft/storage/sql_storage.rs index 86166a4793..6844b261d9 100644 --- a/mm2src/coins/nft/storage/sql_storage.rs +++ b/mm2src/coins/nft/storage/sql_storage.rs @@ -10,7 +10,7 @@ use db_common::sql_build::{SqlCondition, SqlQuery}; use db_common::sqlite::rusqlite::types::{FromSqlError, Type}; use db_common::sqlite::rusqlite::{Connection, Error as SqlError, Result as SqlResult, Row, Statement}; use db_common::sqlite::sql_builder::SqlBuilder; -use db_common::sqlite::{query_single_row, string_from_row, validate_table_name, CHECK_TABLE_EXISTS_SQL}; +use db_common::sqlite::{query_single_row, string_from_row, SafeTableName, CHECK_TABLE_EXISTS_SQL}; use ethereum_types::Address; use futures::lock::MutexGuard as AsyncMutexGuard; use mm2_err_handle::prelude::*; @@ -23,27 +23,27 @@ use std::num::NonZeroUsize; use std::str::FromStr; impl Chain { - fn nft_list_table_name(&self) -> SqlResult { + fn nft_list_table_name(&self) -> SqlResult { let name = self.to_ticker().to_owned() + "_nft_list"; - validate_table_name(&name)?; - Ok(name) + let safe_name = SafeTableName::new(&name)?; + Ok(safe_name) } - fn transfer_history_table_name(&self) -> SqlResult { + fn transfer_history_table_name(&self) -> SqlResult { let name = self.to_ticker().to_owned() + "_nft_transfer_history"; - validate_table_name(&name)?; - Ok(name) + let safe_name = SafeTableName::new(&name)?; + Ok(safe_name) } } -fn scanned_nft_blocks_table_name() -> SqlResult { +fn scanned_nft_blocks_table_name() -> SqlResult { let name = "scanned_nft_blocks".to_string(); - validate_table_name(&name)?; - Ok(name) + let safe_name = SafeTableName::new(&name)?; + Ok(safe_name) } fn create_nft_list_table_sql(chain: &Chain) -> MmResult { - let table_name = chain.nft_list_table_name()?; + let safe_table_name = chain.nft_list_table_name()?; let sql = format!( "CREATE TABLE IF NOT EXISTS {} ( token_address VARCHAR(256) NOT NULL, @@ -75,13 +75,13 @@ fn create_nft_list_table_sql(chain: &Chain) -> MmResult { details_json TEXT, PRIMARY KEY (token_address, token_id) );", - table_name + safe_table_name.inner() ); Ok(sql) } fn create_transfer_history_table_sql(chain: &Chain) -> Result { - let table_name = chain.transfer_history_table_name()?; + let safe_table_name = chain.transfer_history_table_name()?; let sql = format!( "CREATE TABLE IF NOT EXISTS {} ( transaction_hash VARCHAR(256) NOT NULL, @@ -105,19 +105,19 @@ fn create_transfer_history_table_sql(chain: &Chain) -> Result details_json TEXT, PRIMARY KEY (transaction_hash, log_index) );", - table_name + safe_table_name.inner() ); Ok(sql) } fn create_scanned_nft_blocks_sql() -> Result { - let table_name = scanned_nft_blocks_table_name()?; + let safe_table_name = scanned_nft_blocks_table_name()?; let sql = format!( "CREATE TABLE IF NOT EXISTS {} ( chain TEXT PRIMARY KEY, last_scanned_block INTEGER DEFAULT 0 );", - table_name + safe_table_name.inner() ); Ok(sql) } @@ -129,7 +129,7 @@ fn get_nft_list_builder_preimage(chains: Vec, filters: Option, filters: Option) -> Result { - let mut sql_builder = SqlBuilder::select_from(table_name); +fn nft_list_builder_preimage( + safe_table_name: SafeTableName, + filters: Option, +) -> Result { + let mut sql_builder = SqlBuilder::select_from(safe_table_name.inner()); if let Some(filters) = filters { if filters.exclude_spam { sql_builder.and_where("possible_spam == 0"); @@ -167,7 +170,7 @@ fn get_nft_transfer_builder_preimage( .into_iter() .map(|chain| { let table_name = chain.transfer_history_table_name()?; - let sql_builder = nft_history_table_builder_preimage(table_name.as_str(), filters)?; + let sql_builder = nft_history_table_builder_preimage(table_name, filters)?; let sql_string = sql_builder .sql() .map_err(|e| SqlError::ToSqlConversionFailure(e.into()))? @@ -184,10 +187,10 @@ fn get_nft_transfer_builder_preimage( } fn nft_history_table_builder_preimage( - table_name: &str, + safe_table_name: SafeTableName, filters: Option, ) -> Result { - let mut sql_builder = SqlBuilder::select_from(table_name); + let mut sql_builder = SqlBuilder::select_from(safe_table_name.inner()); if let Some(filters) = filters { if filters.send && !filters.receive { sql_builder.and_where_eq("status", "'Send'"); @@ -388,7 +391,7 @@ fn token_address_id_from_row(row: &Row<'_>) -> Result } fn insert_nft_in_list_sql(chain: &Chain) -> Result { - let table_name = chain.nft_list_table_name()?; + let safe_table_name = chain.nft_list_table_name()?; let sql = format!( "INSERT INTO {} ( token_address, token_id, chain, amount, block_number, contract_type, possible_spam, @@ -400,13 +403,13 @@ fn insert_nft_in_list_sql(chain: &Chain) -> Result { ?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22, ?23, ?24, ?25, ?26, ?27 );", - table_name + safe_table_name.inner() ); Ok(sql) } fn insert_transfer_in_history_sql(chain: &Chain) -> Result { - let table_name = chain.transfer_history_table_name()?; + let safe_table_name = chain.transfer_history_table_name()?; let sql = format!( "INSERT INTO {} ( transaction_hash, log_index, chain, block_number, block_timestamp, contract_type, @@ -415,66 +418,69 @@ fn insert_transfer_in_history_sql(chain: &Chain) -> Result { ) VALUES ( ?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19 );", - table_name + safe_table_name.inner() ); Ok(sql) } fn upsert_last_scanned_block_sql() -> Result { - let table_name = scanned_nft_blocks_table_name()?; + let safe_table_name = scanned_nft_blocks_table_name()?; let sql = format!( "INSERT OR REPLACE INTO {} (chain, last_scanned_block) VALUES (?1, ?2);", - table_name + safe_table_name.inner() ); Ok(sql) } fn refresh_nft_metadata_sql(chain: &Chain) -> Result { - let table_name = chain.nft_list_table_name()?; + let safe_table_name = chain.nft_list_table_name()?; let sql = format!( "UPDATE {} SET possible_spam = ?1, possible_phishing = ?2, collection_name = ?3, symbol = ?4, token_uri = ?5, token_domain = ?6, metadata = ?7, \ last_token_uri_sync = ?8, last_metadata_sync = ?9, raw_image_url = ?10, image_url = ?11, image_domain = ?12, token_name = ?13, description = ?14, \ attributes = ?15, animation_url = ?16, animation_domain = ?17, external_url = ?18, external_domain = ?19, image_details = ?20 WHERE token_address = ?21 AND token_id = ?22;", - table_name + safe_table_name.inner() ); Ok(sql) } fn update_transfers_meta_by_token_addr_id_sql(chain: &Chain) -> Result { - let table_name = chain.transfer_history_table_name()?; + let safe_table_name = chain.transfer_history_table_name()?; let sql = format!( "UPDATE {} SET token_uri = ?1, token_domain = ?2, collection_name = ?3, image_url = ?4, image_domain = ?5, \ token_name = ?6 WHERE token_address = ?7 AND token_id = ?8;", - table_name + safe_table_name.inner() ); Ok(sql) } fn update_transfer_spam_by_token_addr_id(chain: &Chain) -> Result { - let table_name = chain.transfer_history_table_name()?; + let safe_table_name = chain.transfer_history_table_name()?; let sql = format!( "UPDATE {} SET possible_spam = ?1 WHERE token_address = ?2 AND token_id = ?3;", - table_name + safe_table_name.inner() ); Ok(sql) } -fn select_last_block_number_sql(table_name: String) -> Result { +fn select_last_block_number_sql(safe_table_name: SafeTableName) -> Result { let sql = format!( "SELECT block_number FROM {} ORDER BY block_number DESC LIMIT 1", - table_name + safe_table_name.inner() ); Ok(sql) } fn select_last_scanned_block_sql() -> MmResult { let table_name = scanned_nft_blocks_table_name()?; - let sql = format!("SELECT last_scanned_block FROM {} WHERE chain=?1", table_name,); + let sql = format!("SELECT last_scanned_block FROM {} WHERE chain=?1", table_name.inner()); Ok(sql) } -fn delete_nft_sql(table_name: String) -> Result { - let sql = format!("DELETE FROM {} WHERE token_address=?1 AND token_id=?2", table_name); +fn delete_nft_sql(safe_table_name: SafeTableName) -> Result { + let sql = format!( + "DELETE FROM {} WHERE token_address=?1 AND token_id=?2", + safe_table_name.inner() + ); Ok(sql) } @@ -482,38 +488,44 @@ fn block_number_from_row(row: &Row<'_>) -> Result { row.get::<_, fn nft_amount_from_row(row: &Row<'_>) -> Result { row.get(0) } -fn get_nfts_by_token_address_statement(conn: &Connection, table_name: String) -> Result { - let sql_query = format!("SELECT * FROM {} WHERE token_address = ?", table_name); +fn get_nfts_by_token_address_statement( + conn: &Connection, + safe_table_name: SafeTableName, +) -> Result { + let sql_query = format!("SELECT * FROM {} WHERE token_address = ?", safe_table_name.inner()); let stmt = conn.prepare(&sql_query)?; Ok(stmt) } -fn get_token_addresses_statement(conn: &Connection, table_name: String) -> Result { - let sql_query = format!("SELECT DISTINCT token_address FROM {}", table_name); +fn get_token_addresses_statement(conn: &Connection, safe_table_name: SafeTableName) -> Result { + let sql_query = format!("SELECT DISTINCT token_address FROM {}", safe_table_name.inner()); let stmt = conn.prepare(&sql_query)?; Ok(stmt) } fn get_transfers_from_block_statement<'a>(conn: &'a Connection, chain: &'a Chain) -> Result, SqlError> { - let table_name = chain.transfer_history_table_name()?; + let safe_table_name = chain.transfer_history_table_name()?; let sql_query = format!( "SELECT * FROM {} WHERE block_number >= ? ORDER BY block_number ASC", - table_name + safe_table_name.inner() ); let stmt = conn.prepare(&sql_query)?; Ok(stmt) } fn get_transfers_by_token_addr_id_statement(conn: &Connection, chain: Chain) -> Result { - let table_name = chain.transfer_history_table_name()?; - let sql_query = format!("SELECT * FROM {} WHERE token_address = ? AND token_id = ?", table_name); + let safe_table_name = chain.transfer_history_table_name()?; + let sql_query = format!( + "SELECT * FROM {} WHERE token_address = ? AND token_id = ?", + safe_table_name.inner() + ); let stmt = conn.prepare(&sql_query)?; Ok(stmt) } fn get_transfers_with_empty_meta_builder<'a>(conn: &'a Connection, chain: &'a Chain) -> Result, SqlError> { - let table_name = chain.transfer_history_table_name()?; - let mut sql_builder = SqlQuery::select_from(conn, table_name.as_str())?; + let safe_table_name = chain.transfer_history_table_name()?; + let mut sql_builder = SqlQuery::select_from(conn, safe_table_name.inner())?; sql_builder .sql_builder() .distinct() @@ -528,6 +540,12 @@ fn get_transfers_with_empty_meta_builder<'a>(conn: &'a Connection, chain: &'a Ch Ok(sql_builder) } +fn is_table_empty(conn: &Connection, safe_table_name: SafeTableName) -> Result { + let query = format!("SELECT COUNT(*) FROM {}", safe_table_name.inner()); + conn.query_row(&query, [], |row| row.get::<_, i64>(0)) + .map(|count| count == 0) +} + #[async_trait] impl NftListStorageOps for AsyncMutexGuard<'_, AsyncConnection> { type Error = AsyncConnError; @@ -546,11 +564,12 @@ impl NftListStorageOps for AsyncMutexGuard<'_, AsyncConnection> { async fn is_initialized(&self, chain: &Chain) -> MmResult { let table_name = chain.nft_list_table_name()?; self.call(move |conn| { - let nft_list_initialized = query_single_row(conn, CHECK_TABLE_EXISTS_SQL, [table_name], string_from_row)?; + let nft_list_initialized = + query_single_row(conn, CHECK_TABLE_EXISTS_SQL, [table_name.inner()], string_from_row)?; let scanned_nft_blocks_initialized = query_single_row( conn, CHECK_TABLE_EXISTS_SQL, - [scanned_nft_blocks_table_name()?], + [scanned_nft_blocks_table_name()?.inner()], string_from_row, )?; Ok(nft_list_initialized.is_some() && scanned_nft_blocks_initialized.is_some()) @@ -660,7 +679,10 @@ impl NftListStorageOps for AsyncMutexGuard<'_, AsyncConnection> { ) -> MmResult, Self::Error> { let table_name = chain.nft_list_table_name()?; self.call(move |conn| { - let sql = format!("SELECT * FROM {} WHERE token_address=?1 AND token_id=?2", table_name); + let sql = format!( + "SELECT * FROM {} WHERE token_address=?1 AND token_id=?2", + table_name.inner() + ); let params = [token_address, token_id.to_string()]; let nft = query_single_row(conn, &sql, params, nft_from_row)?; Ok(nft) @@ -706,7 +728,7 @@ impl NftListStorageOps for AsyncMutexGuard<'_, AsyncConnection> { let table_name = chain.nft_list_table_name()?; let sql = format!( "SELECT amount FROM {} WHERE token_address=?1 AND token_id=?2", - table_name + table_name.inner() ); let params = [token_address, token_id.to_string()]; self.call(move |conn| { @@ -783,7 +805,7 @@ impl NftListStorageOps for AsyncMutexGuard<'_, AsyncConnection> { let table_name = chain.nft_list_table_name()?; let sql = format!( "UPDATE {} SET amount = ?1 WHERE token_address = ?2 AND token_id = ?3;", - table_name + table_name.inner() ); let scanned_block_params = [chain.to_ticker().to_string(), scanned_block.to_string()]; self.call(move |conn| { @@ -806,7 +828,7 @@ impl NftListStorageOps for AsyncMutexGuard<'_, AsyncConnection> { let table_name = chain.nft_list_table_name()?; let sql = format!( "UPDATE {} SET amount = ?1, block_number = ?2 WHERE token_address = ?3 AND token_id = ?4;", - table_name + table_name.inner() ); let scanned_block_params = [chain.to_ticker().to_string(), nft.block_number.to_string()]; self.call(move |conn| { @@ -846,7 +868,10 @@ impl NftListStorageOps for AsyncMutexGuard<'_, AsyncConnection> { possible_spam: bool, ) -> MmResult<(), Self::Error> { let table_name = chain.nft_list_table_name()?; - let sql = format!("UPDATE {} SET possible_spam = ?1 WHERE token_address = ?2;", table_name); + let sql = format!( + "UPDATE {} SET possible_spam = ?1 WHERE token_address = ?2;", + table_name.inner() + ); self.call(move |conn| { let sql_transaction = conn.transaction()?; let params = [Some(i32::from(possible_spam).to_string()), Some(token_address.clone())]; @@ -859,8 +884,9 @@ impl NftListStorageOps for AsyncMutexGuard<'_, AsyncConnection> { } async fn get_animation_external_domains(&self, chain: &Chain) -> MmResult, Self::Error> { - let table_name = chain.nft_list_table_name()?; + let safe_table_name = chain.nft_list_table_name()?; self.call(move |conn| { + let table_name = safe_table_name.inner(); let sql_query = format!( "SELECT DISTINCT animation_domain FROM {} UNION SELECT DISTINCT external_domain FROM {}", table_name, table_name @@ -886,7 +912,7 @@ impl NftListStorageOps for AsyncMutexGuard<'_, AsyncConnection> { let sql = format!( "UPDATE {} SET possible_phishing = ?1 WHERE token_domain = ?2 OR image_domain = ?2 OR animation_domain = ?2 OR external_domain = ?2;", - table_name + table_name.inner() ); self.call(move |conn| { let sql_transaction = conn.transaction()?; @@ -898,6 +924,43 @@ impl NftListStorageOps for AsyncMutexGuard<'_, AsyncConnection> { .await .map_to_mm(AsyncConnError::from) } + + async fn clear_nft_data(&self, chain: &Chain) -> MmResult<(), Self::Error> { + let table_nft_name = chain.nft_list_table_name()?; + let sql_nft = format!("DROP TABLE IF EXISTS {};", table_nft_name.inner()); + let table_scanned_blocks = scanned_nft_blocks_table_name()?; + let sql_scanned_block = format!("DELETE from {} where chain=?1", table_scanned_blocks.inner()); + let scanned_block_param = [chain.to_ticker()]; + self.call(move |conn| { + let sql_transaction = conn.transaction()?; + sql_transaction.execute(&sql_nft, [])?; + sql_transaction.execute(&sql_scanned_block, scanned_block_param)?; + sql_transaction.commit()?; + if is_table_empty(conn, table_scanned_blocks.clone())? { + conn.execute(&format!("DROP TABLE IF EXISTS {};", table_scanned_blocks.inner()), []) + .map(|_| ())?; + } + Ok(()) + }) + .await + .map_to_mm(AsyncConnError::from) + } + + async fn clear_all_nft_data(&self) -> MmResult<(), Self::Error> { + self.call(move |conn| { + let sql_transaction = conn.transaction()?; + for chain in Chain::variant_list().into_iter() { + let table_name = chain.nft_list_table_name()?; + sql_transaction.execute(&format!("DROP TABLE IF EXISTS {};", table_name.inner()), [])?; + } + let table_scanned_blocks = scanned_nft_blocks_table_name()?; + sql_transaction.execute(&format!("DROP TABLE IF EXISTS {};", table_scanned_blocks.inner()), [])?; + sql_transaction.commit()?; + Ok(()) + }) + .await + .map_to_mm(AsyncConnError::from) + } } #[async_trait] @@ -917,7 +980,8 @@ impl NftTransferHistoryStorageOps for AsyncMutexGuard<'_, AsyncConnection> { async fn is_initialized(&self, chain: &Chain) -> MmResult { let table_name = chain.transfer_history_table_name()?; self.call(move |conn| { - let nft_list_initialized = query_single_row(conn, CHECK_TABLE_EXISTS_SQL, [table_name], string_from_row)?; + let nft_list_initialized = + query_single_row(conn, CHECK_TABLE_EXISTS_SQL, [table_name.inner()], string_from_row)?; Ok(nft_list_initialized.is_some()) }) .await @@ -1066,7 +1130,7 @@ impl NftTransferHistoryStorageOps for AsyncMutexGuard<'_, AsyncConnection> { let table_name = chain.transfer_history_table_name()?; let sql = format!( "SELECT * FROM {} WHERE transaction_hash=?1 AND log_index = ?2", - table_name + table_name.inner() ); self.call(move |conn| { let transfer = query_single_row( @@ -1151,7 +1215,10 @@ impl NftTransferHistoryStorageOps for AsyncMutexGuard<'_, AsyncConnection> { possible_spam: bool, ) -> MmResult<(), Self::Error> { let table_name = chain.transfer_history_table_name()?; - let sql = format!("UPDATE {} SET possible_spam = ?1 WHERE token_address = ?2;", table_name); + let sql = format!( + "UPDATE {} SET possible_spam = ?1 WHERE token_address = ?2;", + table_name.inner() + ); self.call(move |conn| { let sql_transaction = conn.transaction()?; let params = [Some(i32::from(possible_spam).to_string()), Some(token_address.clone())]; @@ -1177,8 +1244,9 @@ impl NftTransferHistoryStorageOps for AsyncMutexGuard<'_, AsyncConnection> { } async fn get_domains(&self, chain: &Chain) -> MmResult, Self::Error> { - let table_name = chain.transfer_history_table_name()?; + let safe_table_name = chain.transfer_history_table_name()?; self.call(move |conn| { + let table_name = safe_table_name.inner(); let sql_query = format!( "SELECT DISTINCT token_domain FROM {} UNION SELECT DISTINCT image_domain FROM {}", table_name, table_name @@ -1200,10 +1268,10 @@ impl NftTransferHistoryStorageOps for AsyncMutexGuard<'_, AsyncConnection> { domain: String, possible_phishing: bool, ) -> MmResult<(), Self::Error> { - let table_name = chain.transfer_history_table_name()?; + let safe_table_name = chain.transfer_history_table_name()?; let sql = format!( "UPDATE {} SET possible_phishing = ?1 WHERE token_domain = ?2 OR image_domain = ?2;", - table_name + safe_table_name.inner() ); self.call(move |conn| { let sql_transaction = conn.transaction()?; @@ -1215,4 +1283,30 @@ impl NftTransferHistoryStorageOps for AsyncMutexGuard<'_, AsyncConnection> { .await .map_to_mm(AsyncConnError::from) } + + async fn clear_history_data(&self, chain: &Chain) -> MmResult<(), Self::Error> { + let table_name = chain.transfer_history_table_name()?; + self.call(move |conn| { + let sql_transaction = conn.transaction()?; + sql_transaction.execute(&format!("DROP TABLE IF EXISTS {};", table_name.inner()), [])?; + sql_transaction.commit()?; + Ok(()) + }) + .await + .map_to_mm(AsyncConnError::from) + } + + async fn clear_all_history_data(&self) -> MmResult<(), Self::Error> { + self.call(move |conn| { + let sql_transaction = conn.transaction()?; + for chain in Chain::variant_list().into_iter() { + let table_name = chain.transfer_history_table_name()?; + sql_transaction.execute(&format!("DROP TABLE IF EXISTS {};", table_name.inner()), [])?; + } + sql_transaction.commit()?; + Ok(()) + }) + .await + .map_to_mm(AsyncConnError::from) + } } diff --git a/mm2src/coins/nft/storage/wasm/wasm_storage.rs b/mm2src/coins/nft/storage/wasm/wasm_storage.rs index faf79f663a..7cb1c9b998 100644 --- a/mm2src/coins/nft/storage/wasm/wasm_storage.rs +++ b/mm2src/coins/nft/storage/wasm/wasm_storage.rs @@ -421,6 +421,27 @@ impl NftListStorageOps for NftCacheIDBLocked<'_> { update_nft_phishing_for_index(&table, &chain_str, external_index, &domain, possible_phishing).await?; Ok(()) } + + async fn clear_nft_data(&self, chain: &Chain) -> MmResult<(), Self::Error> { + let db_transaction = self.get_inner().transaction().await?; + let nft_table = db_transaction.table::().await?; + let last_scanned_block_table = db_transaction.table::().await?; + + nft_table.delete_items_by_index("chain", chain.to_string()).await?; + last_scanned_block_table + .delete_item_by_unique_index("chain", chain.to_string()) + .await?; + Ok(()) + } + + async fn clear_all_nft_data(&self) -> MmResult<(), Self::Error> { + let db_transaction = self.get_inner().transaction().await?; + let nft_table = db_transaction.table::().await?; + let last_scanned_block_table = db_transaction.table::().await?; + nft_table.clear().await?; + last_scanned_block_table.clear().await?; + Ok(()) + } } #[async_trait] @@ -722,6 +743,20 @@ impl NftTransferHistoryStorageOps for NftCacheIDBLocked<'_> { .await?; Ok(()) } + + async fn clear_history_data(&self, chain: &Chain) -> MmResult<(), Self::Error> { + let db_transaction = self.get_inner().transaction().await?; + let table = db_transaction.table::().await?; + table.delete_items_by_index("chain", chain.to_string()).await?; + Ok(()) + } + + async fn clear_all_history_data(&self) -> MmResult<(), Self::Error> { + let db_transaction = self.get_inner().transaction().await?; + let table = db_transaction.table::().await?; + table.clear().await?; + Ok(()) + } } async fn update_transfer_phishing_for_index( diff --git a/mm2src/coins/rpc_command/get_new_address.rs b/mm2src/coins/rpc_command/get_new_address.rs index 6d57870ac5..ee8d8ad73d 100644 --- a/mm2src/coins/rpc_command/get_new_address.rs +++ b/mm2src/coins/rpc_command/get_new_address.rs @@ -8,7 +8,7 @@ use common::{HttpStatusCode, SuccessResponse}; use crypto::hw_rpc_task::{HwConnectStatuses, HwRpcTaskAwaitingStatus, HwRpcTaskUserAction, HwRpcTaskUserActionRequest}; use crypto::{from_hw_error, Bip44Chain, HwError, HwRpcError, WithHwRpcError}; use derive_more::Display; -use enum_from::EnumFromTrait; +use enum_derives::EnumFromTrait; use http::StatusCode; use mm2_core::mm_ctx::MmArc; use mm2_err_handle::prelude::*; diff --git a/mm2src/coins/rpc_command/init_create_account.rs b/mm2src/coins/rpc_command/init_create_account.rs index c67cd8cd3d..6e8b47047d 100644 --- a/mm2src/coins/rpc_command/init_create_account.rs +++ b/mm2src/coins/rpc_command/init_create_account.rs @@ -8,7 +8,7 @@ use common::{true_f, HttpStatusCode, SuccessResponse}; use crypto::hw_rpc_task::{HwConnectStatuses, HwRpcTaskAwaitingStatus, HwRpcTaskUserAction, HwRpcTaskUserActionRequest}; use crypto::{from_hw_error, Bip44Chain, HwError, HwRpcError, RpcDerivationPath, WithHwRpcError}; use derive_more::Display; -use enum_from::EnumFromTrait; +use enum_derives::EnumFromTrait; use http::StatusCode; use mm2_core::mm_ctx::MmArc; use mm2_err_handle::prelude::*; diff --git a/mm2src/crypto/Cargo.toml b/mm2src/crypto/Cargo.toml index c4e4a84f92..80fd38a212 100644 --- a/mm2src/crypto/Cargo.toml +++ b/mm2src/crypto/Cargo.toml @@ -14,7 +14,7 @@ bitcrypto = { path = "../mm2_bitcoin/crypto" } bs58 = "0.4.0" common = { path = "../common" } derive_more = "0.99" -enum_from = { path = "../derives/enum_from" } +enum_derives = { path = "../derives/enum_derives" } enum-primitive-derive = "0.2" futures = "0.3" hex = "0.4.2" diff --git a/mm2src/crypto/src/shared_db_id.rs b/mm2src/crypto/src/shared_db_id.rs index 1aff809ca9..8c78baaaf3 100644 --- a/mm2src/crypto/src/shared_db_id.rs +++ b/mm2src/crypto/src/shared_db_id.rs @@ -1,6 +1,6 @@ use crate::privkey::private_from_seed_hash; use derive_more::Display; -use enum_from::EnumFromStringify; +use enum_derives::EnumFromStringify; use keys::{Error as KeysError, KeyPair}; use mm2_err_handle::prelude::*; use primitives::hash::H160; diff --git a/mm2src/db_common/src/sqlite.rs b/mm2src/db_common/src/sqlite.rs index 5da327c0fd..af9c4905e5 100644 --- a/mm2src/db_common/src/sqlite.rs +++ b/mm2src/db_common/src/sqlite.rs @@ -86,12 +86,33 @@ pub fn validate_ident(ident: &str) -> SqlResult<()> { validate_ident_impl(ident, |c| c.is_alphanumeric() || c == '_' || c == '.') } +/// Validates a table name against SQL injection risks. +/// +/// This function checks if the provided `table_name` is safe for use in SQL queries. +/// It disallows any characters in the table name that may lead to SQL injection, only +/// allowing alphanumeric characters and underscores. pub fn validate_table_name(table_name: &str) -> SqlResult<()> { // As per https://stackoverflow.com/a/3247553, tables can't be the target of parameter substitution. // So we have to use a plain concatenation disallowing any characters in the table name that may lead to SQL injection. validate_ident_impl(table_name, |c| c.is_alphanumeric() || c == '_') } +/// Represents a SQL table name that has been validated for safety. +#[derive(Clone, Debug)] +pub struct SafeTableName(String); + +impl SafeTableName { + /// Creates a new SafeTableName, validating the provided table name. + pub fn new(table_name: &str) -> SqlResult { + validate_table_name(table_name)?; + Ok(SafeTableName(table_name.to_owned())) + } + + /// Retrieves the table name. + #[inline(always)] + pub fn inner(&self) -> &str { &self.0 } +} + /// Calculates the offset to skip records by uuid. /// Expects `query_builder` to have where clauses applied *before* calling this fn. pub fn offset_by_uuid( @@ -317,6 +338,10 @@ impl StringError { pub fn into_boxed(self) -> Box { Box::new(self) } } +/// Internal function to validate identifiers such as table names. +/// +/// This function is a general-purpose identifier validator. It uses a closure to determine +/// the validity of each character in the provided identifier. fn validate_ident_impl(ident: &str, is_valid: F) -> SqlResult<()> where F: Fn(char) -> bool, diff --git a/mm2src/derives/enum_from/Cargo.toml b/mm2src/derives/enum_derives/Cargo.toml similarity index 90% rename from mm2src/derives/enum_from/Cargo.toml rename to mm2src/derives/enum_derives/Cargo.toml index 5a1c58c1fb..9518bc6d1a 100644 --- a/mm2src/derives/enum_from/Cargo.toml +++ b/mm2src/derives/enum_derives/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "enum_from" +name = "enum_derives" version = "0.1.0" edition = "2021" diff --git a/mm2src/derives/enum_from/src/from_inner.rs b/mm2src/derives/enum_derives/src/from_inner.rs similarity index 100% rename from mm2src/derives/enum_from/src/from_inner.rs rename to mm2src/derives/enum_derives/src/from_inner.rs diff --git a/mm2src/derives/enum_from/src/from_stringify.rs b/mm2src/derives/enum_derives/src/from_stringify.rs similarity index 100% rename from mm2src/derives/enum_from/src/from_stringify.rs rename to mm2src/derives/enum_derives/src/from_stringify.rs diff --git a/mm2src/derives/enum_from/src/from_trait.rs b/mm2src/derives/enum_derives/src/from_trait.rs similarity index 100% rename from mm2src/derives/enum_from/src/from_trait.rs rename to mm2src/derives/enum_derives/src/from_trait.rs diff --git a/mm2src/derives/enum_from/src/lib.rs b/mm2src/derives/enum_derives/src/lib.rs similarity index 79% rename from mm2src/derives/enum_from/src/lib.rs rename to mm2src/derives/enum_derives/src/lib.rs index 83aef6f6f0..666e95a0e3 100644 --- a/mm2src/derives/enum_from/src/lib.rs +++ b/mm2src/derives/enum_derives/src/lib.rs @@ -3,21 +3,23 @@ use proc_macro2::{Ident, Span, TokenStream as TokenStream2}; use quote::quote; use std::fmt; use syn::Meta::List; -use syn::{parse_macro_input, Data, DeriveInput, Error, Field, Fields, ImplGenerics, Type, TypeGenerics, WhereClause}; +use syn::{parse_macro_input, Data, DataEnum, DeriveInput, Error, Field, Fields, ImplGenerics, Type, TypeGenerics, + WhereClause}; use syn::{Attribute, NestedMeta, Variant}; mod from_inner; mod from_stringify; mod from_trait; -const MACRO_IDENT: &str = "EnumFromInner"; +const ENUM_FROM_INNER_IDENT: &str = "EnumFromInner"; +const ENUM_VARIANT_LIST_IDENT: &str = "EnumVariantList"; /// Implements `From` trait for the given enumeration. /// /// # Usage /// /// ```rust -/// use enum_from::EnumFromInner; +/// use enum_derives::EnumFromInner; /// /// #[derive(EnumFromInner)] /// enum FooBar { @@ -50,7 +52,7 @@ pub fn enum_from_inner(input: TokenStream) -> TokenStream { /// # Usage /// /// ```rust -/// use enum_from::EnumFromTrait; +/// use enum_derives::EnumFromTrait; /// /// #[derive(EnumFromTrait)] /// enum FooBar { @@ -92,7 +94,7 @@ pub fn enum_from_trait(input: TokenStream) -> TokenStream { /// ### USAGE: /// /// ```rust -/// use enum_from::EnumFromStringify; +/// use enum_derives::EnumFromStringify; /// use std::fmt::{Display, Formatter}; /// use std::io::{Error, ErrorKind}; /// @@ -124,6 +126,59 @@ pub fn derive(input: TokenStream) -> TokenStream { } } +/// `EnumVariantList` is a procedural macro used to generate a method that returns a vector containing all variants of an enum. +/// This macro is intended for use with simple enums (enums without associated data or complex structures). +/// +/// ### USAGE: +/// +/// ```rust +/// use enum_derives::EnumVariantList; +/// +/// #[derive(EnumVariantList)] +/// enum Chain { +/// Avalanche, +/// Bsc, +/// Eth, +/// Fantom, +/// Polygon, +/// } +/// +///#[test] +///fn test_enum_variant_list() { +/// let all_chains = Chain::variant_list(); +/// assert_eq!(all_chains, vec![ +/// Chain::Avalanche, +/// Chain::Bsc, +/// Chain::Eth, +/// Chain::Fantom, +/// Chain::Polygon +/// ]); +///} +/// ``` +#[proc_macro_derive(EnumVariantList)] +pub fn enum_variant_list(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let name = input.ident; + + let variants = match input.data { + Data::Enum(DataEnum { variants, .. }) => variants, + Data::Struct(_) => return CompileError::expected_enum(ENUM_VARIANT_LIST_IDENT, "struct").into(), + Data::Union(_) => return CompileError::expected_enum(ENUM_VARIANT_LIST_IDENT, "union").into(), + }; + + let variant_list: Vec<_> = variants.iter().map(|v| &v.ident).collect(); + + let expanded = quote! { + impl #name { + pub fn variant_list() -> Vec<#name> { + vec![ #( #name::#variant_list ),* ] + } + } + }; + + TokenStream::from(expanded) +} + #[allow(clippy::enum_variant_names)] #[derive(Clone, Copy)] enum MacroAttr { @@ -148,8 +203,8 @@ impl fmt::Display for MacroAttr { struct CompileError(String); impl CompileError { - fn expected_enum(found: &str) -> CompileError { - CompileError(format!("'{}' cannot be implement for a {}", MACRO_IDENT, found)) + fn expected_enum(macro_ident: &str, found: &str) -> CompileError { + CompileError(format!("'{}' cannot be implement for a {}", macro_ident, found)) } fn expected_unnamed_inner(attr: MacroAttr) -> CompileError { @@ -229,8 +284,8 @@ impl<'a> UnnamedInnerField<'a> { fn derive_enum_from_macro(input: DeriveInput, attr: MacroAttr) -> Result { let enumeration = match input.data { Data::Enum(ref enumeration) => enumeration, - Data::Struct(_) => return Err(CompileError::expected_enum("struct")), - Data::Union(_) => return Err(CompileError::expected_enum("union")), + Data::Struct(_) => return Err(CompileError::expected_enum(ENUM_FROM_INNER_IDENT, "struct")), + Data::Union(_) => return Err(CompileError::expected_enum(ENUM_FROM_INNER_IDENT, "union")), }; let ctx = IdentCtx::from(&input); diff --git a/mm2src/mm2_db/Cargo.toml b/mm2src/mm2_db/Cargo.toml index 7f2418159b..5f5374acad 100644 --- a/mm2src/mm2_db/Cargo.toml +++ b/mm2src/mm2_db/Cargo.toml @@ -10,7 +10,7 @@ doctest = false async-trait = "0.1" common = { path = "../common" } derive_more = "0.99" -enum_from = { path = "../derives/enum_from" } +enum_derives = { path = "../derives/enum_derives" } futures = { version = "0.3", package = "futures", features = ["compat", "async-await", "thread-pool"] } itertools = "0.10" hex = "0.4.2" diff --git a/mm2src/mm2_db/src/indexed_db/drivers/cursor/cursor.rs b/mm2src/mm2_db/src/indexed_db/drivers/cursor/cursor.rs index bba549d3bd..2e7de40aae 100644 --- a/mm2src/mm2_db/src/indexed_db/drivers/cursor/cursor.rs +++ b/mm2src/mm2_db/src/indexed_db/drivers/cursor/cursor.rs @@ -3,7 +3,7 @@ use crate::indexed_db::db_driver::{InternalItem, ItemId}; use crate::indexed_db::BeBigUint; use common::wasm::{deserialize_from_js, serialize_to_js, stringify_js_error}; use derive_more::Display; -use enum_from::EnumFromTrait; +use enum_derives::EnumFromTrait; use futures::channel::mpsc; use futures::StreamExt; use js_sys::Array; diff --git a/mm2src/mm2_db/src/indexed_db/drivers/transaction.rs b/mm2src/mm2_db/src/indexed_db/drivers/transaction.rs index f76794b2de..1973e6c9bd 100644 --- a/mm2src/mm2_db/src/indexed_db/drivers/transaction.rs +++ b/mm2src/mm2_db/src/indexed_db/drivers/transaction.rs @@ -1,7 +1,7 @@ use super::IdbObjectStoreImpl; use common::wasm::stringify_js_error; use derive_more::Display; -use enum_from::EnumFromTrait; +use enum_derives::EnumFromTrait; use mm2_err_handle::prelude::*; use serde_json::Value as Json; use std::collections::HashSet; diff --git a/mm2src/mm2_main/Cargo.toml b/mm2src/mm2_main/Cargo.toml index aa68d9469a..2458214e11 100644 --- a/mm2src/mm2_main/Cargo.toml +++ b/mm2src/mm2_main/Cargo.toml @@ -41,7 +41,7 @@ db_common = { path = "../db_common" } derive_more = "0.99" either = "1.6" ethereum-types = { version = "0.13", default-features = false, features = ["std", "serialize"] } -enum_from = { path = "../derives/enum_from" } +enum_derives = { path = "../derives/enum_derives" } enum-primitive-derive = "0.2" futures01 = { version = "0.1", package = "futures" } futures = { version = "0.3.1", package = "futures", features = ["compat", "async-await"] } diff --git a/mm2src/mm2_main/src/lp_init/init_hw.rs b/mm2src/mm2_main/src/lp_init/init_hw.rs index d9bc45da49..18e97bc096 100644 --- a/mm2src/mm2_main/src/lp_init/init_hw.rs +++ b/mm2src/mm2_main/src/lp_init/init_hw.rs @@ -6,7 +6,7 @@ use crypto::hw_rpc_task::{HwConnectStatuses, HwRpcTaskAwaitingStatus, HwRpcTaskU use crypto::{from_hw_error, CryptoCtx, CryptoCtxError, HwCtxInitError, HwDeviceInfo, HwError, HwPubkey, HwRpcError, HwWalletType, WithHwRpcError}; use derive_more::Display; -use enum_from::EnumFromTrait; +use enum_derives::EnumFromTrait; use http::StatusCode; use mm2_core::mm_ctx::MmArc; use mm2_err_handle::prelude::*; diff --git a/mm2src/mm2_main/src/lp_init/init_metamask.rs b/mm2src/mm2_main/src/lp_init/init_metamask.rs index f624a7c5c4..f80afe5878 100644 --- a/mm2src/mm2_main/src/lp_init/init_metamask.rs +++ b/mm2src/mm2_main/src/lp_init/init_metamask.rs @@ -4,7 +4,7 @@ use common::{HttpStatusCode, SerdeInfallible, SuccessResponse}; use crypto::metamask::{from_metamask_error, MetamaskError, MetamaskRpcError, WithMetamaskRpcError}; use crypto::{CryptoCtx, CryptoCtxError, MetamaskCtxInitError}; use derive_more::Display; -use enum_from::EnumFromTrait; +use enum_derives::EnumFromTrait; use http::StatusCode; use mm2_core::mm_ctx::MmArc; use mm2_err_handle::common_errors::WithInternal; diff --git a/mm2src/mm2_main/src/lp_native_dex.rs b/mm2src/mm2_main/src/lp_native_dex.rs index 8f98bfc975..22a6cd3d4e 100644 --- a/mm2src/mm2_main/src/lp_native_dex.rs +++ b/mm2src/mm2_main/src/lp_native_dex.rs @@ -24,7 +24,7 @@ use common::executor::{SpawnFuture, Timer}; use common::log::{info, warn}; use crypto::{from_hw_error, CryptoCtx, CryptoInitError, HwError, HwProcessingError, HwRpcError, WithHwRpcError}; use derive_more::Display; -use enum_from::EnumFromTrait; +use enum_derives::EnumFromTrait; use mm2_core::mm_ctx::{MmArc, MmCtx}; use mm2_err_handle::common_errors::InternalError; use mm2_err_handle::prelude::*; diff --git a/mm2src/mm2_main/src/rpc/dispatcher/dispatcher.rs b/mm2src/mm2_main/src/rpc/dispatcher/dispatcher.rs index 7512829efc..b121b31ede 100644 --- a/mm2src/mm2_main/src/rpc/dispatcher/dispatcher.rs +++ b/mm2src/mm2_main/src/rpc/dispatcher/dispatcher.rs @@ -49,7 +49,8 @@ use http::Response; use mm2_core::mm_ctx::MmArc; use mm2_err_handle::prelude::*; use mm2_rpc::mm_protocol::{MmRpcBuilder, MmRpcRequest, MmRpcVersion}; -use nft::{get_nft_list, get_nft_metadata, get_nft_transfers, refresh_nft_metadata, update_nft, withdraw_nft}; +use nft::{clear_nft_db, get_nft_list, get_nft_metadata, get_nft_transfers, refresh_nft_metadata, update_nft, + withdraw_nft}; use serde::de::DeserializeOwned; use serde_json::{self as json, Value as Json}; use std::net::SocketAddr; @@ -156,6 +157,7 @@ async fn dispatcher_v2(request: MmRpcRequest, ctx: MmArc) -> DispatcherResult handle_mmrpc(ctx, request, add_delegation).await, "add_node_to_version_stat" => handle_mmrpc(ctx, request, add_node_to_version_stat).await, "best_orders" => handle_mmrpc(ctx, request, best_orders_rpc_v2).await, + "clear_nft_db" => handle_mmrpc(ctx, request, clear_nft_db).await, "enable_bch_with_tokens" => handle_mmrpc(ctx, request, enable_platform_coin_with_tokens::).await, "enable_slp" => handle_mmrpc(ctx, request, enable_token::).await, "enable_eth_with_tokens" => handle_mmrpc(ctx, request, enable_platform_coin_with_tokens::).await,