Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion mistralrs-core/src/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ pub struct Sampler {
top_p: f64,
min_p: f64,
logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
/// Cached Gumbel noise tensor to avoid reallocating it.
gumbel_cache: Arc<Mutex<Option<Tensor>>>,
}

#[cfg_attr(feature = "pyo3_macros", pyclass)]
Expand Down Expand Up @@ -253,6 +255,7 @@ impl Sampler {
top_p,
min_p,
logits_processors,
gumbel_cache: Arc::new(Mutex::new(None)),
})
}

Expand Down Expand Up @@ -414,7 +417,26 @@ impl Sampler {
probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
}

let next_token = probs.argmax(D::Minus1)?.to_scalar::<u32>()?;
// Sample using the Gumbel-max trick fully on-device.
let log_probs = probs.log()?;
// Generate cached Gumbel noise (-log(-log(u))) once.
let gumbel = {
let mut guard = self.gumbel_cache.lock().unwrap();
if guard.is_none() {
let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
let noise = uniform
.clamp(1e-20, 1.0)?
.log()? // ln(u)
.neg()? // -ln(u)
.log()? // ln(-ln(u))
.neg()?; // -ln(-ln(u))
*guard = Some(noise);
}
guard.as_ref().unwrap().clone()
};

let gumbel_logits = (&log_probs + &gumbel)?;
let next_token = gumbel_logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
Comment on lines +420 to +439
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Critical issue: Cache lacks shape and device validation.

The Gumbel sampling implementation is mathematically correct, but the caching strategy has a critical flaw: there's no validation that the cached Gumbel noise tensor matches the current log_probs in terms of shape or device.

This could cause runtime errors when:

  1. Different vocabulary sizes are used between calls
  2. Tensors are on different devices
  3. Batch sizes change

Apply this fix to validate cache compatibility:

 // Generate cached Gumbel noise (-log(-log(u))) once.
 let gumbel = {
     let mut guard = self.gumbel_cache.lock().unwrap();
-    if guard.is_none() {
+    let needs_regeneration = guard.as_ref().map_or(true, |cached| {
+        cached.shape() != log_probs.shape() || cached.device() != log_probs.device()
+    });
+    
+    if needs_regeneration {
         let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
         let noise = uniform
             .clamp(1e-20, 1.0)?
             .log()? // ln(u)
             .neg()? // -ln(u)
             .log()? // ln(-ln(u))
             .neg()?; // -ln(-ln(u))
         *guard = Some(noise);
     }
     guard.as_ref().unwrap().clone()
 };
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Sample using the Gumbel-max trick fully on-device.
let log_probs = probs.log()?;
// Generate cached Gumbel noise (-log(-log(u))) once.
let gumbel = {
let mut guard = self.gumbel_cache.lock().unwrap();
if guard.is_none() {
let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
let noise = uniform
.clamp(1e-20, 1.0)?
.log()? // ln(u)
.neg()? // -ln(u)
.log()? // ln(-ln(u))
.neg()?; // -ln(-ln(u))
*guard = Some(noise);
}
guard.as_ref().unwrap().clone()
};
let gumbel_logits = (&log_probs + &gumbel)?;
let next_token = gumbel_logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
// Sample using the Gumbel-max trick fully on-device.
let log_probs = probs.log()?;
// Generate cached Gumbel noise (-log(-log(u))) once.
let gumbel = {
let mut guard = self.gumbel_cache.lock().unwrap();
let needs_regeneration = guard.as_ref().map_or(true, |cached| {
cached.shape() != log_probs.shape() || cached.device() != log_probs.device()
});
if needs_regeneration {
let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
let noise = uniform
.clamp(1e-20, 1.0)?
.log()? // ln(u)
.neg()? // -ln(u)
.log()? // ln(-ln(u))
.neg()?; // -ln(-ln(u))
*guard = Some(noise);
}
guard.as_ref().unwrap().clone()
};
let gumbel_logits = (&log_probs + &gumbel)?;
let next_token = gumbel_logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
🤖 Prompt for AI Agents
In mistralrs-core/src/sampler.rs around lines 420 to 439, the cached Gumbel
noise tensor is used without validating that its shape and device match the
current log_probs tensor, which can cause runtime errors if vocabulary size,
device, or batch size changes. To fix this, add checks after locking the cache
to verify if the cached tensor exists and whether its shape and device match
log_probs; if not, regenerate and replace the cached noise tensor accordingly
before using it.


// Extract the top‑n log‑probs if the caller asked for them.
let (top_logprobs, logprob) = if return_logprobs {
Expand Down
Loading