Skip to content

Commit

Permalink
Merge pull request #293 from roboflow/workflow-inference-source-param
Browse files Browse the repository at this point in the history
set source on workflow execution inference requests
  • Loading branch information
hansent authored Feb 28, 2024
2 parents 6f92bfb + 35d77f6 commit 7429c9a
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 1 deletion.
13 changes: 13 additions & 0 deletions docs/inference_helpers/inference_sdk.md
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,8 @@ The following fields are passed to API
`disable_preproc_static_crop` to alter server-side pre-processing
- `disable_active_learning` to prevent Active Learning feature from registering the datapoint (can be useful for
instance while testing model)
- `source` Optional string to set a "source" attribute on the inference call; if using model monitoring, this will get logged with the inference request so you can filter/query inference requests coming from a particular source. e.g. to identify which application, system, or deployment is making the request.
- `source_info` Optional string to set additional "source_info" attribute on the inference call; e.g. to identify a sub component in an app.

The following fields are passed to API

Expand All @@ -571,6 +573,8 @@ The following fields are passed to API
- `disable_preproc_auto_orientation`, `disable_preproc_contrast`, `disable_preproc_grayscale`,
`disable_preproc_static_crop` to alter server-side pre-processing
- `disable_active_learning` to prevent Active Learning feature from registering the datapoint (can be useful for instance while testing model)
- `source` Optional string to set a "source" attribute on the inference call; if using model monitoring, this will get logged with the inference request so you can filter/query inference requests coming from a particular source. e.g. to identify which application, system, or deployment is making the request.
- `source_info` Optional string to set additional "source_info" attribute on the inference call; e.g. to identify a sub component in an app.

### Classification model in `v1` mode:

Expand All @@ -589,6 +593,9 @@ The following fields are passed to API
`disable_preproc_static_crop` to alter server-side pre-processing
* `disable_active_learning` to prevent Active Learning feature from registering the datapoint (can be useful for instance while testing model)

- `source` Optional string to set a "source" attribute on the inference call; if using model monitoring, this will get logged with the inference request so you can filter/query inference requests coming from a particular source. e.g. to identify which application, system, or deployment is making the request.
- `source_info` Optional string to set additional "source_info" attribute on the inference call; e.g. to identify a sub component in an app.

### Object detection model in `v1` mode:

- `visualize_predictions`: flag to enable / disable visualisation
Expand All @@ -605,6 +612,8 @@ The following fields are passed to API
`disable_preproc_static_crop` to alter server-side pre-processing
- `disable_active_learning` to prevent Active Learning feature from registering the datapoint (can be useful for
instance while testing model)
- `source` Optional string to set a "source" attribute on the inference call; if using model monitoring, this will get logged with the inference request so you can filter/query inference requests coming from a particular source. e.g. to identify which application, system, or deployment is making the request.
- `source_info` Optional string to set additional "source_info" attribute on the inference call; e.g. to identify a sub component in an app.

### Keypoints detection model in `v1` mode:

Expand All @@ -624,6 +633,8 @@ The following fields are passed to API
`disable_preproc_static_crop` to alter server-side pre-processing
- `disable_active_learning` to prevent Active Learning feature from registering the datapoint (can be useful for
instance while testing model)
- `source` Optional string to set a "source" attribute on the inference call; if using model monitoring, this will get logged with the inference request so you can filter/query inference requests coming from a particular source. e.g. to identify which application, system, or deployment is making the request.
- `source_info` Optional string to set additional "source_info" attribute on the inference call; e.g. to identify a sub component in an app.

### Instance segmentation model in `v1` mode:

Expand All @@ -643,6 +654,8 @@ The following fields are passed to API
- `tradeoff_factor`
- `disable_active_learning` to prevent Active Learning feature from registering the datapoint (can be useful for
instance while testing model)
- `source` Optional string to set a "source" attribute on the inference call; if using model monitoring, this will get logged with the inference request so you can filter/query inference requests coming from a particular source. e.g. to identify which application, system, or deployment is making the request.
- `source_info` Optional string to set additional "source_info" attribute on the inference call; e.g. to identify a sub component in an app.

### Configuration of client

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def construct_classification_request(
image=image,
confidence=resolve(step.confidence),
disable_active_learning=resolve(step.disable_active_learning),
source="workflow-execution",
)


Expand All @@ -213,6 +214,7 @@ def construct_object_detection_request(
iou_threshold=resolve(step.iou_threshold),
max_detections=resolve(step.max_detections),
max_candidates=resolve(step.max_candidates),
source="workflow-execution",
)


