Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference params Support in Model Predict #200

Merged
merged 9 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,20 +132,28 @@ all_concepts = list(app.list_concepts())
# Note: CLARIFAI_PAT must be set as env variable.
from clarifai.client.model import Model

# Model Predict
model_prediction = Model("https://clarifai.com/anthropic/completion/models/claude-v2").predict_by_bytes(b"Write a tweet on future of AI", "text")
"""
Get Model information on details of model(description, usecases..etc) and info on training or
# other inference parameters(eg: temperature, top_k, max_tokens..etc for LLMs)
"""
gpt_4_model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
print(gpt_4_model)


model = Model(user_id="user_id", app_id="app_id", model_id="model_id")
model_prediction = model.predict_by_url(url="url", input_type="image") # Supports image, text, audio, video
# Model Predict
model_prediction = Model("https://clarifai.com/anthropic/completion/models/claude-v2").predict_by_bytes(b"Write a tweet on future of AI", input_type="text")

# Customizing Model Inference Output
model = Model(user_id="user_id", app_id="app_id", model_id="model_id",
output_config={"min_value": 0.98}) # Return predictions having prediction confidence > 0.98
model_prediction = model.predict_by_filepath(filepath="local_filepath", input_type="text") # Supports image, text, audio, video
model_prediction = gpt_4_model.predict_by_bytes(b"Write a tweet on future of AI", "text", inference_params=dict(temperature=str(0.7), max_tokens=30))
# Return predictions having prediction confidence > 0.98
model_prediction = model.predict_by_filepath(filepath="local_filepath", input_type, output_config={"min_value": 0.98}) # Supports image, text, audio, video

# Supports prediction by url
model_prediction = model.predict_by_url(url="url", input_type) # Supports image, text, audio, video

model = Model(user_id="user_id", app_id="app_id", model_id="model_id",
output_config={"sample_ms": 2000}) # Return predictions for specified interval
model_prediction = model.predict_by_url(url="VIDEO_URL", input_type="video")
# Return predictions for specified interval of video
video_input_proto = [input_obj.get_input_from_url("Input_id", video_url=BEER_VIDEO_URL)]
model_prediction = model.predict(video_input_proto, input_type="video", output_config={"sample_ms": 2000})
```
#### Models Listing
```python
Expand Down
18 changes: 15 additions & 3 deletions clarifai/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def convert_string_to_timestamp(self, date_str) -> Timestamp:

return timestamp_obj

def process_response_keys(self, old_dict, listing_resource):
def process_response_keys(self, old_dict, listing_resource=None):
"""Converts keys in a response dictionary to resource proto format.

Args:
Expand All @@ -91,13 +91,25 @@ def process_response_keys(self, old_dict, listing_resource):
Returns:
new_dict (dict): The dictionary with processed keys.
"""
old_dict[f'{listing_resource}_id'] = old_dict['id']
old_dict.pop('id')
if listing_resource:
old_dict[f'{listing_resource}_id'] = old_dict['id']
old_dict.pop('id')

def convert_recursive(item):
if isinstance(item, dict):
new_item = {}
for key, value in item.items():
if key == 'default_value':
# Map infer param value to proto value
value_map = dict(number_value=None, string_value=None, bool_value=None)

def map_fn(v):
return 'number_value' if isinstance(v, float) or isinstance(v, int) else \
'string_value' if isinstance(v, str) else \
'bool_value' if isinstance(v, bool) else None

value_map[map_fn(value)] = value
value = struct_pb2.Value(**value_map)
if key in ['created_at', 'modified_at', 'completed_at']:
value = self.convert_string_to_timestamp(value)
elif key in ['workflow_recommended']:
Expand Down
140 changes: 99 additions & 41 deletions clarifai/client/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from clarifai_grpc.grpc.api.resources_pb2 import Input
from clarifai_grpc.grpc.api.status import status_code_pb2
from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Struct

