Skip to content

Commit

Permalink
Support Input Filters in Search (#204)
Browse files Browse the repository at this point in the history
* wip

* change input_types to list

* add input search tests

* add query example in docstring

* fix data upload test

---------

Co-authored-by: Isaac Chung <[email protected]>
  • Loading branch information
sainivedh and isaac-chung authored Nov 6, 2023
1 parent 8758457 commit f8ebe13
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 32 deletions.
1 change: 0 additions & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ on:
push:
branches: [ master ]
pull_request:
pull_request_target:

jobs:
build:
Expand Down
120 changes: 93 additions & 27 deletions clarifai/client/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def __init__(self,
metric: str = DEFAULT_SEARCH_METRIC):
"""Initialize the Search object.
Args:
user_id (str): User ID.
app_id (str): App ID.
top_k (int, optional): Top K results to retrieve. Defaults to 10.
metric (str, optional): Similarity metric (either 'cosine' or 'euclidean'). Defaults to 'cosine'.
Raises:
UserError: If the metric is not 'cosine' or 'euclidean'.
"""
Args:
user_id (str): User ID.
app_id (str): App ID.
top_k (int, optional): Top K results to retrieve. Defaults to 10.
metric (str, optional): Similarity metric (either 'cosine' or 'euclidean'). Defaults to 'cosine'.
Raises:
UserError: If the metric is not 'cosine' or 'euclidean'.
"""
if metric not in ["cosine", "euclidean"]:
raise UserError("Metric should be either cosine or euclidean")

Expand All @@ -48,12 +48,12 @@ def __init__(self,
def _get_annot_proto(self, **kwargs):
"""Get an Annotation proto message based on keyword arguments.
Args:
**kwargs: Keyword arguments specifying the resource.
Args:
**kwargs: Keyword arguments specifying the resource.
Returns:
resources_pb2.Annotation: An Annotation proto message.
"""
Returns:
resources_pb2.Annotation: An Annotation proto message.
"""
if not kwargs:
return resources_pb2.Annotation()

Expand Down Expand Up @@ -91,18 +91,52 @@ def _get_annot_proto(self, **kwargs):
raise UserError(f"kwargs contain key that is not supported: {key}")
return resources_pb2.Annotation(data=self.data_proto)

def _get_input_proto(self, **kwargs):
"""Get an Input proto message based on keyword arguments.
Args:
**kwargs: Keyword arguments specifying the resource.
Returns:
resources_pb2.Input: An Input proto message.
"""
if not kwargs:
return resources_pb2.Input()

self.input_proto = resources_pb2.Input()
self.data_proto = resources_pb2.Data()
for key, value in kwargs.items():
if key == "input_types":
for input_type in value:
if input_type == "image":
self.data_proto.image.CopyFrom(resources_pb2.Image())
elif input_type == "text":
self.data_proto.text.CopyFrom(resources_pb2.Text())
elif input_type == "audio":
self.data_proto.audio.CopyFrom(resources_pb2.Audio())
elif input_type == "video":
self.data_proto.video.CopyFrom(resources_pb2.Video())
self.input_proto.data.CopyFrom(self.data_proto)
elif key == "input_dataset_ids":
self.input_proto.dataset_ids = value
elif key == "input_status_code":
self.input_proto.status.code = value
else:
raise UserError(f"kwargs contain key that is not supported: {key}")
return self.input_proto

def _get_geo_point_proto(self, longitude: float, latitude: float,
geo_limit: float) -> resources_pb2.Geo:
"""Get a GeoPoint proto message based on geographical data.
Args:
longitude (float): Longitude coordinate.
latitude (float): Latitude coordinate.
geo_limit (float): Geographical limit.
Args:
longitude (float): Longitude coordinate.
latitude (float): Latitude coordinate.
geo_limit (float): Geographical limit.
Returns:
resources_pb2.Geo: A Geo proto message.
"""
Returns:
resources_pb2.Geo: A Geo proto message.
"""
return resources_pb2.Geo(
geo_point=resources_pb2.GeoPoint(longitude=longitude, latitude=latitude),
geo_limit=resources_pb2.GeoLimit(type="withinKilometers", value=geo_limit))
Expand Down Expand Up @@ -137,19 +171,51 @@ def list_all_pages_generator(
def query(self, ranks=[{}], filters=[{}]):
"""Perform a query with rank and filters.
Args:
ranks (List[Dict], optional): List of rank parameters. Defaults to [{}].
filters (List[Dict], optional): List of filter parameters. Defaults to [{}].
Args:
ranks (List[Dict], optional): List of rank parameters. Defaults to [{}].
filters (List[Dict], optional): List of filter parameters. Defaults to [{}].
Returns:
Generator[Dict[str, Any], None, None]: A generator of query results.
"""
Returns:
Generator[Dict[str, Any], None, None]: A generator of query results.
Examples:
Get successful inputs of type image or text
>>> from clarifai.client.search import Search
>>> search = Search(user_id='user_id', app_id='app_id', top_k=10, metric='cosine')
>>> res = search.query(filters=[{'input_types': ['image', 'text']}, {'input_status_code': 30000}])
Vector search over inputs
>>> from clarifai.client.search import Search
>>> search = Search(user_id='user_id', app_id='app_id', top_k=10, metric='cosine')
>>> res = search.query(ranks=[{'image_url': 'https://samples.clarifai.com/dog.tiff'}])
Note: For more detailed search examples, please refer to [examples](https://github.com/Clarifai/examples/tree/main/search).
"""
try:
self.rank_filter_schema.validate(ranks)
self.rank_filter_schema.validate(filters)
except SchemaError as err:
raise UserError(f"Invalid rank or filter input: {err}")

## Calls PostInputsSearches for input filters
if any(["input" in k for k in filters[0].keys()]):
filters_input_proto = []
for filter_dict in filters:
filters_input_proto.append(self._get_input_proto(**filter_dict))
all_filters = [
resources_pb2.Filter(input=filter_input) for filter_input in filters_input_proto
]
request_data = dict(
user_app_id=self.user_app_id,
searches=[
resources_pb2.Search(
query=resources_pb2.Query(filters=all_filters), metric=self.metric_distance)
])

return self.list_all_pages_generator(self.STUB.PostInputsSearches,
service_pb2.PostInputsSearchesRequest, request_data)

# Calls PostAnnotationsSearches for annotation ranks, filters
rank_annot_proto, filters_annot_proto = [], []
for rank_dict in ranks:
rank_annot_proto.append(self._get_annot_proto(**rank_dict))
Expand Down
9 changes: 9 additions & 0 deletions clarifai/schema/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ def get_schema() -> Schema:
},
Optional("concepts"):
And(list, lambda x: all(concept_schema.is_valid(item) and len(item) > 0 for item in x)),

