Skip to content

Commit

Permalink
✨ prompt() do not need Argv
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Oct 27, 2024
1 parent 87f1bbe commit 292f4fb
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 56 deletions.
12 changes: 6 additions & 6 deletions devtool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import namedtuple
from typing import Any, Literal

from arclet.alconna.ingedia._analyser import Analyser, default_compiler
from arclet.alconna.ingedia._analyser import Analyser
from arclet.alconna.ingedia._handlers import analyse_header as alh
from arclet.alconna.ingedia._handlers import analyse_args as ala
from arclet.alconna.ingedia._handlers import analyse_option as alo
Expand Down Expand Up @@ -61,7 +61,7 @@ def analyse_args(
def analyse_header(
headers: list[str | Any] | list[tuple[Any, str]],
command_name: str,
command: DataCollection[str | Any],
command: list[str | Any],
sep: str = " ",
compact: bool = False,
raise_exception: bool = True,
Expand All @@ -83,7 +83,7 @@ def analyse_header(

def analyse_option(
option: Option,
command: DataCollection[str | Any],
command: list[str | Any],
raise_exception: bool = True,
context_style: Literal["bracket", "parentheses"] | None = None,
**kwargs,
Expand All @@ -95,7 +95,7 @@ def analyse_option(
_analyser.command.separators = " "
_analyser.need_main_args = False
_analyser.command.options.append(option)
default_compiler(_analyser)
_analyser.compile()
argv.stack_params.base = _analyser.compile_params
_analyser.command.options.clear()
try:
Expand All @@ -111,7 +111,7 @@ def analyse_option(

def analyse_subcommand(
subcommand: Subcommand,
command: DataCollection[str | Any],
command: list[str | Any],
raise_exception: bool = True,
context_style: Literal["bracket", "parentheses"] | None = None,
**kwargs,
Expand All @@ -123,7 +123,7 @@ def analyse_subcommand(
_analyser.command.separators = " "
_analyser.need_main_args = False
_analyser.command.options.append(subcommand)
default_compiler(_analyser)
_analyser.compile()
argv.stack_params.base = _analyser.compile_params
_analyser.command.options.clear()
try:
Expand Down
14 changes: 6 additions & 8 deletions src/arclet/alconna/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .args import Arg

if TYPE_CHECKING:
from .ingedia._argv import Argv
from .core import Alconna


Expand Down Expand Up @@ -243,12 +242,11 @@ def _prompt_none(command: Alconna, args_got: list[str], opts_got: list[str]):
return res


def prompt(command: Alconna, argv: Argv, args_got: list[str], opts_got: list[str], trigger: str | Arg | Subcommand | None = None):
def prompt(command: Alconna, buffer: list, args_got: list[str], opts_got: list[str], trigger: str | Arg | Subcommand | None = None):
"""获取补全列表"""
releases = argv.release(recover=True)
target = str(releases[-1])
if isinstance(releases[-1], str) and releases[-1] in command.config.builtin_option_name["completion"]:
target = str(releases[-2])
target = str(buffer[-1])
if isinstance(buffer[-1], str) and buffer[-1] in command.config.builtin_option_name["completion"]:
target = str(buffer[-2])
if isinstance(trigger, Arg):
if not (comp := trigger.field.get_completion()):
return [Prompt(command.formatter.param(trigger), False)]
Expand All @@ -257,10 +255,10 @@ def prompt(command: Alconna, argv: Argv, args_got: list[str], opts_got: list[str
o = list(filter(lambda x: target in x, comp)) or comp
return [Prompt(f"{trigger.name}: {i}", False, target) for i in o]
elif isinstance(trigger, Subcommand):
return [Prompt(i) for i in argv.stack_params.stack[-1]]
return [Prompt(i, True) for opt in trigger.options for i in opt.aliases if target in i]
if isinstance(trigger, str):
target = trigger
if _res := list(filter(lambda x: target in x, argv.stack_params.base)):
if _res := [x for opt in command.options for x in opt.aliases if target in x]:
out = [i for i in _res if i not in opts_got]
return [Prompt(i, True, target) for i in (out or _res)]
return _prompt_none(command, args_got, opts_got)
8 changes: 4 additions & 4 deletions src/arclet/alconna/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from nepattern import TPattern
from tarina import init_spec, lang, Empty

from .ingedia._analyser import Analyser, TCompile
from .ingedia._analyser import Analyser
from .ingedia._handlers import handle_head_fuzzy, analyse_header
from .ingedia._argv import Argv, __argv_type__
from .args import Arg, ArgsBuilder, ArgsBase, Args, ArgsMeta, handle_args
Expand Down Expand Up @@ -105,7 +105,7 @@ def _(command: Alconna, arp: Arparma):
trigger = arp.error_info.context_node
if res := prompt(
command,
argv,
argv.release(recover=True),
list(arp.main_args.keys()),
[*arp.options.keys(), *arp.subcommands.keys()],
trigger
Expand Down Expand Up @@ -208,14 +208,14 @@ class Alconna(Subcommand):
behaviors: list[ArparmaBehavior]
"""命令行为器"""

def compile(self, compiler: TCompile | None = None) -> Analyser:
def compile(self) -> Analyser:
"""编译 `Alconna` 为对应的解析器"""
if TYPE_CHECKING:
argv_type = Argv
else:
argv_type: type[Argv] = __argv_type__.get()
argv = argv_type(self.config, self.namespace_config, self.separators)
return Analyser(self, argv, compiler)
return Analyser(self, argv)

def __init__(
self,
Expand Down
61 changes: 26 additions & 35 deletions src/arclet/alconna/ingedia/_analyser.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable
from typing_extensions import Self, TypeAlias
from typing import TYPE_CHECKING, Any
from typing_extensions import Self

from tarina import Empty, lang

Expand Down Expand Up @@ -35,31 +35,6 @@
from ._argv import Argv


def default_compiler(analyser: SubAnalyser):
"""默认的编译方法
_Args:
analyser (SubAnalyser): 任意子解析器
"""
for opts in analyser.command.options:
if isinstance(opts, Option):
if opts.compact or opts.action.type == 2 or not set(analyser.command.separators).issuperset(opts.separators): # noqa: E501
analyser.compact_params.append(opts)
for alias in opts.aliases:
analyser.compile_params[alias] = opts
if opts.default is not Empty:
analyser.default_opt_result[opts.dest] = (opts.default, opts.action)
elif isinstance(opts, Subcommand):
sub = SubAnalyser(opts)
for alias in opts.aliases:
analyser.compile_params[alias] = sub
default_compiler(sub)
if not set(analyser.command.separators).issuperset(opts.separators):
analyser.compact_params.append(sub)
if sub.command.default is not Empty:
analyser.default_sub_result[opts.dest] = sub.command.default


@dataclass
class SubAnalyser:
"""子解析器, 用于子命令的解析"""
Expand Down Expand Up @@ -180,6 +155,26 @@ def process(self, argv: Argv, name_validated: bool = True) -> Self:
argv.stack_params.leave()
return self

def compile(self):
"""默认的编译方法"""
for opts in self.command.options:
if isinstance(opts, Option):
if opts.compact or opts.action.type == 2 or not set(self.command.separators).issuperset(opts.separators): # noqa: E501
self.compact_params.append(opts)
for alias in opts.aliases:
self.compile_params[alias] = opts
if opts.default is not Empty:
self.default_opt_result[opts.dest] = (opts.default, opts.action)
elif isinstance(opts, Subcommand):
sub = SubAnalyser(opts)
for alias in opts.aliases:
self.compile_params[alias] = sub
sub.compile()
if not set(self.command.separators).issuperset(opts.separators):
self.compact_params.append(sub)
if sub.command.default is not Empty:
self.default_sub_result[opts.dest] = sub.command.default


class Analyser(SubAnalyser):
"""命令解析器"""
Expand All @@ -189,18 +184,17 @@ class Analyser(SubAnalyser):
argv: Argv
"""命令行参数"""

def __init__(self, alconna: Alconna, argv: Argv, compiler: TCompile | None = None):
def __init__(self, alconna: Alconna, argv: Argv):
"""初始化解析器
_Args:
alconna (Alconna): 命令实例
argv (Argv): 命令行参数
compiler (TCompile | None, optional): 编译器方法
"""
super().__init__(alconna)
self.argv = argv
self.extra_allow = not self.command.config.strict
(compiler or default_compiler)(self)
self.compile()
self.argv.stack_params.base = self.compile_params

def __repr__(self):
Expand Down Expand Up @@ -232,7 +226,7 @@ def process(self, argv: Argv, name_validated: bool = True) -> Exception | None:
if isinstance(e1, InvalidParam):
argv.free(e1.context_node.separators if e1.context_node else None)
return PauseTriggered(
prompt(self.command, argv, [*self.args_result.keys()], [*self.options_result.keys(), *self.subcommands_result.keys()], e1.context_node),
prompt(self.command, argv.release(recover=True), [*self.args_result.keys()], [*self.options_result.keys(), *self.subcommands_result.keys()], e1.context_node),
e1,
argv
)
Expand Down Expand Up @@ -260,7 +254,7 @@ def process(self, argv: Argv, name_validated: bool = True) -> Exception | None:
)
if comp_ctx.get(None):
return PauseTriggered(
prompt(self.command, argv, [*self.args_result.keys()], [*self.options_result.keys(), *self.subcommands_result.keys()]),
prompt(self.command, argv.release(recover=True), [*self.args_result.keys()], [*self.options_result.keys(), *self.subcommands_result.keys()]),
exc,
argv
)
Expand Down Expand Up @@ -305,6 +299,3 @@ def export(
command_manager.record(argv.token, result)
self.reset()
return result # type: ignore


TCompile: TypeAlias = Callable[[SubAnalyser], None]
1 change: 1 addition & 0 deletions src/arclet/alconna/shortcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .exceptions import ArgumentMissing, ParamsUnmatched


class _ShortcutRegWrapper(Protocol):
def __call__(self, slot: int | str, content: str | None, context: dict[str, Any]) -> Any: ...

Expand Down
6 changes: 3 additions & 3 deletions tests/devtool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import namedtuple
from typing import Any, Literal

from arclet.alconna.ingedia._analyser import Analyser, default_compiler
from arclet.alconna.ingedia._analyser import Analyser
from arclet.alconna.ingedia._handlers import analyse_header as alh
from arclet.alconna.ingedia._handlers import analyse_args as ala
from arclet.alconna.ingedia._handlers import analyse_option as alo
Expand Down Expand Up @@ -95,7 +95,7 @@ def analyse_option(
_analyser.command.separators = " "
_analyser.need_main_args = False
_analyser.command.options.append(option)
default_compiler(_analyser)
_analyser.compile()
argv.stack_params.base = _analyser.compile_params
_analyser.command.options.clear()
try:
Expand Down Expand Up @@ -123,7 +123,7 @@ def analyse_subcommand(
_analyser.command.separators = " "
_analyser.need_main_args = False
_analyser.command.options.append(subcommand)
default_compiler(_analyser)
_analyser.compile()
argv.stack_params.base = _analyser.compile_params
_analyser.command.options.clear()
try:
Expand Down

0 comments on commit 292f4fb

Please sign in to comment.