Skip to content

Commit

Permalink
Merge pull request #12 from Querent-ai/local_storage_implementation
Browse files Browse the repository at this point in the history
Local storage implementation
  • Loading branch information
saraswatpuneet authored Aug 13, 2023
2 parents 288dc44 + 80952f7 commit f19aac8
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 33 deletions.
4 changes: 4 additions & 0 deletions querent/common/uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def from_well_formed(cls, uri: str) -> "Uri":
def extension(self) -> Optional[str]:
return Path(self.uri).suffix.lstrip(".")

@property
def path(self) -> str:
return self.uri[self.protocol_idx + len(self.PROTOCOL_SEPARATOR) :]

def as_str(self) -> str:
return self.uri

Expand Down
7 changes: 6 additions & 1 deletion querent/config/storage_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class StorageBackend(str, Enum):
LocalFile = "localfile"
Redis = "redis"

class StorageBackendFlavor(str, Enum):
DigitalOcean = "do"
Garage = "garage"
Expand All @@ -22,6 +22,11 @@ class Config:
class LocalFileStorageConfig(BaseModel):
root_path: str

class RedisStorageConfig(BaseModel):
host: str
port: int
password: Optional[str] = None

class StorageConfigWrapper(BaseModel):
backend: StorageBackend
config: Optional[BaseModel] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import shutil
from querent.common.uri import Protocol, Uri
from querent.config.storage_config import StorageBackend
from querent.storage.payload import PutPayload

from querent.storage.storage_errors import StorageError, StorageErrorKind
from querent.storage.storage_base import Storage
Expand Down Expand Up @@ -97,8 +98,10 @@ def uri(self):
return self.underlying.uri()

class LocalFileStorage(Storage):
def __init__(self, uri, root):
def __init__(self, uri: Uri, root=None):
self.uri = uri
if not root:
root = Path(uri.path)
self.root = root
self.cache_lock = Lock()

Expand All @@ -125,24 +128,41 @@ async def check_connectivity(self):
f"Failed to create directories at {self.root}: {e}",
)

async def put(self, path, payload):
async def put(self, path: Path, payload: PutPayload):
full_path = await self.full_path(path)
parent_dir = full_path.parent
try:
parent_dir.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=parent_dir, delete=False) as temp_file:
temp_path = Path(temp_file.name)
temp_file.close()
await asyncio.to_thread(shutil.copyfileobj, payload.byte_stream(), temp_path)
temp_path.rename(full_path)
payload_len = payload.len()
if payload_len > 0:
with open(full_path, "wb") as file:
for i in range(0, payload_len, 1024):
chunk = await payload.range_byte_stream(i, i + 1024)
file.write(chunk)
except Exception as e:
raise StorageError(
StorageErrorKind.Io,
f"Failed to write file to {full_path}: {e}",
)

async def delete_single_file(self, relative_path):
full_path = await self.full_path(relative_path)
async def copy_to(self, path, output):
full_path = await self.full_path(path)
with open(full_path, "rb") as file:
await asyncio.to_thread(shutil.copyfileobj, file, output)

async def get_slice(self, path, start, end):
full_path = await self.full_path(path)
with open(full_path, "rb") as file:
file.seek(start)
return file.read(end - start)

async def get_all(self, path):
full_path = await self.full_path(path)
with open(full_path, "rb") as file:
return file.read()

async def delete(self, path):
full_path = await self.full_path(path)
try:
full_path.unlink()
except FileNotFoundError:
Expand All @@ -153,17 +173,28 @@ async def delete_single_file(self, relative_path):
f"Failed to delete file {full_path}: {e}",
)

async def delete(self, path):
await self.delete_single_file(path)
async def bulk_delete(self, paths):
for path in paths:
await self.delete(path)

async def exists(self, path):
full_path = await self.full_path(path)
return full_path.exists()

async def file_num_bytes(self, path):
full_path = await self.full_path(path)
return full_path.stat().st_size

def uri(self):
return str(self.uri)

class LocalStorageFactory(StorageFactory):
def backend(self) -> StorageBackend:
return StorageBackend.LocalFile

async def resolve(self, uri: str) -> Storage:
parsed_uri = Uri(uri) # Ensure you have the Uri class imported and defined
if parsed_uri.protocol == Protocol.File:
root_path = Path(parsed_uri.path)
async def resolve(self, uri: Uri) -> Storage:
if uri.protocol == Protocol.File:
root_path = Path(uri.path)
return LocalFileStorage(uri, root_path)
else:
raise ValueError("Unsupported protocol")
raise ValueError("Unsupported protocol")
52 changes: 51 additions & 1 deletion querent/storage/payload.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
from abc import ABC, abstractmethod
from typing import Optional
from pathlib import Path


class PutPayload(ABC):
@abstractmethod
Expand All @@ -15,6 +16,55 @@ async def byte_stream(self) -> bytes:
total_len = self.len()
return await self.range_byte_stream(0, total_len)

