Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "vllm-rs"
version = "0.8.10"
version = "0.8.11"
edition = "2021"
default-run = "vllm-rs"

Expand All @@ -15,7 +15,8 @@ itertools = "0.13.0"
akin = "0.4.0"
indicatif = "0.17.11"
serde_json = "1.0.108"
llguidance = "0.6"
llguidance = { version = "1.2.0", default-features = false, features = ["lark"] }
toktrie_hf_tokenizers = "1.2.0"
toktrie = "1.4"
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
tokio = { version = "1.38.0", features = ["sync"] }
Expand Down
26 changes: 25 additions & 1 deletion docs/claude_code.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,22 @@ python3 -m vllm_rs.server --m miromind-ai/MiroThinker-v1.5-30B --d 0,1 --server

## 2) Configure Claude Code

Install claude code

```shell
npm install -g @anthropic-ai/claude-code
```

Export config

```shell
export ANTHROPIC_BASE_URL="http://127.0.0.1:8000"
export ANTHROPIC_AUTH_TOKEN="sk-dummy"
export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1
```

Or make it permanent

Set `~/.claude/settings.json` (or copy from `example/claude/settings.json`):

```json
Expand All @@ -32,7 +48,15 @@ Set `~/.claude/settings.json` (or copy from `example/claude/settings.json`):
}
```

## 3) Verify with a direct request (optional)
## 3) Run Claude Code

run claude code

```shell
claude
```

or verify with a direct request (optional)

