Skip to content

Commit daa6743

Browse files
committed
refactor: move session management & port allocation to backend
- Remove the in‑process `activeSessions` map and its cleanup logic from the TypeScript side. - Introduce new Tauri commands in Rust: - `get_random_port` – picks an unused port using a seeded RNG and checks availability. - `find_session_by_model` – returns the `SessionInfo` for a given model ID. - `get_loaded_models` – returns a list of currently loaded model IDs. - Update the extension’s TypeScript code to use these commands via `invoke`: - `findSessionByModel`, `load`, `unload`, `chat`, `getLoadedModels`, and `embed` now operate asynchronously and query the backend. - Remove the old `is_port_available` command and the custom port‑checking loop. - Simplify `onUnload` – session termination is now handled by the backend. - Drop unused helpers (`sleep`, `waitForModelLoad`) and related port‑availability code. - Add missing Rust imports (`rand::{StdRng,Rng,SeedableRng}`, `HashSet`) and improve error handling. - Register the new commands in `src-tauri/src/lib.rs` (replace `is_port_available` with the three new commands). This refactor centralises session state and port allocation in the Rust backend, eliminates duplicated logic, and resolves race conditions around model loading and session cleanup.
1 parent 1f1605b commit daa6743

File tree

3 files changed

+101
-96
lines changed
  • extensions/llamacpp-extension/src
  • src-tauri/src

3 files changed

+101
-96
lines changed

extensions/llamacpp-extension/src/index.ts

