diff --git a/nemo_skills/inference/server/serve_trt.py b/nemo_skills/inference/server/serve_trt.py index b9a371e0c6..64feb6c66d 100644 --- a/nemo_skills/inference/server/serve_trt.py +++ b/nemo_skills/inference/server/serve_trt.py @@ -184,6 +184,7 @@ def read_model_name(engine_dir: str): name_map = { 'MistralForCausalLM'.lower(): 'mistral', 'LlamaForCausalLM'.lower(): 'llama', + 'MixtralForCausalLM'.lower(): 'mixtral', } return name_map[name]