Skip to content

Conversation

@guoqingbao
Copy link
Contributor

@guoqingbao guoqingbao commented Jun 6, 2025

This PR adds support for the latest GLM4 model (e.g., GLM-4-9B-0414).

Updates

The problem of llguidance/toktrie_hf_tokenizers has been resolved by the owner of llguidance. The GLM4 model was also revised to support merged gate_up_proj, which is now working:

Tested cases

cargo run --features cuda -- -i plain -m /home/data/GLM-4-9B-0414/

ISQ:

cargo run --features cuda -- -i --isq Q4K plain -m /home/data/GLM-4-9B-0414/

[Deleted: However, it is not yet functional because the project currently uses an external crate, toktrie_hf_tokenizers, which is not compatible with the GLM tokenizer. This crate is used to convert a standard tokenizer to the llguidance environment (via build_llg_factory).

The PR is marked as a draft and will be ready to merge once the compatibility issue with toktrie_hf_tokenizers is resolved. I will open an issue in the llguidance/toktrie_hf_tokenizers repository to track this problem.]

Summary by CodeRabbit

  • New Features

    • Added support for the GLM4 transformer-based language model with rotary embeddings, quantization, and mixture-of-experts (MoE) capabilities.
    • Introduced GLM4 model loader integrated into the normal loading pipeline for seamless model loading and inference.
    • Expanded architecture and loader options to include GLM4 across APIs and user interfaces.
    • Added merged gate and up projection layer construction for improved model efficiency.
  • Improvements

    • Improved error handling when loading optional generation configuration files.
    • Updated dependencies to reference specific Git revisions for better consistency.
    • Updated documentation to include GLM4 usage examples and support details.
    • Fixed chat template processing to better support GLM4 models by removing specific metadata placeholders.

@coderabbitai
Copy link

coderabbitai bot commented Jun 6, 2025

Walkthrough

A new GLM4 transformer model implementation was added, including its configuration, rotary embeddings, attention, decoder layers, forward passes, quantization, and mixture-of-experts support. The model is integrated into the loader system with new loader structs, enums, and public exports. Additional changes include merged layer constructors, a trait lifetime signature update, dependency source updates in Cargo.toml, and a chat template fix for GLM4.

Changes

File(s) Change Summary
mistralrs-core/src/models/glm4.rs Added GLM4 model: config, rotary embeddings, attention, decoder, model struct, trait impls, MoE support.
mistralrs-core/src/models/mod.rs Declared new glm4 module.
mistralrs-core/src/pipeline/loaders/normal_loaders.rs Added GLM4Loader, extended enums and traits, integrated GLM4 into loader dispatch.
mistralrs-core/src/pipeline/loaders/mod.rs, mistralrs-core/src/pipeline/mod.rs Publicly re-exported GLM4Loader.
mistralrs-core/src/pipeline/normal.rs Supported GLM4Loader in loader builder; improved error handling for optional config deserialization.
mistralrs-core/src/pipeline/isq.rs Changed trait impl signature for CowBytesView to use anonymous lifetime.
mistralrs-pyo3/src/which.rs Added GLM4 variant to Architecture enum and conversion to NormalLoaderType.
mistralrs-core/src/layers.rs Added Mlp::new_merged constructor for merged gate/up projections.
mistralrs-quant/src/distributed/layers.rs Added ColumnParallelLayer::new_merged for multi-chunk column parallel layer creation.
Cargo.toml Updated llguidance and toktrie_hf_tokenizers dependencies to use a specific Git revision.
mistralrs-core/src/pipeline/chat_template.rs Fixed GLM4 chat template processing by removing meta Jinja statements and placeholders.
docs/GLM4.md Added documentation for GLM4 model usage with HTTP and Python APIs.
README.md Added GLM4 to supported architectures and quantization support tables.
mistralrs-pyo3/mistralrs.pyi Added GLM4 enum member to Architecture in Python bindings.
mistralrs-quant/README.md Renamed "Gptq" quantization layer to "Gptq/Awq" with updated implementation reference.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Loader
    participant GLM4Loader
    participant Model
    participant Device

    User->>Loader: Request GLM4 model load
    Loader->>GLM4Loader: Dispatch load request
    GLM4Loader->>Model: Deserialize config, construct Model
    Model->>Device: Initialize layers, rotary embeddings, cache
    Model-->>GLM4Loader: Return ready model
    Loader-->>User: Provide loaded GLM4 model
Loading

Poem

In the warren, code grows anew,
GLM4 hops in, with layers to view!
Rotary spins and MoE blends,
Quantized lines where logic bends.
With loaders and traits,
This model awaits—
A carrot for progress, as Rust code transcends! 🥕

✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@github-actions
Copy link

github-actions bot commented Jun 6, 2025

Code Metrics Report
===============================================================================
 Language            Files        Lines         Code     Comments       Blanks
===============================================================================
 C Header                3           62           53            0            9
 CSS                     1          473          408           14           51
 Dockerfile              1           44           24           10           10
 HTML                    1           78           64            5            9
 JavaScript              7         1397         1068          180          149
 JSON                   18          282          279            0            3
 Makefile                1            6            5            0            1
 Python                 92         4530         3813          185          532
 Shell                   1           63           26           18           19
 Plain Text              3         3723            0         2413         1310
 TOML                   23          844          777           11           56
 YAML                    2           21           19            2            0
-------------------------------------------------------------------------------
 Jupyter Notebooks       3            0            0            0            0
 |- Markdown             2           77           32           31           14
 |- Python               2          205          178            1           26
 (Total)                            282          210           32           40
-------------------------------------------------------------------------------
 Markdown               68         6102            0         4600         1502
 |- BASH                14          203          190            7            6
 |- JSON                 9          436          436            0            0
 |- Python              12          391          326           24           41
 |- Rust                29         1123          945           30          148
 |- TOML                 2           75           63            0           12
 (Total)                           8330         1960         4661         1709
-------------------------------------------------------------------------------
 Rust                  392       138350       122808         3173        12369
 |- Markdown           185         3956          284         3154          518
 (Total)                         142306       123092         6327        12887
===============================================================================
 Total                 616       155975       129344        10611        16020
===============================================================================

@guoqingbao
Copy link
Contributor Author

@EricLBuehler Do we need to make llguidance a compulsory dependency? The GLM4 tokenizer works normally in candle-vllm without relying on toktrie_hf_tokenizers.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (5)
mistralrs-core/Cargo.toml (2)

82-82: Gate toktrie_hf_tokenizers behind a feature flag
Since upstream GLM4 tokenization isn’t supported yet, consider making this dependency optional and gating it behind a Cargo feature to avoid blocking builds for other models until compatibility is resolved.


83-85: Track GLM4 tokenizer compatibility issue
Add a TODO to open and link an issue in the llguidance/toktrie_hf_tokenizers repo so you can track resolving the GLM4 tokenization blocker.

mistralrs-core/src/pipeline/loaders/mod.rs (1)

23-26: Expose GLM4Loader in public re-exports
Nice inclusion of the new loader. For consistency, consider alphabetizing this list so GLM4Loader sits alongside other G*Loaders. Also add an integration test to verify that AutoNormalLoader correctly discovers and instantiates the GLM4 loader.

mistralrs-pyo3/src/which.rs (1)

25-25: Include GLM4 in Python Architecture enum
Good extension of the Python API. Don’t forget to update any README or doc examples to illustrate how to choose Architecture.GLM4.

mistralrs-core/src/models/glm4.rs (1)

746-760: Track the unimplemented xlora_forward method.

The xlora_forward method is currently unimplemented. Since this PR is marked as draft, this is acceptable, but consider adding a TODO comment to track this for future implementation.

 fn xlora_forward(
     &self,
     _input_ids: &Tensor,
     _input_ids_full: &Tensor,
     _seqlen_offsets: &[usize],
     _seqlen_offsets_full: &[usize],
     _no_kv_cache: bool,
     _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
     _context_lens: Vec<(usize, usize)>,
     _position_ids: Vec<usize>,
     _flash_params: &FlashParams,
     _flash_params_full: &FlashParams,
 ) -> Result<Tensor> {
-    unimplemented!()
+    // TODO: Implement xlora_forward for GLM4 model
+    unimplemented!("xlora_forward is not yet implemented for GLM4")
 }

Would you like me to open an issue to track the implementation of xlora support for GLM4?

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cfd1e89 and 3870068.

