Skip to content

Commit

Permalink
Add nonzero layer (#402)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jun 8, 2024
1 parent 8b01a90 commit e4a4947
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 81 deletions.
156 changes: 154 additions & 2 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::{
collections::HashMap,
ops::Mul,
ops::{BitAnd, Mul},
str::FromStr,
sync::{
atomic::{AtomicBool, Ordering},
Expand All @@ -12,9 +12,10 @@ use std::{

use candle_core::{
quantized::{gguf_file, QMatMul, QTensor},
DType, Device, IndexOp, Result, Shape, Tensor, D,
DType, Device, IndexOp, Result, Shape, Tensor, WithDType, D,
};
use candle_nn::{Linear, Module, VarBuilder};
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};

pub use crate::layers_masker::CausalMasker;
pub use crate::layers_utils::{flash_attn, repeat_kv};
Expand Down Expand Up @@ -544,6 +545,76 @@ impl Module for QLinear {
}
}

/// Equivalent functions to `torch.nonzero`
pub struct Nonzero;

impl Nonzero {
/// Equivalent to: `torch.nonzero(lt & gt, as_tuple=False)`
///
/// This performs the operation on the CPU and as such triggers a device synchronization.
/// The output tensor is `DType::U8` and the device will be the same as the input.
/// The input tensors must be of the same data type and on the same device.
pub fn nonzero_and<T: WithDType>(&self, lt: &Tensor, gt: &Tensor) -> Result<Tensor>
where
for<'a> &'a T: BitAnd<T, Output = T>,
{
let dev = lt.device();
let lt = lt.to_vec2::<T>()?;
let gt = gt.to_vec2::<T>()?;
// lt & gt
let res = lt
.par_iter()
.zip(gt)
.enumerate()
.flat_map(|(i, (lt_row, gt_row))| {
lt_row
.par_iter()
.zip(gt_row)
.enumerate()
.filter_map(|(j, (lt, gt))| {
if (lt & gt) != T::zero() {
Some(vec![i as u32, j as u32])
} else {
None
}
})
.collect::<Vec<_>>()
})
.map(|x| Tensor::from_slice(&x, (x.len(),), dev))
.collect::<Result<Vec<_>>>()?;
Tensor::stack(&res, 0)
}

/// Equivalent to: `torch.nonzero(x, as_tuple=False)`
///
/// This performs the operation on the CPU and as such triggers a device synchronization.
/// The output tensor is `DType::U8` and the device will be the same as the input.
pub fn nonzero<T: WithDType>(&self, x: &Tensor) -> Result<Tensor> {
let dev = x.device();
let x = x.to_vec2::<T>()?;
// lt & gt
let res = x
.par_iter()
.enumerate()
.flat_map(|(i, x_row)| {
x_row
.par_iter()
.enumerate()
.filter_map(|(j, x)| {
if *x != T::zero() {
Some(vec![i as u32, j as u32])
} else {
None
}
})
.collect::<Vec<_>>()
})
.map(|x| Tensor::from_slice(&x, (x.len(),), dev))
.collect::<Result<Vec<_>>>()?;
Tensor::stack(&res, 0)
}
}

mod tests {

#[test]
Expand Down Expand Up @@ -617,4 +688,85 @@ mod tests {
)
}
}

#[test]
fn nonzero_and() {
use crate::layers::Nonzero;
use candle_core::{Device, Tensor};

let input1 = Tensor::from_vec(
vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
(10,),
&Device::Cpu,
)
.unwrap();
let input2 = Tensor::from_vec(
vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
(10,),
&Device::Cpu,
)
.unwrap();
let input = Tensor::stack(&[input1, input2], 0).unwrap();

let lt = input.lt(0.0).unwrap();
let gt = input.gt(-10.0).unwrap();
let res = Nonzero
.nonzero_and::<u8>(&lt, &gt)
.unwrap()
.to_vec2::<u32>()
.unwrap();

assert_eq!(
res,
[
[0, 3],
[0, 4],
[0, 5],
[0, 6],
[1, 0],
[1, 3],
[1, 5],
[1, 6]
]
);
}

#[test]
fn nonzero() {
use crate::layers::Nonzero;
use candle_core::{Device, Tensor};

let input1 =
Tensor::from_vec(vec![1i64, 2, 0, -1, -1, 0, 0, 0, 5, 7], (10,), &Device::Cpu).unwrap();
let input2 = Tensor::from_vec(
vec![-1i64, 0, 3, -1, 1, -1, 0, 0, 0, 0],
(10,),
&Device::Cpu,
)
.unwrap();
let input = Tensor::stack(&[input1, input2], 0).unwrap();

let res = Nonzero
.nonzero::<i64>(&input)
.unwrap()
.to_vec2::<u32>()
.unwrap();

assert_eq!(
res,
[
[0, 0],
[0, 1],
[0, 3],
[0, 4],
[0, 8],
[0, 9],
[1, 0],
[1, 2],
[1, 3],
[1, 4],
[1, 5]
]
);
}
}
80 changes: 3 additions & 77 deletions mistralrs-core/src/vision_models/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@ use candle_core::{
};
use candle_nn::{linear_b, linear_no_bias, VarBuilder};
use either::Either;
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};

