Skip to content

Commit

Permalink
Remove args from help that we don't wish to advertise yet (#912)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack-Khuu authored and malfet committed Jul 17, 2024
1 parent 0f68cca commit ee681bf
Showing 1 changed file with 120 additions and 91 deletions.
211 changes: 120 additions & 91 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import json
import logging
import os
Expand Down Expand Up @@ -43,44 +44,46 @@ def check_args(args, verb: str) -> None:
download_and_convert(args.model, args.model_directory, args.hf_token)


# Given a arg parser and a subcommand (verb), add the appropriate arguments
# for that subcommand.
def add_arguments_for_verb(parser, verb: str) -> None:
# Model specification. TODO Simplify this.
# A model can be specified using a positional model name or HuggingFace
# path. Alternatively, the model can be specified via --gguf-path or via
# an explicit --checkpoint-dir, --checkpoint-path, or --tokenizer-path.

# Argument closure for inventory related subcommands
if verb in INVENTORY_VERBS:
_configure_artifact_inventory_args(parser, verb)
_add_cli_metadata_args(parser)
return

# Model specification
# A model can be specified using a positional model name or checkpoint path
parser.add_argument(
"model",
type=str,
nargs="?",
default=None,
help="Model name for well-known models",
)
parser.add_argument(
"--checkpoint-path",
type=Path,
default="not_specified",
help="Use the specified model checkpoint path",
)

# Add thematic argument groups based on the subcommand
if verb in ["browser", "chat", "generate"]:
_add_generation_args(parser)
if verb == "eval":
_add_evaluation_args(parser)

# Add argument groups for exported model path IO
_add_exported_input_path_args(parser)
_add_export_output_path_args(parser)

parser.add_argument(
"--distributed",
action="store_true",
help="Whether to enable distributed inference",
)
parser.add_argument(
"--is-chat-model",
action="store_true",
help="Indicate that the model was trained to support chat functionality",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Initialize torch seed",
)
parser.add_argument(
"--compile",
action="store_true",
Expand All @@ -91,52 +94,6 @@ def add_arguments_for_verb(parser, verb: str) -> None:
action="store_true",
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
)
parser.add_argument(
"--profile",
type=Path,
default=None,
help="Profile path.",
)
parser.add_argument(
"--draft-checkpoint-path",
type=Path,
default=None,
help="Use the specified draft checkpoint path",
)
parser.add_argument(
"--checkpoint-path",
type=Path,
default="not_specified",
help="Use the specified model checkpoint path",
)
parser.add_argument(
"--dcp-dir",
type=Path,
default=None,
help="Use the specified model checkpoint directory",
)
parser.add_argument(
"--params-path",
type=Path,
default=None,
help="Use the specified parameter file",
)
parser.add_argument(
"--gguf-path",
type=Path,
default=None,
help="Use the specified GGUF model file",
)
parser.add_argument(
"--tokenizer-path",
type=Path,
default=None,
help="Use the specified model tokenizer file",
)

_add_exported_model_input_args(parser)
_add_export_output_path_args(parser)

parser.add_argument(
"--dtype",
default="fast",
Expand All @@ -152,34 +109,13 @@ def add_arguments_for_verb(parser, verb: str) -> None:
+ "modes are: embedding, linear:int8, linear:int4, linear:a8w4dq, precision."
),
)
parser.add_argument(
"--draft-quantize",
type=str,
default="{ }",
help=(
"Quantization options. Same format as quantize, "
+ "or 'quantize' to indicate same options specified by "
+ "--quantize to main model. Applied to draft model."
),
)
parser.add_argument(
"--params-table",
type=str,
default=None,
choices=allowable_params_table(),
help="Parameter table to use",
)
parser.add_argument(
"--device",
type=str,
default=default_device,
choices=["fast", "cpu", "cuda", "mps"],
help="Hardware device to use. Options: cpu, cuda, mps",
)

if verb == "eval":
_add_evaluation_args(parser)

