Skip to content

Commit 6cc9c62

Browse files
committed
fix: add check to avoid deleting custom_components
If custom_components not identified, it's possible the custom_components directory could be deleted. Add addiontal checks to prevent this.
1 parent 8088b3d commit 6cc9c62

File tree

5 files changed

+330
-143
lines changed

5 files changed

+330
-143
lines changed

custom_components/pr_custom_component/api.py

+47-62
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
import shutil
1717
import socket
18-
from typing import Dict, Text, Union, List
18+
from typing import Dict, List, Union
1919

2020
import aiofiles
2121
import aiohttp
@@ -43,35 +43,38 @@
4343

4444

4545
class PRCustomComponentApiClient:
46+
"""Api Client."""
47+
4648
def __init__(
4749
self,
4850
session: aiohttp.ClientSession,
4951
pull_url: yarl.URL,
50-
config_path: Text = "/config",
52+
config_path: str = "/config",
5153
) -> None:
5254
"""Initialize API client.
5355
5456
Args:
5557
session (aiohttp.ClientSession): Websession to use
56-
pull_url (yarl.URL): URL of pull request, e.g., https://github.com/home-assistant/core/pull/46558
57-
config_path (Text): base path for config, e.g., /config
58+
pull_url (yarl.URL): URL of pull request, e.g.,
59+
https://github.com/home-assistant/core/pull/46558
60+
config_path (str): base path for config, e.g., /config
5861
5962
"""
6063
self._pull_url: yarl.URL = pull_url
6164
self._session: aiohttp.ClientSession = session
62-
self._config_path: Text = config_path
63-
self._manifest: Dict[Text, Union[Text, List[Text]]] = {}
64-
self._component_name: Text = ""
65-
self._updated_at: Text = ""
66-
self._base_path: Text = ""
67-
self._update_available: Text = ""
68-
self._token: Text = ""
69-
self._headers: Dict[Text, Text] = {}
65+
self._config_path: str = config_path
66+
self._manifest: Dict[str, Union[str, List[str]]] = {}
67+
self._component_name: str = ""
68+
self._updated_at: str = ""
69+
self._base_path: str = ""
70+
self._update_available: str = ""
71+
self._token: str = ""
72+
self._headers: Dict[str, str] = {}
7073
self._auto_update: bool = False
7174
self._pull_number: int = 0
7275

7376
@property
74-
def name(self) -> Text:
77+
def name(self) -> str:
7578
"""Return the component name."""
7679
return self._component_name
7780

@@ -81,17 +84,17 @@ def pull_number(self) -> int:
8184
return self._pull_number
8285

8386
@property
84-
def updated_at(self) -> Text:
87+
def updated_at(self) -> str:
8588
"""Return the last updated time."""
8689
return self._updated_at
8790

8891
@updated_at.setter
89-
def updated_at(self, value: Text) -> None:
92+
def updated_at(self, value: str) -> None:
9093
"""Set the last updated time."""
9194
self._updated_at = value
9295

9396
@property
94-
def update_available(self) -> Text:
97+
def update_available(self) -> str:
9598
"""Return the whether an update is available."""
9699
return self._update_available
97100

@@ -100,11 +103,11 @@ def auto_update(self) -> bool:
100103
"""Return the whether an to autoupdate when available."""
101104
return self._auto_update
102105

103-
def set_token(self, token: Text = "") -> None:
106+
def set_token(self, token: str = "") -> None:
104107
"""Set auth token for GitHub to avoid rate limits.
105108
106109
Args:
107-
token (Text, optional): Authentication token from GitHub. Defaults to "".
110+
token (str, optional): Authentication token from GitHub. Defaults to "".
108111
"""
109112
if token:
110113
self._token = token
@@ -116,20 +119,21 @@ async def async_update_data(self, download: bool = False) -> dict:
116119
if not pull_json or pull_json.get("message") == "Not Found":
117120
_LOGGER.debug("No pull data found")
118121
return {}
119-
component_name: Text = ""
122+
component_name: str = ""
120123
for label in pull_json["labels"]:
121124
if label["name"].startswith("integration: "):
122125
component_name = label["name"].replace("integration: ", "")
123126
break
124127
if not component_name:
125128
_LOGGER.error("Unable to find integration in pull request")
129+
return {}
126130
else:
127131
_LOGGER.debug("Found %s integration", component_name)
128-
branch: Text = pull_json["head"]["ref"]
132+
branch: str = pull_json["head"]["ref"]
129133
pull_number: int = pull_json["number"]
130-
user: Text = pull_json["head"]["user"]["login"]
131-
path: Text = f"{COMPONENT_PATH}{component_name}"
132-
contents_url: Text = pull_json["head"]["repo"]["contents_url"]
134+
user: str = pull_json["head"]["user"]["login"]
135+
path: str = f"{COMPONENT_PATH}{component_name}"
136+
contents_url: str = pull_json["head"]["repo"]["contents_url"]
133137
url: yarl.URL = yarl.URL(contents_url.replace("{+path}", path)).with_query(
134138
{"ref": branch}
135139
)
@@ -144,7 +148,7 @@ async def async_update_data(self, download: bool = False) -> dict:
144148
"version": self._updated_at.replace("-", "."),
145149
}
146150
self._component_name = component_name
147-
component_path: Text = os.path.join(
151+
component_path: str = os.path.join(
148152
self._config_path, CUSTOM_COMPONENT_PATH, self._component_name
149153
)
150154
if not os.path.isdir(component_path):
@@ -174,12 +178,12 @@ async def async_get_patch_data(self) -> dict:
174178
).with_host(PATCH_DOMAIN)
175179
return await self.api_wrapper("get", url)
176180

