Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 99 additions & 2 deletions homeassistant/components/media_source/local_source.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,37 @@
"""Local Media Source Implementation."""
from __future__ import annotations

import logging
import mimetypes
from pathlib import Path

from aiohttp import web
from aiohttp.web_request import FileField
from aioshutil import shutil
import voluptuous as vol

from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_player.const import MEDIA_CLASS_DIRECTORY
from homeassistant.components.media_player.errors import BrowseError
from homeassistant.core import HomeAssistant, callback
from homeassistant.util import raise_if_invalid_path
from homeassistant.exceptions import Unauthorized
from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path

from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES
from .error import Unresolvable
from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia

MAX_UPLOAD_SIZE = 1024 * 1024 * 10
LOGGER = logging.getLogger(__name__)


@callback
def async_setup(hass: HomeAssistant) -> None:
"""Set up local media source."""
source = LocalSource(hass)
hass.data[DOMAIN][DOMAIN] = source
hass.http.register_view(LocalMediaView(hass, source))
hass.http.register_view(UploadMediaView(hass, source))


class LocalSource(MediaSource):
Expand All @@ -43,11 +52,14 @@ def async_full_path(self, source_dir_id: str, location: str) -> Path:
@callback
def async_parse_identifier(self, item: MediaSourceItem) -> tuple[str, str]:
"""Parse identifier."""
if item.domain != DOMAIN:
raise Unresolvable("Unknown domain.")

if not item.identifier:
# Empty source_dir_id and location
return "", ""

source_dir_id, location = item.identifier.split("/", 1)
source_dir_id, _, location = item.identifier.partition("/")
if source_dir_id not in self.hass.config.media_dirs:
raise Unresolvable("Unknown source directory.")

