Skip to content

Commit

Permalink
Merge pull request #1467 from h2oai/c4ai-command-r-v01
Browse files Browse the repository at this point in the history
CohereForAI/c4ai-command-r-v01
  • Loading branch information
pseudotensor authored Mar 12, 2024
2 parents e430d5a + b1729c4 commit 6680f46
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 25 deletions.
28 changes: 20 additions & 8 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2267,10 +2267,17 @@ def get_config(base_model,
config.update({"max_seq_len": 2 * 8192})
if return_model and \
issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())):
model = AutoModel.from_config(
config,
trust_remote_code=trust_remote_code,
)
try:
model = AutoModel.from_config(
config,
trust_remote_code=trust_remote_code,
)
except Exception as e:
if 'has no attribute' in str(e):
# half-baked hack to transformers by Cohere
model = None
else:
raise
else:
# can't infer
model = None
Expand Down Expand Up @@ -5700,12 +5707,17 @@ def get_limited_prompt(instruction,
min_max_new_tokens=min_max_new_tokens)

from openai_server.backend_utils import structure_to_messages
use_chat_template = (prompt_type in [None, '', 'plain'] and
hasattr(tokenizer, 'chat_template') and
tokenizer.chat_template)
use_chat_template = prompt_type in [None, '', 'plain'] and \
(hasattr(tokenizer, 'chat_template') and
tokenizer.chat_template not in [None, ''] or
hasattr(tokenizer, 'default_chat_template') and
tokenizer.default_chat_template not in [None, '']
)

if use_chat_template:
messages = structure_to_messages(instruction, system_prompt, history)
messages = structure_to_messages(instruction,
system_prompt if system_prompt not in [None, '', 'auto'] else None,
history)
context2 = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
iinput = ''
context = ''
Expand Down
57 changes: 40 additions & 17 deletions src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
get_list_or_str, have_pillow, only_selenium, only_playwright, only_unstructured_urls, get_short_name, \
get_accordion, have_jq, get_doc, get_source, have_chromamigdb, get_token_count, reverse_ucurve_list, get_size, \
get_test_name_core, download_simple, have_fiftyone, have_librosa, return_good_url, n_gpus_global, \
get_accordion_named, hyde_titles, have_cv2, FullSet, create_relative_symlink, split_list, get_gradio_tmp
get_accordion_named, hyde_titles, have_cv2, FullSet, create_relative_symlink, split_list, get_gradio_tmp, merge_dict
from enums import DocumentSubset, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
LangChainAction, LangChainMode, DocumentChoice, LangChainTypes, font_size, head_acc, super_source_prefix, \
super_source_postfix, langchain_modes_intrinsic, get_langchain_prompts, LangChainAgent, docs_joiner_default, \
Expand Down Expand Up @@ -1672,6 +1672,7 @@ async def agenerate_prompt(
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
)


class H2OChatAnthropic2Sys(H2OChatAnthropic2):
pass

Expand All @@ -1685,7 +1686,6 @@ class H2OChatAnthropic3(GenerateStream, ExtraChat, ChatAnthropic3):
# max_new_tokens0: Any = None # FIXME: Doesn't seem to have same max_tokens == -1 for prompts==1



class H2OChatAnthropic3Sys(H2OChatAnthropic3):
pass