from clarifai.client.base import BaseClient
from clarifai.client.input import Inputs
from clarifai.client.lister import Lister
from clarifai.errors import UserError
from clarifai.urls.helper import ClarifaiUrlHelper
Expand All @@ -22,7 +24,6 @@ def __init__(self,
url_init: str = "",
model_id: str = "",
model_version: Dict = {'id': ""},
output_config: Dict = {'min_value': 0},
base_url: str = "https://api.clarifai.com",
**kwargs):
"""Initializes a Model object.
Expand All @@ -31,11 +32,6 @@ def __init__(self,
url_init (str): The URL to initialize the model object.
model_id (str): The Model ID to interact with.
model_version (dict): The Model Version to interact with.
output_config (dict): The output config to interact with.
min_value (float): The minimum value of the prediction confidence to filter.
max_concepts (int): The maximum number of concepts to return.
select_concepts (list[Concept]): The concepts to select.
sample_ms (int): The number of milliseconds to sample.
base_url (str): Base API url. Default "https://api.clarifai.com"
**kwargs: Additional keyword arguments to be passed to the Model.
"""
Expand All @@ -48,8 +44,7 @@ def __init__(self,
url_init)
model_version = {'id': model_version_id}
kwargs = {'user_id': user_id, 'app_id': app_id}
self.kwargs = {**kwargs, 'id': model_id, 'model_version': model_version,
'output_info': {'output_config': output_config}}
self.kwargs = {**kwargs, 'id': model_id, 'model_version': model_version,}
self.model_info = resources_pb2.Model(**self.kwargs)
self.logger = get_logger(logger_level="INFO")
BaseClient.__init__(self, user_id=self.user_id, app_id=self.app_id, base=base_url)
Expand Down Expand Up @@ -128,15 +123,18 @@ def list_versions(self, page_no: int = None,
del model_version_info['model_version_id']
yield Model(model_id=self.id, **dict(self.kwargs, model_version=model_version_info))

def predict(self, inputs: List[Input]):
def predict(self, inputs: List[Input], inference_params: Dict = {}, output_config: Dict = {}):
"""Predicts the model based on the given inputs.

Args:
inputs (list[Input]): The inputs to predict, must be less than 128.
"""
if not isinstance(inputs, list):
raise UserError('Invalid inputs, inputs must be a list of Input objects.')
if len(inputs) > 128:
raise UserError("Too many inputs. Max is 128.") # TODO Use Chunker for inputs len > 128

self._override_model_version(inference_params, output_config)
request = service_pb2.PostModelOutputsRequest(
user_app_id=self.user_app_id,
model_id=self.id,
Expand All @@ -149,8 +147,7 @@ def predict(self, inputs: List[Input]):
while True:
response = self._grpc_request(self.STUB.PostModelOutputs, request)

if response.outputs and \
response.outputs[0].status.code == status_code_pb2.MODEL_DEPLOYING and \
if response.status.code == status_code_pb2.MODEL_DEPLOYING and \
time.time() - start_time < 60 * 10: # 10 minutes
self.logger.info(f"{self.id} model is still deploying, please wait...")
time.sleep(next(backoff_iterator))
Expand All @@ -163,12 +160,21 @@ def predict(self, inputs: List[Input]):

return response

def predict_by_filepath(self, filepath: str, input_type: str):
def predict_by_filepath(self,
filepath: str,
input_type: str,
inference_params: Dict = {},
output_config: Dict = {}):
"""Predicts the model based on the given filepath.

Args:
filepath (str): The filepath to predict.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
inference_params (dict): The inference params to override.
output_config (dict): The output config to override.
min_value (float): The minimum value of the prediction confidence to filter.
max_concepts (int): The maximum number of concepts to return.
select_concepts (list[Concept]): The concepts to select.

Example:
>>> from clarifai.client.model import Model
Expand All @@ -178,54 +184,70 @@ def predict_by_filepath(self, filepath: str, input_type: str):
>>> model_prediction = model.predict_by_filepath('/path/to/image.jpg', 'image')
>>> model_prediction = model.predict_by_filepath('/path/to/text.txt', 'text')
"""
if input_type not in ['image', 'text', 'video', 'audio']:
raise UserError('Invalid input type it should be image, text, video or audio.')
if not os.path.isfile(filepath):
raise UserError('Invalid filepath.')

with open(filepath, "rb") as f:
file_bytes = f.read()

return self.predict_by_bytes(file_bytes, input_type)
return self.predict_by_bytes(file_bytes, input_type, inference_params, output_config)

def predict_by_bytes(self, input_bytes: bytes, input_type: str):
def predict_by_bytes(self,
input_bytes: bytes,
input_type: str,
inference_params: Dict = {},
output_config: Dict = {}):
"""Predicts the model based on the given bytes.

Args:
input_bytes (bytes): File Bytes to predict on.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio'.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
inference_params (dict): The inference params to override.
output_config (dict): The output config to override.
min_value (float): The minimum value of the prediction confidence to filter.
max_concepts (int): The maximum number of concepts to return.
select_concepts (list[Concept]): The concepts to select.

Example:
>>> from clarifai.client.model import Model
>>> model = Model("https://clarifai.com/anthropic/completion/models/claude-v2")
>>> model_prediction = model.predict_by_bytes(b'Write a tweet on future of AI', 'text')
>>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
>>> model_prediction = model.predict_by_bytes(b'Write a tweet on future of AI',
input_type='text',
inference_params=dict(temperature=str(0.7), max_tokens=30)))
"""
if input_type not in {'image', 'text', 'video', 'audio'}:
raise UserError('Invalid input type it should be image, text, video or audio.')
raise UserError(
f"Got input type {input_type} but expected one of image, text, video, audio.")
if not isinstance(input_bytes, bytes):
raise UserError('Invalid bytes.')
# TODO will obtain proto from input class

if input_type == "image":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(image=resources_pb2.Image(base64=input_bytes)))
input_proto = Inputs().get_input_from_bytes("", image_bytes=input_bytes)
elif input_type == "text":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=input_bytes)))
input_proto = Inputs().get_input_from_bytes("", text_bytes=input_bytes)
elif input_type == "video":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(video=resources_pb2.Video(base64=input_bytes)))
input_proto = Inputs().get_input_from_bytes("", video_bytes=input_bytes)
elif input_type == "audio":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(audio=resources_pb2.Audio(base64=input_bytes)))
input_proto = Inputs().get_input_from_bytes("", audio_bytes=input_bytes)

