diff --git a/lib/llm/src/model_card/create.rs b/lib/llm/src/model_card/create.rs index 99df2b2c5c..b69525beb7 100644 --- a/lib/llm/src/model_card/create.rs +++ b/lib/llm/src/model_card/create.rs @@ -86,6 +86,7 @@ impl ModelDeploymentCard { tokenizer: Some(TokenizerKind::from_gguf(gguf_file)?), gen_config: None, // AFAICT there is no equivalent in a GGUF prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())), + chat_template_file: None, prompt_context: None, // TODO - auto-detect prompt context revision: 0, last_published: None, @@ -124,6 +125,7 @@ impl ModelDeploymentCard { tokenizer: Some(TokenizerKind::from_repo(repo_id).await?), gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?, + chat_template_file: PromptFormatterArtifact::chat_template_from_repo(repo_id).await?, prompt_context: None, // TODO - auto-detect prompt context revision: 0, last_published: None, @@ -157,6 +159,19 @@ impl PromptFormatterArtifact { .ok()) } + pub async fn chat_template_from_repo(repo_id: &str) -> Result> { + Ok(Self::chat_template_try_is_hf_repo(repo_id) + .await + .with_context(|| format!("unable to extract prompt format from repo {}", repo_id)) + .ok()) + } + + async fn chat_template_try_is_hf_repo(repo: &str) -> anyhow::Result { + Ok(Self::HfChatTemplate( + check_for_file(repo, "chat_template.jinja").await?, + )) + } + async fn try_is_hf_repo(repo: &str) -> anyhow::Result { Ok(Self::HfTokenizerConfigJson( check_for_file(repo, "tokenizer_config.json").await?, diff --git a/lib/llm/src/model_card/model.rs b/lib/llm/src/model_card/model.rs index 031d2baf41..6fd6efe38d 100644 --- a/lib/llm/src/model_card/model.rs +++ b/lib/llm/src/model_card/model.rs @@ -62,6 +62,7 @@ pub enum TokenizerKind { #[serde(rename_all = "snake_case")] pub enum PromptFormatterArtifact { HfTokenizerConfigJson(String), + HfChatTemplate(String), GGUF(PathBuf), } @@ -101,6 +102,10 @@ pub struct ModelDeploymentCard { #[serde(default, skip_serializing_if = "Option::is_none")] pub prompt_formatter: Option, + /// chat template may be stored as a separate file instead of in `prompt_formatter`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub chat_template_file: Option, + /// Generation config - default sampling params #[serde(default, skip_serializing_if = "Option::is_none")] pub gen_config: Option, @@ -259,6 +264,11 @@ impl ModelDeploymentCard { PromptFormatterArtifact::HfTokenizerConfigJson, "tokenizer_config.json" ); + nats_upload!( + self.chat_template_file, + PromptFormatterArtifact::HfChatTemplate, + "chat_template.jinja" + ); nats_upload!( self.tokenizer, TokenizerKind::HfTokenizerJson, @@ -308,6 +318,11 @@ impl ModelDeploymentCard { PromptFormatterArtifact::HfTokenizerConfigJson, "tokenizer_config.json" ); + nats_download!( + self.chat_template_file, + PromptFormatterArtifact::HfChatTemplate, + "chat_template.jinja" + ); nats_download!( self.tokenizer, TokenizerKind::HfTokenizerJson, diff --git a/lib/llm/src/preprocessor/prompt/template.rs b/lib/llm/src/preprocessor/prompt/template.rs index 5261238343..82b990434f 100644 --- a/lib/llm/src/preprocessor/prompt/template.rs +++ b/lib/llm/src/preprocessor/prompt/template.rs @@ -26,7 +26,7 @@ mod oai; mod tokcfg; use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter}; -use tokcfg::ChatTemplate; +use tokcfg::{ChatTemplate, ChatTemplateValue}; impl PromptFormatter { pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result { @@ -37,13 +37,28 @@ impl PromptFormatter { PromptFormatterArtifact::HfTokenizerConfigJson(file) => { let content = std::fs::read_to_string(&file) .with_context(|| format!("fs:read_to_string '{file}'"))?; - let config: ChatTemplate = serde_json::from_str(&content)?; + let mut config: ChatTemplate = serde_json::from_str(&content)?; + // Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8) + // stores the chat template in a separate file, we check if the file exists and + // put the chat template into config as normalization. + if let Some(PromptFormatterArtifact::HfChatTemplate(chat_template_file)) = + mdc.chat_template_file + { + let chat_template = std::fs::read_to_string(&chat_template_file) + .with_context(|| format!("fs:read_to_string '{}'", chat_template_file))?; + // clean up the string to remove newlines + let chat_template = chat_template.replace('\n', ""); + config.chat_template = Some(ChatTemplateValue(either::Left(chat_template))); + } Self::from_parts( config, mdc.prompt_context .map_or(ContextMixins::default(), |x| ContextMixins::new(&x)), ) } + PromptFormatterArtifact::HfChatTemplate(_) => Err(anyhow::anyhow!( + "prompt_formatter should not have type HfChatTemplate" + )), PromptFormatterArtifact::GGUF(gguf_path) => { let config = ChatTemplate::from_gguf(&gguf_path)?; Self::from_parts(config, ContextMixins::default())