Lines changed: 25 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ export default class llamacpp_extension extends AIEngine {
145145
readonly providerId: string = 'llamacpp'
146146

147147
private config: LlamacppConfig
148-
private activeSessions: Map<number, SessionInfo> = new Map()
149148
private providerPath!: string
150149
private apiSecret: string = 'JustAskNow'
151150
private pendingDownloads: Map<string, Promise<void>> = new Map()
@@ -771,16 +770,6 @@ export default class llamacpp_extension extends AIEngine {
771770

772771
override async onUnload(): Promise<void> {
773772
// Terminate all active sessions
774-
for (const [_, sInfo] of this.activeSessions) {
775-
try {
776-
await this.unload(sInfo.model_id)
777-
} catch (error) {
778-
logger.error(`Failed to unload model ${sInfo.model_id}:`, error)
779-
}
780-
}
781-
782-
// Clear the sessions map
783-
this.activeSessions.clear()
784773
}
785774

786775
onSettingUpdate<T>(key: string, value: T): void {
@@ -1104,75 +1093,21 @@ export default class llamacpp_extension extends AIEngine {
11041093
* Function to find a random port
11051094
*/
11061095
private async getRandomPort(): Promise<number> {
1107-
const MAX_ATTEMPTS = 20000
1108-
let attempts = 0
1109-
1110-
while (attempts < MAX_ATTEMPTS) {
1111-
const port = Math.floor(Math.random() * 1000) + 3000
1112-
1113-
const isAlreadyUsed = Array.from(this.activeSessions.values()).some(
1114-
(info) => info.port === port
1115-
)
1116-
1117-
if (!isAlreadyUsed) {
1118-
const isAvailable = await invoke<boolean>('is_port_available', { port })
1119-
if (isAvailable) return port
1120-
}
1121-
1122-
attempts++
1123-
}
1124-
1125-
throw new Error('Failed to find an available port for the model to load')
1126-
}
1127-
1128-
private async sleep(ms: number): Promise<void> {
1129-
return new Promise((resolve) => setTimeout(resolve, ms))
1130-
}
1131-
1132-
private async waitForModelLoad(
1133-
sInfo: SessionInfo,
1134-
timeoutMs = 240_000
1135-
): Promise<void> {
1136-
await this.sleep(500) // Wait before first check
1137-
const start = Date.now()
1138-
while (Date.now() - start < timeoutMs) {
1139-
try {
1140-
const res = await fetch(`http://localhost:${sInfo.port}/health`)
1141-
1142-
if (res.status === 503) {
1143-
const body = await res.json()
1144-
const msg = body?.error?.message ?? 'Model loading'
1145-
logger.info(`waiting for model load... (${msg})`)
1146-
} else if (res.ok) {
1147-
const body = await res.json()
1148-
if (body.status === 'ok') {
1149-
return
1150-
} else {
1151-
logger.warn('Unexpected OK response from /health:', body)
1152-
}
1153-
} else {
1154-
logger.warn(`Unexpected status ${res.status} from /health`)
1155-
}
1156-
} catch (e) {
1157-
await this.unload(sInfo.model_id)
1158-
throw new Error(`Model appears to have crashed: ${e}`)
1159-
}
1160-
1161-
await this.sleep(800) // Retry interval
1096+
try {
1097+
const port = await invoke<number>('get_random_port')
1098+
return port
1099+
} catch {
1100+
logger.error('Unable to find a suitable port')
1101+
throw new Error('Unable to find a suitable port for model')
11621102
}
1163-
1164-
await this.unload(sInfo.model_id)
1165-
throw new Error(
1166-
`Timed out loading model after ${timeoutMs}... killing llamacpp`
1167-
)
11681103
}
11691104

11701105
override async load(
11711106
modelId: string,
11721107
overrideSettings?: Partial<LlamacppConfig>,
11731108
isEmbedding: boolean = false
11741109
): Promise<SessionInfo> {
1175-
const sInfo = this.findSessionByModel(modelId)
1110+
const sInfo = await this.findSessionByModel(modelId)
11761111
if (sInfo) {
11771112
throw new Error('Model already loaded!!')
11781113
}
@@ -1342,11 +1277,6 @@ export default class llamacpp_extension extends AIEngine {
13421277
libraryPath,
13431278
args,
13441279
})
1345-
1346-
// Store the session info for later use
1347-
this.activeSessions.set(sInfo.pid, sInfo)
1348-
await this.waitForModelLoad(sInfo)
1349-
13501280
return sInfo
13511281
} catch (error) {
13521282
logger.error('Error in load command:\n', error)
@@ -1355,13 +1285,12 @@ export default class llamacpp_extension extends AIEngine {
13551285
}
13561286

13571287
override async unload(modelId: string): Promise<UnloadResult> {
1358-
const sInfo: SessionInfo = this.findSessionByModel(modelId)
1288+
const sInfo: SessionInfo = await this.findSessionByModel(modelId)
13591289
if (!sInfo) {
13601290
throw new Error(`No active session found for model: ${modelId}`)
13611291
}
13621292
const pid = sInfo.pid
13631293
try {
1364-
this.activeSessions.delete(pid)
13651294

13661295
// Pass the PID as the session_id
13671296
const result = await invoke<UnloadResult>('unload_llama_model', {
@@ -1373,13 +1302,11 @@ export default class llamacpp_extension extends AIEngine {
13731302
logger.info(`Successfully unloaded model with PID ${pid}`)
13741303
} else {
13751304
logger.warn(`Failed to unload model: ${result.error}`)
1376-
this.activeSessions.set(sInfo.pid, sInfo)
13771305
}
13781306

13791307
return result
13801308
} catch (error) {
13811309
logger.error('Error in unload command:', error)
1382-
this.activeSessions.set(sInfo.pid, sInfo)
13831310
return {
13841311
success: false,
13851312
error: `Failed to unload model: ${error}`,
@@ -1502,17 +1429,21 @@ export default class llamacpp_extension extends AIEngine {
15021429
}
15031430
}
15041431

1505-
private findSessionByModel(modelId: string): SessionInfo | undefined {
1506-
return Array.from(this.activeSessions.values()).find(
1507-
(session) => session.model_id === modelId
1508-
)
1432+
private async findSessionByModel(modelId: string): Promise<SessionInfo> {
1433+
try {
1434+
let sInfo = await invoke<SessionInfo>('find_session_by_model', {modelId})
1435+
return sInfo
1436+
} catch (e) {
1437+
logger.error(e)
1438+
throw new Error(e)
1439+
}
15091440
}
15101441

15111442
override async chat(
15121443
opts: chatCompletionRequest,
15131444
abortController?: AbortController
15141445
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
1515-
const sessionInfo = this.findSessionByModel(opts.model)
1446+
const sessionInfo = await this.findSessionByModel(opts.model)
15161447
if (!sessionInfo) {
15171448
throw new Error(`No active session found for model: ${opts.model}`)
15181449
}
@@ -1528,7 +1459,6 @@ export default class llamacpp_extension extends AIEngine {
15281459
throw new Error('Model appears to have crashed! Please reload!')
15291460
}
15301461
} else {
1531-
this.activeSessions.delete(sessionInfo.pid)
15321462
throw new Error('Model have crashed! Please reload!')
15331463
}
15341464
const baseUrl = `http://localhost:${sessionInfo.port}/v1`
@@ -1577,11 +1507,13 @@ export default class llamacpp_extension extends AIEngine {
15771507
}
15781508

15791509
override async getLoadedModels(): Promise<string[]> {
1580-
let lmodels: string[] = []
1581-
for (const [_, sInfo] of this.activeSessions) {
1582-
lmodels.push(sInfo.model_id)
1583-
}
1584-
return lmodels
1510+
try {
1511+
let models: string[] = await invoke<string[]>('get_loaded_models')
1512+
return models
1513+
} catch (e) {
1514+
logger.error(e)
1515+
throw new Error(e)
1516+
}
15851517
}
15861518

15871519
async getDevices(): Promise<DeviceList[]> {
@@ -1611,7 +1543,7 @@ export default class llamacpp_extension extends AIEngine {
16111543
}
16121544

16131545
async embed(text: string[]): Promise<EmbeddingResponse> {
1614-
let sInfo = this.findSessionByModel('sentence-transformer-mini')
1546+
let sInfo = await this.findSessionByModel('sentence-transformer-mini')
16151547
if (!sInfo) {
16161548
const downloadedModelList = await this.list()
16171549
if (

src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use base64::{engine::general_purpose, Engine as _};
22
use hmac::{Hmac, Mac};
3+
use rand::{rngs::StdRng, Rng, SeedableRng};
34
use serde::{Deserialize, Serialize};
45
use sha2::Sha256;
6+
use std::collections::HashSet;
57
use std::path::PathBuf;
68
use std::process::Stdio;
79
use std::time::Duration;
@@ -724,11 +726,80 @@ pub async fn is_process_running(pid: i32, state: State<'_, AppState>) -> Result<
724726
}
725727

726728
// check port availability
727-
#[tauri::command]
728-
pub fn is_port_available(port: u16) -> bool {
729+
fn is_port_available(port: u16) -> bool {
729730
std::net::TcpListener::bind(("127.0.0.1", port)).is_ok()
730731
}
731732

733+
#[tauri::command]
734+
pub async fn get_random_port(state: State<'_, AppState>) -> Result<u16, String> {
735+
const MAX_ATTEMPTS: u32 = 20000;
736+
let mut attempts = 0;
737+
let mut rng = StdRng::from_entropy();
738+
739+
// Get all active ports from sessions
740+
let map = state.llama_server_process.lock().await;
741+
742+
let used_ports: HashSet<u16> = map
743+
.values()
744+
.filter_map(|session| {
745+
// Convert valid ports to u16 (filter out placeholder ports like -1)
746+
if session.info.port > 0 && session.info.port <= u16::MAX as i32 {
747+
Some(session.info.port as u16)
748+
} else {
749+
None
750+
}
751+
})
752+
.collect();
753+
754+
drop(map); // unlock early
755+
756+
while attempts < MAX_ATTEMPTS {
757+
let port = rng.gen_range(3000..4000);
758+
759+
if used_ports.contains(&port) {
760+
attempts += 1;
761+
continue;
762+
}
763+
764+
if is_port_available(port) {
765+
return Ok(port);
766+
}
767+
768+
attempts += 1;
769+
}
770+
771+
Err("Failed to find an available port for the model to load".into())
772+
}
773+
774+
// find session
775+
#[tauri::command]
776+
pub async fn find_session_by_model(
777+
model_id: String,
778+
state: State<'_, AppState>,
779+
) -> Result<Option<SessionInfo>, String> {
780+
let map = state.llama_server_process.lock().await;
781+
782+
let session_info = map
783+
.values()
784+
.find(|backend_session| backend_session.info.model_id == model_id)
785+
.map(|backend_session| backend_session.info.clone());
786+
787+
Ok(session_info)
788+
}
789+
790+
// get running models
791+
#[tauri::command]
792+
pub async fn get_loaded_models(state: State<'_, AppState>) -> Result<Vec<String>, String> {
793+
let map = state.llama_server_process.lock().await;
794+
795+
let model_ids = map
796+
.values()
797+
.map(|backend_session| backend_session.info.model_id.clone())
798+
.collect();
799+
800+
Ok(model_ids)
801+
}
802+
732803
// tests
733804
//
734805
#[cfg(test)]

src-tauri/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ pub fn run() {
9595
core::utils::extensions::inference_llamacpp_extension::server::load_llama_model,
9696
core::utils::extensions::inference_llamacpp_extension::server::unload_llama_model,
9797
core::utils::extensions::inference_llamacpp_extension::server::get_devices,
98-
core::utils::extensions::inference_llamacpp_extension::server::is_port_available,
98+
core::utils::extensions::inference_llamacpp_extension::server::get_random_port,
99+
core::utils::extensions::inference_llamacpp_extension::server::find_session_by_model,
100+
core::utils::extensions::inference_llamacpp_extension::server::get_loaded_models,
99101
core::utils::extensions::inference_llamacpp_extension::server::generate_api_key,
100102
core::utils::extensions::inference_llamacpp_extension::server::is_process_running,
101103
])

0 commit comments

Comments
 (0)