return self.predict(inputs=[input_proto])
return self.predict(
inputs=[input_proto], inference_params=inference_params, output_config=output_config)

def predict_by_url(self, url: str, input_type: str):
def predict_by_url(self,
url: str,
input_type: str,
inference_params: Dict = {},
output_config: Dict = {}):
"""Predicts the model based on the given URL.

Args:
url (str): The URL to predict.
input_type (str): The type of input. Can be 'image', 'text', 'video' or 'audio.
inference_params (dict): The inference params to override.
output_config (dict): The output config to override.
min_value (float): The minimum value of the prediction confidence to filter.
max_concepts (int): The maximum number of concepts to return.
select_concepts (list[Concept]): The concepts to select.

Example:
>>> from clarifai.client.model import Model
Expand All @@ -235,26 +257,62 @@ def predict_by_url(self, url: str, input_type: str):
>>> model_prediction = model.predict_by_url('url', 'image')
"""
if input_type not in {'image', 'text', 'video', 'audio'}:
raise UserError('Invalid input type it should be image, text, video or audio.')
# TODO will be obtain proto from input class
raise UserError(
f"Got input type {input_type} but expected one of image, text, video, audio.")

if input_type == "image":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(image=resources_pb2.Image(url=url)))
input_proto = Inputs().get_input_from_url("", image_url=url)
elif input_type == "text":
input_proto = resources_pb2.Input(data=resources_pb2.Data(text=resources_pb2.Text(url=url)))
input_proto = Inputs().get_input_from_url("", text_url=url)
elif input_type == "video":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(video=resources_pb2.Video(url=url)))
input_proto = Inputs().get_input_from_url("", video_url=url)
elif input_type == "audio":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(audio=resources_pb2.Audio(url=url)))
input_proto = Inputs().get_input_from_url("", audio_url=url)

