diff --git a/querent/common/uri.py b/querent/common/uri.py index 3856b5e1..4341c9df 100644 --- a/querent/common/uri.py +++ b/querent/common/uri.py @@ -62,6 +62,10 @@ def from_well_formed(cls, uri: str) -> "Uri": def extension(self) -> Optional[str]: return Path(self.uri).suffix.lstrip(".") + @property + def path(self) -> str: + return self.uri[self.protocol_idx + len(self.PROTOCOL_SEPARATOR) :] + def as_str(self) -> str: return self.uri diff --git a/querent/config/storage_config.py b/querent/config/storage_config.py index 41c8b056..7f60f31c 100644 --- a/querent/config/storage_config.py +++ b/querent/config/storage_config.py @@ -5,7 +5,7 @@ class StorageBackend(str, Enum): LocalFile = "localfile" Redis = "redis" - + class StorageBackendFlavor(str, Enum): DigitalOcean = "do" Garage = "garage" @@ -22,6 +22,11 @@ class Config: class LocalFileStorageConfig(BaseModel): root_path: str +class RedisStorageConfig(BaseModel): + host: str + port: int + password: Optional[str] = None + class StorageConfigWrapper(BaseModel): backend: StorageBackend config: Optional[BaseModel] = None diff --git a/querent/storage/local/local_storage.py b/querent/storage/local/local_file_storage.py similarity index 73% rename from querent/storage/local/local_storage.py rename to querent/storage/local/local_file_storage.py index ef93cc8d..d20cca5f 100644 --- a/querent/storage/local/local_storage.py +++ b/querent/storage/local/local_file_storage.py @@ -6,6 +6,7 @@ import shutil from querent.common.uri import Protocol, Uri from querent.config.storage_config import StorageBackend +from querent.storage.payload import PutPayload from querent.storage.storage_errors import StorageError, StorageErrorKind from querent.storage.storage_base import Storage @@ -97,8 +98,10 @@ def uri(self): return self.underlying.uri() class LocalFileStorage(Storage): - def __init__(self, uri, root): + def __init__(self, uri: Uri, root=None): self.uri = uri + if not root: + root = Path(uri.path) self.root = root self.cache_lock = Lock() @@ -125,24 +128,41 @@ async def check_connectivity(self): f"Failed to create directories at {self.root}: {e}", ) - async def put(self, path, payload): + async def put(self, path: Path, payload: PutPayload): full_path = await self.full_path(path) parent_dir = full_path.parent try: parent_dir.mkdir(parents=True, exist_ok=True) - with tempfile.NamedTemporaryFile(dir=parent_dir, delete=False) as temp_file: - temp_path = Path(temp_file.name) - temp_file.close() - await asyncio.to_thread(shutil.copyfileobj, payload.byte_stream(), temp_path) - temp_path.rename(full_path) + payload_len = payload.len() + if payload_len > 0: + with open(full_path, "wb") as file: + for i in range(0, payload_len, 1024): + chunk = await payload.range_byte_stream(i, i + 1024) + file.write(chunk) except Exception as e: raise StorageError( StorageErrorKind.Io, f"Failed to write file to {full_path}: {e}", ) - async def delete_single_file(self, relative_path): - full_path = await self.full_path(relative_path) + async def copy_to(self, path, output): + full_path = await self.full_path(path) + with open(full_path, "rb") as file: + await asyncio.to_thread(shutil.copyfileobj, file, output) + + async def get_slice(self, path, start, end): + full_path = await self.full_path(path) + with open(full_path, "rb") as file: + file.seek(start) + return file.read(end - start) + + async def get_all(self, path): + full_path = await self.full_path(path) + with open(full_path, "rb") as file: + return file.read() + + async def delete(self, path): + full_path = await self.full_path(path) try: full_path.unlink() except FileNotFoundError: @@ -153,17 +173,28 @@ async def delete_single_file(self, relative_path): f"Failed to delete file {full_path}: {e}", ) - async def delete(self, path): - await self.delete_single_file(path) + async def bulk_delete(self, paths): + for path in paths: + await self.delete(path) + + async def exists(self, path): + full_path = await self.full_path(path) + return full_path.exists() + + async def file_num_bytes(self, path): + full_path = await self.full_path(path) + return full_path.stat().st_size + + def uri(self): + return str(self.uri) class LocalStorageFactory(StorageFactory): def backend(self) -> StorageBackend: return StorageBackend.LocalFile - async def resolve(self, uri: str) -> Storage: - parsed_uri = Uri(uri) # Ensure you have the Uri class imported and defined - if parsed_uri.protocol == Protocol.File: - root_path = Path(parsed_uri.path) + async def resolve(self, uri: Uri) -> Storage: + if uri.protocol == Protocol.File: + root_path = Path(uri.path) return LocalFileStorage(uri, root_path) else: - raise ValueError("Unsupported protocol") \ No newline at end of file + raise ValueError("Unsupported protocol") diff --git a/querent/storage/payload.py b/querent/storage/payload.py index c7cdc7ca..3292e96f 100644 --- a/querent/storage/payload.py +++ b/querent/storage/payload.py @@ -1,6 +1,7 @@ +import io from abc import ABC, abstractmethod from typing import Optional -from pathlib import Path + class PutPayload(ABC): @abstractmethod @@ -15,6 +16,55 @@ async def byte_stream(self) -> bytes: total_len = self.len() return await self.range_byte_stream(0, total_len) + async def read(self, n: Optional[int] = None) -> bytes: + if n is None: + return await self.byte_stream() + return await self.range_byte_stream(0, n) + async def read_all(self) -> bytes: total_len = self.len() return await self.range_byte_stream(0, total_len) + +class BytesPayload(PutPayload): + def __init__(self, data: bytes): + self.data = data + + def len(self) -> int: + return len(self.data) + + async def range_byte_stream(self, start: int, end: int) -> bytes: + start = max(start, 0) + end = min(end, len(self.data)) + return self.data[start:end] + + async def read(self, n: Optional[int] = None) -> bytes: + if n is None: + return await self.byte_stream() + return await self.range_byte_stream(0, n) + + async def read_all(self) -> bytes: + return await self.byte_stream() + +class ByteStream: + def __init__(self, data: bytes): + self.data = data + + async def read(self, n: Optional[int] = None) -> bytes: + if n is None: + return self.data + return self.data[:n] + + +# Example usage +async def main(): + payload_data = b"test content" + payload = BytesPayload(payload_data) + + byte_stream = ByteStream(payload_data) + result = await byte_stream.read(4) + print(result) # Output: b"test" + + +if __name__ == "__main__": + import asyncio + asyncio.run(main()) diff --git a/querent/storage/storage_base.py b/querent/storage/storage_base.py index d5f937cd..c4e7f325 100644 --- a/querent/storage/storage_base.py +++ b/querent/storage/storage_base.py @@ -2,13 +2,15 @@ from pathlib import Path from typing import IO +from querent.storage.payload import PutPayload + class Storage(ABC): @abstractmethod async def check_connectivity(self) -> None: pass @abstractmethod - async def put(self, path: Path, payload) -> None: + async def put(self, path: Path, payload: PutPayload) -> None: pass @abstractmethod diff --git a/querent/storage/storage_resolver.py b/querent/storage/storage_resolver.py index fb991a37..106102ba 100644 --- a/querent/storage/storage_resolver.py +++ b/querent/storage/storage_resolver.py @@ -4,18 +4,17 @@ from querent.storage.storage_errors import StorageResolverError, StorageErrorKind from querent.common.uri import Protocol, Uri from querent.storage.storage_factory import StorageFactory -from querent.storage.local.local_storage import LocalStorageFactory +from querent.storage.local.local_file_storage import LocalStorageFactory class StorageResolver: def __init__(self): self.storage_factories: Dict[StorageBackend, StorageFactory] = { StorageBackend.LocalFile: LocalStorageFactory(), } - def resolve(self, uri_str: str) -> Optional[Storage]: - uri = Uri(uri_str) + def resolve(self, uri: Uri) -> Optional[Storage]: backend = self._determine_backend(uri.protocol) if backend in self.storage_factories: - return self.storage_factories[backend].resolve(uri_str) + return self.storage_factories[backend].resolve(uri) else: raise StorageResolverError( StorageErrorKind.NotSupported, backend, "Unsupported backend" diff --git a/tests/test_local_storage.py b/tests/test_local_storage.py index cf9bfd3e..4ecd4af5 100644 --- a/tests/test_local_storage.py +++ b/tests/test_local_storage.py @@ -1,9 +1,11 @@ +import asyncio import tempfile from pathlib import Path import pytest -from querent.storage.local.local_storage import LocalFileStorage +from querent.storage.local.local_file_storage import LocalFileStorage, LocalStorageFactory from querent.common.uri import Uri -from querent.storage.payload import PutPayload +import querent.storage.payload as querent_payload +from querent.storage.storage_resolver import StorageResolver @pytest.fixture def temp_dir(): @@ -12,11 +14,36 @@ def temp_dir(): temp_dir.cleanup() def test_local_storage(temp_dir): - storage = LocalFileStorage(temp_dir) - uri = Uri("file://test.txt") - payload = PutPayload(b"test") - storage.put(uri, payload) - assert Path(temp_dir, uri.path).exists() - assert storage.get(uri).payload == b"test" - storage.delete(uri) - assert not Path(temp_dir, uri.path).exists() + uri = Uri("file://" + temp_dir) # Use the temp_dir as the base URI + storage = LocalFileStorage(uri, Path(temp_dir)) # Provide the 'uri' argument only + payload = querent_payload.BytesPayload(b"test") + + print(f"Temp dir: {temp_dir}") + print(f"URI: {uri}") + + # Put the payload + asyncio.run(storage.put(Path(temp_dir + "/test.txt"), payload)) + file_path = Path(temp_dir, "test.txt") + print(f"File path: {file_path}") + + assert file_path.exists() + with open(file_path, "rb") as file: + content = file.read() + print(f"File content: {content.decode('utf-8')}") + assert content == b"test" + +def test_storage_resolver(temp_dir): + uri = Uri("file://" + temp_dir) # Use the temp_dir as the base URI + resolver = StorageResolver() + + storage = asyncio.run(resolver.resolve(uri)) + + payload = querent_payload.BytesPayload(b"ok") + asyncio.run(storage.put(Path(temp_dir + "/test.txt"), payload)) + + file_path = Path(temp_dir, "test.txt") + assert file_path.exists() + + with open(file_path, "rb") as file: + content = file.read() + assert content == b"ok" \ No newline at end of file