## input filters
Optional('input_types'):
And(list, lambda input_types: all(input_type in ('image', 'video', 'text', 'audio')
for input_type in input_types)),
Optional('input_dataset_ids'):
list,
Optional('input_status_code'):
int,
})

# Schema for rank and filter args
Expand Down
1 change: 0 additions & 1 deletion tests/test_data_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def test_upload_csv(self, caplog):
with caplog.at_level(logging.INFO):
self.input_object.delete_inputs(uploaded_inputs)
assert "Inputs Deleted" in caplog.text # Testing delete inputs action
assert uploaded_inputs[0].data.concepts[0].name == 'neg' # label of the first input in the CSV file
assert len(uploaded_inputs) == 5 # 5 inputs are uploaded from the CSV file
assert len(concepts) == 2 # Test for list concepts

Expand Down
52 changes: 49 additions & 3 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing

import pytest
from google.protobuf import struct_pb2

from clarifai.client.search import Search
from clarifai.client.user import User
Expand Down Expand Up @@ -55,7 +56,42 @@ def get_filters_for_test() -> [(typing.List[typing.Dict], int)]:
}]
}
],
0)
0),
(
[

{
"metadata": {"Breed": "Saint Bernard"}
}
],
1),

# Input Search
(
[
{ # AND
"input_types": ["image"],
},
{
"input_status_code": 30000 # Download Success
}
],
1),
(
[
{
"input_types": ["text", "audio", "video"],
}
],
0),
(
[
{ # OR
"input_types": ["text", "audio", "video"],
"input_status_code": 30000 # Download Success
},
],
1)
]


Expand All @@ -71,12 +107,14 @@ def setup_class(cls):
@classmethod
def upload_input(self):
inp_obj = self.client.create_app(CREATE_APP_ID, base_workflow="General").inputs()
metadata = struct_pb2.Struct()
metadata.update({"Breed": "Saint Bernard"})
input_proto = inp_obj.get_input_from_url(
input_id="dog-tiff",
image_url=DOG_IMG_URL,
labels=["dog"],
geo_info=[-30.0, 40.0] # longitude, latitude
)
geo_info=[-30.0, 40.0], # longitude, latitude
metadata=metadata)
inp_obj.upload_inputs([input_proto])

@pytest.mark.parametrize("filter_dict_list,expected_hits", get_filters_for_test())
Expand Down Expand Up @@ -128,5 +166,13 @@ def test_schema_error(self):
}]
}])

# Incorrect input type search
with pytest.raises(UserError):
_ = self.search.query(filters=[{"input_types": ["imaage"]}])

# Incorrect input search filter key
with pytest.raises(UserError):
_ = self.search.query(filters=[{"input_id": "test"}])

def teardown_class(cls):
cls.client.delete_app(app_id=CREATE_APP_ID)

0 comments on commit f8ebe13

Please sign in to comment.