Skip to content

Commit

Permalink
chat mode improvements (pytorch#244)
Browse files Browse the repository at this point in the history
* chat mode improvements

* disable int4 on macos/x86 because of old nightlies

* typo

* typo

* typo

* convert runtime error to arning

* wording of option texts
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 947a371 commit 652bc3c
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 11 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/eager-dtype.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,12 @@ jobs:
echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"

python generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager

echo "INT4 should work on MacOS on x86, but cannot be tested"
echo "because nightlies are too old!"
# python generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
# cat ./output_eager

echo "tests complete for ${DTYPE}"
done
Expand Down
22 changes: 21 additions & 1 deletion build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class BuilderArgs:
precision: torch.dtype = torch.float32
setup_caches: bool = False
use_tp: bool = False

is_chat_model: bool = False

def __post_init__(self):
if not (
(self.checkpoint_path and self.checkpoint_path.is_file())
Expand Down Expand Up @@ -66,6 +67,24 @@ def __post_init__(self):

@classmethod
def from_args(cls, args): # -> BuilderArgs:
is_chat_model = False
if args.is_chat_model:
is_chat_model = True
else:
for path in [
args.checkpoint_path,
args.checkpoint_dir,
args.dso_path,
args.pte_path,
args.gguf_path
]:
path = str(path)
if path.endswith('/'):
path = path[:-1]
path_basename = os.path.basename(path)
if "chat" in path_basename:
is_chat_model = True

return cls(
checkpoint_path=args.checkpoint_path,
checkpoint_dir=args.checkpoint_dir,
Expand All @@ -78,6 +97,7 @@ def from_args(cls, args): # -> BuilderArgs:
precision=name_to_dtype(args.dtype),
setup_caches=(args.output_dso_path or args.output_pte_path),
use_tp=False,
is_chat_model=is_chat_model,
)

@classmethod
Expand Down
7 changes: 6 additions & 1 deletion cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ def _add_arguments_common(parser):
parser.add_argument(
"--chat",
action="store_true",
help="Use torchchat to for an interactive chat session.",
help="Use torchchat for an interactive chat session.",
)
parser.add_argument(
"--is-chat-model",
action="store_true",
help="Indicate that the model was trained to support chat functionality.",
)
parser.add_argument(
"--gui",
Expand Down
18 changes: 12 additions & 6 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from cli import add_arguments_for_generate, arg_init, check_args
from quantize import set_precision

B_INST, E_INST = "[INST]", "[/INST]"

@dataclass
class GeneratorArgs:
Expand Down Expand Up @@ -343,11 +344,16 @@ def _main(
set_precision(builder_args.precision)
is_speculative = speculative_builder_args.checkpoint_path is not None

is_chat = "chat" in str(os.path.basename(builder_args.checkpoint_path))
if is_chat:
raise RuntimeError(
"need to stop filename based kludgery, at a minimum need to look at all pathnames. in particular, this now fails because chat is part of the pathname, yuck!"
)
if generator_args.chat_mode and not builder_args.is_chat_model:
print("""
*******************************************************
This model is not known to support the chat function.
We will enable chat mode based on your instructions.
If the model is not trained to support chat, it will
produce nonsensical or false output.
*******************************************************
""")
# raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.")

tokenizer = _initialize_tokenizer(tokenizer_args)

Expand Down Expand Up @@ -410,7 +416,7 @@ def _main(
device_sync(device=builder_args.device)
if i >= 0 and generator_args.chat_mode:
prompt = input("What is your prompt? ")
if is_chat:
if builder_args.is_chat_model:
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
encoded = encode_tokens(
tokenizer, prompt, bos=True, device=builder_args.device
Expand Down

0 comments on commit 652bc3c

Please sign in to comment.