parser.add_argument(
"--hf-token",
type=str,
Expand All @@ -192,6 +128,12 @@ def add_arguments_for_verb(parser, verb: str) -> None:
default=default_model_dir,
help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}",
)
parser.add_argument(
"--profile",
type=Path,
default=None,
help="Profile path.",
)
parser.add_argument(
"--port",
type=int,
Expand All @@ -200,6 +142,11 @@ def add_arguments_for_verb(parser, verb: str) -> None:
)
_add_cli_metadata_args(parser)

# WIP Features (suppressed from --help)
_add_distributed_args(parser)
_add_custom_model_args(parser)
_add_speculative_execution_args(parser)


# Add CLI Args representing user provided exported model files
def _add_export_output_path_args(parser) -> None:
Expand All @@ -219,7 +166,7 @@ def _add_export_output_path_args(parser) -> None:


# Add CLI Args representing user provided exported model files
def _add_exported_model_input_args(parser) -> None:
def _add_exported_input_path_args(parser) -> None:
exported_model_path_parser = parser.add_argument_group("Exported Model Path Args", "Specify the path of the exported model files to ingest")
exported_model_path_parser.add_argument(
"--dso-path",
Expand All @@ -235,14 +182,20 @@ def _add_exported_model_input_args(parser) -> None:
)


# Add CLI Args that are relevant to any subcommand execution
# Add CLI Args that are general to subcommand cli execution
def _add_cli_metadata_args(parser) -> None:
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Verbose output",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Initialize torch seed",
)


# Configure CLI Args specific to Model Artifact Management
Expand Down Expand Up @@ -318,12 +271,6 @@ def _add_generation_args(parser) -> None:
action="store_true",
help="Whether to perform prefill sequentially. Only used for model debug.",
)
generator_parser.add_argument(
"--speculate-k",
type=int,
default=5,
help="Speculative execution depth",
)


# Add CLI Args specific to Model Evaluation
Expand All @@ -350,6 +297,88 @@ def _add_evaluation_args(parser) -> None:
)


# Add CLI Args related to distributed inference
# This feature is currently a [WIP] and hidden from --help
def _add_distributed_args(parser) -> None:
parser.add_argument(
"--distributed",
action="store_true",
help=argparse.SUPPRESS,
# "Whether to enable distributed inference",
)
parser.add_argument(
"--dcp-dir",
type=Path,
default=None,
help=argparse.SUPPRESS,
# "Use the specified model checkpoint directory",
)


# Add CLI Args related to custom model inputs (e.g. GGUF)
# This feature is currently a [WIP] and hidden from --help
def _add_custom_model_args(parser) -> None:
parser.add_argument(
"--params-table",
type=str,
default=None,
choices=allowable_params_table(),
help=argparse.SUPPRESS,
# "Parameter table to use",
)
parser.add_argument(
"--params-path",
type=Path,
default=None,
help=argparse.SUPPRESS,
# "Use the specified parameter file",
)
parser.add_argument(
"--gguf-path",
type=Path,
default=None,
help=argparse.SUPPRESS,
# "Use the specified GGUF model file",
)
parser.add_argument(
"--tokenizer-path",
type=Path,
default=None,
help=argparse.SUPPRESS,
# "Use the specified model tokenizer file",
)


# Add CLI Args related to speculative execution
# This feature is currently a [WIP] and hidden from --help
def _add_speculative_execution_args(parser) -> None:
parser.add_argument(
"--speculate-k",
type=int,
default=5,
help=argparse.SUPPRESS,
# "Speculative execution depth",
)
parser.add_argument(
"--draft-checkpoint-path",
type=Path,
default=None,
help=argparse.SUPPRESS,
# "Use the specified draft checkpoint path",
)
parser.add_argument(
"--draft-quantize",
type=str,
default="{ }",
help=argparse.SUPPRESS,
# (
# "Quantization options. Same format as quantize, "
# + "or 'quantize' to indicate same options specified by "
# + "--quantize to main model. Applied to draft model."
# ),
)


def arg_init(args):
if not (torch.__version__ > "2.3"):
raise RuntimeError(
Expand Down

0 comments on commit ee681bf

Please sign in to comment.