Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ schemars = "0.8.22"
serde_yaml = "0.9.34"
serde_plain = "1.0.2"
as-any = "0.3.2"
llguidance = { version = "0.7.29", default-features = false, features = ["lark"] }
toktrie_hf_tokenizers = "0.7.29"
llguidance = { git = "https://github.com/guidance-ai/llguidance.git", version = "0.7.29", default-features = false, features = ["lark"], rev = "2ce5ab8" }
toktrie_hf_tokenizers = {git = "https://github.com/guidance-ai/llguidance.git", version = "0.7.29", rev = "2ce5ab8" }
objc = { version = "0.2.7" }
serde-big-array = "0.5.1"
interprocess = "2.2.3"
Expand Down
36 changes: 36 additions & 0 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2127,6 +2127,42 @@ impl Mlp {
})
}

pub fn new_merged(
vb: ShardedVarBuilder,
hidden_size: usize,
intermediate_size: usize,
chunks: usize,
quantization_config: &Option<QuantizedConfig>,
hidden_act: Activation,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
assert!(chunks == 2, "Only gate_up_proj merge is supported!");
let gate_up_projs = ColumnParallelLayer::new_merged(
hidden_size,
intermediate_size * 2,
2,
quantization_config,
false,
comm,
vb.pp("gate_up_proj"),
)?;

Ok(Self {
gate: gate_up_projs[0].to_owned(),
up: gate_up_projs[1].to_owned(),
down: RowParallelLayer::new(
intermediate_size,
hidden_size,
quantization_config,
false,
comm,
vb.pp("down_proj"),
)?,
act: hidden_act,
params: vec![hidden_size, intermediate_size],
})
}

pub fn replicate(
params: &[usize],
vb: ShardedVarBuilder,
Expand Down
Loading
Loading