Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mistralrs-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ fn main() -> anyhow::Result<()> {
),
}
};
let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config)
let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false)
.with_no_prefix_cache(true)
.with_disable_eos_stop(true)
.build();
Expand Down
75 changes: 75 additions & 0 deletions mistralrs-core/src/engine/logger.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;

use tracing::info;

pub struct IntervalLogger {
enable_logging: Arc<AtomicBool>,
prefix_cache_hits: Arc<AtomicUsize>,
tokens_processed: Arc<AtomicUsize>,
total_new_seqs: Arc<AtomicUsize>,
}

impl IntervalLogger {
/// Starts an interval logger. Call `begin_logging` to begin the logging process.
pub fn new(interval: Duration) -> Self {
let prefix_cache_hits = Arc::new(AtomicUsize::new(0));
let tokens_processed = Arc::new(AtomicUsize::new(0));
let total_new_seqs = Arc::new(AtomicUsize::new(0));
let enable_logging = Arc::new(AtomicBool::new(false));

let t_prefix_cache_hits = prefix_cache_hits.clone();
let t_tokens_processed = tokens_processed.clone();
let t_total_new_seqs = total_new_seqs.clone();
let t_enable_logging = enable_logging.clone();
thread::spawn(move || {
// Wait
while !t_enable_logging.load(Ordering::Relaxed) {}

// Start the actual logging
loop {
thread::sleep(interval);

let total_new_seqs = t_total_new_seqs.load(Ordering::Relaxed);
let prefix_cache_hits = t_prefix_cache_hits.load(Ordering::Relaxed);
let tokens_processed = t_tokens_processed.swap(0, Ordering::Relaxed);

if total_new_seqs != 0 && tokens_processed != 0 {
info!(
"Throughput (T/s) {:.2}, Prefix cache hitrate {:.2}%",
tokens_processed as f64 / interval.as_secs_f64(),
100. * prefix_cache_hits as f64 / total_new_seqs as f64,
);
}
}
});

Self {
prefix_cache_hits,
tokens_processed,
total_new_seqs,
enable_logging,
}
}

pub fn enable_logging(&self) {
self.enable_logging.store(true, Ordering::Relaxed);
}

pub fn add_tokens_processed(&self, num_tokens: usize) {
self.tokens_processed
.fetch_add(num_tokens, Ordering::Relaxed);
}

pub fn add_new_sequence(&self) {
self.total_new_seqs.fetch_add(1, Ordering::Relaxed);
}

pub fn add_prefix_cache_hit(&self) {
self.prefix_cache_hits.fetch_add(1, Ordering::Relaxed);
}
}
86 changes: 29 additions & 57 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use candle_core::Tensor;
use either::Either;
use interprocess::local_socket::{traits::Listener, ListenerOptions};
use llguidance::toktrie::TokEnv;
use logger::IntervalLogger;
use once_cell::sync::Lazy;
use std::{
collections::HashMap,
Expand All @@ -10,7 +11,7 @@ use std::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::{Instant, SystemTime, UNIX_EPOCH},
time::{Duration, Instant, SystemTime, UNIX_EPOCH},
};
use tokio::sync::{mpsc::Receiver, Mutex};

Expand Down Expand Up @@ -43,6 +44,8 @@ use crate::{
Constraint, StopTokens,
};

mod logger;

pub enum EngineInstruction {
Terminate,
}
Expand All @@ -66,6 +69,7 @@ pub struct Engine {
is_debug: bool,
disable_eos_stop: bool,
throughput_logging_enabled: bool,
logger: IntervalLogger,
}

impl Engine {
Expand Down Expand Up @@ -98,10 +102,15 @@ impl Engine {
is_debug: DEBUG.load(Ordering::Relaxed),
disable_eos_stop,
throughput_logging_enabled,
logger: IntervalLogger::new(Duration::from_secs(5)),
}
}

