Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 65 additions & 25 deletions crates/lib/src/fee/fee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,37 @@ pub struct TotalFeeCalculation {
pub transfer_fee_amount: u64,
}

impl TotalFeeCalculation {
pub fn new(
total_fee_lamports: u64,
base_fee: u64,
kora_signature_fee: u64,
fee_payer_outflow: u64,
payment_instruction_fee: u64,
transfer_fee_amount: u64,
) -> Self {
Self {
total_fee_lamports,
base_fee,
kora_signature_fee,
fee_payer_outflow,
payment_instruction_fee,
transfer_fee_amount,
}
}

pub fn new_fixed(total_fee_lamports: u64) -> Self {
Self {
total_fee_lamports,
base_fee: 0,
kora_signature_fee: 0,
fee_payer_outflow: 0,
payment_instruction_fee: 0,
transfer_fee_amount: 0,
}
}
}

pub struct FeeConfigUtil {}

impl FeeConfigUtil {
Expand Down Expand Up @@ -255,34 +286,43 @@ impl FeeConfigUtil {
) -> Result<TotalFeeCalculation, KoraError> {
let config = get_config()?;

// Check if the price is free, so that we can return early (and skip expensive RPC calls / estimation)
if matches!(&config.validation.price.model, PriceModel::Free) {
return Ok(TotalFeeCalculation {
total_fee_lamports: 0,
base_fee: 0,
kora_signature_fee: 0,
fee_payer_outflow: 0,
payment_instruction_fee: 0,
transfer_fee_amount: 0,
});
}
match &config.validation.price.model {
PriceModel::Free => Ok(TotalFeeCalculation::new_fixed(0)),
PriceModel::Fixed { .. } => {
let fixed_fee_lamports = config
.validation
.price
.get_required_lamports_with_fixed(rpc_client, price_source)
.await?;

// Get the raw transaction fees
let mut fee_calculation =
Self::estimate_transaction_fee(rpc_client, transaction, fee_payer, is_payment_required)
Ok(TotalFeeCalculation::new_fixed(fixed_fee_lamports))
}
PriceModel::Margin { .. } => {
// Get the raw transaction
let fee_calculation = Self::estimate_transaction_fee(
rpc_client,
transaction,
fee_payer,
is_payment_required,
)
.await?;

// Apply Kora's price model
let adjusted_fee = config
.validation
.price
.get_required_lamports(rpc_client, price_source, fee_calculation.total_fee_lamports)
.await?;

// Update the total with the price model applied
fee_calculation.total_fee_lamports = adjusted_fee;

Ok(fee_calculation)
let total_fee_lamports = config
.validation
.price
.get_required_lamports_with_margin(fee_calculation.total_fee_lamports)
.await?;

Ok(TotalFeeCalculation::new(
total_fee_lamports,
fee_calculation.base_fee,
fee_calculation.kora_signature_fee,
fee_calculation.fee_payer_outflow,
fee_calculation.payment_instruction_fee,
fee_calculation.transfer_fee_amount,
))
}
}
}

/// Calculate the fee in a specific token if provided
Expand Down
148 changes: 52 additions & 96 deletions crates/lib/src/fee/price.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,56 +26,59 @@ pub struct PriceConfig {
}

impl PriceConfig {
pub async fn get_required_lamports(
pub async fn get_required_lamports_with_fixed(
&self,
rpc_client: &RpcClient,
price_source: PriceSource,
) -> Result<u64, KoraError> {
if let PriceModel::Fixed { amount, token } = &self.model {
return TokenUtil::calculate_token_value_in_lamports(
*amount,
&Pubkey::from_str(token).map_err(|e| {
log::error!("Invalid Pubkey for price {e}");

KoraError::ConfigError
})?,
price_source,
rpc_client,
)
.await;
}

Err(KoraError::ConfigError)
}

pub async fn get_required_lamports_with_margin(
&self,
min_transaction_fee: u64,
) -> Result<u64, KoraError> {
match &self.model {
PriceModel::Margin { margin } => {
let multiplier = 1.0 + margin;
let result = min_transaction_fee as f64 * multiplier;

// Check for overflow/underflow before casting to u64
if result > u64::MAX as f64 || result < 0.0 {
log::error!(
"Margin calculation overflow: min_transaction_fee={}, margin={}, result={}",
min_transaction_fee,
margin,
result
);
return Err(KoraError::ValidationError(
"Margin calculation overflow".to_string(),
));
}

Ok(result as u64)
if let PriceModel::Margin { margin } = &self.model {
let multiplier = 1.0 + margin;
let result = min_transaction_fee as f64 * multiplier;

// Check for overflow/underflow before casting to u64
if result > u64::MAX as f64 || result < 0.0 {
log::error!(
"Margin calculation overflow: min_transaction_fee={}, margin={}, result={}",
min_transaction_fee,
margin,
result
);
return Err(KoraError::ValidationError("Margin calculation overflow".to_string()));
}
PriceModel::Fixed { amount, token } => {
Ok(TokenUtil::calculate_token_value_in_lamports(
*amount,
&Pubkey::from_str(token).map_err(|e| {
log::error!("Invalid Pubkey for price {e}");

KoraError::ConfigError
})?,
price_source,
rpc_client,
)
.await?)
}
PriceModel::Free => Ok(0),

return Ok(result as u64);
}

Err(KoraError::ConfigError)
}
}

#[cfg(test)]
mod tests {

use crate::tests::common::create_mock_rpc_client_with_mint;

use super::*;
use crate::tests::{common::create_mock_rpc_client_with_mint, config_mock::ConfigMockBuilder};

#[tokio::test]
async fn test_margin_model_get_required_lamports() {
Expand All @@ -85,12 +88,8 @@ mod tests {
let min_transaction_fee = 5000u64; // 5000 lamports base fee
let expected_lamports = (5000.0 * 1.1) as u64; // 5500 lamports

let rpc_client = create_mock_rpc_client_with_mint(6);

let result = price_config
.get_required_lamports(&rpc_client, PriceSource::Mock, min_transaction_fee)
.await
.unwrap();
let result =
price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();

assert_eq!(result, expected_lamports);
}
Expand All @@ -102,18 +101,15 @@ mod tests {

let min_transaction_fee = 5000u64;

let rpc_client = create_mock_rpc_client_with_mint(6);

let result = price_config
.get_required_lamports(&rpc_client, PriceSource::Mock, min_transaction_fee)
.await
.unwrap();
let result =
price_config.get_required_lamports_with_margin(min_transaction_fee).await.unwrap();

assert_eq!(result, min_transaction_fee);
}

#[tokio::test]
async fn test_fixed_model_get_required_lamports_with_oracle() {
let _m = ConfigMockBuilder::new().build_and_setup();
let rpc_client = create_mock_rpc_client_with_mint(6); // USDC has 6 decimals

let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
Expand All @@ -126,12 +122,9 @@ mod tests {

// Use Mock price source which returns 0.0001 SOL per USDC
let price_source = PriceSource::Mock;
let min_transaction_fee = 5000u64;

let result = price_config
.get_required_lamports(&rpc_client, price_source, min_transaction_fee)
.await
.unwrap();
let result =
price_config.get_required_lamports_with_fixed(&rpc_client, price_source).await.unwrap();

// Expected calculation:
// 1,000,000 base units / 10^6 = 1.0 USDC
Expand All @@ -142,6 +135,7 @@ mod tests {

#[tokio::test]
async fn test_fixed_model_get_required_lamports_with_custom_price() {
let _m = ConfigMockBuilder::new().build_and_setup();
let rpc_client = create_mock_rpc_client_with_mint(9); // 9 decimals token

let custom_token = "So11111111111111111111111111111111111111112"; // SOL mint
Expand All @@ -154,12 +148,9 @@ mod tests {

// Mock oracle returns 1.0 SOL price for SOL mint
let price_source = PriceSource::Mock;
let min_transaction_fee = 5000u64;

let result = price_config
.get_required_lamports(&rpc_client, price_source, min_transaction_fee)
.await
.unwrap();
let result =
price_config.get_required_lamports_with_fixed(&rpc_client, price_source).await.unwrap();

// Expected calculation:
// 500,000,000 base units / 10^9 = 0.5 tokens
Expand All @@ -170,6 +161,7 @@ mod tests {

#[tokio::test]
async fn test_fixed_model_get_required_lamports_small_amount() {
let _m = ConfigMockBuilder::new().build_and_setup();
let rpc_client = create_mock_rpc_client_with_mint(6); // USDC has 6 decimals

let usdc_mint = "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU";
Expand All @@ -181,12 +173,9 @@ mod tests {
};

let price_source = PriceSource::Mock;
let min_transaction_fee = 5000u64;

let result = price_config
.get_required_lamports(&rpc_client, price_source, min_transaction_fee)
.await
.unwrap();
let result =
price_config.get_required_lamports_with_fixed(&rpc_client, price_source).await.unwrap();

// Expected calculation:
// 1,000 base units / 10^6 = 0.001 USDC
Expand All @@ -195,39 +184,6 @@ mod tests {
assert_eq!(result, 100);
}

#[tokio::test]
async fn test_free_model_get_required_lamports() {
let rpc_client = create_mock_rpc_client_with_mint(6);

let price_config = PriceConfig { model: PriceModel::Free };

let min_transaction_fee = 10000u64;

let result = price_config
.get_required_lamports(&rpc_client, PriceSource::Mock, min_transaction_fee)
.await
.unwrap();

assert_eq!(result, 0);
}

#[tokio::test]
async fn test_free_model_get_required_lamports_with_high_base_fee() {
let rpc_client = create_mock_rpc_client_with_mint(6);

let price_config = PriceConfig { model: PriceModel::Free };

let min_transaction_fee = 1000000u64;

let result = price_config
.get_required_lamports(&rpc_client, PriceSource::Mock, min_transaction_fee)
.await
.unwrap();

// Free model should always return 0 regardless of base fee
assert_eq!(result, 0);
}

#[tokio::test]
async fn test_default_price_config() {
// Test that default creates Margin with 0.0 margin
Expand Down
2 changes: 1 addition & 1 deletion crates/lib/src/rpc_server/method/transfer_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ mod tests {
let request = TransferTransactionRequest {
amount: 1000,
token: Pubkey::new_unique().to_string(),
source: "invalid_pubkey".to_string(),
source: "invalid".to_string(),
destination: Pubkey::new_unique().to_string(),
signer_key: None,
};
Expand Down