Expand Down Expand Up @@ -5982,7 +5982,8 @@ def split_merge_docs(docs_with_score, tokenizer=None, max_input_tokens=None, doc
# see if need to split
# account for joiner tokens
joiner_tokens = get_token_count(docs_joiner_default, tokenizer)
doc_chunk_size = max(64, min(max_input_tokens, max(64, max_input_tokens - joiner_tokens * len(docs_with_score))))
doc_chunk_size = max(64, min(max_input_tokens,
max(64, max_input_tokens - joiner_tokens * len(docs_with_score))))
text_splitter = H2OCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer, chunk_size=doc_chunk_size, chunk_overlap=0
)
Expand Down Expand Up @@ -7035,7 +7036,7 @@ def get_chain(query=None,
estimated_prompt_no_docs = template_if_no_docs.format(context='', question=query)

# add metadata to documents and make new copy of docs with them to not contaminate originals
if metadata_in_context and not doc_json_mode:
if metadata_in_context and not doc_json_mode and not hasattr(tokenizer, 'apply_grounded_generation_template'):
docs_with_score = [(Document(page_content='Begin Document:\n\n' +
'Metadata:\n' +
'\n'.join(['%s = %s' % (k, v) for k, v in x.metadata.items() if
Expand Down Expand Up @@ -7216,11 +7217,8 @@ def get_chain(query=None,
prompter=prompter)

if doc_json_mode:
def merge_dict(dict1, dict2):
return dict2.update(dict1)

# make copy so don't change originals
if metadata_in_context:
if metadata_in_context and not hasattr(tokenizer, 'apply_grounded_generation_template'):
docs = [Document(page_content=json.dumps(merge_dict(dict(ID=xi, content=x.page_content),
{k: v for k, v in x.metadata.items() if
v and k in metadata_in_context_set})),
Expand All @@ -7232,21 +7230,46 @@ def merge_dict(dict1, dict2):
for xi, x in enumerate(docs)]

if langchain_action == LangChainAction.QUERY.value:
if use_template:
# instruct-like, rather than few-shot prompt_type='plain' as default
# but then sources confuse the model with how inserted among rest of text, so avoid
if hasattr(tokenizer, 'apply_grounded_generation_template'):
assert prompt_type == 'plain'
# https://huggingface.co/CohereForAI/c4ai-command-r-v01
prompt = PromptTemplate(
# input_variables=["summaries", "question"],
input_variables=["context", "question"],
template=template,
template='{context}{question}', # ignored
)
chain = load_qa_chain(llm, prompt=prompt, verbose=verbose)
documents = [merge_dict(dict(text=x.page_content),
{k: v for k, v in x.metadata.items() if
v and k in metadata_in_context_set}) for x in docs]
from openai_server.backend_utils import structure_to_messages
conversation = structure_to_messages(query,
system_prompt if system_prompt not in [None, '', 'auto'] else None,
chat_conversation)
query_with_docs = tokenizer.apply_grounded_generation_template(
conversation,
documents=documents,
citation_mode="accurate", # or "fast"
tokenize=False,
add_generation_prompt=True,
)
chain_kwargs = dict(input_documents=[], question=query_with_docs)
else:
# unused normally except in testing
assert use_openai_model or prompt_type == 'plain', "Unexpected to use few-shot template for %s %s" % (
model_name, prompt_type)
chain = load_qa_with_sources_chain(llm)
chain_kwargs = dict(input_documents=docs, question=query)
if use_template:
# instruct-like, rather than few-shot prompt_type='plain' as default
# but then sources confuse the model with how inserted among rest of text, so avoid
prompt = PromptTemplate(
# input_variables=["summaries", "question"],
input_variables=["context", "question"],
template=template,
)
chain = load_qa_chain(llm, prompt=prompt, verbose=verbose)
else:
# unused normally except in testing
assert use_openai_model or prompt_type == 'plain', "Unexpected to use few-shot template for %s %s" % (
model_name, prompt_type)
chain = load_qa_with_sources_chain(llm)
chain_kwargs = dict(input_documents=docs, question=query)
target = wrapped_partial(chain, chain_kwargs)
elif summarize_action:
if async_output:
Expand Down
6 changes: 6 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,3 +2027,9 @@ def get_lock_file(name):
lock_file = os.path.join(base_path, "%s.lock" % lock_type)
makedirs(os.path.dirname(lock_file)) # ensure made
return lock_file


def merge_dict(dict1, dict2):
ret = dict1.copy()
ret.update(dict2)
return ret

0 comments on commit 6680f46

Please sign in to comment.