⛔ Files ignored due to path filters (1)
  • Cargo.lock is excluded by !**/*.lock
📒 Files selected for processing (9)
  • mistralrs-core/Cargo.toml (1 hunks)
  • mistralrs-core/src/models/glm4.rs (1 hunks)
  • mistralrs-core/src/models/mod.rs (1 hunks)
  • mistralrs-core/src/pipeline/isq.rs (1 hunks)
  • mistralrs-core/src/pipeline/loaders/mod.rs (1 hunks)
  • mistralrs-core/src/pipeline/loaders/normal_loaders.rs (6 hunks)
  • mistralrs-core/src/pipeline/mod.rs (1 hunks)
  • mistralrs-core/src/pipeline/normal.rs (3 hunks)
  • mistralrs-pyo3/src/which.rs (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (8)
  • GitHub Check: Clippy
  • GitHub Check: Docs
  • GitHub Check: Check (ubuntu-latest, stable)
  • GitHub Check: Check (macOS-latest, stable)
  • GitHub Check: Check (windows-latest, stable)
  • GitHub Check: Test Suite (macOS-latest, stable)
  • GitHub Check: Test Suite (windows-latest, stable)
  • GitHub Check: Test Suite (ubuntu-latest, stable)
🔇 Additional comments (13)
mistralrs-core/Cargo.toml (1)

81-82: Verify version compatibility for llguidance and toktrie_hf_tokenizers
You’ve bumped both to 0.7.29—confirm that the llguidance features (default-features = false, features = ["lark"]) include the ParserFactory/toktrie support required by toktrie_hf_tokenizers, and that the new toktrie_hf_tokenizers version is compatible with your GLM4 tokenizer code.

mistralrs-core/src/pipeline/isq.rs (1)

32-32: LGTM! Lifetime simplification improves readability.

The change from explicit generic lifetime <'a> to anonymous lifetime '_ is a good stylistic improvement since the lifetime parameter isn't explicitly used in the implementation beyond what's required by the trait.

mistralrs-core/src/pipeline/normal.rs (2)

14-17: LGTM! GLM4Loader import follows existing patterns.

The addition of GLM4Loader to the imports is consistent with how other model loaders are imported and organized.


225-225: LGTM! GLM4 integration matches existing loader pattern.

The GLM4 case is properly integrated into the loader type matching, following the same pattern as other model loaders.

mistralrs-core/src/pipeline/loaders/normal_loaders.rs (6)

168-169: LGTM!

The GLM4 enum variant follows the established pattern with proper serde rename annotation.


191-191: LGTM!

The mapping from "Glm4ForCausalLM" to GLM4 variant is consistent with the naming pattern used by other models.


216-219: LGTM!

The FromStr implementation correctly maps "glm4" to the GLM4 variant and the error message is properly updated to include the new architecture.


240-240: LGTM!

The Display implementation correctly formats GLM4 as "glm4", maintaining consistency with the FromStr implementation.


291-291: LGTM!

The AutoNormalLoader correctly dispatches GLM4 loader type to a new GLM4Loader instance.


3269-3269: Architecture-specific layer norm calculation looks correct.

The calculation accounts for GLM4's unique architecture with 3 post-attention layer norms (standard + post_self_attn_layernorm + post_mlp_layernorm) as indicated by the comment. This differs from other models but appears intentional.

mistralrs-core/src/models/glm4.rs (3)

275-292: Verify paged attention functionality with GLM4.

The paged attention implementation looks correct, but given the PR's draft status due to tokenizer compatibility issues, ensure that the paged attention mechanism works correctly with GLM4's specific requirements once the tokenizer issues are resolved.


1-899: Well-structured GLM4 model implementation.

The implementation is comprehensive and follows established patterns from other models in the codebase. The code properly handles quantization, distributed computation, and MoE support. Given the PR's draft status due to external tokenizer compatibility issues, the implementation looks ready from a code perspective.


325-332:

❓ Verification inconclusive

Verify the GLM4 architecture requires 4 normalization layers.

The decoder layer implements 4 RmsNorm layers (input_layernorm, post_attention_layernorm, post_self_attn_layernorm, post_mlp_layernorm), which is unusual compared to typical transformer architectures that use 2 normalizations per layer. Please confirm this matches the GLM4 paper/reference implementation.

Also applies to: 374-384, 406-422


🌐 Web query:

What is the architecture of GLM4 transformer model and how many normalization layers does each decoder layer have?

💡 Result:

GLM-4 is a large-scale autoregressive language model developed by Tsinghua University's KEG Lab. It is based on the Transformer architecture, which consists of encoder and decoder layers. In the decoder layers of GLM-4, each layer typically includes two normalization steps:

  1. Layer Normalization after Self-Attention: This normalization is applied after the self-attention mechanism to stabilize training and improve convergence.

  2. Layer Normalization after Feed-Forward Network: Another normalization is applied after the feed-forward network within the decoder layer to further enhance training stability.

These normalization steps are part of the residual connections and layer normalization strategy commonly used in Transformer architectures to facilitate training and improve performance.

For a more detailed understanding of Transformer architectures and the role of normalization layers, you can refer to the Transformer (deep learning architecture) article on Wikipedia.


Verify GLM-4 DecoderLayer normalization count

The DecoderLayer currently defines four RMSNorm layers per layer, but the GLM-4 paper/reference implementation uses only two normalizations (one after self-attention and one after the feed-forward network). Please confirm whether the additional norms (input_layernorm and post_attention_layernorm/post_self_attn_layernorm) are intentional and documented in the GLM-4 architecture, or if they should be consolidated to match the reference.

Affected locations:

  • mistralrs-core/src/models/glm4.rs: lines 325–332
  • mistralrs-core/src/models/glm4.rs: lines 374–384
  • mistralrs-core/src/models/glm4.rs: lines 406–422

pub(crate) mod deepseek3;
pub(crate) mod gemma;
pub(crate) mod gemma2;
pub(crate) mod glm4;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Ensure glm4.rs is implemented and tested
You added pub(crate) mod glm4;—verify that src/models/glm4.rs exists, implements all required model traits (e.g., Loader, ModelConfigLike), and add unit tests covering instantiation and forward passes of the GLM4 model.


🏁 Script executed:

#!/bin/bash
set -e

echo "1. Checking for src/models/glm4.rs:"
fd glm4.rs src/models || echo "✗ src/models/glm4.rs not found"

echo
echo "2. Verifying trait implementations in glm4.rs:"
if [[ -f src/models/glm4.rs ]]; then
  rg -n "impl.*Loader" src/models/glm4.rs || echo "✗ No Loader impl found in glm4.rs"
  rg -n "impl.*ModelConfigLike" src/models/glm4.rs || echo "✗ No ModelConfigLike impl found in glm4.rs"
else
  echo "Skipping trait checks: glm4.rs is missing"
fi

echo
echo "3. Searching for unit tests covering GLM4:"
rg -R "glm4" -n tests || echo "✗ No tests found referencing glm4"

Length of output: 1123


Missing glm4 module implementation
You’ve declared pub(crate) mod glm4; in mistralrs-core/src/models/mod.rs, but the corresponding file and its tests are not present. Please:

  • Create mistralrs-core/src/models/glm4.rs
  • Implement the required traits (e.g., Loader, ModelConfigLike) in that file
  • Add unit tests (in mistralrs-core/tests or alongside the module) covering model instantiation and forward passes
🤖 Prompt for AI Agents
In mistralrs-core/src/models/mod.rs at line 5, you declared the glm4 module but
the corresponding file src/models/glm4.rs is missing along with its
implementations and tests. To fix this, create the file src/models/glm4.rs,
implement all required model traits such as Loader and ModelConfigLike within
it, and add unit tests either in mistralrs-core/tests or alongside the module to
cover model instantiation and forward passes.

Comment on lines +40 to +47
DiffusionModelLoader, FluxLoader, GLM4Loader, Gemma2Loader, Gemma3Loader, GemmaLoader,
Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader,
LocalModelPaths, MiniCpmOLoader, Mistral3Loader, MistralLoader, MixtralLoader, ModelKind,
ModelPaths, NormalLoaderType, NormalLoadingMetadata, NormalModel, NormalModelLoader,
Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader, PrettyName,
QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader, Qwen3Loader, Qwen3MoELoader,
Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel,
VisionModelLoader,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Add GLM4Loader to pipeline exports
You’ve made GLM4Loader part of the public API—ensure that NormalLoaderBuilder and the rest of the pipeline support the new GLM4 variant, and add end-to-end tests (e.g., loading a dummy GLM4 config) to catch integration issues early.


🏁 Script executed:

#!/bin/bash
# Search for NormalLoaderBuilder
rg -A5 "struct NormalLoaderBuilder" -n src
rg -A5 "impl.*NormalLoaderBuilder" -n src

# Search for ModelKind enum definition and variants
rg "enum ModelKind" -n src
rg -A10 "ModelKind" -n src

# Search for GLM4 references in loader builder or pipeline code
rg "GLM4" -n src

Length of output: 381


🏁 Script executed:

#!/bin/bash
# Search for NormalLoaderBuilder across the repository
rg -A5 "NormalLoaderBuilder" -n .

# Search for ModelKind enum definition and its variants
rg -A5 "enum ModelKind" -n .

# Search for GLM4 references in codebase
rg -n "GLM4" -n .

Length of output: 14739


Verify GLM4Loader Integration and Add Tests
GLM4Loader is already publicly exported and wired into the normal‐model pipeline, but we still need to cover it with an end-to-end test to catch any integration gaps early.

• In mistralrs-core/src/pipeline/normal.rs, GLM4Loader is imported alongside the other normal loaders.
• In mistralrs-core/src/pipeline/loaders/normal_loaders.rs, the NormalLoaderType::GLM4 variant maps to GLM4Loader.

Please add a simple integration test—e.g. load a minimal “glm4” config via the pipeline builder and run a no-op inference—to validate the public-API path.

🤖 Prompt for AI Agents
In mistralrs-core/src/pipeline/mod.rs around lines 40 to 47, GLM4Loader is
exported but lacks an end-to-end integration test. To fix this, add a new
integration test in the appropriate test module (likely in src/pipeline/tests or
similar) that uses the pipeline builder to load a minimal GLM4 model
configuration and performs a no-op inference. This will ensure the GLM4Loader is
fully integrated and functional through the public API.

Comment on lines +654 to 662
let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
Ok(conf) => Some(conf),
Err(e) => {
warn!("Failed to parse generation_config.json: {}", e);
None
}
}
});
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling by addressing the remaining panic.

The change to handle JSON parsing errors gracefully is good, but there's still a potential panic on file reading that should be addressed.

Apply this diff to handle file reading errors as well:

 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
-    match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
+    match fs::read_to_string(f).and_then(|content| serde_json::from_str::<GenerationConfig>(&content).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))) {
         Ok(conf) => Some(conf),
         Err(e) => {
-            warn!("Failed to parse generation_config.json: {}", e);
+            warn!("Failed to read or parse generation_config.json: {}", e);
             None
         }
     }
 });
🤖 Prompt for AI Agents
In mistralrs-core/src/pipeline/normal.rs around lines 654 to 662, the code
currently unwraps the result of reading the file, which can cause a panic if the
file is missing or unreadable. To fix this, replace the unwrap with proper error
handling by using a method like and_then or map to handle the Result from
fs::read_to_string, logging a warning and returning None if reading the file
fails, similar to how JSON parsing errors are handled. This will prevent panics
and improve robustness.

Comment on lines +3170 to +3181
fn load_xlora(
&self,
_config: &str,
_vb: ShardedVarBuilder,
_lora_config: &[((String, String), LoraConfig)],
_xlora_config: Option<XLoraConfig>,
_xlora_ordering: Ordering,
_normal_loading_metadata: NormalLoadingMetadata,
_preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
) -> Result<Box<dyn NormalModel + Send + Sync>> {
todo!()
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Implement or explicitly handle XLora loading.

The load_xlora method contains a todo!() which will cause a panic at runtime if XLora is attempted with GLM4 models. This should either be properly implemented or return a more descriptive error.

Consider returning a proper error instead:

-    fn load_xlora(
-        &self,
-        _config: &str,
-        _vb: ShardedVarBuilder,
-        _lora_config: &[((String, String), LoraConfig)],
-        _xlora_config: Option<XLoraConfig>,
-        _xlora_ordering: Ordering,
-        _normal_loading_metadata: NormalLoadingMetadata,
-        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
-    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
-        todo!()
+    fn load_xlora(
+        &self,
+        _config: &str,
+        _vb: ShardedVarBuilder,
+        _lora_config: &[((String, String), LoraConfig)],
+        _xlora_config: Option<XLoraConfig>,
+        _xlora_ordering: Ordering,
+        _normal_loading_metadata: NormalLoadingMetadata,
+        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
+    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
+        anyhow::bail!("XLora is not yet supported for GLM4 models")
🤖 Prompt for AI Agents
In mistralrs-core/src/pipeline/loaders/normal_loaders.rs around lines 3170 to
3181, the load_xlora function currently uses todo!() which will panic if called.
Replace todo!() with a proper error return that clearly indicates XLora loading
is not supported or implemented for GLM4 models, providing a descriptive error
message to avoid runtime panics.

@guoqingbao guoqingbao marked this pull request as ready for review June 9, 2025 07:46
@guoqingbao
Copy link
Contributor Author

@EricLBuehler GLM4 Support is now working.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
mistralrs-quant/src/distributed/layers.rs (1)

342-369: Add parameter validation for the chunks parameter.

The implementation looks correct for creating multiple column parallel layers with extended sharding. However, consider adding validation to ensure chunks > 0 to prevent returning an empty vector unintentionally.

 pub fn new_merged(
     in_dim: usize,
     out_dim: usize,
     chunks: usize,
     config: &Option<QuantizedConfig>,
     bias: bool,
     comm: &Arc<crate::Comm>,
     vb: ShardedVarBuilder,
 ) -> Result<Vec<Arc<dyn QuantMethod>>> {
+    if chunks == 0 {
+        candle_core::bail!("chunks must be greater than 0");
+    }
     let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
     for chunk_idx in 0..chunks {
mistralrs-core/src/layers.rs (1)

2130-2164: Consider improving parameter usage and flexibility.

The implementation looks correct for GLM4's merged gate_up_proj requirements. However, consider these improvements:

  1. The chunks parameter is only used in the assertion - consider removing it if it's always 2, or make the assertion more flexible.
  2. The hardcoded assertion chunks == 2 limits extensibility for other merged projection patterns.

Alternative implementation:

pub fn new_merged(
    vb: ShardedVarBuilder,
    hidden_size: usize,
    intermediate_size: usize,
-   chunks: usize,
    quantization_config: &Option<QuantizedConfig>,
    hidden_act: Activation,
    comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
-   assert!(chunks == 2, "Only gate_up_proj merge is supported!");
    let gate_up_projs = ColumnParallelLayer::new_merged(
        hidden_size,
        intermediate_size * 2,
        2,
        quantization_config,
        false,
        comm,
        vb.pp("gate_up_proj"),
    )?;

Or maintain flexibility:

-   assert!(chunks == 2, "Only gate_up_proj merge is supported!");
+   if chunks != 2 {
+       candle_core::bail!("Currently only gate_up_proj merge (chunks=2) is supported, got chunks={}", chunks);
+   }
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3870068 and 1845d91.

⛔ Files ignored due to path filters (1)
  • Cargo.lock is excluded by !**/*.lock
📒 Files selected for processing (11)
  • Cargo.toml (1 hunks)
  • mistralrs-core/src/layers.rs (1 hunks)
  • mistralrs-core/src/models/glm4.rs (1 hunks)
  • mistralrs-core/src/models/mod.rs (1 hunks)
  • mistralrs-core/src/pipeline/isq.rs (1 hunks)
  • mistralrs-core/src/pipeline/loaders/mod.rs (1 hunks)
  • mistralrs-core/src/pipeline/loaders/normal_loaders.rs (6 hunks)
  • mistralrs-core/src/pipeline/mod.rs (1 hunks)
  • mistralrs-core/src/pipeline/normal.rs (3 hunks)
  • mistralrs-pyo3/src/which.rs (2 hunks)
  • mistralrs-quant/src/distributed/layers.rs (1 hunks)
✅ Files skipped from review due to trivial changes (2)
  • mistralrs-core/src/pipeline/loaders/mod.rs
  • mistralrs-core/src/pipeline/mod.rs
🚧 Files skipped from review as they are similar to previous changes (6)
  • mistralrs-core/src/models/mod.rs
  • mistralrs-pyo3/src/which.rs
  • mistralrs-core/src/pipeline/isq.rs
  • mistralrs-core/src/pipeline/normal.rs
  • mistralrs-core/src/pipeline/loaders/normal_loaders.rs
  • mistralrs-core/src/models/glm4.rs
⏰ Context from checks skipped due to timeout of 90000ms (1)
  • GitHub Check: Test Suite (windows-latest, stable)
🔇 Additional comments (1)
Cargo.toml (1)

128-129: Verify the necessity of git dependencies and consider long-term implications.

