Skip to content

Commit 1119b22

Browse files
committed
Add IP address based tracking
1 parent 3d1718a commit 1119b22

File tree

5 files changed

+165
-6
lines changed

5 files changed

+165
-6
lines changed

src/channel.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ pub struct ChannelResponse {
1717
pub txid: String,
1818
}
1919

20-
pub async fn open_channel(state: AppState, payload: ChannelRequest) -> anyhow::Result<String> {
20+
pub async fn open_channel(
21+
state: AppState,
22+
x_forwarded_for: &str,
23+
payload: ChannelRequest,
24+
) -> anyhow::Result<String> {
2125
if payload.capacity > MAX_SEND_AMOUNT.try_into().unwrap() {
2226
anyhow::bail!("max capacity is 10,000,000");
2327
}
@@ -84,5 +88,10 @@ pub async fn open_channel(state: AppState, payload: ChannelRequest) -> anyhow::R
8488
None => anyhow::bail!("failed to open channel"),
8589
};
8690

91+
state
92+
.payments
93+
.add_payment(x_forwarded_for, payload.capacity as u64)
94+
.await;
95+
8796
Ok(txid)
8897
}

src/lightning.rs

+13-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ pub struct LightningResponse {
2323
pub payment_hash: String,
2424
}
2525

26-
pub async fn pay_lightning(state: AppState, bolt11: &str) -> anyhow::Result<String> {
26+
pub async fn pay_lightning(
27+
state: AppState,
28+
x_forwarded_for: &str,
29+
bolt11: &str,
30+
) -> anyhow::Result<String> {
2731
let params = PaymentParams::from_str(bolt11).map_err(|_| anyhow::anyhow!("invalid bolt 11"))?;
2832

2933
let invoice = if let Some(invoice) = params.invoice() {
@@ -114,5 +118,13 @@ pub async fn pay_lightning(state: AppState, bolt11: &str) -> anyhow::Result<Stri
114118
response.payment_preimage
115119
};
116120

121+
state
122+
.payments
123+
.add_payment(
124+
x_forwarded_for,
125+
invoice.amount_milli_satoshis().unwrap_or(0),
126+
)
127+
.await;
128+
117129
Ok(hex::encode(payment_preimage))
118130
}

src/main.rs

+59-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use axum::extract::Query;
2+
use axum::headers::{HeaderMap, HeaderValue};
23
use axum::http::Uri;
34
use axum::{
45
http::StatusCode,
@@ -19,6 +20,7 @@ use tonic_openssl_lnd::LndLightningClient;
1920
use tower_http::cors::{AllowHeaders, AllowMethods, Any, CorsLayer};
2021

2122
use crate::nostr_dms::listen_to_nostr_dms;
23+
use crate::payments::PaymentsByIp;
2224
use bolt11::{request_bolt11, Bolt11Request, Bolt11Response};
2325
use channel::{open_channel, ChannelRequest, ChannelResponse};
2426
use lightning::{pay_lightning, LightningRequest, LightningResponse};
@@ -30,6 +32,7 @@ mod channel;
3032
mod lightning;
3133
mod nostr_dms;
3234
mod onchain;
35+
mod payments;
3336
mod setup;
3437

3538
#[derive(Clone)]
@@ -39,6 +42,7 @@ pub struct AppState {
3942
network: bitcoin::Network,
4043
lightning_client: LndLightningClient,
4144
lnurl: AsyncClient,
45+
payments: PaymentsByIp,
4246
}
4347

4448
impl AppState {
@@ -55,6 +59,7 @@ impl AppState {
5559
network,
5660
lightning_client,
5761
lnurl,
62+
payments: PaymentsByIp::new(),
5863
}
5964
}
6065
}
@@ -138,19 +143,41 @@ async fn main() -> anyhow::Result<()> {
138143
#[axum::debug_handler]
139144
async fn onchain_handler(
140145
Extension(state): Extension<AppState>,
146+
headers: HeaderMap,
141147
Json(payload): Json<OnchainRequest>,
142148
) -> Result<Json<OnchainResponse>, AppError> {
143-
let res = pay_onchain(state, payload).await?;
149+
// Extract the X-Forwarded-For header
150+
let x_forwarded_for = headers
151+
.get("x-forwarded-for")
152+
.and_then(|x| HeaderValue::to_str(x).ok())
153+
.unwrap_or("Unknown");
154+
155+
if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 {
156+
return Err(AppError::new("Too many payments"));
157+
}
158+
159+
let res = pay_onchain(state, x_forwarded_for, payload).await?;
144160

145161
Ok(Json(res))
146162
}
147163

148164
#[axum::debug_handler]
149165
async fn lightning_handler(
150166
Extension(state): Extension<AppState>,
167+
headers: HeaderMap,
151168
Json(payload): Json<LightningRequest>,
152169
) -> Result<Json<LightningResponse>, AppError> {
153-
let payment_hash = pay_lightning(state, &payload.bolt11).await?;
170+
// Extract the X-Forwarded-For header
171+
let x_forwarded_for = headers
172+
.get("x-forwarded-for")
173+
.and_then(|x| HeaderValue::to_str(x).ok())
174+
.unwrap_or("Unknown");
175+
176+
if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 {
177+
return Err(AppError::new("Too many payments"));
178+
}
179+
180+
let payment_hash = pay_lightning(state, x_forwarded_for, &payload.bolt11).await?;
154181

155182
Ok(Json(LightningResponse { payment_hash }))
156183
}
@@ -178,10 +205,21 @@ pub struct LnurlWithdrawParams {
178205
#[axum::debug_handler]
179206
async fn lnurlw_callback_handler(
180207
Extension(state): Extension<AppState>,
208+
headers: HeaderMap,
181209
Query(payload): Query<LnurlWithdrawParams>,
182210
) -> Result<Json<Value>, Json<Value>> {
183211
if payload.k1 == "k1" {
184-
pay_lightning(state, &payload.pr)
212+
// Extract the X-Forwarded-For header
213+
let x_forwarded_for = headers
214+
.get("x-forwarded-for")
215+
.and_then(|x| HeaderValue::to_str(x).ok())
216+
.unwrap_or("Unknown");
217+
218+
if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 {
219+
return Err(Json(json!({"status": "ERROR", "reason": "Incorrect k1"})));
220+
}
221+
222+
pay_lightning(state, x_forwarded_for, &payload.pr)
185223
.await
186224
.map_err(|e| Json(json!({"status": "ERROR", "reason": format!("{e}")})))?;
187225
Ok(Json(json!({"status": "OK"})))
@@ -203,16 +241,33 @@ async fn bolt11_handler(
203241
#[axum::debug_handler]
204242
async fn channel_handler(
205243
Extension(state): Extension<AppState>,
244+
headers: HeaderMap,
206245
Json(payload): Json<ChannelRequest>,
207246
) -> Result<Json<ChannelResponse>, AppError> {
208-
let txid = open_channel(state, payload.clone()).await?;
247+
// Extract the X-Forwarded-For header
248+
let x_forwarded_for = headers
249+
.get("x-forwarded-for")
250+
.and_then(|x| HeaderValue::to_str(x).ok())
251+
.unwrap_or("Unknown");
252+
253+
if state.payments.get_total_payments(x_forwarded_for).await > MAX_SEND_AMOUNT * 10 {
254+
return Err(AppError::new("Too many payments"));
255+
}
256+
257+
let txid = open_channel(state, x_forwarded_for, payload).await?;
209258

210259
Ok(Json(ChannelResponse { txid }))
211260
}
212261

213262
// Make our own error that wraps `anyhow::Error`.
214263
struct AppError(anyhow::Error);
215264

265+
impl AppError {
266+
fn new(msg: &'static str) -> Self {
267+
AppError(anyhow::anyhow!(msg))
268+
}
269+
}
270+
216271
// Tell axum how to convert `AppError` into a response.
217272
impl IntoResponse for AppError {
218273
fn into_response(self) -> Response {

src/onchain.rs

+6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub struct OnchainResponse {
1919

2020
pub async fn pay_onchain(
2121
state: AppState,
22+
x_forwarded_for: &str,
2223
payload: OnchainRequest,
2324
) -> anyhow::Result<OnchainResponse> {
2425
let res = {
@@ -61,6 +62,11 @@ pub async fn pay_onchain(
6162
wallet_client.send_coins(req).await?.into_inner()
6263
};
6364

65+
state
66+
.payments
67+
.add_payment(x_forwarded_for, amount.to_sat())
68+
.await;
69+
6470
OnchainResponse {
6571
txid: resp.txid,
6672
address: address.to_string(),

src/payments.rs

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
use std::collections::{HashMap, VecDeque};
2+
use std::sync::Arc;
3+
use std::time::{Duration, Instant};
4+
use tokio::sync::Mutex;
5+
6+
const CACHE_DURATION: Duration = Duration::from_secs(86_400); // 1 day
7+
8+
struct Payment {
9+
time: Instant,
10+
amount: u64,
11+
}
12+
13+
struct PaymentTracker {
14+
payments: VecDeque<Payment>,
15+
}
16+
17+
impl PaymentTracker {
18+
pub fn new() -> Self {
19+
PaymentTracker {
20+
payments: VecDeque::new(),
21+
}
22+
}
23+
24+
pub fn add_payment(&mut self, amount: u64) {
25+
let now = Instant::now();
26+
let payment = Payment { time: now, amount };
27+
28+
self.payments.push_back(payment);
29+
}
30+
31+
fn clean_old_payments(&mut self) {
32+
let now = Instant::now();
33+
while let Some(payment) = self.payments.front() {
34+
if now.duration_since(payment.time) < CACHE_DURATION {
35+
break;
36+
}
37+
38+
self.payments.pop_front();
39+
}
40+
}
41+
42+
pub fn sum_payments(&mut self) -> u64 {
43+
self.clean_old_payments();
44+
self.payments.iter().map(|p| p.amount).sum()
45+
}
46+
}
47+
48+
#[derive(Clone)]
49+
pub struct PaymentsByIp {
50+
trackers: Arc<Mutex<HashMap<String, PaymentTracker>>>,
51+
}
52+
53+
impl PaymentsByIp {
54+
pub fn new() -> Self {
55+
PaymentsByIp {
56+
trackers: Arc::new(Mutex::new(HashMap::new())),
57+
}
58+
}
59+
60+
// Add a payment to the tracker for the given ip
61+
pub async fn add_payment(&self, ip: &str, amount: u64) {
62+
let mut trackers = self.trackers.lock().await;
63+
let tracker = trackers
64+
.entry(ip.to_string())
65+
.or_insert_with(PaymentTracker::new);
66+
tracker.add_payment(amount);
67+
}
68+
69+
// Get the total amount of payments for the given ip
70+
pub async fn get_total_payments(&self, ip: &str) -> u64 {
71+
let mut trackers = self.trackers.lock().await;
72+
match trackers.get_mut(ip) {
73+
Some(tracker) => tracker.sum_payments(),
74+
None => 0,
75+
}
76+
}
77+
}

0 commit comments

Comments
 (0)