diff --git a/docs/inference_helpers/inference_sdk.md b/docs/inference_helpers/inference_sdk.md index 123556129..b60824010 100644 --- a/docs/inference_helpers/inference_sdk.md +++ b/docs/inference_helpers/inference_sdk.md @@ -553,6 +553,8 @@ The following fields are passed to API `disable_preproc_static_crop` to alter server-side pre-processing - `disable_active_learning` to prevent Active Learning feature from registering the datapoint (can be useful for instance while testing model) +- `source` Optional string to set a "source" attribute on the inference call; if using model monitoring, this will get logged with the inference request so you can filter/query inference requests coming from a particular source. e.g. to identify which application, system, or deployment is making the request. +- `source_info` Optional string to set additional "source_info" attribute on the inference call; e.g. to identify a sub component in an app. The following fields are passed to API @@ -571,6 +573,8 @@ The following fields are passed to API - `disable_preproc_auto_orientation`, `disable_preproc_contrast`, `disable_preproc_grayscale`, `disable_preproc_static_crop` to alter server-side pre-processing - `disable_active_learning` to prevent Active Learning feature from registering the datapoint (can be useful for instance while testing model) +- `source` Optional string to set a "source" attribute on the inference call; if using model monitoring, this will get logged with the inference request so you can filter/query inference requests coming from a particular source. e.g. to identify which application, system, or deployment is making the request. +- `source_info` Optional string to set additional "source_info" attribute on the inference call; e.g. to identify a sub component in an app. ### Classification model in `v1` mode: @@ -589,6 +593,9 @@ The following fields are passed to API `disable_preproc_static_crop` to alter server-side pre-processing * `disable_active_learning` to prevent Active Learning feature from registering the datapoint (can be useful for instance while testing model) +- `source` Optional string to set a "source" attribute on the inference call; if using model monitoring, this will get logged with the inference request so you can filter/query inference requests coming from a particular source. e.g. to identify which application, system, or deployment is making the request. +- `source_info` Optional string to set additional "source_info" attribute on the inference call; e.g. to identify a sub component in an app. + ### Object detection model in `v1` mode: - `visualize_predictions`: flag to enable / disable visualisation @@ -605,6 +612,8 @@ The following fields are passed to API `disable_preproc_static_crop` to alter server-side pre-processing - `disable_active_learning` to prevent Active Learning feature from registering the datapoint (can be useful for instance while testing model) +- `source` Optional string to set a "source" attribute on the inference call; if using model monitoring, this will get logged with the inference request so you can filter/query inference requests coming from a particular source. e.g. to identify which application, system, or deployment is making the request. +- `source_info` Optional string to set additional "source_info" attribute on the inference call; e.g. to identify a sub component in an app. ### Keypoints detection model in `v1` mode: @@ -624,6 +633,8 @@ The following fields are passed to API `disable_preproc_static_crop` to alter server-side pre-processing - `disable_active_learning` to prevent Active Learning feature from registering the datapoint (can be useful for instance while testing model) +- `source` Optional string to set a "source" attribute on the inference call; if using model monitoring, this will get logged with the inference request so you can filter/query inference requests coming from a particular source. e.g. to identify which application, system, or deployment is making the request. +- `source_info` Optional string to set additional "source_info" attribute on the inference call; e.g. to identify a sub component in an app. ### Instance segmentation model in `v1` mode: @@ -643,6 +654,8 @@ The following fields are passed to API - `tradeoff_factor` - `disable_active_learning` to prevent Active Learning feature from registering the datapoint (can be useful for instance while testing model) +- `source` Optional string to set a "source" attribute on the inference call; if using model monitoring, this will get logged with the inference request so you can filter/query inference requests coming from a particular source. e.g. to identify which application, system, or deployment is making the request. +- `source_info` Optional string to set additional "source_info" attribute on the inference call; e.g. to identify a sub component in an app. ### Configuration of client diff --git a/inference/enterprise/workflows/complier/steps_executors/models.py b/inference/enterprise/workflows/complier/steps_executors/models.py index 3d17e6586..20e9940c8 100644 --- a/inference/enterprise/workflows/complier/steps_executors/models.py +++ b/inference/enterprise/workflows/complier/steps_executors/models.py @@ -187,6 +187,7 @@ def construct_classification_request( image=image, confidence=resolve(step.confidence), disable_active_learning=resolve(step.disable_active_learning), + source="workflow-execution", ) @@ -213,6 +214,7 @@ def construct_object_detection_request( iou_threshold=resolve(step.iou_threshold), max_detections=resolve(step.max_detections), max_candidates=resolve(step.max_candidates), + source="workflow-execution", ) @@ -241,6 +243,7 @@ def construct_instance_segmentation_request( max_candidates=resolve(step.max_candidates), mask_decode_mode=resolve(step.mask_decode_mode), tradeoff_factor=resolve(step.tradeoff_factor), + source="workflow-execution", ) @@ -268,6 +271,7 @@ def construct_keypoints_detection_request( max_detections=resolve(step.max_detections), max_candidates=resolve(step.max_candidates), keypoint_confidence=resolve(step.keypoint_confidence), + source="workflow-execution", ) diff --git a/inference_sdk/http/entities.py b/inference_sdk/http/entities.py index 1ace4662f..341fbb179 100644 --- a/inference_sdk/http/entities.py +++ b/inference_sdk/http/entities.py @@ -86,6 +86,8 @@ class InferenceConfiguration: disable_active_learning: bool = False max_concurrent_requests: int = 1 max_batch_size: int = 1 + source: Optional[str] = None + source_info: Optional[str] = None @classmethod def init_default(cls) -> "InferenceConfiguration": @@ -130,6 +132,8 @@ def to_object_detection_parameters(self) -> Dict[str, Any]: ("stroke_width", "visualization_stroke_width"), ("visualize_predictions", "visualize_predictions"), ("disable_active_learning", "disable_active_learning"), + ("source", "source"), + ("source_info", "source_info"), ] return get_non_empty_attributes( source_object=self, @@ -141,6 +145,8 @@ def to_instance_segmentation_parameters(self) -> Dict[str, Any]: parameters_specs = [ ("mask_decode_mode", "mask_decode_mode"), ("tradeoff_factor", "tradeoff_factor"), + ("source", "source"), + ("source_info", "source_info"), ] for internal_name, external_name in parameters_specs: parameters[external_name] = getattr(self, internal_name) @@ -156,6 +162,8 @@ def to_classification_parameters(self) -> Dict[str, Any]: ("visualize_predictions", "visualize_predictions"), ("stroke_width", "visualization_stroke_width"), ("disable_active_learning", "disable_active_learning"), + ("source", "source"), + ("source_info", "source_info"), ] return get_non_empty_attributes( source_object=self, @@ -180,6 +188,8 @@ def to_legacy_call_parameters(self) -> Dict[str, Any]: ("disable_preproc_grayscale", "disable_preproc_grayscale"), ("disable_preproc_static_crop", "disable_preproc_static_crop"), ("disable_active_learning", "disable_active_learning"), + ("source", "source"), + ("source_info", "source_info"), ] return get_non_empty_attributes( source_object=self, diff --git a/tests/inference_sdk/unit_tests/http/test_client.py b/tests/inference_sdk/unit_tests/http/test_client.py index 394c5c92e..b35fc37f4 100644 --- a/tests/inference_sdk/unit_tests/http/test_client.py +++ b/tests/inference_sdk/unit_tests/http/test_client.py @@ -230,7 +230,7 @@ async def example() -> None: def test_setting_configuration_statically() -> None: # given http_client = InferenceHTTPClient(api_key="my-api-key", api_url="https://some.com") - configuration = InferenceConfiguration(visualize_labels=True) + configuration = InferenceConfiguration(visualize_labels=True, source="source-test") # when previous_configuration = http_client.inference_configuration diff --git a/tests/inference_sdk/unit_tests/http/test_entities.py b/tests/inference_sdk/unit_tests/http/test_entities.py index 0a0e24dd2..adddac3e0 100644 --- a/tests/inference_sdk/unit_tests/http/test_entities.py +++ b/tests/inference_sdk/unit_tests/http/test_entities.py @@ -28,6 +28,8 @@ visualize_labels=True, iou_threshold=0.7, disable_active_learning=True, + source="config-test", + source_info="config-test-source-info", ) @@ -55,6 +57,29 @@ def test_get_non_empty_attributes() -> None: } +def test_source_attributes() -> None: + # given + reference_object = InferenceConfiguration( + source="source-test", + source_info="source-info-test", + ) + + # when + result = get_non_empty_attributes( + source_object=reference_object, + specification=[ + ("source", "A"), + ("source_info", "B"), + ], + ) + + # then + assert result == { + "A": "source-test", + "B": "source-info-test", + } + + def test_to_api_call_parameters_for_api_v0() -> None: # when result = REFERENCE_IMAGE_CONFIGURATION.to_api_call_parameters( @@ -78,6 +103,8 @@ def test_to_api_call_parameters_for_api_v0() -> None: "disable_preproc_grayscale": True, "disable_preproc_static_crop": False, "disable_active_learning": True, + "source": "config-test", + "source_info": "config-test-source-info", } @@ -98,6 +125,8 @@ def test_to_api_call_parameters_for_api_v1_classification() -> None: "disable_preproc_grayscale": True, "disable_preproc_static_crop": False, "disable_active_learning": True, + "source": "config-test", + "source_info": "config-test-source-info", } @@ -124,4 +153,6 @@ def test_to_api_call_parameters_for_api_v1_object_detection() -> None: "visualize_predictions": False, "visualization_stroke_width": 1, "disable_active_learning": True, + "source": "config-test", + "source_info": "config-test-source-info", }