return self.predict(inputs=[input_proto])
return self.predict(
inputs=[input_proto], inference_params=inference_params, output_config=output_config)

def _override_model_version(self, inference_params: Dict = {}, output_config: Dict = {}) -> None:
"""Overrides the model version.

Args:
inference_params (dict): The inference params to override.
output_config (dict): The output config to override.
min_value (float): The minimum value of the prediction confidence to filter.
max_concepts (int): The maximum number of concepts to return.
select_concepts (list[Concept]): The concepts to select.
sample_ms (int): The number of milliseconds to sample.
"""
if inference_params is not None:
params = Struct()
params.update(inference_params)

self.model_info.model_version.output_info.CopyFrom(
resources_pb2.OutputInfo(
output_config=resources_pb2.OutputConfig(**output_config), params=params))

def load_info(self) -> None:
sanjaychelliah marked this conversation as resolved.
Show resolved Hide resolved
"""Loads the model info."""
request = service_pb2.GetModelRequest(
user_app_id=self.user_app_id,
model_id=self.id,
version_id=self.model_info.model_version.id)
response = self._grpc_request(self.STUB.GetModel, request)

if response.status.code != status_code_pb2.SUCCESS:
raise Exception(response.status)

dict_response = MessageToDict(response, preserving_proto_field_name=True)
self.kwargs = self.process_response_keys(dict_response['model'])
self.model_info = resources_pb2.Model(**self.kwargs)

def __getattr__(self, name):
return getattr(self.model_info, name)

def __str__(self):
if len(self.kwargs) < 10:
self.load_info()

init_params = [param for param in self.kwargs.keys()]
attribute_strings = [
f"{param}={getattr(self.model_info, param)}" for param in init_params
Expand Down
24 changes: 9 additions & 15 deletions clarifai/client/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from clarifai_grpc.grpc.api.status import status_code_pb2

from clarifai.client.base import BaseClient
from clarifai.client.input import Inputs
from clarifai.client.lister import Lister
from clarifai.errors import UserError
from clarifai.urls.helper import ClarifaiUrlHelper
Expand Down Expand Up @@ -111,17 +112,13 @@ def predict_by_bytes(self, input_bytes: bytes, input_type: str):
raise UserError('Invalid bytes.')

if input_type == "image":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(image=resources_pb2.Image(base64=input_bytes)))
input_proto = Inputs().get_input_from_bytes("", image_bytes=input_bytes)
elif input_type == "text":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=input_bytes)))
input_proto = Inputs().get_input_from_bytes("", text_bytes=input_bytes)
elif input_type == "video":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(video=resources_pb2.Video(base64=input_bytes)))
input_proto = Inputs().get_input_from_bytes("", video_bytes=input_bytes)
elif input_type == "audio":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(audio=resources_pb2.Audio(base64=input_bytes)))
input_proto = Inputs().get_input_from_bytes("", audio_bytes=input_bytes)

return self.predict(inputs=[input_proto])

Expand All @@ -143,16 +140,13 @@ def predict_by_url(self, url: str, input_type: str):
raise UserError('Invalid input type it should be image, text, video or audio.')

if input_type == "image":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(image=resources_pb2.Image(url=url)))
input_proto = Inputs().get_input_from_url("", image_url=url)
elif input_type == "text":
input_proto = resources_pb2.Input(data=resources_pb2.Data(text=resources_pb2.Text(url=url)))
input_proto = Inputs().get_input_from_url("", text_url=url)
elif input_type == "video":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(video=resources_pb2.Video(url=url)))
input_proto = Inputs().get_input_from_url("", video_url=url)
elif input_type == "audio":
input_proto = resources_pb2.Input(
data=resources_pb2.Data(audio=resources_pb2.Audio(url=url)))
input_proto = Inputs().get_input_from_url("", audio_url=url)

return self.predict(inputs=[input_proto])

Expand Down
Loading