177-
async def async_download(self, url: Text, path: Text) -> bool:
181+
async def async_download(self, url: str, path: str) -> bool:
178182
"""Download and save files to path.
179183
180184
Args:
181185
url (yarl.URL): Remote path to download
182-
path (Text): Local path to save to
186+
path (str): Local path to save to
183187
184188
Returns:
185189
bool: Whether saved successful
@@ -193,26 +197,6 @@ async def async_download(self, url: Text, path: Text) -> bool:
193197
isinstance(result, dict) and result.get("message") == "Not Found"
194198
):
195199
return False
196-
# [
197-
# {
198-
# "name": "__init__.py",
199-
# "path": "homeassistant/components/tesla/__init__.py",
200-
# "sha": "9e6db33d24ab50c1af1e1e2818580cc96069e076",
201-
# "size": 9220,
202-
# "url": "https://api.github.com/repos/alandtse/home-assistant/contents/homeassistant/components/tesla/__init__.py?ref=tesla_oauth_callback",
203-
# "html_url": "https://github.com/alandtse/home-assistant/blob/tesla_oauth_callback/homeassistant/components/tesla/__init__.py",
204-
# "git_url": "https://api.github.com/repos/alandtse/home-assistant/git/blobs/9e6db33d24ab50c1af1e1e2818580cc96069e076",
205-
# "download_url": "https://raw.githubusercontent.com/alandtse/home-assistant/tesla_oauth_callback/homeassistant/components/tesla/__init__.py",
206-
# "type": "file",
207-
# "content": "IiIiU3V<SNIP>\n",
208-
# "encoding": "base64",
209-
# "_links": {
210-
# "self": "https://api.github.com/repos/alandtse/home-assistant/contents/homeassistant/components/tesla/__init__.py?ref=tesla_oauth_callback",
211-
# "git": "https://api.github.com/repos/alandtse/home-assistant/git/blobs/9e6db33d24ab50c1af1e1e2818580cc96069e076",
212-
# "html": "https://github.com/alandtse/home-assistant/blob/tesla_oauth_callback/homeassistant/components/tesla/__init__.py"
213-
# }
214-
# }
215-
# ]
216200
if not result:
217201
_LOGGER.debug("%s is empty", url)
218202
return True
@@ -226,7 +210,7 @@ async def async_download(self, url: Text, path: Text) -> bool:
226210
_LOGGER.debug("Processing directory")
227211
tasks = []
228212
for file_json in result:
229-
file_path: Text = os.path.join(path, file_json["name"])
213+
file_path: str = os.path.join(path, file_json["name"])
230214
if file_json["type"] == "dir" and not os.path.isdir(file_path):
231215
_LOGGER.debug("Creating new sub directory %s", file_path)
232216
os.mkdir(file_path)
@@ -235,15 +219,9 @@ async def async_download(self, url: Text, path: Text) -> bool:
235219
return True
236220
if isinstance(result, dict):
237221
path.split(os.sep)
238-
file_name: Text = result["name"]
222+
file_name: str = result["name"]
239223
file_path = result["path"].replace(self._base_path, "")
240-
full_path: Text = os.path.join(path, file_path.lstrip(os.sep))
241-
# if len(file_path.split(os.sep)) > 1:
242-
# directory = file_path.split(os.sep)[0]
243-
# directory_path = os.path.join(path, directory.lstrip(os.sep))
244-
# if not os.path.isdir(directory_path):
245-
# _LOGGER.debug("Creating new directory %s", directory_path)
246-
# os.mkdir(directory_path)
224+
full_path: str = os.path.join(path, file_path.lstrip(os.sep))
247225
contents = base64.b64decode(result["content"].encode("utf-8"))
248226
_LOGGER.debug("Saving %s size: %s KB", full_path, result["size"] / 1000)
249227
if file_name == "manifest.json":
@@ -267,11 +245,13 @@ async def api_wrapper(
267245
self,
268246
method: str,
269247
url: Union[str, yarl.URL],
270-
data: dict = {},
271-
headers: dict = {},
248+
data: dict = None,
249+
headers: dict = None,
272250
) -> dict:
273251
"""Get information from the API."""
252+
headers = headers or {}
274253
headers = self._headers if not headers else headers
254+
data = data or {}
275255
try:
276256
async with async_timeout.timeout(TIMEOUT, loop=asyncio.get_event_loop()):
277257
if method == "get":
@@ -288,17 +268,17 @@ async def api_wrapper(
288268
_LOGGER.error("Rate limited: %s", response_json["message"])
289269
raise RateLimitException("Rate limited")
290270
return response_json
291-
elif method == "put":
271+
if method == "put":
292272
return await (
293273
await self._session.put(url, headers=headers, json=data)
294274
).json()
295275

296-
elif method == "patch":
276+
if method == "patch":
297277
return await (
298278
await self._session.patch(url, headers=headers, json=data)
299279
).json()
300280

301-
elif method == "post":
281+
if method == "post":
302282
return await (
303283
await self._session.post(url, headers=headers, json=data)
304284
).json()
@@ -331,14 +311,19 @@ async def async_delete(self) -> bool:
331311
"""Delete files from config path.
332312
333313
Args:
334-
path (Text): Delete component files.
314+
path (str): Delete component files.
335315
336316
Returns:
337317
bool: Whether delete is successful.
338318
"""
339-
component_path: Text = os.path.join(
319+
component_path: str = os.path.join(
340320
self._config_path, CUSTOM_COMPONENT_PATH, self._component_name
341321
)
322+
if not self._component_name or component_path == os.path.join(
323+
self._config_path, CUSTOM_COMPONENT_PATH
324+
):
325+
_LOGGER.warning("Component name was empty while delete was called.")
326+
return False
342327
if os.path.isdir(component_path):
343328
_LOGGER.debug("Deleting %s", component_path)
344329
try:

custom_components/pr_custom_component/config_flow.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
For more details about this integration, please refer to
99
https://github.com/alandtse/pr_custom_component
1010
"""
11-
from typing import Dict, Text
11+
import logging
12+
from typing import Dict, Typing
1213

