From 2d5146e1efa008b811aea6da381c5e311ceeff04 Mon Sep 17 00:00:00 2001 From: Vasily Styagov Date: Thu, 17 Oct 2024 18:58:21 +0100 Subject: [PATCH] restore test for restake_all --- contract/src/jar/api.rs | 41 +++--- contract/src/jar/tests/restake_all.rs | 185 +++++++++++++------------- contract/src/jar/view.rs | 11 +- contract/src/product/helpers.rs | 19 ++- contract/src/product/model/v2.rs | 9 ++ 5 files changed, 154 insertions(+), 111 deletions(-) diff --git a/contract/src/jar/api.rs b/contract/src/jar/api.rs index 1462b9a2..4dff473b 100644 --- a/contract/src/jar/api.rs +++ b/contract/src/jar/api.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, ops::Deref}; +use std::{collections::HashMap, convert::Into, ops::Deref}; use near_sdk::{env, json_types::U128, near_bindgen, require, AccountId}; use sweat_jar_model::{ @@ -18,7 +18,7 @@ use crate::{ }; impl Contract { - fn restake_internal(&mut self, product: &ProductV2) -> TokenAmount { + fn restake_internal(&mut self, product: &ProductV2) -> Option { require!(product.is_enabled, "The product is disabled"); let account_id = env::predecessor_account_id(); @@ -29,7 +29,9 @@ impl Contract { let (amount, partition_index) = jar.get_liquid_balance(&product.terms, now); - require!(amount > 0, "Nothing to restake"); + if amount == 0 { + return None; + } self.update_jar_cache(&account_id, &product.id); @@ -38,7 +40,7 @@ impl Contract { jar.clean_up_deposits(partition_index); account.deposit(&product.id, amount); - amount + Some(amount) } fn get_total_interest_for_account(&self, account: &AccountV2) -> AggregatedInterestView { @@ -100,7 +102,8 @@ impl JarApi for Contract { fn restake(&mut self, product_id: ProductId) { self.migrate_account_if_needed(&env::predecessor_account_id()); - self.restake_internal(&self.get_product(&product_id)); + let result = self.restake_internal(&self.get_product(&product_id)); + require!(result.is_some(), "Nothing to restake"); // TODO: add event logging } @@ -110,18 +113,24 @@ impl JarApi for Contract { self.migrate_account_if_needed(&account_id); - let product_ids = product_ids.unwrap_or_else(|| { - self.get_account(&account_id) - .jars - .keys() - .filter(|product_id| self.get_product(product_id).is_enabled) - .cloned() - .collect() - }); + let products: Vec = product_ids + .unwrap_or_else(|| { + self.get_account(&account_id) + .jars + .keys() + .cloned() + .collect::>() + }) + .iter() + .map(|product_id| self.get_product(product_id)) + .filter(|product| product.is_enabled) + .collect(); + let mut result: Vec<(ProductId, TokenAmount)> = vec![]; - for product_id in product_ids.iter() { - let amount = self.restake_internal(&self.get_product(product_id)); - result.push((product_id.clone(), amount)); + for product in products.iter() { + if let Some(amount) = self.restake_internal(product) { + result.push((product.id.clone(), amount)); + } } // TODO: add event logging diff --git a/contract/src/jar/tests/restake_all.rs b/contract/src/jar/tests/restake_all.rs index 2953bbbf..bb1238e0 100644 --- a/contract/src/jar/tests/restake_all.rs +++ b/contract/src/jar/tests/restake_all.rs @@ -1,92 +1,99 @@ -// use near_sdk::test_utils::test_env::alice; -// use sweat_jar_model::{ -// api::{ClaimApi, JarApi}, -// MS_IN_YEAR, -// }; -// -// use crate::{ -// common::tests::Context, -// jar::model::Jar, -// product::model::Product, -// test_utils::{admin, PRINCIPAL}, -// }; -// -// #[test] -// fn restake_all() { -// let alice = alice(); -// let admin = admin(); -// -// let restakable_product = Product::new().id("restakable_product").with_allows_restaking(true); -// -// let disabled_restakable_product = Product::new() -// .id("disabled_restakable_product") -// .with_allows_restaking(true) -// .enabled(false); -// -// let non_restakable_product = Product::new().id("non_restakable_product").with_allows_restaking(false); -// -// let long_term_restakable_product = Product::new() -// .id("long_term_restakable_product") -// .with_allows_restaking(true) -// .lockup_term(MS_IN_YEAR * 2); -// -// let restakable_jar_1 = Jar::new(0).product_id(&restakable_product.id).principal(PRINCIPAL); -// let restakable_jar_2 = Jar::new(1).product_id(&restakable_product.id).principal(PRINCIPAL); -// -// let disabled_jar = Jar::new(2) -// .product_id(&disabled_restakable_product.id) -// .principal(PRINCIPAL); -// -// let non_restakable_jar = Jar::new(3).product_id(&non_restakable_product.id).principal(PRINCIPAL); -// -// let long_term_jar = Jar::new(4) -// .product_id(&long_term_restakable_product.id) -// .principal(PRINCIPAL); -// -// let mut context = Context::new(admin) -// .with_products(&[ -// restakable_product, -// disabled_restakable_product, -// non_restakable_product, -// long_term_restakable_product, -// ]) -// .with_jars(&[ -// restakable_jar_1.clone(), -// restakable_jar_2.clone(), -// disabled_jar.clone(), -// non_restakable_jar.clone(), -// long_term_jar.clone(), -// ]); -// -// context.set_block_timestamp_in_days(366); -// -// context.switch_account(&alice); -// -// let restaked_jars = context.contract().restake_all(None); -// -// assert_eq!(restaked_jars.len(), 2); -// assert_eq!( -// restaked_jars.iter().map(|j| j.id.0).collect::>(), -// // 4 was last jar is, so 2 new restaked jars will have ids 5 and 6 -// vec![5, 6] -// ); -// -// let all_jars = context.contract().get_jars_for_account(alice); -// -// assert_eq!( -// all_jars.iter().map(|j| j.id.0).collect::>(), -// [ -// restakable_jar_1.id, -// restakable_jar_2.id, -// disabled_jar.id, -// non_restakable_jar.id, -// long_term_jar.id, -// 5, -// 6, -// ] -// ) -// } -// +use near_sdk::test_utils::test_env::alice; +use sweat_jar_model::{ + api::{JarApi, ProductApi}, + MS_IN_DAY, MS_IN_YEAR, +}; + +use crate::{ + common::tests::Context, + jar::{model::JarV2, view::create_synthetic_jar_id}, + product::model::{ + v2::{Apy, FixedProductTerms, Terms}, + ProductV2, + }, + test_utils::{admin, PRINCIPAL}, +}; + +#[test] +fn restake_all() { + let alice = alice(); + let admin = admin(); + + let regular_product = ProductV2::new().id("regular_product"); + let regular_product_to_disable = ProductV2::new().id("disabled_product"); + let long_term_product = ProductV2::new() + .id("long_term_product") + .with_terms(Terms::Fixed(FixedProductTerms { + lockup_term: MS_IN_YEAR * 2, + apy: Apy::new_downgradable(), + })); + let long_term_product_to_disable = ProductV2::new() + .id("long_term_disabled_product") + .with_terms(Terms::Fixed(FixedProductTerms { + lockup_term: MS_IN_YEAR * 2, + apy: Apy::new_downgradable(), + })); + + let regular_product_jar = JarV2::new() + .with_deposit(0, PRINCIPAL) + .with_deposit(MS_IN_DAY, PRINCIPAL); + let product_to_disable_jar = JarV2::new().with_deposit(0, PRINCIPAL); + let long_term_product_jar = JarV2::new().with_deposit(0, PRINCIPAL); + let long_term_product_to_disable_jar = JarV2::new().with_deposit(0, PRINCIPAL); + + let mut context = Context::new(admin.clone()) + .with_products(&[ + regular_product.clone(), + regular_product_to_disable.clone(), + long_term_product.clone(), + long_term_product_to_disable.clone(), + ]) + .with_jars( + &alice, + &[ + (regular_product.id.clone(), regular_product_jar), + (regular_product_to_disable.id.clone(), product_to_disable_jar), + (long_term_product.id.clone(), long_term_product_jar), + ( + long_term_product_to_disable.id.clone(), + long_term_product_to_disable_jar, + ), + ], + ); + + context.set_block_timestamp_in_ms(MS_IN_YEAR); + + context.switch_account(&admin); + context.with_deposit_yocto(1, |context| { + context + .contract() + .set_enabled(regular_product_to_disable.id.clone(), false) + }); + context.with_deposit_yocto(1, |context| { + context + .contract() + .set_enabled(long_term_product_to_disable.id.clone(), false) + }); + + let restaking_time = MS_IN_YEAR + 2 * MS_IN_DAY; + context.set_block_timestamp_in_ms(restaking_time); + + context.switch_account(&alice); + let restaked_jars = context.contract().restake_all(None); + assert_eq!(restaked_jars.len(), 1); + assert_eq!( + restaked_jars.first().unwrap(), + &(regular_product.id.clone(), PRINCIPAL * 2) + ); + + let all_jars = context.contract().get_jars_for_account(alice); + let all_jar_ids = all_jars.iter().map(|j| j.id.clone()).collect::>(); + assert!(all_jar_ids.contains(&create_synthetic_jar_id(regular_product.id, restaking_time))); + assert!(all_jar_ids.contains(&create_synthetic_jar_id(regular_product_to_disable.id, 0))); + assert!(all_jar_ids.contains(&create_synthetic_jar_id(long_term_product.id, 0))); + assert!(all_jar_ids.contains(&create_synthetic_jar_id(long_term_product_to_disable.id, 0))); +} + // #[test] // fn restake_all_after_maturity_for_restakable_product_one_jar() { // let alice = alice(); diff --git a/contract/src/jar/view.rs b/contract/src/jar/view.rs index 8b1f0b21..78dafc89 100644 --- a/contract/src/jar/view.rs +++ b/contract/src/jar/view.rs @@ -1,7 +1,10 @@ use near_sdk::json_types::{U128, U64}; use sweat_jar_model::{jar::JarView, ProductId}; -use crate::jar::model::{Jar, JarV2}; +use crate::{ + common::Timestamp, + jar::model::{Jar, JarV2}, +}; impl From for JarView { fn from(value: Jar) -> Self { @@ -35,7 +38,7 @@ impl From<&DetailedJarV2> for Vec { .deposits .iter() .map(|deposit| JarView { - id: format!("{}_{}", product_id.clone(), deposit.created_at), + id: create_synthetic_jar_id(product_id.clone(), deposit.created_at), product_id: product_id.clone(), created_at: deposit.created_at.into(), principal: deposit.principal.into(), @@ -43,3 +46,7 @@ impl From<&DetailedJarV2> for Vec { .collect() } } + +pub fn create_synthetic_jar_id(product_id: ProductId, created_at: Timestamp) -> String { + format!("{}_{}", product_id.clone(), created_at) +} diff --git a/contract/src/product/helpers.rs b/contract/src/product/helpers.rs index d5262536..6c16bc6b 100644 --- a/contract/src/product/helpers.rs +++ b/contract/src/product/helpers.rs @@ -53,10 +53,7 @@ impl ProductV2 { cap: Cap { min: 0, max: 1_000_000 }, terms: Terms::Fixed(FixedProductTerms { lockup_term: MS_IN_YEAR, - apy: Apy::Downgradable(DowngradableApy { - default: UDecimal::new(20, 2), - fallback: UDecimal::new(10, 2), - }), + apy: Apy::new_downgradable(), }), withdrawal_fee: None, public_key: None, @@ -126,3 +123,17 @@ impl Into for u32 { Apy::Constant(UDecimal::new(self.into(), 2)) } } + +// TODO: move to tests +impl Apy { + fn new_constant() -> Self { + Apy::Constant(UDecimal::new(10, 2)) + } + + pub(crate) fn new_downgradable() -> Self { + Apy::Downgradable(DowngradableApy { + default: UDecimal::new(20, 2), + fallback: UDecimal::new(10, 2), + }) + } +} diff --git a/contract/src/product/model/v2.rs b/contract/src/product/model/v2.rs index 265a799c..c01e5c8f 100644 --- a/contract/src/product/model/v2.rs +++ b/contract/src/product/model/v2.rs @@ -342,6 +342,7 @@ impl Apy { impl Contract { // UnorderedMap doesn't have cache and deserializes `Product` on each get // This cached getter significantly reduces gas usage + #[cfg(not(test))] pub(crate) fn get_product(&self, product_id: &ProductId) -> ProductV2 { self.products_cache .borrow_mut() @@ -353,4 +354,12 @@ impl Contract { }) .clone() } + + // We should avoid this caching behaviour in tests though + #[cfg(test)] + pub(crate) fn get_product(&self, product_id: &ProductId) -> ProductV2 { + self.products + .get(product_id) + .unwrap_or_else(|| env::panic_str(format!("Product {product_id} is not found").as_str())) + } }