pub async fn run(&mut self) {
if self.throughput_logging_enabled {
self.logger.enable_logging();
}

let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
let mut last_completion_ids: Vec<usize> = vec![];
'lp: loop {
Expand Down Expand Up @@ -135,10 +144,7 @@ impl Engine {
SchedulerOutput::DefaultScheduler {
output: mut scheduled,
} => {
let mut prompt_ts = None;
let mut completion_ts = None;
if scheduled.completion.len() > 0 {
let throughput_start = Instant::now();
let current_completion_ids: Vec<usize> =
scheduled.completion.iter().map(|seq| *seq.id()).collect();
let res = {
Expand Down Expand Up @@ -201,16 +207,12 @@ impl Engine {
self.prefix_cacher
);

let throughput_end = Instant::now();
#[allow(clippy::cast_precision_loss)]
if self.throughput_logging_enabled {
completion_ts = Some(
scheduled.completion.len() as f64
/ throughput_end
.duration_since(throughput_start)
.as_secs_f64(),
);
}
let total_processed_tokens: usize = scheduled
.completion
.iter()
.map(|seq| seq.get_toks().len())
.sum();
self.logger.add_tokens_processed(total_processed_tokens);

last_completion_ids = current_completion_ids;
}
Expand Down Expand Up @@ -277,17 +279,12 @@ impl Engine {
self.prefix_cacher
);

#[allow(clippy::cast_precision_loss)]
if self.throughput_logging_enabled {
prompt_ts = Some(
scheduled
.prompt
.iter()
.map(|seq| seq.get_toks().len())
.sum::<usize>() as f64
/ prompt_exec_time.as_secs_f64(),
);
}
let total_processed_tokens: usize = scheduled
.prompt
.iter()
.map(|seq| seq.get_toks().len())
.sum();
self.logger.add_tokens_processed(total_processed_tokens);

for seq in scheduled.prompt.iter_mut() {
match seq.sequence_stepping_type() {
Expand Down Expand Up @@ -338,21 +335,6 @@ impl Engine {
}
}

if self.throughput_logging_enabled {
match (prompt_ts, completion_ts) {
(Some(prompt), Some(completion)) => {
info!("Throughput (scheduler V1): Prompt: {prompt} T/s Completion {completion} T/s");
}
(None, Some(completion)) => {
info!("Throughput (scheduler V1): Completion {completion} T/s");
}
(Some(prompt), None) => {
info!("Throughput (scheduler V1): Prompt: {prompt} T/s");
}
(None, None) => (),
}
}

if scheduled.prompt.len() == 0
&& scheduled.completion.len() == 0
&& self.scheduler.waiting_len() == 0
Expand All @@ -369,8 +351,6 @@ impl Engine {
}
SchedulerOutput::PagedAttention { mut output } => {
if !output.scheduled.is_empty() {
let throughput_start = Instant::now();

let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();

let mut guards = output
Expand Down Expand Up @@ -428,6 +408,10 @@ impl Engine {
self.prefix_cacher
);

let total_processed_tokens: usize =
guards.iter().map(|seq| seq.get_toks().len()).sum();
self.logger.add_tokens_processed(total_processed_tokens);

if self.is_debug {
let ms_from_last_run = run_start.elapsed().as_secs_f64();
let total_len = guards.len();
Expand All @@ -453,21 +437,6 @@ impl Engine {
}
}

let throughput_end = Instant::now();
#[allow(clippy::cast_precision_loss)]
if self.throughput_logging_enabled {
let n_toks = if is_prompt {
guards.iter().map(|seq| seq.get_toks().len()).sum::<usize>()
} else {
guards.len()
};
let ts = n_toks as f64
/ throughput_end
.duration_since(throughput_start)
.as_secs_f64();
info!("Throughput (scheduler V2): {ts} T/s");
}

if is_prompt {
for mut seq in guards {
let now = SystemTime::now()
Expand Down Expand Up @@ -965,7 +934,10 @@ impl Engine {
request.return_raw_logits,
eos_toks,
);
self.logger.add_new_sequence();
let seq = if let Some(prefill_cache) = prefill_cache.clone() {
self.logger.add_prefix_cache_hit();

seq.prefill_v2(
prefill_cache.normal,
prefill_cache.toks,
Expand Down
17 changes: 8 additions & 9 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,15 @@ pub struct MistralRsBuilder {
no_prefix_cache: Option<bool>,
prefix_cache_n: Option<usize>,
disable_eos_stop: Option<bool>,
throughput_logging_enabled: Option<()>,
throughput_logging_enabled: bool,
}

impl MistralRsBuilder {
pub fn new(pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>, method: SchedulerConfig) -> Self {
pub fn new(
pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
method: SchedulerConfig,
throughput_logging: bool,
) -> Self {
Self {
pipeline,
method,
Expand All @@ -200,7 +204,7 @@ impl MistralRsBuilder {
no_prefix_cache: None,
prefix_cache_n: None,
disable_eos_stop: None,
throughput_logging_enabled: None,
throughput_logging_enabled: throughput_logging,
}
}
pub fn with_log(mut self, log: String) -> Self {
Expand Down Expand Up @@ -231,10 +235,6 @@ impl MistralRsBuilder {
self.disable_eos_stop = Some(disable_eos_stop);
self
}
pub fn with_throughput_logging(mut self) -> Self {
self.throughput_logging_enabled = Some(());
self
}

pub fn build(self) -> Arc<MistralRs> {
MistralRs::new(self)
Expand Down Expand Up @@ -272,7 +272,6 @@ impl MistralRs {
let no_prefix_cache = no_prefix_cache.unwrap_or(false);
let prefix_cache_n = prefix_cache_n.unwrap_or(16);
let disable_eos_stop = disable_eos_stop.unwrap_or(false);
let throughput_logging_enabled = throughput_logging_enabled.is_some();

let reboot_state = RebootState {
pipeline: pipeline.clone(),
Expand Down Expand Up @@ -443,7 +442,7 @@ impl MistralRs {
.duration_since(UNIX_EPOCH)
.expect("Time travel has occurred!")
.as_secs(),
next_request_id: Mutex::new(RefCell::new(0)),
next_request_id: Mutex::new(RefCell::new(1)),
reboot_state,
engine_handler: RwLock::new(engine_handler),
category,
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ impl Runner {
),
}
};
let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config)
let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false)
.with_no_kv_cache(no_kv_cache)
.with_prefix_cache_n(prefix_cache_n)
.build();
Expand Down
14 changes: 4 additions & 10 deletions mistralrs-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,24 +479,18 @@ async fn main() -> Result<()> {
}
};
// Throughput logging in the server
let builder = MistralRsBuilder::new(pipeline, scheduler_config)
let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, !args.interactive_mode)
.with_opt_log(args.log)
.with_truncate_sequence(args.truncate_sequence)
.with_no_kv_cache(args.no_kv_cache)
.with_prefix_cache_n(args.prefix_cache_n);
.with_prefix_cache_n(args.prefix_cache_n)
.build();

if args.interactive_mode {
interactive_mode(builder.build(), args.throughput_log).await;
interactive_mode(mistralrs, args.throughput_log).await;
return Ok(());
}

let builder = if args.throughput_log {
builder.with_throughput_logging()
} else {
builder
};
let mistralrs = builder.build();

// Needs to be after the .build call as that is where the daemon waits.
let setting_server = if !args.interactive_mode {
let port = args.port.expect("Interactive mode was not specified, so expected port to be specified. Perhaps you forgot `-i` or `--port`?");
Expand Down
7 changes: 4 additions & 3 deletions mistralrs/src/anymoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ impl AnyMoeModelBuilder {
},
};

let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
.with_no_kv_cache(self.base.no_kv_cache)
.with_no_prefix_cache(self.base.prefix_cache_n.is_none());
let mut runner =
MistralRsBuilder::new(pipeline, scheduler_method, self.base.throughput_logging)
.with_no_kv_cache(self.base.no_kv_cache)
.with_no_prefix_cache(self.base.prefix_cache_n.is_none());

if let Some(n) = self.base.prefix_cache_n {
runner = runner.with_prefix_cache_n(n)
Expand Down
2 changes: 1 addition & 1 deletion mistralrs/src/diffusion_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl DiffusionModelBuilder {
method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
};

let runner = MistralRsBuilder::new(pipeline, scheduler_method);
let runner = MistralRsBuilder::new(pipeline, scheduler_method, false);

Ok(Model::new(runner.build()))
}
Expand Down
Loading
Loading