1314
from homeassistant import config_entries
1415
from homeassistant.core import callback
1516
from homeassistant.helpers.aiohttp_client import async_get_clientsession
1617
import voluptuous as vol
1718
import yarl
18-
import logging
1919

2020
from . import get_hacs_token
2121
from .api import PRCustomComponentApiClient
@@ -25,9 +25,7 @@
2525
_LOGGER: logging.Logger = logging.getLogger(__package__)
2626

2727

28-
class PRCustomComponentFlowHandler( # type: ignore
29-
config_entries.ConfigFlow, domain=DOMAIN
30-
):
28+
class PRCustomComponentFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): # type: ignore
3129
"""Config flow for PRCustomComponent."""
3230

3331
VERSION = 1
@@ -57,8 +55,7 @@ async def async_step_user(self, user_input=None):
5755
return self.async_create_entry(
5856
title=result.get("name"), data=user_input
5957
)
60-
else:
61-
self._errors["base"] = "bad_pr"
58+
self._errors["base"] = "bad_pr"
6259

6360
return await self._show_config_form(user_input)
6461

@@ -67,6 +64,7 @@ async def async_step_user(self, user_input=None):
6764
@staticmethod
6865
@callback
6966
def async_get_options_flow(config_entry):
67+
"""Get the options flow."""
7068
return PRCustomComponentOptionsFlowHandler(config_entry)
7169

7270
async def _show_config_form(self, user_input): # pylint: disable=unused-argument
@@ -77,22 +75,19 @@ async def _show_config_form(self, user_input): # pylint: disable=unused-argumen
7775
errors=self._errors,
7876
)
7977

80-
async def install_integration(self, pr_url: Text) -> Dict:
78+
async def install_integration(self, pr_url: str) -> Dict:
8179
"""Return true if integration is successfully installed."""
82-
try:
83-
session = async_get_clientsession(self.hass)
84-
client = PRCustomComponentApiClient(
85-
session, yarl.URL(pr_url), self.hass.config.path()
86-
)
87-
client.set_token(get_hacs_token(self.hass))
88-
await client.async_update_data(download=True)
80+
session = async_get_clientsession(self.hass)
81+
client = PRCustomComponentApiClient(
82+
session, yarl.URL(pr_url), self.hass.config.path()
83+
)
84+
client.set_token(get_hacs_token(self.hass))
85+
if await client.async_update_data(download=True):
8986
return {
9087
"name": client.name,
9188
"update_time": client.updated_at,
9289
"pull_number": client.pull_number,
9390
}
94-
except Exception: # pylint: disable=broad-except
95-
raise
9691
return {}
9792

9893

custom_components/pr_custom_component/exceptions.py

-4
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313
class PRCustomComponentException(Exception):
1414
"""Class of PR Custom Component exceptions."""
1515

16-
pass
17-
1816

1917
class RateLimitException(PRCustomComponentException):
2018
"""Class of exceptions for hitting retry limits."""
21-
22-
pass

0 commit comments

Comments
 (0)