Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor model class and runners to be more independent #494

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions clarifai/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,20 @@ def model():
)
def upload(model_path, download_checkpoints, skip_dockerfile):
"""Upload a model to Clarifai."""
from clarifai.runners.models import model_upload
from clarifai.runners.models.model_builder import upload_model
upload_model(model_path, download_checkpoints, skip_dockerfile)

model_upload.main(model_path, download_checkpoints, skip_dockerfile)

@model.command()
@click.option(
'--model_path',
type=click.Path(exists=True),
required=True,
help='Path to the model directory.')
def download_checkpoints(model_path):
"""Download remote checkpoints that are specified in the config."""
from clarifai.runners.models.model_builder import ModelBuilder
ModelBuilder(model_path).download_checkpoints()


@model.command()
Expand Down
4 changes: 2 additions & 2 deletions clarifai/runners/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .models.base_typed_model import AnyAnyModel, TextInputModel, VisualInputModel
from .models.model_builder import ModelBuilder
from .models.model_runner import ModelRunner
from .models.model_upload import ModelUploader
from .utils.data_handler import InputDataHandler, OutputDataHandler

__all__ = [
"ModelRunner",
"ModelUploader",
"ModelBuilder",
"InputDataHandler",
"OutputDataHandler",
"AnyAnyModel",
Expand Down
4 changes: 2 additions & 2 deletions clarifai/runners/models/base_typed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from google.protobuf import json_format

from ..utils.data_handler import InputDataHandler, OutputDataHandler
from .model_runner import ModelRunner
from .model_class import ModelClass


class AnyAnyModel(ModelRunner):
class AnyAnyModel(ModelClass):

def load_model(self):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import importlib
import inspect
import os
import re
import sys
Expand All @@ -13,6 +15,7 @@
from rich.markup import escape

from clarifai.client import BaseClient
from clarifai.runners.models.model_class import ModelClass
from clarifai.runners.utils.const import (AVAILABLE_PYTHON_IMAGES, AVAILABLE_TORCH_IMAGES,
CONCEPTS_REQUIRED_MODEL_TYPE, DEFAULT_PYTHON_VERSION,
PYTHON_BASE_IMAGE, TORCH_BASE_IMAGE)
Expand All @@ -28,7 +31,7 @@ def _clear_line(n: int = 1) -> None:
print(LINE_UP, end=LINE_CLEAR, flush=True)


class ModelUploader:
class ModelBuilder:

def __init__(self, folder: str, validate_api_ids: bool = True, download_validation_only=False):
"""
Expand All @@ -52,6 +55,56 @@ def __init__(self, folder: str, validate_api_ids: bool = True, download_validati
self.inference_compute_info = self._get_inference_compute_info()
self.is_v3 = True # Do model build for v3

def create_model_instance(self, load_model=True):
"""
Create an instance of the model class, as specified in the config file.
"""
class_config = self.config.get("class_info", {})

model_file = class_config.get("file_path")
if model_file:
model_file = os.path.join(self.folder, model_file)
if not os.path.exists(model_file):
raise Exception(f"Model file {model_file} does not exist.")
else:
# look for default model.py file location
for loc in ["model.py", "1/model.py"]:
model_file = os.path.join(self.folder, loc)
if os.path.exists(model_file):
break
if not os.path.exists(model_file):
raise Exception("Model file not found.")

module_name = os.path.basename(model_file).replace(".py", "")

spec = importlib.util.spec_from_file_location(module_name, model_file)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)

if class_config.get("class_name"):
model_class = getattr(module, class_config["class_name"])
else:
# Find all classes in the model.py file that are subclasses of ModelClass
classes = [
cls for _, cls in inspect.getmembers(module, inspect.isclass)
if issubclass(cls, ModelClass) and cls.__module__ == module.__name__
]
# Ensure there is exactly one subclass of BaseRunner in the model.py file
if len(classes) != 1:
raise Exception(
"Could not determine model class. Please specify it in the config with class_info.model_class."
)
model_class = classes[0]

model_args = class_config.get("args", {})

# initialize the model class with the args.
model = model_class(**model_args)
if load_model:
model.load_model()
return model

def _validate_folder(self, folder):
if folder == ".":
folder = "" # will getcwd() next which ends with /
Expand Down Expand Up @@ -564,19 +617,19 @@ def monitor_model_build(self):
return False


def main(folder, download_checkpoints, skip_dockerfile):
uploader = ModelUploader(folder)
def upload_model(folder, download_checkpoints, skip_dockerfile):
builder = ModelBuilder(folder)
if download_checkpoints:
uploader.download_checkpoints()
builder.download_checkpoints()
if not skip_dockerfile:
uploader.create_dockerfile()
exists = uploader.check_model_exists()
builder.create_dockerfile()
exists = builder.check_model_exists()
if exists:
logger.info(
f"Model already exists at {uploader.model_url}, this upload will create a new version for it."
f"Model already exists at {builder.model_url}, this upload will create a new version for it."
)
else:
logger.info(f"New model will be created at {uploader.model_url} with it's first version.")
logger.info(f"New model will be created at {builder.model_url} with it's first version.")

input("Press Enter to continue...")
uploader.upload_model_version(download_checkpoints)
builder.upload_model_version(download_checkpoints)
3 changes: 1 addition & 2 deletions clarifai/runners/models/model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ def stream_wrapper(self, request: service_pb2.PostModelOutputsRequest
"""This method is used for input/output proto data conversion and yield outcome"""
return self.stream(request)

@abstractmethod
def load_model(self):
raise NotImplementedError("load_model() not implemented")
pass

@abstractmethod
def predict(self,
Expand Down
73 changes: 17 additions & 56 deletions clarifai/runners/models/model_run_locally.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import hashlib
import importlib.util
import inspect
import os
import platform
import shutil
Expand All @@ -14,9 +12,8 @@

from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
from clarifai_protocol import BaseRunner

from clarifai.runners.models.model_upload import ModelUploader
from clarifai.runners.models.model_builder import ModelBuilder
from clarifai.runners.utils.url_fetcher import ensure_urls_downloaded
from clarifai.utils.logging import logger

Expand All @@ -27,9 +24,9 @@ def __init__(self, model_path):
self.model_path = model_path
self.requirements_file = os.path.join(self.model_path, "requirements.txt")

# ModelUploader contains multiple useful methods to interact with the model
self.uploader = ModelUploader(self.model_path, download_validation_only=True)
self.config = self.uploader.config
# ModelBuilder contains multiple useful methods to interact with the model
self.builder = ModelBuilder(self.model_path, download_validation_only=True)
self.config = self.builder.config

def _requirements_hash(self):
"""Generate a hash of the requirements file."""
Expand Down Expand Up @@ -91,38 +88,10 @@ def install_requirements(self):
self.clean_up()
sys.exit(1)

def _get_model_runner(self):
"""Dynamically import the runner class from the model file."""

# import the runner class that to be implement by the user
runner_path = os.path.join(self.model_path, "1", "model.py")

# arbitrary name given to the module to be imported
module = "runner_module"

spec = importlib.util.spec_from_file_location(module, runner_path)
runner_module = importlib.util.module_from_spec(spec)
sys.modules[module] = runner_module
spec.loader.exec_module(runner_module)

# Find all classes in the model.py file that are subclasses of BaseRunner
classes = [
cls for _, cls in inspect.getmembers(runner_module, inspect.isclass)
if issubclass(cls, BaseRunner) and cls.__module__ == runner_module.__name__
]

# Ensure there is exactly one subclass of BaseRunner in the model.py file
if len(classes) != 1:
raise Exception("Expected exactly one subclass of BaseRunner, found: {}".format(
len(classes)))

MyRunner = classes[0]
return MyRunner

def _build_request(self):
"""Create a mock inference request for testing the model."""

model_version_proto = self.uploader.get_model_version_proto()
model_version_proto = self.builder.get_model_version_proto()
model_version_proto.id = "model_version"

return service_pb2.PostModelOutputsRequest(
Expand All @@ -142,8 +111,8 @@ def _build_stream_request(self):
for i in range(1):
yield request

def _run_model_inference(self, runner):
"""Perform inference using the runner."""
def _run_model_inference(self, model):
"""Perform inference using the model."""
request = self._build_request()
stream_request = self._build_stream_request()

Expand All @@ -152,7 +121,7 @@ def _run_model_inference(self, runner):
generate_response = None
stream_response = None
try:
predict_response = runner.predict(request)
predict_response = model.predict(request)
except NotImplementedError:
logger.info("Model does not implement predict() method.")
except Exception as e:
Expand All @@ -172,7 +141,7 @@ def _run_model_inference(self, runner):
logger.info(f"Model Prediction succeeded: {predict_response}")

try:
generate_response = runner.generate(request)
generate_response = model.generate(request)
except NotImplementedError:
logger.info("Model does not implement generate() method.")
except Exception as e:
Expand All @@ -194,7 +163,7 @@ def _run_model_inference(self, runner):
f"Model Prediction succeeded for generate and first response: {generate_first_res}")

try:
stream_response = runner.stream(stream_request)
stream_response = model.stream(stream_request)
except NotImplementedError:
logger.info("Model does not implement stream() method.")
except Exception as e:
Expand All @@ -217,16 +186,10 @@ def _run_model_inference(self, runner):

def _run_test(self):
"""Test the model locally by making a prediction."""
# construct MyRunner which will call load_model()
MyRunner = self._get_model_runner()
runner = MyRunner(
runner_id="n/a",
nodepool_id="n/a",
compute_cluster_id="n/a",
user_id="n/a",
)
# Create the model
model = self.builder.create_model_instance()
# send an inference.
self._run_model_inference(runner)
self._run_model_inference(model)

def test_model(self):
"""Test the model by running it locally in the virtual environment."""
Expand Down Expand Up @@ -274,7 +237,7 @@ def run_model_server(self, port=8080):

command = [
self.python_executable, "-m", "clarifai.runners.server", "--model_path", self.model_path,
"--start_dev_server", "--port",
"--grpc", "--port",
str(port)
]
try:
Expand Down Expand Up @@ -383,9 +346,7 @@ def run_docker_container(self,
# Add the image name
cmd.append(image_name)
# update the CMD to run the server
cmd.extend(
["--model_path", "/app/model_dir/main", "--start_dev_server", "--port",
str(port)])
cmd.extend(["--model_path", "/app/model_dir/main", "--grpc", "--port", str(port)])
# Run the container
process = subprocess.Popen(cmd,)
logger.info(
Expand Down Expand Up @@ -518,11 +479,11 @@ def main(model_path,
)
sys.exit(1)
manager = ModelRunLocally(model_path)
manager.uploader.download_checkpoints()
manager.builder.download_checkpoints()
if inside_container:
if not manager.is_docker_installed():
sys.exit(1)
manager.uploader.create_dockerfile()
manager.builder.create_dockerfile()
image_tag = manager._docker_hash()
image_name = f"{manager.config['model']['id']}:{image_tag}"
container_name = manager.config['model']['id']
Expand Down
14 changes: 6 additions & 8 deletions clarifai/runners/models/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@
from .model_class import ModelClass


class ModelRunner(BaseRunner, ModelClass, HealthProbeRequestHandler):
class ModelRunner(BaseRunner, HealthProbeRequestHandler):
"""
This is a subclass of the runner class which will handle only the work items relevant to models.

It is also a subclass of ModelClass so that any subclass of ModelRunner will need to just
implement predict(), generate() and stream() methods and load_model() if needed.
"""

def __init__(
self,
model: ModelClass,
runner_id: str,
nodepool_id: str,
compute_cluster_id: str,
Expand All @@ -43,7 +41,7 @@ def __init__(
num_parallel_polls,
**kwargs,
)
self.load_model()
self.model = model

# After model load successfully set the health probe to ready and startup
HealthProbeRequestHandler.is_ready = True
Expand Down Expand Up @@ -83,7 +81,7 @@ def runner_item_predict(self,
request = runner_item.post_model_outputs_request
ensure_urls_downloaded(request)

resp = self.predict_wrapper(request)
resp = self.model.predict_wrapper(request)
successes = [o.status.code == status_code_pb2.SUCCESS for o in resp.outputs]
if all(successes):
status = status_pb2.Status(
Expand Down Expand Up @@ -113,7 +111,7 @@ def runner_item_generate(
request = runner_item.post_model_outputs_request
ensure_urls_downloaded(request)

for resp in self.generate_wrapper(request):
for resp in self.model.generate_wrapper(request):
successes = []
for output in resp.outputs:
if not output.HasField('status') or not output.status.code:
Expand Down Expand Up @@ -141,7 +139,7 @@ def runner_item_generate(
def runner_item_stream(self, runner_item_iterator: Iterator[service_pb2.RunnerItem]
) -> Iterator[service_pb2.RunnerItemOutput]:
# Call the generate() method the underlying model implements.
for resp in self.stream_wrapper(pmo_iterator(runner_item_iterator)):
for resp in self.model.stream_wrapper(pmo_iterator(runner_item_iterator)):
successes = []
for output in resp.outputs:
if not output.HasField('status') or not output.status.code:
Expand Down
Loading
Loading