Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,38 +589,6 @@ struct Embedding_Element : JSON::Element {
EmbeddingOutputs_Element outputs_{v_.outputs};
};

struct PromptTemplates_Element : JSON::Element {
explicit PromptTemplates_Element(std::optional<Config::Model::PromptTemplates>& v) : v_{v} {}

void OnValue(std::string_view name, JSON::Value value) override {
// if one of templates is given in json, then any non-specified template will be default "{Content}"
if (name == "assistant") {
EnsureAvailable();
v_->assistant = JSON::Get<std::string_view>(value);
} else if (name == "prompt") {
EnsureAvailable();
v_->prompt = JSON::Get<std::string_view>(value);
} else if (name == "system") {
EnsureAvailable();
v_->system = JSON::Get<std::string_view>(value);
} else if (name == "user") {
EnsureAvailable();
v_->user = JSON::Get<std::string_view>(value);
} else {
throw JSON::unknown_value_error{};
}
}

private:
std::optional<Config::Model::PromptTemplates>& v_;

void EnsureAvailable() {
if (!v_.has_value()) {
v_.emplace();
}
}
};

struct Model_Element : JSON::Element {
explicit Model_Element(Config::Model& v) : v_{v} {}

Expand Down Expand Up @@ -664,9 +632,6 @@ struct Model_Element : JSON::Element {
if (name == "embedding") {
return embedding_;
}
if (name == "prompt_templates") {
return prompt_templates_;
}
if (name == "speech") {
return speech_;
}
Expand All @@ -680,7 +645,6 @@ struct Model_Element : JSON::Element {
Eos_Array_Element eos_token_ids_{v_};
Vision_Element vision_{v_.vision};
Embedding_Element embedding_{v_.embedding};
PromptTemplates_Element prompt_templates_{v_.prompt_templates};
Speech_Element speech_{v_.speech};
};

Expand Down
8 changes: 0 additions & 8 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ struct Config {
static constexpr std::string_view InputsEmbedsName = "inputs_embeds";
static constexpr std::string_view CurrentSequenceLengthName = "current_sequence_length";
static constexpr std::string_view PastSequenceLengthName = "past_sequence_length";
static constexpr std::string_view promptTemplate = "{Content}";
static constexpr std::string_view TotalSequenceLengthName = "total_sequence_length";
static constexpr std::string_view TokenTypeIdsName = "token_type_ids";

Expand Down Expand Up @@ -206,13 +205,6 @@ struct Config {

} decoder;

struct PromptTemplates {
std::string assistant{Defaults::promptTemplate};
std::string prompt{Defaults::promptTemplate};
std::string system{Defaults::promptTemplate};
std::string user{Defaults::promptTemplate};
};
std::optional<PromptTemplates> prompt_templates;
} model;

struct Search {
Expand Down
33 changes: 1 addition & 32 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,6 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
ep_options = { self.ep : self.ep_attrs[self.ep] }
genai_config["model"]["decoder"]["session_options"]["provider_options"].append(ep_options)

if self.extra_options.get("include_prompt_templates", False):
prompt_templates = self._get_prompt_templates(model_name_or_path, extra_kwargs)
if prompt_templates is not None:
genai_config["model"]["prompt_templates"] = prompt_templates

print(f"Saving GenAI config in {out_dir}")
with open(os.path.join(out_dir,"genai_config.json"), "w") as f:
json.dump(genai_config, f, indent=4)
Expand All @@ -413,30 +408,6 @@ def save_processing(self, model_name_or_path, extra_kwargs, out_dir):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token=self.hf_token, trust_remote_code=True, **extra_kwargs)
print(f"Saving processing files in {out_dir} for GenAI")
tokenizer.save_pretrained(out_dir)

def _get_prompt_templates(self, hf_name, extra_kwargs):
try:
# disable end of sentence padding with eos_token=None
tokenizer = AutoTokenizer.from_pretrained(hf_name, token=self.hf_token, trust_remote_code=True, eos_token=None, **extra_kwargs)
system_template = tokenizer.apply_chat_template([{'role': 'system', 'content': '{Content}'}], tokenize=False)
system_user_template = tokenizer.apply_chat_template([{'role': 'system', 'content': '{Content}'}, {'role': 'user', 'content': '{Content}'}], tokenize=False)
system_user_assistant_template = tokenizer.apply_chat_template([{'role': 'system', 'content': '{Content}'}, {'role': 'user', 'content': '{Content}'}, {'role': 'assistant', 'content': '{Content}'}], tokenize=False)
assert system_user_template.startswith(system_template), "Chat templates may contain padding tokens, leading to incorrect prompt templates"
assert system_user_assistant_template.startswith(system_user_template), "Chat templates may contain padding tokens, leading to incorrect prompt templates"
user_template = system_user_template[len(system_template):]
assistant_template = system_user_assistant_template[len(system_user_template):]
prompt_template = system_user_assistant_template[len(system_template):]
prompt_template = prompt_template[:prompt_template.rfind('{Content}')]
templates = {
"system": system_template,
"user": user_template,
"assistant": assistant_template,
"prompt": prompt_template
}
return templates
except Exception as e:
print(f"Failed to get prompt templates. Error: {e}")
return None

def save_model(self, out_dir):
print(f"Saving ONNX model in {out_dir}")
Expand Down Expand Up @@ -3284,7 +3255,7 @@ def check_extra_options(kv_pairs):
"""
Check key-value pairs and set values correctly
"""
bools = ["int4_is_symmetric", "exclude_embeds", "exclude_lm_head", "include_hidden_states", "enable_cuda_graph", "use_8bits_moe", "use_qdq", "include_prompt_templates"]
bools = ["int4_is_symmetric", "exclude_embeds", "exclude_lm_head", "include_hidden_states", "enable_cuda_graph", "use_8bits_moe", "use_qdq"]
for key in bools:
if key in kv_pairs:
if kv_pairs[key] in {"false", "False", "0"}:
Expand Down Expand Up @@ -3550,8 +3521,6 @@ def get_args():
Use this option to enable GPUs that do not support FP16 on WebGPU (e.g. GTX 10xx).
adapter_path = Path to folder on disk containing the adapter files (adapter_config.json and adapter model weights).
Use this option for LoRA models.
include_prompt_templates = Include prompt templates in the GenAI config file. Default is false.
Use this option to include per-role prompt templates in the `genai_config.json` file.
"""),
)

Expand Down
Loading