Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions lib/llm/src/model_card/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ impl ModelDeploymentCard {
tokenizer: Some(TokenizerKind::from_gguf(gguf_file)?),
gen_config: None, // AFAICT there is no equivalent in a GGUF
prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())),
chat_template_file: None,
prompt_context: None, // TODO - auto-detect prompt context
revision: 0,
last_published: None,
Expand Down Expand Up @@ -124,6 +125,7 @@ impl ModelDeploymentCard {
tokenizer: Some(TokenizerKind::from_repo(repo_id).await?),
gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional
prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?,
chat_template_file: PromptFormatterArtifact::chat_template_from_repo(repo_id).await?,
prompt_context: None, // TODO - auto-detect prompt context
revision: 0,
last_published: None,
Expand Down Expand Up @@ -157,6 +159,19 @@ impl PromptFormatterArtifact {
.ok())
}

pub async fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> {
Ok(Self::chat_template_try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract prompt format from repo {}", repo_id))
.ok())
}

async fn chat_template_try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfChatTemplate(
check_for_file(repo, "chat_template.jinja").await?,
))
}

async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfTokenizerConfigJson(
check_for_file(repo, "tokenizer_config.json").await?,
Expand Down
15 changes: 15 additions & 0 deletions lib/llm/src/model_card/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub enum TokenizerKind {
#[serde(rename_all = "snake_case")]
pub enum PromptFormatterArtifact {
HfTokenizerConfigJson(String),
HfChatTemplate(String),
GGUF(PathBuf),
}

Expand Down Expand Up @@ -101,6 +102,10 @@ pub struct ModelDeploymentCard {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_formatter: Option<PromptFormatterArtifact>,

/// chat template may be stored as a separate file instead of in `prompt_formatter`.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub chat_template_file: Option<PromptFormatterArtifact>,

/// Generation config - default sampling params
#[serde(default, skip_serializing_if = "Option::is_none")]
pub gen_config: Option<GenerationConfig>,
Expand Down Expand Up @@ -259,6 +264,11 @@ impl ModelDeploymentCard {
PromptFormatterArtifact::HfTokenizerConfigJson,
"tokenizer_config.json"
);
nats_upload!(
self.chat_template_file,
PromptFormatterArtifact::HfChatTemplate,
"chat_template.jinja"
);
nats_upload!(
self.tokenizer,
TokenizerKind::HfTokenizerJson,
Expand Down Expand Up @@ -308,6 +318,11 @@ impl ModelDeploymentCard {
PromptFormatterArtifact::HfTokenizerConfigJson,
"tokenizer_config.json"
);
nats_download!(
self.chat_template_file,
PromptFormatterArtifact::HfChatTemplate,
"chat_template.jinja"
);
nats_download!(
self.tokenizer,
TokenizerKind::HfTokenizerJson,
Expand Down
19 changes: 17 additions & 2 deletions lib/llm/src/preprocessor/prompt/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ mod oai;
mod tokcfg;

use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
use tokcfg::ChatTemplate;
use tokcfg::{ChatTemplate, ChatTemplateValue};

impl PromptFormatter {
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> {
Expand All @@ -37,13 +37,28 @@ impl PromptFormatter {
PromptFormatterArtifact::HfTokenizerConfigJson(file) => {
let content = std::fs::read_to_string(&file)
.with_context(|| format!("fs:read_to_string '{file}'"))?;
let config: ChatTemplate = serde_json::from_str(&content)?;
let mut config: ChatTemplate = serde_json::from_str(&content)?;
// Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)
// stores the chat template in a separate file, we check if the file exists and
// put the chat template into config as normalization.
if let Some(PromptFormatterArtifact::HfChatTemplate(chat_template_file)) =
mdc.chat_template_file
{
let chat_template = std::fs::read_to_string(&chat_template_file)
.with_context(|| format!("fs:read_to_string '{}'", chat_template_file))?;
// clean up the string to remove newlines
let chat_template = chat_template.replace('\n', "");
config.chat_template = Some(ChatTemplateValue(either::Left(chat_template)));
}
Self::from_parts(
config,
mdc.prompt_context
.map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
)
}
PromptFormatterArtifact::HfChatTemplate(_) => Err(anyhow::anyhow!(
"prompt_formatter should not have type HfChatTemplate"
)),
PromptFormatterArtifact::GGUF(gguf_path) => {
let config = ChatTemplate::from_gguf(&gguf_path)?;
Self::from_parts(config, ContextMixins::default())
Expand Down
Loading