The switch to git dependencies resolves the compatibility issues mentioned in the PR objectives. However, consider the following:

  1. Security: Git dependencies bypass crates.io's security scanning - ensure the repository is trustworthy.
  2. Maintenance: This creates a dependency on external git repository availability.
  3. Updates: Fixed revisions prevent automatic security updates.

Verification questions:

  • Is this intended as a temporary fix until the changes are published to crates.io?
  • Have the changes in revision 2ce5ab8 been reviewed for security?
  • Will there be a plan to migrate back to crates.io versions?

Consider adding a comment in the code explaining the rationale:

+# Using git dependencies temporarily until llguidance compatibility fixes are published to crates.io
llguidance = { git = "https://github.com/guidance-ai/llguidance.git", version = "0.7.29", default-features = false, features = ["lark"], rev = "2ce5ab8" }
toktrie_hf_tokenizers = {git = "https://github.com/guidance-ai/llguidance.git", version = "0.7.29", rev = "2ce5ab8" }

Copy link
Owner

@EricLBuehler EricLBuehler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @guoqingbao! Thanks for the model, looks great. Just a small thing:

  • Can you please update README.md (toplevel) support matrices to reflect this?
  • Can you please update API.md for the Python API to show similarly reflect this?
  • Can you please update mistralrs.pyi with the new model architecture?

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1845d91 and 66d596e.

📒 Files selected for processing (4)
  • README.md (5 hunks)
  • docs/GLM4.md (1 hunks)
  • mistralrs-pyo3/API.md (1 hunks)
  • mistralrs-quant/README.md (1 hunks)
✅ Files skipped from review due to trivial changes (3)
  • mistralrs-pyo3/API.md
  • mistralrs-quant/README.md
  • README.md
⏰ Context from checks skipped due to timeout of 90000ms (8)
  • GitHub Check: Clippy
  • GitHub Check: Docs
  • GitHub Check: Check (macOS-latest, stable)
  • GitHub Check: Check (windows-latest, stable)
  • GitHub Check: Check (ubuntu-latest, stable)
  • GitHub Check: Test Suite (windows-latest, stable)
  • GitHub Check: Test Suite (macOS-latest, stable)
  • GitHub Check: Test Suite (ubuntu-latest, stable)
🔇 Additional comments (4)
docs/GLM4.md (4)

1-4: Documentation: Title and model link
The header and the Hugging Face model collection link are clear and correctly formatted.


5-6: Description clarity
The introduction succinctly describes the GLM4 series and its support by mistral.rs.


7-8: Section header
The "HTTP API" header is properly placed.


34-35: Section header
The "Python API" header and its code fence are correctly added.

@guoqingbao
Copy link
Contributor Author

Hi @guoqingbao! Thanks for the model, looks great. Just a small thing:

  • Can you please update README.md (toplevel) support matrices to reflect this?
  • Can you please update API.md for the Python API to show similarly reflect this?
  • Can you please update mistralrs.pyi with the new model architecture?

Yes, addressed.

@guoqingbao guoqingbao requested a review from EricLBuehler June 11, 2025 03:40
@ghost
Copy link

ghost commented Jun 12, 2025

I tested the PR (on my mac) and when I make request, it responsed:

{"message":"unknown method (in chat_template:9)"}

Compile with:

git clone https://github.com/EricLBuehler/mistral.rs
cd mistral.rs
git fetch origin pull/1437/head:pr-1437
git merge pr-1437
cargo build --release --features metal
cd target/release

command:

./mistralrs-server --paged-attn --pa-gpu-mem 8192 --pa-blk-size 32  --port 2000 plain -m /mypathto/GLM-4-9B-0414

curl:

curl --request POST   --url http://myip:2000/v1/chat/completions    --header 'Content-Type: application/json'   --data '{
  "model": "mymodelname",
  "messages": [
    {
      "role": "user",
      "content": "mycontent"
    }
  ]
}'

And I just think it's similar to EricLBuehler/candle-vllm#158

By the way, I tried also mlx and gguf, and find that it does not support them yet.
mlx says:
mistralrs_core::engine: prompt step - Model failed with error: Msg("w inner dims ([6144, 5760]) must match x inner dims ([1, 1, 46080]). transpose=true")
but it didn't quit, and when I make request with curl like above, the response is the same as above.

gguf says:
Unknown GGUF architecture glm4

@guoqingbao
Copy link
Contributor Author

I tested the PR (on my mac) and when I make request, it responsed:

{"message":"unknown method (in chat_template:9)"}

Compile with:

git clone https://github.com/EricLBuehler/mistral.rs
cd mistral.rs
git fetch origin pull/1437/head:pr-1437
git merge pr-1437
cargo build --release --features metal
cd target/release

command:

./mistralrs-server --paged-attn --pa-gpu-mem 8192 --pa-blk-size 32  --port 2000 plain -m /mypathto/GLM-4-9B-0414

curl:

curl --request POST   --url http://myip:2000/v1/chat/completions    --header 'Content-Type: application/json'   --data '{
  "model": "mymodelname",
  "messages": [
    {
      "role": "user",
      "content": "mycontent"
    }
  ]
}'

And I just think it's similar to EricLBuehler/candle-vllm#158

By the way, I tried also mlx and gguf, and find that it does not support them yet. mlx says: mistralrs_core::engine: prompt step - Model failed with error: Msg("w inner dims ([6144, 5760]) must match x inner dims ([1, 1, 46080]). transpose=true") but it didn't quit, and when I make request with curl like above, the response is the same as above.

gguf says: Unknown GGUF architecture glm4

We haven't support GGUF in Mistral.rs. Are you able to revise the tokenizer_config.json a little bit, cause some Chinese characters in this file cannot be decoded in both Candle-vLLM and Mistral.rs:

Original one:

"chat_template": "[gMASK]<sop>{%- if tools -%}<|system|>\n# 可用工具\n{% for tool in tools %}{%- set function = tool.function if tool.get(\"function\") else tool %}\n\n## {{ function.name }}\n\n{{ function | tojson(indent=4, ensure_ascii=False) }}\n在调用上述函数时,请使用 Json 格式表示调用的参数。{%- endfor %}{%- endif -%}{%- for msg in messages %}{%- if msg.role == 'system' %}<|system|>\n{{ msg.content }}{%- endif %}{%- endfor %}{%- for message in messages if message.role != 'system' %}{%- set role = message['role'] %}{%- set content = message['content'] %}{%- set meta = message.get(\"metadata\", \"\") %}{%- if role == 'user' %}<|user|>\n{{ content }}{%- elif role == 'assistant' and not meta %}<|assistant|>\n{{ content }}{%- elif role == 'assistant' and meta %}<|assistant|>{{ meta }} \n{{ content }}{%- elif role == 'observation' %}<|observation|>\n{{ content }}{%- endif %}{%- endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",

Revised one (with contents between {%- if tools -%} to {%- endif -%} removed)

  "chat_template": "[gMASK]<sop>{%- for msg in messages %}{%- if msg.role == 'system' %}<|system|>\n{{ msg.content }}{%- endif %}{%- endfor %}{%- for message in messages if message.role != 'system' %}{%- set role = message['role'] %}{%- set content = message['content'] %}{%- if role == 'user' %}<|user|>\n{{ content }}{%- elif role == 'assistant' %}<|assistant|>\n{{ content }}{%- elif role == 'assistant' %}<|assistant|>{{ content }}{%- elif role == 'observation' %}<|observation|>\n{{ content }}{%- endif %}{%- endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",

@guoqingbao
Copy link
Contributor Author

guoqingbao commented Jun 12, 2025

It's the Chat template issue, current candle-vllm and Mistral.rs unable to handle chat tempalte for latest GLM4-0414, if we remove the unwanted function/characters in GLM4's chat template, we got the correct answer:

     Running `target/release/candle-vllm --log --multi-process --dtype bf16 --port 2000 --device-ids 0 --weight-file /home/data/GLM-4-9B-0414-Q4_K_M.gguf glm4 --quant gguf --temperature 0. --penalty 1.0`
2025-06-12 17:36:58 WARN candle_vllm subprocess rank 0 started!
2025-06-12 17:36:58 WARN candle_vllm::openai::communicator command channel is built!
2025-06-12 17:36:59 WARN candle_vllm::openai::communicator command channel is built!
2025-06-12 17:37:03 WARN candle_vllm::backend::progress all ranks finished model loading!
2025-06-12 17:37:03 WARN candle_vllm::openai::pipelines::pipeline Done loading.
2025-06-12 17:37:04 WARN candle_vllm::openai::pipelines::pipeline stop_token_ids [151336]
-> Batch 1: 32768
-> Batch 8: 13104
2025-06-12 17:37:04 WARN candle_vllm Server started at http://0.0.0.0:2000.
TEMPLATE [gMASK]<sop>{%- for msg in messages %}{%- if msg.role == 'system' %}<|system|>\n{{ msg.content }}{%- endif %}{%- endfor %}{%- for message in messages if message.role != 'system' %}{%- set role = message['role'] %}{%- set content = message['content'] %}{%- if role == 'user' %}<|user|>\n{{ content }}{%- elif role == 'assistant' and not meta %}<|assistant|>\n{{ content }}{%- elif role == 'assistant' and meta %}<|assistant|> \n{{ content }}{%- elif role == 'observation' %}<|observation|>\n{{ content }}{%- endif %}{%- endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}
2025-06-12 17:37:08 WARN candle_vllm::openai::pipelines::llm_engine New Request with length 6 (cmpl-608c889f-023d-4380-a5ee-fa5f5fc35295).
2025-06-12 17:37:08 WARN candle_vllm::openai::pipelines::llm_engine Sending 1 tasks to 0 subprocesses
2025-06-12 17:37:09 WARN candle_vllm::openai::pipelines::llm_engine Decoding 35 tokens finished in 0 seconds (cmpl-608c889f-023d-4380-a5ee-fa5f5fc35295)
2025-06-12 17:37:09 WARN candle_vllm::openai::pipelines::llm_engine Sending finish message to subprocesses

 [1 requests] Prefilling: 6 prompt tokens processed (avg tps 6.45 tokens/s, throughput 6.45 tokens/s)

 [1 requests] Decoding: 35 tokens processed (avg tps 80.28 tokens/s, throughput 80.28 tokens/s)
🙋 Please Input (Ctrl+C to start a new chat or exit): hi
Candle-vLLM: ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Hello 👋! I'm ChatGLM(智谱清言), the artificial intelligence assistant, nice to meet you. Feel free to ask me any questions.

For safetensor models, just replace chat template in tokenizer_config.json with the following:

  "chat_template": "[gMASK]<sop>{%- for msg in messages %}{%- if msg.role == 'system' %}<|system|>\n{{ msg.content }}{%- endif %}{%- endfor %}{%- for message in messages if message.role != 'system' %}{%- set role = message['role'] %}{%- set content = message['content'] %}{%- if role == 'user' %}<|user|>\n{{ content }}{%- elif role == 'assistant' %}<|assistant|>\n{{ content }}{%- elif role == 'assistant' %}<|assistant|>{{ content }}{%- elif role == 'observation' %}<|observation|>\n{{ content }}{%- endif %}{%- endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",

It seems its a know issue for GLM4: ggml-org/llama.cpp#13099

@ghost
Copy link

ghost commented Jun 12, 2025

Ok I tested the given chat_template and it worked.

But what about "{% for tool in tools %}"?

Your given chat_template is missing that. So no tool calling support then?

@guoqingbao
Copy link
Contributor Author

Ok I tested the given chat_template and it worked.

But what about "{% for tool in tools %}"?

Your given chat_template is missing that. So no tool calling support then?

Tool calling can be preserved, refer to this fix EricLBuehler/candle-vllm#178

Only remove contents from "{%- set meta = message.get("metadata", "") %}" to "{{ meta }}" can also be worked.

@ghost
Copy link

ghost commented Jun 12, 2025

nly remove contents from "{%- set meta = message.get("metadata", "") %}" to "{{

OK I comfired. This worked.

@guoqingbao
Copy link
Contributor Author

nly remove contents from "{%- set meta = message.get("metadata", "") %}" to "{{

OK I comfired. This worked.

I’ve pushed an updated fix for that—please feel free to use it. I’ll provide another PR later for GLM4 GGUF model support in Mistral.rs.

@EricLBuehler This should work now. Sorry about that—I forgot I had used a fixed tokenizer_config.json in the initial PR submission.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
mistralrs-core/src/pipeline/chat_template.rs (1)

295-296: Slice-syntax replacement is overly broad

Using a blind
template.replace("[::-1]", "|reverse") will also mutate any literal string that happens to contain the same substring (e.g. inside a quoted example). Consider a regex that targets the slice syntax only when it appears inside a Jinja expression:

-use std::collections::HashMap;
+use std::collections::HashMap;
+use regex::Regex;        // new

 ...