Expand Down Expand Up @@ -217,3 +229,88 @@ async def get(
raise web.HTTPNotFound()

return web.FileResponse(media_path)


class UploadMediaView(HomeAssistantView):
"""View to upload images."""

url = "/api/media_source/local_source/upload"
name = "api:media_source:local_source:upload"

def __init__(self, hass: HomeAssistant, source: LocalSource) -> None:
"""Initialize the media view."""
self.hass = hass
self.source = source
self.schema = vol.Schema(
{
"media_content_id": str,
"file": FileField,
}
)

async def post(self, request: web.Request) -> web.Response:
"""Handle upload."""
if not request["hass_user"].is_admin:
raise Unauthorized()

# Increase max payload
request._client_max_size = MAX_UPLOAD_SIZE # pylint: disable=protected-access

try:
data = self.schema(dict(await request.post()))
except vol.Invalid as err:
LOGGER.error("Received invalid upload data: %s", err)
raise web.HTTPBadRequest() from err

try:
item = MediaSourceItem.from_uri(self.hass, data["media_content_id"])
except ValueError as err:
LOGGER.error("Received invalid upload data: %s", err)
raise web.HTTPBadRequest() from err

try:
source_dir_id, location = self.source.async_parse_identifier(item)
except Unresolvable as err:
LOGGER.error("Invalid local source ID")
raise web.HTTPBadRequest() from err

uploaded_file: FileField = data["file"]

if not uploaded_file.content_type.startswith(("image/", "video/")):
LOGGER.error("Content type not allowed")
raise vol.Invalid("Only images and video are allowed")

try:
raise_if_invalid_filename(uploaded_file.filename)
except ValueError as err:
LOGGER.error("Invalid filename")
raise web.HTTPBadRequest() from err

try:
await self.hass.async_add_executor_job(
self._move_file,
self.source.async_full_path(source_dir_id, location),
uploaded_file,
)
except ValueError as err:
LOGGER.error("Moving upload failed: %s", err)
raise web.HTTPBadRequest() from err

return self.json(
{"media_content_id": f"{data['media_content_id']}/{uploaded_file.filename}"}
)

def _move_file( # pylint: disable=no-self-use
self, target_dir: Path, uploaded_file: FileField
) -> None:
"""Move file to target."""
if not target_dir.is_dir():
raise ValueError("Target is not an existing directory")

target_path = target_dir / uploaded_file.filename

target_path.relative_to(target_dir)
raise_if_invalid_path(str(target_path))

with target_path.open("wb") as target_fp:
shutil.copyfileobj(uploaded_file.file, target_fp)
122 changes: 122 additions & 0 deletions tests/components/media_source/test_local_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""Test Local Media Source."""
from http import HTTPStatus
import io
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import patch

import pytest

Expand All @@ -9,6 +13,20 @@
from homeassistant.setup import async_setup_component


@pytest.fixture
async def temp_dir(hass):
"""Return a temp dir."""
with TemporaryDirectory() as tmpdirname:
target_dir = Path(tmpdirname) / "another_subdir"
target_dir.mkdir()
await async_process_ha_core_config(
hass, {"media_dirs": {"test_dir": str(target_dir)}}
)
assert await async_setup_component(hass, const.DOMAIN, {})

yield str(target_dir)


async def test_async_browse_media(hass):
"""Test browse media."""
local_media = hass.config.path("media")
Expand Down Expand Up @@ -102,3 +120,107 @@ async def test_media_view(hass, hass_client):

resp = await client.get("/media/recordings/test.mp3")
assert resp.status == HTTPStatus.OK


async def test_upload_view(hass, hass_client, temp_dir, hass_admin_user):
"""Allow uploading media."""

img = (Path(__file__).parent.parent / "image/logo.png").read_bytes()

def get_file(name):
pic = io.BytesIO(img)
pic.name = name
return pic

client = await hass_client()

# Test normal upload
res = await client.post(
"/api/media_source/local_source/upload",
data={
"media_content_id": "media-source://media_source/test_dir/.",
"file": get_file("logo.png"),
},
)

assert res.status == 200
assert (Path(temp_dir) / "logo.png").is_file()

# Test with bad media source ID
for bad_id in (
# Subdir doesn't exist
"media-source://media_source/test_dir/some-other-dir",
# Main dir doesn't exist
"media-source://media_source/test_dir2",
# Location is invalid
"media-source://media_source/test_dir/..",
# Domain != media_source
"media-source://nest/test_dir/.",
# Completely something else
"http://bla",
):
res = await client.post(
"/api/media_source/local_source/upload",
data={
"media_content_id": bad_id,
"file": get_file("bad-source-id.png"),
},
)

assert res.status == 400
assert not (Path(temp_dir) / "bad-source-id.png").is_file()

# Test invalid POST data
res = await client.post(
"/api/media_source/local_source/upload",
data={
"media_content_id": "media-source://media_source/test_dir/.",
"file": get_file("invalid-data.png"),
"incorrect": "format",
},
)

assert res.status == 400
assert not (Path(temp_dir) / "invalid-data.png").is_file()

# Test invalid content type
text_file = io.BytesIO(b"Hello world")
text_file.name = "hello.txt"
res = await client.post(
"/api/media_source/local_source/upload",
data={
"media_content_id": "media-source://media_source/test_dir/.",
"file": text_file,
},
)

assert res.status == 400
assert not (Path(temp_dir) / "hello.txt").is_file()

# Test invalid filename
with patch(
"aiohttp.formdata.guess_filename", return_value="../invalid-filename.png"
):
res = await client.post(
"/api/media_source/local_source/upload",
data={
"media_content_id": "media-source://media_source/test_dir/.",
"file": get_file("../invalid-filename.png"),
},
)

assert res.status == 400
assert not (Path(temp_dir) / "../invalid-filename.png").is_file()

# Remove admin access
hass_admin_user.groups = []
res = await client.post(
"/api/media_source/local_source/upload",
data={
"media_content_id": "media-source://media_source/test_dir/.",
"file": get_file("no-admin-test.png"),
},
)

assert res.status == 401
assert not (Path(temp_dir) / "no-admin-test.png").is_file()