Skip to content

Commit

Permalink
fix(sb_ai): Sessions are dropped while still in use (#443)
Browse files Browse the repository at this point in the history
* fix(sb_ai): error handling on `run session`

- removing `unwrap()` from `run session`

* fix(sb_ai): cleanup logic, passing Session ref to js land

- adding currently active session's refs to `OpState`

* stamp: clippy

* feat(k6): add scenario for `ort-rust-backend`

- adding a scenario that uses `transformers.js` + `ort rust backend`
  • Loading branch information
kallebysantos authored Nov 13, 2024
1 parent 390a5c1 commit e522925
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 11 deletions.
27 changes: 20 additions & 7 deletions crates/sb_ai/onnxruntime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,29 @@ pub(crate) mod onnx;
pub(crate) mod session;
mod tensor;

use std::{borrow::Cow, collections::HashMap};
use std::{borrow::Cow, cell::RefCell, collections::HashMap, rc::Rc, sync::Arc};

use anyhow::Result;
use deno_core::op2;
use anyhow::{anyhow, Result};
use deno_core::{op2, OpState};

use model_session::{ModelInfo, ModelSession};
use ort::Session;
use tensor::{JsTensor, ToJsTensor};

#[op2]
#[to_v8]
pub fn op_sb_ai_ort_init_session(#[buffer] model_bytes: &[u8]) -> Result<ModelInfo> {
pub fn op_sb_ai_ort_init_session(
state: Rc<RefCell<OpState>>,
#[buffer] model_bytes: &[u8],
) -> Result<ModelInfo> {
let mut state = state.borrow_mut();
let model_info = ModelSession::from_bytes(model_bytes)?;

let mut sessions = { state.try_take::<Vec<Arc<Session>>>().unwrap_or_default() };

sessions.push(model_info.inner());
state.put(sessions);

Ok(model_info.info())
}

Expand All @@ -25,10 +35,11 @@ pub fn op_sb_ai_ort_run_session(
#[string] model_id: String,
#[serde] input_values: HashMap<String, JsTensor>,
) -> Result<HashMap<String, ToJsTensor>> {
let model = ModelSession::from_id(model_id).unwrap();
let model = ModelSession::from_id(model_id.to_owned())
.ok_or(anyhow!("could not found session for id={model_id:?}"))?;

let model_session = model.inner();

// println!("{model_session:?}");
let input_values = input_values
.into_iter()
.map(|(key, value)| {
Expand All @@ -44,7 +55,9 @@ pub fn op_sb_ai_ort_run_session(
// We need to `pop` over outputs to get 'value' ownership, since keys are attached to 'model_session' lifetime
// it can't be iterated with `into_iter()`
for _ in 0..outputs.len() {
let (key, value) = outputs.pop_first().unwrap();
let (key, value) = outputs.pop_first().ok_or(anyhow!(
"could not retrieve output value from model session"
))?;

let value = ToJsTensor::from_ort_tensor(value)?;

Expand Down
2 changes: 2 additions & 0 deletions crates/sb_ai/onnxruntime/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ pub fn cleanup() -> Result<usize, AnyError> {
let mut to_be_removed = vec![];

for (key, session) in &mut *guard {
// Since we're currently referencing the session at this point
// It also will increments the counter, so we need to check: counter > 1
if Arc::strong_count(session) > 1 {
continue;
}
Expand Down
19 changes: 19 additions & 0 deletions examples/k6-ort-rust-backend/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { env, pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/[email protected]';

// Ensure we do not use browser cache
env.useBrowserCache = false;
env.allowLocalModels = false;

const pipe = await pipeline('feature-extraction', 'supabase/gte-small', { device: 'auto' });

Deno.serve(async (req) => {
const payload = await req.json();
const text_for_embedding = payload.text_for_embedding;

// Generate embedding
const embedding = await pipe(text_for_embedding, { pooling: 'mean', normalize: true });

return Response.json({
length: embedding.ort_tensor.size,
});
});
4 changes: 0 additions & 4 deletions examples/main/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ console.log('main function started');

// cleanup unused sessions every 30s
// setInterval(async () => {
// const { activeUserWorkersCount } = await EdgeRuntime.getRuntimeMetrics();
// if (activeUserWorkersCount > 0) {
// return;
// }
// try {
// const cleanupCount = await EdgeRuntime.ai.tryCleanupUnusedSession();
// if (cleanupCount == 0) {
Expand Down
70 changes: 70 additions & 0 deletions k6/specs/ort-rust-backend.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
./scripts/run.sh
#!/usr/bin/env bash
GIT_V_TAG=0.1.1 cargo build --features cli/tracing && \
EDGE_RUNTIME_WORKER_POOL_SIZE=8 \
EDGE_RUNTIME_PORT=9998 RUST_BACKTRACE=full ./target/debug/edge-runtime "$@" start \
--main-service ./examples/main \
--event-worker ./examples/event-manager
*/

import http from "k6/http";

import { check, fail } from "k6";
import { Options } from "k6/options";

import { target } from "../config";

/** @ts-ignore */
import { randomIntBetween } from "https://jslib.k6.io/k6-utils/1.2.0/index.js";
import { MSG_CANCELED } from "../constants";

export const options: Options = {
scenarios: {
simple: {
executor: "constant-vus",
vus: 12,
duration: "3m",
}
}
};

const GENERATORS = import("../generators");

export async function setup() {
const pkg = await GENERATORS;
return {
words: pkg.makeText(1000)
}
}

export default function ort_rust_backend(data: { words: string[] }) {
const wordIdx = randomIntBetween(0, data.words.length - 1);

console.debug(`WORD[${wordIdx}]: ${data.words[wordIdx]}`);
const res = http.post(
`${target}/k6-ort-rust-backend`,
JSON.stringify({
"text_for_embedding": data.words[wordIdx]
})
);

const isOk = check(res, {
"status is 200": r => r.status === 200
});

const isRequestCancelled = check(res, {
"request cancelled": r => {
const msg = r.json("msg");
return r.status === 500 && msg === MSG_CANCELED;
}
});

if (!isOk && !isRequestCancelled) {
console.log(res.body);
fail("unexpected response");
}
}

0 comments on commit e522925

Please sign in to comment.