-let mut template = template.replace("[::-1]", "|reverse");
+let re = Regex::new(r"\[\s*::-1\s*\]").unwrap();
+let mut template = re.replace_all(&template, "|reverse").into_owned();

This keeps the transformation precise and avoids surprising template rewrites.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c520285 and ed974bf.

📒 Files selected for processing (1)
  • mistralrs-core/src/pipeline/chat_template.rs (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: Test Suite (ubuntu-latest, stable)
  • GitHub Check: Test Suite (windows-latest, stable)
  • GitHub Check: Test Suite (macOS-latest, stable)

Comment on lines +297 to +301
if template.contains("{{ meta }}") {
//fix for GLM4 models
template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", "");
template = template.replace("{{ meta }}", "");
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

metadata hot-fix is model-agnostic and can corrupt other templates

The current logic strips every {%- set meta = ... %} and {{ meta }} whenever either token is present.
Templates for non-GLM4 models that legitimately use a meta variable will silently lose information.

At minimum, guard the workaround behind a GLM4-specific check (e.g. an explicit is_glm4 flag or by inspecting the tokenizer’s model_type).
A cleaner alternative is to keep the template intact and instead make meta safely optional:

-env.add_template("chat_template", &template)?;
+env.set_undefined_behavior(minijinja::UndefinedBehavior::ChainableNoLogging);
+env.add_template("chat_template", &template)?;

With ChainableNoLogging, undefined variables evaluate to the empty string, eliminating the need for brittle string replacements and keeping other model templates untouched.


🏁 Script executed:

#!/bin/bash
# Locate all usages of add_template in the codebase
rg --color=never -n "add_template" || true

# Display the context around lines 280–320 in chat_template.rs
sed -n '260,320p' mistralrs-core/src/pipeline/chat_template.rs

Length of output: 2706


Guard meta hot-fix behind GLM4 check or enable chainable undefined behavior

  • The unconditional replacement in mistralrs-core/src/pipeline/chat_template.rs (around lines 297–301) strips any meta usage from all models.
  • This will break templates that legitimately use meta in non-GLM4 contexts.

Proposed fixes:

  1. GLM4-only guard
    Wrap the replacements in a GLM4 check (e.g. if is_glm4_model { … }) so other models remain unaffected.
  2. Safer alternative
    Use Minijinja’s ChainableNoLogging undefined behavior to render missing variables as empty strings:
 env.add_function("strftime_now", strftime_now);
-// fix for GLM4 models
-if template.contains("{{ meta }}") {
-    template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", "");
-    template = template.replace("{{ meta }}", "");
-}
-env.add_template("chat_template", &template)?;
+// make undefined variables return empty string
+env.set_undefined_behavior(minijinja::UndefinedBehavior::ChainableNoLogging);
+env.add_template("chat_template", &template)?;

Either approach prevents unintended corruption of non-GLM4 templates.

🤖 Prompt for AI Agents
In mistralrs-core/src/pipeline/chat_template.rs around lines 297 to 301, the
current code unconditionally removes `{%- set meta = ... %}` and `{{ meta }}`
from all templates, which breaks templates that legitimately use `meta` in
non-GLM4 models. To fix this, wrap the replacement logic inside a conditional
check that ensures it only runs for GLM4 models (e.g., `if is_glm4_model { ...
}`). Alternatively, configure the template rendering to use Minijinja’s
`ChainableNoLogging` undefined behavior so that missing variables like `meta`
safely evaluate to empty strings without modifying the template text. This
prevents unintended template corruption for other models.

@ghost
Copy link

ghost commented Jun 12, 2025

But What about MLX?

I mean, the

mistralrs_core::engine: prompt step - Model failed with error: Msg("w inner dims ([6144, 5760]) must match x inner dims ([1, 1, 46080]). transpose=true")

issue, I tested it again with the updated chat_template and it still does not work

@guoqingbao
Copy link
Contributor Author

But What about MLX?

I mean, the

mistralrs_core::engine: prompt step - Model failed with error: Msg("w inner dims ([6144, 5760]) must match x inner dims ([1, 1, 46080]). transpose=true")

issue, I tested it again with the updated chat_template and it still does not work

Have you used the correct command:

cargo run --features cuda -- -i plain -m /home/data/GLM-4-9B-0414/

This works on my side. You mean this does not work on Metal (apple silicon)?

@guoqingbao
Copy link
Contributor Author

But What about MLX?

I mean, the

mistralrs_core::engine: prompt step - Model failed with error: Msg("w inner dims ([6144, 5760]) must match x inner dims ([1, 1, 46080]). transpose=true")

issue, I tested it again with the updated chat_template and it still does not work

Just want to confirm, is this worked on your side?

@ghost
Copy link

ghost commented Jun 13, 2025

This works on my side.

mlx is mlx. mlx can't run on cuda. mlx only works on metal. "mlx works on cuda" makes no sense

@guoqingbao
Copy link
Contributor Author

guoqingbao commented Jun 13, 2025 via email

@ghost
Copy link

ghost commented Jun 13, 2025

I remember that the original model worked on my mac, but the mlx model didn't.

Oh, wait, I was running mlx-q8. So is it because it's 8bit?

Well, I just lost interest on this. Maybe I'll test it later at some time.

@guoqingbao
Copy link
Contributor Author

guoqingbao commented Jun 13, 2025 via email

@ghost
Copy link

ghost commented Jun 13, 2025

It's the same model for both platforms (CUDA and Metal), we haven't had a GLM4 model until this one. You can simply run the command line I provided with cuda feature on NV GPU and metal feature on Apple Silicon. The GLM4 gguf model is not yet supported in Mistral.rs, so loading gguf model as plain model doesn't work.

What? I thought you need to convert original model to mlx, using mlx_lm.convert, see

https://github.com/ml-explore/mlx-lm

@ghost
Copy link

ghost commented Jun 13, 2025

metal is metal, mlx is mlx, two different things,man.

@guoqingbao
Copy link
Contributor Author

guoqingbao commented Jun 13, 2025 via email

@ghost
Copy link

ghost commented Jun 13, 2025

not just released by the mlx community, man. If I convert by myself, I use my own private model then. A lot of other people also release mlx models.

@ghost
Copy link

ghost commented Jun 13, 2025

mistral.rs just advertise it can run mlx model, and I have run other mlx models successfully using mistral.rs.

Man, you need to update your knowledge..

@ghost
Copy link

ghost commented Jun 13, 2025

I remember I can run mlx model Qwen3-8B-MLX-8bit with mistral.rs successfully.

@guoqingbao
Copy link
Contributor Author

guoqingbao commented Jun 13, 2025 via email

@guoqingbao
Copy link
Contributor Author

guoqingbao commented Jun 13, 2025 via email

@ghost
Copy link

ghost commented Jun 13, 2025

OMG. I won't talk with you any more. You makes no sense and have gone too far. MLX model and framework are also two different things. I don't say that I need mlx framework to run mlx model right?

@guoqingbao
Copy link
Contributor Author

guoqingbao commented Jun 13, 2025 via email

@sempervictus
Copy link
Contributor

@guoqingbao - this is distinct from the ChatGLMModel used by THUDM/codegeex4-all-9b i presume? If it is supposed to run that model then... it does not succeed (usual error - Error: Unsupported Hugging Face Transformers -CausalLM model class 'ChatGLMModel'. Please raise an issue)

@guoqingbao
Copy link
Contributor Author

@guoqingbao - this is distinct from the ChatGLMModel used by THUDM/codegeex4-all-9b i presume? If it is supposed to run that model then... it does not succeed (usual error - Error: Unsupported Hugging Face Transformers -CausalLM model class 'ChatGLMModel'. Please raise an issue)

the latest GLM4 model (e.g., GLM-4-9B-0414)

As the PR said, it supports the latest GLM4 model (i.e., 0414 version), which has different arch compared to ChatGLM. ZhipuAI (THUDM) is expected to abandon the ChatGLM structure, the GLM4 arch will be their backbone structure.

@sempervictus
Copy link
Contributor

@guoqingbao - this is distinct from the ChatGLMModel used by THUDM/codegeex4-all-9b i presume? If it is supposed to run that model then... it does not succeed (usual error - Error: Unsupported Hugging Face Transformers -CausalLM model class 'ChatGLMModel'. Please raise an issue)

the latest GLM4 model (e.g., GLM-4-9B-0414)

As the PR said, it supports the latest GLM4 model (i.e., 0414 version), which has different arch compared to ChatGLM. ZhipuAI (THUDM) is expected to abandon the ChatGLM structure, the GLM4 arch will be their backbone structure.

Thanks for the insight - how different is the old ChatGLM architecture/how big a lift to implement? Their codegeex4 model is very handy for simple things like TF/packer work and i'm trying to get all the model workloads into mistral.rs at this point :-)

@guoqingbao
Copy link
Contributor Author

@guoqingbao - this is distinct from the ChatGLMModel used by THUDM/codegeex4-all-9b i presume? If it is supposed to run that model then... it does not succeed (usual error - Error: Unsupported Hugging Face Transformers -CausalLM model class 'ChatGLMModel'. Please raise an issue)

the latest GLM4 model (e.g., GLM-4-9B-0414)

As the PR said, it supports the latest GLM4 model (i.e., 0414 version), which has different arch compared to ChatGLM. ZhipuAI (THUDM) is expected to abandon the ChatGLM structure, the GLM4 arch will be their backbone structure.

Thanks for the insight - how different is the old ChatGLM architecture/how big a lift to implement? Their codegeex4 model is very handy for simple things like TF/packer work and i'm trying to get all the model workloads into mistral.rs at this point :-)

We can indeed support the old models (ChatGLM and earlier versions of GLM4) in Mistral.rs, perhaps by openning another PR :). ChatGLM and pre-0414 GLM4 share the same architecture, which differs significantly from popular models like LLaMA, Qwen, and Mistral. ZhipuAI (THUDM) is gradually shifting from an encoder-decoder architecture to a decoder-only one. It's likely that the next-gen CodeGeeX4 model will also adopt the GLM4-0414 architecture, though that hasn’t been confirmed yet.

@sempervictus
Copy link
Contributor

@guoqingbao - this is distinct from the ChatGLMModel used by THUDM/codegeex4-all-9b i presume? If it is supposed to run that model then... it does not succeed (usual error - Error: Unsupported Hugging Face Transformers -CausalLM model class 'ChatGLMModel'. Please raise an issue)

the latest GLM4 model (e.g., GLM-4-9B-0414)

As the PR said, it supports the latest GLM4 model (i.e., 0414 version), which has different arch compared to ChatGLM. ZhipuAI (THUDM) is expected to abandon the ChatGLM structure, the GLM4 arch will be their backbone structure.

Thanks for the insight - how different is the old ChatGLM architecture/how big a lift to implement? Their codegeex4 model is very handy for simple things like TF/packer work and i'm trying to get all the model workloads into mistral.rs at this point :-)

We can indeed support the old models (ChatGLM and earlier versions of GLM4) in Mistral.rs, perhaps by openning another PR :). ChatGLM and pre-0414 GLM4 share the same architecture, which differs significantly from popular models like LLaMA, Qwen, and Mistral. ZhipuAI (THUDM) is gradually shifting from an encoder-decoder architecture to a decoder-only one. It's likely that the next-gen CodeGeeX4 model will also adopt the GLM4-0414 architecture, though that hasn’t been confirmed yet.

Thank you for the clarification. I would absolutely love to dive into this codebase and piggyback off your PR to implement the architecture loader - its not for lack of comprehension capacity or interest in contributing back but a matter of time: principal architect for a top-shelf consultancy running design, implementation, go-live, and lifecycle for the (probably literally evil) metal backing this whole AI thing (hence the occasional snippets from inside B200 stacks and the like in my commentary). Fun job, doesn't leave any time to do anything else though. If you're similarly strapped for bandwidth i'm happy to put the request into our internal backlog and see if any of my team have the cycles.

Copy link
Owner

@EricLBuehler EricLBuehler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, looks great! Tested it and it works well.

@EricLBuehler EricLBuehler merged commit 210061f into EricLBuehler:master Jun 19, 2025
12 checks passed
EricLBuehler pushed a commit that referenced this pull request Jun 20, 2025
* Support GLM4 model

* Mention GLM4 model in ReadMe

* glm4 type hint

* Typo fix

* Fix unsupported chat_template function

* Clippy fix
@sempervictus
Copy link
Contributor

@EricLBuehler @guoqingbao - Seems we need to be able to specify tokenizer path configurations for this:

Could not get file "tokenizer.json" from API: RequestError(Status(404, Response[status: 404, status_text: Not Found, url: https://huggingface.co/THUDM/LongCite-glm4-9b/resolve/main/tokenizer.json]))

file is tokenizer_config.json from what i can tell

@sempervictus
Copy link
Contributor

🤦 "other help section" - positionally dependent --help is a good way to keep us on our toes ;-p

@sempervictus
Copy link
Contributor

That said: the model crashes out w/

Error: Unsupported Hugging Face Transformers -CausalLM model class `ChatGLMModel`. Please raise an issue.

Stack backtrace:
   0: anyhow::error::<impl anyhow::Error>::msg
   1: anyhow::__private::format_err.28178
   2: mistralrs_core::pipeline::loaders::normal_loaders::NormalLoaderType::from_causal_lm_name
   3: mistralrs_core::pipeline::loaders::normal_loaders::AutoNormalLoader::get_loader
   4: <mistralrs_core::pipeline::loaders::normal_loaders::AutoNormalLoader as mistralrs_core::pipeline::loaders::normal_loaders::NormalModelLoader>::supports_paged_attention
   5: <mistralrs_core::pipeline::normal::NormalLoader as mistralrs_core::pipeline::loaders::Loader>::load_model_from_path
   6: <mistralrs_core::pipeline::normal::NormalLoader as mistralrs_core::pipeline::loaders::Loader>::load_model_from_hf
   7: mistralrs_server::main::{{closure}}.61526
   8: mistralrs_server::main
   9: std::sys::backtrace::__rust_begin_short_backtrace
  10: main
  11: <unknown>
  12: __libc_start_main
  13: _start

will rebuild the dockers off latest master and retry this weekend

@guoqingbao
Copy link
Contributor Author

@EricLBuehler @guoqingbao - Seems we need to be able to specify tokenizer path configurations for this:

Could not get file "tokenizer.json" from API: RequestError(Status(404, Response[status: 404, status_text: Not Found, url: https://huggingface.co/THUDM/LongCite-glm4-9b/resolve/main/tokenizer.json]))

file is tokenizer_config.json from what i can tell

The THUDM/LongCite-glm4-9b model is a ChatGLM (older GLM4), and they haven’t included a tokenizer.json. Instead, they provided a tokenizer.model, which is not compatible with Mistral.rs.

EricLBuehler added a commit that referenced this pull request Jun 23, 2025
* Add most of paged attn kv quant

* It builds a bit

* All the functionality at least

* Small fix

* Add a scale

* Fix bf16 usage

* Make k_v_scale optional

* Collector

* Tweak collection

* Refactor

* Add to apis

* Add cuda impl

* Fix compilation

* Fixes

* Handle ENABLE_FP8

* Format

* Tweak

* Fix scaled_convert usage

* Fix cache_t size

* Fixed scale collection

* Actual fix

* Fix fp8 for CC<8

* Fix the usual String != &str bit (#1483)

Co-authored-by: RageLtMan <rageltman [at] sempervictus>

* chore: `Dockerfile` - Drop runtime rayon thread ENV (#1465)

* chore: Dockerfile - Remove rayon threads env

* chore: Dockerfile - Improve formatting for `apt-get`

* Remove duplicate calls for api_dir_list (#1474)

* Remove duplicate calls for api_dir_list

* Support local cache for api_dir_list

* Fix home folder for metal

* Capitalized

* Fix transient pyo3 dep (#1478)

Co-authored-by: Eric Buehler <[email protected]>

* Fix objc dep with non macos (#1480)

* Fix phi 3/4 + nccl issue (#1481)

* Fix log

* Fix n kv heads

* Fix phi3.5 moe (#1482)

* Fix phi3.5 moe accum device

* Fix again

* Fix again

* Support GLM4 model! (#1437)

* Support GLM4 model

* Mention GLM4 model in ReadMe

* glm4 type hint

* Typo fix

* Fix unsupported chat_template function

* Clippy fix

* Refactor distributed backend (#1484)

* Refactor distributed backend, check power of 2

* Fix compilation

* Cap metal paged attn kv allocation (#1485)

* Better paged attn metal cap (#1486)

* Better paged attn metal cap

* Small fix

* Comment

* Small fix

* Refactor

* Server core: consolidate and unify route handlers and API surface (#1423)

* Start working on consolidating completion and chat_completion underlying implementations

* Move response channel to util mod for now (since it's used with streaming and non streaming)

* More work on consolidating completions and chat completions

* More WIP consolidation of server core handlers

* More WIP consolidation of server core handlers

* More WIP consolidation of server core handlers

* Update docs and restrict completion core visibility

* CodeRabbit feedback: remove logprobs warn from route handler since parse request also checks this

* Use consistent var name for completions mod

* Make route handler modules public API consistent (same fn names, etc.) and provide proxy fn that wrap core fns so core mod doesn't have to be pub
Make lib.rs example compile checked and update example

* Code formatting

* Typo

* Sync fork

* Sync fork

* Docs example fix

* Support qwen3 gguf (#1488)

* Add qwen3 gguf

* Template fixup

* Make bos/eos token IDs optional (#1493)

* Remove python deps from CUDA dockerfiles (#1487)

* Handle USE_FP8 for cuda

* Fix cuda warn

* Add readme

* Saturating sub in sequence state

---------

Co-authored-by: Eric Buehler <[email protected]>
Co-authored-by: RageLtMan <[email protected]>
Co-authored-by: Brennan Kinney <[email protected]>
Co-authored-by: Guoqing Bao <[email protected]>
Co-authored-by: Matthew Haynes <[email protected]>
Jeadie added a commit to spiceai/mistral.rs that referenced this pull request Jul 14, 2025
* Fix handling of Metal fused attn head dims (EricLBuehler#1234)

* Fix handling of metal attn head dims

* Fix handling of gemma3 1b when images

* Tweak default for paged attn builder

* Support paged attn for vision model rust api (EricLBuehler#1235)

* [Breaking] Support setting HF cache path (EricLBuehler#1237)

* Add it internally

* Add the apis

* Support tool calling for DeepSeek models (EricLBuehler#1239)

* Support tool calling for deepseek models

* Format

* Fix deepseek

* Server image processing refactor and fixes (EricLBuehler#1244)

* Fix strict gemma3 case

* Accept multiple images in the content array

* Fix multiple images in one array ct

* Add it to the python api

* Typos

* Optimized CUDA RoPE kernels (EricLBuehler#1247)

* Add the kernels

* It works

* Works

* Buulds

* Typo fix (add_speial_tokens to add_special_tokens) (EricLBuehler#1246)

* Fix typo

* Update mistralrs.pyi

* Fixes for UQFF + distributed layers (EricLBuehler#1250)

* Fixes for uqff + distributed layers

* Typo

* Automatic agentic search integration (`web_search_options`) (EricLBuehler#1243)

* Add the tool

* Actually search

* Clippy

* Sort of works

* Remove some debuggers

* tweak

* Add some rules

* Works great

* Tweak 'system' prompt

* Update mistralrs-core/src/search/mod.rs

Co-authored-by: Copilot <[email protected]>

* Typo

* Add it to all the apis

* Add bert model for similarity reranking

* Typos

* Early detection of tools

* Alias max_tokens -> max_completion_tokens too

* Customizable bert model

* Flip the enabler around

* Add docs

* Update readme

* Typo

---------

Co-authored-by: Copilot <[email protected]>

* Format kernels (EricLBuehler#1251)

* Update readme

* Update readme

* Remove test

* Add quantize guards for uqff deserialize (EricLBuehler#1252)

* Refactor cuBLASlt-related code (EricLBuehler#1253)

* Centralize cublaslt into mistralrs-quant

* Use cublaslt in unquant layer

* Use beautiful trait constants for simpler code

* Move tests

* Dispatch to unquant for cublaslt

* Dispatch to unquant for cublaslt

* Fix feature

* Add convert_to_gptq script

* Update deps, bump pyo3 version (EricLBuehler#1259)

* Faster cuda FP8 performance (EricLBuehler#1257)

* Avoid fp8 sync

* Fix dtype

* Rust 1.86 clippy (EricLBuehler#1260)

* Rust 1.86 clippy

* Clippy

* Refactor engine arch (EricLBuehler#1262)

* Refactor engine add_request

* Don't recompile regex

* Clippy

* Revamped LoRA support - removing the Ordering system! (EricLBuehler#1263)

* Play with varbuilder lifetimes

* Merge lora weights

* Clippy

* Lora works

* Support multiple loras

* Cleanup, remove adapter activation

* Complete merge

* Fast Metal-specific quantization method: AFQ (EricLBuehler#1264)

* Add mlx quantized kernels

* Add mlx quantized kernels

* Kernel launcher

* Add AFQ isq quant and dequant

* Some quantmethod things

* Begin to implement the qmm caller

* Clippy

* Much faster

* Cache kernels

* Docs

* Clippy

* Add it to uqff

* Support prequantized models from MLX (EricLBuehler#1265)

* Refactor quantizedconfig

* Support AFQ prequantized

* Update docs

* Update docs

* Automatic ISQ to select fastest & most accurate method (EricLBuehler#1266)

* Automatic isq

* typo

* Doc

* Improved usage metrics (EricLBuehler#1267)

* Fix cuda

* Bump tokio from 1.44.1 to 1.44.2 (EricLBuehler#1270)

Bumps [tokio](https://github.com/tokio-rs/tokio) from 1.44.1 to 1.44.2.
- [Release notes](https://github.com/tokio-rs/tokio/releases)
- [Commits](tokio-rs/tokio@tokio-1.44.1...tokio-1.44.2)

---
updated-dependencies:
- dependency-name: tokio
  dependency-version: 1.44.2
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Gather MM ops in mistralrs-quant (EricLBuehler#1272)

* Update the caller

* Wire things up

* Broadcase for afq gathermm

* Broadcase for afq gathermm

* Clippy

* Improve performance of deepseek models

* Typo fix

* BincountOp not used

* Implement Llama 4! (EricLBuehler#1268)

* Implement Llama 4

* Implement the main changes for the text model

* Make chunked mask

* Wire things up

* Add some EP

* Initial sketch of inputs processor

* Runs

* Progress

* all reduce moes

* It works!

* Some cleanup

* Faster moe block

* Add device map

* Make chunked matrix

* Fully working now!

* Reactivate cublaslt

* Fix shared mlp cublaslt

* Refactor to packed experts

* Complete merge

* It is a normal model now

* Fixes

* Set device for moe

* ISQ fixes

* Much faster sort kernel

* Faster loading!

* Faster loading!

* Fp8 cpu copy ops in candle backend

* Add the vision model

* Add mmproj layer

* Actually merge the inputs

* Sketch most of the image processor

* Add the rest of the image processor

* Implement the whole processor

* Add the loader

* Some fixes

* A batch of fixes

* Some fixes

* tmp

* Actually support isq

* Ok it works a bit

* Fix norm device

* It works

* A bit cleaner

* Support residul tensors

* Remove text loader

* Implement the device mapping system

* Fix auto device map

* Add examples

* Add model card

* Typo

* Remove superflous logging

* Fixes for Llama 4 UQFF loading (EricLBuehler#1275)

* Support sharding for UQFF (EricLBuehler#1276)

* Serialize sharded uqff files

* Loading

* Fix base64

* Fix bug for group-topk (group_limited_greedy) in deepseek models (EricLBuehler#1278)

* Support the DeepCoder model (EricLBuehler#1279)

* Add faq for metal not found

* Improved PagedAttn scheduling accuracy (EricLBuehler#1282)

* Scheduler ops by reference

* Ensure scheduler gets correct prompts

* Fix cuda build for copy_blocks

* Fixes for scheduling image seqs with pagedattn (EricLBuehler#1283)

* update to llguidance 0.7.16 (EricLBuehler#1284)

* update llguidance to 0.7.16 from crates.io; use ParserFactory

* add lark_llg.py example

* use new llguidance::Matcher APIs

* rework spec-decoding with llg

* more work on spec sampling

* check for parser stop

* fix clippy

* remove unneeded rollback

* update build_llg_factory to return Result

* Update dependencies (EricLBuehler#1286)

* Much faster image inputs processing (EricLBuehler#1289)

* Add more SDPA head dims for much faster SigLIP (EricLBuehler#1290)

* More sdpa head dims, faster vision models

* Move nonzero to above for faster metal synch

* Doc

* Update valid head dims

* Show throughput in interactive mode (EricLBuehler#1291)

* Update interactive mode throughput stats

* Accurate prompt t/s

* Accurate prompt t/s for usage

* Unify bitwise operations (EricLBuehler#1288)

* Unify bitwise ops

* Tests pass

* Fix cuda build

* Clippy

* Multimodal prefix caching support! (EricLBuehler#1298)

* Initial progress

* Support vision prefix caching

* Update docs

* Add multimodal data abstraction

* Interactive mode improvements (EricLBuehler#1299)

* More ergonomic image url parsing

* Add option to clear

* Add the Qwen 3 and Qwen 3 MoE models! (EricLBuehler#1285)

* Add qwen3 model

* Add enable_thinking

* Add initial qwen3 moe

* Add the moe model

* Format

* Fix order of norm

* Fix expert shapes

* Fix reverse

* Fix norm device for isq

* Fix nonzero when no nonzero

* Moe model runs

* Working qwen3 moe

* Add metal fp8 blockwise dequant

* Clean

* Typo

* Enable tool calling

* Streamlined ux

* Add some examples

* Add docs

* Fix dead link

* Remove interactive mode max_len

* Update QWEN3.md

* Hotfix for vision mode clear

* Revamped and streaming web search support (EricLBuehler#1301)

* Streaming web search

* Refactor a bit

* More refactoring

* Add some logging, parallelize some things

* Allow url

* Suppress warning, allow multi-turn searching

* Batch compute_similarities

* Cap content len

* Typos

* Doc

* Handle vision messages or different tool call prefixes (EricLBuehler#1302)

* Fix cuda

* Tune web search budget

* Simplify prefix cacher (EricLBuehler#1305)

* Use rustyline to handle non-ascii in interactive mode (EricLBuehler#1306)

The io::stdin().read_line() cannot handle non-ascii input, which caused
crash when use backspace to delete non-ascii characters.

Introduce rustyline to the interactive mode to solve the problem. Plus
it can bring more editing features in the future.

Close EricLBuehler#1140

* Add more tools for automatic search (EricLBuehler#1307)

* Add interactive mode history

* Add a website extraction tool

* Pass toks by reference

* Optimize prompt chunking

* Fix CPU hogging in interactive mode (EricLBuehler#1309)

The log enabler should be checked after the sleep instead of a busy
loop checking.

Since the interactive mode always disables the token speed logger, 100%
CPU was taken by this loop always.

* Add Metal precompilation support  (EricLBuehler#1311)

* Add metal precompilation for paged attn

* Add for mistralrs-quant

* Better constructor

* Dont always build

* Fix name for paged attn rebuild

* Reduce thrashing of Metal autorelease (EricLBuehler#1313)

* Reduce calls to autorelease

* Optimize clone_in_cache

* Refactor float8

* make `AdapterPaths` and `LoraAdapterPaths` public (EricLBuehler#1314)

Make `AdapterPaths` and `LoraAdapterPaths` public so `LocalModelPaths`
can be constructed outside of `mistralrs-core`.

* Refactor KV cache manager (EricLBuehler#1315)

* Refactor kv cache

* Refactor caches

* Fix some overflows

* Add `Audio` and `Speech` model categories (EricLBuehler#1317)

* add `Audio` to `ModelCategory`

* add `Speech` to `ModelCategory`

* fix to go back to PartialEq having an exhaustiveness check

* Remove has_conv2d from vision model API (EricLBuehler#1318)

* Unified/automatic flash attention enabler (EricLBuehler#1319)

* Remove from sdpa params

* Fix errors

* No warnings

* Log

* Clippy

* Fix cublaslt 4d mask (EricLBuehler#1320)

* Fix cublaslt 4d mask

* Clippy

* Keep caches on gpu

* Qwen VL models fixes (EricLBuehler#1322)

* Add some defaults

* Fix

* Fix one thing

* 2.5 vl works

* Use caching again

* Fix v2

* Move index inside loop

* Offset in ropeidx

* Default support for vision prefix caching is false

* Fixes for all vision models (EricLBuehler#1323)

* Fix phi input processor?

* Fix phi input processor

* Handle no_prefix_cache from pipeline

* Phi models confirmed 👍

* Fixed for phi inputs processors

* Fixed for phi4

* Llama 3 confirmed 😀

* Mistral 3 confirmed 😃

* Idefics 2/3 fixes

* Some fixes

* Remove unsafety

* Improved+faster LRU prefix cacher (EricLBuehler#1321)

* Show TTFT

* Use LRU prefix cacher

* Faster prefix cacher

* Inplace ISQ support and default to mmap (EricLBuehler#1277)

* Initial impl of immediate isq

* Immediate isq -> !loading_isq

* Varbuiler utils always using mmap!

* Log

* Add for packed experts

* Afq without copy

* Clarify

* Clippy

* Apple immediate isq

* Better logic for loading_isq

* Support showing ttft

* Rename

* Shared quantize guard

* Parallel progress bar

* Parallel loading for progress bars

* Actual ISQ support

* Conditional parallelism for NiceProgressBar

* Use conditional iterator

* Warn once

* Predicate for applying immediate isq

* Allow parallel

* Remove debug print

* Remove debug print

* Remove debug print

* Fix typos (EricLBuehler#1329)

* Fix Idefics 3 arch chat templating (EricLBuehler#1330)

* Update inputs merger

* Fix

* Better warning

* Better warning

* Better warning

* Nonzero ahead of time

* No f32

* Clippy

* Optimize get_logprobs

* Fix packed experts

* Update masking

* Use Sdpa in idefics3

* QuantMethod in idefics3 vision

* Remove a .contiguous

* Remove two space from PR comment (EricLBuehler#1331)

* Add automatic vision loader type (EricLBuehler#1332)

* Add automatic vision loader

* Remove references to --arch

* Update examples

* Add the Dia 1.6b TTS model! (EricLBuehler#1304)

* Add loading

* Add rope, mlp, most of attn

* Add encoder + encoder layer, decoder layer forwards

* Add decoder forwards

* Add prepare_audio_prompt

* prepare_generation mostly done

* Add a proper dia kvcache

* Add most of decoder_step

* Add the sampler

* Add the generation loop

* Wire things up

* Add speech pipeline

* Fixes

* Loads

* Some fixes

* f32

* Some progress

* Ok it runs upto dac decoding

* Add dac part loading

* Loads and runs at least

* Remove encodec

* Debugging

* Debugging

* Huh

* Complete merge

* Interactive

* Confirmed dac works at least

* Looks like encoder works

* Much progress

* Hmm

* Sampling

* Almost there

* Sampler

* Sampler

* Bf16 support

* Response

* Use it in interactive mode

* Fix oneshot

* Add openai api

* Add openai api

* Refactor loading

* Use naive sdpa for inplace

* Factor out

* Clippy

* Clippy

* Config

* Refactor config

* Metal clippy

* Fix t/s

* ISQ support

* Some fixes, nits

* Fix cuda

* Clippy

* Inhibit cublaslt for cuda

* Add server example

* Add python example

* Add rust api

* Add docs

* Update config.toml

* Fix .pyi

* Update readme

* config.toml tweak

* config.toml tweak

* config.toml tweak

* config.toml tweak

* config.toml tweak

* config.toml tweak

* config.toml tweak

* config.toml tweak

* config.toml tweak

* update `llguidance` to `0.7.20` (EricLBuehler#1334)

Update `llguidance` from `0.7.16` to `0.7.20` so that it has guidance-ai/llguidance#172 which is a fix for building on GCC 15.

* Add model category <> messages check (EricLBuehler#1335)

* Verify model category matches the messages

* Add vision chat

* Fixes

* Add element-wise normalization check (EricLBuehler#1340)

* Fix streaming example print statement (EricLBuehler#1339)

* Fix normalization formula in comment (EricLBuehler#1338)

* Fix image_to_pixels to handle non-RGB images (EricLBuehler#1337)

* Fix typo in expect messages (EricLBuehler#1342)

* Don't use mmap on cuda (EricLBuehler#1336)

* No mmap on cuda

* Simplify streaming tool call logic

* Remove debug

* Support AWQ format models (EricLBuehler#1350)

* Support AWQ format models

* Clippy fix

* Fix uqff dummy layer ISQ application (EricLBuehler#1351)

* Disable immediate isq if write_uqff (EricLBuehler#1352)

* Fixes for UQFF loading on CUDA, ISQ pack factor (EricLBuehler#1354)

* Fix logic for uqff on cuda

* Updated pack_factor

* Refactor Option references for model paths (EricLBuehler#1347)

* refactor: use Option refs in model path helpers

* Format

* Add a script for server benchmarking (EricLBuehler#1355)

* Serde alias

* Fix

* Update for tie_word_embeddings

* Print running/waiting

* 30 users

* Update num_users

* Update dummy paged attn

* Optimized Metal qmv_fast path (EricLBuehler#1356)

* Compile with lto

* Tweak profiles

* New, fast sampler for Metal! (EricLBuehler#1327)

* Show TTFT

* Use LRU prefix cacher

* Faster prefix cacher

* A bit of gpu sampling

* Minp but cpu for now

* Metal fast cumsum impl

* Sampling with fast topp kernel

* Hmm not perfect

* Add metal sort kernels

* Tmp

* Add single block sort

* Add most of multi block sort, just need copy op

* Add copy kernels

* Expose kernels

* Add a test

* Ok it works

* Structure things

* Add caching

* Rename

* Cpu is default

* CUDA case

* Topk

* Refactor Option references for model paths (EricLBuehler#1347)

* refactor: use Option refs in model path helpers

* Format

* Add a script for server benchmarking (EricLBuehler#1355)

* Serde alias

* Fix

* Update for tie_word_embeddings

* Print running/waiting

* 30 users

* Update num_users

* Update dummy paged attn

* Optimized Metal qmv_fast path (EricLBuehler#1356)

* Compile with lto

* Tweak profiles

* Fix topk

* Penalties

* Add logits processor, clippy fixes

* Fix chat port

* Remove warning

* Fix chat port

* Fix metal parallel sampling (EricLBuehler#1357)

* Cpu if parallel for now

* Tweak bench script

* Add immediate isq predicates for qwen3 (EricLBuehler#1358)

* Add immediate isq predicates for qwen3

* Fix parsing of "parse_isq_value" depedent of device

* Typo

* Fix gemma3 logging

* Regressions fixes (EricLBuehler#1359)

* Fix regression for mmap

* Revert EricLBuehler#1321

* Refactored matching_cache impl

* Clippy

* Revamped and smaller readme (EricLBuehler#1360)

* Expandable detail sections

* Refactor using derivative model

* Tweak quick examples

* Update llama

* Update llama

* Supported accelerators is a table

* Update installation guides

* Tweak apis

* Remove --port in quick examples

* Add demo gif

* Add gif in readme

* Update demo gif

* Update demo gif

* Update demo gif

* Add gif in readme

* Add gif in readme

* Add a web chat app! (EricLBuehler#1362)

* Initial

* Markdown

* Copy code

* Add model loading sidebar

* Support vision models

* Tweak isq

* Links go to another page

* Clear when switch model

* Fix html tags

* Add image support!

* More then one images

* Fix

* Improved textarea

* Tab for switching between vision and text

* No paged attn for now

* Prettier format

* Multiple models at once

* Better switching, clearing ability

* Mobile support

* Inline markdown parser

* Update examples

* Typos

* Support specifying isq

* Fix mobile

* Fixes

* Fix button on mobile

* Image height is capped

* Thumbnail

* Fix rotating kv cache edge case

* Add drag and drop for images

* Small things

* Sidebar is frozen now

* Better listner

* Add readme

* Tweak readme

* Add chat history support to web chat app (EricLBuehler#1363)

* Add chat history

* Support renaming

* Start immediately with new chat

* Add timestamp

* Prettier chat list

* Style

* Delete chat

* Fix copy button

* Fix markdown rendering

* Store things in cache

* Store things in cache

* Refactor web chat, fix multichat image restore (EricLBuehler#1364)

* Fix multichat image restoration.

* Clippy

* Refactor

* Refactor frontent

* Fix repeated immediate isq init (EricLBuehler#1365)

* Add images_ref

* Add debug impl

* Fix the bug

* Tweak style of buttons

* Add a spinner

* Move spinner

* Tweak emoji

* Add gif

* Tweak initial gif

* Include vision tower tensors in Mistral3 UQFF (EricLBuehler#1366)

* Fix mistral 3 uqff resitdual tensors for vision

* Rolling shard creation for uqff files (EricLBuehler#1367)

* Fix occasional unstability during isq of afq (EricLBuehler#1368)

* Fix unstability during isq of afq

* Clippy

* Fix web chat installation

* Support web chat file uploading (EricLBuehler#1370)

* Web chat fixes

* Fix thumbnail in message, reuse blank chat

* Add file uploading support

* Fix scroll

* Allowed extensions

* Preserve files as literals

* Support multiple clients

* Add a stop button

* New cache dir

* New cache dir

* Fix

* Refactor

* Update readme

* Tweak drag-and-drop css

* Add speech generation support to the web chat! (EricLBuehler#1373)

* Initial speech gen support for web chat

* Tweak ui

* Update docs

* Prefix caching for PagedAttention! (EricLBuehler#1369)

* Exposing some things for logical token blocks

* Prefix cache manager has the scheduler

* Refactor

* Get logical and physical blocks into the prefix cacher

* Hash and cache

* Pass physical block prefill

* Allocation of prefilled block tables

* Temp

* Dont always use 2

* Hmm

* Hmm

* It mostly works

* Increment refcount

* Support images!

* Add to dummy paged attn

* Fix some clippy

* Clippy

* More checks

* Include EricLBuehler#1371, closes EricLBuehler#1371

* Typos

* Update docs

* Metal PagedAttention accuracy improvements (EricLBuehler#1374)

* Fix subtle bug

* Fix half sum bug

* Format metal paged attention

* Handle images in paged attn scheduler (EricLBuehler#1375)

* Include schemas needed for chatcompletions endpoint (EricLBuehler#1353)

* EricLBuehler#1326: WIP include schemas needed for chat completions endpoint

 Conflicts:
	Cargo.lock
	mistralrs-server/src/main.rs

* EricLBuehler#1326: WIP define utoipa as a workspace dep since core and server both need it

* EricLBuehler#1326: first draft of handling schemas that use Either

* EricLBuehler#1326: first draft of handling schema for Grammar

* EricLBuehler#1326: Add in other endpoints to API docs.

* EricLBuehler#1326: Adjust code comments

* EricLBuehler#1326: Implement coderabbitai suggestions

- EricLBuehler#1353 (review)
- EricLBuehler#1353 (comment)

* Fix constraints with metal sampler

* Revert EricLBuehler#1375

* Fix case where prefix cacher returns no toks (EricLBuehler#1377)

* Fix AFQ UQFF serialization

* Faster UQFF serialization (EricLBuehler#1379)

* Faster UQFF serialization

* Fix uqff gemma3

* Improve gemma3 auto loader names

* UQFF creation for AFQ on CPU support (EricLBuehler#1380)

* Add afq cpu quantize/dequantize

* Clippy

* Improved device for afq quantize

* Improved dtype handling for cpu afq (de)quantize

* Improved generate_uqff_card

* Add fused CPU attention kernel! (EricLBuehler#1382)

* Working

* Fix warnings

* Allow mask

* Support bf16, f16

* Handle striding

* Parallelized

* Add initial vector flash attn

* Avoid repeated allocations

* Tiled kv

* Apply some clippy

* Some small fixes

* Chunked vec_dot

* Clipy

* Use T::zero

* Refactor attention backends (EricLBuehler#1384)

* Refactor attention code

* Refactor attention code

* Move into backends

* Set macOS thread affinity for CPU attn (EricLBuehler#1385)

* Use lazylock

* Format

* Fix metal warn build

* Faster Qwen 3 MoE support on Metal (EricLBuehler#1387)

* Fix load

* Use afq gather qmm

* Well it runs

* It works

* Polish

* Fast and slow options

* Remove quantized.rs

* Polish some more

* Refactor

* Add isq

* Update load in parallel

* Support fp8

* Refactor for FusedExperts

* Clippy

* Handle pack factor when loading prequantized models

* Use f32 only in moe

* Avoid using f32 so much

* Avoid using f32 so much

* Fix PagedAttention block leaks (EricLBuehler#1388)

* Warn and ignore if ignored

* Fix a block allocation leak

* Update bench.py

* Fix double free in block engine

* Do not apply ISQ if loading a prequantized model

* Fix cuda build again (EricLBuehler#1389)

* Fix cuda build

* Fix

* Format

* Fixes for cuda docker

* Update dockerfiles

* Bump version to 0.6.0 (EricLBuehler#1390)

* Bump version to 0.6.0

* Remove lower_level api

* Make a static dir

* Update deps

* Fix routing for static handler in web chat

* Fewer .contiguous calls for qwen3 moe (EricLBuehler#1391)

* Allow speech models to accept batched inputs (EricLBuehler#1393)

* Allow speech models to accept batched inputs

* Clippy

* Ring distributed backend for heterogeneous TP (EricLBuehler#1238)

* Begin work on ring distributed backend for Metal

* Add the actual ring functionality

* It loads and kind of runs

* It works

* Optimize buffer allocation

* Avoid copy

* It works

* Add allgather

* Fix load

* Ping-pong

* Small things

* Add config json

* Allow different ip address

* Read config once

* Read config when appropriate

* Replicate requests

* Small fix

* Fix small compat with openai

* Clippy

* Update docs

* Add deepseek tool calling chat template

* Add auto loader for vision/text detection! (EricLBuehler#1402)

* Add auto loader for vision/text detection

* Build fixes

* Add model loader

* Update docs

* Format

* Create Mistral.rs Server Core Lib: `mistralrs-server-core` (EricLBuehler#1346)

* First draft of exposing mistral server routes as lib

* make arg struct fields pub

* Take base path so utoipa swagger route can properly redirect

* Expose swagger routes and make it configurable

* Add base path option for swagger docs

* More work on modularizing mistralrs server

* Sync fork (+1 squashed commit)
Squashed commits:
[169ae9e] Sync fork

* Adjust fn params to use refs / individual params instead of args

* Start breaking down controller actions into smaller pieces

* Continue refactoring

* Make mods pub so they can be used outside crate

* Allow chat completion streamer to take a callback so that you can get the complete response when finished

WIP (+3 squashed commits)
Squashed commits:
[0061d87] WIP
[c484d56] WIP
[16f8a60] WIP

* Sync fork

* Adjust callback type

* Remove throughput_log arg that was removed in 26afcc3

* Implement defaults for Args (and use for Clap)

* Small code formatting tweaks

* Rename callback to match SSE event and code clean up

* Sync fork

* WIP: first very rough draft of server core builder. Doesn't meet parity with old functional approach yet (slower / unstable?).

* Clean up (+4 squashed commits)
Squashed commits:
[e1cff387] Sync fork
[d8301025] WIP debugging
[1ea9f8c8] Sync fork
[4fe28cf5] WIP: debug function

* WIP server core builders

* Code clean up

* Add on_chunk callback

* Code clean up

* First draft of creating version of mistral-server that uses server-core

Code clean up (+1 squashed commit)
Squashed commits:
[adea1693]

* Sync fork

* Add helper methods to builder to make optional args more ergonomic (since .build validates params)

* Start adding docs

* Start cleaning up crates deps

* Example commit of mistral-server with implementing server-core

* Start addressing CodeRabbit feedback

* Fix comment typo

* Tweak doc blocks

* - Update type alias naming for clarity (MistralRs instead of Mistral)
- CodeRabbit, don't use eprintln for lib (use trace)
- Allow buffer size to be passed in and default to Constant
- Allow router body limit to be passed in and default to Constant
- Update doc examples

* Typo

* Address CoderRabbitAI feedback

* Support linear rope for llama3 (EricLBuehler#1408)

* Hotfix for loading

* Fix vllama4 uqff loading (EricLBuehler#1409)

* Fix vllama4 uqff loading

* Fix regex

* Fix regex

* Maybe a fix

* Gracefully handle receiver disconnects (EricLBuehler#1410)

* Handle receiver disconnects

* Format

* Fix Qwen3 MoE device mapping irregularities (EricLBuehler#1411)

* Fix bias

* Fix lm_head packing case

* Account for gate

* Fix head dim

* Fix interactive mode URL parsing (EricLBuehler#1412)

* fix url regex in vision interactive mode

* Fix regex

* Clippy

* Refactor auto device map (EricLBuehler#1413)

* Refactor auto device map

* Refactor a bit more

* Clippy

* Enable runtime sampling tweaks in interactive mode (EricLBuehler#1414)

* Document runtime sampling commands

* Fix readme

* Tweak

* Bounds checking

* Tweak temp bounds

* Send streaming tokens every time

* Gumbel sampling for fast sampler (EricLBuehler#1416)

* Improved handling for initialize_logging

* Improved CPU flash attention accuracy & performance (EricLBuehler#1417)

* Downcast correctly

* Operate internally in f32

* Avoid some casts and striding

* Prefetch

* Provide chat_templates to container users (EricLBuehler#1419)

Models often come without chat templates requiring mapping them
from the source repository into a container for access by the
mistralrs-server.

Copy the templates from the build tree into the root of the image
to permit use via `--chat-template /chat_templates/something.json`

TODO:
  With the increase in quantized models and support for other
formats, the initial benchmark run during model load can be used
to qualify/select existing chat templates embedded into the binary
for models which do not come with any (to include output of the
functional failures in each test allowing users to modify the
ones already provided correctly to suit the model being loaded).

Co-authored-by: RageLtMan <rageltman [at] sempervictus>

* Faster cpu flash attn (EricLBuehler#1418)

* Faster cpu flash attn

* Prefetch

* Clippy

* Add some tests

* Add softcap tests

* Fix test_parse_image_url test

* Update tests

* Update tests

* Web search improvements (bm25, web chat) (EricLBuehler#1420)

* Fix web search blocking case

* Web search support in web chat

* Tweak ui

* Support fallback to bm25

* Clippy

* Reinject descriptions

* Propely handle consecutive searches (EricLBuehler#1421)

* Update extraction tool reinjection

* Looped

* Update docs (EricLBuehler#1422)

- lib.rs: clean up example var names and match logging change from EricLBuehler@201d6be
- server_builder: fix typo
- READMEs: link to crate docs

* Better tool call detection logic (EricLBuehler#1424)

* Add web search hook callbacks (EricLBuehler#1426)

* feat: add customizable search hook

* Move to builder

* Update docs

* Fix CUDA context switching, bind thread on CudaStorage drop (EricLBuehler#1428)

* Add CUDA context helper and use in Llama forward

* No flashparams?

* working

* Tweak

* Update to use dep

* conditionally build flash attention inputs (EricLBuehler#1429)

* Add AGENTS.md (EricLBuehler#1430)

* Support Qwen3 GGUF model (EricLBuehler#1432)

* Support QWen3 GGUF model

* Clippy fix

* cargo fmt

* Improved paged attn prefix caching (EricLBuehler#1434)

* Improved paged attn prefix caching

* Disable

* Clippy

* Temporary fix for qwen3 gguf tokenizer (EricLBuehler#1433)

* Temporary fix for qwen3 gguf tokenizer

* Typo fix

* Add tool callback support (EricLBuehler#1427)

* Add tool callback support

* Fixes

* Support named tool callbacks

* Update examples

* Update docs

* Clippy

* Centralize crate dependencies (EricLBuehler#1438)

* chore: centralize dependencies

* Format

* Fix bug in tokenizer created with gguf metadata (EricLBuehler#1440)

* Fix bug in tokenizer created with gguf metadata

* Clippy fix

* Update deps (EricLBuehler#1441)

* Small things

* Update deps

* Update deps

* Update breaking changes

* Doc fixes (EricLBuehler#1442)

* Mention uqff_maker

* Downgrade rustyline 16.0.0 -> 15.0.0 (EricLBuehler#1444)

* Add max_completion_tokens alias for server (EricLBuehler#1451)

* Audio input support (Phi 4 multimodal) (EricLBuehler#1448)

* Deps

* Add conformer

* Nemo loading

* Position embeds

* Load t5 attn bias

* Attn and feed forward

* Add conv module and glu pointwise

* Implement relative attn bias

* Add the forward methods

* Add encoder embedding

* Fix oproj

* Some loading

* Conformer loads!

* Fully loading speech stack

* Merger

* Dont need that

* First pass at audio processing

* Read samples

* Optional

* Small loading fix

* Runs but not correct yet

* Improved audio processing?

* Works with this

* Fix t5 attn bias

* It works!

* Comment

* Use some other crates

* Clippy

* Allow bf16 on metal

* Add prefix_audio

* Remove unused

* Typo

* User specified

* Add audio url parsing

* AudioProjectionMode -> InputMode

* Audio prefix caching

* Fix bug in audio prefix caching

* Support both at the same time!

* Tweak logging

* Support stereo

* Add mistralrs-audio

* Support batching

* Add server and rust api example

* Add python api

* Fix add_multimodal_message

* Fix unfold for conformer

* Streaming example

* Add web chat support

* Add modalities registry

* Fix offline cache issue for gguf models (EricLBuehler#1452)

* Add MCP server endpoints (EricLBuehler#1453)

* feat(server): add MCP server support

* Add mcp docs

* Add handle_list_tools_request

* Better launch, tool handling

* Tmp state

* Ok works

* Handle modalities

* Update docs

* Add ping

* Tweak temperature bounds, args

* MCP documentation pass (EricLBuehler#1455)

* Fix table

* Update mcp docs

* Improve readme header

* Improve readme header

* Integrate an MCP client (EricLBuehler#1456)

* Add builtin mcp client

* Use async loader

* Add headers

* Handle sse

* More flexible search request

* Add tool callbacks with tools, for mcp

* Add bearer token support

* Add websocket support

* Update docs

* Add python api

* Clippy

* Add http api, docs

* Tests pass

* Make these configs actually work

* Add docs

* Make mistralrs-mcp

* Refactor examples

* Update examples

* Add defaults

* Add defaults

* Add defaults

* Update docs

* Improved docs

* Add -y to npx usages

* Even better examples

* Update generate_wheels

* Update generate_wheels

* Update generate_wheels

* Fix Dockerfile.cuda-all

* Improve automatic tool call (EricLBuehler#1460)

* Improved auto tool call

* Add logging

* chore: `Dockerfile.cuda-all` configurable threads (EricLBuehler#1458)

* chore: `Dockerfile.cuda-all` - Merge `RUN` for `apt-get install` (EricLBuehler#1459)

* Add fallback definition for isnan (EricLBuehler#1463)

* chore: `Dockerfile` - Drop runtime rayon thread ENV (EricLBuehler#1465)

* chore: Dockerfile - Remove rayon threads env

* chore: Dockerfile - Improve formatting for `apt-get`

* Remove duplicate calls for api_dir_list (EricLBuehler#1474)

* Remove duplicate calls for api_dir_list

* Support local cache for api_dir_list

* Fix home folder for metal

* Capitalized

* Fix transient pyo3 dep (EricLBuehler#1478)

Co-authored-by: Eric Buehler <[email protected]>

* Fix objc dep with non macos (EricLBuehler#1480)

* Fix phi 3/4 + nccl issue (EricLBuehler#1481)

* Fix log

* Fix n kv heads

* Fix phi3.5 moe (EricLBuehler#1482)

* Fix phi3.5 moe accum device

* Fix again

* Fix again

* Support GLM4 model! (EricLBuehler#1437)

* Support GLM4 model

* Mention GLM4 model in ReadMe

* glm4 type hint

* Typo fix

* Fix unsupported chat_template function

* Clippy fix

* Refactor distributed backend (EricLBuehler#1484)

* Refactor distributed backend, check power of 2

* Fix compilation

* Cap metal paged attn kv allocation (EricLBuehler#1485)

* Better paged attn metal cap (EricLBuehler#1486)

* Better paged attn metal cap

* Small fix

* Comment

* Small fix

* Refactor

* Server core: consolidate and unify route handlers and API surface (EricLBuehler#1423)

* Start working on consolidating completion and chat_completion underlying implementations

* Move response channel to util mod for now (since it's used with streaming and non streaming)

* More work on consolidating completions and chat completions

* More WIP consolidation of server core handlers

* More WIP consolidation of server core handlers

* More WIP consolidation of server core handlers

* Update docs and restrict completion core visibility

* CodeRabbit feedback: remove logprobs warn from route handler since parse request also checks this

* Use consistent var name for completions mod

* Make route handler modules public API consistent (same fn names, etc.) and provide proxy fn that wrap core fns so core mod doesn't have to be pub
Make lib.rs example compile checked and update example

* Code formatting

* Typo

* Sync fork

* Sync fork

* Docs example fix

* Support qwen3 gguf (EricLBuehler#1488)

* Add qwen3 gguf

* Template fixup

* Make bos/eos token IDs optional (EricLBuehler#1493)

* Remove python deps from CUDA dockerfiles (EricLBuehler#1487)

* Handle noncontiguous v in naive_sdpa (EricLBuehler#1499)

Co-authored-by: Eric Buehler <[email protected]>

* Server Core: refactor Paged Attention configuration (EricLBuehler#1500)

* Use StorageModePrivate for Metal PA kv cache (EricLBuehler#1506)

* Fix OpenAI stream: emit field in tool-call deltas for schema compliance (EricLBuehler#1507)

* FP8 KV-cache quantization for PagedAttention (EricLBuehler#1400)

* Add most of paged attn kv quant

* It builds a bit

* All the functionality at least

* Small fix

* Add a scale

* Fix bf16 usage

* Make k_v_scale optional

* Collector

* Tweak collection

* Refactor

* Add to apis

* Add cuda impl

* Fix compilation

* Fixes

* Handle ENABLE_FP8

* Format

* Tweak

* Fix scaled_convert usage

* Fix cache_t size

* Fixed scale collection

* Actual fix

* Fix fp8 for CC<8

* Fix the usual String != &str bit (EricLBuehler#1483)

Co-authored-by: RageLtMan <rageltman [at] sempervictus>

* chore: `Dockerfile` - Drop runtime rayon thread ENV (EricLBuehler#1465)

* chore: Dockerfile - Remove rayon threads env

* chore: Dockerfile - Improve formatting for `apt-get`

* Remove duplicate calls for api_dir_list (EricLBuehler#1474)

* Remove duplicate calls for api_dir_list

* Support local cache for api_dir_list

* Fix home folder for metal

* Capitalized

* Fix transient pyo3 dep (EricLBuehler#1478)

Co-authored-by: Eric Buehler <[email protected]>

* Fix objc dep with non macos (EricLBuehler#1480)

* Fix phi 3/4 + nccl issue (EricLBuehler#1481)

* Fix log

* Fix n kv heads

* Fix phi3.5 moe (EricLBuehler#1482)

* Fix phi3.5 moe accum device

* Fix again

* Fix again

* Support GLM4 model! (EricLBuehler#1437)

* Support GLM4 model

* Mention GLM4 model in ReadMe

* glm4 type hint

* Typo fix

* Fix unsupported chat_template function

* Clippy fix

* Refactor distributed backend (EricLBuehler#1484)

* Refactor distributed backend, check power of 2

* Fix compilation

* Cap metal paged attn kv allocation (EricLBuehler#1485)

* Better paged attn metal cap (EricLBuehler#1486)

* Better paged attn metal cap

* Small fix

* Comment

* Small fix

* Refactor

* Server core: consolidate and unify route handlers and API surface (EricLBuehler#1423)

* Start working on consolidating completion and chat_completion underlying implementations

* Move response channel to util mod for now (since it's used with streaming and non streaming)

* More work on consolidating completions and chat completions

* More WIP consolidation of server core handlers

* More WIP consolidation of server core handlers

* More WIP consolidation of server core handlers

* Update docs and restrict completion core visibility

* CodeRabbit feedback: remove logprobs warn from route handler since parse request also checks this

* Use consistent var name for completions mod

* Make route handler modules public API consistent (same fn names, etc.) and provide proxy fn that wrap core fns so core mod doesn't have to be pub
Make lib.rs example compile checked and update example

* Code formatting

* Typo

* Sync fork

* Sync fork

* Docs example fix

* Support qwen3 gguf (EricLBuehler#1488)

* Add qwen3 gguf

* Template fixup

* Make bos/eos token IDs optional (EricLBuehler#1493)

* Remove python deps from CUDA dockerfiles (EricLBuehler#1487)

* Handle USE_FP8 for cuda

* Fix cuda warn

* Add readme

* Saturating sub in sequence state

---------

Co-authored-by: Eric Buehler <[email protected]>
Co-authored-by: RageLtMan <[email protected]>
Co-authored-by: Brennan Kinney <[email protected]>
Co-authored-by: Guoqing Bao <[email protected]>
Co-authored-by: Matthew Haynes <[email protected]>

* Validate model name in OpenAI API (EricLBuehler#1509)

* Validate model name in openai api

* Add docs, allow 'ignore'

* Updated examples for EricLBuehler#1509

* Fix mcp import in doc string (EricLBuehler#1510)

* Add multi-model support! (EricLBuehler#1512)

* Refactor MistralRs

* Working multi-model!

* Add mutli-model docs initially

* Update mistralrs-pyo3, mistralrs-bench, mistralrs

* Update apis for consistency

* API tweaks

* Logging tweaks

* Add examples, tweak cli

* Clearer pipeline id

* Fix config key semantics

* Format and clippy

* Tweak logging, fix example

* Clippy refactor

* Update examples

* Remove unused multi model docs

* Replace 'ignore' with 'default'

* Update docs

* Add stars label to readme (EricLBuehler#1513)

* Add CLAUDE.md

* Handle base_model.model case in lora (EricLBuehler#1514)

* Add thread_local! for engine-specific const/static (EricLBuehler#1517)

* Fix MCP doc test (EricLBuehler#1511)

* Allow disabling metal precompilation (EricLBuehler#1518)

* Allow disabling metal precompilation

* Simple preprocessor

* Simple docs

---------

Co-authored-by: Eric Buehler <[email protected]>

* Rust 1.88 clippy (EricLBuehler#1522)

* Rust 1.88 clippy

* Format

* Fix cuda warnings (EricLBuehler#1526)

* Avoid panic decoding tokens on error (EricLBuehler#1527)

* Split Marlin and Paged Attention kernels for faster build (EricLBuehler#1525)

* Split Marlin and Paged Attention kernels for faster build

* Typo fix

* chore: update llguidance (EricLBuehler#1535)

* chore: update llguidance

* chore: remove unused import

* Add the SmolLM3 model! (EricLBuehler#1501)

* Add model

* Update loader

* Fix llama config usage

* Docs

* Fix config no_rope_layers

* Fix tie_word_embeddings default

* Add chat template

* Embed the chat templates

* Fix embedding template

* enable_thinking default true

* Update examples

* XML tools for smollm3

* Add smollm3 docs

* Fix openai examples

* Clippy

---------

Co-authored-by: Eric Buehler <[email protected]>

* Add full Gemma 3n support! (EricLBuehler#1519)

* Add initial

* Loading for text model

* Add ple embeddings

* Add altup, laurel block

* Update rmsnorm

* Add mlp

* Update attn norm application

* Currently no kv shared

* Wire it up

* It runs

* Fix bf16

* Fix scaled embd

* Fixes for mean

* tmp

* Attn confirmed

* Fix target_magnitude

* Add shared kv

* Ok it works

* Remove npy

* Fix streaming

* Remove warnings

* Remove paged attn

* Refactor rope

* Add immediate isq

* Add vision & mproj

* Update image processor

* Vision merge runs, not correct

* Remove

* Add mobilenet v5

* Add multimodal vision embedding

* Fix load

* runs

* Fix gamma

* Works but just not vision tower

* It works!!

* Tweak

* Fix warnings

* Move vision tower

* Fix warn

* Update cache manager things

* Refactor

* Add audio model, it loads

* Add audio processing

* It runs at least

* tmp

* A bit better

* Audio works!!!!

* Fused attn in vision

* Clippy

* Update audio runner

* Optimized audio model

* Remove unused things

* Fix inputs processor bug

* Remove comments

* Clippy

* Small optimizations

* Format

* Correctly register modalities

* Add docs

* Update readme

* Runs there

* Fixed padding from Blaizzy/mlx-vlm#410

* Add better checks

* Fix sdpa n_kv_groups

* Vision encoder works!

* Rotate image

* Clippy

* Fix cuda loading

* Updated device mapper

* Fix overflow

* Fix dtype errors

* Refactor image/audio embeddings

* Fix metal

* Fix dtype mismatch

* Audio processing fixes

* Audio processing fixes

* Works

* Audio is good

* Fix boi/eoi too

* Embed the chat templates

* Better embedding accuracy in non f32

* More f32

* Support bf16 on metal

* Add more ISQ

* Fixed device map

* Clippy

* Gemma3n no paged attn

* Fix saturating sub

* Faster rmsnorm

* Use sdpa for vision model

* Fix ple bug

* Fix name

* Fix multiaudio

* Add matformer config loading

* Add docs

* Add support for matformer in auto device mapper

* Update docs

* Typos

* Tweak

* Tweak

* Fix multidevice

* Fix gemma3n text model auto device map

* Fix dims3

* Fix auto devic emap vision

* Non-metal keeps PLE on cpu

* Complete merge

* Vision dtype f16 -> f32

* Fix metal nm device

* Fix uqff

* Typos

* Reference uqff

* Fix tests

* Fix sequence length check (EricLBuehler#1546)

* update candle version (EricLBuehler#1545)

Co-authored-by: AlpineVibrations <[email protected]>

* add ios target to metal deps (EricLBuehler#1548)

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: Eric Buehler <[email protected]>
Co-authored-by: Eric Buehler <[email protected]>
Co-authored-by: edwko <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Guoqing Bao <[email protected]>
Co-authored-by: Michał Moskal <[email protected]>
Co-authored-by: Chen Mulong <[email protected]>
Co-authored-by: Steph Wolski <[email protected]>
Co-authored-by: omahs <[email protected]>
Co-authored-by: Viktor Szépe <[email protected]>
Co-authored-by: Matthew Haynes <[email protected]>
Co-authored-by: RageLtMan <[email protected]>
Co-authored-by: Brennan Kinney <[email protected]>
Co-authored-by: Eric Buehler <[email protected]>
Co-authored-by: Sbargaoui <[email protected]>
Co-authored-by: Gaétan Lepage <[email protected]>
Co-authored-by: Ammar Elsabe <[email protected]>
Co-authored-by: luke <[email protected]>
Co-authored-by: AlpineVibrations <[email protected]>
Co-authored-by: Michael Tissen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants