diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index 2261ef233134..ac7f9e0a7e02 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -7,6 +7,7 @@ import importlib.metadata import sys +from importlib.util import find_spec from vllm.logger import init_logger @@ -34,47 +35,63 @@ def main(): cli_env_setup() - # For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default - if len(sys.argv) > 1 and sys.argv[1] == "bench": - logger.debug( - "Bench command detected, must ensure current platform is not " - "UnspecifiedPlatform to avoid device type inference error" - ) - from vllm import platforms + # If `--omni` arg is passed to the CLI, delegate to vLLM Omni's entrypoint handling + if "--omni" in sys.argv: + # NOTE: Check the spec instead of importing directly here, since things could + # fail with ImportError due to mismatched versions if things are moved around. + spec = find_spec("vllm_omni") + if spec is None: + logger.error( + "--omni flag requires a valid instance of vllm-omni to be installed." + ) + sys.exit(1) - if platforms.current_platform.is_unspecified(): - from vllm.platforms.cpu import CpuPlatform + from vllm_omni.entrypoints.cli.main import main as omni_main - platforms.current_platform = CpuPlatform() - logger.info( - "Unspecified platform detected, switching to CPU Platform instead." + logger.info("Delegating entrypoint handling to vllm-omni") + omni_main() + else: + # For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default + if len(sys.argv) > 1 and sys.argv[1] == "bench": + logger.debug( + "Bench command detected, must ensure current platform is not " + "UnspecifiedPlatform to avoid device type inference error" ) + from vllm import platforms - parser = FlexibleArgumentParser( - description="vLLM CLI", - epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"), - ) - parser.add_argument( - "-v", - "--version", - action="version", - version=importlib.metadata.version("vllm"), - ) - subparsers = parser.add_subparsers(required=False, dest="subparser") - cmds = {} - for cmd_module in CMD_MODULES: - new_cmds = cmd_module.cmd_init() - for cmd in new_cmds: - cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd) - cmds[cmd.name] = cmd - args = parser.parse_args() - if args.subparser in cmds: - cmds[args.subparser].validate(args) - - if hasattr(args, "dispatch_function"): - args.dispatch_function(args) - else: - parser.print_help() + if platforms.current_platform.is_unspecified(): + from vllm.platforms.cpu import CpuPlatform + + platforms.current_platform = CpuPlatform() + logger.info( + "Unspecified platform detected, switching to CPU Platform instead." + ) + + parser = FlexibleArgumentParser( + description="vLLM CLI", + epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"), + ) + parser.add_argument( + "-v", + "--version", + action="version", + version=importlib.metadata.version("vllm"), + ) + subparsers = parser.add_subparsers(required=False, dest="subparser") + cmds = {} + for cmd_module in CMD_MODULES: + new_cmds = cmd_module.cmd_init() + for cmd in new_cmds: + cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd) + cmds[cmd.name] = cmd + args = parser.parse_args() + if args.subparser in cmds: + cmds[args.subparser].validate(args) + + if hasattr(args, "dispatch_function"): + args.dispatch_function(args) + else: + parser.print_help() if __name__ == "__main__":