Expand Down Expand Up @@ -241,6 +243,7 @@ def construct_instance_segmentation_request(
max_candidates=resolve(step.max_candidates),
mask_decode_mode=resolve(step.mask_decode_mode),
tradeoff_factor=resolve(step.tradeoff_factor),
source="workflow-execution",
)


Expand Down Expand Up @@ -268,6 +271,7 @@ def construct_keypoints_detection_request(
max_detections=resolve(step.max_detections),
max_candidates=resolve(step.max_candidates),
keypoint_confidence=resolve(step.keypoint_confidence),
source="workflow-execution",
)


Expand Down
10 changes: 10 additions & 0 deletions inference_sdk/http/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class InferenceConfiguration:
disable_active_learning: bool = False
max_concurrent_requests: int = 1
max_batch_size: int = 1
source: Optional[str] = None
source_info: Optional[str] = None

@classmethod
def init_default(cls) -> "InferenceConfiguration":
Expand Down Expand Up @@ -130,6 +132,8 @@ def to_object_detection_parameters(self) -> Dict[str, Any]:
("stroke_width", "visualization_stroke_width"),
("visualize_predictions", "visualize_predictions"),
("disable_active_learning", "disable_active_learning"),
("source", "source"),
("source_info", "source_info"),
]
return get_non_empty_attributes(
source_object=self,
Expand All @@ -141,6 +145,8 @@ def to_instance_segmentation_parameters(self) -> Dict[str, Any]:
parameters_specs = [
("mask_decode_mode", "mask_decode_mode"),
("tradeoff_factor", "tradeoff_factor"),
("source", "source"),
("source_info", "source_info"),
]
for internal_name, external_name in parameters_specs:
parameters[external_name] = getattr(self, internal_name)
Expand All @@ -156,6 +162,8 @@ def to_classification_parameters(self) -> Dict[str, Any]:
("visualize_predictions", "visualize_predictions"),
("stroke_width", "visualization_stroke_width"),
("disable_active_learning", "disable_active_learning"),
("source", "source"),
("source_info", "source_info"),
]
return get_non_empty_attributes(
source_object=self,
Expand All @@ -180,6 +188,8 @@ def to_legacy_call_parameters(self) -> Dict[str, Any]:
("disable_preproc_grayscale", "disable_preproc_grayscale"),
("disable_preproc_static_crop", "disable_preproc_static_crop"),
("disable_active_learning", "disable_active_learning"),
("source", "source"),
("source_info", "source_info"),
]
return get_non_empty_attributes(
source_object=self,
Expand Down
2 changes: 1 addition & 1 deletion tests/inference_sdk/unit_tests/http/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ async def example() -> None:
def test_setting_configuration_statically() -> None:
# given
http_client = InferenceHTTPClient(api_key="my-api-key", api_url="https://some.com")
configuration = InferenceConfiguration(visualize_labels=True)
configuration = InferenceConfiguration(visualize_labels=True, source="source-test")

# when
previous_configuration = http_client.inference_configuration
Expand Down
31 changes: 31 additions & 0 deletions tests/inference_sdk/unit_tests/http/test_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
visualize_labels=True,
iou_threshold=0.7,
disable_active_learning=True,
source="config-test",
source_info="config-test-source-info",
)


Expand Down Expand Up @@ -55,6 +57,29 @@ def test_get_non_empty_attributes() -> None:
}


def test_source_attributes() -> None:
# given
reference_object = InferenceConfiguration(
source="source-test",
source_info="source-info-test",
)

# when
result = get_non_empty_attributes(
source_object=reference_object,
specification=[
("source", "A"),
("source_info", "B"),
],
)

# then
assert result == {
"A": "source-test",
"B": "source-info-test",
}


def test_to_api_call_parameters_for_api_v0() -> None:
# when
result = REFERENCE_IMAGE_CONFIGURATION.to_api_call_parameters(
Expand All @@ -78,6 +103,8 @@ def test_to_api_call_parameters_for_api_v0() -> None:
"disable_preproc_grayscale": True,
"disable_preproc_static_crop": False,
"disable_active_learning": True,
"source": "config-test",
"source_info": "config-test-source-info",
}


Expand All @@ -98,6 +125,8 @@ def test_to_api_call_parameters_for_api_v1_classification() -> None:
"disable_preproc_grayscale": True,
"disable_preproc_static_crop": False,
"disable_active_learning": True,
"source": "config-test",
"source_info": "config-test-source-info",
}


Expand All @@ -124,4 +153,6 @@ def test_to_api_call_parameters_for_api_v1_object_detection() -> None:
"visualize_predictions": False,
"visualization_stroke_width": 1,
"disable_active_learning": True,
"source": "config-test",
"source_info": "config-test-source-info",
}

0 comments on commit 7429c9a

Please sign in to comment.