```bash
curl http://127.0.0.1:8000/v1/messages \
Expand Down
4 changes: 1 addition & 3 deletions docs/goose.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@ python3 -m vllm_rs.server --m Qwen/Qwen3-30B-A3B-Instruct-2507 --d 0,1 --server

## 2) Configure Goose

### Download and install Goose: https://block.github.io/goose/docs/getting-started/installation/

```shell
# For non-UI system,
export GOOSE_DISABLE_KEYRING=1
```

Export empty API KEY

```shell
export VLLM_API_KEY="empty"
```

### Download and install Goose: https://block.github.io/goose/docs/getting-started/installation/

### Configure goose with `Custom Providers` and API key `empty`

Expand Down
26 changes: 22 additions & 4 deletions src/core/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::transfer::PdRole;
use crate::transfer::Transfer;
use crate::utils::chat_template::Message;
use crate::utils::config::{EngineConfig, EosTokenId, ModelType, SamplingParams};
use crate::utils::guidance::load_toktrie_from_path;
use crate::utils::guidance::{build_llg_factory, load_toktrie_from_path};
use crate::utils::heartbeat::heartbeat_worker;
use crate::utils::image::{get_image_config, ImageData, ImageProcessConfig};
use crate::utils::kvcache_allocator::KVCacheAllocator;
Expand Down Expand Up @@ -106,9 +106,23 @@ impl LLMEngine {
pub fn new(econfig: &EngineConfig, dtype: DType) -> Result<Arc<RwLock<Self>>> {
let (model_pathes, is_gguf, mut config, config_tokenizer, tokenizer, mut generation_cfg) =
init_config_tokenizer(econfig)?;
let toktrie = load_toktrie_from_path(&model_pathes.get_tokenizer_filename()).map(Arc::new);
let toktrie = match load_toktrie_from_path(&model_pathes.get_tokenizer_filename()) {
Ok(trie) => Some(Arc::new(trie)),
Err(e) => {
crate::log_warn!("Failed to load tokenizer trie: {}", e);
None
}
};
let llg_factory = match build_llg_factory(tokenizer.clone(), config.vocab_size) {
Ok(f) => Some(f),
Err(e) => {
crate::log_warn!("Failed to build llguidance factory: {}", e);
None
}
};

if toktrie.is_none() {
crate::log_warn!("Guided decoding disabled: tokenizer trie unavailable.");
crate::log_warn!("Guided decoding (legacy) disabled: tokenizer trie unavailable.");
}

let stop_flag = Arc::new(AtomicBool::new(false));
Expand Down Expand Up @@ -196,7 +210,7 @@ impl LLMEngine {
device.clone(),
reporter,
transfer,
toktrie.clone(),
llg_factory.clone(),
None,
)?;

Expand Down Expand Up @@ -1507,4 +1521,8 @@ impl LLMEngine {
pub fn get_chat_template(&self) -> ChatTemplate {
self.template.clone()
}

pub fn template_supports_tools(&self) -> bool {
self.template.supports_tools()
}
}
151 changes: 147 additions & 4 deletions src/core/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::server::EmbeddingStrategy;
use crate::transfer::Transfer;
#[cfg(all(feature = "cuda", feature = "graph"))]
use crate::utils::graph::{CudaGraphFn, CudaGraphWrapper, GraphCapturer, ModelFn};
use crate::utils::guidance::GuidanceState;
use crate::utils::guidance::{GuidanceState, ParserFactory};
use crate::utils::image::compute_image_slice;
use crate::utils::logits_processor::{LogitsProcessor, Sampling};
use crate::utils::progress::ProgressLike;
Expand All @@ -28,10 +28,9 @@ use attention_rs::InputMetadata;
use candle_core::{DType, Device, Result, Tensor, D};
use interprocess::local_socket::Stream as LocalStream;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::collections::{hash_map::Entry, HashMap, HashSet};
use std::rc::Rc;
use std::sync::{Arc, Mutex, MutexGuard};
use toktrie::TokTrie;

/// Cached sampling parameters computed once during prefill, reused during decode
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -82,6 +81,9 @@ pub struct ModelRunner {
cached_sampling: RwLock<Option<CachedSamplingParams>>,
seq_tokens: RwLock<HashMap<usize, Vec<u32>>>,
guidance_states: RwLock<HashMap<usize, GuidanceState>>,
guidance_failed: RwLock<HashSet<usize>>,
guidance_mismatch: RwLock<HashSet<usize>>,
llg_factory: Option<Arc<ParserFactory>>,
transfer: Option<Arc<Transfer>>,
/// Whether this runner is on the first rank (for logging)
is_first_rank: bool,
Expand All @@ -101,7 +103,7 @@ impl ModelRunner {
device: Device,
reporter: Arc<RwLock<Box<dyn ProgressLike>>>,
transfer: Option<Arc<Transfer>>,
toktrie: Option<Arc<TokTrie>>,
llg_factory: Option<Arc<ParserFactory>>,
stream: Option<LocalStream>,
) -> Result<Self> {
let model = crate::build_model!(
Expand Down Expand Up @@ -200,6 +202,30 @@ impl ModelRunner {
} else {
econfig.seed.unwrap()
};
let model_vocab_size = match &model {
Model::Qwen3(model) => model.get_vocab_size(),
Model::Qwen3MoE(model) => model.get_vocab_size(),
Model::LLaMa(model) => model.get_vocab_size(),
Model::Phi4(model) => model.get_vocab_size(),
Model::GLM4(model) => model.get_vocab_size(),
Model::GLM4MoE(model) => model.get_vocab_size(),
Model::Mistral3VL(model) => model.get_vocab_size(),
Model::Gemma3(model) => model.get_vocab_size(),
Model::Qwen3VL(model) => model.get_vocab_size(),
};

if let Some(factory) = &llg_factory {
let llg_vocab_size = factory.tok_env().tok_trie().vocab_size();
if llg_vocab_size != model_vocab_size {
crate::log_warn!(
"llguidance vocab size {} does not match model vocab size {} for {:?}.",
llg_vocab_size,
model_vocab_size,
model_type
);
}
}

Ok(Self {
model,
gpu_kv_cache: Arc::new(Mutex::new(gpu_kv_cache)),
Expand All @@ -218,6 +244,9 @@ impl ModelRunner {
cached_sampling: RwLock::new(None),
seq_tokens: RwLock::new(HashMap::new()),
guidance_states: RwLock::new(HashMap::new()),
guidance_failed: RwLock::new(HashSet::new()),
guidance_mismatch: RwLock::new(HashSet::new()),
llg_factory,
transfer,
is_first_rank: comm.rank() == 0,
model_type,
Expand Down Expand Up @@ -700,6 +729,104 @@ impl ModelRunner {
logits.to_owned()
};

let logits = if let Some(factory) = &self.llg_factory {
let mut guidance_states = self.guidance_states.write();
let mut guidance_failed = self.guidance_failed.write();
let mut guidance_mismatch = self.guidance_mismatch.write();
let mut modified = false;
let vocab_size = logits.dim(1)?;
// We only materialize logits on CPU if at least one constraint mask applies.

// We'll collect masks first to minimize holding locks or complex logic inside the loop
let mut masks = Vec::new(); // (seq_index, seq_id, mask)

for (i, id) in seq_ids.iter().enumerate() {
let seq_constraint = match &seqs {
Seqs::SeqRefs(refs) => &refs[i].sampling_params.constraint,
Seqs::DecodeVec(vec) => &vec[i].sampling_params.constraint,
};

if guidance_failed.contains(id) {
continue;
}

if let Some(constraint) = seq_constraint {
let state = match guidance_states.entry(*id) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => {
match GuidanceState::new(factory.clone(), constraint) {
Ok(state) => entry.insert(state),
Err(err) => {
guidance_failed.insert(*id);
crate::log_warn!(
"[Seq {}] Failed to create guidance state: {}. Disabling constraints for this sequence.",
id,
err
);
continue;
}
}
}
};

if let Ok(Some(mask)) = state.compute_mask() {
masks.push((i, *id, mask));
modified = true;
}
}
}

if modified {
// Now we must convert to Vec, modify, and update logits
let mut logits_vec = logits.flatten_all()?.to_vec1::<f32>()?;

for (seq_idx, seq_id, mask) in masks {
let start = seq_idx * vocab_size;
let end = start + vocab_size;
let row = &mut logits_vec[start..end];
let mask_len = mask.len();

// Apply mask: set disallowed to -inf
// This iterates entire vocab, but check is fast
if mask_len == 0 {
if guidance_failed.insert(seq_id) {
crate::log_warn!(
"[Seq {}] Guidance mask length is 0. Disabling constraints for this sequence.",
seq_id
);
}
continue;
}

if mask_len != vocab_size && guidance_mismatch.insert(seq_id) {
crate::log_warn!(
"[Seq {}] Guidance mask size {} does not match vocab size {}. Clamping mask application.",
seq_id,
mask_len,
vocab_size
);
}

let apply_len = std::cmp::min(vocab_size, mask_len);
for tok in 0..apply_len {
if !mask.is_allowed(tok as u32) {
row[tok] = f32::NEG_INFINITY;
}
}
if mask_len < vocab_size {
for tok in mask_len..vocab_size {
row[tok] = f32::NEG_INFINITY;
}
}
}
Tensor::from_vec(logits_vec, logits.shape(), &self.device)?
} else {
logits
}
} else {
logits
};

let tokens = self
.logit_processor
.sample_with_strategy(&logits, &cached_params.sampling)?;
Expand All @@ -718,6 +845,18 @@ impl ModelRunner {
}
}
}

// Commit tokens to guidance states
if let Some(_) = &self.llg_factory {
let mut guidance_states = self.guidance_states.write();
for (i, id) in seq_ids.iter().enumerate() {
if let Some(state) = guidance_states.get_mut(id) {
if !state.is_finished() {
let _ = state.commit_token(tokens[i]);
}
}
}
}
Ok(tokens)
}

Expand All @@ -726,6 +865,10 @@ impl ModelRunner {
let _ = seq_tokens.remove(&id);
let mut guidance_states = self.guidance_states.write();
let _ = guidance_states.remove(&id);
let mut guidance_failed = self.guidance_failed.write();
let _ = guidance_failed.remove(&id);
let mut guidance_mismatch = self.guidance_mismatch.write();
let _ = guidance_mismatch.remove(&id);
}

pub fn get_model_vocab_size(&self) -> usize {
Expand Down
Loading