|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +import asyncio |
| 3 | +import logging |
| 4 | +import time |
| 5 | +import traceback |
| 6 | + |
| 7 | +from .data_objects import Image, RedditImage, TagCollection, WikiHowImage, Ban, BanIterator |
| 8 | +from .errors import APIError |
| 9 | +from .events import BanEvent, UnBanEvent |
| 10 | +from .http import krequest, Route |
| 11 | + |
| 12 | +logger = logging.getLogger() |
| 13 | + |
| 14 | + |
| 15 | +class Client: |
| 16 | + """ |
| 17 | + .. _aiohttp session: https://aiohttp.readthedocs.io/en/stable/client_reference.html#client-session |
| 18 | +
|
| 19 | + Client object for KSOFT.SI API |
| 20 | +
|
| 21 | + This is a client object for KSoft.Si API. Here are two versions. Basic without discord.py bot |
| 22 | + and a pluggable version that inserts this client object directly into your discord.py bot. |
| 23 | +
|
| 24 | +
|
| 25 | + Represents a client connection that connects to ksoft.si. It works in two modes: |
| 26 | + 1. As a standalone variable. |
| 27 | + 2. Plugged-in to discord.py Bot or AutoShardedBot, see :any:`Client.pluggable` |
| 28 | +
|
| 29 | + Parameters |
| 30 | + ------------- |
| 31 | + api_key: :class:`str` |
| 32 | + Your ksoft.si api token. |
| 33 | + Specify different base url. |
| 34 | + **bot: Bot or AutoShardedBot |
| 35 | + Your bot client from discord.py |
| 36 | + **loop: asyncio loop |
| 37 | + Your asyncio loop. |
| 38 | + **low_memory: bool[Optional] |
| 39 | + Low memory mode, save images to files instead of memory. WIP |
| 40 | +
|
| 41 | + """ |
| 42 | + |
| 43 | + def __init__(self, api_key: str, bot=None, loop=asyncio.get_event_loop(), **kwargs): |
| 44 | + self.api_key = api_key |
| 45 | + self._loop = loop |
| 46 | + self.http = krequest(global_headers=[ |
| 47 | + ("Authorization", f"NANI {self.api_key}") |
| 48 | + ], loop=self._loop, lowmem=kwargs.get("lowmem_mode")) |
| 49 | + self.bot = bot |
| 50 | + |
| 51 | + self._ban_hook = [] |
| 52 | + self._last_update = time.time() - 60 * 10 |
| 53 | + |
| 54 | + if self.bot is not None: |
| 55 | + self.bot.loop.create_task(self._ban_updater) |
| 56 | + |
| 57 | + logger_string = str("NANI") + self.api_key[-4:].rjust(len(self.api_key), "*") |
| 58 | + logger.info(f"KSOFT API Logging in as {logger_string}") |
| 59 | + |
| 60 | + def register_ban_hook(self, func): |
| 61 | + if func not in self._ban_hook: |
| 62 | + logger.info("Registered event hook", func.__name__) |
| 63 | + self._ban_hook.append(func) |
| 64 | + |
| 65 | + def unregister_ban_hook(self, func): |
| 66 | + if func in self._ban_hook: |
| 67 | + logger.info("Unregistered event hook", func.__name__) |
| 68 | + self._ban_hook.remove(func) |
| 69 | + |
| 70 | + async def _dispatch_ban_event(self, event): |
| 71 | + logger.info('Dispatching event of type %s to %d hooks', event.__class__.__name__, len(self._ban_hook)) |
| 72 | + for hook in self._ban_hook: |
| 73 | + await hook(event) |
| 74 | + |
| 75 | + async def _ban_updater(self): |
| 76 | + await self.bot.wait_until_ready() |
| 77 | + while not self.bot.is_closed(): |
| 78 | + try: |
| 79 | + if self._ban_hook: |
| 80 | + r = await self.http.request(Route.bans("GET", "/updates"), params={"timestamp": self._last_update}) |
| 81 | + self._last_update = time.time() |
| 82 | + for b in r['data']: |
| 83 | + if b['active'] is True: |
| 84 | + await self._dispatch_ban_event(BanEvent(**b)) |
| 85 | + else: |
| 86 | + await self._dispatch_ban_event(UnBanEvent(**b)) |
| 87 | + except Exception as e: |
| 88 | + logger.error("Error in the ban update loop: %s" % e) |
| 89 | + traceback.print_exc() |
| 90 | + finally: |
| 91 | + await asyncio.sleep(60 * 5) |
| 92 | + |
| 93 | + @classmethod |
| 94 | + def pluggable(cls, bot, api_key: str, *args, **kwargs): |
| 95 | + """ |
| 96 | + Pluggable version of Client. Inserts Client directly into your Bot client. |
| 97 | + Called by using `bot.ksoft` |
| 98 | +
|
| 99 | +
|
| 100 | + Parameters |
| 101 | + ------------- |
| 102 | + bot: discord.ext.commands.Bot or discord.ext.commands.AutoShardedBot |
| 103 | + Your bot client from discord.py |
| 104 | + api_key: :class:`str` |
| 105 | + Your ksoft.si api token. |
| 106 | +
|
| 107 | +
|
| 108 | + .. note:: |
| 109 | + Takes the same parameters as :class:`Client` class. |
| 110 | + Usage changes to ``bot.ksoft``. (``bot`` is your bot client variable) |
| 111 | +
|
| 112 | + """ |
| 113 | + try: |
| 114 | + return bot.ksoft |
| 115 | + except AttributeError: |
| 116 | + bot.ksoft = cls(api_key, bot=bot, *args, **kwargs) |
| 117 | + return bot.ksoft |
| 118 | + |
| 119 | + async def random_image(self, tag: str, nsfw: bool = False) -> Image: |
| 120 | + """|coro| |
| 121 | +
|
| 122 | + This function gets a random image from the specified tag. |
| 123 | +
|
| 124 | + Parameters |
| 125 | + ------------ |
| 126 | + tag: :class:`str` |
| 127 | + Image tag from string. |
| 128 | + nsfw: :class:`bool` |
| 129 | + If to display NSFW images. |
| 130 | +
|
| 131 | +
|
| 132 | + :return: :class:`ksoftapi.data_objects.Image` |
| 133 | +
|
| 134 | + """ |
| 135 | + g = await self.http.request(Route.meme("GET", "/random-image"), params={"tag": tag, "nsfw": nsfw}) |
| 136 | + return Image(**g) |
| 137 | + |
| 138 | + async def random_meme(self) -> RedditImage: |
| 139 | + """|coro| |
| 140 | +
|
| 141 | + This function gets a random meme from multiple sources from reddit. |
| 142 | +
|
| 143 | +
|
| 144 | +
|
| 145 | + :return: :class:`ksoftapi.data_objects.RedditImage` |
| 146 | +
|
| 147 | + """ |
| 148 | + g = await self.http.request(Route.meme("GET", "/random-meme")) |
| 149 | + return RedditImage(**g) |
| 150 | + |
| 151 | + async def random_aww(self) -> RedditImage: |
| 152 | + """|coro| |
| 153 | +
|
| 154 | + This function gets a random cute pictures from multiple sources from reddit. |
| 155 | +
|
| 156 | +
|
| 157 | +
|
| 158 | + :return: :class:`ksoftapi.data_objects.RedditImage` |
| 159 | +
|
| 160 | + """ |
| 161 | + g = await self.http.request(Route.meme("GET", "/random-aww")) |
| 162 | + return RedditImage(**g) |
| 163 | + |
| 164 | + async def random_wikihow(self) -> WikiHowImage: |
| 165 | + """|coro| |
| 166 | +
|
| 167 | + This function gets a random WikiHow image. |
| 168 | +
|
| 169 | +
|
| 170 | +
|
| 171 | + :return: :class:`ksoftapi.data_objects.WikiHowImage` |
| 172 | +
|
| 173 | + """ |
| 174 | + g = await self.http.request(Route.meme("GET", "/random-wikihow")) |
| 175 | + return WikiHowImage(**g) |
| 176 | + |
| 177 | + async def random_reddit(self, subreddit: str) -> RedditImage: |
| 178 | + """|coro| |
| 179 | +
|
| 180 | + This function gets a random post from specified subreddit. |
| 181 | +
|
| 182 | +
|
| 183 | +
|
| 184 | + :return: :class:`ksoftapi.data_objects.RedditImage` |
| 185 | +
|
| 186 | + """ |
| 187 | + g = await self.http.request(Route.meme("GET", "/rand-reddit/{subreddit}", subreddit=subreddit)) |
| 188 | + return RedditImage(**g) |
| 189 | + |
| 190 | + async def tags(self) -> TagCollection: |
| 191 | + """|coro| |
| 192 | +
|
| 193 | + This function gets all available tags on the api. |
| 194 | +
|
| 195 | +
|
| 196 | +
|
| 197 | + :return: :class:`ksoftapi.data_objects.TagCollection` |
| 198 | +
|
| 199 | + """ |
| 200 | + g = await self.http.request(Route.meme("GET", "/tags")) |
| 201 | + return TagCollection(**g) |
| 202 | + |
| 203 | + # BANS |
| 204 | + async def bans_add(self, user_id: int, reason: str, proof: str, **kwargs): |
| 205 | + arg_params = ["mod", "user_name", "user_discriminator", "appeal_possible"] |
| 206 | + data = { |
| 207 | + "user": user_id, |
| 208 | + "reason": reason, |
| 209 | + "proof": proof |
| 210 | + } |
| 211 | + for arg, val in kwargs.items(): |
| 212 | + if arg in arg_params: |
| 213 | + data.update({arg: val}) |
| 214 | + else: |
| 215 | + raise ValueError(f"unknown parameter: {arg}") |
| 216 | + r = await self.http.request(Route.bans("POST", "/add"), data=data) |
| 217 | + if r.get("success", False) is True: |
| 218 | + return True |
| 219 | + else: |
| 220 | + raise APIError(**r) |
| 221 | + |
| 222 | + async def bans_check(self, user_id: int) -> bool: |
| 223 | + r = await self.http.request(Route.bans("GET", "/check"), params={"user": user_id}) |
| 224 | + if r.get("is_banned", None) is not None: |
| 225 | + return r['is_banned'] |
| 226 | + else: |
| 227 | + raise APIError(**r) |
| 228 | + |
| 229 | + async def bans_info(self, user_id: int) -> Ban: |
| 230 | + r = await self.http.request(Route.bans("GET", "/info"), params={"user": user_id}) |
| 231 | + if r.get("is_ban_active", None) is not None: |
| 232 | + return Ban(**r) |
| 233 | + else: |
| 234 | + raise APIError(**r) |
| 235 | + |
| 236 | + async def bans_remove(self, user_id: int) -> bool: |
| 237 | + r = await self.http.request(Route.bans("DELETE", "/remove"), params={"user": user_id}) |
| 238 | + if r.get("done", None) is not None: |
| 239 | + return True |
| 240 | + else: |
| 241 | + raise APIError(**r) |
| 242 | + |
| 243 | + def ban_get_list_iterator(self): |
| 244 | + return BanIterator(self, Route.bans("GET", "/list")) |
0 commit comments