Skip to content
This repository has been archived by the owner on Oct 11, 2023. It is now read-only.

Commit

Permalink
AWS Option for LIGHT data model storage (#305)
Browse files Browse the repository at this point in the history
* Some initial transitions over to model pool

* Moving initialization code out from where it occurred

* Wiring more of the system together

* Adding opt for reranked generative

* Filling out UserDB

* Abstract enforce get first

* Clearer argument - num_turns

* Using enums in DB

* Initial episode data model

* Update content loggers to use episode formatting

* Updating tables to work with testing

* Fixing some test changes

* Fixing small warnings that were noise during tests

* Moving default log path

* Test fix

* Correcting math thanks to Kurt

* Updating env DB classes to SQLAlchemy

* Name keys and Elems coded

* Adding arbitrary node attributes

* First complete pass of EnvDB

* Mypy fixings

* Fixing agents

* Writing some tests

* Finishing tests for object and room creates and queries

* Edge testing

* Arbitrary attributes testing

* Quests and testing

* And finally, DBGraph tests

* fixing episode change

* TODO function

* final mypy fixes

* DBID testing

* a -> either a or an depending on aeiou

* adding WorldConfig to hold complex configuration vars

* Moving episode_db into relevant GraphBuilders

* Game launches, but not logging

* Local BaseDB, now saving episodes

* Missing files

* deleting miscommit

* test fix

* Migrating to UserDB

* No more LIGHTDatabase in TornadoServer

* Fixing tests

* Works after testing locally

* Updated messaging for unimplemented

* Upgrading OneRoomGraphBuilder to ModelPool

* Completing (almost) the rest of Modelpool references

* Works without loading models in play_map

* Model pool actually works

* Safety working as well

* removing prints

* Fixing some tests, skipping starspace

* Runs on server too

* Creating LIGHT's ModelServer

* Undo server change

* tornado simplicity

* Handling for inline candidate models

* But regular models should also work without this

* Async... all of the things...

* Async the server too

* clearing up async server tests

* Correct async mock

* internalize init_world

* clean up tornado usage

* small GameInstance bug

* small GameInstance bug

* Some deploy fixes

* now using aws as a storage backend

* Moving safety model to async part

* Some safety fixes

* test fixes

* silly elif fix
  • Loading branch information
JackUrb authored Aug 22, 2022
1 parent bf1b367 commit a372a2e
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 53 deletions.
6 changes: 0 additions & 6 deletions deploy/web/server/game_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@

from typing import Optional, TYPE_CHECKING

if TYPE_CHECKING:
from light.data_model.db.environment import EpisodeDB
from light.world.world import WorldConfig

from typing import Optional, TYPE_CHECKING

if TYPE_CHECKING:
from light.data_model.db.environment import EpisodeDB
from light.world.world import WorldConfig
Expand Down
71 changes: 59 additions & 12 deletions light/data_model/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,18 @@ class LightDBConfig:
file_root: Optional[str] = DEFAULT_LOG_PATH


@dataclass
class LightAWSDBConfig(LightDBConfig):
backend: str = "aws-postgres"
file_root: str = MISSING
db_address: str = MISSING
db_user: str = MISSING
db_pass: str = MISSING


cs = ConfigStore.instance()
cs.store(name="config1", node=LightDBConfig)
cs.store(name="db/base", node=LightDBConfig)
cs.store(name="db/aws-postgres", node=LightAWSDBConfig)


class DBStatus(Enum):
Expand Down Expand Up @@ -88,7 +98,6 @@ def __init__(self, config: "DictConfig"):
Create this database, either connecting to a remote host or local
files and instances.
"""
# TODO replace with a swappable engine that persists the data
self.backend = config.backend
if config.backend == "test":
self.engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
Expand All @@ -101,6 +110,28 @@ def __init__(self, config: "DictConfig"):
self.file_root = config.file_root
db_path = os.path.join(self.file_root, f"{self.DB_TYPE}.db")
self.engine = create_engine(f"sqlite:////{db_path}")
elif config.backend == "aws-postgres":
try:
import psycopg2
import boto3
except ImportError:
print(
"For aws-postgres usage, you must also `pip install mysqlclient boto3 psycopg2-binary"
)
raise
# Get DB registered and functioning
self.db_address = config.db_address
db_address = config.db_address
login_user = config.db_user
login_pass = config.db_pass
self.engine = create_engine(
f"postgresql://{login_user}:{login_pass}@{db_address}:5432/postgres"
)

# Connect to the s3 filestore
self.file_root = config.file_root # file root is a s3 bucket address
s3 = boto3.resource("s3")
self.bucket = s3.Bucket(self.file_root)
else:
raise NotImplementedError(
f"Provided backend {config.backend} doens't exist"
Expand Down Expand Up @@ -136,8 +167,21 @@ def file_path_exists(self, file_path: str) -> bool:
if self.backend in ["test", "local"]:
full_path = os.path.join(self.file_root, file_path)
return os.path.exists(full_path)
elif self.backend in ["aws-postgres"]:
import botocore

try:
self.bucket.Object(file_path).load()
return True
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "404":
# The object does not exist.
return False
else:
# Something else has gone wrong.
raise
else:
raise NotImplementedError
raise NotImplementedError(f"Backend {self.backend} is not implemented")

def write_data_to_file(
self, data: Union[str, Dict[str, Any]], filename: str, json_encode: bool = False
Expand All @@ -154,8 +198,12 @@ def write_data_to_file(
json.dump(data, target_file)
else:
target_file.write(data)
elif self.backend in ["aws-postgres"]:
if json_encode:
data = json.dumps(data)
self.bucket.Object(filename).put(Body=data)
else:
raise NotImplementedError()
raise NotImplementedError(f"Backend {self.backend} is not implemented")

def read_data_from_file(
self, filename: str, json_encoded: bool = False
Expand All @@ -171,15 +219,14 @@ def read_data_from_file(
return json.load(target_file)
else:
return target_file.read()
elif self.backend in ["aws-postgres"]:
data = self.bucket.Object(filename).get()["Body"]
if json_encoded:
return json.loads(data)
else:
return data
else:
raise NotImplementedError()

def open_file(self):
try:
file = open(self.file_name, "w")
yield file
finally:
file.close()
raise NotImplementedError(f"Backend {self.backend} is not implemented")

def shutdown(self):
if self.backend == "test":
Expand Down
8 changes: 0 additions & 8 deletions light/graph/builders/one_room_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,6 @@
from omegaconf import MISSING, DictConfig
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from light.data_model.light_database import LIGHTDatabase
from light.registry.model_pool import ModelPool

from dataclasses import dataclass, field
from omegaconf import MISSING, DictConfig
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from light.data_model.light_database import LIGHTDatabase
from light.registry.model_pool import ModelPool
Expand Down
24 changes: 6 additions & 18 deletions light/registry/model_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from dataclasses import dataclass, field
from omegaconf import MISSING, DictConfig
import asyncio
import enum

from light.registry.parlai_model import ParlAIModelConfig, ParlAIModelLoader
from light.registry.parlai_remote_model import (
Expand Down Expand Up @@ -40,23 +39,12 @@
}


class ModelTypeName(enum.Enum):
"""Common model names of use in LIGHT, for use in register_model"""

SAFETY = "safety" # Models used to evaluate dialog or env safety
DIALOG = "dialog" # Models for generating dialogue
SCORING = "role_playing_score" # Models to score player utterances
ACTION = "action" # Models used by model agents for generating actions
GENERIC_ACTS = "generic_action" # Models to select a next action from cands
PARSER = "parser" # Models to parse raw text to in-game actions


class ModelPool:
def __init__(self):
self._model_loaders = {}

async def register_model_async(
self, config: Union[DictConfig, ModelConfig], model_names: List[ModelTypeName]
self, config: Union[DictConfig, ModelConfig], model_names: List[str]
) -> None:
"""
Takes the given config, loads the model, and
Expand All @@ -70,7 +58,7 @@ async def register_model_async(
loader = loader_class(config)
await loader.force_load()
for model_name in model_names:
self._model_loaders[model_name.value] = loader
self._model_loaders[model_name] = loader

def register_model(
self, config: Union[DictConfig, ModelConfig], model_names: List[str]
Expand All @@ -80,21 +68,21 @@ def register_model(
"""
return asyncio.run(self.register_model_async(config, model_names))

def has_model(self, model_name: ModelTypeName) -> bool:
def has_model(self, model_name: str) -> bool:
"""
Determine if there's a model registered for the given name.
"""
return model_name.value in self._model_loaders
return model_name in self._model_loaders

def get_model(
self, model_name: ModelTypeName, overrides: Optional[Dict[str, Any]] = None
self, model_name: str, overrides: Optional[Dict[str, Any]] = None
) -> Agent:
"""
Get a copy of the model stored in the given name
If overrides are provided, pass those to the loader as well
"""
loader = self._model_loaders.get(model_name.value)
loader = self._model_loaders.get(model_name)
if loader is None:
raise AssertionError(
f"No models registered for requested name {model_name}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
{
"model": "projects.light_whoami.agents.expanded_attention:ExpandedDecoderAttentionAndPacerAgent",
"predictor_model_file": "zoo:light_whoami/rpa_reranker/model",
"model_file": "$LIGHT_MODEL_ROOT/dialog/baseline/model",
"model_file": "zoo:light_whoami/profile_expanded_attention_128/model",
"inference": "beam",
"datatype": "valid",
"beam_context_block_ngram": 3,
"beam_block_ngram": 3,
"beam_size": 10,
"beam_min_length": 20,
"skip_generation": false,
"interactive_mode": true
}
20 changes: 12 additions & 8 deletions light/registry/parlai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,19 @@ async def load_model(self, config: DictConfig) -> None:
opt["override"].update(overrides)
model = create_agent(opt)

context_fill = opt.get("truncate", CONTEXT_FILL_COUNT)
# Push something through the model to fill context
act = {
"text": INIT_CONTEXT + "Hello " * CONTEXT_FILL_COUNT,
"episode_done": True,
}
if opt.get("eval_candidates") == "inline":
act["label_candidates"] = ["hi", "hi there", "whatup"]
model.observe(act)
await model.act()
try:
act = {
"text": INIT_CONTEXT + "Hello " * context_fill,
"episode_done": True,
}
if opt.get("eval_candidates") == "inline":
act["label_candidates"] = ["hi", "hi there", "whatup"]
model.observe(act)
await model.act()
except Exception as e:
print(f"Cannot warm model {opt['model']}, hit error {e}")

# Share the model params for use in `get_model`
self._shared = model.share()
Expand Down

0 comments on commit a372a2e

Please sign in to comment.