diff --git a/contracts/cosmwasm-vm/cw-xcall/src/fees.rs b/contracts/cosmwasm-vm/cw-xcall/src/fees.rs index 558b8e3f..29dadd55 100644 --- a/contracts/cosmwasm-vm/cw-xcall/src/fees.rs +++ b/contracts/cosmwasm-vm/cw-xcall/src/fees.rs @@ -1,3 +1,4 @@ +use cosmwasm_std::Coin; use cw_xcall_lib::network_address::NetId; use super::*; @@ -58,4 +59,15 @@ impl<'a> CwCallService<'a> { Ok(protocol_fee + conn_total) } + + pub fn get_total_paid(&self, deps: Deps, coins: &Vec) -> Result { + let config = self.get_config(deps.storage)?; + let mut total = 0_u128; + for c in coins.iter() { + if c.denom == config.denom { + total += c.amount.u128(); + } + } + Ok(total) + } } diff --git a/contracts/cosmwasm-vm/cw-xcall/src/send_call_message.rs b/contracts/cosmwasm-vm/cw-xcall/src/send_call_message.rs index f092af35..9f888003 100644 --- a/contracts/cosmwasm-vm/cw-xcall/src/send_call_message.rs +++ b/contracts/cosmwasm-vm/cw-xcall/src/send_call_message.rs @@ -60,6 +60,7 @@ impl<'a> CwCallService<'a> { let message: CSMessage = call_request.into(); let sn: i64 = if need_response { sequence_no as i64 } else { 0 }; + let mut total_spent = 0_u128; let submessages = confirmed_sources .iter() @@ -68,6 +69,7 @@ impl<'a> CwCallService<'a> { .query_connection_fee(deps.as_ref(), to.nid(), need_response, r) .and_then(|fee| { let fund = if fee > 0 { + total_spent = total_spent.checked_add(fee).unwrap(); coins(fee, config.denom.clone()) } else { vec![] @@ -78,8 +80,16 @@ impl<'a> CwCallService<'a> { }); }) .collect::, ContractError>>()?; - let protocol_fee = self.get_protocol_fee(deps.storage); + + let total_paid = self.get_total_paid(deps.as_ref(), &info.funds)?; let fee_handler = self.fee_handler().load(deps.storage)?; + let protocol_fee = self.get_protocol_fee(deps.as_ref().storage); + let total_fee_required = protocol_fee + total_spent; + + if total_paid < total_fee_required { + return Err(ContractError::InsufficientFunds); + } + let remaining = total_paid - total_spent; let event = event_xcall_message_sent(caller.to_string(), to.to_string(), sequence_no); println!("{LOG_PREFIX} Sent Bank Message"); @@ -89,13 +99,14 @@ impl<'a> CwCallService<'a> { .add_attribute("method", "send_packet") .add_attribute("sequence_no", sequence_no.to_string()) .add_event(event); - if protocol_fee > 0 { + if remaining > 0 { let msg = BankMsg::Send { to_address: fee_handler, - amount: coins(protocol_fee, config.denom), + amount: coins(remaining, config.denom), }; res = res.add_message(msg); } + Ok(res) } } diff --git a/contracts/cosmwasm-vm/cw-xcall/tests/test_call_message.rs b/contracts/cosmwasm-vm/cw-xcall/tests/test_call_message.rs index ea39ee8d..068e3c24 100644 --- a/contracts/cosmwasm-vm/cw-xcall/tests/test_call_message.rs +++ b/contracts/cosmwasm-vm/cw-xcall/tests/test_call_message.rs @@ -19,7 +19,7 @@ const MOCK_CONTRACT_TO_ADDR: &str = "cosmoscontract"; fn send_packet_by_non_contract_and_rollback_data_is_not_null() { let mut mock_deps = deps(); - let mock_info = create_mock_info(&alice().to_string(), "umlg", 2000); + let mock_info = create_mock_info(&alice().to_string(), "arch", 2000); let contract = CwCallService::default(); @@ -144,6 +144,10 @@ fn send_packet_failure_due_rollback_len() { }) } } + WasmQuery::Smart { + contract_addr: _, + msg: _, + } => SystemResult::Ok(ContractResult::Ok(to_binary(&0_u128).unwrap())), _ => todo!(), } }); @@ -182,7 +186,82 @@ fn send_packet_failure_due_rollback_len() { fn send_packet_success_needresponse() { let mut mock_deps = deps(); - let mock_info = create_mock_info(MOCK_CONTRACT_ADDR, "umlg", 2000); + let mock_info = create_mock_info(MOCK_CONTRACT_ADDR, "arch", 2000); + + let _env = mock_env(); + + let contract = CwCallService::default(); + contract + .instantiate( + mock_deps.as_mut(), + _env, + mock_info.clone(), + cw_xcall::msg::InstantiateMsg { + network_id: "nid".to_string(), + denom: "arch".to_string(), + }, + ) + .unwrap(); + + contract.sn().save(mock_deps.as_mut().storage, &0).unwrap(); + + mock_deps.querier.update_wasm(|r| { + let constract1 = Addr::unchecked(MOCK_CONTRACT_ADDR); + let mut storage1 = HashMap::::default(); + storage1.insert(b"the key".into(), b"the value".into()); + match r { + WasmQuery::ContractInfo { contract_addr } => { + if *contract_addr == constract1 { + let response = ContractInfoResponse::default(); + SystemResult::Ok(ContractResult::Ok(to_binary(&response).unwrap())) + } else { + SystemResult::Err(SystemError::NoSuchContract { + addr: contract_addr.clone(), + }) + } + } + WasmQuery::Smart { + contract_addr: _, + msg: _, + } => SystemResult::Ok(ContractResult::Ok(to_binary(&10_u128).unwrap())), + _ => todo!(), + } + }); + + contract + .store_default_connection( + mock_deps.as_mut().storage, + NetId::from("btp".to_owned()), + Addr::unchecked("hostaddress"), + ) + .unwrap(); + + contract + .send_call_message( + mock_deps.as_mut(), + mock_info, + mock_env(), + NetworkAddress::new("btp", MOCK_CONTRACT_TO_ADDR), + vec![1, 2, 3], + Some(vec![1, 2, 3]), + vec![], + vec![], + ) + .unwrap(); + + let result = contract + .get_call_request(mock_deps.as_ref().storage, 1) + .unwrap(); + + assert!(!result.enabled()) +} + +#[test] +#[should_panic(expected = "InsufficientFunds")] +fn send_packet_fail_insufficient_funds() { + let mut mock_deps = deps(); + + let mock_info = create_mock_info(MOCK_CONTRACT_ADDR, "arch", 0); let _env = mock_env();