use crate::{
device_map::DeviceMapper,
layers::{
repeat_kv, CausalMasker, FusedBiasLinear, MatMul, PhiRopeConfig, PhiRotaryEmbedding,
RmsNorm, ScaledDotProductAttention,
repeat_kv, CausalMasker, FusedBiasLinear, MatMul, Nonzero, PhiRopeConfig,
PhiRotaryEmbedding, RmsNorm, ScaledDotProductAttention,
},
pipeline::{extract_logits, Cache, IsqModel, Phi3RopeScaling, VisionModel},
serde_default_fn,
Expand Down Expand Up @@ -352,35 +351,6 @@ impl DecoderLayer {

const MAX_INPUT_ID: f64 = 1e9;

/// torch.nonzero(lt & gt, as_tuple=False)
fn nonzero_between_as_tuple_false(lt: &Tensor, gt: &Tensor) -> Result<Tensor> {
let dev = lt.device();
let lt = lt.to_vec2::<u8>()?;
let gt = gt.to_vec2::<u8>()?;
// lt & gt
let res = lt
.par_iter()
.zip(gt)
.enumerate()
.flat_map(|(i, (lt_row, gt_row))| {
lt_row
.par_iter()
.zip(gt_row)
.enumerate()
.filter_map(|(j, (lt, gt))| {
if (lt & gt) != 0 {
Some(vec![i as u32, j as u32])
} else {
None
}
})
.collect::<Vec<_>>()
})
.map(|x| Tensor::from_slice(&x, (x.len(),), dev))
.collect::<Result<Vec<_>>>()?;
Tensor::stack(&res, 0)
}

#[derive(Debug)]
struct EmbeddingLayers(Vec<Box<dyn ModuleWithMetadata>>);

Expand Down Expand Up @@ -566,7 +536,7 @@ impl ImageEmbedding {
let input_ids_lt = input_ids.lt(0.0f64)?;
let input_ids_gt = input_ids.gt(-MAX_INPUT_ID)?;
// positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False)
let positions = nonzero_between_as_tuple_false(&input_ids_lt, &input_ids_gt)?;
let positions = Nonzero.nonzero_and::<u8>(&input_ids_lt, &input_ids_gt)?;

let target_dev = self.layers.0[0].device();
let target_dtype = self.layers.0[0].dtype();
Expand Down Expand Up @@ -952,47 +922,3 @@ impl VisionModel for Model {
true
}
}

mod tests {

#[test]
fn nonzero() {
use super::nonzero_between_as_tuple_false;
use candle_core::{Device, Tensor};

let input1 = Tensor::from_vec(
vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
(10,),
&Device::Cpu,
)
.unwrap();
let input2 = Tensor::from_vec(
vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
(10,),
&Device::Cpu,
)
.unwrap();
let input = Tensor::stack(&[input1, input2], 0).unwrap();

let lt = input.lt(0.0).unwrap();
let gt = input.gt(-10.0).unwrap();
let res = nonzero_between_as_tuple_false(&lt, &gt)
.unwrap()
.to_vec2::<u32>()
.unwrap();

assert_eq!(
res,
[
[0, 3],
[0, 4],
[0, 5],
[0, 6],
[1, 0],
[1, 3],
[1, 5],
[1, 6]
]
);
}
}
2 changes: 1 addition & 1 deletion mistralrs/examples/gguf_locally/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ fn setup() -> anyhow::Result<Arc<MistralRs>> {
Some("chat_templates/mistral.json".to_string()),
None,
".".to_string(),
vec!["mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string()],
"mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string(),
)
.build();
// Load, into a Pipeline
Expand Down
2 changes: 1 addition & 1 deletion mistralrs/examples/quantized/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn setup() -> anyhow::Result<Arc<MistralRs>> {
None,
Some("mistralai/Mistral-7B-Instruct-v0.1".to_string()),
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF".to_string(),
vec!["mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string()],
"mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string(),
)
.build();
// Load, into a Pipeline
Expand Down

0 comments on commit e4a4947

Please sign in to comment.