async def read(self, n: Optional[int] = None) -> bytes:
if n is None:
return await self.byte_stream()
return await self.range_byte_stream(0, n)

async def read_all(self) -> bytes:
total_len = self.len()
return await self.range_byte_stream(0, total_len)

class BytesPayload(PutPayload):
def __init__(self, data: bytes):
self.data = data

def len(self) -> int:
return len(self.data)

async def range_byte_stream(self, start: int, end: int) -> bytes:
start = max(start, 0)
end = min(end, len(self.data))
return self.data[start:end]

async def read(self, n: Optional[int] = None) -> bytes:
if n is None:
return await self.byte_stream()
return await self.range_byte_stream(0, n)

async def read_all(self) -> bytes:
return await self.byte_stream()

class ByteStream:
def __init__(self, data: bytes):
self.data = data

async def read(self, n: Optional[int] = None) -> bytes:
if n is None:
return self.data
return self.data[:n]


# Example usage
async def main():
payload_data = b"test content"
payload = BytesPayload(payload_data)

byte_stream = ByteStream(payload_data)
result = await byte_stream.read(4)
print(result) # Output: b"test"


if __name__ == "__main__":
import asyncio
asyncio.run(main())
4 changes: 3 additions & 1 deletion querent/storage/storage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from pathlib import Path
from typing import IO

from querent.storage.payload import PutPayload

class Storage(ABC):
@abstractmethod
async def check_connectivity(self) -> None:
pass

@abstractmethod
async def put(self, path: Path, payload) -> None:
async def put(self, path: Path, payload: PutPayload) -> None:
pass

@abstractmethod
Expand Down
7 changes: 3 additions & 4 deletions querent/storage/storage_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@
from querent.storage.storage_errors import StorageResolverError, StorageErrorKind
from querent.common.uri import Protocol, Uri
from querent.storage.storage_factory import StorageFactory
from querent.storage.local.local_storage import LocalStorageFactory
from querent.storage.local.local_file_storage import LocalStorageFactory
class StorageResolver:
def __init__(self):
self.storage_factories: Dict[StorageBackend, StorageFactory] = {
StorageBackend.LocalFile: LocalStorageFactory(),
}

def resolve(self, uri_str: str) -> Optional[Storage]:
uri = Uri(uri_str)
def resolve(self, uri: Uri) -> Optional[Storage]:
backend = self._determine_backend(uri.protocol)
if backend in self.storage_factories:
return self.storage_factories[backend].resolve(uri_str)
return self.storage_factories[backend].resolve(uri)
else:
raise StorageResolverError(
StorageErrorKind.NotSupported, backend, "Unsupported backend"
Expand Down
47 changes: 37 additions & 10 deletions tests/test_local_storage.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import tempfile
from pathlib import Path
import pytest
from querent.storage.local.local_storage import LocalFileStorage
from querent.storage.local.local_file_storage import LocalFileStorage, LocalStorageFactory
from querent.common.uri import Uri
from querent.storage.payload import PutPayload
import querent.storage.payload as querent_payload
from querent.storage.storage_resolver import StorageResolver

@pytest.fixture
def temp_dir():
Expand All @@ -12,11 +14,36 @@ def temp_dir():
temp_dir.cleanup()

def test_local_storage(temp_dir):
storage = LocalFileStorage(temp_dir)
uri = Uri("file://test.txt")
payload = PutPayload(b"test")
storage.put(uri, payload)
assert Path(temp_dir, uri.path).exists()
assert storage.get(uri).payload == b"test"
storage.delete(uri)
assert not Path(temp_dir, uri.path).exists()
uri = Uri("file://" + temp_dir) # Use the temp_dir as the base URI
storage = LocalFileStorage(uri, Path(temp_dir)) # Provide the 'uri' argument only
payload = querent_payload.BytesPayload(b"test")

print(f"Temp dir: {temp_dir}")
print(f"URI: {uri}")

# Put the payload
asyncio.run(storage.put(Path(temp_dir + "/test.txt"), payload))
file_path = Path(temp_dir, "test.txt")
print(f"File path: {file_path}")

assert file_path.exists()
with open(file_path, "rb") as file:
content = file.read()
print(f"File content: {content.decode('utf-8')}")
assert content == b"test"

def test_storage_resolver(temp_dir):
uri = Uri("file://" + temp_dir) # Use the temp_dir as the base URI
resolver = StorageResolver()

storage = asyncio.run(resolver.resolve(uri))

payload = querent_payload.BytesPayload(b"ok")
asyncio.run(storage.put(Path(temp_dir + "/test.txt"), payload))

file_path = Path(temp_dir, "test.txt")
assert file_path.exists()

with open(file_path, "rb") as file:
content = file.read()
assert content == b"ok"

0 comments on commit f19aac8

Please sign in to comment.