Add Liquid Foundation Model (LFM2)#16890
Conversation
Summary of ChangesHello @tugot17, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates comprehensive support for the Liquid Foundation Model 2 (LFM2) into the SGLang framework. LFM2 is a novel hybrid architecture that leverages both traditional attention mechanisms and efficient ShortConv layers. The changes include the complete model implementation, a specialized configuration for managing its unique hybrid caching requirements, and a new function call parser tailored to LFM2's specific format. This integration significantly expands SGLang's capabilities to support advanced hybrid models, ensuring efficient inference and accurate function calling. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Liquid Foundation Model (LFM2), a hybrid architecture model. The changes are comprehensive, covering model implementation, configuration, and function call parsing. A key improvement is the dynamic selection of convolution state dtype, which resolves a CUDA graph capture issue and enhances robustness. The new Lfm2Detector correctly handles both Pythonic and JSON-based tool call formats. The addition of extensive unit and integration tests ensures the new model is well-integrated and functions as expected. Overall, this is a high-quality contribution. I have one minor suggestion to remove some unused code for better maintainability.
|
please resolve the conflict, thanks~ |
| # Init memory pool and attention backends | ||
| self.init_memory_pool(min_per_gpu_memory) | ||
| # Set default dtype so mamba2_cache_params picks up the correct dtype for conv state | ||
| with set_default_torch_dtype(self.model_config.dtype): |
There was a problem hiding this comment.
I think it is unnecessary
There was a problem hiding this comment.
yeah, there was an issue when we initialized the model with fp16, but now I just propagate the dtype.
@yizhang2077 wdyt?
|
/tag-and-rerun-ci |
|
Got this on B200: |
|
@JustinTong0323 |
LFM2 was failing on B200/SM100 because: 1. SM100 defaults to trtllm_mha backend which forces page_size=64 2. MambaRadixCache requires page_size=1 for hybrid models 3. Triton backend doesn't work because LFM2's first layer is conv, not attention Add Lfm2ForCausalLM to server_args.py with same handling as NemotronH: - Use flashinfer backend on SM100 (supports page_size=1) - Disable overlap schedule with radix cache - Block triton backend (layer 0 is not an attention layer)
|
Fixed! The issue was that LFM2 wasn't in the SM100 special handling in server_args.py. On B200, the default backend is (Also had to block triton backend since LFM2's first layer is a conv layer, not attention.) The tests pass now on B200 as well |
|
@JustinTong0323 Could we merge it now? |
|
I will help to verify this model on gsm8k to see if it works as expeced required by @JustinTong0323 |
|
@ChangyiYang Could you share the test results? |
Let me get to you by today. |
|
|
||
| # Propagate runtime dtype to hf_config so that hybrid models (mamba, LFM2, etc.) | ||
| # can use it for conv state cache dtype | ||
| self.hf_config.torch_dtype = self.dtype |
There was a problem hiding this comment.
could we use other way to pass dtype to conv, I doubt it may affect other models here.
Besides this, it looks good for me in other part
There was a problem hiding this comment.
@yizhang2077 what do you think about changing it in mamba_utils.py
def mamba2_state_dtype() -> Mamba2StateDType:
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}
conv_dtype = dtype_map.get(
os.environ.get("SGLANG_MAMBA_CONV_DTYPE", "bfloat16"), torch.bfloat16
)
ssm_dtype = dtype_map.get(
os.environ.get("SGLANG_MAMBA_SSM_DTYPE", "float32"), torch.float32
)
return Mamba2StateDType(conv=conv_dtype, temporal=ssm_dtype)This way we could modify SGLANG_MAMBA_CONV_DTYPE as an ENV VARIABLE, similar to how it is already done for other models. Than the tests that use fp16 should just pass.
Would this make sense? I agree the current version is kinda too hacky.
There was a problem hiding this comment.
I think we could use this way.
There was a problem hiding this comment.
I also added this to the tests so they don't require setting the flag manually
def assert_close_logits_and_output_strs(
self,
prompts: List[str],
model_case: ModelCase,
torch_dtype: torch.dtype,
) -> None:
model_path = model_case.model_path
prefill_tolerance, decode_tolerance, rouge_l_tolerance = (
model_case.prefill_tolerance,
model_case.decode_tolerance,
model_case.rouge_l_tolerance,
)
max_new_tokens = 32
# Set conv dtype for hybrid models to match inference dtype
dtype_str = {torch.float16: "float16", torch.bfloat16: "bfloat16"}.get(
torch_dtype, "bfloat16"
)
os.environ["SGLANG_MAMBA_CONV_DTYPE"] = dtype_str|
I run model with command The test command and result is: Is this expected performace? |
|
We also get following result for @tugot17 Could you please confirm it's the evaluation mismatch or the implementation's issue? Thanks~ |
|
BTW, the tool call parser LGTM |
| # For ShortConv layers, we use a simplified Mamba2StateShape | ||
| # LFM2 doesn't use SSM state (state_size=0), only conv state | ||
| shape = Mamba2StateShape.create( | ||
| tp_world_size=tp_size, |
There was a problem hiding this comment.
I think we need to refactor it later but it is ok for current pr. I think ShortConv-only models being mixed with mamba models is tricky here. cc @ispobock @hebiao064
There was a problem hiding this comment.
Yes, we need to do some refactor later.
|
@tugot17 Could you address the above comments and verify the accuracy? And then we can merge it soon. |
|
I will run the internal eval tool today and get back to you. |
|
@ispobock
This is very similar (some slighly better, some slighly worse) to the numbers I get from the internal vLLM. |
|
Hi @tugot17 Thanks for the evaluation! Could you address the |
yes, I proposed another solution, this one was a workaround due to the numeric tests running in fp16 that crushed the cuda graphs config |
|
/rerun-stage unit-test-backend-4-gpu |
|
✅ Triggered |
|
@yizhang2077 Could you approve again? I added one commit to make tests more smooth, see comment |
|
/tag-and-rerun-ci |
|
B300: We found that such a workload could be as little as 180ms e2e. Amazing! |
Summary
Changes
srt/models/lfm2.pysrt/configs/lfm2.pymamba2_cache_paramssrt/configs/mamba_utils.pysrt/model_executor/model_runner.pysrt/function_call/lfm2_detector.pysrt/function_call/function_call_parser.pylfm2parsertest/srt/models/test_generation_models.pytest/registered/function_call/test_function_call_parser.pyLfm2Detectortest/registered/openai_server/function_call/test_tool_choice.pyTestToolChoiceLfm2)Key Technical Details
ShortConv Architecture
HybridReqToTokenPool+MambaPoolfor hybrid cachingcausal_conv1d_fn()Conv State Dtype Fix
The
causal_conv1d_updatekernel requires conv state dtype to match input dtype exactly. We fixed a dtype mismatch that caused CUDA graph capture to fail:mamba_utils.pyhardcodedCONV_DTYPE = torch.bfloat16, but tests run models intorch.float16SGLANG_MAMBA_SSM_DTYPEenv var viaServerArgs, but test path usesEnginedirectly (bypassing this)get_conv_dtype()function that dynamically gets dtype fromtorch.get_default_dtype()SGLANG_MAMBA_SSM_DTYPEenv varinit_memory_pool()inset_default_torch_dtype(self.model_config.dtype)contextThis fix is safe for other Mamba-based models (NemotronH, FalconH1, Qwen3Next) - server behavior is unchanged.
Function Calling Support
Added
Lfm2Detectorfor parsing LFM2's tool call format with special tokens:<|tool_call_start|>[get_weather(city="Paris")]<|tool_call_end|>
Usage:
--tool-call-parser lfm2Tests
Logprob accuracy test (compares SGLang vs HuggingFace):
LFM2 function call parser (unit test):
LFM2 function calling (integration test):
Benchmark Performance
We run
MMLU-Proandtau2on our internal OAI-server compatible benchmarking suite and the scores match the declared performance.Running on
1xH100 SXM5============ Serving Benchmark Result ============ Backend: sglang Traffic request rate: inf Max request concurrency: not set Successful requests: 500 Benchmark duration (s): 3.77 Total input tokens: 127216 Total input text tokens: 127216 Total input vision tokens: 0 Total generated tokens: 63834 Total generated tokens (retokenized): 62640 Request throughput (req/s): 132.79 Input token throughput (tok/s): 33786.60 Output token throughput (tok/s): 16953.32 Peak output token throughput (tok/s): 33486.00 Peak concurrent requests: 500 Total token throughput (tok/s): 50739.92 Concurrency: 404.12 ----------------End-to-End Latency---------------- Mean E2E Latency (ms): 3043.26 Median E2E Latency (ms): 3341.52 ---------------Time to First Token---------------- Mean TTFT (ms): 863.90 Median TTFT (ms): 786.99 P99 TTFT (ms): 1621.48 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 32.01 Median TPOT (ms): 18.84 P99 TPOT (ms): 310.85 ---------------Inter-Token Latency---------------- Mean ITL (ms): 17.24 Median ITL (ms): 8.57 P95 ITL (ms): 18.11 P99 ITL (ms): 120.99 Max ITL (ms): 996.73 ==================================================