From 7e4c47ab36ec56d394547158df18d2a9f2eb0849 Mon Sep 17 00:00:00 2001 From: HsiangNianian Date: Sun, 11 Aug 2024 03:41:33 +0000 Subject: [PATCH] chore(docs): update api docs with sphinx-apidoc --- docs/source/pages/api/iamai.libcore.rst | 7 + docs/source/pages/api/iamai.rst | 1 + iamai/adapter/apscheduler/__init__.py | 139 +++ iamai/adapter/apscheduler/config.py | 20 + iamai/adapter/apscheduler/event.py | 37 + iamai/adapter/bililive/__init__.py | 432 ++++++++ iamai/adapter/bililive/api/blivedm.py | 437 ++++++++ iamai/adapter/bililive/config.py | 30 + iamai/adapter/bililive/event.py | 309 ++++++ iamai/adapter/bililive/exceptions.py | 13 + iamai/adapter/bililive/message.py | 39 + iamai/adapter/bililive/tests.py | 35 + iamai/adapter/bililive/utils/bilibili_api.py | 233 +++++ iamai/adapter/bililive/utils/bilibili_bot.py | 134 +++ iamai/adapter/bililive/utils/file_loader.py | 34 + iamai/adapter/bililive/utils/main.py | 68 ++ iamai/adapter/bililive/utils/plugin.py | 441 +++++++++ .../adapter/bililive/utils/plugins_loader.py | 34 + iamai/adapter/console/__init__.py | 37 + iamai/adapter/console/config.py | 12 + iamai/adapter/console/event.py | 67 ++ iamai/adapter/console/message.py | 26 + iamai/adapter/cqhttp/__init__.py | 294 ++++++ iamai/adapter/cqhttp/config.py | 30 + iamai/adapter/cqhttp/event.py | 455 +++++++++ iamai/adapter/cqhttp/exceptions.py | 43 + iamai/adapter/cqhttp/message.py | 280 ++++++ iamai/adapter/gensokyo/__init__.py | 314 ++++++ iamai/adapter/gensokyo/config.py | 33 + iamai/adapter/gensokyo/event.py | 455 +++++++++ iamai/adapter/gensokyo/exceptions.py | 43 + iamai/adapter/gensokyo/message.py | 280 ++++++ iamai/adapter/kook/__init__.py | 397 ++++++++ iamai/adapter/kook/_event.py | 930 ++++++++++++++++++ iamai/adapter/kook/api/__init__.py | 2 + iamai/adapter/kook/api/client.py | 1 + iamai/adapter/kook/api/client.pyi | 355 +++++++ iamai/adapter/kook/api/handle.py | 75 ++ iamai/adapter/kook/api/model.py | 444 +++++++++ iamai/adapter/kook/config.py | 27 + iamai/adapter/kook/event.py | 930 ++++++++++++++++++ iamai/adapter/kook/exceptions.py | 96 ++ iamai/adapter/kook/message.py | 339 +++++++ iamai/adapter/red/__init__.py | 205 ++++ iamai/adapter/red/config.py | 38 + iamai/adapter/red/event.py | 436 ++++++++ 46 files changed, 9087 insertions(+) create mode 100644 docs/source/pages/api/iamai.libcore.rst create mode 100644 iamai/adapter/apscheduler/__init__.py create mode 100644 iamai/adapter/apscheduler/config.py create mode 100644 iamai/adapter/apscheduler/event.py create mode 100644 iamai/adapter/bililive/__init__.py create mode 100644 iamai/adapter/bililive/api/blivedm.py create mode 100644 iamai/adapter/bililive/config.py create mode 100644 iamai/adapter/bililive/event.py create mode 100644 iamai/adapter/bililive/exceptions.py create mode 100644 iamai/adapter/bililive/message.py create mode 100644 iamai/adapter/bililive/tests.py create mode 100644 iamai/adapter/bililive/utils/bilibili_api.py create mode 100644 iamai/adapter/bililive/utils/bilibili_bot.py create mode 100644 iamai/adapter/bililive/utils/file_loader.py create mode 100644 iamai/adapter/bililive/utils/main.py create mode 100644 iamai/adapter/bililive/utils/plugin.py create mode 100644 iamai/adapter/bililive/utils/plugins_loader.py create mode 100644 iamai/adapter/console/__init__.py create mode 100644 iamai/adapter/console/config.py create mode 100644 iamai/adapter/console/event.py create mode 100644 iamai/adapter/console/message.py create mode 100644 iamai/adapter/cqhttp/__init__.py create mode 100644 iamai/adapter/cqhttp/config.py create mode 100644 iamai/adapter/cqhttp/event.py create mode 100644 iamai/adapter/cqhttp/exceptions.py create mode 100644 iamai/adapter/cqhttp/message.py create mode 100644 iamai/adapter/gensokyo/__init__.py create mode 100644 iamai/adapter/gensokyo/config.py create mode 100644 iamai/adapter/gensokyo/event.py create mode 100644 iamai/adapter/gensokyo/exceptions.py create mode 100644 iamai/adapter/gensokyo/message.py create mode 100644 iamai/adapter/kook/__init__.py create mode 100644 iamai/adapter/kook/_event.py create mode 100644 iamai/adapter/kook/api/__init__.py create mode 100644 iamai/adapter/kook/api/client.py create mode 100644 iamai/adapter/kook/api/client.pyi create mode 100644 iamai/adapter/kook/api/handle.py create mode 100644 iamai/adapter/kook/api/model.py create mode 100644 iamai/adapter/kook/config.py create mode 100644 iamai/adapter/kook/event.py create mode 100644 iamai/adapter/kook/exceptions.py create mode 100644 iamai/adapter/kook/message.py create mode 100644 iamai/adapter/red/__init__.py create mode 100644 iamai/adapter/red/config.py create mode 100644 iamai/adapter/red/event.py diff --git a/docs/source/pages/api/iamai.libcore.rst b/docs/source/pages/api/iamai.libcore.rst new file mode 100644 index 00000000..9bda5271 --- /dev/null +++ b/docs/source/pages/api/iamai.libcore.rst @@ -0,0 +1,7 @@ +iamai.libcore module +==================== + +.. automodule:: iamai.libcore + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/pages/api/iamai.rst b/docs/source/pages/api/iamai.rst index 166379cb..2db59688 100644 --- a/docs/source/pages/api/iamai.rst +++ b/docs/source/pages/api/iamai.rst @@ -23,6 +23,7 @@ Submodules iamai.dependencies iamai.event iamai.exceptions + iamai.libcore iamai.log iamai.message iamai.plugin diff --git a/iamai/adapter/apscheduler/__init__.py b/iamai/adapter/apscheduler/__init__.py new file mode 100644 index 00000000..bfde3133 --- /dev/null +++ b/iamai/adapter/apscheduler/__init__.py @@ -0,0 +1,139 @@ +"""APScheduler 适配器。 + +本适配器用于实现定时任务,适配器将使用 APScheduler 实现定时任务,在设定的时间产生一个事件供插件处理。 +APScheduler 使用方法请参考:[APScheduler](https://apscheduler.readthedocs.io/)。 +""" + +import inspect +from functools import wraps +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Type, Union + +from apscheduler.job import Job +from apscheduler.schedulers.asyncio import AsyncIOScheduler + +from iamai.adapter import Adapter +from iamai.log import logger +from iamai.plugin import Plugin +from iamai.typing import PluginT + +from .config import Config +from .event import APSchedulerEvent + +if TYPE_CHECKING: + from apscheduler.triggers.base import BaseTrigger + +__all__ = ["APSchedulerAdapter", "scheduler_decorator"] + + +class APSchedulerAdapter(Adapter[APSchedulerEvent, Config]): + """APScheduler 适配器。""" + + name: str = "apscheduler" + Config = Config + + scheduler: AsyncIOScheduler + plugin_class_to_job: Dict[Type[Plugin[Any, Any, Any]], Job] + + async def startup(self) -> None: + """创建 `AsyncIOScheduler` 对象。""" + self.scheduler = AsyncIOScheduler(self.config.scheduler_config) + self.plugin_class_to_job = {} + + async def run(self) -> None: + """启动调度器。""" + for plugin in self.bot.plugins: + if not hasattr(plugin, "__schedule__"): + continue + + if not hasattr(plugin, "trigger") or not hasattr(plugin, "trigger_args"): + logger.error( + f"Plugin {plugin.__name__} __schedule__ is True, " + f"but did not set trigger or trigger_args" + ) + continue + + trigger: Union[str, BaseTrigger] = getattr(plugin, "trigger") # noqa: B009 + trigger_args: Dict[str, Any] = getattr(plugin, "trigger_args") # noqa: B009 + + if not isinstance(trigger, str) or not isinstance(trigger_args, dict): + logger.error( + f"Plugin {plugin.__name__} trigger or trigger_args type error" + ) + continue + + try: + self.plugin_class_to_job[plugin] = self.scheduler.add_job( + self.create_event, args=(plugin,), trigger=trigger, **trigger_args + ) + except Exception as e: + self.bot.error_or_exception( + f"Plugin {plugin.__name__} add_job filed, " + "please check trigger and trigger_args:", + e, + ) + else: + logger.info(f"Plugin {plugin.__name__} has been scheduled to run") + + self.scheduler.start() + + async def shutdown(self) -> None: + """关闭调度器。""" + self.scheduler.shutdown() + + async def create_event(self, plugin_class: Type[Plugin[Any, Any, Any]]) -> None: + """创建 `APSchedulerEvent` 事件。 + + Args: + plugin_class: `Plugin` 类。 + """ + logger.info(f"APSchedulerEvent set by {plugin_class} is created as scheduled") + await self.handle_event( + APSchedulerEvent(adapter=self, plugin_class=plugin_class), + handle_get=False, + show_log=False, + ) + + async def send(self, *args: Any, **kwargs: Any) -> Any: + """APScheduler 适配器不适用发送消息。""" + raise NotImplementedError + + +def scheduler_decorator( + trigger: str, trigger_args: Dict[str, Any], override_rule: bool = False +) -> Callable[[Type[PluginT]], Type[PluginT]]: + """用于为插件类添加计划任务功能的装饰器。 + + Args: + trigger: APScheduler 触发器。 + trigger_args: APScheduler 触发器参数。 + override_rule: 是否重写 `rule()` 方法。 + 若为 `True`,则会在 `rule()` 方法中添加处理本插件定义的计划任务事件的逻辑。 + """ + + def _decorator(cls: Type[PluginT]) -> Type[PluginT]: + if not inspect.isclass(cls): + raise TypeError("can only decorate class") + if not issubclass(cls, Plugin): + raise TypeError("can only decorate Plugin class") + setattr(cls, "__schedule__", True) # noqa: B010 + setattr(cls, "trigger", trigger) # noqa: B010 + setattr(cls, "trigger_args", trigger_args) # noqa: B010 + if override_rule: + + def _rule_decorator(func: Callable[[PluginT], Awaitable[bool]]) -> Any: + @wraps(func) + async def _wrapper(self: PluginT) -> bool: + if ( + self.event.type == "apscheduler" + # pylint: disable-next=unidiomatic-typecheck + and type(self) is self.event.plugin_class + ): + return True + return await func(self) + + return _wrapper + + cls.rule = _rule_decorator(cls.rule) # type: ignore + return cls # type: ignore + + return _decorator diff --git a/iamai/adapter/apscheduler/config.py b/iamai/adapter/apscheduler/config.py new file mode 100644 index 00000000..e5bd1c86 --- /dev/null +++ b/iamai/adapter/apscheduler/config.py @@ -0,0 +1,20 @@ +"""APScheduler 适配器配置。""" + +from typing import Any, Dict + +from pydantic import Field + +from iamai.config import ConfigModel + +__all__ = ["Config"] + + +class Config(ConfigModel): + """APScheduler 配置类,将在适配器被加载时被混入到机器人主配置中。 + + Attributes: + scheduler_config: 调度器配置。 + """ + + __config_name__ = "apscheduler" + scheduler_config: Dict[str, Any] = Field(default_factory=dict) diff --git a/iamai/adapter/apscheduler/event.py b/iamai/adapter/apscheduler/event.py new file mode 100644 index 00000000..412eb5c4 --- /dev/null +++ b/iamai/adapter/apscheduler/event.py @@ -0,0 +1,37 @@ +"""APScheduler 适配器事件。""" + +from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union + +from apscheduler.job import Job +from apscheduler.triggers.base import BaseTrigger + +from iamai.event import Event +from iamai.plugin import Plugin + +if TYPE_CHECKING: + from . import APSchedulerAdapter + + +__all__ = ["APSchedulerEvent"] + + +class APSchedulerEvent(Event["APSchedulerAdapter"]): + """APSchedulerEvent 事件基类。""" + + type: Optional[str] = "apscheduler" + plugin_class: Type[Plugin] # type: ignore + + @property + def job(self) -> Job: + """产生当前事件的 APScheduler `Job` 对象。""" + return self.adapter.plugin_class_to_job[self.plugin_class] + + @property + def trigger(self) -> Union[str, BaseTrigger]: + """当前事件对应的 Plugin 的 `trigger`。""" + return getattr(self.plugin_class, "trigger") # noqa: B009 + + @property + def trigger_args(self) -> Dict[str, Any]: + """当前事件对应的 Plugin 的 `trigger_args`。""" + return getattr(self.plugin_class, "trigger_args") # noqa: B009 diff --git a/iamai/adapter/bililive/__init__.py b/iamai/adapter/bililive/__init__.py new file mode 100644 index 00000000..4d22e7e1 --- /dev/null +++ b/iamai/adapter/bililive/__init__.py @@ -0,0 +1,432 @@ +"""bililive 协议适配器。 + +本适配器适配了 bililive 协议。 +协议详情请参考: [xfgryujk/blivedm](https://github.com/xfgryujk/blivedm) 。 + +TODO: + - [x] 扫码登录 + - [x] 本地缓存cookie登录 + - [ ] onebot 适配 + - [ ] api +""" + +import os +import re +import sys +import json +import time +import zlib +import struct +import asyncio +from math import log +from functools import partial +from abc import abstractmethod +from collections import namedtuple +from os.path import join, split, abspath, dirname +from typing import TYPE_CHECKING, Any, Dict, NamedTuple + +import qrcode +import aiohttp +from genericpath import exists +from aiohttp.client import ClientSession + +from iamai.utils import DataclassEncoder +from iamai.adapter.utils import WebSocketAdapter +from iamai.log import logger, error_or_exception + +from .event import * +from .message import * +from .config import Config +from .event import get_event_class + +if TYPE_CHECKING: + from .message import T_BililiveMSG + +__all__ = ["BililiveAdapter"] + +ROOM_INIT_URL = "https://api.live.bilibili.com/xlive/web-room/v1/index/getInfoByRoom" +DANMAKU_SERVER_CONF_URL = ( + "https://api.live.bilibili.com/xlive/web-room/v1/index/getDanmuInfo" +) +DEFAULT_DANMAKU_SERVER_LIST = [ + { + "host": "broadcastlv.chat.bilibili.com", + "port": 2243, + "wss_port": 443, + "ws_port": 2244, + } +] +QRCODE_REQUEST_URL = "http://passport.bilibili.com/qrcode/getLoginUrl" +CHECK_LOGIN_RESULT = "http://passport.bilibili.com/qrcode/getLoginInfo" +SEND_URL = "https://api.live.bilibili.com/msg/send" +MUTE_USER_URL = ( + "https://api.live.bilibili.com/xlive/web-ucenter/v1/banned/AddSilentUser" +) +ROOM_SLIENT_URL = "https://api.live.bilibili.com/xlive/web-room/v1/banned/RoomSilent" +ADD_BADWORD_URL = ( + "https://api.live.bilibili.com/xlive/web-ucenter/v1/banned/AddShieldKeyword" +) +DEL_BADWORD_URL = ( + "https://api.live.bilibili.com/xlive/web-ucenter/v1/banned/DelShieldKeyword" +) +HEADER_STRUCT = struct.Struct(">I2H2I") +HeaderTuple = namedtuple( + "HeaderTuple", ("pack_len", "raw_header_size", "ver", "operation", "seq_id") +) +WS_BODY_PROTOCOL_VERSION_INFLATE = 0 +WS_BODY_PROTOCOL_VERSION_NORMAL = 1 +WS_BODY_PROTOCOL_VERSION_DEFLATE = 2 + +user_cookies = aiohttp.cookiejar.CookieJar() + + +class BililiveAdapter(WebSocketAdapter[BililiveEvent, Config]): + """bililive 协议适配器。""" + + name: str = "bililive" + Config = Config + _gateway_response = {} # type: ignore + _host_server_list = DEFAULT_DANMAKU_SERVER_LIST + _api_response: Dict[Any, Any] + _api_response_cond: asyncio.Condition = None # type: ignore + _api_id: int = 0 + _heartbeat_interval = 30 + _retry_count = 0 + + def __getattr__(self, item): # type: ignore + return partial(self.call_api, item) + + async def startup(self): + """初始化适配器。""" + self.adapter_type = self.config.adapter_type # type: ignore + if self.adapter_type == "websocket": # type: ignore + self.adapter_type = "ws" # type: ignore + self.reconnect_interval = self.config.reconnect_interval # type: ignore + self.room_id = self.config.room_id # type: ignore + self.session_data_path = self.config.session_data_path # type: ignore + self._api_response_cond = asyncio.Condition() + self.jct: str = "" + self.cookies = {} + _path = f"{dirname(abspath(sys.argv[0]))}/{self.session_data_path}" + if not os.path.exists(_path): + os.mkdir(dirname(_path)) + if exists(_path): + with open(_path) as f: + self.cookies = json.load(f) + user_cookies.update_cookies(self.cookies) + if self.config.login: # type: ignore + logger.debug(f"Login enabled!") + try: + # 尝试登陆 + async with ClientSession(cookie_jar=user_cookies) as self.session: + success = await login(self.session) + + if success: + self._uid = get_cookies("DedeUserID") + self.jct = get_cookies("bili_jct") + + if self._uid == None or self.jct == None: + logger.error( + f"Unable to get cookies, please check your cookies." + ) + return + if not exists(_path): + for cookie in user_cookies: + self.cookies[cookie.key] = cookie.value + + logger.debug(f"Stored cookies: {self.cookies}") + with open(_path, mode="w") as f: + json.dump(self.cookies, f) + + await super().startup() + except Exception as e: + logger.error(e) + return + else: + logger.debug(f"Login disabled!") + await super().startup() + + async def websocket_connect(self): + """创建正向 WebSocket 连接。""" + + logger.info("Trying to connect to WebSocket server...") + host_server = self._host_server_list[ + self._retry_count % len(self._host_server_list) + ] + try: + async with self.session.ws_connect( + f'wss://{host_server["host"]}:{host_server["wss_port"]}/sub', + receive_timeout=self._heartbeat_interval + 5, + ) as self.websocket: + await self._send_auth() + self._heartbeat_timer_handle = asyncio.ensure_future( + self._start_heartbeat() + ) + logger.success(f"Success to be invited to room {self.room_id}.") + await self.handle_websocket() + except Exception as e: + logger.error(e) + self._retry_count += 1 + await asyncio.sleep(self.reconnect_interval) + await self.websocket_connect() + + async def handle_websocket_msg(self, msg: aiohttp.WSMessage): + """处理 WebSocket 消息。""" + logger.info(msg) + if msg.type == aiohttp.WSMsgType.BINARY: + try: + data = msg.data # await self.websocket.receive_bytes() + logger.info(data) + offset = 0 + while offset < len(data): + try: + header = HeaderTuple(*HEADER_STRUCT.unpack_from(data, offset)) + except struct.error: + break + if header.operation == Operation.HEARTBEAT_REPLY: + popularity = int.from_bytes( + data[ + offset + HEADER_STRUCT.size : offset + + HEADER_STRUCT.size + + 4 + ], + "big", + ) + await self._on_receive_popularity(popularity) + elif header.operation == Operation.SEND_MSG_REPLY: + body = data[ + offset + HEADER_STRUCT.size : offset + header.pack_len + ] + if header.ver == WS_BODY_PROTOCOL_VERSION_DEFLATE: + self._loop = asyncio.get_event_loop() + body = await self._loop.run_in_executor( + None, zlib.decompress, body + ) + # await self.handle_websocket_msg(body) + return + else: + try: + body = json.loads(body.decode("utf-8")) + data = body + logger.info(data) + data["post_type"] = data["cmd"].lower().split("_")[0] + data["message"] = data.get("msg_common") or "" + data["message_id"] = data.get("msg_id") or 0 + data["group_id"] = data.get("roomid") or 0 + data["time"] = data.get("send_time") or 0 + await self.handle_bililive_event(data) # type: ignore + except Exception: + logger.debug(f"body: {body}") + raise + + elif header.operation == Operation.AUTH_REPLY: + await self.websocket.send_bytes( + self._make_packet({}, Operation.HEARTBEAT) + ) + + else: + body = data[ + offset + HEADER_STRUCT.size : offset + header.pack_len + ] + logger.warning( + f"room {self.room_id,} 未知包类型:operation={header.operation, header, body}" + ) + + offset += header.pack_len + except Exception as e: + error_or_exception( + "WebSocket message parsing error, not BINARY:", + e, + self.bot.config.bot.log.verbose_exception, + ) + async with self._api_response_cond: + self._api_response = msg.data + logger.warning(msg.data) + self._api_response_cond.notify_all() + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error( + f"WebSocket connection closed " + f"with exception {self.websocket.exception()!r}" + ) + + async def handle_bililive_event(self, data: Dict[str, Any]): + logger.info(str(data)) + post_type = data.get("post_type") + event_type = data.get(f"{post_type}_type") + sub_type = data.get("sub_type", None) + + event_class = get_event_class(post_type, event_type, sub_type) + bililive_event = event_class(adapter=self, **data) + + await self.handle_event(bililive_event) + + # 发送登录包 + async def _send_auth(self): + auth_params = { + "uid": self._uid or 0, # 0: 游客 + "roomid": self.room_id, + "protover": 2, + "platform": "web", + "clientver": "1.14.3", + "type": 2, + } + await self.websocket.send_bytes(self._make_packet(auth_params, Operation.AUTH)) + + @staticmethod + def _make_packet(data, operation): + body = json.dumps(data).encode("utf-8") + header = HEADER_STRUCT.pack( + HEADER_STRUCT.size + len(body), HEADER_STRUCT.size, 1, operation, 1 + ) + return header + body + + async def _start_heartbeat(self) -> None: + """ + 每30s一次心跳 + :return: + """ + hb = "0000001f0010000100000002000000015b6f626a656374204f626a6563745d" + try: + while not self.bot.should_exit.is_set(): + if self.websocket.closed: + break + await self.websocket.send_bytes(bytes.fromhex(hb)) + logger.debug(f"HeartBeat sent!") + await asyncio.sleep(29) + except Exception as e: + logger.error(e) + + async def call_api(self, api: str, **params): + """调用 bililive API。 + + TODO: 因为基于OlivOS的那个OlivaBiliLive插件架构其实相当于一个小框架的缘故, + 所以要改的东西太多了,这里插个保留的接口... + """ + + ... + + async def send_danmu(self, **fields) -> bool: + token = get_cookies("bili_jct") + async with ClientSession(cookie_jar=user_cookies) as session: + try: + res = await _post( + session, + SEND_URL, + rnd=time.time(), + csrf=token, + csrf_token=token, + **fields, + ) + return "data" in res + except Exception as e: + logger.warning(f"Send danmu failed: {e}") + return False + + async def send( + self, + danmaku: str, + fontsize: int = 25, + color: int = 0xFFFFFF, + pos: DanmakuPosition = DanmakuPosition.NORMAL, + ) -> bool: + # don't know what the hell is bubble + return await self.send_danmu( + msg=danmaku, + fontsize=fontsize, + color=color, + pos=pos, + roomid=self.room_id, + bubble=0, + ) + + @abstractmethod + async def _on_receive_popularity(self, popularity: int): + pass + + +def rawData_to_jsonData(data: bytes): + packetLen = int(data[:4].hex(), 16) + ver = int(data[6:8].hex(), 16) + op = int(data[8:12].hex(), 16) + + if len(data) > packetLen: # 防止 + rawData_to_jsonData(data[packetLen:]) + data = data[:packetLen] + + if ver == 2: + data = zlib.decompress(data[16:]) + return rawData_to_jsonData(data) + + if op == 5: + try: + jd = json.loads(data[16:].decode("utf-8", errors="ignore")) + return jd + except Exception as e: + pass + + +async def login(session: ClientSession) -> bool: + if get_cookies("bili_jct") != None: + logger.info(f"Aleady login!") + return True + try: + res = await _get(session, QRCODE_REQUEST_URL) + ts = res["ts"] + outdated = ts + 180 * 1000 # 180 秒後逾時 + authKey = res["data"]["oauthKey"] + url = res["data"]["url"] + qr = qrcode.QRCode() + logger.info("请扫描下面的二维码进行登录... (或者到目录下寻找 qrcode.png)") + qr.add_data(url) + qr.print_ascii(invert=True) + qr.make_image().save("qrcode.png") + while True: + await asyncio.sleep(5) + if time.time() > outdated: + logger.warning("Timeout!") + return False # 登入失敗 + res = await _post(session, CHECK_LOGIN_RESULT, oauthKey=authKey) + if res["status"]: + logger.success("login success!") + return True + else: + code = res["data"] + if code in [-1, -2]: + logger.warning(f'login failed: {res["message"]}') + return False + except Exception as e: + logger.warning(f"Something went wrong: {e}") + return False + finally: + os.remove("qrcode.png") + + +def get_cookies(name: str) -> any: # type: ignore + for cookie in user_cookies: + if cookie.key == name: + return cookie.value + return None + + +async def _get(session: ClientSession, url: str): + async with session.get(url) as resp: + resp.raise_for_status() + data = await resp.json() + logger.debug(data) + if "code" in data and data["code"] != 0: + raise Exception(data["message"] if "message" in data else data["code"]) + return data + + +async def _post(session: ClientSession, url: str, **data): + form = aiohttp.FormData() + for k, v in data.items(): + form.add_field(k, v) + logger.debug(f"Sending POST: {url}, content: {data}") + async with session.post(url, data=form) as resp: + resp.raise_for_status() + data = await resp.json() + logger.debug(data) + if "code" in data and data["code"] != 0: + raise Exception(data["message"] if "message" in data else data["code"]) + return data diff --git a/iamai/adapter/bililive/api/blivedm.py b/iamai/adapter/bililive/api/blivedm.py new file mode 100644 index 00000000..7da42301 --- /dev/null +++ b/iamai/adapter/bililive/api/blivedm.py @@ -0,0 +1,437 @@ +# -*- coding: utf-8 -*- + +__all__ = ["BLiveClient"] + +import json +import zlib +import struct +import asyncio +import logging +import ssl as ssl_ +from enum import IntEnum + +# code from xfgryujk +from abc import abstractmethod +from collections import namedtuple +from typing import * # ??这是什么粗暴的import方式 + +import aiohttp + +logger = logging.getLogger(__name__) + + +ROOM_INIT_URL = "https://api.live.bilibili.com/xlive/web-room/v1/index/getInfoByRoom" +DANMAKU_SERVER_CONF_URL = ( + "https://api.live.bilibili.com/xlive/web-room/v1/index/getDanmuInfo" +) +DEFAULT_DANMAKU_SERVER_LIST = [ + { + "host": "broadcastlv.chat.bilibili.com", + "port": 2243, + "wss_port": 443, + "ws_port": 2244, + } +] + +HEADER_STRUCT = struct.Struct(">I2H2I") +HeaderTuple = namedtuple( + "HeaderTuple", ("pack_len", "raw_header_size", "ver", "operation", "seq_id") +) +WS_BODY_PROTOCOL_VERSION_INFLATE = 0 +WS_BODY_PROTOCOL_VERSION_NORMAL = 1 +WS_BODY_PROTOCOL_VERSION_DEFLATE = 2 + + +# go-common\app\service\main\broadcast\model\operation.go +class Operation(IntEnum): + HANDSHAKE = 0 + HANDSHAKE_REPLY = 1 + HEARTBEAT = 2 + HEARTBEAT_REPLY = 3 + SEND_MSG = 4 + SEND_MSG_REPLY = 5 + DISCONNECT_REPLY = 6 + AUTH = 7 + AUTH_REPLY = 8 + RAW = 9 + PROTO_READY = 10 + PROTO_FINISH = 11 + CHANGE_ROOM = 12 + CHANGE_ROOM_REPLY = 13 + REGISTER = 14 + REGISTER_REPLY = 15 + UNREGISTER = 16 + UNREGISTER_REPLY = 17 + # B站业务自定义OP + # MinBusinessOp = 1000 + # MaxBusinessOp = 10000 + + +class InitError(Exception): + """初始化失败""" + + +class BLiveClient: + def __init__( + self, + room_id, + uid=0, + session: aiohttp.ClientSession = None, + heartbeat_interval=30, + ssl=True, + loop=None, + ): + """ + :param room_id: URL中的房间ID,可以为短ID + :param uid: B站用户ID,0表示未登录 + :param session: cookie、连接池 + :param heartbeat_interval: 发送心跳包的间隔时间(秒) + :param ssl: True表示用默认的SSLContext验证,False表示不验证,也可以传入SSLContext + :param loop: 协程事件循环 + """ + # 用来init_room的临时房间ID + self._tmp_room_id = room_id + # 调用init_room后初始化 + self._room_id = self._room_short_id = self._room_owner_uid = None + # [{host: "tx-bj4-live-comet-04.chat.bilibili.com", port: 2243, wss_port: 443, ws_port: 2244}, ...] + self._host_server_list = None + self._host_server_token = None + self._uid = uid + + if loop is not None: + self._loop = loop + elif session is not None: + # noinspection PyDeprecation + self._loop = session.loop + else: + self._loop = asyncio.get_event_loop() + self._future = None + + if session is None: + self._session = aiohttp.ClientSession( + loop=self._loop, timeout=aiohttp.ClientTimeout(total=10) + ) + self._own_session = True + else: + self._session = session + self._own_session = False + # noinspection PyDeprecation + if self._session.loop is not self._loop: + raise RuntimeError("BLiveClient and session has to use same event loop") + + self._heartbeat_interval = heartbeat_interval + # noinspection PyProtectedMember + self._ssl = ssl if ssl else ssl_._create_unverified_context() + self._websocket = None + self._heartbeat_timer_handle = None + + @property + def is_running(self): + return self._future is not None + + @property + def room_id(self): + """ + 房间ID,调用init_room后初始化 + """ + return self._room_id + + @property + def room_short_id(self): + """ + 房间短ID,没有则为0,调用init_room后初始化 + """ + return self._room_short_id + + @property + def room_owner_uid(self): + """ + 主播ID,调用init_room后初始化 + """ + return self._room_owner_uid + + async def close(self): + """ + 如果session是自己创建的则关闭session + """ + if self._own_session: + await self._session.close() + + def start(self): + """ + 创建相关的协程,不会执行事件循环 + :return: 协程的future + """ + if self._future is not None: + raise RuntimeError("This client is already running") + self._future = asyncio.ensure_future(self._message_loop(), loop=self._loop) + self._future.add_done_callback(self.__on_message_loop_done) + return self._future + + def __on_message_loop_done(self, future): + self._future = None + logger.debug("room %s 消息协程结束", self.room_id) + exception = future.exception() + if exception is not None: + logger.exception( + "room %s 消息协程异常结束:", + self.room_id, + exc_info=(type(exception), exception, exception.__traceback__), + ) + + def stop(self): + """ + 停止相关的协程 + :return: 协程的future + """ + if self._future is None: + raise RuntimeError("This client is not running") + self._future.cancel() + return self._future + + async def init_room(self): + """ + :return: True代表没有降级,如果需要降级后还可用,重载这个函数返回True + """ + res = True + if not await self._init_room_id_and_owner(): + res = False + # 失败了则降级 + self._room_id = self._room_short_id = self._tmp_room_id + self._room_owner_uid = 0 + + if not await self._init_host_server(): + res = False + # 失败了则降级 + self._host_server_list = DEFAULT_DANMAKU_SERVER_LIST + self._host_server_token = None + return res + + async def _init_room_id_and_owner(self): + try: + async with self._session.get( + ROOM_INIT_URL, params={"room_id": self._tmp_room_id}, ssl=self._ssl + ) as res: + if res.status != 200: + logger.warning( + "room %d init_room失败:%d %s", + self._tmp_room_id, + res.status, + res.reason, + ) + return False + data = await res.json() + if data["code"] != 0: + logger.warning( + "room %d init_room失败:%s", self._tmp_room_id, data["message"] + ) + return False + if not self._parse_room_init(data["data"]): + return False + except (aiohttp.ClientConnectionError, asyncio.TimeoutError): + logger.exception("room %d init_room失败:", self._tmp_room_id) + return False + return True + + def _parse_room_init(self, data): + room_info = data["room_info"] + self._room_id = room_info["room_id"] + self._room_short_id = room_info["short_id"] + self._room_owner_uid = room_info["uid"] + return True + + async def _init_host_server(self): + try: + async with self._session.get( + DANMAKU_SERVER_CONF_URL, + params={"id": self._room_id, "type": 0}, + ssl=self._ssl, + ) as res: + if res.status != 200: + logger.warning( + "room %d getConf失败:%d %s", + self._room_id, + res.status, + res.reason, + ) + return False + data = await res.json() + if data["code"] != 0: + logger.warning( + "room %d getConf失败:%s", self._room_id, data["message"] + ) + return False + if not self._parse_danmaku_server_conf(data["data"]): + return False + except (aiohttp.ClientConnectionError, asyncio.TimeoutError): + logger.exception("room %d getConf失败:", self._room_id) + return False + return True + + def _parse_danmaku_server_conf(self, data): + self._host_server_list = data["host_list"] + self._host_server_token = data["token"] + if not self._host_server_list: + logger.warning("room %d getConf失败:host_server_list为空", self._room_id) + return False + return True + + @staticmethod + def _make_packet(data, operation): + body = json.dumps(data).encode("utf-8") + header = HEADER_STRUCT.pack( + HEADER_STRUCT.size + len(body), HEADER_STRUCT.size, 1, operation, 1 + ) + return header + body + + async def _send_auth(self): + auth_params = { + "uid": self._uid, + "roomid": self._room_id, + "protover": 2, + "platform": "web", + "clientver": "1.14.3", + "type": 2, + } + if self._host_server_token is not None: + auth_params["key"] = self._host_server_token + await self._websocket.send_bytes(self._make_packet(auth_params, Operation.AUTH)) + + async def _message_loop(self): + # 如果之前未初始化则初始化 + if self._host_server_token is None: + if not await self.init_room(): + raise InitError("初始化失败") + + retry_count = 0 + while True: + try: + # 连接 + host_server = self._host_server_list[ + retry_count % len(self._host_server_list) + ] + async with self._session.ws_connect( + f'wss://{host_server["host"]}:{host_server["wss_port"]}/sub', + receive_timeout=self._heartbeat_interval + 5, + ssl=self._ssl, + ) as websocket: + self._websocket = websocket + await self._send_auth() + self._heartbeat_timer_handle = self._loop.call_later( + self._heartbeat_interval, self._on_send_heartbeat + ) + + # 处理消息 + async for message in websocket: # type: aiohttp.WSMessage + retry_count = 0 + if message.type != aiohttp.WSMsgType.BINARY: + logger.warning( + "room %d 未知的websocket消息:type=%s %s", + self.room_id, + message.type, + message.data, + ) + continue + + try: + await self._handle_message(message.data) + except asyncio.CancelledError: + logger.warn(f"{self.room_id} 程序被強制取消。") + raise + except Exception: + logger.exception( + "room %d 处理消息时发生错误:", self.room_id + ) + + except asyncio.CancelledError: + break + except (aiohttp.ClientConnectionError, asyncio.TimeoutError): + # 重连 + pass + except ssl_.SSLError: + logger.exception("SSL错误:") + # 证书错误时无法重连 + break + finally: + self._websocket = None + if self._heartbeat_timer_handle is not None: + self._heartbeat_timer_handle.cancel() + self._heartbeat_timer_handle = None + + retry_count += 1 + logger.warning("room %d 掉线重连中%d", self.room_id, retry_count) + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + break + + def _on_send_heartbeat(self): + coro = self._websocket.send_bytes(self._make_packet({}, Operation.HEARTBEAT)) + asyncio.ensure_future(coro, loop=self._loop) + self._heartbeat_timer_handle = self._loop.call_later( + self._heartbeat_interval, self._on_send_heartbeat + ) + + async def _handle_message(self, data): + offset = 0 + while offset < len(data): + try: + header = HeaderTuple(*HEADER_STRUCT.unpack_from(data, offset)) + except struct.error: + break + + if header.operation == Operation.HEARTBEAT_REPLY: + popularity = int.from_bytes( + data[offset + HEADER_STRUCT.size : offset + HEADER_STRUCT.size + 4], + "big", + ) + await self._on_receive_popularity(popularity) + + elif header.operation == Operation.SEND_MSG_REPLY: + body = data[offset + HEADER_STRUCT.size : offset + header.pack_len] + if header.ver == WS_BODY_PROTOCOL_VERSION_DEFLATE: + body = await self._loop.run_in_executor(None, zlib.decompress, body) + await self._handle_message(body) + else: + try: + body = json.loads(body.decode("utf-8")) + await self._handle_command(body) + except Exception: + logger.error("body: %s", body) + raise + + elif header.operation == Operation.AUTH_REPLY: + await self._websocket.send_bytes( + self._make_packet({}, Operation.HEARTBEAT) + ) + + else: + body = data[offset + HEADER_STRUCT.size : offset + header.pack_len] + logger.warning( + "room %d 未知包类型:operation=%d %s%s", + self.room_id, + header.operation, + header, + body, + ) + + offset += header.pack_len + + async def _handle_command(self, command): + if isinstance(command, list): + for one_command in command: + await self._handle_command(one_command) + return + cmd = command.get("cmd", "") + pos = cmd.find(":") # 2019-5-29 B站弹幕升级新增了参数 + if pos != -1: + cmd = cmd[:pos] + await self.on_command_received(cmd, command) + + @abstractmethod + async def on_command_received(self, cmd, data): + pass + + @abstractmethod + async def _on_receive_popularity(self, popularity: int): + pass diff --git a/iamai/adapter/bililive/config.py b/iamai/adapter/bililive/config.py new file mode 100644 index 00000000..e62f4481 --- /dev/null +++ b/iamai/adapter/bililive/config.py @@ -0,0 +1,30 @@ +"""Bililive 适配器配置。""" + +from typing import Any, Dict, List, Union, Literal, Optional + +from iamai.config import ConfigModel + + +class Config(ConfigModel): + """Bililive 配置类,将在适配器被加载时被混入到机器人主配置中。 + + Attributes: + adapter_type: 适配器类型,需要和协议端配置相同。 + reconnect_interval: 重连等待时间。 + api_timeout: 进行 API 调用时等待返回响应的超时时间。 + show_raw: 是否显示原始数据,默认为 False,不显示。 + session_data_path: session 数据文件路径, 默认为 "data/session.token"。 + report_self_message: 是否上报自己发送的消息,默认为 False,不上报。 + room_id: 监听的房间号列表,默认为 [0])。 + ssl: 是否使用 SSL,默认为 True,使用。 + """ + + __config_name__ = "bililive" + adapter_type: Literal["ws"] = "ws" + reconnect_interval: int = 3 + api_timeout: int = 1000 + session_data_path: str = "data/session.token" + show_raw: bool = False + report_self_message: bool = False + room_id: int = 0 + login: bool = True diff --git a/iamai/adapter/bililive/event.py b/iamai/adapter/bililive/event.py new file mode 100644 index 00000000..983c1937 --- /dev/null +++ b/iamai/adapter/bililive/event.py @@ -0,0 +1,309 @@ +"""Bililive 适配器事件。""" + +import asyncio +import inspect +from enum import IntEnum +from email import message +from collections import UserDict +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Type, + Tuple, + Union, + Literal, + TypeVar, + Optional, +) + +from isort import literal +from pydantic import Field, HttpUrl, BaseModel, validator, root_validator + +from iamai.event import Event + +from .message import Message, BililiveMessage + +if TYPE_CHECKING: + from . import BililiveAdapter + from .message import T_BililiveMSG + +T_BililiveEvent = TypeVar("T_BililiveEvent", bound="BililiveEvent") + + +# go-common\app\service\main\broadcast\model\operation.go +class Operation(IntEnum): + HANDSHAKE = 0 + HANDSHAKE_REPLY = 1 + HEARTBEAT = 2 + HEARTBEAT_REPLY = 3 + SEND_MSG = 4 + SEND_MSG_REPLY = 5 + DISCONNECT_REPLY = 6 + AUTH = 7 + AUTH_REPLY = 8 + RAW = 9 + PROTO_READY = 10 + PROTO_FINISH = 11 + CHANGE_ROOM = 12 + CHANGE_ROOM_REPLY = 13 + REGISTER = 14 + REGISTER_REPLY = 15 + UNREGISTER = 16 + UNREGISTER_REPLY = 17 + # B站业务自定义OP + # MinBusinessOp = 1000 + # MaxBusinessOp = 10000 + + +class DanmakuPosition(IntEnum): + TOP = (5,) + BOTTOM = (4,) + NORMAL = 1 + + +class BililiveEvent(Event["BililiveAdapter"]): + """Blilive 适配器事件类。""" + + __event__ = "" + cmd: str + + +class MessageEvent(BililiveEvent): + """消息事件""" + + __event__ = "message" + post_type: Literal["message"] = "message" + sub_type: str + message: BililiveMessage + session_id: str + + def __repr__(self) -> str: + return f'Event<{self.type}>: "{self.message}"' + + def get_plain_text(self) -> str: + return self.message.get_plain_text() + + async def reply(self, msg: "T_BililiveMSG") -> Dict[str, Any]: + raise NotImplementedError + + +class Danmu_msg(MessageEvent): + """弹幕""" + + __event__ = "message.danmu_msg" + message_type: Literal["danmu_msg"] + info: List[Any] + + async def reply(self, msg: "T_BililiveMSG") -> Dict[str, Any]: + return await self.adapter.send(danmaku=msg) + + +class Super_chat_message(MessageEvent): + """醒目留言""" + + __event__ = "message.super_chat_message" + message_type: Literal["super_chat_message"] + data: Dict[str, Any] + duration: int + + +class NoticeEvent(Event): + __event__ = "notice" + + +class Combo_send(NoticeEvent): + """连击礼物""" + + __event__ = "notice.combo_send" + data: Dict[Any, Any] + notice_type: Literal["combo_send"] + + +class Send_gift(NoticeEvent): + """投喂礼物""" + + __event__ = "notice.send_gift" + data: Dict[Any, Any] + notice_type: Literal["send_gift"] + + +class Common_notice_danmaku(NoticeEvent): + """限时任务(系统通知的)""" + + __event__ = "notice.common_notice_danmaku" + data: Dict[Any, Any] + notice_type: Literal["common_notice_danmaku"] + + +class Entry_effect(NoticeEvent): + """舰长进房""" + + __event__ = "notice.entry_effect" + data: Dict[Any, Any] + notice_type: Literal["entry_effect"] + + +class Interact_word(NoticeEvent): + """普通进房消息""" + + __event__ = "notice_interact_word" + data: Dict[Any, Any] + notice_type: Literal["notice_interact_word"] + + +class Guard_buy(NoticeEvent): + """上舰""" + + __event__ = "notice.guard_buy" + data: Dict[Any, Any] + notice_type: Literal["guard_buy"] + + +class User_toast_msg(NoticeEvent): + """续费舰长""" + + __event__ = "notice.user_toast_msg" + data: Dict[Any, Any] + notice_type: Literal["user_toast_msg"] + + +class Notice_msg(NoticeEvent): + """在本房间续费了舰长""" + + __event__ = "notice.notice_msg" + id: int + name: str + full: Dict[str, Any] + half: Dict[str, Any] + side: Dict[str, Any] + scatter: Dict[str, int] + roomid: int + real_roomid: int + msg_common: int + msg_self: str + link_url: str + msg_type: int + shield_uid: int + business_id: str + marquee_id: str + notice_type: Union[Literal["notice_msg"], int] + + +class Like_info_v3_click(NoticeEvent): + """点赞""" + + __event__ = "notice.like_info_v3_click" + data: Dict[Any, Any] + notice_type: Literal["like_info_v3_click"] + + +class Like_info_v3_update(NoticeEvent): + """总点赞数""" + + __event__ = "notice.like_info_v3_update" + data: Dict[Any, Any] + notice_type: Literal["like_info_v3_update"] + + +class Online_rank_count(NoticeEvent): + """在线等级统计""" + + __event__ = "notice.online_rank_count" + data: Dict[Any, Any] + notice_type: Literal["online_rank_count"] + + +class Online_rank_v2(NoticeEvent): + """在线等级榜""" + + __event__ = "notice.online_rank_v2" + data: Dict[Any, Any] + notice_type: Literal["online_rank_v2"] + + +class Popular_rank_changed(NoticeEvent): + __event__ = "notice.popular_rank_changed" + data: Dict[Any, Any] + notice_type: Literal["popular_rank_changed"] + + +class Room_change(NoticeEvent): + """房间信息变动(分区、标题等)""" + + __event__ = "notice.room_change" + data: Dict[Any, Any] + notice_type: Literal["room_change"] + + +class Room_real_time_message_update(NoticeEvent): + """房间数据""" + + __event__ = "notice.room_real_time_message_update" + data: Dict[Any, Any] + notice_type: Literal["room_real_time_message_update"] + + +class Watched_change(NoticeEvent): + """直播间观看人数""" + + __event__ = "notice.watched_change" + data: Dict[Any, Any] + notice_type: Literal["watched_change"] + + +class Stop_live_room_list(NoticeEvent): + """下播列表""" + + __event__ = "notice.stop_live_room_list" + data: Dict[Any, Any] + room_id_list: List[int] + notice_type: Literal["stop_live_room_list"] + + +class Anchor_lot_start(NoticeEvent): + """天选之人开始""" + + __event__ = "notice.anchor_lot_start" + data: Dict[Any, Any] + notice_type: Literal["anchor_lot_start"] + + def get_anchor_lot_info(self): + """获取天选之人的相关信息""" + return { + "award_name": self.data["award_name"], + "danmu": self.data["danmu"], + "gift_name": self.data["gift_name"], + } + + +class Anchor_lot_award(NoticeEvent): + """天选之人结果""" + + __event__ = "notice.anchor_lot_award" + data: Dict[Any, Any] + notice_type: Literal["anchor_lot_award"] + + def winner_info(self): + """获取中奖人信息""" + return self.data["award_users"] + + +# 事件类映射 +_bililive_events = { + model.__event__: model + for model in globals().values() + if inspect.isclass(model) and issubclass(model, BililiveEvent) +} + + +def get_event_class( + post_type: str, event_type: str, sub_type: Optional[str] = None +) -> Type[T_BililiveEvent]: # type: ignore + if sub_type is None: + return _bililive_events[".".join((post_type, event_type))] # type: ignore + return ( + _bililive_events.get(".".join((post_type, event_type, sub_type))) + or _bililive_events[".".join((post_type, event_type))] + ) # type: ignore diff --git a/iamai/adapter/bililive/exceptions.py b/iamai/adapter/bililive/exceptions.py new file mode 100644 index 00000000..d57cc55a --- /dev/null +++ b/iamai/adapter/bililive/exceptions.py @@ -0,0 +1,13 @@ +"""Bililive 适配器异常。""" + +from typing import Optional + +from iamai.exceptions import AdapterException + + +class BililiveException(AdapterException): + """Bililive 适配器异常基类。""" + + +class InitError(BililiveException): + """初始化失败""" diff --git a/iamai/adapter/bililive/message.py b/iamai/adapter/bililive/message.py new file mode 100644 index 00000000..c2eaae56 --- /dev/null +++ b/iamai/adapter/bililive/message.py @@ -0,0 +1,39 @@ +"""Bililive 适配器消息。""" + +import json +from io import StringIO +from dataclasses import dataclass +from typing_extensions import deprecated +from typing import Any, Dict, Type, Tuple, Union, Mapping, Iterable, Optional, cast + +from iamai.message import Message, MessageSegment + +from .exceptions import * + +__all__ = ["T_BililiveMSG", "BililiveMessage", "BililiveMessageSegment"] + +T_BililiveMSG = Union[ + str, Mapping, Iterable[Mapping], "BililiveMessageSegment", "BililiveMessage" +] + + +class BililiveMessage(Message["BililiveMessageSegment"]): + @property + def _message_segment_class(self) -> Type["BililiveMessageSegment"]: + return BililiveMessageSegment + + def _str_to_message_segment(self, msg: str) -> "BililiveMessageSegment": + return BililiveMessageSegment.danmu(msg) + + +class BililiveMessageSegment(MessageSegment["BililiveMessage"]): + @property + def _message_class(cls) -> Type["BililiveMessage"]: + return BililiveMessage + + def __str__(self) -> str: + return self.data.get("danmu", "") + + @classmethod + def danmu(cls, msg: str) -> "BililiveMessageSegment": + return cls(type="danmu", data={"danmu": msg}) diff --git a/iamai/adapter/bililive/tests.py b/iamai/adapter/bililive/tests.py new file mode 100644 index 00000000..1c9a89ff --- /dev/null +++ b/iamai/adapter/bililive/tests.py @@ -0,0 +1,35 @@ +import time + +from bilibili_api import Danmaku, Credential, sync +from bilibili_api.live import LiveRoom, LiveDanmaku + +# 自己直播间号 +ROOMID = 21752074 +# 凭证 根据回复弹幕的账号填写 +credential = Credential( + sessdata="b62ece97%2C1705379969%2Ccdd22*71", + bili_jct="a6e051b71890306f61b94771eb7281ab", +) +# 监听直播间弹幕 +monitor = LiveDanmaku(ROOMID, credential=credential) +# 用来发送弹幕 +sender = LiveRoom(ROOMID, credential=credential) + + +@monitor.on("DANMU_MSG") +async def recv(event): + # 发送者UID + print(event) + uid = event["data"]["info"][2][0] + # 排除自己发送的弹幕 + # if uid == UID: + # return + # 弹幕文本 + msg = event["data"]["info"][1] + if str(msg).startswith("1"): + # 发送弹幕 + await sender.send_danmaku(Danmaku(str(time.time()))) + + +# 启动监听 +sync(monitor.connect()) diff --git a/iamai/adapter/bililive/utils/bilibili_api.py b/iamai/adapter/bililive/utils/bilibili_api.py new file mode 100644 index 00000000..c67e2e63 --- /dev/null +++ b/iamai/adapter/bililive/utils/bilibili_api.py @@ -0,0 +1,233 @@ +import os +import time +import asyncio +import logging +from typing import Any + +import qrcode +import aiohttp +from aiohttp import cookiejar +from qrcode.main import QRCode +from aiohttp.client import ClientSession +from aiohttp.client_exceptions import ClientResponseError + +QRCODE_REQUEST_URL = "http://passport.bilibili.com/qrcode/getLoginUrl" +CHECK_LOGIN_RESULT = "http://passport.bilibili.com/qrcode/getLoginInfo" +SEND_URL = "https://api.live.bilibili.com/msg/send" +MUTE_USER_URL = ( + "https://api.live.bilibili.com/xlive/web-ucenter/v1/banned/AddSilentUser" +) +ROOM_SLIENT_URL = "https://api.live.bilibili.com/xlive/web-room/v1/banned/RoomSilent" +ADD_BADWORD_URL = ( + "https://api.live.bilibili.com/xlive/web-ucenter/v1/banned/AddShieldKeyword" +) +DEL_BADWORD_URL = ( + "https://api.live.bilibili.com/xlive/web-ucenter/v1/banned/DelShieldKeyword" +) + + +user_cookies = cookiejar.CookieJar() + +""" +Bilibili Client Operation + +""" + + +async def login(session: ClientSession) -> bool: + if get_cookies("bili_jct") != None: + # 無需重複獲取 + logging.info(f"先前已經登入,因此無需再度登入。") + return True + try: + res = await _get(session, QRCODE_REQUEST_URL) + + ts = res["ts"] + outdated = ts + 180 * 1000 # 180 秒後逾時 + authKey = res["data"]["oauthKey"] + + url = res["data"]["url"] + qr = qrcode.QRCode() + logging.info("請掃描下列二維碼進行登入... (或者到目錄下尋找 qrcode.png)") + + qr.add_data(url) + qr.print_ascii(invert=True) + qr.make_image().save("qrcode.png") + + while True: + await asyncio.sleep(5) + + if time.time() > outdated: + logging.info("已逾時。") + return False # 登入失敗 + + res = await _post(session, CHECK_LOGIN_RESULT, oauthKey=authKey) + + if res["status"]: + logging.info("登入成功。") + return True + else: + code = res["data"] + if code in [-1, -2]: + logging.warning(f'登入失敗: {res["message"]}') + return False + + except ClientResponseError as e: + logging.warning(f"請求時出現錯誤: {e}") + return False + finally: + os.remove("qrcode.png") + + +async def send_danmu(**fields) -> bool: + token = get_cookies("bili_jct") + async with ClientSession(cookie_jar=user_cookies) as session: + try: + res = await _post( + session, + SEND_URL, + rnd=time.time(), + csrf=token, + csrf_token=token, + **fields, + ) + return "data" in res + except Exception as e: + logging.warning(f"發送彈幕時出現錯誤: {e}") + return False + + +def get_cookies(name: str) -> Any: + for cookie in user_cookies: + if cookie.key == name: + return cookie.value + return None + + +async def mute_user(tuid: int, roomid: int) -> bool: + token = get_cookies("bili_jct") + async with ClientSession(cookie_jar=user_cookies) as session: + try: + res = await _post( + session, + MUTE_USER_URL, + csrf=token, + csrf_token=token, + visit_id="", + mobile_app="web", + tuid=str(tuid), + room_id=str(roomid), + ) + return res["code"] == 0 + except Exception as e: + logging.warning(f"禁言時出現錯誤: {e}") + return False + + +async def room_slient(roomid: int, slientType: str, level: int, minute: int) -> bool: + type_availables = ["off", "medal", "member", "level"] + if slientType not in type_availables: + logging.warning(f"未知的禁言類型: {slientType} ({type_availables})") + return False + + minute_available = [0, 30, 60] + if minute not in minute_available: + logging.warning(f"未知的静音时间: {minute} ({minute_available})") + return False + + token = get_cookies("bili_jct") + async with ClientSession(cookie_jar=user_cookies) as session: + try: + res = await _post( + session, + ROOM_SLIENT_URL, + csrf=token, + csrf_token=token, + visit_id="", + room_id=str(roomid), + type=str(slientType), + minute=str(minute), + level=str(level), + ) + return res["code"] == 0 + except Exception as e: + logging.warning(f"房間靜音時出現錯誤: {e}") + return False + + +async def add_badword(roomid: int, keyword: str) -> bool: + token = get_cookies("bili_jct") + async with ClientSession(cookie_jar=user_cookies) as session: + try: + res = await _post( + session, + ADD_BADWORD_URL, + csrf=token, + csrf_token=token, + visit_id="", + room_id=str(roomid), + keyword=keyword, + ) + return res["code"] == 0 + except Exception as e: + logging.warning(f"添加屏蔽字時出現錯誤: {e}") + return False + + +async def remove_badword(roomid: int, keyword: str) -> bool: + token = get_cookies("bili_jct") + async with ClientSession(cookie_jar=user_cookies) as session: + try: + res = await _post( + session, + DEL_BADWORD_URL, + csrf=token, + csrf_token=token, + visit_id="", + room_id=str(roomid), + keyword=keyword, + ) + return res["code"] == 0 + except Exception as e: + logging.warning(f"删除屏蔽字時出現錯誤: {e}") + return False + + +def logout(): + user_cookies.clear() + + +""" +Http Request + +""" + + +async def _get(session: ClientSession, url: str): + async with session.get(url) as resp: + resp.raise_for_status() + data = await resp.json() + logging.debug(data) + if "code" in data and data["code"] != 0: + raise Exception(data["message"] if "message" in data else data["code"]) + return data + + +async def _post(session: ClientSession, url: str, **data): + form = aiohttp.FormData() + for k, v in data.items(): + form.add_field(k, v) + logging.debug(f"正在发送 POST 请求: {url}, 内容: {data}") + async with session.post(url, data=form) as resp: + resp.raise_for_status() + data = await resp.json() + logging.debug(data) + if "code" in data and data["code"] != 0: + raise Exception(data["message"] if "message" in data else data["code"]) + return data + + +if __name__ == "__main__": + session = ClientSession(cookies={"a": 1, "b": 2}) + for c in session.cookie_jar: + print(c.key, c.value) diff --git a/iamai/adapter/bililive/utils/bilibili_bot.py b/iamai/adapter/bililive/utils/bilibili_bot.py new file mode 100644 index 00000000..a85f5a97 --- /dev/null +++ b/iamai/adapter/bililive/utils/bilibili_bot.py @@ -0,0 +1,134 @@ +import logging +from typing import List + +from aiohttp import ClientSession + +from utils.plugin import BotPlugin, DanmakuMessage, DanmakuPosition, SuperChatMessage +from utils.bilibili_api import ( + mute_user, + send_danmu, + add_badword, + room_slient, + user_cookies, + remove_badword, +) + +from ..api.blivedm import BLiveClient + + +class BiliLiveBot(BLiveClient): + BOT_PLUGINS: List[BotPlugin] = [] + + def __init__( + self, + room_id, + uid=0, + session: ClientSession = None, + heartbeat_interval=30, + ssl=True, + loop=None, + ): + super().__init__( + room_id, + session=session, + heartbeat_interval=heartbeat_interval, + ssl=ssl, + loop=loop, + ) + self.botid = uid + if session is None: + self._session._cookie_jar = user_cookies + + for bot_plugin in self.BOT_PLUGINS: + bot_plugin.botid = uid + bot_plugin.send_message = self.send_message + bot_plugin.add_badword = self.add_badword + bot_plugin.remove_badword = self.remove_badword + bot_plugin.mute_user = self.mute_user + bot_plugin.room_slient_on = self.room_slient_on + bot_plugin.room_slient_off = self.room_slient_off + + """ + -> bool: 返回操作成功與否 + + """ + + async def send_message( + self, + danmaku: str, + fontsize: int = 25, + color: int = 0xFFFFFF, + pos: DanmakuPosition = DanmakuPosition.NORMAL, + ) -> bool: + # don't know what the hell is bubble + return await send_danmu( + msg=danmaku, + fontsize=fontsize, + color=color, + pos=pos, + roomid=self.room_id, + bubble=0, + ) + + async def mute_user(self, uid: int) -> bool: + return await mute_user(uid, self.room_id) + + # "level" | "medal" | "member" | "off" + async def room_slient_on( + self, slientType: str = "off", minute: int = 0, level: int = 1 + ) -> bool: + return await room_slient(self.room_id, slientType, level, minute) + + async def room_slient_off(self) -> bool: + return await room_slient(self.room_id, "off", 1, 0) + + async def add_badword(self, badword: str) -> bool: + return await add_badword(self.room_id, badword) + + async def remove_badword(self, badword: str) -> bool: + return await remove_badword(self.room_id, badword) + + """ + 執行插件所有處理 + + """ + + async def on_command_received(self, cmd, data): + if self.is_bot_itself(cmd, data): + return + logging.debug(f"從房間 {self.room_id} 收到指令: {cmd}") + for bot_plugin in self.BOT_PLUGINS: + try: + await bot_plugin.on_command_received(cmd, data) + except Exception as e: + logging.warning( + f"执行插件 {get_type_name(bot_plugin)} 时出现错误({get_type_name(e)}): {e}" + ) + + async def _on_receive_popularity(self, popularity: int): + logging.debug(f"從房間 {self.room_id} 收到人氣值: {popularity}") + for bot_plugin in self.BOT_PLUGINS: + try: + await bot_plugin.on_receive_popularity(popularity) + except Exception as e: + logging.warning( + f"执行插件 {get_type_name(bot_plugin)} 时出现错误({get_type_name(e)}): {e}" + ) + + # 其餘的自己過濾 + def is_bot_itself(self, cmd, data) -> bool: + if cmd == "DANMU_MSG": + danmu = DanmakuMessage.from_command(data["info"]) + return danmu.uid == self.botid + elif cmd == "SUPER_CHAT_MESSAGE": + sc = SuperChatMessage.from_command(data["data"]) + return sc.uid == self.botid + elif cmd == "INTERACT_WORD": + uid = data["data"]["uid"] + return uid == self.botid + else: + return False + + +def get_type_name(ins: any) -> str: + return type(ins).__name__ diff --git a/iamai/adapter/bililive/utils/file_loader.py b/iamai/adapter/bililive/utils/file_loader.py new file mode 100644 index 00000000..d182a3f5 --- /dev/null +++ b/iamai/adapter/bililive/utils/file_loader.py @@ -0,0 +1,34 @@ +from pathlib import Path + +import yaml +from genericpath import exists + +DEFAULT_CONFIG_YML = {"debug": False, "roomid": 5651193} + + +def make_folder(folder: str) -> bool: + path = Path(folder) + if path.exists(): + return False + else: + path.mkdir(exist_ok=True, parents=True) + return True + + +def load_config(yml: str, default_values: dict) -> any: + make_folder("config") + path = f"config/{yml}" + data = {} + if exists(path): + with open(path, mode="r", encoding="utf-8") as f: + data = yaml.safe_load(f) + for k, v in data.items(): + default_values[k] = v + if default_values.keys() != data.keys(): + with open(path, mode="w", encoding="utf-8") as f: + yaml.safe_dump(default_values, f, allow_unicode=True) + return default_values + + +def load_default_config() -> any: + return load_config("config.yaml", DEFAULT_CONFIG_YML) diff --git a/iamai/adapter/bililive/utils/main.py b/iamai/adapter/bililive/utils/main.py new file mode 100644 index 00000000..5b8178b2 --- /dev/null +++ b/iamai/adapter/bililive/utils/main.py @@ -0,0 +1,68 @@ +import json +import asyncio +import logging + +from genericpath import exists +from aiohttp.client import ClientSession + +from utils.bilibili_bot import BiliLiveBot +from utils.plugins_loader import load_plugins +from utils.file_loader import make_folder, load_default_config +from utils.bilibili_api import login, get_cookies, user_cookies + + +async def start_bot(room: int): + cookies = {} + # 有上次的 session + session_exist = exists(SESSION_DATA_PATH) + if session_exist: + with open(SESSION_DATA_PATH) as f: + cookies = json.load(f) + # 加到 cookies + user_cookies.update_cookies(cookies) + async with ClientSession(cookie_jar=user_cookies) as session: + # 嘗試登入 + success = await login(session) + # 成功登入 + if success: + uid = get_cookies("DedeUserID") + jct = get_cookies("bili_jct") + + if uid == None or jct == None: + logging.error(f"获取 cookies 失败") + return + if not session_exist: + for cookie in user_cookies: + cookies[cookie.key] = cookie.value + + logging.debug(f"已储存 cookies: {cookies}") + with open(SESSION_DATA_PATH, mode="w") as f: + json.dump(cookies, f) + + bot = BiliLiveBot( + room_id=room, uid=int(uid), session=session, loop=session._loop + ) + await bot.init_room() + logging.info(f"機器人已啟動。") + await bot.start() + # while True: + # await asyncio.sleep(60) + await bot.close() + logging.info(f"機器人已關閉。") + else: + exit() + + +if __name__ == "__main__": + make_folder("data") + make_folder("config") + make_folder("plugins") + + data = load_default_config() + + logging.basicConfig(level=logging.INFO if not data["debug"] else logging.DEBUG) + + room = data["roomid"] + + BiliLiveBot.BOT_PLUGINS = load_plugins() + asyncio.run(start_bot(room)) diff --git a/iamai/adapter/bililive/utils/plugin.py b/iamai/adapter/bililive/utils/plugin.py new file mode 100644 index 00000000..35113bf4 --- /dev/null +++ b/iamai/adapter/bililive/utils/plugin.py @@ -0,0 +1,441 @@ +from enum import IntEnum +from abc import abstractmethod + +from utils.file_loader import load_config as load_plugin_config + + +def load_config(yml: str, default: dict = {}) -> any: + return load_plugin_config(yml, default) + + +class DanmakuPosition(IntEnum): + TOP = (5,) + BOTTOM = (4,) + NORMAL = 1 + + +class BotPlugin: + def __init__(self) -> None: + self.botid = -1 + + """ + 收到指令时 + """ + + @abstractmethod + async def on_command_received(self, cmd, data): + pass + + """ + 收到人气时 + """ + + @abstractmethod + async def on_receive_popularity(self, popularity: int): + pass + + """ + 发送弹幕 + """ + + async def send_message( + self, + danmaku: str, + fontsize: int = 25, + color: int = 0xFFFFFF, + pos: DanmakuPosition = DanmakuPosition.NORMAL, + ) -> bool: + pass + + """ + 以下所有操作全部需要房管权限 + + """ + + """ + 禁言用户 + """ + + async def mute_user(self, uid: int) -> bool: + pass + + """ + 全局禁言 + """ + + # "level" | "medal" | "member" + async def room_slient_on(self, slientType: str, minute: int, level: int) -> bool: + pass + + """ + 全局禁言关闭 + """ + + async def room_slient_off(self) -> bool: + pass + + """ + 新增屏蔽字 + """ + + async def add_badword(self, badword: str) -> bool: + pass + + """ + 删除屏蔽字 + """ + + async def remove_badword(self, badword: str) -> bool: + pass + + +""" +WS數據物件化 (from xfgryujk) +""" + + +class DanmakuMessage: + def __init__( + self, + mode, + font_size, + color, + timestamp, + rnd, + uid_crc32, + msg_type, + bubble, + msg, + uid, + uname, + admin, + vip, + svip, + urank, + mobile_verify, + uname_color, + medal_level, + medal_name, + runame, + room_id, + mcolor, + special_medal, + user_level, + ulevel_color, + ulevel_rank, + old_title, + title, + privilege_type, + ): + """ + :param mode: 弹幕显示模式(滚动、顶部、底部) + :param font_size: 字体尺寸 + :param color: 颜色 + :param timestamp: 时间戳 + :param rnd: 随机数 + :param uid_crc32: 用户ID文本的CRC32 + :param msg_type: 是否礼物弹幕(节奏风暴) + :param bubble: 右侧评论栏气泡 + + :param msg: 弹幕内容 + + :param uid: 用户ID + :param uname: 用户名 + :param admin: 是否房管 + :param vip: 是否月费老爷 + :param svip: 是否年费老爷 + :param urank: 用户身份,用来判断是否正式会员,猜测非正式会员为5000,正式会员为10000 + :param mobile_verify: 是否绑定手机 + :param uname_color: 用户名颜色 + + :param medal_level: 勋章等级 + :param medal_name: 勋章名 + :param runame: 勋章房间主播名 + :param room_id: 勋章房间ID + :param mcolor: 勋章颜色 + :param special_medal: 特殊勋章 + + :param user_level: 用户等级 + :param ulevel_color: 用户等级颜色 + :param ulevel_rank: 用户等级排名,>50000时为'>50000' + + :param old_title: 旧头衔 + :param title: 头衔 + + :param privilege_type: 舰队类型,0非舰队,1总督,2提督,3舰长 + """ + self.mode = mode + self.font_size = font_size + self.color = color + self.timestamp = timestamp + self.rnd = rnd + self.uid_crc32 = uid_crc32 + self.msg_type = msg_type + self.bubble = bubble + + self.msg = msg + + self.uid = uid + self.uname = uname + self.admin = admin + self.vip = vip + self.svip = svip + self.urank = urank + self.mobile_verify = mobile_verify + self.uname_color = uname_color + + self.medal_level = medal_level + self.medal_name = medal_name + self.runame = runame + self.room_id = room_id + self.mcolor = mcolor + self.special_medal = special_medal + + self.user_level = user_level + self.ulevel_color = ulevel_color + self.ulevel_rank = ulevel_rank + + self.old_title = old_title + self.title = title + + self.privilege_type = privilege_type + + @classmethod + def from_command(cls, info: dict): + return cls( + info[0][1], + info[0][2], + info[0][3], + info[0][4], + info[0][5], + info[0][7], + info[0][9], + info[0][10], + info[1], + *info[2][:8], + *(info[3][:6] or (0, "", "", 0, 0, 0)), + info[4][0], + info[4][2], + info[4][3], + *info[5][:2], + info[7], + ) + + +class GiftMessage: + def __init__( + self, + gift_name, + num, + uname, + face, + guard_level, + uid, + timestamp, + gift_id, + gift_type, + action, + price, + rnd, + coin_type, + total_coin, + ): + """ + :param gift_name: 礼物名 + :param num: 礼物数量 + :param uname: 用户名 + :param face: 用户头像URL + :param guard_level: 舰队等级,0非舰队,1总督,2提督,3舰长 + :param uid: 用户ID + :param timestamp: 时间戳 + :param gift_id: 礼物ID + :param gift_type: 礼物类型(未知) + :param action: 目前遇到的有'喂食'、'赠送' + :param price: 礼物单价瓜子数 + :param rnd: 随机数 + :param coin_type: 瓜子类型,'silver'或'gold' + :param total_coin: 总瓜子数 + """ + self.gift_name = gift_name + self.num = num + self.uname = uname + self.face = face + self.guard_level = guard_level + self.uid = uid + self.timestamp = timestamp + self.gift_id = gift_id + self.gift_type = gift_type + self.action = action + self.price = price + self.rnd = rnd + self.coin_type = coin_type + self.total_coin = total_coin + + @classmethod + def from_command(cls, data: dict): + return cls( + data["giftName"], + data["num"], + data["uname"], + data["face"], + data["guard_level"], + data["uid"], + data["timestamp"], + data["giftId"], + data["giftType"], + data["action"], + data["price"], + data["rnd"], + data["coin_type"], + data["total_coin"], + ) + + +class GuardBuyMessage: + def __init__( + self, + uid, + username, + guard_level, + num, + price, + gift_id, + gift_name, + start_time, + end_time, + ): + """ + :param uid: 用户ID + :param username: 用户名 + :param guard_level: 舰队等级,0非舰队,1总督,2提督,3舰长 + :param num: 数量 + :param price: 单价金瓜子数 + :param gift_id: 礼物ID + :param gift_name: 礼物名 + :param start_time: 开始时间戳? + :param end_time: 结束时间戳? + """ + self.uid = uid + self.username = username + self.guard_level = guard_level + self.num = num + self.price = price + self.gift_id = gift_id + self.gift_name = gift_name + self.start_time = start_time + self.end_time = end_time + + @classmethod + def from_command(cls, data: dict): + return cls( + data["uid"], + data["username"], + data["guard_level"], + data["num"], + data["price"], + data["gift_id"], + data["gift_name"], + data["start_time"], + data["end_time"], + ) + + +class SuperChatMessage: + def __init__( + self, + price, + message, + message_jpn, + start_time, + end_time, + time, + id_, + gift_id, + gift_name, + uid, + uname, + face, + guard_level, + user_level, + background_bottom_color, + background_color, + background_icon, + background_image, + background_price_color, + ): + """ + :param price: 价格(人民币) + :param message: 消息 + :param message_jpn: 消息日文翻译(目前只出现在SUPER_CHAT_MESSAGE_JPN) + :param start_time: 开始时间戳 + :param end_time: 结束时间戳 + :param time: 剩余时间 + :param id_: str,消息ID,删除时用 + :param gift_id: 礼物ID + :param gift_name: 礼物名 + :param uid: 用户ID + :param uname: 用户名 + :param face: 用户头像URL + :param guard_level: 舰队等级,0非舰队,1总督,2提督,3舰长 + :param user_level: 用户等级 + :param background_bottom_color: 底部背景色 + :param background_color: 背景色 + :param background_icon: 背景图标 + :param background_image: 背景图 + :param background_price_color: 背景价格颜色 + """ + self.price = price + self.message = message + self.message_jpn = message_jpn + self.start_time = start_time + self.end_time = end_time + self.time = time + self.id = id_ + self.gift_id = gift_id + self.gift_name = gift_name + self.uid = uid + self.uname = uname + self.face = face + self.guard_level = guard_level + self.user_level = user_level + self.background_bottom_color = background_bottom_color + self.background_color = background_color + self.background_icon = background_icon + self.background_image = background_image + self.background_price_color = background_price_color + + @classmethod + def from_command(cls, data: dict): + return cls( + data["price"], + data["message"], + data["message_trans"], + data["start_time"], + data["end_time"], + data["time"], + data["id"], + data["gift"]["gift_id"], + data["gift"]["gift_name"], + data["uid"], + data["user_info"]["uname"], + data["user_info"]["face"], + data["user_info"]["guard_level"], + data["user_info"]["user_level"], + data["background_bottom_color"], + data["background_color"], + data["background_icon"], + data["background_image"], + data["background_price_color"], + ) + + +class SuperChatDeleteMessage: + def __init__(self, ids): + """ + :param ids: 消息ID数组 + """ + self.ids = ids + + @classmethod + def from_command(cls, data: dict): + return cls(data["ids"]) diff --git a/iamai/adapter/bililive/utils/plugins_loader.py b/iamai/adapter/bililive/utils/plugins_loader.py new file mode 100644 index 00000000..51930125 --- /dev/null +++ b/iamai/adapter/bililive/utils/plugins_loader.py @@ -0,0 +1,34 @@ +import inspect +import logging +from os import listdir +from os.path import join, isfile +from importlib.machinery import SourceFileLoader + +from utils.plugin import BotPlugin +from utils.file_loader import make_folder + +PLUGINS_DIR = "plugins" +make_folder(PLUGINS_DIR) + + +def load_plugins(): + plugins = [ + f + for f in listdir(PLUGINS_DIR) + if isfile(join(PLUGINS_DIR, f)) and f.endswith(".py") + ] + bot_plugins = [] + for plugin in plugins: + try: + module = SourceFileLoader( + plugin[:-3], f"{PLUGINS_DIR}/{plugin}" + ).load_module() + for name, cs in inspect.getmembers(module, inspect.isclass): + if cs.__base__ == BotPlugin: + logging.info(f"正在加載插件 {plugin} ({name})") + bot_plugins.append(cs()) + break + except Exception as e: + logging.error(f"加載插件 {plugin} 時出現錯誤: {e}") + + return bot_plugins diff --git a/iamai/adapter/console/__init__.py b/iamai/adapter/console/__init__.py new file mode 100644 index 00000000..1580d15e --- /dev/null +++ b/iamai/adapter/console/__init__.py @@ -0,0 +1,37 @@ +"""Console适配器。""" + +from asyncio import Condition + +from iamai.event import Event +from iamai.adapter import Adapter + +from .config import Config +from .event import ConsoleEvent + +__all__ = ["ConsoleAdapter"] + + +class ConsoleAdapter(Adapter[ConsoleEvent, Config]): + """Console适配器。""" + + name: str = "console" + _msg: str = "" + _cond: Condition + Config = Config + + async def startup(self): + self._cond = Condition() + + async def run(self): + while not self.bot.should_exit.is_set(): + async with self._cond: + await self._cond.wait() + await self.handle_event( + ConsoleEvent(adapter=self, type="message", message=self._msg) + ) + + async def send(self, msg: str): + """此方法发送的消息会直接使此适配器产生一个事件。""" + async with self._cond: + self._msg = msg + self._cond.notify() diff --git a/iamai/adapter/console/config.py b/iamai/adapter/console/config.py new file mode 100644 index 00000000..02e14366 --- /dev/null +++ b/iamai/adapter/console/config.py @@ -0,0 +1,12 @@ +"""Console 适配器配置。""" + +from typing import Any, Dict + +from iamai.config import ConfigModel + + +class Config(ConfigModel): + """Console 配置类,将在适配器被加载时被混入到机器人主配置中。""" + + __config_name__ = "console" + show_raw: bool = False diff --git a/iamai/adapter/console/event.py b/iamai/adapter/console/event.py new file mode 100644 index 00000000..bab0c453 --- /dev/null +++ b/iamai/adapter/console/event.py @@ -0,0 +1,67 @@ +"""Console 适配器事件。""" + +import inspect +from datetime import datetime +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Type, + Tuple, + Union, + Literal, + TypeVar, + Optional, +) + +from pydantic import BaseModel + +from iamai.event import Event +from iamai.plugin import Plugin + +from .message import Message, ConsoleMessage, MessageSegment + +T_ConsoleEvent = TypeVar("T_ConsoleEvent", bound="ConsoleEvent") + +if TYPE_CHECKING: + from . import ConsoleAdapter # type: ignore[class] + +__all__ = ["ConsoleEvent", "MessageEvent", "User", "Robot"] + + +class User(BaseModel, frozen=True): + """用户""" + + id: str + avatar: str = "👤" + nickname: str = "User" + + +class Robot(User, frozen=True): + """机器人""" + + avatar: str = "🤖" + nickname: str = "Bot" + + +class ConsoleEvent(Event["ConsoleAdapter"]): + """Console 事件基类。""" + + message: str + + def get_event_description(self) -> str: + return str(self.dict()) + + def get_message(self) -> Message: + raise ValueError("Event has no message!") + + def get_user_id(self) -> str: + raise ValueError("Event has no user_id!") + + def get_session_id(self) -> str: + raise ValueError("Event has no session_id!") + + def is_tome(self) -> bool: + """获取事件是否与机器人有关的方法。""" + return True diff --git a/iamai/adapter/console/message.py b/iamai/adapter/console/message.py new file mode 100644 index 00000000..a94fa2b1 --- /dev/null +++ b/iamai/adapter/console/message.py @@ -0,0 +1,26 @@ +import re +from typing import Type, Union, Literal, Iterable, Optional, TypedDict + +from iamai.message import Message, MessageSegment + + +class ConsoleMessage(MessageSegment[None]): + """Console 适配器消息。""" + + @property + def _message_class(self) -> None: + return None + + def is_text(self) -> bool: + return self.type == "text" + + +def escape_tag(s: str) -> str: + """用于记录带颜色日志时转义 `` 类型特殊标签 + + 参考: [loguru color 标签](https://loguru.readthedocs.io/en/stable/api/logger.html#color) + + 参数: + s: 需要转义的字符串 + """ + return re.sub(r"\s]*)>", r"\\\g<0>", s) diff --git a/iamai/adapter/cqhttp/__init__.py b/iamai/adapter/cqhttp/__init__.py new file mode 100644 index 00000000..781fbda0 --- /dev/null +++ b/iamai/adapter/cqhttp/__init__.py @@ -0,0 +1,294 @@ +"""CQHTTP 协议适配器。 + +本适配器适配了 OneBot v11 协议。 +协议详情请参考:[OneBot](https://github.com/howmanybots/onebot/blob/master/README.md)。 +""" + +import asyncio +import inspect +import json +import sys +import time +from functools import partial +from typing import ( + Any, + Awaitable, + Callable, + ClassVar, + Dict, + Literal, + Optional, + Tuple, + Type, +) + +import aiohttp +from aiohttp import web + +from iamai.adapter.utils import WebSocketAdapter +from iamai.log import logger +from iamai.message import BuildMessageType +from iamai.utils import PydanticEncoder + +from . import event +from .config import Config +from .event import CQHTTPEvent, HeartbeatMetaEvent, LifecycleMetaEvent, MetaEvent +from .exceptions import ActionFailed, ApiNotAvailable, ApiTimeout, NetworkError +from .message import CQHTTPMessage, CQHTTPMessageSegment + +__all__ = ["CQHTTPAdapter"] + +EventModels = Dict[ + Tuple[Optional[str], Optional[str], Optional[str]], Type[CQHTTPEvent] +] + +DEFAULT_EVENT_MODELS: EventModels = {} +for _, model in inspect.getmembers(event, inspect.isclass): + if issubclass(model, CQHTTPEvent): + DEFAULT_EVENT_MODELS[model.get_event_type()] = model + + +class CQHTTPAdapter(WebSocketAdapter[CQHTTPEvent, Config]): + """CQHTTP 协议适配器。""" + + name = "cqhttp" + Config = Config + + event_models: ClassVar[EventModels] = DEFAULT_EVENT_MODELS + + _api_response: Dict[str, Any] + _api_response_cond: asyncio.Condition + _api_id: int = 0 + + def __getattr__(self, item: str) -> Callable[..., Awaitable[Any]]: + """用于调用 API。可以直接通过访问适配器的属性访问对应名称的 API。 + + Args: + item: API 名称。 + + Returns: + 用于调用 API 的函数。 + """ + return partial(self.call_api, item) + + async def startup(self) -> None: + """初始化适配器。""" + adapter_type = self.config.adapter_type + if adapter_type == "ws-reverse": + adapter_type = "reverse-ws" + self.adapter_type = adapter_type + self.host = self.config.host + self.port = self.config.port + self.url = self.config.url + self.reconnect_interval = self.config.reconnect_interval + self._api_response_cond = asyncio.Condition() + await super().startup() + + async def reverse_ws_connection_hook(self) -> None: + """反向 WebSocket 连接建立时的钩子函数。""" + logger.info("WebSocket connected!") + if self.config.access_token: + assert isinstance(self.websocket, web.WebSocketResponse) + if ( + self.websocket.headers.get("Authorization", "") + != f"Bearer {self.config.access_token}" + ): + await self.websocket.close() + + async def websocket_connect(self) -> None: + """创建正向 WebSocket 连接。""" + assert self.session is not None + logger.info("Tying to connect to WebSocket server...") + async with self.session.ws_connect( + f"ws://{self.host}:{self.port}/", + headers=( + {"Authorization": f"Bearer {self.config.access_token}"} + if self.config.access_token + else None + ), + ) as self.websocket: + await self.handle_websocket() + + async def handle_websocket_msg(self, msg: aiohttp.WSMessage) -> None: + """处理 WebSocket 消息。""" + assert self.websocket is not None + if msg.type == aiohttp.WSMsgType.TEXT: + try: + msg_dict = msg.json() + except json.JSONDecodeError as e: + self.bot.error_or_exception( + "WebSocket message parsing error, not json:", e + ) + return + + if "post_type" in msg_dict: + await self.handle_cqhttp_event(msg_dict) + else: + async with self._api_response_cond: + self._api_response = msg_dict + self._api_response_cond.notify_all() + + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error( + f"WebSocket connection closed " + f"with exception {self.websocket.exception()!r}" + ) + + def _get_api_echo(self) -> int: + self._api_id = (self._api_id + 1) % sys.maxsize + return self._api_id + + @classmethod + def add_event_model(cls, event_model: Type[CQHTTPEvent]) -> None: + """添加自定义事件模型,事件模型类必须继承于 `CQHTTPEvent`。 + + Args: + event_model: 事件模型类。 + """ + cls.event_models[event_model.get_event_type()] = event_model + + @classmethod + def get_event_model( + cls, + post_type: Optional[str], + detail_type: Optional[str], + sub_type: Optional[str], + ) -> Type[CQHTTPEvent]: + """根据接收到的消息类型返回对应的事件类。 + + Args: + post_type: 请求类型。 + detail_type: 事件类型。 + sub_type: 子类型。 + + Returns: + 对应的事件类。 + """ + event_model = ( + cls.event_models.get((post_type, detail_type, sub_type), None) + or cls.event_models.get((post_type, detail_type, None), None) + or cls.event_models.get((post_type, None, None), None) + ) + return event_model or cls.event_models[(None, None, None)] + + async def handle_cqhttp_event(self, msg: Dict[str, Any]) -> None: + """处理 CQHTTP 事件。 + + Args: + msg: 接收到的信息。 + """ + post_type = msg.get("post_type", None) + if post_type is None: + event_class = self.get_event_model(None, None, None) + else: + event_class = self.get_event_model( + post_type, + msg.get(post_type + "_type", None), + msg.get("sub_type", None), + ) + + cqhttp_event = event_class(adapter=self, **msg) + + if cqhttp_event.post_type == "meta_event": + # meta_event 不交由插件处理 + assert isinstance(cqhttp_event, MetaEvent) + if cqhttp_event.meta_event_type == "lifecycle": + assert isinstance(cqhttp_event, LifecycleMetaEvent) + if cqhttp_event.sub_type == "connect": + logger.info( + f"WebSocket connection " + f"from CQHTTP Bot {msg.get('self_id')} accepted!" + ) + elif cqhttp_event.meta_event_type == "heartbeat": + assert isinstance(cqhttp_event, HeartbeatMetaEvent) + if cqhttp_event.status.good and cqhttp_event.status.online: + pass + else: + logger.error( + f"CQHTTP Bot status is not good: {cqhttp_event.status.model_dump()}" + ) + else: + await self.handle_event(cqhttp_event) + + async def call_api(self, api: str, **params: Any) -> Any: + """调用 CQHTTP API,协程会等待直到获得 API 响应。 + + Args: + api: API 名称。 + **params: API 参数。 + + Returns: + API 响应中的 data 字段。 + + Raises: + NetworkError: 网络错误。 + ApiNotAvailable: API 请求响应 404, API 不可用。 + ActionFailed: API 请求响应 failed, API 操作失败。 + ApiTimeout: API 请求响应超时。 + """ + assert self.websocket is not None + api_echo = self._get_api_echo() + try: + await self.websocket.send_str( + json.dumps( + {"action": api, "params": params, "echo": api_echo}, + cls=PydanticEncoder, + ) + ) + except Exception as e: + raise NetworkError from e + + start_time = time.time() + while not self.bot.should_exit.is_set(): + if time.time() - start_time > self.config.api_timeout: + break + async with self._api_response_cond: + try: + await asyncio.wait_for( + self._api_response_cond.wait(), + timeout=start_time + self.config.api_timeout - time.time(), + ) + except asyncio.TimeoutError: + break + if self._api_response["echo"] == api_echo: + if self._api_response.get("retcode") == ApiNotAvailable.ERROR_CODE: + raise ApiNotAvailable(resp=self._api_response) + if self._api_response.get("status") == "failed": + raise ActionFailed(resp=self._api_response) + return self._api_response.get("data") + + if not self.bot.should_exit.is_set(): + raise ApiTimeout + return None + + async def send( + self, + message_: BuildMessageType[CQHTTPMessageSegment], + message_type: Literal["private", "group"], + id_: int, + ) -> Any: + """发送消息,调用 `send_private_msg` 或 `send_group_msg` API 发送消息。 + + Args: + message_: 消息内容,可以是 `str`, `Mapping`, `Iterable[Mapping]`, + `CQHTTPMessageSegment`, `CQHTTPMessage。` + 将使用 `CQHTTPMessage` 进行封装。 + message_type: 消息类型。应该是 "private" 或者 "group"。 + id_: 发送对象的 ID, QQ 号码或者群号码。 + + Returns: + API 响应。 + + Raises: + TypeError: `message_type` 不是 "private" 或 "group"。 + ...: 同 `call_api()` 方法。 + """ + if message_type == "private": + return await self.send_private_msg( + user_id=id_, message=CQHTTPMessage(message_) + ) + if message_type == "group": + return await self.send_group_msg( + group_id=id_, message=CQHTTPMessage(message_) + ) + raise TypeError('message_type must be "private" or "group"') diff --git a/iamai/adapter/cqhttp/config.py b/iamai/adapter/cqhttp/config.py new file mode 100644 index 00000000..984bda8c --- /dev/null +++ b/iamai/adapter/cqhttp/config.py @@ -0,0 +1,30 @@ +"""CQHTTP 适配器配置。""" + +from typing import Literal + +from iamai.config import ConfigModel + +__all__ = ["Config"] + + +class Config(ConfigModel): + """CQHTTP 配置类,将在适配器被加载时被混入到机器人主配置中。 + + Attributes: + adapter_type: 适配器类型,需要和协议端配置相同。 + host: 本机域名。 + port: 监听的端口。 + url: WebSocket 路径,需和协议端配置相同。 + reconnect_interval: 重连等待时间。 + api_timeout: 进行 API 调用时等待返回响应的超时时间。 + access_token: 鉴权。 + """ + + __config_name__ = "cqhttp" + adapter_type: Literal["ws", "reverse-ws", "ws-reverse"] = "reverse-ws" + host: str = "127.0.0.1" + port: int = 8080 + url: str = "/cqhttp/ws" + reconnect_interval: int = 3 + api_timeout: int = 1000 + access_token: str = "" diff --git a/iamai/adapter/cqhttp/event.py b/iamai/adapter/cqhttp/event.py new file mode 100644 index 00000000..7680e63d --- /dev/null +++ b/iamai/adapter/cqhttp/event.py @@ -0,0 +1,455 @@ +"""CQHTTP 适配器事件。""" + +# pyright: reportIncompatibleVariableOverride=false + +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Literal, + Optional, + Tuple, + get_args, + get_origin, +) +from typing_extensions import Self + +from pydantic import BaseModel, ConfigDict, Field +from pydantic.fields import FieldInfo + +from iamai.event import Event +from iamai.event import MessageEvent as BaseMessageEvent +from iamai.message import BuildMessageType + +from .message import CQHTTPMessage, CQHTTPMessageSegment + +if TYPE_CHECKING: + from . import CQHTTPAdapter + + +class Sender(BaseModel): + """发送人信息""" + + user_id: Optional[int] = None + nickname: Optional[str] = None + card: Optional[str] = None + sex: Optional[Literal["male", "female", "unknown"]] = None + age: Optional[int] = None + area: Optional[str] = None + level: Optional[str] = None + role: Optional[str] = None + title: Optional[str] = None + + +class Anonymous(BaseModel): + """匿名信息""" + + id: int + name: str + flag: str + + +class File(BaseModel): + """文件信息""" + + id: str + name: str + size: int + busid: int + + +class Status(BaseModel): + """状态信息""" + + model_config = ConfigDict(extra="allow") + + online: bool + good: bool + + +def _get_literal_field(field: Optional[FieldInfo]) -> Optional[str]: + if field is None: + return None + annotation = field.annotation + if annotation is None or get_origin(annotation) is not Literal: + return None + literal_values = get_args(annotation) + if len(literal_values) != 1: + return None + return literal_values[0] + + +class CQHTTPEvent(Event["CQHTTPAdapter"]): + """CQHTTP 事件基类""" + + __event__ = "" + type: Optional[str] = Field(alias="post_type") + time: int + self_id: int + post_type: str + + @property + def to_me(self) -> bool: + """当前事件的 `user_id` 是否等于 `self_id`。""" + return getattr(self, "user_id", None) == self.self_id + + @classmethod + def get_event_type(cls) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """获取事件类型。 + + Returns: + 事件类型。 + """ + post_type = _get_literal_field(cls.model_fields.get("post_type", None)) + if post_type is None: + return (None, None, None) + return ( + post_type, + _get_literal_field(cls.model_fields.get(post_type + "_type", None)), + _get_literal_field(cls.model_fields.get("sub_type", None)), + ) + + +class MessageEvent(CQHTTPEvent, BaseMessageEvent["CQHTTPAdapter"]): + """消息事件""" + + __event__ = "message" + post_type: Literal["message"] + message_type: Literal["private", "group"] + sub_type: str + message_id: int + user_id: int + message: CQHTTPMessage + raw_message: str + font: int + sender: Sender + + def __repr__(self) -> str: + """返回消息事件的描述。 + + Returns: + 消息事件的描述。 + """ + return f'Event<{self.type}>: "{self.message}"' + + def get_plain_text(self) -> str: + """获取消息的纯文本内容。 + + Returns: + 消息的纯文本内容。 + """ + return self.message.get_plain_text() + + async def reply( + self, message: BuildMessageType[CQHTTPMessageSegment] + ) -> Dict[str, Any]: + """回复消息。 + + Args: + message: 回复消息的内容,同 `call_api()` 方法。 + + Returns: + API 请求响应。 + """ + raise NotImplementedError + + async def is_same_sender(self, other: Self) -> bool: + """判断自身和另一个事件是否是同一个发送者。 + + Args: + other: 另一个事件。 + + Returns: + 是否是同一个发送者。 + """ + return self.sender.user_id == other.sender.user_id + + +class PrivateMessageEvent(MessageEvent): + """私聊消息""" + + __event__ = "message.private" + message_type: Literal["private"] + sub_type: Literal["friend", "group", "other"] + + async def reply( + self, message: BuildMessageType[CQHTTPMessageSegment] + ) -> Dict[str, Any]: + """回复消息。 + + Args: + message: 回复消息的内容,同 `call_api()` 方法。 + + Returns: + API 请求响应。 + """ + return await self.adapter.send_private_msg( + user_id=self.user_id, message=CQHTTPMessage(message) + ) + + +class GroupMessageEvent(MessageEvent): + """群消息""" + + __event__ = "message.group" + message_type: Literal["group"] + sub_type: Literal["normal", "anonymous", "notice"] + group_id: int + anonymous: Optional[Anonymous] = None + + async def reply( + self, message: BuildMessageType[CQHTTPMessageSegment] + ) -> Dict[str, Any]: + """回复消息。 + + Args: + message: 回复消息的内容,同 `call_api()` 方法。 + + Returns: + API 请求响应。 + """ + return await self.adapter.send_group_msg( + group_id=self.group_id, message=CQHTTPMessage(message) + ) + + +class NoticeEvent(CQHTTPEvent): + """通知事件""" + + __event__ = "notice" + post_type: Literal["notice"] + notice_type: str + + +class GroupUploadNoticeEvent(NoticeEvent): + """群文件上传""" + + __event__ = "notice.group_upload" + notice_type: Literal["group_upload"] + user_id: int + group_id: int + file: File + + +class GroupAdminNoticeEvent(NoticeEvent): + """群管理员变动""" + + __event__ = "notice.group_admin" + notice_type: Literal["group_admin"] + sub_type: Literal["set", "unset"] + user_id: int + group_id: int + + +class GroupDecreaseNoticeEvent(NoticeEvent): + """群成员减少""" + + __event__ = "notice.group_decrease" + notice_type: Literal["group_decrease"] + sub_type: Literal["leave", "kick", "kick_me"] + group_id: int + operator_id: int + user_id: int + + +class GroupIncreaseNoticeEvent(NoticeEvent): + """群成员增加""" + + __event__ = "notice.group_increase" + notice_type: Literal["group_increase"] + sub_type: Literal["approve", "invite"] + group_id: int + operator_id: int + user_id: int + + +class GroupBanNoticeEvent(NoticeEvent): + """群禁言""" + + __event__ = "notice.group_ban" + notice_type: Literal["group_ban"] + sub_type: Literal["ban", "lift_ban"] + group_id: int + operator_id: int + user_id: int + duration: int + + +class FriendAddNoticeEvent(NoticeEvent): + """好友添加""" + + __event__ = "notice.friend_add" + notice_type: Literal["friend_add"] + user_id: int + + +class GroupRecallNoticeEvent(NoticeEvent): + """群消息撤回""" + + __event__ = "notice.group_recall" + notice_type: Literal["group_recall"] + group_id: int + operator_id: int + user_id: int + message_id: int + + +class FriendRecallNoticeEvent(NoticeEvent): + """好友消息撤回""" + + __event__ = "notice.friend_recall" + notice_type: Literal["friend_recall"] + user_id: int + message_id: int + + +class NotifyEvent(NoticeEvent): + """提醒事件""" + + __event__ = "notice.notify" + notice_type: Literal["notify"] + sub_type: str + group_id: Optional[int] = None + user_id: int + + +class PokeNotifyEvent(NotifyEvent): + """戳一戳""" + + __event__ = "notice.notify.poke" + sub_type: Literal["poke"] + target_id: int + group_id: Optional[int] = None + + +class GroupLuckyKingNotifyEvent(NotifyEvent): + """群红包运气王""" + + __event__ = "notice.notify.lucky_king" + sub_type: Literal["lucky_king"] + group_id: int + target_id: int + + +class GroupHonorNotifyEvent(NotifyEvent): + """群成员荣誉变更""" + + __event__ = "notice.notify.honor" + sub_type: Literal["honor"] + group_id: int + honor_type: Literal["talkative", "performer", "emotion"] + + +class RequestEvent(CQHTTPEvent): + """请求事件""" + + __event__ = "request" + post_type: Literal["request"] + request_type: str + + async def approve(self) -> Dict[str, Any]: + """同意请求。 + + Returns: + API 请求响应。 + """ + raise NotImplementedError + + async def refuse(self) -> Dict[str, Any]: + """拒绝请求。 + + Returns: + API 请求响应。 + """ + raise NotImplementedError + + +class FriendRequestEvent(RequestEvent): + """加好友请求""" + + __event__ = "request.friend" + request_type: Literal["friend"] + user_id: int + comment: str + flag: str + + async def approve(self, remark: str = "") -> Dict[str, Any]: + """同意请求。 + + Args: + remark: 好友备注。 + + Returns: + API 请求响应。 + """ + return await self.adapter.set_friend_add_request( + flag=self.flag, approve=True, remark=remark + ) + + async def refuse(self) -> Dict[str, Any]: + """拒绝请求。 + + Returns: + API 请求响应。 + """ + return await self.adapter.set_friend_add_request(flag=self.flag, approve=False) + + +class GroupRequestEvent(RequestEvent): + """加群请求 / 邀请""" + + __event__ = "request.group" + request_type: Literal["group"] + sub_type: Literal["add", "invite"] + group_id: int + user_id: int + comment: str + flag: str + + async def approve(self) -> Dict[str, Any]: + """同意请求。 + + Returns: + API 请求响应。 + """ + return await self.adapter.set_group_add_request( + flag=self.flag, sub_type=self.sub_type, approve=True + ) + + async def refuse(self, reason: str = "") -> Dict[str, Any]: + """拒绝请求。 + + Args: + reason: 拒绝原因。 + + Returns: + API 请求响应。 + """ + return await self.adapter.set_group_add_request( + flag=self.flag, sub_type=self.sub_type, approve=False, reason=reason + ) + + +class MetaEvent(CQHTTPEvent): + """元事件""" + + __event__ = "meta_event" + post_type: Literal["meta_event"] + meta_event_type: str + + +class LifecycleMetaEvent(MetaEvent): + """生命周期""" + + __event__ = "meta_event.lifecycle" + meta_event_type: Literal["lifecycle"] + sub_type: Literal["enable", "disable", "connect"] + + +class HeartbeatMetaEvent(MetaEvent): + """心跳""" + + __event__ = "meta_event.heartbeat" + meta_event_type: Literal["heartbeat"] + status: Status + interval: int diff --git a/iamai/adapter/cqhttp/exceptions.py b/iamai/adapter/cqhttp/exceptions.py new file mode 100644 index 00000000..6b4211e0 --- /dev/null +++ b/iamai/adapter/cqhttp/exceptions.py @@ -0,0 +1,43 @@ +"""CQHTTP 适配器异常。""" + +from typing import Any, ClassVar, Dict + +from iamai.exceptions import AdapterException + +__all__ = [ + "CQHTTPException", + "NetworkError", + "ActionFailed", + "ApiNotAvailable", + "ApiTimeout", +] + + +class CQHTTPException(AdapterException): + """CQHTTP 异常基类。""" + + +class NetworkError(CQHTTPException): + """网络异常。""" + + +class ActionFailed(CQHTTPException): + """API 请求成功响应,但响应表示 API 操作失败。""" + + def __init__(self, resp: Dict[str, Any]) -> None: + """初始化。 + + Args: + resp: 返回的响应。 + """ + self.resp = resp + + +class ApiNotAvailable(ActionFailed): + """API 请求返回 404,表示当前请求的 API 不可用或不存在。""" + + ERROR_CODE: ClassVar[int] = 1404 + + +class ApiTimeout(CQHTTPException): + """API 请求响应超时。""" diff --git a/iamai/adapter/cqhttp/message.py b/iamai/adapter/cqhttp/message.py new file mode 100644 index 00000000..f1ea6585 --- /dev/null +++ b/iamai/adapter/cqhttp/message.py @@ -0,0 +1,280 @@ +"""CQHTTP 适配器消息。""" + +from typing import Literal, Optional, Type, Union +from typing_extensions import Self + +from iamai.message import Message, MessageSegment + +__all__ = ["CQHTTPMessage", "CQHTTPMessageSegment", "escape"] + + +class CQHTTPMessage(Message["CQHTTPMessageSegment"]): + """CQHTTP 消息。""" + + @classmethod + def get_segment_class(cls) -> Type["CQHTTPMessageSegment"]: + """获取消息字段类。 + + Returns: + 消息字段类。 + """ + return CQHTTPMessageSegment + + +class CQHTTPMessageSegment(MessageSegment["CQHTTPMessage"]): + """CQHTTP 消息字段。""" + + @classmethod + def get_message_class(cls) -> Type[CQHTTPMessage]: + """获取消息类。 + + Returns: + 消息类。 + """ + return CQHTTPMessage + + @classmethod + def from_str(cls, msg: str) -> Self: + """用于将 `str` 转换为消息字段。 + + Args: + msg: 要解析为消息字段的数据。 + + Returns: + 由 `str` 转换的消息字段。 + """ + return cls.text(msg) + + def __str__(self) -> str: + """返回消息字段的文本表示。 + + Returns: + 消息字段的文本表示。 + """ + if self.type == "text": + return self.data.get("text", "") + return self.get_cqcode() + + def get_cqcode(self) -> str: + """获取此消息字段的 CQ 码形式。 + + Returns: + 此消息字段的 CQ 码形式。 + """ + if self.type == "text": + return escape(self.data.get("text", ""), escape_comma=False) + + params = ",".join( + [f"{k}={escape(str(v))}" for k, v in self.data.items() if v is not None] + ) + return f'[CQ:{self.type}{"," if params else ""}{params}]' + + @classmethod + def text(cls, text: str) -> Self: + """纯文本""" + return cls(type="text", data={"text": text}) + + @classmethod + def face(cls, id_: int) -> Self: + """QQ 表情""" + return cls(type="face", data={"id": str(id_)}) + + @classmethod + def image( + cls, + file: str, + type_: Optional[Literal["flash"]] = None, + cache: bool = True, + proxy: bool = True, + timeout: Optional[int] = None, + ) -> Self: + """图片""" + return cls( + type="image", + data={ + "file": file, + "type": type_, + "cache": cache, + "proxy": proxy, + "timeout": timeout, + }, + ) + + @classmethod + def record( + cls, + file: str, + magic: bool = False, + cache: bool = True, + proxy: bool = True, + timeout: Optional[int] = None, + ) -> Self: + """语音""" + return cls( + type="record", + data={ + "file": file, + "magic": magic, + "cache": cache, + "proxy": proxy, + "timeout": timeout, + }, + ) + + @classmethod + def video( + cls, + file: str, + cache: bool = True, + proxy: bool = True, + timeout: Optional[int] = None, + ) -> Self: + """短视频""" + return cls( + type="video", + data={"file": file, "cache": cache, "proxy": proxy, "timeout": timeout}, + ) + + @classmethod + def at(cls, qq: Union[int, Literal["all"]]) -> Self: # pylint: disable=invalid-name + """@某人""" + return cls(type="at", data={"qq": str(qq)}) + + @classmethod + def rps(cls) -> Self: + """猜拳魔法表情""" + return cls(type="rps", data={}) + + @classmethod + def dice(cls) -> Self: + """掷骰子魔法表情""" + return cls(type="dice", data={}) + + @classmethod + def shake(cls) -> Self: + """窗口抖动 (戳一戳)""" + return cls(type="shake", data={}) + + @classmethod + def poke(cls, type_: str, id_: int) -> Self: + """戳一戳""" + return cls(type="poke", data={"type": type_, "id": str(id_)}) + + @classmethod + def anonymous(cls, ignore: Optional[bool] = None) -> Self: + """匿名发消息""" + return cls(type="anonymous", data={"ignore": ignore}) + + @classmethod + def share( + cls, + url: str, + title: str, + content: Optional[str] = None, + image: Optional[str] = None, + ) -> Self: + """链接分享""" + return cls( + type="share", + data={"url": url, "title": title, "content": content, "image": image}, + ) + + @classmethod + def contact(cls, type_: Literal["qq", "group"], id_: int) -> Self: + """推荐好友/推荐群""" + return cls(type="contact", data={"type": type_, "id": str(id_)}) + + @classmethod + def contact_friend(cls, id_: int) -> Self: + """推荐好友""" + return cls(type="contact", data={"type": "qq", "id": str(id_)}) + + @classmethod + def contact_group(cls, id_: int) -> Self: + """推荐好友""" + return cls(type="contact", data={"type": "group", "id": str(id_)}) + + @classmethod + def location( + cls, lat: float, lon: float, title: Optional[str], content: Optional[str] = None + ) -> Self: + """位置""" + return cls( + type="location", + data={"lat": str(lat), "lon": str(lon), "title": title, "content": content}, + ) + + @classmethod + def music(cls, type_: Literal["qq", "163", "xm"], id_: int) -> Self: + """音乐分享""" + return cls(type="music", data={"type": type_, "id": str(id_)}) + + @classmethod + def music_custom( + cls, + url: str, + audio: str, + title: str, + content: Optional[str] = None, + image: Optional[str] = None, + ) -> Self: + """音乐自定义分享""" + return cls( + type="music", + data={ + "type": "custom", + "url": url, + "audio": audio, + "title": title, + "content": content, + "image": image, + }, + ) + + @classmethod + def reply(cls, id_: int) -> Self: + """回复""" + return cls(type="reply", data={"id": str(id_)}) + + @classmethod + def node(cls, id_: int) -> Self: + """合并转发节点""" + return cls(type="node", data={"id": str(id_)}) + + @classmethod + def node_custom(cls, user_id: int, nickname: str, content: "CQHTTPMessage") -> Self: + """合并转发自定义节点""" + return cls( + type="node", + data={ + "user_id": str(user_id), + "nickname": str(nickname), + "content": content, + }, + ) + + @classmethod + def xml_message(cls, data: str) -> Self: + """XML 消息""" + return cls(type="xml", data={"data": data}) + + @classmethod + def json_message(cls, data: str) -> Self: + """JSON 消息""" + return cls(type="json", data={"data": data}) + + +def escape(string: str, *, escape_comma: bool = True) -> str: + """对 CQ 码中的特殊字符进行转义。 + + Args: + string: 待转义的字符串。 + escape_comma: 是否转义 `,`。 + + Returns: + 转义后的字符串。 + """ + string = string.replace("&", "&").replace("[", "[").replace("]", "]") + if escape_comma: + string = string.replace(",", ",") + return string diff --git a/iamai/adapter/gensokyo/__init__.py b/iamai/adapter/gensokyo/__init__.py new file mode 100644 index 00000000..ecc4dfc9 --- /dev/null +++ b/iamai/adapter/gensokyo/__init__.py @@ -0,0 +1,314 @@ +"""gensokyo *ob11 协议适配器。 + +本适配器适配了 gensokyo obv11 协议。 +协议详情请参考:[OneBot](https://github.com/howmanybots/onebot/blob/master/README.md)。 +""" + +import asyncio +import inspect +import json +import sys +import time +from functools import partial +from typing import ( + Any, + Awaitable, + Callable, + ClassVar, + Dict, + Literal, + Optional, + Tuple, + Type, +) + +import aiohttp +from aiohttp import web + +from iamai.adapter.utils import WebSocketAdapter +from iamai.log import logger +from iamai.message import BuildMessageType +from iamai.utils import PydanticEncoder + +from . import event +from .config import Config +from .event import GSKEvent, HeartbeatMetaEvent, LifecycleMetaEvent, MetaEvent +from .exceptions import ActionFailed, ApiNotAvailable, ApiTimeout, NetworkError +from .message import GSKMessage, GSKMessageSegment + +__all__ = ["GSKAdapter"] + +EventModels = Dict[Tuple[Optional[str], Optional[str], Optional[str]], Type[GSKEvent]] + +DEFAULT_EVENT_MODELS: EventModels = {} +for _, model in inspect.getmembers(event, inspect.isclass): + if issubclass(model, GSKEvent): + DEFAULT_EVENT_MODELS[model.get_event_type()] = model + + +class GSKAdapter(WebSocketAdapter[GSKEvent, Config]): + """GSK 协议适配器。""" + + name = "gensokyo" + Config = Config + + event_models: ClassVar[EventModels] = DEFAULT_EVENT_MODELS + + _api_response: Dict[str, Any] + _api_response_cond: asyncio.Condition + _api_id: int = 0 + _get_access_token_url: str = "https://bots.qq.com/app/getAppAccessToken" + + def __getattr__(self, item: str) -> Callable[..., Awaitable[Any]]: + """用于调用 API。可以直接通过访问适配器的属性访问对应名称的 API。 + + Args: + item: API 名称。 + + Returns: + 用于调用 API 的函数。 + """ + return partial(self.call_api, item) + + async def startup(self) -> None: + """初始化适配器。""" + adapter_type = self.config.adapter_type + if adapter_type == "ws-reverse": + adapter_type = "reverse-ws" + self.adapter_type = adapter_type + self.host = self.config.host + self.port = self.config.port + self.url = self.config.url + self.reconnect_interval = self.config.reconnect_interval + self._api_response_cond = asyncio.Condition() + # if not self.config.access_token: + # self.config.access_token = await self.get_access_token() + await super().startup() + + async def reverse_ws_connection_hook(self) -> None: + """反向 WebSocket 连接建立时的钩子函数。""" + logger.info("WebSocket connected!") + if self.config.access_token: + assert isinstance(self.websocket, web.WebSocketResponse) + if ( + self.websocket.headers.get("Authorization", "") + != f"Bearer {self.config.access_token}" + ): + await self.websocket.close() + + async def websocket_connect(self) -> None: + """创建正向 WebSocket 连接。""" + assert self.session is not None + logger.info("Tying to connect to WebSocket server...") + async with self.session.ws_connect( + f"ws://{self.host}:{self.port}/", + headers=( + {"Authorization": f"Bearer {self.config.access_token}"} + if self.config.access_token + else None + ), + ) as self.websocket: + await self.handle_websocket() + + async def handle_websocket_msg(self, msg: aiohttp.WSMessage) -> None: + """处理 WebSocket 消息。""" + assert self.websocket is not None + if msg.type == aiohttp.WSMsgType.TEXT: + try: + msg_dict = msg.json() + except json.JSONDecodeError as e: + self.bot.error_or_exception( + "WebSocket message parsing error, not json:", e + ) + return + + if "post_type" in msg_dict: + await self.handle_gsk_event(msg_dict) + else: + async with self._api_response_cond: + self._api_response = msg_dict + self._api_response_cond.notify_all() + + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error( + f"WebSocket connection closed " + f"with exception {self.websocket.exception()!r}" + ) + + def _get_api_echo(self) -> int: + self._api_id = (self._api_id + 1) % sys.maxsize + return self._api_id + + @classmethod + def add_event_model(cls, event_model: Type[GSKEvent]) -> None: + """添加自定义事件模型,事件模型类必须继承于 `GSKEvent`。 + + Args: + event_model: 事件模型类。 + """ + cls.event_models[event_model.get_event_type()] = event_model + + @classmethod + def get_event_model( + cls, + post_type: Optional[str], + detail_type: Optional[str], + sub_type: Optional[str], + ) -> Type[GSKEvent]: + """根据接收到的消息类型返回对应的事件类。 + + Args: + post_type: 请求类型。 + detail_type: 事件类型。 + sub_type: 子类型。 + + Returns: + 对应的事件类。 + """ + event_model = ( + cls.event_models.get((post_type, detail_type, sub_type), None) + or cls.event_models.get((post_type, detail_type, None), None) + or cls.event_models.get((post_type, None, None), None) + ) + return event_model or cls.event_models[(None, None, None)] + + async def handle_gsk_event(self, msg: Dict[str, Any]) -> None: + """处理 GSK 事件。 + + Args: + msg: 接收到的信息。 + """ + post_type = msg.get("post_type", None) + if post_type is None: + event_class = self.get_event_model(None, None, None) + else: + event_class = self.get_event_model( + post_type, + msg.get(post_type + "_type", None), + msg.get("sub_type", None), + ) + + gsk_event = event_class(adapter=self, **msg) + + if gsk_event.post_type == "meta_event": + # meta_event 不交由插件处理 + assert isinstance(gsk_event, MetaEvent) + if gsk_event.meta_event_type == "lifecycle": + assert isinstance(gsk_event, LifecycleMetaEvent) + if gsk_event.sub_type == "connect": + logger.info( + f"WebSocket connection " + f"from gensokyo Bot {msg.get('self_id')} accepted!" + ) + elif gsk_event.meta_event_type == "heartbeat": + assert isinstance(gsk_event, HeartbeatMetaEvent) + if gsk_event.status.good and gsk_event.status.online: + pass + else: + logger.error( + f"gensokyo Bot status is not good: {gsk_event.status.model_dump()}" + ) + else: + await self.handle_event(gsk_event) + + async def call_api(self, api: str, **params: Any) -> Any: + """调用 GSK API,协程会等待直到获得 API 响应。 + + Args: + api: API 名称。 + **params: API 参数。 + + Returns: + API 响应中的 data 字段。 + + Raises: + NetworkError: 网络错误。 + ApiNotAvailable: API 请求响应 404, API 不可用。 + ActionFailed: API 请求响应 failed, API 操作失败。 + ApiTimeout: API 请求响应超时。 + """ + assert self.websocket is not None + api_echo = self._get_api_echo() + logger.debug(f"api_echo is {api_echo}") + try: + await self.websocket.send_str( + json.dumps( + {"action": api, "params": params, "echo": api_echo}, + cls=PydanticEncoder, + ) + ) + except Exception as e: + raise NetworkError from e + + start_time = time.time() + while not self.bot.should_exit.is_set(): + if time.time() - start_time > self.config.api_timeout: + break + async with self._api_response_cond: + try: + await asyncio.wait_for( + self._api_response_cond.wait(), + timeout=start_time + self.config.api_timeout - time.time(), + ) + except asyncio.TimeoutError: + break + if self._api_response["echo"] == api_echo: + if self._api_response.get("retcode") == ApiNotAvailable.ERROR_CODE: + raise ApiNotAvailable(resp=self._api_response) + if self._api_response.get("status") == "failed": + raise ActionFailed(resp=self._api_response) + return self._api_response.get("data") + + if not self.bot.should_exit.is_set(): + raise ApiTimeout + return None + + async def send( + self, + message_: BuildMessageType[GSKMessageSegment], + message_type: Literal["private", "group"], + id_: int, + ) -> Any: + """发送消息,调用 `send_private_msg` 或 `send_group_msg` API 发送消息。 + + Args: + message_: 消息内容,可以是 `str`, `Mapping`, `Iterable[Mapping]`, + `GSKMessageSegment`, `GSKMessage。` + 将使用 `GSKMessage` 进行封装。 + message_type: 消息类型。应该是 "private" 或者 "group"。 + id_: 发送对象的 ID, QQ 号码或者群号码。 + + Returns: + API 响应。 + + Raises: + TypeError: `message_type` 不是 "private" 或 "group"。 + ...: 同 `call_api()` 方法。 + """ + if message_type == "private": + return await self.send_private_msg( + user_id=id_, message=GSKMessage(message_) + ) + if message_type == "group": + return await self.send_group_msg(group_id=id_, message=GSKMessage(message_)) + raise TypeError('message_type must be "private" or "group"') + + async def get_access_token(self) -> str: + """异步获取登录凭证 + + https://bots.qq.com/app/getAppAccessToken + 属性 类型 必填 说明 + appId string 是 在开放平台管理端上获得。 + clientSecret string 是 在开放平台管理端上获得。 + """ + async with aiohttp.ClientSession() as session: + async with session.post( + self._get_access_token_url, + {"appId": self.config.app_id, "clientSecret": self.config.app_secret}, + ) as response: + if response.status == 200: + data = await response.json() + logger.info(f"Access token: {data}") + return data + else: + raise TimeoutError diff --git a/iamai/adapter/gensokyo/config.py b/iamai/adapter/gensokyo/config.py new file mode 100644 index 00000000..e5acfa84 --- /dev/null +++ b/iamai/adapter/gensokyo/config.py @@ -0,0 +1,33 @@ +"""GSK 适配器配置。""" + +from typing import Literal + +from iamai.config import ConfigModel + +__all__ = ["Config"] + + +class Config(ConfigModel): + """GSK 配置类,将在适配器被加载时被混入到机器人主配置中。 + + Attributes: + adapter_type: 适配器类型,需要和协议端配置相同。 + host: 本机域名。 + port: 监听的端口。 + url: WebSocket 路径,需和协议端配置相同。 + reconnect_interval: 重连等待时间。 + api_timeout: 进行 API 调用时等待返回响应的超时时间。 + access_token: 鉴权。 + """ + + __config_name__ = "gensokyo" + adapter_type: Literal["ws", "reverse-ws", "ws-reverse"] = "reverse-ws" + host: str = "127.0.0.1" + port: int = 8080 + url: str = "/gsk/ws" + reconnect_interval: int = 3 + api_timeout: int = 1000 + app_id: str = "" + app_secret: str = "" + token: str = "" + access_token: str = "" diff --git a/iamai/adapter/gensokyo/event.py b/iamai/adapter/gensokyo/event.py new file mode 100644 index 00000000..6be7f45a --- /dev/null +++ b/iamai/adapter/gensokyo/event.py @@ -0,0 +1,455 @@ +"""GSK 适配器事件。""" + +# pyright: reportIncompatibleVariableOverride=false + +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Literal, + Optional, + Tuple, + get_args, + get_origin, +) +from typing_extensions import Self + +from pydantic import BaseModel, ConfigDict, Field +from pydantic.fields import FieldInfo + +from iamai.event import Event +from iamai.event import MessageEvent as BaseMessageEvent +from iamai.message import BuildMessageType + +from .message import GSKMessage, GSKMessageSegment + +if TYPE_CHECKING: + from . import GSKAdapter + + +class Sender(BaseModel): + """发送人信息""" + + user_id: Optional[int] = None + nickname: Optional[str] = None + card: Optional[str] = None + sex: Optional[Literal["male", "female", "unknown"]] = None + age: Optional[int] = None + area: Optional[str] = None + level: Optional[str] = None + role: Optional[str] = None + title: Optional[str] = None + + +class Anonymous(BaseModel): + """匿名信息""" + + id: int + name: str + flag: str + + +class File(BaseModel): + """文件信息""" + + id: str + name: str + size: int + busid: int + + +class Status(BaseModel): + """状态信息""" + + model_config = ConfigDict(extra="allow") + + online: bool + good: bool + + +def _get_literal_field(field: Optional[FieldInfo]) -> Optional[str]: + if field is None: + return None + annotation = field.annotation + if annotation is None or get_origin(annotation) is not Literal: + return None + literal_values = get_args(annotation) + if len(literal_values) != 1: + return None + return literal_values[0] + + +class GSKEvent(Event["GSKAdapter"]): + """GSK 事件基类""" + + __event__ = "" + type: Optional[str] = Field(alias="post_type") + time: int + self_id: int + post_type: str + + @property + def to_me(self) -> bool: + """当前事件的 `user_id` 是否等于 `self_id`。""" + return getattr(self, "user_id", None) == self.self_id + + @classmethod + def get_event_type(cls) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """获取事件类型。 + + Returns: + 事件类型。 + """ + post_type = _get_literal_field(cls.model_fields.get("post_type", None)) + if post_type is None: + return (None, None, None) + return ( + post_type, + _get_literal_field(cls.model_fields.get(post_type + "_type", None)), + _get_literal_field(cls.model_fields.get("sub_type", None)), + ) + + +class MessageEvent(GSKEvent, BaseMessageEvent["GSKAdapter"]): + """消息事件""" + + __event__ = "message" + post_type: Literal["message"] + message_type: Literal["private", "group"] + sub_type: str + message_id: int + user_id: int + message: GSKMessage + raw_message: str + font: int + sender: Sender + + def __repr__(self) -> str: + """返回消息事件的描述。 + + Returns: + 消息事件的描述。 + """ + return f'Event<{self.type}>: "{self.message}"' + + def get_plain_text(self) -> str: + """获取消息的纯文本内容。 + + Returns: + 消息的纯文本内容。 + """ + return self.message.get_plain_text() + + async def reply( + self, message: BuildMessageType[GSKMessageSegment] + ) -> Dict[str, Any]: + """回复消息。 + + Args: + message: 回复消息的内容,同 `call_api()` 方法。 + + Returns: + API 请求响应。 + """ + raise NotImplementedError + + async def is_same_sender(self, other: Self) -> bool: + """判断自身和另一个事件是否是同一个发送者。 + + Args: + other: 另一个事件。 + + Returns: + 是否是同一个发送者。 + """ + return self.sender.user_id == other.sender.user_id + + +class PrivateMessageEvent(MessageEvent): + """私聊消息""" + + __event__ = "message.private" + message_type: Literal["private"] + sub_type: Literal["friend", "group", "other"] + + async def reply( + self, message: BuildMessageType[GSKMessageSegment] + ) -> Dict[str, Any]: + """回复消息。 + + Args: + message: 回复消息的内容,同 `call_api()` 方法。 + + Returns: + API 请求响应。 + """ + return await self.adapter.send_private_msg( + user_id=self.user_id, message=GSKMessage(message) + ) + + +class GroupMessageEvent(MessageEvent): + """群消息""" + + __event__ = "message.group" + message_type: Literal["group"] + sub_type: Literal["normal", "anonymous", "notice"] + group_id: int + anonymous: Optional[Anonymous] = None + + async def reply( + self, message: BuildMessageType[GSKMessageSegment] + ) -> Dict[str, Any]: + """回复消息。 + + Args: + message: 回复消息的内容,同 `call_api()` 方法。 + + Returns: + API 请求响应。 + """ + return await self.adapter.send_group_msg( + group_id=self.group_id, message=GSKMessage(message) + ) + + +class NoticeEvent(GSKEvent): + """通知事件""" + + __event__ = "notice" + post_type: Literal["notice"] + notice_type: str + + +class GroupUploadNoticeEvent(NoticeEvent): + """群文件上传""" + + __event__ = "notice.group_upload" + notice_type: Literal["group_upload"] + user_id: int + group_id: int + file: File + + +class GroupAdminNoticeEvent(NoticeEvent): + """群管理员变动""" + + __event__ = "notice.group_admin" + notice_type: Literal["group_admin"] + sub_type: Literal["set", "unset"] + user_id: int + group_id: int + + +class GroupDecreaseNoticeEvent(NoticeEvent): + """群成员减少""" + + __event__ = "notice.group_decrease" + notice_type: Literal["group_decrease"] + sub_type: Literal["leave", "kick", "kick_me"] + group_id: int + operator_id: int + user_id: int + + +class GroupIncreaseNoticeEvent(NoticeEvent): + """群成员增加""" + + __event__ = "notice.group_increase" + notice_type: Literal["group_increase"] + sub_type: Literal["approve", "invite"] + group_id: int + operator_id: int + user_id: int + + +class GroupBanNoticeEvent(NoticeEvent): + """群禁言""" + + __event__ = "notice.group_ban" + notice_type: Literal["group_ban"] + sub_type: Literal["ban", "lift_ban"] + group_id: int + operator_id: int + user_id: int + duration: int + + +class FriendAddNoticeEvent(NoticeEvent): + """好友添加""" + + __event__ = "notice.friend_add" + notice_type: Literal["friend_add"] + user_id: int + + +class GroupRecallNoticeEvent(NoticeEvent): + """群消息撤回""" + + __event__ = "notice.group_recall" + notice_type: Literal["group_recall"] + group_id: int + operator_id: int + user_id: int + message_id: int + + +class FriendRecallNoticeEvent(NoticeEvent): + """好友消息撤回""" + + __event__ = "notice.friend_recall" + notice_type: Literal["friend_recall"] + user_id: int + message_id: int + + +class NotifyEvent(NoticeEvent): + """提醒事件""" + + __event__ = "notice.notify" + notice_type: Literal["notify"] + sub_type: str + group_id: Optional[int] = None + user_id: int + + +class PokeNotifyEvent(NotifyEvent): + """戳一戳""" + + __event__ = "notice.notify.poke" + sub_type: Literal["poke"] + target_id: int + group_id: Optional[int] = None + + +class GroupLuckyKingNotifyEvent(NotifyEvent): + """群红包运气王""" + + __event__ = "notice.notify.lucky_king" + sub_type: Literal["lucky_king"] + group_id: int + target_id: int + + +class GroupHonorNotifyEvent(NotifyEvent): + """群成员荣誉变更""" + + __event__ = "notice.notify.honor" + sub_type: Literal["honor"] + group_id: int + honor_type: Literal["talkative", "performer", "emotion"] + + +class RequestEvent(GSKEvent): + """请求事件""" + + __event__ = "request" + post_type: Literal["request"] + request_type: str + + async def approve(self) -> Dict[str, Any]: + """同意请求。 + + Returns: + API 请求响应。 + """ + raise NotImplementedError + + async def refuse(self) -> Dict[str, Any]: + """拒绝请求。 + + Returns: + API 请求响应。 + """ + raise NotImplementedError + + +class FriendRequestEvent(RequestEvent): + """加好友请求""" + + __event__ = "request.friend" + request_type: Literal["friend"] + user_id: int + comment: str + flag: str + + async def approve(self, remark: str = "") -> Dict[str, Any]: + """同意请求。 + + Args: + remark: 好友备注。 + + Returns: + API 请求响应。 + """ + return await self.adapter.set_friend_add_request( + flag=self.flag, approve=True, remark=remark + ) + + async def refuse(self) -> Dict[str, Any]: + """拒绝请求。 + + Returns: + API 请求响应。 + """ + return await self.adapter.set_friend_add_request(flag=self.flag, approve=False) + + +class GroupRequestEvent(RequestEvent): + """加群请求 / 邀请""" + + __event__ = "request.group" + request_type: Literal["group"] + sub_type: Literal["add", "invite"] + group_id: int + user_id: int + comment: str + flag: str + + async def approve(self) -> Dict[str, Any]: + """同意请求。 + + Returns: + API 请求响应。 + """ + return await self.adapter.set_group_add_request( + flag=self.flag, sub_type=self.sub_type, approve=True + ) + + async def refuse(self, reason: str = "") -> Dict[str, Any]: + """拒绝请求。 + + Args: + reason: 拒绝原因。 + + Returns: + API 请求响应。 + """ + return await self.adapter.set_group_add_request( + flag=self.flag, sub_type=self.sub_type, approve=False, reason=reason + ) + + +class MetaEvent(GSKEvent): + """元事件""" + + __event__ = "meta_event" + post_type: Literal["meta_event"] + meta_event_type: str + + +class LifecycleMetaEvent(MetaEvent): + """生命周期""" + + __event__ = "meta_event.lifecycle" + meta_event_type: Literal["lifecycle"] + sub_type: Literal["enable", "disable", "connect"] + + +class HeartbeatMetaEvent(MetaEvent): + """心跳""" + + __event__ = "meta_event.heartbeat" + meta_event_type: Literal["heartbeat"] + status: Status + interval: int diff --git a/iamai/adapter/gensokyo/exceptions.py b/iamai/adapter/gensokyo/exceptions.py new file mode 100644 index 00000000..ec132fea --- /dev/null +++ b/iamai/adapter/gensokyo/exceptions.py @@ -0,0 +1,43 @@ +"""GSK 适配器异常。""" + +from typing import Any, ClassVar, Dict + +from iamai.exceptions import AdapterException + +__all__ = [ + "GSKException", + "NetworkError", + "ActionFailed", + "ApiNotAvailable", + "ApiTimeout", +] + + +class GSKException(AdapterException): + """GSK 异常基类。""" + + +class NetworkError(GSKException): + """网络异常。""" + + +class ActionFailed(GSKException): + """API 请求成功响应,但响应表示 API 操作失败。""" + + def __init__(self, resp: Dict[str, Any]) -> None: + """初始化。 + + Args: + resp: 返回的响应。 + """ + self.resp = resp + + +class ApiNotAvailable(ActionFailed): + """API 请求返回 404,表示当前请求的 API 不可用或不存在。""" + + ERROR_CODE: ClassVar[int] = 1404 + + +class ApiTimeout(GSKException): + """API 请求响应超时。""" diff --git a/iamai/adapter/gensokyo/message.py b/iamai/adapter/gensokyo/message.py new file mode 100644 index 00000000..0ad2d146 --- /dev/null +++ b/iamai/adapter/gensokyo/message.py @@ -0,0 +1,280 @@ +"""GSK 适配器消息。""" + +from typing import Literal, Optional, Type, Union +from typing_extensions import Self + +from iamai.message import Message, MessageSegment + +__all__ = ["GSKMessage", "GSKMessageSegment", "escape"] + + +class GSKMessage(Message["GSKMessageSegment"]): + """GSK 消息。""" + + @classmethod + def get_segment_class(cls) -> Type["GSKMessageSegment"]: + """获取消息字段类。 + + Returns: + 消息字段类。 + """ + return GSKMessageSegment + + +class GSKMessageSegment(MessageSegment["GSKMessage"]): + """GSK 消息字段。""" + + @classmethod + def get_message_class(cls) -> Type[GSKMessage]: + """获取消息类。 + + Returns: + 消息类。 + """ + return GSKMessage + + @classmethod + def from_str(cls, msg: str) -> Self: + """用于将 `str` 转换为消息字段。 + + Args: + msg: 要解析为消息字段的数据。 + + Returns: + 由 `str` 转换的消息字段。 + """ + return cls.text(msg) + + def __str__(self) -> str: + """返回消息字段的文本表示。 + + Returns: + 消息字段的文本表示。 + """ + if self.type == "text": + return self.data.get("text", "") + return self.get_cqcode() + + def get_cqcode(self) -> str: + """获取此消息字段的 CQ 码形式。 + + Returns: + 此消息字段的 CQ 码形式。 + """ + if self.type == "text": + return escape(self.data.get("text", ""), escape_comma=False) + + params = ",".join( + [f"{k}={escape(str(v))}" for k, v in self.data.items() if v is not None] + ) + return f'[CQ:{self.type}{"," if params else ""}{params}]' + + @classmethod + def text(cls, text: str) -> Self: + """纯文本""" + return cls(type="text", data={"text": text}) + + @classmethod + def face(cls, id_: int) -> Self: + """QQ 表情""" + return cls(type="face", data={"id": str(id_)}) + + @classmethod + def image( + cls, + file: str, + type_: Optional[Literal["flash"]] = None, + cache: bool = True, + proxy: bool = True, + timeout: Optional[int] = None, + ) -> Self: + """图片""" + return cls( + type="image", + data={ + "file": file, + "type": type_, + "cache": cache, + "proxy": proxy, + "timeout": timeout, + }, + ) + + @classmethod + def record( + cls, + file: str, + magic: bool = False, + cache: bool = True, + proxy: bool = True, + timeout: Optional[int] = None, + ) -> Self: + """语音""" + return cls( + type="record", + data={ + "file": file, + "magic": magic, + "cache": cache, + "proxy": proxy, + "timeout": timeout, + }, + ) + + @classmethod + def video( + cls, + file: str, + cache: bool = True, + proxy: bool = True, + timeout: Optional[int] = None, + ) -> Self: + """短视频""" + return cls( + type="video", + data={"file": file, "cache": cache, "proxy": proxy, "timeout": timeout}, + ) + + @classmethod + def at(cls, qq: Union[int, Literal["all"]]) -> Self: # pylint: disable=invalid-name + """@某人""" + return cls(type="at", data={"qq": str(qq)}) + + @classmethod + def rps(cls) -> Self: + """猜拳魔法表情""" + return cls(type="rps", data={}) + + @classmethod + def dice(cls) -> Self: + """掷骰子魔法表情""" + return cls(type="dice", data={}) + + @classmethod + def shake(cls) -> Self: + """窗口抖动 (戳一戳)""" + return cls(type="shake", data={}) + + @classmethod + def poke(cls, type_: str, id_: int) -> Self: + """戳一戳""" + return cls(type="poke", data={"type": type_, "id": str(id_)}) + + @classmethod + def anonymous(cls, ignore: Optional[bool] = None) -> Self: + """匿名发消息""" + return cls(type="anonymous", data={"ignore": ignore}) + + @classmethod + def share( + cls, + url: str, + title: str, + content: Optional[str] = None, + image: Optional[str] = None, + ) -> Self: + """链接分享""" + return cls( + type="share", + data={"url": url, "title": title, "content": content, "image": image}, + ) + + @classmethod + def contact(cls, type_: Literal["qq", "group"], id_: int) -> Self: + """推荐好友/推荐群""" + return cls(type="contact", data={"type": type_, "id": str(id_)}) + + @classmethod + def contact_friend(cls, id_: int) -> Self: + """推荐好友""" + return cls(type="contact", data={"type": "qq", "id": str(id_)}) + + @classmethod + def contact_group(cls, id_: int) -> Self: + """推荐好友""" + return cls(type="contact", data={"type": "group", "id": str(id_)}) + + @classmethod + def location( + cls, lat: float, lon: float, title: Optional[str], content: Optional[str] = None + ) -> Self: + """位置""" + return cls( + type="location", + data={"lat": str(lat), "lon": str(lon), "title": title, "content": content}, + ) + + @classmethod + def music(cls, type_: Literal["qq", "163", "xm"], id_: int) -> Self: + """音乐分享""" + return cls(type="music", data={"type": type_, "id": str(id_)}) + + @classmethod + def music_custom( + cls, + url: str, + audio: str, + title: str, + content: Optional[str] = None, + image: Optional[str] = None, + ) -> Self: + """音乐自定义分享""" + return cls( + type="music", + data={ + "type": "custom", + "url": url, + "audio": audio, + "title": title, + "content": content, + "image": image, + }, + ) + + @classmethod + def reply(cls, id_: int) -> Self: + """回复""" + return cls(type="reply", data={"id": str(id_)}) + + @classmethod + def node(cls, id_: int) -> Self: + """合并转发节点""" + return cls(type="node", data={"id": str(id_)}) + + @classmethod + def node_custom(cls, user_id: int, nickname: str, content: "GSKMessage") -> Self: + """合并转发自定义节点""" + return cls( + type="node", + data={ + "user_id": str(user_id), + "nickname": str(nickname), + "content": content, + }, + ) + + @classmethod + def xml_message(cls, data: str) -> Self: + """XML 消息""" + return cls(type="xml", data={"data": data}) + + @classmethod + def json_message(cls, data: str) -> Self: + """JSON 消息""" + return cls(type="json", data={"data": data}) + + +def escape(string: str, *, escape_comma: bool = True) -> str: + """对 CQ 码中的特殊字符进行转义。 + + Args: + string: 待转义的字符串。 + escape_comma: 是否转义 `,`。 + + Returns: + 转义后的字符串。 + """ + string = string.replace("&", "&").replace("[", "[").replace("]", "]") + if escape_comma: + string = string.replace(",", ",") + return string diff --git a/iamai/adapter/kook/__init__.py b/iamai/adapter/kook/__init__.py new file mode 100644 index 00000000..fa857db1 --- /dev/null +++ b/iamai/adapter/kook/__init__.py @@ -0,0 +1,397 @@ +"""Kook Adapter + +This adapter is adapted to the Kook Platform. +For details of the agreement, please refer to: [Kook Developer Platform](https://developer.kookapp.cn/) +""" + +import re +import sys +import json +import time +import zlib +import asyncio +from functools import partial +from typing import TYPE_CHECKING, Any, Dict, Literal, Mapping, Optional + +import aiohttp +import requests +from pydantic import parse_obj_as + +from iamai.adapter.utils import WebSocketAdapter +from iamai.log import logger, error_or_exception + +from .config import Config +from .message import MessageDeserializer, rev_msg_type_map +from .api.handle import User, get_api_method, get_api_restype +from .exceptions import ( + ApiTimeout, + TokenError, + ActionFailed, + NetworkError, + ReconnectError, + ApiNotAvailable, +) +from .event import ( + KookEvent, + EventTypes, + OriginEvent, + ResultStore, + SignalTypes, + _kook_events, + get_event_class, +) + +if TYPE_CHECKING: + from .message import T_KookMSG + +__all__ = ["KookAdapter"] + +BASE_URL = "https://www.kookapp.cn/api" + + +class KookAdapter(WebSocketAdapter[KookEvent, Config]): + """Kook Adapter.""" + + name: str = "kook" + Config = Config + + _gateway_response: dict = {} + _api_response: Dict[Any, Any] + _api_response_cond: asyncio.Condition + _api_id: int = 0 + + def __getattr__(self, item): + return partial(self.call_api, item) + + def get_api_protocol(self, version_number: int | str = 3) -> str: + """API version management + KOOK may have different versions of API in the future. You can pass it like ``https://www.kookapp.cn/api/v{version_number}`` + This explicitly specifies the API version to use in the request path. If version_number is omitted, it will point to the default version. + + Specific reference: https://developer.kookapp.cn/doc/reference + + Args: + version_number (int, optional): version code. Defaults to 3. + + Returns: + str: KOOK API URL of the corresponding version + """ + return f"{BASE_URL}/v{version_number}" + + def build_url(self, args) -> str: + return "/".join(args) + + async def startup(self): + """Initialize the adapter.""" + self.adapter_type = self.config.adapter_type + if self.adapter_type == "websocket": + self.adapter_type = "ws" + self.bot.global_state["adapter"] = self.bot.global_state.get("adapter", {}) + self.bot.global_state["adapter"]["kook"] = {} + self.reconnect_interval = self.config.reconnect_interval + self._api_response_cond = asyncio.Condition() + await super().startup() + + async def websocket_connect(self) -> None: + """Create a forward WebSocket connection.""" + logger.info("Trying to GET the GateWay...") + url = self.build_url([self.get_api_protocol(), "gateway", "index"]) + headers = { + "Authorization": f"Bot {self.config.access_token}", + } + # Get The Gateway URL + # https://developer.kookapp.cn/doc/http/gateway + response = requests.get( + url, headers=headers, params={"compress": self.config.compress} + ) + if response.status_code == 200: + logger.success("Successed to get GateWay.") + self._gateway_response = response.json() + self.bot.global_state["adapter"]["kook"][ + "bot_info" + ] = await self._get_self_data(self.config.access_token) + self.self_id = self.bot.global_state["adapter"]["kook"]["bot_info"].id_ + self.self_name = self.bot.global_state["adapter"]["kook"][ + "bot_info" + ].username + logger.success(f"Bot<{self.self_name}> self id: {self.self_id}") + else: + logger.error(f"Failed to get GateWay, status_code: {response.status_code}") + return + + logger.info("Trying to connect to WebSocket server...") + + # start connection + async with self.session.ws_connect( + self._gateway_response["data"]["url"] + ) as self.websocket: + await self.handle_websocket() + + async def handle_websocket_msg(self, msg: aiohttp.WSMessage): + """Handle Websocket Message.""" + msg_dict: dict + if msg.type == aiohttp.WSMsgType.TEXT: + try: + msg_dict = msg.json() + logger.debug(msg_dict) + except json.JSONDecodeError as e: + error_or_exception( + "WebSocket message parsing error, not json:", + e, + self.bot.config.bot.log.verbose_exception, + ) + return + + elif msg.type == aiohttp.WSMsgType.BINARY: + try: + msg_dict: dict = zlib.decompress(msg.data).decode("utf-8") # type: ignore[dict] + logger.debug(msg_dict) + except zlib.error as e: + error_or_exception( + "WebSocket message decoding error, not binary:", + e, + self.bot.config.bot.log.verbose_exception, + ) + return + + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error( + f"WebSocket connection closed " + f"with exception {self.websocket.exception()!r}" # type: ignore + ) + return + else: + return + + # reveive hello package + if msg_dict.get("s") == SignalTypes.HELLO: + data: dict = msg_dict["d"] + if data.get("code") == 0: + try: + logger.success( + f"WebSocket connection verified, " + f"Session key: {data['session_id'][:7]}" + ) + # Call start_heartbeat to send heartbeats at intervals of 30 (+5,-5) + self.bot.global_state["adapter"]["kook"]["session"] = data.get( + "session_id" + ) + ResultStore.set_sn( + self.bot.global_state["adapter"]["kook"]["session"], 0 + ) + asyncio.ensure_future( + self.start_heartbeat( + self.bot.global_state["adapter"]["kook"]["session"] + ) + ) + logger.debug("HeartBeat task started!") + await self.handle_kook_event(data) + except Exception as e: + logger.error(f"WebSocket connection verified failed!\n{e}") + raise ReconnectError from e + elif data.get("code") == 40103: + raise ReconnectError + elif data.get("code") == 40101: + raise TokenError("Invalid Token!") + elif data.get("code") == 40102: + raise TokenError("Token verification failed") + else: + logger.warning( + f"Websocket connection failed with code {msg_dict['d'].get('code') or msg_dict}, " + f"retrying..." + ) + await asyncio.sleep(self.reconnect_interval) + elif msg_dict.get("s") == SignalTypes.PONG: + data = { + "self_id": self.self_id, + "post_type": "meta_event", + "meta_event_type": "heartbeat", + } + logger.info(f"HeartBeat received!{data}") + logger.info( + f"Bot {self.bot.global_state['kook']['bot_info'].username} HeartBeat", + ) + await self.handle_kook_event(data) + elif msg_dict.get("s") == SignalTypes.EVENT: + ResultStore.set_sn(self.bot.global_state["kook"]["session"], msg_dict["sn"]) + try: + data = msg_dict["d"] + await self.handle_kook_event(data) + except Exception as e: + logger.error(f"Event handle failed!\n{e!r}") + elif msg_dict.get("s") == SignalTypes.RECONNECT: + raise ReconnectError + elif msg_dict.get("s") == SignalTypes.RESUME_ACK: + return + else: + async with self._api_response_cond: + self._api_response = msg_dict + self._api_response_cond.notify_all() + + async def handle_kook_event(self, data: Dict[str, Any]): + """Handle kook events. + + Args: + msg: received message. + """ + post_type = data.get("type") + + kook_event = KookEvent(adapter=self, **data) + + if self.config.show_raw: + logger.debug(data) + + if kook_event.post_type == "meta_event": + if ( + kook_event.meta_event_type == "lifecycle" + and kook_event.sub_type == "connect" + ): + logger.success( + f"WebSocket connection " + f"from Kook Bot {self.bot.global_state['kook']['bot_info'].username} accepted!" + ) + else: + if ( + not self.config.report_self_message + and kook_event.user_id == kook_event.self_id + ): + return + await self.handle_event(kook_event) + + async def call_api(self, api: str, **data: dict) -> Any: + match = re.findall(r"[A-Z]", api) + if len(match) > 0: + for m in match: + api = api.replace(m, f"-{m.lower()}") + api = api.replace("_", "/") + + if api.startswith("/api/v3/"): + api = api[len("/api/v3/") :] + elif api.startswith("api/v3"): + api = api[len("api/v3") :] + api = api.strip("/") + return await self._call_api(api, data, self.config.access_token) # type: ignore + + async def _call_api( + self, + api: str, + data: Optional[Mapping[str, Any]] = None, + token: Optional[str] = None, + ) -> Any: + data = dict(data) if data is not None else {} + + method = data.get("method") if data.get("method") else get_api_method(api) + headers = data.get("headers", {}) + + files = None + query = None + body = None + + if "files" in data: + files = data["files"] + del data["files"] + elif "file" in data: + files = {"file": data["file"]} + del data["file"] + + if method == "GET": + query = data + elif method == "POST": + body = data + + if token is not None: + headers["Authorization"] = f"Bot {self.config.access_token}" + + result_type = get_api_restype(api) + try: + resp = requests.request( + method=method, + url=self.build_url([self.get_api_protocol(), api]), + headers=headers, + params=query, + data=body, + files=files, + timeout=self.config.api_timeout, + ) + result = _handle_api_result(resp) + logger.debug(f"API {api} called with result {result}") + return parse_obj_as(result_type, result) if result_type else None + except Exception as e: + raise e + + async def _get_self_data(self, token: str) -> User: + """获取当前机器人的信息。 + + Returns: + Optional[dict]: 当前机器人的信息。 + """ + token = token or self.config.access_token + return await self._call_api("user/me", token=token) + + async def start_heartbeat(self, session: str) -> None: + """ + 每30s一次心跳 + :return: + """ + while not self.bot.should_exit.is_set() and not self.websocket.closed: + await self.websocket.send_json( + json.dumps({"s": 2, "sn": ResultStore.get_sn(session)}) + ) + logger.debug(f"HeartBeat sent {ResultStore.get_sn(session)} times!") + await asyncio.sleep(26) + + async def send( + self, message_: "T_KookMSG", message_type: Literal["GROUP", "PERSON"], id_: int + ) -> Dict[str, Any]: + """发送消息,调用 message/create 或 direct-message/create API 发送消息。 + + Args: + message_: 消息内容,可以是 str, Mapping, Iterable[Mapping], + 'KookMessageSegment', 'KookMessage'。 + 将使用 `KookMessage` 进行封装。 + message_type: 消息类型。应该是 GROUP 或者 PERSON。 + id_: 发送对象的 ID ,Kook 用户码或者Kook频道码。 + + Returns: + API 响应。 + + Raises: + TypeError: message_type 不是 'PERSON' 或 'GROUP'。 + ...: 同 `call_api()` 方法。 + """ + if message_type == "PERSON": + return await self.call_api( + api="direct-message/create", target_id=id_, content=message_ + ) + elif message_type == "GROUP": + return await self.call_api( + api="message/create", target_id=id_, content=message_ + ) + else: + raise TypeError('message_type must be "PERSON" or "GROUP"') + + +def _handle_api_result(response: Any) -> Any: + """ + :说明: + + 处理 API 请求返回值。 + + :参数: + + * ``response: Response``: API 响应体 + + :返回: + + - ``T``: API 调用返回数据 + + :异常: + + - ``ActionFailed``: API 调用失败 + """ + result = json.loads(response.content) + if isinstance(result, dict): + if result.get("code") != 0: + raise ActionFailed(response) + else: + return result.get("data") diff --git a/iamai/adapter/kook/_event.py b/iamai/adapter/kook/_event.py new file mode 100644 index 00000000..e140414f --- /dev/null +++ b/iamai/adapter/kook/_event.py @@ -0,0 +1,930 @@ +"""Kook 适配器事件。""" + +import asyncio +import inspect +from enum import IntEnum +from collections import UserDict +from typing import ( # type: ignore + TYPE_CHECKING, + Any, + Dict, + List, + Type, + Tuple, + Union, + Literal, + TypeVar, + Optional, +) + +from pydantic import Field, HttpUrl, BaseModel, validator, root_validator + +from iamai.event import Event + +from .api import Role, User, Emoji, Guild, Channel +from .message import KookMessage, MessageDeserializer + +if TYPE_CHECKING: + from . import KookAdapter + from .message import T_KookMSG + +T_KookEvent = TypeVar("T_KookEvent", bound="KookEvent") + + +class ResultStore: + _seq = 1 + _futures: Dict[Tuple[str, int], asyncio.Future] = {} + _sn_map = {} + + @classmethod + def set_sn(cls, self_id: str, sn: int) -> None: + cls._sn_map[self_id] = sn + + @classmethod + def get_sn(cls, self_id: str) -> int: + return cls._sn_map.get(self_id, 0) + + +class AttrDict(UserDict): + def __init__(self, data=None): + initial = dict(data) # type: ignore + for k in initial: + if isinstance(initial[k], dict): + initial[k] = AttrDict(initial[k]) # type: ignore + + super().__init__(initial) + + def __getattr__(self, name): + return self[name] + + +class PermissionOverwrite(BaseModel): + role_id: Optional[int] = None + allow: Optional[int] = None + deny: Optional[int] = None + + +class PermissionUser(BaseModel): + user: Optional[User] = None + allow: Optional[int] = None + deny: Optional[int] = None + + +class ChannelRoleInfo(BaseModel): + """频道角色权限详情""" + + permission_overwrites: Optional[List[PermissionOverwrite]] = None + """针对角色在该频道的权限覆写规则组成的列表""" + permission_users: Optional[List[PermissionUser]] = None + """针对用户在该频道的权限覆写规则组成的列表""" + permission_sync: Optional[int] = None + """权限设置是否与分组同步, 1 or 0""" + + +class Quote(BaseModel): + """引用消息""" + + id_: Optional[str] = Field(None, alias="id") + """引用消息 id""" + type: Optional[int] = None + """引用消息类型""" + content: Optional[str] = None + """引用消息内容""" + create_at: Optional[int] = None + """引用消息创建时间(毫秒)""" + author: Optional[User] = None + """作者的用户信息""" + + +class Attachments(BaseModel): + """附加的多媒体数据""" + + type: Optional[str] = None + """多媒体类型""" + url: Optional[str] = None + """多媒体地址""" + name: Optional[str] = None + """多媒体名""" + size: Optional[int] = None + """大小 单位(B)""" + + +class URL(BaseModel): + url: Optional[str] = None + """资源的 url""" + + +class Meta(BaseModel): + page: Optional[int] = None + page_total: Optional[int] = None + page_size: Optional[int] = None + total: Optional[int] = None + + +class ListReturn(BaseModel): + meta: Optional[Meta] = None + sort: Optional[Dict[str, Any]] = None + + +class BlackList(BaseModel): + """黑名单""" + + user_id: Optional[str] = None + """用户 id""" + created_time: Optional[int] = None + """加入黑名单的时间戳(毫秒)""" + remark: Optional[str] = None + """加入黑名单的原因""" + user: Optional[User] = None + """用户""" + + +class BlackListsReturn(ListReturn): + """获取黑名单列表返回信息""" + + blacklists: Optional[List[BlackList]] = Field(None, alias="items") + """黑名单列表""" + + +class MessageCreateReturn(BaseModel): + """发送频道消息返回信息""" + + msg_id: Optional[str] = None + """服务端生成的消息 id""" + msg_timestamp: Optional[int] = None + """消息发送时间(服务器时间戳)""" + nonce: Optional[str] = None + """随机字符串""" + + +class ChannelRoleReturn(BaseModel): + """创建或更新频道角色权限返回信息""" + + role_id: Optional[int] = None + user_id: Optional[str] = None + allow: Optional[int] = None + deny: Optional[int] = None + + +class GuildsReturn(ListReturn): + guilds: Optional[List[Guild]] = Field(None, alias="items") + + +class ChannelsReturn(ListReturn): + channels: Optional[List[Channel]] = Field(None, alias="items") + + +class GuildUsersRetrun(ListReturn): + """服务器中的用户列表""" + + users: Optional[List[User]] = Field(None, alias="items") + """用户列表""" + user_count: Optional[int] = None + """用户数量""" + online_count: Optional[int] = None + """在线用户数量""" + offline_count: Optional[int] = None + """离线用户数量""" + + +class Reaction(BaseModel): + emoji: Optional[Emoji] = None + count: Optional[int] = None + me: Optional[bool] = None + + +class MentionInfo(BaseModel): + mention_part: Optional[List[dict]] = None + mention_role_part: Optional[List[dict]] = None + channel_part: Optional[List[dict]] = None + item_part: Optional[List[dict]] = None + + +class BaseMessage(BaseModel): + id_: Optional[str] = Field(None, alias="id") + """消息 ID""" + type: Optional[int] = None + """消息类型""" + content: Optional[str] = None + """消息内容""" + embeds: Optional[List[dict]] = None + """超链接解析数据""" + attachments: Optional[Union[bool, Attachments]] = None + """附加的多媒体数据""" + create_at: Optional[int] = None + """创建时间""" + updated_at: Optional[int] = None + """更新时间""" + reactions: Optional[List[Reaction]] = None + """回应数据""" + image_name: Optional[str] = None + """""" + read_status: Optional[bool] = None + """是否已读""" + quote: Optional[Quote] = None + """引用数据""" + mention_info: Optional[MentionInfo] = None + """引用特定用户或特定角色的信息""" + + +class ChannelMessage(BaseMessage): + """频道消息""" + + author: Optional[User] = None + mention: Optional[List[Any]] = None + mention_all: Optional[bool] = None + mention_roles: Optional[List[Any]] = None + mention_here: Optional[bool] = None + + +class DirectMessage(BaseMessage): + """私聊消息""" + + author_id: Optional[str] = None + """作者的用户 ID""" + from_type: Optional[int] = None + """from_type""" + msg_icon: Optional[str] = None + """msg_icon""" + + +class ChannelMessagesReturn(BaseModel): + """获取私信聊天消息列表返回信息""" + + direct_messages: Optional[List[ChannelMessage]] = Field(None, alias="items") + + +class DirectMessagesReturn(BaseModel): + """获取私信聊天消息列表返回信息""" + + direct_messages: Optional[List[DirectMessage]] = Field(None, alias="items") + + +class ReactionUser(User): + reaction_time: Optional[int] = None + + +class TargetInfo(BaseModel): + """私聊会话 目标用户信息""" + + id_: Optional[str] = Field(None, alias="id") + """目标用户 ID""" + username: Optional[str] = None + """目标用户名""" + online: Optional[bool] = None + """是否在线""" + avatar: Optional[str] = None + """头像图片链接""" + + +class UserChat(BaseModel): + """私聊会话""" + + code: Optional[str] = None + """私信会话 Code""" + last_read_time: Optional[int] = None + """上次阅读消息的时间 (毫秒)""" + latest_msg_time: Optional[int] = None + """最新消息时间 (毫秒)""" + unread_count: Optional[int] = None + """未读消息数""" + target_info: Optional[TargetInfo] = None + """目标用户信息""" + + +class UserChatsReturn(ListReturn): + """获取私信聊天会话列表返回信息""" + + user_chats: Optional[List[UserChat]] = Field(None, alias="items") + """私聊会话列表""" + + +class RolesReturn(ListReturn): + """获取服务器角色列表返回信息""" + + roles: Optional[List[Role]] = Field(None, alias="items") + """服务器角色列表""" + + +class GuilRoleReturn(BaseModel): + """赋予或删除用户角色返回信息""" + + user_id: Optional[str] = None + """用户 id""" + guild_id: Optional[str] = None + """服务器 id""" + roles: Optional[List[int]] = None + """角色 id 的列表""" + + +class IntimacyImg(BaseModel): + """形象图片的总列表""" + + id_: Optional[str] = Field(None, alias="id") + """ 形象图片的 id""" + url: Optional[str] = None + """形象图片的地址""" + + +class IntimacyIndexReturn(BaseModel): + """获取用户亲密度返回信息""" + + img_url: Optional[str] = None + """机器人给用户显示的形象图片地址""" + social_info: Optional[str] = None + """机器人显示给用户的社交信息""" + last_read: Optional[int] = None + """用户上次查看的时间戳""" + score: Optional[int] = None + """亲密度,0-2200""" + img_list: Optional[List[IntimacyImg]] = None + """形象图片的总列表""" + + +class GuildEmoji(BaseModel): + """服务器表情""" + + name: Optional[str] = None + """表情的名称""" + id_: Optional[str] = Field(None, alias="id") + """表情的 ID""" + user_info: Optional[User] = None + """上传用户""" + + +class GuildEmojisReturn(ListReturn): + """获取服务器表情列表返回信息""" + + roles: Optional[List[GuildEmoji]] = Field(None, alias="items") + """服务器表情列表""" + + +class Invite(BaseModel): + """邀请信息""" + + guild_id: Optional[str] = None + """服务器 id""" + channel_id: Optional[str] = None + """频道 id""" + url_code: Optional[str] = None + """url code""" + url: Optional[str] = None + """地址""" + user: Optional[User] = None + """用户""" + + +class InvitesReturn(ListReturn): + """获取邀请列表返回信息""" + + roles: Optional[List[Invite]] = Field(None, alias="items") + """邀请列表""" + + +class EventTypes(IntEnum): + """ + 事件主要格式 + Kook 协议事件,字段与 Kook 一致。各事件字段参考 `Kook 文档` + + .. Kook 文档: + https://developer.kookapp.cn/doc/event/event-introduction#事件主要格式 + """ + + text = 1 + image = 2 + video = 3 + file = 4 + audio = 8 + kmarkdown = 9 + card = 10 + sys = 255 + + +class SignalTypes(IntEnum): + """ + 信令类型 + Kook 协议信令,字段与 Kook 一致。各事件字段参考 `Kook 文档` + + .. Kook 文档: + https://developer.kookapp.cn/doc/websocket#信令格式 + """ + + EVENT = 0 + HELLO = 1 + PING = 2 + PONG = 3 + RESUME = 4 + RECONNECT = 5 + RESUME_ACK = 6 + SYS = 255 + + +class Attachment(BaseModel): + type: str + name: str + url: HttpUrl + file_type: Optional[str] = Field(None) + size: Optional[int] = Field(None) + duration: Optional[float] = Field(None) + width: Optional[int] = Field(None) + hight: Optional[int] = Field(None) + + +class Extra(BaseModel): + type_: Union[int, str] = Field(None, alias="type") + guild_id: Optional[str] = Field(None) + channel_name: Optional[str] = Field(None) + mention: Optional[List[str]] = Field(None) + mention_all: Optional[bool] = Field(None) + mention_roles: Optional[List[str]] = Field(None) + mention_here: Optional[bool] = Field(None) + author: Optional[User] = Field(None) + body: Optional[AttrDict] = Field(None) + attachments: Optional[Attachment] = Field(None) + code: Optional[str] = Field(None) + + @validator("body", pre=True) + def convert_body(cls, v): + if v is None: + return None + + if not isinstance(v, dict): + raise TypeError("body must be dict") + if not isinstance(v, AttrDict): + v = AttrDict(v) + return v + + class Config: + arbitrary_types_allowed = True + + +class OriginEvent(Event["KookAdapter"]): + """为了区分信令中非Event事件,增加了前置OriginEvent""" + + __event__ = "" + + post_type: str + + +class Kmarkdown(BaseModel): + raw_content: str + mention_part: list + mention_role_part: list + + +class EventMessage(BaseModel): + type: Union[int, str] + guild_id: Optional[str] + channel_name: Optional[str] + mention: Optional[List] + mention_all: Optional[bool] + mention_roles: Optional[List] + mention_here: Optional[bool] + nav_channels: Optional[List] + author: User + + kmarkdown: Optional[Kmarkdown] + + code: Optional[str] = None + attachments: Optional[Attachment] = None + + content: KookMessage + + +class KookEvent(OriginEvent): + """ + 事件主要格式,来自 d 字段 + Kook 协议事件,字段与 Kook 一致。各事件字段参考 `Kook 文档` + + .. Kook 文档: + https://developer.kookapp.cn/doc/event/event-introduction + """ + + __event__ = "" + channel_type: Literal["PERSON", "GROUP"] + type_: int = Field(alias="type") + """1:文字消息\n2:图片消息\n3:视频消息\n4:文件消息\n8:音频消息\n9:KMarkdown\n10:card消息\n255:系统消息\n其它的暂未开放""" + target_id: str + """ + 发送目的\n + 频道消息类时, 代表的是频道 channel_id\n + 如果 channel_type 为 GROUP 组播且 type 为 255 系统消息时,则代表服务器 guild_id""" + author_id: Optional[str] = None + content: KookMessage + msg_id: str + msg_timestamp: int + nonce: str + extra: Extra + user_id: str + + post_type: str + self_id: Optional[str] = None # onebot兼容 + + +# Message Events +class MessageEvent(KookEvent): + """消息事件""" + + __event__ = "message" + + post_type: Literal["message"] = "message" + message_type: str # group private 其实是person + sub_type: str + event: EventMessage + + def __repr__(self) -> str: + return f'Event<{self.post_type}>: "{self.content}"' + + def get_plain_text(self) -> str: + """获取消息的纯文本内容。 + + Returns: + 消息的纯文本内容。 + """ + return self.content.get_plain_text() # type: ignore + + async def reply(self, msg: "T_KookMSG") -> Dict[str, Any]: + """回复消息。 + + Args: + msg: 回复消息的内容,同 `call_api()` 方法。 + + Returns: + API 请求响应。 + """ + raise NotImplementedError + + +class PrivateMessageEvent(MessageEvent): + """私聊消息""" + + __event__ = "message.private" + message_type: Literal["private"] + + async def reply(self, msg: "T_KookMSG") -> Dict[str, Any]: + return await self.adapter.call_api( + api="direct-message/create", target_id=self.author_id, content=msg + ) + + +class ChannelMessageEvent(MessageEvent): + """公共频道消息""" + + __event__ = "message.group" + message_type: Literal["group"] + group_id: str + + async def reply(self, msg: "T_KookMSG") -> Dict[str, Any]: + return await self.adapter.call_api( + "message/create", target_id=self.target_id, content=msg + ) + + +# Notice Events +class NoticeEvent(KookEvent): + """通知事件""" + + __event__ = "notice" + post_type: Literal["notice"] + notice_type: str + + def __repr__(self) -> str: + return f'Event<{self.post_type}>: "{self.content}"' + + +# Channel Events +class ChannelNoticeEvent(NoticeEvent): + """频道消息事件""" + + __event__ = "notice" + group_id: int + + +class ChannelAddReactionEvent(ChannelNoticeEvent): + """频道内用户添加 reaction""" + + __event__ = "notice.added_reaction" + notice_type: Literal["added_reaction"] + + +class ChannelDeletedReactionEvent(ChannelNoticeEvent): + """频道内用户删除 reaction""" + + __event__ = "notice.deleted_reaction" + notice_type: Literal["deleted_reaction"] + + +class ChannelUpdatedMessageEvent(ChannelNoticeEvent): + """频道消息更新""" + + __event__ = "notice.updated_message" + notice_type: Literal["updated_message"] + + +class ChannelDeleteMessageEvent(ChannelNoticeEvent): + """频道消息被删除""" + + __event__ = "notice.deleted_message" + notice_type: Literal["deleted_message"] + + +class ChannelAddedEvent(ChannelNoticeEvent): + """新增频道""" + + __event__ = "notice.added_channel" + notice_type: Literal["added_channel"] + + +class ChannelUpdatedEvent(ChannelNoticeEvent): + """修改频道信息""" + + __event__ = "notice.updated_channel" + notice_type: Literal["updated_channel"] + + +class ChannelDeleteEvent(ChannelNoticeEvent): + """删除频道""" + + __event__ = "notice.deleted_channel" + notice_type: Literal["deleted_channel"] + + +class ChannelPinnedMessageEvent(ChannelNoticeEvent): + """新增频道置顶消息""" + + __event__ = "notice.pinned_message" + notice_type: Literal["pinned_message"] + + +class ChannelUnpinnedMessageEvent(ChannelNoticeEvent): + """取消频道置顶消息""" + + __event__ = "notice.unpinned_message" + notice_type: Literal["unpinned_message"] + + +# Private Events +class PrivateNoticeEvent(NoticeEvent): + "私聊消息事件" + + +class PrivateUpdateMessageEvent(PrivateNoticeEvent): + """私聊消息更新""" + + __event__ = "notice.updated_private_message" + notice_type: Literal["updated_private_message"] + + +class PrivateDeleteMessageEvent(PrivateNoticeEvent): + """私聊消息删除""" + + __event__ = "notice.deleted_private_message" + notice_type: Literal["deleted_private_message"] + + +class PrivateAddReactionEvent(PrivateNoticeEvent): + """私聊内用户添加 reaction""" + + __event__ = "notice.private_added_reaction" + notice_type: Literal["private_added_reaction"] + + +class PrivateDeleteReactionEvent(PrivateNoticeEvent): + """私聊内用户取消 reaction""" + + __event__ = "notice.private_deleted_reaction" + notice_type: Literal["private_deleted_reaction"] + + +# Guild Events +class GuildNoticeEvent(NoticeEvent): + """服务器相关事件""" + + group_id: int + + def get_guild_id(self): + return self.target_id # type: ignore + + +# Guild Member Events +class GuildMemberNoticeEvent(GuildNoticeEvent): + """服务器成员相关事件""" + + pass + + +class GuildMemberIncreaseNoticeEvent(GuildMemberNoticeEvent): + """新成员加入服务器""" + + __event__ = "notice.joined_guild" + notice_type: Literal["joined_guild"] + + +class GuildMemberDecreaseNoticeEvent(GuildMemberNoticeEvent): + """服务器成员退出""" + + __event__ = "notice.exited_guild" + notice_type: Literal["exited_guild"] + + +class GuildMemberUpdateNoticeEvent(GuildMemberNoticeEvent): + """服务器成员信息更新(修改昵称)""" + + __event__ = "notice.updated_guild_member" + notice_type: Literal["updated_guild_member"] + + +class GuildMemberOnlineNoticeEvent(GuildMemberNoticeEvent): + """服务器成员上线""" + + __event__ = "notice.guild_member_online" + notice_type: Literal["guild_member_online"] + + +class GuildMemberOfflineNoticeEvent(GuildMemberNoticeEvent): + """服务器成员下线""" + + __event__ = "notice.guild_member_offline" + notice_type: Literal["guild_member_offline"] + + +# Guild Role Events +class GuildRoleNoticeEvent(GuildNoticeEvent): + """服务器角色相关事件""" + + +class GuildRoleAddNoticeEvent(GuildRoleNoticeEvent): + """服务器角色增加""" + + __event__ = "notice.added_role" + notice_type: Literal["added_role"] + + +class GuildRoleDeleteNoticeEvent(GuildRoleNoticeEvent): + """服务器角色增加""" + + __event__ = "notice.deleted_role" + notice_type: Literal["deleted_role"] + + +class GuildRoleUpdateNoticeEvent(GuildRoleNoticeEvent): + """服务器角色增加""" + + __event__ = "notice.updated_role" + notice_type: Literal["updated_role"] + + +# Guild Events +class GuildUpdateNoticeEvent(GuildNoticeEvent): + """服务器信息更新""" + + __event__ = "notice.updated_guild" + notice_type: Literal["updated_guild"] + + +class GuildDeleteNoticeEvent(GuildNoticeEvent): + """服务器删除""" + + __event__ = "notice.deleted_guild" + notice_type: Literal["deleted_guild"] + + +class GuildAddBlockListNoticeEvent(GuildNoticeEvent): + """服务器封禁用户""" + + __event__ = "notice.added_block_list" + notice_type: Literal["added_block_list"] + + +class GuildDeleteBlockListNoticeEvent(GuildNoticeEvent): + """服务器取消封禁用户""" + + __event__ = "notice.deleted_block_list" + notice_type: Literal["deleted_block_list"] + + +# User Events +class UserNoticeEvent(NoticeEvent): + """用户相关事件列表""" + + group_id: int + + +class UserJoinAudioChannelNoticeEvent(UserNoticeEvent): + """用户加入语音频道""" + + __event__ = "notice.joined_channel" + notice_type: Literal["joined_channel"] + + +class UserJoinAudioChannelEvent(UserNoticeEvent): + """用户退出语音频道""" + + __event__ = "notice.exited_channel" + notice_type: Literal["exited_channel"] + + +class UserInfoUpdateNoticeEvent(UserNoticeEvent): + """ + 用户信息更新 + + 该事件与服务器无关, 遵循以下条件: + - 仅当用户的 用户名 或 头像 变更时 + - 仅通知与该用户存在关联的用户或 Bot + a. 存在聊天会话 + b. 双方好友关系 + """ + + __event__ = "notice.user_updated" + notice_type: Literal["user_updated"] + + +class SelfJoinGuildNoticeEvent(NoticeEvent): + """ + 自己新加入服务器 + + 当自己被邀请或主动加入新的服务器时, 产生该事件 + """ + + __event__ = "notice.self_joined_guild" + notice_type: Literal["self_joined_guild"] + user_id: str + group_id: int + + +class SelfExitGuildNoticeEvent(NoticeEvent): + """ + 自己退出服务器 + + 当自己被踢出服务器或被拉黑或主动退出服务器时, 产生该事件 + """ + + __event__ = "notice.self_exited_guild" + notice_type: Literal["self_exited_guild"] + user_id: str + group_id: int + + +class CartBtnClickNoticeEvent(NoticeEvent): + """ + Card 消息中的 Button 点击事件 + """ + + __event__ = "notice.message_btn_click" + notice_type: Literal["message_btn_click"] + user_id: str + group_id: int + + +# Meta Events +class MetaEvent(OriginEvent): + """元事件""" + + __event__ = "meta_event" + post_type: Literal["meta_event"] + meta_event_type: str + + +class LifecycleMetaEvent(MetaEvent): + """生命周期元事件""" + + __event__ = "meta_event.lifecycle" + meta_event_type: Literal["lifecycle"] + sub_type: str + + +class HeartbeatMetaEvent(MetaEvent): + """心跳元事件""" + + __event__ = "meta_event.heartbeat" + meta_event_type: Literal["heartbeat"] + sub_type: str + + +# 事件类映射 +_kook_events = { + model.__event__: model + for model in globals().values() + if inspect.isclass(model) and issubclass(model, OriginEvent) +} + + +def get_event_class( + post_type: str, event_type: str, sub_type: Optional[str] = None +) -> Type[T_KookEvent]: # type: ignore + """根据接收到的消息类型返回对应的事件类。 + + Args: + post_type: 请求类型。 + event_type: 事件类型。 + sub_type: 子类型。 + + Returns: + 对应的事件类。 + """ + if sub_type is None: + return _kook_events[".".join((post_type, event_type))] # type: ignore + return ( + _kook_events.get(".".join((post_type, event_type, sub_type))) + or _kook_events[".".join((post_type, event_type))] + ) # type: ignore diff --git a/iamai/adapter/kook/api/__init__.py b/iamai/adapter/kook/api/__init__.py new file mode 100644 index 00000000..adb03441 --- /dev/null +++ b/iamai/adapter/kook/api/__init__.py @@ -0,0 +1,2 @@ +from .model import * +from .client import ApiClient as ApiClient diff --git a/iamai/adapter/kook/api/client.py b/iamai/adapter/kook/api/client.py new file mode 100644 index 00000000..165d2728 --- /dev/null +++ b/iamai/adapter/kook/api/client.py @@ -0,0 +1 @@ +class ApiClient: ... diff --git a/iamai/adapter/kook/api/client.pyi b/iamai/adapter/kook/api/client.pyi new file mode 100644 index 00000000..a2058bb2 --- /dev/null +++ b/iamai/adapter/kook/api/client.pyi @@ -0,0 +1,355 @@ +from .model import * + +class ApiClient: + async def asset_create(self, *, file) -> URL: ... + async def blacklist_create( + self, + *, + guild_id: str, + target_id: str, + remark: Optional[str] = ..., + del_msg_days: Optional[str] = ..., + ) -> None: ... + async def blacklist_delete(self, *, guild_id: str, target_id: str) -> None: ... + async def blacklist_list(self, *, guild_id: str) -> BlackListsReturn: ... + async def channelRole_create( + self, *, channel_id: str, type: Optional[str] = ..., value: Optional[str] = ... + ) -> ChannelRoleReturn: ... + async def channelRole_delete( + self, + *, + channel_id: str, + type: Optional[str] = ..., + value: Optional[str] = ..., + ) -> None: ... + async def channelRole_index(self, *, channel_id: str) -> ChannelRoleInfo: + """获取频道角色权限详情 + + Args: + channel_id (str): 频道ID + + Returns: + ChannelRoleInfo: 频道角色权限详情 + """ + ... + + async def channelRole_update( + self, + *, + channel_id: str, + type: Optional[str] = ..., + value: Optional[str] = ..., + allow: Optional[int] = ..., + deny: Optional[int] = ..., + ) -> ChannelRoleReturn: ... + async def channel_create( + self, + *, + guild_id: str, + name: str, + parent_id: Optional[str] = ..., + type: Optional[int] = ..., + limit_amount: Optional[int] = ..., + voice_quality: Optional[str] = ..., + is_category: Optional[int] = ..., + ) -> Channel: ... + async def channel_delete(self, *, channel_id: str) -> None: ... + async def channel_update( + self, + *, + channel_id: str, + name: Optional[str] = ..., + topic: Optional[str] = ..., + slow_mode: Optional[int] = ..., + ) -> Channel: ... + async def channel_list( + self, + *, + guild_id: str, + type: Optional[int] = ..., + page: Optional[int] = ..., + page_size: Optional[int] = ..., + ) -> ChannelsReturn: ... + async def channel_moveUser( + self, + *, + target_id: str, + user_ids: List[int], + ) -> None: ... + async def channel_userList(self, *, channel_id: str) -> List[User]: ... + async def channel_view(self, *, target_id: str) -> Channel: ... + async def directMessage_addReaction(self, *, msg_id: str, emoji: str) -> None: ... + async def directMessage_create( + self, + *, + content: str, + type: Optional[int] = ..., + target_id: Optional[str] = ..., + chat_code: Optional[str] = ..., + quote: Optional[str] = ..., + nonce: Optional[str] = ..., + ) -> MessageCreateReturn: ... + async def directMessage_delete(self, *, msg_id: str) -> None: + """删除私信聊天消息 + + Args: + msg_id (str): 消息 id + """ + ... + + async def directMessage_deleteReaction( + self, *, msg_id: str, emoji: str, user_id: Optional[str] = ... + ) -> None: ... + async def directMessage_list( + self, + *, + chat_code: Optional[str] = ..., + target_id: Optional[str] = ..., + msg_id: Optional[str] = ..., + flag: Optional[str] = ..., + page: Optional[int] = ..., + page_size: Optional[int] = ..., + ) -> DirectMessagesReturn: + """获取私信聊天消息列表 + + Args: + chat_code (str, optional): + 私信会话 Code,chat_code与target_id必须传一个. + target_id (str, optional): + 目标用户 id,后端会自动创建会话. 有此参数之后可不传chat_code参数. + msg_id (str, optional): + 参考消息 id,不传则查询最新消息. + flag (str, optional): + 查询模式,有三种模式可以选择. 不传则默认查询最新的消息. + page (int, optional): 目标页数. + page_size (int, optional): 当前分页消息数量, 默认 `50`. + + Returns: + DirectMessagesReturn:获取私信聊天消息列表返回信息 + """ + ... + + async def directMessage_reactionList( + self, *, msg_id: str, emoji: str + ) -> List[ReactionUser]: ... + async def directMessage_update( + self, *, content: str, msg_id: Optional[str] = ..., quote: Optional[str] = ... + ) -> None: + """更新私信聊天消息 + + Args: + content (str): + 消息 id + msg_id (str, optional): + 消息内容 + quote (str, optional): + 回复某条消息的msgId. 如果为空,则代表删除回复,不传则无影响. + """ + ... + + async def directMessage_view( + self, *, chat_code: str, msg_id: str + ) -> DirectMessage: ... + async def gateway_index(self, *, compress: Optional[int] = ...) -> URL: ... + async def guildEmoji_create( + self, *, guild_id: str, emoji: Optional[bytes] = ..., name: Optional[str] = ... + ) -> GuildEmoji: ... + async def guildEmoji_delete(self, *, id: str) -> None: ... + async def guildEmoji_list( + self, + *, + guild_id: str, + page: Optional[int] = ..., + page_size: Optional[int] = ..., + ) -> GuildEmojisReturn: ... + async def guildEmoji_update(self, *, id: str, name: str) -> None: ... + async def guildMute_create( + self, *, guild_id: str = ..., target_id: str = ..., type: int = ... + ) -> None: ... + async def guildMute_delete( + self, *, guild_id: str = ..., target_id: str = ..., type: int = ... + ) -> None: ... + async def guildMute_list( + self, *, guild_id: str, return_type: Optional[str] = ... + ) -> None: ... + async def guildRole_create( + self, *, guild_id: str, name: Optional[str] = ... + ) -> Role: ... + async def guildRole_delete(self, *, guild_id: str, role_id: int) -> None: ... + async def guildRole_grant( + self, *, guild_id: str, user_id: str, role_id: int + ) -> GuilRoleReturn: ... + async def guildRole_list( + self, + *, + guild_id: str, + page: Optional[int] = ..., + page_size: Optional[int] = ..., + ) -> RolesReturn: ... + async def guildRole_revoke( + self, *, guild_id: str, user_id: str, role_id: int + ) -> GuilRoleReturn: ... + async def guildRole_update( + self, + *, + guild_id: str, + role_id: int, + name: Optional[str] = ..., + color: Optional[int] = ..., + hoist: Optional[int] = ..., + mentionable: Optional[int] = ..., + permissions: Optional[int] = ..., + ) -> Role: ... + async def guild_kickout(self, *, guild_id: str, target_id: str) -> None: ... + async def guild_leave(self, *, guild_id: str) -> None: ... + async def guild_list( + self, + *, + page: Optional[int] = ..., + page_size: Optional[int] = ..., + sort: Optional[str] = ..., + ) -> GuildsReturn: + """获取当前用户加入的服务器列表 + + Args: + page (Optional[int], optional): 目标页数 + page_size (Optional[int], optional): 每页数据数量 + sort (Optional[str], optional): 代表排序的字段 + + Returns: + GuildsReturn: 当前用户加入的服务器列表返回信息 + """ + ... + + async def guild_nickname( + self, + *, + guild_id: str = ..., + nickname: Optional[str] = ..., + user_id: Optional[str] = ..., + ) -> None: ... + async def guild_userList( + self, + *, + guild_id: str, + channel_id: Optional[str] = ..., + search: Optional[str] = ..., + role_id: Optional[int] = ..., + mobile_verified: Optional[int] = ..., + active_time: Optional[int] = ..., + joined_at: Optional[int] = ..., + page: Optional[int] = ..., + page_size: Optional[int] = ..., + filter_user_id: Optional[str] = ..., + ) -> GuildUsersRetrun: + """获取服务器中的用户列表 + + Args: + guild_id (str): 服务器 id + channel_id (Optional[str], optional): 频道 id + search (Optional[str], optional): 搜索关键字,在用户名或昵称中搜索 + role_id (Optional[int], optional): 角色 ID,获取特定角色的用户列表 + mobile_verified (Optional[int], optional): 只能为0或1,0是未认证,1是已认证 + active_time (Optional[int], optional): 根据活跃时间排序,0是顺序排列,1是倒序排列 + joined_at (Optional[int], optional): 根据加入时间排序,0是顺序排列,1是倒序排列 + page (Optional[int], optional): 目标页数 + page_size (Optional[int], optional): 每页数据数量 + filter_user_id (Optional[str], optional): 获取指定 id 所属用户的信息 + Returns: + GuildsReturn: 服务器中的用户列表返回信息 + """ + ... + + async def guild_view(self, *, guild_id: str) -> Guild: + """获取服务器详情 + + Args: + guild_id (str): 服务器id + + Returns: + Guild: 服务器详情 + """ + ... + + async def intimacy_index(self, *, user_id: str) -> IntimacyIndexReturn: ... + async def intimacy_update( + self, + *, + user_id: str, + score: Optional[int] = ..., + social_info: Optional[str] = ..., + img_id: Optional[str] = ..., + ) -> None: ... + async def invite_create( + self, + *, + guild_id: Optional[str] = ..., + channel_id: Optional[str] = ..., + duration: Optional[int] = ..., + setting_times: Optional[int] = ..., + ) -> URL: ... + async def invite_delete( + self, + *, + url_code: str, + guild_id: Optional[str] = ..., + channel_id: Optional[str] = ..., + ) -> None: ... + async def invite_list( + self, + *, + guild_id: Optional[str] = ..., + channel_id: Optional[str] = ..., + page: Optional[int] = ..., + page_size: Optional[int] = ..., + ) -> InvitesReturn: ... + async def message_addReaction(self, *, msg_id: str, emoji: str) -> None: ... + async def message_create( + self, + *, + content: str, + target_id: str, + type: Optional[int] = ..., + quote: Optional[str] = ..., + nonce: Optional[str] = ..., + temp_target_id: Optional[str] = ..., + ) -> MessageCreateReturn: ... + async def message_delete(self, *, msg_id: str) -> None: ... + async def message_deleteReaction( + self, *, msg_id: str, emoji: str, user_id: Optional[str] = ... + ) -> None: ... + async def message_list( + self, + *, + target_id: str, + msg_id: Optional[str] = ..., + pin: Optional[int] = ..., + flag: Optional[str] = ..., + page_size: Optional[int] = ..., + ) -> ChannelMessagesReturn: ... + async def message_reactionList( + self, *, msg_id: str, emoji: str + ) -> List[ReactionUser]: ... + async def message_update( + self, + *, + msg_id: str, + content: str, + quote: Optional[str] = ..., + temp_target_id: Optional[str] = ..., + ) -> None: ... + async def message_view(self, *, msg_id: str) -> ChannelMessage: ... + async def userChat_create(self, *, target_id: str) -> UserChat: ... + async def userChat_delete(self, *, chat_code: str) -> None: ... + async def userChat_list( + self, *, page: Optional[int] = ..., page_size: Optional[int] = ... + ) -> UserChatsReturn: ... + async def userChat_view(self, *, chat_code: str) -> UserChat: ... + async def user_me(self) -> User: ... + async def user_offline(self) -> None: + """下线机器人""" + ... + + async def user_view( + self, *, user_id: str, guild_id: Optional[str] = ... + ) -> User: ... diff --git a/iamai/adapter/kook/api/handle.py b/iamai/adapter/kook/api/handle.py new file mode 100644 index 00000000..8341b628 --- /dev/null +++ b/iamai/adapter/kook/api/handle.py @@ -0,0 +1,75 @@ +from .model import * + +api_method_map = { + "asset/create": {"method": "POST", "type": URL}, + "blacklist/create": {"method": "POST", "type": None}, + "blacklist/delete": {"method": "POST", "type": None}, + "blacklist/list": {"method": "GET", "type": BlackListsReturn}, + "channel-role/create": {"method": "POST", "type": ChannelRoleReturn}, + "channel-role/delete": {"method": "POST", "type": None}, + "channel-role/index": {"method": "GET", "type": ChannelRoleInfo}, + "channel-role/update": {"method": "POST", "type": ChannelRoleReturn}, + "channel/create": {"method": "POST", "type": Channel}, + "channel/delete": {"method": "POST", "type": None}, + "channel/update": {"method": "POST", "type": Channel}, + "channel/list": {"method": "GET", "type": ChannelsReturn}, + "channel/move-user": {"method": "POST", "type": None}, + "channel/user-list": {"method": "POST", "type": List[User]}, + "channel/view": {"method": "GET", "type": Channel}, + "direct-message/add-reaction": {"method": "POST", "type": None}, + "direct-message/create": {"method": "POST", "type": MessageCreateReturn}, + "direct-message/delete": {"method": "POST", "type": None}, + "direct-message/delete-reaction": {"method": "POST", "type": None}, + "direct-message/list": {"method": "GET", "type": DirectMessagesReturn}, + "direct-message/reaction-list": {"method": "GET", "type": List[ReactionUser]}, + "direct-message/update": {"method": "POST", "type": None}, + "direct-message/view": {"method": "GET", "type": DirectMessage}, + "gateway/index": {"method": "GET", "type": URL}, + "guild-emoji/create": {"method": "POST", "type": None}, + "guild-emoji/delete": {"method": "POST", "type": None}, + "guild-emoji/list": {"method": "GET", "type": GuildEmojisReturn}, + "guild-emoji/update": {"method": "POST", "type": None}, + "guild-mute/create": {"method": "POST", "type": None}, + "guild-mute/delete": {"method": "POST", "type": None}, + "guild-mute/list": {"method": "GET", "type": None}, + "guild-role/create": {"method": "POST", "type": Role}, + "guild-role/delete": {"method": "POST", "type": None}, + "guild-role/grant": {"method": "POST", "type": GuilRoleReturn}, + "guild-role/list": {"method": "GET", "type": RolesReturn}, + "guild-role/revoke": {"method": "POST", "type": GuilRoleReturn}, + "guild-role/update": {"method": "POST", "type": Role}, + "guild/kickout": {"method": "POST", "type": None}, + "guild/leave": {"method": "POST", "type": None}, + "guild/list": {"method": "GET", "type": GuildsReturn}, + "guild/nickname": {"method": "POST", "type": None}, + "guild/user-list": {"method": "GET", "type": GuildUsersRetrun}, + "guild/view": {"method": "GET", "type": Guild}, + "intimacy/index": {"method": "GET", "type": IntimacyIndexReturn}, + "intimacy/update": {"method": "POST", "type": None}, + "invite/create": {"method": "POST", "type": URL}, + "invite/delete": {"method": "POST", "type": None}, + "invite/list": {"method": "GET", "type": InvitesReturn}, + "message/add-reaction": {"method": "POST", "type": None}, + "message/create": {"method": "POST", "type": MessageCreateReturn}, + "message/delete": {"method": "POST", "type": None}, + "message/delete-reaction": {"method": "POST", "type": None}, + "message/list": {"method": "GET", "type": ChannelMessagesReturn}, + "message/reaction-list": {"method": "GET", "type": List[ReactionUser]}, + "message/update": {"method": "POST", "type": None}, + "message/view": {"method": "GET", "type": ChannelMessage}, + "user-chat/create": {"method": "POST", "type": UserChat}, + "user-chat/delete": {"method": "POST", "type": None}, + "user-chat/list": {"method": "GET", "type": UserChatsReturn}, + "user-chat/view": {"method": "GET", "type": UserChat}, + "user/me": {"method": "GET", "type": User}, + "user/offline": {"method": "POST", "type": None}, + "user/view": {"method": "GET", "type": User}, +} + + +def get_api_method(api: str) -> str: + return api_method_map.get(api, {}).get("method", "POST") + + +def get_api_restype(api: str) -> Any: + return api_method_map.get(api, {}).get("type") diff --git a/iamai/adapter/kook/api/model.py b/iamai/adapter/kook/api/model.py new file mode 100644 index 00000000..b37380c6 --- /dev/null +++ b/iamai/adapter/kook/api/model.py @@ -0,0 +1,444 @@ +from typing import Any, Dict, List, Union, Optional + +from pydantic import Field, BaseModel + + +class User(BaseModel): + """ + 开黑啦 User 字段 + + https://developer.kookapp.cn/doc/objects#%E7%94%A8%E6%88%B7User + """ + + id_: Optional[str] = Field(alias="id") + username: Optional[str] + nickname: Optional[str] + identify_num: Optional[str] + online: Optional[bool] + bot: Optional[bool] + status: Optional[int] + avatar: Optional[str] + vip_avatar: Optional[str] + mobile_verified: Optional[bool] + roles: Optional[List[int]] + + +class Role(BaseModel): + """角色""" + + role_id: Optional[int] = None + """角色 id""" + name: Optional[str] = None + """角色名称""" + color: Optional[int] = None + """颜色色值""" + position: Optional[int] = None + """顺序位置""" + hoist: Optional[int] = None + """是否为角色设定(与普通成员分开显示)""" + mentionable: Optional[int] = None + """是否允许任何人@提及此角色""" + permissions: Optional[int] = None + """权限码""" + + +class PermissionOverwrite(BaseModel): + role_id: Optional[int] = None + allow: Optional[int] = None + deny: Optional[int] = None + + +class PermissionUser(BaseModel): + user: Optional[User] = None + allow: Optional[int] = None + deny: Optional[int] = None + + +class ChannelRoleInfo(BaseModel): + """频道角色权限详情""" + + permission_overwrites: Optional[List[PermissionOverwrite]] = None + """针对角色在该频道的权限覆写规则组成的列表""" + permission_users: Optional[List[PermissionUser]] = None + """针对用户在该频道的权限覆写规则组成的列表""" + permission_sync: Optional[int] = None + """权限设置是否与分组同步, 1 or 0""" + + +class Channel(ChannelRoleInfo): + """开黑啦 频道 字段""" + + id_: Optional[str] = Field(None, alias="id") + """频道 id""" + name: Optional[str] = None + """频道名称""" + user_id: Optional[str] = None + """创建者 id""" + master_id: Optional[str] = None + """master id """ + guild_id: Optional[str] = None + """服务器 id""" + topic: Optional[str] = None + """频道简介""" + is_category: Optional[bool] = None + """是否为分组,事件中为 int 格式""" + parent_id: Optional[str] = None + """上级分组的 id""" + level: Optional[int] = None + """排序 level""" + slow_mode: Optional[int] = None + """慢速模式下限制发言的最短时间间隔, 单位为秒(s)""" + type: Optional[int] = None + """频道类型: 1 文字频道, 2 语音频道""" + has_password: Optional[bool] = None + """是否有密码""" + limit_amount: Optional[int] = None + """人数限制""" + + +class Guild(BaseModel): + """服务器""" + + id_: Optional[str] = Field(None, alias="id") + """服务器 id""" + name: Optional[str] = None + """服务器名称""" + topic: Optional[str] = None + """服务器主题""" + user_id: Optional[str] = None + """服务器主的 id""" + icon: Optional[str] = None + """服务器 icon 的地址""" + notify_type: Optional[int] = None + """通知类型\n + `0`代表默认使用服务器通知设置\n + `1`代表接收所有通知\n + `2`代表仅@被提及\n + `3`代表不接收通知 + """ + region: Optional[str] = None + """服务器默认使用语音区域""" + enable_open: Optional[bool] = None + """是否为公开服务器""" + open_id: Optional[str] = None + """公开服务器 id""" + default_channel_id: Optional[str] = None + """默认频道 id""" + welcome_channel_id: Optional[str] = None + """欢迎频道 id""" + roles: Optional[List[Role]] = None + """角色列表""" + channels: Optional[List[Channel]] = None + """频道列表""" + + +class Quote(BaseModel): + """引用消息""" + + id_: Optional[str] = Field(None, alias="id") + """引用消息 id""" + type: Optional[int] = None + """引用消息类型""" + content: Optional[str] = None + """引用消息内容""" + create_at: Optional[int] = None + """引用消息创建时间(毫秒)""" + author: Optional[User] = None + """作者的用户信息""" + + +class Attachments(BaseModel): + """附加的多媒体数据""" + + type: Optional[str] = None + """多媒体类型""" + url: Optional[str] = None + """多媒体地址""" + name: Optional[str] = None + """多媒体名""" + size: Optional[int] = None + """大小 单位(B)""" + + +class Emoji(BaseModel): + id_: Optional[str] = Field(None, alias="id") + name: Optional[str] = None + + # 转义 unicdoe 为 emoji表情 + # @root_validator(pre=True) + # def parse_emoji(cls, values: dict): + # values['id'] = chr(int(values['id'][2:-2])) + # values['name'] = chr(int(values['name'][2:-2])) + # return values + + +class URL(BaseModel): + url: Optional[str] = None + """资源的 url""" + + +class Meta(BaseModel): + page: Optional[int] = None + page_total: Optional[int] = None + page_size: Optional[int] = None + total: Optional[int] = None + + +class ListReturn(BaseModel): + meta: Optional[Meta] = None + sort: Optional[Dict[str, Any]] = None + + +class BlackList(BaseModel): + """黑名单""" + + user_id: Optional[str] = None + """用户 id""" + created_time: Optional[int] = None + """加入黑名单的时间戳(毫秒)""" + remark: Optional[str] = None + """加入黑名单的原因""" + user: Optional[User] = None + """用户""" + + +class BlackListsReturn(ListReturn): + """获取黑名单列表返回信息""" + + blacklists: Optional[List[BlackList]] = Field(None, alias="items") + """黑名单列表""" + + +class MessageCreateReturn(BaseModel): + """发送频道消息返回信息""" + + msg_id: Optional[str] = None + """服务端生成的消息 id""" + msg_timestamp: Optional[int] = None + """消息发送时间(服务器时间戳)""" + nonce: Optional[str] = None + """随机字符串""" + + +class ChannelRoleReturn(BaseModel): + """创建或更新频道角色权限返回信息""" + + role_id: Optional[int] = None + user_id: Optional[str] = None + allow: Optional[int] = None + deny: Optional[int] = None + + +class GuildsReturn(ListReturn): + guilds: Optional[List[Guild]] = Field(None, alias="items") + + +class ChannelsReturn(ListReturn): + channels: Optional[List[Channel]] = Field(None, alias="items") + + +class GuildUsersRetrun(ListReturn): + """服务器中的用户列表""" + + users: Optional[List[User]] = Field(None, alias="items") + """用户列表""" + user_count: Optional[int] = None + """用户数量""" + online_count: Optional[int] = None + """在线用户数量""" + offline_count: Optional[int] = None + """离线用户数量""" + + +class Reaction(BaseModel): + emoji: Optional[Emoji] = None + count: Optional[int] = None + me: Optional[bool] = None + + +class MentionInfo(BaseModel): + mention_part: Optional[List[dict]] = None + mention_role_part: Optional[List[dict]] = None + channel_part: Optional[List[dict]] = None + item_part: Optional[List[dict]] = None + + +class BaseMessage(BaseModel): + id_: Optional[str] = Field(None, alias="id") + """消息 ID""" + type: Optional[int] = None + """消息类型""" + content: Optional[str] = None + """消息内容""" + embeds: Optional[List[dict]] = None + """超链接解析数据""" + attachments: Optional[Union[bool, Attachments]] = None + """附加的多媒体数据""" + create_at: Optional[int] = None + """创建时间""" + updated_at: Optional[int] = None + """更新时间""" + reactions: Optional[List[Reaction]] = None + """回应数据""" + image_name: Optional[str] = None + """""" + read_status: Optional[bool] = None + """是否已读""" + quote: Optional[Quote] = None + """引用数据""" + mention_info: Optional[MentionInfo] = None + """引用特定用户或特定角色的信息""" + + +class ChannelMessage(BaseMessage): + """频道消息""" + + author: Optional[User] = None + mention: Optional[List[Any]] = None + mention_all: Optional[bool] = None + mention_roles: Optional[List[Any]] = None + mention_here: Optional[bool] = None + + +class DirectMessage(BaseMessage): + """私聊消息""" + + author_id: Optional[str] = None + """作者的用户 ID""" + from_type: Optional[int] = None + """from_type""" + msg_icon: Optional[str] = None + """msg_icon""" + + +class ChannelMessagesReturn(BaseModel): + """获取私信聊天消息列表返回信息""" + + direct_messages: Optional[List[ChannelMessage]] = Field(None, alias="items") + + +class DirectMessagesReturn(BaseModel): + """获取私信聊天消息列表返回信息""" + + direct_messages: Optional[List[DirectMessage]] = Field(None, alias="items") + + +class ReactionUser(User): + reaction_time: Optional[int] = None + + +class TargetInfo(BaseModel): + """私聊会话 目标用户信息""" + + id_: Optional[str] = Field(None, alias="id") + """目标用户 ID""" + username: Optional[str] = None + """目标用户名""" + online: Optional[bool] = None + """是否在线""" + avatar: Optional[str] = None + """头像图片链接""" + + +class UserChat(BaseModel): + """私聊会话""" + + code: Optional[str] = None + """私信会话 Code""" + last_read_time: Optional[int] = None + """上次阅读消息的时间 (毫秒)""" + latest_msg_time: Optional[int] = None + """最新消息时间 (毫秒)""" + unread_count: Optional[int] = None + """未读消息数""" + target_info: Optional[TargetInfo] = None + """目标用户信息""" + + +class UserChatsReturn(ListReturn): + """获取私信聊天会话列表返回信息""" + + user_chats: Optional[List[UserChat]] = Field(None, alias="items") + """私聊会话列表""" + + +class RolesReturn(ListReturn): + """获取服务器角色列表返回信息""" + + roles: Optional[List[Role]] = Field(None, alias="items") + """服务器角色列表""" + + +class GuilRoleReturn(BaseModel): + """赋予或删除用户角色返回信息""" + + user_id: Optional[str] = None + """用户 id""" + guild_id: Optional[str] = None + """服务器 id""" + roles: Optional[List[int]] = None + """角色 id 的列表""" + + +class IntimacyImg(BaseModel): + """形象图片的总列表""" + + id_: Optional[str] = Field(None, alias="id") + """ 形象图片的 id""" + url: Optional[str] = None + """形象图片的地址""" + + +class IntimacyIndexReturn(BaseModel): + """获取用户亲密度返回信息""" + + img_url: Optional[str] = None + """机器人给用户显示的形象图片地址""" + social_info: Optional[str] = None + """机器人显示给用户的社交信息""" + last_read: Optional[int] = None + """用户上次查看的时间戳""" + score: Optional[int] = None + """亲密度,0-2200""" + img_list: Optional[List[IntimacyImg]] = None + """形象图片的总列表""" + + +class GuildEmoji(BaseModel): + """服务器表情""" + + name: Optional[str] = None + """表情的名称""" + id_: Optional[str] = Field(None, alias="id") + """表情的 ID""" + user_info: Optional[User] = None + """上传用户""" + + +class GuildEmojisReturn(ListReturn): + """获取服务器表情列表返回信息""" + + roles: Optional[List[GuildEmoji]] = Field(None, alias="items") + """服务器表情列表""" + + +class Invite(BaseModel): + """邀请信息""" + + guild_id: Optional[str] = None + """服务器 id""" + channel_id: Optional[str] = None + """频道 id""" + url_code: Optional[str] = None + """url code""" + url: Optional[str] = None + """地址""" + user: Optional[User] = None + """用户""" + + +class InvitesReturn(ListReturn): + """获取邀请列表返回信息""" + + roles: Optional[List[Invite]] = Field(None, alias="items") + """邀请列表""" diff --git a/iamai/adapter/kook/config.py b/iamai/adapter/kook/config.py new file mode 100644 index 00000000..a0ec8120 --- /dev/null +++ b/iamai/adapter/kook/config.py @@ -0,0 +1,27 @@ +"""Kook 适配器配置。""" + +from typing import Literal + +from iamai.config import ConfigModel + + +class Config(ConfigModel): + """Kook 配置类,将在适配器被加载时被混入到机器人主配置中。 + + Attributes: + adapter_type: 适配器类型,需要和协议端配置相同。 + reconnect_interval: 重连等待时间。 + api_timeout: 进行 API 调用时等待返回响应的超时时间。 + access_token: 鉴权密钥。 + compress: 是否启用压缩,默认为 0,(建议)不启用。 + show_raw: 是否显示原始数据,默认为 False,不显示。 + """ + + __config_name__ = "kook" + adapter_type: Literal["ws"] = "ws" + reconnect_interval: int = 3 + api_timeout: int = 1000 + access_token: str = "" + compress: Literal[0, 1] = 0 + show_raw: bool = False + report_self_message: bool = False diff --git a/iamai/adapter/kook/event.py b/iamai/adapter/kook/event.py new file mode 100644 index 00000000..e140414f --- /dev/null +++ b/iamai/adapter/kook/event.py @@ -0,0 +1,930 @@ +"""Kook 适配器事件。""" + +import asyncio +import inspect +from enum import IntEnum +from collections import UserDict +from typing import ( # type: ignore + TYPE_CHECKING, + Any, + Dict, + List, + Type, + Tuple, + Union, + Literal, + TypeVar, + Optional, +) + +from pydantic import Field, HttpUrl, BaseModel, validator, root_validator + +from iamai.event import Event + +from .api import Role, User, Emoji, Guild, Channel +from .message import KookMessage, MessageDeserializer + +if TYPE_CHECKING: + from . import KookAdapter + from .message import T_KookMSG + +T_KookEvent = TypeVar("T_KookEvent", bound="KookEvent") + + +class ResultStore: + _seq = 1 + _futures: Dict[Tuple[str, int], asyncio.Future] = {} + _sn_map = {} + + @classmethod + def set_sn(cls, self_id: str, sn: int) -> None: + cls._sn_map[self_id] = sn + + @classmethod + def get_sn(cls, self_id: str) -> int: + return cls._sn_map.get(self_id, 0) + + +class AttrDict(UserDict): + def __init__(self, data=None): + initial = dict(data) # type: ignore + for k in initial: + if isinstance(initial[k], dict): + initial[k] = AttrDict(initial[k]) # type: ignore + + super().__init__(initial) + + def __getattr__(self, name): + return self[name] + + +class PermissionOverwrite(BaseModel): + role_id: Optional[int] = None + allow: Optional[int] = None + deny: Optional[int] = None + + +class PermissionUser(BaseModel): + user: Optional[User] = None + allow: Optional[int] = None + deny: Optional[int] = None + + +class ChannelRoleInfo(BaseModel): + """频道角色权限详情""" + + permission_overwrites: Optional[List[PermissionOverwrite]] = None + """针对角色在该频道的权限覆写规则组成的列表""" + permission_users: Optional[List[PermissionUser]] = None + """针对用户在该频道的权限覆写规则组成的列表""" + permission_sync: Optional[int] = None + """权限设置是否与分组同步, 1 or 0""" + + +class Quote(BaseModel): + """引用消息""" + + id_: Optional[str] = Field(None, alias="id") + """引用消息 id""" + type: Optional[int] = None + """引用消息类型""" + content: Optional[str] = None + """引用消息内容""" + create_at: Optional[int] = None + """引用消息创建时间(毫秒)""" + author: Optional[User] = None + """作者的用户信息""" + + +class Attachments(BaseModel): + """附加的多媒体数据""" + + type: Optional[str] = None + """多媒体类型""" + url: Optional[str] = None + """多媒体地址""" + name: Optional[str] = None + """多媒体名""" + size: Optional[int] = None + """大小 单位(B)""" + + +class URL(BaseModel): + url: Optional[str] = None + """资源的 url""" + + +class Meta(BaseModel): + page: Optional[int] = None + page_total: Optional[int] = None + page_size: Optional[int] = None + total: Optional[int] = None + + +class ListReturn(BaseModel): + meta: Optional[Meta] = None + sort: Optional[Dict[str, Any]] = None + + +class BlackList(BaseModel): + """黑名单""" + + user_id: Optional[str] = None + """用户 id""" + created_time: Optional[int] = None + """加入黑名单的时间戳(毫秒)""" + remark: Optional[str] = None + """加入黑名单的原因""" + user: Optional[User] = None + """用户""" + + +class BlackListsReturn(ListReturn): + """获取黑名单列表返回信息""" + + blacklists: Optional[List[BlackList]] = Field(None, alias="items") + """黑名单列表""" + + +class MessageCreateReturn(BaseModel): + """发送频道消息返回信息""" + + msg_id: Optional[str] = None + """服务端生成的消息 id""" + msg_timestamp: Optional[int] = None + """消息发送时间(服务器时间戳)""" + nonce: Optional[str] = None + """随机字符串""" + + +class ChannelRoleReturn(BaseModel): + """创建或更新频道角色权限返回信息""" + + role_id: Optional[int] = None + user_id: Optional[str] = None + allow: Optional[int] = None + deny: Optional[int] = None + + +class GuildsReturn(ListReturn): + guilds: Optional[List[Guild]] = Field(None, alias="items") + + +class ChannelsReturn(ListReturn): + channels: Optional[List[Channel]] = Field(None, alias="items") + + +class GuildUsersRetrun(ListReturn): + """服务器中的用户列表""" + + users: Optional[List[User]] = Field(None, alias="items") + """用户列表""" + user_count: Optional[int] = None + """用户数量""" + online_count: Optional[int] = None + """在线用户数量""" + offline_count: Optional[int] = None + """离线用户数量""" + + +class Reaction(BaseModel): + emoji: Optional[Emoji] = None + count: Optional[int] = None + me: Optional[bool] = None + + +class MentionInfo(BaseModel): + mention_part: Optional[List[dict]] = None + mention_role_part: Optional[List[dict]] = None + channel_part: Optional[List[dict]] = None + item_part: Optional[List[dict]] = None + + +class BaseMessage(BaseModel): + id_: Optional[str] = Field(None, alias="id") + """消息 ID""" + type: Optional[int] = None + """消息类型""" + content: Optional[str] = None + """消息内容""" + embeds: Optional[List[dict]] = None + """超链接解析数据""" + attachments: Optional[Union[bool, Attachments]] = None + """附加的多媒体数据""" + create_at: Optional[int] = None + """创建时间""" + updated_at: Optional[int] = None + """更新时间""" + reactions: Optional[List[Reaction]] = None + """回应数据""" + image_name: Optional[str] = None + """""" + read_status: Optional[bool] = None + """是否已读""" + quote: Optional[Quote] = None + """引用数据""" + mention_info: Optional[MentionInfo] = None + """引用特定用户或特定角色的信息""" + + +class ChannelMessage(BaseMessage): + """频道消息""" + + author: Optional[User] = None + mention: Optional[List[Any]] = None + mention_all: Optional[bool] = None + mention_roles: Optional[List[Any]] = None + mention_here: Optional[bool] = None + + +class DirectMessage(BaseMessage): + """私聊消息""" + + author_id: Optional[str] = None + """作者的用户 ID""" + from_type: Optional[int] = None + """from_type""" + msg_icon: Optional[str] = None + """msg_icon""" + + +class ChannelMessagesReturn(BaseModel): + """获取私信聊天消息列表返回信息""" + + direct_messages: Optional[List[ChannelMessage]] = Field(None, alias="items") + + +class DirectMessagesReturn(BaseModel): + """获取私信聊天消息列表返回信息""" + + direct_messages: Optional[List[DirectMessage]] = Field(None, alias="items") + + +class ReactionUser(User): + reaction_time: Optional[int] = None + + +class TargetInfo(BaseModel): + """私聊会话 目标用户信息""" + + id_: Optional[str] = Field(None, alias="id") + """目标用户 ID""" + username: Optional[str] = None + """目标用户名""" + online: Optional[bool] = None + """是否在线""" + avatar: Optional[str] = None + """头像图片链接""" + + +class UserChat(BaseModel): + """私聊会话""" + + code: Optional[str] = None + """私信会话 Code""" + last_read_time: Optional[int] = None + """上次阅读消息的时间 (毫秒)""" + latest_msg_time: Optional[int] = None + """最新消息时间 (毫秒)""" + unread_count: Optional[int] = None + """未读消息数""" + target_info: Optional[TargetInfo] = None + """目标用户信息""" + + +class UserChatsReturn(ListReturn): + """获取私信聊天会话列表返回信息""" + + user_chats: Optional[List[UserChat]] = Field(None, alias="items") + """私聊会话列表""" + + +class RolesReturn(ListReturn): + """获取服务器角色列表返回信息""" + + roles: Optional[List[Role]] = Field(None, alias="items") + """服务器角色列表""" + + +class GuilRoleReturn(BaseModel): + """赋予或删除用户角色返回信息""" + + user_id: Optional[str] = None + """用户 id""" + guild_id: Optional[str] = None + """服务器 id""" + roles: Optional[List[int]] = None + """角色 id 的列表""" + + +class IntimacyImg(BaseModel): + """形象图片的总列表""" + + id_: Optional[str] = Field(None, alias="id") + """ 形象图片的 id""" + url: Optional[str] = None + """形象图片的地址""" + + +class IntimacyIndexReturn(BaseModel): + """获取用户亲密度返回信息""" + + img_url: Optional[str] = None + """机器人给用户显示的形象图片地址""" + social_info: Optional[str] = None + """机器人显示给用户的社交信息""" + last_read: Optional[int] = None + """用户上次查看的时间戳""" + score: Optional[int] = None + """亲密度,0-2200""" + img_list: Optional[List[IntimacyImg]] = None + """形象图片的总列表""" + + +class GuildEmoji(BaseModel): + """服务器表情""" + + name: Optional[str] = None + """表情的名称""" + id_: Optional[str] = Field(None, alias="id") + """表情的 ID""" + user_info: Optional[User] = None + """上传用户""" + + +class GuildEmojisReturn(ListReturn): + """获取服务器表情列表返回信息""" + + roles: Optional[List[GuildEmoji]] = Field(None, alias="items") + """服务器表情列表""" + + +class Invite(BaseModel): + """邀请信息""" + + guild_id: Optional[str] = None + """服务器 id""" + channel_id: Optional[str] = None + """频道 id""" + url_code: Optional[str] = None + """url code""" + url: Optional[str] = None + """地址""" + user: Optional[User] = None + """用户""" + + +class InvitesReturn(ListReturn): + """获取邀请列表返回信息""" + + roles: Optional[List[Invite]] = Field(None, alias="items") + """邀请列表""" + + +class EventTypes(IntEnum): + """ + 事件主要格式 + Kook 协议事件,字段与 Kook 一致。各事件字段参考 `Kook 文档` + + .. Kook 文档: + https://developer.kookapp.cn/doc/event/event-introduction#事件主要格式 + """ + + text = 1 + image = 2 + video = 3 + file = 4 + audio = 8 + kmarkdown = 9 + card = 10 + sys = 255 + + +class SignalTypes(IntEnum): + """ + 信令类型 + Kook 协议信令,字段与 Kook 一致。各事件字段参考 `Kook 文档` + + .. Kook 文档: + https://developer.kookapp.cn/doc/websocket#信令格式 + """ + + EVENT = 0 + HELLO = 1 + PING = 2 + PONG = 3 + RESUME = 4 + RECONNECT = 5 + RESUME_ACK = 6 + SYS = 255 + + +class Attachment(BaseModel): + type: str + name: str + url: HttpUrl + file_type: Optional[str] = Field(None) + size: Optional[int] = Field(None) + duration: Optional[float] = Field(None) + width: Optional[int] = Field(None) + hight: Optional[int] = Field(None) + + +class Extra(BaseModel): + type_: Union[int, str] = Field(None, alias="type") + guild_id: Optional[str] = Field(None) + channel_name: Optional[str] = Field(None) + mention: Optional[List[str]] = Field(None) + mention_all: Optional[bool] = Field(None) + mention_roles: Optional[List[str]] = Field(None) + mention_here: Optional[bool] = Field(None) + author: Optional[User] = Field(None) + body: Optional[AttrDict] = Field(None) + attachments: Optional[Attachment] = Field(None) + code: Optional[str] = Field(None) + + @validator("body", pre=True) + def convert_body(cls, v): + if v is None: + return None + + if not isinstance(v, dict): + raise TypeError("body must be dict") + if not isinstance(v, AttrDict): + v = AttrDict(v) + return v + + class Config: + arbitrary_types_allowed = True + + +class OriginEvent(Event["KookAdapter"]): + """为了区分信令中非Event事件,增加了前置OriginEvent""" + + __event__ = "" + + post_type: str + + +class Kmarkdown(BaseModel): + raw_content: str + mention_part: list + mention_role_part: list + + +class EventMessage(BaseModel): + type: Union[int, str] + guild_id: Optional[str] + channel_name: Optional[str] + mention: Optional[List] + mention_all: Optional[bool] + mention_roles: Optional[List] + mention_here: Optional[bool] + nav_channels: Optional[List] + author: User + + kmarkdown: Optional[Kmarkdown] + + code: Optional[str] = None + attachments: Optional[Attachment] = None + + content: KookMessage + + +class KookEvent(OriginEvent): + """ + 事件主要格式,来自 d 字段 + Kook 协议事件,字段与 Kook 一致。各事件字段参考 `Kook 文档` + + .. Kook 文档: + https://developer.kookapp.cn/doc/event/event-introduction + """ + + __event__ = "" + channel_type: Literal["PERSON", "GROUP"] + type_: int = Field(alias="type") + """1:文字消息\n2:图片消息\n3:视频消息\n4:文件消息\n8:音频消息\n9:KMarkdown\n10:card消息\n255:系统消息\n其它的暂未开放""" + target_id: str + """ + 发送目的\n + 频道消息类时, 代表的是频道 channel_id\n + 如果 channel_type 为 GROUP 组播且 type 为 255 系统消息时,则代表服务器 guild_id""" + author_id: Optional[str] = None + content: KookMessage + msg_id: str + msg_timestamp: int + nonce: str + extra: Extra + user_id: str + + post_type: str + self_id: Optional[str] = None # onebot兼容 + + +# Message Events +class MessageEvent(KookEvent): + """消息事件""" + + __event__ = "message" + + post_type: Literal["message"] = "message" + message_type: str # group private 其实是person + sub_type: str + event: EventMessage + + def __repr__(self) -> str: + return f'Event<{self.post_type}>: "{self.content}"' + + def get_plain_text(self) -> str: + """获取消息的纯文本内容。 + + Returns: + 消息的纯文本内容。 + """ + return self.content.get_plain_text() # type: ignore + + async def reply(self, msg: "T_KookMSG") -> Dict[str, Any]: + """回复消息。 + + Args: + msg: 回复消息的内容,同 `call_api()` 方法。 + + Returns: + API 请求响应。 + """ + raise NotImplementedError + + +class PrivateMessageEvent(MessageEvent): + """私聊消息""" + + __event__ = "message.private" + message_type: Literal["private"] + + async def reply(self, msg: "T_KookMSG") -> Dict[str, Any]: + return await self.adapter.call_api( + api="direct-message/create", target_id=self.author_id, content=msg + ) + + +class ChannelMessageEvent(MessageEvent): + """公共频道消息""" + + __event__ = "message.group" + message_type: Literal["group"] + group_id: str + + async def reply(self, msg: "T_KookMSG") -> Dict[str, Any]: + return await self.adapter.call_api( + "message/create", target_id=self.target_id, content=msg + ) + + +# Notice Events +class NoticeEvent(KookEvent): + """通知事件""" + + __event__ = "notice" + post_type: Literal["notice"] + notice_type: str + + def __repr__(self) -> str: + return f'Event<{self.post_type}>: "{self.content}"' + + +# Channel Events +class ChannelNoticeEvent(NoticeEvent): + """频道消息事件""" + + __event__ = "notice" + group_id: int + + +class ChannelAddReactionEvent(ChannelNoticeEvent): + """频道内用户添加 reaction""" + + __event__ = "notice.added_reaction" + notice_type: Literal["added_reaction"] + + +class ChannelDeletedReactionEvent(ChannelNoticeEvent): + """频道内用户删除 reaction""" + + __event__ = "notice.deleted_reaction" + notice_type: Literal["deleted_reaction"] + + +class ChannelUpdatedMessageEvent(ChannelNoticeEvent): + """频道消息更新""" + + __event__ = "notice.updated_message" + notice_type: Literal["updated_message"] + + +class ChannelDeleteMessageEvent(ChannelNoticeEvent): + """频道消息被删除""" + + __event__ = "notice.deleted_message" + notice_type: Literal["deleted_message"] + + +class ChannelAddedEvent(ChannelNoticeEvent): + """新增频道""" + + __event__ = "notice.added_channel" + notice_type: Literal["added_channel"] + + +class ChannelUpdatedEvent(ChannelNoticeEvent): + """修改频道信息""" + + __event__ = "notice.updated_channel" + notice_type: Literal["updated_channel"] + + +class ChannelDeleteEvent(ChannelNoticeEvent): + """删除频道""" + + __event__ = "notice.deleted_channel" + notice_type: Literal["deleted_channel"] + + +class ChannelPinnedMessageEvent(ChannelNoticeEvent): + """新增频道置顶消息""" + + __event__ = "notice.pinned_message" + notice_type: Literal["pinned_message"] + + +class ChannelUnpinnedMessageEvent(ChannelNoticeEvent): + """取消频道置顶消息""" + + __event__ = "notice.unpinned_message" + notice_type: Literal["unpinned_message"] + + +# Private Events +class PrivateNoticeEvent(NoticeEvent): + "私聊消息事件" + + +class PrivateUpdateMessageEvent(PrivateNoticeEvent): + """私聊消息更新""" + + __event__ = "notice.updated_private_message" + notice_type: Literal["updated_private_message"] + + +class PrivateDeleteMessageEvent(PrivateNoticeEvent): + """私聊消息删除""" + + __event__ = "notice.deleted_private_message" + notice_type: Literal["deleted_private_message"] + + +class PrivateAddReactionEvent(PrivateNoticeEvent): + """私聊内用户添加 reaction""" + + __event__ = "notice.private_added_reaction" + notice_type: Literal["private_added_reaction"] + + +class PrivateDeleteReactionEvent(PrivateNoticeEvent): + """私聊内用户取消 reaction""" + + __event__ = "notice.private_deleted_reaction" + notice_type: Literal["private_deleted_reaction"] + + +# Guild Events +class GuildNoticeEvent(NoticeEvent): + """服务器相关事件""" + + group_id: int + + def get_guild_id(self): + return self.target_id # type: ignore + + +# Guild Member Events +class GuildMemberNoticeEvent(GuildNoticeEvent): + """服务器成员相关事件""" + + pass + + +class GuildMemberIncreaseNoticeEvent(GuildMemberNoticeEvent): + """新成员加入服务器""" + + __event__ = "notice.joined_guild" + notice_type: Literal["joined_guild"] + + +class GuildMemberDecreaseNoticeEvent(GuildMemberNoticeEvent): + """服务器成员退出""" + + __event__ = "notice.exited_guild" + notice_type: Literal["exited_guild"] + + +class GuildMemberUpdateNoticeEvent(GuildMemberNoticeEvent): + """服务器成员信息更新(修改昵称)""" + + __event__ = "notice.updated_guild_member" + notice_type: Literal["updated_guild_member"] + + +class GuildMemberOnlineNoticeEvent(GuildMemberNoticeEvent): + """服务器成员上线""" + + __event__ = "notice.guild_member_online" + notice_type: Literal["guild_member_online"] + + +class GuildMemberOfflineNoticeEvent(GuildMemberNoticeEvent): + """服务器成员下线""" + + __event__ = "notice.guild_member_offline" + notice_type: Literal["guild_member_offline"] + + +# Guild Role Events +class GuildRoleNoticeEvent(GuildNoticeEvent): + """服务器角色相关事件""" + + +class GuildRoleAddNoticeEvent(GuildRoleNoticeEvent): + """服务器角色增加""" + + __event__ = "notice.added_role" + notice_type: Literal["added_role"] + + +class GuildRoleDeleteNoticeEvent(GuildRoleNoticeEvent): + """服务器角色增加""" + + __event__ = "notice.deleted_role" + notice_type: Literal["deleted_role"] + + +class GuildRoleUpdateNoticeEvent(GuildRoleNoticeEvent): + """服务器角色增加""" + + __event__ = "notice.updated_role" + notice_type: Literal["updated_role"] + + +# Guild Events +class GuildUpdateNoticeEvent(GuildNoticeEvent): + """服务器信息更新""" + + __event__ = "notice.updated_guild" + notice_type: Literal["updated_guild"] + + +class GuildDeleteNoticeEvent(GuildNoticeEvent): + """服务器删除""" + + __event__ = "notice.deleted_guild" + notice_type: Literal["deleted_guild"] + + +class GuildAddBlockListNoticeEvent(GuildNoticeEvent): + """服务器封禁用户""" + + __event__ = "notice.added_block_list" + notice_type: Literal["added_block_list"] + + +class GuildDeleteBlockListNoticeEvent(GuildNoticeEvent): + """服务器取消封禁用户""" + + __event__ = "notice.deleted_block_list" + notice_type: Literal["deleted_block_list"] + + +# User Events +class UserNoticeEvent(NoticeEvent): + """用户相关事件列表""" + + group_id: int + + +class UserJoinAudioChannelNoticeEvent(UserNoticeEvent): + """用户加入语音频道""" + + __event__ = "notice.joined_channel" + notice_type: Literal["joined_channel"] + + +class UserJoinAudioChannelEvent(UserNoticeEvent): + """用户退出语音频道""" + + __event__ = "notice.exited_channel" + notice_type: Literal["exited_channel"] + + +class UserInfoUpdateNoticeEvent(UserNoticeEvent): + """ + 用户信息更新 + + 该事件与服务器无关, 遵循以下条件: + - 仅当用户的 用户名 或 头像 变更时 + - 仅通知与该用户存在关联的用户或 Bot + a. 存在聊天会话 + b. 双方好友关系 + """ + + __event__ = "notice.user_updated" + notice_type: Literal["user_updated"] + + +class SelfJoinGuildNoticeEvent(NoticeEvent): + """ + 自己新加入服务器 + + 当自己被邀请或主动加入新的服务器时, 产生该事件 + """ + + __event__ = "notice.self_joined_guild" + notice_type: Literal["self_joined_guild"] + user_id: str + group_id: int + + +class SelfExitGuildNoticeEvent(NoticeEvent): + """ + 自己退出服务器 + + 当自己被踢出服务器或被拉黑或主动退出服务器时, 产生该事件 + """ + + __event__ = "notice.self_exited_guild" + notice_type: Literal["self_exited_guild"] + user_id: str + group_id: int + + +class CartBtnClickNoticeEvent(NoticeEvent): + """ + Card 消息中的 Button 点击事件 + """ + + __event__ = "notice.message_btn_click" + notice_type: Literal["message_btn_click"] + user_id: str + group_id: int + + +# Meta Events +class MetaEvent(OriginEvent): + """元事件""" + + __event__ = "meta_event" + post_type: Literal["meta_event"] + meta_event_type: str + + +class LifecycleMetaEvent(MetaEvent): + """生命周期元事件""" + + __event__ = "meta_event.lifecycle" + meta_event_type: Literal["lifecycle"] + sub_type: str + + +class HeartbeatMetaEvent(MetaEvent): + """心跳元事件""" + + __event__ = "meta_event.heartbeat" + meta_event_type: Literal["heartbeat"] + sub_type: str + + +# 事件类映射 +_kook_events = { + model.__event__: model + for model in globals().values() + if inspect.isclass(model) and issubclass(model, OriginEvent) +} + + +def get_event_class( + post_type: str, event_type: str, sub_type: Optional[str] = None +) -> Type[T_KookEvent]: # type: ignore + """根据接收到的消息类型返回对应的事件类。 + + Args: + post_type: 请求类型。 + event_type: 事件类型。 + sub_type: 子类型。 + + Returns: + 对应的事件类。 + """ + if sub_type is None: + return _kook_events[".".join((post_type, event_type))] # type: ignore + return ( + _kook_events.get(".".join((post_type, event_type, sub_type))) + or _kook_events[".".join((post_type, event_type))] + ) # type: ignore diff --git a/iamai/adapter/kook/exceptions.py b/iamai/adapter/kook/exceptions.py new file mode 100644 index 00000000..abb733d9 --- /dev/null +++ b/iamai/adapter/kook/exceptions.py @@ -0,0 +1,96 @@ +"""Kook 适配器异常。""" + +from typing import Optional + +from iamai.exceptions import AdapterException + + +class KookException(AdapterException): + """Kook 异常基类。""" + + +class NetworkError(KookException): + """网络异常。""" + + +class ActionFailed(KookException): + """API 请求成功响应,但响应表示 API 操作失败。""" + + def __init__(self, resp): + """ + Args: + resp: 返回的响应。 + """ + self.resp = resp + + +class ApiNotAvailable(ActionFailed): + """API 请求返回 404,表示当前请求的 API 不可用或不存在。""" + + +class ApiTimeout(KookException): + """API 请求响应超时。""" + + +class UnauthorizedException(KookException): + pass + + +class RateLimitException(KookException): + pass + + +class UnsupportedMessageType(KookException): + """ + :说明: + + 在发送不支持的消息类型时抛出,开黑啦 Bot 不支持发送音频消息。 + """ + + def __init__(self, message: str = ""): + super().__init__() + self.message = message + + def __repr__(self) -> str: + return self.message + + +class UnsupportedMessageOperation(KookException): + """ + :说明: + + 在调用不支持的 Message 或 MessageSegment 操作时抛出,例如对图片类型的 MessageSegment 使用加运算。 + """ + + def __init__(self, message: str = ""): + super().__init__() + self.message = message + + def __repr__(self) -> str: + return self.message + + +class ReconnectError(KookException): + """ + :说明: + + 服务端通知客户端, 代表该连接已失效, 请重新连接。客户端收到后应该主动断开当前连接。 + """ + + +class TokenError(KookException): + """ + :说明: + + 服务端通知客户端, 代表该连接已失效, 请重新连接。客户端收到后应该主动断开当前连接。 + """ + + def __init__(self, msg: Optional[str] = None): + super().__init__() + self.msg = msg + + def __repr__(self): + return f"" + + def __str__(self): + return self.__repr__() diff --git a/iamai/adapter/kook/message.py b/iamai/adapter/kook/message.py new file mode 100644 index 00000000..42346f2d --- /dev/null +++ b/iamai/adapter/kook/message.py @@ -0,0 +1,339 @@ +"""Kook 适配器消息。""" + +import json +from io import StringIO +from dataclasses import dataclass +from typing_extensions import override, deprecated +from typing import ( # type: ignore + Any, + Dict, + Type, + Tuple, + Union, + Mapping, + Iterable, + Optional, + cast, +) + +from iamai.log import logger +from iamai.message import Message, MessageSegment + +from .exceptions import UnsupportedMessageType, UnsupportedMessageOperation + +__all__ = [ + "T_KookMSG", + "KookMessage", + "KookMessageSegment", + "escape_kmarkdown", + "unescape_kmarkdown", +] + +T_KookMSG = Union[str, Mapping, Iterable[Mapping], "KookMessageSegment", "KookMessage"] + +ESCAPE_CHAR = "!()*-.:>[\]`~" + +msg_type_map = { + "text": 1, + "image": 2, + "video": 3, + "file": 4, + "audio": 8, + "kmarkdown": 9, + "card": 10, +} + +rev_msg_type_map = {code: msg_type for msg_type, code in msg_type_map.items()} +# 根据协议消息段类型显示消息段内容 +segment_text = { + "text": "[文字]", + "image": "[图片]", + "video": "[视频]", + "file": "[文件]", + "audio": "[音频]", + "kmarkdown": "[KMarkdown消息]", + "card": "[卡片消息]", +} + + +class KookMessage(Message["KookMessageSegment"]): + """ + Kook v3 协议 Message 适配。 + """ + + @property + def _message_segment_class(self) -> Type["KookMessageSegment"]: + return KookMessageSegment + + def _str_to_message_segment(self, msg) -> "KookMessageSegment": + return KookMessageSegment(type="text", data={"content": msg}) + + def _mapping_to_message_segment(self, msg: Mapping) -> "KookMessageSegment": + return KookMessageSegment(type=msg["type"], data=msg.get("content") or {}) + + +class KookMessageSegment(MessageSegment["KookMessage"]): + """Kook 消息字段。""" + + """ + Kook 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。 + + https://developer.kookapp.cn/doc/event/message + """ + + @property + def _message_class(self) -> Type["KookMessage"]: + return KookMessage + + def __str__(self) -> str: + if self.type in ["text", "kmarkdown"]: + return str(self.data["content"]) + elif self.type == "at": + return str(f"@{self.data['user_name']}") + else: + return segment_text.get(self.type, "[未知类型消息]") + + @classmethod + @deprecated("用 KMarkdown 语法 (met)用户id/here/all(met) 代替") + def at(cls, user_id: str) -> "KookMessageSegment": + return KookMessageSegment.KMarkdown(f"(met){user_id}(met)", user_id) + + @classmethod + def text(cls, text: str) -> "KookMessageSegment": + return cls(type="text", data={"content": text}) + + @classmethod + def image(cls, file_key: str) -> "KookMessageSegment": + return cls(type="image", data={"file_key": file_key}) + + @classmethod + def video(cls, file_key: str, title: Optional[str] = None) -> "KookMessageSegment": + return cls( + type="video", + data={ + "file_key": file_key, + "title": title, + }, + ) + + @classmethod + def file(cls, file_key: str, title: Optional[str] = None) -> "KookMessageSegment": + return cls( + "file", + { + "file_key": file_key, + "title": title, + }, + ) + + @classmethod + def audio( + cls, + file_key: str, + title: Optional[str] = None, + cover_file_key: Optional[str] = None, + ) -> "KookMessageSegment": + return cls( + type="audio", + data={ + "file_key": file_key, + "title": title, + "cover_file_key": cover_file_key, + }, + ) + + @classmethod + def KMarkdown( + cls, content: str, raw_content: Optional[str] = None + ) -> "KookMessageSegment": + """ + 构造KMarkdown消息段 + + @param content: KMarkdown消息内容(语法参考:https://developer.kookapp.cn/doc/kmarkdown) + @param raw_content: (可选)消息段的纯文本内容 + """ + if raw_content is None: + raw_content = "" + + return cls( + type="kmarkdown", data={"content": content, "raw_content": raw_content} + ) + + @classmethod + def Card(cls, content: Any) -> "KookMessageSegment": + """ + 构造卡片消息 + + @param content: KMarkdown消息内容(语法参考:https://developer.kookapp.cn/doc/cardmessage) + """ + if not isinstance(content, str): + content = json.dumps(content) + + return cls(type="card", data={"content": content}) + + @classmethod + def quote(cls, msg_id: str) -> "KookMessageSegment": + return cls(type="quote", data={"msg_id": msg_id}) + + +def _convert_to_card_message(msg: KookMessage) -> KookMessageSegment: + cards = [] + modules = [] + + for seg in msg: + if seg.type == "card": + if len(modules) != 0: + cards.append( + {"type": "card", "theme": "none", "size": "lg", "modules": modules} + ) + modules = [] + cards.extend(json.loads(seg.data["content"])) + elif seg.type == "text": + modules.append( + { + "type": "section", + "text": {"type": "plain-text", "content": seg.data["content"]}, + } + ) + elif seg.type == "kmarkdown": + modules.append( + { + "type": "section", + "text": {"type": "kmarkdown", "content": seg.data["content"]}, + } + ) + elif seg.type == "image": + modules.append( + { + "type": "container", + "elements": [{"type": "image", "src": seg.data["file_key"]}], + } + ) + elif seg.type in ("audio", "video", "file"): + mod = { + "type": seg.type, + "src": seg.data["file_key"], + } + if seg.data.get("title") is not None: + mod["title"] = seg.data["title"] + if seg.data.get("cover_file_key") is not None: + mod["cover"] = seg.data["cover_file_key"] + modules.append(mod) + else: + raise UnsupportedMessageType(seg.type) + + if len(modules) != 0: + cards.append( + {"type": "card", "theme": "none", "size": "lg", "modules": modules} + ) + + return KookMessageSegment.Card(cards) + + +@dataclass +class MessageSerializer: + """ + Kook 协议 Message 序列化器。 + """ + + message: KookMessage + + def serialize(self, for_send: bool = True) -> Tuple[int, str]: + if len(self.message) != 1: + self.message = self.message.copy() + self.message.reduce() # type: ignore + + if len(self.message) != 1: + # 转化为卡片消息发送 + return MessageSerializer( + KookMessage(_convert_to_card_message(self.message)) + ).serialize() # type: ignore + + msg_type = self.message[0].type + msg_type_code = msg_type_map[msg_type] + # bot 发送消息只支持"text", "kmarkdown", "card" + # 经测试还支持"image", "video", "file" + if msg_type in ("text", "kmarkdown", "card"): + return msg_type_code, self.message[0].data["content"] + elif msg_type in ("image", "video", "file"): + return msg_type_code, self.message[0].data["file_key"] + elif msg_type == "audio": + if not for_send: + return msg_type_code, self.message[0].data["file_key"] + else: + # 转化为卡片消息发送 + return MessageSerializer( + KookMessage(_convert_to_card_message(self.message)) + ).serialize() + else: + raise UnsupportedMessageType(msg_type) + + +@dataclass +class MessageDeserializer: + """ + Kook 协议 Message 反序列化器。 + """ + + type_code: int + data: Dict + + def __post_init__(self): + self.type = rev_msg_type_map.get(self.type_code, "") + + def deserialize(self) -> KookMessage: + if self.type == "text": + return KookMessage(KookMessageSegment.text(self.data["content"])) + elif self.type == "image": + return KookMessage(KookMessageSegment.image(self.data["content"])) + elif self.type == "video": + return KookMessage( + KookMessageSegment.video(self.data["attachments"]["url"]) + ) + elif self.type == "file": + return KookMessage(KookMessageSegment.file(self.data["attachments"]["url"])) + elif self.type == "kmarkdown": + content = self.data["content"] + raw_content = self.data["extra"]["kmarkdown"]["raw_content"] + + unescaped = unescape_kmarkdown(content) + is_plain_text = unescaped.strip() == raw_content + if not is_plain_text: + return KookMessage(KookMessageSegment.KMarkdown(content, raw_content)) + raw_content = unescaped + + return KookMessage(KookMessageSegment.text(raw_content)) + elif self.type == "card": + return KookMessage(KookMessageSegment.Card(self.data["content"])) + else: + return KookMessage(KookMessageSegment(self.type, self.data)) + + +def escape_kmarkdown(content: str): + """ + 将文本中的kmarkdown标识符进行转义 + """ + with StringIO() as f: + for c in content: + if c in ESCAPE_CHAR: + f.write("\\") + f.write(c) + return f.getvalue() + + +def unescape_kmarkdown(content: str): + """ + 去除kmarkdown中的转义字符 + """ + with StringIO() as f: + i = 0 + while i < len(content): + if content[i] == "\\": + if i + 1 < len(content) and content[i + 1] in ESCAPE_CHAR: + f.write(content[i + 1]) + i += 2 + continue + + f.write(content[i]) + i += 1 + return f.getvalue() diff --git a/iamai/adapter/red/__init__.py b/iamai/adapter/red/__init__.py new file mode 100644 index 00000000..8e968896 --- /dev/null +++ b/iamai/adapter/red/__init__.py @@ -0,0 +1,205 @@ +"""red 协议适配器。 + +本适配器适配了 red 协议。 +协议详情请参考: [RedProtocol](https://chrononeko.github.io/QQNTRedProtocol/) 。 +""" + +import os +import json +import asyncio +from uu import Error +from functools import partial +from typing import TYPE_CHECKING, Any, Dict, Literal + +import yaml +import aiohttp +from itsdangerous import exc + +from iamai.log import logger +from iamai.adapter.utils import WebSocketAdapter + +from .api import HANDLE +from .exceptions import * +from .message import RedMessage +from .config import USER_CONFIG, Config +from .event import MsgType, RedEvent, get_event_class + +if TYPE_CHECKING: + from .message import T_RedMSG # type: ignore + + +__all__ = ["RedAdapter"] + + +class RedAdapter(WebSocketAdapter[RedEvent, Config]): + """Red 协议适配器。""" + + name: str = "red" + Config = Config + config: Config + + _api_response: Dict[Any, Any] + _api_response_cond: asyncio.Condition = None # type: ignore + _api_id: int = 0 + + def __getattr__(self, item: Any): + return partial(self.call_api, item) + + async def startup(self): + """初始化适配器。""" + self.adapter_type = "ws" + self.host = self.config.host + self.port = self.config.port + self.access_token = self.config.access_token + self.reconnect_interval = self.config.reconnect_interval + self._api_response_cond = asyncio.Condition() + if self.config.auto_fill: + logger.info("Auto Detecting Chronocat Config...") + self.chronocat_config = self.get_red_config() + if not self.chronocat_config: + raise Error("Can not parse or find Chronocat Config file!") + logger.success("Succeed to Parse Chronocat Config.") + servers = self.chronocat_config["servers"] + red = servers[0] if servers[0]["type"] == "red" else servers[1] + self.port = red["port"] + self.access_token = red["token"] + logger.debug(f"token: {self.access_token}") + await super().startup() + + async def websocket_connect(self): + """创建正向 WebSocket 连接。""" + logger.info("Tying to connect to WebSocket server...") + async with self.session.ws_connect( + f"ws://{self.host}:{self.port}/", + headers=( + {"Authorization": f"Bearer {self.access_token}"} + if self.access_token + else None + ), + ) as self.websocket: + connect = { + "type": "meta::connect", + "payload": {"token": self.access_token}, + } + await self.websocket.send_json(connect) + await self.handle_websocket() + + async def handle_websocket_msg(self, msg: aiohttp.WSMessage): + """处理 WebSocket 消息。""" + msg_dict = json.loads(msg.data) + msg_data = msg_dict["payload"] + if self.config.show_raw: + logger.info(msg_data) + if msg_dict["type"] == "meta::connect": + self.self_id = msg_data.get("authData").get("account") + logger.success( + f"WebSocket connection " + f"from {msg_data.get('name')}({msg_data.get('version')}) Bot {self.self_id} accepted!" + ) + elif msg_dict["type"] == "message::recv": + msg_data = msg_data[0] + try: + data = msg_data + if msg_data.get("chatType", None): + data["post_type"] = "message" + data["message_type"] = ( + "private" if msg_data["chatType"] == 1 else "group" + ) + + if data["message_type"] == "group": + data["group_id"] = msg_data.get("peerUid") + data["sub_type"] = "normal" + if data["message_type"] == "private": + data["user_id"] = msg_data.get("peerUid") + data["sub_type"] = "group" + data["timestamp"] = msg_data.get("msgTime") + data["nick_name"] = msg_data.get("sendNickName") + data["msgId"] = msg_data.get("msgId") + try: + data["message"] = ( + msg_data.get("elements")[0] + .get("textElement") + .get("content") + ) + except: + data["message"] = msg_data.get("elements")[0].get("summary") + logger.info(f"Event Received: {data}") + # elif ( + # msg_data.get("msgType") == MsgType.system and msg_data.get("sendType") == 3 + # ): + # data["post_type"] = "notice" + # if sub_type := msg_data["elements"][0]["grayTipElement"][ + # "groupElement" + # ]: + # if sub_type["type"] == 1: + # data["notice_type"] = "member_add" + # if sub_type["type"] == 8: + # data["notice_type"] = "member_mute" + # if sub_type["type"] == 5: + # data["notice_type"] = "group_name_update" + # if xml_type := msg_data["elements"][0]["grayTipElement"]["xmlElement"]: + # if ( + # xml_type["subElementType"] == 12 + # and xml_type["busiType"] == "1" + # and xml_type["busiId"] == "10145" + # ): + # data["notice_type"] = "member_unmute" + await self.handle_red_event(data) + except Exception as e: + logger.error(f"Event Handled Error with {e!r}") + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error( + f"Websocket connection closed " + f"with exception {self.websocket.exception()!r}" + ) + + async def handle_red_event(self, msg: Dict[str, Any]): + """处理 red 事件。 + + Args: + msg: 接收到的信息。 + """ + post_type = msg.get("post_type") + event_type = msg.get(f"{post_type}_type") + sub_type = msg.get("sub_type", None) + event_class = get_event_class(post_type, event_type, sub_type) # type: ignore + red_event = event_class(adapter=self, **msg) + + await self.handle_event(red_event) + + async def call_api(self, api: str, **params) -> Dict[str, Any]: # type: ignore + url = f"http://{self.host}/{self.port}/api" + + if api not in HANDLE: + raise ValueError(f"API '{api}' is not supported.") + + sender = HANDLE[api](params) + + async with aiohttp.ClientSession() as session: + endpoint, method, payload = sender(params) + async with session.request( + method, url=f"{url}/{endpoint}", json=payload + ) as response: + return await response.json() + + @staticmethod + def get_red_config(): + if not os.path.exists(USER_CONFIG): + return None + with open(USER_CONFIG, encoding="utf-8") as f: + chronocat_config = yaml.safe_load(f.read()) + return chronocat_config + + async def send( + self, elements: "T_RedMSG", chatType: Literal["private", "group"], peerUin: int + ) -> Dict[str, Any]: + """发送消息,调用 send_message API 发送消息。""" + if chatType == "private": + return await self.send_message( + chatType=1, peerUin=peerUin, elements=RedMessage(elements) + ) + elif chatType == "group": + return await self.send_message( + chatType=2, peerUin=peerUin, elements=RedMessage(elements) + ) + raise TypeError('message_type must be "private" or "group"') diff --git a/iamai/adapter/red/config.py b/iamai/adapter/red/config.py new file mode 100644 index 00000000..e70b9359 --- /dev/null +++ b/iamai/adapter/red/config.py @@ -0,0 +1,38 @@ +"""red 协议配置。""" + +import os +from ast import List +from os.path import join +from pathlib import Path +from typing import Optional + +from iamai.config import ConfigModel + +HOME = Path(os.path.expanduser("~")) +USER_CONFIG = join(HOME, ".chronocat", "config", "chronocat.yml") + + +class Config(ConfigModel): + """red 配置类,将在适配器被加载时被混入到机器人主配置中。 + + Attributes: + adapter_type:USER_CONFIG 适配器类型,需要和协议端配置相同。 + auto_fill: 是否根据配置自动读取设置,默认开启。 + reconnect_interval: 重连等待时间。 + api_timeout: 进行 API 调用时等待返回响应的超时时间。 + access_token: 鉴权。 + show_raw: 是否显示原始数据,默认为 False,不显示。 + report_self_message: 是否上报自身消息, 默认不上报 + """ + + __config_name__ = "red" + multi_account: bool = False + account_list: list = [] + auto_fill: bool = True + reconnect_interval: int = 3 + api_timeout: int = 1000 + host: str = "localhost" + port: int = 16531 + access_token: str = "" + show_raw: bool = False + report_self_message: bool = False diff --git a/iamai/adapter/red/event.py b/iamai/adapter/red/event.py new file mode 100644 index 00000000..a9135700 --- /dev/null +++ b/iamai/adapter/red/event.py @@ -0,0 +1,436 @@ +"""Red 适配器事件。""" + +import inspect +from enum import IntEnum +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any, Dict, Type, Union, Literal, TypeVar, Optional + +from pydantic import Field, BaseModel + +from iamai.event import Event + +from .message import T_RedMSG, RedMessage + +if TYPE_CHECKING: + from . import RedAdapter + from .message import T_RedMSG + +T_RedEvent = TypeVar("T_RedEvent", bound="RedEvent") + + +class RedEvent(Event["RedAdapter"]): + """Red 事件基类""" + + __event__ = "" + # type = Optional[str] = Field(alias="post_type") + post_type: Literal["message", "notice", "request", "meta_event"] + + +class EmojiAd(BaseModel): + url: str + desc: str + + +class EmojiMall(BaseModel): + packageId: int + emojiId: int + + +class EmojiZplan(BaseModel): + actionId: int + actionName: str + actionType: int + playerNumber: int + peerUid: str + bytesReserveInfo: str + + +class ThumbPath(BaseModel): ... + + +class TextElement(BaseModel): + content: str + atType: Literal[0, 1, 2] + atUid: str + atTinyId: str + atNtUid: str + subElementType: int + atChannelId: str + atRoleId: str + atRoleColor: int + atRoleName: str + needNotify: int + + +class RoleInfo(BaseModel): + roleId: str + name: str + color: int + + +class XMLElement(BaseModel): + busiType: str + busiId: str + c2cType: int + serviceType: int + ctrlFlag: int + content: str + templId: str + seqId: str + templParam: Any + pbReserv: str + members: Any + + +class PicElement(BaseModel): + picSubType: int + fileName: str + fileSize: str + picWidth: int + picHeight: int + original: bool + md5HexStr: str + sourcePath: str + thumbPath: ThumbPath + transferStatus: int + progress: int + picType: int + invalidState: int + fileUuid: str + fileSubId: str + thumbFileSize: int + summary: str + emojiAd: EmojiAd + emojiMall: EmojiMall + emojiZplan: EmojiZplan + + +class Element(BaseModel): + elementType: int + elementId: str + extBufForUI: str + picElement: Optional[PicElement] + textElement: Optional[TextElement] + arkElement: Optional[Any] + avRecordElement: Optional[Any] + calendarElement: Optional[Any] + faceElement: Optional[Any] + fileElement: Optional[Any] + giphyElement: Optional[Any] + + class grayTipElement: + xmlElement: XMLElement + aioOpGrayTipElement: Optional[Any] + blockGrayTipElement: Optional[Any] + buddyElement: Optional[Any] + buddyNotifyElement: Optional[Any] + emojiReplyElement: Optional[Any] + essenceElement: Optional[Any] + feedMsgElement: Optional[Any] + fileReceiptElement: Optional[Any] + groupElement: Optional[Any] + groupNotifyElement: Optional[Any] + jsonGrayTipElement: Optional[Any] + localGrayTipElement: Optional[Any] + proclamationElement: Optional[Any] + revokeElement: Optional[Any] + subElementType: Optional[Any] + + inlineKeyboardElement: Optional[Any] + liveGiftElement: Optional[Any] + markdownElement: Optional[Any] + marketFaceElement: Optional[Any] + multiForwardMsgElement: Optional[Any] + pttElement: Optional[Any] + replyElement: Optional[Any] + structLongMsgElement: Optional[Any] + textGiftElement: Optional[Any] + videoElement: Optional[Any] + walletElement: Optional[Any] + yoloGameResultElement: Optional[Any] + + +class ChatType(IntEnum): + FRIEND = 1 + GROUP = 2 + + +class OtherAdd(BaseModel): + uid: Optional[str] + name: Optional[str] + uin: Optional[str] + + +class MemberAdd(BaseModel): + showType: int + otherAdd: Optional[OtherAdd] + otherAddByOtherQRCode: Optional[Any] + otherAddByYourQRCode: Optional[Any] + youAddByOtherQRCode: Optional[Any] + otherInviteOther: Optional[Any] + otherInviteYou: Optional[Any] + youInviteOther: Optional[Any] + + +class ShutUpTarget(BaseModel): + uid: Optional[str] + card: str + name: str + role: int + uin: str + + +class ShutUp(BaseModel): + curTime: int + duration: int + admin: ShutUpTarget + member: ShutUpTarget + + +class GroupElement(BaseModel): + type: int + role: int + groupName: Optional[str] + memberUid: Optional[str] + memberNick: Optional[str] + memberRemark: Optional[str] + adminUid: Optional[str] + adminNick: Optional[str] + adminRemark: Optional[str] + createGroup: Optional[Any] + memberAdd: Optional[MemberAdd] + shutUp: Optional[ShutUp] + memberUin: Optional[str] + adminUin: Optional[str] + + +class XmlElement(BaseModel): + busiType: Optional[str] + busiId: Optional[str] + c2cType: int + serviceType: int + ctrlFlag: int + content: Optional[str] + templId: Optional[str] + seqId: Optional[str] + templParam: Optional[Any] + pbReserv: Optional[str] + members: Optional[Any] + + +class Member(BaseModel): + uid: str + qid: str + uin: str + nick: str + remark: str + cardType: int + cardName: str + role: int + avatarPath: str + shutUpTime: int + isDelete: bool + + +class Group(BaseModel): + groupCode: str + maxMember: int + memberCount: int + groupName: str + groupStatus: int + memberRole: int + isTop: bool + toppedTimestamp: str + privilegeFlag: int + isConf: bool + hasModifyConfGroupFace: bool + hasModifyConfGroupName: bool + remarkName: str + avatarUrl: str + hasMemo: bool + groupShutupExpireTime: str + personShutupExpireTime: str + discussToGroupUin: str + discussToGroupMaxMsgSeq: int + discussToGroupTime: int + + +class ImageInfo(BaseModel): + width: int + height: int + type: Optional[str] + mime: Optional[str] + wUnits: Optional[str] + hUnits: Optional[str] + + +class UploadResponse(BaseModel): + md5: str + imageInfo: Optional[ImageInfo] + fileSize: int + filePath: str + ntFilePath: str + + +class MsgType(IntEnum): + normal = 2 + may_file = 3 + system = 5 + voice = 6 + video = 7 + value8 = 8 + reply = 9 + wallet = 10 + ark = 11 + may_market = 17 + + +class MessageEvent(RedEvent): + """消息事件""" + + __event__ = "message" + post_type: Literal["message"] + message_type: Literal["private", "group"] + sub_type: Union[Literal["channel"], str] + message: RedMessage + original_message: RedMessage + + def __repr__(self) -> str: + return f'Event<{self.type}>: "{self.message}"' + + async def reply(self, msg: "T_RedMSG") -> Dict[str, Any]: + """回复消息""" + + raise NotImplementedError + + +class PrivateMessageEvent(MessageEvent): + """私聊消息事件""" + + __event__ = "message.private" + message_type: Literal["private"] + sub_type: Literal["friend", "group", "group_self", "other"] + roleType: int + + async def reply(self, msg: T_RedMSG) -> Dict[str, Any]: + return await self.adapter.send_message( + chatType=1, peerUin=self.peerUid, elements=RedMessage(msg) + ) + + +class GroupMessageEvent(MessageEvent): + """群消息事件""" + + __event__ = "message.group" + message_type: Literal["group"] + sub_type: Literal["normal", "anonymous", "notice"] + roleType: int + + async def reply(self, msg: T_RedMSG) -> Dict[str, Any]: + return await self.adapter.send_message( + chatType=2, peerUin=self.peerUid, elements=RedMessage(msg) + ) + + +class NoticeEvent(RedEvent): + __event__ = "notice" + post_type: Literal["notice"] + notice_type: str + msgId: str + msgRandom: str + msgSeq: str + cntSeq: str + chatType: ChatType + msgType: MsgType + subMsgType: int + peerUid: str + peerUin: Optional[str] + + # class Config: + # extra = "ignore" + + +class GroupNameUpdateEvent(NoticeEvent): + """群名变更事件""" + + __event__ = "notice.group_name_update" + notice_type: Literal["group_name_update"] + currentName: str + operatorUid: str + operatorName: str + + +class MemberAddEvent(NoticeEvent): + """群成员增加事件""" + + __event__ = "notice.member_add" + notice_type: Literal["member_add"] + memberUid: str + operatorUid: str + memberName: Optional[str] + + +class MemberMuteEvent(NoticeEvent): + """群成员禁言相关事件""" + + __event__ = "notice.member_mute" + notice_type: Literal["member_mute"] + start: datetime + duration: timedelta + operator: ShutUpTarget + member: ShutUpTarget + + +class MemberUnmuteEvent(NoticeEvent): + """群成员被解除禁言事件""" + + __event__ = "notice.member_unmute" + notice_type: Literal["member_unmute"] + start: datetime + duration: timedelta + operator: ShutUpTarget + member: ShutUpTarget + + +class MetaEvent(RedEvent): + """元事件""" + + __event__ = "meta_event" + post_type: Literal["meta_event"] + meta_event_type: str + + +class LifecycleMetaEvent(MetaEvent): + """生命周期""" + + __event__ = "meta_event.lifecycle" + meta_event_type: Literal["lifecycle"] + sub_type: Literal["enable", "disable", "connect"] + + +_red_events = { + model.__event__: model + for model in globals().values() + if inspect.isclass(model) and issubclass(model, RedEvent) +} + + +def get_event_class( + post_type: str, event_type: str, sub_type: Optional[str] = None +) -> Type[T_RedEvent]: + """根据接收到的消息类型返回对应的事件类。 + + Args: + post_type: 请求类型。 + event_type: 事件类型。 + sub_type: 子类型。 + + Returns: + 对应的事件类。 + """ + if sub_type is None: + return _red_events[".".join((post_type, event_type))] + return ( + _red_events.get(".".join((post_type, event_type, sub_type))) + or _red_events[".".join((post_type, event_type))] + )