Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 122 additions & 10 deletions python/triton/runtime/cache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import importlib
import json
import os
import random
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Optional
from typing import Dict, List, Optional
import hashlib


def default_cache_dir():
Expand All @@ -27,10 +29,6 @@ def __init__(self, key):
def get_file(self, filename) -> Optional[str]:
pass

@abstractmethod
def has_file(self, filename) -> bool:
pass

@abstractmethod
def put(self, data, filename, binary=True) -> str:
pass
Expand Down Expand Up @@ -70,20 +68,20 @@ def __init__(self, key, override=False, dump=False):
def _make_path(self, filename) -> str:
return os.path.join(self.cache_dir, filename)

def has_file(self, filename) -> bool:
def _has_file(self, filename) -> bool:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
return os.path.exists(self._make_path(filename))

def get_file(self, filename) -> Optional[str]:
if self.has_file(filename):
if self._has_file(filename):
return self._make_path(filename)
else:
return None

def get_group(self, filename: str) -> Optional[Dict[str, str]]:
grp_filename = f"__grp__{filename}"
if not self.has_file(grp_filename):
if not self._has_file(grp_filename):
return None
grp_filepath = self._make_path(grp_filename)
with open(grp_filepath) as f:
Expand Down Expand Up @@ -130,6 +128,122 @@ def put(self, data, filename, binary=True) -> str:
return filepath


class RemoteCacheBackend:
"""
A backend implementation for accessing a remote/distributed cache.
"""

def __init__(self, key: str):
pass

@abstractmethod
def get(self, filenames: List[str]) -> Dict[str, bytes]:
pass

@abstractmethod
def put(self, filename: str, data: bytes):
pass


class RedisRemoteCacheBackend(RemoteCacheBackend):

def __init__(self, key):
import redis
self._key = key
self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
self._redis = redis.Redis(
host=os.environ.get("TRITON_REDIS_HOST", "localhost"),
port=int(os.environ.get("TRITON_REDIS_PORT", 6379)),
)

def _get_key(self, filename: str) -> str:
return self._key_fmt.format(key=self._key, filename=filename)

def get(self, filenames: List[str]) -> Dict[str, str]:
results = self._redis.mget([self._get_key(f) for f in filenames])
return {filename: result for filename, result in zip(filenames, results) if result is not None}

def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
self._redis.set(self._get_key(filename), data)


class RemoteCacheManager(CacheManager):

def __init__(self, key, override=False, dump=False):
# Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"]
module_path, clz_nme = remote_cache_manager.split(":")
module = importlib.import_module(module_path)
remote_cache_cls = getattr(module, clz_nme)
self._backend = remote_cache_cls(key)

self._override = override
self._dump = dump

# Use a `FileCacheManager` to materialize remote cache paths locally.
self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)

def _materialize(self, filename: str, data: bytes):
# We use a backing `FileCacheManager` to provide the materialized data.
return self._file_cache_manager.put(data, filename, binary=True)

def get_file(self, filename: str) -> Optional[str]:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.get_file(filename)

# We always check the remote cache backend -- even if our internal file-
# based cache has the item -- to make sure LRU accounting works as
# expected.
results = self._backend.get([filename])
if len(results) == 0:
return None
(_, data), = results.items()
return self._materialize(filename, data)

def put(self, data, filename: str, binary=True) -> str:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.put(data, filename, binary=binary)

if not isinstance(data, bytes):
data = str(data).encode("utf-8")
self._backend.put(filename, data)
return self._materialize(filename, data)

def get_group(self, filename: str) -> Optional[Dict[str, str]]:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.get_group(filename)

grp_filename = f"__grp__{filename}"
grp_filepath = self.get_file(grp_filename)
if grp_filepath is None:
return None
with open(grp_filepath) as f:
grp_data = json.load(f)
child_paths = grp_data.get("child_paths", None)

result = None

# Found group data.
if child_paths is not None:
result = {}
for child_path, data in self._backend.get(child_paths).items():
result[child_path] = self._materialize(child_path, data)

return result

def put_group(self, filename: str, group: Dict[str, str]):
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.put_group(filename, group)

grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename)


__cache_cls = FileCacheManager
__cache_cls_nme = "DEFAULT"

Expand All @@ -142,8 +256,6 @@ def get_cache_manager(key) -> CacheManager:
global __cache_cls_nme

if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
import importlib

module_path, clz_nme = user_cache_manager.split(":")
module = importlib.import_module(module_path)
__cache_cls = getattr(module, clz_nme)
Expand Down